333 lines
12 KiB
C#
333 lines
12 KiB
C#
// Copyright 2012-2026 The NATS Authors
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
using System.Collections.Specialized;
|
|
using System.Linq;
|
|
using System.Net;
|
|
using System.Net.Sockets;
|
|
using System.Security.Cryptography;
|
|
using System.Text;
|
|
using ZB.MOM.NatsNet.Server.Internal;
|
|
using ZB.MOM.NatsNet.Server.WebSocket;
|
|
|
|
namespace ZB.MOM.NatsNet.Server;
|
|
|
|
public sealed partial class NatsServer
|
|
{
|
|
private sealed class WsHttpError : Exception
|
|
{
|
|
public int StatusCode { get; }
|
|
|
|
public WsHttpError(int statusCode, string message)
|
|
: base(message)
|
|
{
|
|
StatusCode = statusCode;
|
|
}
|
|
}
|
|
|
|
internal static bool WsHeaderContains(NameValueCollection headers, string key, string expected)
|
|
{
|
|
var values = headers.GetValues(key);
|
|
if (values == null || values.Length == 0)
|
|
return false;
|
|
|
|
foreach (var value in values)
|
|
{
|
|
if (string.IsNullOrEmpty(value))
|
|
continue;
|
|
|
|
var tokens = value.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
|
|
foreach (var token in tokens)
|
|
{
|
|
if (string.Equals(token, expected, StringComparison.OrdinalIgnoreCase))
|
|
return true;
|
|
}
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
internal static (bool supported, bool noContext) WsPMCExtensionSupport(NameValueCollection headers, bool checkNoContextTakeOver)
|
|
{
|
|
var values = headers.GetValues("Sec-WebSocket-Extensions");
|
|
if (values == null || values.Length == 0)
|
|
return (false, false);
|
|
|
|
var foundPmc = false;
|
|
var noContext = false;
|
|
foreach (var value in values)
|
|
{
|
|
var extensions = value.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
|
|
foreach (var ext in extensions)
|
|
{
|
|
var parts = ext.Split(';', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
|
|
if (parts.Length == 0)
|
|
continue;
|
|
if (!parts[0].Equals(WsConstants.PMCExtension, StringComparison.OrdinalIgnoreCase))
|
|
continue;
|
|
|
|
foundPmc = true;
|
|
if (!checkNoContextTakeOver)
|
|
return (true, false);
|
|
|
|
noContext = parts.Any(p => p.Equals(WsConstants.PMCSrvNoCtx, StringComparison.OrdinalIgnoreCase)) &&
|
|
parts.Any(p => p.Equals(WsConstants.PMCCliNoCtx, StringComparison.OrdinalIgnoreCase));
|
|
}
|
|
}
|
|
|
|
return (foundPmc && (!checkNoContextTakeOver || noContext), noContext);
|
|
}
|
|
|
|
internal static Exception WsReturnHTTPError(int statusCode, string message) =>
|
|
new WsHttpError(statusCode, message);
|
|
|
|
private static bool wsHeaderContains(NameValueCollection headers, string key, string expected) =>
|
|
WsHeaderContains(headers, key, expected);
|
|
|
|
private static (bool supported, bool noContext) wsPMCExtensionSupport(NameValueCollection headers, bool checkNoContextTakeOver) =>
|
|
WsPMCExtensionSupport(headers, checkNoContextTakeOver);
|
|
|
|
private static Exception wsReturnHTTPError(int statusCode, string message) =>
|
|
WsReturnHTTPError(statusCode, message);
|
|
|
|
internal static string WsGetHostAndPort(string hostPort, out int port)
|
|
{
|
|
port = 0;
|
|
if (string.IsNullOrWhiteSpace(hostPort))
|
|
return string.Empty;
|
|
|
|
if (hostPort.Contains(':'))
|
|
{
|
|
var idx = hostPort.LastIndexOf(':');
|
|
var host = hostPort[..idx];
|
|
var portText = hostPort[(idx + 1)..];
|
|
if (int.TryParse(portText, out var parsed))
|
|
{
|
|
port = parsed;
|
|
return host;
|
|
}
|
|
}
|
|
|
|
return hostPort;
|
|
}
|
|
|
|
private static string wsGetHostAndPort(string hostPort, out int port) =>
|
|
WsGetHostAndPort(hostPort, out port);
|
|
|
|
internal static byte[] WsMakeChallengeKey(string key)
|
|
{
|
|
ArgumentNullException.ThrowIfNull(key);
|
|
return Encoding.ASCII.GetBytes(key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
|
|
}
|
|
|
|
internal static string WsAcceptKey(string key)
|
|
{
|
|
using var sha = SHA1.Create();
|
|
var digest = sha.ComputeHash(WsMakeChallengeKey(key));
|
|
return Convert.ToBase64String(digest);
|
|
}
|
|
|
|
private static byte[] wsMakeChallengeKey(string key) => WsMakeChallengeKey(key);
|
|
|
|
private static string wsAcceptKey(string key) => WsAcceptKey(key);
|
|
|
|
internal static Exception? ValidateWebsocketOptions(WebsocketOpts options)
|
|
{
|
|
if (options.Port < 0 || options.Port > 65535)
|
|
return new ArgumentException("websocket port out of range");
|
|
if (options.NoTls && options.TlsConfig != null)
|
|
return new ArgumentException("websocket no_tls and tls options are mutually exclusive");
|
|
if (options.HandshakeTimeout < TimeSpan.Zero)
|
|
return new ArgumentException("websocket handshake timeout can not be negative");
|
|
return null;
|
|
}
|
|
|
|
private static Exception? validateWebsocketOptions(WebsocketOpts options) => ValidateWebsocketOptions(options);
|
|
|
|
private void WsSetOriginOptions()
|
|
{
|
|
lock (_websocket.Mu)
|
|
{
|
|
_websocket.AllowedOrigins.Clear();
|
|
var opts = GetOpts().Websocket;
|
|
_websocket.SameOrigin = opts.SameOrigin;
|
|
|
|
foreach (var entry in opts.AllowedOrigins)
|
|
{
|
|
if (!Uri.TryCreate(entry, UriKind.Absolute, out var uri) || string.IsNullOrEmpty(uri.Host))
|
|
continue;
|
|
|
|
var port = uri.IsDefaultPort
|
|
? uri.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase) ? "443" : "80"
|
|
: uri.Port.ToString();
|
|
_websocket.AllowedOrigins[uri.Host] = new AllowedOrigin(uri.Scheme, port);
|
|
}
|
|
}
|
|
}
|
|
|
|
private void WsSetHeadersOptions()
|
|
{
|
|
var headers = GetOpts().Websocket.Headers;
|
|
if (headers.Count == 0)
|
|
{
|
|
_websocket.RawHeaders = string.Empty;
|
|
return;
|
|
}
|
|
|
|
var builder = new StringBuilder();
|
|
foreach (var (name, value) in headers)
|
|
{
|
|
if (string.IsNullOrWhiteSpace(name))
|
|
continue;
|
|
builder.Append(name.Trim());
|
|
builder.Append(": ");
|
|
builder.Append(value?.Trim() ?? string.Empty);
|
|
builder.Append("\r\n");
|
|
}
|
|
_websocket.RawHeaders = builder.ToString();
|
|
}
|
|
|
|
private void WsConfigAuth()
|
|
{
|
|
var ws = GetOpts().Websocket;
|
|
_websocket.AuthOverride = !string.IsNullOrWhiteSpace(ws.Username)
|
|
|| !string.IsNullOrWhiteSpace(ws.Password)
|
|
|| !string.IsNullOrWhiteSpace(ws.Token)
|
|
|| !string.IsNullOrWhiteSpace(ws.JwtCookie)
|
|
|| !string.IsNullOrWhiteSpace(ws.TokenCookie)
|
|
|| !string.IsNullOrWhiteSpace(ws.UsernameCookie)
|
|
|| !string.IsNullOrWhiteSpace(ws.PasswordCookie)
|
|
|| !string.IsNullOrWhiteSpace(ws.NoAuthUser);
|
|
}
|
|
|
|
private void StartWebsocketServer()
|
|
{
|
|
var opts = GetOpts().Websocket;
|
|
if (opts.Port == 0)
|
|
return;
|
|
if (_websocket.Listener != null)
|
|
return;
|
|
|
|
var host = string.IsNullOrWhiteSpace(opts.Host) ? ServerConstants.DefaultHost : opts.Host;
|
|
var ip = IPAddress.TryParse(host, out var parsed) ? parsed : IPAddress.Any;
|
|
var listener = new TcpListener(ip, opts.Port);
|
|
listener.Start();
|
|
|
|
_websocket.Listener = listener;
|
|
_websocket.ListenerErr = null;
|
|
_websocket.Port = ((IPEndPoint)listener.LocalEndpoint).Port;
|
|
_websocket.Host = host;
|
|
_websocket.Compression = opts.Compression;
|
|
_websocket.TlsConfig = opts.TlsConfig;
|
|
|
|
WsSetOriginOptions();
|
|
WsSetHeadersOptions();
|
|
WsConfigAuth();
|
|
|
|
var connectUrl = WebsocketUrl();
|
|
_websocket.ConnectUrls.Clear();
|
|
_websocket.ConnectUrls.Add(connectUrl);
|
|
|
|
UpdateServerINFOAndSendINFOToClients([], [connectUrl], add: true);
|
|
}
|
|
|
|
private int CloseWebsocketServerCore()
|
|
{
|
|
if (_websocket.Listener == null)
|
|
return 0;
|
|
|
|
try
|
|
{
|
|
_websocket.Listener.Stop();
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_websocket.ListenerErr = ex;
|
|
}
|
|
|
|
_websocket.Listener = null;
|
|
|
|
if (_websocket.ConnectUrls.Count > 0)
|
|
{
|
|
var urls = _websocket.ConnectUrls.ToArray();
|
|
_websocket.ConnectUrls.Clear();
|
|
UpdateServerINFOAndSendINFOToClients([], urls, add: false);
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
internal ClientConnection CreateWSClient(Stream nc, ClientKind kind)
|
|
{
|
|
var client = new ClientConnection(kind, this, nc)
|
|
{
|
|
Ws = new WebsocketConnection
|
|
{
|
|
Compress = _websocket.Compression,
|
|
MaskRead = true,
|
|
},
|
|
};
|
|
return client;
|
|
}
|
|
|
|
internal (WebsocketConnection? ws, ClientKind kind, Exception? err) WsUpgrade(HttpListenerRequest request)
|
|
{
|
|
var kind = ClientKind.Client;
|
|
if (request.Url != null)
|
|
{
|
|
var path = request.Url.AbsolutePath;
|
|
if (path.EndsWith("/leafnode", StringComparison.OrdinalIgnoreCase))
|
|
kind = ClientKind.Leaf;
|
|
else if (path.EndsWith("/mqtt", StringComparison.OrdinalIgnoreCase))
|
|
kind = ClientKind.Client;
|
|
}
|
|
|
|
if (!string.Equals(request.HttpMethod, "GET", StringComparison.OrdinalIgnoreCase))
|
|
return (null, kind, WsReturnHTTPError(405, "request method must be GET"));
|
|
if (string.IsNullOrWhiteSpace(request.UserHostName))
|
|
return (null, kind, WsReturnHTTPError(400, "'Host' missing in request"));
|
|
if (!WsHeaderContains(request.Headers, "Upgrade", "websocket"))
|
|
return (null, kind, WsReturnHTTPError(400, "invalid value for header 'Upgrade'"));
|
|
if (!WsHeaderContains(request.Headers, "Connection", "Upgrade"))
|
|
return (null, kind, WsReturnHTTPError(400, "invalid value for header 'Connection'"));
|
|
if (string.IsNullOrWhiteSpace(request.Headers["Sec-WebSocket-Key"]))
|
|
return (null, kind, WsReturnHTTPError(400, "key missing"));
|
|
if (!WsHeaderContains(request.Headers, "Sec-WebSocket-Version", "13"))
|
|
return (null, kind, WsReturnHTTPError(400, "invalid version"));
|
|
|
|
var ws = new WebsocketConnection
|
|
{
|
|
Compress = GetOpts().Websocket.Compression && WsPMCExtensionSupport(request.Headers, true).supported,
|
|
MaskRead = !string.Equals(request.Headers[WsConstants.NoMaskingHeader], WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase),
|
|
};
|
|
|
|
var xff = request.Headers.GetValues(WsConstants.XForwardedForHeader);
|
|
if (xff != null && xff.Length > 0 && IPAddress.TryParse(xff[0], out _))
|
|
ws.ClientIP = xff[0];
|
|
|
|
return (ws, kind, null);
|
|
}
|
|
|
|
internal static bool IsWSURL(string url)
|
|
{
|
|
if (!Uri.TryCreate(url, UriKind.Absolute, out var uri))
|
|
return false;
|
|
return uri.Scheme.Equals(WsConstants.SchemePrefix, StringComparison.OrdinalIgnoreCase);
|
|
}
|
|
|
|
internal static bool IsWSSURL(string url)
|
|
{
|
|
if (!Uri.TryCreate(url, UriKind.Absolute, out var uri))
|
|
return false;
|
|
return uri.Scheme.Equals(WsConstants.SchemePrefixTls, StringComparison.OrdinalIgnoreCase);
|
|
}
|
|
|
|
private static bool isWSURL(string url) => IsWSURL(url);
|
|
|
|
private static bool isWSSURL(string url) => IsWSSURL(url);
|
|
|
|
private System.Net.Security.SslServerAuthenticationOptions? wsGetTLSConfig() =>
|
|
GetOpts().Websocket.TlsConfig;
|
|
}
|