From 2980a343c17bfb59f77e407ccfed5d805f685d08 Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 22 Feb 2026 22:55:50 -0500 Subject: [PATCH] feat: integrate authentication into server accept loop and client CONNECT processing Wire AuthService into NatsServer and NatsClient to enforce authentication on incoming connections. The server builds an AuthService from NatsOptions, sets auth_required in ServerInfo, and generates per-client nonces when NKey auth is configured. NatsClient validates credentials in ProcessConnect, enforces publish/subscribe permissions, and implements an auth timeout that closes connections that don't send CONNECT in time. Existing tests without auth continue to work since AuthService.IsAuthRequired is false by default. --- src/NATS.Server/NatsClient.cs | 106 +++++++- src/NATS.Server/NatsServer.cs | 37 ++- .../NATS.Server.Tests/AuthIntegrationTests.cs | 256 ++++++++++++++++++ tests/NATS.Server.Tests/ClientTests.cs | 4 +- 4 files changed, 396 insertions(+), 7 deletions(-) create mode 100644 tests/NATS.Server.Tests/AuthIntegrationTests.cs diff --git a/src/NATS.Server/NatsClient.cs b/src/NATS.Server/NatsClient.cs index 149a399..4e62313 100644 --- a/src/NATS.Server/NatsClient.cs +++ b/src/NATS.Server/NatsClient.cs @@ -4,6 +4,7 @@ using System.Net.Sockets; using System.Text; using System.Text.Json; using Microsoft.Extensions.Logging; +using NATS.Server.Auth; using NATS.Server.Protocol; using NATS.Server.Subscriptions; @@ -27,16 +28,20 @@ public sealed class NatsClient : IDisposable private readonly NetworkStream _stream; private readonly NatsOptions _options; private readonly ServerInfo _serverInfo; + private readonly AuthService _authService; + private readonly byte[]? _nonce; private readonly NatsParser _parser; private readonly SemaphoreSlim _writeLock = new(1, 1); private CancellationTokenSource? _clientCts; private readonly Dictionary _subs = new(); private readonly ILogger _logger; + private ClientPermissions? _permissions; public ulong Id { get; } public ClientOptions? ClientOpts { get; private set; } public IMessageRouter? Router { get; set; } public bool ConnectReceived { get; private set; } + public Account? Account { get; private set; } // Stats public long InMsgs; @@ -50,13 +55,16 @@ public sealed class NatsClient : IDisposable public IReadOnlyDictionary Subscriptions => _subs; - public NatsClient(ulong id, Socket socket, NatsOptions options, ServerInfo serverInfo, ILogger logger) + public NatsClient(ulong id, Socket socket, NatsOptions options, ServerInfo serverInfo, + AuthService authService, byte[]? nonce, ILogger logger) { Id = id; _socket = socket; _stream = new NetworkStream(socket, ownsSocket: false); _options = options; _serverInfo = serverInfo; + _authService = authService; + _nonce = nonce; _logger = logger; _parser = new NatsParser(options.MaxPayload); } @@ -71,6 +79,28 @@ public sealed class NatsClient : IDisposable // Send INFO await SendInfoAsync(_clientCts.Token); + // Start auth timeout if auth is required + Task? authTimeoutTask = null; + if (_authService.IsAuthRequired) + { + authTimeoutTask = Task.Run(async () => + { + try + { + await Task.Delay(_options.AuthTimeout, _clientCts!.Token); + if (!ConnectReceived) + { + _logger.LogDebug("Client {ClientId} auth timeout", Id); + await SendErrAndCloseAsync(NatsProtocol.ErrAuthTimeout); + } + } + catch (OperationCanceledException) + { + // Normal — client connected or was cancelled + } + }, _clientCts.Token); + } + // Start read pump, command processing, and ping timer in parallel var fillTask = FillPipeAsync(pipe.Writer, _clientCts.Token); var processTask = ProcessCommandsAsync(pipe.Reader, _clientCts.Token); @@ -147,10 +177,28 @@ public sealed class NatsClient : IDisposable private async ValueTask DispatchCommandAsync(ParsedCommand cmd, CancellationToken ct) { + // If auth is required and CONNECT hasn't been received yet, + // only allow CONNECT and PING commands + if (_authService.IsAuthRequired && !ConnectReceived) + { + switch (cmd.Type) + { + case CommandType.Connect: + await ProcessConnectAsync(cmd); + return; + case CommandType.Ping: + await WriteAsync(NatsProtocol.PongBytes, ct); + return; + default: + // Ignore all other commands until authenticated + return; + } + } + switch (cmd.Type) { case CommandType.Connect: - ProcessConnect(cmd); + await ProcessConnectAsync(cmd); break; case CommandType.Ping: @@ -162,7 +210,7 @@ public sealed class NatsClient : IDisposable break; case CommandType.Sub: - ProcessSub(cmd); + await ProcessSubAsync(cmd); break; case CommandType.Unsub: @@ -176,16 +224,56 @@ public sealed class NatsClient : IDisposable } } - private void ProcessConnect(ParsedCommand cmd) + private async ValueTask ProcessConnectAsync(ParsedCommand cmd) { ClientOpts = JsonSerializer.Deserialize(cmd.Payload.Span) ?? new ClientOptions(); + + // Authenticate if auth is required + if (_authService.IsAuthRequired) + { + var context = new ClientAuthContext + { + Opts = ClientOpts, + Nonce = _nonce ?? [], + }; + + var result = _authService.Authenticate(context); + if (result == null) + { + _logger.LogWarning("Client {ClientId} authentication failed", Id); + await SendErrAndCloseAsync(NatsProtocol.ErrAuthorizationViolation); + return; + } + + // Build permissions from auth result + _permissions = ClientPermissions.Build(result.Permissions); + + // Resolve account + if (Router is NatsServer server) + { + var accountName = result.AccountName ?? Account.GlobalAccountName; + Account = server.GetOrCreateAccount(accountName); + Account.AddClient(Id); + } + + _logger.LogDebug("Client {ClientId} authenticated as {Identity}", Id, result.Identity); + } + ConnectReceived = true; _logger.LogDebug("CONNECT received from client {ClientId}, name={ClientName}", Id, ClientOpts?.Name); } - private void ProcessSub(ParsedCommand cmd) + private async ValueTask ProcessSubAsync(ParsedCommand cmd) { + // Permission check for subscribe + if (_permissions != null && !_permissions.IsSubscribeAllowed(cmd.Subject!, cmd.Queue)) + { + _logger.LogDebug("Client {ClientId} subscribe permission denied for {Subject}", Id, cmd.Subject); + await SendErrAsync(NatsProtocol.ErrPermissionsSubscribe); + return; + } + var sub = new Subscription { Subject = cmd.Subject!, @@ -244,6 +332,14 @@ public sealed class NatsClient : IDisposable return; } + // Permission check for publish + if (_permissions != null && !_permissions.IsPublishAllowed(cmd.Subject!)) + { + _logger.LogDebug("Client {ClientId} publish permission denied for {Subject}", Id, cmd.Subject); + await SendErrAsync(NatsProtocol.ErrPermissionsPublish); + return; + } + ReadOnlyMemory headers = default; ReadOnlyMemory payload = cmd.Payload; diff --git a/src/NATS.Server/NatsServer.cs b/src/NATS.Server/NatsServer.cs index 5900aef..ae1637e 100644 --- a/src/NATS.Server/NatsServer.cs +++ b/src/NATS.Server/NatsServer.cs @@ -3,6 +3,7 @@ using System.Net; using System.Net.Sockets; using System.Text; using Microsoft.Extensions.Logging; +using NATS.Server.Auth; using NATS.Server.Protocol; using NATS.Server.Subscriptions; @@ -17,6 +18,9 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable private readonly ILogger _logger; private readonly ILoggerFactory _loggerFactory; private readonly TaskCompletionSource _listeningStarted = new(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly AuthService _authService; + private readonly ConcurrentDictionary _accounts = new(StringComparer.Ordinal); + private readonly Account _globalAccount; private Socket? _listener; private ulong _nextClientId; @@ -29,6 +33,9 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable _options = options; _loggerFactory = loggerFactory; _logger = loggerFactory.CreateLogger(); + _authService = AuthService.Build(options); + _globalAccount = new Account(Account.GlobalAccountName); + _accounts[Account.GlobalAccountName] = _globalAccount; _serverInfo = new ServerInfo { ServerId = Guid.NewGuid().ToString("N")[..20].ToUpperInvariant(), @@ -37,6 +44,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable Host = options.Host, Port = options.Port, MaxPayload = options.MaxPayload, + AuthRequired = _authService.IsAuthRequired, }; } @@ -87,8 +95,27 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable _logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint); + // Build per-client ServerInfo with nonce if NKey auth is configured + byte[]? nonce = null; + var clientInfo = _serverInfo; + if (_authService.NonceRequired) + { + nonce = _authService.GenerateNonce(); + clientInfo = new ServerInfo + { + ServerId = _serverInfo.ServerId, + ServerName = _serverInfo.ServerName, + Version = _serverInfo.Version, + Host = _serverInfo.Host, + Port = _serverInfo.Port, + MaxPayload = _serverInfo.MaxPayload, + AuthRequired = _serverInfo.AuthRequired, + Nonce = _authService.EncodeNonce(nonce), + }; + } + var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]"); - var client = new NatsClient(clientId, socket, _options, _serverInfo, clientLogger); + var client = new NatsClient(clientId, socket, _options, clientInfo, _authService, nonce, clientLogger); client.Router = this; _clients[clientId] = client; @@ -169,11 +196,17 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable _ = client.SendMessageAsync(subject, sub.Sid, replyTo, headers, payload, CancellationToken.None); } + public Account GetOrCreateAccount(string name) + { + return _accounts.GetOrAdd(name, n => new Account(n)); + } + public void RemoveClient(NatsClient client) { _clients.TryRemove(client.Id, out _); _logger.LogDebug("Removed client {ClientId}", client.Id); client.RemoveAllSubscriptions(_subList); + client.Account?.RemoveClient(client.Id); } public void Dispose() @@ -182,5 +215,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable foreach (var client in _clients.Values) client.Dispose(); _subList.Dispose(); + foreach (var account in _accounts.Values) + account.Dispose(); } } diff --git a/tests/NATS.Server.Tests/AuthIntegrationTests.cs b/tests/NATS.Server.Tests/AuthIntegrationTests.cs new file mode 100644 index 0000000..e8151dd --- /dev/null +++ b/tests/NATS.Server.Tests/AuthIntegrationTests.cs @@ -0,0 +1,256 @@ +using System.Net; +using System.Net.Sockets; +using Microsoft.Extensions.Logging.Abstractions; +using NATS.Client.Core; +using NATS.Server; +using NATS.Server.Auth; + +namespace NATS.Server.Tests; + +public class AuthIntegrationTests +{ + private static int GetFreePort() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + sock.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + return ((IPEndPoint)sock.LocalEndPoint!).Port; + } + + /// + /// Checks whether any exception in the chain contains the given substring. + /// The NATS client wraps server errors in outer NatsException messages, + /// so the actual "Authorization Violation" may be in an inner exception. + /// + private static bool ExceptionChainContains(Exception ex, string substring) + { + Exception? current = ex; + while (current != null) + { + if (current.Message.Contains(substring, StringComparison.OrdinalIgnoreCase)) + return true; + current = current.InnerException; + } + + return false; + } + + private static (NatsServer server, int port, CancellationTokenSource cts) StartServer(NatsOptions options) + { + var port = GetFreePort(); + options.Port = port; + var server = new NatsServer(options, NullLoggerFactory.Instance); + var cts = new CancellationTokenSource(); + _ = server.StartAsync(cts.Token); + return (server, port, cts); + } + + private static async Task<(NatsServer server, int port, CancellationTokenSource cts)> StartServerAsync(NatsOptions options) + { + var (server, port, cts) = StartServer(options); + await server.WaitForReadyAsync(); + return (server, port, cts); + } + + [Fact] + public async Task Token_auth_success() + { + var (server, port, cts) = await StartServerAsync(new NatsOptions + { + Authorization = "s3cr3t", + }); + + try + { + await using var client = new NatsConnection(new NatsOpts + { + Url = $"nats://s3cr3t@127.0.0.1:{port}", + }); + await client.ConnectAsync(); + await client.PingAsync(); + } + finally + { + await cts.CancelAsync(); + server.Dispose(); + } + } + + [Fact] + public async Task Token_auth_failure_disconnects() + { + var (server, port, cts) = await StartServerAsync(new NatsOptions + { + Authorization = "s3cr3t", + }); + + try + { + await using var client = new NatsConnection(new NatsOpts + { + Url = $"nats://wrongtoken@127.0.0.1:{port}", + MaxReconnectRetry = 0, + }); + + var ex = await Should.ThrowAsync(async () => + { + await client.ConnectAsync(); + await client.PingAsync(); + }); + + ExceptionChainContains(ex, "Authorization Violation").ShouldBeTrue( + $"Expected 'Authorization Violation' in exception chain, but got: {ex}"); + } + finally + { + await cts.CancelAsync(); + server.Dispose(); + } + } + + [Fact] + public async Task UserPassword_auth_success() + { + var (server, port, cts) = await StartServerAsync(new NatsOptions + { + Username = "admin", + Password = "secret", + }); + + try + { + await using var client = new NatsConnection(new NatsOpts + { + Url = $"nats://admin:secret@127.0.0.1:{port}", + }); + await client.ConnectAsync(); + await client.PingAsync(); + } + finally + { + await cts.CancelAsync(); + server.Dispose(); + } + } + + [Fact] + public async Task UserPassword_auth_failure_disconnects() + { + var (server, port, cts) = await StartServerAsync(new NatsOptions + { + Username = "admin", + Password = "secret", + }); + + try + { + await using var client = new NatsConnection(new NatsOpts + { + Url = $"nats://admin:wrong@127.0.0.1:{port}", + MaxReconnectRetry = 0, + }); + + var ex = await Should.ThrowAsync(async () => + { + await client.ConnectAsync(); + await client.PingAsync(); + }); + + ExceptionChainContains(ex, "Authorization Violation").ShouldBeTrue( + $"Expected 'Authorization Violation' in exception chain, but got: {ex}"); + } + finally + { + await cts.CancelAsync(); + server.Dispose(); + } + } + + [Fact] + public async Task MultiUser_auth_success() + { + var (server, port, cts) = await StartServerAsync(new NatsOptions + { + Users = + [ + new User { Username = "alice", Password = "pass1" }, + new User { Username = "bob", Password = "pass2" }, + ], + }); + + try + { + await using var alice = new NatsConnection(new NatsOpts + { + Url = $"nats://alice:pass1@127.0.0.1:{port}", + }); + await using var bob = new NatsConnection(new NatsOpts + { + Url = $"nats://bob:pass2@127.0.0.1:{port}", + }); + + await alice.ConnectAsync(); + await alice.PingAsync(); + + await bob.ConnectAsync(); + await bob.PingAsync(); + } + finally + { + await cts.CancelAsync(); + server.Dispose(); + } + } + + [Fact] + public async Task No_credentials_when_auth_required_disconnects() + { + var (server, port, cts) = await StartServerAsync(new NatsOptions + { + Authorization = "s3cr3t", + }); + + try + { + await using var client = new NatsConnection(new NatsOpts + { + Url = $"nats://127.0.0.1:{port}", + MaxReconnectRetry = 0, + }); + + var ex = await Should.ThrowAsync(async () => + { + await client.ConnectAsync(); + await client.PingAsync(); + }); + + ExceptionChainContains(ex, "Authorization Violation").ShouldBeTrue( + $"Expected 'Authorization Violation' in exception chain, but got: {ex}"); + } + finally + { + await cts.CancelAsync(); + server.Dispose(); + } + } + + [Fact] + public async Task No_auth_configured_allows_all() + { + var (server, port, cts) = await StartServerAsync(new NatsOptions()); + + try + { + await using var client = new NatsConnection(new NatsOpts + { + Url = $"nats://127.0.0.1:{port}", + }); + await client.ConnectAsync(); + await client.PingAsync(); + } + finally + { + await cts.CancelAsync(); + server.Dispose(); + } + } +} diff --git a/tests/NATS.Server.Tests/ClientTests.cs b/tests/NATS.Server.Tests/ClientTests.cs index 9e6b9a8..096877a 100644 --- a/tests/NATS.Server.Tests/ClientTests.cs +++ b/tests/NATS.Server.Tests/ClientTests.cs @@ -6,6 +6,7 @@ using System.Text.Json; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using NATS.Server; +using NATS.Server.Auth; using NATS.Server.Protocol; namespace NATS.Server.Tests; @@ -39,7 +40,8 @@ public class ClientTests : IAsyncDisposable Port = 4222, }; - _natsClient = new NatsClient(1, _serverSocket, new NatsOptions(), serverInfo, NullLogger.Instance); + var authService = AuthService.Build(new NatsOptions()); + _natsClient = new NatsClient(1, _serverSocket, new NatsOptions(), serverInfo, authService, null, NullLogger.Instance); } public async ValueTask DisposeAsync()