feat: add TlsConnectionWrapper with 4-mode TLS negotiation
This commit is contained in:
172
src/NATS.Server/Tls/TlsConnectionWrapper.cs
Normal file
172
src/NATS.Server/Tls/TlsConnectionWrapper.cs
Normal file
@@ -0,0 +1,172 @@
|
||||
using System.Net.Security;
|
||||
using System.Net.Sockets;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using NATS.Server.Protocol;
|
||||
|
||||
namespace NATS.Server.Tls;
|
||||
|
||||
public static class TlsConnectionWrapper
|
||||
{
|
||||
private const byte TlsRecordMarker = 0x16;
|
||||
|
||||
public static async Task<(Stream stream, bool infoAlreadySent)> NegotiateAsync(
|
||||
Socket socket,
|
||||
Stream networkStream,
|
||||
NatsOptions options,
|
||||
SslServerAuthenticationOptions? sslOptions,
|
||||
ServerInfo serverInfo,
|
||||
ILogger logger,
|
||||
CancellationToken ct)
|
||||
{
|
||||
// Mode 1: No TLS
|
||||
if (sslOptions == null || !options.HasTls)
|
||||
return (networkStream, false);
|
||||
|
||||
// Mode 3: TLS First
|
||||
if (options.TlsHandshakeFirst)
|
||||
return await NegotiateTlsFirstAsync(socket, networkStream, options, sslOptions, serverInfo, logger, ct);
|
||||
|
||||
// Mode 2 & 4: Send INFO first, then decide
|
||||
serverInfo.TlsRequired = !options.AllowNonTls;
|
||||
serverInfo.TlsAvailable = options.AllowNonTls;
|
||||
serverInfo.TlsVerify = options.TlsVerify;
|
||||
await SendInfoAsync(networkStream, serverInfo, ct);
|
||||
|
||||
// Peek first byte to detect TLS
|
||||
var peekable = new PeekableStream(networkStream);
|
||||
var peeked = await PeekWithTimeoutAsync(peekable, 1, options.TlsTimeout, ct);
|
||||
|
||||
if (peeked.Length == 0)
|
||||
{
|
||||
// Client disconnected or timed out
|
||||
return (peekable, true);
|
||||
}
|
||||
|
||||
if (peeked[0] == TlsRecordMarker)
|
||||
{
|
||||
// Client is starting TLS
|
||||
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
|
||||
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
||||
handshakeCts.CancelAfter(options.TlsTimeout);
|
||||
|
||||
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
|
||||
logger.LogDebug("TLS handshake complete: {Protocol} {CipherSuite}",
|
||||
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
|
||||
|
||||
// Validate pinned certs
|
||||
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
|
||||
{
|
||||
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
|
||||
{
|
||||
logger.LogWarning("Certificate pinning check failed");
|
||||
sslStream.Dispose();
|
||||
throw new InvalidOperationException("Certificate pinning check failed");
|
||||
}
|
||||
}
|
||||
|
||||
return (sslStream, true);
|
||||
}
|
||||
|
||||
// Mode 4: Mixed — client chose plaintext
|
||||
if (options.AllowNonTls)
|
||||
{
|
||||
logger.LogDebug("Client connected without TLS (mixed mode)");
|
||||
return (peekable, true);
|
||||
}
|
||||
|
||||
// TLS required but client sent plaintext
|
||||
logger.LogWarning("TLS required but client sent plaintext data");
|
||||
throw new InvalidOperationException("TLS required");
|
||||
}
|
||||
|
||||
private static async Task<(Stream stream, bool infoAlreadySent)> NegotiateTlsFirstAsync(
|
||||
Socket socket,
|
||||
Stream networkStream,
|
||||
NatsOptions options,
|
||||
SslServerAuthenticationOptions sslOptions,
|
||||
ServerInfo serverInfo,
|
||||
ILogger logger,
|
||||
CancellationToken ct)
|
||||
{
|
||||
// Wait for data with fallback timeout
|
||||
var peekable = new PeekableStream(networkStream);
|
||||
var peeked = await PeekWithTimeoutAsync(peekable, 1, options.TlsHandshakeFirstFallback, ct);
|
||||
|
||||
if (peeked.Length > 0 && peeked[0] == TlsRecordMarker)
|
||||
{
|
||||
// Client started TLS immediately — handshake first, then send INFO
|
||||
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
|
||||
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
||||
handshakeCts.CancelAfter(options.TlsTimeout);
|
||||
|
||||
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
|
||||
logger.LogDebug("TLS-first handshake complete: {Protocol} {CipherSuite}",
|
||||
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
|
||||
|
||||
// Validate pinned certs
|
||||
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
|
||||
{
|
||||
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
|
||||
{
|
||||
sslStream.Dispose();
|
||||
throw new InvalidOperationException("Certificate pinning check failed");
|
||||
}
|
||||
}
|
||||
|
||||
// Now send INFO over encrypted stream
|
||||
serverInfo.TlsRequired = true;
|
||||
serverInfo.TlsVerify = options.TlsVerify;
|
||||
await SendInfoAsync(sslStream, serverInfo, ct);
|
||||
return (sslStream, true);
|
||||
}
|
||||
|
||||
// Fallback: timeout expired or non-TLS data — send INFO and negotiate normally
|
||||
logger.LogDebug("TLS-first fallback: sending INFO");
|
||||
serverInfo.TlsRequired = !options.AllowNonTls;
|
||||
serverInfo.TlsAvailable = options.AllowNonTls;
|
||||
serverInfo.TlsVerify = options.TlsVerify;
|
||||
await SendInfoAsync(peekable, serverInfo, ct);
|
||||
|
||||
if (peeked.Length == 0)
|
||||
{
|
||||
// Timeout — INFO was sent, return stream for normal flow
|
||||
return (peekable, true);
|
||||
}
|
||||
|
||||
// Non-TLS data received during fallback window
|
||||
if (options.AllowNonTls)
|
||||
{
|
||||
return (peekable, true);
|
||||
}
|
||||
|
||||
// TLS required but got plaintext
|
||||
throw new InvalidOperationException("TLS required but client sent plaintext");
|
||||
}
|
||||
|
||||
private static async Task<byte[]> PeekWithTimeoutAsync(
|
||||
PeekableStream stream, int count, TimeSpan timeout, CancellationToken ct)
|
||||
{
|
||||
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
||||
cts.CancelAfter(timeout);
|
||||
try
|
||||
{
|
||||
return await stream.PeekAsync(count, cts.Token);
|
||||
}
|
||||
catch (OperationCanceledException) when (!ct.IsCancellationRequested)
|
||||
{
|
||||
// Timeout — not a cancellation of the outer token
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
private static async Task SendInfoAsync(Stream stream, ServerInfo serverInfo, CancellationToken ct)
|
||||
{
|
||||
var infoJson = JsonSerializer.Serialize(serverInfo);
|
||||
var infoLine = Encoding.ASCII.GetBytes($"INFO {infoJson}\r\n");
|
||||
await stream.WriteAsync(infoLine, ct);
|
||||
await stream.FlushAsync(ct);
|
||||
}
|
||||
}
|
||||
202
tests/NATS.Server.Tests/TlsConnectionWrapperTests.cs
Normal file
202
tests/NATS.Server.Tests/TlsConnectionWrapperTests.cs
Normal file
@@ -0,0 +1,202 @@
|
||||
using System.Net;
|
||||
using System.Net.Security;
|
||||
using System.Net.Sockets;
|
||||
using System.Security.Cryptography;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using NATS.Server;
|
||||
using NATS.Server.Protocol;
|
||||
using NATS.Server.Tls;
|
||||
|
||||
namespace NATS.Server.Tests;
|
||||
|
||||
public class TlsConnectionWrapperTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task NoTls_returns_plain_stream()
|
||||
{
|
||||
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
||||
using var serverStream = new NetworkStream(serverSocket, ownsSocket: true);
|
||||
using var clientStream = new NetworkStream(clientSocket, ownsSocket: true);
|
||||
|
||||
var opts = new NatsOptions(); // No TLS configured
|
||||
var serverInfo = CreateServerInfo();
|
||||
|
||||
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
|
||||
serverSocket, serverStream, opts, null, serverInfo, NullLogger.Instance, CancellationToken.None);
|
||||
|
||||
stream.ShouldBe(serverStream); // Same stream, no wrapping
|
||||
infoSent.ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TlsRequired_upgrades_to_ssl()
|
||||
{
|
||||
var (cert, _) = TlsHelperTests.GenerateTestCert();
|
||||
|
||||
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
||||
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
|
||||
|
||||
var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy" };
|
||||
var sslOpts = new SslServerAuthenticationOptions
|
||||
{
|
||||
ServerCertificate = cert,
|
||||
};
|
||||
var serverInfo = CreateServerInfo();
|
||||
|
||||
// Client side: read INFO then start TLS
|
||||
var clientTask = Task.Run(async () =>
|
||||
{
|
||||
// Read INFO line
|
||||
var buf = new byte[4096];
|
||||
var read = await clientNetStream.ReadAsync(buf);
|
||||
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
|
||||
info.ShouldStartWith("INFO ");
|
||||
|
||||
// Upgrade to TLS
|
||||
var sslClient = new SslStream(clientNetStream, true,
|
||||
(_, _, _, _) => true); // Trust all for testing
|
||||
await sslClient.AuthenticateAsClientAsync("localhost");
|
||||
return sslClient;
|
||||
});
|
||||
|
||||
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
|
||||
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
|
||||
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
|
||||
|
||||
stream.ShouldBeOfType<SslStream>();
|
||||
infoSent.ShouldBeTrue();
|
||||
|
||||
var clientSsl = await clientTask;
|
||||
|
||||
// Verify encrypted communication works
|
||||
await stream.WriteAsync("PING\r\n"u8.ToArray());
|
||||
await stream.FlushAsync();
|
||||
|
||||
var readBuf = new byte[64];
|
||||
var bytesRead = await clientSsl.ReadAsync(readBuf);
|
||||
var msg = System.Text.Encoding.ASCII.GetString(readBuf, 0, bytesRead);
|
||||
msg.ShouldBe("PING\r\n");
|
||||
|
||||
stream.Dispose();
|
||||
clientSsl.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task MixedMode_allows_plaintext_when_AllowNonTls()
|
||||
{
|
||||
var (cert, _) = TlsHelperTests.GenerateTestCert();
|
||||
|
||||
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
||||
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
|
||||
|
||||
var opts = new NatsOptions
|
||||
{
|
||||
TlsCert = "dummy",
|
||||
TlsKey = "dummy",
|
||||
AllowNonTls = true,
|
||||
TlsTimeout = TimeSpan.FromSeconds(2),
|
||||
};
|
||||
var sslOpts = new SslServerAuthenticationOptions
|
||||
{
|
||||
ServerCertificate = cert,
|
||||
};
|
||||
var serverInfo = CreateServerInfo();
|
||||
|
||||
// Client side: read INFO then send plaintext (not TLS)
|
||||
var clientTask = Task.Run(async () =>
|
||||
{
|
||||
var buf = new byte[4096];
|
||||
var read = await clientNetStream.ReadAsync(buf);
|
||||
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
|
||||
info.ShouldStartWith("INFO ");
|
||||
|
||||
// Send plaintext CONNECT (not a TLS handshake)
|
||||
var connectLine = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n");
|
||||
await clientNetStream.WriteAsync(connectLine);
|
||||
await clientNetStream.FlushAsync();
|
||||
});
|
||||
|
||||
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
|
||||
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
|
||||
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
|
||||
|
||||
await clientTask;
|
||||
|
||||
// In mixed mode with plaintext client, we get a PeekableStream, not SslStream
|
||||
stream.ShouldBeOfType<PeekableStream>();
|
||||
infoSent.ShouldBeTrue();
|
||||
|
||||
stream.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TlsRequired_rejects_plaintext()
|
||||
{
|
||||
var (cert, _) = TlsHelperTests.GenerateTestCert();
|
||||
|
||||
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
||||
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
|
||||
|
||||
var opts = new NatsOptions
|
||||
{
|
||||
TlsCert = "dummy",
|
||||
TlsKey = "dummy",
|
||||
AllowNonTls = false,
|
||||
TlsTimeout = TimeSpan.FromSeconds(2),
|
||||
};
|
||||
var sslOpts = new SslServerAuthenticationOptions
|
||||
{
|
||||
ServerCertificate = cert,
|
||||
};
|
||||
var serverInfo = CreateServerInfo();
|
||||
|
||||
// Client side: read INFO then send plaintext
|
||||
var clientTask = Task.Run(async () =>
|
||||
{
|
||||
var buf = new byte[4096];
|
||||
var read = await clientNetStream.ReadAsync(buf);
|
||||
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
|
||||
info.ShouldStartWith("INFO ");
|
||||
|
||||
// Send plaintext data (first byte is 'C', not 0x16 TLS marker)
|
||||
var connectLine = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n");
|
||||
await clientNetStream.WriteAsync(connectLine);
|
||||
await clientNetStream.FlushAsync();
|
||||
});
|
||||
|
||||
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
|
||||
|
||||
await Should.ThrowAsync<InvalidOperationException>(async () =>
|
||||
{
|
||||
await TlsConnectionWrapper.NegotiateAsync(
|
||||
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
|
||||
});
|
||||
|
||||
await clientTask;
|
||||
serverNetStream.Dispose();
|
||||
}
|
||||
|
||||
private static ServerInfo CreateServerInfo() => new()
|
||||
{
|
||||
ServerId = "TEST",
|
||||
ServerName = "test",
|
||||
Version = NatsProtocol.Version,
|
||||
Host = "127.0.0.1",
|
||||
Port = 4222,
|
||||
};
|
||||
|
||||
private static async Task<(Socket server, Socket client)> CreateSocketPairAsync()
|
||||
{
|
||||
using var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
|
||||
listener.Listen(1);
|
||||
var port = ((IPEndPoint)listener.LocalEndPoint!).Port;
|
||||
|
||||
var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
await client.ConnectAsync(new IPEndPoint(IPAddress.Loopback, port));
|
||||
var server = await listener.AcceptAsync();
|
||||
|
||||
return (server, client);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user