diff --git a/src/NATS.Server/WebSocket/WsFrameWriter.cs b/src/NATS.Server/WebSocket/WsFrameWriter.cs new file mode 100644 index 0000000..1d0848d --- /dev/null +++ b/src/NATS.Server/WebSocket/WsFrameWriter.cs @@ -0,0 +1,160 @@ +using System.Buffers.Binary; +using System.Security.Cryptography; +using System.Text; + +namespace NATS.Server.WebSocket; + +/// +/// WebSocket frame construction, masking, and control message creation. +/// Ported from golang/nats-server/server/websocket.go lines 543-726. +/// +public static class WsFrameWriter +{ + /// + /// Creates a complete frame header for a single-frame message (first=true, final=true). + /// Returns (header bytes, mask key or null). + /// + public static (byte[] header, byte[]? key) CreateFrameHeader( + bool useMasking, bool compressed, int opcode, int payloadLength) + { + var fh = new byte[WsConstants.MaxFrameHeaderSize]; + var (n, key) = FillFrameHeader(fh, useMasking, + first: true, final: true, compressed: compressed, opcode: opcode, payloadLength: payloadLength); + return (fh[..n], key); + } + + /// + /// Fills a pre-allocated frame header buffer. + /// Returns (bytes written, mask key or null). + /// + public static (int written, byte[]? key) FillFrameHeader( + Span fh, bool useMasking, bool first, bool final, bool compressed, int opcode, int payloadLength) + { + byte b0 = first ? (byte)opcode : (byte)0; + if (final) b0 |= WsConstants.FinalBit; + if (compressed) b0 |= WsConstants.Rsv1Bit; + + byte b1 = 0; + if (useMasking) b1 |= WsConstants.MaskBit; + + int n; + switch (payloadLength) + { + case <= 125: + n = 2; + fh[0] = b0; + fh[1] = (byte)(b1 | (byte)payloadLength); + break; + case < 65536: + n = 4; + fh[0] = b0; + fh[1] = (byte)(b1 | 126); + BinaryPrimitives.WriteUInt16BigEndian(fh[2..], (ushort)payloadLength); + break; + default: + n = 10; + fh[0] = b0; + fh[1] = (byte)(b1 | 127); + BinaryPrimitives.WriteUInt64BigEndian(fh[2..], (ulong)payloadLength); + break; + } + + byte[]? key = null; + if (useMasking) + { + key = new byte[4]; + RandomNumberGenerator.Fill(key); + key.CopyTo(fh[n..]); + n += 4; + } + + return (n, key); + } + + /// + /// XOR masks a buffer with a 4-byte key. Applies in-place. + /// + public static void MaskBuf(ReadOnlySpan key, Span buf) + { + for (int i = 0; i < buf.Length; i++) + buf[i] ^= key[i & 3]; + } + + /// + /// XOR masks multiple contiguous buffers as if they were one. + /// + public static void MaskBufs(ReadOnlySpan key, List bufs) + { + int pos = 0; + foreach (var buf in bufs) + { + for (int j = 0; j < buf.Length; j++) + { + buf[j] ^= key[pos & 3]; + pos++; + } + } + } + + /// + /// Creates a close message payload: 2-byte status code + optional UTF-8 body. + /// Body truncated to fit MaxControlPayloadSize with "..." suffix. + /// + public static byte[] CreateCloseMessage(int status, string body) + { + if (body.Length > WsConstants.MaxControlPayloadSize - WsConstants.CloseStatusSize) + { + body = body[..(WsConstants.MaxControlPayloadSize - WsConstants.CloseStatusSize - 3)] + "..."; + } + + var bodyBytes = Encoding.UTF8.GetBytes(body); + var buf = new byte[WsConstants.CloseStatusSize + bodyBytes.Length]; + BinaryPrimitives.WriteUInt16BigEndian(buf, (ushort)status); + bodyBytes.CopyTo(buf.AsSpan(WsConstants.CloseStatusSize)); + return buf; + } + + /// + /// Builds a complete control frame (header + payload, optional masking). + /// + public static byte[] BuildControlFrame(int opcode, ReadOnlySpan payload, bool useMasking) + { + int headerSize = 2 + (useMasking ? 4 : 0); + var frame = new byte[headerSize + payload.Length]; + var span = frame.AsSpan(); + var (n, key) = FillFrameHeader(span, useMasking, + first: true, final: true, compressed: false, opcode: opcode, payloadLength: payload.Length); + if (payload.Length > 0) + { + payload.CopyTo(span[n..]); + if (useMasking && key != null) + MaskBuf(key, span[n..]); + } + + return frame; + } + + /// + /// Maps a ClientClosedReason to a WebSocket close status code. + /// Matches Go wsEnqueueCloseMessage in websocket.go lines 668-694. + /// + public static int MapCloseStatus(ClientClosedReason reason) => reason switch + { + ClientClosedReason.ClientClosed => WsConstants.CloseStatusNormalClosure, + ClientClosedReason.AuthenticationTimeout or + ClientClosedReason.AuthenticationViolation or + ClientClosedReason.SlowConsumerPendingBytes or + ClientClosedReason.SlowConsumerWriteDeadline or + ClientClosedReason.MaxSubscriptionsExceeded or + ClientClosedReason.AuthenticationExpired => WsConstants.CloseStatusPolicyViolation, + ClientClosedReason.TlsHandshakeError => WsConstants.CloseStatusTlsHandshake, + ClientClosedReason.ParseError or + ClientClosedReason.ProtocolViolation => WsConstants.CloseStatusProtocolError, + ClientClosedReason.MaxPayloadExceeded => WsConstants.CloseStatusMessageTooBig, + ClientClosedReason.WriteError or + ClientClosedReason.ReadError or + ClientClosedReason.StaleConnection or + ClientClosedReason.ServerShutdown => WsConstants.CloseStatusGoingAway, + _ => WsConstants.CloseStatusInternalSrvError, + }; +} diff --git a/tests/NATS.Server.Tests/WebSocket/WsFrameWriterTests.cs b/tests/NATS.Server.Tests/WebSocket/WsFrameWriterTests.cs new file mode 100644 index 0000000..153b120 --- /dev/null +++ b/tests/NATS.Server.Tests/WebSocket/WsFrameWriterTests.cs @@ -0,0 +1,152 @@ +using System.Buffers.Binary; +using NATS.Server.WebSocket; +using Shouldly; + +namespace NATS.Server.Tests.WebSocket; + +public class WsFrameWriterTests +{ + [Fact] + public void CreateFrameHeader_SmallPayload_7BitLength() + { + var (header, _) = WsFrameWriter.CreateFrameHeader( + useMasking: false, compressed: false, + opcode: WsConstants.BinaryMessage, payloadLength: 100); + header.Length.ShouldBe(2); + (header[0] & WsConstants.FinalBit).ShouldNotBe(0); // FIN set + (header[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage); + (header[1] & 0x7F).ShouldBe(100); + } + + [Fact] + public void CreateFrameHeader_MediumPayload_16BitLength() + { + var (header, _) = WsFrameWriter.CreateFrameHeader( + useMasking: false, compressed: false, + opcode: WsConstants.BinaryMessage, payloadLength: 1000); + header.Length.ShouldBe(4); + (header[1] & 0x7F).ShouldBe(126); + BinaryPrimitives.ReadUInt16BigEndian(header.AsSpan(2)).ShouldBe((ushort)1000); + } + + [Fact] + public void CreateFrameHeader_LargePayload_64BitLength() + { + var (header, _) = WsFrameWriter.CreateFrameHeader( + useMasking: false, compressed: false, + opcode: WsConstants.BinaryMessage, payloadLength: 70000); + header.Length.ShouldBe(10); + (header[1] & 0x7F).ShouldBe(127); + BinaryPrimitives.ReadUInt64BigEndian(header.AsSpan(2)).ShouldBe(70000UL); + } + + [Fact] + public void CreateFrameHeader_WithMasking_Adds4ByteKey() + { + var (header, key) = WsFrameWriter.CreateFrameHeader( + useMasking: true, compressed: false, + opcode: WsConstants.BinaryMessage, payloadLength: 10); + header.Length.ShouldBe(6); // 2 header + 4 mask key + (header[1] & WsConstants.MaskBit).ShouldNotBe(0); + key.ShouldNotBeNull(); + key.Length.ShouldBe(4); + } + + [Fact] + public void CreateFrameHeader_Compressed_SetsRsv1Bit() + { + var (header, _) = WsFrameWriter.CreateFrameHeader( + useMasking: false, compressed: true, + opcode: WsConstants.BinaryMessage, payloadLength: 10); + (header[0] & WsConstants.Rsv1Bit).ShouldNotBe(0); + } + + [Fact] + public void MaskBuf_XorsCorrectly() + { + byte[] key = [0xAA, 0xBB, 0xCC, 0xDD]; + byte[] data = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06]; + byte[] expected = new byte[data.Length]; + for (int i = 0; i < data.Length; i++) + expected[i] = (byte)(data[i] ^ key[i & 3]); + + WsFrameWriter.MaskBuf(key, data); + data.ShouldBe(expected); + } + + [Fact] + public void MaskBuf_RoundTrip() + { + byte[] key = [0x12, 0x34, 0x56, 0x78]; + byte[] original = "Hello, WebSocket!"u8.ToArray(); + var data = original.ToArray(); + + WsFrameWriter.MaskBuf(key, data); + data.ShouldNotBe(original); + WsFrameWriter.MaskBuf(key, data); + data.ShouldBe(original); + } + + [Fact] + public void CreateCloseMessage_WithStatusAndBody() + { + var msg = WsFrameWriter.CreateCloseMessage(1000, "normal closure"); + msg.Length.ShouldBe(2 + "normal closure".Length); + BinaryPrimitives.ReadUInt16BigEndian(msg).ShouldBe((ushort)1000); + } + + [Fact] + public void CreateCloseMessage_LongBody_Truncated() + { + var longBody = new string('x', 200); + var msg = WsFrameWriter.CreateCloseMessage(1000, longBody); + msg.Length.ShouldBeLessThanOrEqualTo(WsConstants.MaxControlPayloadSize); + } + + [Fact] + public void MapCloseStatus_ClientClosed_NormalClosure() + { + WsFrameWriter.MapCloseStatus(ClientClosedReason.ClientClosed) + .ShouldBe(WsConstants.CloseStatusNormalClosure); + } + + [Fact] + public void MapCloseStatus_AuthTimeout_PolicyViolation() + { + WsFrameWriter.MapCloseStatus(ClientClosedReason.AuthenticationTimeout) + .ShouldBe(WsConstants.CloseStatusPolicyViolation); + } + + [Fact] + public void MapCloseStatus_ParseError_ProtocolError() + { + WsFrameWriter.MapCloseStatus(ClientClosedReason.ParseError) + .ShouldBe(WsConstants.CloseStatusProtocolError); + } + + [Fact] + public void MapCloseStatus_MaxPayload_MessageTooBig() + { + WsFrameWriter.MapCloseStatus(ClientClosedReason.MaxPayloadExceeded) + .ShouldBe(WsConstants.CloseStatusMessageTooBig); + } + + [Fact] + public void BuildControlFrame_PingNomask() + { + var frame = WsFrameWriter.BuildControlFrame(WsConstants.PingMessage, [], useMasking: false); + frame.Length.ShouldBe(2); + (frame[0] & WsConstants.FinalBit).ShouldNotBe(0); + (frame[0] & 0x0F).ShouldBe(WsConstants.PingMessage); + (frame[1] & 0x7F).ShouldBe(0); + } + + [Fact] + public void BuildControlFrame_PongWithPayload() + { + byte[] payload = [1, 2, 3, 4]; + var frame = WsFrameWriter.BuildControlFrame(WsConstants.PongMessage, payload, useMasking: false); + frame.Length.ShouldBe(2 + 4); + frame[2..].ShouldBe(payload); + } +}