feat: add server-side PING keepalive with stale connection detection
This commit is contained in:
@@ -44,6 +44,10 @@ public sealed class NatsClient : IDisposable
|
||||
public long InBytes;
|
||||
public long OutBytes;
|
||||
|
||||
// PING keepalive state
|
||||
private int _pingsOut;
|
||||
private long _lastIn;
|
||||
|
||||
public IReadOnlyDictionary<string, Subscription> Subscriptions => _subs;
|
||||
|
||||
public NatsClient(ulong id, Socket socket, NatsOptions options, ServerInfo serverInfo, ILogger logger)
|
||||
@@ -60,17 +64,19 @@ public sealed class NatsClient : IDisposable
|
||||
public async Task RunAsync(CancellationToken ct)
|
||||
{
|
||||
_clientCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
||||
Interlocked.Exchange(ref _lastIn, Environment.TickCount64);
|
||||
var pipe = new Pipe();
|
||||
try
|
||||
{
|
||||
// Send INFO
|
||||
await SendInfoAsync(_clientCts.Token);
|
||||
|
||||
// Start read pump and command processing in parallel
|
||||
// Start read pump, command processing, and ping timer in parallel
|
||||
var fillTask = FillPipeAsync(pipe.Writer, _clientCts.Token);
|
||||
var processTask = ProcessCommandsAsync(pipe.Reader, _clientCts.Token);
|
||||
var pingTask = RunPingTimerAsync(_clientCts.Token);
|
||||
|
||||
await Task.WhenAny(fillTask, processTask);
|
||||
await Task.WhenAny(fillTask, processTask, pingTask);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
@@ -82,15 +88,10 @@ public sealed class NatsClient : IDisposable
|
||||
}
|
||||
finally
|
||||
{
|
||||
try { _socket.Shutdown(SocketShutdown.Both); }
|
||||
catch (SocketException) { }
|
||||
catch (ObjectDisposedException) { }
|
||||
Router?.RemoveClient(this);
|
||||
try
|
||||
{
|
||||
_socket.Shutdown(SocketShutdown.Both);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogDebug(ex, "Client {ClientId} socket shutdown error", Id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,6 +129,7 @@ public sealed class NatsClient : IDisposable
|
||||
|
||||
while (_parser.TryParse(ref buffer, out var cmd))
|
||||
{
|
||||
Interlocked.Exchange(ref _lastIn, Environment.TickCount64);
|
||||
await DispatchCommandAsync(cmd, ct);
|
||||
}
|
||||
|
||||
@@ -156,7 +158,7 @@ public sealed class NatsClient : IDisposable
|
||||
break;
|
||||
|
||||
case CommandType.Pong:
|
||||
// Update RTT tracking (placeholder)
|
||||
Interlocked.Exchange(ref _pingsOut, 0);
|
||||
break;
|
||||
|
||||
case CommandType.Sub:
|
||||
@@ -339,6 +341,48 @@ public sealed class NatsClient : IDisposable
|
||||
_socket.Close();
|
||||
}
|
||||
|
||||
private async Task RunPingTimerAsync(CancellationToken ct)
|
||||
{
|
||||
using var timer = new PeriodicTimer(_options.PingInterval);
|
||||
try
|
||||
{
|
||||
while (await timer.WaitForNextTickAsync(ct))
|
||||
{
|
||||
var elapsed = Environment.TickCount64 - Interlocked.Read(ref _lastIn);
|
||||
if (elapsed < (long)_options.PingInterval.TotalMilliseconds)
|
||||
{
|
||||
// Client was recently active, skip ping
|
||||
Interlocked.Exchange(ref _pingsOut, 0);
|
||||
continue;
|
||||
}
|
||||
|
||||
var currentPingsOut = Interlocked.Increment(ref _pingsOut);
|
||||
if (currentPingsOut > _options.MaxPingsOut)
|
||||
{
|
||||
_logger.LogDebug("Client {ClientId} stale connection — closing", Id);
|
||||
await SendErrAndCloseAsync(NatsProtocol.ErrStaleConnection);
|
||||
return;
|
||||
}
|
||||
|
||||
_logger.LogDebug("Client {ClientId} sending PING ({PingsOut}/{MaxPingsOut})",
|
||||
Id, currentPingsOut, _options.MaxPingsOut);
|
||||
try
|
||||
{
|
||||
await WriteAsync(NatsProtocol.PingBytes, ct);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogDebug(ex, "Client {ClientId} failed to send PING", Id);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Normal shutdown
|
||||
}
|
||||
}
|
||||
|
||||
public void RemoveAllSubscriptions(SubList subList)
|
||||
{
|
||||
foreach (var sub in _subs.Values)
|
||||
|
||||
@@ -278,3 +278,148 @@ public class MaxConnectionsTests : IAsyncLifetime
|
||||
client3.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
public class PingKeepaliveTests : IAsyncLifetime
|
||||
{
|
||||
private readonly NatsServer _server;
|
||||
private readonly int _port;
|
||||
private readonly CancellationTokenSource _cts = new();
|
||||
|
||||
public PingKeepaliveTests()
|
||||
{
|
||||
_port = GetFreePort();
|
||||
// Short intervals for testing: 500ms ping interval, 2 max pings out
|
||||
_server = new NatsServer(
|
||||
new NatsOptions
|
||||
{
|
||||
Port = _port,
|
||||
PingInterval = TimeSpan.FromMilliseconds(500),
|
||||
MaxPingsOut = 2,
|
||||
},
|
||||
NullLoggerFactory.Instance);
|
||||
}
|
||||
|
||||
public async Task InitializeAsync()
|
||||
{
|
||||
_ = _server.StartAsync(_cts.Token);
|
||||
await _server.WaitForReadyAsync();
|
||||
}
|
||||
|
||||
public async Task DisposeAsync()
|
||||
{
|
||||
await _cts.CancelAsync();
|
||||
_server.Dispose();
|
||||
}
|
||||
|
||||
private static int GetFreePort()
|
||||
{
|
||||
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
|
||||
return ((IPEndPoint)sock.LocalEndPoint!).Port;
|
||||
}
|
||||
|
||||
private static async Task<string> ReadUntilAsync(Socket sock, string expected, int timeoutMs = 5000)
|
||||
{
|
||||
using var cts = new CancellationTokenSource(timeoutMs);
|
||||
var sb = new StringBuilder();
|
||||
var buf = new byte[4096];
|
||||
while (!sb.ToString().Contains(expected))
|
||||
{
|
||||
var n = await sock.ReceiveAsync(buf, SocketFlags.None, cts.Token);
|
||||
if (n == 0) break;
|
||||
sb.Append(Encoding.ASCII.GetString(buf, 0, n));
|
||||
}
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Server_sends_PING_after_inactivity()
|
||||
{
|
||||
var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
await client.ConnectAsync(IPAddress.Loopback, _port);
|
||||
|
||||
// Read INFO
|
||||
var buf = new byte[4096];
|
||||
await client.ReceiveAsync(buf, SocketFlags.None);
|
||||
|
||||
// Send CONNECT to start keepalive
|
||||
await client.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\n"));
|
||||
|
||||
// Wait for server to send PING (should come within ~500ms)
|
||||
var response = await ReadUntilAsync(client, "PING", timeoutMs: 3000);
|
||||
response.ShouldContain("PING");
|
||||
|
||||
client.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Server_pong_resets_ping_counter()
|
||||
{
|
||||
var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
await client.ConnectAsync(IPAddress.Loopback, _port);
|
||||
|
||||
var buf = new byte[4096];
|
||||
await client.ReceiveAsync(buf, SocketFlags.None); // INFO
|
||||
|
||||
await client.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\n"));
|
||||
|
||||
// Wait for first PING
|
||||
var response = await ReadUntilAsync(client, "PING", timeoutMs: 3000);
|
||||
response.ShouldContain("PING");
|
||||
|
||||
// Respond with PONG — this resets the counter
|
||||
await client.SendAsync(Encoding.ASCII.GetBytes("PONG\r\n"));
|
||||
|
||||
// Wait for next PING (counter reset, so we should get another one)
|
||||
response = await ReadUntilAsync(client, "PING", timeoutMs: 3000);
|
||||
response.ShouldContain("PING");
|
||||
|
||||
// Respond again to keep alive
|
||||
await client.SendAsync(Encoding.ASCII.GetBytes("PONG\r\n"));
|
||||
|
||||
// Client should still be alive — send a PING and expect PONG back
|
||||
await client.SendAsync(Encoding.ASCII.GetBytes("PING\r\n"));
|
||||
response = await ReadUntilAsync(client, "PONG", timeoutMs: 3000);
|
||||
response.ShouldContain("PONG");
|
||||
|
||||
client.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Server_disconnects_stale_client()
|
||||
{
|
||||
var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
await client.ConnectAsync(IPAddress.Loopback, _port);
|
||||
|
||||
var buf = new byte[4096];
|
||||
await client.ReceiveAsync(buf, SocketFlags.None); // INFO
|
||||
|
||||
await client.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\n"));
|
||||
|
||||
// Don't respond to PINGs — wait for stale disconnect
|
||||
// With 500ms interval and MaxPingsOut=2:
|
||||
// t=500ms: PING #1, pingsOut=1
|
||||
// t=1000ms: PING #2, pingsOut=2
|
||||
// t=1500ms: pingsOut+1 > MaxPingsOut → -ERR 'Stale Connection' + close
|
||||
var sb = new StringBuilder();
|
||||
try
|
||||
{
|
||||
using var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(5));
|
||||
while (true)
|
||||
{
|
||||
var n = await client.ReceiveAsync(buf, SocketFlags.None, timeout.Token);
|
||||
if (n == 0) break;
|
||||
sb.Append(Encoding.ASCII.GetString(buf, 0, n));
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Timeout is acceptable — check what we got
|
||||
}
|
||||
|
||||
var allData = sb.ToString();
|
||||
allData.ShouldContain("-ERR 'Stale Connection'");
|
||||
|
||||
client.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user