From 87746168ba90ba4e375c2a45e979437dc4094109 Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 22 Feb 2026 22:35:42 -0500 Subject: [PATCH] feat: wire TLS negotiation into NatsServer accept loop Integrate TLS support into the server's connection accept path: - Add SslServerAuthenticationOptions and TlsRateLimiter fields to NatsServer - Extract AcceptClientAsync method for TLS negotiation, rate limiting, and TLS state extraction (protocol version, cipher suite, peer certificate) - Add InfoAlreadySent flag to NatsClient to skip redundant INFO when TlsConnectionWrapper already sent it during negotiation - Add TlsServerTests verifying TLS connect+INFO and TLS pub/sub --- src/NATS.Server/NatsClient.cs | 6 +- src/NATS.Server/NatsServer.cs | 68 +++++++++-- tests/NATS.Server.Tests/TlsServerTests.cs | 137 ++++++++++++++++++++++ 3 files changed, 202 insertions(+), 9 deletions(-) create mode 100644 tests/NATS.Server.Tests/TlsServerTests.cs diff --git a/src/NATS.Server/NatsClient.cs b/src/NATS.Server/NatsClient.cs index 2978e25..e70c453 100644 --- a/src/NATS.Server/NatsClient.cs +++ b/src/NATS.Server/NatsClient.cs @@ -57,6 +57,7 @@ public sealed class NatsClient : IDisposable private long _lastIn; public TlsConnectionState? TlsState { get; set; } + public bool InfoAlreadySent { get; set; } public IReadOnlyDictionary Subscriptions => _subs; @@ -87,8 +88,9 @@ public sealed class NatsClient : IDisposable var pipe = new Pipe(); try { - // Send INFO - await SendInfoAsync(_clientCts.Token); + // Send INFO (skip if already sent during TLS negotiation) + if (!InfoAlreadySent) + await SendInfoAsync(_clientCts.Token); // Start read pump, command processing, and ping timer in parallel var fillTask = FillPipeAsync(pipe.Writer, _clientCts.Token); diff --git a/src/NATS.Server/NatsServer.cs b/src/NATS.Server/NatsServer.cs index 4315cd7..a9db6fb 100644 --- a/src/NATS.Server/NatsServer.cs +++ b/src/NATS.Server/NatsServer.cs @@ -1,11 +1,14 @@ using System.Collections.Concurrent; using System.Net; +using System.Net.Security; using System.Net.Sockets; +using System.Security.Cryptography.X509Certificates; using System.Text; using Microsoft.Extensions.Logging; using NATS.Server.Monitoring; using NATS.Server.Protocol; using NATS.Server.Subscriptions; +using NATS.Server.Tls; namespace NATS.Server; @@ -19,6 +22,8 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable private readonly ILoggerFactory _loggerFactory; private readonly ServerStats _stats = new(); private readonly TaskCompletionSource _listeningStarted = new(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly SslServerAuthenticationOptions? _sslOptions; + private readonly TlsRateLimiter? _tlsRateLimiter; private Socket? _listener; private MonitorServer? _monitorServer; private ulong _nextClientId; @@ -48,6 +53,17 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable Port = options.Port, MaxPayload = options.MaxPayload, }; + + if (options.HasTls) + { + _sslOptions = TlsHelper.BuildServerAuthOptions(options); + _serverInfo.TlsRequired = !options.AllowNonTls; + _serverInfo.TlsAvailable = options.AllowNonTls; + _serverInfo.TlsVerify = options.TlsVerify; + + if (options.TlsRateLimit > 0) + _tlsRateLimiter = new TlsRateLimiter(options.TlsRateLimit); + } } public async Task StartAsync(CancellationToken ct) @@ -105,13 +121,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable _logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint); - var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]"); - var networkStream = new NetworkStream(socket, ownsSocket: false); - var client = new NatsClient(clientId, networkStream, socket, _options, _serverInfo, clientLogger, _stats); - client.Router = this; - _clients[clientId] = client; - - _ = RunClientAsync(client, ct); + _ = AcceptClientAsync(socket, clientId, ct); } } catch (OperationCanceledException) @@ -120,6 +130,49 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable } } + private async Task AcceptClientAsync(Socket socket, ulong clientId, CancellationToken ct) + { + try + { + // Rate limit TLS handshakes + if (_tlsRateLimiter != null) + await _tlsRateLimiter.WaitAsync(ct); + + var networkStream = new NetworkStream(socket, ownsSocket: false); + + // TLS negotiation (no-op if not configured) + var (stream, infoAlreadySent) = await TlsConnectionWrapper.NegotiateAsync( + socket, networkStream, _options, _sslOptions, _serverInfo, + _loggerFactory.CreateLogger("NATS.Server.Tls"), ct); + + // Extract TLS state + TlsConnectionState? tlsState = null; + if (stream is SslStream ssl) + { + tlsState = new TlsConnectionState( + ssl.SslProtocol.ToString(), + ssl.NegotiatedCipherSuite.ToString(), + ssl.RemoteCertificate as X509Certificate2); + } + + var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]"); + var client = new NatsClient(clientId, stream, socket, _options, _serverInfo, + clientLogger, _stats); + client.Router = this; + client.TlsState = tlsState; + client.InfoAlreadySent = infoAlreadySent; + _clients[clientId] = client; + + await RunClientAsync(client, ct); + } + catch (Exception ex) + { + _logger.LogDebug(ex, "Failed to accept client {ClientId}", clientId); + try { socket.Shutdown(SocketShutdown.Both); } catch { } + socket.Dispose(); + } + } + private async Task RunClientAsync(NatsClient client, CancellationToken ct) { try @@ -199,6 +252,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable { if (_monitorServer != null) _monitorServer.DisposeAsync().AsTask().GetAwaiter().GetResult(); + _tlsRateLimiter?.Dispose(); _listener?.Dispose(); foreach (var client in _clients.Values) client.Dispose(); diff --git a/tests/NATS.Server.Tests/TlsServerTests.cs b/tests/NATS.Server.Tests/TlsServerTests.cs new file mode 100644 index 0000000..8e87a75 --- /dev/null +++ b/tests/NATS.Server.Tests/TlsServerTests.cs @@ -0,0 +1,137 @@ +using System.Net; +using System.Net.Security; +using System.Net.Sockets; +using System.Text; +using Microsoft.Extensions.Logging.Abstractions; +using NATS.Server; + +namespace NATS.Server.Tests; + +public class TlsServerTests : IAsyncLifetime +{ + private readonly NatsServer _server; + private readonly int _port; + private readonly CancellationTokenSource _cts = new(); + private readonly string _certPath; + private readonly string _keyPath; + + public TlsServerTests() + { + _port = GetFreePort(); + (_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles(); + _server = new NatsServer( + new NatsOptions + { + Port = _port, + TlsCert = _certPath, + TlsKey = _keyPath, + }, + NullLoggerFactory.Instance); + } + + public async Task InitializeAsync() + { + _ = _server.StartAsync(_cts.Token); + await _server.WaitForReadyAsync(); + } + + public async Task DisposeAsync() + { + await _cts.CancelAsync(); + _server.Dispose(); + File.Delete(_certPath); + File.Delete(_keyPath); + } + + [Fact] + public async Task Tls_client_connects_and_receives_info() + { + using var tcp = new TcpClient(); + await tcp.ConnectAsync(IPAddress.Loopback, _port); + using var netStream = tcp.GetStream(); + + // Read INFO (sent before TLS upgrade in Mode 2) + var buf = new byte[4096]; + var read = await netStream.ReadAsync(buf); + var info = Encoding.ASCII.GetString(buf, 0, read); + info.ShouldStartWith("INFO "); + info.ShouldContain("\"tls_required\":true"); + + // Upgrade to TLS + using var sslStream = new SslStream(netStream, false, (_, _, _, _) => true); + await sslStream.AuthenticateAsClientAsync("localhost"); + + // Send CONNECT + PING over TLS + await sslStream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray()); + await sslStream.FlushAsync(); + + // Read PONG + var pongBuf = new byte[256]; + read = await sslStream.ReadAsync(pongBuf); + var pong = Encoding.ASCII.GetString(pongBuf, 0, read); + pong.ShouldContain("PONG"); + } + + [Fact] + public async Task Tls_pubsub_works_over_encrypted_connection() + { + using var tcp1 = new TcpClient(); + await tcp1.ConnectAsync(IPAddress.Loopback, _port); + using var ssl1 = await UpgradeToTlsAsync(tcp1); + + using var tcp2 = new TcpClient(); + await tcp2.ConnectAsync(IPAddress.Loopback, _port); + using var ssl2 = await UpgradeToTlsAsync(tcp2); + + // Sub on client 1 + await ssl1.WriteAsync("CONNECT {}\r\nSUB test 1\r\nPING\r\n"u8.ToArray()); + await ssl1.FlushAsync(); + + // Wait for PONG to confirm subscription is registered + var pongBuf = new byte[256]; + var pongRead = await ssl1.ReadAsync(pongBuf); + var pongStr = Encoding.ASCII.GetString(pongBuf, 0, pongRead); + pongStr.ShouldContain("PONG"); + + // Pub on client 2 + await ssl2.WriteAsync("CONNECT {}\r\nPUB test 5\r\nhello\r\nPING\r\n"u8.ToArray()); + await ssl2.FlushAsync(); + + // Client 1 should receive MSG (may arrive across multiple TLS records) + var msg = await ReadUntilAsync(ssl1, "hello"); + msg.ShouldContain("MSG test 1 5"); + msg.ShouldContain("hello"); + } + + private static async Task ReadUntilAsync(Stream stream, string expected, int timeoutMs = 5000) + { + using var cts = new CancellationTokenSource(timeoutMs); + var sb = new StringBuilder(); + var buf = new byte[4096]; + while (!sb.ToString().Contains(expected)) + { + var n = await stream.ReadAsync(buf, cts.Token); + if (n == 0) break; + sb.Append(Encoding.ASCII.GetString(buf, 0, n)); + } + return sb.ToString(); + } + + private static async Task UpgradeToTlsAsync(TcpClient tcp) + { + var netStream = tcp.GetStream(); + var buf = new byte[4096]; + _ = await netStream.ReadAsync(buf); // Read INFO (discard) + + var ssl = new SslStream(netStream, false, (_, _, _, _) => true); + await ssl.AuthenticateAsClientAsync("localhost"); + return ssl; + } + + private static int GetFreePort() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + sock.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + return ((IPEndPoint)sock.LocalEndPoint!).Port; + } +}