From dac641c52ced4eac0a480fd24d461e177c768ed3 Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Mon, 23 Feb 2026 04:26:40 -0500 Subject: [PATCH] docs: add WebSocket implementation plan with 11 tasks TDD-based plan covering constants, origin checker, frame writer, frame reader, compression, HTTP upgrade, connection wrapper, server/client integration, differences.md update, and verification. --- docs/plans/2026-02-23-websocket-plan.md | 2792 +++++++++++++++++ .../2026-02-23-websocket-plan.md.tasks.json | 17 + 2 files changed, 2809 insertions(+) create mode 100644 docs/plans/2026-02-23-websocket-plan.md create mode 100644 docs/plans/2026-02-23-websocket-plan.md.tasks.json diff --git a/docs/plans/2026-02-23-websocket-plan.md b/docs/plans/2026-02-23-websocket-plan.md new file mode 100644 index 0000000..fcb0c70 --- /dev/null +++ b/docs/plans/2026-02-23-websocket-plan.md @@ -0,0 +1,2792 @@ +# WebSocket Support Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers-extended-cc:executing-plans to implement this plan task-by-task. + +**Goal:** Port full WebSocket connection support from the Go NATS server to the .NET solution, enabling NATS clients to connect over WebSocket with compression, masking, origin checking, and cookie-based auth. + +**Architecture:** Self-contained `WebSocket/` module under `src/NATS.Server/` with custom frame parser (no System.Net.WebSockets). A `WsConnection` Stream wrapper integrates transparently with existing `NatsClient` read/write loops. Second TCP accept loop in `NatsServer` handles WebSocket port. + +**Tech Stack:** .NET 10, System.IO.Compression (DeflateStream), System.Security.Cryptography (SHA1), xUnit 3, Shouldly + +--- + +### Task 0: Add WebSocketOptions configuration + +**Files:** +- Modify: `src/NATS.Server/NatsOptions.cs` + +**Step 1: Write the failing test** + +Create test file `tests/NATS.Server.Tests/WebSocket/WebSocketOptionsTests.cs`: + +```csharp +using Shouldly; + +namespace NATS.Server.Tests.WebSocket; + +public class WebSocketOptionsTests +{ + [Fact] + public void DefaultOptions_PortIsZero_Disabled() + { + var opts = new WebSocketOptions(); + opts.Port.ShouldBe(0); + opts.Host.ShouldBe("0.0.0.0"); + opts.Compression.ShouldBeFalse(); + opts.NoTls.ShouldBeFalse(); + opts.HandshakeTimeout.ShouldBe(TimeSpan.FromSeconds(2)); + opts.AuthTimeout.ShouldBe(TimeSpan.FromSeconds(2)); + } + + [Fact] + public void NatsOptions_HasWebSocketProperty() + { + var opts = new NatsOptions(); + opts.WebSocket.ShouldNotBeNull(); + opts.WebSocket.Port.ShouldBe(0); + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WebSocketOptionsTests" -v normal` +Expected: FAIL — `WebSocketOptions` type does not exist + +**Step 3: Write minimal implementation** + +Add to `src/NATS.Server/NatsOptions.cs` — a new `WebSocketOptions` class and property on `NatsOptions`: + +```csharp +public sealed class WebSocketOptions +{ + public string Host { get; set; } = "0.0.0.0"; + public int Port { get; set; } + public string? Advertise { get; set; } + public string? NoAuthUser { get; set; } + public string? JwtCookie { get; set; } + public string? UsernameCookie { get; set; } + public string? PasswordCookie { get; set; } + public string? TokenCookie { get; set; } + public string? Username { get; set; } + public string? Password { get; set; } + public string? Token { get; set; } + public TimeSpan AuthTimeout { get; set; } = TimeSpan.FromSeconds(2); + public bool NoTls { get; set; } + public string? TlsCert { get; set; } + public string? TlsKey { get; set; } + public bool SameOrigin { get; set; } + public List? AllowedOrigins { get; set; } + public bool Compression { get; set; } + public TimeSpan HandshakeTimeout { get; set; } = TimeSpan.FromSeconds(2); + public TimeSpan? PingInterval { get; set; } + public Dictionary? Headers { get; set; } +} +``` + +Add to `NatsOptions`: +```csharp +public WebSocketOptions WebSocket { get; set; } = new(); +``` + +**Step 4: Run test to verify it passes** + +Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WebSocketOptionsTests" -v normal` +Expected: PASS + +**Step 5: Commit** + +```bash +git add src/NATS.Server/NatsOptions.cs tests/NATS.Server.Tests/WebSocket/WebSocketOptionsTests.cs +git commit -m "feat: add WebSocketOptions configuration class" +``` + +--- + +### Task 1: Add WsConstants + +**Files:** +- Create: `src/NATS.Server/WebSocket/WsConstants.cs` + +**Reference:** `golang/nats-server/server/websocket.go` lines 41-106 + +**Step 1: Write the failing test** + +Create `tests/NATS.Server.Tests/WebSocket/WsConstantsTests.cs`: + +```csharp +using NATS.Server.WebSocket; +using Shouldly; + +namespace NATS.Server.Tests.WebSocket; + +public class WsConstantsTests +{ + [Fact] + public void OpCodes_MatchRfc6455() + { + WsConstants.TextMessage.ShouldBe(1); + WsConstants.BinaryMessage.ShouldBe(2); + WsConstants.CloseMessage.ShouldBe(8); + WsConstants.PingMessage.ShouldBe(9); + WsConstants.PongMessage.ShouldBe(10); + } + + [Fact] + public void FrameBits_MatchRfc6455() + { + WsConstants.FinalBit.ShouldBe(0x80); + WsConstants.Rsv1Bit.ShouldBe(0x40); + WsConstants.MaskBit.ShouldBe(0x80); + } + + [Fact] + public void CloseStatusCodes_MatchRfc6455() + { + WsConstants.CloseStatusNormalClosure.ShouldBe(1000); + WsConstants.CloseStatusGoingAway.ShouldBe(1001); + WsConstants.CloseStatusProtocolError.ShouldBe(1002); + WsConstants.CloseStatusPolicyViolation.ShouldBe(1008); + WsConstants.CloseStatusMessageTooBig.ShouldBe(1009); + } + + [Theory] + [InlineData(WsConstants.CloseMessage)] + [InlineData(WsConstants.PingMessage)] + [InlineData(WsConstants.PongMessage)] + public void IsControlFrame_True(int opcode) + { + WsConstants.IsControlFrame(opcode).ShouldBeTrue(); + } + + [Theory] + [InlineData(WsConstants.TextMessage)] + [InlineData(WsConstants.BinaryMessage)] + [InlineData(0)] + public void IsControlFrame_False(int opcode) + { + WsConstants.IsControlFrame(opcode).ShouldBeFalse(); + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsConstantsTests" -v normal` +Expected: FAIL — `WsConstants` does not exist + +**Step 3: Write minimal implementation** + +Create `src/NATS.Server/WebSocket/WsConstants.cs`: + +```csharp +namespace NATS.Server.WebSocket; + +/// +/// WebSocket protocol constants (RFC 6455). +/// Ported from golang/nats-server/server/websocket.go lines 41-106. +/// +public static class WsConstants +{ + // Opcodes (RFC 6455 Section 5.2) + public const int TextMessage = 1; + public const int BinaryMessage = 2; + public const int CloseMessage = 8; + public const int PingMessage = 9; + public const int PongMessage = 10; + public const int ContinuationFrame = 0; + + // Frame header bits + public const byte FinalBit = 0x80; // 1 << 7 + public const byte Rsv1Bit = 0x40; // 1 << 6 (compression, RFC 7692) + public const byte Rsv2Bit = 0x20; // 1 << 5 + public const byte Rsv3Bit = 0x10; // 1 << 4 + public const byte MaskBit = 0x80; // 1 << 7 (in second byte) + + // Frame size limits + public const int MaxFrameHeaderSize = 14; + public const int MaxControlPayloadSize = 125; + public const int FrameSizeForBrowsers = 4096; + public const int CompressThreshold = 64; + public const int CloseStatusSize = 2; + + // Close status codes (RFC 6455 Section 11.7) + public const int CloseStatusNormalClosure = 1000; + public const int CloseStatusGoingAway = 1001; + public const int CloseStatusProtocolError = 1002; + public const int CloseStatusUnsupportedData = 1003; + public const int CloseStatusNoStatusReceived = 1005; + public const int CloseStatusInvalidPayloadData = 1007; + public const int CloseStatusPolicyViolation = 1008; + public const int CloseStatusMessageTooBig = 1009; + public const int CloseStatusInternalSrvError = 1011; + public const int CloseStatusTlsHandshake = 1015; + + // Compression constants (RFC 7692) + public const string PmcExtension = "permessage-deflate"; + public const string PmcSrvNoCtx = "server_no_context_takeover"; + public const string PmcCliNoCtx = "client_no_context_takeover"; + public static readonly string PmcReqHeaderValue = $"{PmcExtension}; {PmcSrvNoCtx}; {PmcCliNoCtx}"; + public static readonly string PmcFullResponse = $"Sec-WebSocket-Extensions: {PmcExtension}; {PmcSrvNoCtx}; {PmcCliNoCtx}\r\n"; + + // Header names + public const string NoMaskingHeader = "Nats-No-Masking"; + public const string NoMaskingValue = "true"; + public static readonly string NoMaskingFullResponse = $"{NoMaskingHeader}: {NoMaskingValue}\r\n"; + public const string XForwardedForHeader = "X-Forwarded-For"; + + // Path routing + public const string ClientPath = "/"; + public const string LeafNodePath = "/leafnode"; + public const string MqttPath = "/mqtt"; + + // WebSocket GUID (RFC 6455 Section 1.3) + public static readonly byte[] WsGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"u8.ToArray(); + + // Compression trailer (RFC 7692 Section 7.2.2) + public static readonly byte[] CompressLastBlock = [0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff]; + + // Decompression trailer appended before decompressing + public static readonly byte[] DecompressTrailer = [0x00, 0x00, 0xff, 0xff]; + + public static bool IsControlFrame(int opcode) => opcode >= CloseMessage; +} + +public enum WsClientKind +{ + Client, + Leaf, + Mqtt, +} +``` + +**Step 4: Run test to verify it passes** + +Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsConstantsTests" -v normal` +Expected: PASS + +**Step 5: Commit** + +```bash +git add src/NATS.Server/WebSocket/WsConstants.cs tests/NATS.Server.Tests/WebSocket/WsConstantsTests.cs +git commit -m "feat: add WebSocket constants (RFC 6455/7692)" +``` + +--- + +### Task 2: Add WsOriginChecker + +**Files:** +- Create: `src/NATS.Server/WebSocket/WsOriginChecker.cs` +- Create: `tests/NATS.Server.Tests/WebSocket/WsOriginCheckerTests.cs` + +**Reference:** `golang/nats-server/server/websocket.go` lines 933-1000 (`checkOrigin`, `wsGetHostAndPort`) + +**Step 1: Write the failing test** + +Create `tests/NATS.Server.Tests/WebSocket/WsOriginCheckerTests.cs`: + +```csharp +using NATS.Server.WebSocket; +using Shouldly; + +namespace NATS.Server.Tests.WebSocket; + +public class WsOriginCheckerTests +{ + [Fact] + public void NoOriginHeader_Accepted() + { + var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); + checker.CheckOrigin(origin: null, requestHost: "localhost:4222", isTls: false) + .ShouldBeNull(); + } + + [Fact] + public void NeitherSameNorList_AlwaysAccepted() + { + var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: null); + checker.CheckOrigin("https://evil.com", "localhost:4222", false) + .ShouldBeNull(); + } + + [Fact] + public void SameOrigin_Match() + { + var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); + checker.CheckOrigin("http://localhost:4222", "localhost:4222", false) + .ShouldBeNull(); + } + + [Fact] + public void SameOrigin_Mismatch() + { + var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); + checker.CheckOrigin("http://other:4222", "localhost:4222", false) + .ShouldNotBeNull(); + } + + [Fact] + public void SameOrigin_DefaultPort_Http() + { + var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); + // No port in origin means port 80 for http + checker.CheckOrigin("http://localhost", "localhost:80", false) + .ShouldBeNull(); + } + + [Fact] + public void SameOrigin_DefaultPort_Https() + { + var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); + checker.CheckOrigin("https://localhost", "localhost:443", true) + .ShouldBeNull(); + } + + [Fact] + public void AllowedOrigins_Match() + { + var checker = new WsOriginChecker(sameOrigin: false, + allowedOrigins: ["https://app.example.com"]); + checker.CheckOrigin("https://app.example.com", "localhost:4222", false) + .ShouldBeNull(); + } + + [Fact] + public void AllowedOrigins_Mismatch() + { + var checker = new WsOriginChecker(sameOrigin: false, + allowedOrigins: ["https://app.example.com"]); + checker.CheckOrigin("https://evil.example.com", "localhost:4222", false) + .ShouldNotBeNull(); + } + + [Fact] + public void AllowedOrigins_SchemeMismatch() + { + var checker = new WsOriginChecker(sameOrigin: false, + allowedOrigins: ["https://app.example.com"]); + checker.CheckOrigin("http://app.example.com", "localhost:4222", false) + .ShouldNotBeNull(); + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsOriginCheckerTests" -v normal` +Expected: FAIL — `WsOriginChecker` does not exist + +**Step 3: Write minimal implementation** + +Create `src/NATS.Server/WebSocket/WsOriginChecker.cs`: + +```csharp +namespace NATS.Server.WebSocket; + +/// +/// Validates WebSocket Origin headers per RFC 6455 Section 10.2. +/// Ported from golang/nats-server/server/websocket.go lines 933-1000. +/// +public sealed class WsOriginChecker +{ + private readonly bool _sameOrigin; + private readonly Dictionary? _allowedOrigins; + + public WsOriginChecker(bool sameOrigin, List? allowedOrigins) + { + _sameOrigin = sameOrigin; + if (allowedOrigins is { Count: > 0 }) + { + _allowedOrigins = new Dictionary(StringComparer.OrdinalIgnoreCase); + foreach (var ao in allowedOrigins) + { + if (Uri.TryCreate(ao, UriKind.Absolute, out var uri)) + { + var (host, port) = GetHostAndPort(uri.Scheme == "https", uri.Host, uri.Port); + _allowedOrigins[host] = new AllowedOrigin(uri.Scheme, port); + } + } + } + } + + /// + /// Returns null if origin is allowed, or an error message if rejected. + /// + public string? CheckOrigin(string? origin, string requestHost, bool isTls) + { + if (!_sameOrigin && _allowedOrigins == null) + return null; + + if (string.IsNullOrEmpty(origin)) + return null; + + if (!Uri.TryCreate(origin, UriKind.Absolute, out var originUri)) + return $"invalid origin: {origin}"; + + var (oh, op) = GetHostAndPort(originUri.Scheme == "https", originUri.Host, originUri.Port); + + if (_sameOrigin) + { + var (rh, rp) = ParseHostPort(requestHost, isTls); + if (!string.Equals(oh, rh, StringComparison.OrdinalIgnoreCase) || op != rp) + return "not same origin"; + } + + if (_allowedOrigins != null) + { + if (!_allowedOrigins.TryGetValue(oh, out var allowed) || + !string.Equals(originUri.Scheme, allowed.Scheme, StringComparison.OrdinalIgnoreCase) || + op != allowed.Port) + { + return "not in the allowed list"; + } + } + + return null; + } + + private static (string host, int port) GetHostAndPort(bool tls, string host, int port) + { + if (port <= 0) + port = tls ? 443 : 80; + return (host.ToLowerInvariant(), port); + } + + private static (string host, int port) ParseHostPort(string hostPort, bool isTls) + { + var colonIdx = hostPort.LastIndexOf(':'); + if (colonIdx > 0 && int.TryParse(hostPort.AsSpan(colonIdx + 1), out var port)) + return (hostPort[..colonIdx].ToLowerInvariant(), port); + return (hostPort.ToLowerInvariant(), isTls ? 443 : 80); + } + + private readonly record struct AllowedOrigin(string Scheme, int Port); +} +``` + +**Step 4: Run test to verify it passes** + +Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsOriginCheckerTests" -v normal` +Expected: PASS + +**Step 5: Commit** + +```bash +git add src/NATS.Server/WebSocket/WsOriginChecker.cs tests/NATS.Server.Tests/WebSocket/WsOriginCheckerTests.cs +git commit -m "feat: add WebSocket origin checker" +``` + +--- + +### Task 3: Add WsFrameWriter (frame header construction, masking, control frames) + +**Files:** +- Create: `src/NATS.Server/WebSocket/WsFrameWriter.cs` +- Create: `tests/NATS.Server.Tests/WebSocket/WsFrameWriterTests.cs` + +**Reference:** `golang/nats-server/server/websocket.go` lines 543-726 (`wsFillFrameHeader`, `wsCreateFrameHeader`, `wsMaskBuf`, `wsCreateCloseMessage`, `wsEnqueueControlMessageLocked`) + +**Step 1: Write the failing test** + +Create `tests/NATS.Server.Tests/WebSocket/WsFrameWriterTests.cs`: + +```csharp +using System.Buffers.Binary; +using NATS.Server.WebSocket; +using Shouldly; + +namespace NATS.Server.Tests.WebSocket; + +public class WsFrameWriterTests +{ + [Fact] + public void CreateFrameHeader_SmallPayload_7BitLength() + { + var (header, _) = WsFrameWriter.CreateFrameHeader( + useMasking: false, compressed: false, + opcode: WsConstants.BinaryMessage, payloadLength: 100); + header.Length.ShouldBe(2); + (header[0] & WsConstants.FinalBit).ShouldNotBe(0); // FIN set + (header[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage); + (header[1] & 0x7F).ShouldBe(100); + } + + [Fact] + public void CreateFrameHeader_MediumPayload_16BitLength() + { + var (header, _) = WsFrameWriter.CreateFrameHeader( + useMasking: false, compressed: false, + opcode: WsConstants.BinaryMessage, payloadLength: 1000); + header.Length.ShouldBe(4); + (header[1] & 0x7F).ShouldBe(126); + BinaryPrimitives.ReadUInt16BigEndian(header.AsSpan(2)).ShouldBe((ushort)1000); + } + + [Fact] + public void CreateFrameHeader_LargePayload_64BitLength() + { + var (header, _) = WsFrameWriter.CreateFrameHeader( + useMasking: false, compressed: false, + opcode: WsConstants.BinaryMessage, payloadLength: 70000); + header.Length.ShouldBe(10); + (header[1] & 0x7F).ShouldBe(127); + BinaryPrimitives.ReadUInt64BigEndian(header.AsSpan(2)).ShouldBe(70000UL); + } + + [Fact] + public void CreateFrameHeader_WithMasking_Adds4ByteKey() + { + var (header, key) = WsFrameWriter.CreateFrameHeader( + useMasking: true, compressed: false, + opcode: WsConstants.BinaryMessage, payloadLength: 10); + header.Length.ShouldBe(6); // 2 header + 4 mask key + (header[1] & WsConstants.MaskBit).ShouldNotBe(0); + key.ShouldNotBeNull(); + key.Length.ShouldBe(4); + } + + [Fact] + public void CreateFrameHeader_Compressed_SetsRsv1Bit() + { + var (header, _) = WsFrameWriter.CreateFrameHeader( + useMasking: false, compressed: true, + opcode: WsConstants.BinaryMessage, payloadLength: 10); + (header[0] & WsConstants.Rsv1Bit).ShouldNotBe(0); + } + + [Fact] + public void MaskBuf_XorsCorrectly() + { + byte[] key = [0xAA, 0xBB, 0xCC, 0xDD]; + byte[] data = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06]; + byte[] expected = new byte[data.Length]; + for (int i = 0; i < data.Length; i++) + expected[i] = (byte)(data[i] ^ key[i & 3]); + + WsFrameWriter.MaskBuf(key, data); + data.ShouldBe(expected); + } + + [Fact] + public void MaskBuf_RoundTrip() + { + byte[] key = [0x12, 0x34, 0x56, 0x78]; + byte[] original = "Hello, WebSocket!"u8.ToArray(); + var data = original.ToArray(); + + WsFrameWriter.MaskBuf(key, data); + data.ShouldNotBe(original); + WsFrameWriter.MaskBuf(key, data); + data.ShouldBe(original); + } + + [Fact] + public void CreateCloseMessage_WithStatusAndBody() + { + var msg = WsFrameWriter.CreateCloseMessage(1000, "normal closure"); + msg.Length.ShouldBe(2 + "normal closure".Length); + BinaryPrimitives.ReadUInt16BigEndian(msg).ShouldBe((ushort)1000); + } + + [Fact] + public void CreateCloseMessage_LongBody_Truncated() + { + var longBody = new string('x', 200); + var msg = WsFrameWriter.CreateCloseMessage(1000, longBody); + msg.Length.ShouldBeLessThanOrEqualTo(WsConstants.MaxControlPayloadSize); + } + + [Fact] + public void MapCloseStatus_ClientClosed_NormalClosure() + { + WsFrameWriter.MapCloseStatus(ClientClosedReason.ClientClosed) + .ShouldBe(WsConstants.CloseStatusNormalClosure); + } + + [Fact] + public void MapCloseStatus_AuthTimeout_PolicyViolation() + { + WsFrameWriter.MapCloseStatus(ClientClosedReason.AuthenticationTimeout) + .ShouldBe(WsConstants.CloseStatusPolicyViolation); + } + + [Fact] + public void MapCloseStatus_ParseError_ProtocolError() + { + WsFrameWriter.MapCloseStatus(ClientClosedReason.ParseError) + .ShouldBe(WsConstants.CloseStatusProtocolError); + } + + [Fact] + public void MapCloseStatus_MaxPayload_MessageTooBig() + { + WsFrameWriter.MapCloseStatus(ClientClosedReason.MaxPayloadExceeded) + .ShouldBe(WsConstants.CloseStatusMessageTooBig); + } + + [Fact] + public void BuildControlFrame_PingNomask() + { + var frame = WsFrameWriter.BuildControlFrame(WsConstants.PingMessage, [], useMasking: false); + frame.Length.ShouldBe(2); + (frame[0] & WsConstants.FinalBit).ShouldNotBe(0); + (frame[0] & 0x0F).ShouldBe(WsConstants.PingMessage); + (frame[1] & 0x7F).ShouldBe(0); + } + + [Fact] + public void BuildControlFrame_PongWithPayload() + { + byte[] payload = [1, 2, 3, 4]; + var frame = WsFrameWriter.BuildControlFrame(WsConstants.PongMessage, payload, useMasking: false); + frame.Length.ShouldBe(2 + 4); + frame[2..].ShouldBe(payload); + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsFrameWriterTests" -v normal` +Expected: FAIL + +**Step 3: Write minimal implementation** + +Create `src/NATS.Server/WebSocket/WsFrameWriter.cs`: + +```csharp +using System.Buffers.Binary; +using System.Security.Cryptography; +using System.Text; + +namespace NATS.Server.WebSocket; + +/// +/// WebSocket frame construction, masking, and control message creation. +/// Ported from golang/nats-server/server/websocket.go lines 543-726. +/// +public static class WsFrameWriter +{ + /// + /// Creates a complete frame header for a single-frame message (first=true, final=true). + /// Returns (header bytes, mask key or null). + /// + public static (byte[] header, byte[]? key) CreateFrameHeader( + bool useMasking, bool compressed, int opcode, int payloadLength) + { + var fh = new byte[WsConstants.MaxFrameHeaderSize]; + var (n, key) = FillFrameHeader(fh, useMasking, + first: true, final: true, compressed: compressed, opcode: opcode, payloadLength: payloadLength); + return (fh[..n], key); + } + + /// + /// Fills a pre-allocated frame header buffer. + /// Returns (bytes written, mask key or null). + /// + public static (int written, byte[]? key) FillFrameHeader( + Span fh, bool useMasking, bool first, bool final, bool compressed, int opcode, int payloadLength) + { + byte b0 = first ? (byte)opcode : (byte)0; + if (final) b0 |= WsConstants.FinalBit; + if (compressed) b0 |= WsConstants.Rsv1Bit; + + byte b1 = 0; + if (useMasking) b1 |= WsConstants.MaskBit; + + int n; + switch (payloadLength) + { + case <= 125: + n = 2; + fh[0] = b0; + fh[1] = (byte)(b1 | (byte)payloadLength); + break; + case < 65536: + n = 4; + fh[0] = b0; + fh[1] = (byte)(b1 | 126); + BinaryPrimitives.WriteUInt16BigEndian(fh[2..], (ushort)payloadLength); + break; + default: + n = 10; + fh[0] = b0; + fh[1] = (byte)(b1 | 127); + BinaryPrimitives.WriteUInt64BigEndian(fh[2..], (ulong)payloadLength); + break; + } + + byte[]? key = null; + if (useMasking) + { + key = new byte[4]; + RandomNumberGenerator.Fill(key); + key.CopyTo(fh[n..]); + n += 4; + } + + return (n, key); + } + + /// + /// XOR masks a buffer with a 4-byte key. Applies in-place. + /// + public static void MaskBuf(ReadOnlySpan key, Span buf) + { + for (int i = 0; i < buf.Length; i++) + buf[i] ^= key[i & 3]; + } + + /// + /// XOR masks multiple contiguous buffers as if they were one. + /// + public static void MaskBufs(ReadOnlySpan key, List bufs) + { + int pos = 0; + foreach (var buf in bufs) + { + for (int j = 0; j < buf.Length; j++) + { + buf[j] ^= key[pos & 3]; + pos++; + } + } + } + + /// + /// Creates a close message payload: 2-byte status code + optional UTF-8 body. + /// Body truncated to fit MaxControlPayloadSize with "..." suffix. + /// + public static byte[] CreateCloseMessage(int status, string body) + { + if (body.Length > WsConstants.MaxControlPayloadSize - WsConstants.CloseStatusSize) + { + body = body[..(WsConstants.MaxControlPayloadSize - WsConstants.CloseStatusSize - 3)] + "..."; + } + + var bodyBytes = Encoding.UTF8.GetBytes(body); + var buf = new byte[WsConstants.CloseStatusSize + bodyBytes.Length]; + BinaryPrimitives.WriteUInt16BigEndian(buf, (ushort)status); + bodyBytes.CopyTo(buf.AsSpan(WsConstants.CloseStatusSize)); + return buf; + } + + /// + /// Builds a complete control frame (header + payload, optional masking). + /// + public static byte[] BuildControlFrame(int opcode, ReadOnlySpan payload, bool useMasking) + { + int headerSize = 2 + (useMasking ? 4 : 0); + var frame = new byte[headerSize + payload.Length]; + var span = frame.AsSpan(); + var (n, key) = FillFrameHeader(span, useMasking, + first: true, final: true, compressed: false, opcode: opcode, payloadLength: payload.Length); + if (payload.Length > 0) + { + payload.CopyTo(span[n..]); + if (useMasking && key != null) + MaskBuf(key, span[n..]); + } + + return frame; + } + + /// + /// Maps a ClientClosedReason to a WebSocket close status code. + /// Matches Go wsEnqueueCloseMessage in websocket.go lines 668-694. + /// + public static int MapCloseStatus(ClientClosedReason reason) => reason switch + { + ClientClosedReason.ClientClosed => WsConstants.CloseStatusNormalClosure, + ClientClosedReason.AuthenticationTimeout or + ClientClosedReason.AuthenticationViolation or + ClientClosedReason.SlowConsumerPendingBytes or + ClientClosedReason.SlowConsumerWriteDeadline or + ClientClosedReason.MaxSubscriptionsExceeded or + ClientClosedReason.AuthenticationExpired => WsConstants.CloseStatusPolicyViolation, + ClientClosedReason.TlsHandshakeError => WsConstants.CloseStatusTlsHandshake, + ClientClosedReason.ParseError or + ClientClosedReason.ProtocolViolation => WsConstants.CloseStatusProtocolError, + ClientClosedReason.MaxPayloadExceeded => WsConstants.CloseStatusMessageTooBig, + ClientClosedReason.WriteError or + ClientClosedReason.ReadError or + ClientClosedReason.StaleConnection or + ClientClosedReason.ServerShutdown => WsConstants.CloseStatusGoingAway, + _ => WsConstants.CloseStatusInternalSrvError, + }; +} +``` + +**Step 4: Run test to verify it passes** + +Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsFrameWriterTests" -v normal` +Expected: PASS + +**Step 5: Commit** + +```bash +git add src/NATS.Server/WebSocket/WsFrameWriter.cs tests/NATS.Server.Tests/WebSocket/WsFrameWriterTests.cs +git commit -m "feat: add WebSocket frame writer with masking and close status mapping" +``` + +--- + +### Task 4: Add WsReadInfo (frame reader state machine) + +**Files:** +- Create: `src/NATS.Server/WebSocket/WsReadInfo.cs` +- Create: `tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs` + +**Reference:** `golang/nats-server/server/websocket.go` lines 156-440 (`wsReadInfo`, `wsRead`, `unmask`, `decompress`) + +**Step 1: Write the failing test** + +Create `tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs`: + +```csharp +using System.Buffers.Binary; +using NATS.Server.WebSocket; +using Shouldly; + +namespace NATS.Server.Tests.WebSocket; + +public class WsFrameReadTests +{ + /// Helper: build a single unmasked binary frame. + private static byte[] BuildFrame(byte[] payload, bool fin = true, bool compressed = false, int opcode = WsConstants.BinaryMessage, bool mask = false, byte[]? maskKey = null) + { + int headerLen = 2; + int payloadLen = payload.Length; + byte b0 = (byte)opcode; + if (fin) b0 |= WsConstants.FinalBit; + if (compressed) b0 |= WsConstants.Rsv1Bit; + byte b1 = 0; + if (mask) b1 |= WsConstants.MaskBit; + + byte[] lenBytes; + if (payloadLen <= 125) + { + lenBytes = [(byte)(b1 | (byte)payloadLen)]; + } + else if (payloadLen < 65536) + { + lenBytes = new byte[3]; + lenBytes[0] = (byte)(b1 | 126); + BinaryPrimitives.WriteUInt16BigEndian(lenBytes.AsSpan(1), (ushort)payloadLen); + } + else + { + lenBytes = new byte[9]; + lenBytes[0] = (byte)(b1 | 127); + BinaryPrimitives.WriteUInt64BigEndian(lenBytes.AsSpan(1), (ulong)payloadLen); + } + + int totalLen = 1 + lenBytes.Length + (mask ? 4 : 0) + payloadLen; + var frame = new byte[totalLen]; + frame[0] = b0; + lenBytes.CopyTo(frame.AsSpan(1)); + int pos = 1 + lenBytes.Length; + if (mask && maskKey != null) + { + maskKey.CopyTo(frame.AsSpan(pos)); + pos += 4; + var maskedPayload = payload.ToArray(); + WsFrameWriter.MaskBuf(maskKey, maskedPayload); + maskedPayload.CopyTo(frame.AsSpan(pos)); + } + else + { + payload.CopyTo(frame.AsSpan(pos)); + } + return frame; + } + + [Fact] + public void ReadSingleUnmaskedFrame() + { + var payload = "Hello"u8.ToArray(); + var frame = BuildFrame(payload); + + var readInfo = new WsReadInfo(expectMask: false); + var stream = new MemoryStream(frame); + var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); + + result.Count.ShouldBe(1); + result[0].ShouldBe(payload); + } + + [Fact] + public void ReadMaskedFrame() + { + var payload = "Hello"u8.ToArray(); + byte[] key = [0x37, 0xFA, 0x21, 0x3D]; + var frame = BuildFrame(payload, mask: true, maskKey: key); + + var readInfo = new WsReadInfo(expectMask: true); + var stream = new MemoryStream(frame); + var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); + + result.Count.ShouldBe(1); + result[0].ShouldBe(payload); + } + + [Fact] + public void Read16BitLengthFrame() + { + var payload = new byte[200]; + Random.Shared.NextBytes(payload); + var frame = BuildFrame(payload); + + var readInfo = new WsReadInfo(expectMask: false); + var stream = new MemoryStream(frame); + var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); + + result.Count.ShouldBe(1); + result[0].ShouldBe(payload); + } + + [Fact] + public void ReadPingFrame_ReturnsPongAction() + { + var frame = BuildFrame([], opcode: WsConstants.PingMessage); + + var readInfo = new WsReadInfo(expectMask: false); + var stream = new MemoryStream(frame); + var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); + + result.Count.ShouldBe(0); // control frames don't produce payload + readInfo.PendingControlFrames.Count.ShouldBe(1); + readInfo.PendingControlFrames[0].Opcode.ShouldBe(WsConstants.PongMessage); + } + + [Fact] + public void ReadCloseFrame_ReturnsCloseAction() + { + var closePayload = new byte[2]; + BinaryPrimitives.WriteUInt16BigEndian(closePayload, 1000); + var frame = BuildFrame(closePayload, opcode: WsConstants.CloseMessage); + + var readInfo = new WsReadInfo(expectMask: false); + var stream = new MemoryStream(frame); + var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); + + result.Count.ShouldBe(0); + readInfo.CloseReceived.ShouldBeTrue(); + readInfo.CloseStatus.ShouldBe(1000); + } + + [Fact] + public void ReadPongFrame_NoAction() + { + var frame = BuildFrame([], opcode: WsConstants.PongMessage); + + var readInfo = new WsReadInfo(expectMask: false); + var stream = new MemoryStream(frame); + var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); + + result.Count.ShouldBe(0); + readInfo.PendingControlFrames.Count.ShouldBe(0); + } + + [Fact] + public void Unmask_Optimized_8ByteChunks() + { + byte[] key = [0xAA, 0xBB, 0xCC, 0xDD]; + var original = new byte[32]; + Random.Shared.NextBytes(original); + var masked = original.ToArray(); + + // Mask it + for (int i = 0; i < masked.Length; i++) + masked[i] ^= key[i & 3]; + + // Unmask using the state machine + var info = new WsReadInfo(expectMask: true); + info.SetMaskKey(key); + info.Unmask(masked); + + masked.ShouldBe(original); + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsFrameReadTests" -v normal` +Expected: FAIL + +**Step 3: Write minimal implementation** + +Create `src/NATS.Server/WebSocket/WsReadInfo.cs`: + +```csharp +using System.Buffers.Binary; +using System.Text; + +namespace NATS.Server.WebSocket; + +/// +/// Per-connection WebSocket frame reading state machine. +/// Ported from golang/nats-server/server/websocket.go lines 156-506. +/// +public struct WsReadInfo +{ + public int Remaining; + public bool FrameStart; + public bool FirstFrame; + public bool FrameCompressed; + public bool ExpectMask; + public byte MaskKeyPos; + public byte[] MaskKey; + public List? CompressedBuffers; + public int CompressedOffset; + + // Control frame outputs + public List PendingControlFrames; + public bool CloseReceived; + public int CloseStatus; + public string? CloseBody; + + public WsReadInfo(bool expectMask) + { + Remaining = 0; + FrameStart = true; + FirstFrame = true; + FrameCompressed = false; + ExpectMask = expectMask; + MaskKeyPos = 0; + MaskKey = new byte[4]; + CompressedBuffers = null; + CompressedOffset = 0; + PendingControlFrames = []; + CloseReceived = false; + CloseStatus = 0; + CloseBody = null; + } + + public void SetMaskKey(ReadOnlySpan key) + { + key[..4].CopyTo(MaskKey); + MaskKeyPos = 0; + } + + /// + /// Unmask buffer in-place using current mask key and position. + /// Optimized for 8-byte chunks when buffer is large enough. + /// Ported from websocket.go lines 509-536. + /// + public void Unmask(Span buf) + { + int p = MaskKeyPos; + if (buf.Length < 16) + { + for (int i = 0; i < buf.Length; i++) + { + buf[i] ^= MaskKey[p & 3]; + p++; + } + MaskKeyPos = (byte)(p & 3); + return; + } + + // Build 8-byte key for bulk XOR + Span k = stackalloc byte[8]; + for (int i = 0; i < 8; i++) + k[i] = MaskKey[(p + i) & 3]; + ulong km = BinaryPrimitives.ReadUInt64BigEndian(k); + + int n = (buf.Length / 8) * 8; + for (int i = 0; i < n; i += 8) + { + ulong tmp = BinaryPrimitives.ReadUInt64BigEndian(buf[i..]); + tmp ^= km; + BinaryPrimitives.WriteUInt64BigEndian(buf[i..], tmp); + } + + // Handle remaining bytes + var tail = buf[n..]; + for (int i = 0; i < tail.Length; i++) + { + tail[i] ^= MaskKey[p & 3]; + p++; + } + MaskKeyPos = (byte)(p & 3); + } + + /// + /// Read and decode WebSocket frames from a buffer. + /// Returns list of decoded payload byte arrays. + /// Ported from websocket.go lines 208-351. + /// + public static List ReadFrames(ref WsReadInfo r, Stream stream, int available, int maxPayload) + { + var bufs = new List(); + var buf = new byte[available]; + int bytesRead = 0; + + // Fill the buffer from the stream + while (bytesRead < available) + { + int n = stream.Read(buf, bytesRead, available - bytesRead); + if (n == 0) break; + bytesRead += n; + } + + int pos = 0; + int max = bytesRead; + + while (pos < max) + { + if (r.FrameStart) + { + if (pos >= max) break; + byte b0 = buf[pos]; + int frameType = b0 & 0x0F; + bool final = (b0 & WsConstants.FinalBit) != 0; + bool compressed = (b0 & WsConstants.Rsv1Bit) != 0; + pos++; + + // Read second byte + var (b1Buf, newPos) = WsGet(stream, buf, pos, max, 1); + pos = newPos; + byte b1 = b1Buf[0]; + + // Check mask bit + if (r.ExpectMask && (b1 & WsConstants.MaskBit) == 0) + throw new InvalidOperationException("mask bit missing"); + + r.Remaining = b1 & 0x7F; + + // Validate frame types + if (WsConstants.IsControlFrame(frameType)) + { + if (r.Remaining > WsConstants.MaxControlPayloadSize) + throw new InvalidOperationException("control frame length too large"); + if (!final) + throw new InvalidOperationException("control frame does not have final bit set"); + } + else if (frameType == WsConstants.TextMessage || frameType == WsConstants.BinaryMessage) + { + if (!r.FirstFrame) + throw new InvalidOperationException("new message before previous finished"); + r.FirstFrame = final; + r.FrameCompressed = compressed; + } + else if (frameType == WsConstants.ContinuationFrame) + { + if (r.FirstFrame || compressed) + throw new InvalidOperationException("invalid continuation frame"); + r.FirstFrame = final; + } + else + { + throw new InvalidOperationException($"unknown opcode {frameType}"); + } + + // Extended payload length + switch (r.Remaining) + { + case 126: + { + var (lenBuf, p2) = WsGet(stream, buf, pos, max, 2); + pos = p2; + r.Remaining = BinaryPrimitives.ReadUInt16BigEndian(lenBuf); + break; + } + case 127: + { + var (lenBuf, p2) = WsGet(stream, buf, pos, max, 8); + pos = p2; + r.Remaining = (int)BinaryPrimitives.ReadUInt64BigEndian(lenBuf); + break; + } + } + + // Read mask key + if (r.ExpectMask && (b1 & WsConstants.MaskBit) != 0) + { + var (keyBuf, p2) = WsGet(stream, buf, pos, max, 4); + pos = p2; + keyBuf.AsSpan(0, 4).CopyTo(r.MaskKey); + r.MaskKeyPos = 0; + } + + // Handle control frames + if (WsConstants.IsControlFrame(frameType)) + { + pos = HandleControlFrame(ref r, frameType, stream, buf, pos, max); + continue; + } + + r.FrameStart = false; + } + + if (pos < max) + { + int n = r.Remaining; + if (pos + n > max) n = max - pos; + + var payloadSlice = buf.AsSpan(pos, n).ToArray(); + pos += n; + r.Remaining -= n; + + if (r.ExpectMask) + r.Unmask(payloadSlice); + + bool addToBufs = true; + if (r.FrameCompressed) + { + addToBufs = false; + r.CompressedBuffers ??= []; + r.CompressedBuffers.Add(payloadSlice); + + if (r.FirstFrame && r.Remaining == 0) + { + var decompressed = WsCompression.Decompress(r.CompressedBuffers, maxPayload); + r.CompressedBuffers = null; + r.FrameCompressed = false; + addToBufs = true; + payloadSlice = decompressed; + } + } + + if (addToBufs && payloadSlice.Length > 0) + bufs.Add(payloadSlice); + + if (r.Remaining == 0) + r.FrameStart = true; + } + } + + return bufs; + } + + private static int HandleControlFrame(ref WsReadInfo r, int frameType, Stream stream, byte[] buf, int pos, int max) + { + byte[]? payload = null; + if (r.Remaining > 0) + { + var (payloadBuf, newPos) = WsGet(stream, buf, pos, max, r.Remaining); + pos = newPos; + payload = payloadBuf; + if (r.ExpectMask) + r.Unmask(payload); + r.Remaining = 0; + } + + switch (frameType) + { + case WsConstants.CloseMessage: + r.CloseReceived = true; + r.CloseStatus = WsConstants.CloseStatusNoStatusReceived; + if (payload != null && payload.Length >= WsConstants.CloseStatusSize) + { + r.CloseStatus = BinaryPrimitives.ReadUInt16BigEndian(payload); + if (payload.Length > WsConstants.CloseStatusSize) + r.CloseBody = Encoding.UTF8.GetString(payload.AsSpan(WsConstants.CloseStatusSize)); + } + if (r.CloseStatus != WsConstants.CloseStatusNoStatusReceived) + { + var closeMsg = WsFrameWriter.CreateCloseMessage(r.CloseStatus, r.CloseBody ?? ""); + r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.CloseMessage, closeMsg)); + } + break; + + case WsConstants.PingMessage: + r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.PongMessage, payload ?? [])); + break; + + case WsConstants.PongMessage: + // Nothing to do + break; + } + + return pos; + } + + /// + /// Gets needed bytes from buffer or reads from stream. + /// Ported from websocket.go lines 178-193. + /// + private static (byte[] data, int newPos) WsGet(Stream stream, byte[] buf, int pos, int max, int needed) + { + int avail = max - pos; + if (avail >= needed) + return (buf[pos..(pos + needed)], pos + needed); + + var b = new byte[needed]; + int start = 0; + if (avail > 0) + { + Buffer.BlockCopy(buf, pos, b, 0, avail); + start = avail; + } + while (start < needed) + { + int n = stream.Read(b, start, needed - start); + if (n == 0) throw new IOException("unexpected end of stream"); + start += n; + } + return (b, pos + avail); + } +} + +public readonly record struct ControlFrameAction(int Opcode, byte[] Payload); +``` + +**Step 4: Run test to verify it passes** + +Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsFrameReadTests" -v normal` +Expected: PASS + +**Step 5: Commit** + +```bash +git add src/NATS.Server/WebSocket/WsReadInfo.cs tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs +git commit -m "feat: add WebSocket frame reader state machine" +``` + +--- + +### Task 5: Add WsCompression (permessage-deflate) + +**Files:** +- Create: `src/NATS.Server/WebSocket/WsCompression.cs` +- Create: `tests/NATS.Server.Tests/WebSocket/WsCompressionTests.cs` + +**Reference:** `golang/nats-server/server/websocket.go` lines 403-440 (decompress), lines 1391-1466 (compress) + +**Step 1: Write the failing test** + +Create `tests/NATS.Server.Tests/WebSocket/WsCompressionTests.cs`: + +```csharp +using NATS.Server.WebSocket; +using Shouldly; + +namespace NATS.Server.Tests.WebSocket; + +public class WsCompressionTests +{ + [Fact] + public void CompressDecompress_RoundTrip() + { + var original = "Hello, WebSocket compression test! This is long enough to compress."u8.ToArray(); + var compressed = WsCompression.Compress(original); + compressed.ShouldNotBeNull(); + compressed.Length.ShouldBeGreaterThan(0); + + var decompressed = WsCompression.Decompress([compressed], maxPayload: 4096); + decompressed.ShouldBe(original); + } + + [Fact] + public void Decompress_ExceedsMaxPayload_Throws() + { + var original = new byte[1000]; + Random.Shared.NextBytes(original); + var compressed = WsCompression.Compress(original); + + Should.Throw(() => + WsCompression.Decompress([compressed], maxPayload: 100)); + } + + [Fact] + public void Compress_RemovesTrailing4Bytes() + { + var data = new byte[200]; + Random.Shared.NextBytes(data); + var compressed = WsCompression.Compress(data); + + // The compressed data should be valid for decompression when we add the trailer back + var decompressed = WsCompression.Decompress([compressed], maxPayload: 4096); + decompressed.ShouldBe(data); + } + + [Fact] + public void Decompress_MultipleBuffers() + { + var original = new byte[500]; + Random.Shared.NextBytes(original); + var compressed = WsCompression.Compress(original); + + // Split compressed data into multiple chunks + int mid = compressed.Length / 2; + var chunk1 = compressed[..mid]; + var chunk2 = compressed[mid..]; + + var decompressed = WsCompression.Decompress([chunk1, chunk2], maxPayload: 4096); + decompressed.ShouldBe(original); + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsCompressionTests" -v normal` +Expected: FAIL + +**Step 3: Write minimal implementation** + +Create `src/NATS.Server/WebSocket/WsCompression.cs`: + +```csharp +using System.IO.Compression; + +namespace NATS.Server.WebSocket; + +/// +/// permessage-deflate compression/decompression for WebSocket frames (RFC 7692). +/// Ported from golang/nats-server/server/websocket.go lines 403-440 and 1391-1466. +/// +public static class WsCompression +{ + /// + /// Compresses data using deflate. Removes trailing 4 bytes (sync marker) + /// per RFC 7692 Section 7.2.1. + /// + public static byte[] Compress(ReadOnlySpan data) + { + using var output = new MemoryStream(); + using (var deflate = new DeflateStream(output, CompressionLevel.Fastest, leaveOpen: true)) + { + deflate.Write(data); + deflate.Flush(); + } + + var compressed = output.ToArray(); + + // Remove trailing 4-byte sync marker (0x00 0x00 0xff 0xff) per RFC 7692 + if (compressed.Length >= 4) + return compressed[..^4]; + + return compressed; + } + + /// + /// Decompresses collected compressed buffers. + /// Appends trailer bytes before decompressing per RFC 7692 Section 7.2.2. + /// + public static byte[] Decompress(List compressedBuffers, int maxPayload) + { + if (maxPayload <= 0) + maxPayload = 1024 * 1024; // Default 1MB + + // Concatenate all compressed buffers + trailer + int totalLen = 0; + foreach (var buf in compressedBuffers) + totalLen += buf.Length; + totalLen += WsConstants.DecompressTrailer.Length; + + var combined = new byte[totalLen]; + int offset = 0; + foreach (var buf in compressedBuffers) + { + buf.CopyTo(combined, offset); + offset += buf.Length; + } + WsConstants.DecompressTrailer.CopyTo(combined, offset); + + using var input = new MemoryStream(combined); + using var deflate = new DeflateStream(input, CompressionMode.Decompress); + using var output = new MemoryStream(); + + var readBuf = new byte[4096]; + int totalRead = 0; + int n; + while ((n = deflate.Read(readBuf, 0, readBuf.Length)) > 0) + { + totalRead += n; + if (totalRead > maxPayload) + throw new InvalidOperationException("decompressed data exceeds maximum payload size"); + output.Write(readBuf, 0, n); + } + + return output.ToArray(); + } +} +``` + +**Step 4: Run test to verify it passes** + +Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsCompressionTests" -v normal` +Expected: PASS + +**Step 5: Commit** + +```bash +git add src/NATS.Server/WebSocket/WsCompression.cs tests/NATS.Server.Tests/WebSocket/WsCompressionTests.cs +git commit -m "feat: add WebSocket permessage-deflate compression" +``` + +--- + +### Task 6: Add WsUpgrade (HTTP upgrade handshake) + +**Files:** +- Create: `src/NATS.Server/WebSocket/WsUpgrade.cs` +- Create: `tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs` + +**Reference:** `golang/nats-server/server/websocket.go` lines 731-917 (`wsUpgrade`, `wsHeaderContains`, `wsAcceptKey`, `wsPMCExtensionSupport`) + +**Step 1: Write the failing test** + +Create `tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs`: + +```csharp +using System.Text; +using NATS.Server.WebSocket; +using Shouldly; + +namespace NATS.Server.Tests.WebSocket; + +public class WsUpgradeTests +{ + private static string BuildValidRequest(string path = "/", string? extraHeaders = null) + { + var sb = new StringBuilder(); + sb.AppendLine($"GET {path} HTTP/1.1"); + sb.AppendLine("Host: localhost:4222"); + sb.AppendLine("Upgrade: websocket"); + sb.AppendLine("Connection: Upgrade"); + sb.AppendLine("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ=="); + sb.AppendLine("Sec-WebSocket-Version: 13"); + if (extraHeaders != null) + sb.Append(extraHeaders); + sb.AppendLine(); + return sb.ToString(); + } + + [Fact] + public async Task ValidUpgrade_Returns101() + { + var request = BuildValidRequest(); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Kind.ShouldBe(WsClientKind.Client); + var response = ReadResponse(outputStream); + response.ShouldContain("HTTP/1.1 101"); + response.ShouldContain("Upgrade: websocket"); + response.ShouldContain("Sec-WebSocket-Accept:"); + } + + [Fact] + public async Task MissingUpgradeHeader_Returns400() + { + var request = "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n"; + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeFalse(); + ReadResponse(outputStream).ShouldContain("400"); + } + + [Fact] + public async Task MissingHost_Returns400() + { + var request = "GET / HTTP/1.1\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n"; + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeFalse(); + } + + [Fact] + public async Task WrongVersion_Returns400() + { + var request = "GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 12\r\n\r\n"; + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeFalse(); + } + + [Fact] + public async Task LeafNodePath_ReturnsLeafKind() + { + var request = BuildValidRequest("/leafnode"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Kind.ShouldBe(WsClientKind.Leaf); + } + + [Fact] + public async Task MqttPath_ReturnsMqttKind() + { + var request = BuildValidRequest("/mqtt"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Kind.ShouldBe(WsClientKind.Mqtt); + } + + [Fact] + public async Task CompressionNegotiation_WhenEnabled() + { + var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}; {WsConstants.PmcSrvNoCtx}; {WsConstants.PmcCliNoCtx}\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true }); + + result.Success.ShouldBeTrue(); + result.Compress.ShouldBeTrue(); + ReadResponse(outputStream).ShouldContain("permessage-deflate"); + } + + [Fact] + public async Task CompressionNegotiation_WhenDisabled() + { + var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = false }); + + result.Success.ShouldBeTrue(); + result.Compress.ShouldBeFalse(); + } + + [Fact] + public async Task NoMaskingHeader_ForLeaf() + { + var request = BuildValidRequest("/leafnode", "Nats-No-Masking: true\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.MaskRead.ShouldBeFalse(); + } + + [Fact] + public async Task BrowserDetection_Mozilla() + { + var request = BuildValidRequest(extraHeaders: "User-Agent: Mozilla/5.0 (Windows)\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Browser.ShouldBeTrue(); + } + + [Fact] + public async Task SafariDetection_NoCompFrag() + { + var request = BuildValidRequest(extraHeaders: + "User-Agent: Mozilla/5.0 (Macintosh) Version/15.0 Safari/605.1.15\r\n" + + $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true }); + + result.Success.ShouldBeTrue(); + result.NoCompFrag.ShouldBeTrue(); + } + + [Fact] + public async Task AcceptKey_MatchesRfc6455Example() + { + // RFC 6455 Section 4.2.2 example + var key = WsUpgrade.ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ=="); + key.ShouldBe("s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); + } + + [Fact] + public async Task CookieExtraction() + { + var request = BuildValidRequest(extraHeaders: + "Cookie: jwt_token=my-jwt; nats_user=admin; nats_pass=secret\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var opts = new WebSocketOptions + { + NoTls = true, + JwtCookie = "jwt_token", + UsernameCookie = "nats_user", + PasswordCookie = "nats_pass", + }; + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts); + + result.Success.ShouldBeTrue(); + result.CookieJwt.ShouldBe("my-jwt"); + result.CookieUsername.ShouldBe("admin"); + result.CookiePassword.ShouldBe("secret"); + } + + [Fact] + public async Task XForwardedFor_ExtractsClientIp() + { + var request = BuildValidRequest(extraHeaders: "X-Forwarded-For: 192.168.1.100\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.ClientIp.ShouldBe("192.168.1.100"); + } + + [Fact] + public async Task PostMethod_Returns405() + { + var request = "POST / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n"; + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeFalse(); + ReadResponse(outputStream).ShouldContain("405"); + } + + // Helper: create a readable input stream and writable output stream + private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest) + { + var inputBytes = Encoding.ASCII.GetBytes(httpRequest); + return (new MemoryStream(inputBytes), new MemoryStream()); + } + + private static string ReadResponse(MemoryStream output) + { + output.Position = 0; + return Encoding.ASCII.GetString(output.ToArray()); + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsUpgradeTests" -v normal` +Expected: FAIL + +**Step 3: Write minimal implementation** + +Create `src/NATS.Server/WebSocket/WsUpgrade.cs`: + +```csharp +using System.Net; +using System.Security.Cryptography; +using System.Text; + +namespace NATS.Server.WebSocket; + +/// +/// WebSocket HTTP upgrade handshake handler. +/// Ported from golang/nats-server/server/websocket.go lines 731-917. +/// +public static class WsUpgrade +{ + /// + /// Attempts to read an HTTP upgrade request from the input stream, + /// validate per RFC 6455, and write the 101 response to the output stream. + /// + public static async Task TryUpgradeAsync( + Stream inputStream, Stream outputStream, WebSocketOptions options) + { + try + { + // Read HTTP request + var (method, path, headers) = await ReadHttpRequestAsync(inputStream); + + // RFC 6455 Section 4.2.1 validation + // Point 1: Method must be GET + if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase)) + return await FailAsync(outputStream, 405, "request method must be GET"); + + // Point 2: Host header required + if (!headers.ContainsKey("Host")) + return await FailAsync(outputStream, 400, "'Host' missing in request"); + + // Point 3: Upgrade header must contain "websocket" + if (!HeaderContains(headers, "Upgrade", "websocket")) + return await FailAsync(outputStream, 400, "invalid value for header 'Upgrade'"); + + // Point 4: Connection header must contain "Upgrade" + if (!HeaderContains(headers, "Connection", "Upgrade")) + return await FailAsync(outputStream, 400, "invalid value for header 'Connection'"); + + // Point 5: Sec-WebSocket-Key required + if (!headers.TryGetValue("Sec-WebSocket-Key", out var key) || string.IsNullOrEmpty(key)) + return await FailAsync(outputStream, 400, "key missing"); + + // Point 6: Version must be 13 + if (!HeaderContains(headers, "Sec-WebSocket-Version", "13")) + return await FailAsync(outputStream, 400, "invalid version"); + + // Path routing + var kind = path switch + { + _ when path.EndsWith("/leafnode") => WsClientKind.Leaf, + _ when path.EndsWith("/mqtt") => WsClientKind.Mqtt, + _ => WsClientKind.Client, + }; + + // Origin checking + if (options.SameOrigin || options.AllowedOrigins is { Count: > 0 }) + { + var checker = new WsOriginChecker(options.SameOrigin, options.AllowedOrigins); + headers.TryGetValue("Origin", out var origin); + if (string.IsNullOrEmpty(origin)) + headers.TryGetValue("Sec-WebSocket-Origin", out origin); + var originErr = checker.CheckOrigin(origin, headers.GetValueOrDefault("Host", ""), isTls: false); + if (originErr != null) + return await FailAsync(outputStream, 403, $"origin not allowed: {originErr}"); + } + + // Compression negotiation + bool compress = options.Compression; + if (compress) + { + compress = headers.TryGetValue("Sec-WebSocket-Extensions", out var ext) && + ext.Contains(WsConstants.PmcExtension, StringComparison.OrdinalIgnoreCase); + } + + // No-masking negotiation + bool noMasking = headers.TryGetValue(WsConstants.NoMaskingHeader, out var nmVal) && + string.Equals(nmVal.Trim(), WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase); + + // Browser detection + bool browser = false; + bool noCompFrag = false; + if (kind is WsClientKind.Client or WsClientKind.Mqtt && + headers.TryGetValue("User-Agent", out var ua) && ua.StartsWith("Mozilla/")) + { + browser = true; + noCompFrag = compress && ua.Contains("Version/") && ua.Contains("Safari/"); + } + + // Cookie extraction + string? cookieJwt = null, cookieUsername = null, cookiePassword = null, cookieToken = null; + if ((kind is WsClientKind.Client or WsClientKind.Mqtt) && + headers.TryGetValue("Cookie", out var cookieHeader)) + { + var cookies = ParseCookies(cookieHeader); + if (options.JwtCookie != null) cookies.TryGetValue(options.JwtCookie, out cookieJwt); + if (options.UsernameCookie != null) cookies.TryGetValue(options.UsernameCookie, out cookieUsername); + if (options.PasswordCookie != null) cookies.TryGetValue(options.PasswordCookie, out cookiePassword); + if (options.TokenCookie != null) cookies.TryGetValue(options.TokenCookie, out cookieToken); + } + + // X-Forwarded-For + string? clientIp = null; + if (headers.TryGetValue(WsConstants.XForwardedForHeader, out var xff)) + { + var ip = xff.Split(',')[0].Trim(); + if (IPAddress.TryParse(ip, out _)) + clientIp = ip; + } + + // Build 101 response + var response = new StringBuilder(); + response.Append("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "); + response.Append(ComputeAcceptKey(key)); + response.Append("\r\n"); + if (compress) + response.Append(WsConstants.PmcFullResponse); + if (noMasking) + response.Append(WsConstants.NoMaskingFullResponse); + if (options.Headers != null) + { + foreach (var (k, v) in options.Headers) + { + response.Append(k); + response.Append(": "); + response.Append(v); + response.Append("\r\n"); + } + } + response.Append("\r\n"); + + var responseBytes = Encoding.ASCII.GetBytes(response.ToString()); + await outputStream.WriteAsync(responseBytes); + await outputStream.FlushAsync(); + + return new WsUpgradeResult( + Success: true, + Compress: compress, + Browser: browser, + NoCompFrag: noCompFrag, + MaskRead: !noMasking, + MaskWrite: false, + CookieJwt: cookieJwt, + CookieUsername: cookieUsername, + CookiePassword: cookiePassword, + CookieToken: cookieToken, + ClientIp: clientIp, + Kind: kind); + } + catch (Exception) + { + return WsUpgradeResult.Failed; + } + } + + /// + /// Computes the Sec-WebSocket-Accept value per RFC 6455 Section 4.2.2. + /// + public static string ComputeAcceptKey(string clientKey) + { + var combined = Encoding.ASCII.GetBytes(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + var hash = SHA1.HashData(combined); + return Convert.ToBase64String(hash); + } + + private static async Task FailAsync(Stream output, int statusCode, string reason) + { + var statusText = statusCode switch + { + 400 => "Bad Request", + 403 => "Forbidden", + 405 => "Method Not Allowed", + _ => "Internal Server Error", + }; + var response = $"HTTP/1.1 {statusCode} {statusText}\r\nSec-WebSocket-Version: 13\r\nContent-Type: text/plain\r\nContent-Length: {reason.Length}\r\n\r\n{reason}"; + await output.WriteAsync(Encoding.ASCII.GetBytes(response)); + await output.FlushAsync(); + return WsUpgradeResult.Failed; + } + + private static async Task<(string method, string path, Dictionary headers)> ReadHttpRequestAsync(Stream stream) + { + var headerBytes = new List(4096); + int prev = 0; + var buf = new byte[1]; + // Read until \r\n\r\n + while (true) + { + int n = await stream.ReadAsync(buf); + if (n == 0) throw new IOException("connection closed during handshake"); + headerBytes.Add(buf[0]); + if (headerBytes.Count >= 4 && + headerBytes[^4] == '\r' && headerBytes[^3] == '\n' && + headerBytes[^2] == '\r' && headerBytes[^1] == '\n') + break; + if (headerBytes.Count > 8192) + throw new InvalidOperationException("HTTP header too large"); + } + + var text = Encoding.ASCII.GetString(headerBytes.ToArray()); + var lines = text.Split("\r\n", StringSplitOptions.None); + if (lines.Length < 1) throw new InvalidOperationException("invalid HTTP request"); + + // Parse request line + var parts = lines[0].Split(' '); + if (parts.Length < 3) throw new InvalidOperationException("invalid HTTP request line"); + var method = parts[0]; + var path = parts[1]; + + // Parse headers + var headers = new Dictionary(StringComparer.OrdinalIgnoreCase); + for (int i = 1; i < lines.Length; i++) + { + var line = lines[i]; + if (string.IsNullOrEmpty(line)) break; + var colonIdx = line.IndexOf(':'); + if (colonIdx > 0) + { + var name = line[..colonIdx].Trim(); + var value = line[(colonIdx + 1)..].Trim(); + headers[name] = value; + } + } + + return (method, path, headers); + } + + private static bool HeaderContains(Dictionary headers, string name, string value) + { + if (!headers.TryGetValue(name, out var headerValue)) + return false; + foreach (var token in headerValue.Split(',')) + { + if (string.Equals(token.Trim(), value, StringComparison.OrdinalIgnoreCase)) + return true; + } + return false; + } + + private static Dictionary ParseCookies(string cookieHeader) + { + var cookies = new Dictionary(StringComparer.Ordinal); + foreach (var pair in cookieHeader.Split(';')) + { + var trimmed = pair.Trim(); + var eqIdx = trimmed.IndexOf('='); + if (eqIdx > 0) + cookies[trimmed[..eqIdx].Trim()] = trimmed[(eqIdx + 1)..].Trim(); + } + return cookies; + } +} + +public readonly record struct WsUpgradeResult( + bool Success, + bool Compress, + bool Browser, + bool NoCompFrag, + bool MaskRead, + bool MaskWrite, + string? CookieJwt, + string? CookieUsername, + string? CookiePassword, + string? CookieToken, + string? ClientIp, + WsClientKind Kind) +{ + public static readonly WsUpgradeResult Failed = new( + Success: false, Compress: false, Browser: false, NoCompFrag: false, + MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null, + CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client); +} +``` + +**Step 4: Run test to verify it passes** + +Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsUpgradeTests" -v normal` +Expected: PASS + +**Step 5: Commit** + +```bash +git add src/NATS.Server/WebSocket/WsUpgrade.cs tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs +git commit -m "feat: add WebSocket HTTP upgrade handshake" +``` + +--- + +### Task 7: Add WsConnection Stream wrapper + +**Files:** +- Create: `src/NATS.Server/WebSocket/WsConnection.cs` +- Create: `tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs` + +**Step 1: Write the failing test** + +Create `tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs`: + +```csharp +using System.Buffers.Binary; +using NATS.Server.WebSocket; +using Shouldly; + +namespace NATS.Server.Tests.WebSocket; + +public class WsConnectionTests +{ + [Fact] + public async Task ReadAsync_DecodesFrameAndReturnsPayload() + { + var payload = "SUB test 1\r\n"u8.ToArray(); + var frame = BuildUnmaskedFrame(payload); + var inner = new MemoryStream(frame); + var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false); + + var buf = new byte[256]; + int n = await ws.ReadAsync(buf); + + n.ShouldBe(payload.Length); + buf[..n].ShouldBe(payload); + } + + [Fact] + public async Task WriteAsync_FramesPayload() + { + var inner = new MemoryStream(); + var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false); + + var payload = "MSG test 1 5\r\nHello\r\n"u8.ToArray(); + await ws.WriteAsync(payload); + await ws.FlushAsync(); + + inner.Position = 0; + var written = inner.ToArray(); + // First 2 bytes should be WS frame header + (written[0] & WsConstants.FinalBit).ShouldNotBe(0); + (written[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage); + int len = written[1] & 0x7F; + len.ShouldBe(payload.Length); + written[2..].ShouldBe(payload); + } + + [Fact] + public async Task WriteAsync_WithCompression_CompressesLargePayload() + { + var inner = new MemoryStream(); + var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false); + + var payload = new byte[200]; + Array.Fill(payload, 0x41); // 'A' repeated - very compressible + await ws.WriteAsync(payload); + await ws.FlushAsync(); + + inner.Position = 0; + var written = inner.ToArray(); + // RSV1 bit should be set for compressed frame + (written[0] & WsConstants.Rsv1Bit).ShouldNotBe(0); + // Compressed size should be less than original + written.Length.ShouldBeLessThan(payload.Length + 10); + } + + [Fact] + public async Task WriteAsync_SmallPayload_NotCompressedEvenWhenEnabled() + { + var inner = new MemoryStream(); + var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false); + + var payload = "Hi"u8.ToArray(); // Below CompressThreshold + await ws.WriteAsync(payload); + await ws.FlushAsync(); + + inner.Position = 0; + var written = inner.ToArray(); + // RSV1 bit should NOT be set for small payloads + (written[0] & WsConstants.Rsv1Bit).ShouldBe(0); + } + + private static byte[] BuildUnmaskedFrame(byte[] payload) + { + var header = new byte[2]; + header[0] = (byte)(WsConstants.FinalBit | WsConstants.BinaryMessage); + header[1] = (byte)payload.Length; + var frame = new byte[2 + payload.Length]; + header.CopyTo(frame, 0); + payload.CopyTo(frame, 2); + return frame; + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsConnectionTests" -v normal` +Expected: FAIL + +**Step 3: Write minimal implementation** + +Create `src/NATS.Server/WebSocket/WsConnection.cs`: + +```csharp +namespace NATS.Server.WebSocket; + +/// +/// Stream wrapper that transparently frames/deframes WebSocket around raw TCP I/O. +/// NatsClient uses this as its _stream — FillPipeAsync and RunWriteLoopAsync work unchanged. +/// +public sealed class WsConnection : Stream +{ + private readonly Stream _inner; + private readonly bool _compress; + private readonly bool _maskRead; + private readonly bool _maskWrite; + private readonly bool _browser; + private readonly bool _noCompFrag; + private WsReadInfo _readInfo; + private readonly Queue _readQueue = new(); + private int _readOffset; + private readonly object _writeLock = new(); + private readonly List _pendingControlWrites = []; + + public bool CloseReceived => _readInfo.CloseReceived; + public int CloseStatus => _readInfo.CloseStatus; + + public WsConnection(Stream inner, bool compress, bool maskRead, bool maskWrite, bool browser, bool noCompFrag) + { + _inner = inner; + _compress = compress; + _maskRead = maskRead; + _maskWrite = maskWrite; + _browser = browser; + _noCompFrag = noCompFrag; + _readInfo = new WsReadInfo(expectMask: maskRead); + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken ct = default) + { + // Drain any buffered decoded payloads first + if (_readQueue.Count > 0) + return DrainReadQueue(buffer.Span); + + // Read raw bytes from inner stream + var rawBuf = new byte[Math.Max(buffer.Length, 4096)]; + int bytesRead = await _inner.ReadAsync(rawBuf.AsMemory(0, rawBuf.Length), ct); + if (bytesRead == 0) return 0; + + // Decode frames + var payloads = WsReadInfo.ReadFrames(ref _readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024); + + // Collect control frame responses + if (_readInfo.PendingControlFrames.Count > 0) + { + lock (_writeLock) + _pendingControlWrites.AddRange(_readInfo.PendingControlFrames); + _readInfo.PendingControlFrames.Clear(); + // Write pending control frames + await FlushControlFramesAsync(ct); + } + + if (_readInfo.CloseReceived) + return 0; + + foreach (var payload in payloads) + _readQueue.Enqueue(payload); + + if (_readQueue.Count == 0) + return 0; + + return DrainReadQueue(buffer.Span); + } + + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken ct = default) + { + var data = buffer.Span; + + if (_compress && data.Length > WsConstants.CompressThreshold) + { + var compressed = WsCompression.Compress(data); + WriteFramed(compressed, compressed: true, ct); + } + else + { + WriteFramed(data.ToArray(), compressed: false, ct); + } + } + + private void WriteFramed(byte[] payload, bool compressed, CancellationToken ct) + { + if (_browser && payload.Length > WsConstants.FrameSizeForBrowsers && !(_noCompFrag && compressed)) + { + // Fragment for browsers + int offset = 0; + bool first = true; + while (offset < payload.Length) + { + int chunkLen = Math.Min(WsConstants.FrameSizeForBrowsers, payload.Length - offset); + bool final = offset + chunkLen >= payload.Length; + var fh = new byte[WsConstants.MaxFrameHeaderSize]; + var (n, key) = WsFrameWriter.FillFrameHeader(fh, _maskWrite, + first: first, final: final, compressed: first && compressed, + opcode: WsConstants.BinaryMessage, payloadLength: chunkLen); + + var chunk = payload.AsSpan(offset, chunkLen).ToArray(); + if (_maskWrite && key != null) + WsFrameWriter.MaskBuf(key, chunk); + + _inner.Write(fh, 0, n); + _inner.Write(chunk, 0, chunkLen); + offset += chunkLen; + first = false; + } + } + else + { + var (header, key) = WsFrameWriter.CreateFrameHeader(_maskWrite, compressed, WsConstants.BinaryMessage, payload.Length); + if (_maskWrite && key != null) + WsFrameWriter.MaskBuf(key, payload); + _inner.Write(header); + _inner.Write(payload); + } + } + + private async Task FlushControlFramesAsync(CancellationToken ct) + { + List toWrite; + lock (_writeLock) + { + if (_pendingControlWrites.Count == 0) return; + toWrite = [.. _pendingControlWrites]; + _pendingControlWrites.Clear(); + } + + foreach (var action in toWrite) + { + var frame = WsFrameWriter.BuildControlFrame(action.Opcode, action.Payload, _maskWrite); + await _inner.WriteAsync(frame, ct); + } + await _inner.FlushAsync(ct); + } + + /// + /// Sends a WebSocket close frame. + /// + public async Task SendCloseAsync(ClientClosedReason reason, CancellationToken ct = default) + { + var status = WsFrameWriter.MapCloseStatus(reason); + var closePayload = WsFrameWriter.CreateCloseMessage(status, reason.ToReasonString()); + var frame = WsFrameWriter.BuildControlFrame(WsConstants.CloseMessage, closePayload, _maskWrite); + await _inner.WriteAsync(frame, ct); + await _inner.FlushAsync(ct); + } + + private int DrainReadQueue(Span buffer) + { + int written = 0; + while (_readQueue.Count > 0 && written < buffer.Length) + { + var current = _readQueue.Peek(); + int available = current.Length - _readOffset; + int toCopy = Math.Min(available, buffer.Length - written); + current.AsSpan(_readOffset, toCopy).CopyTo(buffer[written..]); + written += toCopy; + _readOffset += toCopy; + if (_readOffset >= current.Length) + { + _readQueue.Dequeue(); + _readOffset = 0; + } + } + return written; + } + + // Stream abstract members + public override bool CanRead => true; + public override bool CanWrite => true; + public override bool CanSeek => false; + public override long Length => throw new NotSupportedException(); + public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } + public override void Flush() => _inner.Flush(); + public override Task FlushAsync(CancellationToken ct) => _inner.FlushAsync(ct); + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use ReadAsync"); + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use WriteAsync"); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + + protected override void Dispose(bool disposing) + { + if (disposing) + _inner.Dispose(); + base.Dispose(disposing); + } +} +``` + +**Step 4: Run test to verify it passes** + +Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsConnectionTests" -v normal` +Expected: PASS + +**Step 5: Commit** + +```bash +git add src/NATS.Server/WebSocket/WsConnection.cs tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs +git commit -m "feat: add WsConnection Stream wrapper for transparent framing" +``` + +--- + +### Task 8: Integrate WebSocket into NatsServer and NatsClient + +**Files:** +- Modify: `src/NATS.Server/NatsServer.cs` +- Modify: `src/NATS.Server/NatsClient.cs` + +**Step 1: Write the failing test** + +Create `tests/NATS.Server.Tests/WebSocket/WsIntegrationTests.cs`: + +```csharp +using System.Buffers.Binary; +using System.Net; +using System.Net.Sockets; +using System.Security.Cryptography; +using System.Text; +using NATS.Server.WebSocket; +using Shouldly; + +namespace NATS.Server.Tests.WebSocket; + +public class WsIntegrationTests : IAsyncLifetime +{ + private NatsServer _server = null!; + private NatsOptions _options = null!; + + public async Task InitializeAsync() + { + _options = new NatsOptions + { + Port = 0, + WebSocket = new WebSocketOptions { Port = 0, NoTls = true }, + }; + var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(b => { }); + _server = new NatsServer(_options, loggerFactory); + _ = _server.StartAsync(CancellationToken.None); + await _server.WaitForReadyAsync(); + } + + public async Task DisposeAsync() + { + await _server.ShutdownAsync(); + _server.Dispose(); + } + + [Fact] + public async Task WebSocket_ConnectAndReceiveInfo() + { + using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port)); + using var stream = new NetworkStream(socket, ownsSocket: false); + + // Send WebSocket upgrade request + await SendUpgradeRequest(stream); + + // Read 101 response + var response = await ReadHttpResponse(stream); + response.ShouldContain("101"); + + // Now read the INFO line through WebSocket frames + var wsFrame = await ReadWsFrame(stream); + var info = Encoding.ASCII.GetString(wsFrame); + info.ShouldStartWith("INFO "); + } + + [Fact] + public async Task WebSocket_PubSub() + { + // Connect two WS clients + using var sub = await ConnectWsClient(); + using var pub = await ConnectWsClient(); + + // Subscribe on first client + await SendWsText(sub, "CONNECT {}\r\nSUB test.ws 1\r\n"); + await Task.Delay(100); + + // Publish on second client + await SendWsText(pub, "CONNECT {}\r\nPUB test.ws 5\r\nHello\r\n"); + await Task.Delay(100); + + // Read from subscriber + var msg = await ReadWsFrame(sub); + Encoding.ASCII.GetString(msg).ShouldContain("MSG test.ws 1 5"); + } + + private async Task ConnectWsClient() + { + var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port)); + var stream = new NetworkStream(socket, ownsSocket: true); + + await SendUpgradeRequest(stream); + var response = await ReadHttpResponse(stream); + response.ShouldContain("101"); + + // Read INFO frame + await ReadWsFrame(stream); + + return stream; + } + + private static async Task SendUpgradeRequest(NetworkStream stream) + { + var keyBytes = new byte[16]; + RandomNumberGenerator.Fill(keyBytes); + var key = Convert.ToBase64String(keyBytes); + + var request = $"GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {key}\r\nSec-WebSocket-Version: 13\r\n\r\n"; + await stream.WriteAsync(Encoding.ASCII.GetBytes(request)); + await stream.FlushAsync(); + } + + private static async Task ReadHttpResponse(NetworkStream stream) + { + var buf = new byte[4096]; + var sb = new StringBuilder(); + while (true) + { + int n = await stream.ReadAsync(buf); + if (n == 0) break; + sb.Append(Encoding.ASCII.GetString(buf, 0, n)); + if (sb.ToString().Contains("\r\n\r\n")) break; + } + return sb.ToString(); + } + + private static async Task ReadWsFrame(NetworkStream stream) + { + var header = new byte[2]; + await stream.ReadExactlyAsync(header); + int len = header[1] & 0x7F; + byte[]? extLen = null; + if (len == 126) + { + extLen = new byte[2]; + await stream.ReadExactlyAsync(extLen); + len = BinaryPrimitives.ReadUInt16BigEndian(extLen); + } + else if (len == 127) + { + extLen = new byte[8]; + await stream.ReadExactlyAsync(extLen); + len = (int)BinaryPrimitives.ReadUInt64BigEndian(extLen); + } + var payload = new byte[len]; + if (len > 0) await stream.ReadExactlyAsync(payload); + return payload; + } + + private static async Task SendWsText(NetworkStream stream, string text) + { + var payload = Encoding.ASCII.GetBytes(text); + var (header, _) = WsFrameWriter.CreateFrameHeader( + useMasking: true, compressed: false, + opcode: WsConstants.BinaryMessage, payloadLength: payload.Length); + // The masking key is in the header — we need to mask the payload + var maskKey = header[^4..]; + WsFrameWriter.MaskBuf(maskKey, payload); + await stream.WriteAsync(header); + await stream.WriteAsync(payload); + await stream.FlushAsync(); + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsIntegrationTests" -v normal` +Expected: FAIL — NatsServer has no WebSocket listener + +**Step 3: Modify NatsServer.cs** + +Add these fields to `NatsServer`: + +```csharp +private Socket? _wsListener; +private readonly TaskCompletionSource _wsAcceptLoopExited = new(TaskCreationOptions.RunContinuationsAsynchronously); +``` + +In `StartAsync`, after the monitoring server startup and before the main accept loop, add: + +```csharp +if (_options.WebSocket.Port > 0) +{ + _wsListener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + _wsListener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true); + _wsListener.Bind(new IPEndPoint( + _options.WebSocket.Host == "0.0.0.0" ? IPAddress.Any : IPAddress.Parse(_options.WebSocket.Host), + _options.WebSocket.Port)); + _wsListener.Listen(128); + + if (_options.WebSocket.Port == 0) + { + _options.WebSocket.Port = ((IPEndPoint)_wsListener.LocalEndPoint!).Port; + } + + _logger.LogInformation("Listening for WebSocket clients on {Host}:{Port}", + _options.WebSocket.Host, _options.WebSocket.Port); + + if (_options.WebSocket.NoTls) + _logger.LogWarning("WebSocket not configured with TLS. DO NOT USE IN PRODUCTION!"); + + _ = RunWebSocketAcceptLoopAsync(linked.Token); +} +``` + +Add the WebSocket accept loop method: + +```csharp +private async Task RunWebSocketAcceptLoopAsync(CancellationToken ct) +{ + var tmpDelay = AcceptMinSleep; + try + { + while (!ct.IsCancellationRequested) + { + Socket socket; + try + { + socket = await _wsListener!.AcceptAsync(ct); + tmpDelay = AcceptMinSleep; + } + catch (OperationCanceledException) { break; } + catch (ObjectDisposedException) { break; } + catch (SocketException ex) + { + if (IsShuttingDown || IsLameDuckMode) break; + _logger.LogError(ex, "Temporary WebSocket accept error, sleeping {Delay}ms", tmpDelay.TotalMilliseconds); + try { await Task.Delay(tmpDelay, ct); } catch (OperationCanceledException) { break; } + tmpDelay = TimeSpan.FromTicks(Math.Min(tmpDelay.Ticks * 2, AcceptMaxSleep.Ticks)); + continue; + } + + if (_options.MaxConnections > 0 && _clients.Count >= _options.MaxConnections) + { + socket.Dispose(); + continue; + } + + var clientId = Interlocked.Increment(ref _nextClientId); + Interlocked.Increment(ref _stats.TotalConnections); + Interlocked.Increment(ref _activeClientCount); + + _ = AcceptWebSocketClientAsync(socket, clientId, ct); + } + } + finally + { + _wsAcceptLoopExited.TrySetResult(); + } +} + +private async Task AcceptWebSocketClientAsync(Socket socket, ulong clientId, CancellationToken ct) +{ + try + { + var networkStream = new NetworkStream(socket, ownsSocket: false); + Stream stream = networkStream; + + // TLS negotiation if configured + if (_sslOptions != null && !_options.WebSocket.NoTls) + { + var (tlsStream, _) = await Tls.TlsConnectionWrapper.NegotiateAsync( + socket, networkStream, _options, _sslOptions, _serverInfo, + _loggerFactory.CreateLogger("NATS.Server.Tls"), ct); + stream = tlsStream; + } + + // HTTP upgrade handshake + var upgradeResult = await WebSocket.WsUpgrade.TryUpgradeAsync(stream, stream, _options.WebSocket); + if (!upgradeResult.Success) + { + _logger.LogDebug("WebSocket upgrade failed for client {ClientId}", clientId); + socket.Dispose(); + Interlocked.Decrement(ref _activeClientCount); + return; + } + + // Create WsConnection wrapper + var wsConn = new WebSocket.WsConnection(stream, + compress: upgradeResult.Compress, + maskRead: upgradeResult.MaskRead, + maskWrite: upgradeResult.MaskWrite, + browser: upgradeResult.Browser, + noCompFrag: upgradeResult.NoCompFrag); + + var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]"); + var client = new NatsClient(clientId, wsConn, socket, _options, _serverInfo, + _authService, null, clientLogger, _stats); + client.Router = this; + client.IsWebSocket = true; + client.WsInfo = upgradeResult; + _clients[clientId] = client; + + await RunClientAsync(client, ct); + } + catch (Exception ex) + { + _logger.LogDebug(ex, "Failed to accept WebSocket client {ClientId}", clientId); + try { socket.Shutdown(SocketShutdown.Both); } catch { } + socket.Dispose(); + Interlocked.Decrement(ref _activeClientCount); + } +} +``` + +In `ShutdownAsync`, add before `_listener?.Close()`: + +```csharp +_wsListener?.Close(); +``` + +And after `_acceptLoopExited.Task.WaitAsync(...)`, add: + +```csharp +await _wsAcceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); +``` + +In `Dispose`, add: + +```csharp +_wsListener?.Dispose(); +``` + +**Step 4: Modify NatsClient.cs** + +Add two properties: + +```csharp +public bool IsWebSocket { get; set; } +public WsUpgradeResult? WsInfo { get; set; } +``` + +**Step 5: Run test to verify it passes** + +Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsIntegrationTests" -v normal` +Expected: PASS + +**Step 6: Commit** + +```bash +git add src/NATS.Server/NatsServer.cs src/NATS.Server/NatsClient.cs tests/NATS.Server.Tests/WebSocket/WsIntegrationTests.cs +git commit -m "feat: integrate WebSocket accept loop into NatsServer" +``` + +--- + +### Task 9: Update differences.md + +**Files:** +- Modify: `differences.md` + +**Step 1: Update WebSocket row in Connection Types table** + +Change line 70 from: +``` +| WebSocket clients | Y | N | | +``` +To: +``` +| WebSocket clients | Y | Y | Custom frame parser, permessage-deflate compression, origin checking, cookie auth | +``` + +**Step 2: Update Missing Options Categories** + +Change the line: +``` +- WebSocket/MQTT options +``` +To: +``` +- ~~WebSocket options~~ — WebSocketOptions with port, compression, origin checking, cookie auth, custom headers +- MQTT options +``` + +**Step 3: Commit** + +```bash +git add differences.md +git commit -m "docs: update differences.md to reflect WebSocket implementation" +``` + +--- + +### Task 10: Run full test suite and verify + +**Step 1: Build** + +Run: `dotnet build` +Expected: Build succeeded + +**Step 2: Run all tests** + +Run: `dotnet test -v normal` +Expected: All tests pass (both existing and new WebSocket tests) + +**Step 3: Run only WebSocket tests** + +Run: `dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WebSocket" -v normal` +Expected: All WebSocket tests pass + +**Step 4: Final commit (if any fixes needed)** + +```bash +git add -A +git commit -m "fix: address test failures from full suite run" +``` diff --git a/docs/plans/2026-02-23-websocket-plan.md.tasks.json b/docs/plans/2026-02-23-websocket-plan.md.tasks.json new file mode 100644 index 0000000..5c9617c --- /dev/null +++ b/docs/plans/2026-02-23-websocket-plan.md.tasks.json @@ -0,0 +1,17 @@ +{ + "planPath": "docs/plans/2026-02-23-websocket-plan.md", + "tasks": [ + {"id": 6, "subject": "Task 0: Add WebSocketOptions configuration", "status": "pending"}, + {"id": 7, "subject": "Task 1: Add WsConstants", "status": "pending", "blockedBy": [6]}, + {"id": 8, "subject": "Task 2: Add WsOriginChecker", "status": "pending", "blockedBy": [6, 7]}, + {"id": 9, "subject": "Task 3: Add WsFrameWriter", "status": "pending", "blockedBy": [7, 8]}, + {"id": 10, "subject": "Task 4: Add WsReadInfo frame reader state machine", "status": "pending", "blockedBy": [7, 8, 9]}, + {"id": 11, "subject": "Task 5: Add WsCompression (permessage-deflate)", "status": "pending", "blockedBy": [7]}, + {"id": 12, "subject": "Task 6: Add WsUpgrade HTTP handshake", "status": "pending", "blockedBy": [7, 8, 11]}, + {"id": 13, "subject": "Task 7: Add WsConnection Stream wrapper", "status": "pending", "blockedBy": [7, 9, 10, 11]}, + {"id": 14, "subject": "Task 8: Integrate WebSocket into NatsServer and NatsClient", "status": "pending", "blockedBy": [6, 7, 12, 13]}, + {"id": 15, "subject": "Task 9: Update differences.md", "status": "pending", "blockedBy": [14]}, + {"id": 16, "subject": "Task 10: Run full test suite and verify", "status": "pending", "blockedBy": [14, 15]} + ], + "lastUpdated": "2026-02-23T00:00:00Z" +}