using System.Net; using System.Security.Cryptography; using System.Text; namespace NATS.Server.WebSocket; /// /// WebSocket HTTP upgrade handshake handler. /// Ported from golang/nats-server/server/websocket.go lines 731-917. /// public static class WsUpgrade { public static async Task TryUpgradeAsync( Stream inputStream, Stream outputStream, WebSocketOptions options) { try { var (method, path, headers) = await ReadHttpRequestAsync(inputStream); if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase)) return await FailAsync(outputStream, 405, "request method must be GET"); if (!headers.ContainsKey("Host")) return await FailAsync(outputStream, 400, "'Host' missing in request"); if (!HeaderContains(headers, "Upgrade", "websocket")) return await FailAsync(outputStream, 400, "invalid value for header 'Upgrade'"); if (!HeaderContains(headers, "Connection", "Upgrade")) return await FailAsync(outputStream, 400, "invalid value for header 'Connection'"); if (!headers.TryGetValue("Sec-WebSocket-Key", out var key) || string.IsNullOrEmpty(key)) return await FailAsync(outputStream, 400, "key missing"); if (!HeaderContains(headers, "Sec-WebSocket-Version", "13")) return await FailAsync(outputStream, 400, "invalid version"); var kind = path switch { _ when path.EndsWith("/leafnode") => WsClientKind.Leaf, _ when path.EndsWith("/mqtt") => WsClientKind.Mqtt, _ => WsClientKind.Client, }; // Origin checking if (options.SameOrigin || options.AllowedOrigins is { Count: > 0 }) { var checker = new WsOriginChecker(options.SameOrigin, options.AllowedOrigins); headers.TryGetValue("Origin", out var origin); if (string.IsNullOrEmpty(origin)) headers.TryGetValue("Sec-WebSocket-Origin", out origin); var originErr = checker.CheckOrigin(origin, headers.GetValueOrDefault("Host", ""), isTls: false); if (originErr != null) return await FailAsync(outputStream, 403, $"origin not allowed: {originErr}"); } // Compression negotiation bool compress = options.Compression; if (compress) { compress = headers.TryGetValue("Sec-WebSocket-Extensions", out var ext) && ext.Contains(WsConstants.PmcExtension, StringComparison.OrdinalIgnoreCase); } // No-masking support bool noMasking = headers.TryGetValue(WsConstants.NoMaskingHeader, out var nmVal) && string.Equals(nmVal.Trim(), WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase); // Browser detection bool browser = false; bool noCompFrag = false; if (kind is WsClientKind.Client or WsClientKind.Mqtt && headers.TryGetValue("User-Agent", out var ua) && ua.StartsWith("Mozilla/")) { browser = true; // Disable fragmentation of compressed frames for Safari browsers. // Safari has both "Version/" and "Safari/" in the user agent string, // while Chrome on macOS has "Safari/" but not "Version/". noCompFrag = compress && ua.Contains("Version/") && ua.Contains("Safari/"); } // Cookie extraction string? cookieJwt = null, cookieUsername = null, cookiePassword = null, cookieToken = null; if ((kind is WsClientKind.Client or WsClientKind.Mqtt) && headers.TryGetValue("Cookie", out var cookieHeader)) { var cookies = ParseCookies(cookieHeader); if (options.JwtCookie != null) cookies.TryGetValue(options.JwtCookie, out cookieJwt); if (options.UsernameCookie != null) cookies.TryGetValue(options.UsernameCookie, out cookieUsername); if (options.PasswordCookie != null) cookies.TryGetValue(options.PasswordCookie, out cookiePassword); if (options.TokenCookie != null) cookies.TryGetValue(options.TokenCookie, out cookieToken); } // X-Forwarded-For client IP extraction string? clientIp = null; if (headers.TryGetValue(WsConstants.XForwardedForHeader, out var xff)) { var ip = xff.Split(',')[0].Trim(); if (IPAddress.TryParse(ip, out _)) clientIp = ip; } // Build the 101 Switching Protocols response var response = new StringBuilder(); response.Append("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "); response.Append(ComputeAcceptKey(key)); response.Append("\r\n"); if (compress) response.Append(WsConstants.PmcFullResponse); if (noMasking) response.Append(WsConstants.NoMaskingFullResponse); if (options.Headers != null) { foreach (var (k, v) in options.Headers) { response.Append(k); response.Append(": "); response.Append(v); response.Append("\r\n"); } } response.Append("\r\n"); var responseBytes = Encoding.ASCII.GetBytes(response.ToString()); await outputStream.WriteAsync(responseBytes); await outputStream.FlushAsync(); return new WsUpgradeResult( Success: true, Compress: compress, Browser: browser, NoCompFrag: noCompFrag, MaskRead: !noMasking, MaskWrite: false, CookieJwt: cookieJwt, CookieUsername: cookieUsername, CookiePassword: cookiePassword, CookieToken: cookieToken, ClientIp: clientIp, Kind: kind); } catch (Exception) { return WsUpgradeResult.Failed; } } /// /// Computes the Sec-WebSocket-Accept value per RFC 6455 Section 4.2.2. /// public static string ComputeAcceptKey(string clientKey) { var combined = Encoding.ASCII.GetBytes(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); var hash = SHA1.HashData(combined); return Convert.ToBase64String(hash); } private static async Task FailAsync(Stream output, int statusCode, string reason) { var statusText = statusCode switch { 400 => "Bad Request", 403 => "Forbidden", 405 => "Method Not Allowed", _ => "Internal Server Error", }; var response = $"HTTP/1.1 {statusCode} {statusText}\r\nSec-WebSocket-Version: 13\r\nContent-Type: text/plain\r\nContent-Length: {reason.Length}\r\n\r\n{reason}"; await output.WriteAsync(Encoding.ASCII.GetBytes(response)); await output.FlushAsync(); return WsUpgradeResult.Failed; } private static async Task<(string method, string path, Dictionary headers)> ReadHttpRequestAsync(Stream stream) { var headerBytes = new List(4096); var buf = new byte[1]; while (true) { int n = await stream.ReadAsync(buf); if (n == 0) throw new IOException("connection closed during handshake"); headerBytes.Add(buf[0]); if (headerBytes.Count >= 4 && headerBytes[^4] == '\r' && headerBytes[^3] == '\n' && headerBytes[^2] == '\r' && headerBytes[^1] == '\n') break; if (headerBytes.Count > 8192) throw new InvalidOperationException("HTTP header too large"); } var text = Encoding.ASCII.GetString(headerBytes.ToArray()); var lines = text.Split("\r\n", StringSplitOptions.None); if (lines.Length < 1) throw new InvalidOperationException("invalid HTTP request"); var parts = lines[0].Split(' '); if (parts.Length < 3) throw new InvalidOperationException("invalid HTTP request line"); var method = parts[0]; var path = parts[1]; var headers = new Dictionary(StringComparer.OrdinalIgnoreCase); for (int i = 1; i < lines.Length; i++) { var line = lines[i]; if (string.IsNullOrEmpty(line)) break; var colonIdx = line.IndexOf(':'); if (colonIdx > 0) { var name = line[..colonIdx].Trim(); var value = line[(colonIdx + 1)..].Trim(); headers[name] = value; } } return (method, path, headers); } private static bool HeaderContains(Dictionary headers, string name, string value) { if (!headers.TryGetValue(name, out var headerValue)) return false; foreach (var token in headerValue.Split(',')) { if (string.Equals(token.Trim(), value, StringComparison.OrdinalIgnoreCase)) return true; } return false; } private static Dictionary ParseCookies(string cookieHeader) { var cookies = new Dictionary(StringComparer.Ordinal); foreach (var pair in cookieHeader.Split(';')) { var trimmed = pair.Trim(); var eqIdx = trimmed.IndexOf('='); if (eqIdx > 0) cookies[trimmed[..eqIdx].Trim()] = trimmed[(eqIdx + 1)..].Trim(); } return cookies; } } /// /// Result of a WebSocket upgrade handshake attempt. /// public readonly record struct WsUpgradeResult( bool Success, bool Compress, bool Browser, bool NoCompFrag, bool MaskRead, bool MaskWrite, string? CookieJwt, string? CookieUsername, string? CookiePassword, string? CookieToken, string? ClientIp, WsClientKind Kind) { public static readonly WsUpgradeResult Failed = new( Success: false, Compress: false, Browser: false, NoCompFrag: false, MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null, CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client); }