feat: add TlsHelper, PeekableStream, and TlsRateLimiter

Add TLS utility classes for certificate loading, peekable stream for TLS
detection, token-bucket rate limiter for handshake throttling, and
TlsConnectionState for post-handshake info. Add TlsState property to
NatsClient. Fix X509Certificate2 constructor usage for .NET 10 compat.
This commit is contained in:
Joseph Doherty
2026-02-22 22:13:53 -05:00
parent 045c12cce7
commit f6b38df291
6 changed files with 283 additions and 0 deletions

View File

@@ -7,6 +7,7 @@ using System.Text.Json;
using Microsoft.Extensions.Logging;
using NATS.Server.Protocol;
using NATS.Server.Subscriptions;
using NATS.Server.Tls;
namespace NATS.Server;
@@ -55,6 +56,8 @@ public sealed class NatsClient : IDisposable
private int _pingsOut;
private long _lastIn;
public TlsConnectionState? TlsState { get; set; }
public IReadOnlyDictionary<string, Subscription> Subscriptions => _subs;
public NatsClient(ulong id, Stream stream, Socket socket, NatsOptions options, ServerInfo serverInfo,

View File

@@ -0,0 +1,71 @@
namespace NATS.Server.Tls;
public sealed class PeekableStream : Stream
{
private readonly Stream _inner;
private byte[]? _peekedBytes;
private int _peekedOffset;
private int _peekedCount;
public PeekableStream(Stream inner) => _inner = inner;
public async Task<byte[]> PeekAsync(int count, CancellationToken ct = default)
{
var buf = new byte[count];
int read = await _inner.ReadAsync(buf.AsMemory(0, count), ct);
if (read < count) Array.Resize(ref buf, read);
_peekedBytes = buf;
_peekedOffset = 0;
_peekedCount = read;
return buf;
}
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken ct = default)
{
if (_peekedBytes != null && _peekedOffset < _peekedCount)
{
int available = _peekedCount - _peekedOffset;
int toCopy = Math.Min(available, buffer.Length);
_peekedBytes.AsMemory(_peekedOffset, toCopy).CopyTo(buffer);
_peekedOffset += toCopy;
if (_peekedOffset >= _peekedCount) _peekedBytes = null;
return toCopy;
}
return await _inner.ReadAsync(buffer, ct);
}
public override int Read(byte[] buffer, int offset, int count)
{
if (_peekedBytes != null && _peekedOffset < _peekedCount)
{
int available = _peekedCount - _peekedOffset;
int toCopy = Math.Min(available, count);
Array.Copy(_peekedBytes, _peekedOffset, buffer, offset, toCopy);
_peekedOffset += toCopy;
if (_peekedOffset >= _peekedCount) _peekedBytes = null;
return toCopy;
}
return _inner.Read(buffer, offset, count);
}
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken ct)
=> ReadAsync(buffer.AsMemory(offset, count), ct).AsTask();
// Write passthrough
public override void Write(byte[] buffer, int offset, int count) => _inner.Write(buffer, offset, count);
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ct) => _inner.WriteAsync(buffer, offset, count, ct);
public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken ct = default) => _inner.WriteAsync(buffer, ct);
public override void Flush() => _inner.Flush();
public override Task FlushAsync(CancellationToken ct) => _inner.FlushAsync(ct);
// Required Stream overrides
public override bool CanRead => _inner.CanRead;
public override bool CanSeek => false;
public override bool CanWrite => _inner.CanWrite;
public override long Length => throw new NotSupportedException();
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
public override void SetLength(long value) => throw new NotSupportedException();
protected override void Dispose(bool disposing) { if (disposing) _inner.Dispose(); base.Dispose(disposing); }
}

View File

@@ -0,0 +1,9 @@
using System.Security.Cryptography.X509Certificates;
namespace NATS.Server.Tls;
public sealed record TlsConnectionState(
string? TlsVersion,
string? CipherSuite,
X509Certificate2? PeerCert
);

View File

@@ -0,0 +1,65 @@
using System.Net.Security;
using System.Security.Authentication;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
namespace NATS.Server.Tls;
public static class TlsHelper
{
public static X509Certificate2 LoadCertificate(string certPath, string? keyPath)
{
if (keyPath != null)
return X509Certificate2.CreateFromPemFile(certPath, keyPath);
return X509CertificateLoader.LoadCertificateFromFile(certPath);
}
public static X509Certificate2Collection LoadCaCertificates(string caPath)
{
var collection = new X509Certificate2Collection();
collection.ImportFromPemFile(caPath);
return collection;
}
public static SslServerAuthenticationOptions BuildServerAuthOptions(NatsOptions opts)
{
var cert = LoadCertificate(opts.TlsCert!, opts.TlsKey);
var authOpts = new SslServerAuthenticationOptions
{
ServerCertificate = cert,
EnabledSslProtocols = opts.TlsMinVersion,
ClientCertificateRequired = opts.TlsVerify,
};
if (opts.TlsVerify && opts.TlsCaCert != null)
{
var caCerts = LoadCaCertificates(opts.TlsCaCert);
authOpts.RemoteCertificateValidationCallback = (_, cert, chain, errors) =>
{
if (cert == null) return false;
using var chain2 = new X509Chain();
chain2.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust;
foreach (var ca in caCerts)
chain2.ChainPolicy.CustomTrustStore.Add(ca);
chain2.ChainPolicy.RevocationMode = X509RevocationMode.NoCheck;
var cert2 = cert as X509Certificate2 ?? X509CertificateLoader.LoadCertificate(cert.GetRawCertData());
return chain2.Build(cert2);
};
}
return authOpts;
}
public static string GetCertificateHash(X509Certificate2 cert)
{
var spki = cert.PublicKey.ExportSubjectPublicKeyInfo();
var hash = SHA256.HashData(spki);
return Convert.ToHexStringLower(hash);
}
public static bool MatchesPinnedCert(X509Certificate2 cert, HashSet<string> pinned)
{
var hash = GetCertificateHash(cert);
return pinned.Contains(hash);
}
}

View File

@@ -0,0 +1,25 @@
namespace NATS.Server.Tls;
public sealed class TlsRateLimiter : IDisposable
{
private readonly SemaphoreSlim _semaphore;
private readonly Timer _refillTimer;
private readonly int _tokensPerSecond;
public TlsRateLimiter(long tokensPerSecond)
{
_tokensPerSecond = (int)Math.Max(1, tokensPerSecond);
_semaphore = new SemaphoreSlim(_tokensPerSecond, _tokensPerSecond);
_refillTimer = new Timer(Refill, null, TimeSpan.FromSeconds(1), TimeSpan.FromSeconds(1));
}
private void Refill(object? state)
{
int toRelease = _tokensPerSecond - _semaphore.CurrentCount;
if (toRelease > 0) _semaphore.Release(toRelease);
}
public Task WaitAsync(CancellationToken ct) => _semaphore.WaitAsync(ct);
public void Dispose() { _refillTimer.Dispose(); _semaphore.Dispose(); }
}

View File

@@ -0,0 +1,110 @@
using System.Net;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using NATS.Server;
using NATS.Server.Tls;
namespace NATS.Server.Tests;
public class TlsHelperTests
{
[Fact]
public void LoadCertificate_loads_pem_cert_and_key()
{
var (certPath, keyPath) = GenerateTestCertFiles();
try
{
var cert = TlsHelper.LoadCertificate(certPath, keyPath);
cert.ShouldNotBeNull();
cert.HasPrivateKey.ShouldBeTrue();
}
finally { File.Delete(certPath); File.Delete(keyPath); }
}
[Fact]
public void BuildServerAuthOptions_creates_valid_options()
{
var (certPath, keyPath) = GenerateTestCertFiles();
try
{
var opts = new NatsOptions { TlsCert = certPath, TlsKey = keyPath };
var authOpts = TlsHelper.BuildServerAuthOptions(opts);
authOpts.ShouldNotBeNull();
authOpts.ServerCertificate.ShouldNotBeNull();
}
finally { File.Delete(certPath); File.Delete(keyPath); }
}
[Fact]
public void MatchesPinnedCert_matches_correct_hash()
{
var (cert, _) = GenerateTestCert();
var hash = TlsHelper.GetCertificateHash(cert);
var pinned = new HashSet<string> { hash };
TlsHelper.MatchesPinnedCert(cert, pinned).ShouldBeTrue();
}
[Fact]
public void MatchesPinnedCert_rejects_wrong_hash()
{
var (cert, _) = GenerateTestCert();
var pinned = new HashSet<string> { "0000000000000000000000000000000000000000000000000000000000000000" };
TlsHelper.MatchesPinnedCert(cert, pinned).ShouldBeFalse();
}
[Fact]
public async Task PeekableStream_peeks_and_replays()
{
var data = "Hello, World!"u8.ToArray();
using var ms = new MemoryStream(data);
using var peekable = new PeekableStream(ms);
var peeked = await peekable.PeekAsync(1);
peeked.Length.ShouldBe(1);
peeked[0].ShouldBe((byte)'H');
var buf = new byte[data.Length];
int total = 0;
while (total < data.Length)
{
var read = await peekable.ReadAsync(buf.AsMemory(total));
if (read == 0) break;
total += read;
}
total.ShouldBe(data.Length);
buf.ShouldBe(data);
}
[Fact]
public async Task TlsRateLimiter_allows_within_limit()
{
using var limiter = new TlsRateLimiter(10);
var cts = new CancellationTokenSource(TimeSpan.FromSeconds(2));
for (int i = 0; i < 5; i++)
await limiter.WaitAsync(cts.Token);
}
// Public helper methods used by other test classes
public static (string certPath, string keyPath) GenerateTestCertFiles()
{
var (cert, key) = GenerateTestCert();
var certPath = Path.GetTempFileName();
var keyPath = Path.GetTempFileName();
File.WriteAllText(certPath, cert.ExportCertificatePem());
File.WriteAllText(keyPath, key.ExportPkcs8PrivateKeyPem());
return (certPath, keyPath);
}
public static (X509Certificate2 cert, RSA key) GenerateTestCert()
{
var key = RSA.Create(2048);
var req = new CertificateRequest("CN=localhost", key, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
req.CertificateExtensions.Add(new X509BasicConstraintsExtension(false, false, 0, false));
var sanBuilder = new SubjectAlternativeNameBuilder();
sanBuilder.AddIpAddress(IPAddress.Loopback);
sanBuilder.AddDnsName("localhost");
req.CertificateExtensions.Add(sanBuilder.Build());
var cert = req.CreateSelfSigned(DateTimeOffset.UtcNow, DateTimeOffset.UtcNow.AddYears(1));
return (cert, key);
}
}