diff --git a/src/NATS.Server/NatsClient.cs b/src/NATS.Server/NatsClient.cs index 88e78f9..2978e25 100644 --- a/src/NATS.Server/NatsClient.cs +++ b/src/NATS.Server/NatsClient.cs @@ -7,6 +7,7 @@ using System.Text.Json; using Microsoft.Extensions.Logging; using NATS.Server.Protocol; using NATS.Server.Subscriptions; +using NATS.Server.Tls; namespace NATS.Server; @@ -55,6 +56,8 @@ public sealed class NatsClient : IDisposable private int _pingsOut; private long _lastIn; + public TlsConnectionState? TlsState { get; set; } + public IReadOnlyDictionary Subscriptions => _subs; public NatsClient(ulong id, Stream stream, Socket socket, NatsOptions options, ServerInfo serverInfo, diff --git a/src/NATS.Server/Tls/PeekableStream.cs b/src/NATS.Server/Tls/PeekableStream.cs new file mode 100644 index 0000000..29abf07 --- /dev/null +++ b/src/NATS.Server/Tls/PeekableStream.cs @@ -0,0 +1,71 @@ +namespace NATS.Server.Tls; + +public sealed class PeekableStream : Stream +{ + private readonly Stream _inner; + private byte[]? _peekedBytes; + private int _peekedOffset; + private int _peekedCount; + + public PeekableStream(Stream inner) => _inner = inner; + + public async Task PeekAsync(int count, CancellationToken ct = default) + { + var buf = new byte[count]; + int read = await _inner.ReadAsync(buf.AsMemory(0, count), ct); + if (read < count) Array.Resize(ref buf, read); + _peekedBytes = buf; + _peekedOffset = 0; + _peekedCount = read; + return buf; + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken ct = default) + { + if (_peekedBytes != null && _peekedOffset < _peekedCount) + { + int available = _peekedCount - _peekedOffset; + int toCopy = Math.Min(available, buffer.Length); + _peekedBytes.AsMemory(_peekedOffset, toCopy).CopyTo(buffer); + _peekedOffset += toCopy; + if (_peekedOffset >= _peekedCount) _peekedBytes = null; + return toCopy; + } + return await _inner.ReadAsync(buffer, ct); + } + + public override int Read(byte[] buffer, int offset, int count) + { + if (_peekedBytes != null && _peekedOffset < _peekedCount) + { + int available = _peekedCount - _peekedOffset; + int toCopy = Math.Min(available, count); + Array.Copy(_peekedBytes, _peekedOffset, buffer, offset, toCopy); + _peekedOffset += toCopy; + if (_peekedOffset >= _peekedCount) _peekedBytes = null; + return toCopy; + } + return _inner.Read(buffer, offset, count); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken ct) + => ReadAsync(buffer.AsMemory(offset, count), ct).AsTask(); + + // Write passthrough + public override void Write(byte[] buffer, int offset, int count) => _inner.Write(buffer, offset, count); + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ct) => _inner.WriteAsync(buffer, offset, count, ct); + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken ct = default) => _inner.WriteAsync(buffer, ct); + public override void Flush() => _inner.Flush(); + public override Task FlushAsync(CancellationToken ct) => _inner.FlushAsync(ct); + + // Required Stream overrides + public override bool CanRead => _inner.CanRead; + public override bool CanSeek => false; + public override bool CanWrite => _inner.CanWrite; + public override long Length => throw new NotSupportedException(); + public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + + protected override void Dispose(bool disposing) { if (disposing) _inner.Dispose(); base.Dispose(disposing); } +} diff --git a/src/NATS.Server/Tls/TlsConnectionState.cs b/src/NATS.Server/Tls/TlsConnectionState.cs new file mode 100644 index 0000000..0fe788a --- /dev/null +++ b/src/NATS.Server/Tls/TlsConnectionState.cs @@ -0,0 +1,9 @@ +using System.Security.Cryptography.X509Certificates; + +namespace NATS.Server.Tls; + +public sealed record TlsConnectionState( + string? TlsVersion, + string? CipherSuite, + X509Certificate2? PeerCert +); diff --git a/src/NATS.Server/Tls/TlsHelper.cs b/src/NATS.Server/Tls/TlsHelper.cs new file mode 100644 index 0000000..cdc5ef6 --- /dev/null +++ b/src/NATS.Server/Tls/TlsHelper.cs @@ -0,0 +1,65 @@ +using System.Net.Security; +using System.Security.Authentication; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; + +namespace NATS.Server.Tls; + +public static class TlsHelper +{ + public static X509Certificate2 LoadCertificate(string certPath, string? keyPath) + { + if (keyPath != null) + return X509Certificate2.CreateFromPemFile(certPath, keyPath); + return X509CertificateLoader.LoadCertificateFromFile(certPath); + } + + public static X509Certificate2Collection LoadCaCertificates(string caPath) + { + var collection = new X509Certificate2Collection(); + collection.ImportFromPemFile(caPath); + return collection; + } + + public static SslServerAuthenticationOptions BuildServerAuthOptions(NatsOptions opts) + { + var cert = LoadCertificate(opts.TlsCert!, opts.TlsKey); + var authOpts = new SslServerAuthenticationOptions + { + ServerCertificate = cert, + EnabledSslProtocols = opts.TlsMinVersion, + ClientCertificateRequired = opts.TlsVerify, + }; + + if (opts.TlsVerify && opts.TlsCaCert != null) + { + var caCerts = LoadCaCertificates(opts.TlsCaCert); + authOpts.RemoteCertificateValidationCallback = (_, cert, chain, errors) => + { + if (cert == null) return false; + using var chain2 = new X509Chain(); + chain2.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust; + foreach (var ca in caCerts) + chain2.ChainPolicy.CustomTrustStore.Add(ca); + chain2.ChainPolicy.RevocationMode = X509RevocationMode.NoCheck; + var cert2 = cert as X509Certificate2 ?? X509CertificateLoader.LoadCertificate(cert.GetRawCertData()); + return chain2.Build(cert2); + }; + } + + return authOpts; + } + + public static string GetCertificateHash(X509Certificate2 cert) + { + var spki = cert.PublicKey.ExportSubjectPublicKeyInfo(); + var hash = SHA256.HashData(spki); + return Convert.ToHexStringLower(hash); + } + + public static bool MatchesPinnedCert(X509Certificate2 cert, HashSet pinned) + { + var hash = GetCertificateHash(cert); + return pinned.Contains(hash); + } +} diff --git a/src/NATS.Server/Tls/TlsRateLimiter.cs b/src/NATS.Server/Tls/TlsRateLimiter.cs new file mode 100644 index 0000000..75741a5 --- /dev/null +++ b/src/NATS.Server/Tls/TlsRateLimiter.cs @@ -0,0 +1,25 @@ +namespace NATS.Server.Tls; + +public sealed class TlsRateLimiter : IDisposable +{ + private readonly SemaphoreSlim _semaphore; + private readonly Timer _refillTimer; + private readonly int _tokensPerSecond; + + public TlsRateLimiter(long tokensPerSecond) + { + _tokensPerSecond = (int)Math.Max(1, tokensPerSecond); + _semaphore = new SemaphoreSlim(_tokensPerSecond, _tokensPerSecond); + _refillTimer = new Timer(Refill, null, TimeSpan.FromSeconds(1), TimeSpan.FromSeconds(1)); + } + + private void Refill(object? state) + { + int toRelease = _tokensPerSecond - _semaphore.CurrentCount; + if (toRelease > 0) _semaphore.Release(toRelease); + } + + public Task WaitAsync(CancellationToken ct) => _semaphore.WaitAsync(ct); + + public void Dispose() { _refillTimer.Dispose(); _semaphore.Dispose(); } +} diff --git a/tests/NATS.Server.Tests/TlsHelperTests.cs b/tests/NATS.Server.Tests/TlsHelperTests.cs new file mode 100644 index 0000000..c8d1cfa --- /dev/null +++ b/tests/NATS.Server.Tests/TlsHelperTests.cs @@ -0,0 +1,110 @@ +using System.Net; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using NATS.Server; +using NATS.Server.Tls; + +namespace NATS.Server.Tests; + +public class TlsHelperTests +{ + [Fact] + public void LoadCertificate_loads_pem_cert_and_key() + { + var (certPath, keyPath) = GenerateTestCertFiles(); + try + { + var cert = TlsHelper.LoadCertificate(certPath, keyPath); + cert.ShouldNotBeNull(); + cert.HasPrivateKey.ShouldBeTrue(); + } + finally { File.Delete(certPath); File.Delete(keyPath); } + } + + [Fact] + public void BuildServerAuthOptions_creates_valid_options() + { + var (certPath, keyPath) = GenerateTestCertFiles(); + try + { + var opts = new NatsOptions { TlsCert = certPath, TlsKey = keyPath }; + var authOpts = TlsHelper.BuildServerAuthOptions(opts); + authOpts.ShouldNotBeNull(); + authOpts.ServerCertificate.ShouldNotBeNull(); + } + finally { File.Delete(certPath); File.Delete(keyPath); } + } + + [Fact] + public void MatchesPinnedCert_matches_correct_hash() + { + var (cert, _) = GenerateTestCert(); + var hash = TlsHelper.GetCertificateHash(cert); + var pinned = new HashSet { hash }; + TlsHelper.MatchesPinnedCert(cert, pinned).ShouldBeTrue(); + } + + [Fact] + public void MatchesPinnedCert_rejects_wrong_hash() + { + var (cert, _) = GenerateTestCert(); + var pinned = new HashSet { "0000000000000000000000000000000000000000000000000000000000000000" }; + TlsHelper.MatchesPinnedCert(cert, pinned).ShouldBeFalse(); + } + + [Fact] + public async Task PeekableStream_peeks_and_replays() + { + var data = "Hello, World!"u8.ToArray(); + using var ms = new MemoryStream(data); + using var peekable = new PeekableStream(ms); + + var peeked = await peekable.PeekAsync(1); + peeked.Length.ShouldBe(1); + peeked[0].ShouldBe((byte)'H'); + + var buf = new byte[data.Length]; + int total = 0; + while (total < data.Length) + { + var read = await peekable.ReadAsync(buf.AsMemory(total)); + if (read == 0) break; + total += read; + } + total.ShouldBe(data.Length); + buf.ShouldBe(data); + } + + [Fact] + public async Task TlsRateLimiter_allows_within_limit() + { + using var limiter = new TlsRateLimiter(10); + var cts = new CancellationTokenSource(TimeSpan.FromSeconds(2)); + for (int i = 0; i < 5; i++) + await limiter.WaitAsync(cts.Token); + } + + // Public helper methods used by other test classes + public static (string certPath, string keyPath) GenerateTestCertFiles() + { + var (cert, key) = GenerateTestCert(); + var certPath = Path.GetTempFileName(); + var keyPath = Path.GetTempFileName(); + File.WriteAllText(certPath, cert.ExportCertificatePem()); + File.WriteAllText(keyPath, key.ExportPkcs8PrivateKeyPem()); + return (certPath, keyPath); + } + + public static (X509Certificate2 cert, RSA key) GenerateTestCert() + { + var key = RSA.Create(2048); + var req = new CertificateRequest("CN=localhost", key, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + req.CertificateExtensions.Add(new X509BasicConstraintsExtension(false, false, 0, false)); + var sanBuilder = new SubjectAlternativeNameBuilder(); + sanBuilder.AddIpAddress(IPAddress.Loopback); + sanBuilder.AddDnsName("localhost"); + req.CertificateExtensions.Add(sanBuilder.Build()); + var cert = req.CreateSelfSigned(DateTimeOffset.UtcNow, DateTimeOffset.UtcNow.AddYears(1)); + return (cert, key); + } +}