diff --git a/tests/NATS.Server.Tests/WriteLoopTests.cs b/tests/NATS.Server.Tests/WriteLoopTests.cs new file mode 100644 index 0000000..c27b1e4 --- /dev/null +++ b/tests/NATS.Server.Tests/WriteLoopTests.cs @@ -0,0 +1,363 @@ +using System.Net; +using System.Net.Sockets; +using System.Text; +using Microsoft.Extensions.Logging.Abstractions; +using NATS.Server; +using NATS.Server.Auth; +using NATS.Server.Protocol; + +namespace NATS.Server.Tests; + +public class WriteLoopTests : IAsyncDisposable +{ + private readonly Socket _serverSocket; + private readonly Socket _clientSocket; + private readonly CancellationTokenSource _cts = new(); + + private readonly ServerInfo _serverInfo = new() + { + ServerId = "test", + ServerName = "test", + Version = "0.1.0", + Host = "127.0.0.1", + Port = 4222, + }; + + /// + /// Creates a connected socket pair via loopback and returns both sockets. + /// + private static (Socket serverSocket, Socket clientSocket) CreateSocketPair() + { + var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + listener.Listen(1); + var port = ((IPEndPoint)listener.LocalEndPoint!).Port; + + var clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + clientSocket.Connect(IPAddress.Loopback, port); + var serverSocket = listener.Accept(); + listener.Dispose(); + + return (serverSocket, clientSocket); + } + + public WriteLoopTests() + { + (_serverSocket, _clientSocket) = CreateSocketPair(); + } + + public async ValueTask DisposeAsync() + { + await _cts.CancelAsync(); + _serverSocket.Dispose(); + _clientSocket.Dispose(); + } + + private NatsClient CreateClient(NatsOptions? options = null) + { + options ??= new NatsOptions(); + var authService = AuthService.Build(options); + return new NatsClient( + 1, + new NetworkStream(_serverSocket, ownsSocket: false), + _serverSocket, + options, + _serverInfo, + authService, + null, + NullLogger.Instance, + new ServerStats()); + } + + /// + /// Reads all available data from the client socket until timeout, accumulating it into a string. + /// + private async Task ReadFromClientSocketAsync(int bufferSize = 8192, int timeoutMs = 5000) + { + var buf = new byte[bufferSize]; + using var readCts = new CancellationTokenSource(timeoutMs); + var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token); + return Encoding.ASCII.GetString(buf, 0, n); + } + + /// + /// Reads and discards the INFO line that is sent on client startup. + /// + private async Task ConsumeInfoAsync() + { + var response = await ReadFromClientSocketAsync(); + response.ShouldStartWith("INFO "); + } + + [Fact] + public async Task QueueOutbound_writes_data_to_client() + { + using var natsClient = CreateClient(); + var runTask = natsClient.RunAsync(_cts.Token); + + // Read and discard the initial INFO message + await ConsumeInfoAsync(); + + // Queue some data via QueueOutbound + var testData = "PING\r\n"u8.ToArray(); + var result = natsClient.QueueOutbound(testData); + result.ShouldBeTrue(); + + // Read the data from the client socket + var received = await ReadFromClientSocketAsync(); + received.ShouldBe("PING\r\n"); + + await _cts.CancelAsync(); + } + + [Fact] + public async Task QueueOutbound_writes_multiple_messages_to_client() + { + using var natsClient = CreateClient(); + var runTask = natsClient.RunAsync(_cts.Token); + + // Read and discard the initial INFO message + await ConsumeInfoAsync(); + + // Queue multiple messages + natsClient.QueueOutbound("MSG foo 1 5\r\nhello\r\n"u8.ToArray()); + natsClient.QueueOutbound("MSG bar 2 5\r\nworld\r\n"u8.ToArray()); + + // Read all data from the socket -- may arrive in one or two reads + var sb = new StringBuilder(); + var buf = new byte[8192]; + using var readCts = new CancellationTokenSource(5000); + while (sb.Length < "MSG foo 1 5\r\nhello\r\nMSG bar 2 5\r\nworld\r\n".Length) + { + var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token); + sb.Append(Encoding.ASCII.GetString(buf, 0, n)); + } + + var received = sb.ToString(); + received.ShouldContain("MSG foo 1 5\r\nhello\r\n"); + received.ShouldContain("MSG bar 2 5\r\nworld\r\n"); + + await _cts.CancelAsync(); + } + + [Fact] + public async Task SlowConsumer_closes_when_pending_exceeds_max() + { + // Use a very small MaxPending so we can easily exceed it + var options = new NatsOptions { MaxPending = 1024 }; + using var natsClient = CreateClient(options); + var runTask = natsClient.RunAsync(_cts.Token); + + // Read and discard the initial INFO message + await ConsumeInfoAsync(); + + // Queue data that exceeds MaxPending. The first call adds to pending, then + // subsequent calls should eventually trigger slow consumer detection. + var largeData = new byte[2048]; + Array.Fill(largeData, (byte)'X'); + + var queued = natsClient.QueueOutbound(largeData); + + // The QueueOutbound should return false since data exceeds MaxPending + queued.ShouldBeFalse(); + + // CloseReason should be set to SlowConsumerPendingBytes + // Give a small delay for the async close to propagate + await Task.Delay(200); + natsClient.CloseReason.ShouldBe(ClientClosedReason.SlowConsumerPendingBytes); + + await _cts.CancelAsync(); + } + + [Fact] + public async Task SlowConsumer_sets_reason_before_closing_connection() + { + // Use very small MaxPending + var options = new NatsOptions { MaxPending = 512 }; + using var natsClient = CreateClient(options); + var runTask = natsClient.RunAsync(_cts.Token); + + // Read and discard the initial INFO message + await ConsumeInfoAsync(); + + // Try to queue data that exceeds MaxPending + var largeData = new byte[1024]; + Array.Fill(largeData, (byte)'Y'); + natsClient.QueueOutbound(largeData); + + // Wait for the close to propagate + await Task.Delay(200); + + natsClient.CloseReason.ShouldBe(ClientClosedReason.SlowConsumerPendingBytes); + + // The connection should eventually be closed -- ReceiveAsync returns 0 + var buf = new byte[4096]; + using var readCts = new CancellationTokenSource(5000); + try + { + // Read any remaining data (the -ERR 'Slow Consumer' message) then expect 0 + var total = 0; + while (true) + { + var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token); + if (n == 0) + break; + total += n; + } + + // We should get here -- the socket was closed by the server + true.ShouldBeTrue(); + } + catch (SocketException) + { + // Also acceptable -- connection reset + true.ShouldBeTrue(); + } + + await _cts.CancelAsync(); + } + + [Fact] + public async Task PendingBytes_tracks_queued_data() + { + using var natsClient = CreateClient(); + var runTask = natsClient.RunAsync(_cts.Token); + + // Read and discard the initial INFO message -- note that the write loop + // processes the INFO so by the time we've read it, PendingBytes for INFO + // should have been decremented. + await ConsumeInfoAsync(); + + // After INFO has been written and flushed, pending bytes should be back to 0 + // (give a tiny delay for the write loop to decrement) + await Task.Delay(100); + natsClient.PendingBytes.ShouldBe(0); + + // Now queue known data + var data1 = new byte[100]; + var data2 = new byte[200]; + natsClient.QueueOutbound(data1); + + // PendingBytes should reflect the queued but not-yet-written data. + // However, the write loop is running concurrently so it may drain quickly. + // At a minimum, after queueing data1 PendingBytes should be >= 0 + // (it could have been written already). Let's queue both in quick succession + // and check the combined pending. + natsClient.QueueOutbound(data2); + + // PendingBytes should be at least 0 (write loop may have already drained) + natsClient.PendingBytes.ShouldBeGreaterThanOrEqualTo(0); + + // Now let the write loop drain everything and verify it returns to 0 + await Task.Delay(500); + natsClient.PendingBytes.ShouldBe(0); + + await _cts.CancelAsync(); + } + + [Fact] + public async Task PendingBytes_increases_with_queued_data_before_drain() + { + // To observe PendingBytes before the write loop drains it, we create + // a scenario where we can check the value atomically after queuing. + // We use a fresh client and check PendingBytes right after QueueOutbound. + using var natsClient = CreateClient(); + var runTask = natsClient.RunAsync(_cts.Token); + + // Read the INFO to avoid backpressure + await ConsumeInfoAsync(); + await Task.Delay(100); + + // Queue a message and check that PendingBytes is non-negative + // (the write loop may drain it very quickly, so we can't guarantee a specific value, + // but we can verify the property works and eventually returns to 0) + var data = new byte[500]; + natsClient.QueueOutbound(data); + + // PendingBytes is either 500 (not yet drained) or 0 (already drained) -- both valid + var pending = natsClient.PendingBytes; + pending.ShouldBeOneOf(0L, 500L); + + // After draining, it should be 0 + await Task.Delay(500); + natsClient.PendingBytes.ShouldBe(0); + + // Read and discard the written data from the socket so it doesn't block + var buf = new byte[8192]; + using var readCts = new CancellationTokenSource(2000); + try + { + await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token); + } + catch (OperationCanceledException) + { + // OK -- may have already been consumed + } + + await _cts.CancelAsync(); + } + + [Fact] + public async Task QueueOutbound_returns_false_when_client_is_closed() + { + var options = new NatsOptions { MaxPending = 512 }; + using var natsClient = CreateClient(options); + var runTask = natsClient.RunAsync(_cts.Token); + + // Read INFO + await ConsumeInfoAsync(); + + // Trigger slow consumer close by exceeding MaxPending + var largeData = new byte[1024]; + natsClient.QueueOutbound(largeData); + + // Wait for close to propagate + await Task.Delay(200); + + // Subsequent QueueOutbound calls should return false + var result = natsClient.QueueOutbound("PING\r\n"u8.ToArray()); + result.ShouldBeFalse(); + + await _cts.CancelAsync(); + } + + [Fact] + public async Task SlowConsumer_increments_server_stats() + { + var options = new NatsOptions { MaxPending = 512 }; + var stats = new ServerStats(); + var authService = AuthService.Build(options); + using var natsClient = new NatsClient( + 1, + new NetworkStream(_serverSocket, ownsSocket: false), + _serverSocket, + options, + _serverInfo, + authService, + null, + NullLogger.Instance, + stats); + + var runTask = natsClient.RunAsync(_cts.Token); + + // Read INFO + await ConsumeInfoAsync(); + + // Initial stats should be 0 + Interlocked.Read(ref stats.SlowConsumers).ShouldBe(0); + Interlocked.Read(ref stats.SlowConsumerClients).ShouldBe(0); + + // Trigger slow consumer + var largeData = new byte[1024]; + natsClient.QueueOutbound(largeData); + + // Wait for propagation + await Task.Delay(200); + + // Stats should be incremented + Interlocked.Read(ref stats.SlowConsumers).ShouldBeGreaterThan(0); + Interlocked.Read(ref stats.SlowConsumerClients).ShouldBeGreaterThan(0); + + await _cts.CancelAsync(); + } +}