From 45de110a84888b0d4a5dad92a8b3d58e0b73c3f1 Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 22 Feb 2026 23:45:26 -0500 Subject: [PATCH] feat: add flush-before-close for graceful client shutdown --- src/NATS.Server/NatsClient.cs | 25 ++++++++++ src/NATS.Server/NatsServer.cs | 5 +- tests/NATS.Server.Tests/ServerTests.cs | 69 ++++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 1 deletion(-) diff --git a/src/NATS.Server/NatsClient.cs b/src/NATS.Server/NatsClient.cs index 0e58d82..43f2e28 100644 --- a/src/NATS.Server/NatsClient.cs +++ b/src/NATS.Server/NatsClient.cs @@ -563,6 +563,31 @@ public sealed class NatsClient : IDisposable _logger.LogDebug("Client {ClientId} connection closed: {CloseReason}", Id, reason); } + /// + /// Flushes pending data (unless skip-flush is set) and closes the connection. + /// + public async Task FlushAndCloseAsync(bool minimalFlush = false) + { + if (!ShouldSkipFlush) + { + try + { + using var flushCts = new CancellationTokenSource(minimalFlush + ? TimeSpan.FromMilliseconds(100) + : TimeSpan.FromSeconds(1)); + await _stream.FlushAsync(flushCts.Token); + } + catch (Exception) + { + // Best effort flush — don't let it prevent close + } + } + + try { _socket.Shutdown(SocketShutdown.Both); } + catch (SocketException) { } + catch (ObjectDisposedException) { } + } + public void RemoveAllSubscriptions(SubList subList) { foreach (var sub in _subs.Values) diff --git a/src/NATS.Server/NatsServer.cs b/src/NATS.Server/NatsServer.cs index 2a33091..199f3cd 100644 --- a/src/NATS.Server/NatsServer.cs +++ b/src/NATS.Server/NatsServer.cs @@ -86,11 +86,14 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable // Wait for accept loop to exit await _acceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); - // Close all client connections + // Close all client connections — flush first, then mark closed + var flushTasks = new List(); foreach (var client in _clients.Values) { client.MarkClosed(ClosedState.ServerShutdown); + flushTasks.Add(client.FlushAndCloseAsync(minimalFlush: true)); } + await Task.WhenAll(flushTasks).WaitAsync(TimeSpan.FromSeconds(2)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); // Wait for active client tasks to drain (with timeout) if (Volatile.Read(ref _activeClientCount) > 0) diff --git a/tests/NATS.Server.Tests/ServerTests.cs b/tests/NATS.Server.Tests/ServerTests.cs index caa4d36..8f39a8c 100644 --- a/tests/NATS.Server.Tests/ServerTests.cs +++ b/tests/NATS.Server.Tests/ServerTests.cs @@ -548,6 +548,75 @@ public class ServerIdentityTests } } +public class FlushBeforeCloseTests +{ + private static int GetFreePort() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + sock.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + return ((IPEndPoint)sock.LocalEndPoint!).Port; + } + + private static async Task ReadUntilAsync(Socket sock, string expected, int timeoutMs = 5000) + { + using var cts = new CancellationTokenSource(timeoutMs); + var sb = new StringBuilder(); + var buf = new byte[4096]; + while (!sb.ToString().Contains(expected)) + { + var n = await sock.ReceiveAsync(buf, SocketFlags.None, cts.Token); + if (n == 0) break; + sb.Append(Encoding.ASCII.GetString(buf, 0, n)); + } + return sb.ToString(); + } + + [Fact] + public async Task Shutdown_flushes_pending_data_to_clients() + { + var port = GetFreePort(); + var server = new NatsServer(new NatsOptions { Port = port }, NullLoggerFactory.Instance); + _ = server.StartAsync(CancellationToken.None); + await server.WaitForReadyAsync(); + + try + { + // Connect a subscriber via raw socket + using var sub = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await sub.ConnectAsync(IPAddress.Loopback, port); + + // Read INFO + var buf = new byte[4096]; + await sub.ReceiveAsync(buf, SocketFlags.None); + + // Subscribe to "foo" + await sub.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nSUB foo 1\r\nPING\r\n")); + var pong = await ReadUntilAsync(sub, "PONG"); + pong.ShouldContain("PONG"); + + // Connect a publisher + using var pub = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await pub.ConnectAsync(IPAddress.Loopback, port); + await pub.ReceiveAsync(buf, SocketFlags.None); // INFO + + // Publish "Hello" to "foo" + await pub.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPUB foo 5\r\nHello\r\n")); + + // Wait briefly for delivery + await Task.Delay(200); + + // Read from subscriber to verify MSG was received + var msg = await ReadUntilAsync(sub, "Hello\r\n"); + msg.ShouldContain("MSG foo 1 5\r\nHello\r\n"); + } + finally + { + await server.ShutdownAsync(); + server.Dispose(); + } + } +} + public class GracefulShutdownTests { private static int GetFreePort()