Move TLS, OCSP, WebSocket, Networking, and IO test files from NATS.Server.Tests into a dedicated NATS.Server.Transport.Tests project. Update namespaces, replace private GetFreePort/ReadUntilAsync with shared TestUtilities helpers, extract TestCertHelper to TestUtilities, and replace Task.Delay polling loops with PollHelper.WaitUntilAsync/YieldForAsync for proper synchronization.
256 lines
9.1 KiB
C#
256 lines
9.1 KiB
C#
using System.Net;
|
|
using System.Net.Security;
|
|
using System.Net.Sockets;
|
|
using System.Security.Cryptography;
|
|
using System.Security.Cryptography.X509Certificates;
|
|
using Microsoft.Extensions.Logging.Abstractions;
|
|
using NATS.Server;
|
|
using NATS.Server.Protocol;
|
|
using NATS.Server.TestUtilities;
|
|
using NATS.Server.Tls;
|
|
|
|
namespace NATS.Server.Transport.Tests;
|
|
|
|
public class TlsConnectionWrapperTests
|
|
{
|
|
[Fact]
|
|
public async Task NoTls_returns_plain_stream()
|
|
{
|
|
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
|
using var serverStream = new NetworkStream(serverSocket, ownsSocket: true);
|
|
using var clientStream = new NetworkStream(clientSocket, ownsSocket: true);
|
|
|
|
var opts = new NatsOptions(); // No TLS configured
|
|
var serverInfo = CreateServerInfo();
|
|
|
|
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
|
|
serverSocket, serverStream, opts, null, serverInfo, NullLogger.Instance, CancellationToken.None);
|
|
|
|
stream.ShouldBe(serverStream); // Same stream, no wrapping
|
|
infoSent.ShouldBeFalse();
|
|
}
|
|
|
|
[Fact]
|
|
public async Task TlsRequired_upgrades_to_ssl()
|
|
{
|
|
var (cert, _) = TestCertHelper.GenerateTestCert();
|
|
|
|
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
|
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
|
|
|
|
var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy" };
|
|
var sslOpts = new SslServerAuthenticationOptions
|
|
{
|
|
ServerCertificate = cert,
|
|
};
|
|
var serverInfo = CreateServerInfo();
|
|
|
|
// Client side: read INFO then start TLS
|
|
var clientTask = Task.Run(async () =>
|
|
{
|
|
// Read INFO line
|
|
var buf = new byte[4096];
|
|
var read = await clientNetStream.ReadAsync(buf);
|
|
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
|
|
info.ShouldStartWith("INFO ");
|
|
|
|
// Upgrade to TLS
|
|
var sslClient = new SslStream(clientNetStream, true,
|
|
(_, _, _, _) => true); // Trust all for testing
|
|
await sslClient.AuthenticateAsClientAsync("localhost");
|
|
return sslClient;
|
|
});
|
|
|
|
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
|
|
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
|
|
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
|
|
|
|
stream.ShouldBeOfType<SslStream>();
|
|
infoSent.ShouldBeTrue();
|
|
|
|
var clientSsl = await clientTask;
|
|
|
|
// Verify encrypted communication works
|
|
await stream.WriteAsync("PING\r\n"u8.ToArray());
|
|
await stream.FlushAsync();
|
|
|
|
var readBuf = new byte[64];
|
|
var bytesRead = await clientSsl.ReadAsync(readBuf);
|
|
var msg = System.Text.Encoding.ASCII.GetString(readBuf, 0, bytesRead);
|
|
msg.ShouldBe("PING\r\n");
|
|
|
|
stream.Dispose();
|
|
clientSsl.Dispose();
|
|
}
|
|
|
|
[Fact]
|
|
public async Task MixedMode_allows_plaintext_when_AllowNonTls()
|
|
{
|
|
var (cert, _) = TestCertHelper.GenerateTestCert();
|
|
|
|
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
|
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
|
|
|
|
var opts = new NatsOptions
|
|
{
|
|
TlsCert = "dummy",
|
|
TlsKey = "dummy",
|
|
AllowNonTls = true,
|
|
TlsTimeout = TimeSpan.FromSeconds(2),
|
|
};
|
|
var sslOpts = new SslServerAuthenticationOptions
|
|
{
|
|
ServerCertificate = cert,
|
|
};
|
|
var serverInfo = CreateServerInfo();
|
|
|
|
// Client side: read INFO then send plaintext (not TLS)
|
|
var clientTask = Task.Run(async () =>
|
|
{
|
|
var buf = new byte[4096];
|
|
var read = await clientNetStream.ReadAsync(buf);
|
|
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
|
|
info.ShouldStartWith("INFO ");
|
|
|
|
// Send plaintext CONNECT (not a TLS handshake)
|
|
var connectLine = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n");
|
|
await clientNetStream.WriteAsync(connectLine);
|
|
await clientNetStream.FlushAsync();
|
|
});
|
|
|
|
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
|
|
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
|
|
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
|
|
|
|
await clientTask;
|
|
|
|
// In mixed mode with plaintext client, we get a PeekableStream, not SslStream
|
|
stream.ShouldBeOfType<PeekableStream>();
|
|
infoSent.ShouldBeTrue();
|
|
|
|
stream.Dispose();
|
|
}
|
|
|
|
[Fact]
|
|
public async Task TlsRequired_rejects_plaintext()
|
|
{
|
|
var (cert, _) = TestCertHelper.GenerateTestCert();
|
|
|
|
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
|
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
|
|
|
|
var opts = new NatsOptions
|
|
{
|
|
TlsCert = "dummy",
|
|
TlsKey = "dummy",
|
|
AllowNonTls = false,
|
|
TlsTimeout = TimeSpan.FromSeconds(2),
|
|
};
|
|
var sslOpts = new SslServerAuthenticationOptions
|
|
{
|
|
ServerCertificate = cert,
|
|
};
|
|
var serverInfo = CreateServerInfo();
|
|
|
|
// Client side: read INFO then send plaintext
|
|
var clientTask = Task.Run(async () =>
|
|
{
|
|
var buf = new byte[4096];
|
|
var read = await clientNetStream.ReadAsync(buf);
|
|
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
|
|
info.ShouldStartWith("INFO ");
|
|
|
|
// Send plaintext data (first byte is 'C', not 0x16 TLS marker)
|
|
var connectLine = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n");
|
|
await clientNetStream.WriteAsync(connectLine);
|
|
await clientNetStream.FlushAsync();
|
|
});
|
|
|
|
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
|
|
|
|
await Should.ThrowAsync<InvalidOperationException>(async () =>
|
|
{
|
|
await TlsConnectionWrapper.NegotiateAsync(
|
|
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
|
|
});
|
|
|
|
await clientTask;
|
|
serverNetStream.Dispose();
|
|
}
|
|
|
|
[Fact]
|
|
public async Task TlsFirst_handshakes_before_sending_info()
|
|
{
|
|
var (cert, _) = TestCertHelper.GenerateTestCert();
|
|
|
|
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
|
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
|
|
|
|
var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy", TlsHandshakeFirst = true };
|
|
var sslOpts = new SslServerAuthenticationOptions
|
|
{
|
|
ServerCertificate = cert,
|
|
};
|
|
var serverInfo = CreateServerInfo();
|
|
|
|
// Client side: immediately start TLS (no INFO first)
|
|
var clientTask = Task.Run(async () =>
|
|
{
|
|
var sslClient = new SslStream(clientNetStream, true, (_, _, _, _) => true);
|
|
await sslClient.AuthenticateAsClientAsync("localhost");
|
|
|
|
// After TLS, read INFO over encrypted stream
|
|
var buf = new byte[4096];
|
|
var read = await sslClient.ReadAsync(buf);
|
|
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
|
|
info.ShouldStartWith("INFO ");
|
|
|
|
return sslClient;
|
|
});
|
|
|
|
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
|
|
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
|
|
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
|
|
|
|
stream.ShouldBeOfType<SslStream>();
|
|
infoSent.ShouldBeTrue();
|
|
|
|
var clientSsl = await clientTask;
|
|
|
|
// Verify encrypted communication works
|
|
await stream.WriteAsync("PING\r\n"u8.ToArray());
|
|
await stream.FlushAsync();
|
|
|
|
var readBuf = new byte[64];
|
|
var bytesRead = await clientSsl.ReadAsync(readBuf);
|
|
var msg = System.Text.Encoding.ASCII.GetString(readBuf, 0, bytesRead);
|
|
msg.ShouldBe("PING\r\n");
|
|
|
|
stream.Dispose();
|
|
clientSsl.Dispose();
|
|
}
|
|
|
|
private static ServerInfo CreateServerInfo() => new()
|
|
{
|
|
ServerId = "TEST",
|
|
ServerName = "test",
|
|
Version = NatsProtocol.Version,
|
|
Host = "127.0.0.1",
|
|
Port = 4222,
|
|
};
|
|
|
|
private static async Task<(Socket server, Socket client)> CreateSocketPairAsync()
|
|
{
|
|
using var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
|
listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
|
|
listener.Listen(1);
|
|
var port = ((IPEndPoint)listener.LocalEndPoint!).Port;
|
|
|
|
var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
|
await client.ConnectAsync(new IPEndPoint(IPAddress.Loopback, port));
|
|
var server = await listener.AcceptAsync();
|
|
|
|
return (server, client);
|
|
}
|
|
}
|