fix: address TlsConnectionWrapper review — clone ServerInfo, fix SslStream leak, add TLS-first test
This commit is contained in:
@@ -25,6 +25,21 @@ public static class TlsConnectionWrapper
|
|||||||
if (sslOptions == null || !options.HasTls)
|
if (sslOptions == null || !options.HasTls)
|
||||||
return (networkStream, false);
|
return (networkStream, false);
|
||||||
|
|
||||||
|
// Clone to avoid mutating shared instance
|
||||||
|
serverInfo = new ServerInfo
|
||||||
|
{
|
||||||
|
ServerId = serverInfo.ServerId,
|
||||||
|
ServerName = serverInfo.ServerName,
|
||||||
|
Version = serverInfo.Version,
|
||||||
|
Proto = serverInfo.Proto,
|
||||||
|
Host = serverInfo.Host,
|
||||||
|
Port = serverInfo.Port,
|
||||||
|
Headers = serverInfo.Headers,
|
||||||
|
MaxPayload = serverInfo.MaxPayload,
|
||||||
|
ClientId = serverInfo.ClientId,
|
||||||
|
ClientIp = serverInfo.ClientIp,
|
||||||
|
};
|
||||||
|
|
||||||
// Mode 3: TLS First
|
// Mode 3: TLS First
|
||||||
if (options.TlsHandshakeFirst)
|
if (options.TlsHandshakeFirst)
|
||||||
return await NegotiateTlsFirstAsync(socket, networkStream, options, sslOptions, serverInfo, logger, ct);
|
return await NegotiateTlsFirstAsync(socket, networkStream, options, sslOptions, serverInfo, logger, ct);
|
||||||
@@ -49,6 +64,8 @@ public static class TlsConnectionWrapper
|
|||||||
{
|
{
|
||||||
// Client is starting TLS
|
// Client is starting TLS
|
||||||
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
|
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
|
||||||
|
try
|
||||||
|
{
|
||||||
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
||||||
handshakeCts.CancelAfter(options.TlsTimeout);
|
handshakeCts.CancelAfter(options.TlsTimeout);
|
||||||
|
|
||||||
@@ -62,10 +79,15 @@ public static class TlsConnectionWrapper
|
|||||||
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
|
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
|
||||||
{
|
{
|
||||||
logger.LogWarning("Certificate pinning check failed");
|
logger.LogWarning("Certificate pinning check failed");
|
||||||
sslStream.Dispose();
|
|
||||||
throw new InvalidOperationException("Certificate pinning check failed");
|
throw new InvalidOperationException("Certificate pinning check failed");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
catch
|
||||||
|
{
|
||||||
|
sslStream.Dispose();
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
|
||||||
return (sslStream, true);
|
return (sslStream, true);
|
||||||
}
|
}
|
||||||
@@ -99,6 +121,8 @@ public static class TlsConnectionWrapper
|
|||||||
{
|
{
|
||||||
// Client started TLS immediately — handshake first, then send INFO
|
// Client started TLS immediately — handshake first, then send INFO
|
||||||
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
|
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
|
||||||
|
try
|
||||||
|
{
|
||||||
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
||||||
handshakeCts.CancelAfter(options.TlsTimeout);
|
handshakeCts.CancelAfter(options.TlsTimeout);
|
||||||
|
|
||||||
@@ -111,7 +135,6 @@ public static class TlsConnectionWrapper
|
|||||||
{
|
{
|
||||||
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
|
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
|
||||||
{
|
{
|
||||||
sslStream.Dispose();
|
|
||||||
throw new InvalidOperationException("Certificate pinning check failed");
|
throw new InvalidOperationException("Certificate pinning check failed");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -120,6 +143,13 @@ public static class TlsConnectionWrapper
|
|||||||
serverInfo.TlsRequired = true;
|
serverInfo.TlsRequired = true;
|
||||||
serverInfo.TlsVerify = options.TlsVerify;
|
serverInfo.TlsVerify = options.TlsVerify;
|
||||||
await SendInfoAsync(sslStream, serverInfo, ct);
|
await SendInfoAsync(sslStream, serverInfo, ct);
|
||||||
|
}
|
||||||
|
catch
|
||||||
|
{
|
||||||
|
sslStream.Dispose();
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
|
||||||
return (sslStream, true);
|
return (sslStream, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -177,6 +177,58 @@ public class TlsConnectionWrapperTests
|
|||||||
serverNetStream.Dispose();
|
serverNetStream.Dispose();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task TlsFirst_handshakes_before_sending_info()
|
||||||
|
{
|
||||||
|
var (cert, _) = TlsHelperTests.GenerateTestCert();
|
||||||
|
|
||||||
|
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
||||||
|
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
|
||||||
|
|
||||||
|
var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy", TlsHandshakeFirst = true };
|
||||||
|
var sslOpts = new SslServerAuthenticationOptions
|
||||||
|
{
|
||||||
|
ServerCertificate = cert,
|
||||||
|
};
|
||||||
|
var serverInfo = CreateServerInfo();
|
||||||
|
|
||||||
|
// Client side: immediately start TLS (no INFO first)
|
||||||
|
var clientTask = Task.Run(async () =>
|
||||||
|
{
|
||||||
|
var sslClient = new SslStream(clientNetStream, true, (_, _, _, _) => true);
|
||||||
|
await sslClient.AuthenticateAsClientAsync("localhost");
|
||||||
|
|
||||||
|
// After TLS, read INFO over encrypted stream
|
||||||
|
var buf = new byte[4096];
|
||||||
|
var read = await sslClient.ReadAsync(buf);
|
||||||
|
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
|
||||||
|
info.ShouldStartWith("INFO ");
|
||||||
|
|
||||||
|
return sslClient;
|
||||||
|
});
|
||||||
|
|
||||||
|
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
|
||||||
|
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
|
||||||
|
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
|
||||||
|
|
||||||
|
stream.ShouldBeOfType<SslStream>();
|
||||||
|
infoSent.ShouldBeTrue();
|
||||||
|
|
||||||
|
var clientSsl = await clientTask;
|
||||||
|
|
||||||
|
// Verify encrypted communication works
|
||||||
|
await stream.WriteAsync("PING\r\n"u8.ToArray());
|
||||||
|
await stream.FlushAsync();
|
||||||
|
|
||||||
|
var readBuf = new byte[64];
|
||||||
|
var bytesRead = await clientSsl.ReadAsync(readBuf);
|
||||||
|
var msg = System.Text.Encoding.ASCII.GetString(readBuf, 0, bytesRead);
|
||||||
|
msg.ShouldBe("PING\r\n");
|
||||||
|
|
||||||
|
stream.Dispose();
|
||||||
|
clientSsl.Dispose();
|
||||||
|
}
|
||||||
|
|
||||||
private static ServerInfo CreateServerInfo() => new()
|
private static ServerInfo CreateServerInfo() => new()
|
||||||
{
|
{
|
||||||
ServerId = "TEST",
|
ServerId = "TEST",
|
||||||
|
|||||||
Reference in New Issue
Block a user