From 34067f2b9b4353a99df07f11c44e86273fc44bd0 Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 22 Feb 2026 23:48:06 -0500 Subject: [PATCH] feat: add lame duck mode with staggered client shutdown --- src/NATS.Server/NatsServer.cs | 78 ++++++++++++++++- tests/NATS.Server.Tests/ServerTests.cs | 114 +++++++++++++++++++++++++ 2 files changed, 189 insertions(+), 3 deletions(-) diff --git a/src/NATS.Server/NatsServer.cs b/src/NATS.Server/NatsServer.cs index 199f3cd..88ae76c 100644 --- a/src/NATS.Server/NatsServer.cs +++ b/src/NATS.Server/NatsServer.cs @@ -40,10 +40,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable private int _shutdown; private int _activeClientCount; - // Used by future lame duck mode implementation -#pragma warning disable CS0649 // Field is never assigned to private int _lameDuck; -#pragma warning restore CS0649 // Used by future ports file implementation #pragma warning disable CS0169 // Field is never used @@ -115,6 +112,81 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable _shutdownComplete.TrySetResult(); } + public async Task LameDuckShutdownAsync() + { + if (IsShuttingDown || Interlocked.CompareExchange(ref _lameDuck, 1, 0) != 0) + return; + + _logger.LogInformation("Entering lame duck mode, stop accepting new clients"); + + // Close listener to stop accepting new connections + _listener?.Close(); + + // Wait for accept loop to exit + await _acceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + + var gracePeriod = _options.LameDuckGracePeriod; + if (gracePeriod < TimeSpan.Zero) gracePeriod = -gracePeriod; + + // If no clients, go straight to shutdown + if (_clients.IsEmpty) + { + await ShutdownAsync(); + return; + } + + // Wait grace period for clients to drain naturally + _logger.LogInformation("Waiting {GracePeriod}ms grace period", gracePeriod.TotalMilliseconds); + try + { + await Task.Delay(gracePeriod, _quitCts.Token); + } + catch (OperationCanceledException) { return; } + + if (_clients.IsEmpty) + { + await ShutdownAsync(); + return; + } + + // Stagger-close remaining clients + var dur = _options.LameDuckDuration - gracePeriod; + if (dur <= TimeSpan.Zero) dur = TimeSpan.FromSeconds(1); + + var clients = _clients.Values.ToList(); + var numClients = clients.Count; + + if (numClients > 0) + { + _logger.LogInformation("Closing {Count} existing clients over {Duration}ms", + numClients, dur.TotalMilliseconds); + + var sleepInterval = dur.Ticks / numClients; + if (sleepInterval < TimeSpan.TicksPerMillisecond) + sleepInterval = TimeSpan.TicksPerMillisecond; + if (sleepInterval > TimeSpan.TicksPerSecond) + sleepInterval = TimeSpan.TicksPerSecond; + + for (int i = 0; i < clients.Count; i++) + { + clients[i].MarkClosed(ClosedState.ServerShutdown); + await clients[i].FlushAndCloseAsync(minimalFlush: true); + + if (i < clients.Count - 1) + { + var jitter = Random.Shared.NextInt64(sleepInterval / 2, sleepInterval); + try + { + await Task.Delay(TimeSpan.FromTicks(jitter), _quitCts.Token); + } + catch (OperationCanceledException) { break; } + } + } + } + + await ShutdownAsync(); + } + public NatsServer(NatsOptions options, ILoggerFactory loggerFactory) { _options = options; diff --git a/tests/NATS.Server.Tests/ServerTests.cs b/tests/NATS.Server.Tests/ServerTests.cs index 8f39a8c..1897fb1 100644 --- a/tests/NATS.Server.Tests/ServerTests.cs +++ b/tests/NATS.Server.Tests/ServerTests.cs @@ -728,3 +728,117 @@ public class GracefulShutdownTests server.Dispose(); } } + +public class LameDuckTests +{ + private static int GetFreePort() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + sock.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + return ((IPEndPoint)sock.LocalEndPoint!).Port; + } + + [Fact] + public async Task LameDuckShutdown_stops_accepting_new_connections() + { + var port = GetFreePort(); + var server = new NatsServer( + new NatsOptions + { + Port = port, + LameDuckDuration = TimeSpan.FromSeconds(3), + LameDuckGracePeriod = TimeSpan.FromMilliseconds(500), + }, + NullLoggerFactory.Instance); + + _ = server.StartAsync(CancellationToken.None); + await server.WaitForReadyAsync(); + + try + { + // Connect 1 client + using var client1 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await client1.ConnectAsync(IPAddress.Loopback, port); + var buf = new byte[4096]; + await client1.ReceiveAsync(buf, SocketFlags.None); // INFO + await client1.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPING\r\n")); + using var readCts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + await client1.ReceiveAsync(buf, SocketFlags.None, readCts.Token); // PONG + + // Start lame duck (don't await yet) + var lameDuckTask = server.LameDuckShutdownAsync(); + + // Wait briefly for listener to close + await Task.Delay(300); + + // Verify lame duck mode is active + server.IsLameDuckMode.ShouldBeTrue(); + + // Try connecting a new client -- should fail (connection refused) + using var client2 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + var connectAction = async () => + { + await client2.ConnectAsync(IPAddress.Loopback, port); + }; + await connectAction.ShouldThrowAsync(); + + // Await the lame duck task with timeout + var completed = await Task.WhenAny(lameDuckTask, Task.Delay(TimeSpan.FromSeconds(15))); + completed.ShouldBe(lameDuckTask); + } + finally + { + server.Dispose(); + } + } + + [Fact] + public async Task LameDuckShutdown_eventually_closes_all_clients() + { + var port = GetFreePort(); + var server = new NatsServer( + new NatsOptions + { + Port = port, + LameDuckDuration = TimeSpan.FromSeconds(2), + LameDuckGracePeriod = TimeSpan.FromMilliseconds(200), + }, + NullLoggerFactory.Instance); + + _ = server.StartAsync(CancellationToken.None); + await server.WaitForReadyAsync(); + + try + { + // Connect 3 clients via raw sockets + var clients = new List(); + var buf = new byte[4096]; + for (int i = 0; i < 3; i++) + { + var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await sock.ConnectAsync(IPAddress.Loopback, port); + await sock.ReceiveAsync(buf, SocketFlags.None); // INFO + await sock.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPING\r\n")); + using var readCts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + await sock.ReceiveAsync(buf, SocketFlags.None, readCts.Token); // PONG + clients.Add(sock); + } + + server.ClientCount.ShouldBe(3); + + // Await LameDuckShutdownAsync + var lameDuckTask = server.LameDuckShutdownAsync(); + var completed = await Task.WhenAny(lameDuckTask, Task.Delay(TimeSpan.FromSeconds(15))); + completed.ShouldBe(lameDuckTask); + + server.ClientCount.ShouldBe(0); + + foreach (var sock in clients) + sock.Dispose(); + } + finally + { + server.Dispose(); + } + } +}