From 1c948b5b0f6b2d2de491b1bb87eea87914ddd374 Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Mon, 23 Feb 2026 04:53:21 -0500 Subject: [PATCH] feat: add WebSocket HTTP upgrade handshake --- src/NATS.Server/WebSocket/WsUpgrade.cs | 259 ++++++++++++++++++ .../WebSocket/WsUpgradeTests.cs | 226 +++++++++++++++ 2 files changed, 485 insertions(+) create mode 100644 src/NATS.Server/WebSocket/WsUpgrade.cs create mode 100644 tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs diff --git a/src/NATS.Server/WebSocket/WsUpgrade.cs b/src/NATS.Server/WebSocket/WsUpgrade.cs new file mode 100644 index 0000000..ba91191 --- /dev/null +++ b/src/NATS.Server/WebSocket/WsUpgrade.cs @@ -0,0 +1,259 @@ +using System.Net; +using System.Security.Cryptography; +using System.Text; + +namespace NATS.Server.WebSocket; + +/// +/// WebSocket HTTP upgrade handshake handler. +/// Ported from golang/nats-server/server/websocket.go lines 731-917. +/// +public static class WsUpgrade +{ + public static async Task TryUpgradeAsync( + Stream inputStream, Stream outputStream, WebSocketOptions options) + { + try + { + var (method, path, headers) = await ReadHttpRequestAsync(inputStream); + + if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase)) + return await FailAsync(outputStream, 405, "request method must be GET"); + + if (!headers.ContainsKey("Host")) + return await FailAsync(outputStream, 400, "'Host' missing in request"); + + if (!HeaderContains(headers, "Upgrade", "websocket")) + return await FailAsync(outputStream, 400, "invalid value for header 'Upgrade'"); + + if (!HeaderContains(headers, "Connection", "Upgrade")) + return await FailAsync(outputStream, 400, "invalid value for header 'Connection'"); + + if (!headers.TryGetValue("Sec-WebSocket-Key", out var key) || string.IsNullOrEmpty(key)) + return await FailAsync(outputStream, 400, "key missing"); + + if (!HeaderContains(headers, "Sec-WebSocket-Version", "13")) + return await FailAsync(outputStream, 400, "invalid version"); + + var kind = path switch + { + _ when path.EndsWith("/leafnode") => WsClientKind.Leaf, + _ when path.EndsWith("/mqtt") => WsClientKind.Mqtt, + _ => WsClientKind.Client, + }; + + // Origin checking + if (options.SameOrigin || options.AllowedOrigins is { Count: > 0 }) + { + var checker = new WsOriginChecker(options.SameOrigin, options.AllowedOrigins); + headers.TryGetValue("Origin", out var origin); + if (string.IsNullOrEmpty(origin)) + headers.TryGetValue("Sec-WebSocket-Origin", out origin); + var originErr = checker.CheckOrigin(origin, headers.GetValueOrDefault("Host", ""), isTls: false); + if (originErr != null) + return await FailAsync(outputStream, 403, $"origin not allowed: {originErr}"); + } + + // Compression negotiation + bool compress = options.Compression; + if (compress) + { + compress = headers.TryGetValue("Sec-WebSocket-Extensions", out var ext) && + ext.Contains(WsConstants.PmcExtension, StringComparison.OrdinalIgnoreCase); + } + + // No-masking support + bool noMasking = headers.TryGetValue(WsConstants.NoMaskingHeader, out var nmVal) && + string.Equals(nmVal.Trim(), WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase); + + // Browser detection + bool browser = false; + bool noCompFrag = false; + if (kind is WsClientKind.Client or WsClientKind.Mqtt && + headers.TryGetValue("User-Agent", out var ua) && ua.StartsWith("Mozilla/")) + { + browser = true; + // Disable fragmentation of compressed frames for Safari browsers. + // Safari has both "Version/" and "Safari/" in the user agent string, + // while Chrome on macOS has "Safari/" but not "Version/". + noCompFrag = compress && ua.Contains("Version/") && ua.Contains("Safari/"); + } + + // Cookie extraction + string? cookieJwt = null, cookieUsername = null, cookiePassword = null, cookieToken = null; + if ((kind is WsClientKind.Client or WsClientKind.Mqtt) && + headers.TryGetValue("Cookie", out var cookieHeader)) + { + var cookies = ParseCookies(cookieHeader); + if (options.JwtCookie != null) cookies.TryGetValue(options.JwtCookie, out cookieJwt); + if (options.UsernameCookie != null) cookies.TryGetValue(options.UsernameCookie, out cookieUsername); + if (options.PasswordCookie != null) cookies.TryGetValue(options.PasswordCookie, out cookiePassword); + if (options.TokenCookie != null) cookies.TryGetValue(options.TokenCookie, out cookieToken); + } + + // X-Forwarded-For client IP extraction + string? clientIp = null; + if (headers.TryGetValue(WsConstants.XForwardedForHeader, out var xff)) + { + var ip = xff.Split(',')[0].Trim(); + if (IPAddress.TryParse(ip, out _)) + clientIp = ip; + } + + // Build the 101 Switching Protocols response + var response = new StringBuilder(); + response.Append("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "); + response.Append(ComputeAcceptKey(key)); + response.Append("\r\n"); + if (compress) + response.Append(WsConstants.PmcFullResponse); + if (noMasking) + response.Append(WsConstants.NoMaskingFullResponse); + if (options.Headers != null) + { + foreach (var (k, v) in options.Headers) + { + response.Append(k); + response.Append(": "); + response.Append(v); + response.Append("\r\n"); + } + } + + response.Append("\r\n"); + + var responseBytes = Encoding.ASCII.GetBytes(response.ToString()); + await outputStream.WriteAsync(responseBytes); + await outputStream.FlushAsync(); + + return new WsUpgradeResult( + Success: true, Compress: compress, Browser: browser, NoCompFrag: noCompFrag, + MaskRead: !noMasking, MaskWrite: false, + CookieJwt: cookieJwt, CookieUsername: cookieUsername, + CookiePassword: cookiePassword, CookieToken: cookieToken, + ClientIp: clientIp, Kind: kind); + } + catch (Exception) + { + return WsUpgradeResult.Failed; + } + } + + /// + /// Computes the Sec-WebSocket-Accept value per RFC 6455 Section 4.2.2. + /// + public static string ComputeAcceptKey(string clientKey) + { + var combined = Encoding.ASCII.GetBytes(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + var hash = SHA1.HashData(combined); + return Convert.ToBase64String(hash); + } + + private static async Task FailAsync(Stream output, int statusCode, string reason) + { + var statusText = statusCode switch + { + 400 => "Bad Request", + 403 => "Forbidden", + 405 => "Method Not Allowed", + _ => "Internal Server Error", + }; + var response = $"HTTP/1.1 {statusCode} {statusText}\r\nSec-WebSocket-Version: 13\r\nContent-Type: text/plain\r\nContent-Length: {reason.Length}\r\n\r\n{reason}"; + await output.WriteAsync(Encoding.ASCII.GetBytes(response)); + await output.FlushAsync(); + return WsUpgradeResult.Failed; + } + + private static async Task<(string method, string path, Dictionary headers)> ReadHttpRequestAsync(Stream stream) + { + var headerBytes = new List(4096); + var buf = new byte[1]; + while (true) + { + int n = await stream.ReadAsync(buf); + if (n == 0) throw new IOException("connection closed during handshake"); + headerBytes.Add(buf[0]); + if (headerBytes.Count >= 4 && + headerBytes[^4] == '\r' && headerBytes[^3] == '\n' && + headerBytes[^2] == '\r' && headerBytes[^1] == '\n') + break; + if (headerBytes.Count > 8192) + throw new InvalidOperationException("HTTP header too large"); + } + + var text = Encoding.ASCII.GetString(headerBytes.ToArray()); + var lines = text.Split("\r\n", StringSplitOptions.None); + if (lines.Length < 1) throw new InvalidOperationException("invalid HTTP request"); + + var parts = lines[0].Split(' '); + if (parts.Length < 3) throw new InvalidOperationException("invalid HTTP request line"); + var method = parts[0]; + var path = parts[1]; + + var headers = new Dictionary(StringComparer.OrdinalIgnoreCase); + for (int i = 1; i < lines.Length; i++) + { + var line = lines[i]; + if (string.IsNullOrEmpty(line)) break; + var colonIdx = line.IndexOf(':'); + if (colonIdx > 0) + { + var name = line[..colonIdx].Trim(); + var value = line[(colonIdx + 1)..].Trim(); + headers[name] = value; + } + } + + return (method, path, headers); + } + + private static bool HeaderContains(Dictionary headers, string name, string value) + { + if (!headers.TryGetValue(name, out var headerValue)) + return false; + foreach (var token in headerValue.Split(',')) + { + if (string.Equals(token.Trim(), value, StringComparison.OrdinalIgnoreCase)) + return true; + } + + return false; + } + + private static Dictionary ParseCookies(string cookieHeader) + { + var cookies = new Dictionary(StringComparer.Ordinal); + foreach (var pair in cookieHeader.Split(';')) + { + var trimmed = pair.Trim(); + var eqIdx = trimmed.IndexOf('='); + if (eqIdx > 0) + cookies[trimmed[..eqIdx].Trim()] = trimmed[(eqIdx + 1)..].Trim(); + } + + return cookies; + } +} + +/// +/// Result of a WebSocket upgrade handshake attempt. +/// +public readonly record struct WsUpgradeResult( + bool Success, + bool Compress, + bool Browser, + bool NoCompFrag, + bool MaskRead, + bool MaskWrite, + string? CookieJwt, + string? CookieUsername, + string? CookiePassword, + string? CookieToken, + string? ClientIp, + WsClientKind Kind) +{ + public static readonly WsUpgradeResult Failed = new( + Success: false, Compress: false, Browser: false, NoCompFrag: false, + MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null, + CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client); +} diff --git a/tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs b/tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs new file mode 100644 index 0000000..a5e1168 --- /dev/null +++ b/tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs @@ -0,0 +1,226 @@ +using System.Text; +using NATS.Server.WebSocket; + +namespace NATS.Server.Tests.WebSocket; + +public class WsUpgradeTests +{ + private static string BuildValidRequest(string path = "/", string? extraHeaders = null) + { + var sb = new StringBuilder(); + sb.Append($"GET {path} HTTP/1.1\r\n"); + sb.Append("Host: localhost:4222\r\n"); + sb.Append("Upgrade: websocket\r\n"); + sb.Append("Connection: Upgrade\r\n"); + sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"); + sb.Append("Sec-WebSocket-Version: 13\r\n"); + if (extraHeaders != null) + sb.Append(extraHeaders); + sb.Append("\r\n"); + return sb.ToString(); + } + + [Fact] + public async Task ValidUpgrade_Returns101() + { + var request = BuildValidRequest(); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Kind.ShouldBe(WsClientKind.Client); + var response = ReadResponse(outputStream); + response.ShouldContain("HTTP/1.1 101"); + response.ShouldContain("Upgrade: websocket"); + response.ShouldContain("Sec-WebSocket-Accept:"); + } + + [Fact] + public async Task MissingUpgradeHeader_Returns400() + { + var request = "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n"; + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeFalse(); + ReadResponse(outputStream).ShouldContain("400"); + } + + [Fact] + public async Task MissingHost_Returns400() + { + var request = "GET / HTTP/1.1\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n"; + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeFalse(); + } + + [Fact] + public async Task WrongVersion_Returns400() + { + var request = "GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 12\r\n\r\n"; + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeFalse(); + } + + [Fact] + public async Task LeafNodePath_ReturnsLeafKind() + { + var request = BuildValidRequest("/leafnode"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Kind.ShouldBe(WsClientKind.Leaf); + } + + [Fact] + public async Task MqttPath_ReturnsMqttKind() + { + var request = BuildValidRequest("/mqtt"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Kind.ShouldBe(WsClientKind.Mqtt); + } + + [Fact] + public async Task CompressionNegotiation_WhenEnabled() + { + var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}; {WsConstants.PmcSrvNoCtx}; {WsConstants.PmcCliNoCtx}\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true }); + + result.Success.ShouldBeTrue(); + result.Compress.ShouldBeTrue(); + ReadResponse(outputStream).ShouldContain("permessage-deflate"); + } + + [Fact] + public async Task CompressionNegotiation_WhenDisabled() + { + var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = false }); + + result.Success.ShouldBeTrue(); + result.Compress.ShouldBeFalse(); + } + + [Fact] + public async Task NoMaskingHeader_ForLeaf() + { + var request = BuildValidRequest("/leafnode", "Nats-No-Masking: true\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.MaskRead.ShouldBeFalse(); + } + + [Fact] + public async Task BrowserDetection_Mozilla() + { + var request = BuildValidRequest(extraHeaders: "User-Agent: Mozilla/5.0 (Windows)\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Browser.ShouldBeTrue(); + } + + [Fact] + public async Task SafariDetection_NoCompFrag() + { + var request = BuildValidRequest(extraHeaders: + "User-Agent: Mozilla/5.0 (Macintosh) Version/15.0 Safari/605.1.15\r\n" + + $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true }); + + result.Success.ShouldBeTrue(); + result.NoCompFrag.ShouldBeTrue(); + } + + [Fact] + public void AcceptKey_MatchesRfc6455Example() + { + // RFC 6455 Section 4.2.2 example + var key = WsUpgrade.ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ=="); + key.ShouldBe("s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); + } + + [Fact] + public async Task CookieExtraction() + { + var request = BuildValidRequest(extraHeaders: + "Cookie: jwt_token=my-jwt; nats_user=admin; nats_pass=secret\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var opts = new WebSocketOptions + { + NoTls = true, + JwtCookie = "jwt_token", + UsernameCookie = "nats_user", + PasswordCookie = "nats_pass", + }; + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts); + + result.Success.ShouldBeTrue(); + result.CookieJwt.ShouldBe("my-jwt"); + result.CookieUsername.ShouldBe("admin"); + result.CookiePassword.ShouldBe("secret"); + } + + [Fact] + public async Task XForwardedFor_ExtractsClientIp() + { + var request = BuildValidRequest(extraHeaders: "X-Forwarded-For: 192.168.1.100\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.ClientIp.ShouldBe("192.168.1.100"); + } + + [Fact] + public async Task PostMethod_Returns405() + { + var request = "POST / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n"; + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeFalse(); + ReadResponse(outputStream).ShouldContain("405"); + } + + // Helper: create a readable input stream and writable output stream + private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest) + { + var inputBytes = Encoding.ASCII.GetBytes(httpRequest); + return (new MemoryStream(inputBytes), new MemoryStream()); + } + + private static string ReadResponse(MemoryStream output) + { + output.Position = 0; + return Encoding.ASCII.GetString(output.ToArray()); + } +}