diff --git a/src/NATS.Server/NatsClient.cs b/src/NATS.Server/NatsClient.cs index 149a399..4e62313 100644 --- a/src/NATS.Server/NatsClient.cs +++ b/src/NATS.Server/NatsClient.cs @@ -4,6 +4,7 @@ using System.Net.Sockets; using System.Text; using System.Text.Json; using Microsoft.Extensions.Logging; +using NATS.Server.Auth; using NATS.Server.Protocol; using NATS.Server.Subscriptions; @@ -27,16 +28,20 @@ public sealed class NatsClient : IDisposable private readonly NetworkStream _stream; private readonly NatsOptions _options; private readonly ServerInfo _serverInfo; + private readonly AuthService _authService; + private readonly byte[]? _nonce; private readonly NatsParser _parser; private readonly SemaphoreSlim _writeLock = new(1, 1); private CancellationTokenSource? _clientCts; private readonly Dictionary _subs = new(); private readonly ILogger _logger; + private ClientPermissions? _permissions; public ulong Id { get; } public ClientOptions? ClientOpts { get; private set; } public IMessageRouter? Router { get; set; } public bool ConnectReceived { get; private set; } + public Account? Account { get; private set; } // Stats public long InMsgs; @@ -50,13 +55,16 @@ public sealed class NatsClient : IDisposable public IReadOnlyDictionary Subscriptions => _subs; - public NatsClient(ulong id, Socket socket, NatsOptions options, ServerInfo serverInfo, ILogger logger) + public NatsClient(ulong id, Socket socket, NatsOptions options, ServerInfo serverInfo, + AuthService authService, byte[]? nonce, ILogger logger) { Id = id; _socket = socket; _stream = new NetworkStream(socket, ownsSocket: false); _options = options; _serverInfo = serverInfo; + _authService = authService; + _nonce = nonce; _logger = logger; _parser = new NatsParser(options.MaxPayload); } @@ -71,6 +79,28 @@ public sealed class NatsClient : IDisposable // Send INFO await SendInfoAsync(_clientCts.Token); + // Start auth timeout if auth is required + Task? authTimeoutTask = null; + if (_authService.IsAuthRequired) + { + authTimeoutTask = Task.Run(async () => + { + try + { + await Task.Delay(_options.AuthTimeout, _clientCts!.Token); + if (!ConnectReceived) + { + _logger.LogDebug("Client {ClientId} auth timeout", Id); + await SendErrAndCloseAsync(NatsProtocol.ErrAuthTimeout); + } + } + catch (OperationCanceledException) + { + // Normal — client connected or was cancelled + } + }, _clientCts.Token); + } + // Start read pump, command processing, and ping timer in parallel var fillTask = FillPipeAsync(pipe.Writer, _clientCts.Token); var processTask = ProcessCommandsAsync(pipe.Reader, _clientCts.Token); @@ -147,10 +177,28 @@ public sealed class NatsClient : IDisposable private async ValueTask DispatchCommandAsync(ParsedCommand cmd, CancellationToken ct) { + // If auth is required and CONNECT hasn't been received yet, + // only allow CONNECT and PING commands + if (_authService.IsAuthRequired && !ConnectReceived) + { + switch (cmd.Type) + { + case CommandType.Connect: + await ProcessConnectAsync(cmd); + return; + case CommandType.Ping: + await WriteAsync(NatsProtocol.PongBytes, ct); + return; + default: + // Ignore all other commands until authenticated + return; + } + } + switch (cmd.Type) { case CommandType.Connect: - ProcessConnect(cmd); + await ProcessConnectAsync(cmd); break; case CommandType.Ping: @@ -162,7 +210,7 @@ public sealed class NatsClient : IDisposable break; case CommandType.Sub: - ProcessSub(cmd); + await ProcessSubAsync(cmd); break; case CommandType.Unsub: @@ -176,16 +224,56 @@ public sealed class NatsClient : IDisposable } } - private void ProcessConnect(ParsedCommand cmd) + private async ValueTask ProcessConnectAsync(ParsedCommand cmd) { ClientOpts = JsonSerializer.Deserialize(cmd.Payload.Span) ?? new ClientOptions(); + + // Authenticate if auth is required + if (_authService.IsAuthRequired) + { + var context = new ClientAuthContext + { + Opts = ClientOpts, + Nonce = _nonce ?? [], + }; + + var result = _authService.Authenticate(context); + if (result == null) + { + _logger.LogWarning("Client {ClientId} authentication failed", Id); + await SendErrAndCloseAsync(NatsProtocol.ErrAuthorizationViolation); + return; + } + + // Build permissions from auth result + _permissions = ClientPermissions.Build(result.Permissions); + + // Resolve account + if (Router is NatsServer server) + { + var accountName = result.AccountName ?? Account.GlobalAccountName; + Account = server.GetOrCreateAccount(accountName); + Account.AddClient(Id); + } + + _logger.LogDebug("Client {ClientId} authenticated as {Identity}", Id, result.Identity); + } + ConnectReceived = true; _logger.LogDebug("CONNECT received from client {ClientId}, name={ClientName}", Id, ClientOpts?.Name); } - private void ProcessSub(ParsedCommand cmd) + private async ValueTask ProcessSubAsync(ParsedCommand cmd) { + // Permission check for subscribe + if (_permissions != null && !_permissions.IsSubscribeAllowed(cmd.Subject!, cmd.Queue)) + { + _logger.LogDebug("Client {ClientId} subscribe permission denied for {Subject}", Id, cmd.Subject); + await SendErrAsync(NatsProtocol.ErrPermissionsSubscribe); + return; + } + var sub = new Subscription { Subject = cmd.Subject!, @@ -244,6 +332,14 @@ public sealed class NatsClient : IDisposable return; } + // Permission check for publish + if (_permissions != null && !_permissions.IsPublishAllowed(cmd.Subject!)) + { + _logger.LogDebug("Client {ClientId} publish permission denied for {Subject}", Id, cmd.Subject); + await SendErrAsync(NatsProtocol.ErrPermissionsPublish); + return; + } + ReadOnlyMemory headers = default; ReadOnlyMemory payload = cmd.Payload; diff --git a/src/NATS.Server/NatsServer.cs b/src/NATS.Server/NatsServer.cs index 5900aef..ae1637e 100644 --- a/src/NATS.Server/NatsServer.cs +++ b/src/NATS.Server/NatsServer.cs @@ -3,6 +3,7 @@ using System.Net; using System.Net.Sockets; using System.Text; using Microsoft.Extensions.Logging; +using NATS.Server.Auth; using NATS.Server.Protocol; using NATS.Server.Subscriptions; @@ -17,6 +18,9 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable private readonly ILogger _logger; private readonly ILoggerFactory _loggerFactory; private readonly TaskCompletionSource _listeningStarted = new(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly AuthService _authService; + private readonly ConcurrentDictionary _accounts = new(StringComparer.Ordinal); + private readonly Account _globalAccount; private Socket? _listener; private ulong _nextClientId; @@ -29,6 +33,9 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable _options = options; _loggerFactory = loggerFactory; _logger = loggerFactory.CreateLogger(); + _authService = AuthService.Build(options); + _globalAccount = new Account(Account.GlobalAccountName); + _accounts[Account.GlobalAccountName] = _globalAccount; _serverInfo = new ServerInfo { ServerId = Guid.NewGuid().ToString("N")[..20].ToUpperInvariant(), @@ -37,6 +44,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable Host = options.Host, Port = options.Port, MaxPayload = options.MaxPayload, + AuthRequired = _authService.IsAuthRequired, }; } @@ -87,8 +95,27 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable _logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint); + // Build per-client ServerInfo with nonce if NKey auth is configured + byte[]? nonce = null; + var clientInfo = _serverInfo; + if (_authService.NonceRequired) + { + nonce = _authService.GenerateNonce(); + clientInfo = new ServerInfo + { + ServerId = _serverInfo.ServerId, + ServerName = _serverInfo.ServerName, + Version = _serverInfo.Version, + Host = _serverInfo.Host, + Port = _serverInfo.Port, + MaxPayload = _serverInfo.MaxPayload, + AuthRequired = _serverInfo.AuthRequired, + Nonce = _authService.EncodeNonce(nonce), + }; + } + var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]"); - var client = new NatsClient(clientId, socket, _options, _serverInfo, clientLogger); + var client = new NatsClient(clientId, socket, _options, clientInfo, _authService, nonce, clientLogger); client.Router = this; _clients[clientId] = client; @@ -169,11 +196,17 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable _ = client.SendMessageAsync(subject, sub.Sid, replyTo, headers, payload, CancellationToken.None); } + public Account GetOrCreateAccount(string name) + { + return _accounts.GetOrAdd(name, n => new Account(n)); + } + public void RemoveClient(NatsClient client) { _clients.TryRemove(client.Id, out _); _logger.LogDebug("Removed client {ClientId}", client.Id); client.RemoveAllSubscriptions(_subList); + client.Account?.RemoveClient(client.Id); } public void Dispose() @@ -182,5 +215,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable foreach (var client in _clients.Values) client.Dispose(); _subList.Dispose(); + foreach (var account in _accounts.Values) + account.Dispose(); } } diff --git a/tests/NATS.Server.Tests/AuthIntegrationTests.cs b/tests/NATS.Server.Tests/AuthIntegrationTests.cs new file mode 100644 index 0000000..e8151dd --- /dev/null +++ b/tests/NATS.Server.Tests/AuthIntegrationTests.cs @@ -0,0 +1,256 @@ +using System.Net; +using System.Net.Sockets; +using Microsoft.Extensions.Logging.Abstractions; +using NATS.Client.Core; +using NATS.Server; +using NATS.Server.Auth; + +namespace NATS.Server.Tests; + +public class AuthIntegrationTests +{ + private static int GetFreePort() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + sock.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + return ((IPEndPoint)sock.LocalEndPoint!).Port; + } + + /// + /// Checks whether any exception in the chain contains the given substring. + /// The NATS client wraps server errors in outer NatsException messages, + /// so the actual "Authorization Violation" may be in an inner exception. + /// + private static bool ExceptionChainContains(Exception ex, string substring) + { + Exception? current = ex; + while (current != null) + { + if (current.Message.Contains(substring, StringComparison.OrdinalIgnoreCase)) + return true; + current = current.InnerException; + } + + return false; + } + + private static (NatsServer server, int port, CancellationTokenSource cts) StartServer(NatsOptions options) + { + var port = GetFreePort(); + options.Port = port; + var server = new NatsServer(options, NullLoggerFactory.Instance); + var cts = new CancellationTokenSource(); + _ = server.StartAsync(cts.Token); + return (server, port, cts); + } + + private static async Task<(NatsServer server, int port, CancellationTokenSource cts)> StartServerAsync(NatsOptions options) + { + var (server, port, cts) = StartServer(options); + await server.WaitForReadyAsync(); + return (server, port, cts); + } + + [Fact] + public async Task Token_auth_success() + { + var (server, port, cts) = await StartServerAsync(new NatsOptions + { + Authorization = "s3cr3t", + }); + + try + { + await using var client = new NatsConnection(new NatsOpts + { + Url = $"nats://s3cr3t@127.0.0.1:{port}", + }); + await client.ConnectAsync(); + await client.PingAsync(); + } + finally + { + await cts.CancelAsync(); + server.Dispose(); + } + } + + [Fact] + public async Task Token_auth_failure_disconnects() + { + var (server, port, cts) = await StartServerAsync(new NatsOptions + { + Authorization = "s3cr3t", + }); + + try + { + await using var client = new NatsConnection(new NatsOpts + { + Url = $"nats://wrongtoken@127.0.0.1:{port}", + MaxReconnectRetry = 0, + }); + + var ex = await Should.ThrowAsync(async () => + { + await client.ConnectAsync(); + await client.PingAsync(); + }); + + ExceptionChainContains(ex, "Authorization Violation").ShouldBeTrue( + $"Expected 'Authorization Violation' in exception chain, but got: {ex}"); + } + finally + { + await cts.CancelAsync(); + server.Dispose(); + } + } + + [Fact] + public async Task UserPassword_auth_success() + { + var (server, port, cts) = await StartServerAsync(new NatsOptions + { + Username = "admin", + Password = "secret", + }); + + try + { + await using var client = new NatsConnection(new NatsOpts + { + Url = $"nats://admin:secret@127.0.0.1:{port}", + }); + await client.ConnectAsync(); + await client.PingAsync(); + } + finally + { + await cts.CancelAsync(); + server.Dispose(); + } + } + + [Fact] + public async Task UserPassword_auth_failure_disconnects() + { + var (server, port, cts) = await StartServerAsync(new NatsOptions + { + Username = "admin", + Password = "secret", + }); + + try + { + await using var client = new NatsConnection(new NatsOpts + { + Url = $"nats://admin:wrong@127.0.0.1:{port}", + MaxReconnectRetry = 0, + }); + + var ex = await Should.ThrowAsync(async () => + { + await client.ConnectAsync(); + await client.PingAsync(); + }); + + ExceptionChainContains(ex, "Authorization Violation").ShouldBeTrue( + $"Expected 'Authorization Violation' in exception chain, but got: {ex}"); + } + finally + { + await cts.CancelAsync(); + server.Dispose(); + } + } + + [Fact] + public async Task MultiUser_auth_success() + { + var (server, port, cts) = await StartServerAsync(new NatsOptions + { + Users = + [ + new User { Username = "alice", Password = "pass1" }, + new User { Username = "bob", Password = "pass2" }, + ], + }); + + try + { + await using var alice = new NatsConnection(new NatsOpts + { + Url = $"nats://alice:pass1@127.0.0.1:{port}", + }); + await using var bob = new NatsConnection(new NatsOpts + { + Url = $"nats://bob:pass2@127.0.0.1:{port}", + }); + + await alice.ConnectAsync(); + await alice.PingAsync(); + + await bob.ConnectAsync(); + await bob.PingAsync(); + } + finally + { + await cts.CancelAsync(); + server.Dispose(); + } + } + + [Fact] + public async Task No_credentials_when_auth_required_disconnects() + { + var (server, port, cts) = await StartServerAsync(new NatsOptions + { + Authorization = "s3cr3t", + }); + + try + { + await using var client = new NatsConnection(new NatsOpts + { + Url = $"nats://127.0.0.1:{port}", + MaxReconnectRetry = 0, + }); + + var ex = await Should.ThrowAsync(async () => + { + await client.ConnectAsync(); + await client.PingAsync(); + }); + + ExceptionChainContains(ex, "Authorization Violation").ShouldBeTrue( + $"Expected 'Authorization Violation' in exception chain, but got: {ex}"); + } + finally + { + await cts.CancelAsync(); + server.Dispose(); + } + } + + [Fact] + public async Task No_auth_configured_allows_all() + { + var (server, port, cts) = await StartServerAsync(new NatsOptions()); + + try + { + await using var client = new NatsConnection(new NatsOpts + { + Url = $"nats://127.0.0.1:{port}", + }); + await client.ConnectAsync(); + await client.PingAsync(); + } + finally + { + await cts.CancelAsync(); + server.Dispose(); + } + } +} diff --git a/tests/NATS.Server.Tests/ClientTests.cs b/tests/NATS.Server.Tests/ClientTests.cs index 9e6b9a8..096877a 100644 --- a/tests/NATS.Server.Tests/ClientTests.cs +++ b/tests/NATS.Server.Tests/ClientTests.cs @@ -6,6 +6,7 @@ using System.Text.Json; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using NATS.Server; +using NATS.Server.Auth; using NATS.Server.Protocol; namespace NATS.Server.Tests; @@ -39,7 +40,8 @@ public class ClientTests : IAsyncDisposable Port = 4222, }; - _natsClient = new NatsClient(1, _serverSocket, new NatsOptions(), serverInfo, NullLogger.Instance); + var authService = AuthService.Build(new NatsOptions()); + _natsClient = new NatsClient(1, _serverSocket, new NatsOptions(), serverInfo, authService, null, NullLogger.Instance); } public async ValueTask DisposeAsync()