using System.Net; using System.Net.Security; using System.Net.Sockets; using System.Text; using Microsoft.Extensions.Logging.Abstractions; using NATS.Server; namespace NATS.Server.Tests; public class TlsServerTests : IAsyncLifetime { private readonly NatsServer _server; private readonly int _port; private readonly CancellationTokenSource _cts = new(); private readonly string _certPath; private readonly string _keyPath; public TlsServerTests() { _port = GetFreePort(); (_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles(); _server = new NatsServer( new NatsOptions { Port = _port, TlsCert = _certPath, TlsKey = _keyPath, }, NullLoggerFactory.Instance); } public async Task InitializeAsync() { _ = _server.StartAsync(_cts.Token); await _server.WaitForReadyAsync(); } public async Task DisposeAsync() { await _cts.CancelAsync(); _server.Dispose(); File.Delete(_certPath); File.Delete(_keyPath); } [Fact] public async Task Tls_client_connects_and_receives_info() { using var tcp = new TcpClient(); await tcp.ConnectAsync(IPAddress.Loopback, _port); using var netStream = tcp.GetStream(); // Read INFO (sent before TLS upgrade in Mode 2) var buf = new byte[4096]; var read = await netStream.ReadAsync(buf); var info = Encoding.ASCII.GetString(buf, 0, read); info.ShouldStartWith("INFO "); info.ShouldContain("\"tls_required\":true"); // Upgrade to TLS using var sslStream = new SslStream(netStream, false, (_, _, _, _) => true); await sslStream.AuthenticateAsClientAsync("localhost"); // Send CONNECT + PING over TLS await sslStream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray()); await sslStream.FlushAsync(); // Read PONG var pongBuf = new byte[256]; read = await sslStream.ReadAsync(pongBuf); var pong = Encoding.ASCII.GetString(pongBuf, 0, read); pong.ShouldContain("PONG"); } [Fact] public async Task Tls_pubsub_works_over_encrypted_connection() { using var tcp1 = new TcpClient(); await tcp1.ConnectAsync(IPAddress.Loopback, _port); using var ssl1 = await UpgradeToTlsAsync(tcp1); using var tcp2 = new TcpClient(); await tcp2.ConnectAsync(IPAddress.Loopback, _port); using var ssl2 = await UpgradeToTlsAsync(tcp2); // Sub on client 1 await ssl1.WriteAsync("CONNECT {}\r\nSUB test 1\r\nPING\r\n"u8.ToArray()); await ssl1.FlushAsync(); // Wait for PONG to confirm subscription is registered var pongBuf = new byte[256]; var pongRead = await ssl1.ReadAsync(pongBuf); var pongStr = Encoding.ASCII.GetString(pongBuf, 0, pongRead); pongStr.ShouldContain("PONG"); // Pub on client 2 await ssl2.WriteAsync("CONNECT {}\r\nPUB test 5\r\nhello\r\nPING\r\n"u8.ToArray()); await ssl2.FlushAsync(); // Client 1 should receive MSG (may arrive across multiple TLS records) var msg = await ReadUntilAsync(ssl1, "hello"); msg.ShouldContain("MSG test 1 5"); msg.ShouldContain("hello"); } private static async Task ReadUntilAsync(Stream stream, string expected, int timeoutMs = 5000) { using var cts = new CancellationTokenSource(timeoutMs); var sb = new StringBuilder(); var buf = new byte[4096]; while (!sb.ToString().Contains(expected)) { var n = await stream.ReadAsync(buf, cts.Token); if (n == 0) break; sb.Append(Encoding.ASCII.GetString(buf, 0, n)); } return sb.ToString(); } private static async Task UpgradeToTlsAsync(TcpClient tcp) { var netStream = tcp.GetStream(); var buf = new byte[4096]; _ = await netStream.ReadAsync(buf); // Read INFO (discard) var ssl = new SslStream(netStream, false, (_, _, _, _) => true); await ssl.AuthenticateAsClientAsync("localhost"); return ssl; } private static int GetFreePort() { using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); sock.Bind(new IPEndPoint(IPAddress.Loopback, 0)); return ((IPEndPoint)sock.LocalEndPoint!).Port; } } public class TlsMixedModeTests : IAsyncLifetime { private readonly NatsServer _server; private readonly int _port; private readonly CancellationTokenSource _cts = new(); private readonly string _certPath; private readonly string _keyPath; public TlsMixedModeTests() { _port = GetFreePort(); (_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles(); _server = new NatsServer( new NatsOptions { Port = _port, TlsCert = _certPath, TlsKey = _keyPath, AllowNonTls = true, }, NullLoggerFactory.Instance); } public async Task InitializeAsync() { _ = _server.StartAsync(_cts.Token); await _server.WaitForReadyAsync(); } public async Task DisposeAsync() { await _cts.CancelAsync(); _server.Dispose(); File.Delete(_certPath); File.Delete(_keyPath); } [Fact] public async Task Mixed_mode_accepts_plain_client() { using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port)); using var stream = new NetworkStream(sock); var buf = new byte[4096]; var read = await stream.ReadAsync(buf); var info = Encoding.ASCII.GetString(buf, 0, read); info.ShouldContain("\"tls_available\":true"); await stream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray()); await stream.FlushAsync(); var pongBuf = new byte[64]; read = await stream.ReadAsync(pongBuf); var pong = Encoding.ASCII.GetString(pongBuf, 0, read); pong.ShouldContain("PONG"); } [Fact] public async Task Mixed_mode_accepts_tls_client() { using var tcp = new TcpClient(); await tcp.ConnectAsync(IPAddress.Loopback, _port); using var netStream = tcp.GetStream(); var buf = new byte[4096]; _ = await netStream.ReadAsync(buf); // Read INFO using var ssl = new SslStream(netStream, false, (_, _, _, _) => true); await ssl.AuthenticateAsClientAsync("localhost"); await ssl.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray()); await ssl.FlushAsync(); var pongBuf = new byte[64]; var read = await ssl.ReadAsync(pongBuf); var pong = Encoding.ASCII.GetString(pongBuf, 0, read); pong.ShouldContain("PONG"); } private static int GetFreePort() { using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); sock.Bind(new IPEndPoint(IPAddress.Loopback, 0)); return ((IPEndPoint)sock.LocalEndPoint!).Port; } }