using System.Net; using System.Net.Security; using System.Net.Sockets; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using Microsoft.Extensions.Logging.Abstractions; using NATS.Server; using NATS.Server.Protocol; using NATS.Server.TestUtilities; using NATS.Server.Tls; namespace NATS.Server.Transport.Tests; public class TlsConnectionWrapperTests { [Fact] public async Task NoTls_returns_plain_stream() { var (serverSocket, clientSocket) = await CreateSocketPairAsync(); using var serverStream = new NetworkStream(serverSocket, ownsSocket: true); using var clientStream = new NetworkStream(clientSocket, ownsSocket: true); var opts = new NatsOptions(); // No TLS configured var serverInfo = CreateServerInfo(); var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync( serverSocket, serverStream, opts, null, serverInfo, NullLogger.Instance, CancellationToken.None); stream.ShouldBe(serverStream); // Same stream, no wrapping infoSent.ShouldBeFalse(); } [Fact] public async Task TlsRequired_upgrades_to_ssl() { var (cert, _) = TestCertHelper.GenerateTestCert(); var (serverSocket, clientSocket) = await CreateSocketPairAsync(); using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true); var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy" }; var sslOpts = new SslServerAuthenticationOptions { ServerCertificate = cert, }; var serverInfo = CreateServerInfo(); // Client side: read INFO then start TLS var clientTask = Task.Run(async () => { // Read INFO line var buf = new byte[4096]; var read = await clientNetStream.ReadAsync(buf); var info = System.Text.Encoding.ASCII.GetString(buf, 0, read); info.ShouldStartWith("INFO "); // Upgrade to TLS var sslClient = new SslStream(clientNetStream, true, (_, _, _, _) => true); // Trust all for testing await sslClient.AuthenticateAsClientAsync("localhost"); return sslClient; }); var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true); var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync( serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None); stream.ShouldBeOfType(); infoSent.ShouldBeTrue(); var clientSsl = await clientTask; // Verify encrypted communication works await stream.WriteAsync("PING\r\n"u8.ToArray()); await stream.FlushAsync(); var readBuf = new byte[64]; var bytesRead = await clientSsl.ReadAsync(readBuf); var msg = System.Text.Encoding.ASCII.GetString(readBuf, 0, bytesRead); msg.ShouldBe("PING\r\n"); stream.Dispose(); clientSsl.Dispose(); } [Fact] public async Task MixedMode_allows_plaintext_when_AllowNonTls() { var (cert, _) = TestCertHelper.GenerateTestCert(); var (serverSocket, clientSocket) = await CreateSocketPairAsync(); using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true); var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy", AllowNonTls = true, TlsTimeout = TimeSpan.FromSeconds(2), }; var sslOpts = new SslServerAuthenticationOptions { ServerCertificate = cert, }; var serverInfo = CreateServerInfo(); // Client side: read INFO then send plaintext (not TLS) var clientTask = Task.Run(async () => { var buf = new byte[4096]; var read = await clientNetStream.ReadAsync(buf); var info = System.Text.Encoding.ASCII.GetString(buf, 0, read); info.ShouldStartWith("INFO "); // Send plaintext CONNECT (not a TLS handshake) var connectLine = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n"); await clientNetStream.WriteAsync(connectLine); await clientNetStream.FlushAsync(); }); var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true); var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync( serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None); await clientTask; // In mixed mode with plaintext client, we get a PeekableStream, not SslStream stream.ShouldBeOfType(); infoSent.ShouldBeTrue(); stream.Dispose(); } [Fact] public async Task TlsRequired_rejects_plaintext() { var (cert, _) = TestCertHelper.GenerateTestCert(); var (serverSocket, clientSocket) = await CreateSocketPairAsync(); using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true); var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy", AllowNonTls = false, TlsTimeout = TimeSpan.FromSeconds(2), }; var sslOpts = new SslServerAuthenticationOptions { ServerCertificate = cert, }; var serverInfo = CreateServerInfo(); // Client side: read INFO then send plaintext var clientTask = Task.Run(async () => { var buf = new byte[4096]; var read = await clientNetStream.ReadAsync(buf); var info = System.Text.Encoding.ASCII.GetString(buf, 0, read); info.ShouldStartWith("INFO "); // Send plaintext data (first byte is 'C', not 0x16 TLS marker) var connectLine = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n"); await clientNetStream.WriteAsync(connectLine); await clientNetStream.FlushAsync(); }); var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true); await Should.ThrowAsync(async () => { await TlsConnectionWrapper.NegotiateAsync( serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None); }); await clientTask; serverNetStream.Dispose(); } [Fact] public async Task TlsFirst_handshakes_before_sending_info() { var (cert, _) = TestCertHelper.GenerateTestCert(); var (serverSocket, clientSocket) = await CreateSocketPairAsync(); using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true); var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy", TlsHandshakeFirst = true }; var sslOpts = new SslServerAuthenticationOptions { ServerCertificate = cert, }; var serverInfo = CreateServerInfo(); // Client side: immediately start TLS (no INFO first) var clientTask = Task.Run(async () => { var sslClient = new SslStream(clientNetStream, true, (_, _, _, _) => true); await sslClient.AuthenticateAsClientAsync("localhost"); // After TLS, read INFO over encrypted stream var buf = new byte[4096]; var read = await sslClient.ReadAsync(buf); var info = System.Text.Encoding.ASCII.GetString(buf, 0, read); info.ShouldStartWith("INFO "); return sslClient; }); var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true); var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync( serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None); stream.ShouldBeOfType(); infoSent.ShouldBeTrue(); var clientSsl = await clientTask; // Verify encrypted communication works await stream.WriteAsync("PING\r\n"u8.ToArray()); await stream.FlushAsync(); var readBuf = new byte[64]; var bytesRead = await clientSsl.ReadAsync(readBuf); var msg = System.Text.Encoding.ASCII.GetString(readBuf, 0, bytesRead); msg.ShouldBe("PING\r\n"); stream.Dispose(); clientSsl.Dispose(); } private static ServerInfo CreateServerInfo() => new() { ServerId = "TEST", ServerName = "test", Version = NatsProtocol.Version, Host = "127.0.0.1", Port = 4222, }; private static async Task<(Socket server, Socket client)> CreateSocketPairAsync() { using var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); listener.Listen(1); var port = ((IPEndPoint)listener.LocalEndPoint!).Port; var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); await client.ConnectAsync(new IPEndPoint(IPAddress.Loopback, port)); var server = await listener.AcceptAsync(); return (server, client); } }