feat: add WebSocket frame reader state machine
This commit is contained in:
313
src/NATS.Server/WebSocket/WsReadInfo.cs
Normal file
313
src/NATS.Server/WebSocket/WsReadInfo.cs
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
using System.Buffers.Binary;
|
||||||
|
using System.Text;
|
||||||
|
|
||||||
|
namespace NATS.Server.WebSocket;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Per-connection WebSocket frame reading state machine.
|
||||||
|
/// Ported from golang/nats-server/server/websocket.go lines 156-506.
|
||||||
|
/// </summary>
|
||||||
|
public struct WsReadInfo
|
||||||
|
{
|
||||||
|
public int Remaining;
|
||||||
|
public bool FrameStart;
|
||||||
|
public bool FirstFrame;
|
||||||
|
public bool FrameCompressed;
|
||||||
|
public bool ExpectMask;
|
||||||
|
public byte MaskKeyPos;
|
||||||
|
public byte[] MaskKey;
|
||||||
|
public List<byte[]>? CompressedBuffers;
|
||||||
|
public int CompressedOffset;
|
||||||
|
|
||||||
|
// Control frame outputs
|
||||||
|
public List<ControlFrameAction> PendingControlFrames;
|
||||||
|
public bool CloseReceived;
|
||||||
|
public int CloseStatus;
|
||||||
|
public string? CloseBody;
|
||||||
|
|
||||||
|
public WsReadInfo(bool expectMask)
|
||||||
|
{
|
||||||
|
Remaining = 0;
|
||||||
|
FrameStart = true;
|
||||||
|
FirstFrame = true;
|
||||||
|
FrameCompressed = false;
|
||||||
|
ExpectMask = expectMask;
|
||||||
|
MaskKeyPos = 0;
|
||||||
|
MaskKey = new byte[4];
|
||||||
|
CompressedBuffers = null;
|
||||||
|
CompressedOffset = 0;
|
||||||
|
PendingControlFrames = [];
|
||||||
|
CloseReceived = false;
|
||||||
|
CloseStatus = 0;
|
||||||
|
CloseBody = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void SetMaskKey(ReadOnlySpan<byte> key)
|
||||||
|
{
|
||||||
|
key[..4].CopyTo(MaskKey);
|
||||||
|
MaskKeyPos = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Unmask buffer in-place using current mask key and position.
|
||||||
|
/// Optimized for 8-byte chunks when buffer is large enough.
|
||||||
|
/// Ported from websocket.go lines 509-536.
|
||||||
|
/// </summary>
|
||||||
|
public void Unmask(Span<byte> buf)
|
||||||
|
{
|
||||||
|
int p = MaskKeyPos;
|
||||||
|
if (buf.Length < 16)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < buf.Length; i++)
|
||||||
|
{
|
||||||
|
buf[i] ^= MaskKey[p & 3];
|
||||||
|
p++;
|
||||||
|
}
|
||||||
|
MaskKeyPos = (byte)(p & 3);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build 8-byte key for bulk XOR
|
||||||
|
Span<byte> k = stackalloc byte[8];
|
||||||
|
for (int i = 0; i < 8; i++)
|
||||||
|
k[i] = MaskKey[(p + i) & 3];
|
||||||
|
ulong km = BinaryPrimitives.ReadUInt64BigEndian(k);
|
||||||
|
|
||||||
|
int n = (buf.Length / 8) * 8;
|
||||||
|
for (int i = 0; i < n; i += 8)
|
||||||
|
{
|
||||||
|
ulong tmp = BinaryPrimitives.ReadUInt64BigEndian(buf[i..]);
|
||||||
|
tmp ^= km;
|
||||||
|
BinaryPrimitives.WriteUInt64BigEndian(buf[i..], tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle remaining bytes
|
||||||
|
p += n;
|
||||||
|
var tail = buf[n..];
|
||||||
|
for (int i = 0; i < tail.Length; i++)
|
||||||
|
{
|
||||||
|
tail[i] ^= MaskKey[p & 3];
|
||||||
|
p++;
|
||||||
|
}
|
||||||
|
MaskKeyPos = (byte)(p & 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Read and decode WebSocket frames from a buffer.
|
||||||
|
/// Returns list of decoded payload byte arrays.
|
||||||
|
/// Ported from websocket.go lines 208-351.
|
||||||
|
/// </summary>
|
||||||
|
public static List<byte[]> ReadFrames(ref WsReadInfo r, Stream stream, int available, int maxPayload)
|
||||||
|
{
|
||||||
|
var bufs = new List<byte[]>();
|
||||||
|
var buf = new byte[available];
|
||||||
|
int bytesRead = 0;
|
||||||
|
|
||||||
|
// Fill the buffer from the stream
|
||||||
|
while (bytesRead < available)
|
||||||
|
{
|
||||||
|
int n = stream.Read(buf, bytesRead, available - bytesRead);
|
||||||
|
if (n == 0) break;
|
||||||
|
bytesRead += n;
|
||||||
|
}
|
||||||
|
|
||||||
|
int pos = 0;
|
||||||
|
int max = bytesRead;
|
||||||
|
|
||||||
|
while (pos < max)
|
||||||
|
{
|
||||||
|
if (r.FrameStart)
|
||||||
|
{
|
||||||
|
if (pos >= max) break;
|
||||||
|
byte b0 = buf[pos];
|
||||||
|
int frameType = b0 & 0x0F;
|
||||||
|
bool final = (b0 & WsConstants.FinalBit) != 0;
|
||||||
|
bool compressed = (b0 & WsConstants.Rsv1Bit) != 0;
|
||||||
|
pos++;
|
||||||
|
|
||||||
|
// Read second byte
|
||||||
|
var (b1Buf, newPos) = WsGet(stream, buf, pos, max, 1);
|
||||||
|
pos = newPos;
|
||||||
|
byte b1 = b1Buf[0];
|
||||||
|
|
||||||
|
// Check mask bit
|
||||||
|
if (r.ExpectMask && (b1 & WsConstants.MaskBit) == 0)
|
||||||
|
throw new InvalidOperationException("mask bit missing");
|
||||||
|
|
||||||
|
r.Remaining = b1 & 0x7F;
|
||||||
|
|
||||||
|
// Validate frame types
|
||||||
|
if (WsConstants.IsControlFrame(frameType))
|
||||||
|
{
|
||||||
|
if (r.Remaining > WsConstants.MaxControlPayloadSize)
|
||||||
|
throw new InvalidOperationException("control frame length too large");
|
||||||
|
if (!final)
|
||||||
|
throw new InvalidOperationException("control frame does not have final bit set");
|
||||||
|
}
|
||||||
|
else if (frameType == WsConstants.TextMessage || frameType == WsConstants.BinaryMessage)
|
||||||
|
{
|
||||||
|
if (!r.FirstFrame)
|
||||||
|
throw new InvalidOperationException("new message before previous finished");
|
||||||
|
r.FirstFrame = final;
|
||||||
|
r.FrameCompressed = compressed;
|
||||||
|
}
|
||||||
|
else if (frameType == WsConstants.ContinuationFrame)
|
||||||
|
{
|
||||||
|
if (r.FirstFrame || compressed)
|
||||||
|
throw new InvalidOperationException("invalid continuation frame");
|
||||||
|
r.FirstFrame = final;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
throw new InvalidOperationException($"unknown opcode {frameType}");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extended payload length
|
||||||
|
switch (r.Remaining)
|
||||||
|
{
|
||||||
|
case 126:
|
||||||
|
{
|
||||||
|
var (lenBuf, p2) = WsGet(stream, buf, pos, max, 2);
|
||||||
|
pos = p2;
|
||||||
|
r.Remaining = BinaryPrimitives.ReadUInt16BigEndian(lenBuf);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 127:
|
||||||
|
{
|
||||||
|
var (lenBuf, p2) = WsGet(stream, buf, pos, max, 8);
|
||||||
|
pos = p2;
|
||||||
|
r.Remaining = (int)BinaryPrimitives.ReadUInt64BigEndian(lenBuf);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read mask key
|
||||||
|
if (r.ExpectMask && (b1 & WsConstants.MaskBit) != 0)
|
||||||
|
{
|
||||||
|
var (keyBuf, p2) = WsGet(stream, buf, pos, max, 4);
|
||||||
|
pos = p2;
|
||||||
|
keyBuf.AsSpan(0, 4).CopyTo(r.MaskKey);
|
||||||
|
r.MaskKeyPos = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle control frames
|
||||||
|
if (WsConstants.IsControlFrame(frameType))
|
||||||
|
{
|
||||||
|
pos = HandleControlFrame(ref r, frameType, stream, buf, pos, max);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
r.FrameStart = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pos < max)
|
||||||
|
{
|
||||||
|
int n = r.Remaining;
|
||||||
|
if (pos + n > max) n = max - pos;
|
||||||
|
|
||||||
|
var payloadSlice = buf.AsSpan(pos, n).ToArray();
|
||||||
|
pos += n;
|
||||||
|
r.Remaining -= n;
|
||||||
|
|
||||||
|
if (r.ExpectMask)
|
||||||
|
r.Unmask(payloadSlice);
|
||||||
|
|
||||||
|
bool addToBufs = true;
|
||||||
|
if (r.FrameCompressed)
|
||||||
|
{
|
||||||
|
addToBufs = false;
|
||||||
|
r.CompressedBuffers ??= [];
|
||||||
|
r.CompressedBuffers.Add(payloadSlice);
|
||||||
|
|
||||||
|
if (r.FirstFrame && r.Remaining == 0)
|
||||||
|
{
|
||||||
|
var decompressed = WsCompression.Decompress(r.CompressedBuffers, maxPayload);
|
||||||
|
r.CompressedBuffers = null;
|
||||||
|
r.FrameCompressed = false;
|
||||||
|
addToBufs = true;
|
||||||
|
payloadSlice = decompressed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (addToBufs && payloadSlice.Length > 0)
|
||||||
|
bufs.Add(payloadSlice);
|
||||||
|
|
||||||
|
if (r.Remaining == 0)
|
||||||
|
r.FrameStart = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return bufs;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static int HandleControlFrame(ref WsReadInfo r, int frameType, Stream stream, byte[] buf, int pos, int max)
|
||||||
|
{
|
||||||
|
byte[]? payload = null;
|
||||||
|
if (r.Remaining > 0)
|
||||||
|
{
|
||||||
|
var (payloadBuf, newPos) = WsGet(stream, buf, pos, max, r.Remaining);
|
||||||
|
pos = newPos;
|
||||||
|
payload = payloadBuf;
|
||||||
|
if (r.ExpectMask)
|
||||||
|
r.Unmask(payload);
|
||||||
|
r.Remaining = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (frameType)
|
||||||
|
{
|
||||||
|
case WsConstants.CloseMessage:
|
||||||
|
r.CloseReceived = true;
|
||||||
|
r.CloseStatus = WsConstants.CloseStatusNoStatusReceived;
|
||||||
|
if (payload != null && payload.Length >= WsConstants.CloseStatusSize)
|
||||||
|
{
|
||||||
|
r.CloseStatus = BinaryPrimitives.ReadUInt16BigEndian(payload);
|
||||||
|
if (payload.Length > WsConstants.CloseStatusSize)
|
||||||
|
r.CloseBody = Encoding.UTF8.GetString(payload.AsSpan(WsConstants.CloseStatusSize));
|
||||||
|
}
|
||||||
|
if (r.CloseStatus != WsConstants.CloseStatusNoStatusReceived)
|
||||||
|
{
|
||||||
|
var closeMsg = WsFrameWriter.CreateCloseMessage(r.CloseStatus, r.CloseBody ?? "");
|
||||||
|
r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.CloseMessage, closeMsg));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case WsConstants.PingMessage:
|
||||||
|
r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.PongMessage, payload ?? []));
|
||||||
|
break;
|
||||||
|
|
||||||
|
case WsConstants.PongMessage:
|
||||||
|
// Nothing to do
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Gets needed bytes from buffer or reads from stream.
|
||||||
|
/// Ported from websocket.go lines 178-193.
|
||||||
|
/// </summary>
|
||||||
|
private static (byte[] data, int newPos) WsGet(Stream stream, byte[] buf, int pos, int max, int needed)
|
||||||
|
{
|
||||||
|
int avail = max - pos;
|
||||||
|
if (avail >= needed)
|
||||||
|
return (buf[pos..(pos + needed)], pos + needed);
|
||||||
|
|
||||||
|
var b = new byte[needed];
|
||||||
|
int start = 0;
|
||||||
|
if (avail > 0)
|
||||||
|
{
|
||||||
|
Buffer.BlockCopy(buf, pos, b, 0, avail);
|
||||||
|
start = avail;
|
||||||
|
}
|
||||||
|
while (start < needed)
|
||||||
|
{
|
||||||
|
int n = stream.Read(b, start, needed - start);
|
||||||
|
if (n == 0) throw new IOException("unexpected end of stream");
|
||||||
|
start += n;
|
||||||
|
}
|
||||||
|
return (b, pos + avail);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public readonly record struct ControlFrameAction(int Opcode, byte[] Payload);
|
||||||
163
tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs
Normal file
163
tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
using System.Buffers.Binary;
|
||||||
|
using NATS.Server.WebSocket;
|
||||||
|
using Shouldly;
|
||||||
|
|
||||||
|
namespace NATS.Server.Tests.WebSocket;
|
||||||
|
|
||||||
|
public class WsFrameReadTests
|
||||||
|
{
|
||||||
|
/// <summary>Helper: build a single unmasked binary frame.</summary>
|
||||||
|
private static byte[] BuildFrame(byte[] payload, bool fin = true, bool compressed = false, int opcode = WsConstants.BinaryMessage, bool mask = false, byte[]? maskKey = null)
|
||||||
|
{
|
||||||
|
int payloadLen = payload.Length;
|
||||||
|
byte b0 = (byte)opcode;
|
||||||
|
if (fin) b0 |= WsConstants.FinalBit;
|
||||||
|
if (compressed) b0 |= WsConstants.Rsv1Bit;
|
||||||
|
byte b1 = 0;
|
||||||
|
if (mask) b1 |= WsConstants.MaskBit;
|
||||||
|
|
||||||
|
byte[] lenBytes;
|
||||||
|
if (payloadLen <= 125)
|
||||||
|
{
|
||||||
|
lenBytes = [(byte)(b1 | (byte)payloadLen)];
|
||||||
|
}
|
||||||
|
else if (payloadLen < 65536)
|
||||||
|
{
|
||||||
|
lenBytes = new byte[3];
|
||||||
|
lenBytes[0] = (byte)(b1 | 126);
|
||||||
|
BinaryPrimitives.WriteUInt16BigEndian(lenBytes.AsSpan(1), (ushort)payloadLen);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
lenBytes = new byte[9];
|
||||||
|
lenBytes[0] = (byte)(b1 | 127);
|
||||||
|
BinaryPrimitives.WriteUInt64BigEndian(lenBytes.AsSpan(1), (ulong)payloadLen);
|
||||||
|
}
|
||||||
|
|
||||||
|
int totalLen = 1 + lenBytes.Length + (mask ? 4 : 0) + payloadLen;
|
||||||
|
var frame = new byte[totalLen];
|
||||||
|
frame[0] = b0;
|
||||||
|
lenBytes.CopyTo(frame.AsSpan(1));
|
||||||
|
int pos = 1 + lenBytes.Length;
|
||||||
|
if (mask && maskKey != null)
|
||||||
|
{
|
||||||
|
maskKey.CopyTo(frame.AsSpan(pos));
|
||||||
|
pos += 4;
|
||||||
|
var maskedPayload = payload.ToArray();
|
||||||
|
WsFrameWriter.MaskBuf(maskKey, maskedPayload);
|
||||||
|
maskedPayload.CopyTo(frame.AsSpan(pos));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
payload.CopyTo(frame.AsSpan(pos));
|
||||||
|
}
|
||||||
|
return frame;
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ReadSingleUnmaskedFrame()
|
||||||
|
{
|
||||||
|
var payload = "Hello"u8.ToArray();
|
||||||
|
var frame = BuildFrame(payload);
|
||||||
|
|
||||||
|
var readInfo = new WsReadInfo(expectMask: false);
|
||||||
|
var stream = new MemoryStream(frame);
|
||||||
|
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);
|
||||||
|
|
||||||
|
result.Count.ShouldBe(1);
|
||||||
|
result[0].ShouldBe(payload);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ReadMaskedFrame()
|
||||||
|
{
|
||||||
|
var payload = "Hello"u8.ToArray();
|
||||||
|
byte[] key = [0x37, 0xFA, 0x21, 0x3D];
|
||||||
|
var frame = BuildFrame(payload, mask: true, maskKey: key);
|
||||||
|
|
||||||
|
var readInfo = new WsReadInfo(expectMask: true);
|
||||||
|
var stream = new MemoryStream(frame);
|
||||||
|
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);
|
||||||
|
|
||||||
|
result.Count.ShouldBe(1);
|
||||||
|
result[0].ShouldBe(payload);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void Read16BitLengthFrame()
|
||||||
|
{
|
||||||
|
var payload = new byte[200];
|
||||||
|
Random.Shared.NextBytes(payload);
|
||||||
|
var frame = BuildFrame(payload);
|
||||||
|
|
||||||
|
var readInfo = new WsReadInfo(expectMask: false);
|
||||||
|
var stream = new MemoryStream(frame);
|
||||||
|
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);
|
||||||
|
|
||||||
|
result.Count.ShouldBe(1);
|
||||||
|
result[0].ShouldBe(payload);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ReadPingFrame_ReturnsPongAction()
|
||||||
|
{
|
||||||
|
var frame = BuildFrame([], opcode: WsConstants.PingMessage);
|
||||||
|
|
||||||
|
var readInfo = new WsReadInfo(expectMask: false);
|
||||||
|
var stream = new MemoryStream(frame);
|
||||||
|
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);
|
||||||
|
|
||||||
|
result.Count.ShouldBe(0); // control frames don't produce payload
|
||||||
|
readInfo.PendingControlFrames.Count.ShouldBe(1);
|
||||||
|
readInfo.PendingControlFrames[0].Opcode.ShouldBe(WsConstants.PongMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ReadCloseFrame_ReturnsCloseAction()
|
||||||
|
{
|
||||||
|
var closePayload = new byte[2];
|
||||||
|
BinaryPrimitives.WriteUInt16BigEndian(closePayload, 1000);
|
||||||
|
var frame = BuildFrame(closePayload, opcode: WsConstants.CloseMessage);
|
||||||
|
|
||||||
|
var readInfo = new WsReadInfo(expectMask: false);
|
||||||
|
var stream = new MemoryStream(frame);
|
||||||
|
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);
|
||||||
|
|
||||||
|
result.Count.ShouldBe(0);
|
||||||
|
readInfo.CloseReceived.ShouldBeTrue();
|
||||||
|
readInfo.CloseStatus.ShouldBe(1000);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ReadPongFrame_NoAction()
|
||||||
|
{
|
||||||
|
var frame = BuildFrame([], opcode: WsConstants.PongMessage);
|
||||||
|
|
||||||
|
var readInfo = new WsReadInfo(expectMask: false);
|
||||||
|
var stream = new MemoryStream(frame);
|
||||||
|
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);
|
||||||
|
|
||||||
|
result.Count.ShouldBe(0);
|
||||||
|
readInfo.PendingControlFrames.Count.ShouldBe(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void Unmask_Optimized_8ByteChunks()
|
||||||
|
{
|
||||||
|
byte[] key = [0xAA, 0xBB, 0xCC, 0xDD];
|
||||||
|
var original = new byte[32];
|
||||||
|
Random.Shared.NextBytes(original);
|
||||||
|
var masked = original.ToArray();
|
||||||
|
|
||||||
|
// Mask it
|
||||||
|
for (int i = 0; i < masked.Length; i++)
|
||||||
|
masked[i] ^= key[i & 3];
|
||||||
|
|
||||||
|
// Unmask using the state machine
|
||||||
|
var info = new WsReadInfo(expectMask: true);
|
||||||
|
info.SetMaskKey(key);
|
||||||
|
info.Unmask(masked);
|
||||||
|
|
||||||
|
masked.ShouldBe(original);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user