diff --git a/src/NATS.Server/NatsClient.cs b/src/NATS.Server/NatsClient.cs index 31b609a..df04363 100644 --- a/src/NATS.Server/NatsClient.cs +++ b/src/NATS.Server/NatsClient.cs @@ -44,6 +44,10 @@ public sealed class NatsClient : IDisposable public long InBytes; public long OutBytes; + // PING keepalive state + private int _pingsOut; + private long _lastIn; + public IReadOnlyDictionary Subscriptions => _subs; public NatsClient(ulong id, Socket socket, NatsOptions options, ServerInfo serverInfo, ILogger logger) @@ -60,17 +64,19 @@ public sealed class NatsClient : IDisposable public async Task RunAsync(CancellationToken ct) { _clientCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + Interlocked.Exchange(ref _lastIn, Environment.TickCount64); var pipe = new Pipe(); try { // Send INFO await SendInfoAsync(_clientCts.Token); - // Start read pump and command processing in parallel + // Start read pump, command processing, and ping timer in parallel var fillTask = FillPipeAsync(pipe.Writer, _clientCts.Token); var processTask = ProcessCommandsAsync(pipe.Reader, _clientCts.Token); + var pingTask = RunPingTimerAsync(_clientCts.Token); - await Task.WhenAny(fillTask, processTask); + await Task.WhenAny(fillTask, processTask, pingTask); } catch (OperationCanceledException) { @@ -82,15 +88,10 @@ public sealed class NatsClient : IDisposable } finally { + try { _socket.Shutdown(SocketShutdown.Both); } + catch (SocketException) { } + catch (ObjectDisposedException) { } Router?.RemoveClient(this); - try - { - _socket.Shutdown(SocketShutdown.Both); - } - catch (Exception ex) - { - _logger.LogDebug(ex, "Client {ClientId} socket shutdown error", Id); - } } } @@ -128,6 +129,7 @@ public sealed class NatsClient : IDisposable while (_parser.TryParse(ref buffer, out var cmd)) { + Interlocked.Exchange(ref _lastIn, Environment.TickCount64); await DispatchCommandAsync(cmd, ct); } @@ -156,7 +158,7 @@ public sealed class NatsClient : IDisposable break; case CommandType.Pong: - // Update RTT tracking (placeholder) + Interlocked.Exchange(ref _pingsOut, 0); break; case CommandType.Sub: @@ -339,6 +341,48 @@ public sealed class NatsClient : IDisposable _socket.Close(); } + private async Task RunPingTimerAsync(CancellationToken ct) + { + using var timer = new PeriodicTimer(_options.PingInterval); + try + { + while (await timer.WaitForNextTickAsync(ct)) + { + var elapsed = Environment.TickCount64 - Interlocked.Read(ref _lastIn); + if (elapsed < (long)_options.PingInterval.TotalMilliseconds) + { + // Client was recently active, skip ping + Interlocked.Exchange(ref _pingsOut, 0); + continue; + } + + var currentPingsOut = Interlocked.Increment(ref _pingsOut); + if (currentPingsOut > _options.MaxPingsOut) + { + _logger.LogDebug("Client {ClientId} stale connection — closing", Id); + await SendErrAndCloseAsync(NatsProtocol.ErrStaleConnection); + return; + } + + _logger.LogDebug("Client {ClientId} sending PING ({PingsOut}/{MaxPingsOut})", + Id, currentPingsOut, _options.MaxPingsOut); + try + { + await WriteAsync(NatsProtocol.PingBytes, ct); + } + catch (Exception ex) + { + _logger.LogDebug(ex, "Client {ClientId} failed to send PING", Id); + return; + } + } + } + catch (OperationCanceledException) + { + // Normal shutdown + } + } + public void RemoveAllSubscriptions(SubList subList) { foreach (var sub in _subs.Values) diff --git a/tests/NATS.Server.Tests/ServerTests.cs b/tests/NATS.Server.Tests/ServerTests.cs index 476e0c6..3b49d3f 100644 --- a/tests/NATS.Server.Tests/ServerTests.cs +++ b/tests/NATS.Server.Tests/ServerTests.cs @@ -278,3 +278,148 @@ public class MaxConnectionsTests : IAsyncLifetime client3.Dispose(); } } + +public class PingKeepaliveTests : IAsyncLifetime +{ + private readonly NatsServer _server; + private readonly int _port; + private readonly CancellationTokenSource _cts = new(); + + public PingKeepaliveTests() + { + _port = GetFreePort(); + // Short intervals for testing: 500ms ping interval, 2 max pings out + _server = new NatsServer( + new NatsOptions + { + Port = _port, + PingInterval = TimeSpan.FromMilliseconds(500), + MaxPingsOut = 2, + }, + NullLoggerFactory.Instance); + } + + public async Task InitializeAsync() + { + _ = _server.StartAsync(_cts.Token); + await _server.WaitForReadyAsync(); + } + + public async Task DisposeAsync() + { + await _cts.CancelAsync(); + _server.Dispose(); + } + + private static int GetFreePort() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + sock.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + return ((IPEndPoint)sock.LocalEndPoint!).Port; + } + + private static async Task ReadUntilAsync(Socket sock, string expected, int timeoutMs = 5000) + { + using var cts = new CancellationTokenSource(timeoutMs); + var sb = new StringBuilder(); + var buf = new byte[4096]; + while (!sb.ToString().Contains(expected)) + { + var n = await sock.ReceiveAsync(buf, SocketFlags.None, cts.Token); + if (n == 0) break; + sb.Append(Encoding.ASCII.GetString(buf, 0, n)); + } + return sb.ToString(); + } + + [Fact] + public async Task Server_sends_PING_after_inactivity() + { + var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await client.ConnectAsync(IPAddress.Loopback, _port); + + // Read INFO + var buf = new byte[4096]; + await client.ReceiveAsync(buf, SocketFlags.None); + + // Send CONNECT to start keepalive + await client.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\n")); + + // Wait for server to send PING (should come within ~500ms) + var response = await ReadUntilAsync(client, "PING", timeoutMs: 3000); + response.ShouldContain("PING"); + + client.Dispose(); + } + + [Fact] + public async Task Server_pong_resets_ping_counter() + { + var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await client.ConnectAsync(IPAddress.Loopback, _port); + + var buf = new byte[4096]; + await client.ReceiveAsync(buf, SocketFlags.None); // INFO + + await client.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\n")); + + // Wait for first PING + var response = await ReadUntilAsync(client, "PING", timeoutMs: 3000); + response.ShouldContain("PING"); + + // Respond with PONG — this resets the counter + await client.SendAsync(Encoding.ASCII.GetBytes("PONG\r\n")); + + // Wait for next PING (counter reset, so we should get another one) + response = await ReadUntilAsync(client, "PING", timeoutMs: 3000); + response.ShouldContain("PING"); + + // Respond again to keep alive + await client.SendAsync(Encoding.ASCII.GetBytes("PONG\r\n")); + + // Client should still be alive — send a PING and expect PONG back + await client.SendAsync(Encoding.ASCII.GetBytes("PING\r\n")); + response = await ReadUntilAsync(client, "PONG", timeoutMs: 3000); + response.ShouldContain("PONG"); + + client.Dispose(); + } + + [Fact] + public async Task Server_disconnects_stale_client() + { + var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await client.ConnectAsync(IPAddress.Loopback, _port); + + var buf = new byte[4096]; + await client.ReceiveAsync(buf, SocketFlags.None); // INFO + + await client.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\n")); + + // Don't respond to PINGs — wait for stale disconnect + // With 500ms interval and MaxPingsOut=2: + // t=500ms: PING #1, pingsOut=1 + // t=1000ms: PING #2, pingsOut=2 + // t=1500ms: pingsOut+1 > MaxPingsOut → -ERR 'Stale Connection' + close + var sb = new StringBuilder(); + try + { + using var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + while (true) + { + var n = await client.ReceiveAsync(buf, SocketFlags.None, timeout.Token); + if (n == 0) break; + sb.Append(Encoding.ASCII.GetString(buf, 0, n)); + } + } + catch (OperationCanceledException) + { + // Timeout is acceptable — check what we got + } + + var allData = sb.ToString(); + allData.ShouldContain("-ERR 'Stale Connection'"); + + client.Dispose(); + } +}