feat(config+ws): add TLS cert reload, WS compression negotiation, WS JWT auth (E9+E10+E11)
E9: TLS Certificate Reload - Add TlsCertificateProvider with Interlocked-swappable cert field - New connections get current cert, existing connections keep theirs - ConfigReloader.ReloadTlsCertificate rebuilds SslServerAuthenticationOptions - NatsServer.ApplyConfigChanges triggers TLS reload on TLS config changes - 11 tests covering cert swap, versioning, thread safety, config diff E10: WebSocket Compression Negotiation (RFC 7692) - Add WsDeflateNegotiator to parse Sec-WebSocket-Extensions parameters - Parse server_no_context_takeover, client_no_context_takeover, server_max_window_bits, client_max_window_bits - WsDeflateParams record struct with ToResponseHeaderValue() - NATS always enforces no_context_takeover (matching Go server) - WsUpgrade returns negotiated WsDeflateParams in upgrade result - 22 tests covering parameter parsing, clamping, response headers E11: WebSocket JWT Authentication - Extract JWT from Authorization header (Bearer token), cookie, or ?jwt= query param - Priority: Authorization header > cookie > query parameter - WsUpgrade.TryUpgradeAsync now parses query string from request URI - Add FailUnauthorizedAsync for 401 responses - 24 tests covering all JWT extraction sources and priority ordering
This commit is contained in:
Binary file not shown.
@@ -1,6 +1,9 @@
|
|||||||
// Port of Go server/reload.go — config diffing, validation, and CLI override merging
|
// Port of Go server/reload.go — config diffing, validation, and CLI override merging
|
||||||
// for hot reload support. Reference: golang/nats-server/server/reload.go.
|
// for hot reload support. Reference: golang/nats-server/server/reload.go.
|
||||||
|
|
||||||
|
using System.Net.Security;
|
||||||
|
using NATS.Server.Tls;
|
||||||
|
|
||||||
namespace NATS.Server.Configuration;
|
namespace NATS.Server.Configuration;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
@@ -459,6 +462,29 @@ public static class ConfigReloader
|
|||||||
|
|
||||||
return !string.Equals(oldJetStream.StoreDir, newJetStream.StoreDir, StringComparison.Ordinal);
|
return !string.Equals(oldJetStream.StoreDir, newJetStream.StoreDir, StringComparison.Ordinal);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Reloads TLS certificates from the current options and atomically swaps them
|
||||||
|
/// into the certificate provider. New connections will use the new certificate;
|
||||||
|
/// existing connections keep their original certificate.
|
||||||
|
/// Reference: golang/nats-server/server/reload.go — tlsOption.Apply.
|
||||||
|
/// </summary>
|
||||||
|
public static bool ReloadTlsCertificate(
|
||||||
|
NatsOptions options,
|
||||||
|
TlsCertificateProvider? certProvider)
|
||||||
|
{
|
||||||
|
if (certProvider == null || !options.HasTls)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
var oldCert = certProvider.SwapCertificate(options.TlsCert!, options.TlsKey);
|
||||||
|
oldCert?.Dispose();
|
||||||
|
|
||||||
|
// Rebuild SslServerAuthenticationOptions with the new certificate
|
||||||
|
var newSslOptions = TlsHelper.BuildServerAuthOptions(options);
|
||||||
|
certProvider.SwapSslOptions(newSslOptions);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
|||||||
@@ -50,8 +50,9 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
|||||||
private readonly Account _globalAccount;
|
private readonly Account _globalAccount;
|
||||||
private readonly Account _systemAccount;
|
private readonly Account _systemAccount;
|
||||||
private InternalEventSystem? _eventSystem;
|
private InternalEventSystem? _eventSystem;
|
||||||
private readonly SslServerAuthenticationOptions? _sslOptions;
|
private SslServerAuthenticationOptions? _sslOptions;
|
||||||
private readonly TlsRateLimiter? _tlsRateLimiter;
|
private readonly TlsRateLimiter? _tlsRateLimiter;
|
||||||
|
private readonly TlsCertificateProvider? _tlsCertProvider;
|
||||||
private readonly SubjectTransform[] _subjectTransforms;
|
private readonly SubjectTransform[] _subjectTransforms;
|
||||||
private readonly RouteManager? _routeManager;
|
private readonly RouteManager? _routeManager;
|
||||||
|
|
||||||
@@ -148,6 +149,8 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
|||||||
|
|
||||||
public void WaitForShutdown() => _shutdownComplete.Task.GetAwaiter().GetResult();
|
public void WaitForShutdown() => _shutdownComplete.Task.GetAwaiter().GetResult();
|
||||||
|
|
||||||
|
internal TlsCertificateProvider? TlsCertProviderForTest => _tlsCertProvider;
|
||||||
|
|
||||||
internal Task AcquireReloadLockForTestAsync() => _reloadMu.WaitAsync();
|
internal Task AcquireReloadLockForTestAsync() => _reloadMu.WaitAsync();
|
||||||
|
|
||||||
internal void ReleaseReloadLockForTest() => _reloadMu.Release();
|
internal void ReleaseReloadLockForTest() => _reloadMu.Release();
|
||||||
@@ -427,7 +430,9 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
|||||||
|
|
||||||
if (options.HasTls)
|
if (options.HasTls)
|
||||||
{
|
{
|
||||||
|
_tlsCertProvider = new TlsCertificateProvider(options.TlsCert!, options.TlsKey);
|
||||||
_sslOptions = TlsHelper.BuildServerAuthOptions(options);
|
_sslOptions = TlsHelper.BuildServerAuthOptions(options);
|
||||||
|
_tlsCertProvider.SwapSslOptions(_sslOptions);
|
||||||
|
|
||||||
// OCSP stapling: build a certificate context so the runtime can
|
// OCSP stapling: build a certificate context so the runtime can
|
||||||
// fetch and cache a fresh OCSP response and staple it during the
|
// fetch and cache a fresh OCSP response and staple it during the
|
||||||
@@ -1377,6 +1382,16 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
|||||||
Connections = ClientCount,
|
Connections = ClientCount,
|
||||||
TotalConnections = Interlocked.Read(ref _stats.TotalConnections),
|
TotalConnections = Interlocked.Read(ref _stats.TotalConnections),
|
||||||
Subscriptions = SubList.Count,
|
Subscriptions = SubList.Count,
|
||||||
|
Sent = new Events.DataStats
|
||||||
|
{
|
||||||
|
Msgs = Interlocked.Read(ref _stats.OutMsgs),
|
||||||
|
Bytes = Interlocked.Read(ref _stats.OutBytes),
|
||||||
|
},
|
||||||
|
Received = new Events.DataStats
|
||||||
|
{
|
||||||
|
Msgs = Interlocked.Read(ref _stats.InMsgs),
|
||||||
|
Bytes = Interlocked.Read(ref _stats.InBytes),
|
||||||
|
},
|
||||||
InMsgs = Interlocked.Read(ref _stats.InMsgs),
|
InMsgs = Interlocked.Read(ref _stats.InMsgs),
|
||||||
OutMsgs = Interlocked.Read(ref _stats.OutMsgs),
|
OutMsgs = Interlocked.Read(ref _stats.OutMsgs),
|
||||||
InBytes = Interlocked.Read(ref _stats.InBytes),
|
InBytes = Interlocked.Read(ref _stats.InBytes),
|
||||||
@@ -1672,11 +1687,13 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
|||||||
{
|
{
|
||||||
bool hasLoggingChanges = false;
|
bool hasLoggingChanges = false;
|
||||||
bool hasAuthChanges = false;
|
bool hasAuthChanges = false;
|
||||||
|
bool hasTlsChanges = false;
|
||||||
|
|
||||||
foreach (var change in changes)
|
foreach (var change in changes)
|
||||||
{
|
{
|
||||||
if (change.IsLoggingChange) hasLoggingChanges = true;
|
if (change.IsLoggingChange) hasLoggingChanges = true;
|
||||||
if (change.IsAuthChange) hasAuthChanges = true;
|
if (change.IsAuthChange) hasAuthChanges = true;
|
||||||
|
if (change.IsTlsChange) hasTlsChanges = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy reloadable values from newOpts to _options
|
// Copy reloadable values from newOpts to _options
|
||||||
@@ -1689,6 +1706,18 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
|||||||
_logger.LogInformation("Logging configuration reloaded");
|
_logger.LogInformation("Logging configuration reloaded");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (hasTlsChanges)
|
||||||
|
{
|
||||||
|
// Reload TLS certificates: new connections get the new cert,
|
||||||
|
// existing connections keep their original cert.
|
||||||
|
// Reference: golang/nats-server/server/reload.go — tlsOption.Apply.
|
||||||
|
if (ConfigReloader.ReloadTlsCertificate(_options, _tlsCertProvider))
|
||||||
|
{
|
||||||
|
_sslOptions = _tlsCertProvider!.GetCurrentSslOptions();
|
||||||
|
_logger.LogInformation("TLS configuration reloaded");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (hasAuthChanges)
|
if (hasAuthChanges)
|
||||||
{
|
{
|
||||||
// Rebuild auth service with new options, then propagate changes to connected clients
|
// Rebuild auth service with new options, then propagate changes to connected clients
|
||||||
@@ -1837,6 +1866,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
|||||||
reg.Dispose();
|
reg.Dispose();
|
||||||
_quitCts.Dispose();
|
_quitCts.Dispose();
|
||||||
_tlsRateLimiter?.Dispose();
|
_tlsRateLimiter?.Dispose();
|
||||||
|
_tlsCertProvider?.Dispose();
|
||||||
_listener?.Dispose();
|
_listener?.Dispose();
|
||||||
_wsListener?.Dispose();
|
_wsListener?.Dispose();
|
||||||
_routeManager?.DisposeAsync().AsTask().GetAwaiter().GetResult();
|
_routeManager?.DisposeAsync().AsTask().GetAwaiter().GetResult();
|
||||||
|
|||||||
89
src/NATS.Server/Tls/TlsCertificateProvider.cs
Normal file
89
src/NATS.Server/Tls/TlsCertificateProvider.cs
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
// TLS certificate provider that supports atomic cert swapping for hot reload.
|
||||||
|
// New connections get the current certificate; existing connections keep their original.
|
||||||
|
// Reference: golang/nats-server/server/reload.go — tlsOption.Apply.
|
||||||
|
|
||||||
|
using System.Net.Security;
|
||||||
|
using System.Security.Cryptography.X509Certificates;
|
||||||
|
|
||||||
|
namespace NATS.Server.Tls;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Thread-safe provider for TLS certificates that supports atomic swapping
|
||||||
|
/// during config reload. New connections retrieve the latest certificate via
|
||||||
|
/// <see cref="GetCurrentCertificate"/>; existing connections are unaffected.
|
||||||
|
/// </summary>
|
||||||
|
public sealed class TlsCertificateProvider : IDisposable
|
||||||
|
{
|
||||||
|
private volatile X509Certificate2? _currentCert;
|
||||||
|
private volatile SslServerAuthenticationOptions? _currentSslOptions;
|
||||||
|
private int _version;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates a new provider and loads the initial certificate from the given paths.
|
||||||
|
/// </summary>
|
||||||
|
public TlsCertificateProvider(string certPath, string? keyPath)
|
||||||
|
{
|
||||||
|
_currentCert = TlsHelper.LoadCertificate(certPath, keyPath);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates a provider from a pre-loaded certificate (for testing).
|
||||||
|
/// </summary>
|
||||||
|
public TlsCertificateProvider(X509Certificate2 cert)
|
||||||
|
{
|
||||||
|
_currentCert = cert;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Returns the current certificate. This is called for each new TLS handshake
|
||||||
|
/// so that new connections always get the latest certificate.
|
||||||
|
/// </summary>
|
||||||
|
public X509Certificate2? GetCurrentCertificate() => _currentCert;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Atomically swaps the current certificate with a newly loaded one.
|
||||||
|
/// Returns the old certificate (caller may dispose it after existing connections drain).
|
||||||
|
/// </summary>
|
||||||
|
public X509Certificate2? SwapCertificate(string certPath, string? keyPath)
|
||||||
|
{
|
||||||
|
var newCert = TlsHelper.LoadCertificate(certPath, keyPath);
|
||||||
|
return SwapCertificate(newCert);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Atomically swaps the current certificate with the provided one.
|
||||||
|
/// Returns the old certificate.
|
||||||
|
/// </summary>
|
||||||
|
public X509Certificate2? SwapCertificate(X509Certificate2 newCert)
|
||||||
|
{
|
||||||
|
var old = Interlocked.Exchange(ref _currentCert, newCert);
|
||||||
|
Interlocked.Increment(ref _version);
|
||||||
|
return old;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Returns the current SSL options, rebuilding them if the certificate has changed.
|
||||||
|
/// </summary>
|
||||||
|
public SslServerAuthenticationOptions? GetCurrentSslOptions() => _currentSslOptions;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Atomically swaps the SSL server authentication options.
|
||||||
|
/// Called after TLS config changes are detected during reload.
|
||||||
|
/// </summary>
|
||||||
|
public void SwapSslOptions(SslServerAuthenticationOptions newOptions)
|
||||||
|
{
|
||||||
|
Interlocked.Exchange(ref _currentSslOptions, newOptions);
|
||||||
|
Interlocked.Increment(ref _version);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Monotonically increasing version number, incremented on each swap.
|
||||||
|
/// Useful for tests to verify a reload occurred.
|
||||||
|
/// </summary>
|
||||||
|
public int Version => Volatile.Read(ref _version);
|
||||||
|
|
||||||
|
public void Dispose()
|
||||||
|
{
|
||||||
|
_currentCert?.Dispose();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,6 +2,146 @@ using System.IO.Compression;
|
|||||||
|
|
||||||
namespace NATS.Server.WebSocket;
|
namespace NATS.Server.WebSocket;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Negotiated permessage-deflate parameters per RFC 7692 Section 7.1.
|
||||||
|
/// Captures the results of extension parameter negotiation during the
|
||||||
|
/// WebSocket upgrade handshake.
|
||||||
|
/// </summary>
|
||||||
|
public readonly record struct WsDeflateParams(
|
||||||
|
bool ServerNoContextTakeover,
|
||||||
|
bool ClientNoContextTakeover,
|
||||||
|
int ServerMaxWindowBits,
|
||||||
|
int ClientMaxWindowBits)
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Default parameters matching NATS Go server behavior:
|
||||||
|
/// both sides use no_context_takeover, default 15-bit windows.
|
||||||
|
/// </summary>
|
||||||
|
public static readonly WsDeflateParams Default = new(
|
||||||
|
ServerNoContextTakeover: true,
|
||||||
|
ClientNoContextTakeover: true,
|
||||||
|
ServerMaxWindowBits: 15,
|
||||||
|
ClientMaxWindowBits: 15);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Builds the Sec-WebSocket-Extensions response header value from negotiated parameters.
|
||||||
|
/// Only includes parameters that differ from the default RFC values.
|
||||||
|
/// Reference: RFC 7692 Section 7.1.
|
||||||
|
/// </summary>
|
||||||
|
public string ToResponseHeaderValue()
|
||||||
|
{
|
||||||
|
var parts = new List<string> { WsConstants.PmcExtension };
|
||||||
|
|
||||||
|
if (ServerNoContextTakeover)
|
||||||
|
parts.Add(WsConstants.PmcSrvNoCtx);
|
||||||
|
if (ClientNoContextTakeover)
|
||||||
|
parts.Add(WsConstants.PmcCliNoCtx);
|
||||||
|
if (ServerMaxWindowBits is > 0 and < 15)
|
||||||
|
parts.Add($"server_max_window_bits={ServerMaxWindowBits}");
|
||||||
|
if (ClientMaxWindowBits is > 0 and < 15)
|
||||||
|
parts.Add($"client_max_window_bits={ClientMaxWindowBits}");
|
||||||
|
|
||||||
|
return string.Join("; ", parts);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Parses and negotiates permessage-deflate extension parameters from the
|
||||||
|
/// Sec-WebSocket-Extensions header per RFC 7692 Section 7.
|
||||||
|
/// Reference: golang/nats-server/server/websocket.go — wsPMCExtensionSupport.
|
||||||
|
/// </summary>
|
||||||
|
public static class WsDeflateNegotiator
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Parses the Sec-WebSocket-Extensions header value and negotiates
|
||||||
|
/// permessage-deflate parameters. Returns null if no valid
|
||||||
|
/// permessage-deflate offer is found.
|
||||||
|
/// </summary>
|
||||||
|
public static WsDeflateParams? Negotiate(string? extensionHeader)
|
||||||
|
{
|
||||||
|
if (string.IsNullOrEmpty(extensionHeader))
|
||||||
|
return null;
|
||||||
|
|
||||||
|
// The header may contain multiple extensions separated by commas
|
||||||
|
var extensions = extensionHeader.Split(',');
|
||||||
|
foreach (var extension in extensions)
|
||||||
|
{
|
||||||
|
var trimmed = extension.Trim();
|
||||||
|
var parts = trimmed.Split(';');
|
||||||
|
|
||||||
|
// First part must be the extension name
|
||||||
|
if (parts.Length == 0)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (!string.Equals(parts[0].Trim(), WsConstants.PmcExtension, StringComparison.OrdinalIgnoreCase))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// Found permessage-deflate — parse parameters
|
||||||
|
// Note: serverNoCtx and clientNoCtx are parsed but always overridden
|
||||||
|
// with true below (NATS enforces no_context_takeover for both sides).
|
||||||
|
int serverMaxWindowBits = 15;
|
||||||
|
int clientMaxWindowBits = 15;
|
||||||
|
|
||||||
|
for (int i = 1; i < parts.Length; i++)
|
||||||
|
{
|
||||||
|
var param = parts[i].Trim();
|
||||||
|
|
||||||
|
if (string.Equals(param, WsConstants.PmcSrvNoCtx, StringComparison.OrdinalIgnoreCase))
|
||||||
|
{
|
||||||
|
// Parsed but overridden: NATS always enforces no_context_takeover.
|
||||||
|
}
|
||||||
|
else if (string.Equals(param, WsConstants.PmcCliNoCtx, StringComparison.OrdinalIgnoreCase))
|
||||||
|
{
|
||||||
|
// Parsed but overridden: NATS always enforces no_context_takeover.
|
||||||
|
}
|
||||||
|
else if (param.StartsWith("server_max_window_bits", StringComparison.OrdinalIgnoreCase))
|
||||||
|
{
|
||||||
|
serverMaxWindowBits = ParseWindowBits(param, 15);
|
||||||
|
}
|
||||||
|
else if (param.StartsWith("client_max_window_bits", StringComparison.OrdinalIgnoreCase))
|
||||||
|
{
|
||||||
|
// client_max_window_bits with no value means the client supports it
|
||||||
|
// and the server may choose a value. Per RFC 7692 Section 7.1.2.2,
|
||||||
|
// an offer with just "client_max_window_bits" (no value) indicates
|
||||||
|
// the client can accept any value 8-15.
|
||||||
|
clientMaxWindowBits = ParseWindowBits(param, 15);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NATS server always enforces no_context_takeover for both sides
|
||||||
|
// (matching Go behavior) to avoid holding compressor state per connection.
|
||||||
|
return new WsDeflateParams(
|
||||||
|
ServerNoContextTakeover: true,
|
||||||
|
ClientNoContextTakeover: true,
|
||||||
|
ServerMaxWindowBits: ClampWindowBits(serverMaxWindowBits),
|
||||||
|
ClientMaxWindowBits: ClampWindowBits(clientMaxWindowBits));
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static int ParseWindowBits(string param, int defaultValue)
|
||||||
|
{
|
||||||
|
var eqIdx = param.IndexOf('=');
|
||||||
|
if (eqIdx < 0)
|
||||||
|
return defaultValue;
|
||||||
|
|
||||||
|
var valueStr = param[(eqIdx + 1)..].Trim();
|
||||||
|
if (int.TryParse(valueStr, out var bits))
|
||||||
|
return bits;
|
||||||
|
|
||||||
|
return defaultValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static int ClampWindowBits(int bits)
|
||||||
|
{
|
||||||
|
// RFC 7692: valid range is 8-15
|
||||||
|
if (bits < 8) return 8;
|
||||||
|
if (bits > 15) return 15;
|
||||||
|
return bits;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// permessage-deflate compression/decompression for WebSocket frames (RFC 7692).
|
/// permessage-deflate compression/decompression for WebSocket frames (RFC 7692).
|
||||||
/// Ported from golang/nats-server/server/websocket.go lines 403-440 and 1391-1466.
|
/// Ported from golang/nats-server/server/websocket.go lines 403-440 and 1391-1466.
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ public static class WsUpgrade
|
|||||||
{
|
{
|
||||||
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
||||||
cts.CancelAfter(options.HandshakeTimeout);
|
cts.CancelAfter(options.HandshakeTimeout);
|
||||||
var (method, path, headers) = await ReadHttpRequestAsync(inputStream, cts.Token);
|
var (method, path, queryString, headers) = await ReadHttpRequestAsync(inputStream, cts.Token);
|
||||||
|
|
||||||
if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase))
|
if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase))
|
||||||
return await FailAsync(outputStream, 405, "request method must be GET");
|
return await FailAsync(outputStream, 405, "request method must be GET");
|
||||||
@@ -57,15 +57,17 @@ public static class WsUpgrade
|
|||||||
return await FailAsync(outputStream, 403, $"origin not allowed: {originErr}");
|
return await FailAsync(outputStream, 403, $"origin not allowed: {originErr}");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compression negotiation
|
// Compression negotiation (RFC 7692)
|
||||||
bool compress = options.Compression;
|
bool compress = options.Compression;
|
||||||
|
WsDeflateParams? deflateParams = null;
|
||||||
if (compress)
|
if (compress)
|
||||||
{
|
{
|
||||||
compress = headers.TryGetValue("Sec-WebSocket-Extensions", out var ext) &&
|
headers.TryGetValue("Sec-WebSocket-Extensions", out var ext);
|
||||||
ext.Contains(WsConstants.PmcExtension, StringComparison.OrdinalIgnoreCase);
|
deflateParams = WsDeflateNegotiator.Negotiate(ext);
|
||||||
|
compress = deflateParams != null;
|
||||||
}
|
}
|
||||||
|
|
||||||
// No-masking support (leaf nodes only — browser clients must always mask)
|
// No-masking support (leaf nodes only -- browser clients must always mask)
|
||||||
bool noMasking = kind == WsClientKind.Leaf &&
|
bool noMasking = kind == WsClientKind.Leaf &&
|
||||||
headers.TryGetValue(WsConstants.NoMaskingHeader, out var nmVal) &&
|
headers.TryGetValue(WsConstants.NoMaskingHeader, out var nmVal) &&
|
||||||
string.Equals(nmVal.Trim(), WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase);
|
string.Equals(nmVal.Trim(), WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase);
|
||||||
@@ -95,6 +97,24 @@ public static class WsUpgrade
|
|||||||
if (options.TokenCookie != null) cookies.TryGetValue(options.TokenCookie, out cookieToken);
|
if (options.TokenCookie != null) cookies.TryGetValue(options.TokenCookie, out cookieToken);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JWT extraction from multiple sources (E11):
|
||||||
|
// Priority: Authorization header > cookie > query parameter
|
||||||
|
// Reference: NATS WebSocket JWT auth — browser clients often pass JWT
|
||||||
|
// via cookie or query param since they cannot set custom headers.
|
||||||
|
string? jwt = null;
|
||||||
|
if (headers.TryGetValue("Authorization", out var authHeader))
|
||||||
|
{
|
||||||
|
jwt = ExtractBearerToken(authHeader);
|
||||||
|
}
|
||||||
|
|
||||||
|
jwt ??= cookieJwt;
|
||||||
|
|
||||||
|
if (jwt == null && queryString != null)
|
||||||
|
{
|
||||||
|
var queryParams = ParseQueryString(queryString);
|
||||||
|
queryParams.TryGetValue("jwt", out jwt);
|
||||||
|
}
|
||||||
|
|
||||||
// X-Forwarded-For client IP extraction
|
// X-Forwarded-For client IP extraction
|
||||||
string? clientIp = null;
|
string? clientIp = null;
|
||||||
if (headers.TryGetValue(WsConstants.XForwardedForHeader, out var xff))
|
if (headers.TryGetValue(WsConstants.XForwardedForHeader, out var xff))
|
||||||
@@ -109,8 +129,13 @@ public static class WsUpgrade
|
|||||||
response.Append("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ");
|
response.Append("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ");
|
||||||
response.Append(ComputeAcceptKey(key));
|
response.Append(ComputeAcceptKey(key));
|
||||||
response.Append("\r\n");
|
response.Append("\r\n");
|
||||||
if (compress)
|
if (compress && deflateParams != null)
|
||||||
response.Append(WsConstants.PmcFullResponse);
|
{
|
||||||
|
response.Append("Sec-WebSocket-Extensions: ");
|
||||||
|
response.Append(deflateParams.Value.ToResponseHeaderValue());
|
||||||
|
response.Append("\r\n");
|
||||||
|
}
|
||||||
|
|
||||||
if (noMasking)
|
if (noMasking)
|
||||||
response.Append(WsConstants.NoMaskingFullResponse);
|
response.Append(WsConstants.NoMaskingFullResponse);
|
||||||
if (options.Headers != null)
|
if (options.Headers != null)
|
||||||
@@ -135,7 +160,8 @@ public static class WsUpgrade
|
|||||||
MaskRead: !noMasking, MaskWrite: false,
|
MaskRead: !noMasking, MaskWrite: false,
|
||||||
CookieJwt: cookieJwt, CookieUsername: cookieUsername,
|
CookieJwt: cookieJwt, CookieUsername: cookieUsername,
|
||||||
CookiePassword: cookiePassword, CookieToken: cookieToken,
|
CookiePassword: cookiePassword, CookieToken: cookieToken,
|
||||||
ClientIp: clientIp, Kind: kind);
|
ClientIp: clientIp, Kind: kind,
|
||||||
|
DeflateParams: deflateParams, Jwt: jwt);
|
||||||
}
|
}
|
||||||
catch (Exception)
|
catch (Exception)
|
||||||
{
|
{
|
||||||
@@ -153,11 +179,56 @@ public static class WsUpgrade
|
|||||||
return Convert.ToBase64String(hash);
|
return Convert.ToBase64String(hash);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Extracts a bearer token from an Authorization header value.
|
||||||
|
/// Supports both "Bearer {token}" and bare "{token}" formats.
|
||||||
|
/// </summary>
|
||||||
|
internal static string? ExtractBearerToken(string? authHeader)
|
||||||
|
{
|
||||||
|
if (string.IsNullOrWhiteSpace(authHeader))
|
||||||
|
return null;
|
||||||
|
|
||||||
|
var trimmed = authHeader.Trim();
|
||||||
|
if (trimmed.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase))
|
||||||
|
return trimmed["Bearer ".Length..].Trim();
|
||||||
|
|
||||||
|
// Some clients send the token directly without "Bearer" prefix
|
||||||
|
return trimmed;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Parses a query string into key-value pairs.
|
||||||
|
/// </summary>
|
||||||
|
internal static Dictionary<string, string> ParseQueryString(string queryString)
|
||||||
|
{
|
||||||
|
var result = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
|
||||||
|
if (queryString.StartsWith('?'))
|
||||||
|
queryString = queryString[1..];
|
||||||
|
|
||||||
|
foreach (var pair in queryString.Split('&'))
|
||||||
|
{
|
||||||
|
var eqIdx = pair.IndexOf('=');
|
||||||
|
if (eqIdx > 0)
|
||||||
|
{
|
||||||
|
var name = Uri.UnescapeDataString(pair[..eqIdx]);
|
||||||
|
var value = Uri.UnescapeDataString(pair[(eqIdx + 1)..]);
|
||||||
|
result[name] = value;
|
||||||
|
}
|
||||||
|
else if (pair.Length > 0)
|
||||||
|
{
|
||||||
|
result[Uri.UnescapeDataString(pair)] = string.Empty;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
private static async Task<WsUpgradeResult> FailAsync(Stream output, int statusCode, string reason)
|
private static async Task<WsUpgradeResult> FailAsync(Stream output, int statusCode, string reason)
|
||||||
{
|
{
|
||||||
var statusText = statusCode switch
|
var statusText = statusCode switch
|
||||||
{
|
{
|
||||||
400 => "Bad Request",
|
400 => "Bad Request",
|
||||||
|
401 => "Unauthorized",
|
||||||
403 => "Forbidden",
|
403 => "Forbidden",
|
||||||
405 => "Method Not Allowed",
|
405 => "Method Not Allowed",
|
||||||
_ => "Internal Server Error",
|
_ => "Internal Server Error",
|
||||||
@@ -165,10 +236,21 @@ public static class WsUpgrade
|
|||||||
var response = $"HTTP/1.1 {statusCode} {statusText}\r\nSec-WebSocket-Version: 13\r\nContent-Type: text/plain\r\nContent-Length: {reason.Length}\r\n\r\n{reason}";
|
var response = $"HTTP/1.1 {statusCode} {statusText}\r\nSec-WebSocket-Version: 13\r\nContent-Type: text/plain\r\nContent-Length: {reason.Length}\r\n\r\n{reason}";
|
||||||
await output.WriteAsync(Encoding.ASCII.GetBytes(response));
|
await output.WriteAsync(Encoding.ASCII.GetBytes(response));
|
||||||
await output.FlushAsync();
|
await output.FlushAsync();
|
||||||
return WsUpgradeResult.Failed;
|
return statusCode == 401
|
||||||
|
? WsUpgradeResult.Unauthorized
|
||||||
|
: WsUpgradeResult.Failed;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static async Task<(string method, string path, Dictionary<string, string> headers)> ReadHttpRequestAsync(
|
/// <summary>
|
||||||
|
/// Sends a 401 Unauthorized response and returns a failed upgrade result.
|
||||||
|
/// Used by the server when JWT authentication fails during WS upgrade.
|
||||||
|
/// </summary>
|
||||||
|
public static async Task<WsUpgradeResult> FailUnauthorizedAsync(Stream output, string reason)
|
||||||
|
{
|
||||||
|
return await FailAsync(output, 401, reason);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static async Task<(string method, string path, string? queryString, Dictionary<string, string> headers)> ReadHttpRequestAsync(
|
||||||
Stream stream, CancellationToken ct)
|
Stream stream, CancellationToken ct)
|
||||||
{
|
{
|
||||||
var headerBytes = new List<byte>(4096);
|
var headerBytes = new List<byte>(4096);
|
||||||
@@ -197,7 +279,21 @@ public static class WsUpgrade
|
|||||||
var parts = lines[0].Split(' ');
|
var parts = lines[0].Split(' ');
|
||||||
if (parts.Length < 3) throw new InvalidOperationException("invalid HTTP request line");
|
if (parts.Length < 3) throw new InvalidOperationException("invalid HTTP request line");
|
||||||
var method = parts[0];
|
var method = parts[0];
|
||||||
var path = parts[1];
|
var requestUri = parts[1];
|
||||||
|
|
||||||
|
// Split path and query string
|
||||||
|
string path;
|
||||||
|
string? queryString = null;
|
||||||
|
var qIdx = requestUri.IndexOf('?');
|
||||||
|
if (qIdx >= 0)
|
||||||
|
{
|
||||||
|
path = requestUri[..qIdx];
|
||||||
|
queryString = requestUri[qIdx..]; // includes the '?'
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
path = requestUri;
|
||||||
|
}
|
||||||
|
|
||||||
var headers = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
|
var headers = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
|
||||||
for (int i = 1; i < lines.Length; i++)
|
for (int i = 1; i < lines.Length; i++)
|
||||||
@@ -213,7 +309,7 @@ public static class WsUpgrade
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return (method, path, headers);
|
return (method, path, queryString, headers);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static bool HeaderContains(Dictionary<string, string> headers, string name, string value)
|
private static bool HeaderContains(Dictionary<string, string> headers, string name, string value)
|
||||||
@@ -259,10 +355,17 @@ public readonly record struct WsUpgradeResult(
|
|||||||
string? CookiePassword,
|
string? CookiePassword,
|
||||||
string? CookieToken,
|
string? CookieToken,
|
||||||
string? ClientIp,
|
string? ClientIp,
|
||||||
WsClientKind Kind)
|
WsClientKind Kind,
|
||||||
|
WsDeflateParams? DeflateParams = null,
|
||||||
|
string? Jwt = null)
|
||||||
{
|
{
|
||||||
public static readonly WsUpgradeResult Failed = new(
|
public static readonly WsUpgradeResult Failed = new(
|
||||||
Success: false, Compress: false, Browser: false, NoCompFrag: false,
|
Success: false, Compress: false, Browser: false, NoCompFrag: false,
|
||||||
MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null,
|
MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null,
|
||||||
CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client);
|
CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client);
|
||||||
|
|
||||||
|
public static readonly WsUpgradeResult Unauthorized = new(
|
||||||
|
Success: false, Compress: false, Browser: false, NoCompFrag: false,
|
||||||
|
MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null,
|
||||||
|
CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client);
|
||||||
}
|
}
|
||||||
|
|||||||
239
tests/NATS.Server.Tests/Configuration/TlsReloadTests.cs
Normal file
239
tests/NATS.Server.Tests/Configuration/TlsReloadTests.cs
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
// Tests for TLS certificate hot reload (E9).
|
||||||
|
// Verifies that TlsCertificateProvider supports atomic cert swapping
|
||||||
|
// and that ConfigReloader.ReloadTlsCertificate integrates correctly.
|
||||||
|
// Reference: golang/nats-server/server/reload_test.go — TestConfigReloadRotateTLS (line 392).
|
||||||
|
|
||||||
|
using System.Security.Cryptography;
|
||||||
|
using System.Security.Cryptography.X509Certificates;
|
||||||
|
using NATS.Server.Configuration;
|
||||||
|
using NATS.Server.Tls;
|
||||||
|
|
||||||
|
namespace NATS.Server.Tests.Configuration;
|
||||||
|
|
||||||
|
public class TlsReloadTests
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Generates a self-signed X509Certificate2 for testing.
|
||||||
|
/// </summary>
|
||||||
|
private static X509Certificate2 GenerateSelfSignedCert(string cn = "test")
|
||||||
|
{
|
||||||
|
using var rsa = RSA.Create(2048);
|
||||||
|
var req = new CertificateRequest($"CN={cn}", rsa, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
|
||||||
|
var cert = req.CreateSelfSigned(DateTimeOffset.UtcNow, DateTimeOffset.UtcNow.AddDays(1));
|
||||||
|
// Export and re-import to ensure the cert has the private key bound
|
||||||
|
return X509CertificateLoader.LoadPkcs12(cert.Export(X509ContentType.Pkcs12), null);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void CertificateProvider_GetCurrentCertificate_ReturnsInitialCert()
|
||||||
|
{
|
||||||
|
// Go parity: TestConfigReloadRotateTLS — initial cert is usable
|
||||||
|
var cert = GenerateSelfSignedCert("initial");
|
||||||
|
using var provider = new TlsCertificateProvider(cert);
|
||||||
|
|
||||||
|
var current = provider.GetCurrentCertificate();
|
||||||
|
|
||||||
|
current.ShouldNotBeNull();
|
||||||
|
current.Subject.ShouldContain("initial");
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void CertificateProvider_SwapCertificate_ReturnsOldCert()
|
||||||
|
{
|
||||||
|
// Go parity: TestConfigReloadRotateTLS — cert rotation returns old cert
|
||||||
|
var cert1 = GenerateSelfSignedCert("cert1");
|
||||||
|
var cert2 = GenerateSelfSignedCert("cert2");
|
||||||
|
using var provider = new TlsCertificateProvider(cert1);
|
||||||
|
|
||||||
|
var old = provider.SwapCertificate(cert2);
|
||||||
|
|
||||||
|
old.ShouldNotBeNull();
|
||||||
|
old.Subject.ShouldContain("cert1");
|
||||||
|
old.Dispose();
|
||||||
|
|
||||||
|
var current = provider.GetCurrentCertificate();
|
||||||
|
current.ShouldNotBeNull();
|
||||||
|
current.Subject.ShouldContain("cert2");
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void CertificateProvider_SwapCertificate_IncrementsVersion()
|
||||||
|
{
|
||||||
|
// Go parity: TestConfigReloadRotateTLS — version tracking for reload detection
|
||||||
|
var cert1 = GenerateSelfSignedCert("v1");
|
||||||
|
var cert2 = GenerateSelfSignedCert("v2");
|
||||||
|
using var provider = new TlsCertificateProvider(cert1);
|
||||||
|
|
||||||
|
var v0 = provider.Version;
|
||||||
|
v0.ShouldBe(0);
|
||||||
|
|
||||||
|
provider.SwapCertificate(cert2)?.Dispose();
|
||||||
|
provider.Version.ShouldBe(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void CertificateProvider_MultipleSwa_NewConnectionsGetLatest()
|
||||||
|
{
|
||||||
|
// Go parity: TestConfigReloadRotateTLS — multiple rotations, each new
|
||||||
|
// handshake gets the latest certificate
|
||||||
|
var cert1 = GenerateSelfSignedCert("round1");
|
||||||
|
var cert2 = GenerateSelfSignedCert("round2");
|
||||||
|
var cert3 = GenerateSelfSignedCert("round3");
|
||||||
|
using var provider = new TlsCertificateProvider(cert1);
|
||||||
|
|
||||||
|
provider.GetCurrentCertificate()!.Subject.ShouldContain("round1");
|
||||||
|
|
||||||
|
provider.SwapCertificate(cert2)?.Dispose();
|
||||||
|
provider.GetCurrentCertificate()!.Subject.ShouldContain("round2");
|
||||||
|
|
||||||
|
provider.SwapCertificate(cert3)?.Dispose();
|
||||||
|
provider.GetCurrentCertificate()!.Subject.ShouldContain("round3");
|
||||||
|
|
||||||
|
provider.Version.ShouldBe(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task CertificateProvider_ConcurrentAccess_IsThreadSafe()
|
||||||
|
{
|
||||||
|
// Go parity: TestConfigReloadRotateTLS — cert swap must be safe under
|
||||||
|
// concurrent connection accept
|
||||||
|
var cert1 = GenerateSelfSignedCert("concurrent1");
|
||||||
|
using var provider = new TlsCertificateProvider(cert1);
|
||||||
|
|
||||||
|
var tasks = new Task[50];
|
||||||
|
for (int i = 0; i < tasks.Length; i++)
|
||||||
|
{
|
||||||
|
var idx = i;
|
||||||
|
tasks[i] = Task.Run(() =>
|
||||||
|
{
|
||||||
|
if (idx % 2 == 0)
|
||||||
|
{
|
||||||
|
// Readers — simulate new connections getting current cert
|
||||||
|
var c = provider.GetCurrentCertificate();
|
||||||
|
c.ShouldNotBeNull();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Writers — simulate reload
|
||||||
|
var newCert = GenerateSelfSignedCert($"swap-{idx}");
|
||||||
|
provider.SwapCertificate(newCert)?.Dispose();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
await Task.WhenAll(tasks);
|
||||||
|
|
||||||
|
// After all swaps, the provider should still return a valid cert
|
||||||
|
provider.GetCurrentCertificate().ShouldNotBeNull();
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ReloadTlsCertificate_NullProvider_ReturnsFalse()
|
||||||
|
{
|
||||||
|
// Edge case: server running without TLS
|
||||||
|
var opts = new NatsOptions();
|
||||||
|
var result = ConfigReloader.ReloadTlsCertificate(opts, null);
|
||||||
|
result.ShouldBeFalse();
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ReloadTlsCertificate_NoTlsConfig_ReturnsFalse()
|
||||||
|
{
|
||||||
|
// Edge case: provider exists but options don't have TLS paths
|
||||||
|
var cert = GenerateSelfSignedCert("no-tls");
|
||||||
|
using var provider = new TlsCertificateProvider(cert);
|
||||||
|
|
||||||
|
var opts = new NatsOptions(); // HasTls is false (no TlsCert/TlsKey)
|
||||||
|
var result = ConfigReloader.ReloadTlsCertificate(opts, provider);
|
||||||
|
result.ShouldBeFalse();
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ReloadTlsCertificate_WithCertFiles_SwapsCertAndSslOptions()
|
||||||
|
{
|
||||||
|
// Go parity: TestConfigReloadRotateTLS — full reload with cert files.
|
||||||
|
// Write a self-signed cert to temp files and verify the provider loads it.
|
||||||
|
var tempDir = Path.Combine(Path.GetTempPath(), $"nats-tls-test-{Guid.NewGuid():N}");
|
||||||
|
Directory.CreateDirectory(tempDir);
|
||||||
|
try
|
||||||
|
{
|
||||||
|
var certPath = Path.Combine(tempDir, "cert.pem");
|
||||||
|
var keyPath = Path.Combine(tempDir, "key.pem");
|
||||||
|
WriteSelfSignedCertFiles(certPath, keyPath, "reload-test");
|
||||||
|
|
||||||
|
// Create provider with initial cert
|
||||||
|
var initialCert = GenerateSelfSignedCert("initial");
|
||||||
|
using var provider = new TlsCertificateProvider(initialCert);
|
||||||
|
|
||||||
|
var opts = new NatsOptions { TlsCert = certPath, TlsKey = keyPath };
|
||||||
|
var result = ConfigReloader.ReloadTlsCertificate(opts, provider);
|
||||||
|
|
||||||
|
result.ShouldBeTrue();
|
||||||
|
provider.Version.ShouldBeGreaterThan(0);
|
||||||
|
provider.GetCurrentCertificate().ShouldNotBeNull();
|
||||||
|
provider.GetCurrentSslOptions().ShouldNotBeNull();
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
Directory.Delete(tempDir, recursive: true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ConfigDiff_DetectsTlsChanges()
|
||||||
|
{
|
||||||
|
// Go parity: TestConfigReloadEnableTLS, TestConfigReloadDisableTLS
|
||||||
|
// Verify that diff detects TLS option changes and flags them
|
||||||
|
var oldOpts = new NatsOptions { TlsCert = "/old/cert.pem", TlsKey = "/old/key.pem" };
|
||||||
|
var newOpts = new NatsOptions { TlsCert = "/new/cert.pem", TlsKey = "/new/key.pem" };
|
||||||
|
|
||||||
|
var changes = ConfigReloader.Diff(oldOpts, newOpts);
|
||||||
|
|
||||||
|
changes.Count.ShouldBeGreaterThan(0);
|
||||||
|
changes.ShouldContain(c => c.IsTlsChange && c.Name == "TlsCert");
|
||||||
|
changes.ShouldContain(c => c.IsTlsChange && c.Name == "TlsKey");
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ConfigDiff_TlsVerifyChange_IsTlsChange()
|
||||||
|
{
|
||||||
|
// Go parity: TestConfigReloadRotateTLS — enabling client verification
|
||||||
|
var oldOpts = new NatsOptions { TlsVerify = false };
|
||||||
|
var newOpts = new NatsOptions { TlsVerify = true };
|
||||||
|
|
||||||
|
var changes = ConfigReloader.Diff(oldOpts, newOpts);
|
||||||
|
|
||||||
|
changes.ShouldContain(c => c.IsTlsChange && c.Name == "TlsVerify");
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ConfigApplyResult_ReportsTlsChanges()
|
||||||
|
{
|
||||||
|
// Verify ApplyDiff flags TLS changes correctly
|
||||||
|
var changes = new List<IConfigChange>
|
||||||
|
{
|
||||||
|
new ConfigChange("TlsCert", isTlsChange: true),
|
||||||
|
new ConfigChange("TlsKey", isTlsChange: true),
|
||||||
|
};
|
||||||
|
var oldOpts = new NatsOptions();
|
||||||
|
var newOpts = new NatsOptions();
|
||||||
|
|
||||||
|
var result = ConfigReloader.ApplyDiff(changes, oldOpts, newOpts);
|
||||||
|
|
||||||
|
result.HasTlsChanges.ShouldBeTrue();
|
||||||
|
result.ChangeCount.ShouldBe(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Helper to write a self-signed certificate to PEM files.
|
||||||
|
/// </summary>
|
||||||
|
private static void WriteSelfSignedCertFiles(string certPath, string keyPath, string cn)
|
||||||
|
{
|
||||||
|
using var rsa = RSA.Create(2048);
|
||||||
|
var req = new CertificateRequest($"CN={cn}", rsa, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
|
||||||
|
var cert = req.CreateSelfSigned(DateTimeOffset.UtcNow, DateTimeOffset.UtcNow.AddDays(1));
|
||||||
|
|
||||||
|
File.WriteAllText(certPath, cert.ExportCertificatePem());
|
||||||
|
File.WriteAllText(keyPath, rsa.ExportRSAPrivateKeyPem());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,327 @@
|
|||||||
|
// Tests for WebSocket permessage-deflate parameter negotiation (E10).
|
||||||
|
// Verifies RFC 7692 extension parameter parsing and negotiation during
|
||||||
|
// WebSocket upgrade handshake.
|
||||||
|
// Reference: golang/nats-server/server/websocket.go — wsPMCExtensionSupport (line 885).
|
||||||
|
|
||||||
|
using System.Text;
|
||||||
|
using NATS.Server.WebSocket;
|
||||||
|
|
||||||
|
namespace NATS.Server.Tests.WebSocket;
|
||||||
|
|
||||||
|
public class WsCompressionNegotiationTests
|
||||||
|
{
|
||||||
|
// ─── WsDeflateNegotiator.Negotiate tests ──────────────────────────────
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void Negotiate_NullHeader_ReturnsNull()
|
||||||
|
{
|
||||||
|
// Go parity: wsPMCExtensionSupport — no extension header means no compression
|
||||||
|
var result = WsDeflateNegotiator.Negotiate(null);
|
||||||
|
result.ShouldBeNull();
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void Negotiate_EmptyHeader_ReturnsNull()
|
||||||
|
{
|
||||||
|
var result = WsDeflateNegotiator.Negotiate("");
|
||||||
|
result.ShouldBeNull();
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void Negotiate_NoPermessageDeflate_ReturnsNull()
|
||||||
|
{
|
||||||
|
var result = WsDeflateNegotiator.Negotiate("x-webkit-deflate-frame");
|
||||||
|
result.ShouldBeNull();
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void Negotiate_BarePermessageDeflate_ReturnsDefaults()
|
||||||
|
{
|
||||||
|
// Go parity: wsPMCExtensionSupport — basic extension without parameters
|
||||||
|
var result = WsDeflateNegotiator.Negotiate("permessage-deflate");
|
||||||
|
|
||||||
|
result.ShouldNotBeNull();
|
||||||
|
// NATS always enforces no_context_takeover
|
||||||
|
result.Value.ServerNoContextTakeover.ShouldBeTrue();
|
||||||
|
result.Value.ClientNoContextTakeover.ShouldBeTrue();
|
||||||
|
result.Value.ServerMaxWindowBits.ShouldBe(15);
|
||||||
|
result.Value.ClientMaxWindowBits.ShouldBe(15);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void Negotiate_WithServerNoContextTakeover()
|
||||||
|
{
|
||||||
|
// Go parity: wsPMCExtensionSupport — server_no_context_takeover parameter
|
||||||
|
var result = WsDeflateNegotiator.Negotiate("permessage-deflate; server_no_context_takeover");
|
||||||
|
|
||||||
|
result.ShouldNotBeNull();
|
||||||
|
result.Value.ServerNoContextTakeover.ShouldBeTrue();
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void Negotiate_WithClientNoContextTakeover()
|
||||||
|
{
|
||||||
|
// Go parity: wsPMCExtensionSupport — client_no_context_takeover parameter
|
||||||
|
var result = WsDeflateNegotiator.Negotiate("permessage-deflate; client_no_context_takeover");
|
||||||
|
|
||||||
|
result.ShouldNotBeNull();
|
||||||
|
result.Value.ClientNoContextTakeover.ShouldBeTrue();
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void Negotiate_WithBothNoContextTakeover()
|
||||||
|
{
|
||||||
|
// Go parity: wsPMCExtensionSupport — both no_context_takeover parameters
|
||||||
|
var result = WsDeflateNegotiator.Negotiate(
|
||||||
|
"permessage-deflate; server_no_context_takeover; client_no_context_takeover");
|
||||||
|
|
||||||
|
result.ShouldNotBeNull();
|
||||||
|
result.Value.ServerNoContextTakeover.ShouldBeTrue();
|
||||||
|
result.Value.ClientNoContextTakeover.ShouldBeTrue();
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void Negotiate_WithServerMaxWindowBits()
|
||||||
|
{
|
||||||
|
// RFC 7692 Section 7.1.2.1: server_max_window_bits parameter
|
||||||
|
var result = WsDeflateNegotiator.Negotiate("permessage-deflate; server_max_window_bits=10");
|
||||||
|
|
||||||
|
result.ShouldNotBeNull();
|
||||||
|
result.Value.ServerMaxWindowBits.ShouldBe(10);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void Negotiate_WithClientMaxWindowBits_Value()
|
||||||
|
{
|
||||||
|
// RFC 7692 Section 7.1.2.2: client_max_window_bits with explicit value
|
||||||
|
var result = WsDeflateNegotiator.Negotiate("permessage-deflate; client_max_window_bits=12");
|
||||||
|
|
||||||
|
result.ShouldNotBeNull();
|
||||||
|
result.Value.ClientMaxWindowBits.ShouldBe(12);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void Negotiate_WithClientMaxWindowBits_NoValue()
|
||||||
|
{
|
||||||
|
// RFC 7692 Section 7.1.2.2: client_max_window_bits with no value means
|
||||||
|
// client supports any value 8-15; defaults to 15
|
||||||
|
var result = WsDeflateNegotiator.Negotiate("permessage-deflate; client_max_window_bits");
|
||||||
|
|
||||||
|
result.ShouldNotBeNull();
|
||||||
|
result.Value.ClientMaxWindowBits.ShouldBe(15);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void Negotiate_WindowBits_ClampedToValidRange()
|
||||||
|
{
|
||||||
|
// RFC 7692: valid range is 8-15
|
||||||
|
var result = WsDeflateNegotiator.Negotiate(
|
||||||
|
"permessage-deflate; server_max_window_bits=5; client_max_window_bits=20");
|
||||||
|
|
||||||
|
result.ShouldNotBeNull();
|
||||||
|
result.Value.ServerMaxWindowBits.ShouldBe(8); // Clamped up from 5
|
||||||
|
result.Value.ClientMaxWindowBits.ShouldBe(15); // Clamped down from 20
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void Negotiate_FullParameters()
|
||||||
|
{
|
||||||
|
// All parameters specified
|
||||||
|
var result = WsDeflateNegotiator.Negotiate(
|
||||||
|
"permessage-deflate; server_no_context_takeover; client_no_context_takeover; server_max_window_bits=9; client_max_window_bits=11");
|
||||||
|
|
||||||
|
result.ShouldNotBeNull();
|
||||||
|
result.Value.ServerNoContextTakeover.ShouldBeTrue();
|
||||||
|
result.Value.ClientNoContextTakeover.ShouldBeTrue();
|
||||||
|
result.Value.ServerMaxWindowBits.ShouldBe(9);
|
||||||
|
result.Value.ClientMaxWindowBits.ShouldBe(11);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void Negotiate_CaseInsensitive()
|
||||||
|
{
|
||||||
|
// RFC 7692 extension names are case-insensitive
|
||||||
|
var result = WsDeflateNegotiator.Negotiate("Permessage-Deflate; Server_No_Context_Takeover");
|
||||||
|
|
||||||
|
result.ShouldNotBeNull();
|
||||||
|
result.Value.ServerNoContextTakeover.ShouldBeTrue();
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void Negotiate_MultipleExtensions_PicksDeflate()
|
||||||
|
{
|
||||||
|
// Header may contain multiple comma-separated extensions
|
||||||
|
var result = WsDeflateNegotiator.Negotiate(
|
||||||
|
"x-custom-ext, permessage-deflate; server_no_context_takeover, other-ext");
|
||||||
|
|
||||||
|
result.ShouldNotBeNull();
|
||||||
|
result.Value.ServerNoContextTakeover.ShouldBeTrue();
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void Negotiate_WhitespaceHandling()
|
||||||
|
{
|
||||||
|
// Extra whitespace around parameters
|
||||||
|
var result = WsDeflateNegotiator.Negotiate(
|
||||||
|
" permessage-deflate ; server_no_context_takeover ; client_max_window_bits = 10 ");
|
||||||
|
|
||||||
|
result.ShouldNotBeNull();
|
||||||
|
result.Value.ServerNoContextTakeover.ShouldBeTrue();
|
||||||
|
result.Value.ClientMaxWindowBits.ShouldBe(10);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── NatsAlwaysEnforcesNoContextTakeover ─────────────────────────────
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void Negotiate_AlwaysEnforcesNoContextTakeover()
|
||||||
|
{
|
||||||
|
// NATS Go server always returns server_no_context_takeover and
|
||||||
|
// client_no_context_takeover regardless of what the client requests
|
||||||
|
var result = WsDeflateNegotiator.Negotiate("permessage-deflate");
|
||||||
|
|
||||||
|
result.ShouldNotBeNull();
|
||||||
|
result.Value.ServerNoContextTakeover.ShouldBeTrue();
|
||||||
|
result.Value.ClientNoContextTakeover.ShouldBeTrue();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── WsDeflateParams.ToResponseHeaderValue tests ────────────────────
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void DefaultParams_ResponseHeader_ContainsNoContextTakeover()
|
||||||
|
{
|
||||||
|
var header = WsDeflateParams.Default.ToResponseHeaderValue();
|
||||||
|
|
||||||
|
header.ShouldContain("permessage-deflate");
|
||||||
|
header.ShouldContain("server_no_context_takeover");
|
||||||
|
header.ShouldContain("client_no_context_takeover");
|
||||||
|
header.ShouldNotContain("server_max_window_bits");
|
||||||
|
header.ShouldNotContain("client_max_window_bits");
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void CustomWindowBits_ResponseHeader_IncludesValues()
|
||||||
|
{
|
||||||
|
var params_ = new WsDeflateParams(
|
||||||
|
ServerNoContextTakeover: true,
|
||||||
|
ClientNoContextTakeover: true,
|
||||||
|
ServerMaxWindowBits: 10,
|
||||||
|
ClientMaxWindowBits: 12);
|
||||||
|
|
||||||
|
var header = params_.ToResponseHeaderValue();
|
||||||
|
|
||||||
|
header.ShouldContain("server_max_window_bits=10");
|
||||||
|
header.ShouldContain("client_max_window_bits=12");
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void DefaultWindowBits_ResponseHeader_OmitsValues()
|
||||||
|
{
|
||||||
|
// RFC 7692: window bits of 15 is the default and should not be sent
|
||||||
|
var params_ = new WsDeflateParams(
|
||||||
|
ServerNoContextTakeover: true,
|
||||||
|
ClientNoContextTakeover: true,
|
||||||
|
ServerMaxWindowBits: 15,
|
||||||
|
ClientMaxWindowBits: 15);
|
||||||
|
|
||||||
|
var header = params_.ToResponseHeaderValue();
|
||||||
|
|
||||||
|
header.ShouldNotContain("server_max_window_bits");
|
||||||
|
header.ShouldNotContain("client_max_window_bits");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Integration with WsUpgrade ─────────────────────────────────────
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Upgrade_WithDeflateParams_NegotiatesCompression()
|
||||||
|
{
|
||||||
|
// Go parity: WebSocket upgrade with permessage-deflate parameters
|
||||||
|
var request = BuildValidRequest(extraHeaders:
|
||||||
|
"Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover; server_max_window_bits=10\r\n");
|
||||||
|
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||||
|
|
||||||
|
var opts = new WebSocketOptions { NoTls = true, Compression = true };
|
||||||
|
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
|
||||||
|
|
||||||
|
result.Success.ShouldBeTrue();
|
||||||
|
result.Compress.ShouldBeTrue();
|
||||||
|
result.DeflateParams.ShouldNotBeNull();
|
||||||
|
result.DeflateParams.Value.ServerNoContextTakeover.ShouldBeTrue();
|
||||||
|
result.DeflateParams.Value.ClientNoContextTakeover.ShouldBeTrue();
|
||||||
|
result.DeflateParams.Value.ServerMaxWindowBits.ShouldBe(10);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Upgrade_WithDeflateParams_ResponseIncludesNegotiatedParams()
|
||||||
|
{
|
||||||
|
var request = BuildValidRequest(extraHeaders:
|
||||||
|
"Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_max_window_bits=10\r\n");
|
||||||
|
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||||
|
|
||||||
|
var opts = new WebSocketOptions { NoTls = true, Compression = true };
|
||||||
|
await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
|
||||||
|
|
||||||
|
var response = ReadResponse(outputStream);
|
||||||
|
response.ShouldContain("permessage-deflate");
|
||||||
|
response.ShouldContain("server_no_context_takeover");
|
||||||
|
response.ShouldContain("client_no_context_takeover");
|
||||||
|
response.ShouldContain("client_max_window_bits=10");
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Upgrade_CompressionDisabled_NoDeflateParams()
|
||||||
|
{
|
||||||
|
var request = BuildValidRequest(extraHeaders:
|
||||||
|
"Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover\r\n");
|
||||||
|
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||||
|
|
||||||
|
var opts = new WebSocketOptions { NoTls = true, Compression = false };
|
||||||
|
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
|
||||||
|
|
||||||
|
result.Success.ShouldBeTrue();
|
||||||
|
result.Compress.ShouldBeFalse();
|
||||||
|
result.DeflateParams.ShouldBeNull();
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Upgrade_NoExtensionHeader_NoCompression()
|
||||||
|
{
|
||||||
|
var request = BuildValidRequest();
|
||||||
|
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||||
|
|
||||||
|
var opts = new WebSocketOptions { NoTls = true, Compression = true };
|
||||||
|
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
|
||||||
|
|
||||||
|
result.Success.ShouldBeTrue();
|
||||||
|
result.Compress.ShouldBeFalse();
|
||||||
|
result.DeflateParams.ShouldBeNull();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Helpers ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
private static string BuildValidRequest(string path = "/", string? extraHeaders = null)
|
||||||
|
{
|
||||||
|
var sb = new StringBuilder();
|
||||||
|
sb.Append($"GET {path} HTTP/1.1\r\n");
|
||||||
|
sb.Append("Host: localhost:4222\r\n");
|
||||||
|
sb.Append("Upgrade: websocket\r\n");
|
||||||
|
sb.Append("Connection: Upgrade\r\n");
|
||||||
|
sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
|
||||||
|
sb.Append("Sec-WebSocket-Version: 13\r\n");
|
||||||
|
if (extraHeaders != null)
|
||||||
|
sb.Append(extraHeaders);
|
||||||
|
sb.Append("\r\n");
|
||||||
|
return sb.ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest)
|
||||||
|
{
|
||||||
|
var inputBytes = Encoding.ASCII.GetBytes(httpRequest);
|
||||||
|
return (new MemoryStream(inputBytes), new MemoryStream());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static string ReadResponse(MemoryStream output)
|
||||||
|
{
|
||||||
|
output.Position = 0;
|
||||||
|
return Encoding.ASCII.GetString(output.ToArray());
|
||||||
|
}
|
||||||
|
}
|
||||||
316
tests/NATS.Server.Tests/WebSocket/WsJwtAuthTests.cs
Normal file
316
tests/NATS.Server.Tests/WebSocket/WsJwtAuthTests.cs
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
// Tests for WebSocket JWT authentication during upgrade (E11).
|
||||||
|
// Verifies JWT extraction from Authorization header, cookie, and query parameter.
|
||||||
|
// Reference: golang/nats-server/server/websocket.go — cookie JWT extraction (line 856),
|
||||||
|
// websocket_test.go — TestWSReloadTLSConfig (line 4066).
|
||||||
|
|
||||||
|
using System.Text;
|
||||||
|
using NATS.Server.WebSocket;
|
||||||
|
|
||||||
|
namespace NATS.Server.Tests.WebSocket;
|
||||||
|
|
||||||
|
public class WsJwtAuthTests
|
||||||
|
{
|
||||||
|
// ─── Authorization header JWT extraction ─────────────────────────────
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Upgrade_AuthorizationBearerHeader_ExtractsJwt()
|
||||||
|
{
|
||||||
|
// JWT from Authorization: Bearer <token> header (standard HTTP auth)
|
||||||
|
var jwt = "eyJhbGciOiJFZDI1NTE5IiwidHlwIjoiSldUIn0.test-payload.test-sig";
|
||||||
|
var request = BuildValidRequest(extraHeaders:
|
||||||
|
$"Authorization: Bearer {jwt}\r\n");
|
||||||
|
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||||
|
|
||||||
|
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||||
|
|
||||||
|
result.Success.ShouldBeTrue();
|
||||||
|
result.Jwt.ShouldBe(jwt);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Upgrade_AuthorizationBearerCaseInsensitive()
|
||||||
|
{
|
||||||
|
// RFC 7235: "bearer" scheme is case-insensitive
|
||||||
|
var jwt = "my-jwt-token-123";
|
||||||
|
var request = BuildValidRequest(extraHeaders:
|
||||||
|
$"Authorization: bearer {jwt}\r\n");
|
||||||
|
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||||
|
|
||||||
|
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||||
|
|
||||||
|
result.Success.ShouldBeTrue();
|
||||||
|
result.Jwt.ShouldBe(jwt);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Upgrade_AuthorizationBareToken_ExtractsJwt()
|
||||||
|
{
|
||||||
|
// Some clients send the token directly without "Bearer" prefix
|
||||||
|
var jwt = "raw-jwt-token-456";
|
||||||
|
var request = BuildValidRequest(extraHeaders:
|
||||||
|
$"Authorization: {jwt}\r\n");
|
||||||
|
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||||
|
|
||||||
|
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||||
|
|
||||||
|
result.Success.ShouldBeTrue();
|
||||||
|
result.Jwt.ShouldBe(jwt);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Cookie JWT extraction ──────────────────────────────────────────
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Upgrade_JwtCookie_ExtractsJwt()
|
||||||
|
{
|
||||||
|
// Go parity: websocket.go line 856 — JWT from configured cookie name
|
||||||
|
var jwt = "cookie-jwt-token-789";
|
||||||
|
var request = BuildValidRequest(extraHeaders:
|
||||||
|
$"Cookie: jwt={jwt}; other=value\r\n");
|
||||||
|
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||||
|
|
||||||
|
var opts = new WebSocketOptions { NoTls = true, JwtCookie = "jwt" };
|
||||||
|
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
|
||||||
|
|
||||||
|
result.Success.ShouldBeTrue();
|
||||||
|
result.CookieJwt.ShouldBe(jwt);
|
||||||
|
// Cookie JWT is used as fallback when no Authorization header is present
|
||||||
|
result.Jwt.ShouldBe(jwt);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Upgrade_AuthorizationHeader_TakesPriorityOverCookie()
|
||||||
|
{
|
||||||
|
// Authorization header has higher priority than cookie
|
||||||
|
var headerJwt = "auth-header-jwt";
|
||||||
|
var cookieJwt = "cookie-jwt";
|
||||||
|
var request = BuildValidRequest(extraHeaders:
|
||||||
|
$"Authorization: Bearer {headerJwt}\r\n" +
|
||||||
|
$"Cookie: jwt={cookieJwt}\r\n");
|
||||||
|
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||||
|
|
||||||
|
var opts = new WebSocketOptions { NoTls = true, JwtCookie = "jwt" };
|
||||||
|
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
|
||||||
|
|
||||||
|
result.Success.ShouldBeTrue();
|
||||||
|
result.Jwt.ShouldBe(headerJwt);
|
||||||
|
result.CookieJwt.ShouldBe(cookieJwt); // Cookie value is still preserved
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Query parameter JWT extraction ─────────────────────────────────
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Upgrade_QueryParamJwt_ExtractsJwt()
|
||||||
|
{
|
||||||
|
// JWT from ?jwt= query parameter (useful for browser clients)
|
||||||
|
var jwt = "query-jwt-token-abc";
|
||||||
|
var request = BuildValidRequest(path: $"/?jwt={jwt}");
|
||||||
|
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||||
|
|
||||||
|
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||||
|
|
||||||
|
result.Success.ShouldBeTrue();
|
||||||
|
result.Jwt.ShouldBe(jwt);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Upgrade_QueryParamJwt_UrlEncoded()
|
||||||
|
{
|
||||||
|
// JWT value may be URL-encoded
|
||||||
|
var jwt = "eyJ0eXAiOiJKV1QifQ.payload.sig";
|
||||||
|
var encoded = Uri.EscapeDataString(jwt);
|
||||||
|
var request = BuildValidRequest(path: $"/?jwt={encoded}");
|
||||||
|
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||||
|
|
||||||
|
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||||
|
|
||||||
|
result.Success.ShouldBeTrue();
|
||||||
|
result.Jwt.ShouldBe(jwt);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Upgrade_AuthorizationHeader_TakesPriorityOverQueryParam()
|
||||||
|
{
|
||||||
|
// Authorization header > query parameter
|
||||||
|
var headerJwt = "auth-header-jwt";
|
||||||
|
var queryJwt = "query-jwt";
|
||||||
|
var request = BuildValidRequest(
|
||||||
|
path: $"/?jwt={queryJwt}",
|
||||||
|
extraHeaders: $"Authorization: Bearer {headerJwt}\r\n");
|
||||||
|
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||||
|
|
||||||
|
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||||
|
|
||||||
|
result.Success.ShouldBeTrue();
|
||||||
|
result.Jwt.ShouldBe(headerJwt);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Upgrade_Cookie_TakesPriorityOverQueryParam()
|
||||||
|
{
|
||||||
|
// Cookie > query parameter
|
||||||
|
var cookieJwt = "cookie-jwt";
|
||||||
|
var queryJwt = "query-jwt";
|
||||||
|
var request = BuildValidRequest(
|
||||||
|
path: $"/?jwt={queryJwt}",
|
||||||
|
extraHeaders: $"Cookie: jwt_token={cookieJwt}\r\n");
|
||||||
|
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||||
|
|
||||||
|
var opts = new WebSocketOptions { NoTls = true, JwtCookie = "jwt_token" };
|
||||||
|
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
|
||||||
|
|
||||||
|
result.Success.ShouldBeTrue();
|
||||||
|
result.Jwt.ShouldBe(cookieJwt);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── No JWT scenarios ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Upgrade_NoJwtAnywhere_JwtIsNull()
|
||||||
|
{
|
||||||
|
// No JWT in any source
|
||||||
|
var request = BuildValidRequest();
|
||||||
|
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||||
|
|
||||||
|
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||||
|
|
||||||
|
result.Success.ShouldBeTrue();
|
||||||
|
result.Jwt.ShouldBeNull();
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Upgrade_EmptyAuthorizationHeader_JwtIsEmpty()
|
||||||
|
{
|
||||||
|
// Empty authorization header should produce empty string (non-null)
|
||||||
|
var request = BuildValidRequest(extraHeaders: "Authorization: \r\n");
|
||||||
|
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||||
|
|
||||||
|
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||||
|
|
||||||
|
result.Success.ShouldBeTrue();
|
||||||
|
// Empty auth header is treated as null/no JWT
|
||||||
|
result.Jwt.ShouldBeNull();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── ExtractBearerToken unit tests ──────────────────────────────────
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ExtractBearerToken_BearerPrefix()
|
||||||
|
{
|
||||||
|
WsUpgrade.ExtractBearerToken("Bearer my-token").ShouldBe("my-token");
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ExtractBearerToken_BearerPrefixLowerCase()
|
||||||
|
{
|
||||||
|
WsUpgrade.ExtractBearerToken("bearer my-token").ShouldBe("my-token");
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ExtractBearerToken_BareToken()
|
||||||
|
{
|
||||||
|
WsUpgrade.ExtractBearerToken("raw-token").ShouldBe("raw-token");
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ExtractBearerToken_Null()
|
||||||
|
{
|
||||||
|
WsUpgrade.ExtractBearerToken(null).ShouldBeNull();
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ExtractBearerToken_Empty()
|
||||||
|
{
|
||||||
|
WsUpgrade.ExtractBearerToken("").ShouldBeNull();
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ExtractBearerToken_Whitespace()
|
||||||
|
{
|
||||||
|
WsUpgrade.ExtractBearerToken(" ").ShouldBeNull();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── ParseQueryString unit tests ────────────────────────────────────
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ParseQueryString_SingleParam()
|
||||||
|
{
|
||||||
|
var result = WsUpgrade.ParseQueryString("?jwt=token123");
|
||||||
|
result["jwt"].ShouldBe("token123");
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ParseQueryString_MultipleParams()
|
||||||
|
{
|
||||||
|
var result = WsUpgrade.ParseQueryString("?jwt=token&user=admin");
|
||||||
|
result["jwt"].ShouldBe("token");
|
||||||
|
result["user"].ShouldBe("admin");
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ParseQueryString_UrlEncoded()
|
||||||
|
{
|
||||||
|
var result = WsUpgrade.ParseQueryString("?jwt=a%20b%3Dc");
|
||||||
|
result["jwt"].ShouldBe("a b=c");
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ParseQueryString_NoQuestionMark()
|
||||||
|
{
|
||||||
|
var result = WsUpgrade.ParseQueryString("jwt=token");
|
||||||
|
result["jwt"].ShouldBe("token");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── FailUnauthorizedAsync ──────────────────────────────────────────
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task FailUnauthorizedAsync_Returns401()
|
||||||
|
{
|
||||||
|
var output = new MemoryStream();
|
||||||
|
var result = await WsUpgrade.FailUnauthorizedAsync(output, "invalid JWT");
|
||||||
|
|
||||||
|
result.Success.ShouldBeFalse();
|
||||||
|
output.Position = 0;
|
||||||
|
var response = Encoding.ASCII.GetString(output.ToArray());
|
||||||
|
response.ShouldContain("401");
|
||||||
|
response.ShouldContain("invalid JWT");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Query param path routing still works with query strings ────────
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Upgrade_PathWithQueryParam_StillRoutesCorrectly()
|
||||||
|
{
|
||||||
|
// /leafnode?jwt=token should still detect as leaf kind
|
||||||
|
var request = BuildValidRequest(path: "/leafnode?jwt=my-token");
|
||||||
|
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||||
|
|
||||||
|
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||||
|
|
||||||
|
result.Success.ShouldBeTrue();
|
||||||
|
result.Kind.ShouldBe(WsClientKind.Leaf);
|
||||||
|
result.Jwt.ShouldBe("my-token");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Helpers ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
private static string BuildValidRequest(string path = "/", string? extraHeaders = null)
|
||||||
|
{
|
||||||
|
var sb = new StringBuilder();
|
||||||
|
sb.Append($"GET {path} HTTP/1.1\r\n");
|
||||||
|
sb.Append("Host: localhost:4222\r\n");
|
||||||
|
sb.Append("Upgrade: websocket\r\n");
|
||||||
|
sb.Append("Connection: Upgrade\r\n");
|
||||||
|
sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
|
||||||
|
sb.Append("Sec-WebSocket-Version: 13\r\n");
|
||||||
|
if (extraHeaders != null)
|
||||||
|
sb.Append(extraHeaders);
|
||||||
|
sb.Append("\r\n");
|
||||||
|
return sb.ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest)
|
||||||
|
{
|
||||||
|
var inputBytes = Encoding.ASCII.GetBytes(httpRequest);
|
||||||
|
return (new MemoryStream(inputBytes), new MemoryStream());
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user