diff --git a/src/NATS.Server/JetStream/Api/Handlers/ConsumerApiHandlers.cs b/src/NATS.Server/JetStream/Api/Handlers/ConsumerApiHandlers.cs index 8d38c16..4a8a11c 100644 --- a/src/NATS.Server/JetStream/Api/Handlers/ConsumerApiHandlers.cs +++ b/src/NATS.Server/JetStream/Api/Handlers/ConsumerApiHandlers.cs @@ -68,6 +68,16 @@ public static class ConsumerApiHandlers if (root.TryGetProperty("heartbeat_ms", out var hbEl) && hbEl.TryGetInt32(out var hbMs)) config.HeartbeatMs = hbMs; + if (root.TryGetProperty("ack_wait_ms", out var ackWaitEl) && ackWaitEl.TryGetInt32(out var ackWait)) + config.AckWaitMs = ackWait; + + if (root.TryGetProperty("ack_policy", out var ackPolicyEl)) + { + var ackPolicy = ackPolicyEl.GetString(); + if (string.Equals(ackPolicy, "explicit", StringComparison.OrdinalIgnoreCase)) + config.AckPolicy = AckPolicy.Explicit; + } + return config; } catch (JsonException) diff --git a/src/NATS.Server/JetStream/ConsumerManager.cs b/src/NATS.Server/JetStream/ConsumerManager.cs index a235db0..36d778e 100644 --- a/src/NATS.Server/JetStream/ConsumerManager.cs +++ b/src/NATS.Server/JetStream/ConsumerManager.cs @@ -66,7 +66,7 @@ public sealed class ConsumerManager public void OnPublished(string stream, StoredMessage message) { foreach (var handle in _consumers.Values.Where(c => c.Stream == stream && c.Config.Push)) - _pushConsumerEngine.Enqueue(handle.PushFrames, message, handle.Config); + _pushConsumerEngine.Enqueue(handle, message); } public PushFrame? ReadPushFrame(string stream, string durableName) @@ -86,4 +86,5 @@ public sealed record ConsumerHandle(string Stream, ConsumerConfig Config) public ulong NextSequence { get; set; } = 1; public Queue Pending { get; } = new(); public Queue PushFrames { get; } = new(); + public AckProcessor AckProcessor { get; } = new(); } diff --git a/src/NATS.Server/JetStream/Consumers/AckProcessor.cs b/src/NATS.Server/JetStream/Consumers/AckProcessor.cs new file mode 100644 index 0000000..ee44837 --- /dev/null +++ b/src/NATS.Server/JetStream/Consumers/AckProcessor.cs @@ -0,0 +1,24 @@ +namespace NATS.Server.JetStream.Consumers; + +public sealed class AckProcessor +{ + private readonly Dictionary _pending = new(); + + public void Register(ulong sequence, int ackWaitMs) + { + _pending[sequence] = DateTime.UtcNow.AddMilliseconds(Math.Max(ackWaitMs, 1)); + } + + public ulong? NextExpired() + { + foreach (var (seq, deadline) in _pending) + { + if (DateTime.UtcNow >= deadline) + return seq; + } + + return null; + } + + public bool HasPending => _pending.Count > 0; +} diff --git a/src/NATS.Server/JetStream/Consumers/PullConsumerEngine.cs b/src/NATS.Server/JetStream/Consumers/PullConsumerEngine.cs index 47e41f2..210f7d8 100644 --- a/src/NATS.Server/JetStream/Consumers/PullConsumerEngine.cs +++ b/src/NATS.Server/JetStream/Consumers/PullConsumerEngine.cs @@ -1,4 +1,5 @@ using NATS.Server.JetStream.Storage; +using NATS.Server.JetStream.Models; namespace NATS.Server.JetStream.Consumers; @@ -7,6 +8,31 @@ public sealed class PullConsumerEngine public async ValueTask FetchAsync(StreamHandle stream, ConsumerHandle consumer, int batch, CancellationToken ct) { var messages = new List(batch); + + if (consumer.Config.AckPolicy == AckPolicy.Explicit) + { + var expired = consumer.AckProcessor.NextExpired(); + if (expired is { } expiredSequence) + { + var redelivery = await stream.Store.LoadAsync(expiredSequence, ct); + if (redelivery != null) + { + messages.Add(new StoredMessage + { + Sequence = redelivery.Sequence, + Subject = redelivery.Subject, + Payload = redelivery.Payload, + Redelivered = true, + }); + } + + return new PullFetchBatch(messages); + } + + if (consumer.AckProcessor.HasPending) + return new PullFetchBatch(messages); + } + var sequence = consumer.NextSequence; for (var i = 0; i < batch; i++) @@ -16,6 +42,8 @@ public sealed class PullConsumerEngine break; messages.Add(message); + if (consumer.Config.AckPolicy == AckPolicy.Explicit) + consumer.AckProcessor.Register(message.Sequence, consumer.Config.AckWaitMs); sequence++; } diff --git a/src/NATS.Server/JetStream/Consumers/PushConsumerEngine.cs b/src/NATS.Server/JetStream/Consumers/PushConsumerEngine.cs index 52bfbc2..effef2d 100644 --- a/src/NATS.Server/JetStream/Consumers/PushConsumerEngine.cs +++ b/src/NATS.Server/JetStream/Consumers/PushConsumerEngine.cs @@ -5,17 +5,20 @@ namespace NATS.Server.JetStream.Consumers; public sealed class PushConsumerEngine { - public void Enqueue(Queue queue, StoredMessage message, ConsumerConfig config) + public void Enqueue(ConsumerHandle consumer, StoredMessage message) { - queue.Enqueue(new PushFrame + consumer.PushFrames.Enqueue(new PushFrame { IsData = true, Message = message, }); - if (config.HeartbeatMs > 0) + if (consumer.Config.AckPolicy == AckPolicy.Explicit) + consumer.AckProcessor.Register(message.Sequence, consumer.Config.AckWaitMs); + + if (consumer.Config.HeartbeatMs > 0) { - queue.Enqueue(new PushFrame + consumer.PushFrames.Enqueue(new PushFrame { IsHeartbeat = true, }); diff --git a/tests/NATS.Server.Tests/JetStreamAckRedeliveryTests.cs b/tests/NATS.Server.Tests/JetStreamAckRedeliveryTests.cs new file mode 100644 index 0000000..c65dfbc --- /dev/null +++ b/tests/NATS.Server.Tests/JetStreamAckRedeliveryTests.cs @@ -0,0 +1,17 @@ +namespace NATS.Server.Tests; + +public class JetStreamAckRedeliveryTests +{ + [Fact] + public async Task Unacked_message_is_redelivered_after_ack_wait() + { + await using var fixture = await JetStreamApiFixture.StartWithAckExplicitConsumerAsync(ackWaitMs: 50); + await fixture.PublishAndGetAckAsync("orders.created", "1"); + + var first = await fixture.FetchAsync("ORDERS", "PULL", batch: 1); + var second = await fixture.FetchAfterDelayAsync("ORDERS", "PULL", delayMs: 75, batch: 1); + + second.Messages.Single().Sequence.ShouldBe(first.Messages.Single().Sequence); + second.Messages.Single().Redelivered.ShouldBeTrue(); + } +} diff --git a/tests/NATS.Server.Tests/JetStreamApiFixture.cs b/tests/NATS.Server.Tests/JetStreamApiFixture.cs index ca321e2..2561048 100644 --- a/tests/NATS.Server.Tests/JetStreamApiFixture.cs +++ b/tests/NATS.Server.Tests/JetStreamApiFixture.cs @@ -53,6 +53,14 @@ internal sealed class JetStreamApiFixture : IAsyncDisposable return fixture; } + public static async Task StartWithAckExplicitConsumerAsync(int ackWaitMs) + { + var fixture = await StartWithStreamAsync("ORDERS", "orders.*"); + _ = await fixture.CreateConsumerAsync("ORDERS", "PULL", "orders.created", + ackPolicy: AckPolicy.Explicit, ackWaitMs: ackWaitMs); + return fixture; + } + public Task PublishAndGetAckAsync(string subject, string payload, string? msgId = null, bool expectError = false) { if (_publisher.TryCapture(subject, Encoding.UTF8.GetBytes(payload), msgId, out var ack)) @@ -83,9 +91,9 @@ internal sealed class JetStreamApiFixture : IAsyncDisposable return _streamManager.GetStateAsync(streamName, default).AsTask(); } - public Task CreateConsumerAsync(string stream, string durableName, string filterSubject, bool push = false, int heartbeatMs = 0) + public Task CreateConsumerAsync(string stream, string durableName, string filterSubject, bool push = false, int heartbeatMs = 0, AckPolicy ackPolicy = AckPolicy.None, int ackWaitMs = 30_000) { - var payload = $@"{{""durable_name"":""{durableName}"",""filter_subject"":""{filterSubject}"",""push"":{push.ToString().ToLowerInvariant()},""heartbeat_ms"":{heartbeatMs}}}"; + var payload = $@"{{""durable_name"":""{durableName}"",""filter_subject"":""{filterSubject}"",""push"":{push.ToString().ToLowerInvariant()},""heartbeat_ms"":{heartbeatMs},""ack_policy"":""{ackPolicy.ToString().ToLowerInvariant()}"",""ack_wait_ms"":{ackWaitMs}}}"; return RequestLocalAsync($"$JS.API.CONSUMER.CREATE.{stream}.{durableName}", payload); } @@ -100,6 +108,12 @@ internal sealed class JetStreamApiFixture : IAsyncDisposable return _consumerManager.FetchAsync(stream, durableName, batch, _streamManager, default).AsTask(); } + public async Task FetchAfterDelayAsync(string stream, string durableName, int delayMs, int batch) + { + await Task.Delay(delayMs); + return await FetchAsync(stream, durableName, batch); + } + public Task ReadPushFrameAsync(string stream = "ORDERS", string durableName = "PUSH") { var frame = _consumerManager.ReadPushFrame(stream, durableName);