feat: add WsConnection Stream wrapper for transparent framing
This commit is contained in:
192
src/NATS.Server/WebSocket/WsConnection.cs
Normal file
192
src/NATS.Server/WebSocket/WsConnection.cs
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
namespace NATS.Server.WebSocket;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Stream wrapper that transparently frames/deframes WebSocket around raw TCP I/O.
|
||||||
|
/// NatsClient uses this as its _stream -- FillPipeAsync and RunWriteLoopAsync work unchanged.
|
||||||
|
/// Ported from golang/nats-server/server/websocket.go wsUpgrade/wrapWebsocket pattern.
|
||||||
|
/// </summary>
|
||||||
|
public sealed class WsConnection : Stream
|
||||||
|
{
|
||||||
|
private readonly Stream _inner;
|
||||||
|
private readonly bool _compress;
|
||||||
|
private readonly bool _maskRead;
|
||||||
|
private readonly bool _maskWrite;
|
||||||
|
private readonly bool _browser;
|
||||||
|
private readonly bool _noCompFrag;
|
||||||
|
private WsReadInfo _readInfo;
|
||||||
|
private readonly Queue<byte[]> _readQueue = new();
|
||||||
|
private int _readOffset;
|
||||||
|
private readonly object _writeLock = new();
|
||||||
|
private readonly List<ControlFrameAction> _pendingControlWrites = [];
|
||||||
|
|
||||||
|
public bool CloseReceived => _readInfo.CloseReceived;
|
||||||
|
public int CloseStatus => _readInfo.CloseStatus;
|
||||||
|
|
||||||
|
public WsConnection(Stream inner, bool compress, bool maskRead, bool maskWrite, bool browser, bool noCompFrag)
|
||||||
|
{
|
||||||
|
_inner = inner;
|
||||||
|
_compress = compress;
|
||||||
|
_maskRead = maskRead;
|
||||||
|
_maskWrite = maskWrite;
|
||||||
|
_browser = browser;
|
||||||
|
_noCompFrag = noCompFrag;
|
||||||
|
_readInfo = new WsReadInfo(expectMask: maskRead);
|
||||||
|
}
|
||||||
|
|
||||||
|
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken ct = default)
|
||||||
|
{
|
||||||
|
// Drain any buffered decoded payloads first
|
||||||
|
if (_readQueue.Count > 0)
|
||||||
|
return DrainReadQueue(buffer.Span);
|
||||||
|
|
||||||
|
// Read raw bytes from inner stream
|
||||||
|
var rawBuf = new byte[Math.Max(buffer.Length, 4096)];
|
||||||
|
int bytesRead = await _inner.ReadAsync(rawBuf.AsMemory(), ct);
|
||||||
|
if (bytesRead == 0) return 0;
|
||||||
|
|
||||||
|
// Decode frames
|
||||||
|
var payloads = WsReadInfo.ReadFrames(ref _readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024);
|
||||||
|
|
||||||
|
// Collect control frame responses
|
||||||
|
if (_readInfo.PendingControlFrames.Count > 0)
|
||||||
|
{
|
||||||
|
lock (_writeLock)
|
||||||
|
_pendingControlWrites.AddRange(_readInfo.PendingControlFrames);
|
||||||
|
_readInfo.PendingControlFrames.Clear();
|
||||||
|
// Write pending control frames
|
||||||
|
await FlushControlFramesAsync(ct);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (_readInfo.CloseReceived)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
foreach (var payload in payloads)
|
||||||
|
_readQueue.Enqueue(payload);
|
||||||
|
|
||||||
|
if (_readQueue.Count == 0)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
return DrainReadQueue(buffer.Span);
|
||||||
|
}
|
||||||
|
|
||||||
|
public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken ct = default)
|
||||||
|
{
|
||||||
|
var data = buffer.Span;
|
||||||
|
|
||||||
|
if (_compress && data.Length > WsConstants.CompressThreshold)
|
||||||
|
{
|
||||||
|
var compressed = WsCompression.Compress(data);
|
||||||
|
await WriteFramedAsync(compressed, compressed: true, ct);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
await WriteFramedAsync(data.ToArray(), compressed: false, ct);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async ValueTask WriteFramedAsync(byte[] payload, bool compressed, CancellationToken ct)
|
||||||
|
{
|
||||||
|
if (_browser && payload.Length > WsConstants.FrameSizeForBrowsers && !(_noCompFrag && compressed))
|
||||||
|
{
|
||||||
|
// Fragment for browsers
|
||||||
|
int offset = 0;
|
||||||
|
bool first = true;
|
||||||
|
while (offset < payload.Length)
|
||||||
|
{
|
||||||
|
int chunkLen = Math.Min(WsConstants.FrameSizeForBrowsers, payload.Length - offset);
|
||||||
|
bool final = offset + chunkLen >= payload.Length;
|
||||||
|
var fh = new byte[WsConstants.MaxFrameHeaderSize];
|
||||||
|
var (n, key) = WsFrameWriter.FillFrameHeader(fh, _maskWrite,
|
||||||
|
first: first, final: final, compressed: first && compressed,
|
||||||
|
opcode: WsConstants.BinaryMessage, payloadLength: chunkLen);
|
||||||
|
|
||||||
|
var chunk = payload.AsSpan(offset, chunkLen).ToArray();
|
||||||
|
if (_maskWrite && key != null)
|
||||||
|
WsFrameWriter.MaskBuf(key, chunk);
|
||||||
|
|
||||||
|
await _inner.WriteAsync(fh.AsMemory(0, n), ct);
|
||||||
|
await _inner.WriteAsync(chunk.AsMemory(), ct);
|
||||||
|
offset += chunkLen;
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
var (header, key) = WsFrameWriter.CreateFrameHeader(_maskWrite, compressed, WsConstants.BinaryMessage, payload.Length);
|
||||||
|
if (_maskWrite && key != null)
|
||||||
|
WsFrameWriter.MaskBuf(key, payload);
|
||||||
|
await _inner.WriteAsync(header.AsMemory(), ct);
|
||||||
|
await _inner.WriteAsync(payload.AsMemory(), ct);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async Task FlushControlFramesAsync(CancellationToken ct)
|
||||||
|
{
|
||||||
|
List<ControlFrameAction> toWrite;
|
||||||
|
lock (_writeLock)
|
||||||
|
{
|
||||||
|
if (_pendingControlWrites.Count == 0) return;
|
||||||
|
toWrite = [.. _pendingControlWrites];
|
||||||
|
_pendingControlWrites.Clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
foreach (var action in toWrite)
|
||||||
|
{
|
||||||
|
var frame = WsFrameWriter.BuildControlFrame(action.Opcode, action.Payload, _maskWrite);
|
||||||
|
await _inner.WriteAsync(frame, ct);
|
||||||
|
}
|
||||||
|
await _inner.FlushAsync(ct);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Sends a WebSocket close frame.
|
||||||
|
/// </summary>
|
||||||
|
public async Task SendCloseAsync(ClientClosedReason reason, CancellationToken ct = default)
|
||||||
|
{
|
||||||
|
var status = WsFrameWriter.MapCloseStatus(reason);
|
||||||
|
var closePayload = WsFrameWriter.CreateCloseMessage(status, reason.ToReasonString());
|
||||||
|
var frame = WsFrameWriter.BuildControlFrame(WsConstants.CloseMessage, closePayload, _maskWrite);
|
||||||
|
await _inner.WriteAsync(frame, ct);
|
||||||
|
await _inner.FlushAsync(ct);
|
||||||
|
}
|
||||||
|
|
||||||
|
private int DrainReadQueue(Span<byte> buffer)
|
||||||
|
{
|
||||||
|
int written = 0;
|
||||||
|
while (_readQueue.Count > 0 && written < buffer.Length)
|
||||||
|
{
|
||||||
|
var current = _readQueue.Peek();
|
||||||
|
int available = current.Length - _readOffset;
|
||||||
|
int toCopy = Math.Min(available, buffer.Length - written);
|
||||||
|
current.AsSpan(_readOffset, toCopy).CopyTo(buffer[written..]);
|
||||||
|
written += toCopy;
|
||||||
|
_readOffset += toCopy;
|
||||||
|
if (_readOffset >= current.Length)
|
||||||
|
{
|
||||||
|
_readQueue.Dequeue();
|
||||||
|
_readOffset = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return written;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream abstract members
|
||||||
|
public override bool CanRead => true;
|
||||||
|
public override bool CanWrite => true;
|
||||||
|
public override bool CanSeek => false;
|
||||||
|
public override long Length => throw new NotSupportedException();
|
||||||
|
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
|
||||||
|
public override void Flush() => _inner.Flush();
|
||||||
|
public override Task FlushAsync(CancellationToken ct) => _inner.FlushAsync(ct);
|
||||||
|
public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use ReadAsync");
|
||||||
|
public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use WriteAsync");
|
||||||
|
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
|
||||||
|
public override void SetLength(long value) => throw new NotSupportedException();
|
||||||
|
|
||||||
|
protected override void Dispose(bool disposing)
|
||||||
|
{
|
||||||
|
if (disposing)
|
||||||
|
_inner.Dispose();
|
||||||
|
base.Dispose(disposing);
|
||||||
|
}
|
||||||
|
}
|
||||||
88
tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs
Normal file
88
tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
using System.Buffers.Binary;
|
||||||
|
using NATS.Server.WebSocket;
|
||||||
|
|
||||||
|
namespace NATS.Server.Tests.WebSocket;
|
||||||
|
|
||||||
|
public class WsConnectionTests
|
||||||
|
{
|
||||||
|
[Fact]
|
||||||
|
public async Task ReadAsync_DecodesFrameAndReturnsPayload()
|
||||||
|
{
|
||||||
|
var payload = "SUB test 1\r\n"u8.ToArray();
|
||||||
|
var frame = BuildUnmaskedFrame(payload);
|
||||||
|
var inner = new MemoryStream(frame);
|
||||||
|
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
|
||||||
|
|
||||||
|
var buf = new byte[256];
|
||||||
|
int n = await ws.ReadAsync(buf);
|
||||||
|
|
||||||
|
n.ShouldBe(payload.Length);
|
||||||
|
buf[..n].ShouldBe(payload);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task WriteAsync_FramesPayload()
|
||||||
|
{
|
||||||
|
var inner = new MemoryStream();
|
||||||
|
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
|
||||||
|
|
||||||
|
var payload = "MSG test 1 5\r\nHello\r\n"u8.ToArray();
|
||||||
|
await ws.WriteAsync(payload);
|
||||||
|
await ws.FlushAsync();
|
||||||
|
|
||||||
|
inner.Position = 0;
|
||||||
|
var written = inner.ToArray();
|
||||||
|
// First 2 bytes should be WS frame header
|
||||||
|
(written[0] & WsConstants.FinalBit).ShouldNotBe(0);
|
||||||
|
(written[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage);
|
||||||
|
int len = written[1] & 0x7F;
|
||||||
|
len.ShouldBe(payload.Length);
|
||||||
|
written[2..].ShouldBe(payload);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task WriteAsync_WithCompression_CompressesLargePayload()
|
||||||
|
{
|
||||||
|
var inner = new MemoryStream();
|
||||||
|
var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
|
||||||
|
|
||||||
|
var payload = new byte[200];
|
||||||
|
Array.Fill<byte>(payload, 0x41); // 'A' repeated - very compressible
|
||||||
|
await ws.WriteAsync(payload);
|
||||||
|
await ws.FlushAsync();
|
||||||
|
|
||||||
|
inner.Position = 0;
|
||||||
|
var written = inner.ToArray();
|
||||||
|
// RSV1 bit should be set for compressed frame
|
||||||
|
(written[0] & WsConstants.Rsv1Bit).ShouldNotBe(0);
|
||||||
|
// Compressed size should be less than original
|
||||||
|
written.Length.ShouldBeLessThan(payload.Length + 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task WriteAsync_SmallPayload_NotCompressedEvenWhenEnabled()
|
||||||
|
{
|
||||||
|
var inner = new MemoryStream();
|
||||||
|
var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
|
||||||
|
|
||||||
|
var payload = "Hi"u8.ToArray(); // Below CompressThreshold
|
||||||
|
await ws.WriteAsync(payload);
|
||||||
|
await ws.FlushAsync();
|
||||||
|
|
||||||
|
inner.Position = 0;
|
||||||
|
var written = inner.ToArray();
|
||||||
|
// RSV1 bit should NOT be set for small payloads
|
||||||
|
(written[0] & WsConstants.Rsv1Bit).ShouldBe(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static byte[] BuildUnmaskedFrame(byte[] payload)
|
||||||
|
{
|
||||||
|
var header = new byte[2];
|
||||||
|
header[0] = (byte)(WsConstants.FinalBit | WsConstants.BinaryMessage);
|
||||||
|
header[1] = (byte)payload.Length;
|
||||||
|
var frame = new byte[2 + payload.Length];
|
||||||
|
header.CopyTo(frame, 0);
|
||||||
|
payload.CopyTo(frame, 2);
|
||||||
|
return frame;
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user