Files
natsdotnet/src/NATS.Server/NatsServer.cs
2026-02-23 06:19:41 -05:00

1183 lines
45 KiB
C#

using System.Collections.Concurrent;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using Microsoft.Extensions.Logging;
using NATS.NKeys;
using NATS.Server.Auth;
using NATS.Server.Configuration;
using NATS.Server.Gateways;
using NATS.Server.JetStream;
using NATS.Server.JetStream.Api;
using NATS.Server.JetStream.Publish;
using NATS.Server.LeafNodes;
using NATS.Server.Monitoring;
using NATS.Server.Protocol;
using NATS.Server.Routes;
using NATS.Server.Subscriptions;
using NATS.Server.Tls;
using NATS.Server.WebSocket;
namespace NATS.Server;
public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
{
private readonly NatsOptions _options;
private readonly ConcurrentDictionary<ulong, NatsClient> _clients = new();
private readonly ConcurrentQueue<ClosedClient> _closedClients = new();
private readonly ServerInfo _serverInfo;
private readonly ILogger<NatsServer> _logger;
private readonly ILoggerFactory _loggerFactory;
private readonly ServerStats _stats = new();
private readonly TaskCompletionSource _listeningStarted = new(TaskCreationOptions.RunContinuationsAsynchronously);
private AuthService _authService;
private readonly ConcurrentDictionary<string, Account> _accounts = new(StringComparer.Ordinal);
// Config reload state
private NatsOptions? _cliSnapshot;
private HashSet<string> _cliFlags = [];
private string? _configDigest;
private readonly Account _globalAccount;
private readonly Account _systemAccount;
private readonly SslServerAuthenticationOptions? _sslOptions;
private readonly TlsRateLimiter? _tlsRateLimiter;
private readonly SubjectTransform[] _subjectTransforms;
private readonly RouteManager? _routeManager;
private readonly GatewayManager? _gatewayManager;
private readonly LeafNodeManager? _leafNodeManager;
private readonly JetStreamService? _jetStreamService;
private readonly JetStreamApiRouter? _jetStreamApiRouter;
private readonly StreamManager? _jetStreamStreamManager;
private readonly ConsumerManager? _jetStreamConsumerManager;
private readonly JetStreamPublisher? _jetStreamPublisher;
private Socket? _listener;
private Socket? _wsListener;
private readonly TaskCompletionSource _wsAcceptLoopExited = new(TaskCreationOptions.RunContinuationsAsynchronously);
private MonitorServer? _monitorServer;
private ulong _nextClientId;
private long _startTimeTicks;
private readonly CancellationTokenSource _quitCts = new();
private readonly TaskCompletionSource _shutdownComplete = new(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly TaskCompletionSource _acceptLoopExited = new(TaskCreationOptions.RunContinuationsAsynchronously);
private int _shutdown;
private int _activeClientCount;
private int _lameDuck;
private byte[] _cachedInfoLine = [];
private readonly List<PosixSignalRegistration> _signalRegistrations = [];
private string? _portsFilePath;
private static readonly TimeSpan AcceptMinSleep = TimeSpan.FromMilliseconds(10);
private static readonly TimeSpan AcceptMaxSleep = TimeSpan.FromSeconds(1);
public SubList SubList => _globalAccount.SubList;
public byte[] CachedInfoLine => _cachedInfoLine;
public ServerStats Stats => _stats;
public DateTime StartTime => new(Interlocked.Read(ref _startTimeTicks), DateTimeKind.Utc);
public string ServerId => _serverInfo.ServerId;
public string ServerName => _serverInfo.ServerName;
public int ClientCount => _clients.Count;
public int Port => _options.Port;
public Account SystemAccount => _systemAccount;
public string ServerNKey { get; }
public bool IsShuttingDown => Volatile.Read(ref _shutdown) != 0;
public bool IsLameDuckMode => Volatile.Read(ref _lameDuck) != 0;
public string? ClusterListen => _routeManager?.ListenEndpoint;
public JetStreamApiRouter? JetStreamApiRouter => _jetStreamApiRouter;
public int JetStreamStreams => _jetStreamStreamManager?.StreamNames.Count ?? 0;
public int JetStreamConsumers => _jetStreamConsumerManager?.ConsumerCount ?? 0;
public Action? ReOpenLogFile { get; set; }
public IEnumerable<NatsClient> GetClients() => _clients.Values;
public IEnumerable<ClosedClient> GetClosedClients() => _closedClients;
public IEnumerable<Auth.Account> GetAccounts() => _accounts.Values;
public bool HasRemoteInterest(string subject) => _globalAccount.SubList.HasRemoteInterest(subject);
public bool TryCaptureJetStreamPublish(string subject, ReadOnlyMemory<byte> payload, out PubAck ack)
{
if (_jetStreamPublisher != null && _jetStreamPublisher.TryCapture(subject, payload, out ack))
{
if (ack.ErrorCode == null
&& _jetStreamConsumerManager != null
&& _jetStreamStreamManager != null
&& _jetStreamStreamManager.TryGet(ack.Stream, out var streamHandle))
{
var stored = streamHandle.Store.LoadAsync(ack.Seq, default).GetAwaiter().GetResult();
if (stored != null)
_jetStreamConsumerManager.OnPublished(ack.Stream, stored);
}
return true;
}
ack = new PubAck();
return false;
}
public Task WaitForReadyAsync() => _listeningStarted.Task;
public void WaitForShutdown() => _shutdownComplete.Task.GetAwaiter().GetResult();
public async Task ShutdownAsync()
{
if (Interlocked.CompareExchange(ref _shutdown, 1, 0) != 0)
return; // Already shutting down
_logger.LogInformation("Initiating Shutdown...");
// Signal all internal loops to stop
await _quitCts.CancelAsync();
// Close listeners to stop accept loops
_listener?.Close();
_wsListener?.Close();
if (_routeManager != null)
await _routeManager.DisposeAsync();
if (_gatewayManager != null)
await _gatewayManager.DisposeAsync();
if (_leafNodeManager != null)
await _leafNodeManager.DisposeAsync();
if (_jetStreamService != null)
await _jetStreamService.DisposeAsync();
_stats.JetStreamEnabled = false;
// Wait for accept loops to exit
await _acceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
await _wsAcceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
// Close all client connections — flush first, then mark closed
var flushTasks = new List<Task>();
foreach (var client in _clients.Values)
{
client.MarkClosed(ClientClosedReason.ServerShutdown);
flushTasks.Add(client.FlushAndCloseAsync(minimalFlush: true));
}
await Task.WhenAll(flushTasks).WaitAsync(TimeSpan.FromSeconds(2)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
// Wait for active client tasks to drain (with timeout)
if (Volatile.Read(ref _activeClientCount) > 0)
{
using var drainCts = new CancellationTokenSource(TimeSpan.FromSeconds(10));
try
{
while (Volatile.Read(ref _activeClientCount) > 0 && !drainCts.IsCancellationRequested)
await Task.Delay(50, drainCts.Token);
}
catch (OperationCanceledException) { }
}
// Stop monitor server
if (_monitorServer != null)
await _monitorServer.DisposeAsync();
DeletePidFile();
DeletePortsFile();
_logger.LogInformation("Server Exiting..");
_shutdownComplete.TrySetResult();
}
public async Task LameDuckShutdownAsync()
{
if (IsShuttingDown || Interlocked.CompareExchange(ref _lameDuck, 1, 0) != 0)
return;
_logger.LogInformation("Entering lame duck mode, stop accepting new clients");
// Close listeners to stop accepting new connections
_listener?.Close();
_wsListener?.Close();
// Wait for accept loops to exit
await _acceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
await _wsAcceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
var gracePeriod = _options.LameDuckGracePeriod;
if (gracePeriod < TimeSpan.Zero) gracePeriod = -gracePeriod;
// If no clients, go straight to shutdown
if (_clients.IsEmpty)
{
await ShutdownAsync();
return;
}
// Wait grace period for clients to drain naturally
_logger.LogInformation("Waiting {GracePeriod}ms grace period", gracePeriod.TotalMilliseconds);
try
{
await Task.Delay(gracePeriod, _quitCts.Token);
}
catch (OperationCanceledException) { return; }
if (_clients.IsEmpty)
{
await ShutdownAsync();
return;
}
// Stagger-close remaining clients
var dur = _options.LameDuckDuration - gracePeriod;
if (dur <= TimeSpan.Zero) dur = TimeSpan.FromSeconds(1);
var clients = _clients.Values.ToList();
var numClients = clients.Count;
if (numClients > 0)
{
_logger.LogInformation("Closing {Count} existing clients over {Duration}ms",
numClients, dur.TotalMilliseconds);
var sleepInterval = dur.Ticks / numClients;
if (sleepInterval < TimeSpan.TicksPerMillisecond)
sleepInterval = TimeSpan.TicksPerMillisecond;
if (sleepInterval > TimeSpan.TicksPerSecond)
sleepInterval = TimeSpan.TicksPerSecond;
for (int i = 0; i < clients.Count; i++)
{
clients[i].MarkClosed(ClientClosedReason.ServerShutdown);
await clients[i].FlushAndCloseAsync(minimalFlush: true);
if (i < clients.Count - 1)
{
var jitter = Random.Shared.NextInt64(sleepInterval / 2, sleepInterval);
try
{
await Task.Delay(TimeSpan.FromTicks(jitter), _quitCts.Token);
}
catch (OperationCanceledException) { break; }
}
}
}
await ShutdownAsync();
}
/// <summary>
/// Registers Unix signal handlers.
/// SIGTERM → shutdown, SIGUSR2 → lame duck, SIGUSR1 → log reopen, SIGHUP → reload (stub).
/// </summary>
public void HandleSignals()
{
_signalRegistrations.Add(PosixSignalRegistration.Create(PosixSignal.SIGTERM, ctx =>
{
ctx.Cancel = true;
_logger.LogInformation("Trapped SIGTERM signal");
_ = Task.Run(async () => await ShutdownAsync());
}));
_signalRegistrations.Add(PosixSignalRegistration.Create(PosixSignal.SIGQUIT, ctx =>
{
ctx.Cancel = true;
_logger.LogInformation("Trapped SIGQUIT signal");
_ = Task.Run(async () => await ShutdownAsync());
}));
_signalRegistrations.Add(PosixSignalRegistration.Create(PosixSignal.SIGHUP, ctx =>
{
ctx.Cancel = true;
_logger.LogInformation("Trapped SIGHUP signal — reloading configuration");
_ = Task.Run(() => ReloadConfig());
}));
// SIGUSR1 and SIGUSR2 only on non-Windows
if (!OperatingSystem.IsWindows())
{
_signalRegistrations.Add(PosixSignalRegistration.Create((PosixSignal)10, ctx =>
{
ctx.Cancel = true;
_logger.LogInformation("Trapped SIGUSR1 signal — reopening log file");
ReOpenLogFile?.Invoke();
}));
_signalRegistrations.Add(PosixSignalRegistration.Create((PosixSignal)12, ctx =>
{
ctx.Cancel = true;
_logger.LogInformation("Trapped SIGUSR2 signal — entering lame duck mode");
_ = Task.Run(async () => await LameDuckShutdownAsync());
}));
}
}
public NatsServer(NatsOptions options, ILoggerFactory loggerFactory)
{
_options = options;
_loggerFactory = loggerFactory;
_logger = loggerFactory.CreateLogger<NatsServer>();
_authService = AuthService.Build(options);
_globalAccount = new Account(Account.GlobalAccountName);
_accounts[Account.GlobalAccountName] = _globalAccount;
// Create $SYS system account (stub -- no internal subscriptions yet)
_systemAccount = new Account("$SYS");
_accounts["$SYS"] = _systemAccount;
// Generate Ed25519 server NKey identity
using var serverKeyPair = KeyPair.CreatePair(PrefixByte.Server);
ServerNKey = serverKeyPair.GetPublicKey();
_serverInfo = new ServerInfo
{
ServerId = Guid.NewGuid().ToString("N")[..20].ToUpperInvariant(),
ServerName = options.ServerName ?? $"nats-dotnet-{Environment.MachineName}",
Version = NatsProtocol.Version,
Host = options.Host,
Port = options.Port,
MaxPayload = options.MaxPayload,
AuthRequired = _authService.IsAuthRequired,
};
if (options.Cluster != null)
{
_routeManager = new RouteManager(options.Cluster, _stats, _serverInfo.ServerId, ApplyRemoteSubscription,
_loggerFactory.CreateLogger<RouteManager>());
}
if (options.Gateway != null)
{
_gatewayManager = new GatewayManager(options.Gateway, _stats,
_loggerFactory.CreateLogger<GatewayManager>());
}
if (options.LeafNode != null)
{
_leafNodeManager = new LeafNodeManager(options.LeafNode, _stats,
_loggerFactory.CreateLogger<LeafNodeManager>());
}
if (options.JetStream != null)
{
_jetStreamStreamManager = new StreamManager();
_jetStreamConsumerManager = new ConsumerManager();
_jetStreamService = new JetStreamService(options.JetStream);
_jetStreamApiRouter = new JetStreamApiRouter(_jetStreamStreamManager, _jetStreamConsumerManager);
_jetStreamPublisher = new JetStreamPublisher(_jetStreamStreamManager);
}
if (options.HasTls)
{
_sslOptions = TlsHelper.BuildServerAuthOptions(options);
// OCSP stapling: build a certificate context so the runtime can
// fetch and cache a fresh OCSP response and staple it during the
// TLS handshake. offline:false tells the runtime to contact the
// OCSP responder; if the responder is unreachable we fall back to
// no stapling rather than refusing all connections.
var certContext = TlsHelper.BuildCertificateContext(options, offline: false);
if (certContext != null)
{
_sslOptions.ServerCertificateContext = certContext;
_logger.LogInformation("OCSP stapling enabled (mode: {OcspMode})", options.OcspConfig!.Mode);
}
_serverInfo.TlsRequired = !options.AllowNonTls;
_serverInfo.TlsAvailable = options.AllowNonTls;
_serverInfo.TlsVerify = options.TlsVerify;
if (options.TlsRateLimit > 0)
_tlsRateLimiter = new TlsRateLimiter(options.TlsRateLimit);
}
// Compile subject transforms
if (options.SubjectMappings is { Count: > 0 })
{
var transforms = new List<SubjectTransform>();
foreach (var (source, dest) in options.SubjectMappings)
{
var t = SubjectTransform.Create(source, dest);
if (t != null)
transforms.Add(t);
else
_logger.LogWarning("Invalid subject mapping: {Source} -> {Dest}", source, dest);
}
_subjectTransforms = transforms.ToArray();
if (_subjectTransforms.Length > 0)
_logger.LogInformation("Compiled {Count} subject transform(s)", _subjectTransforms.Length);
}
else
{
_subjectTransforms = [];
}
BuildCachedInfo();
// Store initial config digest for reload change detection
if (options.ConfigFile != null)
{
try
{
var (_, digest) = NatsConfParser.ParseFileWithDigest(options.ConfigFile);
_configDigest = digest;
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Could not compute initial config digest for {ConfigFile}", options.ConfigFile);
}
}
}
private void BuildCachedInfo()
{
var infoJson = System.Text.Json.JsonSerializer.Serialize(_serverInfo);
_cachedInfoLine = Encoding.ASCII.GetBytes($"INFO {infoJson}\r\n");
}
public async Task StartAsync(CancellationToken ct)
{
using var linked = CancellationTokenSource.CreateLinkedTokenSource(ct, _quitCts.Token);
_listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_listener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true);
_listener.Bind(new IPEndPoint(
_options.Host == "0.0.0.0" ? IPAddress.Any : IPAddress.Parse(_options.Host),
_options.Port));
Interlocked.Exchange(ref _startTimeTicks, DateTime.UtcNow.Ticks);
_listener.Listen(128);
// Resolve ephemeral port if port=0
if (_options.Port == 0)
{
var actualPort = ((IPEndPoint)_listener.LocalEndPoint!).Port;
_options.Port = actualPort;
_serverInfo.Port = actualPort;
BuildCachedInfo();
}
_logger.LogInformation("Listening for client connections on {Host}:{Port}", _options.Host, _options.Port);
// Warn about stub features
if (_options.ProfPort > 0)
_logger.LogWarning("Profiling endpoint not yet supported (port: {ProfPort})", _options.ProfPort);
if (_options.MonitorPort > 0)
{
_monitorServer = new MonitorServer(this, _options, _stats, _loggerFactory);
await _monitorServer.StartAsync(linked.Token);
}
WritePidFile();
WritePortsFile();
if (_options.WebSocket.Port >= 0)
{
_wsListener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_wsListener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true);
_wsListener.Bind(new IPEndPoint(
_options.WebSocket.Host == "0.0.0.0" ? IPAddress.Any : IPAddress.Parse(_options.WebSocket.Host),
_options.WebSocket.Port));
_wsListener.Listen(128);
if (_options.WebSocket.Port == 0)
{
_options.WebSocket.Port = ((IPEndPoint)_wsListener.LocalEndPoint!).Port;
}
_logger.LogInformation("Listening for WebSocket clients on {Host}:{Port}",
_options.WebSocket.Host, _options.WebSocket.Port);
if (_options.WebSocket.NoTls)
_logger.LogWarning("WebSocket not configured with TLS. DO NOT USE IN PRODUCTION!");
_ = RunWebSocketAcceptLoopAsync(linked.Token);
}
if (_routeManager != null)
await _routeManager.StartAsync(linked.Token);
if (_gatewayManager != null)
await _gatewayManager.StartAsync(linked.Token);
if (_leafNodeManager != null)
await _leafNodeManager.StartAsync(linked.Token);
if (_jetStreamService != null)
{
await _jetStreamService.StartAsync(linked.Token);
_stats.JetStreamEnabled = true;
}
_listeningStarted.TrySetResult();
var tmpDelay = AcceptMinSleep;
try
{
while (!linked.Token.IsCancellationRequested)
{
Socket socket;
try
{
socket = await _listener.AcceptAsync(linked.Token);
tmpDelay = AcceptMinSleep; // Reset on success
}
catch (OperationCanceledException)
{
break;
}
catch (ObjectDisposedException)
{
break;
}
catch (SocketException ex)
{
if (IsShuttingDown || IsLameDuckMode)
break;
_logger.LogError(ex, "Temporary accept error, sleeping {Delay}ms", tmpDelay.TotalMilliseconds);
try { await Task.Delay(tmpDelay, linked.Token); }
catch (OperationCanceledException) { break; }
tmpDelay = TimeSpan.FromTicks(Math.Min(tmpDelay.Ticks * 2, AcceptMaxSleep.Ticks));
continue;
}
// Check MaxConnections
if (_options.MaxConnections > 0 && _clients.Count >= _options.MaxConnections)
{
_logger.LogWarning("Client connection rejected: maximum connections ({MaxConnections}) exceeded",
_options.MaxConnections);
try
{
var stream = new NetworkStream(socket, ownsSocket: false);
var errBytes = Encoding.ASCII.GetBytes(
$"-ERR '{NatsProtocol.ErrMaxConnectionsExceeded}'\r\n");
await stream.WriteAsync(errBytes, linked.Token);
await stream.FlushAsync(linked.Token);
stream.Dispose();
}
catch (Exception ex2)
{
_logger.LogDebug(ex2, "Failed to send -ERR to rejected client");
}
finally
{
socket.Dispose();
}
continue;
}
var clientId = Interlocked.Increment(ref _nextClientId);
Interlocked.Increment(ref _stats.TotalConnections);
Interlocked.Increment(ref _activeClientCount);
_logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint);
_ = AcceptClientAsync(socket, clientId, linked.Token);
}
}
catch (OperationCanceledException)
{
_logger.LogDebug("Accept loop cancelled, server shutting down");
}
finally
{
_acceptLoopExited.TrySetResult();
}
}
private async Task AcceptClientAsync(Socket socket, ulong clientId, CancellationToken ct)
{
try
{
// Rate limit TLS handshakes
if (_tlsRateLimiter != null)
await _tlsRateLimiter.WaitAsync(ct);
var networkStream = new NetworkStream(socket, ownsSocket: false);
// TLS negotiation (no-op if not configured)
var (stream, infoAlreadySent) = await TlsConnectionWrapper.NegotiateAsync(
socket, networkStream, _options, _sslOptions, _serverInfo,
_loggerFactory.CreateLogger("NATS.Server.Tls"), ct);
// Extract TLS state
TlsConnectionState? tlsState = null;
if (stream is SslStream ssl)
{
tlsState = new TlsConnectionState(
ssl.SslProtocol.ToString(),
ssl.NegotiatedCipherSuite.ToString(),
ssl.RemoteCertificate as X509Certificate2);
}
// Build per-client ServerInfo with nonce if NKey auth is configured
byte[]? nonce = null;
var clientInfo = _serverInfo;
if (_authService.NonceRequired)
{
var rawNonce = _authService.GenerateNonce();
var nonceStr = _authService.EncodeNonce(rawNonce);
// The client signs the nonce string (ASCII), not the raw bytes
nonce = Encoding.ASCII.GetBytes(nonceStr);
clientInfo = new ServerInfo
{
ServerId = _serverInfo.ServerId,
ServerName = _serverInfo.ServerName,
Version = _serverInfo.Version,
Host = _serverInfo.Host,
Port = _serverInfo.Port,
MaxPayload = _serverInfo.MaxPayload,
AuthRequired = _serverInfo.AuthRequired,
TlsRequired = _serverInfo.TlsRequired,
TlsAvailable = _serverInfo.TlsAvailable,
TlsVerify = _serverInfo.TlsVerify,
Nonce = nonceStr,
};
}
var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]");
var client = new NatsClient(clientId, stream, socket, _options, clientInfo,
_authService, nonce, clientLogger, _stats);
client.Router = this;
client.TlsState = tlsState;
client.InfoAlreadySent = infoAlreadySent;
_clients[clientId] = client;
await RunClientAsync(client, ct);
}
catch (Exception ex)
{
_logger.LogDebug(ex, "Failed to accept client {ClientId}", clientId);
try { socket.Shutdown(SocketShutdown.Both); } catch { }
socket.Dispose();
}
}
private async Task RunWebSocketAcceptLoopAsync(CancellationToken ct)
{
var tmpDelay = AcceptMinSleep;
try
{
while (!ct.IsCancellationRequested)
{
Socket socket;
try
{
socket = await _wsListener!.AcceptAsync(ct);
tmpDelay = AcceptMinSleep;
}
catch (OperationCanceledException) { break; }
catch (ObjectDisposedException) { break; }
catch (SocketException ex)
{
if (IsShuttingDown || IsLameDuckMode) break;
_logger.LogError(ex, "Temporary WebSocket accept error, sleeping {Delay}ms", tmpDelay.TotalMilliseconds);
try { await Task.Delay(tmpDelay, ct); } catch (OperationCanceledException) { break; }
tmpDelay = TimeSpan.FromTicks(Math.Min(tmpDelay.Ticks * 2, AcceptMaxSleep.Ticks));
continue;
}
if (_options.MaxConnections > 0 && _clients.Count >= _options.MaxConnections)
{
socket.Dispose();
continue;
}
var clientId = Interlocked.Increment(ref _nextClientId);
Interlocked.Increment(ref _stats.TotalConnections);
Interlocked.Increment(ref _activeClientCount);
_ = AcceptWebSocketClientAsync(socket, clientId, ct);
}
}
finally
{
_wsAcceptLoopExited.TrySetResult();
}
}
private async Task AcceptWebSocketClientAsync(Socket socket, ulong clientId, CancellationToken ct)
{
try
{
var networkStream = new NetworkStream(socket, ownsSocket: false);
Stream stream = networkStream;
// TLS negotiation if configured
if (_sslOptions != null && !_options.WebSocket.NoTls)
{
var (tlsStream, _) = await TlsConnectionWrapper.NegotiateAsync(
socket, networkStream, _options, _sslOptions, _serverInfo,
_loggerFactory.CreateLogger("NATS.Server.Tls"), ct);
stream = tlsStream;
}
// HTTP upgrade handshake
var upgradeResult = await WsUpgrade.TryUpgradeAsync(stream, stream, _options.WebSocket, ct);
if (!upgradeResult.Success)
{
_logger.LogDebug("WebSocket upgrade failed for client {ClientId}", clientId);
socket.Dispose();
Interlocked.Decrement(ref _activeClientCount);
return;
}
// Create WsConnection wrapper
var wsConn = new WsConnection(stream,
compress: upgradeResult.Compress,
maskRead: upgradeResult.MaskRead,
maskWrite: upgradeResult.MaskWrite,
browser: upgradeResult.Browser,
noCompFrag: upgradeResult.NoCompFrag);
var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]");
var client = new NatsClient(clientId, wsConn, socket, _options, _serverInfo,
_authService, null, clientLogger, _stats);
client.Router = this;
client.IsWebSocket = true;
client.WsInfo = upgradeResult;
_clients[clientId] = client;
await RunClientAsync(client, ct);
}
catch (Exception ex)
{
_logger.LogDebug(ex, "Failed to accept WebSocket client {ClientId}", clientId);
try { socket.Shutdown(SocketShutdown.Both); } catch { }
socket.Dispose();
Interlocked.Decrement(ref _activeClientCount);
}
}
private async Task RunClientAsync(NatsClient client, CancellationToken ct)
{
try
{
await client.RunAsync(ct);
}
catch (Exception ex)
{
_logger.LogDebug(ex, "Client {ClientId} disconnected with error", client.Id);
}
finally
{
_logger.LogDebug("Client {ClientId} disconnected (reason: {CloseReason})", client.Id, client.CloseReason);
RemoveClient(client);
Interlocked.Decrement(ref _activeClientCount);
}
}
public void OnLocalSubscription(string subject, string? queue)
{
_routeManager?.PropagateLocalSubscription(subject, queue);
}
private void ApplyRemoteSubscription(RemoteSubscription sub)
{
_globalAccount.SubList.ApplyRemoteSub(sub);
}
public void ProcessMessage(string subject, string? replyTo, ReadOnlyMemory<byte> headers,
ReadOnlyMemory<byte> payload, NatsClient sender)
{
if (TryCaptureJetStreamPublish(subject, payload, out var pubAck))
sender.RecordJetStreamPubAck(pubAck);
// Apply subject transforms
if (_subjectTransforms.Length > 0)
{
foreach (var transform in _subjectTransforms)
{
var mapped = transform.Apply(subject);
if (mapped != null)
{
subject = mapped;
break; // First matching transform wins
}
}
}
var subList = sender.Account?.SubList ?? _globalAccount.SubList;
var result = subList.Match(subject);
var delivered = false;
// Deliver to plain subscribers
foreach (var sub in result.PlainSubs)
{
if (sub.Client == null || sub.Client == sender && !(sender.ClientOpts?.Echo ?? true))
continue;
DeliverMessage(sub, subject, replyTo, headers, payload);
delivered = true;
}
// Deliver to one member of each queue group (round-robin)
foreach (var queueGroup in result.QueueSubs)
{
if (queueGroup.Length == 0) continue;
// Simple round-robin -- pick based on total delivered across group
var idx = Math.Abs((int)Interlocked.Increment(ref sender.OutMsgs)) % queueGroup.Length;
// Undo the OutMsgs increment -- it will be incremented properly in SendMessage
Interlocked.Decrement(ref sender.OutMsgs);
for (int attempt = 0; attempt < queueGroup.Length; attempt++)
{
var sub = queueGroup[(idx + attempt) % queueGroup.Length];
if (sub.Client != null && (sub.Client != sender || (sender.ClientOpts?.Echo ?? true)))
{
DeliverMessage(sub, subject, replyTo, headers, payload);
delivered = true;
break;
}
}
}
// No-responders: if nobody received the message and the publisher
// opted in, send back a 503 status HMSG on the reply subject.
if (!delivered && replyTo != null && sender.ClientOpts?.NoResponders == true)
{
SendNoResponders(sender, replyTo);
}
}
private void DeliverMessage(Subscription sub, string subject, string? replyTo,
ReadOnlyMemory<byte> headers, ReadOnlyMemory<byte> payload)
{
var client = sub.Client;
if (client == null) return;
// Check auto-unsub
var count = Interlocked.Increment(ref sub.MessageCount);
if (sub.MaxMessages > 0 && count > sub.MaxMessages)
{
// Clean up exhausted subscription from trie and client tracking
var subList = client.Account?.SubList ?? _globalAccount.SubList;
subList.Remove(sub);
client.RemoveSubscription(sub.Sid);
return;
}
// Deny-list delivery filter
if (client.Permissions?.IsDeliveryAllowed(subject) == false)
return;
client.SendMessage(subject, sub.Sid, replyTo, headers, payload);
// Track reply subject for response permissions
if (replyTo != null && client.Permissions?.ResponseTracker != null)
{
if (client.Permissions.IsPublishAllowed(replyTo) == false)
client.Permissions.ResponseTracker.RegisterReply(replyTo);
}
}
private static void SendNoResponders(NatsClient sender, string replyTo)
{
// Find the sid for a subscription matching the reply subject
var sid = string.Empty;
foreach (var sub in sender.Subscriptions.Values)
{
if (SubjectMatch.MatchLiteral(replyTo, sub.Subject))
{
sid = sub.Sid;
break;
}
}
// Build: HMSG {replyTo} {sid} {hdrLen} {hdrLen}\r\n{headers}\r\n
var headerBlock = "NATS/1.0 503\r\n\r\n"u8;
var hdrLen = headerBlock.Length;
var controlLine = Encoding.ASCII.GetBytes($"HMSG {replyTo} {sid} {hdrLen} {hdrLen}\r\n");
var totalLen = controlLine.Length + hdrLen + NatsProtocol.CrLf.Length;
var msg = new byte[totalLen];
var offset = 0;
controlLine.CopyTo(msg.AsSpan(offset));
offset += controlLine.Length;
headerBlock.CopyTo(msg.AsSpan(offset));
offset += hdrLen;
NatsProtocol.CrLf.CopyTo(msg.AsSpan(offset));
sender.QueueOutbound(msg);
}
public Account GetOrCreateAccount(string name)
{
return _accounts.GetOrAdd(name, n =>
{
var acc = new Account(n);
if (_options.Accounts != null && _options.Accounts.TryGetValue(n, out var config))
{
acc.MaxConnections = config.MaxConnections;
acc.MaxSubscriptions = config.MaxSubscriptions;
acc.DefaultPermissions = config.DefaultPermissions;
}
return acc;
});
}
public void RemoveClient(NatsClient client)
{
_clients.TryRemove(client.Id, out _);
_logger.LogDebug("Removed client {ClientId}", client.Id);
// Snapshot for closed-connections tracking
_closedClients.Enqueue(new ClosedClient
{
Cid = client.Id,
Ip = client.RemoteIp ?? "",
Port = client.RemotePort,
Start = client.StartTime,
Stop = DateTime.UtcNow,
Reason = client.CloseReason.ToReasonString(),
Name = client.ClientOpts?.Name ?? "",
Lang = client.ClientOpts?.Lang ?? "",
Version = client.ClientOpts?.Version ?? "",
InMsgs = Interlocked.Read(ref client.InMsgs),
OutMsgs = Interlocked.Read(ref client.OutMsgs),
InBytes = Interlocked.Read(ref client.InBytes),
OutBytes = Interlocked.Read(ref client.OutBytes),
NumSubs = (uint)client.Subscriptions.Count,
Rtt = client.Rtt,
TlsVersion = client.TlsState?.TlsVersion ?? "",
TlsCipherSuite = client.TlsState?.CipherSuite ?? "",
});
// Cap closed clients list
while (_closedClients.Count > _options.MaxClosedClients)
_closedClients.TryDequeue(out _);
var subList = client.Account?.SubList ?? _globalAccount.SubList;
client.RemoveAllSubscriptions(subList);
client.Account?.RemoveClient(client.Id);
}
private void WritePidFile()
{
if (string.IsNullOrEmpty(_options.PidFile)) return;
try
{
File.WriteAllText(_options.PidFile, Environment.ProcessId.ToString());
_logger.LogDebug("Wrote PID file {PidFile}", _options.PidFile);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error writing PID file {PidFile}", _options.PidFile);
}
}
private void DeletePidFile()
{
if (string.IsNullOrEmpty(_options.PidFile)) return;
try
{
if (File.Exists(_options.PidFile))
File.Delete(_options.PidFile);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error deleting PID file {PidFile}", _options.PidFile);
}
}
private void WritePortsFile()
{
if (string.IsNullOrEmpty(_options.PortsFileDir)) return;
try
{
var exeName = Path.GetFileNameWithoutExtension(Environment.ProcessPath ?? "nats-server");
var fileName = $"{exeName}_{Environment.ProcessId}.ports";
_portsFilePath = Path.Combine(_options.PortsFileDir, fileName);
var ports = new { client = _options.Port, monitor = _options.MonitorPort > 0 ? _options.MonitorPort : (int?)null };
var json = System.Text.Json.JsonSerializer.Serialize(ports);
File.WriteAllText(_portsFilePath, json);
_logger.LogDebug("Wrote ports file {PortsFile}", _portsFilePath);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error writing ports file to {PortsFileDir}", _options.PortsFileDir);
}
}
private void DeletePortsFile()
{
if (_portsFilePath == null) return;
try
{
if (File.Exists(_portsFilePath))
File.Delete(_portsFilePath);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error deleting ports file {PortsFile}", _portsFilePath);
}
}
/// <summary>
/// Stores the CLI snapshot and flags so that command-line overrides
/// always take precedence during config reload.
/// </summary>
public void SetCliSnapshot(NatsOptions cliSnapshot, HashSet<string> cliFlags)
{
_cliSnapshot = cliSnapshot;
_cliFlags = cliFlags;
}
/// <summary>
/// Reloads the configuration file, diffs against current options, validates
/// the changes, and applies reloadable settings. CLI overrides are preserved.
/// </summary>
public void ReloadConfig()
{
if (_options.ConfigFile == null)
{
_logger.LogWarning("No config file specified, cannot reload");
return;
}
try
{
var (newConfig, digest) = NatsConfParser.ParseFileWithDigest(_options.ConfigFile);
if (digest == _configDigest)
{
_logger.LogInformation("Config file unchanged, no reload needed");
return;
}
var newOpts = new NatsOptions { ConfigFile = _options.ConfigFile };
ConfigProcessor.ApplyConfig(newConfig, newOpts);
// CLI flags override config
if (_cliSnapshot != null)
ConfigReloader.MergeCliOverrides(newOpts, _cliSnapshot, _cliFlags);
var changes = ConfigReloader.Diff(_options, newOpts);
var errors = ConfigReloader.Validate(changes);
if (errors.Count > 0)
{
foreach (var err in errors)
_logger.LogError("Config reload error: {Error}", err);
return;
}
// Apply changes to running options
ApplyConfigChanges(changes, newOpts);
_configDigest = digest;
_logger.LogInformation("Config reloaded successfully ({Count} changes applied)", changes.Count);
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to reload config file: {ConfigFile}", _options.ConfigFile);
}
}
private void ApplyConfigChanges(List<IConfigChange> changes, NatsOptions newOpts)
{
bool hasLoggingChanges = false;
bool hasAuthChanges = false;
foreach (var change in changes)
{
if (change.IsLoggingChange) hasLoggingChanges = true;
if (change.IsAuthChange) hasAuthChanges = true;
}
// Copy reloadable values from newOpts to _options
CopyReloadableOptions(newOpts);
// Trigger side effects
if (hasLoggingChanges)
{
ReOpenLogFile?.Invoke();
_logger.LogInformation("Logging configuration reloaded");
}
if (hasAuthChanges)
{
// Rebuild auth service with new options
_authService = AuthService.Build(_options);
_logger.LogInformation("Authorization configuration reloaded");
}
}
private void CopyReloadableOptions(NatsOptions newOpts)
{
// Logging
_options.Debug = newOpts.Debug;
_options.Trace = newOpts.Trace;
_options.TraceVerbose = newOpts.TraceVerbose;
_options.Logtime = newOpts.Logtime;
_options.LogtimeUTC = newOpts.LogtimeUTC;
_options.LogFile = newOpts.LogFile;
_options.LogSizeLimit = newOpts.LogSizeLimit;
_options.LogMaxFiles = newOpts.LogMaxFiles;
_options.Syslog = newOpts.Syslog;
_options.RemoteSyslog = newOpts.RemoteSyslog;
// Auth
_options.Username = newOpts.Username;
_options.Password = newOpts.Password;
_options.Authorization = newOpts.Authorization;
_options.Users = newOpts.Users;
_options.NKeys = newOpts.NKeys;
_options.NoAuthUser = newOpts.NoAuthUser;
_options.AuthTimeout = newOpts.AuthTimeout;
// Limits
_options.MaxConnections = newOpts.MaxConnections;
_options.MaxPayload = newOpts.MaxPayload;
_options.MaxPending = newOpts.MaxPending;
_options.WriteDeadline = newOpts.WriteDeadline;
_options.PingInterval = newOpts.PingInterval;
_options.MaxPingsOut = newOpts.MaxPingsOut;
_options.MaxControlLine = newOpts.MaxControlLine;
_options.MaxSubs = newOpts.MaxSubs;
_options.MaxSubTokens = newOpts.MaxSubTokens;
_options.MaxTracedMsgLen = newOpts.MaxTracedMsgLen;
_options.MaxClosedClients = newOpts.MaxClosedClients;
// TLS
_options.TlsCert = newOpts.TlsCert;
_options.TlsKey = newOpts.TlsKey;
_options.TlsCaCert = newOpts.TlsCaCert;
_options.TlsVerify = newOpts.TlsVerify;
_options.TlsMap = newOpts.TlsMap;
_options.TlsTimeout = newOpts.TlsTimeout;
_options.TlsHandshakeFirst = newOpts.TlsHandshakeFirst;
_options.TlsHandshakeFirstFallback = newOpts.TlsHandshakeFirstFallback;
_options.AllowNonTls = newOpts.AllowNonTls;
_options.TlsRateLimit = newOpts.TlsRateLimit;
_options.TlsPinnedCerts = newOpts.TlsPinnedCerts;
// Misc
_options.Tags = newOpts.Tags;
_options.LameDuckDuration = newOpts.LameDuckDuration;
_options.LameDuckGracePeriod = newOpts.LameDuckGracePeriod;
_options.ClientAdvertise = newOpts.ClientAdvertise;
_options.DisableSublistCache = newOpts.DisableSublistCache;
_options.ConnectErrorReports = newOpts.ConnectErrorReports;
_options.ReconnectErrorReports = newOpts.ReconnectErrorReports;
_options.NoHeaderSupport = newOpts.NoHeaderSupport;
_options.NoSystemAccount = newOpts.NoSystemAccount;
_options.SystemAccount = newOpts.SystemAccount;
}
public void Dispose()
{
if (!IsShuttingDown)
ShutdownAsync().GetAwaiter().GetResult();
foreach (var reg in _signalRegistrations)
reg.Dispose();
_quitCts.Dispose();
_tlsRateLimiter?.Dispose();
_listener?.Dispose();
_wsListener?.Dispose();
_routeManager?.DisposeAsync().AsTask().GetAwaiter().GetResult();
_gatewayManager?.DisposeAsync().AsTask().GetAwaiter().GetResult();
_leafNodeManager?.DisposeAsync().AsTask().GetAwaiter().GetResult();
_jetStreamService?.DisposeAsync().AsTask().GetAwaiter().GetResult();
_stats.JetStreamEnabled = false;
foreach (var client in _clients.Values)
client.Dispose();
foreach (var account in _accounts.Values)
account.Dispose();
}
}