feat: wire TLS negotiation into NatsServer accept loop
Integrate TLS support into the server's connection accept path: - Add SslServerAuthenticationOptions and TlsRateLimiter fields to NatsServer - Extract AcceptClientAsync method for TLS negotiation, rate limiting, and TLS state extraction (protocol version, cipher suite, peer certificate) - Add InfoAlreadySent flag to NatsClient to skip redundant INFO when TlsConnectionWrapper already sent it during negotiation - Add TlsServerTests verifying TLS connect+INFO and TLS pub/sub
This commit is contained in:
@@ -57,6 +57,7 @@ public sealed class NatsClient : IDisposable
|
|||||||
private long _lastIn;
|
private long _lastIn;
|
||||||
|
|
||||||
public TlsConnectionState? TlsState { get; set; }
|
public TlsConnectionState? TlsState { get; set; }
|
||||||
|
public bool InfoAlreadySent { get; set; }
|
||||||
|
|
||||||
public IReadOnlyDictionary<string, Subscription> Subscriptions => _subs;
|
public IReadOnlyDictionary<string, Subscription> Subscriptions => _subs;
|
||||||
|
|
||||||
@@ -87,7 +88,8 @@ public sealed class NatsClient : IDisposable
|
|||||||
var pipe = new Pipe();
|
var pipe = new Pipe();
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
// Send INFO
|
// Send INFO (skip if already sent during TLS negotiation)
|
||||||
|
if (!InfoAlreadySent)
|
||||||
await SendInfoAsync(_clientCts.Token);
|
await SendInfoAsync(_clientCts.Token);
|
||||||
|
|
||||||
// Start read pump, command processing, and ping timer in parallel
|
// Start read pump, command processing, and ping timer in parallel
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
using System.Collections.Concurrent;
|
using System.Collections.Concurrent;
|
||||||
using System.Net;
|
using System.Net;
|
||||||
|
using System.Net.Security;
|
||||||
using System.Net.Sockets;
|
using System.Net.Sockets;
|
||||||
|
using System.Security.Cryptography.X509Certificates;
|
||||||
using System.Text;
|
using System.Text;
|
||||||
using Microsoft.Extensions.Logging;
|
using Microsoft.Extensions.Logging;
|
||||||
using NATS.Server.Monitoring;
|
using NATS.Server.Monitoring;
|
||||||
using NATS.Server.Protocol;
|
using NATS.Server.Protocol;
|
||||||
using NATS.Server.Subscriptions;
|
using NATS.Server.Subscriptions;
|
||||||
|
using NATS.Server.Tls;
|
||||||
|
|
||||||
namespace NATS.Server;
|
namespace NATS.Server;
|
||||||
|
|
||||||
@@ -19,6 +22,8 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
|||||||
private readonly ILoggerFactory _loggerFactory;
|
private readonly ILoggerFactory _loggerFactory;
|
||||||
private readonly ServerStats _stats = new();
|
private readonly ServerStats _stats = new();
|
||||||
private readonly TaskCompletionSource _listeningStarted = new(TaskCreationOptions.RunContinuationsAsynchronously);
|
private readonly TaskCompletionSource _listeningStarted = new(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||||
|
private readonly SslServerAuthenticationOptions? _sslOptions;
|
||||||
|
private readonly TlsRateLimiter? _tlsRateLimiter;
|
||||||
private Socket? _listener;
|
private Socket? _listener;
|
||||||
private MonitorServer? _monitorServer;
|
private MonitorServer? _monitorServer;
|
||||||
private ulong _nextClientId;
|
private ulong _nextClientId;
|
||||||
@@ -48,6 +53,17 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
|||||||
Port = options.Port,
|
Port = options.Port,
|
||||||
MaxPayload = options.MaxPayload,
|
MaxPayload = options.MaxPayload,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (options.HasTls)
|
||||||
|
{
|
||||||
|
_sslOptions = TlsHelper.BuildServerAuthOptions(options);
|
||||||
|
_serverInfo.TlsRequired = !options.AllowNonTls;
|
||||||
|
_serverInfo.TlsAvailable = options.AllowNonTls;
|
||||||
|
_serverInfo.TlsVerify = options.TlsVerify;
|
||||||
|
|
||||||
|
if (options.TlsRateLimit > 0)
|
||||||
|
_tlsRateLimiter = new TlsRateLimiter(options.TlsRateLimit);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task StartAsync(CancellationToken ct)
|
public async Task StartAsync(CancellationToken ct)
|
||||||
@@ -105,13 +121,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
|||||||
|
|
||||||
_logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint);
|
_logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint);
|
||||||
|
|
||||||
var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]");
|
_ = AcceptClientAsync(socket, clientId, ct);
|
||||||
var networkStream = new NetworkStream(socket, ownsSocket: false);
|
|
||||||
var client = new NatsClient(clientId, networkStream, socket, _options, _serverInfo, clientLogger, _stats);
|
|
||||||
client.Router = this;
|
|
||||||
_clients[clientId] = client;
|
|
||||||
|
|
||||||
_ = RunClientAsync(client, ct);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
catch (OperationCanceledException)
|
catch (OperationCanceledException)
|
||||||
@@ -120,6 +130,49 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private async Task AcceptClientAsync(Socket socket, ulong clientId, CancellationToken ct)
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
// Rate limit TLS handshakes
|
||||||
|
if (_tlsRateLimiter != null)
|
||||||
|
await _tlsRateLimiter.WaitAsync(ct);
|
||||||
|
|
||||||
|
var networkStream = new NetworkStream(socket, ownsSocket: false);
|
||||||
|
|
||||||
|
// TLS negotiation (no-op if not configured)
|
||||||
|
var (stream, infoAlreadySent) = await TlsConnectionWrapper.NegotiateAsync(
|
||||||
|
socket, networkStream, _options, _sslOptions, _serverInfo,
|
||||||
|
_loggerFactory.CreateLogger("NATS.Server.Tls"), ct);
|
||||||
|
|
||||||
|
// Extract TLS state
|
||||||
|
TlsConnectionState? tlsState = null;
|
||||||
|
if (stream is SslStream ssl)
|
||||||
|
{
|
||||||
|
tlsState = new TlsConnectionState(
|
||||||
|
ssl.SslProtocol.ToString(),
|
||||||
|
ssl.NegotiatedCipherSuite.ToString(),
|
||||||
|
ssl.RemoteCertificate as X509Certificate2);
|
||||||
|
}
|
||||||
|
|
||||||
|
var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]");
|
||||||
|
var client = new NatsClient(clientId, stream, socket, _options, _serverInfo,
|
||||||
|
clientLogger, _stats);
|
||||||
|
client.Router = this;
|
||||||
|
client.TlsState = tlsState;
|
||||||
|
client.InfoAlreadySent = infoAlreadySent;
|
||||||
|
_clients[clientId] = client;
|
||||||
|
|
||||||
|
await RunClientAsync(client, ct);
|
||||||
|
}
|
||||||
|
catch (Exception ex)
|
||||||
|
{
|
||||||
|
_logger.LogDebug(ex, "Failed to accept client {ClientId}", clientId);
|
||||||
|
try { socket.Shutdown(SocketShutdown.Both); } catch { }
|
||||||
|
socket.Dispose();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private async Task RunClientAsync(NatsClient client, CancellationToken ct)
|
private async Task RunClientAsync(NatsClient client, CancellationToken ct)
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
@@ -199,6 +252,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
|||||||
{
|
{
|
||||||
if (_monitorServer != null)
|
if (_monitorServer != null)
|
||||||
_monitorServer.DisposeAsync().AsTask().GetAwaiter().GetResult();
|
_monitorServer.DisposeAsync().AsTask().GetAwaiter().GetResult();
|
||||||
|
_tlsRateLimiter?.Dispose();
|
||||||
_listener?.Dispose();
|
_listener?.Dispose();
|
||||||
foreach (var client in _clients.Values)
|
foreach (var client in _clients.Values)
|
||||||
client.Dispose();
|
client.Dispose();
|
||||||
|
|||||||
137
tests/NATS.Server.Tests/TlsServerTests.cs
Normal file
137
tests/NATS.Server.Tests/TlsServerTests.cs
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
using System.Net;
|
||||||
|
using System.Net.Security;
|
||||||
|
using System.Net.Sockets;
|
||||||
|
using System.Text;
|
||||||
|
using Microsoft.Extensions.Logging.Abstractions;
|
||||||
|
using NATS.Server;
|
||||||
|
|
||||||
|
namespace NATS.Server.Tests;
|
||||||
|
|
||||||
|
public class TlsServerTests : IAsyncLifetime
|
||||||
|
{
|
||||||
|
private readonly NatsServer _server;
|
||||||
|
private readonly int _port;
|
||||||
|
private readonly CancellationTokenSource _cts = new();
|
||||||
|
private readonly string _certPath;
|
||||||
|
private readonly string _keyPath;
|
||||||
|
|
||||||
|
public TlsServerTests()
|
||||||
|
{
|
||||||
|
_port = GetFreePort();
|
||||||
|
(_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles();
|
||||||
|
_server = new NatsServer(
|
||||||
|
new NatsOptions
|
||||||
|
{
|
||||||
|
Port = _port,
|
||||||
|
TlsCert = _certPath,
|
||||||
|
TlsKey = _keyPath,
|
||||||
|
},
|
||||||
|
NullLoggerFactory.Instance);
|
||||||
|
}
|
||||||
|
|
||||||
|
public async Task InitializeAsync()
|
||||||
|
{
|
||||||
|
_ = _server.StartAsync(_cts.Token);
|
||||||
|
await _server.WaitForReadyAsync();
|
||||||
|
}
|
||||||
|
|
||||||
|
public async Task DisposeAsync()
|
||||||
|
{
|
||||||
|
await _cts.CancelAsync();
|
||||||
|
_server.Dispose();
|
||||||
|
File.Delete(_certPath);
|
||||||
|
File.Delete(_keyPath);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Tls_client_connects_and_receives_info()
|
||||||
|
{
|
||||||
|
using var tcp = new TcpClient();
|
||||||
|
await tcp.ConnectAsync(IPAddress.Loopback, _port);
|
||||||
|
using var netStream = tcp.GetStream();
|
||||||
|
|
||||||
|
// Read INFO (sent before TLS upgrade in Mode 2)
|
||||||
|
var buf = new byte[4096];
|
||||||
|
var read = await netStream.ReadAsync(buf);
|
||||||
|
var info = Encoding.ASCII.GetString(buf, 0, read);
|
||||||
|
info.ShouldStartWith("INFO ");
|
||||||
|
info.ShouldContain("\"tls_required\":true");
|
||||||
|
|
||||||
|
// Upgrade to TLS
|
||||||
|
using var sslStream = new SslStream(netStream, false, (_, _, _, _) => true);
|
||||||
|
await sslStream.AuthenticateAsClientAsync("localhost");
|
||||||
|
|
||||||
|
// Send CONNECT + PING over TLS
|
||||||
|
await sslStream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
|
||||||
|
await sslStream.FlushAsync();
|
||||||
|
|
||||||
|
// Read PONG
|
||||||
|
var pongBuf = new byte[256];
|
||||||
|
read = await sslStream.ReadAsync(pongBuf);
|
||||||
|
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
|
||||||
|
pong.ShouldContain("PONG");
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Tls_pubsub_works_over_encrypted_connection()
|
||||||
|
{
|
||||||
|
using var tcp1 = new TcpClient();
|
||||||
|
await tcp1.ConnectAsync(IPAddress.Loopback, _port);
|
||||||
|
using var ssl1 = await UpgradeToTlsAsync(tcp1);
|
||||||
|
|
||||||
|
using var tcp2 = new TcpClient();
|
||||||
|
await tcp2.ConnectAsync(IPAddress.Loopback, _port);
|
||||||
|
using var ssl2 = await UpgradeToTlsAsync(tcp2);
|
||||||
|
|
||||||
|
// Sub on client 1
|
||||||
|
await ssl1.WriteAsync("CONNECT {}\r\nSUB test 1\r\nPING\r\n"u8.ToArray());
|
||||||
|
await ssl1.FlushAsync();
|
||||||
|
|
||||||
|
// Wait for PONG to confirm subscription is registered
|
||||||
|
var pongBuf = new byte[256];
|
||||||
|
var pongRead = await ssl1.ReadAsync(pongBuf);
|
||||||
|
var pongStr = Encoding.ASCII.GetString(pongBuf, 0, pongRead);
|
||||||
|
pongStr.ShouldContain("PONG");
|
||||||
|
|
||||||
|
// Pub on client 2
|
||||||
|
await ssl2.WriteAsync("CONNECT {}\r\nPUB test 5\r\nhello\r\nPING\r\n"u8.ToArray());
|
||||||
|
await ssl2.FlushAsync();
|
||||||
|
|
||||||
|
// Client 1 should receive MSG (may arrive across multiple TLS records)
|
||||||
|
var msg = await ReadUntilAsync(ssl1, "hello");
|
||||||
|
msg.ShouldContain("MSG test 1 5");
|
||||||
|
msg.ShouldContain("hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
private static async Task<string> ReadUntilAsync(Stream stream, string expected, int timeoutMs = 5000)
|
||||||
|
{
|
||||||
|
using var cts = new CancellationTokenSource(timeoutMs);
|
||||||
|
var sb = new StringBuilder();
|
||||||
|
var buf = new byte[4096];
|
||||||
|
while (!sb.ToString().Contains(expected))
|
||||||
|
{
|
||||||
|
var n = await stream.ReadAsync(buf, cts.Token);
|
||||||
|
if (n == 0) break;
|
||||||
|
sb.Append(Encoding.ASCII.GetString(buf, 0, n));
|
||||||
|
}
|
||||||
|
return sb.ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static async Task<SslStream> UpgradeToTlsAsync(TcpClient tcp)
|
||||||
|
{
|
||||||
|
var netStream = tcp.GetStream();
|
||||||
|
var buf = new byte[4096];
|
||||||
|
_ = await netStream.ReadAsync(buf); // Read INFO (discard)
|
||||||
|
|
||||||
|
var ssl = new SslStream(netStream, false, (_, _, _, _) => true);
|
||||||
|
await ssl.AuthenticateAsClientAsync("localhost");
|
||||||
|
return ssl;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static int GetFreePort()
|
||||||
|
{
|
||||||
|
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||||
|
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
|
||||||
|
return ((IPEndPoint)sock.LocalEndPoint!).Port;
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user