feat(config+ws): add TLS cert reload, WS compression negotiation, WS JWT auth (E9+E10+E11)
E9: TLS Certificate Reload - Add TlsCertificateProvider with Interlocked-swappable cert field - New connections get current cert, existing connections keep theirs - ConfigReloader.ReloadTlsCertificate rebuilds SslServerAuthenticationOptions - NatsServer.ApplyConfigChanges triggers TLS reload on TLS config changes - 11 tests covering cert swap, versioning, thread safety, config diff E10: WebSocket Compression Negotiation (RFC 7692) - Add WsDeflateNegotiator to parse Sec-WebSocket-Extensions parameters - Parse server_no_context_takeover, client_no_context_takeover, server_max_window_bits, client_max_window_bits - WsDeflateParams record struct with ToResponseHeaderValue() - NATS always enforces no_context_takeover (matching Go server) - WsUpgrade returns negotiated WsDeflateParams in upgrade result - 22 tests covering parameter parsing, clamping, response headers E11: WebSocket JWT Authentication - Extract JWT from Authorization header (Bearer token), cookie, or ?jwt= query param - Priority: Authorization header > cookie > query parameter - WsUpgrade.TryUpgradeAsync now parses query string from request URI - Add FailUnauthorizedAsync for 401 responses - 24 tests covering all JWT extraction sources and priority ordering
This commit is contained in:
@@ -2,6 +2,146 @@ using System.IO.Compression;
|
||||
|
||||
namespace NATS.Server.WebSocket;
|
||||
|
||||
/// <summary>
|
||||
/// Negotiated permessage-deflate parameters per RFC 7692 Section 7.1.
|
||||
/// Captures the results of extension parameter negotiation during the
|
||||
/// WebSocket upgrade handshake.
|
||||
/// </summary>
|
||||
public readonly record struct WsDeflateParams(
|
||||
bool ServerNoContextTakeover,
|
||||
bool ClientNoContextTakeover,
|
||||
int ServerMaxWindowBits,
|
||||
int ClientMaxWindowBits)
|
||||
{
|
||||
/// <summary>
|
||||
/// Default parameters matching NATS Go server behavior:
|
||||
/// both sides use no_context_takeover, default 15-bit windows.
|
||||
/// </summary>
|
||||
public static readonly WsDeflateParams Default = new(
|
||||
ServerNoContextTakeover: true,
|
||||
ClientNoContextTakeover: true,
|
||||
ServerMaxWindowBits: 15,
|
||||
ClientMaxWindowBits: 15);
|
||||
|
||||
/// <summary>
|
||||
/// Builds the Sec-WebSocket-Extensions response header value from negotiated parameters.
|
||||
/// Only includes parameters that differ from the default RFC values.
|
||||
/// Reference: RFC 7692 Section 7.1.
|
||||
/// </summary>
|
||||
public string ToResponseHeaderValue()
|
||||
{
|
||||
var parts = new List<string> { WsConstants.PmcExtension };
|
||||
|
||||
if (ServerNoContextTakeover)
|
||||
parts.Add(WsConstants.PmcSrvNoCtx);
|
||||
if (ClientNoContextTakeover)
|
||||
parts.Add(WsConstants.PmcCliNoCtx);
|
||||
if (ServerMaxWindowBits is > 0 and < 15)
|
||||
parts.Add($"server_max_window_bits={ServerMaxWindowBits}");
|
||||
if (ClientMaxWindowBits is > 0 and < 15)
|
||||
parts.Add($"client_max_window_bits={ClientMaxWindowBits}");
|
||||
|
||||
return string.Join("; ", parts);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Parses and negotiates permessage-deflate extension parameters from the
|
||||
/// Sec-WebSocket-Extensions header per RFC 7692 Section 7.
|
||||
/// Reference: golang/nats-server/server/websocket.go — wsPMCExtensionSupport.
|
||||
/// </summary>
|
||||
public static class WsDeflateNegotiator
|
||||
{
|
||||
/// <summary>
|
||||
/// Parses the Sec-WebSocket-Extensions header value and negotiates
|
||||
/// permessage-deflate parameters. Returns null if no valid
|
||||
/// permessage-deflate offer is found.
|
||||
/// </summary>
|
||||
public static WsDeflateParams? Negotiate(string? extensionHeader)
|
||||
{
|
||||
if (string.IsNullOrEmpty(extensionHeader))
|
||||
return null;
|
||||
|
||||
// The header may contain multiple extensions separated by commas
|
||||
var extensions = extensionHeader.Split(',');
|
||||
foreach (var extension in extensions)
|
||||
{
|
||||
var trimmed = extension.Trim();
|
||||
var parts = trimmed.Split(';');
|
||||
|
||||
// First part must be the extension name
|
||||
if (parts.Length == 0)
|
||||
continue;
|
||||
|
||||
if (!string.Equals(parts[0].Trim(), WsConstants.PmcExtension, StringComparison.OrdinalIgnoreCase))
|
||||
continue;
|
||||
|
||||
// Found permessage-deflate — parse parameters
|
||||
// Note: serverNoCtx and clientNoCtx are parsed but always overridden
|
||||
// with true below (NATS enforces no_context_takeover for both sides).
|
||||
int serverMaxWindowBits = 15;
|
||||
int clientMaxWindowBits = 15;
|
||||
|
||||
for (int i = 1; i < parts.Length; i++)
|
||||
{
|
||||
var param = parts[i].Trim();
|
||||
|
||||
if (string.Equals(param, WsConstants.PmcSrvNoCtx, StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
// Parsed but overridden: NATS always enforces no_context_takeover.
|
||||
}
|
||||
else if (string.Equals(param, WsConstants.PmcCliNoCtx, StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
// Parsed but overridden: NATS always enforces no_context_takeover.
|
||||
}
|
||||
else if (param.StartsWith("server_max_window_bits", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
serverMaxWindowBits = ParseWindowBits(param, 15);
|
||||
}
|
||||
else if (param.StartsWith("client_max_window_bits", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
// client_max_window_bits with no value means the client supports it
|
||||
// and the server may choose a value. Per RFC 7692 Section 7.1.2.2,
|
||||
// an offer with just "client_max_window_bits" (no value) indicates
|
||||
// the client can accept any value 8-15.
|
||||
clientMaxWindowBits = ParseWindowBits(param, 15);
|
||||
}
|
||||
}
|
||||
|
||||
// NATS server always enforces no_context_takeover for both sides
|
||||
// (matching Go behavior) to avoid holding compressor state per connection.
|
||||
return new WsDeflateParams(
|
||||
ServerNoContextTakeover: true,
|
||||
ClientNoContextTakeover: true,
|
||||
ServerMaxWindowBits: ClampWindowBits(serverMaxWindowBits),
|
||||
ClientMaxWindowBits: ClampWindowBits(clientMaxWindowBits));
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private static int ParseWindowBits(string param, int defaultValue)
|
||||
{
|
||||
var eqIdx = param.IndexOf('=');
|
||||
if (eqIdx < 0)
|
||||
return defaultValue;
|
||||
|
||||
var valueStr = param[(eqIdx + 1)..].Trim();
|
||||
if (int.TryParse(valueStr, out var bits))
|
||||
return bits;
|
||||
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
private static int ClampWindowBits(int bits)
|
||||
{
|
||||
// RFC 7692: valid range is 8-15
|
||||
if (bits < 8) return 8;
|
||||
if (bits > 15) return 15;
|
||||
return bits;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// permessage-deflate compression/decompression for WebSocket frames (RFC 7692).
|
||||
/// Ported from golang/nats-server/server/websocket.go lines 403-440 and 1391-1466.
|
||||
|
||||
@@ -18,7 +18,7 @@ public static class WsUpgrade
|
||||
{
|
||||
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
||||
cts.CancelAfter(options.HandshakeTimeout);
|
||||
var (method, path, headers) = await ReadHttpRequestAsync(inputStream, cts.Token);
|
||||
var (method, path, queryString, headers) = await ReadHttpRequestAsync(inputStream, cts.Token);
|
||||
|
||||
if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase))
|
||||
return await FailAsync(outputStream, 405, "request method must be GET");
|
||||
@@ -57,15 +57,17 @@ public static class WsUpgrade
|
||||
return await FailAsync(outputStream, 403, $"origin not allowed: {originErr}");
|
||||
}
|
||||
|
||||
// Compression negotiation
|
||||
// Compression negotiation (RFC 7692)
|
||||
bool compress = options.Compression;
|
||||
WsDeflateParams? deflateParams = null;
|
||||
if (compress)
|
||||
{
|
||||
compress = headers.TryGetValue("Sec-WebSocket-Extensions", out var ext) &&
|
||||
ext.Contains(WsConstants.PmcExtension, StringComparison.OrdinalIgnoreCase);
|
||||
headers.TryGetValue("Sec-WebSocket-Extensions", out var ext);
|
||||
deflateParams = WsDeflateNegotiator.Negotiate(ext);
|
||||
compress = deflateParams != null;
|
||||
}
|
||||
|
||||
// No-masking support (leaf nodes only — browser clients must always mask)
|
||||
// No-masking support (leaf nodes only -- browser clients must always mask)
|
||||
bool noMasking = kind == WsClientKind.Leaf &&
|
||||
headers.TryGetValue(WsConstants.NoMaskingHeader, out var nmVal) &&
|
||||
string.Equals(nmVal.Trim(), WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase);
|
||||
@@ -95,6 +97,24 @@ public static class WsUpgrade
|
||||
if (options.TokenCookie != null) cookies.TryGetValue(options.TokenCookie, out cookieToken);
|
||||
}
|
||||
|
||||
// JWT extraction from multiple sources (E11):
|
||||
// Priority: Authorization header > cookie > query parameter
|
||||
// Reference: NATS WebSocket JWT auth — browser clients often pass JWT
|
||||
// via cookie or query param since they cannot set custom headers.
|
||||
string? jwt = null;
|
||||
if (headers.TryGetValue("Authorization", out var authHeader))
|
||||
{
|
||||
jwt = ExtractBearerToken(authHeader);
|
||||
}
|
||||
|
||||
jwt ??= cookieJwt;
|
||||
|
||||
if (jwt == null && queryString != null)
|
||||
{
|
||||
var queryParams = ParseQueryString(queryString);
|
||||
queryParams.TryGetValue("jwt", out jwt);
|
||||
}
|
||||
|
||||
// X-Forwarded-For client IP extraction
|
||||
string? clientIp = null;
|
||||
if (headers.TryGetValue(WsConstants.XForwardedForHeader, out var xff))
|
||||
@@ -109,8 +129,13 @@ public static class WsUpgrade
|
||||
response.Append("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ");
|
||||
response.Append(ComputeAcceptKey(key));
|
||||
response.Append("\r\n");
|
||||
if (compress)
|
||||
response.Append(WsConstants.PmcFullResponse);
|
||||
if (compress && deflateParams != null)
|
||||
{
|
||||
response.Append("Sec-WebSocket-Extensions: ");
|
||||
response.Append(deflateParams.Value.ToResponseHeaderValue());
|
||||
response.Append("\r\n");
|
||||
}
|
||||
|
||||
if (noMasking)
|
||||
response.Append(WsConstants.NoMaskingFullResponse);
|
||||
if (options.Headers != null)
|
||||
@@ -135,7 +160,8 @@ public static class WsUpgrade
|
||||
MaskRead: !noMasking, MaskWrite: false,
|
||||
CookieJwt: cookieJwt, CookieUsername: cookieUsername,
|
||||
CookiePassword: cookiePassword, CookieToken: cookieToken,
|
||||
ClientIp: clientIp, Kind: kind);
|
||||
ClientIp: clientIp, Kind: kind,
|
||||
DeflateParams: deflateParams, Jwt: jwt);
|
||||
}
|
||||
catch (Exception)
|
||||
{
|
||||
@@ -153,11 +179,56 @@ public static class WsUpgrade
|
||||
return Convert.ToBase64String(hash);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Extracts a bearer token from an Authorization header value.
|
||||
/// Supports both "Bearer {token}" and bare "{token}" formats.
|
||||
/// </summary>
|
||||
internal static string? ExtractBearerToken(string? authHeader)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(authHeader))
|
||||
return null;
|
||||
|
||||
var trimmed = authHeader.Trim();
|
||||
if (trimmed.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase))
|
||||
return trimmed["Bearer ".Length..].Trim();
|
||||
|
||||
// Some clients send the token directly without "Bearer" prefix
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Parses a query string into key-value pairs.
|
||||
/// </summary>
|
||||
internal static Dictionary<string, string> ParseQueryString(string queryString)
|
||||
{
|
||||
var result = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
|
||||
if (queryString.StartsWith('?'))
|
||||
queryString = queryString[1..];
|
||||
|
||||
foreach (var pair in queryString.Split('&'))
|
||||
{
|
||||
var eqIdx = pair.IndexOf('=');
|
||||
if (eqIdx > 0)
|
||||
{
|
||||
var name = Uri.UnescapeDataString(pair[..eqIdx]);
|
||||
var value = Uri.UnescapeDataString(pair[(eqIdx + 1)..]);
|
||||
result[name] = value;
|
||||
}
|
||||
else if (pair.Length > 0)
|
||||
{
|
||||
result[Uri.UnescapeDataString(pair)] = string.Empty;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private static async Task<WsUpgradeResult> FailAsync(Stream output, int statusCode, string reason)
|
||||
{
|
||||
var statusText = statusCode switch
|
||||
{
|
||||
400 => "Bad Request",
|
||||
401 => "Unauthorized",
|
||||
403 => "Forbidden",
|
||||
405 => "Method Not Allowed",
|
||||
_ => "Internal Server Error",
|
||||
@@ -165,10 +236,21 @@ public static class WsUpgrade
|
||||
var response = $"HTTP/1.1 {statusCode} {statusText}\r\nSec-WebSocket-Version: 13\r\nContent-Type: text/plain\r\nContent-Length: {reason.Length}\r\n\r\n{reason}";
|
||||
await output.WriteAsync(Encoding.ASCII.GetBytes(response));
|
||||
await output.FlushAsync();
|
||||
return WsUpgradeResult.Failed;
|
||||
return statusCode == 401
|
||||
? WsUpgradeResult.Unauthorized
|
||||
: WsUpgradeResult.Failed;
|
||||
}
|
||||
|
||||
private static async Task<(string method, string path, Dictionary<string, string> headers)> ReadHttpRequestAsync(
|
||||
/// <summary>
|
||||
/// Sends a 401 Unauthorized response and returns a failed upgrade result.
|
||||
/// Used by the server when JWT authentication fails during WS upgrade.
|
||||
/// </summary>
|
||||
public static async Task<WsUpgradeResult> FailUnauthorizedAsync(Stream output, string reason)
|
||||
{
|
||||
return await FailAsync(output, 401, reason);
|
||||
}
|
||||
|
||||
private static async Task<(string method, string path, string? queryString, Dictionary<string, string> headers)> ReadHttpRequestAsync(
|
||||
Stream stream, CancellationToken ct)
|
||||
{
|
||||
var headerBytes = new List<byte>(4096);
|
||||
@@ -197,7 +279,21 @@ public static class WsUpgrade
|
||||
var parts = lines[0].Split(' ');
|
||||
if (parts.Length < 3) throw new InvalidOperationException("invalid HTTP request line");
|
||||
var method = parts[0];
|
||||
var path = parts[1];
|
||||
var requestUri = parts[1];
|
||||
|
||||
// Split path and query string
|
||||
string path;
|
||||
string? queryString = null;
|
||||
var qIdx = requestUri.IndexOf('?');
|
||||
if (qIdx >= 0)
|
||||
{
|
||||
path = requestUri[..qIdx];
|
||||
queryString = requestUri[qIdx..]; // includes the '?'
|
||||
}
|
||||
else
|
||||
{
|
||||
path = requestUri;
|
||||
}
|
||||
|
||||
var headers = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
|
||||
for (int i = 1; i < lines.Length; i++)
|
||||
@@ -213,7 +309,7 @@ public static class WsUpgrade
|
||||
}
|
||||
}
|
||||
|
||||
return (method, path, headers);
|
||||
return (method, path, queryString, headers);
|
||||
}
|
||||
|
||||
private static bool HeaderContains(Dictionary<string, string> headers, string name, string value)
|
||||
@@ -259,10 +355,17 @@ public readonly record struct WsUpgradeResult(
|
||||
string? CookiePassword,
|
||||
string? CookieToken,
|
||||
string? ClientIp,
|
||||
WsClientKind Kind)
|
||||
WsClientKind Kind,
|
||||
WsDeflateParams? DeflateParams = null,
|
||||
string? Jwt = null)
|
||||
{
|
||||
public static readonly WsUpgradeResult Failed = new(
|
||||
Success: false, Compress: false, Browser: false, NoCompFrag: false,
|
||||
MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null,
|
||||
CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client);
|
||||
|
||||
public static readonly WsUpgradeResult Unauthorized = new(
|
||||
Success: false, Compress: false, Browser: false, NoCompFrag: false,
|
||||
MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null,
|
||||
CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user