diff --git a/src/NATS.Server/WebSocket/WsUpgrade.cs b/src/NATS.Server/WebSocket/WsUpgrade.cs
new file mode 100644
index 0000000..ba91191
--- /dev/null
+++ b/src/NATS.Server/WebSocket/WsUpgrade.cs
@@ -0,0 +1,259 @@
+using System.Net;
+using System.Security.Cryptography;
+using System.Text;
+
+namespace NATS.Server.WebSocket;
+
+///
+/// WebSocket HTTP upgrade handshake handler.
+/// Ported from golang/nats-server/server/websocket.go lines 731-917.
+///
+public static class WsUpgrade
+{
+ public static async Task TryUpgradeAsync(
+ Stream inputStream, Stream outputStream, WebSocketOptions options)
+ {
+ try
+ {
+ var (method, path, headers) = await ReadHttpRequestAsync(inputStream);
+
+ if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase))
+ return await FailAsync(outputStream, 405, "request method must be GET");
+
+ if (!headers.ContainsKey("Host"))
+ return await FailAsync(outputStream, 400, "'Host' missing in request");
+
+ if (!HeaderContains(headers, "Upgrade", "websocket"))
+ return await FailAsync(outputStream, 400, "invalid value for header 'Upgrade'");
+
+ if (!HeaderContains(headers, "Connection", "Upgrade"))
+ return await FailAsync(outputStream, 400, "invalid value for header 'Connection'");
+
+ if (!headers.TryGetValue("Sec-WebSocket-Key", out var key) || string.IsNullOrEmpty(key))
+ return await FailAsync(outputStream, 400, "key missing");
+
+ if (!HeaderContains(headers, "Sec-WebSocket-Version", "13"))
+ return await FailAsync(outputStream, 400, "invalid version");
+
+ var kind = path switch
+ {
+ _ when path.EndsWith("/leafnode") => WsClientKind.Leaf,
+ _ when path.EndsWith("/mqtt") => WsClientKind.Mqtt,
+ _ => WsClientKind.Client,
+ };
+
+ // Origin checking
+ if (options.SameOrigin || options.AllowedOrigins is { Count: > 0 })
+ {
+ var checker = new WsOriginChecker(options.SameOrigin, options.AllowedOrigins);
+ headers.TryGetValue("Origin", out var origin);
+ if (string.IsNullOrEmpty(origin))
+ headers.TryGetValue("Sec-WebSocket-Origin", out origin);
+ var originErr = checker.CheckOrigin(origin, headers.GetValueOrDefault("Host", ""), isTls: false);
+ if (originErr != null)
+ return await FailAsync(outputStream, 403, $"origin not allowed: {originErr}");
+ }
+
+ // Compression negotiation
+ bool compress = options.Compression;
+ if (compress)
+ {
+ compress = headers.TryGetValue("Sec-WebSocket-Extensions", out var ext) &&
+ ext.Contains(WsConstants.PmcExtension, StringComparison.OrdinalIgnoreCase);
+ }
+
+ // No-masking support
+ bool noMasking = headers.TryGetValue(WsConstants.NoMaskingHeader, out var nmVal) &&
+ string.Equals(nmVal.Trim(), WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase);
+
+ // Browser detection
+ bool browser = false;
+ bool noCompFrag = false;
+ if (kind is WsClientKind.Client or WsClientKind.Mqtt &&
+ headers.TryGetValue("User-Agent", out var ua) && ua.StartsWith("Mozilla/"))
+ {
+ browser = true;
+ // Disable fragmentation of compressed frames for Safari browsers.
+ // Safari has both "Version/" and "Safari/" in the user agent string,
+ // while Chrome on macOS has "Safari/" but not "Version/".
+ noCompFrag = compress && ua.Contains("Version/") && ua.Contains("Safari/");
+ }
+
+ // Cookie extraction
+ string? cookieJwt = null, cookieUsername = null, cookiePassword = null, cookieToken = null;
+ if ((kind is WsClientKind.Client or WsClientKind.Mqtt) &&
+ headers.TryGetValue("Cookie", out var cookieHeader))
+ {
+ var cookies = ParseCookies(cookieHeader);
+ if (options.JwtCookie != null) cookies.TryGetValue(options.JwtCookie, out cookieJwt);
+ if (options.UsernameCookie != null) cookies.TryGetValue(options.UsernameCookie, out cookieUsername);
+ if (options.PasswordCookie != null) cookies.TryGetValue(options.PasswordCookie, out cookiePassword);
+ if (options.TokenCookie != null) cookies.TryGetValue(options.TokenCookie, out cookieToken);
+ }
+
+ // X-Forwarded-For client IP extraction
+ string? clientIp = null;
+ if (headers.TryGetValue(WsConstants.XForwardedForHeader, out var xff))
+ {
+ var ip = xff.Split(',')[0].Trim();
+ if (IPAddress.TryParse(ip, out _))
+ clientIp = ip;
+ }
+
+ // Build the 101 Switching Protocols response
+ var response = new StringBuilder();
+ response.Append("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ");
+ response.Append(ComputeAcceptKey(key));
+ response.Append("\r\n");
+ if (compress)
+ response.Append(WsConstants.PmcFullResponse);
+ if (noMasking)
+ response.Append(WsConstants.NoMaskingFullResponse);
+ if (options.Headers != null)
+ {
+ foreach (var (k, v) in options.Headers)
+ {
+ response.Append(k);
+ response.Append(": ");
+ response.Append(v);
+ response.Append("\r\n");
+ }
+ }
+
+ response.Append("\r\n");
+
+ var responseBytes = Encoding.ASCII.GetBytes(response.ToString());
+ await outputStream.WriteAsync(responseBytes);
+ await outputStream.FlushAsync();
+
+ return new WsUpgradeResult(
+ Success: true, Compress: compress, Browser: browser, NoCompFrag: noCompFrag,
+ MaskRead: !noMasking, MaskWrite: false,
+ CookieJwt: cookieJwt, CookieUsername: cookieUsername,
+ CookiePassword: cookiePassword, CookieToken: cookieToken,
+ ClientIp: clientIp, Kind: kind);
+ }
+ catch (Exception)
+ {
+ return WsUpgradeResult.Failed;
+ }
+ }
+
+ ///
+ /// Computes the Sec-WebSocket-Accept value per RFC 6455 Section 4.2.2.
+ ///
+ public static string ComputeAcceptKey(string clientKey)
+ {
+ var combined = Encoding.ASCII.GetBytes(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
+ var hash = SHA1.HashData(combined);
+ return Convert.ToBase64String(hash);
+ }
+
+ private static async Task FailAsync(Stream output, int statusCode, string reason)
+ {
+ var statusText = statusCode switch
+ {
+ 400 => "Bad Request",
+ 403 => "Forbidden",
+ 405 => "Method Not Allowed",
+ _ => "Internal Server Error",
+ };
+ var response = $"HTTP/1.1 {statusCode} {statusText}\r\nSec-WebSocket-Version: 13\r\nContent-Type: text/plain\r\nContent-Length: {reason.Length}\r\n\r\n{reason}";
+ await output.WriteAsync(Encoding.ASCII.GetBytes(response));
+ await output.FlushAsync();
+ return WsUpgradeResult.Failed;
+ }
+
+ private static async Task<(string method, string path, Dictionary headers)> ReadHttpRequestAsync(Stream stream)
+ {
+ var headerBytes = new List(4096);
+ var buf = new byte[1];
+ while (true)
+ {
+ int n = await stream.ReadAsync(buf);
+ if (n == 0) throw new IOException("connection closed during handshake");
+ headerBytes.Add(buf[0]);
+ if (headerBytes.Count >= 4 &&
+ headerBytes[^4] == '\r' && headerBytes[^3] == '\n' &&
+ headerBytes[^2] == '\r' && headerBytes[^1] == '\n')
+ break;
+ if (headerBytes.Count > 8192)
+ throw new InvalidOperationException("HTTP header too large");
+ }
+
+ var text = Encoding.ASCII.GetString(headerBytes.ToArray());
+ var lines = text.Split("\r\n", StringSplitOptions.None);
+ if (lines.Length < 1) throw new InvalidOperationException("invalid HTTP request");
+
+ var parts = lines[0].Split(' ');
+ if (parts.Length < 3) throw new InvalidOperationException("invalid HTTP request line");
+ var method = parts[0];
+ var path = parts[1];
+
+ var headers = new Dictionary(StringComparer.OrdinalIgnoreCase);
+ for (int i = 1; i < lines.Length; i++)
+ {
+ var line = lines[i];
+ if (string.IsNullOrEmpty(line)) break;
+ var colonIdx = line.IndexOf(':');
+ if (colonIdx > 0)
+ {
+ var name = line[..colonIdx].Trim();
+ var value = line[(colonIdx + 1)..].Trim();
+ headers[name] = value;
+ }
+ }
+
+ return (method, path, headers);
+ }
+
+ private static bool HeaderContains(Dictionary headers, string name, string value)
+ {
+ if (!headers.TryGetValue(name, out var headerValue))
+ return false;
+ foreach (var token in headerValue.Split(','))
+ {
+ if (string.Equals(token.Trim(), value, StringComparison.OrdinalIgnoreCase))
+ return true;
+ }
+
+ return false;
+ }
+
+ private static Dictionary ParseCookies(string cookieHeader)
+ {
+ var cookies = new Dictionary(StringComparer.Ordinal);
+ foreach (var pair in cookieHeader.Split(';'))
+ {
+ var trimmed = pair.Trim();
+ var eqIdx = trimmed.IndexOf('=');
+ if (eqIdx > 0)
+ cookies[trimmed[..eqIdx].Trim()] = trimmed[(eqIdx + 1)..].Trim();
+ }
+
+ return cookies;
+ }
+}
+
+///
+/// Result of a WebSocket upgrade handshake attempt.
+///
+public readonly record struct WsUpgradeResult(
+ bool Success,
+ bool Compress,
+ bool Browser,
+ bool NoCompFrag,
+ bool MaskRead,
+ bool MaskWrite,
+ string? CookieJwt,
+ string? CookieUsername,
+ string? CookiePassword,
+ string? CookieToken,
+ string? ClientIp,
+ WsClientKind Kind)
+{
+ public static readonly WsUpgradeResult Failed = new(
+ Success: false, Compress: false, Browser: false, NoCompFrag: false,
+ MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null,
+ CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client);
+}
diff --git a/tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs b/tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs
new file mode 100644
index 0000000..a5e1168
--- /dev/null
+++ b/tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs
@@ -0,0 +1,226 @@
+using System.Text;
+using NATS.Server.WebSocket;
+
+namespace NATS.Server.Tests.WebSocket;
+
+public class WsUpgradeTests
+{
+ private static string BuildValidRequest(string path = "/", string? extraHeaders = null)
+ {
+ var sb = new StringBuilder();
+ sb.Append($"GET {path} HTTP/1.1\r\n");
+ sb.Append("Host: localhost:4222\r\n");
+ sb.Append("Upgrade: websocket\r\n");
+ sb.Append("Connection: Upgrade\r\n");
+ sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
+ sb.Append("Sec-WebSocket-Version: 13\r\n");
+ if (extraHeaders != null)
+ sb.Append(extraHeaders);
+ sb.Append("\r\n");
+ return sb.ToString();
+ }
+
+ [Fact]
+ public async Task ValidUpgrade_Returns101()
+ {
+ var request = BuildValidRequest();
+ var (inputStream, outputStream) = CreateStreamPair(request);
+
+ var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
+
+ result.Success.ShouldBeTrue();
+ result.Kind.ShouldBe(WsClientKind.Client);
+ var response = ReadResponse(outputStream);
+ response.ShouldContain("HTTP/1.1 101");
+ response.ShouldContain("Upgrade: websocket");
+ response.ShouldContain("Sec-WebSocket-Accept:");
+ }
+
+ [Fact]
+ public async Task MissingUpgradeHeader_Returns400()
+ {
+ var request = "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
+ var (inputStream, outputStream) = CreateStreamPair(request);
+
+ var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
+
+ result.Success.ShouldBeFalse();
+ ReadResponse(outputStream).ShouldContain("400");
+ }
+
+ [Fact]
+ public async Task MissingHost_Returns400()
+ {
+ var request = "GET / HTTP/1.1\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
+ var (inputStream, outputStream) = CreateStreamPair(request);
+
+ var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
+
+ result.Success.ShouldBeFalse();
+ }
+
+ [Fact]
+ public async Task WrongVersion_Returns400()
+ {
+ var request = "GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 12\r\n\r\n";
+ var (inputStream, outputStream) = CreateStreamPair(request);
+
+ var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
+
+ result.Success.ShouldBeFalse();
+ }
+
+ [Fact]
+ public async Task LeafNodePath_ReturnsLeafKind()
+ {
+ var request = BuildValidRequest("/leafnode");
+ var (inputStream, outputStream) = CreateStreamPair(request);
+
+ var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
+
+ result.Success.ShouldBeTrue();
+ result.Kind.ShouldBe(WsClientKind.Leaf);
+ }
+
+ [Fact]
+ public async Task MqttPath_ReturnsMqttKind()
+ {
+ var request = BuildValidRequest("/mqtt");
+ var (inputStream, outputStream) = CreateStreamPair(request);
+
+ var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
+
+ result.Success.ShouldBeTrue();
+ result.Kind.ShouldBe(WsClientKind.Mqtt);
+ }
+
+ [Fact]
+ public async Task CompressionNegotiation_WhenEnabled()
+ {
+ var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}; {WsConstants.PmcSrvNoCtx}; {WsConstants.PmcCliNoCtx}\r\n");
+ var (inputStream, outputStream) = CreateStreamPair(request);
+
+ var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true });
+
+ result.Success.ShouldBeTrue();
+ result.Compress.ShouldBeTrue();
+ ReadResponse(outputStream).ShouldContain("permessage-deflate");
+ }
+
+ [Fact]
+ public async Task CompressionNegotiation_WhenDisabled()
+ {
+ var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n");
+ var (inputStream, outputStream) = CreateStreamPair(request);
+
+ var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = false });
+
+ result.Success.ShouldBeTrue();
+ result.Compress.ShouldBeFalse();
+ }
+
+ [Fact]
+ public async Task NoMaskingHeader_ForLeaf()
+ {
+ var request = BuildValidRequest("/leafnode", "Nats-No-Masking: true\r\n");
+ var (inputStream, outputStream) = CreateStreamPair(request);
+
+ var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
+
+ result.Success.ShouldBeTrue();
+ result.MaskRead.ShouldBeFalse();
+ }
+
+ [Fact]
+ public async Task BrowserDetection_Mozilla()
+ {
+ var request = BuildValidRequest(extraHeaders: "User-Agent: Mozilla/5.0 (Windows)\r\n");
+ var (inputStream, outputStream) = CreateStreamPair(request);
+
+ var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
+
+ result.Success.ShouldBeTrue();
+ result.Browser.ShouldBeTrue();
+ }
+
+ [Fact]
+ public async Task SafariDetection_NoCompFrag()
+ {
+ var request = BuildValidRequest(extraHeaders:
+ "User-Agent: Mozilla/5.0 (Macintosh) Version/15.0 Safari/605.1.15\r\n" +
+ $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n");
+ var (inputStream, outputStream) = CreateStreamPair(request);
+
+ var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true });
+
+ result.Success.ShouldBeTrue();
+ result.NoCompFrag.ShouldBeTrue();
+ }
+
+ [Fact]
+ public void AcceptKey_MatchesRfc6455Example()
+ {
+ // RFC 6455 Section 4.2.2 example
+ var key = WsUpgrade.ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ==");
+ key.ShouldBe("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
+ }
+
+ [Fact]
+ public async Task CookieExtraction()
+ {
+ var request = BuildValidRequest(extraHeaders:
+ "Cookie: jwt_token=my-jwt; nats_user=admin; nats_pass=secret\r\n");
+ var (inputStream, outputStream) = CreateStreamPair(request);
+
+ var opts = new WebSocketOptions
+ {
+ NoTls = true,
+ JwtCookie = "jwt_token",
+ UsernameCookie = "nats_user",
+ PasswordCookie = "nats_pass",
+ };
+ var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
+
+ result.Success.ShouldBeTrue();
+ result.CookieJwt.ShouldBe("my-jwt");
+ result.CookieUsername.ShouldBe("admin");
+ result.CookiePassword.ShouldBe("secret");
+ }
+
+ [Fact]
+ public async Task XForwardedFor_ExtractsClientIp()
+ {
+ var request = BuildValidRequest(extraHeaders: "X-Forwarded-For: 192.168.1.100\r\n");
+ var (inputStream, outputStream) = CreateStreamPair(request);
+
+ var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
+
+ result.Success.ShouldBeTrue();
+ result.ClientIp.ShouldBe("192.168.1.100");
+ }
+
+ [Fact]
+ public async Task PostMethod_Returns405()
+ {
+ var request = "POST / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
+ var (inputStream, outputStream) = CreateStreamPair(request);
+
+ var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
+
+ result.Success.ShouldBeFalse();
+ ReadResponse(outputStream).ShouldContain("405");
+ }
+
+ // Helper: create a readable input stream and writable output stream
+ private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest)
+ {
+ var inputBytes = Encoding.ASCII.GetBytes(httpRequest);
+ return (new MemoryStream(inputBytes), new MemoryStream());
+ }
+
+ private static string ReadResponse(MemoryStream output)
+ {
+ output.Position = 0;
+ return Encoding.ASCII.GetString(output.ToArray());
+ }
+}