using System.Buffers.Binary; using NATS.Server.WebSocket; using Shouldly; namespace NATS.Server.Tests.WebSocket; public class WsFrameReadTests { /// Helper: build a single unmasked binary frame. private static byte[] BuildFrame(byte[] payload, bool fin = true, bool compressed = false, int opcode = WsConstants.BinaryMessage, bool mask = false, byte[]? maskKey = null) { int payloadLen = payload.Length; byte b0 = (byte)opcode; if (fin) b0 |= WsConstants.FinalBit; if (compressed) b0 |= WsConstants.Rsv1Bit; byte b1 = 0; if (mask) b1 |= WsConstants.MaskBit; byte[] lenBytes; if (payloadLen <= 125) { lenBytes = [(byte)(b1 | (byte)payloadLen)]; } else if (payloadLen < 65536) { lenBytes = new byte[3]; lenBytes[0] = (byte)(b1 | 126); BinaryPrimitives.WriteUInt16BigEndian(lenBytes.AsSpan(1), (ushort)payloadLen); } else { lenBytes = new byte[9]; lenBytes[0] = (byte)(b1 | 127); BinaryPrimitives.WriteUInt64BigEndian(lenBytes.AsSpan(1), (ulong)payloadLen); } int totalLen = 1 + lenBytes.Length + (mask ? 4 : 0) + payloadLen; var frame = new byte[totalLen]; frame[0] = b0; lenBytes.CopyTo(frame.AsSpan(1)); int pos = 1 + lenBytes.Length; if (mask && maskKey != null) { maskKey.CopyTo(frame.AsSpan(pos)); pos += 4; var maskedPayload = payload.ToArray(); WsFrameWriter.MaskBuf(maskKey, maskedPayload); maskedPayload.CopyTo(frame.AsSpan(pos)); } else { payload.CopyTo(frame.AsSpan(pos)); } return frame; } [Fact] public void ReadSingleUnmaskedFrame() { var payload = "Hello"u8.ToArray(); var frame = BuildFrame(payload); var readInfo = new WsReadInfo(expectMask: false); var stream = new MemoryStream(frame); var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); result.Count.ShouldBe(1); result[0].ShouldBe(payload); } [Fact] public void ReadMaskedFrame() { var payload = "Hello"u8.ToArray(); byte[] key = [0x37, 0xFA, 0x21, 0x3D]; var frame = BuildFrame(payload, mask: true, maskKey: key); var readInfo = new WsReadInfo(expectMask: true); var stream = new MemoryStream(frame); var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); result.Count.ShouldBe(1); result[0].ShouldBe(payload); } [Fact] public void Read16BitLengthFrame() { var payload = new byte[200]; Random.Shared.NextBytes(payload); var frame = BuildFrame(payload); var readInfo = new WsReadInfo(expectMask: false); var stream = new MemoryStream(frame); var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); result.Count.ShouldBe(1); result[0].ShouldBe(payload); } [Fact] public void ReadPingFrame_ReturnsPongAction() { var frame = BuildFrame([], opcode: WsConstants.PingMessage); var readInfo = new WsReadInfo(expectMask: false); var stream = new MemoryStream(frame); var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); result.Count.ShouldBe(0); // control frames don't produce payload readInfo.PendingControlFrames.Count.ShouldBe(1); readInfo.PendingControlFrames[0].Opcode.ShouldBe(WsConstants.PongMessage); } [Fact] public void ReadCloseFrame_ReturnsCloseAction() { var closePayload = new byte[2]; BinaryPrimitives.WriteUInt16BigEndian(closePayload, 1000); var frame = BuildFrame(closePayload, opcode: WsConstants.CloseMessage); var readInfo = new WsReadInfo(expectMask: false); var stream = new MemoryStream(frame); var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); result.Count.ShouldBe(0); readInfo.CloseReceived.ShouldBeTrue(); readInfo.CloseStatus.ShouldBe(1000); } [Fact] public void ReadPongFrame_NoAction() { var frame = BuildFrame([], opcode: WsConstants.PongMessage); var readInfo = new WsReadInfo(expectMask: false); var stream = new MemoryStream(frame); var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); result.Count.ShouldBe(0); readInfo.PendingControlFrames.Count.ShouldBe(0); } [Fact] public void Unmask_Optimized_8ByteChunks() { byte[] key = [0xAA, 0xBB, 0xCC, 0xDD]; var original = new byte[32]; Random.Shared.NextBytes(original); var masked = original.ToArray(); // Mask it for (int i = 0; i < masked.Length; i++) masked[i] ^= key[i & 3]; // Unmask using the state machine var info = new WsReadInfo(expectMask: true); info.SetMaskKey(key); info.Unmask(masked); masked.ShouldBe(original); } }