feat: integrate authentication into server accept loop and client CONNECT processing

Wire AuthService into NatsServer and NatsClient to enforce authentication
on incoming connections. The server builds an AuthService from NatsOptions,
sets auth_required in ServerInfo, and generates per-client nonces when
NKey auth is configured. NatsClient validates credentials in ProcessConnect,
enforces publish/subscribe permissions, and implements an auth timeout that
closes connections that don't send CONNECT in time. Existing tests without
auth continue to work since AuthService.IsAuthRequired is false by default.
This commit is contained in:
Joseph Doherty
2026-02-22 22:55:50 -05:00
parent 2a2cc6f0a2
commit 2980a343c1
4 changed files with 396 additions and 7 deletions

View File

@@ -4,6 +4,7 @@ using System.Net.Sockets;
using System.Text;
using System.Text.Json;
using Microsoft.Extensions.Logging;
using NATS.Server.Auth;
using NATS.Server.Protocol;
using NATS.Server.Subscriptions;
@@ -27,16 +28,20 @@ public sealed class NatsClient : IDisposable
private readonly NetworkStream _stream;
private readonly NatsOptions _options;
private readonly ServerInfo _serverInfo;
private readonly AuthService _authService;
private readonly byte[]? _nonce;
private readonly NatsParser _parser;
private readonly SemaphoreSlim _writeLock = new(1, 1);
private CancellationTokenSource? _clientCts;
private readonly Dictionary<string, Subscription> _subs = new();
private readonly ILogger _logger;
private ClientPermissions? _permissions;
public ulong Id { get; }
public ClientOptions? ClientOpts { get; private set; }
public IMessageRouter? Router { get; set; }
public bool ConnectReceived { get; private set; }
public Account? Account { get; private set; }
// Stats
public long InMsgs;
@@ -50,13 +55,16 @@ public sealed class NatsClient : IDisposable
public IReadOnlyDictionary<string, Subscription> Subscriptions => _subs;
public NatsClient(ulong id, Socket socket, NatsOptions options, ServerInfo serverInfo, ILogger logger)
public NatsClient(ulong id, Socket socket, NatsOptions options, ServerInfo serverInfo,
AuthService authService, byte[]? nonce, ILogger logger)
{
Id = id;
_socket = socket;
_stream = new NetworkStream(socket, ownsSocket: false);
_options = options;
_serverInfo = serverInfo;
_authService = authService;
_nonce = nonce;
_logger = logger;
_parser = new NatsParser(options.MaxPayload);
}
@@ -71,6 +79,28 @@ public sealed class NatsClient : IDisposable
// Send INFO
await SendInfoAsync(_clientCts.Token);
// Start auth timeout if auth is required
Task? authTimeoutTask = null;
if (_authService.IsAuthRequired)
{
authTimeoutTask = Task.Run(async () =>
{
try
{
await Task.Delay(_options.AuthTimeout, _clientCts!.Token);
if (!ConnectReceived)
{
_logger.LogDebug("Client {ClientId} auth timeout", Id);
await SendErrAndCloseAsync(NatsProtocol.ErrAuthTimeout);
}
}
catch (OperationCanceledException)
{
// Normal — client connected or was cancelled
}
}, _clientCts.Token);
}
// Start read pump, command processing, and ping timer in parallel
var fillTask = FillPipeAsync(pipe.Writer, _clientCts.Token);
var processTask = ProcessCommandsAsync(pipe.Reader, _clientCts.Token);
@@ -147,10 +177,28 @@ public sealed class NatsClient : IDisposable
private async ValueTask DispatchCommandAsync(ParsedCommand cmd, CancellationToken ct)
{
// If auth is required and CONNECT hasn't been received yet,
// only allow CONNECT and PING commands
if (_authService.IsAuthRequired && !ConnectReceived)
{
switch (cmd.Type)
{
case CommandType.Connect:
await ProcessConnectAsync(cmd);
return;
case CommandType.Ping:
await WriteAsync(NatsProtocol.PongBytes, ct);
return;
default:
// Ignore all other commands until authenticated
return;
}
}
switch (cmd.Type)
{
case CommandType.Connect:
ProcessConnect(cmd);
await ProcessConnectAsync(cmd);
break;
case CommandType.Ping:
@@ -162,7 +210,7 @@ public sealed class NatsClient : IDisposable
break;
case CommandType.Sub:
ProcessSub(cmd);
await ProcessSubAsync(cmd);
break;
case CommandType.Unsub:
@@ -176,16 +224,56 @@ public sealed class NatsClient : IDisposable
}
}
private void ProcessConnect(ParsedCommand cmd)
private async ValueTask ProcessConnectAsync(ParsedCommand cmd)
{
ClientOpts = JsonSerializer.Deserialize<ClientOptions>(cmd.Payload.Span)
?? new ClientOptions();
// Authenticate if auth is required
if (_authService.IsAuthRequired)
{
var context = new ClientAuthContext
{
Opts = ClientOpts,
Nonce = _nonce ?? [],
};
var result = _authService.Authenticate(context);
if (result == null)
{
_logger.LogWarning("Client {ClientId} authentication failed", Id);
await SendErrAndCloseAsync(NatsProtocol.ErrAuthorizationViolation);
return;
}
// Build permissions from auth result
_permissions = ClientPermissions.Build(result.Permissions);
// Resolve account
if (Router is NatsServer server)
{
var accountName = result.AccountName ?? Account.GlobalAccountName;
Account = server.GetOrCreateAccount(accountName);
Account.AddClient(Id);
}
_logger.LogDebug("Client {ClientId} authenticated as {Identity}", Id, result.Identity);
}
ConnectReceived = true;
_logger.LogDebug("CONNECT received from client {ClientId}, name={ClientName}", Id, ClientOpts?.Name);
}
private void ProcessSub(ParsedCommand cmd)
private async ValueTask ProcessSubAsync(ParsedCommand cmd)
{
// Permission check for subscribe
if (_permissions != null && !_permissions.IsSubscribeAllowed(cmd.Subject!, cmd.Queue))
{
_logger.LogDebug("Client {ClientId} subscribe permission denied for {Subject}", Id, cmd.Subject);
await SendErrAsync(NatsProtocol.ErrPermissionsSubscribe);
return;
}
var sub = new Subscription
{
Subject = cmd.Subject!,
@@ -244,6 +332,14 @@ public sealed class NatsClient : IDisposable
return;
}
// Permission check for publish
if (_permissions != null && !_permissions.IsPublishAllowed(cmd.Subject!))
{
_logger.LogDebug("Client {ClientId} publish permission denied for {Subject}", Id, cmd.Subject);
await SendErrAsync(NatsProtocol.ErrPermissionsPublish);
return;
}
ReadOnlyMemory<byte> headers = default;
ReadOnlyMemory<byte> payload = cmd.Payload;

View File

@@ -3,6 +3,7 @@ using System.Net;
using System.Net.Sockets;
using System.Text;
using Microsoft.Extensions.Logging;
using NATS.Server.Auth;
using NATS.Server.Protocol;
using NATS.Server.Subscriptions;
@@ -17,6 +18,9 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
private readonly ILogger<NatsServer> _logger;
private readonly ILoggerFactory _loggerFactory;
private readonly TaskCompletionSource _listeningStarted = new(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly AuthService _authService;
private readonly ConcurrentDictionary<string, Account> _accounts = new(StringComparer.Ordinal);
private readonly Account _globalAccount;
private Socket? _listener;
private ulong _nextClientId;
@@ -29,6 +33,9 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
_options = options;
_loggerFactory = loggerFactory;
_logger = loggerFactory.CreateLogger<NatsServer>();
_authService = AuthService.Build(options);
_globalAccount = new Account(Account.GlobalAccountName);
_accounts[Account.GlobalAccountName] = _globalAccount;
_serverInfo = new ServerInfo
{
ServerId = Guid.NewGuid().ToString("N")[..20].ToUpperInvariant(),
@@ -37,6 +44,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
Host = options.Host,
Port = options.Port,
MaxPayload = options.MaxPayload,
AuthRequired = _authService.IsAuthRequired,
};
}
@@ -87,8 +95,27 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
_logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint);
// Build per-client ServerInfo with nonce if NKey auth is configured
byte[]? nonce = null;
var clientInfo = _serverInfo;
if (_authService.NonceRequired)
{
nonce = _authService.GenerateNonce();
clientInfo = new ServerInfo
{
ServerId = _serverInfo.ServerId,
ServerName = _serverInfo.ServerName,
Version = _serverInfo.Version,
Host = _serverInfo.Host,
Port = _serverInfo.Port,
MaxPayload = _serverInfo.MaxPayload,
AuthRequired = _serverInfo.AuthRequired,
Nonce = _authService.EncodeNonce(nonce),
};
}
var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]");
var client = new NatsClient(clientId, socket, _options, _serverInfo, clientLogger);
var client = new NatsClient(clientId, socket, _options, clientInfo, _authService, nonce, clientLogger);
client.Router = this;
_clients[clientId] = client;
@@ -169,11 +196,17 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
_ = client.SendMessageAsync(subject, sub.Sid, replyTo, headers, payload, CancellationToken.None);
}
public Account GetOrCreateAccount(string name)
{
return _accounts.GetOrAdd(name, n => new Account(n));
}
public void RemoveClient(NatsClient client)
{
_clients.TryRemove(client.Id, out _);
_logger.LogDebug("Removed client {ClientId}", client.Id);
client.RemoveAllSubscriptions(_subList);
client.Account?.RemoveClient(client.Id);
}
public void Dispose()
@@ -182,5 +215,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
foreach (var client in _clients.Values)
client.Dispose();
_subList.Dispose();
foreach (var account in _accounts.Values)
account.Dispose();
}
}

View File

@@ -0,0 +1,256 @@
using System.Net;
using System.Net.Sockets;
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Client.Core;
using NATS.Server;
using NATS.Server.Auth;
namespace NATS.Server.Tests;
public class AuthIntegrationTests
{
private static int GetFreePort()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
return ((IPEndPoint)sock.LocalEndPoint!).Port;
}
/// <summary>
/// Checks whether any exception in the chain contains the given substring.
/// The NATS client wraps server errors in outer NatsException messages,
/// so the actual "Authorization Violation" may be in an inner exception.
/// </summary>
private static bool ExceptionChainContains(Exception ex, string substring)
{
Exception? current = ex;
while (current != null)
{
if (current.Message.Contains(substring, StringComparison.OrdinalIgnoreCase))
return true;
current = current.InnerException;
}
return false;
}
private static (NatsServer server, int port, CancellationTokenSource cts) StartServer(NatsOptions options)
{
var port = GetFreePort();
options.Port = port;
var server = new NatsServer(options, NullLoggerFactory.Instance);
var cts = new CancellationTokenSource();
_ = server.StartAsync(cts.Token);
return (server, port, cts);
}
private static async Task<(NatsServer server, int port, CancellationTokenSource cts)> StartServerAsync(NatsOptions options)
{
var (server, port, cts) = StartServer(options);
await server.WaitForReadyAsync();
return (server, port, cts);
}
[Fact]
public async Task Token_auth_success()
{
var (server, port, cts) = await StartServerAsync(new NatsOptions
{
Authorization = "s3cr3t",
});
try
{
await using var client = new NatsConnection(new NatsOpts
{
Url = $"nats://s3cr3t@127.0.0.1:{port}",
});
await client.ConnectAsync();
await client.PingAsync();
}
finally
{
await cts.CancelAsync();
server.Dispose();
}
}
[Fact]
public async Task Token_auth_failure_disconnects()
{
var (server, port, cts) = await StartServerAsync(new NatsOptions
{
Authorization = "s3cr3t",
});
try
{
await using var client = new NatsConnection(new NatsOpts
{
Url = $"nats://wrongtoken@127.0.0.1:{port}",
MaxReconnectRetry = 0,
});
var ex = await Should.ThrowAsync<NatsException>(async () =>
{
await client.ConnectAsync();
await client.PingAsync();
});
ExceptionChainContains(ex, "Authorization Violation").ShouldBeTrue(
$"Expected 'Authorization Violation' in exception chain, but got: {ex}");
}
finally
{
await cts.CancelAsync();
server.Dispose();
}
}
[Fact]
public async Task UserPassword_auth_success()
{
var (server, port, cts) = await StartServerAsync(new NatsOptions
{
Username = "admin",
Password = "secret",
});
try
{
await using var client = new NatsConnection(new NatsOpts
{
Url = $"nats://admin:secret@127.0.0.1:{port}",
});
await client.ConnectAsync();
await client.PingAsync();
}
finally
{
await cts.CancelAsync();
server.Dispose();
}
}
[Fact]
public async Task UserPassword_auth_failure_disconnects()
{
var (server, port, cts) = await StartServerAsync(new NatsOptions
{
Username = "admin",
Password = "secret",
});
try
{
await using var client = new NatsConnection(new NatsOpts
{
Url = $"nats://admin:wrong@127.0.0.1:{port}",
MaxReconnectRetry = 0,
});
var ex = await Should.ThrowAsync<NatsException>(async () =>
{
await client.ConnectAsync();
await client.PingAsync();
});
ExceptionChainContains(ex, "Authorization Violation").ShouldBeTrue(
$"Expected 'Authorization Violation' in exception chain, but got: {ex}");
}
finally
{
await cts.CancelAsync();
server.Dispose();
}
}
[Fact]
public async Task MultiUser_auth_success()
{
var (server, port, cts) = await StartServerAsync(new NatsOptions
{
Users =
[
new User { Username = "alice", Password = "pass1" },
new User { Username = "bob", Password = "pass2" },
],
});
try
{
await using var alice = new NatsConnection(new NatsOpts
{
Url = $"nats://alice:pass1@127.0.0.1:{port}",
});
await using var bob = new NatsConnection(new NatsOpts
{
Url = $"nats://bob:pass2@127.0.0.1:{port}",
});
await alice.ConnectAsync();
await alice.PingAsync();
await bob.ConnectAsync();
await bob.PingAsync();
}
finally
{
await cts.CancelAsync();
server.Dispose();
}
}
[Fact]
public async Task No_credentials_when_auth_required_disconnects()
{
var (server, port, cts) = await StartServerAsync(new NatsOptions
{
Authorization = "s3cr3t",
});
try
{
await using var client = new NatsConnection(new NatsOpts
{
Url = $"nats://127.0.0.1:{port}",
MaxReconnectRetry = 0,
});
var ex = await Should.ThrowAsync<NatsException>(async () =>
{
await client.ConnectAsync();
await client.PingAsync();
});
ExceptionChainContains(ex, "Authorization Violation").ShouldBeTrue(
$"Expected 'Authorization Violation' in exception chain, but got: {ex}");
}
finally
{
await cts.CancelAsync();
server.Dispose();
}
}
[Fact]
public async Task No_auth_configured_allows_all()
{
var (server, port, cts) = await StartServerAsync(new NatsOptions());
try
{
await using var client = new NatsConnection(new NatsOpts
{
Url = $"nats://127.0.0.1:{port}",
});
await client.ConnectAsync();
await client.PingAsync();
}
finally
{
await cts.CancelAsync();
server.Dispose();
}
}
}

View File

@@ -6,6 +6,7 @@ using System.Text.Json;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Server;
using NATS.Server.Auth;
using NATS.Server.Protocol;
namespace NATS.Server.Tests;
@@ -39,7 +40,8 @@ public class ClientTests : IAsyncDisposable
Port = 4222,
};
_natsClient = new NatsClient(1, _serverSocket, new NatsOptions(), serverInfo, NullLogger.Instance);
var authService = AuthService.Build(new NatsOptions());
_natsClient = new NatsClient(1, _serverSocket, new NatsOptions(), serverInfo, authService, null, NullLogger.Instance);
}
public async ValueTask DisposeAsync()