feat(batch26): implement websocket frame/core feature group A

This commit is contained in:
Joseph Doherty
2026-02-28 21:45:05 -05:00
parent e98e686ef2
commit 3653345a37
11 changed files with 1252 additions and 703 deletions

View File

@@ -0,0 +1,253 @@
// Copyright 2012-2026 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
using System.Buffers.Binary;
using System.Text;
using ZB.MOM.NatsNet.Server.WebSocket;
namespace ZB.MOM.NatsNet.Server;
public sealed partial class ClientConnection
{
internal (List<OutboundChunk> chunks, long attempted) WsCollapsePtoNB() => CollapsePtoNB();
internal List<byte[]> WsRead(WsReadInfo readInfo, Stream reader, byte[] buffer)
{
var bufs = new List<byte[]>();
var pos = 0;
var max = buffer.Length;
var maxPayload = Volatile.Read(ref _mpay);
while (pos != max)
{
if (readInfo.FrameStart)
{
var b0 = buffer[pos];
var frameType = (WsOpCode)(b0 & 0xF);
var final = (b0 & WsConstants.FinalBit) != 0;
var compressed = (b0 & WsConstants.Rsv1Bit) != 0;
pos++;
var (firstLen, newPos) = WebSocketHelpers.WsGet(reader, buffer, pos, 1);
pos = newPos;
var b1 = firstLen[0];
if (readInfo.Mask && (b1 & WsConstants.MaskBit) == 0)
throw WsHandleProtocolError("mask bit missing");
readInfo.Rem = b1 & 0x7F;
switch (frameType)
{
case WsOpCode.Ping:
case WsOpCode.Pong:
case WsOpCode.Close:
if (readInfo.Rem > WsConstants.MaxControlPayloadSize)
throw WsHandleProtocolError($"control frame length bigger than maximum allowed of {WsConstants.MaxControlPayloadSize} bytes");
if (!final)
throw WsHandleProtocolError("control frame does not have final bit set");
break;
case WsOpCode.Text:
case WsOpCode.Binary:
if (!readInfo.FinalFrameReceived)
throw WsHandleProtocolError("new message started before final frame for previous message was received");
readInfo.FinalFrameReceived = final;
readInfo.FrameCompressed = compressed;
break;
case WsOpCode.Continuation:
if (readInfo.FinalFrameReceived || compressed)
throw WsHandleProtocolError("invalid continuation frame");
readInfo.FinalFrameReceived = final;
break;
default:
throw WsHandleProtocolError($"unknown opcode {(int)frameType}");
}
if (readInfo.Rem == 126)
{
var (extended, p) = WebSocketHelpers.WsGet(reader, buffer, pos, 2);
pos = p;
readInfo.Rem = BinaryPrimitives.ReadUInt16BigEndian(extended);
}
else if (readInfo.Rem == 127)
{
var (extended, p) = WebSocketHelpers.WsGet(reader, buffer, pos, 8);
pos = p;
readInfo.Rem = checked((int)BinaryPrimitives.ReadUInt64BigEndian(extended));
}
if (readInfo.Mask)
{
var (maskKey, p) = WebSocketHelpers.WsGet(reader, buffer, pos, 4);
pos = p;
Array.Copy(maskKey, 0, readInfo.MaskKey, 0, 4);
readInfo.MaskKeyPosition = 0;
}
if (WebSocketHelpers.WsIsControlFrame(frameType))
{
pos = WsHandleControlFrame(readInfo, frameType, reader, buffer, pos);
continue;
}
readInfo.FrameStart = false;
}
if (pos >= max)
continue;
var n = readInfo.Rem;
if (pos + n > max)
n = max - pos;
var payload = buffer.AsSpan(pos, n).ToArray();
pos += n;
readInfo.Rem -= n;
if (readInfo.Mask)
readInfo.Unmask(payload);
var addToBufs = true;
if (readInfo.FrameCompressed)
{
addToBufs = false;
readInfo.CompressedBuffers.Add(payload);
if (readInfo.FinalFrameReceived && readInfo.Rem == 0)
{
payload = readInfo.Decompress(maxPayload);
readInfo.FrameCompressed = false;
addToBufs = true;
}
}
if (addToBufs)
bufs.Add(payload);
if (readInfo.Rem == 0)
readInfo.FrameStart = true;
}
return bufs;
}
internal int WsHandleControlFrame(WsReadInfo readInfo, WsOpCode frameType, Stream networkConnection, byte[] buffer, int pos)
{
byte[] payload = [];
if (readInfo.Rem > 0)
{
(payload, pos) = WebSocketHelpers.WsGet(networkConnection, buffer, pos, readInfo.Rem);
if (readInfo.Mask)
readInfo.Unmask(payload);
readInfo.Rem = 0;
}
switch (frameType)
{
case WsOpCode.Close:
{
var status = WsConstants.CloseNoStatusReceived;
string body = string.Empty;
var payloadLength = payload.Length;
var hasStatus = payloadLength >= WsConstants.CloseStatusSize;
var hasBody = payloadLength > WsConstants.CloseStatusSize;
if (hasStatus)
{
status = BinaryPrimitives.ReadUInt16BigEndian(payload.AsSpan(0, 2));
if (hasBody)
{
body = Encoding.UTF8.GetString(payload.AsSpan(WsConstants.CloseStatusSize));
if (!Encoding.UTF8.GetBytes(body).AsSpan().SequenceEqual(payload.AsSpan(WsConstants.CloseStatusSize)))
{
status = WsConstants.CloseInvalidPayloadData;
body = "invalid utf8 body in close frame";
}
}
}
byte[]? closeMessage = null;
if (status != WsConstants.CloseNoStatusReceived)
closeMessage = WebSocketHelpers.WsCreateCloseMessage(status, body);
WsEnqueueControlMessage(WsOpCode.Close, closeMessage ?? []);
throw new EndOfStreamException();
}
case WsOpCode.Ping:
WsEnqueueControlMessage(WsOpCode.Pong, payload);
break;
case WsOpCode.Pong:
break;
}
return pos;
}
internal void WsEnqueueControlMessage(WsOpCode controlMessage, byte[] payload)
{
lock (_mu)
WsEnqueueControlMessageLocked(controlMessage, payload);
}
internal void WsEnqueueControlMessageLocked(WsOpCode controlMessage, byte[] payload)
{
if (Ws == null)
Ws = new WebsocketConnection();
var useMasking = Ws.MaskWrite;
var headerSize = 2 + (useMasking ? 4 : 0);
var control = NbPool.Get(headerSize + payload.Length);
var (n, key) = WebSocketHelpers.WsFillFrameHeader(control, useMasking, first: true, final: true, compressed: false, controlMessage, payload.Length);
var totalLen = n;
if (payload.Length > 0)
{
Array.Copy(payload, 0, control, n, payload.Length);
if (useMasking && key != null)
WebSocketHelpers.WsMaskBuf(key, control.AsSpan(n, payload.Length));
totalLen += payload.Length;
}
var frame = control[..totalLen];
OutPb += totalLen;
if (controlMessage == WsOpCode.Close)
{
Ws.CloseSent = true;
Ws.CloseMessage = frame;
}
else
{
Ws.Frames.Add(frame);
Ws.FrameSize += frame.Length;
}
FlushSignal();
}
internal void WsEnqueueCloseMessage(ClosedState reason)
{
var status = reason switch
{
ClosedState.ClientClosed => WsConstants.CloseNormalClosure,
ClosedState.AuthenticationTimeout or ClosedState.AuthenticationViolation or ClosedState.SlowConsumerPendingBytes or
ClosedState.SlowConsumerWriteDeadline or ClosedState.MaxAccountConnectionsExceeded or ClosedState.MaxConnectionsExceeded or
ClosedState.MaxControlLineExceeded or ClosedState.MaxSubscriptionsExceeded or ClosedState.MissingAccount or
ClosedState.AuthenticationExpired or ClosedState.Revocation => WsConstants.ClosePolicyViolation,
ClosedState.TlsHandshakeError => WsConstants.CloseTlsHandshake,
ClosedState.ParseError or ClosedState.ProtocolViolation or ClosedState.BadClientProtocolVersion => WsConstants.CloseProtocolError,
ClosedState.MaxPayloadExceeded => WsConstants.CloseMessageTooBig,
ClosedState.WriteError or ClosedState.ReadError or ClosedState.StaleConnection or ClosedState.ServerShutdown => WsConstants.CloseGoingAway,
_ => WsConstants.CloseInternalError,
};
var body = WebSocketHelpers.WsCreateCloseMessage(status, reason.ToString());
WsEnqueueControlMessageLocked(WsOpCode.Close, body);
}
internal Exception WsHandleProtocolError(string message)
{
var payload = WebSocketHelpers.WsCreateCloseMessage(WsConstants.CloseProtocolError, message);
WsEnqueueControlMessage(WsOpCode.Close, payload);
return new InvalidDataException(message);
}
}

View File

@@ -26,6 +26,7 @@ using ZB.MOM.NatsNet.Server.Auth;
using ZB.MOM.NatsNet.Server.Internal;
using ZB.MOM.NatsNet.Server.Internal.DataStructures;
using ZB.MOM.NatsNet.Server.Protocol;
using ZB.MOM.NatsNet.Server.WebSocket;
namespace ZB.MOM.NatsNet.Server;
@@ -113,6 +114,7 @@ public sealed partial class ClientConnection
// Client options (from CONNECT message).
internal ClientOptions Opts = ClientOptions.Default;
internal Route? Route;
internal WebsocketConnection? Ws;
// Flags and state.
internal ClientFlags Flags; // mirrors c.flags clientFlag
@@ -1484,10 +1486,26 @@ public sealed partial class ClientConnection
internal (List<OutboundChunk> chunks, long attempted) CollapsePtoNB()
{
var chunks = OutNb;
if (Ws != null && Ws.Frames.Count > 0)
{
chunks = [..OutNb];
foreach (var frame in Ws.Frames)
chunks.Add(new OutboundChunk(frame, frame.Length));
Ws.Frames.Clear();
Ws.FrameSize = 0;
}
if (Ws is { CloseSent: true, CloseMessage: not null } && OutPb == Ws.CloseMessage.Length)
{
chunks = [..chunks, new OutboundChunk(Ws.CloseMessage, Ws.CloseMessage.Length)];
Ws.CloseMessage = null;
}
long attempted = 0;
foreach (var chunk in OutNb)
foreach (var chunk in chunks)
attempted += chunk.Count;
return (OutNb, attempted);
return (chunks, attempted);
}
internal bool FlushOutbound()
@@ -1784,7 +1802,7 @@ public sealed partial class ClientConnection
// =========================================================================
internal bool IsMqtt() => false; // Deferred to session 22 (MQTT).
internal bool IsWebSocket() => false; // Deferred to session 23 (WebSocket).
internal bool IsWebSocket() => Ws != null;
internal bool IsHubLeafNode() => false; // Deferred to session 15 (leaf nodes).
internal string RemoteCluster() => string.Empty; // Deferred to sessions 14/15.
}

View File

@@ -503,7 +503,7 @@ public sealed partial class NatsServer
gatewayErr = _gatewayListenerErr;
leafOk = opts.LeafNode.Port == 0 || _leafNodeListener != null;
leafErr = _leafNodeListenerErr;
wsOk = opts.Websocket.Port == 0;
wsOk = opts.Websocket.Port == 0 || _websocket.Listener != null;
mqttOk = opts.Mqtt.Port == 0;
_mu.ExitReadLock();
@@ -952,10 +952,10 @@ public sealed partial class NatsServer
}
/// <summary>
/// Stub — closes WebSocket server if running (session 23).
/// Closes the WebSocket server if running.
/// Returns the number of done-channel signals to expect.
/// </summary>
private int CloseWebsocketServer() => 0;
private int CloseWebsocketServer() => CloseWebsocketServerCore();
/// <summary>
/// Iterates over all route connections. Stub — session 14.

View File

@@ -189,6 +189,7 @@ public sealed partial class NatsServer
_clientConnectUrls.Clear();
_clientConnectUrls.AddRange(GetClientConnectURLs());
_listener = l;
StartWebsocketServer();
// Start the accept goroutine.
_ = Task.Run(() =>
@@ -801,7 +802,8 @@ public sealed partial class NatsServer
if (wsUpdated)
{
// WebSocket connect URLs stub — session 23.
var wsUrls = _websocket.ConnectUrlsMap.GetAsStringSlice();
_info.WsConnectUrls = wsUrls.Length > 0 ? wsUrls : null;
}
if (cliUpdated || wsUpdated)
@@ -1140,7 +1142,7 @@ public sealed partial class NatsServer
if (opts.Cluster.Port != 0) list.Add(_routeListener);
if (opts.HttpPort != 0 || opts.HttpsPort != 0) list.Add(_http);
if (opts.ProfPort != 0) list.Add(_profiler);
// WebSocket listener — session 23.
if (opts.Websocket.Port != 0) list.Add(_websocket.Listener);
return list.ToArray();
}

View File

@@ -0,0 +1,307 @@
// Copyright 2012-2026 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
using System.Collections.Specialized;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Security.Cryptography;
using System.Text;
using ZB.MOM.NatsNet.Server.Internal;
using ZB.MOM.NatsNet.Server.WebSocket;
namespace ZB.MOM.NatsNet.Server;
public sealed partial class NatsServer
{
private sealed class WsHttpError : Exception
{
public int StatusCode { get; }
public WsHttpError(int statusCode, string message)
: base(message)
{
StatusCode = statusCode;
}
}
internal static bool WsHeaderContains(NameValueCollection headers, string key, string expected)
{
var values = headers.GetValues(key);
if (values == null || values.Length == 0)
return false;
foreach (var value in values)
{
if (string.IsNullOrEmpty(value))
continue;
var tokens = value.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
foreach (var token in tokens)
{
if (string.Equals(token, expected, StringComparison.OrdinalIgnoreCase))
return true;
}
}
return false;
}
internal static (bool supported, bool noContext) WsPMCExtensionSupport(NameValueCollection headers, bool checkNoContextTakeOver)
{
var values = headers.GetValues("Sec-WebSocket-Extensions");
if (values == null || values.Length == 0)
return (false, false);
var foundPmc = false;
var noContext = false;
foreach (var value in values)
{
var extensions = value.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
foreach (var ext in extensions)
{
var parts = ext.Split(';', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
if (parts.Length == 0)
continue;
if (!parts[0].Equals(WsConstants.PMCExtension, StringComparison.OrdinalIgnoreCase))
continue;
foundPmc = true;
if (!checkNoContextTakeOver)
return (true, false);
noContext = parts.Any(p => p.Equals(WsConstants.PMCSrvNoCtx, StringComparison.OrdinalIgnoreCase)) &&
parts.Any(p => p.Equals(WsConstants.PMCCliNoCtx, StringComparison.OrdinalIgnoreCase));
}
}
return (foundPmc && (!checkNoContextTakeOver || noContext), noContext);
}
internal static Exception WsReturnHTTPError(int statusCode, string message) =>
new WsHttpError(statusCode, message);
internal static string WsGetHostAndPort(string hostPort, out int port)
{
port = 0;
if (string.IsNullOrWhiteSpace(hostPort))
return string.Empty;
if (hostPort.Contains(':'))
{
var idx = hostPort.LastIndexOf(':');
var host = hostPort[..idx];
var portText = hostPort[(idx + 1)..];
if (int.TryParse(portText, out var parsed))
{
port = parsed;
return host;
}
}
return hostPort;
}
internal static byte[] WsMakeChallengeKey(string key)
{
ArgumentNullException.ThrowIfNull(key);
return Encoding.ASCII.GetBytes(key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
}
internal static string WsAcceptKey(string key)
{
using var sha = SHA1.Create();
var digest = sha.ComputeHash(WsMakeChallengeKey(key));
return Convert.ToBase64String(digest);
}
internal static Exception? ValidateWebsocketOptions(WebsocketOpts options)
{
if (options.Port < 0 || options.Port > 65535)
return new ArgumentException("websocket port out of range");
if (options.NoTls && options.TlsConfig != null)
return new ArgumentException("websocket no_tls and tls options are mutually exclusive");
if (options.HandshakeTimeout < TimeSpan.Zero)
return new ArgumentException("websocket handshake timeout can not be negative");
return null;
}
private void WsSetOriginOptions()
{
lock (_websocket.Mu)
{
_websocket.AllowedOrigins.Clear();
var opts = GetOpts().Websocket;
_websocket.SameOrigin = opts.SameOrigin;
foreach (var entry in opts.AllowedOrigins)
{
if (!Uri.TryCreate(entry, UriKind.Absolute, out var uri) || string.IsNullOrEmpty(uri.Host))
continue;
var port = uri.IsDefaultPort
? uri.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase) ? "443" : "80"
: uri.Port.ToString();
_websocket.AllowedOrigins[uri.Host] = new AllowedOrigin(uri.Scheme, port);
}
}
}
private void WsSetHeadersOptions()
{
var headers = GetOpts().Websocket.Headers;
if (headers.Count == 0)
{
_websocket.RawHeaders = string.Empty;
return;
}
var builder = new StringBuilder();
foreach (var (name, value) in headers)
{
if (string.IsNullOrWhiteSpace(name))
continue;
builder.Append(name.Trim());
builder.Append(": ");
builder.Append(value?.Trim() ?? string.Empty);
builder.Append("\r\n");
}
_websocket.RawHeaders = builder.ToString();
}
private void WsConfigAuth()
{
var ws = GetOpts().Websocket;
_websocket.AuthOverride = !string.IsNullOrWhiteSpace(ws.Username)
|| !string.IsNullOrWhiteSpace(ws.Password)
|| !string.IsNullOrWhiteSpace(ws.Token)
|| !string.IsNullOrWhiteSpace(ws.JwtCookie)
|| !string.IsNullOrWhiteSpace(ws.TokenCookie)
|| !string.IsNullOrWhiteSpace(ws.UsernameCookie)
|| !string.IsNullOrWhiteSpace(ws.PasswordCookie)
|| !string.IsNullOrWhiteSpace(ws.NoAuthUser);
}
private void StartWebsocketServer()
{
var opts = GetOpts().Websocket;
if (opts.Port == 0)
return;
if (_websocket.Listener != null)
return;
var host = string.IsNullOrWhiteSpace(opts.Host) ? ServerConstants.DefaultHost : opts.Host;
var ip = IPAddress.TryParse(host, out var parsed) ? parsed : IPAddress.Any;
var listener = new TcpListener(ip, opts.Port);
listener.Start();
_websocket.Listener = listener;
_websocket.ListenerErr = null;
_websocket.Port = ((IPEndPoint)listener.LocalEndpoint).Port;
_websocket.Host = host;
_websocket.Compression = opts.Compression;
_websocket.TlsConfig = opts.TlsConfig;
WsSetOriginOptions();
WsSetHeadersOptions();
WsConfigAuth();
var connectUrl = WebsocketUrl();
_websocket.ConnectUrls.Clear();
_websocket.ConnectUrls.Add(connectUrl);
UpdateServerINFOAndSendINFOToClients([], [connectUrl], add: true);
}
private int CloseWebsocketServerCore()
{
if (_websocket.Listener == null)
return 0;
try
{
_websocket.Listener.Stop();
}
catch (Exception ex)
{
_websocket.ListenerErr = ex;
}
_websocket.Listener = null;
if (_websocket.ConnectUrls.Count > 0)
{
var urls = _websocket.ConnectUrls.ToArray();
_websocket.ConnectUrls.Clear();
UpdateServerINFOAndSendINFOToClients([], urls, add: false);
}
return 0;
}
internal ClientConnection CreateWSClient(Stream nc, ClientKind kind)
{
var client = new ClientConnection(kind, this, nc)
{
Ws = new WebsocketConnection
{
Compress = _websocket.Compression,
MaskRead = true,
},
};
return client;
}
internal (WebsocketConnection? ws, ClientKind kind, Exception? err) WsUpgrade(HttpListenerRequest request)
{
var kind = ClientKind.Client;
if (request.Url != null)
{
var path = request.Url.AbsolutePath;
if (path.EndsWith("/leafnode", StringComparison.OrdinalIgnoreCase))
kind = ClientKind.Leaf;
else if (path.EndsWith("/mqtt", StringComparison.OrdinalIgnoreCase))
kind = ClientKind.Client;
}
if (!string.Equals(request.HttpMethod, "GET", StringComparison.OrdinalIgnoreCase))
return (null, kind, WsReturnHTTPError(405, "request method must be GET"));
if (string.IsNullOrWhiteSpace(request.UserHostName))
return (null, kind, WsReturnHTTPError(400, "'Host' missing in request"));
if (!WsHeaderContains(request.Headers, "Upgrade", "websocket"))
return (null, kind, WsReturnHTTPError(400, "invalid value for header 'Upgrade'"));
if (!WsHeaderContains(request.Headers, "Connection", "Upgrade"))
return (null, kind, WsReturnHTTPError(400, "invalid value for header 'Connection'"));
if (string.IsNullOrWhiteSpace(request.Headers["Sec-WebSocket-Key"]))
return (null, kind, WsReturnHTTPError(400, "key missing"));
if (!WsHeaderContains(request.Headers, "Sec-WebSocket-Version", "13"))
return (null, kind, WsReturnHTTPError(400, "invalid version"));
var ws = new WebsocketConnection
{
Compress = GetOpts().Websocket.Compression && WsPMCExtensionSupport(request.Headers, true).supported,
MaskRead = !string.Equals(request.Headers[WsConstants.NoMaskingHeader], WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase),
};
var xff = request.Headers.GetValues(WsConstants.XForwardedForHeader);
if (xff != null && xff.Length > 0 && IPAddress.TryParse(xff[0], out _))
ws.ClientIP = xff[0];
return (ws, kind, null);
}
internal static bool IsWSURL(string url)
{
if (!Uri.TryCreate(url, UriKind.Absolute, out var uri))
return false;
return uri.Scheme.Equals(WsConstants.SchemePrefix, StringComparison.OrdinalIgnoreCase);
}
internal static bool IsWSSURL(string url)
{
if (!Uri.TryCreate(url, UriKind.Absolute, out var uri))
return false;
return uri.Scheme.Equals(WsConstants.SchemePrefixTls, StringComparison.OrdinalIgnoreCase);
}
}

View File

@@ -35,6 +35,8 @@ internal enum WsOpCode : int
/// </summary>
internal static class WsConstants
{
public static readonly byte[] CompressLastBlock = [0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff];
// Frame header bits
public const int FinalBit = 1 << 7;
public const int Rsv1Bit = 1 << 6; // Used for per-message compression (RFC 7692)

View File

@@ -0,0 +1,28 @@
// Copyright 2020-2025 The NATS Authors
// Licensed under the Apache License, Version 2.0.
namespace ZB.MOM.NatsNet.Server.WebSocket;
internal sealed class WebSocketHandler
{
public static (byte[] bytes, int newPos) WsGet(Stream reader, byte[] buffer, int pos, int needed) =>
WebSocketHelpers.WsGet(reader, buffer, pos, needed);
public static bool WsIsControlFrame(WsOpCode frameType) =>
WebSocketHelpers.WsIsControlFrame(frameType);
public static (byte[] header, byte[]? key) WsCreateFrameHeader(bool useMasking, bool compressed, WsOpCode frameType, int length) =>
WebSocketHelpers.WsCreateFrameHeader(useMasking, compressed, frameType, length);
public static (int n, byte[]? key) WsFillFrameHeader(byte[] frameHeader, bool useMasking, bool first, bool final, bool compressed, WsOpCode frameType, int length) =>
WebSocketHelpers.WsFillFrameHeader(frameHeader, useMasking, first, final, compressed, frameType, length);
public static void WsMaskBuf(byte[] key, byte[] buffer) =>
WebSocketHelpers.WsMaskBuf(key, buffer);
public static void WsMaskBufs(byte[] key, IReadOnlyList<byte[]> buffers) =>
WebSocketHelpers.WsMaskBufs(key, buffers);
public static byte[] WsCreateCloseMessage(int status, string body) =>
WebSocketHelpers.WsCreateCloseMessage(status, body);
}

View File

@@ -0,0 +1,135 @@
// Copyright 2020-2025 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
using System.Buffers.Binary;
using System.Security.Cryptography;
namespace ZB.MOM.NatsNet.Server.WebSocket;
internal static class WebSocketHelpers
{
public static (byte[] bytes, int newPos) WsGet(Stream reader, byte[] buf, int pos, int needed)
{
var available = buf.Length - pos;
if (available >= needed)
return (buf[pos..(pos + needed)], pos + needed);
var b = new byte[needed];
var start = 0;
if (available > 0)
{
System.Buffer.BlockCopy(buf, pos, b, 0, available);
start = available;
}
while (start < needed)
{
var n = reader.Read(b, start, needed - start);
if (n <= 0)
throw new EndOfStreamException();
start += n;
}
return (b, pos + available);
}
public static bool WsIsControlFrame(WsOpCode frameType) => frameType >= WsOpCode.Close;
public static (byte[] header, byte[]? key) WsCreateFrameHeader(bool useMasking, bool compressed, WsOpCode frameType, int length)
{
var frameHeader = NbPool.Get(WsConstants.MaxFrameHeaderSize);
var (n, key) = WsFillFrameHeader(frameHeader, useMasking, first: true, final: true, compressed, frameType, length);
return (frameHeader[..n], key);
}
public static (int n, byte[]? key) WsFillFrameHeader(
byte[] frameHeader,
bool useMasking,
bool first,
bool final,
bool compressed,
WsOpCode frameType,
int length)
{
byte b0 = 0;
if (first)
b0 = (byte)frameType;
if (final)
b0 |= WsConstants.FinalBit;
if (compressed)
b0 |= WsConstants.Rsv1Bit;
byte b1 = 0;
if (useMasking)
b1 |= WsConstants.MaskBit;
int n;
if (length <= 125)
{
n = 2;
frameHeader[0] = b0;
frameHeader[1] = (byte)(b1 | (byte)length);
}
else if (length < 65536)
{
n = 4;
frameHeader[0] = b0;
frameHeader[1] = (byte)(b1 | 126);
BinaryPrimitives.WriteUInt16BigEndian(frameHeader.AsSpan(2), (ushort)length);
}
else
{
n = 10;
frameHeader[0] = b0;
frameHeader[1] = (byte)(b1 | 127);
BinaryPrimitives.WriteUInt64BigEndian(frameHeader.AsSpan(2), (ulong)length);
}
byte[]? key = null;
if (useMasking)
{
key = new byte[4];
RandomNumberGenerator.Fill(key);
Array.Copy(key, 0, frameHeader, n, 4);
n += 4;
}
return (n, key);
}
public static void WsMaskBuf(ReadOnlySpan<byte> key, Span<byte> buffer)
{
for (var i = 0; i < buffer.Length; i++)
buffer[i] ^= key[i & 0x3];
}
public static void WsMaskBufs(ReadOnlySpan<byte> key, IReadOnlyList<byte[]> buffers)
{
var pos = 0;
for (var i = 0; i < buffers.Count; i++)
{
var buffer = buffers[i];
for (var j = 0; j < buffer.Length; j++)
{
buffer[j] ^= key[pos & 0x3];
pos++;
}
}
}
public static byte[] WsCreateCloseMessage(int status, string body)
{
if (body.Length > WsConstants.MaxControlPayloadSize - 2)
body = string.Concat(body.AsSpan(0, WsConstants.MaxControlPayloadSize - 5), "...");
var payload = System.Text.Encoding.UTF8.GetBytes(body);
var buffer = new byte[2 + payload.Length];
BinaryPrimitives.WriteUInt16BigEndian(buffer.AsSpan(0, 2), (ushort)status);
payload.CopyTo(buffer.AsSpan(2));
return buffer;
}
}

View File

@@ -13,6 +13,7 @@
//
// Adapted from server/websocket.go in the NATS server Go source.
using System.IO.Compression;
using ZB.MOM.NatsNet.Server.Internal;
namespace ZB.MOM.NatsNet.Server.WebSocket;
@@ -23,88 +24,185 @@ namespace ZB.MOM.NatsNet.Server.WebSocket;
/// </summary>
internal sealed class WsReadInfo
{
/// <summary>Whether masking is disabled for this connection (e.g. leaf node).</summary>
public bool NoMasking { get; set; }
public int Rem { get; set; }
public bool FrameStart { get; set; }
public bool FinalFrameReceived { get; set; }
public bool FrameCompressed { get; set; }
public bool Mask { get; set; }
public byte MaskKeyPosition { get; set; }
public byte[] MaskKey { get; } = new byte[4];
public List<byte[]> CompressedBuffers { get; } = [];
public int CompressedOffset { get; set; }
/// <summary>Whether per-message deflate compression is active.</summary>
public bool Compressed { get; set; }
public void Init()
{
FrameStart = true;
FinalFrameReceived = true;
}
/// <summary>The current frame opcode.</summary>
public WsOpCode FrameType { get; set; }
public int Read(byte[] destination, int offset, int count)
{
if (count == 0)
return 0;
if (CompressedBuffers.Count == 0)
return 0;
/// <summary>Number of payload bytes remaining in the current frame.</summary>
public int PayloadLeft { get; set; }
var copied = 0;
var remaining = count;
while (CompressedBuffers.Count > 0 && remaining > 0)
{
var buffer = CompressedBuffers[0];
var available = buffer.Length - CompressedOffset;
if (available <= 0)
{
NextCBuf();
continue;
}
/// <summary>The 4-byte masking key (only valid when masking is active).</summary>
public int[] Mask { get; set; } = new int[4];
var n = Math.Min(available, remaining);
Array.Copy(buffer, CompressedOffset, destination, offset + copied, n);
copied += n;
remaining -= n;
CompressedOffset += n;
NextCBuf();
}
/// <summary>Current offset into <see cref="Mask"/>.</summary>
public int MaskOffset { get; set; }
return copied;
}
/// <summary>Accumulated compressed payload buffers awaiting decompression.</summary>
public byte[]? Compress { get; set; }
public byte[]? NextCBuf()
{
if (CompressedBuffers.Count == 0)
return null;
if (CompressedOffset != CompressedBuffers[0].Length)
return CompressedBuffers[0];
public WsReadInfo() { }
CompressedOffset = 0;
if (CompressedBuffers.Count == 1)
{
CompressedBuffers.Clear();
return null;
}
CompressedBuffers.RemoveAt(0);
return CompressedBuffers[0];
}
public byte ReadByte()
{
if (CompressedBuffers.Count == 0)
throw new EndOfStreamException();
var b = CompressedBuffers[0][CompressedOffset];
CompressedOffset++;
NextCBuf();
return b;
}
public byte[] Decompress(int maxPayload)
{
if (maxPayload <= 0)
maxPayload = ServerConstants.MaxPayloadSize;
CompressedOffset = 0;
var input = new MemoryStream();
foreach (var buffer in CompressedBuffers)
input.Write(buffer, 0, buffer.Length);
input.Write(WsConstants.CompressLastBlock, 0, WsConstants.CompressLastBlock.Length);
input.Position = 0;
using var deflate = new DeflateStream(input, System.IO.Compression.CompressionMode.Decompress, leaveOpen: true);
using var output = new MemoryStream();
var tmp = new byte[4096];
while (true)
{
var n = deflate.Read(tmp, 0, tmp.Length);
if (n == 0)
break;
if (output.Length + n > maxPayload)
throw ServerErrors.ErrMaxPayload;
output.Write(tmp, 0, n);
}
CompressedBuffers.Clear();
return output.ToArray();
}
public void Unmask(byte[] buf) => Unmask(buf.AsSpan());
public void Unmask(Span<byte> buf)
{
var p = (int)MaskKeyPosition;
if (buf.Length < 16)
{
for (var i = 0; i < buf.Length; i++)
{
buf[i] ^= MaskKey[p & 3];
p++;
}
MaskKeyPosition = (byte)(p & 3);
return;
}
var key8 = new byte[8];
for (var i = 0; i < 8; i++)
key8[i] = MaskKey[(p + i) & 3];
var mask64 = BitConverter.ToUInt64(key8, 0);
var n8 = (buf.Length / 8) * 8;
for (var i = 0; i < n8; i += 8)
{
var value = BitConverter.ToUInt64(buf[i..(i + 8)]) ^ mask64;
var bytes = BitConverter.GetBytes(value);
bytes.CopyTo(buf[i..(i + 8)]);
}
for (var i = n8; i < buf.Length; i++)
{
buf[i] ^= MaskKey[p & 3];
p++;
}
MaskKeyPosition = (byte)(p & 3);
}
}
/// <summary>
/// Client-level WebSocket runtime state.
/// Mirrors Go <c>websocket</c> struct in websocket.go.
/// </summary>
internal sealed class WebsocketConnection
{
public List<byte[]> Frames { get; } = [];
public long FrameSize { get; set; }
public byte[]? CloseMessage { get; set; }
public bool Compress { get; set; }
public bool CloseSent { get; set; }
public bool Browser { get; set; }
public bool NoCompressedFragment { get; set; }
public bool MaskRead { get; set; } = true;
public bool MaskWrite { get; set; }
public string ClientIP { get; set; } = string.Empty;
}
/// <summary>
/// Server-level WebSocket state, shared across all WebSocket connections.
/// Mirrors Go <c>srvWebsocket</c> struct in server/websocket.go.
/// Replaces the stub in NatsServerTypes.cs.
/// </summary>
internal sealed class SrvWebsocket
{
/// <summary>
/// Tracks WebSocket connect URLs per server (ref-counted).
/// Mirrors Go <c>connectURLsMap refCountedUrlSet</c>.
/// </summary>
public RefCountedUrlSet ConnectUrlsMap { get; set; } = new();
/// <summary>
/// TLS configuration for the WebSocket listener.
/// Mirrors Go <c>tls bool</c> field (true if TLS is required).
/// </summary>
public Lock Mu { get; } = new();
public System.Net.Sockets.TcpListener? Listener { get; set; }
public Exception? ListenerErr { get; set; }
public Dictionary<string, AllowedOrigin> AllowedOrigins { get; } = new(StringComparer.OrdinalIgnoreCase);
public bool SameOrigin { get; set; }
public List<string> ConnectUrls { get; } = [];
public RefCountedUrlSet ConnectUrlsMap { get; } = new();
public bool AuthOverride { get; set; }
public string RawHeaders { get; set; } = string.Empty;
public System.Net.Security.SslServerAuthenticationOptions? TlsConfig { get; set; }
/// <summary>Whether per-message deflate compression is enabled globally.</summary>
public bool Compression { get; set; }
/// <summary>Host the WebSocket server is listening on.</summary>
public string Host { get; set; } = string.Empty;
/// <summary>Port the WebSocket server is listening on (may be ephemeral).</summary>
public int Port { get; set; }
}
/// <summary>
/// Handles WebSocket upgrade and framing for a single connection.
/// Mirrors the WebSocket-related methods on Go <c>client</c> in server/websocket.go.
/// Full implementation is deferred to session 23.
/// </summary>
internal sealed class WebSocketHandler
{
private readonly NatsServer _server;
public WebSocketHandler(NatsServer server)
{
_server = server;
}
/// <summary>Upgrades an HTTP connection to WebSocket protocol.</summary>
public void UpgradeToWebSocket(
System.IO.Stream stream,
System.Net.Http.Headers.HttpRequestHeaders headers)
=> throw new NotImplementedException("TODO: session 23 — websocket");
/// <summary>Parses a WebSocket frame from the given buffer slice.</summary>
public void ParseFrame(byte[] data, int offset, int count)
=> throw new NotImplementedException("TODO: session 23 — websocket");
/// <summary>Writes a WebSocket frame with the given payload.</summary>
public void WriteFrame(WsOpCode opCode, byte[] payload, bool final, bool compress)
=> throw new NotImplementedException("TODO: session 23 — websocket");
/// <summary>Writes a WebSocket close frame with the given status code and reason.</summary>
public void WriteCloseFrame(int statusCode, string reason)
=> throw new NotImplementedException("TODO: session 23 — websocket");
}
internal readonly record struct AllowedOrigin(string Scheme, string Port);