203 lines
7.3 KiB
C#
203 lines
7.3 KiB
C#
using System.Net.Security;
|
|
using System.Net.Sockets;
|
|
using System.Security.Cryptography.X509Certificates;
|
|
using System.Text;
|
|
using System.Text.Json;
|
|
using Microsoft.Extensions.Logging;
|
|
using NATS.Server.Protocol;
|
|
|
|
namespace NATS.Server.Tls;
|
|
|
|
public static class TlsConnectionWrapper
|
|
{
|
|
private const byte TlsRecordMarker = 0x16;
|
|
|
|
public static async Task<(Stream stream, bool infoAlreadySent)> NegotiateAsync(
|
|
Socket socket,
|
|
Stream networkStream,
|
|
NatsOptions options,
|
|
SslServerAuthenticationOptions? sslOptions,
|
|
ServerInfo serverInfo,
|
|
ILogger logger,
|
|
CancellationToken ct)
|
|
{
|
|
// Mode 1: No TLS
|
|
if (sslOptions == null || !options.HasTls)
|
|
return (networkStream, false);
|
|
|
|
// Clone to avoid mutating shared instance
|
|
serverInfo = new ServerInfo
|
|
{
|
|
ServerId = serverInfo.ServerId,
|
|
ServerName = serverInfo.ServerName,
|
|
Version = serverInfo.Version,
|
|
Proto = serverInfo.Proto,
|
|
Host = serverInfo.Host,
|
|
Port = serverInfo.Port,
|
|
Headers = serverInfo.Headers,
|
|
MaxPayload = serverInfo.MaxPayload,
|
|
ClientId = serverInfo.ClientId,
|
|
ClientIp = serverInfo.ClientIp,
|
|
};
|
|
|
|
// Mode 3: TLS First
|
|
if (options.TlsHandshakeFirst)
|
|
return await NegotiateTlsFirstAsync(socket, networkStream, options, sslOptions, serverInfo, logger, ct);
|
|
|
|
// Mode 2 & 4: Send INFO first, then decide
|
|
serverInfo.TlsRequired = !options.AllowNonTls;
|
|
serverInfo.TlsAvailable = options.AllowNonTls;
|
|
serverInfo.TlsVerify = options.TlsVerify;
|
|
await SendInfoAsync(networkStream, serverInfo, ct);
|
|
|
|
// Peek first byte to detect TLS
|
|
var peekable = new PeekableStream(networkStream);
|
|
var peeked = await PeekWithTimeoutAsync(peekable, 1, options.TlsTimeout, ct);
|
|
|
|
if (peeked.Length == 0)
|
|
{
|
|
// Client disconnected or timed out
|
|
return (peekable, true);
|
|
}
|
|
|
|
if (peeked[0] == TlsRecordMarker)
|
|
{
|
|
// Client is starting TLS
|
|
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
|
|
try
|
|
{
|
|
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
|
handshakeCts.CancelAfter(options.TlsTimeout);
|
|
|
|
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
|
|
logger.LogDebug("TLS handshake complete: {Protocol} {CipherSuite}",
|
|
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
|
|
|
|
// Validate pinned certs
|
|
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
|
|
{
|
|
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
|
|
{
|
|
logger.LogWarning("Certificate pinning check failed");
|
|
throw new InvalidOperationException("Certificate pinning check failed");
|
|
}
|
|
}
|
|
}
|
|
catch
|
|
{
|
|
sslStream.Dispose();
|
|
throw;
|
|
}
|
|
|
|
return (sslStream, true);
|
|
}
|
|
|
|
// Mode 4: Mixed — client chose plaintext
|
|
if (options.AllowNonTls)
|
|
{
|
|
logger.LogDebug("Client connected without TLS (mixed mode)");
|
|
return (peekable, true);
|
|
}
|
|
|
|
// TLS required but client sent plaintext
|
|
logger.LogWarning("TLS required but client sent plaintext data");
|
|
throw new InvalidOperationException("TLS required");
|
|
}
|
|
|
|
private static async Task<(Stream stream, bool infoAlreadySent)> NegotiateTlsFirstAsync(
|
|
Socket socket,
|
|
Stream networkStream,
|
|
NatsOptions options,
|
|
SslServerAuthenticationOptions sslOptions,
|
|
ServerInfo serverInfo,
|
|
ILogger logger,
|
|
CancellationToken ct)
|
|
{
|
|
// Wait for data with fallback timeout
|
|
var peekable = new PeekableStream(networkStream);
|
|
var peeked = await PeekWithTimeoutAsync(peekable, 1, options.TlsHandshakeFirstFallback, ct);
|
|
|
|
if (peeked.Length > 0 && peeked[0] == TlsRecordMarker)
|
|
{
|
|
// Client started TLS immediately — handshake first, then send INFO
|
|
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
|
|
try
|
|
{
|
|
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
|
handshakeCts.CancelAfter(options.TlsTimeout);
|
|
|
|
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
|
|
logger.LogDebug("TLS-first handshake complete: {Protocol} {CipherSuite}",
|
|
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
|
|
|
|
// Validate pinned certs
|
|
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
|
|
{
|
|
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
|
|
{
|
|
throw new InvalidOperationException("Certificate pinning check failed");
|
|
}
|
|
}
|
|
|
|
// Now send INFO over encrypted stream
|
|
serverInfo.TlsRequired = true;
|
|
serverInfo.TlsVerify = options.TlsVerify;
|
|
await SendInfoAsync(sslStream, serverInfo, ct);
|
|
}
|
|
catch
|
|
{
|
|
sslStream.Dispose();
|
|
throw;
|
|
}
|
|
|
|
return (sslStream, true);
|
|
}
|
|
|
|
// Fallback: timeout expired or non-TLS data — send INFO and negotiate normally
|
|
logger.LogDebug("TLS-first fallback: sending INFO");
|
|
serverInfo.TlsRequired = !options.AllowNonTls;
|
|
serverInfo.TlsAvailable = options.AllowNonTls;
|
|
serverInfo.TlsVerify = options.TlsVerify;
|
|
await SendInfoAsync(peekable, serverInfo, ct);
|
|
|
|
if (peeked.Length == 0)
|
|
{
|
|
// Timeout — INFO was sent, return stream for normal flow
|
|
return (peekable, true);
|
|
}
|
|
|
|
// Non-TLS data received during fallback window
|
|
if (options.AllowNonTls)
|
|
{
|
|
return (peekable, true);
|
|
}
|
|
|
|
// TLS required but got plaintext
|
|
throw new InvalidOperationException("TLS required but client sent plaintext");
|
|
}
|
|
|
|
private static async Task<byte[]> PeekWithTimeoutAsync(
|
|
PeekableStream stream, int count, TimeSpan timeout, CancellationToken ct)
|
|
{
|
|
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
|
cts.CancelAfter(timeout);
|
|
try
|
|
{
|
|
return await stream.PeekAsync(count, cts.Token);
|
|
}
|
|
catch (OperationCanceledException) when (!ct.IsCancellationRequested)
|
|
{
|
|
// Timeout — not a cancellation of the outer token
|
|
return [];
|
|
}
|
|
}
|
|
|
|
private static async Task SendInfoAsync(Stream stream, ServerInfo serverInfo, CancellationToken ct)
|
|
{
|
|
var infoJson = JsonSerializer.Serialize(serverInfo);
|
|
var infoLine = Encoding.ASCII.GetBytes($"INFO {infoJson}\r\n");
|
|
await stream.WriteAsync(infoLine, ct);
|
|
await stream.FlushAsync(ct);
|
|
}
|
|
}
|