feat: wire TLS negotiation into NatsServer accept loop
Integrate TLS support into the server's connection accept path: - Add SslServerAuthenticationOptions and TlsRateLimiter fields to NatsServer - Extract AcceptClientAsync method for TLS negotiation, rate limiting, and TLS state extraction (protocol version, cipher suite, peer certificate) - Add InfoAlreadySent flag to NatsClient to skip redundant INFO when TlsConnectionWrapper already sent it during negotiation - Add TlsServerTests verifying TLS connect+INFO and TLS pub/sub
This commit is contained in:
137
tests/NATS.Server.Tests/TlsServerTests.cs
Normal file
137
tests/NATS.Server.Tests/TlsServerTests.cs
Normal file
@@ -0,0 +1,137 @@
|
||||
using System.Net;
|
||||
using System.Net.Security;
|
||||
using System.Net.Sockets;
|
||||
using System.Text;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using NATS.Server;
|
||||
|
||||
namespace NATS.Server.Tests;
|
||||
|
||||
public class TlsServerTests : IAsyncLifetime
|
||||
{
|
||||
private readonly NatsServer _server;
|
||||
private readonly int _port;
|
||||
private readonly CancellationTokenSource _cts = new();
|
||||
private readonly string _certPath;
|
||||
private readonly string _keyPath;
|
||||
|
||||
public TlsServerTests()
|
||||
{
|
||||
_port = GetFreePort();
|
||||
(_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles();
|
||||
_server = new NatsServer(
|
||||
new NatsOptions
|
||||
{
|
||||
Port = _port,
|
||||
TlsCert = _certPath,
|
||||
TlsKey = _keyPath,
|
||||
},
|
||||
NullLoggerFactory.Instance);
|
||||
}
|
||||
|
||||
public async Task InitializeAsync()
|
||||
{
|
||||
_ = _server.StartAsync(_cts.Token);
|
||||
await _server.WaitForReadyAsync();
|
||||
}
|
||||
|
||||
public async Task DisposeAsync()
|
||||
{
|
||||
await _cts.CancelAsync();
|
||||
_server.Dispose();
|
||||
File.Delete(_certPath);
|
||||
File.Delete(_keyPath);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Tls_client_connects_and_receives_info()
|
||||
{
|
||||
using var tcp = new TcpClient();
|
||||
await tcp.ConnectAsync(IPAddress.Loopback, _port);
|
||||
using var netStream = tcp.GetStream();
|
||||
|
||||
// Read INFO (sent before TLS upgrade in Mode 2)
|
||||
var buf = new byte[4096];
|
||||
var read = await netStream.ReadAsync(buf);
|
||||
var info = Encoding.ASCII.GetString(buf, 0, read);
|
||||
info.ShouldStartWith("INFO ");
|
||||
info.ShouldContain("\"tls_required\":true");
|
||||
|
||||
// Upgrade to TLS
|
||||
using var sslStream = new SslStream(netStream, false, (_, _, _, _) => true);
|
||||
await sslStream.AuthenticateAsClientAsync("localhost");
|
||||
|
||||
// Send CONNECT + PING over TLS
|
||||
await sslStream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
|
||||
await sslStream.FlushAsync();
|
||||
|
||||
// Read PONG
|
||||
var pongBuf = new byte[256];
|
||||
read = await sslStream.ReadAsync(pongBuf);
|
||||
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
|
||||
pong.ShouldContain("PONG");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Tls_pubsub_works_over_encrypted_connection()
|
||||
{
|
||||
using var tcp1 = new TcpClient();
|
||||
await tcp1.ConnectAsync(IPAddress.Loopback, _port);
|
||||
using var ssl1 = await UpgradeToTlsAsync(tcp1);
|
||||
|
||||
using var tcp2 = new TcpClient();
|
||||
await tcp2.ConnectAsync(IPAddress.Loopback, _port);
|
||||
using var ssl2 = await UpgradeToTlsAsync(tcp2);
|
||||
|
||||
// Sub on client 1
|
||||
await ssl1.WriteAsync("CONNECT {}\r\nSUB test 1\r\nPING\r\n"u8.ToArray());
|
||||
await ssl1.FlushAsync();
|
||||
|
||||
// Wait for PONG to confirm subscription is registered
|
||||
var pongBuf = new byte[256];
|
||||
var pongRead = await ssl1.ReadAsync(pongBuf);
|
||||
var pongStr = Encoding.ASCII.GetString(pongBuf, 0, pongRead);
|
||||
pongStr.ShouldContain("PONG");
|
||||
|
||||
// Pub on client 2
|
||||
await ssl2.WriteAsync("CONNECT {}\r\nPUB test 5\r\nhello\r\nPING\r\n"u8.ToArray());
|
||||
await ssl2.FlushAsync();
|
||||
|
||||
// Client 1 should receive MSG (may arrive across multiple TLS records)
|
||||
var msg = await ReadUntilAsync(ssl1, "hello");
|
||||
msg.ShouldContain("MSG test 1 5");
|
||||
msg.ShouldContain("hello");
|
||||
}
|
||||
|
||||
private static async Task<string> ReadUntilAsync(Stream stream, string expected, int timeoutMs = 5000)
|
||||
{
|
||||
using var cts = new CancellationTokenSource(timeoutMs);
|
||||
var sb = new StringBuilder();
|
||||
var buf = new byte[4096];
|
||||
while (!sb.ToString().Contains(expected))
|
||||
{
|
||||
var n = await stream.ReadAsync(buf, cts.Token);
|
||||
if (n == 0) break;
|
||||
sb.Append(Encoding.ASCII.GetString(buf, 0, n));
|
||||
}
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
private static async Task<SslStream> UpgradeToTlsAsync(TcpClient tcp)
|
||||
{
|
||||
var netStream = tcp.GetStream();
|
||||
var buf = new byte[4096];
|
||||
_ = await netStream.ReadAsync(buf); // Read INFO (discard)
|
||||
|
||||
var ssl = new SslStream(netStream, false, (_, _, _, _) => true);
|
||||
await ssl.AuthenticateAsClientAsync("localhost");
|
||||
return ssl;
|
||||
}
|
||||
|
||||
private static int GetFreePort()
|
||||
{
|
||||
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
|
||||
return ((IPEndPoint)sock.LocalEndPoint!).Port;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user