using System.Buffers; using System.IO.Pipelines; using System.Net; using System.Net.Sockets; using System.Security.Cryptography; using System.Text; using System.Text.Json; using System.Threading.Channels; using Microsoft.Extensions.Logging; using NATS.Server.Auth; using NATS.Server.Protocol; using NATS.Server.Subscriptions; using NATS.Server.Tls; using NATS.Server.WebSocket; namespace NATS.Server; public interface IMessageRouter { void ProcessMessage(string subject, string? replyTo, ReadOnlyMemory headers, ReadOnlyMemory payload, NatsClient sender); void RemoveClient(NatsClient client); } public interface ISubListAccess { SubList SubList { get; } } public sealed class NatsClient : IDisposable { private readonly Socket _socket; private readonly Stream _stream; private readonly NatsOptions _options; private readonly ServerInfo _serverInfo; private readonly AuthService _authService; private readonly byte[]? _nonce; private readonly NatsParser _parser; private readonly Channel> _outbound = Channel.CreateBounded>( new BoundedChannelOptions(8192) { SingleReader = true, FullMode = BoundedChannelFullMode.Wait }); private long _pendingBytes; private CancellationTokenSource? _clientCts; private readonly Dictionary _subs = new(); private readonly ILogger _logger; private ClientPermissions? _permissions; private readonly ServerStats _serverStats; public ulong Id { get; } public ClientOptions? ClientOpts { get; private set; } public IMessageRouter? Router { get; set; } public Account? Account { get; private set; } public ClientPermissions? Permissions => _permissions; private readonly ClientFlagHolder _flags = new(); public bool ConnectReceived => _flags.HasFlag(ClientFlags.ConnectReceived); public ClientClosedReason CloseReason { get; private set; } public void SetTraceMode(bool enabled) { if (enabled) { _flags.SetFlag(ClientFlags.TraceMode); _parser.Logger = _logger; } else { _flags.ClearFlag(ClientFlags.TraceMode); _parser.Logger = _options.Trace ? _logger : null; } } public DateTime StartTime { get; } private long _lastActivityTicks; public DateTime LastActivity => new(Interlocked.Read(ref _lastActivityTicks), DateTimeKind.Utc); public string? RemoteIp { get; } public int RemotePort { get; } // Stats public long InMsgs; public long OutMsgs; public long InBytes; public long OutBytes; // Close reason tracking private int _skipFlushOnClose; public bool ShouldSkipFlush => Volatile.Read(ref _skipFlushOnClose) != 0; // PING keepalive state private int _pingsOut; private long _lastIn; // RTT tracking private long _rttStartTicks; private long _rtt; public TimeSpan Rtt => new(Interlocked.Read(ref _rtt)); public bool IsWebSocket { get; set; } public WsUpgradeResult? WsInfo { get; set; } public TlsConnectionState? TlsState { get; set; } public bool InfoAlreadySent { get; set; } public IReadOnlyDictionary Subscriptions => _subs; public NatsClient(ulong id, Stream stream, Socket socket, NatsOptions options, ServerInfo serverInfo, AuthService authService, byte[]? nonce, ILogger logger, ServerStats serverStats) { Id = id; _socket = socket; _stream = stream; _options = options; _serverInfo = serverInfo; _authService = authService; _nonce = nonce; _logger = logger; _serverStats = serverStats; _parser = new NatsParser(options.MaxPayload, options.Trace ? logger : null); StartTime = DateTime.UtcNow; _lastActivityTicks = StartTime.Ticks; if (socket.RemoteEndPoint is IPEndPoint ep) { RemoteIp = ep.Address.ToString(); RemotePort = ep.Port; } } public bool QueueOutbound(ReadOnlyMemory data) { if (_flags.HasFlag(ClientFlags.CloseConnection)) return false; var pending = Interlocked.Add(ref _pendingBytes, data.Length); if (pending > _options.MaxPending) { Interlocked.Add(ref _pendingBytes, -data.Length); _flags.SetFlag(ClientFlags.IsSlowConsumer); Interlocked.Increment(ref _serverStats.SlowConsumers); Interlocked.Increment(ref _serverStats.SlowConsumerClients); _ = CloseWithReasonAsync(ClientClosedReason.SlowConsumerPendingBytes, NatsProtocol.ErrSlowConsumer); return false; } if (!_outbound.Writer.TryWrite(data)) { Interlocked.Add(ref _pendingBytes, -data.Length); _flags.SetFlag(ClientFlags.IsSlowConsumer); Interlocked.Increment(ref _serverStats.SlowConsumers); Interlocked.Increment(ref _serverStats.SlowConsumerClients); _ = CloseWithReasonAsync(ClientClosedReason.SlowConsumerPendingBytes, NatsProtocol.ErrSlowConsumer); return false; } return true; } public long PendingBytes => Interlocked.Read(ref _pendingBytes); public async Task RunAsync(CancellationToken ct) { _clientCts = CancellationTokenSource.CreateLinkedTokenSource(ct); Interlocked.Exchange(ref _lastIn, Environment.TickCount64); var pipe = new Pipe(); try { // Send INFO (skip if already sent during TLS negotiation) if (!InfoAlreadySent) SendInfo(); // Start auth timeout if auth is required Task? authTimeoutTask = null; if (_authService.IsAuthRequired) { authTimeoutTask = Task.Run(async () => { try { await Task.Delay(_options.AuthTimeout, _clientCts!.Token); if (!ConnectReceived) { _logger.LogDebug("Client {ClientId} auth timeout", Id); await SendErrAndCloseAsync(NatsProtocol.ErrAuthTimeout, ClientClosedReason.AuthenticationTimeout); } } catch (OperationCanceledException) { // Normal -- client connected or was cancelled } }, _clientCts.Token); } // Start read pump, command processing, write loop, and ping timer in parallel var fillTask = FillPipeAsync(pipe.Writer, _clientCts.Token); var processTask = ProcessCommandsAsync(pipe.Reader, _clientCts.Token); var pingTask = RunPingTimerAsync(_clientCts.Token); var writeTask = RunWriteLoopAsync(_clientCts.Token); if (authTimeoutTask != null) await Task.WhenAny(fillTask, processTask, pingTask, writeTask, authTimeoutTask); else await Task.WhenAny(fillTask, processTask, pingTask, writeTask); } catch (OperationCanceledException) { _logger.LogDebug("Client {ClientId} operation cancelled", Id); MarkClosed(ClientClosedReason.ServerShutdown); } catch (IOException) { MarkClosed(ClientClosedReason.ReadError); } catch (SocketException) { MarkClosed(ClientClosedReason.ReadError); } catch (Exception ex) { _logger.LogDebug(ex, "Client {ClientId} connection error", Id); MarkClosed(ClientClosedReason.ReadError); } finally { MarkClosed(ClientClosedReason.ClientClosed); _outbound.Writer.TryComplete(); try { _socket.Shutdown(SocketShutdown.Both); } catch (SocketException) { } catch (ObjectDisposedException) { } Router?.RemoveClient(this); } } private async Task FillPipeAsync(PipeWriter writer, CancellationToken ct) { try { while (!ct.IsCancellationRequested) { var memory = writer.GetMemory(4096); int bytesRead = await _stream.ReadAsync(memory, ct); if (bytesRead == 0) break; writer.Advance(bytesRead); var result = await writer.FlushAsync(ct); if (result.IsCompleted) break; } } finally { await writer.CompleteAsync(); } } private async Task ProcessCommandsAsync(PipeReader reader, CancellationToken ct) { try { while (!ct.IsCancellationRequested) { var result = await reader.ReadAsync(ct); var buffer = result.Buffer; long localInMsgs = 0; long localInBytes = 0; while (_parser.TryParse(ref buffer, out var cmd)) { Interlocked.Exchange(ref _lastIn, Environment.TickCount64); // Handle Pub/HPub inline to allow ref parameter passing for stat batching. // DispatchCommandAsync is async and cannot accept ref parameters. if (cmd.Type is CommandType.Pub or CommandType.HPub && (!_authService.IsAuthRequired || ConnectReceived)) { Interlocked.Exchange(ref _lastActivityTicks, DateTime.UtcNow.Ticks); ProcessPub(cmd, ref localInMsgs, ref localInBytes); if (ClientOpts?.Verbose == true) WriteProtocol(NatsProtocol.OkBytes); } else { await DispatchCommandAsync(cmd, ct); } } if (localInMsgs > 0) { Interlocked.Add(ref InMsgs, localInMsgs); Interlocked.Add(ref _serverStats.InMsgs, localInMsgs); } if (localInBytes > 0) { Interlocked.Add(ref InBytes, localInBytes); Interlocked.Add(ref _serverStats.InBytes, localInBytes); } reader.AdvanceTo(buffer.Start, buffer.End); if (result.IsCompleted) break; } } finally { await reader.CompleteAsync(); } } private async ValueTask DispatchCommandAsync(ParsedCommand cmd, CancellationToken ct) { Interlocked.Exchange(ref _lastActivityTicks, DateTime.UtcNow.Ticks); // If auth is required and CONNECT hasn't been received yet, // only allow CONNECT and PING commands if (_authService.IsAuthRequired && !ConnectReceived) { switch (cmd.Type) { case CommandType.Connect: await ProcessConnectAsync(cmd); return; case CommandType.Ping: WriteProtocol(NatsProtocol.PongBytes); return; default: // Ignore all other commands until authenticated return; } } switch (cmd.Type) { case CommandType.Connect: await ProcessConnectAsync(cmd); if (ClientOpts?.Verbose == true) WriteProtocol(NatsProtocol.OkBytes); break; case CommandType.Ping: WriteProtocol(NatsProtocol.PongBytes); if (ClientOpts?.Verbose == true) WriteProtocol(NatsProtocol.OkBytes); break; case CommandType.Pong: Interlocked.Exchange(ref _pingsOut, 0); var rttStart = Interlocked.Read(ref _rttStartTicks); if (rttStart > 0) { var elapsed = DateTime.UtcNow.Ticks - rttStart; if (elapsed <= 0) elapsed = 1; // min 1 tick for Windows granularity Interlocked.Exchange(ref _rtt, elapsed); } _flags.SetFlag(ClientFlags.FirstPongSent); break; case CommandType.Sub: ProcessSub(cmd); if (ClientOpts?.Verbose == true) WriteProtocol(NatsProtocol.OkBytes); break; case CommandType.Unsub: ProcessUnsub(cmd); if (ClientOpts?.Verbose == true) WriteProtocol(NatsProtocol.OkBytes); break; case CommandType.Pub: case CommandType.HPub: // Pub/HPub is handled inline in ProcessCommandsAsync for stat batching break; } } private async ValueTask ProcessConnectAsync(ParsedCommand cmd) { ClientOpts = JsonSerializer.Deserialize(cmd.Payload.Span) ?? new ClientOptions(); // Authenticate if auth is required AuthResult? authResult = null; if (_authService.IsAuthRequired) { var context = new ClientAuthContext { Opts = ClientOpts, Nonce = _nonce ?? [], ClientCertificate = TlsState?.PeerCert, }; authResult = _authService.Authenticate(context); if (authResult == null) { _logger.LogWarning("Client {ClientId} authentication failed", Id); await SendErrAndCloseAsync(NatsProtocol.ErrAuthorizationViolation, ClientClosedReason.AuthenticationViolation); return; } // Build permissions from auth result _permissions = ClientPermissions.Build(authResult.Permissions); // Resolve account if (Router is NatsServer server) { var accountName = authResult.AccountName ?? Account.GlobalAccountName; Account = server.GetOrCreateAccount(accountName); if (!Account.AddClient(Id)) { Account = null; await SendErrAndCloseAsync("maximum connections for account exceeded", ClientClosedReason.AuthenticationViolation); return; } } _logger.LogDebug("Client {ClientId} authenticated as {Identity}", Id, authResult.Identity); // Clear nonce after use -- defense-in-depth against memory dumps if (_nonce != null) CryptographicOperations.ZeroMemory(_nonce); } // If no account was assigned by auth, assign global account if (Account == null && Router is NatsServer server2) { Account = server2.GetOrCreateAccount(Account.GlobalAccountName); if (!Account.AddClient(Id)) { Account = null; await SendErrAndCloseAsync("maximum connections for account exceeded", ClientClosedReason.AuthenticationViolation); return; } } // Validate no_responders requires headers if (ClientOpts.NoResponders && !ClientOpts.Headers) { _logger.LogDebug("Client {ClientId} no_responders requires headers", Id); await CloseWithReasonAsync(ClientClosedReason.NoRespondersRequiresHeaders, NatsProtocol.ErrNoRespondersRequiresHeaders); return; } _flags.SetFlag(ClientFlags.ConnectReceived); _flags.SetFlag(ClientFlags.ConnectProcessFinished); _logger.LogDebug("CONNECT received from client {ClientId}, name={ClientName}", Id, ClientOpts?.Name); // Start auth expiry timer if needed if (_authService.IsAuthRequired && authResult?.Expiry is { } expiry) { var remaining = expiry - DateTimeOffset.UtcNow; if (remaining > TimeSpan.Zero) { _ = Task.Run(async () => { try { await Task.Delay(remaining, _clientCts!.Token); _logger.LogDebug("Client {ClientId} authentication expired", Id); await SendErrAndCloseAsync("Authentication Expired", ClientClosedReason.AuthenticationExpired); } catch (OperationCanceledException) { } }, _clientCts!.Token); } else { await SendErrAndCloseAsync("Authentication Expired", ClientClosedReason.AuthenticationExpired); return; } } } private void ProcessSub(ParsedCommand cmd) { // Permission check for subscribe if (_permissions != null && !_permissions.IsSubscribeAllowed(cmd.Subject!, cmd.Queue)) { _logger.LogDebug("Client {ClientId} subscribe permission denied for {Subject}", Id, cmd.Subject); SendErr(NatsProtocol.ErrPermissionsSubscribe); return; } // Per-connection subscription limit if (_options.MaxSubs > 0 && _subs.Count >= _options.MaxSubs) { _logger.LogDebug("Client {ClientId} max subscriptions exceeded", Id); _ = SendErrAndCloseAsync(NatsProtocol.ErrMaxSubscriptionsExceeded, ClientClosedReason.MaxSubscriptionsExceeded); return; } // Per-account subscription limit if (Account != null && !Account.IncrementSubscriptions()) { _logger.LogDebug("Client {ClientId} account subscription limit exceeded", Id); _ = SendErrAndCloseAsync(NatsProtocol.ErrMaxSubscriptionsExceeded, ClientClosedReason.MaxSubscriptionsExceeded); return; } var sub = new Subscription { Subject = cmd.Subject!, Queue = cmd.Queue, Sid = cmd.Sid!, }; _subs[cmd.Sid!] = sub; sub.Client = this; _logger.LogDebug("SUB {Subject} {Sid} from client {ClientId}", cmd.Subject, cmd.Sid, Id); Account?.SubList.Insert(sub); } private void ProcessUnsub(ParsedCommand cmd) { _logger.LogDebug("UNSUB {Sid} from client {ClientId}", cmd.Sid, Id); if (!_subs.TryGetValue(cmd.Sid!, out var sub)) return; if (cmd.MaxMessages > 0) { sub.MaxMessages = cmd.MaxMessages; // Will be cleaned up when MessageCount reaches MaxMessages return; } _subs.Remove(cmd.Sid!); Account?.DecrementSubscriptions(); Account?.SubList.Remove(sub); } private void ProcessPub(ParsedCommand cmd, ref long localInMsgs, ref long localInBytes) { localInMsgs++; localInBytes += cmd.Payload.Length; // Max payload validation (always, hard close) if (cmd.Payload.Length > _options.MaxPayload) { _logger.LogWarning("Client {ClientId} exceeded max payload: {Size} > {MaxPayload}", Id, cmd.Payload.Length, _options.MaxPayload); _ = SendErrAndCloseAsync(NatsProtocol.ErrMaxPayloadViolation, ClientClosedReason.MaxPayloadExceeded); return; } // Pedantic mode: validate publish subject if (ClientOpts?.Pedantic == true && !SubjectMatch.IsValidPublishSubject(cmd.Subject!)) { _logger.LogDebug("Client {ClientId} invalid publish subject: {Subject}", Id, cmd.Subject); SendErr(NatsProtocol.ErrInvalidPublishSubject); return; } // Permission check for publish if (_permissions != null && !_permissions.IsPublishAllowed(cmd.Subject!)) { _logger.LogDebug("Client {ClientId} publish permission denied for {Subject}", Id, cmd.Subject); SendErr(NatsProtocol.ErrPermissionsPublish); return; } ReadOnlyMemory headers = default; ReadOnlyMemory payload = cmd.Payload; if (cmd.Type == CommandType.HPub && cmd.HeaderSize > 0) { headers = cmd.Payload[..cmd.HeaderSize]; payload = cmd.Payload[cmd.HeaderSize..]; } Router?.ProcessMessage(cmd.Subject!, cmd.ReplyTo, headers, payload, this); } private void SendInfo() { // Use the cached INFO bytes from the server when there is no per-connection // nonce (i.e. NKey auth is not active for this connection). When a nonce is // present the _serverInfo was already cloned with the nonce embedded, so we // must serialise it individually. if (_nonce == null && Router is NatsServer server) { QueueOutbound(server.CachedInfoLine); } else { var infoJson = JsonSerializer.Serialize(_serverInfo); var infoLine = Encoding.ASCII.GetBytes($"INFO {infoJson}\r\n"); QueueOutbound(infoLine); } } public void SendMessage(string subject, string sid, string? replyTo, ReadOnlyMemory headers, ReadOnlyMemory payload) { Interlocked.Increment(ref OutMsgs); Interlocked.Add(ref OutBytes, payload.Length + headers.Length); Interlocked.Increment(ref _serverStats.OutMsgs); Interlocked.Add(ref _serverStats.OutBytes, payload.Length + headers.Length); // Estimate control line size var estimatedLineSize = 5 + subject.Length + 1 + sid.Length + 1 + (replyTo != null ? replyTo.Length + 1 : 0) + 20 + 2; var totalPayloadLen = headers.Length + payload.Length; var totalLen = estimatedLineSize + totalPayloadLen + 2; var buffer = new byte[totalLen]; var span = buffer.AsSpan(); int pos = 0; // Write prefix if (headers.Length > 0) { "HMSG "u8.CopyTo(span); pos = 5; } else { "MSG "u8.CopyTo(span); pos = 4; } // Subject pos += Encoding.ASCII.GetBytes(subject, span[pos..]); span[pos++] = (byte)' '; // SID pos += Encoding.ASCII.GetBytes(sid, span[pos..]); span[pos++] = (byte)' '; // Reply-to if (replyTo != null) { pos += Encoding.ASCII.GetBytes(replyTo, span[pos..]); span[pos++] = (byte)' '; } // Sizes if (headers.Length > 0) { int totalSize = headers.Length + payload.Length; headers.Length.TryFormat(span[pos..], out int written); pos += written; span[pos++] = (byte)' '; totalSize.TryFormat(span[pos..], out written); pos += written; } else { payload.Length.TryFormat(span[pos..], out int written); pos += written; } // CRLF span[pos++] = (byte)'\r'; span[pos++] = (byte)'\n'; // Headers + payload + trailing CRLF if (headers.Length > 0) { headers.Span.CopyTo(span[pos..]); pos += headers.Length; } if (payload.Length > 0) { payload.Span.CopyTo(span[pos..]); pos += payload.Length; } span[pos++] = (byte)'\r'; span[pos++] = (byte)'\n'; QueueOutbound(buffer.AsMemory(0, pos)); } private void WriteProtocol(byte[] data) { QueueOutbound(data); } public void SendErr(string message) { var errLine = Encoding.ASCII.GetBytes($"-ERR '{message}'\r\n"); QueueOutbound(errLine); } private async Task RunWriteLoopAsync(CancellationToken ct) { _flags.SetFlag(ClientFlags.WriteLoopStarted); var reader = _outbound.Reader; try { while (await reader.WaitToReadAsync(ct)) { long batchBytes = 0; while (reader.TryRead(out var data)) { await _stream.WriteAsync(data, ct); batchBytes += data.Length; } using var flushCts = CancellationTokenSource.CreateLinkedTokenSource(ct); flushCts.CancelAfter(_options.WriteDeadline); try { await _stream.FlushAsync(flushCts.Token); } catch (OperationCanceledException) when (!ct.IsCancellationRequested) { _flags.SetFlag(ClientFlags.IsSlowConsumer); Interlocked.Increment(ref _serverStats.SlowConsumers); Interlocked.Increment(ref _serverStats.SlowConsumerClients); await CloseWithReasonAsync(ClientClosedReason.SlowConsumerWriteDeadline, NatsProtocol.ErrSlowConsumer); return; } Interlocked.Add(ref _pendingBytes, -batchBytes); } } catch (OperationCanceledException) { // Normal shutdown } catch (IOException) { await CloseWithReasonAsync(ClientClosedReason.WriteError); } } public async Task SendErrAndCloseAsync(string message, ClientClosedReason reason = ClientClosedReason.ProtocolViolation) { await CloseWithReasonAsync(reason, message); } private async Task CloseWithReasonAsync(ClientClosedReason reason, string? errMessage = null) { CloseReason = reason; if (errMessage != null) SendErr(errMessage); _flags.SetFlag(ClientFlags.CloseConnection); // Complete the outbound channel so the write loop drains remaining data _outbound.Writer.TryComplete(); // Give the write loop a short window to flush the final batch before canceling await Task.Delay(50); if (_clientCts is { } cts) await cts.CancelAsync(); else _socket.Close(); } private async Task RunPingTimerAsync(CancellationToken ct) { using var timer = new PeriodicTimer(_options.PingInterval); try { while (await timer.WaitForNextTickAsync(ct)) { // Delay first PING until client has responded with PONG or 2 seconds elapsed if (!_flags.HasFlag(ClientFlags.FirstPongSent) && (DateTime.UtcNow - StartTime).TotalSeconds < 2) { continue; } var elapsed = Environment.TickCount64 - Interlocked.Read(ref _lastIn); if (elapsed < (long)_options.PingInterval.TotalMilliseconds) { // Client was recently active, skip ping Interlocked.Exchange(ref _pingsOut, 0); continue; } if (Volatile.Read(ref _pingsOut) + 1 > _options.MaxPingsOut) { _logger.LogDebug("Client {ClientId} stale connection -- closing", Id); Interlocked.Increment(ref _serverStats.StaleConnections); Interlocked.Increment(ref _serverStats.StaleConnectionClients); await SendErrAndCloseAsync(NatsProtocol.ErrStaleConnection, ClientClosedReason.StaleConnection); return; } var currentPingsOut = Interlocked.Increment(ref _pingsOut); _logger.LogDebug("Client {ClientId} sending PING ({PingsOut}/{MaxPingsOut})", Id, currentPingsOut, _options.MaxPingsOut); Interlocked.Exchange(ref _rttStartTicks, DateTime.UtcNow.Ticks); WriteProtocol(NatsProtocol.PingBytes); } } catch (OperationCanceledException) { // Normal shutdown } } /// /// Marks this connection as closed with the given reason. /// Sets skip-flush flag for error-related reasons. /// Only the first call sets the reason (subsequent calls are no-ops). /// public void MarkClosed(ClientClosedReason reason) { if (CloseReason != ClientClosedReason.None) return; CloseReason = reason; switch (reason) { case ClientClosedReason.ReadError: case ClientClosedReason.WriteError: case ClientClosedReason.SlowConsumerPendingBytes: case ClientClosedReason.SlowConsumerWriteDeadline: case ClientClosedReason.TlsHandshakeError: Volatile.Write(ref _skipFlushOnClose, 1); break; } _logger.LogDebug("Client {ClientId} connection closed: {CloseReason}", Id, reason); } /// /// Flushes pending data (unless skip-flush is set) and closes the connection. /// public async Task FlushAndCloseAsync(bool minimalFlush = false) { if (!ShouldSkipFlush) { try { using var flushCts = new CancellationTokenSource(minimalFlush ? TimeSpan.FromMilliseconds(100) : TimeSpan.FromSeconds(1)); await _stream.FlushAsync(flushCts.Token); } catch (Exception) { // Best effort flush — don't let it prevent close } } try { _socket.Shutdown(SocketShutdown.Both); } catch (SocketException) { } catch (ObjectDisposedException) { } } public void RemoveSubscription(string sid) { if (_subs.Remove(sid)) Account?.DecrementSubscriptions(); } public void RemoveAllSubscriptions(SubList subList) { foreach (var sub in _subs.Values) subList.Remove(sub); _subs.Clear(); } public void Dispose() { _permissions?.Dispose(); _outbound.Writer.TryComplete(); _clientCts?.Dispose(); _stream.Dispose(); _socket.Dispose(); } }