using System.Collections.Concurrent; using System.Net; using System.Net.Sockets; using System.Text; using Microsoft.Extensions.Logging; using NATS.Server.Auth; using NATS.Server.Protocol; using NATS.Server.Subscriptions; namespace NATS.Server; public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable { private readonly NatsOptions _options; private readonly ConcurrentDictionary _clients = new(); private readonly ServerInfo _serverInfo; private readonly ILogger _logger; private readonly ILoggerFactory _loggerFactory; private readonly TaskCompletionSource _listeningStarted = new(TaskCreationOptions.RunContinuationsAsynchronously); private readonly AuthService _authService; private readonly ConcurrentDictionary _accounts = new(StringComparer.Ordinal); private readonly Account _globalAccount; private Socket? _listener; private ulong _nextClientId; public SubList SubList => _globalAccount.SubList; public Task WaitForReadyAsync() => _listeningStarted.Task; public NatsServer(NatsOptions options, ILoggerFactory loggerFactory) { _options = options; _loggerFactory = loggerFactory; _logger = loggerFactory.CreateLogger(); _authService = AuthService.Build(options); _globalAccount = new Account(Account.GlobalAccountName); _accounts[Account.GlobalAccountName] = _globalAccount; _serverInfo = new ServerInfo { ServerId = Guid.NewGuid().ToString("N")[..20].ToUpperInvariant(), ServerName = options.ServerName ?? $"nats-dotnet-{Environment.MachineName}", Version = NatsProtocol.Version, Host = options.Host, Port = options.Port, MaxPayload = options.MaxPayload, AuthRequired = _authService.IsAuthRequired, }; } public async Task StartAsync(CancellationToken ct) { _listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); _listener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true); _listener.Bind(new IPEndPoint( _options.Host == "0.0.0.0" ? IPAddress.Any : IPAddress.Parse(_options.Host), _options.Port)); _listener.Listen(128); _listeningStarted.TrySetResult(); _logger.LogInformation("Listening on {Host}:{Port}", _options.Host, _options.Port); try { while (!ct.IsCancellationRequested) { var socket = await _listener.AcceptAsync(ct); // Check MaxConnections before creating the client if (_options.MaxConnections > 0 && _clients.Count >= _options.MaxConnections) { _logger.LogWarning("Client connection rejected: maximum connections ({MaxConnections}) exceeded", _options.MaxConnections); try { var stream = new NetworkStream(socket, ownsSocket: false); var errBytes = Encoding.ASCII.GetBytes( $"-ERR '{NatsProtocol.ErrMaxConnectionsExceeded}'\r\n"); await stream.WriteAsync(errBytes, ct); await stream.FlushAsync(ct); stream.Dispose(); } catch (Exception ex) { _logger.LogDebug(ex, "Failed to send -ERR to rejected client"); } finally { socket.Dispose(); } continue; } var clientId = Interlocked.Increment(ref _nextClientId); _logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint); // Build per-client ServerInfo with nonce if NKey auth is configured byte[]? nonce = null; var clientInfo = _serverInfo; if (_authService.NonceRequired) { var rawNonce = _authService.GenerateNonce(); var nonceStr = _authService.EncodeNonce(rawNonce); // The client signs the nonce string (ASCII), not the raw bytes nonce = Encoding.ASCII.GetBytes(nonceStr); clientInfo = new ServerInfo { ServerId = _serverInfo.ServerId, ServerName = _serverInfo.ServerName, Version = _serverInfo.Version, Host = _serverInfo.Host, Port = _serverInfo.Port, MaxPayload = _serverInfo.MaxPayload, AuthRequired = _serverInfo.AuthRequired, Nonce = nonceStr, }; } var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]"); var client = new NatsClient(clientId, socket, _options, clientInfo, _authService, nonce, clientLogger); client.Router = this; _clients[clientId] = client; _ = RunClientAsync(client, ct); } } catch (OperationCanceledException) { _logger.LogDebug("Accept loop cancelled, server shutting down"); } } private async Task RunClientAsync(NatsClient client, CancellationToken ct) { try { await client.RunAsync(ct); } catch (Exception ex) { _logger.LogDebug(ex, "Client {ClientId} disconnected with error", client.Id); } finally { _logger.LogDebug("Client {ClientId} disconnected", client.Id); RemoveClient(client); } } public void ProcessMessage(string subject, string? replyTo, ReadOnlyMemory headers, ReadOnlyMemory payload, NatsClient sender) { var subList = sender.Account?.SubList ?? _globalAccount.SubList; var result = subList.Match(subject); // Deliver to plain subscribers foreach (var sub in result.PlainSubs) { if (sub.Client == null || sub.Client == sender && !(sender.ClientOpts?.Echo ?? true)) continue; DeliverMessage(sub, subject, replyTo, headers, payload); } // Deliver to one member of each queue group (round-robin) foreach (var queueGroup in result.QueueSubs) { if (queueGroup.Length == 0) continue; // Simple round-robin -- pick based on total delivered across group var idx = Math.Abs((int)Interlocked.Increment(ref sender.OutMsgs)) % queueGroup.Length; // Undo the OutMsgs increment -- it will be incremented properly in SendMessageAsync Interlocked.Decrement(ref sender.OutMsgs); for (int attempt = 0; attempt < queueGroup.Length; attempt++) { var sub = queueGroup[(idx + attempt) % queueGroup.Length]; if (sub.Client != null && (sub.Client != sender || (sender.ClientOpts?.Echo ?? true))) { DeliverMessage(sub, subject, replyTo, headers, payload); break; } } } } private static void DeliverMessage(Subscription sub, string subject, string? replyTo, ReadOnlyMemory headers, ReadOnlyMemory payload) { var client = sub.Client; if (client == null) return; // Check auto-unsub var count = Interlocked.Increment(ref sub.MessageCount); if (sub.MaxMessages > 0 && count > sub.MaxMessages) return; // Fire and forget -- deliver asynchronously _ = client.SendMessageAsync(subject, sub.Sid, replyTo, headers, payload, CancellationToken.None); } public Account GetOrCreateAccount(string name) { return _accounts.GetOrAdd(name, n => new Account(n)); } public void RemoveClient(NatsClient client) { _clients.TryRemove(client.Id, out _); _logger.LogDebug("Removed client {ClientId}", client.Id); var subList = client.Account?.SubList ?? _globalAccount.SubList; client.RemoveAllSubscriptions(subList); client.Account?.RemoveClient(client.Id); } public void Dispose() { _listener?.Dispose(); foreach (var client in _clients.Values) client.Dispose(); foreach (var account in _accounts.Values) account.Dispose(); } }