# WebSocket Support Implementation Plan > **For Claude:** REQUIRED SUB-SKILL: Use superpowers-extended-cc:executing-plans to implement this plan task-by-task. **Goal:** Port full WebSocket connection support from the Go NATS server to the .NET solution, enabling NATS clients to connect over WebSocket with compression, masking, origin checking, and cookie-based auth. **Architecture:** Self-contained `WebSocket/` module under `src/NATS.Server/` with custom frame parser (no System.Net.WebSockets). A `WsConnection` Stream wrapper integrates transparently with existing `NatsClient` read/write loops. Second TCP accept loop in `NatsServer` handles WebSocket port. **Tech Stack:** .NET 10, System.IO.Compression (DeflateStream), System.Security.Cryptography (SHA1), xUnit 3, Shouldly --- ### Task 0: Add WebSocketOptions configuration **Files:** - Modify: `src/NATS.Server/NatsOptions.cs` **Step 1: Write the failing test** Create test file `tests/NATS.Server.Tests/WebSocket/WebSocketOptionsTests.cs`: ```csharp using Shouldly; namespace NATS.Server.Tests.WebSocket; public class WebSocketOptionsTests { [Fact] public void DefaultOptions_PortIsZero_Disabled() { var opts = new WebSocketOptions(); opts.Port.ShouldBe(0); opts.Host.ShouldBe("0.0.0.0"); opts.Compression.ShouldBeFalse(); opts.NoTls.ShouldBeFalse(); opts.HandshakeTimeout.ShouldBe(TimeSpan.FromSeconds(2)); opts.AuthTimeout.ShouldBe(TimeSpan.FromSeconds(2)); } [Fact] public void NatsOptions_HasWebSocketProperty() { var opts = new NatsOptions(); opts.WebSocket.ShouldNotBeNull(); opts.WebSocket.Port.ShouldBe(0); } } ``` **Step 2: Run test to verify it fails** Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WebSocketOptionsTests" -v normal` Expected: FAIL — `WebSocketOptions` type does not exist **Step 3: Write minimal implementation** Add to `src/NATS.Server/NatsOptions.cs` — a new `WebSocketOptions` class and property on `NatsOptions`: ```csharp public sealed class WebSocketOptions { public string Host { get; set; } = "0.0.0.0"; public int Port { get; set; } public string? Advertise { get; set; } public string? NoAuthUser { get; set; } public string? JwtCookie { get; set; } public string? UsernameCookie { get; set; } public string? PasswordCookie { get; set; } public string? TokenCookie { get; set; } public string? Username { get; set; } public string? Password { get; set; } public string? Token { get; set; } public TimeSpan AuthTimeout { get; set; } = TimeSpan.FromSeconds(2); public bool NoTls { get; set; } public string? TlsCert { get; set; } public string? TlsKey { get; set; } public bool SameOrigin { get; set; } public List? AllowedOrigins { get; set; } public bool Compression { get; set; } public TimeSpan HandshakeTimeout { get; set; } = TimeSpan.FromSeconds(2); public TimeSpan? PingInterval { get; set; } public Dictionary? Headers { get; set; } } ``` Add to `NatsOptions`: ```csharp public WebSocketOptions WebSocket { get; set; } = new(); ``` **Step 4: Run test to verify it passes** Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WebSocketOptionsTests" -v normal` Expected: PASS **Step 5: Commit** ```bash git add src/NATS.Server/NatsOptions.cs tests/NATS.Server.Tests/WebSocket/WebSocketOptionsTests.cs git commit -m "feat: add WebSocketOptions configuration class" ``` --- ### Task 1: Add WsConstants **Files:** - Create: `src/NATS.Server/WebSocket/WsConstants.cs` **Reference:** `golang/nats-server/server/websocket.go` lines 41-106 **Step 1: Write the failing test** Create `tests/NATS.Server.Tests/WebSocket/WsConstantsTests.cs`: ```csharp using NATS.Server.WebSocket; using Shouldly; namespace NATS.Server.Tests.WebSocket; public class WsConstantsTests { [Fact] public void OpCodes_MatchRfc6455() { WsConstants.TextMessage.ShouldBe(1); WsConstants.BinaryMessage.ShouldBe(2); WsConstants.CloseMessage.ShouldBe(8); WsConstants.PingMessage.ShouldBe(9); WsConstants.PongMessage.ShouldBe(10); } [Fact] public void FrameBits_MatchRfc6455() { WsConstants.FinalBit.ShouldBe(0x80); WsConstants.Rsv1Bit.ShouldBe(0x40); WsConstants.MaskBit.ShouldBe(0x80); } [Fact] public void CloseStatusCodes_MatchRfc6455() { WsConstants.CloseStatusNormalClosure.ShouldBe(1000); WsConstants.CloseStatusGoingAway.ShouldBe(1001); WsConstants.CloseStatusProtocolError.ShouldBe(1002); WsConstants.CloseStatusPolicyViolation.ShouldBe(1008); WsConstants.CloseStatusMessageTooBig.ShouldBe(1009); } [Theory] [InlineData(WsConstants.CloseMessage)] [InlineData(WsConstants.PingMessage)] [InlineData(WsConstants.PongMessage)] public void IsControlFrame_True(int opcode) { WsConstants.IsControlFrame(opcode).ShouldBeTrue(); } [Theory] [InlineData(WsConstants.TextMessage)] [InlineData(WsConstants.BinaryMessage)] [InlineData(0)] public void IsControlFrame_False(int opcode) { WsConstants.IsControlFrame(opcode).ShouldBeFalse(); } } ``` **Step 2: Run test to verify it fails** Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsConstantsTests" -v normal` Expected: FAIL — `WsConstants` does not exist **Step 3: Write minimal implementation** Create `src/NATS.Server/WebSocket/WsConstants.cs`: ```csharp namespace NATS.Server.WebSocket; /// /// WebSocket protocol constants (RFC 6455). /// Ported from golang/nats-server/server/websocket.go lines 41-106. /// public static class WsConstants { // Opcodes (RFC 6455 Section 5.2) public const int TextMessage = 1; public const int BinaryMessage = 2; public const int CloseMessage = 8; public const int PingMessage = 9; public const int PongMessage = 10; public const int ContinuationFrame = 0; // Frame header bits public const byte FinalBit = 0x80; // 1 << 7 public const byte Rsv1Bit = 0x40; // 1 << 6 (compression, RFC 7692) public const byte Rsv2Bit = 0x20; // 1 << 5 public const byte Rsv3Bit = 0x10; // 1 << 4 public const byte MaskBit = 0x80; // 1 << 7 (in second byte) // Frame size limits public const int MaxFrameHeaderSize = 14; public const int MaxControlPayloadSize = 125; public const int FrameSizeForBrowsers = 4096; public const int CompressThreshold = 64; public const int CloseStatusSize = 2; // Close status codes (RFC 6455 Section 11.7) public const int CloseStatusNormalClosure = 1000; public const int CloseStatusGoingAway = 1001; public const int CloseStatusProtocolError = 1002; public const int CloseStatusUnsupportedData = 1003; public const int CloseStatusNoStatusReceived = 1005; public const int CloseStatusInvalidPayloadData = 1007; public const int CloseStatusPolicyViolation = 1008; public const int CloseStatusMessageTooBig = 1009; public const int CloseStatusInternalSrvError = 1011; public const int CloseStatusTlsHandshake = 1015; // Compression constants (RFC 7692) public const string PmcExtension = "permessage-deflate"; public const string PmcSrvNoCtx = "server_no_context_takeover"; public const string PmcCliNoCtx = "client_no_context_takeover"; public static readonly string PmcReqHeaderValue = $"{PmcExtension}; {PmcSrvNoCtx}; {PmcCliNoCtx}"; public static readonly string PmcFullResponse = $"Sec-WebSocket-Extensions: {PmcExtension}; {PmcSrvNoCtx}; {PmcCliNoCtx}\r\n"; // Header names public const string NoMaskingHeader = "Nats-No-Masking"; public const string NoMaskingValue = "true"; public static readonly string NoMaskingFullResponse = $"{NoMaskingHeader}: {NoMaskingValue}\r\n"; public const string XForwardedForHeader = "X-Forwarded-For"; // Path routing public const string ClientPath = "/"; public const string LeafNodePath = "/leafnode"; public const string MqttPath = "/mqtt"; // WebSocket GUID (RFC 6455 Section 1.3) public static readonly byte[] WsGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"u8.ToArray(); // Compression trailer (RFC 7692 Section 7.2.2) public static readonly byte[] CompressLastBlock = [0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff]; // Decompression trailer appended before decompressing public static readonly byte[] DecompressTrailer = [0x00, 0x00, 0xff, 0xff]; public static bool IsControlFrame(int opcode) => opcode >= CloseMessage; } public enum WsClientKind { Client, Leaf, Mqtt, } ``` **Step 4: Run test to verify it passes** Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsConstantsTests" -v normal` Expected: PASS **Step 5: Commit** ```bash git add src/NATS.Server/WebSocket/WsConstants.cs tests/NATS.Server.Tests/WebSocket/WsConstantsTests.cs git commit -m "feat: add WebSocket constants (RFC 6455/7692)" ``` --- ### Task 2: Add WsOriginChecker **Files:** - Create: `src/NATS.Server/WebSocket/WsOriginChecker.cs` - Create: `tests/NATS.Server.Tests/WebSocket/WsOriginCheckerTests.cs` **Reference:** `golang/nats-server/server/websocket.go` lines 933-1000 (`checkOrigin`, `wsGetHostAndPort`) **Step 1: Write the failing test** Create `tests/NATS.Server.Tests/WebSocket/WsOriginCheckerTests.cs`: ```csharp using NATS.Server.WebSocket; using Shouldly; namespace NATS.Server.Tests.WebSocket; public class WsOriginCheckerTests { [Fact] public void NoOriginHeader_Accepted() { var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); checker.CheckOrigin(origin: null, requestHost: "localhost:4222", isTls: false) .ShouldBeNull(); } [Fact] public void NeitherSameNorList_AlwaysAccepted() { var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: null); checker.CheckOrigin("https://evil.com", "localhost:4222", false) .ShouldBeNull(); } [Fact] public void SameOrigin_Match() { var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); checker.CheckOrigin("http://localhost:4222", "localhost:4222", false) .ShouldBeNull(); } [Fact] public void SameOrigin_Mismatch() { var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); checker.CheckOrigin("http://other:4222", "localhost:4222", false) .ShouldNotBeNull(); } [Fact] public void SameOrigin_DefaultPort_Http() { var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); // No port in origin means port 80 for http checker.CheckOrigin("http://localhost", "localhost:80", false) .ShouldBeNull(); } [Fact] public void SameOrigin_DefaultPort_Https() { var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); checker.CheckOrigin("https://localhost", "localhost:443", true) .ShouldBeNull(); } [Fact] public void AllowedOrigins_Match() { var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: ["https://app.example.com"]); checker.CheckOrigin("https://app.example.com", "localhost:4222", false) .ShouldBeNull(); } [Fact] public void AllowedOrigins_Mismatch() { var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: ["https://app.example.com"]); checker.CheckOrigin("https://evil.example.com", "localhost:4222", false) .ShouldNotBeNull(); } [Fact] public void AllowedOrigins_SchemeMismatch() { var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: ["https://app.example.com"]); checker.CheckOrigin("http://app.example.com", "localhost:4222", false) .ShouldNotBeNull(); } } ``` **Step 2: Run test to verify it fails** Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsOriginCheckerTests" -v normal` Expected: FAIL — `WsOriginChecker` does not exist **Step 3: Write minimal implementation** Create `src/NATS.Server/WebSocket/WsOriginChecker.cs`: ```csharp namespace NATS.Server.WebSocket; /// /// Validates WebSocket Origin headers per RFC 6455 Section 10.2. /// Ported from golang/nats-server/server/websocket.go lines 933-1000. /// public sealed class WsOriginChecker { private readonly bool _sameOrigin; private readonly Dictionary? _allowedOrigins; public WsOriginChecker(bool sameOrigin, List? allowedOrigins) { _sameOrigin = sameOrigin; if (allowedOrigins is { Count: > 0 }) { _allowedOrigins = new Dictionary(StringComparer.OrdinalIgnoreCase); foreach (var ao in allowedOrigins) { if (Uri.TryCreate(ao, UriKind.Absolute, out var uri)) { var (host, port) = GetHostAndPort(uri.Scheme == "https", uri.Host, uri.Port); _allowedOrigins[host] = new AllowedOrigin(uri.Scheme, port); } } } } /// /// Returns null if origin is allowed, or an error message if rejected. /// public string? CheckOrigin(string? origin, string requestHost, bool isTls) { if (!_sameOrigin && _allowedOrigins == null) return null; if (string.IsNullOrEmpty(origin)) return null; if (!Uri.TryCreate(origin, UriKind.Absolute, out var originUri)) return $"invalid origin: {origin}"; var (oh, op) = GetHostAndPort(originUri.Scheme == "https", originUri.Host, originUri.Port); if (_sameOrigin) { var (rh, rp) = ParseHostPort(requestHost, isTls); if (!string.Equals(oh, rh, StringComparison.OrdinalIgnoreCase) || op != rp) return "not same origin"; } if (_allowedOrigins != null) { if (!_allowedOrigins.TryGetValue(oh, out var allowed) || !string.Equals(originUri.Scheme, allowed.Scheme, StringComparison.OrdinalIgnoreCase) || op != allowed.Port) { return "not in the allowed list"; } } return null; } private static (string host, int port) GetHostAndPort(bool tls, string host, int port) { if (port <= 0) port = tls ? 443 : 80; return (host.ToLowerInvariant(), port); } private static (string host, int port) ParseHostPort(string hostPort, bool isTls) { var colonIdx = hostPort.LastIndexOf(':'); if (colonIdx > 0 && int.TryParse(hostPort.AsSpan(colonIdx + 1), out var port)) return (hostPort[..colonIdx].ToLowerInvariant(), port); return (hostPort.ToLowerInvariant(), isTls ? 443 : 80); } private readonly record struct AllowedOrigin(string Scheme, int Port); } ``` **Step 4: Run test to verify it passes** Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsOriginCheckerTests" -v normal` Expected: PASS **Step 5: Commit** ```bash git add src/NATS.Server/WebSocket/WsOriginChecker.cs tests/NATS.Server.Tests/WebSocket/WsOriginCheckerTests.cs git commit -m "feat: add WebSocket origin checker" ``` --- ### Task 3: Add WsFrameWriter (frame header construction, masking, control frames) **Files:** - Create: `src/NATS.Server/WebSocket/WsFrameWriter.cs` - Create: `tests/NATS.Server.Tests/WebSocket/WsFrameWriterTests.cs` **Reference:** `golang/nats-server/server/websocket.go` lines 543-726 (`wsFillFrameHeader`, `wsCreateFrameHeader`, `wsMaskBuf`, `wsCreateCloseMessage`, `wsEnqueueControlMessageLocked`) **Step 1: Write the failing test** Create `tests/NATS.Server.Tests/WebSocket/WsFrameWriterTests.cs`: ```csharp using System.Buffers.Binary; using NATS.Server.WebSocket; using Shouldly; namespace NATS.Server.Tests.WebSocket; public class WsFrameWriterTests { [Fact] public void CreateFrameHeader_SmallPayload_7BitLength() { var (header, _) = WsFrameWriter.CreateFrameHeader( useMasking: false, compressed: false, opcode: WsConstants.BinaryMessage, payloadLength: 100); header.Length.ShouldBe(2); (header[0] & WsConstants.FinalBit).ShouldNotBe(0); // FIN set (header[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage); (header[1] & 0x7F).ShouldBe(100); } [Fact] public void CreateFrameHeader_MediumPayload_16BitLength() { var (header, _) = WsFrameWriter.CreateFrameHeader( useMasking: false, compressed: false, opcode: WsConstants.BinaryMessage, payloadLength: 1000); header.Length.ShouldBe(4); (header[1] & 0x7F).ShouldBe(126); BinaryPrimitives.ReadUInt16BigEndian(header.AsSpan(2)).ShouldBe((ushort)1000); } [Fact] public void CreateFrameHeader_LargePayload_64BitLength() { var (header, _) = WsFrameWriter.CreateFrameHeader( useMasking: false, compressed: false, opcode: WsConstants.BinaryMessage, payloadLength: 70000); header.Length.ShouldBe(10); (header[1] & 0x7F).ShouldBe(127); BinaryPrimitives.ReadUInt64BigEndian(header.AsSpan(2)).ShouldBe(70000UL); } [Fact] public void CreateFrameHeader_WithMasking_Adds4ByteKey() { var (header, key) = WsFrameWriter.CreateFrameHeader( useMasking: true, compressed: false, opcode: WsConstants.BinaryMessage, payloadLength: 10); header.Length.ShouldBe(6); // 2 header + 4 mask key (header[1] & WsConstants.MaskBit).ShouldNotBe(0); key.ShouldNotBeNull(); key.Length.ShouldBe(4); } [Fact] public void CreateFrameHeader_Compressed_SetsRsv1Bit() { var (header, _) = WsFrameWriter.CreateFrameHeader( useMasking: false, compressed: true, opcode: WsConstants.BinaryMessage, payloadLength: 10); (header[0] & WsConstants.Rsv1Bit).ShouldNotBe(0); } [Fact] public void MaskBuf_XorsCorrectly() { byte[] key = [0xAA, 0xBB, 0xCC, 0xDD]; byte[] data = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06]; byte[] expected = new byte[data.Length]; for (int i = 0; i < data.Length; i++) expected[i] = (byte)(data[i] ^ key[i & 3]); WsFrameWriter.MaskBuf(key, data); data.ShouldBe(expected); } [Fact] public void MaskBuf_RoundTrip() { byte[] key = [0x12, 0x34, 0x56, 0x78]; byte[] original = "Hello, WebSocket!"u8.ToArray(); var data = original.ToArray(); WsFrameWriter.MaskBuf(key, data); data.ShouldNotBe(original); WsFrameWriter.MaskBuf(key, data); data.ShouldBe(original); } [Fact] public void CreateCloseMessage_WithStatusAndBody() { var msg = WsFrameWriter.CreateCloseMessage(1000, "normal closure"); msg.Length.ShouldBe(2 + "normal closure".Length); BinaryPrimitives.ReadUInt16BigEndian(msg).ShouldBe((ushort)1000); } [Fact] public void CreateCloseMessage_LongBody_Truncated() { var longBody = new string('x', 200); var msg = WsFrameWriter.CreateCloseMessage(1000, longBody); msg.Length.ShouldBeLessThanOrEqualTo(WsConstants.MaxControlPayloadSize); } [Fact] public void MapCloseStatus_ClientClosed_NormalClosure() { WsFrameWriter.MapCloseStatus(ClientClosedReason.ClientClosed) .ShouldBe(WsConstants.CloseStatusNormalClosure); } [Fact] public void MapCloseStatus_AuthTimeout_PolicyViolation() { WsFrameWriter.MapCloseStatus(ClientClosedReason.AuthenticationTimeout) .ShouldBe(WsConstants.CloseStatusPolicyViolation); } [Fact] public void MapCloseStatus_ParseError_ProtocolError() { WsFrameWriter.MapCloseStatus(ClientClosedReason.ParseError) .ShouldBe(WsConstants.CloseStatusProtocolError); } [Fact] public void MapCloseStatus_MaxPayload_MessageTooBig() { WsFrameWriter.MapCloseStatus(ClientClosedReason.MaxPayloadExceeded) .ShouldBe(WsConstants.CloseStatusMessageTooBig); } [Fact] public void BuildControlFrame_PingNomask() { var frame = WsFrameWriter.BuildControlFrame(WsConstants.PingMessage, [], useMasking: false); frame.Length.ShouldBe(2); (frame[0] & WsConstants.FinalBit).ShouldNotBe(0); (frame[0] & 0x0F).ShouldBe(WsConstants.PingMessage); (frame[1] & 0x7F).ShouldBe(0); } [Fact] public void BuildControlFrame_PongWithPayload() { byte[] payload = [1, 2, 3, 4]; var frame = WsFrameWriter.BuildControlFrame(WsConstants.PongMessage, payload, useMasking: false); frame.Length.ShouldBe(2 + 4); frame[2..].ShouldBe(payload); } } ``` **Step 2: Run test to verify it fails** Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsFrameWriterTests" -v normal` Expected: FAIL **Step 3: Write minimal implementation** Create `src/NATS.Server/WebSocket/WsFrameWriter.cs`: ```csharp using System.Buffers.Binary; using System.Security.Cryptography; using System.Text; namespace NATS.Server.WebSocket; /// /// WebSocket frame construction, masking, and control message creation. /// Ported from golang/nats-server/server/websocket.go lines 543-726. /// public static class WsFrameWriter { /// /// Creates a complete frame header for a single-frame message (first=true, final=true). /// Returns (header bytes, mask key or null). /// public static (byte[] header, byte[]? key) CreateFrameHeader( bool useMasking, bool compressed, int opcode, int payloadLength) { var fh = new byte[WsConstants.MaxFrameHeaderSize]; var (n, key) = FillFrameHeader(fh, useMasking, first: true, final: true, compressed: compressed, opcode: opcode, payloadLength: payloadLength); return (fh[..n], key); } /// /// Fills a pre-allocated frame header buffer. /// Returns (bytes written, mask key or null). /// public static (int written, byte[]? key) FillFrameHeader( Span fh, bool useMasking, bool first, bool final, bool compressed, int opcode, int payloadLength) { byte b0 = first ? (byte)opcode : (byte)0; if (final) b0 |= WsConstants.FinalBit; if (compressed) b0 |= WsConstants.Rsv1Bit; byte b1 = 0; if (useMasking) b1 |= WsConstants.MaskBit; int n; switch (payloadLength) { case <= 125: n = 2; fh[0] = b0; fh[1] = (byte)(b1 | (byte)payloadLength); break; case < 65536: n = 4; fh[0] = b0; fh[1] = (byte)(b1 | 126); BinaryPrimitives.WriteUInt16BigEndian(fh[2..], (ushort)payloadLength); break; default: n = 10; fh[0] = b0; fh[1] = (byte)(b1 | 127); BinaryPrimitives.WriteUInt64BigEndian(fh[2..], (ulong)payloadLength); break; } byte[]? key = null; if (useMasking) { key = new byte[4]; RandomNumberGenerator.Fill(key); key.CopyTo(fh[n..]); n += 4; } return (n, key); } /// /// XOR masks a buffer with a 4-byte key. Applies in-place. /// public static void MaskBuf(ReadOnlySpan key, Span buf) { for (int i = 0; i < buf.Length; i++) buf[i] ^= key[i & 3]; } /// /// XOR masks multiple contiguous buffers as if they were one. /// public static void MaskBufs(ReadOnlySpan key, List bufs) { int pos = 0; foreach (var buf in bufs) { for (int j = 0; j < buf.Length; j++) { buf[j] ^= key[pos & 3]; pos++; } } } /// /// Creates a close message payload: 2-byte status code + optional UTF-8 body. /// Body truncated to fit MaxControlPayloadSize with "..." suffix. /// public static byte[] CreateCloseMessage(int status, string body) { if (body.Length > WsConstants.MaxControlPayloadSize - WsConstants.CloseStatusSize) { body = body[..(WsConstants.MaxControlPayloadSize - WsConstants.CloseStatusSize - 3)] + "..."; } var bodyBytes = Encoding.UTF8.GetBytes(body); var buf = new byte[WsConstants.CloseStatusSize + bodyBytes.Length]; BinaryPrimitives.WriteUInt16BigEndian(buf, (ushort)status); bodyBytes.CopyTo(buf.AsSpan(WsConstants.CloseStatusSize)); return buf; } /// /// Builds a complete control frame (header + payload, optional masking). /// public static byte[] BuildControlFrame(int opcode, ReadOnlySpan payload, bool useMasking) { int headerSize = 2 + (useMasking ? 4 : 0); var frame = new byte[headerSize + payload.Length]; var span = frame.AsSpan(); var (n, key) = FillFrameHeader(span, useMasking, first: true, final: true, compressed: false, opcode: opcode, payloadLength: payload.Length); if (payload.Length > 0) { payload.CopyTo(span[n..]); if (useMasking && key != null) MaskBuf(key, span[n..]); } return frame; } /// /// Maps a ClientClosedReason to a WebSocket close status code. /// Matches Go wsEnqueueCloseMessage in websocket.go lines 668-694. /// public static int MapCloseStatus(ClientClosedReason reason) => reason switch { ClientClosedReason.ClientClosed => WsConstants.CloseStatusNormalClosure, ClientClosedReason.AuthenticationTimeout or ClientClosedReason.AuthenticationViolation or ClientClosedReason.SlowConsumerPendingBytes or ClientClosedReason.SlowConsumerWriteDeadline or ClientClosedReason.MaxSubscriptionsExceeded or ClientClosedReason.AuthenticationExpired => WsConstants.CloseStatusPolicyViolation, ClientClosedReason.TlsHandshakeError => WsConstants.CloseStatusTlsHandshake, ClientClosedReason.ParseError or ClientClosedReason.ProtocolViolation => WsConstants.CloseStatusProtocolError, ClientClosedReason.MaxPayloadExceeded => WsConstants.CloseStatusMessageTooBig, ClientClosedReason.WriteError or ClientClosedReason.ReadError or ClientClosedReason.StaleConnection or ClientClosedReason.ServerShutdown => WsConstants.CloseStatusGoingAway, _ => WsConstants.CloseStatusInternalSrvError, }; } ``` **Step 4: Run test to verify it passes** Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsFrameWriterTests" -v normal` Expected: PASS **Step 5: Commit** ```bash git add src/NATS.Server/WebSocket/WsFrameWriter.cs tests/NATS.Server.Tests/WebSocket/WsFrameWriterTests.cs git commit -m "feat: add WebSocket frame writer with masking and close status mapping" ``` --- ### Task 4: Add WsReadInfo (frame reader state machine) **Files:** - Create: `src/NATS.Server/WebSocket/WsReadInfo.cs` - Create: `tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs` **Reference:** `golang/nats-server/server/websocket.go` lines 156-440 (`wsReadInfo`, `wsRead`, `unmask`, `decompress`) **Step 1: Write the failing test** Create `tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs`: ```csharp using System.Buffers.Binary; using NATS.Server.WebSocket; using Shouldly; namespace NATS.Server.Tests.WebSocket; public class WsFrameReadTests { /// Helper: build a single unmasked binary frame. private static byte[] BuildFrame(byte[] payload, bool fin = true, bool compressed = false, int opcode = WsConstants.BinaryMessage, bool mask = false, byte[]? maskKey = null) { int headerLen = 2; int payloadLen = payload.Length; byte b0 = (byte)opcode; if (fin) b0 |= WsConstants.FinalBit; if (compressed) b0 |= WsConstants.Rsv1Bit; byte b1 = 0; if (mask) b1 |= WsConstants.MaskBit; byte[] lenBytes; if (payloadLen <= 125) { lenBytes = [(byte)(b1 | (byte)payloadLen)]; } else if (payloadLen < 65536) { lenBytes = new byte[3]; lenBytes[0] = (byte)(b1 | 126); BinaryPrimitives.WriteUInt16BigEndian(lenBytes.AsSpan(1), (ushort)payloadLen); } else { lenBytes = new byte[9]; lenBytes[0] = (byte)(b1 | 127); BinaryPrimitives.WriteUInt64BigEndian(lenBytes.AsSpan(1), (ulong)payloadLen); } int totalLen = 1 + lenBytes.Length + (mask ? 4 : 0) + payloadLen; var frame = new byte[totalLen]; frame[0] = b0; lenBytes.CopyTo(frame.AsSpan(1)); int pos = 1 + lenBytes.Length; if (mask && maskKey != null) { maskKey.CopyTo(frame.AsSpan(pos)); pos += 4; var maskedPayload = payload.ToArray(); WsFrameWriter.MaskBuf(maskKey, maskedPayload); maskedPayload.CopyTo(frame.AsSpan(pos)); } else { payload.CopyTo(frame.AsSpan(pos)); } return frame; } [Fact] public void ReadSingleUnmaskedFrame() { var payload = "Hello"u8.ToArray(); var frame = BuildFrame(payload); var readInfo = new WsReadInfo(expectMask: false); var stream = new MemoryStream(frame); var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); result.Count.ShouldBe(1); result[0].ShouldBe(payload); } [Fact] public void ReadMaskedFrame() { var payload = "Hello"u8.ToArray(); byte[] key = [0x37, 0xFA, 0x21, 0x3D]; var frame = BuildFrame(payload, mask: true, maskKey: key); var readInfo = new WsReadInfo(expectMask: true); var stream = new MemoryStream(frame); var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); result.Count.ShouldBe(1); result[0].ShouldBe(payload); } [Fact] public void Read16BitLengthFrame() { var payload = new byte[200]; Random.Shared.NextBytes(payload); var frame = BuildFrame(payload); var readInfo = new WsReadInfo(expectMask: false); var stream = new MemoryStream(frame); var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); result.Count.ShouldBe(1); result[0].ShouldBe(payload); } [Fact] public void ReadPingFrame_ReturnsPongAction() { var frame = BuildFrame([], opcode: WsConstants.PingMessage); var readInfo = new WsReadInfo(expectMask: false); var stream = new MemoryStream(frame); var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); result.Count.ShouldBe(0); // control frames don't produce payload readInfo.PendingControlFrames.Count.ShouldBe(1); readInfo.PendingControlFrames[0].Opcode.ShouldBe(WsConstants.PongMessage); } [Fact] public void ReadCloseFrame_ReturnsCloseAction() { var closePayload = new byte[2]; BinaryPrimitives.WriteUInt16BigEndian(closePayload, 1000); var frame = BuildFrame(closePayload, opcode: WsConstants.CloseMessage); var readInfo = new WsReadInfo(expectMask: false); var stream = new MemoryStream(frame); var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); result.Count.ShouldBe(0); readInfo.CloseReceived.ShouldBeTrue(); readInfo.CloseStatus.ShouldBe(1000); } [Fact] public void ReadPongFrame_NoAction() { var frame = BuildFrame([], opcode: WsConstants.PongMessage); var readInfo = new WsReadInfo(expectMask: false); var stream = new MemoryStream(frame); var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); result.Count.ShouldBe(0); readInfo.PendingControlFrames.Count.ShouldBe(0); } [Fact] public void Unmask_Optimized_8ByteChunks() { byte[] key = [0xAA, 0xBB, 0xCC, 0xDD]; var original = new byte[32]; Random.Shared.NextBytes(original); var masked = original.ToArray(); // Mask it for (int i = 0; i < masked.Length; i++) masked[i] ^= key[i & 3]; // Unmask using the state machine var info = new WsReadInfo(expectMask: true); info.SetMaskKey(key); info.Unmask(masked); masked.ShouldBe(original); } } ``` **Step 2: Run test to verify it fails** Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsFrameReadTests" -v normal` Expected: FAIL **Step 3: Write minimal implementation** Create `src/NATS.Server/WebSocket/WsReadInfo.cs`: ```csharp using System.Buffers.Binary; using System.Text; namespace NATS.Server.WebSocket; /// /// Per-connection WebSocket frame reading state machine. /// Ported from golang/nats-server/server/websocket.go lines 156-506. /// public struct WsReadInfo { public int Remaining; public bool FrameStart; public bool FirstFrame; public bool FrameCompressed; public bool ExpectMask; public byte MaskKeyPos; public byte[] MaskKey; public List? CompressedBuffers; public int CompressedOffset; // Control frame outputs public List PendingControlFrames; public bool CloseReceived; public int CloseStatus; public string? CloseBody; public WsReadInfo(bool expectMask) { Remaining = 0; FrameStart = true; FirstFrame = true; FrameCompressed = false; ExpectMask = expectMask; MaskKeyPos = 0; MaskKey = new byte[4]; CompressedBuffers = null; CompressedOffset = 0; PendingControlFrames = []; CloseReceived = false; CloseStatus = 0; CloseBody = null; } public void SetMaskKey(ReadOnlySpan key) { key[..4].CopyTo(MaskKey); MaskKeyPos = 0; } /// /// Unmask buffer in-place using current mask key and position. /// Optimized for 8-byte chunks when buffer is large enough. /// Ported from websocket.go lines 509-536. /// public void Unmask(Span buf) { int p = MaskKeyPos; if (buf.Length < 16) { for (int i = 0; i < buf.Length; i++) { buf[i] ^= MaskKey[p & 3]; p++; } MaskKeyPos = (byte)(p & 3); return; } // Build 8-byte key for bulk XOR Span k = stackalloc byte[8]; for (int i = 0; i < 8; i++) k[i] = MaskKey[(p + i) & 3]; ulong km = BinaryPrimitives.ReadUInt64BigEndian(k); int n = (buf.Length / 8) * 8; for (int i = 0; i < n; i += 8) { ulong tmp = BinaryPrimitives.ReadUInt64BigEndian(buf[i..]); tmp ^= km; BinaryPrimitives.WriteUInt64BigEndian(buf[i..], tmp); } // Handle remaining bytes var tail = buf[n..]; for (int i = 0; i < tail.Length; i++) { tail[i] ^= MaskKey[p & 3]; p++; } MaskKeyPos = (byte)(p & 3); } /// /// Read and decode WebSocket frames from a buffer. /// Returns list of decoded payload byte arrays. /// Ported from websocket.go lines 208-351. /// public static List ReadFrames(ref WsReadInfo r, Stream stream, int available, int maxPayload) { var bufs = new List(); var buf = new byte[available]; int bytesRead = 0; // Fill the buffer from the stream while (bytesRead < available) { int n = stream.Read(buf, bytesRead, available - bytesRead); if (n == 0) break; bytesRead += n; } int pos = 0; int max = bytesRead; while (pos < max) { if (r.FrameStart) { if (pos >= max) break; byte b0 = buf[pos]; int frameType = b0 & 0x0F; bool final = (b0 & WsConstants.FinalBit) != 0; bool compressed = (b0 & WsConstants.Rsv1Bit) != 0; pos++; // Read second byte var (b1Buf, newPos) = WsGet(stream, buf, pos, max, 1); pos = newPos; byte b1 = b1Buf[0]; // Check mask bit if (r.ExpectMask && (b1 & WsConstants.MaskBit) == 0) throw new InvalidOperationException("mask bit missing"); r.Remaining = b1 & 0x7F; // Validate frame types if (WsConstants.IsControlFrame(frameType)) { if (r.Remaining > WsConstants.MaxControlPayloadSize) throw new InvalidOperationException("control frame length too large"); if (!final) throw new InvalidOperationException("control frame does not have final bit set"); } else if (frameType == WsConstants.TextMessage || frameType == WsConstants.BinaryMessage) { if (!r.FirstFrame) throw new InvalidOperationException("new message before previous finished"); r.FirstFrame = final; r.FrameCompressed = compressed; } else if (frameType == WsConstants.ContinuationFrame) { if (r.FirstFrame || compressed) throw new InvalidOperationException("invalid continuation frame"); r.FirstFrame = final; } else { throw new InvalidOperationException($"unknown opcode {frameType}"); } // Extended payload length switch (r.Remaining) { case 126: { var (lenBuf, p2) = WsGet(stream, buf, pos, max, 2); pos = p2; r.Remaining = BinaryPrimitives.ReadUInt16BigEndian(lenBuf); break; } case 127: { var (lenBuf, p2) = WsGet(stream, buf, pos, max, 8); pos = p2; r.Remaining = (int)BinaryPrimitives.ReadUInt64BigEndian(lenBuf); break; } } // Read mask key if (r.ExpectMask && (b1 & WsConstants.MaskBit) != 0) { var (keyBuf, p2) = WsGet(stream, buf, pos, max, 4); pos = p2; keyBuf.AsSpan(0, 4).CopyTo(r.MaskKey); r.MaskKeyPos = 0; } // Handle control frames if (WsConstants.IsControlFrame(frameType)) { pos = HandleControlFrame(ref r, frameType, stream, buf, pos, max); continue; } r.FrameStart = false; } if (pos < max) { int n = r.Remaining; if (pos + n > max) n = max - pos; var payloadSlice = buf.AsSpan(pos, n).ToArray(); pos += n; r.Remaining -= n; if (r.ExpectMask) r.Unmask(payloadSlice); bool addToBufs = true; if (r.FrameCompressed) { addToBufs = false; r.CompressedBuffers ??= []; r.CompressedBuffers.Add(payloadSlice); if (r.FirstFrame && r.Remaining == 0) { var decompressed = WsCompression.Decompress(r.CompressedBuffers, maxPayload); r.CompressedBuffers = null; r.FrameCompressed = false; addToBufs = true; payloadSlice = decompressed; } } if (addToBufs && payloadSlice.Length > 0) bufs.Add(payloadSlice); if (r.Remaining == 0) r.FrameStart = true; } } return bufs; } private static int HandleControlFrame(ref WsReadInfo r, int frameType, Stream stream, byte[] buf, int pos, int max) { byte[]? payload = null; if (r.Remaining > 0) { var (payloadBuf, newPos) = WsGet(stream, buf, pos, max, r.Remaining); pos = newPos; payload = payloadBuf; if (r.ExpectMask) r.Unmask(payload); r.Remaining = 0; } switch (frameType) { case WsConstants.CloseMessage: r.CloseReceived = true; r.CloseStatus = WsConstants.CloseStatusNoStatusReceived; if (payload != null && payload.Length >= WsConstants.CloseStatusSize) { r.CloseStatus = BinaryPrimitives.ReadUInt16BigEndian(payload); if (payload.Length > WsConstants.CloseStatusSize) r.CloseBody = Encoding.UTF8.GetString(payload.AsSpan(WsConstants.CloseStatusSize)); } if (r.CloseStatus != WsConstants.CloseStatusNoStatusReceived) { var closeMsg = WsFrameWriter.CreateCloseMessage(r.CloseStatus, r.CloseBody ?? ""); r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.CloseMessage, closeMsg)); } break; case WsConstants.PingMessage: r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.PongMessage, payload ?? [])); break; case WsConstants.PongMessage: // Nothing to do break; } return pos; } /// /// Gets needed bytes from buffer or reads from stream. /// Ported from websocket.go lines 178-193. /// private static (byte[] data, int newPos) WsGet(Stream stream, byte[] buf, int pos, int max, int needed) { int avail = max - pos; if (avail >= needed) return (buf[pos..(pos + needed)], pos + needed); var b = new byte[needed]; int start = 0; if (avail > 0) { Buffer.BlockCopy(buf, pos, b, 0, avail); start = avail; } while (start < needed) { int n = stream.Read(b, start, needed - start); if (n == 0) throw new IOException("unexpected end of stream"); start += n; } return (b, pos + avail); } } public readonly record struct ControlFrameAction(int Opcode, byte[] Payload); ``` **Step 4: Run test to verify it passes** Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsFrameReadTests" -v normal` Expected: PASS **Step 5: Commit** ```bash git add src/NATS.Server/WebSocket/WsReadInfo.cs tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs git commit -m "feat: add WebSocket frame reader state machine" ``` --- ### Task 5: Add WsCompression (permessage-deflate) **Files:** - Create: `src/NATS.Server/WebSocket/WsCompression.cs` - Create: `tests/NATS.Server.Tests/WebSocket/WsCompressionTests.cs` **Reference:** `golang/nats-server/server/websocket.go` lines 403-440 (decompress), lines 1391-1466 (compress) **Step 1: Write the failing test** Create `tests/NATS.Server.Tests/WebSocket/WsCompressionTests.cs`: ```csharp using NATS.Server.WebSocket; using Shouldly; namespace NATS.Server.Tests.WebSocket; public class WsCompressionTests { [Fact] public void CompressDecompress_RoundTrip() { var original = "Hello, WebSocket compression test! This is long enough to compress."u8.ToArray(); var compressed = WsCompression.Compress(original); compressed.ShouldNotBeNull(); compressed.Length.ShouldBeGreaterThan(0); var decompressed = WsCompression.Decompress([compressed], maxPayload: 4096); decompressed.ShouldBe(original); } [Fact] public void Decompress_ExceedsMaxPayload_Throws() { var original = new byte[1000]; Random.Shared.NextBytes(original); var compressed = WsCompression.Compress(original); Should.Throw(() => WsCompression.Decompress([compressed], maxPayload: 100)); } [Fact] public void Compress_RemovesTrailing4Bytes() { var data = new byte[200]; Random.Shared.NextBytes(data); var compressed = WsCompression.Compress(data); // The compressed data should be valid for decompression when we add the trailer back var decompressed = WsCompression.Decompress([compressed], maxPayload: 4096); decompressed.ShouldBe(data); } [Fact] public void Decompress_MultipleBuffers() { var original = new byte[500]; Random.Shared.NextBytes(original); var compressed = WsCompression.Compress(original); // Split compressed data into multiple chunks int mid = compressed.Length / 2; var chunk1 = compressed[..mid]; var chunk2 = compressed[mid..]; var decompressed = WsCompression.Decompress([chunk1, chunk2], maxPayload: 4096); decompressed.ShouldBe(original); } } ``` **Step 2: Run test to verify it fails** Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsCompressionTests" -v normal` Expected: FAIL **Step 3: Write minimal implementation** Create `src/NATS.Server/WebSocket/WsCompression.cs`: ```csharp using System.IO.Compression; namespace NATS.Server.WebSocket; /// /// permessage-deflate compression/decompression for WebSocket frames (RFC 7692). /// Ported from golang/nats-server/server/websocket.go lines 403-440 and 1391-1466. /// public static class WsCompression { /// /// Compresses data using deflate. Removes trailing 4 bytes (sync marker) /// per RFC 7692 Section 7.2.1. /// public static byte[] Compress(ReadOnlySpan data) { using var output = new MemoryStream(); using (var deflate = new DeflateStream(output, CompressionLevel.Fastest, leaveOpen: true)) { deflate.Write(data); deflate.Flush(); } var compressed = output.ToArray(); // Remove trailing 4-byte sync marker (0x00 0x00 0xff 0xff) per RFC 7692 if (compressed.Length >= 4) return compressed[..^4]; return compressed; } /// /// Decompresses collected compressed buffers. /// Appends trailer bytes before decompressing per RFC 7692 Section 7.2.2. /// public static byte[] Decompress(List compressedBuffers, int maxPayload) { if (maxPayload <= 0) maxPayload = 1024 * 1024; // Default 1MB // Concatenate all compressed buffers + trailer int totalLen = 0; foreach (var buf in compressedBuffers) totalLen += buf.Length; totalLen += WsConstants.DecompressTrailer.Length; var combined = new byte[totalLen]; int offset = 0; foreach (var buf in compressedBuffers) { buf.CopyTo(combined, offset); offset += buf.Length; } WsConstants.DecompressTrailer.CopyTo(combined, offset); using var input = new MemoryStream(combined); using var deflate = new DeflateStream(input, CompressionMode.Decompress); using var output = new MemoryStream(); var readBuf = new byte[4096]; int totalRead = 0; int n; while ((n = deflate.Read(readBuf, 0, readBuf.Length)) > 0) { totalRead += n; if (totalRead > maxPayload) throw new InvalidOperationException("decompressed data exceeds maximum payload size"); output.Write(readBuf, 0, n); } return output.ToArray(); } } ``` **Step 4: Run test to verify it passes** Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsCompressionTests" -v normal` Expected: PASS **Step 5: Commit** ```bash git add src/NATS.Server/WebSocket/WsCompression.cs tests/NATS.Server.Tests/WebSocket/WsCompressionTests.cs git commit -m "feat: add WebSocket permessage-deflate compression" ``` --- ### Task 6: Add WsUpgrade (HTTP upgrade handshake) **Files:** - Create: `src/NATS.Server/WebSocket/WsUpgrade.cs` - Create: `tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs` **Reference:** `golang/nats-server/server/websocket.go` lines 731-917 (`wsUpgrade`, `wsHeaderContains`, `wsAcceptKey`, `wsPMCExtensionSupport`) **Step 1: Write the failing test** Create `tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs`: ```csharp using System.Text; using NATS.Server.WebSocket; using Shouldly; namespace NATS.Server.Tests.WebSocket; public class WsUpgradeTests { private static string BuildValidRequest(string path = "/", string? extraHeaders = null) { var sb = new StringBuilder(); sb.AppendLine($"GET {path} HTTP/1.1"); sb.AppendLine("Host: localhost:4222"); sb.AppendLine("Upgrade: websocket"); sb.AppendLine("Connection: Upgrade"); sb.AppendLine("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ=="); sb.AppendLine("Sec-WebSocket-Version: 13"); if (extraHeaders != null) sb.Append(extraHeaders); sb.AppendLine(); return sb.ToString(); } [Fact] public async Task ValidUpgrade_Returns101() { var request = BuildValidRequest(); var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeTrue(); result.Kind.ShouldBe(WsClientKind.Client); var response = ReadResponse(outputStream); response.ShouldContain("HTTP/1.1 101"); response.ShouldContain("Upgrade: websocket"); response.ShouldContain("Sec-WebSocket-Accept:"); } [Fact] public async Task MissingUpgradeHeader_Returns400() { var request = "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n"; var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeFalse(); ReadResponse(outputStream).ShouldContain("400"); } [Fact] public async Task MissingHost_Returns400() { var request = "GET / HTTP/1.1\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n"; var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeFalse(); } [Fact] public async Task WrongVersion_Returns400() { var request = "GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 12\r\n\r\n"; var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeFalse(); } [Fact] public async Task LeafNodePath_ReturnsLeafKind() { var request = BuildValidRequest("/leafnode"); var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeTrue(); result.Kind.ShouldBe(WsClientKind.Leaf); } [Fact] public async Task MqttPath_ReturnsMqttKind() { var request = BuildValidRequest("/mqtt"); var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeTrue(); result.Kind.ShouldBe(WsClientKind.Mqtt); } [Fact] public async Task CompressionNegotiation_WhenEnabled() { var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}; {WsConstants.PmcSrvNoCtx}; {WsConstants.PmcCliNoCtx}\r\n"); var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true }); result.Success.ShouldBeTrue(); result.Compress.ShouldBeTrue(); ReadResponse(outputStream).ShouldContain("permessage-deflate"); } [Fact] public async Task CompressionNegotiation_WhenDisabled() { var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n"); var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = false }); result.Success.ShouldBeTrue(); result.Compress.ShouldBeFalse(); } [Fact] public async Task NoMaskingHeader_ForLeaf() { var request = BuildValidRequest("/leafnode", "Nats-No-Masking: true\r\n"); var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeTrue(); result.MaskRead.ShouldBeFalse(); } [Fact] public async Task BrowserDetection_Mozilla() { var request = BuildValidRequest(extraHeaders: "User-Agent: Mozilla/5.0 (Windows)\r\n"); var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeTrue(); result.Browser.ShouldBeTrue(); } [Fact] public async Task SafariDetection_NoCompFrag() { var request = BuildValidRequest(extraHeaders: "User-Agent: Mozilla/5.0 (Macintosh) Version/15.0 Safari/605.1.15\r\n" + $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n"); var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true }); result.Success.ShouldBeTrue(); result.NoCompFrag.ShouldBeTrue(); } [Fact] public async Task AcceptKey_MatchesRfc6455Example() { // RFC 6455 Section 4.2.2 example var key = WsUpgrade.ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ=="); key.ShouldBe("s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); } [Fact] public async Task CookieExtraction() { var request = BuildValidRequest(extraHeaders: "Cookie: jwt_token=my-jwt; nats_user=admin; nats_pass=secret\r\n"); var (inputStream, outputStream) = CreateStreamPair(request); var opts = new WebSocketOptions { NoTls = true, JwtCookie = "jwt_token", UsernameCookie = "nats_user", PasswordCookie = "nats_pass", }; var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts); result.Success.ShouldBeTrue(); result.CookieJwt.ShouldBe("my-jwt"); result.CookieUsername.ShouldBe("admin"); result.CookiePassword.ShouldBe("secret"); } [Fact] public async Task XForwardedFor_ExtractsClientIp() { var request = BuildValidRequest(extraHeaders: "X-Forwarded-For: 192.168.1.100\r\n"); var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeTrue(); result.ClientIp.ShouldBe("192.168.1.100"); } [Fact] public async Task PostMethod_Returns405() { var request = "POST / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n"; var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeFalse(); ReadResponse(outputStream).ShouldContain("405"); } // Helper: create a readable input stream and writable output stream private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest) { var inputBytes = Encoding.ASCII.GetBytes(httpRequest); return (new MemoryStream(inputBytes), new MemoryStream()); } private static string ReadResponse(MemoryStream output) { output.Position = 0; return Encoding.ASCII.GetString(output.ToArray()); } } ``` **Step 2: Run test to verify it fails** Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsUpgradeTests" -v normal` Expected: FAIL **Step 3: Write minimal implementation** Create `src/NATS.Server/WebSocket/WsUpgrade.cs`: ```csharp using System.Net; using System.Security.Cryptography; using System.Text; namespace NATS.Server.WebSocket; /// /// WebSocket HTTP upgrade handshake handler. /// Ported from golang/nats-server/server/websocket.go lines 731-917. /// public static class WsUpgrade { /// /// Attempts to read an HTTP upgrade request from the input stream, /// validate per RFC 6455, and write the 101 response to the output stream. /// public static async Task TryUpgradeAsync( Stream inputStream, Stream outputStream, WebSocketOptions options) { try { // Read HTTP request var (method, path, headers) = await ReadHttpRequestAsync(inputStream); // RFC 6455 Section 4.2.1 validation // Point 1: Method must be GET if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase)) return await FailAsync(outputStream, 405, "request method must be GET"); // Point 2: Host header required if (!headers.ContainsKey("Host")) return await FailAsync(outputStream, 400, "'Host' missing in request"); // Point 3: Upgrade header must contain "websocket" if (!HeaderContains(headers, "Upgrade", "websocket")) return await FailAsync(outputStream, 400, "invalid value for header 'Upgrade'"); // Point 4: Connection header must contain "Upgrade" if (!HeaderContains(headers, "Connection", "Upgrade")) return await FailAsync(outputStream, 400, "invalid value for header 'Connection'"); // Point 5: Sec-WebSocket-Key required if (!headers.TryGetValue("Sec-WebSocket-Key", out var key) || string.IsNullOrEmpty(key)) return await FailAsync(outputStream, 400, "key missing"); // Point 6: Version must be 13 if (!HeaderContains(headers, "Sec-WebSocket-Version", "13")) return await FailAsync(outputStream, 400, "invalid version"); // Path routing var kind = path switch { _ when path.EndsWith("/leafnode") => WsClientKind.Leaf, _ when path.EndsWith("/mqtt") => WsClientKind.Mqtt, _ => WsClientKind.Client, }; // Origin checking if (options.SameOrigin || options.AllowedOrigins is { Count: > 0 }) { var checker = new WsOriginChecker(options.SameOrigin, options.AllowedOrigins); headers.TryGetValue("Origin", out var origin); if (string.IsNullOrEmpty(origin)) headers.TryGetValue("Sec-WebSocket-Origin", out origin); var originErr = checker.CheckOrigin(origin, headers.GetValueOrDefault("Host", ""), isTls: false); if (originErr != null) return await FailAsync(outputStream, 403, $"origin not allowed: {originErr}"); } // Compression negotiation bool compress = options.Compression; if (compress) { compress = headers.TryGetValue("Sec-WebSocket-Extensions", out var ext) && ext.Contains(WsConstants.PmcExtension, StringComparison.OrdinalIgnoreCase); } // No-masking negotiation bool noMasking = headers.TryGetValue(WsConstants.NoMaskingHeader, out var nmVal) && string.Equals(nmVal.Trim(), WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase); // Browser detection bool browser = false; bool noCompFrag = false; if (kind is WsClientKind.Client or WsClientKind.Mqtt && headers.TryGetValue("User-Agent", out var ua) && ua.StartsWith("Mozilla/")) { browser = true; noCompFrag = compress && ua.Contains("Version/") && ua.Contains("Safari/"); } // Cookie extraction string? cookieJwt = null, cookieUsername = null, cookiePassword = null, cookieToken = null; if ((kind is WsClientKind.Client or WsClientKind.Mqtt) && headers.TryGetValue("Cookie", out var cookieHeader)) { var cookies = ParseCookies(cookieHeader); if (options.JwtCookie != null) cookies.TryGetValue(options.JwtCookie, out cookieJwt); if (options.UsernameCookie != null) cookies.TryGetValue(options.UsernameCookie, out cookieUsername); if (options.PasswordCookie != null) cookies.TryGetValue(options.PasswordCookie, out cookiePassword); if (options.TokenCookie != null) cookies.TryGetValue(options.TokenCookie, out cookieToken); } // X-Forwarded-For string? clientIp = null; if (headers.TryGetValue(WsConstants.XForwardedForHeader, out var xff)) { var ip = xff.Split(',')[0].Trim(); if (IPAddress.TryParse(ip, out _)) clientIp = ip; } // Build 101 response var response = new StringBuilder(); response.Append("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "); response.Append(ComputeAcceptKey(key)); response.Append("\r\n"); if (compress) response.Append(WsConstants.PmcFullResponse); if (noMasking) response.Append(WsConstants.NoMaskingFullResponse); if (options.Headers != null) { foreach (var (k, v) in options.Headers) { response.Append(k); response.Append(": "); response.Append(v); response.Append("\r\n"); } } response.Append("\r\n"); var responseBytes = Encoding.ASCII.GetBytes(response.ToString()); await outputStream.WriteAsync(responseBytes); await outputStream.FlushAsync(); return new WsUpgradeResult( Success: true, Compress: compress, Browser: browser, NoCompFrag: noCompFrag, MaskRead: !noMasking, MaskWrite: false, CookieJwt: cookieJwt, CookieUsername: cookieUsername, CookiePassword: cookiePassword, CookieToken: cookieToken, ClientIp: clientIp, Kind: kind); } catch (Exception) { return WsUpgradeResult.Failed; } } /// /// Computes the Sec-WebSocket-Accept value per RFC 6455 Section 4.2.2. /// public static string ComputeAcceptKey(string clientKey) { var combined = Encoding.ASCII.GetBytes(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); var hash = SHA1.HashData(combined); return Convert.ToBase64String(hash); } private static async Task FailAsync(Stream output, int statusCode, string reason) { var statusText = statusCode switch { 400 => "Bad Request", 403 => "Forbidden", 405 => "Method Not Allowed", _ => "Internal Server Error", }; var response = $"HTTP/1.1 {statusCode} {statusText}\r\nSec-WebSocket-Version: 13\r\nContent-Type: text/plain\r\nContent-Length: {reason.Length}\r\n\r\n{reason}"; await output.WriteAsync(Encoding.ASCII.GetBytes(response)); await output.FlushAsync(); return WsUpgradeResult.Failed; } private static async Task<(string method, string path, Dictionary headers)> ReadHttpRequestAsync(Stream stream) { var headerBytes = new List(4096); int prev = 0; var buf = new byte[1]; // Read until \r\n\r\n while (true) { int n = await stream.ReadAsync(buf); if (n == 0) throw new IOException("connection closed during handshake"); headerBytes.Add(buf[0]); if (headerBytes.Count >= 4 && headerBytes[^4] == '\r' && headerBytes[^3] == '\n' && headerBytes[^2] == '\r' && headerBytes[^1] == '\n') break; if (headerBytes.Count > 8192) throw new InvalidOperationException("HTTP header too large"); } var text = Encoding.ASCII.GetString(headerBytes.ToArray()); var lines = text.Split("\r\n", StringSplitOptions.None); if (lines.Length < 1) throw new InvalidOperationException("invalid HTTP request"); // Parse request line var parts = lines[0].Split(' '); if (parts.Length < 3) throw new InvalidOperationException("invalid HTTP request line"); var method = parts[0]; var path = parts[1]; // Parse headers var headers = new Dictionary(StringComparer.OrdinalIgnoreCase); for (int i = 1; i < lines.Length; i++) { var line = lines[i]; if (string.IsNullOrEmpty(line)) break; var colonIdx = line.IndexOf(':'); if (colonIdx > 0) { var name = line[..colonIdx].Trim(); var value = line[(colonIdx + 1)..].Trim(); headers[name] = value; } } return (method, path, headers); } private static bool HeaderContains(Dictionary headers, string name, string value) { if (!headers.TryGetValue(name, out var headerValue)) return false; foreach (var token in headerValue.Split(',')) { if (string.Equals(token.Trim(), value, StringComparison.OrdinalIgnoreCase)) return true; } return false; } private static Dictionary ParseCookies(string cookieHeader) { var cookies = new Dictionary(StringComparer.Ordinal); foreach (var pair in cookieHeader.Split(';')) { var trimmed = pair.Trim(); var eqIdx = trimmed.IndexOf('='); if (eqIdx > 0) cookies[trimmed[..eqIdx].Trim()] = trimmed[(eqIdx + 1)..].Trim(); } return cookies; } } public readonly record struct WsUpgradeResult( bool Success, bool Compress, bool Browser, bool NoCompFrag, bool MaskRead, bool MaskWrite, string? CookieJwt, string? CookieUsername, string? CookiePassword, string? CookieToken, string? ClientIp, WsClientKind Kind) { public static readonly WsUpgradeResult Failed = new( Success: false, Compress: false, Browser: false, NoCompFrag: false, MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null, CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client); } ``` **Step 4: Run test to verify it passes** Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsUpgradeTests" -v normal` Expected: PASS **Step 5: Commit** ```bash git add src/NATS.Server/WebSocket/WsUpgrade.cs tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs git commit -m "feat: add WebSocket HTTP upgrade handshake" ``` --- ### Task 7: Add WsConnection Stream wrapper **Files:** - Create: `src/NATS.Server/WebSocket/WsConnection.cs` - Create: `tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs` **Step 1: Write the failing test** Create `tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs`: ```csharp using System.Buffers.Binary; using NATS.Server.WebSocket; using Shouldly; namespace NATS.Server.Tests.WebSocket; public class WsConnectionTests { [Fact] public async Task ReadAsync_DecodesFrameAndReturnsPayload() { var payload = "SUB test 1\r\n"u8.ToArray(); var frame = BuildUnmaskedFrame(payload); var inner = new MemoryStream(frame); var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false); var buf = new byte[256]; int n = await ws.ReadAsync(buf); n.ShouldBe(payload.Length); buf[..n].ShouldBe(payload); } [Fact] public async Task WriteAsync_FramesPayload() { var inner = new MemoryStream(); var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false); var payload = "MSG test 1 5\r\nHello\r\n"u8.ToArray(); await ws.WriteAsync(payload); await ws.FlushAsync(); inner.Position = 0; var written = inner.ToArray(); // First 2 bytes should be WS frame header (written[0] & WsConstants.FinalBit).ShouldNotBe(0); (written[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage); int len = written[1] & 0x7F; len.ShouldBe(payload.Length); written[2..].ShouldBe(payload); } [Fact] public async Task WriteAsync_WithCompression_CompressesLargePayload() { var inner = new MemoryStream(); var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false); var payload = new byte[200]; Array.Fill(payload, 0x41); // 'A' repeated - very compressible await ws.WriteAsync(payload); await ws.FlushAsync(); inner.Position = 0; var written = inner.ToArray(); // RSV1 bit should be set for compressed frame (written[0] & WsConstants.Rsv1Bit).ShouldNotBe(0); // Compressed size should be less than original written.Length.ShouldBeLessThan(payload.Length + 10); } [Fact] public async Task WriteAsync_SmallPayload_NotCompressedEvenWhenEnabled() { var inner = new MemoryStream(); var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false); var payload = "Hi"u8.ToArray(); // Below CompressThreshold await ws.WriteAsync(payload); await ws.FlushAsync(); inner.Position = 0; var written = inner.ToArray(); // RSV1 bit should NOT be set for small payloads (written[0] & WsConstants.Rsv1Bit).ShouldBe(0); } private static byte[] BuildUnmaskedFrame(byte[] payload) { var header = new byte[2]; header[0] = (byte)(WsConstants.FinalBit | WsConstants.BinaryMessage); header[1] = (byte)payload.Length; var frame = new byte[2 + payload.Length]; header.CopyTo(frame, 0); payload.CopyTo(frame, 2); return frame; } } ``` **Step 2: Run test to verify it fails** Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsConnectionTests" -v normal` Expected: FAIL **Step 3: Write minimal implementation** Create `src/NATS.Server/WebSocket/WsConnection.cs`: ```csharp namespace NATS.Server.WebSocket; /// /// Stream wrapper that transparently frames/deframes WebSocket around raw TCP I/O. /// NatsClient uses this as its _stream — FillPipeAsync and RunWriteLoopAsync work unchanged. /// public sealed class WsConnection : Stream { private readonly Stream _inner; private readonly bool _compress; private readonly bool _maskRead; private readonly bool _maskWrite; private readonly bool _browser; private readonly bool _noCompFrag; private WsReadInfo _readInfo; private readonly Queue _readQueue = new(); private int _readOffset; private readonly object _writeLock = new(); private readonly List _pendingControlWrites = []; public bool CloseReceived => _readInfo.CloseReceived; public int CloseStatus => _readInfo.CloseStatus; public WsConnection(Stream inner, bool compress, bool maskRead, bool maskWrite, bool browser, bool noCompFrag) { _inner = inner; _compress = compress; _maskRead = maskRead; _maskWrite = maskWrite; _browser = browser; _noCompFrag = noCompFrag; _readInfo = new WsReadInfo(expectMask: maskRead); } public override async ValueTask ReadAsync(Memory buffer, CancellationToken ct = default) { // Drain any buffered decoded payloads first if (_readQueue.Count > 0) return DrainReadQueue(buffer.Span); // Read raw bytes from inner stream var rawBuf = new byte[Math.Max(buffer.Length, 4096)]; int bytesRead = await _inner.ReadAsync(rawBuf.AsMemory(0, rawBuf.Length), ct); if (bytesRead == 0) return 0; // Decode frames var payloads = WsReadInfo.ReadFrames(ref _readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024); // Collect control frame responses if (_readInfo.PendingControlFrames.Count > 0) { lock (_writeLock) _pendingControlWrites.AddRange(_readInfo.PendingControlFrames); _readInfo.PendingControlFrames.Clear(); // Write pending control frames await FlushControlFramesAsync(ct); } if (_readInfo.CloseReceived) return 0; foreach (var payload in payloads) _readQueue.Enqueue(payload); if (_readQueue.Count == 0) return 0; return DrainReadQueue(buffer.Span); } public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken ct = default) { var data = buffer.Span; if (_compress && data.Length > WsConstants.CompressThreshold) { var compressed = WsCompression.Compress(data); WriteFramed(compressed, compressed: true, ct); } else { WriteFramed(data.ToArray(), compressed: false, ct); } } private void WriteFramed(byte[] payload, bool compressed, CancellationToken ct) { if (_browser && payload.Length > WsConstants.FrameSizeForBrowsers && !(_noCompFrag && compressed)) { // Fragment for browsers int offset = 0; bool first = true; while (offset < payload.Length) { int chunkLen = Math.Min(WsConstants.FrameSizeForBrowsers, payload.Length - offset); bool final = offset + chunkLen >= payload.Length; var fh = new byte[WsConstants.MaxFrameHeaderSize]; var (n, key) = WsFrameWriter.FillFrameHeader(fh, _maskWrite, first: first, final: final, compressed: first && compressed, opcode: WsConstants.BinaryMessage, payloadLength: chunkLen); var chunk = payload.AsSpan(offset, chunkLen).ToArray(); if (_maskWrite && key != null) WsFrameWriter.MaskBuf(key, chunk); _inner.Write(fh, 0, n); _inner.Write(chunk, 0, chunkLen); offset += chunkLen; first = false; } } else { var (header, key) = WsFrameWriter.CreateFrameHeader(_maskWrite, compressed, WsConstants.BinaryMessage, payload.Length); if (_maskWrite && key != null) WsFrameWriter.MaskBuf(key, payload); _inner.Write(header); _inner.Write(payload); } } private async Task FlushControlFramesAsync(CancellationToken ct) { List toWrite; lock (_writeLock) { if (_pendingControlWrites.Count == 0) return; toWrite = [.. _pendingControlWrites]; _pendingControlWrites.Clear(); } foreach (var action in toWrite) { var frame = WsFrameWriter.BuildControlFrame(action.Opcode, action.Payload, _maskWrite); await _inner.WriteAsync(frame, ct); } await _inner.FlushAsync(ct); } /// /// Sends a WebSocket close frame. /// public async Task SendCloseAsync(ClientClosedReason reason, CancellationToken ct = default) { var status = WsFrameWriter.MapCloseStatus(reason); var closePayload = WsFrameWriter.CreateCloseMessage(status, reason.ToReasonString()); var frame = WsFrameWriter.BuildControlFrame(WsConstants.CloseMessage, closePayload, _maskWrite); await _inner.WriteAsync(frame, ct); await _inner.FlushAsync(ct); } private int DrainReadQueue(Span buffer) { int written = 0; while (_readQueue.Count > 0 && written < buffer.Length) { var current = _readQueue.Peek(); int available = current.Length - _readOffset; int toCopy = Math.Min(available, buffer.Length - written); current.AsSpan(_readOffset, toCopy).CopyTo(buffer[written..]); written += toCopy; _readOffset += toCopy; if (_readOffset >= current.Length) { _readQueue.Dequeue(); _readOffset = 0; } } return written; } // Stream abstract members public override bool CanRead => true; public override bool CanWrite => true; public override bool CanSeek => false; public override long Length => throw new NotSupportedException(); public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } public override void Flush() => _inner.Flush(); public override Task FlushAsync(CancellationToken ct) => _inner.FlushAsync(ct); public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use ReadAsync"); public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use WriteAsync"); public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); public override void SetLength(long value) => throw new NotSupportedException(); protected override void Dispose(bool disposing) { if (disposing) _inner.Dispose(); base.Dispose(disposing); } } ``` **Step 4: Run test to verify it passes** Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsConnectionTests" -v normal` Expected: PASS **Step 5: Commit** ```bash git add src/NATS.Server/WebSocket/WsConnection.cs tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs git commit -m "feat: add WsConnection Stream wrapper for transparent framing" ``` --- ### Task 8: Integrate WebSocket into NatsServer and NatsClient **Files:** - Modify: `src/NATS.Server/NatsServer.cs` - Modify: `src/NATS.Server/NatsClient.cs` **Step 1: Write the failing test** Create `tests/NATS.Server.Tests/WebSocket/WsIntegrationTests.cs`: ```csharp using System.Buffers.Binary; using System.Net; using System.Net.Sockets; using System.Security.Cryptography; using System.Text; using NATS.Server.WebSocket; using Shouldly; namespace NATS.Server.Tests.WebSocket; public class WsIntegrationTests : IAsyncLifetime { private NatsServer _server = null!; private NatsOptions _options = null!; public async Task InitializeAsync() { _options = new NatsOptions { Port = 0, WebSocket = new WebSocketOptions { Port = 0, NoTls = true }, }; var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(b => { }); _server = new NatsServer(_options, loggerFactory); _ = _server.StartAsync(CancellationToken.None); await _server.WaitForReadyAsync(); } public async Task DisposeAsync() { await _server.ShutdownAsync(); _server.Dispose(); } [Fact] public async Task WebSocket_ConnectAndReceiveInfo() { using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port)); using var stream = new NetworkStream(socket, ownsSocket: false); // Send WebSocket upgrade request await SendUpgradeRequest(stream); // Read 101 response var response = await ReadHttpResponse(stream); response.ShouldContain("101"); // Now read the INFO line through WebSocket frames var wsFrame = await ReadWsFrame(stream); var info = Encoding.ASCII.GetString(wsFrame); info.ShouldStartWith("INFO "); } [Fact] public async Task WebSocket_PubSub() { // Connect two WS clients using var sub = await ConnectWsClient(); using var pub = await ConnectWsClient(); // Subscribe on first client await SendWsText(sub, "CONNECT {}\r\nSUB test.ws 1\r\n"); await Task.Delay(100); // Publish on second client await SendWsText(pub, "CONNECT {}\r\nPUB test.ws 5\r\nHello\r\n"); await Task.Delay(100); // Read from subscriber var msg = await ReadWsFrame(sub); Encoding.ASCII.GetString(msg).ShouldContain("MSG test.ws 1 5"); } private async Task ConnectWsClient() { var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port)); var stream = new NetworkStream(socket, ownsSocket: true); await SendUpgradeRequest(stream); var response = await ReadHttpResponse(stream); response.ShouldContain("101"); // Read INFO frame await ReadWsFrame(stream); return stream; } private static async Task SendUpgradeRequest(NetworkStream stream) { var keyBytes = new byte[16]; RandomNumberGenerator.Fill(keyBytes); var key = Convert.ToBase64String(keyBytes); var request = $"GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {key}\r\nSec-WebSocket-Version: 13\r\n\r\n"; await stream.WriteAsync(Encoding.ASCII.GetBytes(request)); await stream.FlushAsync(); } private static async Task ReadHttpResponse(NetworkStream stream) { var buf = new byte[4096]; var sb = new StringBuilder(); while (true) { int n = await stream.ReadAsync(buf); if (n == 0) break; sb.Append(Encoding.ASCII.GetString(buf, 0, n)); if (sb.ToString().Contains("\r\n\r\n")) break; } return sb.ToString(); } private static async Task ReadWsFrame(NetworkStream stream) { var header = new byte[2]; await stream.ReadExactlyAsync(header); int len = header[1] & 0x7F; byte[]? extLen = null; if (len == 126) { extLen = new byte[2]; await stream.ReadExactlyAsync(extLen); len = BinaryPrimitives.ReadUInt16BigEndian(extLen); } else if (len == 127) { extLen = new byte[8]; await stream.ReadExactlyAsync(extLen); len = (int)BinaryPrimitives.ReadUInt64BigEndian(extLen); } var payload = new byte[len]; if (len > 0) await stream.ReadExactlyAsync(payload); return payload; } private static async Task SendWsText(NetworkStream stream, string text) { var payload = Encoding.ASCII.GetBytes(text); var (header, _) = WsFrameWriter.CreateFrameHeader( useMasking: true, compressed: false, opcode: WsConstants.BinaryMessage, payloadLength: payload.Length); // The masking key is in the header — we need to mask the payload var maskKey = header[^4..]; WsFrameWriter.MaskBuf(maskKey, payload); await stream.WriteAsync(header); await stream.WriteAsync(payload); await stream.FlushAsync(); } } ``` **Step 2: Run test to verify it fails** Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsIntegrationTests" -v normal` Expected: FAIL — NatsServer has no WebSocket listener **Step 3: Modify NatsServer.cs** Add these fields to `NatsServer`: ```csharp private Socket? _wsListener; private readonly TaskCompletionSource _wsAcceptLoopExited = new(TaskCreationOptions.RunContinuationsAsynchronously); ``` In `StartAsync`, after the monitoring server startup and before the main accept loop, add: ```csharp if (_options.WebSocket.Port > 0) { _wsListener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); _wsListener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true); _wsListener.Bind(new IPEndPoint( _options.WebSocket.Host == "0.0.0.0" ? IPAddress.Any : IPAddress.Parse(_options.WebSocket.Host), _options.WebSocket.Port)); _wsListener.Listen(128); if (_options.WebSocket.Port == 0) { _options.WebSocket.Port = ((IPEndPoint)_wsListener.LocalEndPoint!).Port; } _logger.LogInformation("Listening for WebSocket clients on {Host}:{Port}", _options.WebSocket.Host, _options.WebSocket.Port); if (_options.WebSocket.NoTls) _logger.LogWarning("WebSocket not configured with TLS. DO NOT USE IN PRODUCTION!"); _ = RunWebSocketAcceptLoopAsync(linked.Token); } ``` Add the WebSocket accept loop method: ```csharp private async Task RunWebSocketAcceptLoopAsync(CancellationToken ct) { var tmpDelay = AcceptMinSleep; try { while (!ct.IsCancellationRequested) { Socket socket; try { socket = await _wsListener!.AcceptAsync(ct); tmpDelay = AcceptMinSleep; } catch (OperationCanceledException) { break; } catch (ObjectDisposedException) { break; } catch (SocketException ex) { if (IsShuttingDown || IsLameDuckMode) break; _logger.LogError(ex, "Temporary WebSocket accept error, sleeping {Delay}ms", tmpDelay.TotalMilliseconds); try { await Task.Delay(tmpDelay, ct); } catch (OperationCanceledException) { break; } tmpDelay = TimeSpan.FromTicks(Math.Min(tmpDelay.Ticks * 2, AcceptMaxSleep.Ticks)); continue; } if (_options.MaxConnections > 0 && _clients.Count >= _options.MaxConnections) { socket.Dispose(); continue; } var clientId = Interlocked.Increment(ref _nextClientId); Interlocked.Increment(ref _stats.TotalConnections); Interlocked.Increment(ref _activeClientCount); _ = AcceptWebSocketClientAsync(socket, clientId, ct); } } finally { _wsAcceptLoopExited.TrySetResult(); } } private async Task AcceptWebSocketClientAsync(Socket socket, ulong clientId, CancellationToken ct) { try { var networkStream = new NetworkStream(socket, ownsSocket: false); Stream stream = networkStream; // TLS negotiation if configured if (_sslOptions != null && !_options.WebSocket.NoTls) { var (tlsStream, _) = await Tls.TlsConnectionWrapper.NegotiateAsync( socket, networkStream, _options, _sslOptions, _serverInfo, _loggerFactory.CreateLogger("NATS.Server.Tls"), ct); stream = tlsStream; } // HTTP upgrade handshake var upgradeResult = await WebSocket.WsUpgrade.TryUpgradeAsync(stream, stream, _options.WebSocket); if (!upgradeResult.Success) { _logger.LogDebug("WebSocket upgrade failed for client {ClientId}", clientId); socket.Dispose(); Interlocked.Decrement(ref _activeClientCount); return; } // Create WsConnection wrapper var wsConn = new WebSocket.WsConnection(stream, compress: upgradeResult.Compress, maskRead: upgradeResult.MaskRead, maskWrite: upgradeResult.MaskWrite, browser: upgradeResult.Browser, noCompFrag: upgradeResult.NoCompFrag); var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]"); var client = new NatsClient(clientId, wsConn, socket, _options, _serverInfo, _authService, null, clientLogger, _stats); client.Router = this; client.IsWebSocket = true; client.WsInfo = upgradeResult; _clients[clientId] = client; await RunClientAsync(client, ct); } catch (Exception ex) { _logger.LogDebug(ex, "Failed to accept WebSocket client {ClientId}", clientId); try { socket.Shutdown(SocketShutdown.Both); } catch { } socket.Dispose(); Interlocked.Decrement(ref _activeClientCount); } } ``` In `ShutdownAsync`, add before `_listener?.Close()`: ```csharp _wsListener?.Close(); ``` And after `_acceptLoopExited.Task.WaitAsync(...)`, add: ```csharp await _wsAcceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); ``` In `Dispose`, add: ```csharp _wsListener?.Dispose(); ``` **Step 4: Modify NatsClient.cs** Add two properties: ```csharp public bool IsWebSocket { get; set; } public WsUpgradeResult? WsInfo { get; set; } ``` **Step 5: Run test to verify it passes** Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsIntegrationTests" -v normal` Expected: PASS **Step 6: Commit** ```bash git add src/NATS.Server/NatsServer.cs src/NATS.Server/NatsClient.cs tests/NATS.Server.Tests/WebSocket/WsIntegrationTests.cs git commit -m "feat: integrate WebSocket accept loop into NatsServer" ``` --- ### Task 9: Update differences.md **Files:** - Modify: `differences.md` **Step 1: Update WebSocket row in Connection Types table** Change line 70 from: ``` | WebSocket clients | Y | N | | ``` To: ``` | WebSocket clients | Y | Y | Custom frame parser, permessage-deflate compression, origin checking, cookie auth | ``` **Step 2: Update Missing Options Categories** Change the line: ``` - WebSocket/MQTT options ``` To: ``` - ~~WebSocket options~~ — WebSocketOptions with port, compression, origin checking, cookie auth, custom headers - MQTT options ``` **Step 3: Commit** ```bash git add differences.md git commit -m "docs: update differences.md to reflect WebSocket implementation" ``` --- ### Task 10: Run full test suite and verify **Step 1: Build** Run: `dotnet build` Expected: Build succeeded **Step 2: Run all tests** Run: `dotnet test -v normal` Expected: All tests pass (both existing and new WebSocket tests) **Step 3: Run only WebSocket tests** Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WebSocket" -v normal` Expected: All WebSocket tests pass **Step 4: Final commit (if any fixes needed)** ```bash git add -A git commit -m "fix: address test failures from full suite run" ```