feat: integrate authentication into server accept loop and client CONNECT processing
Wire AuthService into NatsServer and NatsClient to enforce authentication on incoming connections. The server builds an AuthService from NatsOptions, sets auth_required in ServerInfo, and generates per-client nonces when NKey auth is configured. NatsClient validates credentials in ProcessConnect, enforces publish/subscribe permissions, and implements an auth timeout that closes connections that don't send CONNECT in time. Existing tests without auth continue to work since AuthService.IsAuthRequired is false by default.
This commit is contained in:
@@ -4,6 +4,7 @@ using System.Net.Sockets;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using NATS.Server.Auth;
|
||||
using NATS.Server.Protocol;
|
||||
using NATS.Server.Subscriptions;
|
||||
|
||||
@@ -27,16 +28,20 @@ public sealed class NatsClient : IDisposable
|
||||
private readonly NetworkStream _stream;
|
||||
private readonly NatsOptions _options;
|
||||
private readonly ServerInfo _serverInfo;
|
||||
private readonly AuthService _authService;
|
||||
private readonly byte[]? _nonce;
|
||||
private readonly NatsParser _parser;
|
||||
private readonly SemaphoreSlim _writeLock = new(1, 1);
|
||||
private CancellationTokenSource? _clientCts;
|
||||
private readonly Dictionary<string, Subscription> _subs = new();
|
||||
private readonly ILogger _logger;
|
||||
private ClientPermissions? _permissions;
|
||||
|
||||
public ulong Id { get; }
|
||||
public ClientOptions? ClientOpts { get; private set; }
|
||||
public IMessageRouter? Router { get; set; }
|
||||
public bool ConnectReceived { get; private set; }
|
||||
public Account? Account { get; private set; }
|
||||
|
||||
// Stats
|
||||
public long InMsgs;
|
||||
@@ -50,13 +55,16 @@ public sealed class NatsClient : IDisposable
|
||||
|
||||
public IReadOnlyDictionary<string, Subscription> Subscriptions => _subs;
|
||||
|
||||
public NatsClient(ulong id, Socket socket, NatsOptions options, ServerInfo serverInfo, ILogger logger)
|
||||
public NatsClient(ulong id, Socket socket, NatsOptions options, ServerInfo serverInfo,
|
||||
AuthService authService, byte[]? nonce, ILogger logger)
|
||||
{
|
||||
Id = id;
|
||||
_socket = socket;
|
||||
_stream = new NetworkStream(socket, ownsSocket: false);
|
||||
_options = options;
|
||||
_serverInfo = serverInfo;
|
||||
_authService = authService;
|
||||
_nonce = nonce;
|
||||
_logger = logger;
|
||||
_parser = new NatsParser(options.MaxPayload);
|
||||
}
|
||||
@@ -71,6 +79,28 @@ public sealed class NatsClient : IDisposable
|
||||
// Send INFO
|
||||
await SendInfoAsync(_clientCts.Token);
|
||||
|
||||
// Start auth timeout if auth is required
|
||||
Task? authTimeoutTask = null;
|
||||
if (_authService.IsAuthRequired)
|
||||
{
|
||||
authTimeoutTask = Task.Run(async () =>
|
||||
{
|
||||
try
|
||||
{
|
||||
await Task.Delay(_options.AuthTimeout, _clientCts!.Token);
|
||||
if (!ConnectReceived)
|
||||
{
|
||||
_logger.LogDebug("Client {ClientId} auth timeout", Id);
|
||||
await SendErrAndCloseAsync(NatsProtocol.ErrAuthTimeout);
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Normal — client connected or was cancelled
|
||||
}
|
||||
}, _clientCts.Token);
|
||||
}
|
||||
|
||||
// Start read pump, command processing, and ping timer in parallel
|
||||
var fillTask = FillPipeAsync(pipe.Writer, _clientCts.Token);
|
||||
var processTask = ProcessCommandsAsync(pipe.Reader, _clientCts.Token);
|
||||
@@ -147,10 +177,28 @@ public sealed class NatsClient : IDisposable
|
||||
|
||||
private async ValueTask DispatchCommandAsync(ParsedCommand cmd, CancellationToken ct)
|
||||
{
|
||||
// If auth is required and CONNECT hasn't been received yet,
|
||||
// only allow CONNECT and PING commands
|
||||
if (_authService.IsAuthRequired && !ConnectReceived)
|
||||
{
|
||||
switch (cmd.Type)
|
||||
{
|
||||
case CommandType.Connect:
|
||||
await ProcessConnectAsync(cmd);
|
||||
return;
|
||||
case CommandType.Ping:
|
||||
await WriteAsync(NatsProtocol.PongBytes, ct);
|
||||
return;
|
||||
default:
|
||||
// Ignore all other commands until authenticated
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
switch (cmd.Type)
|
||||
{
|
||||
case CommandType.Connect:
|
||||
ProcessConnect(cmd);
|
||||
await ProcessConnectAsync(cmd);
|
||||
break;
|
||||
|
||||
case CommandType.Ping:
|
||||
@@ -162,7 +210,7 @@ public sealed class NatsClient : IDisposable
|
||||
break;
|
||||
|
||||
case CommandType.Sub:
|
||||
ProcessSub(cmd);
|
||||
await ProcessSubAsync(cmd);
|
||||
break;
|
||||
|
||||
case CommandType.Unsub:
|
||||
@@ -176,16 +224,56 @@ public sealed class NatsClient : IDisposable
|
||||
}
|
||||
}
|
||||
|
||||
private void ProcessConnect(ParsedCommand cmd)
|
||||
private async ValueTask ProcessConnectAsync(ParsedCommand cmd)
|
||||
{
|
||||
ClientOpts = JsonSerializer.Deserialize<ClientOptions>(cmd.Payload.Span)
|
||||
?? new ClientOptions();
|
||||
|
||||
// Authenticate if auth is required
|
||||
if (_authService.IsAuthRequired)
|
||||
{
|
||||
var context = new ClientAuthContext
|
||||
{
|
||||
Opts = ClientOpts,
|
||||
Nonce = _nonce ?? [],
|
||||
};
|
||||
|
||||
var result = _authService.Authenticate(context);
|
||||
if (result == null)
|
||||
{
|
||||
_logger.LogWarning("Client {ClientId} authentication failed", Id);
|
||||
await SendErrAndCloseAsync(NatsProtocol.ErrAuthorizationViolation);
|
||||
return;
|
||||
}
|
||||
|
||||
// Build permissions from auth result
|
||||
_permissions = ClientPermissions.Build(result.Permissions);
|
||||
|
||||
// Resolve account
|
||||
if (Router is NatsServer server)
|
||||
{
|
||||
var accountName = result.AccountName ?? Account.GlobalAccountName;
|
||||
Account = server.GetOrCreateAccount(accountName);
|
||||
Account.AddClient(Id);
|
||||
}
|
||||
|
||||
_logger.LogDebug("Client {ClientId} authenticated as {Identity}", Id, result.Identity);
|
||||
}
|
||||
|
||||
ConnectReceived = true;
|
||||
_logger.LogDebug("CONNECT received from client {ClientId}, name={ClientName}", Id, ClientOpts?.Name);
|
||||
}
|
||||
|
||||
private void ProcessSub(ParsedCommand cmd)
|
||||
private async ValueTask ProcessSubAsync(ParsedCommand cmd)
|
||||
{
|
||||
// Permission check for subscribe
|
||||
if (_permissions != null && !_permissions.IsSubscribeAllowed(cmd.Subject!, cmd.Queue))
|
||||
{
|
||||
_logger.LogDebug("Client {ClientId} subscribe permission denied for {Subject}", Id, cmd.Subject);
|
||||
await SendErrAsync(NatsProtocol.ErrPermissionsSubscribe);
|
||||
return;
|
||||
}
|
||||
|
||||
var sub = new Subscription
|
||||
{
|
||||
Subject = cmd.Subject!,
|
||||
@@ -244,6 +332,14 @@ public sealed class NatsClient : IDisposable
|
||||
return;
|
||||
}
|
||||
|
||||
// Permission check for publish
|
||||
if (_permissions != null && !_permissions.IsPublishAllowed(cmd.Subject!))
|
||||
{
|
||||
_logger.LogDebug("Client {ClientId} publish permission denied for {Subject}", Id, cmd.Subject);
|
||||
await SendErrAsync(NatsProtocol.ErrPermissionsPublish);
|
||||
return;
|
||||
}
|
||||
|
||||
ReadOnlyMemory<byte> headers = default;
|
||||
ReadOnlyMemory<byte> payload = cmd.Payload;
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ using System.Net;
|
||||
using System.Net.Sockets;
|
||||
using System.Text;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using NATS.Server.Auth;
|
||||
using NATS.Server.Protocol;
|
||||
using NATS.Server.Subscriptions;
|
||||
|
||||
@@ -17,6 +18,9 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
private readonly ILogger<NatsServer> _logger;
|
||||
private readonly ILoggerFactory _loggerFactory;
|
||||
private readonly TaskCompletionSource _listeningStarted = new(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
private readonly AuthService _authService;
|
||||
private readonly ConcurrentDictionary<string, Account> _accounts = new(StringComparer.Ordinal);
|
||||
private readonly Account _globalAccount;
|
||||
private Socket? _listener;
|
||||
private ulong _nextClientId;
|
||||
|
||||
@@ -29,6 +33,9 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
_options = options;
|
||||
_loggerFactory = loggerFactory;
|
||||
_logger = loggerFactory.CreateLogger<NatsServer>();
|
||||
_authService = AuthService.Build(options);
|
||||
_globalAccount = new Account(Account.GlobalAccountName);
|
||||
_accounts[Account.GlobalAccountName] = _globalAccount;
|
||||
_serverInfo = new ServerInfo
|
||||
{
|
||||
ServerId = Guid.NewGuid().ToString("N")[..20].ToUpperInvariant(),
|
||||
@@ -37,6 +44,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
Host = options.Host,
|
||||
Port = options.Port,
|
||||
MaxPayload = options.MaxPayload,
|
||||
AuthRequired = _authService.IsAuthRequired,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -87,8 +95,27 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
|
||||
_logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint);
|
||||
|
||||
// Build per-client ServerInfo with nonce if NKey auth is configured
|
||||
byte[]? nonce = null;
|
||||
var clientInfo = _serverInfo;
|
||||
if (_authService.NonceRequired)
|
||||
{
|
||||
nonce = _authService.GenerateNonce();
|
||||
clientInfo = new ServerInfo
|
||||
{
|
||||
ServerId = _serverInfo.ServerId,
|
||||
ServerName = _serverInfo.ServerName,
|
||||
Version = _serverInfo.Version,
|
||||
Host = _serverInfo.Host,
|
||||
Port = _serverInfo.Port,
|
||||
MaxPayload = _serverInfo.MaxPayload,
|
||||
AuthRequired = _serverInfo.AuthRequired,
|
||||
Nonce = _authService.EncodeNonce(nonce),
|
||||
};
|
||||
}
|
||||
|
||||
var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]");
|
||||
var client = new NatsClient(clientId, socket, _options, _serverInfo, clientLogger);
|
||||
var client = new NatsClient(clientId, socket, _options, clientInfo, _authService, nonce, clientLogger);
|
||||
client.Router = this;
|
||||
_clients[clientId] = client;
|
||||
|
||||
@@ -169,11 +196,17 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
_ = client.SendMessageAsync(subject, sub.Sid, replyTo, headers, payload, CancellationToken.None);
|
||||
}
|
||||
|
||||
public Account GetOrCreateAccount(string name)
|
||||
{
|
||||
return _accounts.GetOrAdd(name, n => new Account(n));
|
||||
}
|
||||
|
||||
public void RemoveClient(NatsClient client)
|
||||
{
|
||||
_clients.TryRemove(client.Id, out _);
|
||||
_logger.LogDebug("Removed client {ClientId}", client.Id);
|
||||
client.RemoveAllSubscriptions(_subList);
|
||||
client.Account?.RemoveClient(client.Id);
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
@@ -182,5 +215,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
foreach (var client in _clients.Values)
|
||||
client.Dispose();
|
||||
_subList.Dispose();
|
||||
foreach (var account in _accounts.Values)
|
||||
account.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user