diff --git a/src/NATS.Server/NatsClient.cs b/src/NATS.Server/NatsClient.cs index 1ecbc56..fba081f 100644 --- a/src/NATS.Server/NatsClient.cs +++ b/src/NATS.Server/NatsClient.cs @@ -48,6 +48,7 @@ public sealed class NatsClient : IDisposable public ClientOptions? ClientOpts { get; private set; } public IMessageRouter? Router { get; set; } public Account? Account { get; private set; } + public ClientPermissions? Permissions => _permissions; private readonly ClientFlagHolder _flags = new(); public bool ConnectReceived => _flags.HasFlag(ClientFlags.ConnectReceived); @@ -90,7 +91,7 @@ public sealed class NatsClient : IDisposable _nonce = nonce; _logger = logger; _serverStats = serverStats; - _parser = new NatsParser(options.MaxPayload); + _parser = new NatsParser(options.MaxPayload, options.Trace ? logger : null); StartTime = DateTime.UtcNow; _lastActivityTicks = StartTime.Ticks; if (socket.RemoteEndPoint is IPEndPoint ep) @@ -348,6 +349,7 @@ public sealed class NatsClient : IDisposable ?? new ClientOptions(); // Authenticate if auth is required + AuthResult? authResult = null; if (_authService.IsAuthRequired) { var context = new ClientAuthContext @@ -356,8 +358,8 @@ public sealed class NatsClient : IDisposable Nonce = _nonce ?? [], }; - var result = _authService.Authenticate(context); - if (result == null) + authResult = _authService.Authenticate(context); + if (authResult == null) { _logger.LogWarning("Client {ClientId} authentication failed", Id); await SendErrAndCloseAsync(NatsProtocol.ErrAuthorizationViolation, ClientClosedReason.AuthenticationViolation); @@ -365,12 +367,12 @@ public sealed class NatsClient : IDisposable } // Build permissions from auth result - _permissions = ClientPermissions.Build(result.Permissions); + _permissions = ClientPermissions.Build(authResult.Permissions); // Resolve account if (Router is NatsServer server) { - var accountName = result.AccountName ?? Account.GlobalAccountName; + var accountName = authResult.AccountName ?? Account.GlobalAccountName; Account = server.GetOrCreateAccount(accountName); if (!Account.AddClient(Id)) { @@ -381,7 +383,7 @@ public sealed class NatsClient : IDisposable } } - _logger.LogDebug("Client {ClientId} authenticated as {Identity}", Id, result.Identity); + _logger.LogDebug("Client {ClientId} authenticated as {Identity}", Id, authResult.Identity); // Clear nonce after use -- defense-in-depth against memory dumps if (_nonce != null) @@ -413,6 +415,32 @@ public sealed class NatsClient : IDisposable _flags.SetFlag(ClientFlags.ConnectReceived); _flags.SetFlag(ClientFlags.ConnectProcessFinished); _logger.LogDebug("CONNECT received from client {ClientId}, name={ClientName}", Id, ClientOpts?.Name); + + // Start auth expiry timer if needed + if (_authService.IsAuthRequired && authResult?.Expiry is { } expiry) + { + var remaining = expiry - DateTimeOffset.UtcNow; + if (remaining > TimeSpan.Zero) + { + _ = Task.Run(async () => + { + try + { + await Task.Delay(remaining, _clientCts!.Token); + _logger.LogDebug("Client {ClientId} authentication expired", Id); + await SendErrAndCloseAsync("Authentication Expired", + ClientClosedReason.AuthenticationExpired); + } + catch (OperationCanceledException) { } + }, _clientCts!.Token); + } + else + { + await SendErrAndCloseAsync("Authentication Expired", + ClientClosedReason.AuthenticationExpired); + return; + } + } } private void ProcessSub(ParsedCommand cmd) @@ -425,6 +453,24 @@ public sealed class NatsClient : IDisposable return; } + // Per-connection subscription limit + if (_options.MaxSubs > 0 && _subs.Count >= _options.MaxSubs) + { + _logger.LogDebug("Client {ClientId} max subscriptions exceeded", Id); + _ = SendErrAndCloseAsync(NatsProtocol.ErrMaxSubscriptionsExceeded, + ClientClosedReason.MaxSubscriptionsExceeded); + return; + } + + // Per-account subscription limit + if (Account != null && !Account.IncrementSubscriptions()) + { + _logger.LogDebug("Client {ClientId} account subscription limit exceeded", Id); + _ = SendErrAndCloseAsync(NatsProtocol.ErrMaxSubscriptionsExceeded, + ClientClosedReason.MaxSubscriptionsExceeded); + return; + } + var sub = new Subscription { Subject = cmd.Subject!, @@ -455,6 +501,7 @@ public sealed class NatsClient : IDisposable } _subs.Remove(cmd.Sid!); + Account?.DecrementSubscriptions(); Account?.SubList.Remove(sub); } @@ -701,6 +748,12 @@ public sealed class NatsClient : IDisposable catch (ObjectDisposedException) { } } + public void RemoveSubscription(string sid) + { + if (_subs.Remove(sid)) + Account?.DecrementSubscriptions(); + } + public void RemoveAllSubscriptions(SubList subList) { foreach (var sub in _subs.Values) diff --git a/src/NATS.Server/Protocol/NatsParser.cs b/src/NATS.Server/Protocol/NatsParser.cs index 2689ec0..b9debdd 100644 --- a/src/NATS.Server/Protocol/NatsParser.cs +++ b/src/NATS.Server/Protocol/NatsParser.cs @@ -1,5 +1,6 @@ using System.Buffers; using System.Text; +using Microsoft.Extensions.Logging; namespace NATS.Server.Protocol; @@ -35,6 +36,7 @@ public sealed class NatsParser { private static readonly byte[] CrLfBytes = "\r\n"u8.ToArray(); private readonly int _maxPayload; + private readonly ILogger? _logger; // State for split-packet payload reading private bool _awaitingPayload; @@ -44,9 +46,20 @@ public sealed class NatsParser private int _pendingHeaderSize; private CommandType _pendingType; - public NatsParser(int maxPayload = NatsProtocol.MaxPayloadSize) + public NatsParser(int maxPayload = NatsProtocol.MaxPayloadSize, ILogger? logger = null) { _maxPayload = maxPayload; + _logger = logger; + } + + private void TraceInOp(string op, ReadOnlySpan arg = default) + { + if (_logger == null || !_logger.IsEnabled(LogLevel.Trace)) + return; + if (arg.IsEmpty) + _logger.LogTrace("<<- {Op}", op); + else + _logger.LogTrace("<<- {Op} {Arg}", op, Encoding.ASCII.GetString(arg)); } public bool TryParse(ref ReadOnlySequence buffer, out ParsedCommand command) @@ -91,6 +104,7 @@ public sealed class NatsParser { command = ParsedCommand.Simple(CommandType.Ping); buffer = buffer.Slice(reader.Position); + TraceInOp("PING"); return true; } @@ -98,6 +112,7 @@ public sealed class NatsParser { command = ParsedCommand.Simple(CommandType.Pong); buffer = buffer.Slice(reader.Position); + TraceInOp("PONG"); return true; } @@ -121,6 +136,7 @@ public sealed class NatsParser { command = ParseSub(lineSpan); buffer = buffer.Slice(reader.Position); + TraceInOp("SUB", lineSpan[4..]); return true; } @@ -131,6 +147,7 @@ public sealed class NatsParser { command = ParseUnsub(lineSpan); buffer = buffer.Slice(reader.Position); + TraceInOp("UNSUB", lineSpan[6..]); return true; } @@ -141,6 +158,7 @@ public sealed class NatsParser { command = ParseConnect(lineSpan); buffer = buffer.Slice(reader.Position); + TraceInOp("CONNECT"); return true; } @@ -151,6 +169,7 @@ public sealed class NatsParser { command = ParseInfo(lineSpan); buffer = buffer.Slice(reader.Position); + TraceInOp("INFO"); return true; } @@ -159,11 +178,13 @@ public sealed class NatsParser case (byte)'+': // +OK command = ParsedCommand.Simple(CommandType.Ok); buffer = buffer.Slice(reader.Position); + TraceInOp("+OK"); return true; case (byte)'-': // -ERR command = ParsedCommand.Simple(CommandType.Err); buffer = buffer.Slice(reader.Position); + TraceInOp("-ERR"); return true; } @@ -215,6 +236,7 @@ public sealed class NatsParser _pendingHeaderSize = -1; _pendingType = CommandType.Pub; + TraceInOp("PUB", argsSpan); return TryReadPayload(ref buffer, out command); } @@ -264,6 +286,7 @@ public sealed class NatsParser _pendingHeaderSize = hdrSize; _pendingType = CommandType.HPub; + TraceInOp("HPUB", argsSpan); return TryReadPayload(ref buffer, out command); }