feat: replace inline writes with channel-based write loop and batch flush

This commit is contained in:
Joseph Doherty
2026-02-22 23:41:44 -05:00
parent ad6a02b9a2
commit 31660a4187
3 changed files with 118 additions and 71 deletions

View File

@@ -5,6 +5,7 @@ using System.Net.Sockets;
using System.Security.Cryptography; using System.Security.Cryptography;
using System.Text; using System.Text;
using System.Text.Json; using System.Text.Json;
using System.Threading.Channels;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using NATS.Server.Auth; using NATS.Server.Auth;
using NATS.Server.Protocol; using NATS.Server.Protocol;
@@ -34,7 +35,9 @@ public sealed class NatsClient : IDisposable
private readonly AuthService _authService; private readonly AuthService _authService;
private readonly byte[]? _nonce; private readonly byte[]? _nonce;
private readonly NatsParser _parser; private readonly NatsParser _parser;
private readonly SemaphoreSlim _writeLock = new(1, 1); private readonly Channel<ReadOnlyMemory<byte>> _outbound = Channel.CreateBounded<ReadOnlyMemory<byte>>(
new BoundedChannelOptions(8192) { SingleReader = true, FullMode = BoundedChannelFullMode.Wait });
private long _pendingBytes;
private CancellationTokenSource? _clientCts; private CancellationTokenSource? _clientCts;
private readonly Dictionary<string, Subscription> _subs = new(); private readonly Dictionary<string, Subscription> _subs = new();
private readonly ILogger _logger; private readonly ILogger _logger;
@@ -93,6 +96,37 @@ public sealed class NatsClient : IDisposable
} }
} }
public bool QueueOutbound(ReadOnlyMemory<byte> data)
{
if (_flags.HasFlag(ClientFlags.CloseConnection))
return false;
var pending = Interlocked.Add(ref _pendingBytes, data.Length);
if (pending > _options.MaxPending)
{
Interlocked.Add(ref _pendingBytes, -data.Length);
_flags.SetFlag(ClientFlags.IsSlowConsumer);
Interlocked.Increment(ref _serverStats.SlowConsumers);
Interlocked.Increment(ref _serverStats.SlowConsumerClients);
_ = CloseWithReasonAsync(ClientClosedReason.SlowConsumerPendingBytes, NatsProtocol.ErrSlowConsumer);
return false;
}
if (!_outbound.Writer.TryWrite(data))
{
Interlocked.Add(ref _pendingBytes, -data.Length);
_flags.SetFlag(ClientFlags.IsSlowConsumer);
Interlocked.Increment(ref _serverStats.SlowConsumers);
Interlocked.Increment(ref _serverStats.SlowConsumerClients);
_ = CloseWithReasonAsync(ClientClosedReason.SlowConsumerPendingBytes, NatsProtocol.ErrSlowConsumer);
return false;
}
return true;
}
public long PendingBytes => Interlocked.Read(ref _pendingBytes);
public async Task RunAsync(CancellationToken ct) public async Task RunAsync(CancellationToken ct)
{ {
_clientCts = CancellationTokenSource.CreateLinkedTokenSource(ct); _clientCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
@@ -102,7 +136,7 @@ public sealed class NatsClient : IDisposable
{ {
// Send INFO (skip if already sent during TLS negotiation) // Send INFO (skip if already sent during TLS negotiation)
if (!InfoAlreadySent) if (!InfoAlreadySent)
await SendInfoAsync(_clientCts.Token); SendInfo();
// Start auth timeout if auth is required // Start auth timeout if auth is required
Task? authTimeoutTask = null; Task? authTimeoutTask = null;
@@ -126,15 +160,16 @@ public sealed class NatsClient : IDisposable
}, _clientCts.Token); }, _clientCts.Token);
} }
// Start read pump, command processing, and ping timer in parallel // Start read pump, command processing, write loop, and ping timer in parallel
var fillTask = FillPipeAsync(pipe.Writer, _clientCts.Token); var fillTask = FillPipeAsync(pipe.Writer, _clientCts.Token);
var processTask = ProcessCommandsAsync(pipe.Reader, _clientCts.Token); var processTask = ProcessCommandsAsync(pipe.Reader, _clientCts.Token);
var pingTask = RunPingTimerAsync(_clientCts.Token); var pingTask = RunPingTimerAsync(_clientCts.Token);
var writeTask = RunWriteLoopAsync(_clientCts.Token);
if (authTimeoutTask != null) if (authTimeoutTask != null)
await Task.WhenAny(fillTask, processTask, pingTask, authTimeoutTask); await Task.WhenAny(fillTask, processTask, pingTask, writeTask, authTimeoutTask);
else else
await Task.WhenAny(fillTask, processTask, pingTask); await Task.WhenAny(fillTask, processTask, pingTask, writeTask);
} }
catch (OperationCanceledException) catch (OperationCanceledException)
{ {
@@ -146,6 +181,7 @@ public sealed class NatsClient : IDisposable
} }
finally finally
{ {
_outbound.Writer.TryComplete();
try { _socket.Shutdown(SocketShutdown.Both); } try { _socket.Shutdown(SocketShutdown.Both); }
catch (SocketException) { } catch (SocketException) { }
catch (ObjectDisposedException) { } catch (ObjectDisposedException) { }
@@ -217,7 +253,7 @@ public sealed class NatsClient : IDisposable
await ProcessConnectAsync(cmd); await ProcessConnectAsync(cmd);
return; return;
case CommandType.Ping: case CommandType.Ping:
await WriteAsync(NatsProtocol.PongBytes, ct); WriteProtocol(NatsProtocol.PongBytes);
return; return;
default: default:
// Ignore all other commands until authenticated // Ignore all other commands until authenticated
@@ -232,7 +268,7 @@ public sealed class NatsClient : IDisposable
break; break;
case CommandType.Ping: case CommandType.Ping:
await WriteAsync(NatsProtocol.PongBytes, ct); WriteProtocol(NatsProtocol.PongBytes);
break; break;
case CommandType.Pong: case CommandType.Pong:
@@ -240,7 +276,7 @@ public sealed class NatsClient : IDisposable
break; break;
case CommandType.Sub: case CommandType.Sub:
await ProcessSubAsync(cmd); ProcessSub(cmd);
break; break;
case CommandType.Unsub: case CommandType.Unsub:
@@ -249,7 +285,7 @@ public sealed class NatsClient : IDisposable
case CommandType.Pub: case CommandType.Pub:
case CommandType.HPub: case CommandType.HPub:
await ProcessPubAsync(cmd); ProcessPub(cmd);
break; break;
} }
} }
@@ -306,13 +342,13 @@ public sealed class NatsClient : IDisposable
_logger.LogDebug("CONNECT received from client {ClientId}, name={ClientName}", Id, ClientOpts?.Name); _logger.LogDebug("CONNECT received from client {ClientId}, name={ClientName}", Id, ClientOpts?.Name);
} }
private async ValueTask ProcessSubAsync(ParsedCommand cmd) private void ProcessSub(ParsedCommand cmd)
{ {
// Permission check for subscribe // Permission check for subscribe
if (_permissions != null && !_permissions.IsSubscribeAllowed(cmd.Subject!, cmd.Queue)) if (_permissions != null && !_permissions.IsSubscribeAllowed(cmd.Subject!, cmd.Queue))
{ {
_logger.LogDebug("Client {ClientId} subscribe permission denied for {Subject}", Id, cmd.Subject); _logger.LogDebug("Client {ClientId} subscribe permission denied for {Subject}", Id, cmd.Subject);
await SendErrAsync(NatsProtocol.ErrPermissionsSubscribe); SendErr(NatsProtocol.ErrPermissionsSubscribe);
return; return;
} }
@@ -350,7 +386,7 @@ public sealed class NatsClient : IDisposable
Account?.SubList.Remove(sub); Account?.SubList.Remove(sub);
} }
private async ValueTask ProcessPubAsync(ParsedCommand cmd) private void ProcessPub(ParsedCommand cmd)
{ {
Interlocked.Increment(ref InMsgs); Interlocked.Increment(ref InMsgs);
Interlocked.Add(ref InBytes, cmd.Payload.Length); Interlocked.Add(ref InBytes, cmd.Payload.Length);
@@ -362,7 +398,7 @@ public sealed class NatsClient : IDisposable
{ {
_logger.LogWarning("Client {ClientId} exceeded max payload: {Size} > {MaxPayload}", _logger.LogWarning("Client {ClientId} exceeded max payload: {Size} > {MaxPayload}",
Id, cmd.Payload.Length, _options.MaxPayload); Id, cmd.Payload.Length, _options.MaxPayload);
await SendErrAndCloseAsync(NatsProtocol.ErrMaxPayloadViolation, ClientClosedReason.MaxPayloadExceeded); _ = SendErrAndCloseAsync(NatsProtocol.ErrMaxPayloadViolation, ClientClosedReason.MaxPayloadExceeded);
return; return;
} }
@@ -370,7 +406,7 @@ public sealed class NatsClient : IDisposable
if (ClientOpts?.Pedantic == true && !SubjectMatch.IsValidPublishSubject(cmd.Subject!)) if (ClientOpts?.Pedantic == true && !SubjectMatch.IsValidPublishSubject(cmd.Subject!))
{ {
_logger.LogDebug("Client {ClientId} invalid publish subject: {Subject}", Id, cmd.Subject); _logger.LogDebug("Client {ClientId} invalid publish subject: {Subject}", Id, cmd.Subject);
await SendErrAsync(NatsProtocol.ErrInvalidPublishSubject); SendErr(NatsProtocol.ErrInvalidPublishSubject);
return; return;
} }
@@ -378,7 +414,7 @@ public sealed class NatsClient : IDisposable
if (_permissions != null && !_permissions.IsPublishAllowed(cmd.Subject!)) if (_permissions != null && !_permissions.IsPublishAllowed(cmd.Subject!))
{ {
_logger.LogDebug("Client {ClientId} publish permission denied for {Subject}", Id, cmd.Subject); _logger.LogDebug("Client {ClientId} publish permission denied for {Subject}", Id, cmd.Subject);
await SendErrAsync(NatsProtocol.ErrPermissionsPublish); SendErr(NatsProtocol.ErrPermissionsPublish);
return; return;
} }
@@ -394,15 +430,15 @@ public sealed class NatsClient : IDisposable
Router?.ProcessMessage(cmd.Subject!, cmd.ReplyTo, headers, payload, this); Router?.ProcessMessage(cmd.Subject!, cmd.ReplyTo, headers, payload, this);
} }
private async Task SendInfoAsync(CancellationToken ct) private void SendInfo()
{ {
var infoJson = JsonSerializer.Serialize(_serverInfo); var infoJson = JsonSerializer.Serialize(_serverInfo);
var infoLine = Encoding.ASCII.GetBytes($"INFO {infoJson}\r\n"); var infoLine = Encoding.ASCII.GetBytes($"INFO {infoJson}\r\n");
await WriteAsync(infoLine, ct); QueueOutbound(infoLine);
} }
public async Task SendMessageAsync(string subject, string sid, string? replyTo, public void SendMessage(string subject, string sid, string? replyTo,
ReadOnlyMemory<byte> headers, ReadOnlyMemory<byte> payload, CancellationToken ct) ReadOnlyMemory<byte> headers, ReadOnlyMemory<byte> payload)
{ {
Interlocked.Increment(ref OutMsgs); Interlocked.Increment(ref OutMsgs);
Interlocked.Add(ref OutBytes, payload.Length + headers.Length); Interlocked.Add(ref OutBytes, payload.Length + headers.Length);
@@ -420,55 +456,68 @@ public sealed class NatsClient : IDisposable
line = Encoding.ASCII.GetBytes($"MSG {subject} {sid} {(replyTo != null ? replyTo + " " : "")}{payload.Length}\r\n"); line = Encoding.ASCII.GetBytes($"MSG {subject} {sid} {(replyTo != null ? replyTo + " " : "")}{payload.Length}\r\n");
} }
await _writeLock.WaitAsync(ct); var totalLen = line.Length + headers.Length + payload.Length + NatsProtocol.CrLf.Length;
try var msg = new byte[totalLen];
{ var offset = 0;
await _stream.WriteAsync(line, ct); line.CopyTo(msg.AsSpan(offset)); offset += line.Length;
if (headers.Length > 0) if (headers.Length > 0) { headers.Span.CopyTo(msg.AsSpan(offset)); offset += headers.Length; }
await _stream.WriteAsync(headers, ct); if (payload.Length > 0) { payload.Span.CopyTo(msg.AsSpan(offset)); offset += payload.Length; }
if (payload.Length > 0) NatsProtocol.CrLf.CopyTo(msg.AsSpan(offset));
await _stream.WriteAsync(payload, ct);
await _stream.WriteAsync(NatsProtocol.CrLf, ct); QueueOutbound(msg);
await _stream.FlushAsync(ct);
}
finally
{
_writeLock.Release();
}
} }
private async Task WriteAsync(byte[] data, CancellationToken ct) private void WriteProtocol(byte[] data)
{ {
await _writeLock.WaitAsync(ct); QueueOutbound(data);
try
{
await _stream.WriteAsync(data, ct);
await _stream.FlushAsync(ct);
}
finally
{
_writeLock.Release();
}
} }
public async Task SendErrAsync(string message) public void SendErr(string message)
{ {
var errLine = Encoding.ASCII.GetBytes($"-ERR '{message}'\r\n"); var errLine = Encoding.ASCII.GetBytes($"-ERR '{message}'\r\n");
QueueOutbound(errLine);
}
private async Task RunWriteLoopAsync(CancellationToken ct)
{
_flags.SetFlag(ClientFlags.WriteLoopStarted);
var reader = _outbound.Reader;
try try
{ {
await WriteAsync(errLine, _clientCts?.Token ?? CancellationToken.None); while (await reader.WaitToReadAsync(ct))
{
long batchBytes = 0;
while (reader.TryRead(out var data))
{
await _stream.WriteAsync(data, ct);
batchBytes += data.Length;
}
using var flushCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
flushCts.CancelAfter(_options.WriteDeadline);
try
{
await _stream.FlushAsync(flushCts.Token);
}
catch (OperationCanceledException) when (!ct.IsCancellationRequested)
{
_flags.SetFlag(ClientFlags.IsSlowConsumer);
Interlocked.Increment(ref _serverStats.SlowConsumers);
Interlocked.Increment(ref _serverStats.SlowConsumerClients);
await CloseWithReasonAsync(ClientClosedReason.SlowConsumerWriteDeadline, NatsProtocol.ErrSlowConsumer);
return;
}
Interlocked.Add(ref _pendingBytes, -batchBytes);
}
} }
catch (OperationCanceledException) catch (OperationCanceledException)
{ {
// Expected during shutdown // Normal shutdown
} }
catch (IOException ex) catch (IOException)
{ {
_logger.LogDebug(ex, "Client {ClientId} failed to send -ERR", Id); await CloseWithReasonAsync(ClientClosedReason.WriteError);
}
catch (ObjectDisposedException ex)
{
_logger.LogDebug(ex, "Client {ClientId} failed to send -ERR (disposed)", Id);
} }
} }
@@ -480,9 +529,16 @@ public sealed class NatsClient : IDisposable
private async Task CloseWithReasonAsync(ClientClosedReason reason, string? errMessage = null) private async Task CloseWithReasonAsync(ClientClosedReason reason, string? errMessage = null)
{ {
CloseReason = reason; CloseReason = reason;
_flags.SetFlag(ClientFlags.CloseConnection);
if (errMessage != null) if (errMessage != null)
await SendErrAsync(errMessage); SendErr(errMessage);
_flags.SetFlag(ClientFlags.CloseConnection);
// Complete the outbound channel so the write loop drains remaining data
_outbound.Writer.TryComplete();
// Give the write loop a short window to flush the final batch before canceling
await Task.Delay(50);
if (_clientCts is { } cts) if (_clientCts is { } cts)
await cts.CancelAsync(); await cts.CancelAsync();
else else
@@ -514,15 +570,7 @@ public sealed class NatsClient : IDisposable
var currentPingsOut = Interlocked.Increment(ref _pingsOut); var currentPingsOut = Interlocked.Increment(ref _pingsOut);
_logger.LogDebug("Client {ClientId} sending PING ({PingsOut}/{MaxPingsOut})", _logger.LogDebug("Client {ClientId} sending PING ({PingsOut}/{MaxPingsOut})",
Id, currentPingsOut, _options.MaxPingsOut); Id, currentPingsOut, _options.MaxPingsOut);
try WriteProtocol(NatsProtocol.PingBytes);
{
await WriteAsync(NatsProtocol.PingBytes, ct);
}
catch (Exception ex)
{
_logger.LogDebug(ex, "Client {ClientId} failed to send PING", Id);
return;
}
} }
} }
catch (OperationCanceledException) catch (OperationCanceledException)
@@ -541,9 +589,9 @@ public sealed class NatsClient : IDisposable
public void Dispose() public void Dispose()
{ {
_permissions?.Dispose(); _permissions?.Dispose();
_outbound.Writer.TryComplete();
_clientCts?.Dispose(); _clientCts?.Dispose();
_stream.Dispose(); _stream.Dispose();
_socket.Dispose(); _socket.Dispose();
_writeLock.Dispose();
} }
} }

View File

@@ -244,7 +244,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
// Simple round-robin -- pick based on total delivered across group // Simple round-robin -- pick based on total delivered across group
var idx = Math.Abs((int)Interlocked.Increment(ref sender.OutMsgs)) % queueGroup.Length; var idx = Math.Abs((int)Interlocked.Increment(ref sender.OutMsgs)) % queueGroup.Length;
// Undo the OutMsgs increment -- it will be incremented properly in SendMessageAsync // Undo the OutMsgs increment -- it will be incremented properly in SendMessage
Interlocked.Decrement(ref sender.OutMsgs); Interlocked.Decrement(ref sender.OutMsgs);
for (int attempt = 0; attempt < queueGroup.Length; attempt++) for (int attempt = 0; attempt < queueGroup.Length; attempt++)
@@ -270,8 +270,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
if (sub.MaxMessages > 0 && count > sub.MaxMessages) if (sub.MaxMessages > 0 && count > sub.MaxMessages)
return; return;
// Fire and forget -- deliver asynchronously client.SendMessage(subject, sub.Sid, replyTo, headers, payload);
_ = client.SendMessageAsync(subject, sub.Sid, replyTo, headers, payload, CancellationToken.None);
} }
public Account GetOrCreateAccount(string name) public Account GetOrCreateAccount(string name)

View File

@@ -99,8 +99,8 @@ public class ClientTests : IAsyncDisposable
using var readCts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); using var readCts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token); await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token);
// Trigger SendErrAsync // Trigger SendErr
await _natsClient.SendErrAsync("Invalid Subject"); _natsClient.SendErr("Invalid Subject");
var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token); var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token);
var response = Encoding.ASCII.GetString(buf, 0, n); var response = Encoding.ASCII.GetString(buf, 0, n);