feat: add WebSocket frame writer with masking and close status mapping
This commit is contained in:
160
src/NATS.Server/WebSocket/WsFrameWriter.cs
Normal file
160
src/NATS.Server/WebSocket/WsFrameWriter.cs
Normal file
@@ -0,0 +1,160 @@
|
||||
using System.Buffers.Binary;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
|
||||
namespace NATS.Server.WebSocket;
|
||||
|
||||
/// <summary>
|
||||
/// WebSocket frame construction, masking, and control message creation.
|
||||
/// Ported from golang/nats-server/server/websocket.go lines 543-726.
|
||||
/// </summary>
|
||||
public static class WsFrameWriter
|
||||
{
|
||||
/// <summary>
|
||||
/// Creates a complete frame header for a single-frame message (first=true, final=true).
|
||||
/// Returns (header bytes, mask key or null).
|
||||
/// </summary>
|
||||
public static (byte[] header, byte[]? key) CreateFrameHeader(
|
||||
bool useMasking, bool compressed, int opcode, int payloadLength)
|
||||
{
|
||||
var fh = new byte[WsConstants.MaxFrameHeaderSize];
|
||||
var (n, key) = FillFrameHeader(fh, useMasking,
|
||||
first: true, final: true, compressed: compressed, opcode: opcode, payloadLength: payloadLength);
|
||||
return (fh[..n], key);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Fills a pre-allocated frame header buffer.
|
||||
/// Returns (bytes written, mask key or null).
|
||||
/// </summary>
|
||||
public static (int written, byte[]? key) FillFrameHeader(
|
||||
Span<byte> fh, bool useMasking, bool first, bool final, bool compressed, int opcode, int payloadLength)
|
||||
{
|
||||
byte b0 = first ? (byte)opcode : (byte)0;
|
||||
if (final) b0 |= WsConstants.FinalBit;
|
||||
if (compressed) b0 |= WsConstants.Rsv1Bit;
|
||||
|
||||
byte b1 = 0;
|
||||
if (useMasking) b1 |= WsConstants.MaskBit;
|
||||
|
||||
int n;
|
||||
switch (payloadLength)
|
||||
{
|
||||
case <= 125:
|
||||
n = 2;
|
||||
fh[0] = b0;
|
||||
fh[1] = (byte)(b1 | (byte)payloadLength);
|
||||
break;
|
||||
case < 65536:
|
||||
n = 4;
|
||||
fh[0] = b0;
|
||||
fh[1] = (byte)(b1 | 126);
|
||||
BinaryPrimitives.WriteUInt16BigEndian(fh[2..], (ushort)payloadLength);
|
||||
break;
|
||||
default:
|
||||
n = 10;
|
||||
fh[0] = b0;
|
||||
fh[1] = (byte)(b1 | 127);
|
||||
BinaryPrimitives.WriteUInt64BigEndian(fh[2..], (ulong)payloadLength);
|
||||
break;
|
||||
}
|
||||
|
||||
byte[]? key = null;
|
||||
if (useMasking)
|
||||
{
|
||||
key = new byte[4];
|
||||
RandomNumberGenerator.Fill(key);
|
||||
key.CopyTo(fh[n..]);
|
||||
n += 4;
|
||||
}
|
||||
|
||||
return (n, key);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// XOR masks a buffer with a 4-byte key. Applies in-place.
|
||||
/// </summary>
|
||||
public static void MaskBuf(ReadOnlySpan<byte> key, Span<byte> buf)
|
||||
{
|
||||
for (int i = 0; i < buf.Length; i++)
|
||||
buf[i] ^= key[i & 3];
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// XOR masks multiple contiguous buffers as if they were one.
|
||||
/// </summary>
|
||||
public static void MaskBufs(ReadOnlySpan<byte> key, List<byte[]> bufs)
|
||||
{
|
||||
int pos = 0;
|
||||
foreach (var buf in bufs)
|
||||
{
|
||||
for (int j = 0; j < buf.Length; j++)
|
||||
{
|
||||
buf[j] ^= key[pos & 3];
|
||||
pos++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates a close message payload: 2-byte status code + optional UTF-8 body.
|
||||
/// Body truncated to fit MaxControlPayloadSize with "..." suffix.
|
||||
/// </summary>
|
||||
public static byte[] CreateCloseMessage(int status, string body)
|
||||
{
|
||||
if (body.Length > WsConstants.MaxControlPayloadSize - WsConstants.CloseStatusSize)
|
||||
{
|
||||
body = body[..(WsConstants.MaxControlPayloadSize - WsConstants.CloseStatusSize - 3)] + "...";
|
||||
}
|
||||
|
||||
var bodyBytes = Encoding.UTF8.GetBytes(body);
|
||||
var buf = new byte[WsConstants.CloseStatusSize + bodyBytes.Length];
|
||||
BinaryPrimitives.WriteUInt16BigEndian(buf, (ushort)status);
|
||||
bodyBytes.CopyTo(buf.AsSpan(WsConstants.CloseStatusSize));
|
||||
return buf;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Builds a complete control frame (header + payload, optional masking).
|
||||
/// </summary>
|
||||
public static byte[] BuildControlFrame(int opcode, ReadOnlySpan<byte> payload, bool useMasking)
|
||||
{
|
||||
int headerSize = 2 + (useMasking ? 4 : 0);
|
||||
var frame = new byte[headerSize + payload.Length];
|
||||
var span = frame.AsSpan();
|
||||
var (n, key) = FillFrameHeader(span, useMasking,
|
||||
first: true, final: true, compressed: false, opcode: opcode, payloadLength: payload.Length);
|
||||
if (payload.Length > 0)
|
||||
{
|
||||
payload.CopyTo(span[n..]);
|
||||
if (useMasking && key != null)
|
||||
MaskBuf(key, span[n..]);
|
||||
}
|
||||
|
||||
return frame;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Maps a ClientClosedReason to a WebSocket close status code.
|
||||
/// Matches Go wsEnqueueCloseMessage in websocket.go lines 668-694.
|
||||
/// </summary>
|
||||
public static int MapCloseStatus(ClientClosedReason reason) => reason switch
|
||||
{
|
||||
ClientClosedReason.ClientClosed => WsConstants.CloseStatusNormalClosure,
|
||||
ClientClosedReason.AuthenticationTimeout or
|
||||
ClientClosedReason.AuthenticationViolation or
|
||||
ClientClosedReason.SlowConsumerPendingBytes or
|
||||
ClientClosedReason.SlowConsumerWriteDeadline or
|
||||
ClientClosedReason.MaxSubscriptionsExceeded or
|
||||
ClientClosedReason.AuthenticationExpired => WsConstants.CloseStatusPolicyViolation,
|
||||
ClientClosedReason.TlsHandshakeError => WsConstants.CloseStatusTlsHandshake,
|
||||
ClientClosedReason.ParseError or
|
||||
ClientClosedReason.ProtocolViolation => WsConstants.CloseStatusProtocolError,
|
||||
ClientClosedReason.MaxPayloadExceeded => WsConstants.CloseStatusMessageTooBig,
|
||||
ClientClosedReason.WriteError or
|
||||
ClientClosedReason.ReadError or
|
||||
ClientClosedReason.StaleConnection or
|
||||
ClientClosedReason.ServerShutdown => WsConstants.CloseStatusGoingAway,
|
||||
_ => WsConstants.CloseStatusInternalSrvError,
|
||||
};
|
||||
}
|
||||
152
tests/NATS.Server.Tests/WebSocket/WsFrameWriterTests.cs
Normal file
152
tests/NATS.Server.Tests/WebSocket/WsFrameWriterTests.cs
Normal file
@@ -0,0 +1,152 @@
|
||||
using System.Buffers.Binary;
|
||||
using NATS.Server.WebSocket;
|
||||
using Shouldly;
|
||||
|
||||
namespace NATS.Server.Tests.WebSocket;
|
||||
|
||||
public class WsFrameWriterTests
|
||||
{
|
||||
[Fact]
|
||||
public void CreateFrameHeader_SmallPayload_7BitLength()
|
||||
{
|
||||
var (header, _) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: false, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 100);
|
||||
header.Length.ShouldBe(2);
|
||||
(header[0] & WsConstants.FinalBit).ShouldNotBe(0); // FIN set
|
||||
(header[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage);
|
||||
(header[1] & 0x7F).ShouldBe(100);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateFrameHeader_MediumPayload_16BitLength()
|
||||
{
|
||||
var (header, _) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: false, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 1000);
|
||||
header.Length.ShouldBe(4);
|
||||
(header[1] & 0x7F).ShouldBe(126);
|
||||
BinaryPrimitives.ReadUInt16BigEndian(header.AsSpan(2)).ShouldBe((ushort)1000);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateFrameHeader_LargePayload_64BitLength()
|
||||
{
|
||||
var (header, _) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: false, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 70000);
|
||||
header.Length.ShouldBe(10);
|
||||
(header[1] & 0x7F).ShouldBe(127);
|
||||
BinaryPrimitives.ReadUInt64BigEndian(header.AsSpan(2)).ShouldBe(70000UL);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateFrameHeader_WithMasking_Adds4ByteKey()
|
||||
{
|
||||
var (header, key) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: true, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 10);
|
||||
header.Length.ShouldBe(6); // 2 header + 4 mask key
|
||||
(header[1] & WsConstants.MaskBit).ShouldNotBe(0);
|
||||
key.ShouldNotBeNull();
|
||||
key.Length.ShouldBe(4);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateFrameHeader_Compressed_SetsRsv1Bit()
|
||||
{
|
||||
var (header, _) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: false, compressed: true,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 10);
|
||||
(header[0] & WsConstants.Rsv1Bit).ShouldNotBe(0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MaskBuf_XorsCorrectly()
|
||||
{
|
||||
byte[] key = [0xAA, 0xBB, 0xCC, 0xDD];
|
||||
byte[] data = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06];
|
||||
byte[] expected = new byte[data.Length];
|
||||
for (int i = 0; i < data.Length; i++)
|
||||
expected[i] = (byte)(data[i] ^ key[i & 3]);
|
||||
|
||||
WsFrameWriter.MaskBuf(key, data);
|
||||
data.ShouldBe(expected);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MaskBuf_RoundTrip()
|
||||
{
|
||||
byte[] key = [0x12, 0x34, 0x56, 0x78];
|
||||
byte[] original = "Hello, WebSocket!"u8.ToArray();
|
||||
var data = original.ToArray();
|
||||
|
||||
WsFrameWriter.MaskBuf(key, data);
|
||||
data.ShouldNotBe(original);
|
||||
WsFrameWriter.MaskBuf(key, data);
|
||||
data.ShouldBe(original);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateCloseMessage_WithStatusAndBody()
|
||||
{
|
||||
var msg = WsFrameWriter.CreateCloseMessage(1000, "normal closure");
|
||||
msg.Length.ShouldBe(2 + "normal closure".Length);
|
||||
BinaryPrimitives.ReadUInt16BigEndian(msg).ShouldBe((ushort)1000);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateCloseMessage_LongBody_Truncated()
|
||||
{
|
||||
var longBody = new string('x', 200);
|
||||
var msg = WsFrameWriter.CreateCloseMessage(1000, longBody);
|
||||
msg.Length.ShouldBeLessThanOrEqualTo(WsConstants.MaxControlPayloadSize);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MapCloseStatus_ClientClosed_NormalClosure()
|
||||
{
|
||||
WsFrameWriter.MapCloseStatus(ClientClosedReason.ClientClosed)
|
||||
.ShouldBe(WsConstants.CloseStatusNormalClosure);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MapCloseStatus_AuthTimeout_PolicyViolation()
|
||||
{
|
||||
WsFrameWriter.MapCloseStatus(ClientClosedReason.AuthenticationTimeout)
|
||||
.ShouldBe(WsConstants.CloseStatusPolicyViolation);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MapCloseStatus_ParseError_ProtocolError()
|
||||
{
|
||||
WsFrameWriter.MapCloseStatus(ClientClosedReason.ParseError)
|
||||
.ShouldBe(WsConstants.CloseStatusProtocolError);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MapCloseStatus_MaxPayload_MessageTooBig()
|
||||
{
|
||||
WsFrameWriter.MapCloseStatus(ClientClosedReason.MaxPayloadExceeded)
|
||||
.ShouldBe(WsConstants.CloseStatusMessageTooBig);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BuildControlFrame_PingNomask()
|
||||
{
|
||||
var frame = WsFrameWriter.BuildControlFrame(WsConstants.PingMessage, [], useMasking: false);
|
||||
frame.Length.ShouldBe(2);
|
||||
(frame[0] & WsConstants.FinalBit).ShouldNotBe(0);
|
||||
(frame[0] & 0x0F).ShouldBe(WsConstants.PingMessage);
|
||||
(frame[1] & 0x7F).ShouldBe(0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BuildControlFrame_PongWithPayload()
|
||||
{
|
||||
byte[] payload = [1, 2, 3, 4];
|
||||
var frame = WsFrameWriter.BuildControlFrame(WsConstants.PongMessage, payload, useMasking: false);
|
||||
frame.Length.ShouldBe(2 + 4);
|
||||
frame[2..].ShouldBe(payload);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user