From 3b6bd08248822dc892ab0fec9d49fe5b6b3c0c42 Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 22 Feb 2026 22:40:03 -0500 Subject: [PATCH] feat: add TLS mixed mode tests and monitoring TLS field verification Add TlsMixedModeTests verifying that a server with AllowNonTls=true accepts both plaintext and TLS clients on the same port. Add MonitorTlsTests verifying that /connz reports TlsVersion and TlsCipherSuite for TLS-connected clients. --- tests/NATS.Server.Tests/MonitorTests.cs | 88 +++++++++++++++++++++++ tests/NATS.Server.Tests/TlsServerTests.cs | 88 +++++++++++++++++++++++ 2 files changed, 176 insertions(+) diff --git a/tests/NATS.Server.Tests/MonitorTests.cs b/tests/NATS.Server.Tests/MonitorTests.cs index 898557a..65a1399 100644 --- a/tests/NATS.Server.Tests/MonitorTests.cs +++ b/tests/NATS.Server.Tests/MonitorTests.cs @@ -1,6 +1,8 @@ using System.Net; using System.Net.Http.Json; +using System.Net.Security; using System.Net.Sockets; +using System.Text; using Microsoft.Extensions.Logging.Abstractions; using NATS.Server.Monitoring; @@ -184,3 +186,89 @@ public class MonitorTests : IAsyncLifetime return ((IPEndPoint)sock.LocalEndPoint!).Port; } } + +public class MonitorTlsTests : IAsyncLifetime +{ + private readonly NatsServer _server; + private readonly int _natsPort; + private readonly int _monitorPort; + private readonly CancellationTokenSource _cts = new(); + private readonly HttpClient _http = new(); + private readonly string _certPath; + private readonly string _keyPath; + + public MonitorTlsTests() + { + _natsPort = GetFreePort(); + _monitorPort = GetFreePort(); + (_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles(); + _server = new NatsServer( + new NatsOptions + { + Port = _natsPort, + MonitorPort = _monitorPort, + TlsCert = _certPath, + TlsKey = _keyPath, + }, + NullLoggerFactory.Instance); + } + + public async Task InitializeAsync() + { + _ = _server.StartAsync(_cts.Token); + await _server.WaitForReadyAsync(); + // Wait for monitoring HTTP server to be ready + for (int i = 0; i < 50; i++) + { + try + { + var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/healthz"); + if (response.IsSuccessStatusCode) break; + } + catch (HttpRequestException) { } + await Task.Delay(50); + } + } + + public async Task DisposeAsync() + { + _http.Dispose(); + await _cts.CancelAsync(); + _server.Dispose(); + File.Delete(_certPath); + File.Delete(_keyPath); + } + + [Fact] + public async Task Connz_shows_tls_info_for_tls_client() + { + // Connect and upgrade to TLS + using var tcp = new TcpClient(); + await tcp.ConnectAsync(IPAddress.Loopback, _natsPort); + using var netStream = tcp.GetStream(); + var buf = new byte[4096]; + _ = await netStream.ReadAsync(buf); // Read INFO + + using var ssl = new SslStream(netStream, false, (_, _, _, _) => true); + await ssl.AuthenticateAsClientAsync("localhost"); + + await ssl.WriteAsync("CONNECT {}\r\n"u8.ToArray()); + await ssl.FlushAsync(); + await Task.Delay(200); + + var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/connz"); + var connz = await response.Content.ReadFromJsonAsync(); + + connz!.Conns.Length.ShouldBeGreaterThanOrEqualTo(1); + var conn = connz.Conns[0]; + conn.TlsVersion.ShouldNotBeNullOrEmpty(); + conn.TlsCipherSuite.ShouldNotBeNullOrEmpty(); + } + + private static int GetFreePort() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + sock.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + return ((IPEndPoint)sock.LocalEndPoint!).Port; + } +} diff --git a/tests/NATS.Server.Tests/TlsServerTests.cs b/tests/NATS.Server.Tests/TlsServerTests.cs index 8e87a75..703b1bf 100644 --- a/tests/NATS.Server.Tests/TlsServerTests.cs +++ b/tests/NATS.Server.Tests/TlsServerTests.cs @@ -135,3 +135,91 @@ public class TlsServerTests : IAsyncLifetime return ((IPEndPoint)sock.LocalEndPoint!).Port; } } + +public class TlsMixedModeTests : IAsyncLifetime +{ + private readonly NatsServer _server; + private readonly int _port; + private readonly CancellationTokenSource _cts = new(); + private readonly string _certPath; + private readonly string _keyPath; + + public TlsMixedModeTests() + { + _port = GetFreePort(); + (_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles(); + _server = new NatsServer( + new NatsOptions + { + Port = _port, + TlsCert = _certPath, + TlsKey = _keyPath, + AllowNonTls = true, + }, + NullLoggerFactory.Instance); + } + + public async Task InitializeAsync() + { + _ = _server.StartAsync(_cts.Token); + await _server.WaitForReadyAsync(); + } + + public async Task DisposeAsync() + { + await _cts.CancelAsync(); + _server.Dispose(); + File.Delete(_certPath); + File.Delete(_keyPath); + } + + [Fact] + public async Task Mixed_mode_accepts_plain_client() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port)); + using var stream = new NetworkStream(sock); + + var buf = new byte[4096]; + var read = await stream.ReadAsync(buf); + var info = Encoding.ASCII.GetString(buf, 0, read); + info.ShouldContain("\"tls_available\":true"); + + await stream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray()); + await stream.FlushAsync(); + + var pongBuf = new byte[64]; + read = await stream.ReadAsync(pongBuf); + var pong = Encoding.ASCII.GetString(pongBuf, 0, read); + pong.ShouldContain("PONG"); + } + + [Fact] + public async Task Mixed_mode_accepts_tls_client() + { + using var tcp = new TcpClient(); + await tcp.ConnectAsync(IPAddress.Loopback, _port); + using var netStream = tcp.GetStream(); + + var buf = new byte[4096]; + _ = await netStream.ReadAsync(buf); // Read INFO + + using var ssl = new SslStream(netStream, false, (_, _, _, _) => true); + await ssl.AuthenticateAsClientAsync("localhost"); + + await ssl.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray()); + await ssl.FlushAsync(); + + var pongBuf = new byte[64]; + var read = await ssl.ReadAsync(pongBuf); + var pong = Encoding.ASCII.GetString(pongBuf, 0, read); + pong.ShouldContain("PONG"); + } + + private static int GetFreePort() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + sock.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + return ((IPEndPoint)sock.LocalEndPoint!).Port; + } +}