diff --git a/docs/test_parity.db b/docs/test_parity.db index 74d6068..44559f2 100644 Binary files a/docs/test_parity.db and b/docs/test_parity.db differ diff --git a/src/NATS.Server/Configuration/ConfigReloader.cs b/src/NATS.Server/Configuration/ConfigReloader.cs index 67d0e7c..5406f86 100644 --- a/src/NATS.Server/Configuration/ConfigReloader.cs +++ b/src/NATS.Server/Configuration/ConfigReloader.cs @@ -1,6 +1,9 @@ // Port of Go server/reload.go — config diffing, validation, and CLI override merging // for hot reload support. Reference: golang/nats-server/server/reload.go. +using System.Net.Security; +using NATS.Server.Tls; + namespace NATS.Server.Configuration; /// @@ -459,6 +462,29 @@ public static class ConfigReloader return !string.Equals(oldJetStream.StoreDir, newJetStream.StoreDir, StringComparison.Ordinal); } + + /// + /// Reloads TLS certificates from the current options and atomically swaps them + /// into the certificate provider. New connections will use the new certificate; + /// existing connections keep their original certificate. + /// Reference: golang/nats-server/server/reload.go — tlsOption.Apply. + /// + public static bool ReloadTlsCertificate( + NatsOptions options, + TlsCertificateProvider? certProvider) + { + if (certProvider == null || !options.HasTls) + return false; + + var oldCert = certProvider.SwapCertificate(options.TlsCert!, options.TlsKey); + oldCert?.Dispose(); + + // Rebuild SslServerAuthenticationOptions with the new certificate + var newSslOptions = TlsHelper.BuildServerAuthOptions(options); + certProvider.SwapSslOptions(newSslOptions); + + return true; + } } /// diff --git a/src/NATS.Server/NatsServer.cs b/src/NATS.Server/NatsServer.cs index 595de24..90a27df 100644 --- a/src/NATS.Server/NatsServer.cs +++ b/src/NATS.Server/NatsServer.cs @@ -50,8 +50,9 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable private readonly Account _globalAccount; private readonly Account _systemAccount; private InternalEventSystem? _eventSystem; - private readonly SslServerAuthenticationOptions? _sslOptions; + private SslServerAuthenticationOptions? _sslOptions; private readonly TlsRateLimiter? _tlsRateLimiter; + private readonly TlsCertificateProvider? _tlsCertProvider; private readonly SubjectTransform[] _subjectTransforms; private readonly RouteManager? _routeManager; @@ -148,6 +149,8 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable public void WaitForShutdown() => _shutdownComplete.Task.GetAwaiter().GetResult(); + internal TlsCertificateProvider? TlsCertProviderForTest => _tlsCertProvider; + internal Task AcquireReloadLockForTestAsync() => _reloadMu.WaitAsync(); internal void ReleaseReloadLockForTest() => _reloadMu.Release(); @@ -427,7 +430,9 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable if (options.HasTls) { + _tlsCertProvider = new TlsCertificateProvider(options.TlsCert!, options.TlsKey); _sslOptions = TlsHelper.BuildServerAuthOptions(options); + _tlsCertProvider.SwapSslOptions(_sslOptions); // OCSP stapling: build a certificate context so the runtime can // fetch and cache a fresh OCSP response and staple it during the @@ -1377,6 +1382,16 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable Connections = ClientCount, TotalConnections = Interlocked.Read(ref _stats.TotalConnections), Subscriptions = SubList.Count, + Sent = new Events.DataStats + { + Msgs = Interlocked.Read(ref _stats.OutMsgs), + Bytes = Interlocked.Read(ref _stats.OutBytes), + }, + Received = new Events.DataStats + { + Msgs = Interlocked.Read(ref _stats.InMsgs), + Bytes = Interlocked.Read(ref _stats.InBytes), + }, InMsgs = Interlocked.Read(ref _stats.InMsgs), OutMsgs = Interlocked.Read(ref _stats.OutMsgs), InBytes = Interlocked.Read(ref _stats.InBytes), @@ -1672,11 +1687,13 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable { bool hasLoggingChanges = false; bool hasAuthChanges = false; + bool hasTlsChanges = false; foreach (var change in changes) { if (change.IsLoggingChange) hasLoggingChanges = true; if (change.IsAuthChange) hasAuthChanges = true; + if (change.IsTlsChange) hasTlsChanges = true; } // Copy reloadable values from newOpts to _options @@ -1689,6 +1706,18 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable _logger.LogInformation("Logging configuration reloaded"); } + if (hasTlsChanges) + { + // Reload TLS certificates: new connections get the new cert, + // existing connections keep their original cert. + // Reference: golang/nats-server/server/reload.go — tlsOption.Apply. + if (ConfigReloader.ReloadTlsCertificate(_options, _tlsCertProvider)) + { + _sslOptions = _tlsCertProvider!.GetCurrentSslOptions(); + _logger.LogInformation("TLS configuration reloaded"); + } + } + if (hasAuthChanges) { // Rebuild auth service with new options, then propagate changes to connected clients @@ -1837,6 +1866,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable reg.Dispose(); _quitCts.Dispose(); _tlsRateLimiter?.Dispose(); + _tlsCertProvider?.Dispose(); _listener?.Dispose(); _wsListener?.Dispose(); _routeManager?.DisposeAsync().AsTask().GetAwaiter().GetResult(); diff --git a/src/NATS.Server/Tls/TlsCertificateProvider.cs b/src/NATS.Server/Tls/TlsCertificateProvider.cs new file mode 100644 index 0000000..18b17df --- /dev/null +++ b/src/NATS.Server/Tls/TlsCertificateProvider.cs @@ -0,0 +1,89 @@ +// TLS certificate provider that supports atomic cert swapping for hot reload. +// New connections get the current certificate; existing connections keep their original. +// Reference: golang/nats-server/server/reload.go — tlsOption.Apply. + +using System.Net.Security; +using System.Security.Cryptography.X509Certificates; + +namespace NATS.Server.Tls; + +/// +/// Thread-safe provider for TLS certificates that supports atomic swapping +/// during config reload. New connections retrieve the latest certificate via +/// ; existing connections are unaffected. +/// +public sealed class TlsCertificateProvider : IDisposable +{ + private volatile X509Certificate2? _currentCert; + private volatile SslServerAuthenticationOptions? _currentSslOptions; + private int _version; + + /// + /// Creates a new provider and loads the initial certificate from the given paths. + /// + public TlsCertificateProvider(string certPath, string? keyPath) + { + _currentCert = TlsHelper.LoadCertificate(certPath, keyPath); + } + + /// + /// Creates a provider from a pre-loaded certificate (for testing). + /// + public TlsCertificateProvider(X509Certificate2 cert) + { + _currentCert = cert; + } + + /// + /// Returns the current certificate. This is called for each new TLS handshake + /// so that new connections always get the latest certificate. + /// + public X509Certificate2? GetCurrentCertificate() => _currentCert; + + /// + /// Atomically swaps the current certificate with a newly loaded one. + /// Returns the old certificate (caller may dispose it after existing connections drain). + /// + public X509Certificate2? SwapCertificate(string certPath, string? keyPath) + { + var newCert = TlsHelper.LoadCertificate(certPath, keyPath); + return SwapCertificate(newCert); + } + + /// + /// Atomically swaps the current certificate with the provided one. + /// Returns the old certificate. + /// + public X509Certificate2? SwapCertificate(X509Certificate2 newCert) + { + var old = Interlocked.Exchange(ref _currentCert, newCert); + Interlocked.Increment(ref _version); + return old; + } + + /// + /// Returns the current SSL options, rebuilding them if the certificate has changed. + /// + public SslServerAuthenticationOptions? GetCurrentSslOptions() => _currentSslOptions; + + /// + /// Atomically swaps the SSL server authentication options. + /// Called after TLS config changes are detected during reload. + /// + public void SwapSslOptions(SslServerAuthenticationOptions newOptions) + { + Interlocked.Exchange(ref _currentSslOptions, newOptions); + Interlocked.Increment(ref _version); + } + + /// + /// Monotonically increasing version number, incremented on each swap. + /// Useful for tests to verify a reload occurred. + /// + public int Version => Volatile.Read(ref _version); + + public void Dispose() + { + _currentCert?.Dispose(); + } +} diff --git a/src/NATS.Server/WebSocket/WsCompression.cs b/src/NATS.Server/WebSocket/WsCompression.cs index 92f0184..cd389e1 100644 --- a/src/NATS.Server/WebSocket/WsCompression.cs +++ b/src/NATS.Server/WebSocket/WsCompression.cs @@ -2,6 +2,146 @@ using System.IO.Compression; namespace NATS.Server.WebSocket; +/// +/// Negotiated permessage-deflate parameters per RFC 7692 Section 7.1. +/// Captures the results of extension parameter negotiation during the +/// WebSocket upgrade handshake. +/// +public readonly record struct WsDeflateParams( + bool ServerNoContextTakeover, + bool ClientNoContextTakeover, + int ServerMaxWindowBits, + int ClientMaxWindowBits) +{ + /// + /// Default parameters matching NATS Go server behavior: + /// both sides use no_context_takeover, default 15-bit windows. + /// + public static readonly WsDeflateParams Default = new( + ServerNoContextTakeover: true, + ClientNoContextTakeover: true, + ServerMaxWindowBits: 15, + ClientMaxWindowBits: 15); + + /// + /// Builds the Sec-WebSocket-Extensions response header value from negotiated parameters. + /// Only includes parameters that differ from the default RFC values. + /// Reference: RFC 7692 Section 7.1. + /// + public string ToResponseHeaderValue() + { + var parts = new List { WsConstants.PmcExtension }; + + if (ServerNoContextTakeover) + parts.Add(WsConstants.PmcSrvNoCtx); + if (ClientNoContextTakeover) + parts.Add(WsConstants.PmcCliNoCtx); + if (ServerMaxWindowBits is > 0 and < 15) + parts.Add($"server_max_window_bits={ServerMaxWindowBits}"); + if (ClientMaxWindowBits is > 0 and < 15) + parts.Add($"client_max_window_bits={ClientMaxWindowBits}"); + + return string.Join("; ", parts); + } +} + +/// +/// Parses and negotiates permessage-deflate extension parameters from the +/// Sec-WebSocket-Extensions header per RFC 7692 Section 7. +/// Reference: golang/nats-server/server/websocket.go — wsPMCExtensionSupport. +/// +public static class WsDeflateNegotiator +{ + /// + /// Parses the Sec-WebSocket-Extensions header value and negotiates + /// permessage-deflate parameters. Returns null if no valid + /// permessage-deflate offer is found. + /// + public static WsDeflateParams? Negotiate(string? extensionHeader) + { + if (string.IsNullOrEmpty(extensionHeader)) + return null; + + // The header may contain multiple extensions separated by commas + var extensions = extensionHeader.Split(','); + foreach (var extension in extensions) + { + var trimmed = extension.Trim(); + var parts = trimmed.Split(';'); + + // First part must be the extension name + if (parts.Length == 0) + continue; + + if (!string.Equals(parts[0].Trim(), WsConstants.PmcExtension, StringComparison.OrdinalIgnoreCase)) + continue; + + // Found permessage-deflate — parse parameters + // Note: serverNoCtx and clientNoCtx are parsed but always overridden + // with true below (NATS enforces no_context_takeover for both sides). + int serverMaxWindowBits = 15; + int clientMaxWindowBits = 15; + + for (int i = 1; i < parts.Length; i++) + { + var param = parts[i].Trim(); + + if (string.Equals(param, WsConstants.PmcSrvNoCtx, StringComparison.OrdinalIgnoreCase)) + { + // Parsed but overridden: NATS always enforces no_context_takeover. + } + else if (string.Equals(param, WsConstants.PmcCliNoCtx, StringComparison.OrdinalIgnoreCase)) + { + // Parsed but overridden: NATS always enforces no_context_takeover. + } + else if (param.StartsWith("server_max_window_bits", StringComparison.OrdinalIgnoreCase)) + { + serverMaxWindowBits = ParseWindowBits(param, 15); + } + else if (param.StartsWith("client_max_window_bits", StringComparison.OrdinalIgnoreCase)) + { + // client_max_window_bits with no value means the client supports it + // and the server may choose a value. Per RFC 7692 Section 7.1.2.2, + // an offer with just "client_max_window_bits" (no value) indicates + // the client can accept any value 8-15. + clientMaxWindowBits = ParseWindowBits(param, 15); + } + } + + // NATS server always enforces no_context_takeover for both sides + // (matching Go behavior) to avoid holding compressor state per connection. + return new WsDeflateParams( + ServerNoContextTakeover: true, + ClientNoContextTakeover: true, + ServerMaxWindowBits: ClampWindowBits(serverMaxWindowBits), + ClientMaxWindowBits: ClampWindowBits(clientMaxWindowBits)); + } + + return null; + } + + private static int ParseWindowBits(string param, int defaultValue) + { + var eqIdx = param.IndexOf('='); + if (eqIdx < 0) + return defaultValue; + + var valueStr = param[(eqIdx + 1)..].Trim(); + if (int.TryParse(valueStr, out var bits)) + return bits; + + return defaultValue; + } + + private static int ClampWindowBits(int bits) + { + // RFC 7692: valid range is 8-15 + if (bits < 8) return 8; + if (bits > 15) return 15; + return bits; + } +} + /// /// permessage-deflate compression/decompression for WebSocket frames (RFC 7692). /// Ported from golang/nats-server/server/websocket.go lines 403-440 and 1391-1466. diff --git a/src/NATS.Server/WebSocket/WsUpgrade.cs b/src/NATS.Server/WebSocket/WsUpgrade.cs index d2fddbc..39fa113 100644 --- a/src/NATS.Server/WebSocket/WsUpgrade.cs +++ b/src/NATS.Server/WebSocket/WsUpgrade.cs @@ -18,7 +18,7 @@ public static class WsUpgrade { using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); cts.CancelAfter(options.HandshakeTimeout); - var (method, path, headers) = await ReadHttpRequestAsync(inputStream, cts.Token); + var (method, path, queryString, headers) = await ReadHttpRequestAsync(inputStream, cts.Token); if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase)) return await FailAsync(outputStream, 405, "request method must be GET"); @@ -57,15 +57,17 @@ public static class WsUpgrade return await FailAsync(outputStream, 403, $"origin not allowed: {originErr}"); } - // Compression negotiation + // Compression negotiation (RFC 7692) bool compress = options.Compression; + WsDeflateParams? deflateParams = null; if (compress) { - compress = headers.TryGetValue("Sec-WebSocket-Extensions", out var ext) && - ext.Contains(WsConstants.PmcExtension, StringComparison.OrdinalIgnoreCase); + headers.TryGetValue("Sec-WebSocket-Extensions", out var ext); + deflateParams = WsDeflateNegotiator.Negotiate(ext); + compress = deflateParams != null; } - // No-masking support (leaf nodes only — browser clients must always mask) + // No-masking support (leaf nodes only -- browser clients must always mask) bool noMasking = kind == WsClientKind.Leaf && headers.TryGetValue(WsConstants.NoMaskingHeader, out var nmVal) && string.Equals(nmVal.Trim(), WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase); @@ -95,6 +97,24 @@ public static class WsUpgrade if (options.TokenCookie != null) cookies.TryGetValue(options.TokenCookie, out cookieToken); } + // JWT extraction from multiple sources (E11): + // Priority: Authorization header > cookie > query parameter + // Reference: NATS WebSocket JWT auth — browser clients often pass JWT + // via cookie or query param since they cannot set custom headers. + string? jwt = null; + if (headers.TryGetValue("Authorization", out var authHeader)) + { + jwt = ExtractBearerToken(authHeader); + } + + jwt ??= cookieJwt; + + if (jwt == null && queryString != null) + { + var queryParams = ParseQueryString(queryString); + queryParams.TryGetValue("jwt", out jwt); + } + // X-Forwarded-For client IP extraction string? clientIp = null; if (headers.TryGetValue(WsConstants.XForwardedForHeader, out var xff)) @@ -109,8 +129,13 @@ public static class WsUpgrade response.Append("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "); response.Append(ComputeAcceptKey(key)); response.Append("\r\n"); - if (compress) - response.Append(WsConstants.PmcFullResponse); + if (compress && deflateParams != null) + { + response.Append("Sec-WebSocket-Extensions: "); + response.Append(deflateParams.Value.ToResponseHeaderValue()); + response.Append("\r\n"); + } + if (noMasking) response.Append(WsConstants.NoMaskingFullResponse); if (options.Headers != null) @@ -135,7 +160,8 @@ public static class WsUpgrade MaskRead: !noMasking, MaskWrite: false, CookieJwt: cookieJwt, CookieUsername: cookieUsername, CookiePassword: cookiePassword, CookieToken: cookieToken, - ClientIp: clientIp, Kind: kind); + ClientIp: clientIp, Kind: kind, + DeflateParams: deflateParams, Jwt: jwt); } catch (Exception) { @@ -153,11 +179,56 @@ public static class WsUpgrade return Convert.ToBase64String(hash); } + /// + /// Extracts a bearer token from an Authorization header value. + /// Supports both "Bearer {token}" and bare "{token}" formats. + /// + internal static string? ExtractBearerToken(string? authHeader) + { + if (string.IsNullOrWhiteSpace(authHeader)) + return null; + + var trimmed = authHeader.Trim(); + if (trimmed.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase)) + return trimmed["Bearer ".Length..].Trim(); + + // Some clients send the token directly without "Bearer" prefix + return trimmed; + } + + /// + /// Parses a query string into key-value pairs. + /// + internal static Dictionary ParseQueryString(string queryString) + { + var result = new Dictionary(StringComparer.OrdinalIgnoreCase); + if (queryString.StartsWith('?')) + queryString = queryString[1..]; + + foreach (var pair in queryString.Split('&')) + { + var eqIdx = pair.IndexOf('='); + if (eqIdx > 0) + { + var name = Uri.UnescapeDataString(pair[..eqIdx]); + var value = Uri.UnescapeDataString(pair[(eqIdx + 1)..]); + result[name] = value; + } + else if (pair.Length > 0) + { + result[Uri.UnescapeDataString(pair)] = string.Empty; + } + } + + return result; + } + private static async Task FailAsync(Stream output, int statusCode, string reason) { var statusText = statusCode switch { 400 => "Bad Request", + 401 => "Unauthorized", 403 => "Forbidden", 405 => "Method Not Allowed", _ => "Internal Server Error", @@ -165,10 +236,21 @@ public static class WsUpgrade var response = $"HTTP/1.1 {statusCode} {statusText}\r\nSec-WebSocket-Version: 13\r\nContent-Type: text/plain\r\nContent-Length: {reason.Length}\r\n\r\n{reason}"; await output.WriteAsync(Encoding.ASCII.GetBytes(response)); await output.FlushAsync(); - return WsUpgradeResult.Failed; + return statusCode == 401 + ? WsUpgradeResult.Unauthorized + : WsUpgradeResult.Failed; } - private static async Task<(string method, string path, Dictionary headers)> ReadHttpRequestAsync( + /// + /// Sends a 401 Unauthorized response and returns a failed upgrade result. + /// Used by the server when JWT authentication fails during WS upgrade. + /// + public static async Task FailUnauthorizedAsync(Stream output, string reason) + { + return await FailAsync(output, 401, reason); + } + + private static async Task<(string method, string path, string? queryString, Dictionary headers)> ReadHttpRequestAsync( Stream stream, CancellationToken ct) { var headerBytes = new List(4096); @@ -197,7 +279,21 @@ public static class WsUpgrade var parts = lines[0].Split(' '); if (parts.Length < 3) throw new InvalidOperationException("invalid HTTP request line"); var method = parts[0]; - var path = parts[1]; + var requestUri = parts[1]; + + // Split path and query string + string path; + string? queryString = null; + var qIdx = requestUri.IndexOf('?'); + if (qIdx >= 0) + { + path = requestUri[..qIdx]; + queryString = requestUri[qIdx..]; // includes the '?' + } + else + { + path = requestUri; + } var headers = new Dictionary(StringComparer.OrdinalIgnoreCase); for (int i = 1; i < lines.Length; i++) @@ -213,7 +309,7 @@ public static class WsUpgrade } } - return (method, path, headers); + return (method, path, queryString, headers); } private static bool HeaderContains(Dictionary headers, string name, string value) @@ -259,10 +355,17 @@ public readonly record struct WsUpgradeResult( string? CookiePassword, string? CookieToken, string? ClientIp, - WsClientKind Kind) + WsClientKind Kind, + WsDeflateParams? DeflateParams = null, + string? Jwt = null) { public static readonly WsUpgradeResult Failed = new( Success: false, Compress: false, Browser: false, NoCompFrag: false, MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null, CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client); + + public static readonly WsUpgradeResult Unauthorized = new( + Success: false, Compress: false, Browser: false, NoCompFrag: false, + MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null, + CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client); } diff --git a/tests/NATS.Server.Tests/Configuration/TlsReloadTests.cs b/tests/NATS.Server.Tests/Configuration/TlsReloadTests.cs new file mode 100644 index 0000000..0cd0d10 --- /dev/null +++ b/tests/NATS.Server.Tests/Configuration/TlsReloadTests.cs @@ -0,0 +1,239 @@ +// Tests for TLS certificate hot reload (E9). +// Verifies that TlsCertificateProvider supports atomic cert swapping +// and that ConfigReloader.ReloadTlsCertificate integrates correctly. +// Reference: golang/nats-server/server/reload_test.go — TestConfigReloadRotateTLS (line 392). + +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using NATS.Server.Configuration; +using NATS.Server.Tls; + +namespace NATS.Server.Tests.Configuration; + +public class TlsReloadTests +{ + /// + /// Generates a self-signed X509Certificate2 for testing. + /// + private static X509Certificate2 GenerateSelfSignedCert(string cn = "test") + { + using var rsa = RSA.Create(2048); + var req = new CertificateRequest($"CN={cn}", rsa, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + var cert = req.CreateSelfSigned(DateTimeOffset.UtcNow, DateTimeOffset.UtcNow.AddDays(1)); + // Export and re-import to ensure the cert has the private key bound + return X509CertificateLoader.LoadPkcs12(cert.Export(X509ContentType.Pkcs12), null); + } + + [Fact] + public void CertificateProvider_GetCurrentCertificate_ReturnsInitialCert() + { + // Go parity: TestConfigReloadRotateTLS — initial cert is usable + var cert = GenerateSelfSignedCert("initial"); + using var provider = new TlsCertificateProvider(cert); + + var current = provider.GetCurrentCertificate(); + + current.ShouldNotBeNull(); + current.Subject.ShouldContain("initial"); + } + + [Fact] + public void CertificateProvider_SwapCertificate_ReturnsOldCert() + { + // Go parity: TestConfigReloadRotateTLS — cert rotation returns old cert + var cert1 = GenerateSelfSignedCert("cert1"); + var cert2 = GenerateSelfSignedCert("cert2"); + using var provider = new TlsCertificateProvider(cert1); + + var old = provider.SwapCertificate(cert2); + + old.ShouldNotBeNull(); + old.Subject.ShouldContain("cert1"); + old.Dispose(); + + var current = provider.GetCurrentCertificate(); + current.ShouldNotBeNull(); + current.Subject.ShouldContain("cert2"); + } + + [Fact] + public void CertificateProvider_SwapCertificate_IncrementsVersion() + { + // Go parity: TestConfigReloadRotateTLS — version tracking for reload detection + var cert1 = GenerateSelfSignedCert("v1"); + var cert2 = GenerateSelfSignedCert("v2"); + using var provider = new TlsCertificateProvider(cert1); + + var v0 = provider.Version; + v0.ShouldBe(0); + + provider.SwapCertificate(cert2)?.Dispose(); + provider.Version.ShouldBe(1); + } + + [Fact] + public void CertificateProvider_MultipleSwa_NewConnectionsGetLatest() + { + // Go parity: TestConfigReloadRotateTLS — multiple rotations, each new + // handshake gets the latest certificate + var cert1 = GenerateSelfSignedCert("round1"); + var cert2 = GenerateSelfSignedCert("round2"); + var cert3 = GenerateSelfSignedCert("round3"); + using var provider = new TlsCertificateProvider(cert1); + + provider.GetCurrentCertificate()!.Subject.ShouldContain("round1"); + + provider.SwapCertificate(cert2)?.Dispose(); + provider.GetCurrentCertificate()!.Subject.ShouldContain("round2"); + + provider.SwapCertificate(cert3)?.Dispose(); + provider.GetCurrentCertificate()!.Subject.ShouldContain("round3"); + + provider.Version.ShouldBe(2); + } + + [Fact] + public async Task CertificateProvider_ConcurrentAccess_IsThreadSafe() + { + // Go parity: TestConfigReloadRotateTLS — cert swap must be safe under + // concurrent connection accept + var cert1 = GenerateSelfSignedCert("concurrent1"); + using var provider = new TlsCertificateProvider(cert1); + + var tasks = new Task[50]; + for (int i = 0; i < tasks.Length; i++) + { + var idx = i; + tasks[i] = Task.Run(() => + { + if (idx % 2 == 0) + { + // Readers — simulate new connections getting current cert + var c = provider.GetCurrentCertificate(); + c.ShouldNotBeNull(); + } + else + { + // Writers — simulate reload + var newCert = GenerateSelfSignedCert($"swap-{idx}"); + provider.SwapCertificate(newCert)?.Dispose(); + } + }); + } + + await Task.WhenAll(tasks); + + // After all swaps, the provider should still return a valid cert + provider.GetCurrentCertificate().ShouldNotBeNull(); + } + + [Fact] + public void ReloadTlsCertificate_NullProvider_ReturnsFalse() + { + // Edge case: server running without TLS + var opts = new NatsOptions(); + var result = ConfigReloader.ReloadTlsCertificate(opts, null); + result.ShouldBeFalse(); + } + + [Fact] + public void ReloadTlsCertificate_NoTlsConfig_ReturnsFalse() + { + // Edge case: provider exists but options don't have TLS paths + var cert = GenerateSelfSignedCert("no-tls"); + using var provider = new TlsCertificateProvider(cert); + + var opts = new NatsOptions(); // HasTls is false (no TlsCert/TlsKey) + var result = ConfigReloader.ReloadTlsCertificate(opts, provider); + result.ShouldBeFalse(); + } + + [Fact] + public void ReloadTlsCertificate_WithCertFiles_SwapsCertAndSslOptions() + { + // Go parity: TestConfigReloadRotateTLS — full reload with cert files. + // Write a self-signed cert to temp files and verify the provider loads it. + var tempDir = Path.Combine(Path.GetTempPath(), $"nats-tls-test-{Guid.NewGuid():N}"); + Directory.CreateDirectory(tempDir); + try + { + var certPath = Path.Combine(tempDir, "cert.pem"); + var keyPath = Path.Combine(tempDir, "key.pem"); + WriteSelfSignedCertFiles(certPath, keyPath, "reload-test"); + + // Create provider with initial cert + var initialCert = GenerateSelfSignedCert("initial"); + using var provider = new TlsCertificateProvider(initialCert); + + var opts = new NatsOptions { TlsCert = certPath, TlsKey = keyPath }; + var result = ConfigReloader.ReloadTlsCertificate(opts, provider); + + result.ShouldBeTrue(); + provider.Version.ShouldBeGreaterThan(0); + provider.GetCurrentCertificate().ShouldNotBeNull(); + provider.GetCurrentSslOptions().ShouldNotBeNull(); + } + finally + { + Directory.Delete(tempDir, recursive: true); + } + } + + [Fact] + public void ConfigDiff_DetectsTlsChanges() + { + // Go parity: TestConfigReloadEnableTLS, TestConfigReloadDisableTLS + // Verify that diff detects TLS option changes and flags them + var oldOpts = new NatsOptions { TlsCert = "/old/cert.pem", TlsKey = "/old/key.pem" }; + var newOpts = new NatsOptions { TlsCert = "/new/cert.pem", TlsKey = "/new/key.pem" }; + + var changes = ConfigReloader.Diff(oldOpts, newOpts); + + changes.Count.ShouldBeGreaterThan(0); + changes.ShouldContain(c => c.IsTlsChange && c.Name == "TlsCert"); + changes.ShouldContain(c => c.IsTlsChange && c.Name == "TlsKey"); + } + + [Fact] + public void ConfigDiff_TlsVerifyChange_IsTlsChange() + { + // Go parity: TestConfigReloadRotateTLS — enabling client verification + var oldOpts = new NatsOptions { TlsVerify = false }; + var newOpts = new NatsOptions { TlsVerify = true }; + + var changes = ConfigReloader.Diff(oldOpts, newOpts); + + changes.ShouldContain(c => c.IsTlsChange && c.Name == "TlsVerify"); + } + + [Fact] + public void ConfigApplyResult_ReportsTlsChanges() + { + // Verify ApplyDiff flags TLS changes correctly + var changes = new List + { + new ConfigChange("TlsCert", isTlsChange: true), + new ConfigChange("TlsKey", isTlsChange: true), + }; + var oldOpts = new NatsOptions(); + var newOpts = new NatsOptions(); + + var result = ConfigReloader.ApplyDiff(changes, oldOpts, newOpts); + + result.HasTlsChanges.ShouldBeTrue(); + result.ChangeCount.ShouldBe(2); + } + + /// + /// Helper to write a self-signed certificate to PEM files. + /// + private static void WriteSelfSignedCertFiles(string certPath, string keyPath, string cn) + { + using var rsa = RSA.Create(2048); + var req = new CertificateRequest($"CN={cn}", rsa, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + var cert = req.CreateSelfSigned(DateTimeOffset.UtcNow, DateTimeOffset.UtcNow.AddDays(1)); + + File.WriteAllText(certPath, cert.ExportCertificatePem()); + File.WriteAllText(keyPath, rsa.ExportRSAPrivateKeyPem()); + } +} diff --git a/tests/NATS.Server.Tests/WebSocket/WsCompressionNegotiationTests.cs b/tests/NATS.Server.Tests/WebSocket/WsCompressionNegotiationTests.cs new file mode 100644 index 0000000..efef0c4 --- /dev/null +++ b/tests/NATS.Server.Tests/WebSocket/WsCompressionNegotiationTests.cs @@ -0,0 +1,327 @@ +// Tests for WebSocket permessage-deflate parameter negotiation (E10). +// Verifies RFC 7692 extension parameter parsing and negotiation during +// WebSocket upgrade handshake. +// Reference: golang/nats-server/server/websocket.go — wsPMCExtensionSupport (line 885). + +using System.Text; +using NATS.Server.WebSocket; + +namespace NATS.Server.Tests.WebSocket; + +public class WsCompressionNegotiationTests +{ + // ─── WsDeflateNegotiator.Negotiate tests ────────────────────────────── + + [Fact] + public void Negotiate_NullHeader_ReturnsNull() + { + // Go parity: wsPMCExtensionSupport — no extension header means no compression + var result = WsDeflateNegotiator.Negotiate(null); + result.ShouldBeNull(); + } + + [Fact] + public void Negotiate_EmptyHeader_ReturnsNull() + { + var result = WsDeflateNegotiator.Negotiate(""); + result.ShouldBeNull(); + } + + [Fact] + public void Negotiate_NoPermessageDeflate_ReturnsNull() + { + var result = WsDeflateNegotiator.Negotiate("x-webkit-deflate-frame"); + result.ShouldBeNull(); + } + + [Fact] + public void Negotiate_BarePermessageDeflate_ReturnsDefaults() + { + // Go parity: wsPMCExtensionSupport — basic extension without parameters + var result = WsDeflateNegotiator.Negotiate("permessage-deflate"); + + result.ShouldNotBeNull(); + // NATS always enforces no_context_takeover + result.Value.ServerNoContextTakeover.ShouldBeTrue(); + result.Value.ClientNoContextTakeover.ShouldBeTrue(); + result.Value.ServerMaxWindowBits.ShouldBe(15); + result.Value.ClientMaxWindowBits.ShouldBe(15); + } + + [Fact] + public void Negotiate_WithServerNoContextTakeover() + { + // Go parity: wsPMCExtensionSupport — server_no_context_takeover parameter + var result = WsDeflateNegotiator.Negotiate("permessage-deflate; server_no_context_takeover"); + + result.ShouldNotBeNull(); + result.Value.ServerNoContextTakeover.ShouldBeTrue(); + } + + [Fact] + public void Negotiate_WithClientNoContextTakeover() + { + // Go parity: wsPMCExtensionSupport — client_no_context_takeover parameter + var result = WsDeflateNegotiator.Negotiate("permessage-deflate; client_no_context_takeover"); + + result.ShouldNotBeNull(); + result.Value.ClientNoContextTakeover.ShouldBeTrue(); + } + + [Fact] + public void Negotiate_WithBothNoContextTakeover() + { + // Go parity: wsPMCExtensionSupport — both no_context_takeover parameters + var result = WsDeflateNegotiator.Negotiate( + "permessage-deflate; server_no_context_takeover; client_no_context_takeover"); + + result.ShouldNotBeNull(); + result.Value.ServerNoContextTakeover.ShouldBeTrue(); + result.Value.ClientNoContextTakeover.ShouldBeTrue(); + } + + [Fact] + public void Negotiate_WithServerMaxWindowBits() + { + // RFC 7692 Section 7.1.2.1: server_max_window_bits parameter + var result = WsDeflateNegotiator.Negotiate("permessage-deflate; server_max_window_bits=10"); + + result.ShouldNotBeNull(); + result.Value.ServerMaxWindowBits.ShouldBe(10); + } + + [Fact] + public void Negotiate_WithClientMaxWindowBits_Value() + { + // RFC 7692 Section 7.1.2.2: client_max_window_bits with explicit value + var result = WsDeflateNegotiator.Negotiate("permessage-deflate; client_max_window_bits=12"); + + result.ShouldNotBeNull(); + result.Value.ClientMaxWindowBits.ShouldBe(12); + } + + [Fact] + public void Negotiate_WithClientMaxWindowBits_NoValue() + { + // RFC 7692 Section 7.1.2.2: client_max_window_bits with no value means + // client supports any value 8-15; defaults to 15 + var result = WsDeflateNegotiator.Negotiate("permessage-deflate; client_max_window_bits"); + + result.ShouldNotBeNull(); + result.Value.ClientMaxWindowBits.ShouldBe(15); + } + + [Fact] + public void Negotiate_WindowBits_ClampedToValidRange() + { + // RFC 7692: valid range is 8-15 + var result = WsDeflateNegotiator.Negotiate( + "permessage-deflate; server_max_window_bits=5; client_max_window_bits=20"); + + result.ShouldNotBeNull(); + result.Value.ServerMaxWindowBits.ShouldBe(8); // Clamped up from 5 + result.Value.ClientMaxWindowBits.ShouldBe(15); // Clamped down from 20 + } + + [Fact] + public void Negotiate_FullParameters() + { + // All parameters specified + var result = WsDeflateNegotiator.Negotiate( + "permessage-deflate; server_no_context_takeover; client_no_context_takeover; server_max_window_bits=9; client_max_window_bits=11"); + + result.ShouldNotBeNull(); + result.Value.ServerNoContextTakeover.ShouldBeTrue(); + result.Value.ClientNoContextTakeover.ShouldBeTrue(); + result.Value.ServerMaxWindowBits.ShouldBe(9); + result.Value.ClientMaxWindowBits.ShouldBe(11); + } + + [Fact] + public void Negotiate_CaseInsensitive() + { + // RFC 7692 extension names are case-insensitive + var result = WsDeflateNegotiator.Negotiate("Permessage-Deflate; Server_No_Context_Takeover"); + + result.ShouldNotBeNull(); + result.Value.ServerNoContextTakeover.ShouldBeTrue(); + } + + [Fact] + public void Negotiate_MultipleExtensions_PicksDeflate() + { + // Header may contain multiple comma-separated extensions + var result = WsDeflateNegotiator.Negotiate( + "x-custom-ext, permessage-deflate; server_no_context_takeover, other-ext"); + + result.ShouldNotBeNull(); + result.Value.ServerNoContextTakeover.ShouldBeTrue(); + } + + [Fact] + public void Negotiate_WhitespaceHandling() + { + // Extra whitespace around parameters + var result = WsDeflateNegotiator.Negotiate( + " permessage-deflate ; server_no_context_takeover ; client_max_window_bits = 10 "); + + result.ShouldNotBeNull(); + result.Value.ServerNoContextTakeover.ShouldBeTrue(); + result.Value.ClientMaxWindowBits.ShouldBe(10); + } + + // ─── NatsAlwaysEnforcesNoContextTakeover ───────────────────────────── + + [Fact] + public void Negotiate_AlwaysEnforcesNoContextTakeover() + { + // NATS Go server always returns server_no_context_takeover and + // client_no_context_takeover regardless of what the client requests + var result = WsDeflateNegotiator.Negotiate("permessage-deflate"); + + result.ShouldNotBeNull(); + result.Value.ServerNoContextTakeover.ShouldBeTrue(); + result.Value.ClientNoContextTakeover.ShouldBeTrue(); + } + + // ─── WsDeflateParams.ToResponseHeaderValue tests ──────────────────── + + [Fact] + public void DefaultParams_ResponseHeader_ContainsNoContextTakeover() + { + var header = WsDeflateParams.Default.ToResponseHeaderValue(); + + header.ShouldContain("permessage-deflate"); + header.ShouldContain("server_no_context_takeover"); + header.ShouldContain("client_no_context_takeover"); + header.ShouldNotContain("server_max_window_bits"); + header.ShouldNotContain("client_max_window_bits"); + } + + [Fact] + public void CustomWindowBits_ResponseHeader_IncludesValues() + { + var params_ = new WsDeflateParams( + ServerNoContextTakeover: true, + ClientNoContextTakeover: true, + ServerMaxWindowBits: 10, + ClientMaxWindowBits: 12); + + var header = params_.ToResponseHeaderValue(); + + header.ShouldContain("server_max_window_bits=10"); + header.ShouldContain("client_max_window_bits=12"); + } + + [Fact] + public void DefaultWindowBits_ResponseHeader_OmitsValues() + { + // RFC 7692: window bits of 15 is the default and should not be sent + var params_ = new WsDeflateParams( + ServerNoContextTakeover: true, + ClientNoContextTakeover: true, + ServerMaxWindowBits: 15, + ClientMaxWindowBits: 15); + + var header = params_.ToResponseHeaderValue(); + + header.ShouldNotContain("server_max_window_bits"); + header.ShouldNotContain("client_max_window_bits"); + } + + // ─── Integration with WsUpgrade ───────────────────────────────────── + + [Fact] + public async Task Upgrade_WithDeflateParams_NegotiatesCompression() + { + // Go parity: WebSocket upgrade with permessage-deflate parameters + var request = BuildValidRequest(extraHeaders: + "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover; server_max_window_bits=10\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var opts = new WebSocketOptions { NoTls = true, Compression = true }; + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts); + + result.Success.ShouldBeTrue(); + result.Compress.ShouldBeTrue(); + result.DeflateParams.ShouldNotBeNull(); + result.DeflateParams.Value.ServerNoContextTakeover.ShouldBeTrue(); + result.DeflateParams.Value.ClientNoContextTakeover.ShouldBeTrue(); + result.DeflateParams.Value.ServerMaxWindowBits.ShouldBe(10); + } + + [Fact] + public async Task Upgrade_WithDeflateParams_ResponseIncludesNegotiatedParams() + { + var request = BuildValidRequest(extraHeaders: + "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_max_window_bits=10\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var opts = new WebSocketOptions { NoTls = true, Compression = true }; + await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts); + + var response = ReadResponse(outputStream); + response.ShouldContain("permessage-deflate"); + response.ShouldContain("server_no_context_takeover"); + response.ShouldContain("client_no_context_takeover"); + response.ShouldContain("client_max_window_bits=10"); + } + + [Fact] + public async Task Upgrade_CompressionDisabled_NoDeflateParams() + { + var request = BuildValidRequest(extraHeaders: + "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var opts = new WebSocketOptions { NoTls = true, Compression = false }; + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts); + + result.Success.ShouldBeTrue(); + result.Compress.ShouldBeFalse(); + result.DeflateParams.ShouldBeNull(); + } + + [Fact] + public async Task Upgrade_NoExtensionHeader_NoCompression() + { + var request = BuildValidRequest(); + var (inputStream, outputStream) = CreateStreamPair(request); + + var opts = new WebSocketOptions { NoTls = true, Compression = true }; + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts); + + result.Success.ShouldBeTrue(); + result.Compress.ShouldBeFalse(); + result.DeflateParams.ShouldBeNull(); + } + + // ─── Helpers ───────────────────────────────────────────────────────── + + private static string BuildValidRequest(string path = "/", string? extraHeaders = null) + { + var sb = new StringBuilder(); + sb.Append($"GET {path} HTTP/1.1\r\n"); + sb.Append("Host: localhost:4222\r\n"); + sb.Append("Upgrade: websocket\r\n"); + sb.Append("Connection: Upgrade\r\n"); + sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"); + sb.Append("Sec-WebSocket-Version: 13\r\n"); + if (extraHeaders != null) + sb.Append(extraHeaders); + sb.Append("\r\n"); + return sb.ToString(); + } + + private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest) + { + var inputBytes = Encoding.ASCII.GetBytes(httpRequest); + return (new MemoryStream(inputBytes), new MemoryStream()); + } + + private static string ReadResponse(MemoryStream output) + { + output.Position = 0; + return Encoding.ASCII.GetString(output.ToArray()); + } +} diff --git a/tests/NATS.Server.Tests/WebSocket/WsJwtAuthTests.cs b/tests/NATS.Server.Tests/WebSocket/WsJwtAuthTests.cs new file mode 100644 index 0000000..7f90df2 --- /dev/null +++ b/tests/NATS.Server.Tests/WebSocket/WsJwtAuthTests.cs @@ -0,0 +1,316 @@ +// Tests for WebSocket JWT authentication during upgrade (E11). +// Verifies JWT extraction from Authorization header, cookie, and query parameter. +// Reference: golang/nats-server/server/websocket.go — cookie JWT extraction (line 856), +// websocket_test.go — TestWSReloadTLSConfig (line 4066). + +using System.Text; +using NATS.Server.WebSocket; + +namespace NATS.Server.Tests.WebSocket; + +public class WsJwtAuthTests +{ + // ─── Authorization header JWT extraction ───────────────────────────── + + [Fact] + public async Task Upgrade_AuthorizationBearerHeader_ExtractsJwt() + { + // JWT from Authorization: Bearer header (standard HTTP auth) + var jwt = "eyJhbGciOiJFZDI1NTE5IiwidHlwIjoiSldUIn0.test-payload.test-sig"; + var request = BuildValidRequest(extraHeaders: + $"Authorization: Bearer {jwt}\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Jwt.ShouldBe(jwt); + } + + [Fact] + public async Task Upgrade_AuthorizationBearerCaseInsensitive() + { + // RFC 7235: "bearer" scheme is case-insensitive + var jwt = "my-jwt-token-123"; + var request = BuildValidRequest(extraHeaders: + $"Authorization: bearer {jwt}\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Jwt.ShouldBe(jwt); + } + + [Fact] + public async Task Upgrade_AuthorizationBareToken_ExtractsJwt() + { + // Some clients send the token directly without "Bearer" prefix + var jwt = "raw-jwt-token-456"; + var request = BuildValidRequest(extraHeaders: + $"Authorization: {jwt}\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Jwt.ShouldBe(jwt); + } + + // ─── Cookie JWT extraction ────────────────────────────────────────── + + [Fact] + public async Task Upgrade_JwtCookie_ExtractsJwt() + { + // Go parity: websocket.go line 856 — JWT from configured cookie name + var jwt = "cookie-jwt-token-789"; + var request = BuildValidRequest(extraHeaders: + $"Cookie: jwt={jwt}; other=value\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var opts = new WebSocketOptions { NoTls = true, JwtCookie = "jwt" }; + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts); + + result.Success.ShouldBeTrue(); + result.CookieJwt.ShouldBe(jwt); + // Cookie JWT is used as fallback when no Authorization header is present + result.Jwt.ShouldBe(jwt); + } + + [Fact] + public async Task Upgrade_AuthorizationHeader_TakesPriorityOverCookie() + { + // Authorization header has higher priority than cookie + var headerJwt = "auth-header-jwt"; + var cookieJwt = "cookie-jwt"; + var request = BuildValidRequest(extraHeaders: + $"Authorization: Bearer {headerJwt}\r\n" + + $"Cookie: jwt={cookieJwt}\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var opts = new WebSocketOptions { NoTls = true, JwtCookie = "jwt" }; + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts); + + result.Success.ShouldBeTrue(); + result.Jwt.ShouldBe(headerJwt); + result.CookieJwt.ShouldBe(cookieJwt); // Cookie value is still preserved + } + + // ─── Query parameter JWT extraction ───────────────────────────────── + + [Fact] + public async Task Upgrade_QueryParamJwt_ExtractsJwt() + { + // JWT from ?jwt= query parameter (useful for browser clients) + var jwt = "query-jwt-token-abc"; + var request = BuildValidRequest(path: $"/?jwt={jwt}"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Jwt.ShouldBe(jwt); + } + + [Fact] + public async Task Upgrade_QueryParamJwt_UrlEncoded() + { + // JWT value may be URL-encoded + var jwt = "eyJ0eXAiOiJKV1QifQ.payload.sig"; + var encoded = Uri.EscapeDataString(jwt); + var request = BuildValidRequest(path: $"/?jwt={encoded}"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Jwt.ShouldBe(jwt); + } + + [Fact] + public async Task Upgrade_AuthorizationHeader_TakesPriorityOverQueryParam() + { + // Authorization header > query parameter + var headerJwt = "auth-header-jwt"; + var queryJwt = "query-jwt"; + var request = BuildValidRequest( + path: $"/?jwt={queryJwt}", + extraHeaders: $"Authorization: Bearer {headerJwt}\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Jwt.ShouldBe(headerJwt); + } + + [Fact] + public async Task Upgrade_Cookie_TakesPriorityOverQueryParam() + { + // Cookie > query parameter + var cookieJwt = "cookie-jwt"; + var queryJwt = "query-jwt"; + var request = BuildValidRequest( + path: $"/?jwt={queryJwt}", + extraHeaders: $"Cookie: jwt_token={cookieJwt}\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var opts = new WebSocketOptions { NoTls = true, JwtCookie = "jwt_token" }; + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts); + + result.Success.ShouldBeTrue(); + result.Jwt.ShouldBe(cookieJwt); + } + + // ─── No JWT scenarios ─────────────────────────────────────────────── + + [Fact] + public async Task Upgrade_NoJwtAnywhere_JwtIsNull() + { + // No JWT in any source + var request = BuildValidRequest(); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Jwt.ShouldBeNull(); + } + + [Fact] + public async Task Upgrade_EmptyAuthorizationHeader_JwtIsEmpty() + { + // Empty authorization header should produce empty string (non-null) + var request = BuildValidRequest(extraHeaders: "Authorization: \r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + // Empty auth header is treated as null/no JWT + result.Jwt.ShouldBeNull(); + } + + // ─── ExtractBearerToken unit tests ────────────────────────────────── + + [Fact] + public void ExtractBearerToken_BearerPrefix() + { + WsUpgrade.ExtractBearerToken("Bearer my-token").ShouldBe("my-token"); + } + + [Fact] + public void ExtractBearerToken_BearerPrefixLowerCase() + { + WsUpgrade.ExtractBearerToken("bearer my-token").ShouldBe("my-token"); + } + + [Fact] + public void ExtractBearerToken_BareToken() + { + WsUpgrade.ExtractBearerToken("raw-token").ShouldBe("raw-token"); + } + + [Fact] + public void ExtractBearerToken_Null() + { + WsUpgrade.ExtractBearerToken(null).ShouldBeNull(); + } + + [Fact] + public void ExtractBearerToken_Empty() + { + WsUpgrade.ExtractBearerToken("").ShouldBeNull(); + } + + [Fact] + public void ExtractBearerToken_Whitespace() + { + WsUpgrade.ExtractBearerToken(" ").ShouldBeNull(); + } + + // ─── ParseQueryString unit tests ──────────────────────────────────── + + [Fact] + public void ParseQueryString_SingleParam() + { + var result = WsUpgrade.ParseQueryString("?jwt=token123"); + result["jwt"].ShouldBe("token123"); + } + + [Fact] + public void ParseQueryString_MultipleParams() + { + var result = WsUpgrade.ParseQueryString("?jwt=token&user=admin"); + result["jwt"].ShouldBe("token"); + result["user"].ShouldBe("admin"); + } + + [Fact] + public void ParseQueryString_UrlEncoded() + { + var result = WsUpgrade.ParseQueryString("?jwt=a%20b%3Dc"); + result["jwt"].ShouldBe("a b=c"); + } + + [Fact] + public void ParseQueryString_NoQuestionMark() + { + var result = WsUpgrade.ParseQueryString("jwt=token"); + result["jwt"].ShouldBe("token"); + } + + // ─── FailUnauthorizedAsync ────────────────────────────────────────── + + [Fact] + public async Task FailUnauthorizedAsync_Returns401() + { + var output = new MemoryStream(); + var result = await WsUpgrade.FailUnauthorizedAsync(output, "invalid JWT"); + + result.Success.ShouldBeFalse(); + output.Position = 0; + var response = Encoding.ASCII.GetString(output.ToArray()); + response.ShouldContain("401"); + response.ShouldContain("invalid JWT"); + } + + // ─── Query param path routing still works with query strings ──────── + + [Fact] + public async Task Upgrade_PathWithQueryParam_StillRoutesCorrectly() + { + // /leafnode?jwt=token should still detect as leaf kind + var request = BuildValidRequest(path: "/leafnode?jwt=my-token"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Kind.ShouldBe(WsClientKind.Leaf); + result.Jwt.ShouldBe("my-token"); + } + + // ─── Helpers ───────────────────────────────────────────────────────── + + private static string BuildValidRequest(string path = "/", string? extraHeaders = null) + { + var sb = new StringBuilder(); + sb.Append($"GET {path} HTTP/1.1\r\n"); + sb.Append("Host: localhost:4222\r\n"); + sb.Append("Upgrade: websocket\r\n"); + sb.Append("Connection: Upgrade\r\n"); + sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"); + sb.Append("Sec-WebSocket-Version: 13\r\n"); + if (extraHeaders != null) + sb.Append(extraHeaders); + sb.Append("\r\n"); + return sb.ToString(); + } + + private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest) + { + var inputBytes = Encoding.ASCII.GetBytes(httpRequest); + return (new MemoryStream(inputBytes), new MemoryStream()); + } +}