diff --git a/src/NATS.Server/NatsClient.cs b/src/NATS.Server/NatsClient.cs index 69cde4c..e81e450 100644 --- a/src/NATS.Server/NatsClient.cs +++ b/src/NATS.Server/NatsClient.cs @@ -46,9 +46,9 @@ public sealed class NatsClient : IDisposable public IMessageRouter? Router { get; set; } public Account? Account { get; private set; } - // Thread-safe: read from auth timeout task on threadpool, written from command pipeline - private int _connectReceived; - public bool ConnectReceived => Volatile.Read(ref _connectReceived) != 0; + private readonly ClientFlagHolder _flags = new(); + public bool ConnectReceived => _flags.HasFlag(ClientFlags.ConnectReceived); + public ClientClosedReason CloseReason { get; private set; } public DateTime StartTime { get; } private long _lastActivityTicks; @@ -116,7 +116,7 @@ public sealed class NatsClient : IDisposable if (!ConnectReceived) { _logger.LogDebug("Client {ClientId} auth timeout", Id); - await SendErrAndCloseAsync(NatsProtocol.ErrAuthTimeout); + await SendErrAndCloseAsync(NatsProtocol.ErrAuthTimeout, ClientClosedReason.AuthenticationTimeout); } } catch (OperationCanceledException) @@ -272,7 +272,7 @@ public sealed class NatsClient : IDisposable if (result == null) { _logger.LogWarning("Client {ClientId} authentication failed", Id); - await SendErrAndCloseAsync(NatsProtocol.ErrAuthorizationViolation); + await SendErrAndCloseAsync(NatsProtocol.ErrAuthorizationViolation, ClientClosedReason.AuthenticationViolation); return; } @@ -301,7 +301,8 @@ public sealed class NatsClient : IDisposable Account.AddClient(Id); } - Volatile.Write(ref _connectReceived, 1); + _flags.SetFlag(ClientFlags.ConnectReceived); + _flags.SetFlag(ClientFlags.ConnectProcessFinished); _logger.LogDebug("CONNECT received from client {ClientId}, name={ClientName}", Id, ClientOpts?.Name); } @@ -361,7 +362,7 @@ public sealed class NatsClient : IDisposable { _logger.LogWarning("Client {ClientId} exceeded max payload: {Size} > {MaxPayload}", Id, cmd.Payload.Length, _options.MaxPayload); - await SendErrAndCloseAsync(NatsProtocol.ErrMaxPayloadViolation); + await SendErrAndCloseAsync(NatsProtocol.ErrMaxPayloadViolation, ClientClosedReason.MaxPayloadExceeded); return; } @@ -471,9 +472,17 @@ public sealed class NatsClient : IDisposable } } - public async Task SendErrAndCloseAsync(string message) + public async Task SendErrAndCloseAsync(string message, ClientClosedReason reason = ClientClosedReason.ProtocolViolation) { - await SendErrAsync(message); + await CloseWithReasonAsync(reason, message); + } + + private async Task CloseWithReasonAsync(ClientClosedReason reason, string? errMessage = null) + { + CloseReason = reason; + _flags.SetFlag(ClientFlags.CloseConnection); + if (errMessage != null) + await SendErrAsync(errMessage); if (_clientCts is { } cts) await cts.CancelAsync(); else @@ -498,7 +507,7 @@ public sealed class NatsClient : IDisposable if (Volatile.Read(ref _pingsOut) + 1 > _options.MaxPingsOut) { _logger.LogDebug("Client {ClientId} stale connection -- closing", Id); - await SendErrAndCloseAsync(NatsProtocol.ErrStaleConnection); + await SendErrAndCloseAsync(NatsProtocol.ErrStaleConnection, ClientClosedReason.StaleConnection); return; }