Files
natsdotnet/src/NATS.Server/NatsClient.cs

720 lines
25 KiB
C#

using System.Buffers;
using System.IO.Pipelines;
using System.Net;
using System.Net.Sockets;
using System.Security.Cryptography;
using System.Text;
using System.Text.Json;
using System.Threading.Channels;
using Microsoft.Extensions.Logging;
using NATS.Server.Auth;
using NATS.Server.Protocol;
using NATS.Server.Subscriptions;
using NATS.Server.Tls;
namespace NATS.Server;
public interface IMessageRouter
{
void ProcessMessage(string subject, string? replyTo, ReadOnlyMemory<byte> headers,
ReadOnlyMemory<byte> payload, NatsClient sender);
void RemoveClient(NatsClient client);
}
public interface ISubListAccess
{
SubList SubList { get; }
}
public sealed class NatsClient : IDisposable
{
private readonly Socket _socket;
private readonly Stream _stream;
private readonly NatsOptions _options;
private readonly ServerInfo _serverInfo;
private readonly AuthService _authService;
private readonly byte[]? _nonce;
private readonly NatsParser _parser;
private readonly Channel<ReadOnlyMemory<byte>> _outbound = Channel.CreateBounded<ReadOnlyMemory<byte>>(
new BoundedChannelOptions(8192) { SingleReader = true, FullMode = BoundedChannelFullMode.Wait });
private long _pendingBytes;
private CancellationTokenSource? _clientCts;
private readonly Dictionary<string, Subscription> _subs = new();
private readonly ILogger _logger;
private ClientPermissions? _permissions;
private readonly ServerStats _serverStats;
public ulong Id { get; }
public ClientOptions? ClientOpts { get; private set; }
public IMessageRouter? Router { get; set; }
public Account? Account { get; private set; }
private readonly ClientFlagHolder _flags = new();
public bool ConnectReceived => _flags.HasFlag(ClientFlags.ConnectReceived);
public ClientClosedReason CloseReason { get; private set; }
public DateTime StartTime { get; }
private long _lastActivityTicks;
public DateTime LastActivity => new(Interlocked.Read(ref _lastActivityTicks), DateTimeKind.Utc);
public string? RemoteIp { get; }
public int RemotePort { get; }
// Stats
public long InMsgs;
public long OutMsgs;
public long InBytes;
public long OutBytes;
// Close reason tracking
private int _skipFlushOnClose;
public bool ShouldSkipFlush => Volatile.Read(ref _skipFlushOnClose) != 0;
// PING keepalive state
private int _pingsOut;
private long _lastIn;
public TlsConnectionState? TlsState { get; set; }
public bool InfoAlreadySent { get; set; }
public IReadOnlyDictionary<string, Subscription> Subscriptions => _subs;
public NatsClient(ulong id, Stream stream, Socket socket, NatsOptions options, ServerInfo serverInfo,
AuthService authService, byte[]? nonce, ILogger logger, ServerStats serverStats)
{
Id = id;
_socket = socket;
_stream = stream;
_options = options;
_serverInfo = serverInfo;
_authService = authService;
_nonce = nonce;
_logger = logger;
_serverStats = serverStats;
_parser = new NatsParser(options.MaxPayload);
StartTime = DateTime.UtcNow;
_lastActivityTicks = StartTime.Ticks;
if (socket.RemoteEndPoint is IPEndPoint ep)
{
RemoteIp = ep.Address.ToString();
RemotePort = ep.Port;
}
}
public bool QueueOutbound(ReadOnlyMemory<byte> data)
{
if (_flags.HasFlag(ClientFlags.CloseConnection))
return false;
var pending = Interlocked.Add(ref _pendingBytes, data.Length);
if (pending > _options.MaxPending)
{
Interlocked.Add(ref _pendingBytes, -data.Length);
_flags.SetFlag(ClientFlags.IsSlowConsumer);
Interlocked.Increment(ref _serverStats.SlowConsumers);
Interlocked.Increment(ref _serverStats.SlowConsumerClients);
_ = CloseWithReasonAsync(ClientClosedReason.SlowConsumerPendingBytes, NatsProtocol.ErrSlowConsumer);
return false;
}
if (!_outbound.Writer.TryWrite(data))
{
Interlocked.Add(ref _pendingBytes, -data.Length);
_flags.SetFlag(ClientFlags.IsSlowConsumer);
Interlocked.Increment(ref _serverStats.SlowConsumers);
Interlocked.Increment(ref _serverStats.SlowConsumerClients);
_ = CloseWithReasonAsync(ClientClosedReason.SlowConsumerPendingBytes, NatsProtocol.ErrSlowConsumer);
return false;
}
return true;
}
public long PendingBytes => Interlocked.Read(ref _pendingBytes);
public async Task RunAsync(CancellationToken ct)
{
_clientCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
Interlocked.Exchange(ref _lastIn, Environment.TickCount64);
var pipe = new Pipe();
try
{
// Send INFO (skip if already sent during TLS negotiation)
if (!InfoAlreadySent)
SendInfo();
// Start auth timeout if auth is required
Task? authTimeoutTask = null;
if (_authService.IsAuthRequired)
{
authTimeoutTask = Task.Run(async () =>
{
try
{
await Task.Delay(_options.AuthTimeout, _clientCts!.Token);
if (!ConnectReceived)
{
_logger.LogDebug("Client {ClientId} auth timeout", Id);
await SendErrAndCloseAsync(NatsProtocol.ErrAuthTimeout, ClientClosedReason.AuthenticationTimeout);
}
}
catch (OperationCanceledException)
{
// Normal -- client connected or was cancelled
}
}, _clientCts.Token);
}
// Start read pump, command processing, write loop, and ping timer in parallel
var fillTask = FillPipeAsync(pipe.Writer, _clientCts.Token);
var processTask = ProcessCommandsAsync(pipe.Reader, _clientCts.Token);
var pingTask = RunPingTimerAsync(_clientCts.Token);
var writeTask = RunWriteLoopAsync(_clientCts.Token);
if (authTimeoutTask != null)
await Task.WhenAny(fillTask, processTask, pingTask, writeTask, authTimeoutTask);
else
await Task.WhenAny(fillTask, processTask, pingTask, writeTask);
}
catch (OperationCanceledException)
{
_logger.LogDebug("Client {ClientId} operation cancelled", Id);
MarkClosed(ClientClosedReason.ServerShutdown);
}
catch (IOException)
{
MarkClosed(ClientClosedReason.ReadError);
}
catch (SocketException)
{
MarkClosed(ClientClosedReason.ReadError);
}
catch (Exception ex)
{
_logger.LogDebug(ex, "Client {ClientId} connection error", Id);
MarkClosed(ClientClosedReason.ReadError);
}
finally
{
MarkClosed(ClientClosedReason.ClientClosed);
_outbound.Writer.TryComplete();
try { _socket.Shutdown(SocketShutdown.Both); }
catch (SocketException) { }
catch (ObjectDisposedException) { }
Router?.RemoveClient(this);
}
}
private async Task FillPipeAsync(PipeWriter writer, CancellationToken ct)
{
try
{
while (!ct.IsCancellationRequested)
{
var memory = writer.GetMemory(4096);
int bytesRead = await _stream.ReadAsync(memory, ct);
if (bytesRead == 0)
break;
writer.Advance(bytesRead);
var result = await writer.FlushAsync(ct);
if (result.IsCompleted)
break;
}
}
finally
{
await writer.CompleteAsync();
}
}
private async Task ProcessCommandsAsync(PipeReader reader, CancellationToken ct)
{
try
{
while (!ct.IsCancellationRequested)
{
var result = await reader.ReadAsync(ct);
var buffer = result.Buffer;
long localInMsgs = 0;
long localInBytes = 0;
while (_parser.TryParse(ref buffer, out var cmd))
{
Interlocked.Exchange(ref _lastIn, Environment.TickCount64);
// Handle Pub/HPub inline to allow ref parameter passing for stat batching.
// DispatchCommandAsync is async and cannot accept ref parameters.
if (cmd.Type is CommandType.Pub or CommandType.HPub
&& (!_authService.IsAuthRequired || ConnectReceived))
{
Interlocked.Exchange(ref _lastActivityTicks, DateTime.UtcNow.Ticks);
ProcessPub(cmd, ref localInMsgs, ref localInBytes);
if (ClientOpts?.Verbose == true)
WriteProtocol(NatsProtocol.OkBytes);
}
else
{
await DispatchCommandAsync(cmd, ct);
}
}
if (localInMsgs > 0)
{
Interlocked.Add(ref InMsgs, localInMsgs);
Interlocked.Add(ref _serverStats.InMsgs, localInMsgs);
}
if (localInBytes > 0)
{
Interlocked.Add(ref InBytes, localInBytes);
Interlocked.Add(ref _serverStats.InBytes, localInBytes);
}
reader.AdvanceTo(buffer.Start, buffer.End);
if (result.IsCompleted)
break;
}
}
finally
{
await reader.CompleteAsync();
}
}
private async ValueTask DispatchCommandAsync(ParsedCommand cmd, CancellationToken ct)
{
Interlocked.Exchange(ref _lastActivityTicks, DateTime.UtcNow.Ticks);
// If auth is required and CONNECT hasn't been received yet,
// only allow CONNECT and PING commands
if (_authService.IsAuthRequired && !ConnectReceived)
{
switch (cmd.Type)
{
case CommandType.Connect:
await ProcessConnectAsync(cmd);
return;
case CommandType.Ping:
WriteProtocol(NatsProtocol.PongBytes);
return;
default:
// Ignore all other commands until authenticated
return;
}
}
switch (cmd.Type)
{
case CommandType.Connect:
await ProcessConnectAsync(cmd);
if (ClientOpts?.Verbose == true)
WriteProtocol(NatsProtocol.OkBytes);
break;
case CommandType.Ping:
WriteProtocol(NatsProtocol.PongBytes);
if (ClientOpts?.Verbose == true)
WriteProtocol(NatsProtocol.OkBytes);
break;
case CommandType.Pong:
Interlocked.Exchange(ref _pingsOut, 0);
break;
case CommandType.Sub:
ProcessSub(cmd);
if (ClientOpts?.Verbose == true)
WriteProtocol(NatsProtocol.OkBytes);
break;
case CommandType.Unsub:
ProcessUnsub(cmd);
if (ClientOpts?.Verbose == true)
WriteProtocol(NatsProtocol.OkBytes);
break;
case CommandType.Pub:
case CommandType.HPub:
// Pub/HPub is handled inline in ProcessCommandsAsync for stat batching
break;
}
}
private async ValueTask ProcessConnectAsync(ParsedCommand cmd)
{
ClientOpts = JsonSerializer.Deserialize<ClientOptions>(cmd.Payload.Span)
?? new ClientOptions();
// Authenticate if auth is required
if (_authService.IsAuthRequired)
{
var context = new ClientAuthContext
{
Opts = ClientOpts,
Nonce = _nonce ?? [],
};
var result = _authService.Authenticate(context);
if (result == null)
{
_logger.LogWarning("Client {ClientId} authentication failed", Id);
await SendErrAndCloseAsync(NatsProtocol.ErrAuthorizationViolation, ClientClosedReason.AuthenticationViolation);
return;
}
// Build permissions from auth result
_permissions = ClientPermissions.Build(result.Permissions);
// Resolve account
if (Router is NatsServer server)
{
var accountName = result.AccountName ?? Account.GlobalAccountName;
Account = server.GetOrCreateAccount(accountName);
if (!Account.AddClient(Id))
{
Account = null;
await SendErrAndCloseAsync("maximum connections for account exceeded",
ClientClosedReason.AuthenticationViolation);
return;
}
}
_logger.LogDebug("Client {ClientId} authenticated as {Identity}", Id, result.Identity);
// Clear nonce after use -- defense-in-depth against memory dumps
if (_nonce != null)
CryptographicOperations.ZeroMemory(_nonce);
}
// If no account was assigned by auth, assign global account
if (Account == null && Router is NatsServer server2)
{
Account = server2.GetOrCreateAccount(Account.GlobalAccountName);
if (!Account.AddClient(Id))
{
Account = null;
await SendErrAndCloseAsync("maximum connections for account exceeded",
ClientClosedReason.AuthenticationViolation);
return;
}
}
// Validate no_responders requires headers
if (ClientOpts.NoResponders && !ClientOpts.Headers)
{
_logger.LogDebug("Client {ClientId} no_responders requires headers", Id);
await CloseWithReasonAsync(ClientClosedReason.NoRespondersRequiresHeaders,
NatsProtocol.ErrNoRespondersRequiresHeaders);
return;
}
_flags.SetFlag(ClientFlags.ConnectReceived);
_flags.SetFlag(ClientFlags.ConnectProcessFinished);
_logger.LogDebug("CONNECT received from client {ClientId}, name={ClientName}", Id, ClientOpts?.Name);
}
private void ProcessSub(ParsedCommand cmd)
{
// Permission check for subscribe
if (_permissions != null && !_permissions.IsSubscribeAllowed(cmd.Subject!, cmd.Queue))
{
_logger.LogDebug("Client {ClientId} subscribe permission denied for {Subject}", Id, cmd.Subject);
SendErr(NatsProtocol.ErrPermissionsSubscribe);
return;
}
var sub = new Subscription
{
Subject = cmd.Subject!,
Queue = cmd.Queue,
Sid = cmd.Sid!,
};
_subs[cmd.Sid!] = sub;
sub.Client = this;
_logger.LogDebug("SUB {Subject} {Sid} from client {ClientId}", cmd.Subject, cmd.Sid, Id);
Account?.SubList.Insert(sub);
}
private void ProcessUnsub(ParsedCommand cmd)
{
_logger.LogDebug("UNSUB {Sid} from client {ClientId}", cmd.Sid, Id);
if (!_subs.TryGetValue(cmd.Sid!, out var sub))
return;
if (cmd.MaxMessages > 0)
{
sub.MaxMessages = cmd.MaxMessages;
// Will be cleaned up when MessageCount reaches MaxMessages
return;
}
_subs.Remove(cmd.Sid!);
Account?.SubList.Remove(sub);
}
private void ProcessPub(ParsedCommand cmd, ref long localInMsgs, ref long localInBytes)
{
localInMsgs++;
localInBytes += cmd.Payload.Length;
// Max payload validation (always, hard close)
if (cmd.Payload.Length > _options.MaxPayload)
{
_logger.LogWarning("Client {ClientId} exceeded max payload: {Size} > {MaxPayload}",
Id, cmd.Payload.Length, _options.MaxPayload);
_ = SendErrAndCloseAsync(NatsProtocol.ErrMaxPayloadViolation, ClientClosedReason.MaxPayloadExceeded);
return;
}
// Pedantic mode: validate publish subject
if (ClientOpts?.Pedantic == true && !SubjectMatch.IsValidPublishSubject(cmd.Subject!))
{
_logger.LogDebug("Client {ClientId} invalid publish subject: {Subject}", Id, cmd.Subject);
SendErr(NatsProtocol.ErrInvalidPublishSubject);
return;
}
// Permission check for publish
if (_permissions != null && !_permissions.IsPublishAllowed(cmd.Subject!))
{
_logger.LogDebug("Client {ClientId} publish permission denied for {Subject}", Id, cmd.Subject);
SendErr(NatsProtocol.ErrPermissionsPublish);
return;
}
ReadOnlyMemory<byte> headers = default;
ReadOnlyMemory<byte> payload = cmd.Payload;
if (cmd.Type == CommandType.HPub && cmd.HeaderSize > 0)
{
headers = cmd.Payload[..cmd.HeaderSize];
payload = cmd.Payload[cmd.HeaderSize..];
}
Router?.ProcessMessage(cmd.Subject!, cmd.ReplyTo, headers, payload, this);
}
private void SendInfo()
{
var infoJson = JsonSerializer.Serialize(_serverInfo);
var infoLine = Encoding.ASCII.GetBytes($"INFO {infoJson}\r\n");
QueueOutbound(infoLine);
}
public void SendMessage(string subject, string sid, string? replyTo,
ReadOnlyMemory<byte> headers, ReadOnlyMemory<byte> payload)
{
Interlocked.Increment(ref OutMsgs);
Interlocked.Add(ref OutBytes, payload.Length + headers.Length);
Interlocked.Increment(ref _serverStats.OutMsgs);
Interlocked.Add(ref _serverStats.OutBytes, payload.Length + headers.Length);
byte[] line;
if (headers.Length > 0)
{
int totalSize = headers.Length + payload.Length;
line = Encoding.ASCII.GetBytes($"HMSG {subject} {sid} {(replyTo != null ? replyTo + " " : "")}{headers.Length} {totalSize}\r\n");
}
else
{
line = Encoding.ASCII.GetBytes($"MSG {subject} {sid} {(replyTo != null ? replyTo + " " : "")}{payload.Length}\r\n");
}
var totalLen = line.Length + headers.Length + payload.Length + NatsProtocol.CrLf.Length;
var msg = new byte[totalLen];
var offset = 0;
line.CopyTo(msg.AsSpan(offset)); offset += line.Length;
if (headers.Length > 0) { headers.Span.CopyTo(msg.AsSpan(offset)); offset += headers.Length; }
if (payload.Length > 0) { payload.Span.CopyTo(msg.AsSpan(offset)); offset += payload.Length; }
NatsProtocol.CrLf.CopyTo(msg.AsSpan(offset));
QueueOutbound(msg);
}
private void WriteProtocol(byte[] data)
{
QueueOutbound(data);
}
public void SendErr(string message)
{
var errLine = Encoding.ASCII.GetBytes($"-ERR '{message}'\r\n");
QueueOutbound(errLine);
}
private async Task RunWriteLoopAsync(CancellationToken ct)
{
_flags.SetFlag(ClientFlags.WriteLoopStarted);
var reader = _outbound.Reader;
try
{
while (await reader.WaitToReadAsync(ct))
{
long batchBytes = 0;
while (reader.TryRead(out var data))
{
await _stream.WriteAsync(data, ct);
batchBytes += data.Length;
}
using var flushCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
flushCts.CancelAfter(_options.WriteDeadline);
try
{
await _stream.FlushAsync(flushCts.Token);
}
catch (OperationCanceledException) when (!ct.IsCancellationRequested)
{
_flags.SetFlag(ClientFlags.IsSlowConsumer);
Interlocked.Increment(ref _serverStats.SlowConsumers);
Interlocked.Increment(ref _serverStats.SlowConsumerClients);
await CloseWithReasonAsync(ClientClosedReason.SlowConsumerWriteDeadline, NatsProtocol.ErrSlowConsumer);
return;
}
Interlocked.Add(ref _pendingBytes, -batchBytes);
}
}
catch (OperationCanceledException)
{
// Normal shutdown
}
catch (IOException)
{
await CloseWithReasonAsync(ClientClosedReason.WriteError);
}
}
public async Task SendErrAndCloseAsync(string message, ClientClosedReason reason = ClientClosedReason.ProtocolViolation)
{
await CloseWithReasonAsync(reason, message);
}
private async Task CloseWithReasonAsync(ClientClosedReason reason, string? errMessage = null)
{
CloseReason = reason;
if (errMessage != null)
SendErr(errMessage);
_flags.SetFlag(ClientFlags.CloseConnection);
// Complete the outbound channel so the write loop drains remaining data
_outbound.Writer.TryComplete();
// Give the write loop a short window to flush the final batch before canceling
await Task.Delay(50);
if (_clientCts is { } cts)
await cts.CancelAsync();
else
_socket.Close();
}
private async Task RunPingTimerAsync(CancellationToken ct)
{
using var timer = new PeriodicTimer(_options.PingInterval);
try
{
while (await timer.WaitForNextTickAsync(ct))
{
var elapsed = Environment.TickCount64 - Interlocked.Read(ref _lastIn);
if (elapsed < (long)_options.PingInterval.TotalMilliseconds)
{
// Client was recently active, skip ping
Interlocked.Exchange(ref _pingsOut, 0);
continue;
}
if (Volatile.Read(ref _pingsOut) + 1 > _options.MaxPingsOut)
{
_logger.LogDebug("Client {ClientId} stale connection -- closing", Id);
await SendErrAndCloseAsync(NatsProtocol.ErrStaleConnection, ClientClosedReason.StaleConnection);
return;
}
var currentPingsOut = Interlocked.Increment(ref _pingsOut);
_logger.LogDebug("Client {ClientId} sending PING ({PingsOut}/{MaxPingsOut})",
Id, currentPingsOut, _options.MaxPingsOut);
WriteProtocol(NatsProtocol.PingBytes);
}
}
catch (OperationCanceledException)
{
// Normal shutdown
}
}
/// <summary>
/// Marks this connection as closed with the given reason.
/// Sets skip-flush flag for error-related reasons.
/// Only the first call sets the reason (subsequent calls are no-ops).
/// </summary>
public void MarkClosed(ClientClosedReason reason)
{
if (CloseReason != ClientClosedReason.None)
return;
CloseReason = reason;
switch (reason)
{
case ClientClosedReason.ReadError:
case ClientClosedReason.WriteError:
case ClientClosedReason.SlowConsumerPendingBytes:
case ClientClosedReason.SlowConsumerWriteDeadline:
case ClientClosedReason.TlsHandshakeError:
Volatile.Write(ref _skipFlushOnClose, 1);
break;
}
_logger.LogDebug("Client {ClientId} connection closed: {CloseReason}", Id, reason);
}
/// <summary>
/// Flushes pending data (unless skip-flush is set) and closes the connection.
/// </summary>
public async Task FlushAndCloseAsync(bool minimalFlush = false)
{
if (!ShouldSkipFlush)
{
try
{
using var flushCts = new CancellationTokenSource(minimalFlush
? TimeSpan.FromMilliseconds(100)
: TimeSpan.FromSeconds(1));
await _stream.FlushAsync(flushCts.Token);
}
catch (Exception)
{
// Best effort flush — don't let it prevent close
}
}
try { _socket.Shutdown(SocketShutdown.Both); }
catch (SocketException) { }
catch (ObjectDisposedException) { }
}
public void RemoveAllSubscriptions(SubList subList)
{
foreach (var sub in _subs.Values)
subList.Remove(sub);
_subs.Clear();
}
public void Dispose()
{
_permissions?.Dispose();
_outbound.Writer.TryComplete();
_clientCts?.Dispose();
_stream.Dispose();
_socket.Dispose();
}
}