feat: add WsConnection Stream wrapper for transparent framing
This commit is contained in:
192
src/NATS.Server/WebSocket/WsConnection.cs
Normal file
192
src/NATS.Server/WebSocket/WsConnection.cs
Normal file
@@ -0,0 +1,192 @@
|
||||
namespace NATS.Server.WebSocket;
|
||||
|
||||
/// <summary>
|
||||
/// Stream wrapper that transparently frames/deframes WebSocket around raw TCP I/O.
|
||||
/// NatsClient uses this as its _stream -- FillPipeAsync and RunWriteLoopAsync work unchanged.
|
||||
/// Ported from golang/nats-server/server/websocket.go wsUpgrade/wrapWebsocket pattern.
|
||||
/// </summary>
|
||||
public sealed class WsConnection : Stream
|
||||
{
|
||||
private readonly Stream _inner;
|
||||
private readonly bool _compress;
|
||||
private readonly bool _maskRead;
|
||||
private readonly bool _maskWrite;
|
||||
private readonly bool _browser;
|
||||
private readonly bool _noCompFrag;
|
||||
private WsReadInfo _readInfo;
|
||||
private readonly Queue<byte[]> _readQueue = new();
|
||||
private int _readOffset;
|
||||
private readonly object _writeLock = new();
|
||||
private readonly List<ControlFrameAction> _pendingControlWrites = [];
|
||||
|
||||
public bool CloseReceived => _readInfo.CloseReceived;
|
||||
public int CloseStatus => _readInfo.CloseStatus;
|
||||
|
||||
public WsConnection(Stream inner, bool compress, bool maskRead, bool maskWrite, bool browser, bool noCompFrag)
|
||||
{
|
||||
_inner = inner;
|
||||
_compress = compress;
|
||||
_maskRead = maskRead;
|
||||
_maskWrite = maskWrite;
|
||||
_browser = browser;
|
||||
_noCompFrag = noCompFrag;
|
||||
_readInfo = new WsReadInfo(expectMask: maskRead);
|
||||
}
|
||||
|
||||
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken ct = default)
|
||||
{
|
||||
// Drain any buffered decoded payloads first
|
||||
if (_readQueue.Count > 0)
|
||||
return DrainReadQueue(buffer.Span);
|
||||
|
||||
// Read raw bytes from inner stream
|
||||
var rawBuf = new byte[Math.Max(buffer.Length, 4096)];
|
||||
int bytesRead = await _inner.ReadAsync(rawBuf.AsMemory(), ct);
|
||||
if (bytesRead == 0) return 0;
|
||||
|
||||
// Decode frames
|
||||
var payloads = WsReadInfo.ReadFrames(ref _readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024);
|
||||
|
||||
// Collect control frame responses
|
||||
if (_readInfo.PendingControlFrames.Count > 0)
|
||||
{
|
||||
lock (_writeLock)
|
||||
_pendingControlWrites.AddRange(_readInfo.PendingControlFrames);
|
||||
_readInfo.PendingControlFrames.Clear();
|
||||
// Write pending control frames
|
||||
await FlushControlFramesAsync(ct);
|
||||
}
|
||||
|
||||
if (_readInfo.CloseReceived)
|
||||
return 0;
|
||||
|
||||
foreach (var payload in payloads)
|
||||
_readQueue.Enqueue(payload);
|
||||
|
||||
if (_readQueue.Count == 0)
|
||||
return 0;
|
||||
|
||||
return DrainReadQueue(buffer.Span);
|
||||
}
|
||||
|
||||
public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken ct = default)
|
||||
{
|
||||
var data = buffer.Span;
|
||||
|
||||
if (_compress && data.Length > WsConstants.CompressThreshold)
|
||||
{
|
||||
var compressed = WsCompression.Compress(data);
|
||||
await WriteFramedAsync(compressed, compressed: true, ct);
|
||||
}
|
||||
else
|
||||
{
|
||||
await WriteFramedAsync(data.ToArray(), compressed: false, ct);
|
||||
}
|
||||
}
|
||||
|
||||
private async ValueTask WriteFramedAsync(byte[] payload, bool compressed, CancellationToken ct)
|
||||
{
|
||||
if (_browser && payload.Length > WsConstants.FrameSizeForBrowsers && !(_noCompFrag && compressed))
|
||||
{
|
||||
// Fragment for browsers
|
||||
int offset = 0;
|
||||
bool first = true;
|
||||
while (offset < payload.Length)
|
||||
{
|
||||
int chunkLen = Math.Min(WsConstants.FrameSizeForBrowsers, payload.Length - offset);
|
||||
bool final = offset + chunkLen >= payload.Length;
|
||||
var fh = new byte[WsConstants.MaxFrameHeaderSize];
|
||||
var (n, key) = WsFrameWriter.FillFrameHeader(fh, _maskWrite,
|
||||
first: first, final: final, compressed: first && compressed,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: chunkLen);
|
||||
|
||||
var chunk = payload.AsSpan(offset, chunkLen).ToArray();
|
||||
if (_maskWrite && key != null)
|
||||
WsFrameWriter.MaskBuf(key, chunk);
|
||||
|
||||
await _inner.WriteAsync(fh.AsMemory(0, n), ct);
|
||||
await _inner.WriteAsync(chunk.AsMemory(), ct);
|
||||
offset += chunkLen;
|
||||
first = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
var (header, key) = WsFrameWriter.CreateFrameHeader(_maskWrite, compressed, WsConstants.BinaryMessage, payload.Length);
|
||||
if (_maskWrite && key != null)
|
||||
WsFrameWriter.MaskBuf(key, payload);
|
||||
await _inner.WriteAsync(header.AsMemory(), ct);
|
||||
await _inner.WriteAsync(payload.AsMemory(), ct);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task FlushControlFramesAsync(CancellationToken ct)
|
||||
{
|
||||
List<ControlFrameAction> toWrite;
|
||||
lock (_writeLock)
|
||||
{
|
||||
if (_pendingControlWrites.Count == 0) return;
|
||||
toWrite = [.. _pendingControlWrites];
|
||||
_pendingControlWrites.Clear();
|
||||
}
|
||||
|
||||
foreach (var action in toWrite)
|
||||
{
|
||||
var frame = WsFrameWriter.BuildControlFrame(action.Opcode, action.Payload, _maskWrite);
|
||||
await _inner.WriteAsync(frame, ct);
|
||||
}
|
||||
await _inner.FlushAsync(ct);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sends a WebSocket close frame.
|
||||
/// </summary>
|
||||
public async Task SendCloseAsync(ClientClosedReason reason, CancellationToken ct = default)
|
||||
{
|
||||
var status = WsFrameWriter.MapCloseStatus(reason);
|
||||
var closePayload = WsFrameWriter.CreateCloseMessage(status, reason.ToReasonString());
|
||||
var frame = WsFrameWriter.BuildControlFrame(WsConstants.CloseMessage, closePayload, _maskWrite);
|
||||
await _inner.WriteAsync(frame, ct);
|
||||
await _inner.FlushAsync(ct);
|
||||
}
|
||||
|
||||
private int DrainReadQueue(Span<byte> buffer)
|
||||
{
|
||||
int written = 0;
|
||||
while (_readQueue.Count > 0 && written < buffer.Length)
|
||||
{
|
||||
var current = _readQueue.Peek();
|
||||
int available = current.Length - _readOffset;
|
||||
int toCopy = Math.Min(available, buffer.Length - written);
|
||||
current.AsSpan(_readOffset, toCopy).CopyTo(buffer[written..]);
|
||||
written += toCopy;
|
||||
_readOffset += toCopy;
|
||||
if (_readOffset >= current.Length)
|
||||
{
|
||||
_readQueue.Dequeue();
|
||||
_readOffset = 0;
|
||||
}
|
||||
}
|
||||
return written;
|
||||
}
|
||||
|
||||
// Stream abstract members
|
||||
public override bool CanRead => true;
|
||||
public override bool CanWrite => true;
|
||||
public override bool CanSeek => false;
|
||||
public override long Length => throw new NotSupportedException();
|
||||
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
|
||||
public override void Flush() => _inner.Flush();
|
||||
public override Task FlushAsync(CancellationToken ct) => _inner.FlushAsync(ct);
|
||||
public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use ReadAsync");
|
||||
public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use WriteAsync");
|
||||
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
|
||||
public override void SetLength(long value) => throw new NotSupportedException();
|
||||
|
||||
protected override void Dispose(bool disposing)
|
||||
{
|
||||
if (disposing)
|
||||
_inner.Dispose();
|
||||
base.Dispose(disposing);
|
||||
}
|
||||
}
|
||||
88
tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs
Normal file
88
tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs
Normal file
@@ -0,0 +1,88 @@
|
||||
using System.Buffers.Binary;
|
||||
using NATS.Server.WebSocket;
|
||||
|
||||
namespace NATS.Server.Tests.WebSocket;
|
||||
|
||||
public class WsConnectionTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task ReadAsync_DecodesFrameAndReturnsPayload()
|
||||
{
|
||||
var payload = "SUB test 1\r\n"u8.ToArray();
|
||||
var frame = BuildUnmaskedFrame(payload);
|
||||
var inner = new MemoryStream(frame);
|
||||
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
|
||||
|
||||
var buf = new byte[256];
|
||||
int n = await ws.ReadAsync(buf);
|
||||
|
||||
n.ShouldBe(payload.Length);
|
||||
buf[..n].ShouldBe(payload);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WriteAsync_FramesPayload()
|
||||
{
|
||||
var inner = new MemoryStream();
|
||||
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
|
||||
|
||||
var payload = "MSG test 1 5\r\nHello\r\n"u8.ToArray();
|
||||
await ws.WriteAsync(payload);
|
||||
await ws.FlushAsync();
|
||||
|
||||
inner.Position = 0;
|
||||
var written = inner.ToArray();
|
||||
// First 2 bytes should be WS frame header
|
||||
(written[0] & WsConstants.FinalBit).ShouldNotBe(0);
|
||||
(written[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage);
|
||||
int len = written[1] & 0x7F;
|
||||
len.ShouldBe(payload.Length);
|
||||
written[2..].ShouldBe(payload);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WriteAsync_WithCompression_CompressesLargePayload()
|
||||
{
|
||||
var inner = new MemoryStream();
|
||||
var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
|
||||
|
||||
var payload = new byte[200];
|
||||
Array.Fill<byte>(payload, 0x41); // 'A' repeated - very compressible
|
||||
await ws.WriteAsync(payload);
|
||||
await ws.FlushAsync();
|
||||
|
||||
inner.Position = 0;
|
||||
var written = inner.ToArray();
|
||||
// RSV1 bit should be set for compressed frame
|
||||
(written[0] & WsConstants.Rsv1Bit).ShouldNotBe(0);
|
||||
// Compressed size should be less than original
|
||||
written.Length.ShouldBeLessThan(payload.Length + 10);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WriteAsync_SmallPayload_NotCompressedEvenWhenEnabled()
|
||||
{
|
||||
var inner = new MemoryStream();
|
||||
var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
|
||||
|
||||
var payload = "Hi"u8.ToArray(); // Below CompressThreshold
|
||||
await ws.WriteAsync(payload);
|
||||
await ws.FlushAsync();
|
||||
|
||||
inner.Position = 0;
|
||||
var written = inner.ToArray();
|
||||
// RSV1 bit should NOT be set for small payloads
|
||||
(written[0] & WsConstants.Rsv1Bit).ShouldBe(0);
|
||||
}
|
||||
|
||||
private static byte[] BuildUnmaskedFrame(byte[] payload)
|
||||
{
|
||||
var header = new byte[2];
|
||||
header[0] = (byte)(WsConstants.FinalBit | WsConstants.BinaryMessage);
|
||||
header[1] = (byte)payload.Length;
|
||||
var frame = new byte[2 + payload.Length];
|
||||
header.CopyTo(frame, 0);
|
||||
payload.CopyTo(frame, 2);
|
||||
return frame;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user