fix: address TlsConnectionWrapper review — clone ServerInfo, fix SslStream leak, add TLS-first test

This commit is contained in:
Joseph Doherty
2026-02-22 22:28:19 -05:00
parent a52db677e2
commit 63198ef83b
2 changed files with 111 additions and 29 deletions

View File

@@ -25,6 +25,21 @@ public static class TlsConnectionWrapper
if (sslOptions == null || !options.HasTls) if (sslOptions == null || !options.HasTls)
return (networkStream, false); return (networkStream, false);
// Clone to avoid mutating shared instance
serverInfo = new ServerInfo
{
ServerId = serverInfo.ServerId,
ServerName = serverInfo.ServerName,
Version = serverInfo.Version,
Proto = serverInfo.Proto,
Host = serverInfo.Host,
Port = serverInfo.Port,
Headers = serverInfo.Headers,
MaxPayload = serverInfo.MaxPayload,
ClientId = serverInfo.ClientId,
ClientIp = serverInfo.ClientIp,
};
// Mode 3: TLS First // Mode 3: TLS First
if (options.TlsHandshakeFirst) if (options.TlsHandshakeFirst)
return await NegotiateTlsFirstAsync(socket, networkStream, options, sslOptions, serverInfo, logger, ct); return await NegotiateTlsFirstAsync(socket, networkStream, options, sslOptions, serverInfo, logger, ct);
@@ -49,23 +64,30 @@ public static class TlsConnectionWrapper
{ {
// Client is starting TLS // Client is starting TLS
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false); var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct); try
handshakeCts.CancelAfter(options.TlsTimeout);
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
logger.LogDebug("TLS handshake complete: {Protocol} {CipherSuite}",
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
// Validate pinned certs
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
{ {
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts)) using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
handshakeCts.CancelAfter(options.TlsTimeout);
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
logger.LogDebug("TLS handshake complete: {Protocol} {CipherSuite}",
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
// Validate pinned certs
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
{ {
logger.LogWarning("Certificate pinning check failed"); if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
sslStream.Dispose(); {
throw new InvalidOperationException("Certificate pinning check failed"); logger.LogWarning("Certificate pinning check failed");
throw new InvalidOperationException("Certificate pinning check failed");
}
} }
} }
catch
{
sslStream.Dispose();
throw;
}
return (sslStream, true); return (sslStream, true);
} }
@@ -99,27 +121,35 @@ public static class TlsConnectionWrapper
{ {
// Client started TLS immediately — handshake first, then send INFO // Client started TLS immediately — handshake first, then send INFO
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false); var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct); try
handshakeCts.CancelAfter(options.TlsTimeout);
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
logger.LogDebug("TLS-first handshake complete: {Protocol} {CipherSuite}",
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
// Validate pinned certs
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
{ {
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts)) using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
handshakeCts.CancelAfter(options.TlsTimeout);
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
logger.LogDebug("TLS-first handshake complete: {Protocol} {CipherSuite}",
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
// Validate pinned certs
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
{ {
sslStream.Dispose(); if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
throw new InvalidOperationException("Certificate pinning check failed"); {
throw new InvalidOperationException("Certificate pinning check failed");
}
} }
// Now send INFO over encrypted stream
serverInfo.TlsRequired = true;
serverInfo.TlsVerify = options.TlsVerify;
await SendInfoAsync(sslStream, serverInfo, ct);
}
catch
{
sslStream.Dispose();
throw;
} }
// Now send INFO over encrypted stream
serverInfo.TlsRequired = true;
serverInfo.TlsVerify = options.TlsVerify;
await SendInfoAsync(sslStream, serverInfo, ct);
return (sslStream, true); return (sslStream, true);
} }

View File

@@ -177,6 +177,58 @@ public class TlsConnectionWrapperTests
serverNetStream.Dispose(); serverNetStream.Dispose();
} }
[Fact]
public async Task TlsFirst_handshakes_before_sending_info()
{
var (cert, _) = TlsHelperTests.GenerateTestCert();
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy", TlsHandshakeFirst = true };
var sslOpts = new SslServerAuthenticationOptions
{
ServerCertificate = cert,
};
var serverInfo = CreateServerInfo();
// Client side: immediately start TLS (no INFO first)
var clientTask = Task.Run(async () =>
{
var sslClient = new SslStream(clientNetStream, true, (_, _, _, _) => true);
await sslClient.AuthenticateAsClientAsync("localhost");
// After TLS, read INFO over encrypted stream
var buf = new byte[4096];
var read = await sslClient.ReadAsync(buf);
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
info.ShouldStartWith("INFO ");
return sslClient;
});
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
stream.ShouldBeOfType<SslStream>();
infoSent.ShouldBeTrue();
var clientSsl = await clientTask;
// Verify encrypted communication works
await stream.WriteAsync("PING\r\n"u8.ToArray());
await stream.FlushAsync();
var readBuf = new byte[64];
var bytesRead = await clientSsl.ReadAsync(readBuf);
var msg = System.Text.Encoding.ASCII.GetString(readBuf, 0, bytesRead);
msg.ShouldBe("PING\r\n");
stream.Dispose();
clientSsl.Dispose();
}
private static ServerInfo CreateServerInfo() => new() private static ServerInfo CreateServerInfo() => new()
{ {
ServerId = "TEST", ServerId = "TEST",