diff --git a/src/NATS.Server/WebSocket/WsConnection.cs b/src/NATS.Server/WebSocket/WsConnection.cs new file mode 100644 index 0000000..61e51e7 --- /dev/null +++ b/src/NATS.Server/WebSocket/WsConnection.cs @@ -0,0 +1,192 @@ +namespace NATS.Server.WebSocket; + +/// +/// Stream wrapper that transparently frames/deframes WebSocket around raw TCP I/O. +/// NatsClient uses this as its _stream -- FillPipeAsync and RunWriteLoopAsync work unchanged. +/// Ported from golang/nats-server/server/websocket.go wsUpgrade/wrapWebsocket pattern. +/// +public sealed class WsConnection : Stream +{ + private readonly Stream _inner; + private readonly bool _compress; + private readonly bool _maskRead; + private readonly bool _maskWrite; + private readonly bool _browser; + private readonly bool _noCompFrag; + private WsReadInfo _readInfo; + private readonly Queue _readQueue = new(); + private int _readOffset; + private readonly object _writeLock = new(); + private readonly List _pendingControlWrites = []; + + public bool CloseReceived => _readInfo.CloseReceived; + public int CloseStatus => _readInfo.CloseStatus; + + public WsConnection(Stream inner, bool compress, bool maskRead, bool maskWrite, bool browser, bool noCompFrag) + { + _inner = inner; + _compress = compress; + _maskRead = maskRead; + _maskWrite = maskWrite; + _browser = browser; + _noCompFrag = noCompFrag; + _readInfo = new WsReadInfo(expectMask: maskRead); + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken ct = default) + { + // Drain any buffered decoded payloads first + if (_readQueue.Count > 0) + return DrainReadQueue(buffer.Span); + + // Read raw bytes from inner stream + var rawBuf = new byte[Math.Max(buffer.Length, 4096)]; + int bytesRead = await _inner.ReadAsync(rawBuf.AsMemory(), ct); + if (bytesRead == 0) return 0; + + // Decode frames + var payloads = WsReadInfo.ReadFrames(ref _readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024); + + // Collect control frame responses + if (_readInfo.PendingControlFrames.Count > 0) + { + lock (_writeLock) + _pendingControlWrites.AddRange(_readInfo.PendingControlFrames); + _readInfo.PendingControlFrames.Clear(); + // Write pending control frames + await FlushControlFramesAsync(ct); + } + + if (_readInfo.CloseReceived) + return 0; + + foreach (var payload in payloads) + _readQueue.Enqueue(payload); + + if (_readQueue.Count == 0) + return 0; + + return DrainReadQueue(buffer.Span); + } + + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken ct = default) + { + var data = buffer.Span; + + if (_compress && data.Length > WsConstants.CompressThreshold) + { + var compressed = WsCompression.Compress(data); + await WriteFramedAsync(compressed, compressed: true, ct); + } + else + { + await WriteFramedAsync(data.ToArray(), compressed: false, ct); + } + } + + private async ValueTask WriteFramedAsync(byte[] payload, bool compressed, CancellationToken ct) + { + if (_browser && payload.Length > WsConstants.FrameSizeForBrowsers && !(_noCompFrag && compressed)) + { + // Fragment for browsers + int offset = 0; + bool first = true; + while (offset < payload.Length) + { + int chunkLen = Math.Min(WsConstants.FrameSizeForBrowsers, payload.Length - offset); + bool final = offset + chunkLen >= payload.Length; + var fh = new byte[WsConstants.MaxFrameHeaderSize]; + var (n, key) = WsFrameWriter.FillFrameHeader(fh, _maskWrite, + first: first, final: final, compressed: first && compressed, + opcode: WsConstants.BinaryMessage, payloadLength: chunkLen); + + var chunk = payload.AsSpan(offset, chunkLen).ToArray(); + if (_maskWrite && key != null) + WsFrameWriter.MaskBuf(key, chunk); + + await _inner.WriteAsync(fh.AsMemory(0, n), ct); + await _inner.WriteAsync(chunk.AsMemory(), ct); + offset += chunkLen; + first = false; + } + } + else + { + var (header, key) = WsFrameWriter.CreateFrameHeader(_maskWrite, compressed, WsConstants.BinaryMessage, payload.Length); + if (_maskWrite && key != null) + WsFrameWriter.MaskBuf(key, payload); + await _inner.WriteAsync(header.AsMemory(), ct); + await _inner.WriteAsync(payload.AsMemory(), ct); + } + } + + private async Task FlushControlFramesAsync(CancellationToken ct) + { + List toWrite; + lock (_writeLock) + { + if (_pendingControlWrites.Count == 0) return; + toWrite = [.. _pendingControlWrites]; + _pendingControlWrites.Clear(); + } + + foreach (var action in toWrite) + { + var frame = WsFrameWriter.BuildControlFrame(action.Opcode, action.Payload, _maskWrite); + await _inner.WriteAsync(frame, ct); + } + await _inner.FlushAsync(ct); + } + + /// + /// Sends a WebSocket close frame. + /// + public async Task SendCloseAsync(ClientClosedReason reason, CancellationToken ct = default) + { + var status = WsFrameWriter.MapCloseStatus(reason); + var closePayload = WsFrameWriter.CreateCloseMessage(status, reason.ToReasonString()); + var frame = WsFrameWriter.BuildControlFrame(WsConstants.CloseMessage, closePayload, _maskWrite); + await _inner.WriteAsync(frame, ct); + await _inner.FlushAsync(ct); + } + + private int DrainReadQueue(Span buffer) + { + int written = 0; + while (_readQueue.Count > 0 && written < buffer.Length) + { + var current = _readQueue.Peek(); + int available = current.Length - _readOffset; + int toCopy = Math.Min(available, buffer.Length - written); + current.AsSpan(_readOffset, toCopy).CopyTo(buffer[written..]); + written += toCopy; + _readOffset += toCopy; + if (_readOffset >= current.Length) + { + _readQueue.Dequeue(); + _readOffset = 0; + } + } + return written; + } + + // Stream abstract members + public override bool CanRead => true; + public override bool CanWrite => true; + public override bool CanSeek => false; + public override long Length => throw new NotSupportedException(); + public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } + public override void Flush() => _inner.Flush(); + public override Task FlushAsync(CancellationToken ct) => _inner.FlushAsync(ct); + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use ReadAsync"); + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use WriteAsync"); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + + protected override void Dispose(bool disposing) + { + if (disposing) + _inner.Dispose(); + base.Dispose(disposing); + } +} diff --git a/tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs b/tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs new file mode 100644 index 0000000..2955b1d --- /dev/null +++ b/tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs @@ -0,0 +1,88 @@ +using System.Buffers.Binary; +using NATS.Server.WebSocket; + +namespace NATS.Server.Tests.WebSocket; + +public class WsConnectionTests +{ + [Fact] + public async Task ReadAsync_DecodesFrameAndReturnsPayload() + { + var payload = "SUB test 1\r\n"u8.ToArray(); + var frame = BuildUnmaskedFrame(payload); + var inner = new MemoryStream(frame); + var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false); + + var buf = new byte[256]; + int n = await ws.ReadAsync(buf); + + n.ShouldBe(payload.Length); + buf[..n].ShouldBe(payload); + } + + [Fact] + public async Task WriteAsync_FramesPayload() + { + var inner = new MemoryStream(); + var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false); + + var payload = "MSG test 1 5\r\nHello\r\n"u8.ToArray(); + await ws.WriteAsync(payload); + await ws.FlushAsync(); + + inner.Position = 0; + var written = inner.ToArray(); + // First 2 bytes should be WS frame header + (written[0] & WsConstants.FinalBit).ShouldNotBe(0); + (written[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage); + int len = written[1] & 0x7F; + len.ShouldBe(payload.Length); + written[2..].ShouldBe(payload); + } + + [Fact] + public async Task WriteAsync_WithCompression_CompressesLargePayload() + { + var inner = new MemoryStream(); + var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false); + + var payload = new byte[200]; + Array.Fill(payload, 0x41); // 'A' repeated - very compressible + await ws.WriteAsync(payload); + await ws.FlushAsync(); + + inner.Position = 0; + var written = inner.ToArray(); + // RSV1 bit should be set for compressed frame + (written[0] & WsConstants.Rsv1Bit).ShouldNotBe(0); + // Compressed size should be less than original + written.Length.ShouldBeLessThan(payload.Length + 10); + } + + [Fact] + public async Task WriteAsync_SmallPayload_NotCompressedEvenWhenEnabled() + { + var inner = new MemoryStream(); + var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false); + + var payload = "Hi"u8.ToArray(); // Below CompressThreshold + await ws.WriteAsync(payload); + await ws.FlushAsync(); + + inner.Position = 0; + var written = inner.ToArray(); + // RSV1 bit should NOT be set for small payloads + (written[0] & WsConstants.Rsv1Bit).ShouldBe(0); + } + + private static byte[] BuildUnmaskedFrame(byte[] payload) + { + var header = new byte[2]; + header[0] = (byte)(WsConstants.FinalBit | WsConstants.BinaryMessage); + header[1] = (byte)payload.Length; + var frame = new byte[2 + payload.Length]; + header.CopyTo(frame, 0); + payload.CopyTo(frame, 2); + return frame; + } +}