Files
natsdotnet/docs/plans/2026-02-23-websocket-plan.md
Joseph Doherty dac641c52c docs: add WebSocket implementation plan with 11 tasks
TDD-based plan covering constants, origin checker, frame writer,
frame reader, compression, HTTP upgrade, connection wrapper,
server/client integration, differences.md update, and verification.
2026-02-23 04:26:40 -05:00

92 KiB

WebSocket Support Implementation Plan

For Claude: REQUIRED SUB-SKILL: Use superpowers-extended-cc:executing-plans to implement this plan task-by-task.

Goal: Port full WebSocket connection support from the Go NATS server to the .NET solution, enabling NATS clients to connect over WebSocket with compression, masking, origin checking, and cookie-based auth.

Architecture: Self-contained WebSocket/ module under src/NATS.Server/ with custom frame parser (no System.Net.WebSockets). A WsConnection Stream wrapper integrates transparently with existing NatsClient read/write loops. Second TCP accept loop in NatsServer handles WebSocket port.

Tech Stack: .NET 10, System.IO.Compression (DeflateStream), System.Security.Cryptography (SHA1), xUnit 3, Shouldly


Task 0: Add WebSocketOptions configuration

Files:

  • Modify: src/NATS.Server/NatsOptions.cs

Step 1: Write the failing test

Create test file tests/NATS.Server.Tests/WebSocket/WebSocketOptionsTests.cs:

using Shouldly;

namespace NATS.Server.Tests.WebSocket;

public class WebSocketOptionsTests
{
    [Fact]
    public void DefaultOptions_PortIsZero_Disabled()
    {
        var opts = new WebSocketOptions();
        opts.Port.ShouldBe(0);
        opts.Host.ShouldBe("0.0.0.0");
        opts.Compression.ShouldBeFalse();
        opts.NoTls.ShouldBeFalse();
        opts.HandshakeTimeout.ShouldBe(TimeSpan.FromSeconds(2));
        opts.AuthTimeout.ShouldBe(TimeSpan.FromSeconds(2));
    }

    [Fact]
    public void NatsOptions_HasWebSocketProperty()
    {
        var opts = new NatsOptions();
        opts.WebSocket.ShouldNotBeNull();
        opts.WebSocket.Port.ShouldBe(0);
    }
}

Step 2: Run test to verify it fails

Run: dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WebSocketOptionsTests" -v normal Expected: FAIL — WebSocketOptions type does not exist

Step 3: Write minimal implementation

Add to src/NATS.Server/NatsOptions.cs — a new WebSocketOptions class and property on NatsOptions:

public sealed class WebSocketOptions
{
    public string Host { get; set; } = "0.0.0.0";
    public int Port { get; set; }
    public string? Advertise { get; set; }
    public string? NoAuthUser { get; set; }
    public string? JwtCookie { get; set; }
    public string? UsernameCookie { get; set; }
    public string? PasswordCookie { get; set; }
    public string? TokenCookie { get; set; }
    public string? Username { get; set; }
    public string? Password { get; set; }
    public string? Token { get; set; }
    public TimeSpan AuthTimeout { get; set; } = TimeSpan.FromSeconds(2);
    public bool NoTls { get; set; }
    public string? TlsCert { get; set; }
    public string? TlsKey { get; set; }
    public bool SameOrigin { get; set; }
    public List<string>? AllowedOrigins { get; set; }
    public bool Compression { get; set; }
    public TimeSpan HandshakeTimeout { get; set; } = TimeSpan.FromSeconds(2);
    public TimeSpan? PingInterval { get; set; }
    public Dictionary<string, string>? Headers { get; set; }
}

Add to NatsOptions:

public WebSocketOptions WebSocket { get; set; } = new();

Step 4: Run test to verify it passes

Run: dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WebSocketOptionsTests" -v normal Expected: PASS

Step 5: Commit

git add src/NATS.Server/NatsOptions.cs tests/NATS.Server.Tests/WebSocket/WebSocketOptionsTests.cs
git commit -m "feat: add WebSocketOptions configuration class"

Task 1: Add WsConstants

Files:

  • Create: src/NATS.Server/WebSocket/WsConstants.cs

Reference: golang/nats-server/server/websocket.go lines 41-106

Step 1: Write the failing test

Create tests/NATS.Server.Tests/WebSocket/WsConstantsTests.cs:

using NATS.Server.WebSocket;
using Shouldly;

namespace NATS.Server.Tests.WebSocket;

public class WsConstantsTests
{
    [Fact]
    public void OpCodes_MatchRfc6455()
    {
        WsConstants.TextMessage.ShouldBe(1);
        WsConstants.BinaryMessage.ShouldBe(2);
        WsConstants.CloseMessage.ShouldBe(8);
        WsConstants.PingMessage.ShouldBe(9);
        WsConstants.PongMessage.ShouldBe(10);
    }

    [Fact]
    public void FrameBits_MatchRfc6455()
    {
        WsConstants.FinalBit.ShouldBe(0x80);
        WsConstants.Rsv1Bit.ShouldBe(0x40);
        WsConstants.MaskBit.ShouldBe(0x80);
    }

    [Fact]
    public void CloseStatusCodes_MatchRfc6455()
    {
        WsConstants.CloseStatusNormalClosure.ShouldBe(1000);
        WsConstants.CloseStatusGoingAway.ShouldBe(1001);
        WsConstants.CloseStatusProtocolError.ShouldBe(1002);
        WsConstants.CloseStatusPolicyViolation.ShouldBe(1008);
        WsConstants.CloseStatusMessageTooBig.ShouldBe(1009);
    }

    [Theory]
    [InlineData(WsConstants.CloseMessage)]
    [InlineData(WsConstants.PingMessage)]
    [InlineData(WsConstants.PongMessage)]
    public void IsControlFrame_True(int opcode)
    {
        WsConstants.IsControlFrame(opcode).ShouldBeTrue();
    }

    [Theory]
    [InlineData(WsConstants.TextMessage)]
    [InlineData(WsConstants.BinaryMessage)]
    [InlineData(0)]
    public void IsControlFrame_False(int opcode)
    {
        WsConstants.IsControlFrame(opcode).ShouldBeFalse();
    }
}

Step 2: Run test to verify it fails

Run: dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsConstantsTests" -v normal Expected: FAIL — WsConstants does not exist

Step 3: Write minimal implementation

Create src/NATS.Server/WebSocket/WsConstants.cs:

namespace NATS.Server.WebSocket;

/// <summary>
/// WebSocket protocol constants (RFC 6455).
/// Ported from golang/nats-server/server/websocket.go lines 41-106.
/// </summary>
public static class WsConstants
{
    // Opcodes (RFC 6455 Section 5.2)
    public const int TextMessage = 1;
    public const int BinaryMessage = 2;
    public const int CloseMessage = 8;
    public const int PingMessage = 9;
    public const int PongMessage = 10;
    public const int ContinuationFrame = 0;

    // Frame header bits
    public const byte FinalBit = 0x80;  // 1 << 7
    public const byte Rsv1Bit = 0x40;   // 1 << 6 (compression, RFC 7692)
    public const byte Rsv2Bit = 0x20;   // 1 << 5
    public const byte Rsv3Bit = 0x10;   // 1 << 4
    public const byte MaskBit = 0x80;   // 1 << 7 (in second byte)

    // Frame size limits
    public const int MaxFrameHeaderSize = 14;
    public const int MaxControlPayloadSize = 125;
    public const int FrameSizeForBrowsers = 4096;
    public const int CompressThreshold = 64;
    public const int CloseStatusSize = 2;

    // Close status codes (RFC 6455 Section 11.7)
    public const int CloseStatusNormalClosure = 1000;
    public const int CloseStatusGoingAway = 1001;
    public const int CloseStatusProtocolError = 1002;
    public const int CloseStatusUnsupportedData = 1003;
    public const int CloseStatusNoStatusReceived = 1005;
    public const int CloseStatusInvalidPayloadData = 1007;
    public const int CloseStatusPolicyViolation = 1008;
    public const int CloseStatusMessageTooBig = 1009;
    public const int CloseStatusInternalSrvError = 1011;
    public const int CloseStatusTlsHandshake = 1015;

    // Compression constants (RFC 7692)
    public const string PmcExtension = "permessage-deflate";
    public const string PmcSrvNoCtx = "server_no_context_takeover";
    public const string PmcCliNoCtx = "client_no_context_takeover";
    public static readonly string PmcReqHeaderValue = $"{PmcExtension}; {PmcSrvNoCtx}; {PmcCliNoCtx}";
    public static readonly string PmcFullResponse = $"Sec-WebSocket-Extensions: {PmcExtension}; {PmcSrvNoCtx}; {PmcCliNoCtx}\r\n";

    // Header names
    public const string NoMaskingHeader = "Nats-No-Masking";
    public const string NoMaskingValue = "true";
    public static readonly string NoMaskingFullResponse = $"{NoMaskingHeader}: {NoMaskingValue}\r\n";
    public const string XForwardedForHeader = "X-Forwarded-For";

    // Path routing
    public const string ClientPath = "/";
    public const string LeafNodePath = "/leafnode";
    public const string MqttPath = "/mqtt";

    // WebSocket GUID (RFC 6455 Section 1.3)
    public static readonly byte[] WsGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"u8.ToArray();

    // Compression trailer (RFC 7692 Section 7.2.2)
    public static readonly byte[] CompressLastBlock = [0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff];

    // Decompression trailer appended before decompressing
    public static readonly byte[] DecompressTrailer = [0x00, 0x00, 0xff, 0xff];

    public static bool IsControlFrame(int opcode) => opcode >= CloseMessage;
}

public enum WsClientKind
{
    Client,
    Leaf,
    Mqtt,
}

Step 4: Run test to verify it passes

Run: dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsConstantsTests" -v normal Expected: PASS

Step 5: Commit

git add src/NATS.Server/WebSocket/WsConstants.cs tests/NATS.Server.Tests/WebSocket/WsConstantsTests.cs
git commit -m "feat: add WebSocket constants (RFC 6455/7692)"

Task 2: Add WsOriginChecker

Files:

  • Create: src/NATS.Server/WebSocket/WsOriginChecker.cs
  • Create: tests/NATS.Server.Tests/WebSocket/WsOriginCheckerTests.cs

Reference: golang/nats-server/server/websocket.go lines 933-1000 (checkOrigin, wsGetHostAndPort)

Step 1: Write the failing test

Create tests/NATS.Server.Tests/WebSocket/WsOriginCheckerTests.cs:

using NATS.Server.WebSocket;
using Shouldly;

namespace NATS.Server.Tests.WebSocket;

public class WsOriginCheckerTests
{
    [Fact]
    public void NoOriginHeader_Accepted()
    {
        var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
        checker.CheckOrigin(origin: null, requestHost: "localhost:4222", isTls: false)
            .ShouldBeNull();
    }

    [Fact]
    public void NeitherSameNorList_AlwaysAccepted()
    {
        var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: null);
        checker.CheckOrigin("https://evil.com", "localhost:4222", false)
            .ShouldBeNull();
    }

    [Fact]
    public void SameOrigin_Match()
    {
        var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
        checker.CheckOrigin("http://localhost:4222", "localhost:4222", false)
            .ShouldBeNull();
    }

    [Fact]
    public void SameOrigin_Mismatch()
    {
        var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
        checker.CheckOrigin("http://other:4222", "localhost:4222", false)
            .ShouldNotBeNull();
    }

    [Fact]
    public void SameOrigin_DefaultPort_Http()
    {
        var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
        // No port in origin means port 80 for http
        checker.CheckOrigin("http://localhost", "localhost:80", false)
            .ShouldBeNull();
    }

    [Fact]
    public void SameOrigin_DefaultPort_Https()
    {
        var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
        checker.CheckOrigin("https://localhost", "localhost:443", true)
            .ShouldBeNull();
    }

    [Fact]
    public void AllowedOrigins_Match()
    {
        var checker = new WsOriginChecker(sameOrigin: false,
            allowedOrigins: ["https://app.example.com"]);
        checker.CheckOrigin("https://app.example.com", "localhost:4222", false)
            .ShouldBeNull();
    }

    [Fact]
    public void AllowedOrigins_Mismatch()
    {
        var checker = new WsOriginChecker(sameOrigin: false,
            allowedOrigins: ["https://app.example.com"]);
        checker.CheckOrigin("https://evil.example.com", "localhost:4222", false)
            .ShouldNotBeNull();
    }

    [Fact]
    public void AllowedOrigins_SchemeMismatch()
    {
        var checker = new WsOriginChecker(sameOrigin: false,
            allowedOrigins: ["https://app.example.com"]);
        checker.CheckOrigin("http://app.example.com", "localhost:4222", false)
            .ShouldNotBeNull();
    }
}

Step 2: Run test to verify it fails

Run: dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsOriginCheckerTests" -v normal Expected: FAIL — WsOriginChecker does not exist

Step 3: Write minimal implementation

Create src/NATS.Server/WebSocket/WsOriginChecker.cs:

namespace NATS.Server.WebSocket;

/// <summary>
/// Validates WebSocket Origin headers per RFC 6455 Section 10.2.
/// Ported from golang/nats-server/server/websocket.go lines 933-1000.
/// </summary>
public sealed class WsOriginChecker
{
    private readonly bool _sameOrigin;
    private readonly Dictionary<string, AllowedOrigin>? _allowedOrigins;

    public WsOriginChecker(bool sameOrigin, List<string>? allowedOrigins)
    {
        _sameOrigin = sameOrigin;
        if (allowedOrigins is { Count: > 0 })
        {
            _allowedOrigins = new Dictionary<string, AllowedOrigin>(StringComparer.OrdinalIgnoreCase);
            foreach (var ao in allowedOrigins)
            {
                if (Uri.TryCreate(ao, UriKind.Absolute, out var uri))
                {
                    var (host, port) = GetHostAndPort(uri.Scheme == "https", uri.Host, uri.Port);
                    _allowedOrigins[host] = new AllowedOrigin(uri.Scheme, port);
                }
            }
        }
    }

    /// <summary>
    /// Returns null if origin is allowed, or an error message if rejected.
    /// </summary>
    public string? CheckOrigin(string? origin, string requestHost, bool isTls)
    {
        if (!_sameOrigin && _allowedOrigins == null)
            return null;

        if (string.IsNullOrEmpty(origin))
            return null;

        if (!Uri.TryCreate(origin, UriKind.Absolute, out var originUri))
            return $"invalid origin: {origin}";

        var (oh, op) = GetHostAndPort(originUri.Scheme == "https", originUri.Host, originUri.Port);

        if (_sameOrigin)
        {
            var (rh, rp) = ParseHostPort(requestHost, isTls);
            if (!string.Equals(oh, rh, StringComparison.OrdinalIgnoreCase) || op != rp)
                return "not same origin";
        }

        if (_allowedOrigins != null)
        {
            if (!_allowedOrigins.TryGetValue(oh, out var allowed) ||
                !string.Equals(originUri.Scheme, allowed.Scheme, StringComparison.OrdinalIgnoreCase) ||
                op != allowed.Port)
            {
                return "not in the allowed list";
            }
        }

        return null;
    }

    private static (string host, int port) GetHostAndPort(bool tls, string host, int port)
    {
        if (port <= 0)
            port = tls ? 443 : 80;
        return (host.ToLowerInvariant(), port);
    }

    private static (string host, int port) ParseHostPort(string hostPort, bool isTls)
    {
        var colonIdx = hostPort.LastIndexOf(':');
        if (colonIdx > 0 && int.TryParse(hostPort.AsSpan(colonIdx + 1), out var port))
            return (hostPort[..colonIdx].ToLowerInvariant(), port);
        return (hostPort.ToLowerInvariant(), isTls ? 443 : 80);
    }

    private readonly record struct AllowedOrigin(string Scheme, int Port);
}

Step 4: Run test to verify it passes

Run: dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsOriginCheckerTests" -v normal Expected: PASS

Step 5: Commit

git add src/NATS.Server/WebSocket/WsOriginChecker.cs tests/NATS.Server.Tests/WebSocket/WsOriginCheckerTests.cs
git commit -m "feat: add WebSocket origin checker"

Task 3: Add WsFrameWriter (frame header construction, masking, control frames)

Files:

  • Create: src/NATS.Server/WebSocket/WsFrameWriter.cs
  • Create: tests/NATS.Server.Tests/WebSocket/WsFrameWriterTests.cs

Reference: golang/nats-server/server/websocket.go lines 543-726 (wsFillFrameHeader, wsCreateFrameHeader, wsMaskBuf, wsCreateCloseMessage, wsEnqueueControlMessageLocked)

Step 1: Write the failing test

Create tests/NATS.Server.Tests/WebSocket/WsFrameWriterTests.cs:

using System.Buffers.Binary;
using NATS.Server.WebSocket;
using Shouldly;

namespace NATS.Server.Tests.WebSocket;

public class WsFrameWriterTests
{
    [Fact]
    public void CreateFrameHeader_SmallPayload_7BitLength()
    {
        var (header, _) = WsFrameWriter.CreateFrameHeader(
            useMasking: false, compressed: false,
            opcode: WsConstants.BinaryMessage, payloadLength: 100);
        header.Length.ShouldBe(2);
        (header[0] & WsConstants.FinalBit).ShouldNotBe(0); // FIN set
        (header[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage);
        (header[1] & 0x7F).ShouldBe(100);
    }

    [Fact]
    public void CreateFrameHeader_MediumPayload_16BitLength()
    {
        var (header, _) = WsFrameWriter.CreateFrameHeader(
            useMasking: false, compressed: false,
            opcode: WsConstants.BinaryMessage, payloadLength: 1000);
        header.Length.ShouldBe(4);
        (header[1] & 0x7F).ShouldBe(126);
        BinaryPrimitives.ReadUInt16BigEndian(header.AsSpan(2)).ShouldBe((ushort)1000);
    }

    [Fact]
    public void CreateFrameHeader_LargePayload_64BitLength()
    {
        var (header, _) = WsFrameWriter.CreateFrameHeader(
            useMasking: false, compressed: false,
            opcode: WsConstants.BinaryMessage, payloadLength: 70000);
        header.Length.ShouldBe(10);
        (header[1] & 0x7F).ShouldBe(127);
        BinaryPrimitives.ReadUInt64BigEndian(header.AsSpan(2)).ShouldBe(70000UL);
    }

    [Fact]
    public void CreateFrameHeader_WithMasking_Adds4ByteKey()
    {
        var (header, key) = WsFrameWriter.CreateFrameHeader(
            useMasking: true, compressed: false,
            opcode: WsConstants.BinaryMessage, payloadLength: 10);
        header.Length.ShouldBe(6); // 2 header + 4 mask key
        (header[1] & WsConstants.MaskBit).ShouldNotBe(0);
        key.ShouldNotBeNull();
        key.Length.ShouldBe(4);
    }

    [Fact]
    public void CreateFrameHeader_Compressed_SetsRsv1Bit()
    {
        var (header, _) = WsFrameWriter.CreateFrameHeader(
            useMasking: false, compressed: true,
            opcode: WsConstants.BinaryMessage, payloadLength: 10);
        (header[0] & WsConstants.Rsv1Bit).ShouldNotBe(0);
    }

    [Fact]
    public void MaskBuf_XorsCorrectly()
    {
        byte[] key = [0xAA, 0xBB, 0xCC, 0xDD];
        byte[] data = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06];
        byte[] expected = new byte[data.Length];
        for (int i = 0; i < data.Length; i++)
            expected[i] = (byte)(data[i] ^ key[i & 3]);

        WsFrameWriter.MaskBuf(key, data);
        data.ShouldBe(expected);
    }

    [Fact]
    public void MaskBuf_RoundTrip()
    {
        byte[] key = [0x12, 0x34, 0x56, 0x78];
        byte[] original = "Hello, WebSocket!"u8.ToArray();
        var data = original.ToArray();

        WsFrameWriter.MaskBuf(key, data);
        data.ShouldNotBe(original);
        WsFrameWriter.MaskBuf(key, data);
        data.ShouldBe(original);
    }

    [Fact]
    public void CreateCloseMessage_WithStatusAndBody()
    {
        var msg = WsFrameWriter.CreateCloseMessage(1000, "normal closure");
        msg.Length.ShouldBe(2 + "normal closure".Length);
        BinaryPrimitives.ReadUInt16BigEndian(msg).ShouldBe((ushort)1000);
    }

    [Fact]
    public void CreateCloseMessage_LongBody_Truncated()
    {
        var longBody = new string('x', 200);
        var msg = WsFrameWriter.CreateCloseMessage(1000, longBody);
        msg.Length.ShouldBeLessThanOrEqualTo(WsConstants.MaxControlPayloadSize);
    }

    [Fact]
    public void MapCloseStatus_ClientClosed_NormalClosure()
    {
        WsFrameWriter.MapCloseStatus(ClientClosedReason.ClientClosed)
            .ShouldBe(WsConstants.CloseStatusNormalClosure);
    }

    [Fact]
    public void MapCloseStatus_AuthTimeout_PolicyViolation()
    {
        WsFrameWriter.MapCloseStatus(ClientClosedReason.AuthenticationTimeout)
            .ShouldBe(WsConstants.CloseStatusPolicyViolation);
    }

    [Fact]
    public void MapCloseStatus_ParseError_ProtocolError()
    {
        WsFrameWriter.MapCloseStatus(ClientClosedReason.ParseError)
            .ShouldBe(WsConstants.CloseStatusProtocolError);
    }

    [Fact]
    public void MapCloseStatus_MaxPayload_MessageTooBig()
    {
        WsFrameWriter.MapCloseStatus(ClientClosedReason.MaxPayloadExceeded)
            .ShouldBe(WsConstants.CloseStatusMessageTooBig);
    }

    [Fact]
    public void BuildControlFrame_PingNomask()
    {
        var frame = WsFrameWriter.BuildControlFrame(WsConstants.PingMessage, [], useMasking: false);
        frame.Length.ShouldBe(2);
        (frame[0] & WsConstants.FinalBit).ShouldNotBe(0);
        (frame[0] & 0x0F).ShouldBe(WsConstants.PingMessage);
        (frame[1] & 0x7F).ShouldBe(0);
    }

    [Fact]
    public void BuildControlFrame_PongWithPayload()
    {
        byte[] payload = [1, 2, 3, 4];
        var frame = WsFrameWriter.BuildControlFrame(WsConstants.PongMessage, payload, useMasking: false);
        frame.Length.ShouldBe(2 + 4);
        frame[2..].ShouldBe(payload);
    }
}

Step 2: Run test to verify it fails

Run: dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsFrameWriterTests" -v normal Expected: FAIL

Step 3: Write minimal implementation

Create src/NATS.Server/WebSocket/WsFrameWriter.cs:

using System.Buffers.Binary;
using System.Security.Cryptography;
using System.Text;

namespace NATS.Server.WebSocket;

/// <summary>
/// WebSocket frame construction, masking, and control message creation.
/// Ported from golang/nats-server/server/websocket.go lines 543-726.
/// </summary>
public static class WsFrameWriter
{
    /// <summary>
    /// Creates a complete frame header for a single-frame message (first=true, final=true).
    /// Returns (header bytes, mask key or null).
    /// </summary>
    public static (byte[] header, byte[]? key) CreateFrameHeader(
        bool useMasking, bool compressed, int opcode, int payloadLength)
    {
        var fh = new byte[WsConstants.MaxFrameHeaderSize];
        var (n, key) = FillFrameHeader(fh, useMasking,
            first: true, final: true, compressed: compressed, opcode: opcode, payloadLength: payloadLength);
        return (fh[..n], key);
    }

    /// <summary>
    /// Fills a pre-allocated frame header buffer.
    /// Returns (bytes written, mask key or null).
    /// </summary>
    public static (int written, byte[]? key) FillFrameHeader(
        Span<byte> fh, bool useMasking, bool first, bool final, bool compressed, int opcode, int payloadLength)
    {
        byte b0 = first ? (byte)opcode : (byte)0;
        if (final) b0 |= WsConstants.FinalBit;
        if (compressed) b0 |= WsConstants.Rsv1Bit;

        byte b1 = 0;
        if (useMasking) b1 |= WsConstants.MaskBit;

        int n;
        switch (payloadLength)
        {
            case <= 125:
                n = 2;
                fh[0] = b0;
                fh[1] = (byte)(b1 | (byte)payloadLength);
                break;
            case < 65536:
                n = 4;
                fh[0] = b0;
                fh[1] = (byte)(b1 | 126);
                BinaryPrimitives.WriteUInt16BigEndian(fh[2..], (ushort)payloadLength);
                break;
            default:
                n = 10;
                fh[0] = b0;
                fh[1] = (byte)(b1 | 127);
                BinaryPrimitives.WriteUInt64BigEndian(fh[2..], (ulong)payloadLength);
                break;
        }

        byte[]? key = null;
        if (useMasking)
        {
            key = new byte[4];
            RandomNumberGenerator.Fill(key);
            key.CopyTo(fh[n..]);
            n += 4;
        }

        return (n, key);
    }

    /// <summary>
    /// XOR masks a buffer with a 4-byte key. Applies in-place.
    /// </summary>
    public static void MaskBuf(ReadOnlySpan<byte> key, Span<byte> buf)
    {
        for (int i = 0; i < buf.Length; i++)
            buf[i] ^= key[i & 3];
    }

    /// <summary>
    /// XOR masks multiple contiguous buffers as if they were one.
    /// </summary>
    public static void MaskBufs(ReadOnlySpan<byte> key, List<byte[]> bufs)
    {
        int pos = 0;
        foreach (var buf in bufs)
        {
            for (int j = 0; j < buf.Length; j++)
            {
                buf[j] ^= key[pos & 3];
                pos++;
            }
        }
    }

    /// <summary>
    /// Creates a close message payload: 2-byte status code + optional UTF-8 body.
    /// Body truncated to fit MaxControlPayloadSize with "..." suffix.
    /// </summary>
    public static byte[] CreateCloseMessage(int status, string body)
    {
        if (body.Length > WsConstants.MaxControlPayloadSize - WsConstants.CloseStatusSize)
        {
            body = body[..(WsConstants.MaxControlPayloadSize - WsConstants.CloseStatusSize - 3)] + "...";
        }

        var bodyBytes = Encoding.UTF8.GetBytes(body);
        var buf = new byte[WsConstants.CloseStatusSize + bodyBytes.Length];
        BinaryPrimitives.WriteUInt16BigEndian(buf, (ushort)status);
        bodyBytes.CopyTo(buf.AsSpan(WsConstants.CloseStatusSize));
        return buf;
    }

    /// <summary>
    /// Builds a complete control frame (header + payload, optional masking).
    /// </summary>
    public static byte[] BuildControlFrame(int opcode, ReadOnlySpan<byte> payload, bool useMasking)
    {
        int headerSize = 2 + (useMasking ? 4 : 0);
        var frame = new byte[headerSize + payload.Length];
        var span = frame.AsSpan();
        var (n, key) = FillFrameHeader(span, useMasking,
            first: true, final: true, compressed: false, opcode: opcode, payloadLength: payload.Length);
        if (payload.Length > 0)
        {
            payload.CopyTo(span[n..]);
            if (useMasking && key != null)
                MaskBuf(key, span[n..]);
        }

        return frame;
    }

    /// <summary>
    /// Maps a ClientClosedReason to a WebSocket close status code.
    /// Matches Go wsEnqueueCloseMessage in websocket.go lines 668-694.
    /// </summary>
    public static int MapCloseStatus(ClientClosedReason reason) => reason switch
    {
        ClientClosedReason.ClientClosed => WsConstants.CloseStatusNormalClosure,
        ClientClosedReason.AuthenticationTimeout or
        ClientClosedReason.AuthenticationViolation or
        ClientClosedReason.SlowConsumerPendingBytes or
        ClientClosedReason.SlowConsumerWriteDeadline or
        ClientClosedReason.MaxSubscriptionsExceeded or
        ClientClosedReason.AuthenticationExpired => WsConstants.CloseStatusPolicyViolation,
        ClientClosedReason.TlsHandshakeError => WsConstants.CloseStatusTlsHandshake,
        ClientClosedReason.ParseError or
        ClientClosedReason.ProtocolViolation => WsConstants.CloseStatusProtocolError,
        ClientClosedReason.MaxPayloadExceeded => WsConstants.CloseStatusMessageTooBig,
        ClientClosedReason.WriteError or
        ClientClosedReason.ReadError or
        ClientClosedReason.StaleConnection or
        ClientClosedReason.ServerShutdown => WsConstants.CloseStatusGoingAway,
        _ => WsConstants.CloseStatusInternalSrvError,
    };
}

Step 4: Run test to verify it passes

Run: dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsFrameWriterTests" -v normal Expected: PASS

Step 5: Commit

git add src/NATS.Server/WebSocket/WsFrameWriter.cs tests/NATS.Server.Tests/WebSocket/WsFrameWriterTests.cs
git commit -m "feat: add WebSocket frame writer with masking and close status mapping"

Task 4: Add WsReadInfo (frame reader state machine)

Files:

  • Create: src/NATS.Server/WebSocket/WsReadInfo.cs
  • Create: tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs

Reference: golang/nats-server/server/websocket.go lines 156-440 (wsReadInfo, wsRead, unmask, decompress)

Step 1: Write the failing test

Create tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs:

using System.Buffers.Binary;
using NATS.Server.WebSocket;
using Shouldly;

namespace NATS.Server.Tests.WebSocket;

public class WsFrameReadTests
{
    /// <summary>Helper: build a single unmasked binary frame.</summary>
    private static byte[] BuildFrame(byte[] payload, bool fin = true, bool compressed = false, int opcode = WsConstants.BinaryMessage, bool mask = false, byte[]? maskKey = null)
    {
        int headerLen = 2;
        int payloadLen = payload.Length;
        byte b0 = (byte)opcode;
        if (fin) b0 |= WsConstants.FinalBit;
        if (compressed) b0 |= WsConstants.Rsv1Bit;
        byte b1 = 0;
        if (mask) b1 |= WsConstants.MaskBit;

        byte[] lenBytes;
        if (payloadLen <= 125)
        {
            lenBytes = [(byte)(b1 | (byte)payloadLen)];
        }
        else if (payloadLen < 65536)
        {
            lenBytes = new byte[3];
            lenBytes[0] = (byte)(b1 | 126);
            BinaryPrimitives.WriteUInt16BigEndian(lenBytes.AsSpan(1), (ushort)payloadLen);
        }
        else
        {
            lenBytes = new byte[9];
            lenBytes[0] = (byte)(b1 | 127);
            BinaryPrimitives.WriteUInt64BigEndian(lenBytes.AsSpan(1), (ulong)payloadLen);
        }

        int totalLen = 1 + lenBytes.Length + (mask ? 4 : 0) + payloadLen;
        var frame = new byte[totalLen];
        frame[0] = b0;
        lenBytes.CopyTo(frame.AsSpan(1));
        int pos = 1 + lenBytes.Length;
        if (mask && maskKey != null)
        {
            maskKey.CopyTo(frame.AsSpan(pos));
            pos += 4;
            var maskedPayload = payload.ToArray();
            WsFrameWriter.MaskBuf(maskKey, maskedPayload);
            maskedPayload.CopyTo(frame.AsSpan(pos));
        }
        else
        {
            payload.CopyTo(frame.AsSpan(pos));
        }
        return frame;
    }

    [Fact]
    public void ReadSingleUnmaskedFrame()
    {
        var payload = "Hello"u8.ToArray();
        var frame = BuildFrame(payload);

        var readInfo = new WsReadInfo(expectMask: false);
        var stream = new MemoryStream(frame);
        var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);

        result.Count.ShouldBe(1);
        result[0].ShouldBe(payload);
    }

    [Fact]
    public void ReadMaskedFrame()
    {
        var payload = "Hello"u8.ToArray();
        byte[] key = [0x37, 0xFA, 0x21, 0x3D];
        var frame = BuildFrame(payload, mask: true, maskKey: key);

        var readInfo = new WsReadInfo(expectMask: true);
        var stream = new MemoryStream(frame);
        var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);

        result.Count.ShouldBe(1);
        result[0].ShouldBe(payload);
    }

    [Fact]
    public void Read16BitLengthFrame()
    {
        var payload = new byte[200];
        Random.Shared.NextBytes(payload);
        var frame = BuildFrame(payload);

        var readInfo = new WsReadInfo(expectMask: false);
        var stream = new MemoryStream(frame);
        var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);

        result.Count.ShouldBe(1);
        result[0].ShouldBe(payload);
    }

    [Fact]
    public void ReadPingFrame_ReturnsPongAction()
    {
        var frame = BuildFrame([], opcode: WsConstants.PingMessage);

        var readInfo = new WsReadInfo(expectMask: false);
        var stream = new MemoryStream(frame);
        var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);

        result.Count.ShouldBe(0); // control frames don't produce payload
        readInfo.PendingControlFrames.Count.ShouldBe(1);
        readInfo.PendingControlFrames[0].Opcode.ShouldBe(WsConstants.PongMessage);
    }

    [Fact]
    public void ReadCloseFrame_ReturnsCloseAction()
    {
        var closePayload = new byte[2];
        BinaryPrimitives.WriteUInt16BigEndian(closePayload, 1000);
        var frame = BuildFrame(closePayload, opcode: WsConstants.CloseMessage);

        var readInfo = new WsReadInfo(expectMask: false);
        var stream = new MemoryStream(frame);
        var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);

        result.Count.ShouldBe(0);
        readInfo.CloseReceived.ShouldBeTrue();
        readInfo.CloseStatus.ShouldBe(1000);
    }

    [Fact]
    public void ReadPongFrame_NoAction()
    {
        var frame = BuildFrame([], opcode: WsConstants.PongMessage);

        var readInfo = new WsReadInfo(expectMask: false);
        var stream = new MemoryStream(frame);
        var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);

        result.Count.ShouldBe(0);
        readInfo.PendingControlFrames.Count.ShouldBe(0);
    }

    [Fact]
    public void Unmask_Optimized_8ByteChunks()
    {
        byte[] key = [0xAA, 0xBB, 0xCC, 0xDD];
        var original = new byte[32];
        Random.Shared.NextBytes(original);
        var masked = original.ToArray();

        // Mask it
        for (int i = 0; i < masked.Length; i++)
            masked[i] ^= key[i & 3];

        // Unmask using the state machine
        var info = new WsReadInfo(expectMask: true);
        info.SetMaskKey(key);
        info.Unmask(masked);

        masked.ShouldBe(original);
    }
}

Step 2: Run test to verify it fails

Run: dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsFrameReadTests" -v normal Expected: FAIL

Step 3: Write minimal implementation

Create src/NATS.Server/WebSocket/WsReadInfo.cs:

using System.Buffers.Binary;
using System.Text;

namespace NATS.Server.WebSocket;

/// <summary>
/// Per-connection WebSocket frame reading state machine.
/// Ported from golang/nats-server/server/websocket.go lines 156-506.
/// </summary>
public struct WsReadInfo
{
    public int Remaining;
    public bool FrameStart;
    public bool FirstFrame;
    public bool FrameCompressed;
    public bool ExpectMask;
    public byte MaskKeyPos;
    public byte[] MaskKey;
    public List<byte[]>? CompressedBuffers;
    public int CompressedOffset;

    // Control frame outputs
    public List<ControlFrameAction> PendingControlFrames;
    public bool CloseReceived;
    public int CloseStatus;
    public string? CloseBody;

    public WsReadInfo(bool expectMask)
    {
        Remaining = 0;
        FrameStart = true;
        FirstFrame = true;
        FrameCompressed = false;
        ExpectMask = expectMask;
        MaskKeyPos = 0;
        MaskKey = new byte[4];
        CompressedBuffers = null;
        CompressedOffset = 0;
        PendingControlFrames = [];
        CloseReceived = false;
        CloseStatus = 0;
        CloseBody = null;
    }

    public void SetMaskKey(ReadOnlySpan<byte> key)
    {
        key[..4].CopyTo(MaskKey);
        MaskKeyPos = 0;
    }

    /// <summary>
    /// Unmask buffer in-place using current mask key and position.
    /// Optimized for 8-byte chunks when buffer is large enough.
    /// Ported from websocket.go lines 509-536.
    /// </summary>
    public void Unmask(Span<byte> buf)
    {
        int p = MaskKeyPos;
        if (buf.Length < 16)
        {
            for (int i = 0; i < buf.Length; i++)
            {
                buf[i] ^= MaskKey[p & 3];
                p++;
            }
            MaskKeyPos = (byte)(p & 3);
            return;
        }

        // Build 8-byte key for bulk XOR
        Span<byte> k = stackalloc byte[8];
        for (int i = 0; i < 8; i++)
            k[i] = MaskKey[(p + i) & 3];
        ulong km = BinaryPrimitives.ReadUInt64BigEndian(k);

        int n = (buf.Length / 8) * 8;
        for (int i = 0; i < n; i += 8)
        {
            ulong tmp = BinaryPrimitives.ReadUInt64BigEndian(buf[i..]);
            tmp ^= km;
            BinaryPrimitives.WriteUInt64BigEndian(buf[i..], tmp);
        }

        // Handle remaining bytes
        var tail = buf[n..];
        for (int i = 0; i < tail.Length; i++)
        {
            tail[i] ^= MaskKey[p & 3];
            p++;
        }
        MaskKeyPos = (byte)(p & 3);
    }

    /// <summary>
    /// Read and decode WebSocket frames from a buffer.
    /// Returns list of decoded payload byte arrays.
    /// Ported from websocket.go lines 208-351.
    /// </summary>
    public static List<byte[]> ReadFrames(ref WsReadInfo r, Stream stream, int available, int maxPayload)
    {
        var bufs = new List<byte[]>();
        var buf = new byte[available];
        int bytesRead = 0;

        // Fill the buffer from the stream
        while (bytesRead < available)
        {
            int n = stream.Read(buf, bytesRead, available - bytesRead);
            if (n == 0) break;
            bytesRead += n;
        }

        int pos = 0;
        int max = bytesRead;

        while (pos < max)
        {
            if (r.FrameStart)
            {
                if (pos >= max) break;
                byte b0 = buf[pos];
                int frameType = b0 & 0x0F;
                bool final = (b0 & WsConstants.FinalBit) != 0;
                bool compressed = (b0 & WsConstants.Rsv1Bit) != 0;
                pos++;

                // Read second byte
                var (b1Buf, newPos) = WsGet(stream, buf, pos, max, 1);
                pos = newPos;
                byte b1 = b1Buf[0];

                // Check mask bit
                if (r.ExpectMask && (b1 & WsConstants.MaskBit) == 0)
                    throw new InvalidOperationException("mask bit missing");

                r.Remaining = b1 & 0x7F;

                // Validate frame types
                if (WsConstants.IsControlFrame(frameType))
                {
                    if (r.Remaining > WsConstants.MaxControlPayloadSize)
                        throw new InvalidOperationException("control frame length too large");
                    if (!final)
                        throw new InvalidOperationException("control frame does not have final bit set");
                }
                else if (frameType == WsConstants.TextMessage || frameType == WsConstants.BinaryMessage)
                {
                    if (!r.FirstFrame)
                        throw new InvalidOperationException("new message before previous finished");
                    r.FirstFrame = final;
                    r.FrameCompressed = compressed;
                }
                else if (frameType == WsConstants.ContinuationFrame)
                {
                    if (r.FirstFrame || compressed)
                        throw new InvalidOperationException("invalid continuation frame");
                    r.FirstFrame = final;
                }
                else
                {
                    throw new InvalidOperationException($"unknown opcode {frameType}");
                }

                // Extended payload length
                switch (r.Remaining)
                {
                    case 126:
                    {
                        var (lenBuf, p2) = WsGet(stream, buf, pos, max, 2);
                        pos = p2;
                        r.Remaining = BinaryPrimitives.ReadUInt16BigEndian(lenBuf);
                        break;
                    }
                    case 127:
                    {
                        var (lenBuf, p2) = WsGet(stream, buf, pos, max, 8);
                        pos = p2;
                        r.Remaining = (int)BinaryPrimitives.ReadUInt64BigEndian(lenBuf);
                        break;
                    }
                }

                // Read mask key
                if (r.ExpectMask && (b1 & WsConstants.MaskBit) != 0)
                {
                    var (keyBuf, p2) = WsGet(stream, buf, pos, max, 4);
                    pos = p2;
                    keyBuf.AsSpan(0, 4).CopyTo(r.MaskKey);
                    r.MaskKeyPos = 0;
                }

                // Handle control frames
                if (WsConstants.IsControlFrame(frameType))
                {
                    pos = HandleControlFrame(ref r, frameType, stream, buf, pos, max);
                    continue;
                }

                r.FrameStart = false;
            }

            if (pos < max)
            {
                int n = r.Remaining;
                if (pos + n > max) n = max - pos;

                var payloadSlice = buf.AsSpan(pos, n).ToArray();
                pos += n;
                r.Remaining -= n;

                if (r.ExpectMask)
                    r.Unmask(payloadSlice);

                bool addToBufs = true;
                if (r.FrameCompressed)
                {
                    addToBufs = false;
                    r.CompressedBuffers ??= [];
                    r.CompressedBuffers.Add(payloadSlice);

                    if (r.FirstFrame && r.Remaining == 0)
                    {
                        var decompressed = WsCompression.Decompress(r.CompressedBuffers, maxPayload);
                        r.CompressedBuffers = null;
                        r.FrameCompressed = false;
                        addToBufs = true;
                        payloadSlice = decompressed;
                    }
                }

                if (addToBufs && payloadSlice.Length > 0)
                    bufs.Add(payloadSlice);

                if (r.Remaining == 0)
                    r.FrameStart = true;
            }
        }

        return bufs;
    }

    private static int HandleControlFrame(ref WsReadInfo r, int frameType, Stream stream, byte[] buf, int pos, int max)
    {
        byte[]? payload = null;
        if (r.Remaining > 0)
        {
            var (payloadBuf, newPos) = WsGet(stream, buf, pos, max, r.Remaining);
            pos = newPos;
            payload = payloadBuf;
            if (r.ExpectMask)
                r.Unmask(payload);
            r.Remaining = 0;
        }

        switch (frameType)
        {
            case WsConstants.CloseMessage:
                r.CloseReceived = true;
                r.CloseStatus = WsConstants.CloseStatusNoStatusReceived;
                if (payload != null && payload.Length >= WsConstants.CloseStatusSize)
                {
                    r.CloseStatus = BinaryPrimitives.ReadUInt16BigEndian(payload);
                    if (payload.Length > WsConstants.CloseStatusSize)
                        r.CloseBody = Encoding.UTF8.GetString(payload.AsSpan(WsConstants.CloseStatusSize));
                }
                if (r.CloseStatus != WsConstants.CloseStatusNoStatusReceived)
                {
                    var closeMsg = WsFrameWriter.CreateCloseMessage(r.CloseStatus, r.CloseBody ?? "");
                    r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.CloseMessage, closeMsg));
                }
                break;

            case WsConstants.PingMessage:
                r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.PongMessage, payload ?? []));
                break;

            case WsConstants.PongMessage:
                // Nothing to do
                break;
        }

        return pos;
    }

    /// <summary>
    /// Gets needed bytes from buffer or reads from stream.
    /// Ported from websocket.go lines 178-193.
    /// </summary>
    private static (byte[] data, int newPos) WsGet(Stream stream, byte[] buf, int pos, int max, int needed)
    {
        int avail = max - pos;
        if (avail >= needed)
            return (buf[pos..(pos + needed)], pos + needed);

        var b = new byte[needed];
        int start = 0;
        if (avail > 0)
        {
            Buffer.BlockCopy(buf, pos, b, 0, avail);
            start = avail;
        }
        while (start < needed)
        {
            int n = stream.Read(b, start, needed - start);
            if (n == 0) throw new IOException("unexpected end of stream");
            start += n;
        }
        return (b, pos + avail);
    }
}

public readonly record struct ControlFrameAction(int Opcode, byte[] Payload);

Step 4: Run test to verify it passes

Run: dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsFrameReadTests" -v normal Expected: PASS

Step 5: Commit

git add src/NATS.Server/WebSocket/WsReadInfo.cs tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs
git commit -m "feat: add WebSocket frame reader state machine"

Task 5: Add WsCompression (permessage-deflate)

Files:

  • Create: src/NATS.Server/WebSocket/WsCompression.cs
  • Create: tests/NATS.Server.Tests/WebSocket/WsCompressionTests.cs

Reference: golang/nats-server/server/websocket.go lines 403-440 (decompress), lines 1391-1466 (compress)

Step 1: Write the failing test

Create tests/NATS.Server.Tests/WebSocket/WsCompressionTests.cs:

using NATS.Server.WebSocket;
using Shouldly;

namespace NATS.Server.Tests.WebSocket;

public class WsCompressionTests
{
    [Fact]
    public void CompressDecompress_RoundTrip()
    {
        var original = "Hello, WebSocket compression test! This is long enough to compress."u8.ToArray();
        var compressed = WsCompression.Compress(original);
        compressed.ShouldNotBeNull();
        compressed.Length.ShouldBeGreaterThan(0);

        var decompressed = WsCompression.Decompress([compressed], maxPayload: 4096);
        decompressed.ShouldBe(original);
    }

    [Fact]
    public void Decompress_ExceedsMaxPayload_Throws()
    {
        var original = new byte[1000];
        Random.Shared.NextBytes(original);
        var compressed = WsCompression.Compress(original);

        Should.Throw<InvalidOperationException>(() =>
            WsCompression.Decompress([compressed], maxPayload: 100));
    }

    [Fact]
    public void Compress_RemovesTrailing4Bytes()
    {
        var data = new byte[200];
        Random.Shared.NextBytes(data);
        var compressed = WsCompression.Compress(data);

        // The compressed data should be valid for decompression when we add the trailer back
        var decompressed = WsCompression.Decompress([compressed], maxPayload: 4096);
        decompressed.ShouldBe(data);
    }

    [Fact]
    public void Decompress_MultipleBuffers()
    {
        var original = new byte[500];
        Random.Shared.NextBytes(original);
        var compressed = WsCompression.Compress(original);

        // Split compressed data into multiple chunks
        int mid = compressed.Length / 2;
        var chunk1 = compressed[..mid];
        var chunk2 = compressed[mid..];

        var decompressed = WsCompression.Decompress([chunk1, chunk2], maxPayload: 4096);
        decompressed.ShouldBe(original);
    }
}

Step 2: Run test to verify it fails

Run: dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsCompressionTests" -v normal Expected: FAIL

Step 3: Write minimal implementation

Create src/NATS.Server/WebSocket/WsCompression.cs:

using System.IO.Compression;

namespace NATS.Server.WebSocket;

/// <summary>
/// permessage-deflate compression/decompression for WebSocket frames (RFC 7692).
/// Ported from golang/nats-server/server/websocket.go lines 403-440 and 1391-1466.
/// </summary>
public static class WsCompression
{
    /// <summary>
    /// Compresses data using deflate. Removes trailing 4 bytes (sync marker)
    /// per RFC 7692 Section 7.2.1.
    /// </summary>
    public static byte[] Compress(ReadOnlySpan<byte> data)
    {
        using var output = new MemoryStream();
        using (var deflate = new DeflateStream(output, CompressionLevel.Fastest, leaveOpen: true))
        {
            deflate.Write(data);
            deflate.Flush();
        }

        var compressed = output.ToArray();

        // Remove trailing 4-byte sync marker (0x00 0x00 0xff 0xff) per RFC 7692
        if (compressed.Length >= 4)
            return compressed[..^4];

        return compressed;
    }

    /// <summary>
    /// Decompresses collected compressed buffers.
    /// Appends trailer bytes before decompressing per RFC 7692 Section 7.2.2.
    /// </summary>
    public static byte[] Decompress(List<byte[]> compressedBuffers, int maxPayload)
    {
        if (maxPayload <= 0)
            maxPayload = 1024 * 1024; // Default 1MB

        // Concatenate all compressed buffers + trailer
        int totalLen = 0;
        foreach (var buf in compressedBuffers)
            totalLen += buf.Length;
        totalLen += WsConstants.DecompressTrailer.Length;

        var combined = new byte[totalLen];
        int offset = 0;
        foreach (var buf in compressedBuffers)
        {
            buf.CopyTo(combined, offset);
            offset += buf.Length;
        }
        WsConstants.DecompressTrailer.CopyTo(combined, offset);

        using var input = new MemoryStream(combined);
        using var deflate = new DeflateStream(input, CompressionMode.Decompress);
        using var output = new MemoryStream();

        var readBuf = new byte[4096];
        int totalRead = 0;
        int n;
        while ((n = deflate.Read(readBuf, 0, readBuf.Length)) > 0)
        {
            totalRead += n;
            if (totalRead > maxPayload)
                throw new InvalidOperationException("decompressed data exceeds maximum payload size");
            output.Write(readBuf, 0, n);
        }

        return output.ToArray();
    }
}

Step 4: Run test to verify it passes

Run: dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsCompressionTests" -v normal Expected: PASS

Step 5: Commit

git add src/NATS.Server/WebSocket/WsCompression.cs tests/NATS.Server.Tests/WebSocket/WsCompressionTests.cs
git commit -m "feat: add WebSocket permessage-deflate compression"

Task 6: Add WsUpgrade (HTTP upgrade handshake)

Files:

  • Create: src/NATS.Server/WebSocket/WsUpgrade.cs
  • Create: tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs

Reference: golang/nats-server/server/websocket.go lines 731-917 (wsUpgrade, wsHeaderContains, wsAcceptKey, wsPMCExtensionSupport)

Step 1: Write the failing test

Create tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs:

using System.Text;
using NATS.Server.WebSocket;
using Shouldly;

namespace NATS.Server.Tests.WebSocket;

public class WsUpgradeTests
{
    private static string BuildValidRequest(string path = "/", string? extraHeaders = null)
    {
        var sb = new StringBuilder();
        sb.AppendLine($"GET {path} HTTP/1.1");
        sb.AppendLine("Host: localhost:4222");
        sb.AppendLine("Upgrade: websocket");
        sb.AppendLine("Connection: Upgrade");
        sb.AppendLine("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==");
        sb.AppendLine("Sec-WebSocket-Version: 13");
        if (extraHeaders != null)
            sb.Append(extraHeaders);
        sb.AppendLine();
        return sb.ToString();
    }

    [Fact]
    public async Task ValidUpgrade_Returns101()
    {
        var request = BuildValidRequest();
        var (inputStream, outputStream) = CreateStreamPair(request);

        var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });

        result.Success.ShouldBeTrue();
        result.Kind.ShouldBe(WsClientKind.Client);
        var response = ReadResponse(outputStream);
        response.ShouldContain("HTTP/1.1 101");
        response.ShouldContain("Upgrade: websocket");
        response.ShouldContain("Sec-WebSocket-Accept:");
    }

    [Fact]
    public async Task MissingUpgradeHeader_Returns400()
    {
        var request = "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
        var (inputStream, outputStream) = CreateStreamPair(request);

        var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });

        result.Success.ShouldBeFalse();
        ReadResponse(outputStream).ShouldContain("400");
    }

    [Fact]
    public async Task MissingHost_Returns400()
    {
        var request = "GET / HTTP/1.1\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
        var (inputStream, outputStream) = CreateStreamPair(request);

        var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });

        result.Success.ShouldBeFalse();
    }

    [Fact]
    public async Task WrongVersion_Returns400()
    {
        var request = "GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 12\r\n\r\n";
        var (inputStream, outputStream) = CreateStreamPair(request);

        var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });

        result.Success.ShouldBeFalse();
    }

    [Fact]
    public async Task LeafNodePath_ReturnsLeafKind()
    {
        var request = BuildValidRequest("/leafnode");
        var (inputStream, outputStream) = CreateStreamPair(request);

        var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });

        result.Success.ShouldBeTrue();
        result.Kind.ShouldBe(WsClientKind.Leaf);
    }

    [Fact]
    public async Task MqttPath_ReturnsMqttKind()
    {
        var request = BuildValidRequest("/mqtt");
        var (inputStream, outputStream) = CreateStreamPair(request);

        var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });

        result.Success.ShouldBeTrue();
        result.Kind.ShouldBe(WsClientKind.Mqtt);
    }

    [Fact]
    public async Task CompressionNegotiation_WhenEnabled()
    {
        var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}; {WsConstants.PmcSrvNoCtx}; {WsConstants.PmcCliNoCtx}\r\n");
        var (inputStream, outputStream) = CreateStreamPair(request);

        var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true });

        result.Success.ShouldBeTrue();
        result.Compress.ShouldBeTrue();
        ReadResponse(outputStream).ShouldContain("permessage-deflate");
    }

    [Fact]
    public async Task CompressionNegotiation_WhenDisabled()
    {
        var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n");
        var (inputStream, outputStream) = CreateStreamPair(request);

        var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = false });

        result.Success.ShouldBeTrue();
        result.Compress.ShouldBeFalse();
    }

    [Fact]
    public async Task NoMaskingHeader_ForLeaf()
    {
        var request = BuildValidRequest("/leafnode", "Nats-No-Masking: true\r\n");
        var (inputStream, outputStream) = CreateStreamPair(request);

        var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });

        result.Success.ShouldBeTrue();
        result.MaskRead.ShouldBeFalse();
    }

    [Fact]
    public async Task BrowserDetection_Mozilla()
    {
        var request = BuildValidRequest(extraHeaders: "User-Agent: Mozilla/5.0 (Windows)\r\n");
        var (inputStream, outputStream) = CreateStreamPair(request);

        var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });

        result.Success.ShouldBeTrue();
        result.Browser.ShouldBeTrue();
    }

    [Fact]
    public async Task SafariDetection_NoCompFrag()
    {
        var request = BuildValidRequest(extraHeaders:
            "User-Agent: Mozilla/5.0 (Macintosh) Version/15.0 Safari/605.1.15\r\n" +
            $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n");
        var (inputStream, outputStream) = CreateStreamPair(request);

        var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true });

        result.Success.ShouldBeTrue();
        result.NoCompFrag.ShouldBeTrue();
    }

    [Fact]
    public async Task AcceptKey_MatchesRfc6455Example()
    {
        // RFC 6455 Section 4.2.2 example
        var key = WsUpgrade.ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ==");
        key.ShouldBe("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
    }

    [Fact]
    public async Task CookieExtraction()
    {
        var request = BuildValidRequest(extraHeaders:
            "Cookie: jwt_token=my-jwt; nats_user=admin; nats_pass=secret\r\n");
        var (inputStream, outputStream) = CreateStreamPair(request);

        var opts = new WebSocketOptions
        {
            NoTls = true,
            JwtCookie = "jwt_token",
            UsernameCookie = "nats_user",
            PasswordCookie = "nats_pass",
        };
        var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);

        result.Success.ShouldBeTrue();
        result.CookieJwt.ShouldBe("my-jwt");
        result.CookieUsername.ShouldBe("admin");
        result.CookiePassword.ShouldBe("secret");
    }

    [Fact]
    public async Task XForwardedFor_ExtractsClientIp()
    {
        var request = BuildValidRequest(extraHeaders: "X-Forwarded-For: 192.168.1.100\r\n");
        var (inputStream, outputStream) = CreateStreamPair(request);

        var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });

        result.Success.ShouldBeTrue();
        result.ClientIp.ShouldBe("192.168.1.100");
    }

    [Fact]
    public async Task PostMethod_Returns405()
    {
        var request = "POST / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
        var (inputStream, outputStream) = CreateStreamPair(request);

        var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });

        result.Success.ShouldBeFalse();
        ReadResponse(outputStream).ShouldContain("405");
    }

    // Helper: create a readable input stream and writable output stream
    private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest)
    {
        var inputBytes = Encoding.ASCII.GetBytes(httpRequest);
        return (new MemoryStream(inputBytes), new MemoryStream());
    }

    private static string ReadResponse(MemoryStream output)
    {
        output.Position = 0;
        return Encoding.ASCII.GetString(output.ToArray());
    }
}

Step 2: Run test to verify it fails

Run: dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsUpgradeTests" -v normal Expected: FAIL

Step 3: Write minimal implementation

Create src/NATS.Server/WebSocket/WsUpgrade.cs:

using System.Net;
using System.Security.Cryptography;
using System.Text;

namespace NATS.Server.WebSocket;

/// <summary>
/// WebSocket HTTP upgrade handshake handler.
/// Ported from golang/nats-server/server/websocket.go lines 731-917.
/// </summary>
public static class WsUpgrade
{
    /// <summary>
    /// Attempts to read an HTTP upgrade request from the input stream,
    /// validate per RFC 6455, and write the 101 response to the output stream.
    /// </summary>
    public static async Task<WsUpgradeResult> TryUpgradeAsync(
        Stream inputStream, Stream outputStream, WebSocketOptions options)
    {
        try
        {
            // Read HTTP request
            var (method, path, headers) = await ReadHttpRequestAsync(inputStream);

            // RFC 6455 Section 4.2.1 validation
            // Point 1: Method must be GET
            if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase))
                return await FailAsync(outputStream, 405, "request method must be GET");

            // Point 2: Host header required
            if (!headers.ContainsKey("Host"))
                return await FailAsync(outputStream, 400, "'Host' missing in request");

            // Point 3: Upgrade header must contain "websocket"
            if (!HeaderContains(headers, "Upgrade", "websocket"))
                return await FailAsync(outputStream, 400, "invalid value for header 'Upgrade'");

            // Point 4: Connection header must contain "Upgrade"
            if (!HeaderContains(headers, "Connection", "Upgrade"))
                return await FailAsync(outputStream, 400, "invalid value for header 'Connection'");

            // Point 5: Sec-WebSocket-Key required
            if (!headers.TryGetValue("Sec-WebSocket-Key", out var key) || string.IsNullOrEmpty(key))
                return await FailAsync(outputStream, 400, "key missing");

            // Point 6: Version must be 13
            if (!HeaderContains(headers, "Sec-WebSocket-Version", "13"))
                return await FailAsync(outputStream, 400, "invalid version");

            // Path routing
            var kind = path switch
            {
                _ when path.EndsWith("/leafnode") => WsClientKind.Leaf,
                _ when path.EndsWith("/mqtt") => WsClientKind.Mqtt,
                _ => WsClientKind.Client,
            };

            // Origin checking
            if (options.SameOrigin || options.AllowedOrigins is { Count: > 0 })
            {
                var checker = new WsOriginChecker(options.SameOrigin, options.AllowedOrigins);
                headers.TryGetValue("Origin", out var origin);
                if (string.IsNullOrEmpty(origin))
                    headers.TryGetValue("Sec-WebSocket-Origin", out origin);
                var originErr = checker.CheckOrigin(origin, headers.GetValueOrDefault("Host", ""), isTls: false);
                if (originErr != null)
                    return await FailAsync(outputStream, 403, $"origin not allowed: {originErr}");
            }

            // Compression negotiation
            bool compress = options.Compression;
            if (compress)
            {
                compress = headers.TryGetValue("Sec-WebSocket-Extensions", out var ext) &&
                           ext.Contains(WsConstants.PmcExtension, StringComparison.OrdinalIgnoreCase);
            }

            // No-masking negotiation
            bool noMasking = headers.TryGetValue(WsConstants.NoMaskingHeader, out var nmVal) &&
                             string.Equals(nmVal.Trim(), WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase);

            // Browser detection
            bool browser = false;
            bool noCompFrag = false;
            if (kind is WsClientKind.Client or WsClientKind.Mqtt &&
                headers.TryGetValue("User-Agent", out var ua) && ua.StartsWith("Mozilla/"))
            {
                browser = true;
                noCompFrag = compress && ua.Contains("Version/") && ua.Contains("Safari/");
            }

            // Cookie extraction
            string? cookieJwt = null, cookieUsername = null, cookiePassword = null, cookieToken = null;
            if ((kind is WsClientKind.Client or WsClientKind.Mqtt) &&
                headers.TryGetValue("Cookie", out var cookieHeader))
            {
                var cookies = ParseCookies(cookieHeader);
                if (options.JwtCookie != null) cookies.TryGetValue(options.JwtCookie, out cookieJwt);
                if (options.UsernameCookie != null) cookies.TryGetValue(options.UsernameCookie, out cookieUsername);
                if (options.PasswordCookie != null) cookies.TryGetValue(options.PasswordCookie, out cookiePassword);
                if (options.TokenCookie != null) cookies.TryGetValue(options.TokenCookie, out cookieToken);
            }

            // X-Forwarded-For
            string? clientIp = null;
            if (headers.TryGetValue(WsConstants.XForwardedForHeader, out var xff))
            {
                var ip = xff.Split(',')[0].Trim();
                if (IPAddress.TryParse(ip, out _))
                    clientIp = ip;
            }

            // Build 101 response
            var response = new StringBuilder();
            response.Append("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ");
            response.Append(ComputeAcceptKey(key));
            response.Append("\r\n");
            if (compress)
                response.Append(WsConstants.PmcFullResponse);
            if (noMasking)
                response.Append(WsConstants.NoMaskingFullResponse);
            if (options.Headers != null)
            {
                foreach (var (k, v) in options.Headers)
                {
                    response.Append(k);
                    response.Append(": ");
                    response.Append(v);
                    response.Append("\r\n");
                }
            }
            response.Append("\r\n");

            var responseBytes = Encoding.ASCII.GetBytes(response.ToString());
            await outputStream.WriteAsync(responseBytes);
            await outputStream.FlushAsync();

            return new WsUpgradeResult(
                Success: true,
                Compress: compress,
                Browser: browser,
                NoCompFrag: noCompFrag,
                MaskRead: !noMasking,
                MaskWrite: false,
                CookieJwt: cookieJwt,
                CookieUsername: cookieUsername,
                CookiePassword: cookiePassword,
                CookieToken: cookieToken,
                ClientIp: clientIp,
                Kind: kind);
        }
        catch (Exception)
        {
            return WsUpgradeResult.Failed;
        }
    }

    /// <summary>
    /// Computes the Sec-WebSocket-Accept value per RFC 6455 Section 4.2.2.
    /// </summary>
    public static string ComputeAcceptKey(string clientKey)
    {
        var combined = Encoding.ASCII.GetBytes(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
        var hash = SHA1.HashData(combined);
        return Convert.ToBase64String(hash);
    }

    private static async Task<WsUpgradeResult> FailAsync(Stream output, int statusCode, string reason)
    {
        var statusText = statusCode switch
        {
            400 => "Bad Request",
            403 => "Forbidden",
            405 => "Method Not Allowed",
            _ => "Internal Server Error",
        };
        var response = $"HTTP/1.1 {statusCode} {statusText}\r\nSec-WebSocket-Version: 13\r\nContent-Type: text/plain\r\nContent-Length: {reason.Length}\r\n\r\n{reason}";
        await output.WriteAsync(Encoding.ASCII.GetBytes(response));
        await output.FlushAsync();
        return WsUpgradeResult.Failed;
    }

    private static async Task<(string method, string path, Dictionary<string, string> headers)> ReadHttpRequestAsync(Stream stream)
    {
        var headerBytes = new List<byte>(4096);
        int prev = 0;
        var buf = new byte[1];
        // Read until \r\n\r\n
        while (true)
        {
            int n = await stream.ReadAsync(buf);
            if (n == 0) throw new IOException("connection closed during handshake");
            headerBytes.Add(buf[0]);
            if (headerBytes.Count >= 4 &&
                headerBytes[^4] == '\r' && headerBytes[^3] == '\n' &&
                headerBytes[^2] == '\r' && headerBytes[^1] == '\n')
                break;
            if (headerBytes.Count > 8192)
                throw new InvalidOperationException("HTTP header too large");
        }

        var text = Encoding.ASCII.GetString(headerBytes.ToArray());
        var lines = text.Split("\r\n", StringSplitOptions.None);
        if (lines.Length < 1) throw new InvalidOperationException("invalid HTTP request");

        // Parse request line
        var parts = lines[0].Split(' ');
        if (parts.Length < 3) throw new InvalidOperationException("invalid HTTP request line");
        var method = parts[0];
        var path = parts[1];

        // Parse headers
        var headers = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
        for (int i = 1; i < lines.Length; i++)
        {
            var line = lines[i];
            if (string.IsNullOrEmpty(line)) break;
            var colonIdx = line.IndexOf(':');
            if (colonIdx > 0)
            {
                var name = line[..colonIdx].Trim();
                var value = line[(colonIdx + 1)..].Trim();
                headers[name] = value;
            }
        }

        return (method, path, headers);
    }

    private static bool HeaderContains(Dictionary<string, string> headers, string name, string value)
    {
        if (!headers.TryGetValue(name, out var headerValue))
            return false;
        foreach (var token in headerValue.Split(','))
        {
            if (string.Equals(token.Trim(), value, StringComparison.OrdinalIgnoreCase))
                return true;
        }
        return false;
    }

    private static Dictionary<string, string> ParseCookies(string cookieHeader)
    {
        var cookies = new Dictionary<string, string>(StringComparer.Ordinal);
        foreach (var pair in cookieHeader.Split(';'))
        {
            var trimmed = pair.Trim();
            var eqIdx = trimmed.IndexOf('=');
            if (eqIdx > 0)
                cookies[trimmed[..eqIdx].Trim()] = trimmed[(eqIdx + 1)..].Trim();
        }
        return cookies;
    }
}

public readonly record struct WsUpgradeResult(
    bool Success,
    bool Compress,
    bool Browser,
    bool NoCompFrag,
    bool MaskRead,
    bool MaskWrite,
    string? CookieJwt,
    string? CookieUsername,
    string? CookiePassword,
    string? CookieToken,
    string? ClientIp,
    WsClientKind Kind)
{
    public static readonly WsUpgradeResult Failed = new(
        Success: false, Compress: false, Browser: false, NoCompFrag: false,
        MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null,
        CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client);
}

Step 4: Run test to verify it passes

Run: dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsUpgradeTests" -v normal Expected: PASS

Step 5: Commit

git add src/NATS.Server/WebSocket/WsUpgrade.cs tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs
git commit -m "feat: add WebSocket HTTP upgrade handshake"

Task 7: Add WsConnection Stream wrapper

Files:

  • Create: src/NATS.Server/WebSocket/WsConnection.cs
  • Create: tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs

Step 1: Write the failing test

Create tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs:

using System.Buffers.Binary;
using NATS.Server.WebSocket;
using Shouldly;

namespace NATS.Server.Tests.WebSocket;

public class WsConnectionTests
{
    [Fact]
    public async Task ReadAsync_DecodesFrameAndReturnsPayload()
    {
        var payload = "SUB test 1\r\n"u8.ToArray();
        var frame = BuildUnmaskedFrame(payload);
        var inner = new MemoryStream(frame);
        var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);

        var buf = new byte[256];
        int n = await ws.ReadAsync(buf);

        n.ShouldBe(payload.Length);
        buf[..n].ShouldBe(payload);
    }

    [Fact]
    public async Task WriteAsync_FramesPayload()
    {
        var inner = new MemoryStream();
        var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);

        var payload = "MSG test 1 5\r\nHello\r\n"u8.ToArray();
        await ws.WriteAsync(payload);
        await ws.FlushAsync();

        inner.Position = 0;
        var written = inner.ToArray();
        // First 2 bytes should be WS frame header
        (written[0] & WsConstants.FinalBit).ShouldNotBe(0);
        (written[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage);
        int len = written[1] & 0x7F;
        len.ShouldBe(payload.Length);
        written[2..].ShouldBe(payload);
    }

    [Fact]
    public async Task WriteAsync_WithCompression_CompressesLargePayload()
    {
        var inner = new MemoryStream();
        var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);

        var payload = new byte[200];
        Array.Fill<byte>(payload, 0x41); // 'A' repeated - very compressible
        await ws.WriteAsync(payload);
        await ws.FlushAsync();

        inner.Position = 0;
        var written = inner.ToArray();
        // RSV1 bit should be set for compressed frame
        (written[0] & WsConstants.Rsv1Bit).ShouldNotBe(0);
        // Compressed size should be less than original
        written.Length.ShouldBeLessThan(payload.Length + 10);
    }

    [Fact]
    public async Task WriteAsync_SmallPayload_NotCompressedEvenWhenEnabled()
    {
        var inner = new MemoryStream();
        var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);

        var payload = "Hi"u8.ToArray(); // Below CompressThreshold
        await ws.WriteAsync(payload);
        await ws.FlushAsync();

        inner.Position = 0;
        var written = inner.ToArray();
        // RSV1 bit should NOT be set for small payloads
        (written[0] & WsConstants.Rsv1Bit).ShouldBe(0);
    }

    private static byte[] BuildUnmaskedFrame(byte[] payload)
    {
        var header = new byte[2];
        header[0] = (byte)(WsConstants.FinalBit | WsConstants.BinaryMessage);
        header[1] = (byte)payload.Length;
        var frame = new byte[2 + payload.Length];
        header.CopyTo(frame, 0);
        payload.CopyTo(frame, 2);
        return frame;
    }
}

Step 2: Run test to verify it fails

Run: dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsConnectionTests" -v normal Expected: FAIL

Step 3: Write minimal implementation

Create src/NATS.Server/WebSocket/WsConnection.cs:

namespace NATS.Server.WebSocket;

/// <summary>
/// Stream wrapper that transparently frames/deframes WebSocket around raw TCP I/O.
/// NatsClient uses this as its _stream — FillPipeAsync and RunWriteLoopAsync work unchanged.
/// </summary>
public sealed class WsConnection : Stream
{
    private readonly Stream _inner;
    private readonly bool _compress;
    private readonly bool _maskRead;
    private readonly bool _maskWrite;
    private readonly bool _browser;
    private readonly bool _noCompFrag;
    private WsReadInfo _readInfo;
    private readonly Queue<byte[]> _readQueue = new();
    private int _readOffset;
    private readonly object _writeLock = new();
    private readonly List<ControlFrameAction> _pendingControlWrites = [];

    public bool CloseReceived => _readInfo.CloseReceived;
    public int CloseStatus => _readInfo.CloseStatus;

    public WsConnection(Stream inner, bool compress, bool maskRead, bool maskWrite, bool browser, bool noCompFrag)
    {
        _inner = inner;
        _compress = compress;
        _maskRead = maskRead;
        _maskWrite = maskWrite;
        _browser = browser;
        _noCompFrag = noCompFrag;
        _readInfo = new WsReadInfo(expectMask: maskRead);
    }

    public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken ct = default)
    {
        // Drain any buffered decoded payloads first
        if (_readQueue.Count > 0)
            return DrainReadQueue(buffer.Span);

        // Read raw bytes from inner stream
        var rawBuf = new byte[Math.Max(buffer.Length, 4096)];
        int bytesRead = await _inner.ReadAsync(rawBuf.AsMemory(0, rawBuf.Length), ct);
        if (bytesRead == 0) return 0;

        // Decode frames
        var payloads = WsReadInfo.ReadFrames(ref _readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024);

        // Collect control frame responses
        if (_readInfo.PendingControlFrames.Count > 0)
        {
            lock (_writeLock)
                _pendingControlWrites.AddRange(_readInfo.PendingControlFrames);
            _readInfo.PendingControlFrames.Clear();
            // Write pending control frames
            await FlushControlFramesAsync(ct);
        }

        if (_readInfo.CloseReceived)
            return 0;

        foreach (var payload in payloads)
            _readQueue.Enqueue(payload);

        if (_readQueue.Count == 0)
            return 0;

        return DrainReadQueue(buffer.Span);
    }

    public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken ct = default)
    {
        var data = buffer.Span;

        if (_compress && data.Length > WsConstants.CompressThreshold)
        {
            var compressed = WsCompression.Compress(data);
            WriteFramed(compressed, compressed: true, ct);
        }
        else
        {
            WriteFramed(data.ToArray(), compressed: false, ct);
        }
    }

    private void WriteFramed(byte[] payload, bool compressed, CancellationToken ct)
    {
        if (_browser && payload.Length > WsConstants.FrameSizeForBrowsers && !(_noCompFrag && compressed))
        {
            // Fragment for browsers
            int offset = 0;
            bool first = true;
            while (offset < payload.Length)
            {
                int chunkLen = Math.Min(WsConstants.FrameSizeForBrowsers, payload.Length - offset);
                bool final = offset + chunkLen >= payload.Length;
                var fh = new byte[WsConstants.MaxFrameHeaderSize];
                var (n, key) = WsFrameWriter.FillFrameHeader(fh, _maskWrite,
                    first: first, final: final, compressed: first && compressed,
                    opcode: WsConstants.BinaryMessage, payloadLength: chunkLen);

                var chunk = payload.AsSpan(offset, chunkLen).ToArray();
                if (_maskWrite && key != null)
                    WsFrameWriter.MaskBuf(key, chunk);

                _inner.Write(fh, 0, n);
                _inner.Write(chunk, 0, chunkLen);
                offset += chunkLen;
                first = false;
            }
        }
        else
        {
            var (header, key) = WsFrameWriter.CreateFrameHeader(_maskWrite, compressed, WsConstants.BinaryMessage, payload.Length);
            if (_maskWrite && key != null)
                WsFrameWriter.MaskBuf(key, payload);
            _inner.Write(header);
            _inner.Write(payload);
        }
    }

    private async Task FlushControlFramesAsync(CancellationToken ct)
    {
        List<ControlFrameAction> toWrite;
        lock (_writeLock)
        {
            if (_pendingControlWrites.Count == 0) return;
            toWrite = [.. _pendingControlWrites];
            _pendingControlWrites.Clear();
        }

        foreach (var action in toWrite)
        {
            var frame = WsFrameWriter.BuildControlFrame(action.Opcode, action.Payload, _maskWrite);
            await _inner.WriteAsync(frame, ct);
        }
        await _inner.FlushAsync(ct);
    }

    /// <summary>
    /// Sends a WebSocket close frame.
    /// </summary>
    public async Task SendCloseAsync(ClientClosedReason reason, CancellationToken ct = default)
    {
        var status = WsFrameWriter.MapCloseStatus(reason);
        var closePayload = WsFrameWriter.CreateCloseMessage(status, reason.ToReasonString());
        var frame = WsFrameWriter.BuildControlFrame(WsConstants.CloseMessage, closePayload, _maskWrite);
        await _inner.WriteAsync(frame, ct);
        await _inner.FlushAsync(ct);
    }

    private int DrainReadQueue(Span<byte> buffer)
    {
        int written = 0;
        while (_readQueue.Count > 0 && written < buffer.Length)
        {
            var current = _readQueue.Peek();
            int available = current.Length - _readOffset;
            int toCopy = Math.Min(available, buffer.Length - written);
            current.AsSpan(_readOffset, toCopy).CopyTo(buffer[written..]);
            written += toCopy;
            _readOffset += toCopy;
            if (_readOffset >= current.Length)
            {
                _readQueue.Dequeue();
                _readOffset = 0;
            }
        }
        return written;
    }

    // Stream abstract members
    public override bool CanRead => true;
    public override bool CanWrite => true;
    public override bool CanSeek => false;
    public override long Length => throw new NotSupportedException();
    public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
    public override void Flush() => _inner.Flush();
    public override Task FlushAsync(CancellationToken ct) => _inner.FlushAsync(ct);
    public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use ReadAsync");
    public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use WriteAsync");
    public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
    public override void SetLength(long value) => throw new NotSupportedException();

    protected override void Dispose(bool disposing)
    {
        if (disposing)
            _inner.Dispose();
        base.Dispose(disposing);
    }
}

Step 4: Run test to verify it passes

Run: dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsConnectionTests" -v normal Expected: PASS

Step 5: Commit

git add src/NATS.Server/WebSocket/WsConnection.cs tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs
git commit -m "feat: add WsConnection Stream wrapper for transparent framing"

Task 8: Integrate WebSocket into NatsServer and NatsClient

Files:

  • Modify: src/NATS.Server/NatsServer.cs
  • Modify: src/NATS.Server/NatsClient.cs

Step 1: Write the failing test

Create tests/NATS.Server.Tests/WebSocket/WsIntegrationTests.cs:

using System.Buffers.Binary;
using System.Net;
using System.Net.Sockets;
using System.Security.Cryptography;
using System.Text;
using NATS.Server.WebSocket;
using Shouldly;

namespace NATS.Server.Tests.WebSocket;

public class WsIntegrationTests : IAsyncLifetime
{
    private NatsServer _server = null!;
    private NatsOptions _options = null!;

    public async Task InitializeAsync()
    {
        _options = new NatsOptions
        {
            Port = 0,
            WebSocket = new WebSocketOptions { Port = 0, NoTls = true },
        };
        var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(b => { });
        _server = new NatsServer(_options, loggerFactory);
        _ = _server.StartAsync(CancellationToken.None);
        await _server.WaitForReadyAsync();
    }

    public async Task DisposeAsync()
    {
        await _server.ShutdownAsync();
        _server.Dispose();
    }

    [Fact]
    public async Task WebSocket_ConnectAndReceiveInfo()
    {
        using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
        await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port));
        using var stream = new NetworkStream(socket, ownsSocket: false);

        // Send WebSocket upgrade request
        await SendUpgradeRequest(stream);

        // Read 101 response
        var response = await ReadHttpResponse(stream);
        response.ShouldContain("101");

        // Now read the INFO line through WebSocket frames
        var wsFrame = await ReadWsFrame(stream);
        var info = Encoding.ASCII.GetString(wsFrame);
        info.ShouldStartWith("INFO ");
    }

    [Fact]
    public async Task WebSocket_PubSub()
    {
        // Connect two WS clients
        using var sub = await ConnectWsClient();
        using var pub = await ConnectWsClient();

        // Subscribe on first client
        await SendWsText(sub, "CONNECT {}\r\nSUB test.ws 1\r\n");
        await Task.Delay(100);

        // Publish on second client
        await SendWsText(pub, "CONNECT {}\r\nPUB test.ws 5\r\nHello\r\n");
        await Task.Delay(100);

        // Read from subscriber
        var msg = await ReadWsFrame(sub);
        Encoding.ASCII.GetString(msg).ShouldContain("MSG test.ws 1 5");
    }

    private async Task<NetworkStream> ConnectWsClient()
    {
        var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
        await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port));
        var stream = new NetworkStream(socket, ownsSocket: true);

        await SendUpgradeRequest(stream);
        var response = await ReadHttpResponse(stream);
        response.ShouldContain("101");

        // Read INFO frame
        await ReadWsFrame(stream);

        return stream;
    }

    private static async Task SendUpgradeRequest(NetworkStream stream)
    {
        var keyBytes = new byte[16];
        RandomNumberGenerator.Fill(keyBytes);
        var key = Convert.ToBase64String(keyBytes);

        var request = $"GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {key}\r\nSec-WebSocket-Version: 13\r\n\r\n";
        await stream.WriteAsync(Encoding.ASCII.GetBytes(request));
        await stream.FlushAsync();
    }

    private static async Task<string> ReadHttpResponse(NetworkStream stream)
    {
        var buf = new byte[4096];
        var sb = new StringBuilder();
        while (true)
        {
            int n = await stream.ReadAsync(buf);
            if (n == 0) break;
            sb.Append(Encoding.ASCII.GetString(buf, 0, n));
            if (sb.ToString().Contains("\r\n\r\n")) break;
        }
        return sb.ToString();
    }

    private static async Task<byte[]> ReadWsFrame(NetworkStream stream)
    {
        var header = new byte[2];
        await stream.ReadExactlyAsync(header);
        int len = header[1] & 0x7F;
        byte[]? extLen = null;
        if (len == 126)
        {
            extLen = new byte[2];
            await stream.ReadExactlyAsync(extLen);
            len = BinaryPrimitives.ReadUInt16BigEndian(extLen);
        }
        else if (len == 127)
        {
            extLen = new byte[8];
            await stream.ReadExactlyAsync(extLen);
            len = (int)BinaryPrimitives.ReadUInt64BigEndian(extLen);
        }
        var payload = new byte[len];
        if (len > 0) await stream.ReadExactlyAsync(payload);
        return payload;
    }

    private static async Task SendWsText(NetworkStream stream, string text)
    {
        var payload = Encoding.ASCII.GetBytes(text);
        var (header, _) = WsFrameWriter.CreateFrameHeader(
            useMasking: true, compressed: false,
            opcode: WsConstants.BinaryMessage, payloadLength: payload.Length);
        // The masking key is in the header — we need to mask the payload
        var maskKey = header[^4..];
        WsFrameWriter.MaskBuf(maskKey, payload);
        await stream.WriteAsync(header);
        await stream.WriteAsync(payload);
        await stream.FlushAsync();
    }
}

Step 2: Run test to verify it fails

Run: dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsIntegrationTests" -v normal Expected: FAIL — NatsServer has no WebSocket listener

Step 3: Modify NatsServer.cs

Add these fields to NatsServer:

private Socket? _wsListener;
private readonly TaskCompletionSource _wsAcceptLoopExited = new(TaskCreationOptions.RunContinuationsAsynchronously);

In StartAsync, after the monitoring server startup and before the main accept loop, add:

if (_options.WebSocket.Port > 0)
{
    _wsListener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
    _wsListener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true);
    _wsListener.Bind(new IPEndPoint(
        _options.WebSocket.Host == "0.0.0.0" ? IPAddress.Any : IPAddress.Parse(_options.WebSocket.Host),
        _options.WebSocket.Port));
    _wsListener.Listen(128);

    if (_options.WebSocket.Port == 0)
    {
        _options.WebSocket.Port = ((IPEndPoint)_wsListener.LocalEndPoint!).Port;
    }

    _logger.LogInformation("Listening for WebSocket clients on {Host}:{Port}",
        _options.WebSocket.Host, _options.WebSocket.Port);

    if (_options.WebSocket.NoTls)
        _logger.LogWarning("WebSocket not configured with TLS. DO NOT USE IN PRODUCTION!");

    _ = RunWebSocketAcceptLoopAsync(linked.Token);
}

Add the WebSocket accept loop method:

private async Task RunWebSocketAcceptLoopAsync(CancellationToken ct)
{
    var tmpDelay = AcceptMinSleep;
    try
    {
        while (!ct.IsCancellationRequested)
        {
            Socket socket;
            try
            {
                socket = await _wsListener!.AcceptAsync(ct);
                tmpDelay = AcceptMinSleep;
            }
            catch (OperationCanceledException) { break; }
            catch (ObjectDisposedException) { break; }
            catch (SocketException ex)
            {
                if (IsShuttingDown || IsLameDuckMode) break;
                _logger.LogError(ex, "Temporary WebSocket accept error, sleeping {Delay}ms", tmpDelay.TotalMilliseconds);
                try { await Task.Delay(tmpDelay, ct); } catch (OperationCanceledException) { break; }
                tmpDelay = TimeSpan.FromTicks(Math.Min(tmpDelay.Ticks * 2, AcceptMaxSleep.Ticks));
                continue;
            }

            if (_options.MaxConnections > 0 && _clients.Count >= _options.MaxConnections)
            {
                socket.Dispose();
                continue;
            }

            var clientId = Interlocked.Increment(ref _nextClientId);
            Interlocked.Increment(ref _stats.TotalConnections);
            Interlocked.Increment(ref _activeClientCount);

            _ = AcceptWebSocketClientAsync(socket, clientId, ct);
        }
    }
    finally
    {
        _wsAcceptLoopExited.TrySetResult();
    }
}

private async Task AcceptWebSocketClientAsync(Socket socket, ulong clientId, CancellationToken ct)
{
    try
    {
        var networkStream = new NetworkStream(socket, ownsSocket: false);
        Stream stream = networkStream;

        // TLS negotiation if configured
        if (_sslOptions != null && !_options.WebSocket.NoTls)
        {
            var (tlsStream, _) = await Tls.TlsConnectionWrapper.NegotiateAsync(
                socket, networkStream, _options, _sslOptions, _serverInfo,
                _loggerFactory.CreateLogger("NATS.Server.Tls"), ct);
            stream = tlsStream;
        }

        // HTTP upgrade handshake
        var upgradeResult = await WebSocket.WsUpgrade.TryUpgradeAsync(stream, stream, _options.WebSocket);
        if (!upgradeResult.Success)
        {
            _logger.LogDebug("WebSocket upgrade failed for client {ClientId}", clientId);
            socket.Dispose();
            Interlocked.Decrement(ref _activeClientCount);
            return;
        }

        // Create WsConnection wrapper
        var wsConn = new WebSocket.WsConnection(stream,
            compress: upgradeResult.Compress,
            maskRead: upgradeResult.MaskRead,
            maskWrite: upgradeResult.MaskWrite,
            browser: upgradeResult.Browser,
            noCompFrag: upgradeResult.NoCompFrag);

        var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]");
        var client = new NatsClient(clientId, wsConn, socket, _options, _serverInfo,
            _authService, null, clientLogger, _stats);
        client.Router = this;
        client.IsWebSocket = true;
        client.WsInfo = upgradeResult;
        _clients[clientId] = client;

        await RunClientAsync(client, ct);
    }
    catch (Exception ex)
    {
        _logger.LogDebug(ex, "Failed to accept WebSocket client {ClientId}", clientId);
        try { socket.Shutdown(SocketShutdown.Both); } catch { }
        socket.Dispose();
        Interlocked.Decrement(ref _activeClientCount);
    }
}

In ShutdownAsync, add before _listener?.Close():

_wsListener?.Close();

And after _acceptLoopExited.Task.WaitAsync(...), add:

await _wsAcceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);

In Dispose, add:

_wsListener?.Dispose();

Step 4: Modify NatsClient.cs

Add two properties:

public bool IsWebSocket { get; set; }
public WsUpgradeResult? WsInfo { get; set; }

Step 5: Run test to verify it passes

Run: dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WsIntegrationTests" -v normal Expected: PASS

Step 6: Commit

git add src/NATS.Server/NatsServer.cs src/NATS.Server/NatsClient.cs tests/NATS.Server.Tests/WebSocket/WsIntegrationTests.cs
git commit -m "feat: integrate WebSocket accept loop into NatsServer"

Task 9: Update differences.md

Files:

  • Modify: differences.md

Step 1: Update WebSocket row in Connection Types table

Change line 70 from:

| WebSocket clients | Y | N | |

To:

| WebSocket clients | Y | Y | Custom frame parser, permessage-deflate compression, origin checking, cookie auth |

Step 2: Update Missing Options Categories

Change the line:

- WebSocket/MQTT options

To:

- ~~WebSocket options~~ — WebSocketOptions with port, compression, origin checking, cookie auth, custom headers
- MQTT options

Step 3: Commit

git add differences.md
git commit -m "docs: update differences.md to reflect WebSocket implementation"

Task 10: Run full test suite and verify

Step 1: Build

Run: dotnet build Expected: Build succeeded

Step 2: Run all tests

Run: dotnet test -v normal Expected: All tests pass (both existing and new WebSocket tests)

Step 3: Run only WebSocket tests

Run: dotnet test tests/NATS.Server.Tests --filter "FullyQualifiedName~WebSocket" -v normal Expected: All WebSocket tests pass

Step 4: Final commit (if any fixes needed)

git add -A
git commit -m "fix: address test failures from full suite run"