using System.Net; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using NATS.Server; using NATS.Server.Tls; namespace NATS.Server.Tests; public class TlsHelperTests { [Fact] public void LoadCertificate_loads_pem_cert_and_key() { var (certPath, keyPath) = GenerateTestCertFiles(); try { var cert = TlsHelper.LoadCertificate(certPath, keyPath); cert.ShouldNotBeNull(); cert.HasPrivateKey.ShouldBeTrue(); } finally { File.Delete(certPath); File.Delete(keyPath); } } [Fact] public void BuildServerAuthOptions_creates_valid_options() { var (certPath, keyPath) = GenerateTestCertFiles(); try { var opts = new NatsOptions { TlsCert = certPath, TlsKey = keyPath }; var authOpts = TlsHelper.BuildServerAuthOptions(opts); authOpts.ShouldNotBeNull(); authOpts.ServerCertificate.ShouldNotBeNull(); } finally { File.Delete(certPath); File.Delete(keyPath); } } [Fact] public void MatchesPinnedCert_matches_correct_hash() { var (cert, _) = GenerateTestCert(); var hash = TlsHelper.GetCertificateHash(cert); var pinned = new HashSet { hash }; TlsHelper.MatchesPinnedCert(cert, pinned).ShouldBeTrue(); } [Fact] public void MatchesPinnedCert_rejects_wrong_hash() { var (cert, _) = GenerateTestCert(); var pinned = new HashSet { "0000000000000000000000000000000000000000000000000000000000000000" }; TlsHelper.MatchesPinnedCert(cert, pinned).ShouldBeFalse(); } [Fact] public async Task PeekableStream_peeks_and_replays() { var data = "Hello, World!"u8.ToArray(); using var ms = new MemoryStream(data); using var peekable = new PeekableStream(ms); var peeked = await peekable.PeekAsync(1); peeked.Length.ShouldBe(1); peeked[0].ShouldBe((byte)'H'); var buf = new byte[data.Length]; int total = 0; while (total < data.Length) { var read = await peekable.ReadAsync(buf.AsMemory(total)); if (read == 0) break; total += read; } total.ShouldBe(data.Length); buf.ShouldBe(data); } [Fact] public async Task TlsRateLimiter_allows_within_limit() { using var limiter = new TlsRateLimiter(10); var cts = new CancellationTokenSource(TimeSpan.FromSeconds(2)); for (int i = 0; i < 5; i++) await limiter.WaitAsync(cts.Token); } // Public helper methods used by other test classes public static (string certPath, string keyPath) GenerateTestCertFiles() { var (cert, key) = GenerateTestCert(); var certPath = Path.GetTempFileName(); var keyPath = Path.GetTempFileName(); File.WriteAllText(certPath, cert.ExportCertificatePem()); File.WriteAllText(keyPath, key.ExportPkcs8PrivateKeyPem()); return (certPath, keyPath); } public static (X509Certificate2 cert, RSA key) GenerateTestCert() { var key = RSA.Create(2048); var req = new CertificateRequest("CN=localhost", key, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); req.CertificateExtensions.Add(new X509BasicConstraintsExtension(false, false, 0, false)); var sanBuilder = new SubjectAlternativeNameBuilder(); sanBuilder.AddIpAddress(IPAddress.Loopback); sanBuilder.AddDnsName("localhost"); req.CertificateExtensions.Add(sanBuilder.Build()); var cert = req.CreateSelfSigned(DateTimeOffset.UtcNow, DateTimeOffset.UtcNow.AddYears(1)); return (cert, key); } }