Files
natsdotnet/tests/NATS.Server.Transport.Tests/TlsServerTests.cs
Joseph Doherty d2c04fcca5 refactor: extract NATS.Server.Transport.Tests project
Move TLS, OCSP, WebSocket, Networking, and IO test files from
NATS.Server.Tests into a dedicated NATS.Server.Transport.Tests
project. Update namespaces, replace private GetFreePort/ReadUntilAsync
with shared TestUtilities helpers, extract TestCertHelper to
TestUtilities, and replace Task.Delay polling loops with
PollHelper.WaitUntilAsync/YieldForAsync for proper synchronization.
2026-03-12 14:57:35 -04:00

199 lines
6.2 KiB
C#

using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Text;
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Server;
using NATS.Server.TestUtilities;
namespace NATS.Server.Transport.Tests;
public class TlsServerTests : IAsyncLifetime
{
private readonly NatsServer _server;
private readonly int _port;
private readonly CancellationTokenSource _cts = new();
private readonly string _certPath;
private readonly string _keyPath;
public TlsServerTests()
{
_port = TestPortAllocator.GetFreePort();
(_certPath, _keyPath) = TestCertHelper.GenerateTestCertFiles();
_server = new NatsServer(
new NatsOptions
{
Port = _port,
TlsCert = _certPath,
TlsKey = _keyPath,
},
NullLoggerFactory.Instance);
}
public async Task InitializeAsync()
{
_ = _server.StartAsync(_cts.Token);
await _server.WaitForReadyAsync();
}
public async Task DisposeAsync()
{
await _cts.CancelAsync();
_server.Dispose();
File.Delete(_certPath);
File.Delete(_keyPath);
}
[Fact]
public async Task Tls_client_connects_and_receives_info()
{
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, _port);
using var netStream = tcp.GetStream();
// Read INFO (sent before TLS upgrade in Mode 2)
var buf = new byte[4096];
var read = await netStream.ReadAsync(buf);
var info = Encoding.ASCII.GetString(buf, 0, read);
info.ShouldStartWith("INFO ");
info.ShouldContain("\"tls_required\":true");
// Upgrade to TLS
using var sslStream = new SslStream(netStream, false, (_, _, _, _) => true);
await sslStream.AuthenticateAsClientAsync("localhost");
// Send CONNECT + PING over TLS
await sslStream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
await sslStream.FlushAsync();
// Read PONG
var pongBuf = new byte[256];
read = await sslStream.ReadAsync(pongBuf);
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
pong.ShouldContain("PONG");
}
[Fact]
public async Task Tls_pubsub_works_over_encrypted_connection()
{
using var tcp1 = new TcpClient();
await tcp1.ConnectAsync(IPAddress.Loopback, _port);
using var ssl1 = await UpgradeToTlsAsync(tcp1);
using var tcp2 = new TcpClient();
await tcp2.ConnectAsync(IPAddress.Loopback, _port);
using var ssl2 = await UpgradeToTlsAsync(tcp2);
// Sub on client 1
await ssl1.WriteAsync("CONNECT {}\r\nSUB test 1\r\nPING\r\n"u8.ToArray());
await ssl1.FlushAsync();
// Wait for PONG to confirm subscription is registered
var pongBuf = new byte[256];
var pongRead = await ssl1.ReadAsync(pongBuf);
var pongStr = Encoding.ASCII.GetString(pongBuf, 0, pongRead);
pongStr.ShouldContain("PONG");
// Pub on client 2
await ssl2.WriteAsync("CONNECT {}\r\nPUB test 5\r\nhello\r\nPING\r\n"u8.ToArray());
await ssl2.FlushAsync();
// Client 1 should receive MSG (may arrive across multiple TLS records)
var msg = await SocketTestHelper.ReadUntilAsync(ssl1, "hello");
msg.ShouldContain("MSG test 1 5");
msg.ShouldContain("hello");
}
private static async Task<SslStream> UpgradeToTlsAsync(TcpClient tcp)
{
var netStream = tcp.GetStream();
var buf = new byte[4096];
_ = await netStream.ReadAsync(buf); // Read INFO (discard)
var ssl = new SslStream(netStream, false, (_, _, _, _) => true);
await ssl.AuthenticateAsClientAsync("localhost");
return ssl;
}
}
public class TlsMixedModeTests : IAsyncLifetime
{
private readonly NatsServer _server;
private readonly int _port;
private readonly CancellationTokenSource _cts = new();
private readonly string _certPath;
private readonly string _keyPath;
public TlsMixedModeTests()
{
_port = TestPortAllocator.GetFreePort();
(_certPath, _keyPath) = TestCertHelper.GenerateTestCertFiles();
_server = new NatsServer(
new NatsOptions
{
Port = _port,
TlsCert = _certPath,
TlsKey = _keyPath,
AllowNonTls = true,
},
NullLoggerFactory.Instance);
}
public async Task InitializeAsync()
{
_ = _server.StartAsync(_cts.Token);
await _server.WaitForReadyAsync();
}
public async Task DisposeAsync()
{
await _cts.CancelAsync();
_server.Dispose();
File.Delete(_certPath);
File.Delete(_keyPath);
}
[Fact]
public async Task Mixed_mode_accepts_plain_client()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port));
using var stream = new NetworkStream(sock);
var buf = new byte[4096];
var read = await stream.ReadAsync(buf);
var info = Encoding.ASCII.GetString(buf, 0, read);
info.ShouldContain("\"tls_available\":true");
await stream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
await stream.FlushAsync();
var pongBuf = new byte[64];
read = await stream.ReadAsync(pongBuf);
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
pong.ShouldContain("PONG");
}
[Fact]
public async Task Mixed_mode_accepts_tls_client()
{
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, _port);
using var netStream = tcp.GetStream();
var buf = new byte[4096];
_ = await netStream.ReadAsync(buf); // Read INFO
using var ssl = new SslStream(netStream, false, (_, _, _, _) => true);
await ssl.AuthenticateAsClientAsync("localhost");
await ssl.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
await ssl.FlushAsync();
var pongBuf = new byte[64];
var read = await ssl.ReadAsync(pongBuf);
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
pong.ShouldContain("PONG");
}
}