diff --git a/src/NATS.Server/WebSocket/WsReadInfo.cs b/src/NATS.Server/WebSocket/WsReadInfo.cs new file mode 100644 index 0000000..2385f38 --- /dev/null +++ b/src/NATS.Server/WebSocket/WsReadInfo.cs @@ -0,0 +1,313 @@ +using System.Buffers.Binary; +using System.Text; + +namespace NATS.Server.WebSocket; + +/// +/// Per-connection WebSocket frame reading state machine. +/// Ported from golang/nats-server/server/websocket.go lines 156-506. +/// +public struct WsReadInfo +{ + public int Remaining; + public bool FrameStart; + public bool FirstFrame; + public bool FrameCompressed; + public bool ExpectMask; + public byte MaskKeyPos; + public byte[] MaskKey; + public List? CompressedBuffers; + public int CompressedOffset; + + // Control frame outputs + public List PendingControlFrames; + public bool CloseReceived; + public int CloseStatus; + public string? CloseBody; + + public WsReadInfo(bool expectMask) + { + Remaining = 0; + FrameStart = true; + FirstFrame = true; + FrameCompressed = false; + ExpectMask = expectMask; + MaskKeyPos = 0; + MaskKey = new byte[4]; + CompressedBuffers = null; + CompressedOffset = 0; + PendingControlFrames = []; + CloseReceived = false; + CloseStatus = 0; + CloseBody = null; + } + + public void SetMaskKey(ReadOnlySpan key) + { + key[..4].CopyTo(MaskKey); + MaskKeyPos = 0; + } + + /// + /// Unmask buffer in-place using current mask key and position. + /// Optimized for 8-byte chunks when buffer is large enough. + /// Ported from websocket.go lines 509-536. + /// + public void Unmask(Span buf) + { + int p = MaskKeyPos; + if (buf.Length < 16) + { + for (int i = 0; i < buf.Length; i++) + { + buf[i] ^= MaskKey[p & 3]; + p++; + } + MaskKeyPos = (byte)(p & 3); + return; + } + + // Build 8-byte key for bulk XOR + Span k = stackalloc byte[8]; + for (int i = 0; i < 8; i++) + k[i] = MaskKey[(p + i) & 3]; + ulong km = BinaryPrimitives.ReadUInt64BigEndian(k); + + int n = (buf.Length / 8) * 8; + for (int i = 0; i < n; i += 8) + { + ulong tmp = BinaryPrimitives.ReadUInt64BigEndian(buf[i..]); + tmp ^= km; + BinaryPrimitives.WriteUInt64BigEndian(buf[i..], tmp); + } + + // Handle remaining bytes + p += n; + var tail = buf[n..]; + for (int i = 0; i < tail.Length; i++) + { + tail[i] ^= MaskKey[p & 3]; + p++; + } + MaskKeyPos = (byte)(p & 3); + } + + /// + /// Read and decode WebSocket frames from a buffer. + /// Returns list of decoded payload byte arrays. + /// Ported from websocket.go lines 208-351. + /// + public static List ReadFrames(ref WsReadInfo r, Stream stream, int available, int maxPayload) + { + var bufs = new List(); + var buf = new byte[available]; + int bytesRead = 0; + + // Fill the buffer from the stream + while (bytesRead < available) + { + int n = stream.Read(buf, bytesRead, available - bytesRead); + if (n == 0) break; + bytesRead += n; + } + + int pos = 0; + int max = bytesRead; + + while (pos < max) + { + if (r.FrameStart) + { + if (pos >= max) break; + byte b0 = buf[pos]; + int frameType = b0 & 0x0F; + bool final = (b0 & WsConstants.FinalBit) != 0; + bool compressed = (b0 & WsConstants.Rsv1Bit) != 0; + pos++; + + // Read second byte + var (b1Buf, newPos) = WsGet(stream, buf, pos, max, 1); + pos = newPos; + byte b1 = b1Buf[0]; + + // Check mask bit + if (r.ExpectMask && (b1 & WsConstants.MaskBit) == 0) + throw new InvalidOperationException("mask bit missing"); + + r.Remaining = b1 & 0x7F; + + // Validate frame types + if (WsConstants.IsControlFrame(frameType)) + { + if (r.Remaining > WsConstants.MaxControlPayloadSize) + throw new InvalidOperationException("control frame length too large"); + if (!final) + throw new InvalidOperationException("control frame does not have final bit set"); + } + else if (frameType == WsConstants.TextMessage || frameType == WsConstants.BinaryMessage) + { + if (!r.FirstFrame) + throw new InvalidOperationException("new message before previous finished"); + r.FirstFrame = final; + r.FrameCompressed = compressed; + } + else if (frameType == WsConstants.ContinuationFrame) + { + if (r.FirstFrame || compressed) + throw new InvalidOperationException("invalid continuation frame"); + r.FirstFrame = final; + } + else + { + throw new InvalidOperationException($"unknown opcode {frameType}"); + } + + // Extended payload length + switch (r.Remaining) + { + case 126: + { + var (lenBuf, p2) = WsGet(stream, buf, pos, max, 2); + pos = p2; + r.Remaining = BinaryPrimitives.ReadUInt16BigEndian(lenBuf); + break; + } + case 127: + { + var (lenBuf, p2) = WsGet(stream, buf, pos, max, 8); + pos = p2; + r.Remaining = (int)BinaryPrimitives.ReadUInt64BigEndian(lenBuf); + break; + } + } + + // Read mask key + if (r.ExpectMask && (b1 & WsConstants.MaskBit) != 0) + { + var (keyBuf, p2) = WsGet(stream, buf, pos, max, 4); + pos = p2; + keyBuf.AsSpan(0, 4).CopyTo(r.MaskKey); + r.MaskKeyPos = 0; + } + + // Handle control frames + if (WsConstants.IsControlFrame(frameType)) + { + pos = HandleControlFrame(ref r, frameType, stream, buf, pos, max); + continue; + } + + r.FrameStart = false; + } + + if (pos < max) + { + int n = r.Remaining; + if (pos + n > max) n = max - pos; + + var payloadSlice = buf.AsSpan(pos, n).ToArray(); + pos += n; + r.Remaining -= n; + + if (r.ExpectMask) + r.Unmask(payloadSlice); + + bool addToBufs = true; + if (r.FrameCompressed) + { + addToBufs = false; + r.CompressedBuffers ??= []; + r.CompressedBuffers.Add(payloadSlice); + + if (r.FirstFrame && r.Remaining == 0) + { + var decompressed = WsCompression.Decompress(r.CompressedBuffers, maxPayload); + r.CompressedBuffers = null; + r.FrameCompressed = false; + addToBufs = true; + payloadSlice = decompressed; + } + } + + if (addToBufs && payloadSlice.Length > 0) + bufs.Add(payloadSlice); + + if (r.Remaining == 0) + r.FrameStart = true; + } + } + + return bufs; + } + + private static int HandleControlFrame(ref WsReadInfo r, int frameType, Stream stream, byte[] buf, int pos, int max) + { + byte[]? payload = null; + if (r.Remaining > 0) + { + var (payloadBuf, newPos) = WsGet(stream, buf, pos, max, r.Remaining); + pos = newPos; + payload = payloadBuf; + if (r.ExpectMask) + r.Unmask(payload); + r.Remaining = 0; + } + + switch (frameType) + { + case WsConstants.CloseMessage: + r.CloseReceived = true; + r.CloseStatus = WsConstants.CloseStatusNoStatusReceived; + if (payload != null && payload.Length >= WsConstants.CloseStatusSize) + { + r.CloseStatus = BinaryPrimitives.ReadUInt16BigEndian(payload); + if (payload.Length > WsConstants.CloseStatusSize) + r.CloseBody = Encoding.UTF8.GetString(payload.AsSpan(WsConstants.CloseStatusSize)); + } + if (r.CloseStatus != WsConstants.CloseStatusNoStatusReceived) + { + var closeMsg = WsFrameWriter.CreateCloseMessage(r.CloseStatus, r.CloseBody ?? ""); + r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.CloseMessage, closeMsg)); + } + break; + + case WsConstants.PingMessage: + r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.PongMessage, payload ?? [])); + break; + + case WsConstants.PongMessage: + // Nothing to do + break; + } + + return pos; + } + + /// + /// Gets needed bytes from buffer or reads from stream. + /// Ported from websocket.go lines 178-193. + /// + private static (byte[] data, int newPos) WsGet(Stream stream, byte[] buf, int pos, int max, int needed) + { + int avail = max - pos; + if (avail >= needed) + return (buf[pos..(pos + needed)], pos + needed); + + var b = new byte[needed]; + int start = 0; + if (avail > 0) + { + Buffer.BlockCopy(buf, pos, b, 0, avail); + start = avail; + } + while (start < needed) + { + int n = stream.Read(b, start, needed - start); + if (n == 0) throw new IOException("unexpected end of stream"); + start += n; + } + return (b, pos + avail); + } +} + +public readonly record struct ControlFrameAction(int Opcode, byte[] Payload); diff --git a/tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs b/tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs new file mode 100644 index 0000000..9a21b53 --- /dev/null +++ b/tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs @@ -0,0 +1,163 @@ +using System.Buffers.Binary; +using NATS.Server.WebSocket; +using Shouldly; + +namespace NATS.Server.Tests.WebSocket; + +public class WsFrameReadTests +{ + /// Helper: build a single unmasked binary frame. + private static byte[] BuildFrame(byte[] payload, bool fin = true, bool compressed = false, int opcode = WsConstants.BinaryMessage, bool mask = false, byte[]? maskKey = null) + { + int payloadLen = payload.Length; + byte b0 = (byte)opcode; + if (fin) b0 |= WsConstants.FinalBit; + if (compressed) b0 |= WsConstants.Rsv1Bit; + byte b1 = 0; + if (mask) b1 |= WsConstants.MaskBit; + + byte[] lenBytes; + if (payloadLen <= 125) + { + lenBytes = [(byte)(b1 | (byte)payloadLen)]; + } + else if (payloadLen < 65536) + { + lenBytes = new byte[3]; + lenBytes[0] = (byte)(b1 | 126); + BinaryPrimitives.WriteUInt16BigEndian(lenBytes.AsSpan(1), (ushort)payloadLen); + } + else + { + lenBytes = new byte[9]; + lenBytes[0] = (byte)(b1 | 127); + BinaryPrimitives.WriteUInt64BigEndian(lenBytes.AsSpan(1), (ulong)payloadLen); + } + + int totalLen = 1 + lenBytes.Length + (mask ? 4 : 0) + payloadLen; + var frame = new byte[totalLen]; + frame[0] = b0; + lenBytes.CopyTo(frame.AsSpan(1)); + int pos = 1 + lenBytes.Length; + if (mask && maskKey != null) + { + maskKey.CopyTo(frame.AsSpan(pos)); + pos += 4; + var maskedPayload = payload.ToArray(); + WsFrameWriter.MaskBuf(maskKey, maskedPayload); + maskedPayload.CopyTo(frame.AsSpan(pos)); + } + else + { + payload.CopyTo(frame.AsSpan(pos)); + } + return frame; + } + + [Fact] + public void ReadSingleUnmaskedFrame() + { + var payload = "Hello"u8.ToArray(); + var frame = BuildFrame(payload); + + var readInfo = new WsReadInfo(expectMask: false); + var stream = new MemoryStream(frame); + var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); + + result.Count.ShouldBe(1); + result[0].ShouldBe(payload); + } + + [Fact] + public void ReadMaskedFrame() + { + var payload = "Hello"u8.ToArray(); + byte[] key = [0x37, 0xFA, 0x21, 0x3D]; + var frame = BuildFrame(payload, mask: true, maskKey: key); + + var readInfo = new WsReadInfo(expectMask: true); + var stream = new MemoryStream(frame); + var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); + + result.Count.ShouldBe(1); + result[0].ShouldBe(payload); + } + + [Fact] + public void Read16BitLengthFrame() + { + var payload = new byte[200]; + Random.Shared.NextBytes(payload); + var frame = BuildFrame(payload); + + var readInfo = new WsReadInfo(expectMask: false); + var stream = new MemoryStream(frame); + var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); + + result.Count.ShouldBe(1); + result[0].ShouldBe(payload); + } + + [Fact] + public void ReadPingFrame_ReturnsPongAction() + { + var frame = BuildFrame([], opcode: WsConstants.PingMessage); + + var readInfo = new WsReadInfo(expectMask: false); + var stream = new MemoryStream(frame); + var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); + + result.Count.ShouldBe(0); // control frames don't produce payload + readInfo.PendingControlFrames.Count.ShouldBe(1); + readInfo.PendingControlFrames[0].Opcode.ShouldBe(WsConstants.PongMessage); + } + + [Fact] + public void ReadCloseFrame_ReturnsCloseAction() + { + var closePayload = new byte[2]; + BinaryPrimitives.WriteUInt16BigEndian(closePayload, 1000); + var frame = BuildFrame(closePayload, opcode: WsConstants.CloseMessage); + + var readInfo = new WsReadInfo(expectMask: false); + var stream = new MemoryStream(frame); + var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); + + result.Count.ShouldBe(0); + readInfo.CloseReceived.ShouldBeTrue(); + readInfo.CloseStatus.ShouldBe(1000); + } + + [Fact] + public void ReadPongFrame_NoAction() + { + var frame = BuildFrame([], opcode: WsConstants.PongMessage); + + var readInfo = new WsReadInfo(expectMask: false); + var stream = new MemoryStream(frame); + var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); + + result.Count.ShouldBe(0); + readInfo.PendingControlFrames.Count.ShouldBe(0); + } + + [Fact] + public void Unmask_Optimized_8ByteChunks() + { + byte[] key = [0xAA, 0xBB, 0xCC, 0xDD]; + var original = new byte[32]; + Random.Shared.NextBytes(original); + var masked = original.ToArray(); + + // Mask it + for (int i = 0; i < masked.Length; i++) + masked[i] ^= key[i & 3]; + + // Unmask using the state machine + var info = new WsReadInfo(expectMask: true); + info.SetMaskKey(key); + info.Unmask(masked); + + masked.ShouldBe(original); + } +}