From b68f898fa0eafa9acb248c983d61c48e579471cf Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 22 Feb 2026 23:43:25 -0500 Subject: [PATCH] feat: add graceful shutdown, accept loop backoff, and task tracking --- src/NATS.Server/NatsServer.cs | 133 ++++++++++++++++++++++--- tests/NATS.Server.Tests/ServerTests.cs | 112 +++++++++++++++++++++ 2 files changed, 232 insertions(+), 13 deletions(-) diff --git a/src/NATS.Server/NatsServer.cs b/src/NATS.Server/NatsServer.cs index 3b509aa..2a33091 100644 --- a/src/NATS.Server/NatsServer.cs +++ b/src/NATS.Server/NatsServer.cs @@ -34,6 +34,25 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable private ulong _nextClientId; private long _startTimeTicks; + private readonly CancellationTokenSource _quitCts = new(); + private readonly TaskCompletionSource _shutdownComplete = new(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly TaskCompletionSource _acceptLoopExited = new(TaskCreationOptions.RunContinuationsAsynchronously); + private int _shutdown; + private int _activeClientCount; + + // Used by future lame duck mode implementation +#pragma warning disable CS0649 // Field is never assigned to + private int _lameDuck; +#pragma warning restore CS0649 + + // Used by future ports file implementation +#pragma warning disable CS0169 // Field is never used + private string? _portsFilePath; +#pragma warning restore CS0169 + + private static readonly TimeSpan AcceptMinSleep = TimeSpan.FromMilliseconds(10); + private static readonly TimeSpan AcceptMaxSleep = TimeSpan.FromSeconds(1); + public SubList SubList => _globalAccount.SubList; public ServerStats Stats => _stats; public DateTime StartTime => new(Interlocked.Read(ref _startTimeTicks), DateTimeKind.Utc); @@ -43,10 +62,56 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable public int Port => _options.Port; public Account SystemAccount => _systemAccount; public string ServerNKey { get; } + public bool IsShuttingDown => Volatile.Read(ref _shutdown) != 0; + public bool IsLameDuckMode => Volatile.Read(ref _lameDuck) != 0; public IEnumerable GetClients() => _clients.Values; public Task WaitForReadyAsync() => _listeningStarted.Task; + public void WaitForShutdown() => _shutdownComplete.Task.GetAwaiter().GetResult(); + + public async Task ShutdownAsync() + { + if (Interlocked.CompareExchange(ref _shutdown, 1, 0) != 0) + return; // Already shutting down + + _logger.LogInformation("Initiating Shutdown..."); + + // Signal all internal loops to stop + await _quitCts.CancelAsync(); + + // Close listener to stop accept loop + _listener?.Close(); + + // Wait for accept loop to exit + await _acceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + + // Close all client connections + foreach (var client in _clients.Values) + { + client.MarkClosed(ClosedState.ServerShutdown); + } + + // Wait for active client tasks to drain (with timeout) + if (Volatile.Read(ref _activeClientCount) > 0) + { + using var drainCts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + try + { + while (Volatile.Read(ref _activeClientCount) > 0 && !drainCts.IsCancellationRequested) + await Task.Delay(50, drainCts.Token); + } + catch (OperationCanceledException) { } + } + + // Stop monitor server + if (_monitorServer != null) + await _monitorServer.DisposeAsync(); + + _logger.LogInformation("Server Exiting.."); + _shutdownComplete.TrySetResult(); + } + public NatsServer(NatsOptions options, ILoggerFactory loggerFactory) { _options = options; @@ -89,6 +154,8 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable public async Task StartAsync(CancellationToken ct) { + using var linked = CancellationTokenSource.CreateLinkedTokenSource(ct, _quitCts.Token); + _listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); _listener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true); _listener.Bind(new IPEndPoint( @@ -107,21 +174,54 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable _listeningStarted.TrySetResult(); - _logger.LogInformation("Listening on {Host}:{Port}", _options.Host, _options.Port); + _logger.LogInformation("Listening for client connections on {Host}:{Port}", _options.Host, _options.Port); + + // Warn about stub features + if (_options.ConfigFile != null) + _logger.LogWarning("Config file parsing not yet supported (file: {ConfigFile})", _options.ConfigFile); + if (_options.ProfPort > 0) + _logger.LogWarning("Profiling endpoint not yet supported (port: {ProfPort})", _options.ProfPort); if (_options.MonitorPort > 0) { _monitorServer = new MonitorServer(this, _options, _stats, _loggerFactory); - await _monitorServer.StartAsync(ct); + await _monitorServer.StartAsync(linked.Token); } + var tmpDelay = AcceptMinSleep; + try { - while (!ct.IsCancellationRequested) + while (!linked.Token.IsCancellationRequested) { - var socket = await _listener.AcceptAsync(ct); + Socket socket; + try + { + socket = await _listener.AcceptAsync(linked.Token); + tmpDelay = AcceptMinSleep; // Reset on success + } + catch (OperationCanceledException) + { + break; + } + catch (ObjectDisposedException) + { + break; + } + catch (SocketException ex) + { + if (IsShuttingDown || IsLameDuckMode) + break; - // Check MaxConnections before creating the client + _logger.LogError(ex, "Temporary accept error, sleeping {Delay}ms", tmpDelay.TotalMilliseconds); + try { await Task.Delay(tmpDelay, linked.Token); } + catch (OperationCanceledException) { break; } + + tmpDelay = TimeSpan.FromTicks(Math.Min(tmpDelay.Ticks * 2, AcceptMaxSleep.Ticks)); + continue; + } + + // Check MaxConnections if (_options.MaxConnections > 0 && _clients.Count >= _options.MaxConnections) { _logger.LogWarning("Client connection rejected: maximum connections ({MaxConnections}) exceeded", @@ -131,13 +231,13 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable var stream = new NetworkStream(socket, ownsSocket: false); var errBytes = Encoding.ASCII.GetBytes( $"-ERR '{NatsProtocol.ErrMaxConnectionsExceeded}'\r\n"); - await stream.WriteAsync(errBytes, ct); - await stream.FlushAsync(ct); + await stream.WriteAsync(errBytes, linked.Token); + await stream.FlushAsync(linked.Token); stream.Dispose(); } - catch (Exception ex) + catch (Exception ex2) { - _logger.LogDebug(ex, "Failed to send -ERR to rejected client"); + _logger.LogDebug(ex2, "Failed to send -ERR to rejected client"); } finally { @@ -148,16 +248,21 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable var clientId = Interlocked.Increment(ref _nextClientId); Interlocked.Increment(ref _stats.TotalConnections); + Interlocked.Increment(ref _activeClientCount); _logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint); - _ = AcceptClientAsync(socket, clientId, ct); + _ = AcceptClientAsync(socket, clientId, linked.Token); } } catch (OperationCanceledException) { _logger.LogDebug("Accept loop cancelled, server shutting down"); } + finally + { + _acceptLoopExited.TrySetResult(); + } } private async Task AcceptClientAsync(Socket socket, ulong clientId, CancellationToken ct) @@ -240,8 +345,9 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable } finally { - _logger.LogDebug("Client {ClientId} disconnected", client.Id); + _logger.LogDebug("Client {ClientId} disconnected (reason: {CloseReason})", client.Id, client.CloseReason); RemoveClient(client); + Interlocked.Decrement(ref _activeClientCount); } } @@ -313,8 +419,9 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable public void Dispose() { - if (_monitorServer != null) - _monitorServer.DisposeAsync().AsTask().GetAwaiter().GetResult(); + if (!IsShuttingDown) + ShutdownAsync().GetAwaiter().GetResult(); + _quitCts.Dispose(); _tlsRateLimiter?.Dispose(); _listener?.Dispose(); foreach (var client in _clients.Values) diff --git a/tests/NATS.Server.Tests/ServerTests.cs b/tests/NATS.Server.Tests/ServerTests.cs index b4f3618..caa4d36 100644 --- a/tests/NATS.Server.Tests/ServerTests.cs +++ b/tests/NATS.Server.Tests/ServerTests.cs @@ -547,3 +547,115 @@ public class ServerIdentityTests server.Dispose(); } } + +public class GracefulShutdownTests +{ + private static int GetFreePort() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + sock.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + return ((IPEndPoint)sock.LocalEndPoint!).Port; + } + + [Fact] + public async Task ShutdownAsync_disconnects_all_clients() + { + var port = GetFreePort(); + var server = new NatsServer(new NatsOptions { Port = port }, NullLoggerFactory.Instance); + _ = server.StartAsync(CancellationToken.None); + await server.WaitForReadyAsync(); + + // Connect 2 raw TCP clients + using var client1 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await client1.ConnectAsync(IPAddress.Loopback, port); + var buf = new byte[4096]; + await client1.ReceiveAsync(buf, SocketFlags.None); // INFO + + using var client2 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await client2.ConnectAsync(IPAddress.Loopback, port); + await client2.ReceiveAsync(buf, SocketFlags.None); // INFO + + // Send CONNECT so both are registered + await client1.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPING\r\n")); + await client2.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPING\r\n")); + + // Wait for PONG from both (confirming they are registered) + using var readCts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + await client1.ReceiveAsync(buf, SocketFlags.None, readCts.Token); + await client2.ReceiveAsync(buf, SocketFlags.None, readCts.Token); + + server.ClientCount.ShouldBe(2); + + await server.ShutdownAsync(); + + server.ClientCount.ShouldBe(0); + server.Dispose(); + } + + [Fact] + public async Task WaitForShutdown_blocks_until_shutdown() + { + var port = GetFreePort(); + var server = new NatsServer(new NatsOptions { Port = port }, NullLoggerFactory.Instance); + _ = server.StartAsync(CancellationToken.None); + await server.WaitForReadyAsync(); + + // Start WaitForShutdown in background + var waitTask = Task.Run(() => server.WaitForShutdown()); + + // Give it a moment -- it should NOT complete yet + await Task.Delay(200); + waitTask.IsCompleted.ShouldBeFalse(); + + // Trigger shutdown + await server.ShutdownAsync(); + + // WaitForShutdown should complete within 5 seconds + var completed = await Task.WhenAny(waitTask, Task.Delay(TimeSpan.FromSeconds(5))); + completed.ShouldBe(waitTask); + + server.Dispose(); + } + + [Fact] + public async Task ShutdownAsync_is_idempotent() + { + var port = GetFreePort(); + var server = new NatsServer(new NatsOptions { Port = port }, NullLoggerFactory.Instance); + _ = server.StartAsync(CancellationToken.None); + await server.WaitForReadyAsync(); + + // Call ShutdownAsync 3 times -- should not throw + await server.ShutdownAsync(); + await server.ShutdownAsync(); + await server.ShutdownAsync(); + + server.IsShuttingDown.ShouldBeTrue(); + server.Dispose(); + } + + [Fact] + public async Task Accept_loop_waits_for_active_clients() + { + var port = GetFreePort(); + var server = new NatsServer(new NatsOptions { Port = port }, NullLoggerFactory.Instance); + _ = server.StartAsync(CancellationToken.None); + await server.WaitForReadyAsync(); + + // Connect a client + using var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await client.ConnectAsync(IPAddress.Loopback, port); + var buf = new byte[4096]; + await client.ReceiveAsync(buf, SocketFlags.None); // INFO + await client.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPING\r\n")); + using var readCts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + await client.ReceiveAsync(buf, SocketFlags.None, readCts.Token); // PONG + + // ShutdownAsync should complete within 10 seconds (doesn't hang) + var shutdownTask = server.ShutdownAsync(); + var completed = await Task.WhenAny(shutdownTask, Task.Delay(TimeSpan.FromSeconds(10))); + completed.ShouldBe(shutdownTask); + + server.Dispose(); + } +}