using System.Net; using System.Net.Sockets; using System.Text; using Microsoft.Extensions.Logging.Abstractions; using NATS.Server; using NATS.Server.Auth; using NATS.Server.Protocol; namespace NATS.Server.Core.Tests; public class WriteLoopTests : IAsyncDisposable { private readonly Socket _serverSocket; private readonly Socket _clientSocket; private readonly CancellationTokenSource _cts = new(); private readonly ServerInfo _serverInfo = new() { ServerId = "test", ServerName = "test", Version = "0.1.0", Host = "127.0.0.1", Port = 4222, }; /// /// Creates a connected socket pair via loopback and returns both sockets. /// private static (Socket serverSocket, Socket clientSocket) CreateSocketPair() { var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); listener.Listen(1); var port = ((IPEndPoint)listener.LocalEndPoint!).Port; var clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); clientSocket.Connect(IPAddress.Loopback, port); var serverSocket = listener.Accept(); listener.Dispose(); return (serverSocket, clientSocket); } public WriteLoopTests() { (_serverSocket, _clientSocket) = CreateSocketPair(); } public async ValueTask DisposeAsync() { await _cts.CancelAsync(); _serverSocket.Dispose(); _clientSocket.Dispose(); } private NatsClient CreateClient(NatsOptions? options = null) { options ??= new NatsOptions(); var authService = AuthService.Build(options); return new NatsClient( 1, new NetworkStream(_serverSocket, ownsSocket: false), _serverSocket, options, _serverInfo, authService, null, NullLogger.Instance, new ServerStats()); } /// /// Reads all available data from the client socket until timeout, accumulating it into a string. /// private async Task ReadFromClientSocketAsync(int bufferSize = 8192, int timeoutMs = 5000) { var buf = new byte[bufferSize]; using var readCts = new CancellationTokenSource(timeoutMs); var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token); return Encoding.ASCII.GetString(buf, 0, n); } /// /// Reads and discards the INFO line that is sent on client startup. /// private async Task ConsumeInfoAsync() { var response = await ReadFromClientSocketAsync(); response.ShouldStartWith("INFO "); } [Fact] public async Task QueueOutbound_writes_data_to_client() { using var natsClient = CreateClient(); var runTask = natsClient.RunAsync(_cts.Token); // Read and discard the initial INFO message await ConsumeInfoAsync(); // Queue some data via QueueOutbound var testData = "PING\r\n"u8.ToArray(); var result = natsClient.QueueOutbound(testData); result.ShouldBeTrue(); // Read the data from the client socket var received = await ReadFromClientSocketAsync(); received.ShouldBe("PING\r\n"); await _cts.CancelAsync(); } [Fact] public async Task QueueOutbound_writes_multiple_messages_to_client() { using var natsClient = CreateClient(); var runTask = natsClient.RunAsync(_cts.Token); // Read and discard the initial INFO message await ConsumeInfoAsync(); // Queue multiple messages natsClient.QueueOutbound("MSG foo 1 5\r\nhello\r\n"u8.ToArray()); natsClient.QueueOutbound("MSG bar 2 5\r\nworld\r\n"u8.ToArray()); // Read all data from the socket -- may arrive in one or two reads var sb = new StringBuilder(); var buf = new byte[8192]; using var readCts = new CancellationTokenSource(5000); while (sb.Length < "MSG foo 1 5\r\nhello\r\nMSG bar 2 5\r\nworld\r\n".Length) { var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token); sb.Append(Encoding.ASCII.GetString(buf, 0, n)); } var received = sb.ToString(); received.ShouldContain("MSG foo 1 5\r\nhello\r\n"); received.ShouldContain("MSG bar 2 5\r\nworld\r\n"); await _cts.CancelAsync(); } [Fact] public async Task SlowConsumer_closes_when_pending_exceeds_max() { // Use a very small MaxPending so we can easily exceed it var options = new NatsOptions { MaxPending = 1024 }; using var natsClient = CreateClient(options); var runTask = natsClient.RunAsync(_cts.Token); // Read and discard the initial INFO message await ConsumeInfoAsync(); // Queue data that exceeds MaxPending. The first call adds to pending, then // subsequent calls should eventually trigger slow consumer detection. var largeData = new byte[2048]; Array.Fill(largeData, (byte)'X'); var queued = natsClient.QueueOutbound(largeData); // The QueueOutbound should return false since data exceeds MaxPending queued.ShouldBeFalse(); // CloseReason should be set to SlowConsumerPendingBytes // Give a small delay for the async close to propagate await Task.Delay(200); natsClient.CloseReason.ShouldBe(ClientClosedReason.SlowConsumerPendingBytes); await _cts.CancelAsync(); } [Fact] public async Task SlowConsumer_sets_reason_before_closing_connection() { // Use very small MaxPending var options = new NatsOptions { MaxPending = 512 }; using var natsClient = CreateClient(options); var runTask = natsClient.RunAsync(_cts.Token); // Read and discard the initial INFO message await ConsumeInfoAsync(); // Try to queue data that exceeds MaxPending var largeData = new byte[1024]; Array.Fill(largeData, (byte)'Y'); natsClient.QueueOutbound(largeData); // Wait for the close to propagate await Task.Delay(200); natsClient.CloseReason.ShouldBe(ClientClosedReason.SlowConsumerPendingBytes); // The connection should eventually be closed -- ReceiveAsync returns 0 var buf = new byte[4096]; using var readCts = new CancellationTokenSource(5000); try { // Read any remaining data (the -ERR 'Slow Consumer' message) then expect 0 var total = 0; while (true) { var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token); if (n == 0) break; total += n; } // We should get here -- the socket was closed by the server true.ShouldBeTrue(); } catch (SocketException) { // Also acceptable -- connection reset true.ShouldBeTrue(); } await _cts.CancelAsync(); } [Fact] public async Task PendingBytes_tracks_queued_data() { using var natsClient = CreateClient(); var runTask = natsClient.RunAsync(_cts.Token); // Read and discard the initial INFO message -- note that the write loop // processes the INFO so by the time we've read it, PendingBytes for INFO // should have been decremented. await ConsumeInfoAsync(); // After INFO has been written and flushed, pending bytes should be back to 0 // (give a tiny delay for the write loop to decrement) await Task.Delay(100); natsClient.PendingBytes.ShouldBe(0); // Now queue known data var data1 = new byte[100]; var data2 = new byte[200]; natsClient.QueueOutbound(data1); // PendingBytes should reflect the queued but not-yet-written data. // However, the write loop is running concurrently so it may drain quickly. // At a minimum, after queueing data1 PendingBytes should be >= 0 // (it could have been written already). Let's queue both in quick succession // and check the combined pending. natsClient.QueueOutbound(data2); // PendingBytes should be at least 0 (write loop may have already drained) natsClient.PendingBytes.ShouldBeGreaterThanOrEqualTo(0); // Now let the write loop drain everything and verify it returns to 0 await Task.Delay(500); natsClient.PendingBytes.ShouldBe(0); await _cts.CancelAsync(); } [Fact] public async Task PendingBytes_increases_with_queued_data_before_drain() { // To observe PendingBytes before the write loop drains it, we create // a scenario where we can check the value atomically after queuing. // We use a fresh client and check PendingBytes right after QueueOutbound. using var natsClient = CreateClient(); var runTask = natsClient.RunAsync(_cts.Token); // Read the INFO to avoid backpressure await ConsumeInfoAsync(); await Task.Delay(100); // Queue a message and check that PendingBytes is non-negative // (the write loop may drain it very quickly, so we can't guarantee a specific value, // but we can verify the property works and eventually returns to 0) var data = new byte[500]; natsClient.QueueOutbound(data); // PendingBytes is either 500 (not yet drained) or 0 (already drained) -- both valid var pending = natsClient.PendingBytes; pending.ShouldBeOneOf(0L, 500L); // After draining, it should be 0 await Task.Delay(500); natsClient.PendingBytes.ShouldBe(0); // Read and discard the written data from the socket so it doesn't block var buf = new byte[8192]; using var readCts = new CancellationTokenSource(2000); try { await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token); } catch (OperationCanceledException) { // OK -- may have already been consumed } await _cts.CancelAsync(); } [Fact] public async Task QueueOutbound_returns_false_when_client_is_closed() { var options = new NatsOptions { MaxPending = 512 }; using var natsClient = CreateClient(options); var runTask = natsClient.RunAsync(_cts.Token); // Read INFO await ConsumeInfoAsync(); // Trigger slow consumer close by exceeding MaxPending var largeData = new byte[1024]; natsClient.QueueOutbound(largeData); // Wait for close to propagate await Task.Delay(200); // Subsequent QueueOutbound calls should return false var result = natsClient.QueueOutbound("PING\r\n"u8.ToArray()); result.ShouldBeFalse(); await _cts.CancelAsync(); } [Fact] public async Task SlowConsumer_increments_server_stats() { var options = new NatsOptions { MaxPending = 512 }; var stats = new ServerStats(); var authService = AuthService.Build(options); using var natsClient = new NatsClient( 1, new NetworkStream(_serverSocket, ownsSocket: false), _serverSocket, options, _serverInfo, authService, null, NullLogger.Instance, stats); var runTask = natsClient.RunAsync(_cts.Token); // Read INFO await ConsumeInfoAsync(); // Initial stats should be 0 Interlocked.Read(ref stats.SlowConsumers).ShouldBe(0); Interlocked.Read(ref stats.SlowConsumerClients).ShouldBe(0); // Trigger slow consumer var largeData = new byte[1024]; natsClient.QueueOutbound(largeData); // Wait for propagation await Task.Delay(200); // Stats should be incremented Interlocked.Read(ref stats.SlowConsumers).ShouldBeGreaterThan(0); Interlocked.Read(ref stats.SlowConsumerClients).ShouldBeGreaterThan(0); await _cts.CancelAsync(); } }