From 0409acc745f9c305cc458b37c12db552cfc485c2 Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 22 Feb 2026 22:21:11 -0500 Subject: [PATCH] feat: add TlsConnectionWrapper with 4-mode TLS negotiation --- src/NATS.Server/Tls/TlsConnectionWrapper.cs | 172 +++++++++++++++ .../TlsConnectionWrapperTests.cs | 202 ++++++++++++++++++ 2 files changed, 374 insertions(+) create mode 100644 src/NATS.Server/Tls/TlsConnectionWrapper.cs create mode 100644 tests/NATS.Server.Tests/TlsConnectionWrapperTests.cs diff --git a/src/NATS.Server/Tls/TlsConnectionWrapper.cs b/src/NATS.Server/Tls/TlsConnectionWrapper.cs new file mode 100644 index 0000000..5db463f --- /dev/null +++ b/src/NATS.Server/Tls/TlsConnectionWrapper.cs @@ -0,0 +1,172 @@ +using System.Net.Security; +using System.Net.Sockets; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Text.Json; +using Microsoft.Extensions.Logging; +using NATS.Server.Protocol; + +namespace NATS.Server.Tls; + +public static class TlsConnectionWrapper +{ + private const byte TlsRecordMarker = 0x16; + + public static async Task<(Stream stream, bool infoAlreadySent)> NegotiateAsync( + Socket socket, + Stream networkStream, + NatsOptions options, + SslServerAuthenticationOptions? sslOptions, + ServerInfo serverInfo, + ILogger logger, + CancellationToken ct) + { + // Mode 1: No TLS + if (sslOptions == null || !options.HasTls) + return (networkStream, false); + + // Mode 3: TLS First + if (options.TlsHandshakeFirst) + return await NegotiateTlsFirstAsync(socket, networkStream, options, sslOptions, serverInfo, logger, ct); + + // Mode 2 & 4: Send INFO first, then decide + serverInfo.TlsRequired = !options.AllowNonTls; + serverInfo.TlsAvailable = options.AllowNonTls; + serverInfo.TlsVerify = options.TlsVerify; + await SendInfoAsync(networkStream, serverInfo, ct); + + // Peek first byte to detect TLS + var peekable = new PeekableStream(networkStream); + var peeked = await PeekWithTimeoutAsync(peekable, 1, options.TlsTimeout, ct); + + if (peeked.Length == 0) + { + // Client disconnected or timed out + return (peekable, true); + } + + if (peeked[0] == TlsRecordMarker) + { + // Client is starting TLS + var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false); + using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + handshakeCts.CancelAfter(options.TlsTimeout); + + await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token); + logger.LogDebug("TLS handshake complete: {Protocol} {CipherSuite}", + sslStream.SslProtocol, sslStream.NegotiatedCipherSuite); + + // Validate pinned certs + if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert) + { + if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts)) + { + logger.LogWarning("Certificate pinning check failed"); + sslStream.Dispose(); + throw new InvalidOperationException("Certificate pinning check failed"); + } + } + + return (sslStream, true); + } + + // Mode 4: Mixed — client chose plaintext + if (options.AllowNonTls) + { + logger.LogDebug("Client connected without TLS (mixed mode)"); + return (peekable, true); + } + + // TLS required but client sent plaintext + logger.LogWarning("TLS required but client sent plaintext data"); + throw new InvalidOperationException("TLS required"); + } + + private static async Task<(Stream stream, bool infoAlreadySent)> NegotiateTlsFirstAsync( + Socket socket, + Stream networkStream, + NatsOptions options, + SslServerAuthenticationOptions sslOptions, + ServerInfo serverInfo, + ILogger logger, + CancellationToken ct) + { + // Wait for data with fallback timeout + var peekable = new PeekableStream(networkStream); + var peeked = await PeekWithTimeoutAsync(peekable, 1, options.TlsHandshakeFirstFallback, ct); + + if (peeked.Length > 0 && peeked[0] == TlsRecordMarker) + { + // Client started TLS immediately — handshake first, then send INFO + var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false); + using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + handshakeCts.CancelAfter(options.TlsTimeout); + + await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token); + logger.LogDebug("TLS-first handshake complete: {Protocol} {CipherSuite}", + sslStream.SslProtocol, sslStream.NegotiatedCipherSuite); + + // Validate pinned certs + if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert) + { + if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts)) + { + sslStream.Dispose(); + throw new InvalidOperationException("Certificate pinning check failed"); + } + } + + // Now send INFO over encrypted stream + serverInfo.TlsRequired = true; + serverInfo.TlsVerify = options.TlsVerify; + await SendInfoAsync(sslStream, serverInfo, ct); + return (sslStream, true); + } + + // Fallback: timeout expired or non-TLS data — send INFO and negotiate normally + logger.LogDebug("TLS-first fallback: sending INFO"); + serverInfo.TlsRequired = !options.AllowNonTls; + serverInfo.TlsAvailable = options.AllowNonTls; + serverInfo.TlsVerify = options.TlsVerify; + await SendInfoAsync(peekable, serverInfo, ct); + + if (peeked.Length == 0) + { + // Timeout — INFO was sent, return stream for normal flow + return (peekable, true); + } + + // Non-TLS data received during fallback window + if (options.AllowNonTls) + { + return (peekable, true); + } + + // TLS required but got plaintext + throw new InvalidOperationException("TLS required but client sent plaintext"); + } + + private static async Task PeekWithTimeoutAsync( + PeekableStream stream, int count, TimeSpan timeout, CancellationToken ct) + { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); + cts.CancelAfter(timeout); + try + { + return await stream.PeekAsync(count, cts.Token); + } + catch (OperationCanceledException) when (!ct.IsCancellationRequested) + { + // Timeout — not a cancellation of the outer token + return []; + } + } + + private static async Task SendInfoAsync(Stream stream, ServerInfo serverInfo, CancellationToken ct) + { + var infoJson = JsonSerializer.Serialize(serverInfo); + var infoLine = Encoding.ASCII.GetBytes($"INFO {infoJson}\r\n"); + await stream.WriteAsync(infoLine, ct); + await stream.FlushAsync(ct); + } +} diff --git a/tests/NATS.Server.Tests/TlsConnectionWrapperTests.cs b/tests/NATS.Server.Tests/TlsConnectionWrapperTests.cs new file mode 100644 index 0000000..bdfc6e5 --- /dev/null +++ b/tests/NATS.Server.Tests/TlsConnectionWrapperTests.cs @@ -0,0 +1,202 @@ +using System.Net; +using System.Net.Security; +using System.Net.Sockets; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using Microsoft.Extensions.Logging.Abstractions; +using NATS.Server; +using NATS.Server.Protocol; +using NATS.Server.Tls; + +namespace NATS.Server.Tests; + +public class TlsConnectionWrapperTests +{ + [Fact] + public async Task NoTls_returns_plain_stream() + { + var (serverSocket, clientSocket) = await CreateSocketPairAsync(); + using var serverStream = new NetworkStream(serverSocket, ownsSocket: true); + using var clientStream = new NetworkStream(clientSocket, ownsSocket: true); + + var opts = new NatsOptions(); // No TLS configured + var serverInfo = CreateServerInfo(); + + var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync( + serverSocket, serverStream, opts, null, serverInfo, NullLogger.Instance, CancellationToken.None); + + stream.ShouldBe(serverStream); // Same stream, no wrapping + infoSent.ShouldBeFalse(); + } + + [Fact] + public async Task TlsRequired_upgrades_to_ssl() + { + var (cert, _) = TlsHelperTests.GenerateTestCert(); + + var (serverSocket, clientSocket) = await CreateSocketPairAsync(); + using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true); + + var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy" }; + var sslOpts = new SslServerAuthenticationOptions + { + ServerCertificate = cert, + }; + var serverInfo = CreateServerInfo(); + + // Client side: read INFO then start TLS + var clientTask = Task.Run(async () => + { + // Read INFO line + var buf = new byte[4096]; + var read = await clientNetStream.ReadAsync(buf); + var info = System.Text.Encoding.ASCII.GetString(buf, 0, read); + info.ShouldStartWith("INFO "); + + // Upgrade to TLS + var sslClient = new SslStream(clientNetStream, true, + (_, _, _, _) => true); // Trust all for testing + await sslClient.AuthenticateAsClientAsync("localhost"); + return sslClient; + }); + + var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true); + var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync( + serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None); + + stream.ShouldBeOfType(); + infoSent.ShouldBeTrue(); + + var clientSsl = await clientTask; + + // Verify encrypted communication works + await stream.WriteAsync("PING\r\n"u8.ToArray()); + await stream.FlushAsync(); + + var readBuf = new byte[64]; + var bytesRead = await clientSsl.ReadAsync(readBuf); + var msg = System.Text.Encoding.ASCII.GetString(readBuf, 0, bytesRead); + msg.ShouldBe("PING\r\n"); + + stream.Dispose(); + clientSsl.Dispose(); + } + + [Fact] + public async Task MixedMode_allows_plaintext_when_AllowNonTls() + { + var (cert, _) = TlsHelperTests.GenerateTestCert(); + + var (serverSocket, clientSocket) = await CreateSocketPairAsync(); + using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true); + + var opts = new NatsOptions + { + TlsCert = "dummy", + TlsKey = "dummy", + AllowNonTls = true, + TlsTimeout = TimeSpan.FromSeconds(2), + }; + var sslOpts = new SslServerAuthenticationOptions + { + ServerCertificate = cert, + }; + var serverInfo = CreateServerInfo(); + + // Client side: read INFO then send plaintext (not TLS) + var clientTask = Task.Run(async () => + { + var buf = new byte[4096]; + var read = await clientNetStream.ReadAsync(buf); + var info = System.Text.Encoding.ASCII.GetString(buf, 0, read); + info.ShouldStartWith("INFO "); + + // Send plaintext CONNECT (not a TLS handshake) + var connectLine = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n"); + await clientNetStream.WriteAsync(connectLine); + await clientNetStream.FlushAsync(); + }); + + var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true); + var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync( + serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None); + + await clientTask; + + // In mixed mode with plaintext client, we get a PeekableStream, not SslStream + stream.ShouldBeOfType(); + infoSent.ShouldBeTrue(); + + stream.Dispose(); + } + + [Fact] + public async Task TlsRequired_rejects_plaintext() + { + var (cert, _) = TlsHelperTests.GenerateTestCert(); + + var (serverSocket, clientSocket) = await CreateSocketPairAsync(); + using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true); + + var opts = new NatsOptions + { + TlsCert = "dummy", + TlsKey = "dummy", + AllowNonTls = false, + TlsTimeout = TimeSpan.FromSeconds(2), + }; + var sslOpts = new SslServerAuthenticationOptions + { + ServerCertificate = cert, + }; + var serverInfo = CreateServerInfo(); + + // Client side: read INFO then send plaintext + var clientTask = Task.Run(async () => + { + var buf = new byte[4096]; + var read = await clientNetStream.ReadAsync(buf); + var info = System.Text.Encoding.ASCII.GetString(buf, 0, read); + info.ShouldStartWith("INFO "); + + // Send plaintext data (first byte is 'C', not 0x16 TLS marker) + var connectLine = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n"); + await clientNetStream.WriteAsync(connectLine); + await clientNetStream.FlushAsync(); + }); + + var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true); + + await Should.ThrowAsync(async () => + { + await TlsConnectionWrapper.NegotiateAsync( + serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None); + }); + + await clientTask; + serverNetStream.Dispose(); + } + + private static ServerInfo CreateServerInfo() => new() + { + ServerId = "TEST", + ServerName = "test", + Version = NatsProtocol.Version, + Host = "127.0.0.1", + Port = 4222, + }; + + private static async Task<(Socket server, Socket client)> CreateSocketPairAsync() + { + using var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + listener.Listen(1); + var port = ((IPEndPoint)listener.LocalEndPoint!).Port; + + var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await client.ConnectAsync(new IPEndPoint(IPAddress.Loopback, port)); + var server = await listener.AcceptAsync(); + + return (server, client); + } +}