feat: add WsConnection Stream wrapper for transparent framing

This commit is contained in:
Joseph Doherty
2026-02-23 04:58:56 -05:00
parent fe304dfe01
commit 6d0a4d259e
2 changed files with 280 additions and 0 deletions

View File

@@ -0,0 +1,192 @@
namespace NATS.Server.WebSocket;
/// <summary>
/// Stream wrapper that transparently frames/deframes WebSocket around raw TCP I/O.
/// NatsClient uses this as its _stream -- FillPipeAsync and RunWriteLoopAsync work unchanged.
/// Ported from golang/nats-server/server/websocket.go wsUpgrade/wrapWebsocket pattern.
/// </summary>
public sealed class WsConnection : Stream
{
private readonly Stream _inner;
private readonly bool _compress;
private readonly bool _maskRead;
private readonly bool _maskWrite;
private readonly bool _browser;
private readonly bool _noCompFrag;
private WsReadInfo _readInfo;
private readonly Queue<byte[]> _readQueue = new();
private int _readOffset;
private readonly object _writeLock = new();
private readonly List<ControlFrameAction> _pendingControlWrites = [];
public bool CloseReceived => _readInfo.CloseReceived;
public int CloseStatus => _readInfo.CloseStatus;
public WsConnection(Stream inner, bool compress, bool maskRead, bool maskWrite, bool browser, bool noCompFrag)
{
_inner = inner;
_compress = compress;
_maskRead = maskRead;
_maskWrite = maskWrite;
_browser = browser;
_noCompFrag = noCompFrag;
_readInfo = new WsReadInfo(expectMask: maskRead);
}
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken ct = default)
{
// Drain any buffered decoded payloads first
if (_readQueue.Count > 0)
return DrainReadQueue(buffer.Span);
// Read raw bytes from inner stream
var rawBuf = new byte[Math.Max(buffer.Length, 4096)];
int bytesRead = await _inner.ReadAsync(rawBuf.AsMemory(), ct);
if (bytesRead == 0) return 0;
// Decode frames
var payloads = WsReadInfo.ReadFrames(ref _readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024);
// Collect control frame responses
if (_readInfo.PendingControlFrames.Count > 0)
{
lock (_writeLock)
_pendingControlWrites.AddRange(_readInfo.PendingControlFrames);
_readInfo.PendingControlFrames.Clear();
// Write pending control frames
await FlushControlFramesAsync(ct);
}
if (_readInfo.CloseReceived)
return 0;
foreach (var payload in payloads)
_readQueue.Enqueue(payload);
if (_readQueue.Count == 0)
return 0;
return DrainReadQueue(buffer.Span);
}
public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken ct = default)
{
var data = buffer.Span;
if (_compress && data.Length > WsConstants.CompressThreshold)
{
var compressed = WsCompression.Compress(data);
await WriteFramedAsync(compressed, compressed: true, ct);
}
else
{
await WriteFramedAsync(data.ToArray(), compressed: false, ct);
}
}
private async ValueTask WriteFramedAsync(byte[] payload, bool compressed, CancellationToken ct)
{
if (_browser && payload.Length > WsConstants.FrameSizeForBrowsers && !(_noCompFrag && compressed))
{
// Fragment for browsers
int offset = 0;
bool first = true;
while (offset < payload.Length)
{
int chunkLen = Math.Min(WsConstants.FrameSizeForBrowsers, payload.Length - offset);
bool final = offset + chunkLen >= payload.Length;
var fh = new byte[WsConstants.MaxFrameHeaderSize];
var (n, key) = WsFrameWriter.FillFrameHeader(fh, _maskWrite,
first: first, final: final, compressed: first && compressed,
opcode: WsConstants.BinaryMessage, payloadLength: chunkLen);
var chunk = payload.AsSpan(offset, chunkLen).ToArray();
if (_maskWrite && key != null)
WsFrameWriter.MaskBuf(key, chunk);
await _inner.WriteAsync(fh.AsMemory(0, n), ct);
await _inner.WriteAsync(chunk.AsMemory(), ct);
offset += chunkLen;
first = false;
}
}
else
{
var (header, key) = WsFrameWriter.CreateFrameHeader(_maskWrite, compressed, WsConstants.BinaryMessage, payload.Length);
if (_maskWrite && key != null)
WsFrameWriter.MaskBuf(key, payload);
await _inner.WriteAsync(header.AsMemory(), ct);
await _inner.WriteAsync(payload.AsMemory(), ct);
}
}
private async Task FlushControlFramesAsync(CancellationToken ct)
{
List<ControlFrameAction> toWrite;
lock (_writeLock)
{
if (_pendingControlWrites.Count == 0) return;
toWrite = [.. _pendingControlWrites];
_pendingControlWrites.Clear();
}
foreach (var action in toWrite)
{
var frame = WsFrameWriter.BuildControlFrame(action.Opcode, action.Payload, _maskWrite);
await _inner.WriteAsync(frame, ct);
}
await _inner.FlushAsync(ct);
}
/// <summary>
/// Sends a WebSocket close frame.
/// </summary>
public async Task SendCloseAsync(ClientClosedReason reason, CancellationToken ct = default)
{
var status = WsFrameWriter.MapCloseStatus(reason);
var closePayload = WsFrameWriter.CreateCloseMessage(status, reason.ToReasonString());
var frame = WsFrameWriter.BuildControlFrame(WsConstants.CloseMessage, closePayload, _maskWrite);
await _inner.WriteAsync(frame, ct);
await _inner.FlushAsync(ct);
}
private int DrainReadQueue(Span<byte> buffer)
{
int written = 0;
while (_readQueue.Count > 0 && written < buffer.Length)
{
var current = _readQueue.Peek();
int available = current.Length - _readOffset;
int toCopy = Math.Min(available, buffer.Length - written);
current.AsSpan(_readOffset, toCopy).CopyTo(buffer[written..]);
written += toCopy;
_readOffset += toCopy;
if (_readOffset >= current.Length)
{
_readQueue.Dequeue();
_readOffset = 0;
}
}
return written;
}
// Stream abstract members
public override bool CanRead => true;
public override bool CanWrite => true;
public override bool CanSeek => false;
public override long Length => throw new NotSupportedException();
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
public override void Flush() => _inner.Flush();
public override Task FlushAsync(CancellationToken ct) => _inner.FlushAsync(ct);
public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use ReadAsync");
public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use WriteAsync");
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
public override void SetLength(long value) => throw new NotSupportedException();
protected override void Dispose(bool disposing)
{
if (disposing)
_inner.Dispose();
base.Dispose(disposing);
}
}

View File

@@ -0,0 +1,88 @@
using System.Buffers.Binary;
using NATS.Server.WebSocket;
namespace NATS.Server.Tests.WebSocket;
public class WsConnectionTests
{
[Fact]
public async Task ReadAsync_DecodesFrameAndReturnsPayload()
{
var payload = "SUB test 1\r\n"u8.ToArray();
var frame = BuildUnmaskedFrame(payload);
var inner = new MemoryStream(frame);
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
var buf = new byte[256];
int n = await ws.ReadAsync(buf);
n.ShouldBe(payload.Length);
buf[..n].ShouldBe(payload);
}
[Fact]
public async Task WriteAsync_FramesPayload()
{
var inner = new MemoryStream();
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
var payload = "MSG test 1 5\r\nHello\r\n"u8.ToArray();
await ws.WriteAsync(payload);
await ws.FlushAsync();
inner.Position = 0;
var written = inner.ToArray();
// First 2 bytes should be WS frame header
(written[0] & WsConstants.FinalBit).ShouldNotBe(0);
(written[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage);
int len = written[1] & 0x7F;
len.ShouldBe(payload.Length);
written[2..].ShouldBe(payload);
}
[Fact]
public async Task WriteAsync_WithCompression_CompressesLargePayload()
{
var inner = new MemoryStream();
var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
var payload = new byte[200];
Array.Fill<byte>(payload, 0x41); // 'A' repeated - very compressible
await ws.WriteAsync(payload);
await ws.FlushAsync();
inner.Position = 0;
var written = inner.ToArray();
// RSV1 bit should be set for compressed frame
(written[0] & WsConstants.Rsv1Bit).ShouldNotBe(0);
// Compressed size should be less than original
written.Length.ShouldBeLessThan(payload.Length + 10);
}
[Fact]
public async Task WriteAsync_SmallPayload_NotCompressedEvenWhenEnabled()
{
var inner = new MemoryStream();
var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
var payload = "Hi"u8.ToArray(); // Below CompressThreshold
await ws.WriteAsync(payload);
await ws.FlushAsync();
inner.Position = 0;
var written = inner.ToArray();
// RSV1 bit should NOT be set for small payloads
(written[0] & WsConstants.Rsv1Bit).ShouldBe(0);
}
private static byte[] BuildUnmaskedFrame(byte[] payload)
{
var header = new byte[2];
header[0] = (byte)(WsConstants.FinalBit | WsConstants.BinaryMessage);
header[1] = (byte)payload.Length;
var frame = new byte[2 + payload.Length];
header.CopyTo(frame, 0);
payload.CopyTo(frame, 2);
return frame;
}
}