fix: address TlsConnectionWrapper review — clone ServerInfo, fix SslStream leak, add TLS-first test

This commit is contained in:
Joseph Doherty
2026-02-22 22:28:19 -05:00
parent a52db677e2
commit 63198ef83b
2 changed files with 111 additions and 29 deletions

View File

@@ -25,6 +25,21 @@ public static class TlsConnectionWrapper
if (sslOptions == null || !options.HasTls)
return (networkStream, false);
// Clone to avoid mutating shared instance
serverInfo = new ServerInfo
{
ServerId = serverInfo.ServerId,
ServerName = serverInfo.ServerName,
Version = serverInfo.Version,
Proto = serverInfo.Proto,
Host = serverInfo.Host,
Port = serverInfo.Port,
Headers = serverInfo.Headers,
MaxPayload = serverInfo.MaxPayload,
ClientId = serverInfo.ClientId,
ClientIp = serverInfo.ClientIp,
};
// Mode 3: TLS First
if (options.TlsHandshakeFirst)
return await NegotiateTlsFirstAsync(socket, networkStream, options, sslOptions, serverInfo, logger, ct);
@@ -49,23 +64,30 @@ public static class TlsConnectionWrapper
{
// Client is starting TLS
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
handshakeCts.CancelAfter(options.TlsTimeout);
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
logger.LogDebug("TLS handshake complete: {Protocol} {CipherSuite}",
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
// Validate pinned certs
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
try
{
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
handshakeCts.CancelAfter(options.TlsTimeout);
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
logger.LogDebug("TLS handshake complete: {Protocol} {CipherSuite}",
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
// Validate pinned certs
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
{
logger.LogWarning("Certificate pinning check failed");
sslStream.Dispose();
throw new InvalidOperationException("Certificate pinning check failed");
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
{
logger.LogWarning("Certificate pinning check failed");
throw new InvalidOperationException("Certificate pinning check failed");
}
}
}
catch
{
sslStream.Dispose();
throw;
}
return (sslStream, true);
}
@@ -99,27 +121,35 @@ public static class TlsConnectionWrapper
{
// Client started TLS immediately — handshake first, then send INFO
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
handshakeCts.CancelAfter(options.TlsTimeout);
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
logger.LogDebug("TLS-first handshake complete: {Protocol} {CipherSuite}",
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
// Validate pinned certs
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
try
{
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
handshakeCts.CancelAfter(options.TlsTimeout);
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
logger.LogDebug("TLS-first handshake complete: {Protocol} {CipherSuite}",
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
// Validate pinned certs
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
{
sslStream.Dispose();
throw new InvalidOperationException("Certificate pinning check failed");
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
{
throw new InvalidOperationException("Certificate pinning check failed");
}
}
// Now send INFO over encrypted stream
serverInfo.TlsRequired = true;
serverInfo.TlsVerify = options.TlsVerify;
await SendInfoAsync(sslStream, serverInfo, ct);
}
catch
{
sslStream.Dispose();
throw;
}
// Now send INFO over encrypted stream
serverInfo.TlsRequired = true;
serverInfo.TlsVerify = options.TlsVerify;
await SendInfoAsync(sslStream, serverInfo, ct);
return (sslStream, true);
}