using System.Net.Security; using System.Net.Sockets; using System.Security.Cryptography.X509Certificates; using System.Text; using System.Text.Json; using Microsoft.Extensions.Logging; using NATS.Server.Protocol; namespace NATS.Server.Tls; public static class TlsConnectionWrapper { private const byte TlsRecordMarker = 0x16; public static async Task<(Stream stream, bool infoAlreadySent)> NegotiateAsync( Socket socket, Stream networkStream, NatsOptions options, SslServerAuthenticationOptions? sslOptions, ServerInfo serverInfo, ILogger logger, CancellationToken ct) { // Mode 1: No TLS if (sslOptions == null || !options.HasTls) return (networkStream, false); // Clone to avoid mutating shared instance serverInfo = new ServerInfo { ServerId = serverInfo.ServerId, ServerName = serverInfo.ServerName, Version = serverInfo.Version, Proto = serverInfo.Proto, Host = serverInfo.Host, Port = serverInfo.Port, Headers = serverInfo.Headers, MaxPayload = serverInfo.MaxPayload, ClientId = serverInfo.ClientId, ClientIp = serverInfo.ClientIp, }; // Mode 3: TLS First if (options.TlsHandshakeFirst) return await NegotiateTlsFirstAsync(socket, networkStream, options, sslOptions, serverInfo, logger, ct); // Mode 2 & 4: Send INFO first, then decide serverInfo.TlsRequired = !options.AllowNonTls; serverInfo.TlsAvailable = options.AllowNonTls; serverInfo.TlsVerify = options.TlsVerify; await SendInfoAsync(networkStream, serverInfo, ct); // Peek first byte to detect TLS var peekable = new PeekableStream(networkStream); var peeked = await PeekWithTimeoutAsync(peekable, 1, options.TlsTimeout, ct); if (peeked.Length == 0) { // Client disconnected or timed out return (peekable, true); } if (peeked[0] == TlsRecordMarker) { // Client is starting TLS var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false); try { using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct); handshakeCts.CancelAfter(options.TlsTimeout); await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token); logger.LogDebug("TLS handshake complete: {Protocol} {CipherSuite}", sslStream.SslProtocol, sslStream.NegotiatedCipherSuite); // Validate pinned certs if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert) { if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts)) { logger.LogWarning("Certificate pinning check failed"); throw new InvalidOperationException("Certificate pinning check failed"); } } } catch { sslStream.Dispose(); throw; } return (sslStream, true); } // Mode 4: Mixed — client chose plaintext if (options.AllowNonTls) { logger.LogDebug("Client connected without TLS (mixed mode)"); return (peekable, true); } // TLS required but client sent plaintext logger.LogWarning("TLS required but client sent plaintext data"); throw new InvalidOperationException("TLS required"); } private static async Task<(Stream stream, bool infoAlreadySent)> NegotiateTlsFirstAsync( Socket socket, Stream networkStream, NatsOptions options, SslServerAuthenticationOptions sslOptions, ServerInfo serverInfo, ILogger logger, CancellationToken ct) { // Wait for data with fallback timeout var peekable = new PeekableStream(networkStream); var peeked = await PeekWithTimeoutAsync(peekable, 1, options.TlsHandshakeFirstFallback, ct); if (peeked.Length > 0 && peeked[0] == TlsRecordMarker) { // Client started TLS immediately — handshake first, then send INFO var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false); try { using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct); handshakeCts.CancelAfter(options.TlsTimeout); await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token); logger.LogDebug("TLS-first handshake complete: {Protocol} {CipherSuite}", sslStream.SslProtocol, sslStream.NegotiatedCipherSuite); // Validate pinned certs if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert) { if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts)) { throw new InvalidOperationException("Certificate pinning check failed"); } } // Now send INFO over encrypted stream serverInfo.TlsRequired = true; serverInfo.TlsVerify = options.TlsVerify; await SendInfoAsync(sslStream, serverInfo, ct); } catch { sslStream.Dispose(); throw; } return (sslStream, true); } // Fallback: timeout expired or non-TLS data — send INFO and negotiate normally logger.LogDebug("TLS-first fallback: sending INFO"); serverInfo.TlsRequired = !options.AllowNonTls; serverInfo.TlsAvailable = options.AllowNonTls; serverInfo.TlsVerify = options.TlsVerify; await SendInfoAsync(peekable, serverInfo, ct); if (peeked.Length == 0) { // Timeout — INFO was sent, return stream for normal flow return (peekable, true); } // Non-TLS data received during fallback window if (options.AllowNonTls) { return (peekable, true); } // TLS required but got plaintext throw new InvalidOperationException("TLS required but client sent plaintext"); } private static async Task PeekWithTimeoutAsync( PeekableStream stream, int count, TimeSpan timeout, CancellationToken ct) { using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); cts.CancelAfter(timeout); try { return await stream.PeekAsync(count, cts.Token); } catch (OperationCanceledException) when (!ct.IsCancellationRequested) { // Timeout — not a cancellation of the outer token return []; } } private static async Task SendInfoAsync(Stream stream, ServerInfo serverInfo, CancellationToken ct) { var infoJson = JsonSerializer.Serialize(serverInfo); var infoLine = Encoding.ASCII.GetBytes($"INFO {infoJson}\r\n"); await stream.WriteAsync(infoLine, ct); await stream.FlushAsync(ct); } }