refactor: replace _connectReceived with ClientFlagHolder and add CloseReason tracking

This commit is contained in:
Joseph Doherty
2026-02-22 23:35:35 -05:00
parent 61c6b832e5
commit ad6a02b9a2

View File

@@ -46,9 +46,9 @@ public sealed class NatsClient : IDisposable
public IMessageRouter? Router { get; set; } public IMessageRouter? Router { get; set; }
public Account? Account { get; private set; } public Account? Account { get; private set; }
// Thread-safe: read from auth timeout task on threadpool, written from command pipeline private readonly ClientFlagHolder _flags = new();
private int _connectReceived; public bool ConnectReceived => _flags.HasFlag(ClientFlags.ConnectReceived);
public bool ConnectReceived => Volatile.Read(ref _connectReceived) != 0; public ClientClosedReason CloseReason { get; private set; }
public DateTime StartTime { get; } public DateTime StartTime { get; }
private long _lastActivityTicks; private long _lastActivityTicks;
@@ -116,7 +116,7 @@ public sealed class NatsClient : IDisposable
if (!ConnectReceived) if (!ConnectReceived)
{ {
_logger.LogDebug("Client {ClientId} auth timeout", Id); _logger.LogDebug("Client {ClientId} auth timeout", Id);
await SendErrAndCloseAsync(NatsProtocol.ErrAuthTimeout); await SendErrAndCloseAsync(NatsProtocol.ErrAuthTimeout, ClientClosedReason.AuthenticationTimeout);
} }
} }
catch (OperationCanceledException) catch (OperationCanceledException)
@@ -272,7 +272,7 @@ public sealed class NatsClient : IDisposable
if (result == null) if (result == null)
{ {
_logger.LogWarning("Client {ClientId} authentication failed", Id); _logger.LogWarning("Client {ClientId} authentication failed", Id);
await SendErrAndCloseAsync(NatsProtocol.ErrAuthorizationViolation); await SendErrAndCloseAsync(NatsProtocol.ErrAuthorizationViolation, ClientClosedReason.AuthenticationViolation);
return; return;
} }
@@ -301,7 +301,8 @@ public sealed class NatsClient : IDisposable
Account.AddClient(Id); Account.AddClient(Id);
} }
Volatile.Write(ref _connectReceived, 1); _flags.SetFlag(ClientFlags.ConnectReceived);
_flags.SetFlag(ClientFlags.ConnectProcessFinished);
_logger.LogDebug("CONNECT received from client {ClientId}, name={ClientName}", Id, ClientOpts?.Name); _logger.LogDebug("CONNECT received from client {ClientId}, name={ClientName}", Id, ClientOpts?.Name);
} }
@@ -361,7 +362,7 @@ public sealed class NatsClient : IDisposable
{ {
_logger.LogWarning("Client {ClientId} exceeded max payload: {Size} > {MaxPayload}", _logger.LogWarning("Client {ClientId} exceeded max payload: {Size} > {MaxPayload}",
Id, cmd.Payload.Length, _options.MaxPayload); Id, cmd.Payload.Length, _options.MaxPayload);
await SendErrAndCloseAsync(NatsProtocol.ErrMaxPayloadViolation); await SendErrAndCloseAsync(NatsProtocol.ErrMaxPayloadViolation, ClientClosedReason.MaxPayloadExceeded);
return; return;
} }
@@ -471,9 +472,17 @@ public sealed class NatsClient : IDisposable
} }
} }
public async Task SendErrAndCloseAsync(string message) public async Task SendErrAndCloseAsync(string message, ClientClosedReason reason = ClientClosedReason.ProtocolViolation)
{ {
await SendErrAsync(message); await CloseWithReasonAsync(reason, message);
}
private async Task CloseWithReasonAsync(ClientClosedReason reason, string? errMessage = null)
{
CloseReason = reason;
_flags.SetFlag(ClientFlags.CloseConnection);
if (errMessage != null)
await SendErrAsync(errMessage);
if (_clientCts is { } cts) if (_clientCts is { } cts)
await cts.CancelAsync(); await cts.CancelAsync();
else else
@@ -498,7 +507,7 @@ public sealed class NatsClient : IDisposable
if (Volatile.Read(ref _pingsOut) + 1 > _options.MaxPingsOut) if (Volatile.Read(ref _pingsOut) + 1 > _options.MaxPingsOut)
{ {
_logger.LogDebug("Client {ClientId} stale connection -- closing", Id); _logger.LogDebug("Client {ClientId} stale connection -- closing", Id);
await SendErrAndCloseAsync(NatsProtocol.ErrStaleConnection); await SendErrAndCloseAsync(NatsProtocol.ErrStaleConnection, ClientClosedReason.StaleConnection);
return; return;
} }