diff --git a/src/NATS.Server/NatsClient.cs b/src/NATS.Server/NatsClient.cs
index 0e58d82..43f2e28 100644
--- a/src/NATS.Server/NatsClient.cs
+++ b/src/NATS.Server/NatsClient.cs
@@ -563,6 +563,31 @@ public sealed class NatsClient : IDisposable
_logger.LogDebug("Client {ClientId} connection closed: {CloseReason}", Id, reason);
}
+ ///
+ /// Flushes pending data (unless skip-flush is set) and closes the connection.
+ ///
+ public async Task FlushAndCloseAsync(bool minimalFlush = false)
+ {
+ if (!ShouldSkipFlush)
+ {
+ try
+ {
+ using var flushCts = new CancellationTokenSource(minimalFlush
+ ? TimeSpan.FromMilliseconds(100)
+ : TimeSpan.FromSeconds(1));
+ await _stream.FlushAsync(flushCts.Token);
+ }
+ catch (Exception)
+ {
+ // Best effort flush — don't let it prevent close
+ }
+ }
+
+ try { _socket.Shutdown(SocketShutdown.Both); }
+ catch (SocketException) { }
+ catch (ObjectDisposedException) { }
+ }
+
public void RemoveAllSubscriptions(SubList subList)
{
foreach (var sub in _subs.Values)
diff --git a/src/NATS.Server/NatsServer.cs b/src/NATS.Server/NatsServer.cs
index 2a33091..199f3cd 100644
--- a/src/NATS.Server/NatsServer.cs
+++ b/src/NATS.Server/NatsServer.cs
@@ -86,11 +86,14 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
// Wait for accept loop to exit
await _acceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
- // Close all client connections
+ // Close all client connections — flush first, then mark closed
+ var flushTasks = new List();
foreach (var client in _clients.Values)
{
client.MarkClosed(ClosedState.ServerShutdown);
+ flushTasks.Add(client.FlushAndCloseAsync(minimalFlush: true));
}
+ await Task.WhenAll(flushTasks).WaitAsync(TimeSpan.FromSeconds(2)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
// Wait for active client tasks to drain (with timeout)
if (Volatile.Read(ref _activeClientCount) > 0)
diff --git a/tests/NATS.Server.Tests/ServerTests.cs b/tests/NATS.Server.Tests/ServerTests.cs
index caa4d36..8f39a8c 100644
--- a/tests/NATS.Server.Tests/ServerTests.cs
+++ b/tests/NATS.Server.Tests/ServerTests.cs
@@ -548,6 +548,75 @@ public class ServerIdentityTests
}
}
+public class FlushBeforeCloseTests
+{
+ private static int GetFreePort()
+ {
+ using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+ sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
+ return ((IPEndPoint)sock.LocalEndPoint!).Port;
+ }
+
+ private static async Task ReadUntilAsync(Socket sock, string expected, int timeoutMs = 5000)
+ {
+ using var cts = new CancellationTokenSource(timeoutMs);
+ var sb = new StringBuilder();
+ var buf = new byte[4096];
+ while (!sb.ToString().Contains(expected))
+ {
+ var n = await sock.ReceiveAsync(buf, SocketFlags.None, cts.Token);
+ if (n == 0) break;
+ sb.Append(Encoding.ASCII.GetString(buf, 0, n));
+ }
+ return sb.ToString();
+ }
+
+ [Fact]
+ public async Task Shutdown_flushes_pending_data_to_clients()
+ {
+ var port = GetFreePort();
+ var server = new NatsServer(new NatsOptions { Port = port }, NullLoggerFactory.Instance);
+ _ = server.StartAsync(CancellationToken.None);
+ await server.WaitForReadyAsync();
+
+ try
+ {
+ // Connect a subscriber via raw socket
+ using var sub = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+ await sub.ConnectAsync(IPAddress.Loopback, port);
+
+ // Read INFO
+ var buf = new byte[4096];
+ await sub.ReceiveAsync(buf, SocketFlags.None);
+
+ // Subscribe to "foo"
+ await sub.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nSUB foo 1\r\nPING\r\n"));
+ var pong = await ReadUntilAsync(sub, "PONG");
+ pong.ShouldContain("PONG");
+
+ // Connect a publisher
+ using var pub = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+ await pub.ConnectAsync(IPAddress.Loopback, port);
+ await pub.ReceiveAsync(buf, SocketFlags.None); // INFO
+
+ // Publish "Hello" to "foo"
+ await pub.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPUB foo 5\r\nHello\r\n"));
+
+ // Wait briefly for delivery
+ await Task.Delay(200);
+
+ // Read from subscriber to verify MSG was received
+ var msg = await ReadUntilAsync(sub, "Hello\r\n");
+ msg.ShouldContain("MSG foo 1 5\r\nHello\r\n");
+ }
+ finally
+ {
+ await server.ShutdownAsync();
+ server.Dispose();
+ }
+ }
+}
+
public class GracefulShutdownTests
{
private static int GetFreePort()