Files
natsdotnet/tests/NATS.Server.Tests/TlsServerTests.cs
Joseph Doherty 3b6bd08248 feat: add TLS mixed mode tests and monitoring TLS field verification
Add TlsMixedModeTests verifying that a server with AllowNonTls=true
accepts both plaintext and TLS clients on the same port. Add
MonitorTlsTests verifying that /connz reports TlsVersion and
TlsCipherSuite for TLS-connected clients.
2026-02-22 22:40:03 -05:00

226 lines
7.2 KiB
C#

using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Text;
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Server;
namespace NATS.Server.Tests;
public class TlsServerTests : IAsyncLifetime
{
private readonly NatsServer _server;
private readonly int _port;
private readonly CancellationTokenSource _cts = new();
private readonly string _certPath;
private readonly string _keyPath;
public TlsServerTests()
{
_port = GetFreePort();
(_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles();
_server = new NatsServer(
new NatsOptions
{
Port = _port,
TlsCert = _certPath,
TlsKey = _keyPath,
},
NullLoggerFactory.Instance);
}
public async Task InitializeAsync()
{
_ = _server.StartAsync(_cts.Token);
await _server.WaitForReadyAsync();
}
public async Task DisposeAsync()
{
await _cts.CancelAsync();
_server.Dispose();
File.Delete(_certPath);
File.Delete(_keyPath);
}
[Fact]
public async Task Tls_client_connects_and_receives_info()
{
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, _port);
using var netStream = tcp.GetStream();
// Read INFO (sent before TLS upgrade in Mode 2)
var buf = new byte[4096];
var read = await netStream.ReadAsync(buf);
var info = Encoding.ASCII.GetString(buf, 0, read);
info.ShouldStartWith("INFO ");
info.ShouldContain("\"tls_required\":true");
// Upgrade to TLS
using var sslStream = new SslStream(netStream, false, (_, _, _, _) => true);
await sslStream.AuthenticateAsClientAsync("localhost");
// Send CONNECT + PING over TLS
await sslStream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
await sslStream.FlushAsync();
// Read PONG
var pongBuf = new byte[256];
read = await sslStream.ReadAsync(pongBuf);
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
pong.ShouldContain("PONG");
}
[Fact]
public async Task Tls_pubsub_works_over_encrypted_connection()
{
using var tcp1 = new TcpClient();
await tcp1.ConnectAsync(IPAddress.Loopback, _port);
using var ssl1 = await UpgradeToTlsAsync(tcp1);
using var tcp2 = new TcpClient();
await tcp2.ConnectAsync(IPAddress.Loopback, _port);
using var ssl2 = await UpgradeToTlsAsync(tcp2);
// Sub on client 1
await ssl1.WriteAsync("CONNECT {}\r\nSUB test 1\r\nPING\r\n"u8.ToArray());
await ssl1.FlushAsync();
// Wait for PONG to confirm subscription is registered
var pongBuf = new byte[256];
var pongRead = await ssl1.ReadAsync(pongBuf);
var pongStr = Encoding.ASCII.GetString(pongBuf, 0, pongRead);
pongStr.ShouldContain("PONG");
// Pub on client 2
await ssl2.WriteAsync("CONNECT {}\r\nPUB test 5\r\nhello\r\nPING\r\n"u8.ToArray());
await ssl2.FlushAsync();
// Client 1 should receive MSG (may arrive across multiple TLS records)
var msg = await ReadUntilAsync(ssl1, "hello");
msg.ShouldContain("MSG test 1 5");
msg.ShouldContain("hello");
}
private static async Task<string> ReadUntilAsync(Stream stream, string expected, int timeoutMs = 5000)
{
using var cts = new CancellationTokenSource(timeoutMs);
var sb = new StringBuilder();
var buf = new byte[4096];
while (!sb.ToString().Contains(expected))
{
var n = await stream.ReadAsync(buf, cts.Token);
if (n == 0) break;
sb.Append(Encoding.ASCII.GetString(buf, 0, n));
}
return sb.ToString();
}
private static async Task<SslStream> UpgradeToTlsAsync(TcpClient tcp)
{
var netStream = tcp.GetStream();
var buf = new byte[4096];
_ = await netStream.ReadAsync(buf); // Read INFO (discard)
var ssl = new SslStream(netStream, false, (_, _, _, _) => true);
await ssl.AuthenticateAsClientAsync("localhost");
return ssl;
}
private static int GetFreePort()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
return ((IPEndPoint)sock.LocalEndPoint!).Port;
}
}
public class TlsMixedModeTests : IAsyncLifetime
{
private readonly NatsServer _server;
private readonly int _port;
private readonly CancellationTokenSource _cts = new();
private readonly string _certPath;
private readonly string _keyPath;
public TlsMixedModeTests()
{
_port = GetFreePort();
(_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles();
_server = new NatsServer(
new NatsOptions
{
Port = _port,
TlsCert = _certPath,
TlsKey = _keyPath,
AllowNonTls = true,
},
NullLoggerFactory.Instance);
}
public async Task InitializeAsync()
{
_ = _server.StartAsync(_cts.Token);
await _server.WaitForReadyAsync();
}
public async Task DisposeAsync()
{
await _cts.CancelAsync();
_server.Dispose();
File.Delete(_certPath);
File.Delete(_keyPath);
}
[Fact]
public async Task Mixed_mode_accepts_plain_client()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port));
using var stream = new NetworkStream(sock);
var buf = new byte[4096];
var read = await stream.ReadAsync(buf);
var info = Encoding.ASCII.GetString(buf, 0, read);
info.ShouldContain("\"tls_available\":true");
await stream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
await stream.FlushAsync();
var pongBuf = new byte[64];
read = await stream.ReadAsync(pongBuf);
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
pong.ShouldContain("PONG");
}
[Fact]
public async Task Mixed_mode_accepts_tls_client()
{
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, _port);
using var netStream = tcp.GetStream();
var buf = new byte[4096];
_ = await netStream.ReadAsync(buf); // Read INFO
using var ssl = new SslStream(netStream, false, (_, _, _, _) => true);
await ssl.AuthenticateAsClientAsync("localhost");
await ssl.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
await ssl.FlushAsync();
var pongBuf = new byte[64];
var read = await ssl.ReadAsync(pongBuf);
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
pong.ShouldContain("PONG");
}
private static int GetFreePort()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
return ((IPEndPoint)sock.LocalEndPoint!).Port;
}
}