feat: add TlsConnectionWrapper with 4-mode TLS negotiation

This commit is contained in:
Joseph Doherty
2026-02-22 22:21:11 -05:00
parent f2badc3488
commit 0409acc745
2 changed files with 374 additions and 0 deletions

View File

@@ -0,0 +1,172 @@
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Text.Json;
using Microsoft.Extensions.Logging;
using NATS.Server.Protocol;
namespace NATS.Server.Tls;
public static class TlsConnectionWrapper
{
private const byte TlsRecordMarker = 0x16;
public static async Task<(Stream stream, bool infoAlreadySent)> NegotiateAsync(
Socket socket,
Stream networkStream,
NatsOptions options,
SslServerAuthenticationOptions? sslOptions,
ServerInfo serverInfo,
ILogger logger,
CancellationToken ct)
{
// Mode 1: No TLS
if (sslOptions == null || !options.HasTls)
return (networkStream, false);
// Mode 3: TLS First
if (options.TlsHandshakeFirst)
return await NegotiateTlsFirstAsync(socket, networkStream, options, sslOptions, serverInfo, logger, ct);
// Mode 2 & 4: Send INFO first, then decide
serverInfo.TlsRequired = !options.AllowNonTls;
serverInfo.TlsAvailable = options.AllowNonTls;
serverInfo.TlsVerify = options.TlsVerify;
await SendInfoAsync(networkStream, serverInfo, ct);
// Peek first byte to detect TLS
var peekable = new PeekableStream(networkStream);
var peeked = await PeekWithTimeoutAsync(peekable, 1, options.TlsTimeout, ct);
if (peeked.Length == 0)
{
// Client disconnected or timed out
return (peekable, true);
}
if (peeked[0] == TlsRecordMarker)
{
// Client is starting TLS
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
handshakeCts.CancelAfter(options.TlsTimeout);
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
logger.LogDebug("TLS handshake complete: {Protocol} {CipherSuite}",
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
// Validate pinned certs
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
{
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
{
logger.LogWarning("Certificate pinning check failed");
sslStream.Dispose();
throw new InvalidOperationException("Certificate pinning check failed");
}
}
return (sslStream, true);
}
// Mode 4: Mixed — client chose plaintext
if (options.AllowNonTls)
{
logger.LogDebug("Client connected without TLS (mixed mode)");
return (peekable, true);
}
// TLS required but client sent plaintext
logger.LogWarning("TLS required but client sent plaintext data");
throw new InvalidOperationException("TLS required");
}
private static async Task<(Stream stream, bool infoAlreadySent)> NegotiateTlsFirstAsync(
Socket socket,
Stream networkStream,
NatsOptions options,
SslServerAuthenticationOptions sslOptions,
ServerInfo serverInfo,
ILogger logger,
CancellationToken ct)
{
// Wait for data with fallback timeout
var peekable = new PeekableStream(networkStream);
var peeked = await PeekWithTimeoutAsync(peekable, 1, options.TlsHandshakeFirstFallback, ct);
if (peeked.Length > 0 && peeked[0] == TlsRecordMarker)
{
// Client started TLS immediately — handshake first, then send INFO
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
handshakeCts.CancelAfter(options.TlsTimeout);
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
logger.LogDebug("TLS-first handshake complete: {Protocol} {CipherSuite}",
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
// Validate pinned certs
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
{
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
{
sslStream.Dispose();
throw new InvalidOperationException("Certificate pinning check failed");
}
}
// Now send INFO over encrypted stream
serverInfo.TlsRequired = true;
serverInfo.TlsVerify = options.TlsVerify;
await SendInfoAsync(sslStream, serverInfo, ct);
return (sslStream, true);
}
// Fallback: timeout expired or non-TLS data — send INFO and negotiate normally
logger.LogDebug("TLS-first fallback: sending INFO");
serverInfo.TlsRequired = !options.AllowNonTls;
serverInfo.TlsAvailable = options.AllowNonTls;
serverInfo.TlsVerify = options.TlsVerify;
await SendInfoAsync(peekable, serverInfo, ct);
if (peeked.Length == 0)
{
// Timeout — INFO was sent, return stream for normal flow
return (peekable, true);
}
// Non-TLS data received during fallback window
if (options.AllowNonTls)
{
return (peekable, true);
}
// TLS required but got plaintext
throw new InvalidOperationException("TLS required but client sent plaintext");
}
private static async Task<byte[]> PeekWithTimeoutAsync(
PeekableStream stream, int count, TimeSpan timeout, CancellationToken ct)
{
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
cts.CancelAfter(timeout);
try
{
return await stream.PeekAsync(count, cts.Token);
}
catch (OperationCanceledException) when (!ct.IsCancellationRequested)
{
// Timeout — not a cancellation of the outer token
return [];
}
}
private static async Task SendInfoAsync(Stream stream, ServerInfo serverInfo, CancellationToken ct)
{
var infoJson = JsonSerializer.Serialize(serverInfo);
var infoLine = Encoding.ASCII.GetBytes($"INFO {infoJson}\r\n");
await stream.WriteAsync(infoLine, ct);
await stream.FlushAsync(ct);
}
}