// Copyright 2012-2026 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. using System.Collections.Specialized; using System.Linq; using System.Net; using System.Net.Sockets; using System.Security.Cryptography; using System.Text; using ZB.MOM.NatsNet.Server.Internal; using ZB.MOM.NatsNet.Server.WebSocket; namespace ZB.MOM.NatsNet.Server; public sealed partial class NatsServer { private sealed class WsHttpError : Exception { public int StatusCode { get; } public WsHttpError(int statusCode, string message) : base(message) { StatusCode = statusCode; } } internal static bool WsHeaderContains(NameValueCollection headers, string key, string expected) { var values = headers.GetValues(key); if (values == null || values.Length == 0) return false; foreach (var value in values) { if (string.IsNullOrEmpty(value)) continue; var tokens = value.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); foreach (var token in tokens) { if (string.Equals(token, expected, StringComparison.OrdinalIgnoreCase)) return true; } } return false; } internal static (bool supported, bool noContext) WsPMCExtensionSupport(NameValueCollection headers, bool checkNoContextTakeOver) { var values = headers.GetValues("Sec-WebSocket-Extensions"); if (values == null || values.Length == 0) return (false, false); var foundPmc = false; var noContext = false; foreach (var value in values) { var extensions = value.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); foreach (var ext in extensions) { var parts = ext.Split(';', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); if (parts.Length == 0) continue; if (!parts[0].Equals(WsConstants.PMCExtension, StringComparison.OrdinalIgnoreCase)) continue; foundPmc = true; if (!checkNoContextTakeOver) return (true, false); noContext = parts.Any(p => p.Equals(WsConstants.PMCSrvNoCtx, StringComparison.OrdinalIgnoreCase)) && parts.Any(p => p.Equals(WsConstants.PMCCliNoCtx, StringComparison.OrdinalIgnoreCase)); } } return (foundPmc && (!checkNoContextTakeOver || noContext), noContext); } internal static Exception WsReturnHTTPError(int statusCode, string message) => new WsHttpError(statusCode, message); private static bool wsHeaderContains(NameValueCollection headers, string key, string expected) => WsHeaderContains(headers, key, expected); private static (bool supported, bool noContext) wsPMCExtensionSupport(NameValueCollection headers, bool checkNoContextTakeOver) => WsPMCExtensionSupport(headers, checkNoContextTakeOver); private static Exception wsReturnHTTPError(int statusCode, string message) => WsReturnHTTPError(statusCode, message); internal static string WsGetHostAndPort(string hostPort, out int port) { port = 0; if (string.IsNullOrWhiteSpace(hostPort)) return string.Empty; if (hostPort.Contains(':')) { var idx = hostPort.LastIndexOf(':'); var host = hostPort[..idx]; var portText = hostPort[(idx + 1)..]; if (int.TryParse(portText, out var parsed)) { port = parsed; return host; } } return hostPort; } private static string wsGetHostAndPort(string hostPort, out int port) => WsGetHostAndPort(hostPort, out port); internal static byte[] WsMakeChallengeKey(string key) { ArgumentNullException.ThrowIfNull(key); return Encoding.ASCII.GetBytes(key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); } internal static string WsAcceptKey(string key) { using var sha = SHA1.Create(); var digest = sha.ComputeHash(WsMakeChallengeKey(key)); return Convert.ToBase64String(digest); } private static byte[] wsMakeChallengeKey(string key) => WsMakeChallengeKey(key); private static string wsAcceptKey(string key) => WsAcceptKey(key); internal static Exception? ValidateWebsocketOptions(WebsocketOpts options) { if (options.Port < 0 || options.Port > 65535) return new ArgumentException("websocket port out of range"); if (options.NoTls && options.TlsConfig != null) return new ArgumentException("websocket no_tls and tls options are mutually exclusive"); if (options.HandshakeTimeout < TimeSpan.Zero) return new ArgumentException("websocket handshake timeout can not be negative"); return null; } private static Exception? validateWebsocketOptions(WebsocketOpts options) => ValidateWebsocketOptions(options); private void WsSetOriginOptions() { lock (_websocket.Mu) { _websocket.AllowedOrigins.Clear(); var opts = GetOpts().Websocket; _websocket.SameOrigin = opts.SameOrigin; foreach (var entry in opts.AllowedOrigins) { if (!Uri.TryCreate(entry, UriKind.Absolute, out var uri) || string.IsNullOrEmpty(uri.Host)) continue; var port = uri.IsDefaultPort ? uri.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase) ? "443" : "80" : uri.Port.ToString(); _websocket.AllowedOrigins[uri.Host] = new AllowedOrigin(uri.Scheme, port); } } } private void WsSetHeadersOptions() { var headers = GetOpts().Websocket.Headers; if (headers.Count == 0) { _websocket.RawHeaders = string.Empty; return; } var builder = new StringBuilder(); foreach (var (name, value) in headers) { if (string.IsNullOrWhiteSpace(name)) continue; builder.Append(name.Trim()); builder.Append(": "); builder.Append(value?.Trim() ?? string.Empty); builder.Append("\r\n"); } _websocket.RawHeaders = builder.ToString(); } private void WsConfigAuth() { var ws = GetOpts().Websocket; _websocket.AuthOverride = !string.IsNullOrWhiteSpace(ws.Username) || !string.IsNullOrWhiteSpace(ws.Password) || !string.IsNullOrWhiteSpace(ws.Token) || !string.IsNullOrWhiteSpace(ws.JwtCookie) || !string.IsNullOrWhiteSpace(ws.TokenCookie) || !string.IsNullOrWhiteSpace(ws.UsernameCookie) || !string.IsNullOrWhiteSpace(ws.PasswordCookie) || !string.IsNullOrWhiteSpace(ws.NoAuthUser); } private void StartWebsocketServer() { var opts = GetOpts().Websocket; if (opts.Port == 0) return; if (_websocket.Listener != null) return; var host = string.IsNullOrWhiteSpace(opts.Host) ? ServerConstants.DefaultHost : opts.Host; var ip = IPAddress.TryParse(host, out var parsed) ? parsed : IPAddress.Any; var listener = new TcpListener(ip, opts.Port); listener.Start(); _websocket.Listener = listener; _websocket.ListenerErr = null; _websocket.Port = ((IPEndPoint)listener.LocalEndpoint).Port; _websocket.Host = host; _websocket.Compression = opts.Compression; _websocket.TlsConfig = opts.TlsConfig; WsSetOriginOptions(); WsSetHeadersOptions(); WsConfigAuth(); var connectUrl = WebsocketUrl(); _websocket.ConnectUrls.Clear(); _websocket.ConnectUrls.Add(connectUrl); UpdateServerINFOAndSendINFOToClients([], [connectUrl], add: true); } private int CloseWebsocketServerCore() { if (_websocket.Listener == null) return 0; try { _websocket.Listener.Stop(); } catch (Exception ex) { _websocket.ListenerErr = ex; } _websocket.Listener = null; if (_websocket.ConnectUrls.Count > 0) { var urls = _websocket.ConnectUrls.ToArray(); _websocket.ConnectUrls.Clear(); UpdateServerINFOAndSendINFOToClients([], urls, add: false); } return 0; } internal ClientConnection CreateWSClient(Stream nc, ClientKind kind) { var client = new ClientConnection(kind, this, nc) { Ws = new WebsocketConnection { Compress = _websocket.Compression, MaskRead = true, }, }; return client; } internal (WebsocketConnection? ws, ClientKind kind, Exception? err) WsUpgrade(HttpListenerRequest request) { var kind = ClientKind.Client; if (request.Url != null) { var path = request.Url.AbsolutePath; if (path.EndsWith("/leafnode", StringComparison.OrdinalIgnoreCase)) kind = ClientKind.Leaf; else if (path.EndsWith("/mqtt", StringComparison.OrdinalIgnoreCase)) kind = ClientKind.Client; } if (!string.Equals(request.HttpMethod, "GET", StringComparison.OrdinalIgnoreCase)) return (null, kind, WsReturnHTTPError(405, "request method must be GET")); if (string.IsNullOrWhiteSpace(request.UserHostName)) return (null, kind, WsReturnHTTPError(400, "'Host' missing in request")); if (!WsHeaderContains(request.Headers, "Upgrade", "websocket")) return (null, kind, WsReturnHTTPError(400, "invalid value for header 'Upgrade'")); if (!WsHeaderContains(request.Headers, "Connection", "Upgrade")) return (null, kind, WsReturnHTTPError(400, "invalid value for header 'Connection'")); if (string.IsNullOrWhiteSpace(request.Headers["Sec-WebSocket-Key"])) return (null, kind, WsReturnHTTPError(400, "key missing")); if (!WsHeaderContains(request.Headers, "Sec-WebSocket-Version", "13")) return (null, kind, WsReturnHTTPError(400, "invalid version")); var ws = new WebsocketConnection { Compress = GetOpts().Websocket.Compression && WsPMCExtensionSupport(request.Headers, true).supported, MaskRead = !string.Equals(request.Headers[WsConstants.NoMaskingHeader], WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase), }; var xff = request.Headers.GetValues(WsConstants.XForwardedForHeader); if (xff != null && xff.Length > 0 && IPAddress.TryParse(xff[0], out _)) ws.ClientIP = xff[0]; return (ws, kind, null); } internal static bool IsWSURL(string url) { if (!Uri.TryCreate(url, UriKind.Absolute, out var uri)) return false; return uri.Scheme.Equals(WsConstants.SchemePrefix, StringComparison.OrdinalIgnoreCase); } internal static bool IsWSSURL(string url) { if (!Uri.TryCreate(url, UriKind.Absolute, out var uri)) return false; return uri.Scheme.Equals(WsConstants.SchemePrefixTls, StringComparison.OrdinalIgnoreCase); } private static bool isWSURL(string url) => IsWSURL(url); private static bool isWSSURL(string url) => IsWSSURL(url); private System.Net.Security.SslServerAuthenticationOptions? wsGetTLSConfig() => GetOpts().Websocket.TlsConfig; }