From 31660a41874e55ce2cc56b238a49202986e35b5b Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 22 Feb 2026 23:41:44 -0500 Subject: [PATCH] feat: replace inline writes with channel-based write loop and batch flush --- src/NATS.Server/NatsClient.cs | 180 ++++++++++++++++--------- src/NATS.Server/NatsServer.cs | 5 +- tests/NATS.Server.Tests/ClientTests.cs | 4 +- 3 files changed, 118 insertions(+), 71 deletions(-) diff --git a/src/NATS.Server/NatsClient.cs b/src/NATS.Server/NatsClient.cs index e81e450..1ca6c97 100644 --- a/src/NATS.Server/NatsClient.cs +++ b/src/NATS.Server/NatsClient.cs @@ -5,6 +5,7 @@ using System.Net.Sockets; using System.Security.Cryptography; using System.Text; using System.Text.Json; +using System.Threading.Channels; using Microsoft.Extensions.Logging; using NATS.Server.Auth; using NATS.Server.Protocol; @@ -34,7 +35,9 @@ public sealed class NatsClient : IDisposable private readonly AuthService _authService; private readonly byte[]? _nonce; private readonly NatsParser _parser; - private readonly SemaphoreSlim _writeLock = new(1, 1); + private readonly Channel> _outbound = Channel.CreateBounded>( + new BoundedChannelOptions(8192) { SingleReader = true, FullMode = BoundedChannelFullMode.Wait }); + private long _pendingBytes; private CancellationTokenSource? _clientCts; private readonly Dictionary _subs = new(); private readonly ILogger _logger; @@ -93,6 +96,37 @@ public sealed class NatsClient : IDisposable } } + public bool QueueOutbound(ReadOnlyMemory data) + { + if (_flags.HasFlag(ClientFlags.CloseConnection)) + return false; + + var pending = Interlocked.Add(ref _pendingBytes, data.Length); + if (pending > _options.MaxPending) + { + Interlocked.Add(ref _pendingBytes, -data.Length); + _flags.SetFlag(ClientFlags.IsSlowConsumer); + Interlocked.Increment(ref _serverStats.SlowConsumers); + Interlocked.Increment(ref _serverStats.SlowConsumerClients); + _ = CloseWithReasonAsync(ClientClosedReason.SlowConsumerPendingBytes, NatsProtocol.ErrSlowConsumer); + return false; + } + + if (!_outbound.Writer.TryWrite(data)) + { + Interlocked.Add(ref _pendingBytes, -data.Length); + _flags.SetFlag(ClientFlags.IsSlowConsumer); + Interlocked.Increment(ref _serverStats.SlowConsumers); + Interlocked.Increment(ref _serverStats.SlowConsumerClients); + _ = CloseWithReasonAsync(ClientClosedReason.SlowConsumerPendingBytes, NatsProtocol.ErrSlowConsumer); + return false; + } + + return true; + } + + public long PendingBytes => Interlocked.Read(ref _pendingBytes); + public async Task RunAsync(CancellationToken ct) { _clientCts = CancellationTokenSource.CreateLinkedTokenSource(ct); @@ -102,7 +136,7 @@ public sealed class NatsClient : IDisposable { // Send INFO (skip if already sent during TLS negotiation) if (!InfoAlreadySent) - await SendInfoAsync(_clientCts.Token); + SendInfo(); // Start auth timeout if auth is required Task? authTimeoutTask = null; @@ -126,15 +160,16 @@ public sealed class NatsClient : IDisposable }, _clientCts.Token); } - // Start read pump, command processing, and ping timer in parallel + // Start read pump, command processing, write loop, and ping timer in parallel var fillTask = FillPipeAsync(pipe.Writer, _clientCts.Token); var processTask = ProcessCommandsAsync(pipe.Reader, _clientCts.Token); var pingTask = RunPingTimerAsync(_clientCts.Token); + var writeTask = RunWriteLoopAsync(_clientCts.Token); if (authTimeoutTask != null) - await Task.WhenAny(fillTask, processTask, pingTask, authTimeoutTask); + await Task.WhenAny(fillTask, processTask, pingTask, writeTask, authTimeoutTask); else - await Task.WhenAny(fillTask, processTask, pingTask); + await Task.WhenAny(fillTask, processTask, pingTask, writeTask); } catch (OperationCanceledException) { @@ -146,6 +181,7 @@ public sealed class NatsClient : IDisposable } finally { + _outbound.Writer.TryComplete(); try { _socket.Shutdown(SocketShutdown.Both); } catch (SocketException) { } catch (ObjectDisposedException) { } @@ -217,7 +253,7 @@ public sealed class NatsClient : IDisposable await ProcessConnectAsync(cmd); return; case CommandType.Ping: - await WriteAsync(NatsProtocol.PongBytes, ct); + WriteProtocol(NatsProtocol.PongBytes); return; default: // Ignore all other commands until authenticated @@ -232,7 +268,7 @@ public sealed class NatsClient : IDisposable break; case CommandType.Ping: - await WriteAsync(NatsProtocol.PongBytes, ct); + WriteProtocol(NatsProtocol.PongBytes); break; case CommandType.Pong: @@ -240,7 +276,7 @@ public sealed class NatsClient : IDisposable break; case CommandType.Sub: - await ProcessSubAsync(cmd); + ProcessSub(cmd); break; case CommandType.Unsub: @@ -249,7 +285,7 @@ public sealed class NatsClient : IDisposable case CommandType.Pub: case CommandType.HPub: - await ProcessPubAsync(cmd); + ProcessPub(cmd); break; } } @@ -306,13 +342,13 @@ public sealed class NatsClient : IDisposable _logger.LogDebug("CONNECT received from client {ClientId}, name={ClientName}", Id, ClientOpts?.Name); } - private async ValueTask ProcessSubAsync(ParsedCommand cmd) + private void ProcessSub(ParsedCommand cmd) { // Permission check for subscribe if (_permissions != null && !_permissions.IsSubscribeAllowed(cmd.Subject!, cmd.Queue)) { _logger.LogDebug("Client {ClientId} subscribe permission denied for {Subject}", Id, cmd.Subject); - await SendErrAsync(NatsProtocol.ErrPermissionsSubscribe); + SendErr(NatsProtocol.ErrPermissionsSubscribe); return; } @@ -350,7 +386,7 @@ public sealed class NatsClient : IDisposable Account?.SubList.Remove(sub); } - private async ValueTask ProcessPubAsync(ParsedCommand cmd) + private void ProcessPub(ParsedCommand cmd) { Interlocked.Increment(ref InMsgs); Interlocked.Add(ref InBytes, cmd.Payload.Length); @@ -362,7 +398,7 @@ public sealed class NatsClient : IDisposable { _logger.LogWarning("Client {ClientId} exceeded max payload: {Size} > {MaxPayload}", Id, cmd.Payload.Length, _options.MaxPayload); - await SendErrAndCloseAsync(NatsProtocol.ErrMaxPayloadViolation, ClientClosedReason.MaxPayloadExceeded); + _ = SendErrAndCloseAsync(NatsProtocol.ErrMaxPayloadViolation, ClientClosedReason.MaxPayloadExceeded); return; } @@ -370,7 +406,7 @@ public sealed class NatsClient : IDisposable if (ClientOpts?.Pedantic == true && !SubjectMatch.IsValidPublishSubject(cmd.Subject!)) { _logger.LogDebug("Client {ClientId} invalid publish subject: {Subject}", Id, cmd.Subject); - await SendErrAsync(NatsProtocol.ErrInvalidPublishSubject); + SendErr(NatsProtocol.ErrInvalidPublishSubject); return; } @@ -378,7 +414,7 @@ public sealed class NatsClient : IDisposable if (_permissions != null && !_permissions.IsPublishAllowed(cmd.Subject!)) { _logger.LogDebug("Client {ClientId} publish permission denied for {Subject}", Id, cmd.Subject); - await SendErrAsync(NatsProtocol.ErrPermissionsPublish); + SendErr(NatsProtocol.ErrPermissionsPublish); return; } @@ -394,15 +430,15 @@ public sealed class NatsClient : IDisposable Router?.ProcessMessage(cmd.Subject!, cmd.ReplyTo, headers, payload, this); } - private async Task SendInfoAsync(CancellationToken ct) + private void SendInfo() { var infoJson = JsonSerializer.Serialize(_serverInfo); var infoLine = Encoding.ASCII.GetBytes($"INFO {infoJson}\r\n"); - await WriteAsync(infoLine, ct); + QueueOutbound(infoLine); } - public async Task SendMessageAsync(string subject, string sid, string? replyTo, - ReadOnlyMemory headers, ReadOnlyMemory payload, CancellationToken ct) + public void SendMessage(string subject, string sid, string? replyTo, + ReadOnlyMemory headers, ReadOnlyMemory payload) { Interlocked.Increment(ref OutMsgs); Interlocked.Add(ref OutBytes, payload.Length + headers.Length); @@ -420,55 +456,68 @@ public sealed class NatsClient : IDisposable line = Encoding.ASCII.GetBytes($"MSG {subject} {sid} {(replyTo != null ? replyTo + " " : "")}{payload.Length}\r\n"); } - await _writeLock.WaitAsync(ct); - try - { - await _stream.WriteAsync(line, ct); - if (headers.Length > 0) - await _stream.WriteAsync(headers, ct); - if (payload.Length > 0) - await _stream.WriteAsync(payload, ct); - await _stream.WriteAsync(NatsProtocol.CrLf, ct); - await _stream.FlushAsync(ct); - } - finally - { - _writeLock.Release(); - } + var totalLen = line.Length + headers.Length + payload.Length + NatsProtocol.CrLf.Length; + var msg = new byte[totalLen]; + var offset = 0; + line.CopyTo(msg.AsSpan(offset)); offset += line.Length; + if (headers.Length > 0) { headers.Span.CopyTo(msg.AsSpan(offset)); offset += headers.Length; } + if (payload.Length > 0) { payload.Span.CopyTo(msg.AsSpan(offset)); offset += payload.Length; } + NatsProtocol.CrLf.CopyTo(msg.AsSpan(offset)); + + QueueOutbound(msg); } - private async Task WriteAsync(byte[] data, CancellationToken ct) + private void WriteProtocol(byte[] data) { - await _writeLock.WaitAsync(ct); - try - { - await _stream.WriteAsync(data, ct); - await _stream.FlushAsync(ct); - } - finally - { - _writeLock.Release(); - } + QueueOutbound(data); } - public async Task SendErrAsync(string message) + public void SendErr(string message) { var errLine = Encoding.ASCII.GetBytes($"-ERR '{message}'\r\n"); + QueueOutbound(errLine); + } + + private async Task RunWriteLoopAsync(CancellationToken ct) + { + _flags.SetFlag(ClientFlags.WriteLoopStarted); + var reader = _outbound.Reader; try { - await WriteAsync(errLine, _clientCts?.Token ?? CancellationToken.None); + while (await reader.WaitToReadAsync(ct)) + { + long batchBytes = 0; + while (reader.TryRead(out var data)) + { + await _stream.WriteAsync(data, ct); + batchBytes += data.Length; + } + + using var flushCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + flushCts.CancelAfter(_options.WriteDeadline); + try + { + await _stream.FlushAsync(flushCts.Token); + } + catch (OperationCanceledException) when (!ct.IsCancellationRequested) + { + _flags.SetFlag(ClientFlags.IsSlowConsumer); + Interlocked.Increment(ref _serverStats.SlowConsumers); + Interlocked.Increment(ref _serverStats.SlowConsumerClients); + await CloseWithReasonAsync(ClientClosedReason.SlowConsumerWriteDeadline, NatsProtocol.ErrSlowConsumer); + return; + } + + Interlocked.Add(ref _pendingBytes, -batchBytes); + } } catch (OperationCanceledException) { - // Expected during shutdown + // Normal shutdown } - catch (IOException ex) + catch (IOException) { - _logger.LogDebug(ex, "Client {ClientId} failed to send -ERR", Id); - } - catch (ObjectDisposedException ex) - { - _logger.LogDebug(ex, "Client {ClientId} failed to send -ERR (disposed)", Id); + await CloseWithReasonAsync(ClientClosedReason.WriteError); } } @@ -480,9 +529,16 @@ public sealed class NatsClient : IDisposable private async Task CloseWithReasonAsync(ClientClosedReason reason, string? errMessage = null) { CloseReason = reason; - _flags.SetFlag(ClientFlags.CloseConnection); if (errMessage != null) - await SendErrAsync(errMessage); + SendErr(errMessage); + _flags.SetFlag(ClientFlags.CloseConnection); + + // Complete the outbound channel so the write loop drains remaining data + _outbound.Writer.TryComplete(); + + // Give the write loop a short window to flush the final batch before canceling + await Task.Delay(50); + if (_clientCts is { } cts) await cts.CancelAsync(); else @@ -514,15 +570,7 @@ public sealed class NatsClient : IDisposable var currentPingsOut = Interlocked.Increment(ref _pingsOut); _logger.LogDebug("Client {ClientId} sending PING ({PingsOut}/{MaxPingsOut})", Id, currentPingsOut, _options.MaxPingsOut); - try - { - await WriteAsync(NatsProtocol.PingBytes, ct); - } - catch (Exception ex) - { - _logger.LogDebug(ex, "Client {ClientId} failed to send PING", Id); - return; - } + WriteProtocol(NatsProtocol.PingBytes); } } catch (OperationCanceledException) @@ -541,9 +589,9 @@ public sealed class NatsClient : IDisposable public void Dispose() { _permissions?.Dispose(); + _outbound.Writer.TryComplete(); _clientCts?.Dispose(); _stream.Dispose(); _socket.Dispose(); - _writeLock.Dispose(); } } diff --git a/src/NATS.Server/NatsServer.cs b/src/NATS.Server/NatsServer.cs index 87f6b97..0e7a8e9 100644 --- a/src/NATS.Server/NatsServer.cs +++ b/src/NATS.Server/NatsServer.cs @@ -244,7 +244,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable // Simple round-robin -- pick based on total delivered across group var idx = Math.Abs((int)Interlocked.Increment(ref sender.OutMsgs)) % queueGroup.Length; - // Undo the OutMsgs increment -- it will be incremented properly in SendMessageAsync + // Undo the OutMsgs increment -- it will be incremented properly in SendMessage Interlocked.Decrement(ref sender.OutMsgs); for (int attempt = 0; attempt < queueGroup.Length; attempt++) @@ -270,8 +270,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable if (sub.MaxMessages > 0 && count > sub.MaxMessages) return; - // Fire and forget -- deliver asynchronously - _ = client.SendMessageAsync(subject, sub.Sid, replyTo, headers, payload, CancellationToken.None); + client.SendMessage(subject, sub.Sid, replyTo, headers, payload); } public Account GetOrCreateAccount(string name) diff --git a/tests/NATS.Server.Tests/ClientTests.cs b/tests/NATS.Server.Tests/ClientTests.cs index 92bcdd8..d018f1d 100644 --- a/tests/NATS.Server.Tests/ClientTests.cs +++ b/tests/NATS.Server.Tests/ClientTests.cs @@ -99,8 +99,8 @@ public class ClientTests : IAsyncDisposable using var readCts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token); - // Trigger SendErrAsync - await _natsClient.SendErrAsync("Invalid Subject"); + // Trigger SendErr + _natsClient.SendErr("Invalid Subject"); var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token); var response = Encoding.ASCII.GetString(buf, 0, n);