diff --git a/src/NATS.Server/Tls/TlsConnectionWrapper.cs b/src/NATS.Server/Tls/TlsConnectionWrapper.cs index 5db463f..0ca0961 100644 --- a/src/NATS.Server/Tls/TlsConnectionWrapper.cs +++ b/src/NATS.Server/Tls/TlsConnectionWrapper.cs @@ -25,6 +25,21 @@ public static class TlsConnectionWrapper if (sslOptions == null || !options.HasTls) return (networkStream, false); + // Clone to avoid mutating shared instance + serverInfo = new ServerInfo + { + ServerId = serverInfo.ServerId, + ServerName = serverInfo.ServerName, + Version = serverInfo.Version, + Proto = serverInfo.Proto, + Host = serverInfo.Host, + Port = serverInfo.Port, + Headers = serverInfo.Headers, + MaxPayload = serverInfo.MaxPayload, + ClientId = serverInfo.ClientId, + ClientIp = serverInfo.ClientIp, + }; + // Mode 3: TLS First if (options.TlsHandshakeFirst) return await NegotiateTlsFirstAsync(socket, networkStream, options, sslOptions, serverInfo, logger, ct); @@ -49,23 +64,30 @@ public static class TlsConnectionWrapper { // Client is starting TLS var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false); - using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct); - handshakeCts.CancelAfter(options.TlsTimeout); - - await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token); - logger.LogDebug("TLS handshake complete: {Protocol} {CipherSuite}", - sslStream.SslProtocol, sslStream.NegotiatedCipherSuite); - - // Validate pinned certs - if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert) + try { - if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts)) + using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + handshakeCts.CancelAfter(options.TlsTimeout); + + await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token); + logger.LogDebug("TLS handshake complete: {Protocol} {CipherSuite}", + sslStream.SslProtocol, sslStream.NegotiatedCipherSuite); + + // Validate pinned certs + if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert) { - logger.LogWarning("Certificate pinning check failed"); - sslStream.Dispose(); - throw new InvalidOperationException("Certificate pinning check failed"); + if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts)) + { + logger.LogWarning("Certificate pinning check failed"); + throw new InvalidOperationException("Certificate pinning check failed"); + } } } + catch + { + sslStream.Dispose(); + throw; + } return (sslStream, true); } @@ -99,27 +121,35 @@ public static class TlsConnectionWrapper { // Client started TLS immediately — handshake first, then send INFO var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false); - using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct); - handshakeCts.CancelAfter(options.TlsTimeout); - - await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token); - logger.LogDebug("TLS-first handshake complete: {Protocol} {CipherSuite}", - sslStream.SslProtocol, sslStream.NegotiatedCipherSuite); - - // Validate pinned certs - if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert) + try { - if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts)) + using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + handshakeCts.CancelAfter(options.TlsTimeout); + + await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token); + logger.LogDebug("TLS-first handshake complete: {Protocol} {CipherSuite}", + sslStream.SslProtocol, sslStream.NegotiatedCipherSuite); + + // Validate pinned certs + if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert) { - sslStream.Dispose(); - throw new InvalidOperationException("Certificate pinning check failed"); + if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts)) + { + throw new InvalidOperationException("Certificate pinning check failed"); + } } + + // Now send INFO over encrypted stream + serverInfo.TlsRequired = true; + serverInfo.TlsVerify = options.TlsVerify; + await SendInfoAsync(sslStream, serverInfo, ct); + } + catch + { + sslStream.Dispose(); + throw; } - // Now send INFO over encrypted stream - serverInfo.TlsRequired = true; - serverInfo.TlsVerify = options.TlsVerify; - await SendInfoAsync(sslStream, serverInfo, ct); return (sslStream, true); } diff --git a/tests/NATS.Server.Tests/TlsConnectionWrapperTests.cs b/tests/NATS.Server.Tests/TlsConnectionWrapperTests.cs index bdfc6e5..55df6cc 100644 --- a/tests/NATS.Server.Tests/TlsConnectionWrapperTests.cs +++ b/tests/NATS.Server.Tests/TlsConnectionWrapperTests.cs @@ -177,6 +177,58 @@ public class TlsConnectionWrapperTests serverNetStream.Dispose(); } + [Fact] + public async Task TlsFirst_handshakes_before_sending_info() + { + var (cert, _) = TlsHelperTests.GenerateTestCert(); + + var (serverSocket, clientSocket) = await CreateSocketPairAsync(); + using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true); + + var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy", TlsHandshakeFirst = true }; + var sslOpts = new SslServerAuthenticationOptions + { + ServerCertificate = cert, + }; + var serverInfo = CreateServerInfo(); + + // Client side: immediately start TLS (no INFO first) + var clientTask = Task.Run(async () => + { + var sslClient = new SslStream(clientNetStream, true, (_, _, _, _) => true); + await sslClient.AuthenticateAsClientAsync("localhost"); + + // After TLS, read INFO over encrypted stream + var buf = new byte[4096]; + var read = await sslClient.ReadAsync(buf); + var info = System.Text.Encoding.ASCII.GetString(buf, 0, read); + info.ShouldStartWith("INFO "); + + return sslClient; + }); + + var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true); + var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync( + serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None); + + stream.ShouldBeOfType(); + infoSent.ShouldBeTrue(); + + var clientSsl = await clientTask; + + // Verify encrypted communication works + await stream.WriteAsync("PING\r\n"u8.ToArray()); + await stream.FlushAsync(); + + var readBuf = new byte[64]; + var bytesRead = await clientSsl.ReadAsync(readBuf); + var msg = System.Text.Encoding.ASCII.GetString(readBuf, 0, bytesRead); + msg.ShouldBe("PING\r\n"); + + stream.Dispose(); + clientSsl.Dispose(); + } + private static ServerInfo CreateServerInfo() => new() { ServerId = "TEST",