Files
natsdotnet/src/NATS.Server/LeafNodes/WebSocketStreamAdapter.cs

249 lines
8.5 KiB
C#

using SystemWebSocket = System.Net.WebSockets.WebSocket;
using System.Net.WebSockets;
namespace NATS.Server.LeafNodes;
/// <summary>
/// Adapts a System.Net.WebSockets.WebSocket into a Stream suitable for use
/// by LeafConnection. Handles message framing: reads aggregate WebSocket messages
/// into a contiguous byte stream, and writes flush as single WebSocket messages.
/// Go reference: leafnode.go wsCreateLeafConnection, client.go wsRead/wsWrite.
/// </summary>
public sealed class WebSocketStreamAdapter : Stream
{
private readonly SystemWebSocket _ws;
private byte[] _readBuffer;
private int _readOffset;
private int _readCount;
private bool _disposed;
/// <summary>
/// Creates a stream adapter for a WebSocket-backed leaf-node transport.
/// </summary>
/// <param name="ws">WebSocket transport used for framed binary I/O.</param>
/// <param name="initialBufferSize">Initial receive staging-buffer size.</param>
public WebSocketStreamAdapter(SystemWebSocket ws, int initialBufferSize = 4096)
{
_ws = ws ?? throw new ArgumentNullException(nameof(ws));
_readBuffer = new byte[Math.Max(initialBufferSize, 64)];
_readOffset = 0;
_readCount = 0;
}
// Stream capability overrides
/// <inheritdoc />
public override bool CanRead => true;
/// <inheritdoc />
public override bool CanWrite => true;
/// <inheritdoc />
public override bool CanSeek => false;
// Telemetry properties
/// <summary>Whether the underlying WebSocket is currently open.</summary>
public bool IsConnected => _ws.State == WebSocketState.Open;
/// <summary>Total bytes read from received WebSocket messages.</summary>
public long BytesRead { get; private set; }
/// <summary>Total bytes written to outbound WebSocket messages.</summary>
public long BytesWritten { get; private set; }
/// <summary>Total completed WebSocket messages read.</summary>
public int MessagesRead { get; private set; }
/// <summary>Total completed WebSocket messages written.</summary>
public int MessagesWritten { get; private set; }
/// <inheritdoc />
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken ct)
{
ObjectDisposedException.ThrowIf(_disposed, this);
// Drain any leftover data from the previous WebSocket message first.
if (_readCount > 0)
{
var fromBuffer = Math.Min(_readCount, count);
_readBuffer.AsSpan(_readOffset, fromBuffer).CopyTo(buffer.AsSpan(offset, fromBuffer));
_readOffset += fromBuffer;
_readCount -= fromBuffer;
if (_readCount == 0)
_readOffset = 0;
return fromBuffer;
}
// Receive the next WebSocket message, growing the buffer as needed.
var totalReceived = 0;
while (true)
{
EnsureReadBufferCapacity(totalReceived + 1024);
var result = await _ws.ReceiveAsync(
_readBuffer.AsMemory(totalReceived),
ct).ConfigureAwait(false);
if (result.MessageType == WebSocketMessageType.Close)
return 0;
totalReceived += result.Count;
if (result.EndOfMessage)
{
MessagesRead++;
BytesRead += totalReceived;
// Copy what fits into the caller's buffer; remainder stays in _readBuffer.
var toCopy = Math.Min(totalReceived, count);
_readBuffer.AsSpan(0, toCopy).CopyTo(buffer.AsSpan(offset, toCopy));
var remaining = totalReceived - toCopy;
if (remaining > 0)
{
_readOffset = toCopy;
_readCount = remaining;
}
else
{
_readOffset = 0;
_readCount = 0;
}
return toCopy;
}
// Partial message — make sure buffer has room for more data.
EnsureReadBufferCapacity(totalReceived + 1024);
}
}
/// <inheritdoc/>
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken ct = default)
{
ObjectDisposedException.ThrowIf(_disposed, this);
// Drain buffered data first.
if (_readCount > 0)
{
var fromBuffer = Math.Min(_readCount, buffer.Length);
_readBuffer.AsMemory(_readOffset, fromBuffer).CopyTo(buffer[..fromBuffer]);
_readOffset += fromBuffer;
_readCount -= fromBuffer;
if (_readCount == 0)
_readOffset = 0;
return fromBuffer;
}
// Receive the next WebSocket message into a temporary staging area.
var totalReceived = 0;
while (true)
{
EnsureReadBufferCapacity(totalReceived + 1024);
var result = await _ws.ReceiveAsync(
_readBuffer.AsMemory(totalReceived),
ct).ConfigureAwait(false);
if (result.MessageType == WebSocketMessageType.Close)
return 0;
totalReceived += result.Count;
if (result.EndOfMessage)
{
MessagesRead++;
BytesRead += totalReceived;
var toCopy = Math.Min(totalReceived, buffer.Length);
_readBuffer.AsMemory(0, toCopy).CopyTo(buffer[..toCopy]);
var remaining = totalReceived - toCopy;
if (remaining > 0)
{
_readOffset = toCopy;
_readCount = remaining;
}
else
{
_readOffset = 0;
_readCount = 0;
}
return toCopy;
}
EnsureReadBufferCapacity(totalReceived + 1024);
}
}
/// <inheritdoc />
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ct)
{
ObjectDisposedException.ThrowIf(_disposed, this);
await _ws.SendAsync(
buffer.AsMemory(offset, count),
WebSocketMessageType.Binary,
endOfMessage: true,
ct).ConfigureAwait(false);
BytesWritten += count;
MessagesWritten++;
}
/// <inheritdoc/>
public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken ct = default)
{
ObjectDisposedException.ThrowIf(_disposed, this);
await _ws.SendAsync(
buffer,
WebSocketMessageType.Binary,
endOfMessage: true,
ct).ConfigureAwait(false);
BytesWritten += buffer.Length;
MessagesWritten++;
}
/// <inheritdoc/>
public override Task FlushAsync(CancellationToken ct) => Task.CompletedTask;
// Not-supported synchronous and seeking members
/// <inheritdoc />
public override long Length => throw new NotSupportedException();
/// <inheritdoc />
public override long Position
{
get => throw new NotSupportedException();
set => throw new NotSupportedException();
}
/// <inheritdoc />
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
/// <inheritdoc />
public override void SetLength(long value) => throw new NotSupportedException();
/// <inheritdoc />
public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use async methods");
/// <inheritdoc />
public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use async methods");
/// <inheritdoc />
public override void Flush() { }
/// <inheritdoc />
protected override void Dispose(bool disposing)
{
if (_disposed)
return;
_disposed = true;
if (disposing)
_ws.Dispose();
base.Dispose(disposing);
}
// -------------------------------------------------------------------------
// Helpers
// -------------------------------------------------------------------------
private void EnsureReadBufferCapacity(int required)
{
if (_readBuffer.Length >= required)
return;
var newSize = Math.Max(required, _readBuffer.Length * 2);
var next = new byte[newSize];
if (_readCount > 0)
_readBuffer.AsSpan(_readOffset, _readCount).CopyTo(next);
_readBuffer = next;
_readOffset = 0;
// _readCount unchanged
}
}