Files
natsdotnet/src/NATS.Server/NatsServer.cs
Joseph Doherty c40c2cd994 test: add permission enforcement and NKey integration tests
Fix NKey nonce verification: the NATS client signs the nonce string
(ASCII bytes of the base64url-encoded nonce), not the raw nonce bytes.
Pass the encoded nonce string bytes to the authenticator for verification.
2026-02-22 23:03:41 -05:00

225 lines
8.6 KiB
C#

using System.Collections.Concurrent;
using System.Net;
using System.Net.Sockets;
using System.Text;
using Microsoft.Extensions.Logging;
using NATS.Server.Auth;
using NATS.Server.Protocol;
using NATS.Server.Subscriptions;
namespace NATS.Server;
public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
{
private readonly NatsOptions _options;
private readonly ConcurrentDictionary<ulong, NatsClient> _clients = new();
private readonly ServerInfo _serverInfo;
private readonly ILogger<NatsServer> _logger;
private readonly ILoggerFactory _loggerFactory;
private readonly TaskCompletionSource _listeningStarted = new(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly AuthService _authService;
private readonly ConcurrentDictionary<string, Account> _accounts = new(StringComparer.Ordinal);
private readonly Account _globalAccount;
private Socket? _listener;
private ulong _nextClientId;
public SubList SubList => _globalAccount.SubList;
public Task WaitForReadyAsync() => _listeningStarted.Task;
public NatsServer(NatsOptions options, ILoggerFactory loggerFactory)
{
_options = options;
_loggerFactory = loggerFactory;
_logger = loggerFactory.CreateLogger<NatsServer>();
_authService = AuthService.Build(options);
_globalAccount = new Account(Account.GlobalAccountName);
_accounts[Account.GlobalAccountName] = _globalAccount;
_serverInfo = new ServerInfo
{
ServerId = Guid.NewGuid().ToString("N")[..20].ToUpperInvariant(),
ServerName = options.ServerName ?? $"nats-dotnet-{Environment.MachineName}",
Version = NatsProtocol.Version,
Host = options.Host,
Port = options.Port,
MaxPayload = options.MaxPayload,
AuthRequired = _authService.IsAuthRequired,
};
}
public async Task StartAsync(CancellationToken ct)
{
_listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_listener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true);
_listener.Bind(new IPEndPoint(
_options.Host == "0.0.0.0" ? IPAddress.Any : IPAddress.Parse(_options.Host),
_options.Port));
_listener.Listen(128);
_listeningStarted.TrySetResult();
_logger.LogInformation("Listening on {Host}:{Port}", _options.Host, _options.Port);
try
{
while (!ct.IsCancellationRequested)
{
var socket = await _listener.AcceptAsync(ct);
// Check MaxConnections before creating the client
if (_options.MaxConnections > 0 && _clients.Count >= _options.MaxConnections)
{
_logger.LogWarning("Client connection rejected: maximum connections ({MaxConnections}) exceeded",
_options.MaxConnections);
try
{
var stream = new NetworkStream(socket, ownsSocket: false);
var errBytes = Encoding.ASCII.GetBytes(
$"-ERR '{NatsProtocol.ErrMaxConnectionsExceeded}'\r\n");
await stream.WriteAsync(errBytes, ct);
await stream.FlushAsync(ct);
stream.Dispose();
}
catch (Exception ex)
{
_logger.LogDebug(ex, "Failed to send -ERR to rejected client");
}
finally
{
socket.Dispose();
}
continue;
}
var clientId = Interlocked.Increment(ref _nextClientId);
_logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint);
// Build per-client ServerInfo with nonce if NKey auth is configured
byte[]? nonce = null;
var clientInfo = _serverInfo;
if (_authService.NonceRequired)
{
var rawNonce = _authService.GenerateNonce();
var nonceStr = _authService.EncodeNonce(rawNonce);
// The client signs the nonce string (ASCII), not the raw bytes
nonce = Encoding.ASCII.GetBytes(nonceStr);
clientInfo = new ServerInfo
{
ServerId = _serverInfo.ServerId,
ServerName = _serverInfo.ServerName,
Version = _serverInfo.Version,
Host = _serverInfo.Host,
Port = _serverInfo.Port,
MaxPayload = _serverInfo.MaxPayload,
AuthRequired = _serverInfo.AuthRequired,
Nonce = nonceStr,
};
}
var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]");
var client = new NatsClient(clientId, socket, _options, clientInfo, _authService, nonce, clientLogger);
client.Router = this;
_clients[clientId] = client;
_ = RunClientAsync(client, ct);
}
}
catch (OperationCanceledException)
{
_logger.LogDebug("Accept loop cancelled, server shutting down");
}
}
private async Task RunClientAsync(NatsClient client, CancellationToken ct)
{
try
{
await client.RunAsync(ct);
}
catch (Exception ex)
{
_logger.LogDebug(ex, "Client {ClientId} disconnected with error", client.Id);
}
finally
{
_logger.LogDebug("Client {ClientId} disconnected", client.Id);
RemoveClient(client);
}
}
public void ProcessMessage(string subject, string? replyTo, ReadOnlyMemory<byte> headers,
ReadOnlyMemory<byte> payload, NatsClient sender)
{
var subList = sender.Account?.SubList ?? _globalAccount.SubList;
var result = subList.Match(subject);
// Deliver to plain subscribers
foreach (var sub in result.PlainSubs)
{
if (sub.Client == null || sub.Client == sender && !(sender.ClientOpts?.Echo ?? true))
continue;
DeliverMessage(sub, subject, replyTo, headers, payload);
}
// Deliver to one member of each queue group (round-robin)
foreach (var queueGroup in result.QueueSubs)
{
if (queueGroup.Length == 0) continue;
// Simple round-robin -- pick based on total delivered across group
var idx = Math.Abs((int)Interlocked.Increment(ref sender.OutMsgs)) % queueGroup.Length;
// Undo the OutMsgs increment -- it will be incremented properly in SendMessageAsync
Interlocked.Decrement(ref sender.OutMsgs);
for (int attempt = 0; attempt < queueGroup.Length; attempt++)
{
var sub = queueGroup[(idx + attempt) % queueGroup.Length];
if (sub.Client != null && (sub.Client != sender || (sender.ClientOpts?.Echo ?? true)))
{
DeliverMessage(sub, subject, replyTo, headers, payload);
break;
}
}
}
}
private static void DeliverMessage(Subscription sub, string subject, string? replyTo,
ReadOnlyMemory<byte> headers, ReadOnlyMemory<byte> payload)
{
var client = sub.Client;
if (client == null) return;
// Check auto-unsub
var count = Interlocked.Increment(ref sub.MessageCount);
if (sub.MaxMessages > 0 && count > sub.MaxMessages)
return;
// Fire and forget -- deliver asynchronously
_ = client.SendMessageAsync(subject, sub.Sid, replyTo, headers, payload, CancellationToken.None);
}
public Account GetOrCreateAccount(string name)
{
return _accounts.GetOrAdd(name, n => new Account(n));
}
public void RemoveClient(NatsClient client)
{
_clients.TryRemove(client.Id, out _);
_logger.LogDebug("Removed client {ClientId}", client.Id);
var subList = client.Account?.SubList ?? _globalAccount.SubList;
client.RemoveAllSubscriptions(subList);
client.Account?.RemoveClient(client.Id);
}
public void Dispose()
{
_listener?.Dispose();
foreach (var client in _clients.Values)
client.Dispose();
foreach (var account in _accounts.Values)
account.Dispose();
}
}