feat(batch26): implement websocket frame/core feature group A

This commit is contained in:
Joseph Doherty
2026-02-28 21:45:05 -05:00
parent e98e686ef2
commit 3653345a37
11 changed files with 1252 additions and 703 deletions

View File

@@ -0,0 +1,307 @@
// Copyright 2012-2026 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
using System.Collections.Specialized;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Security.Cryptography;
using System.Text;
using ZB.MOM.NatsNet.Server.Internal;
using ZB.MOM.NatsNet.Server.WebSocket;
namespace ZB.MOM.NatsNet.Server;
public sealed partial class NatsServer
{
private sealed class WsHttpError : Exception
{
public int StatusCode { get; }
public WsHttpError(int statusCode, string message)
: base(message)
{
StatusCode = statusCode;
}
}
internal static bool WsHeaderContains(NameValueCollection headers, string key, string expected)
{
var values = headers.GetValues(key);
if (values == null || values.Length == 0)
return false;
foreach (var value in values)
{
if (string.IsNullOrEmpty(value))
continue;
var tokens = value.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
foreach (var token in tokens)
{
if (string.Equals(token, expected, StringComparison.OrdinalIgnoreCase))
return true;
}
}
return false;
}
internal static (bool supported, bool noContext) WsPMCExtensionSupport(NameValueCollection headers, bool checkNoContextTakeOver)
{
var values = headers.GetValues("Sec-WebSocket-Extensions");
if (values == null || values.Length == 0)
return (false, false);
var foundPmc = false;
var noContext = false;
foreach (var value in values)
{
var extensions = value.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
foreach (var ext in extensions)
{
var parts = ext.Split(';', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
if (parts.Length == 0)
continue;
if (!parts[0].Equals(WsConstants.PMCExtension, StringComparison.OrdinalIgnoreCase))
continue;
foundPmc = true;
if (!checkNoContextTakeOver)
return (true, false);
noContext = parts.Any(p => p.Equals(WsConstants.PMCSrvNoCtx, StringComparison.OrdinalIgnoreCase)) &&
parts.Any(p => p.Equals(WsConstants.PMCCliNoCtx, StringComparison.OrdinalIgnoreCase));
}
}
return (foundPmc && (!checkNoContextTakeOver || noContext), noContext);
}
internal static Exception WsReturnHTTPError(int statusCode, string message) =>
new WsHttpError(statusCode, message);
internal static string WsGetHostAndPort(string hostPort, out int port)
{
port = 0;
if (string.IsNullOrWhiteSpace(hostPort))
return string.Empty;
if (hostPort.Contains(':'))
{
var idx = hostPort.LastIndexOf(':');
var host = hostPort[..idx];
var portText = hostPort[(idx + 1)..];
if (int.TryParse(portText, out var parsed))
{
port = parsed;
return host;
}
}
return hostPort;
}
internal static byte[] WsMakeChallengeKey(string key)
{
ArgumentNullException.ThrowIfNull(key);
return Encoding.ASCII.GetBytes(key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
}
internal static string WsAcceptKey(string key)
{
using var sha = SHA1.Create();
var digest = sha.ComputeHash(WsMakeChallengeKey(key));
return Convert.ToBase64String(digest);
}
internal static Exception? ValidateWebsocketOptions(WebsocketOpts options)
{
if (options.Port < 0 || options.Port > 65535)
return new ArgumentException("websocket port out of range");
if (options.NoTls && options.TlsConfig != null)
return new ArgumentException("websocket no_tls and tls options are mutually exclusive");
if (options.HandshakeTimeout < TimeSpan.Zero)
return new ArgumentException("websocket handshake timeout can not be negative");
return null;
}
private void WsSetOriginOptions()
{
lock (_websocket.Mu)
{
_websocket.AllowedOrigins.Clear();
var opts = GetOpts().Websocket;
_websocket.SameOrigin = opts.SameOrigin;
foreach (var entry in opts.AllowedOrigins)
{
if (!Uri.TryCreate(entry, UriKind.Absolute, out var uri) || string.IsNullOrEmpty(uri.Host))
continue;
var port = uri.IsDefaultPort
? uri.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase) ? "443" : "80"
: uri.Port.ToString();
_websocket.AllowedOrigins[uri.Host] = new AllowedOrigin(uri.Scheme, port);
}
}
}
private void WsSetHeadersOptions()
{
var headers = GetOpts().Websocket.Headers;
if (headers.Count == 0)
{
_websocket.RawHeaders = string.Empty;
return;
}
var builder = new StringBuilder();
foreach (var (name, value) in headers)
{
if (string.IsNullOrWhiteSpace(name))
continue;
builder.Append(name.Trim());
builder.Append(": ");
builder.Append(value?.Trim() ?? string.Empty);
builder.Append("\r\n");
}
_websocket.RawHeaders = builder.ToString();
}
private void WsConfigAuth()
{
var ws = GetOpts().Websocket;
_websocket.AuthOverride = !string.IsNullOrWhiteSpace(ws.Username)
|| !string.IsNullOrWhiteSpace(ws.Password)
|| !string.IsNullOrWhiteSpace(ws.Token)
|| !string.IsNullOrWhiteSpace(ws.JwtCookie)
|| !string.IsNullOrWhiteSpace(ws.TokenCookie)
|| !string.IsNullOrWhiteSpace(ws.UsernameCookie)
|| !string.IsNullOrWhiteSpace(ws.PasswordCookie)
|| !string.IsNullOrWhiteSpace(ws.NoAuthUser);
}
private void StartWebsocketServer()
{
var opts = GetOpts().Websocket;
if (opts.Port == 0)
return;
if (_websocket.Listener != null)
return;
var host = string.IsNullOrWhiteSpace(opts.Host) ? ServerConstants.DefaultHost : opts.Host;
var ip = IPAddress.TryParse(host, out var parsed) ? parsed : IPAddress.Any;
var listener = new TcpListener(ip, opts.Port);
listener.Start();
_websocket.Listener = listener;
_websocket.ListenerErr = null;
_websocket.Port = ((IPEndPoint)listener.LocalEndpoint).Port;
_websocket.Host = host;
_websocket.Compression = opts.Compression;
_websocket.TlsConfig = opts.TlsConfig;
WsSetOriginOptions();
WsSetHeadersOptions();
WsConfigAuth();
var connectUrl = WebsocketUrl();
_websocket.ConnectUrls.Clear();
_websocket.ConnectUrls.Add(connectUrl);
UpdateServerINFOAndSendINFOToClients([], [connectUrl], add: true);
}
private int CloseWebsocketServerCore()
{
if (_websocket.Listener == null)
return 0;
try
{
_websocket.Listener.Stop();
}
catch (Exception ex)
{
_websocket.ListenerErr = ex;
}
_websocket.Listener = null;
if (_websocket.ConnectUrls.Count > 0)
{
var urls = _websocket.ConnectUrls.ToArray();
_websocket.ConnectUrls.Clear();
UpdateServerINFOAndSendINFOToClients([], urls, add: false);
}
return 0;
}
internal ClientConnection CreateWSClient(Stream nc, ClientKind kind)
{
var client = new ClientConnection(kind, this, nc)
{
Ws = new WebsocketConnection
{
Compress = _websocket.Compression,
MaskRead = true,
},
};
return client;
}
internal (WebsocketConnection? ws, ClientKind kind, Exception? err) WsUpgrade(HttpListenerRequest request)
{
var kind = ClientKind.Client;
if (request.Url != null)
{
var path = request.Url.AbsolutePath;
if (path.EndsWith("/leafnode", StringComparison.OrdinalIgnoreCase))
kind = ClientKind.Leaf;
else if (path.EndsWith("/mqtt", StringComparison.OrdinalIgnoreCase))
kind = ClientKind.Client;
}
if (!string.Equals(request.HttpMethod, "GET", StringComparison.OrdinalIgnoreCase))
return (null, kind, WsReturnHTTPError(405, "request method must be GET"));
if (string.IsNullOrWhiteSpace(request.UserHostName))
return (null, kind, WsReturnHTTPError(400, "'Host' missing in request"));
if (!WsHeaderContains(request.Headers, "Upgrade", "websocket"))
return (null, kind, WsReturnHTTPError(400, "invalid value for header 'Upgrade'"));
if (!WsHeaderContains(request.Headers, "Connection", "Upgrade"))
return (null, kind, WsReturnHTTPError(400, "invalid value for header 'Connection'"));
if (string.IsNullOrWhiteSpace(request.Headers["Sec-WebSocket-Key"]))
return (null, kind, WsReturnHTTPError(400, "key missing"));
if (!WsHeaderContains(request.Headers, "Sec-WebSocket-Version", "13"))
return (null, kind, WsReturnHTTPError(400, "invalid version"));
var ws = new WebsocketConnection
{
Compress = GetOpts().Websocket.Compression && WsPMCExtensionSupport(request.Headers, true).supported,
MaskRead = !string.Equals(request.Headers[WsConstants.NoMaskingHeader], WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase),
};
var xff = request.Headers.GetValues(WsConstants.XForwardedForHeader);
if (xff != null && xff.Length > 0 && IPAddress.TryParse(xff[0], out _))
ws.ClientIP = xff[0];
return (ws, kind, null);
}
internal static bool IsWSURL(string url)
{
if (!Uri.TryCreate(url, UriKind.Absolute, out var uri))
return false;
return uri.Scheme.Equals(WsConstants.SchemePrefix, StringComparison.OrdinalIgnoreCase);
}
internal static bool IsWSSURL(string url)
{
if (!Uri.TryCreate(url, UriKind.Absolute, out var uri))
return false;
return uri.Scheme.Equals(WsConstants.SchemePrefixTls, StringComparison.OrdinalIgnoreCase);
}
}