feat: add WebSocket frame reader state machine
This commit is contained in:
313
src/NATS.Server/WebSocket/WsReadInfo.cs
Normal file
313
src/NATS.Server/WebSocket/WsReadInfo.cs
Normal file
@@ -0,0 +1,313 @@
|
||||
using System.Buffers.Binary;
|
||||
using System.Text;
|
||||
|
||||
namespace NATS.Server.WebSocket;
|
||||
|
||||
/// <summary>
|
||||
/// Per-connection WebSocket frame reading state machine.
|
||||
/// Ported from golang/nats-server/server/websocket.go lines 156-506.
|
||||
/// </summary>
|
||||
public struct WsReadInfo
|
||||
{
|
||||
public int Remaining;
|
||||
public bool FrameStart;
|
||||
public bool FirstFrame;
|
||||
public bool FrameCompressed;
|
||||
public bool ExpectMask;
|
||||
public byte MaskKeyPos;
|
||||
public byte[] MaskKey;
|
||||
public List<byte[]>? CompressedBuffers;
|
||||
public int CompressedOffset;
|
||||
|
||||
// Control frame outputs
|
||||
public List<ControlFrameAction> PendingControlFrames;
|
||||
public bool CloseReceived;
|
||||
public int CloseStatus;
|
||||
public string? CloseBody;
|
||||
|
||||
public WsReadInfo(bool expectMask)
|
||||
{
|
||||
Remaining = 0;
|
||||
FrameStart = true;
|
||||
FirstFrame = true;
|
||||
FrameCompressed = false;
|
||||
ExpectMask = expectMask;
|
||||
MaskKeyPos = 0;
|
||||
MaskKey = new byte[4];
|
||||
CompressedBuffers = null;
|
||||
CompressedOffset = 0;
|
||||
PendingControlFrames = [];
|
||||
CloseReceived = false;
|
||||
CloseStatus = 0;
|
||||
CloseBody = null;
|
||||
}
|
||||
|
||||
public void SetMaskKey(ReadOnlySpan<byte> key)
|
||||
{
|
||||
key[..4].CopyTo(MaskKey);
|
||||
MaskKeyPos = 0;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Unmask buffer in-place using current mask key and position.
|
||||
/// Optimized for 8-byte chunks when buffer is large enough.
|
||||
/// Ported from websocket.go lines 509-536.
|
||||
/// </summary>
|
||||
public void Unmask(Span<byte> buf)
|
||||
{
|
||||
int p = MaskKeyPos;
|
||||
if (buf.Length < 16)
|
||||
{
|
||||
for (int i = 0; i < buf.Length; i++)
|
||||
{
|
||||
buf[i] ^= MaskKey[p & 3];
|
||||
p++;
|
||||
}
|
||||
MaskKeyPos = (byte)(p & 3);
|
||||
return;
|
||||
}
|
||||
|
||||
// Build 8-byte key for bulk XOR
|
||||
Span<byte> k = stackalloc byte[8];
|
||||
for (int i = 0; i < 8; i++)
|
||||
k[i] = MaskKey[(p + i) & 3];
|
||||
ulong km = BinaryPrimitives.ReadUInt64BigEndian(k);
|
||||
|
||||
int n = (buf.Length / 8) * 8;
|
||||
for (int i = 0; i < n; i += 8)
|
||||
{
|
||||
ulong tmp = BinaryPrimitives.ReadUInt64BigEndian(buf[i..]);
|
||||
tmp ^= km;
|
||||
BinaryPrimitives.WriteUInt64BigEndian(buf[i..], tmp);
|
||||
}
|
||||
|
||||
// Handle remaining bytes
|
||||
p += n;
|
||||
var tail = buf[n..];
|
||||
for (int i = 0; i < tail.Length; i++)
|
||||
{
|
||||
tail[i] ^= MaskKey[p & 3];
|
||||
p++;
|
||||
}
|
||||
MaskKeyPos = (byte)(p & 3);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Read and decode WebSocket frames from a buffer.
|
||||
/// Returns list of decoded payload byte arrays.
|
||||
/// Ported from websocket.go lines 208-351.
|
||||
/// </summary>
|
||||
public static List<byte[]> ReadFrames(ref WsReadInfo r, Stream stream, int available, int maxPayload)
|
||||
{
|
||||
var bufs = new List<byte[]>();
|
||||
var buf = new byte[available];
|
||||
int bytesRead = 0;
|
||||
|
||||
// Fill the buffer from the stream
|
||||
while (bytesRead < available)
|
||||
{
|
||||
int n = stream.Read(buf, bytesRead, available - bytesRead);
|
||||
if (n == 0) break;
|
||||
bytesRead += n;
|
||||
}
|
||||
|
||||
int pos = 0;
|
||||
int max = bytesRead;
|
||||
|
||||
while (pos < max)
|
||||
{
|
||||
if (r.FrameStart)
|
||||
{
|
||||
if (pos >= max) break;
|
||||
byte b0 = buf[pos];
|
||||
int frameType = b0 & 0x0F;
|
||||
bool final = (b0 & WsConstants.FinalBit) != 0;
|
||||
bool compressed = (b0 & WsConstants.Rsv1Bit) != 0;
|
||||
pos++;
|
||||
|
||||
// Read second byte
|
||||
var (b1Buf, newPos) = WsGet(stream, buf, pos, max, 1);
|
||||
pos = newPos;
|
||||
byte b1 = b1Buf[0];
|
||||
|
||||
// Check mask bit
|
||||
if (r.ExpectMask && (b1 & WsConstants.MaskBit) == 0)
|
||||
throw new InvalidOperationException("mask bit missing");
|
||||
|
||||
r.Remaining = b1 & 0x7F;
|
||||
|
||||
// Validate frame types
|
||||
if (WsConstants.IsControlFrame(frameType))
|
||||
{
|
||||
if (r.Remaining > WsConstants.MaxControlPayloadSize)
|
||||
throw new InvalidOperationException("control frame length too large");
|
||||
if (!final)
|
||||
throw new InvalidOperationException("control frame does not have final bit set");
|
||||
}
|
||||
else if (frameType == WsConstants.TextMessage || frameType == WsConstants.BinaryMessage)
|
||||
{
|
||||
if (!r.FirstFrame)
|
||||
throw new InvalidOperationException("new message before previous finished");
|
||||
r.FirstFrame = final;
|
||||
r.FrameCompressed = compressed;
|
||||
}
|
||||
else if (frameType == WsConstants.ContinuationFrame)
|
||||
{
|
||||
if (r.FirstFrame || compressed)
|
||||
throw new InvalidOperationException("invalid continuation frame");
|
||||
r.FirstFrame = final;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException($"unknown opcode {frameType}");
|
||||
}
|
||||
|
||||
// Extended payload length
|
||||
switch (r.Remaining)
|
||||
{
|
||||
case 126:
|
||||
{
|
||||
var (lenBuf, p2) = WsGet(stream, buf, pos, max, 2);
|
||||
pos = p2;
|
||||
r.Remaining = BinaryPrimitives.ReadUInt16BigEndian(lenBuf);
|
||||
break;
|
||||
}
|
||||
case 127:
|
||||
{
|
||||
var (lenBuf, p2) = WsGet(stream, buf, pos, max, 8);
|
||||
pos = p2;
|
||||
r.Remaining = (int)BinaryPrimitives.ReadUInt64BigEndian(lenBuf);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Read mask key
|
||||
if (r.ExpectMask && (b1 & WsConstants.MaskBit) != 0)
|
||||
{
|
||||
var (keyBuf, p2) = WsGet(stream, buf, pos, max, 4);
|
||||
pos = p2;
|
||||
keyBuf.AsSpan(0, 4).CopyTo(r.MaskKey);
|
||||
r.MaskKeyPos = 0;
|
||||
}
|
||||
|
||||
// Handle control frames
|
||||
if (WsConstants.IsControlFrame(frameType))
|
||||
{
|
||||
pos = HandleControlFrame(ref r, frameType, stream, buf, pos, max);
|
||||
continue;
|
||||
}
|
||||
|
||||
r.FrameStart = false;
|
||||
}
|
||||
|
||||
if (pos < max)
|
||||
{
|
||||
int n = r.Remaining;
|
||||
if (pos + n > max) n = max - pos;
|
||||
|
||||
var payloadSlice = buf.AsSpan(pos, n).ToArray();
|
||||
pos += n;
|
||||
r.Remaining -= n;
|
||||
|
||||
if (r.ExpectMask)
|
||||
r.Unmask(payloadSlice);
|
||||
|
||||
bool addToBufs = true;
|
||||
if (r.FrameCompressed)
|
||||
{
|
||||
addToBufs = false;
|
||||
r.CompressedBuffers ??= [];
|
||||
r.CompressedBuffers.Add(payloadSlice);
|
||||
|
||||
if (r.FirstFrame && r.Remaining == 0)
|
||||
{
|
||||
var decompressed = WsCompression.Decompress(r.CompressedBuffers, maxPayload);
|
||||
r.CompressedBuffers = null;
|
||||
r.FrameCompressed = false;
|
||||
addToBufs = true;
|
||||
payloadSlice = decompressed;
|
||||
}
|
||||
}
|
||||
|
||||
if (addToBufs && payloadSlice.Length > 0)
|
||||
bufs.Add(payloadSlice);
|
||||
|
||||
if (r.Remaining == 0)
|
||||
r.FrameStart = true;
|
||||
}
|
||||
}
|
||||
|
||||
return bufs;
|
||||
}
|
||||
|
||||
private static int HandleControlFrame(ref WsReadInfo r, int frameType, Stream stream, byte[] buf, int pos, int max)
|
||||
{
|
||||
byte[]? payload = null;
|
||||
if (r.Remaining > 0)
|
||||
{
|
||||
var (payloadBuf, newPos) = WsGet(stream, buf, pos, max, r.Remaining);
|
||||
pos = newPos;
|
||||
payload = payloadBuf;
|
||||
if (r.ExpectMask)
|
||||
r.Unmask(payload);
|
||||
r.Remaining = 0;
|
||||
}
|
||||
|
||||
switch (frameType)
|
||||
{
|
||||
case WsConstants.CloseMessage:
|
||||
r.CloseReceived = true;
|
||||
r.CloseStatus = WsConstants.CloseStatusNoStatusReceived;
|
||||
if (payload != null && payload.Length >= WsConstants.CloseStatusSize)
|
||||
{
|
||||
r.CloseStatus = BinaryPrimitives.ReadUInt16BigEndian(payload);
|
||||
if (payload.Length > WsConstants.CloseStatusSize)
|
||||
r.CloseBody = Encoding.UTF8.GetString(payload.AsSpan(WsConstants.CloseStatusSize));
|
||||
}
|
||||
if (r.CloseStatus != WsConstants.CloseStatusNoStatusReceived)
|
||||
{
|
||||
var closeMsg = WsFrameWriter.CreateCloseMessage(r.CloseStatus, r.CloseBody ?? "");
|
||||
r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.CloseMessage, closeMsg));
|
||||
}
|
||||
break;
|
||||
|
||||
case WsConstants.PingMessage:
|
||||
r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.PongMessage, payload ?? []));
|
||||
break;
|
||||
|
||||
case WsConstants.PongMessage:
|
||||
// Nothing to do
|
||||
break;
|
||||
}
|
||||
|
||||
return pos;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets needed bytes from buffer or reads from stream.
|
||||
/// Ported from websocket.go lines 178-193.
|
||||
/// </summary>
|
||||
private static (byte[] data, int newPos) WsGet(Stream stream, byte[] buf, int pos, int max, int needed)
|
||||
{
|
||||
int avail = max - pos;
|
||||
if (avail >= needed)
|
||||
return (buf[pos..(pos + needed)], pos + needed);
|
||||
|
||||
var b = new byte[needed];
|
||||
int start = 0;
|
||||
if (avail > 0)
|
||||
{
|
||||
Buffer.BlockCopy(buf, pos, b, 0, avail);
|
||||
start = avail;
|
||||
}
|
||||
while (start < needed)
|
||||
{
|
||||
int n = stream.Read(b, start, needed - start);
|
||||
if (n == 0) throw new IOException("unexpected end of stream");
|
||||
start += n;
|
||||
}
|
||||
return (b, pos + avail);
|
||||
}
|
||||
}
|
||||
|
||||
public readonly record struct ControlFrameAction(int Opcode, byte[] Payload);
|
||||
163
tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs
Normal file
163
tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs
Normal file
@@ -0,0 +1,163 @@
|
||||
using System.Buffers.Binary;
|
||||
using NATS.Server.WebSocket;
|
||||
using Shouldly;
|
||||
|
||||
namespace NATS.Server.Tests.WebSocket;
|
||||
|
||||
public class WsFrameReadTests
|
||||
{
|
||||
/// <summary>Helper: build a single unmasked binary frame.</summary>
|
||||
private static byte[] BuildFrame(byte[] payload, bool fin = true, bool compressed = false, int opcode = WsConstants.BinaryMessage, bool mask = false, byte[]? maskKey = null)
|
||||
{
|
||||
int payloadLen = payload.Length;
|
||||
byte b0 = (byte)opcode;
|
||||
if (fin) b0 |= WsConstants.FinalBit;
|
||||
if (compressed) b0 |= WsConstants.Rsv1Bit;
|
||||
byte b1 = 0;
|
||||
if (mask) b1 |= WsConstants.MaskBit;
|
||||
|
||||
byte[] lenBytes;
|
||||
if (payloadLen <= 125)
|
||||
{
|
||||
lenBytes = [(byte)(b1 | (byte)payloadLen)];
|
||||
}
|
||||
else if (payloadLen < 65536)
|
||||
{
|
||||
lenBytes = new byte[3];
|
||||
lenBytes[0] = (byte)(b1 | 126);
|
||||
BinaryPrimitives.WriteUInt16BigEndian(lenBytes.AsSpan(1), (ushort)payloadLen);
|
||||
}
|
||||
else
|
||||
{
|
||||
lenBytes = new byte[9];
|
||||
lenBytes[0] = (byte)(b1 | 127);
|
||||
BinaryPrimitives.WriteUInt64BigEndian(lenBytes.AsSpan(1), (ulong)payloadLen);
|
||||
}
|
||||
|
||||
int totalLen = 1 + lenBytes.Length + (mask ? 4 : 0) + payloadLen;
|
||||
var frame = new byte[totalLen];
|
||||
frame[0] = b0;
|
||||
lenBytes.CopyTo(frame.AsSpan(1));
|
||||
int pos = 1 + lenBytes.Length;
|
||||
if (mask && maskKey != null)
|
||||
{
|
||||
maskKey.CopyTo(frame.AsSpan(pos));
|
||||
pos += 4;
|
||||
var maskedPayload = payload.ToArray();
|
||||
WsFrameWriter.MaskBuf(maskKey, maskedPayload);
|
||||
maskedPayload.CopyTo(frame.AsSpan(pos));
|
||||
}
|
||||
else
|
||||
{
|
||||
payload.CopyTo(frame.AsSpan(pos));
|
||||
}
|
||||
return frame;
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadSingleUnmaskedFrame()
|
||||
{
|
||||
var payload = "Hello"u8.ToArray();
|
||||
var frame = BuildFrame(payload);
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(1);
|
||||
result[0].ShouldBe(payload);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadMaskedFrame()
|
||||
{
|
||||
var payload = "Hello"u8.ToArray();
|
||||
byte[] key = [0x37, 0xFA, 0x21, 0x3D];
|
||||
var frame = BuildFrame(payload, mask: true, maskKey: key);
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: true);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(1);
|
||||
result[0].ShouldBe(payload);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Read16BitLengthFrame()
|
||||
{
|
||||
var payload = new byte[200];
|
||||
Random.Shared.NextBytes(payload);
|
||||
var frame = BuildFrame(payload);
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(1);
|
||||
result[0].ShouldBe(payload);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadPingFrame_ReturnsPongAction()
|
||||
{
|
||||
var frame = BuildFrame([], opcode: WsConstants.PingMessage);
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(0); // control frames don't produce payload
|
||||
readInfo.PendingControlFrames.Count.ShouldBe(1);
|
||||
readInfo.PendingControlFrames[0].Opcode.ShouldBe(WsConstants.PongMessage);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadCloseFrame_ReturnsCloseAction()
|
||||
{
|
||||
var closePayload = new byte[2];
|
||||
BinaryPrimitives.WriteUInt16BigEndian(closePayload, 1000);
|
||||
var frame = BuildFrame(closePayload, opcode: WsConstants.CloseMessage);
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(0);
|
||||
readInfo.CloseReceived.ShouldBeTrue();
|
||||
readInfo.CloseStatus.ShouldBe(1000);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadPongFrame_NoAction()
|
||||
{
|
||||
var frame = BuildFrame([], opcode: WsConstants.PongMessage);
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(0);
|
||||
readInfo.PendingControlFrames.Count.ShouldBe(0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Unmask_Optimized_8ByteChunks()
|
||||
{
|
||||
byte[] key = [0xAA, 0xBB, 0xCC, 0xDD];
|
||||
var original = new byte[32];
|
||||
Random.Shared.NextBytes(original);
|
||||
var masked = original.ToArray();
|
||||
|
||||
// Mask it
|
||||
for (int i = 0; i < masked.Length; i++)
|
||||
masked[i] ^= key[i & 3];
|
||||
|
||||
// Unmask using the state machine
|
||||
var info = new WsReadInfo(expectMask: true);
|
||||
info.SetMaskKey(key);
|
||||
info.Unmask(masked);
|
||||
|
||||
masked.ShouldBe(original);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user