249 lines
8.5 KiB
C#
249 lines
8.5 KiB
C#
using SystemWebSocket = System.Net.WebSockets.WebSocket;
|
|
using System.Net.WebSockets;
|
|
|
|
namespace NATS.Server.LeafNodes;
|
|
|
|
/// <summary>
|
|
/// Adapts a System.Net.WebSockets.WebSocket into a Stream suitable for use
|
|
/// by LeafConnection. Handles message framing: reads aggregate WebSocket messages
|
|
/// into a contiguous byte stream, and writes flush as single WebSocket messages.
|
|
/// Go reference: leafnode.go wsCreateLeafConnection, client.go wsRead/wsWrite.
|
|
/// </summary>
|
|
public sealed class WebSocketStreamAdapter : Stream
|
|
{
|
|
private readonly SystemWebSocket _ws;
|
|
private byte[] _readBuffer;
|
|
private int _readOffset;
|
|
private int _readCount;
|
|
private bool _disposed;
|
|
|
|
/// <summary>
|
|
/// Creates a stream adapter for a WebSocket-backed leaf-node transport.
|
|
/// </summary>
|
|
/// <param name="ws">WebSocket transport used for framed binary I/O.</param>
|
|
/// <param name="initialBufferSize">Initial receive staging-buffer size.</param>
|
|
public WebSocketStreamAdapter(SystemWebSocket ws, int initialBufferSize = 4096)
|
|
{
|
|
_ws = ws ?? throw new ArgumentNullException(nameof(ws));
|
|
_readBuffer = new byte[Math.Max(initialBufferSize, 64)];
|
|
_readOffset = 0;
|
|
_readCount = 0;
|
|
}
|
|
|
|
// Stream capability overrides
|
|
/// <inheritdoc />
|
|
public override bool CanRead => true;
|
|
/// <inheritdoc />
|
|
public override bool CanWrite => true;
|
|
/// <inheritdoc />
|
|
public override bool CanSeek => false;
|
|
|
|
// Telemetry properties
|
|
/// <summary>Whether the underlying WebSocket is currently open.</summary>
|
|
public bool IsConnected => _ws.State == WebSocketState.Open;
|
|
/// <summary>Total bytes read from received WebSocket messages.</summary>
|
|
public long BytesRead { get; private set; }
|
|
/// <summary>Total bytes written to outbound WebSocket messages.</summary>
|
|
public long BytesWritten { get; private set; }
|
|
/// <summary>Total completed WebSocket messages read.</summary>
|
|
public int MessagesRead { get; private set; }
|
|
/// <summary>Total completed WebSocket messages written.</summary>
|
|
public int MessagesWritten { get; private set; }
|
|
|
|
/// <inheritdoc />
|
|
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken ct)
|
|
{
|
|
ObjectDisposedException.ThrowIf(_disposed, this);
|
|
|
|
// Drain any leftover data from the previous WebSocket message first.
|
|
if (_readCount > 0)
|
|
{
|
|
var fromBuffer = Math.Min(_readCount, count);
|
|
_readBuffer.AsSpan(_readOffset, fromBuffer).CopyTo(buffer.AsSpan(offset, fromBuffer));
|
|
_readOffset += fromBuffer;
|
|
_readCount -= fromBuffer;
|
|
if (_readCount == 0)
|
|
_readOffset = 0;
|
|
return fromBuffer;
|
|
}
|
|
|
|
// Receive the next WebSocket message, growing the buffer as needed.
|
|
var totalReceived = 0;
|
|
while (true)
|
|
{
|
|
EnsureReadBufferCapacity(totalReceived + 1024);
|
|
var result = await _ws.ReceiveAsync(
|
|
_readBuffer.AsMemory(totalReceived),
|
|
ct).ConfigureAwait(false);
|
|
|
|
if (result.MessageType == WebSocketMessageType.Close)
|
|
return 0;
|
|
|
|
totalReceived += result.Count;
|
|
|
|
if (result.EndOfMessage)
|
|
{
|
|
MessagesRead++;
|
|
BytesRead += totalReceived;
|
|
|
|
// Copy what fits into the caller's buffer; remainder stays in _readBuffer.
|
|
var toCopy = Math.Min(totalReceived, count);
|
|
_readBuffer.AsSpan(0, toCopy).CopyTo(buffer.AsSpan(offset, toCopy));
|
|
|
|
var remaining = totalReceived - toCopy;
|
|
if (remaining > 0)
|
|
{
|
|
_readOffset = toCopy;
|
|
_readCount = remaining;
|
|
}
|
|
else
|
|
{
|
|
_readOffset = 0;
|
|
_readCount = 0;
|
|
}
|
|
|
|
return toCopy;
|
|
}
|
|
|
|
// Partial message — make sure buffer has room for more data.
|
|
EnsureReadBufferCapacity(totalReceived + 1024);
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc/>
|
|
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken ct = default)
|
|
{
|
|
ObjectDisposedException.ThrowIf(_disposed, this);
|
|
|
|
// Drain buffered data first.
|
|
if (_readCount > 0)
|
|
{
|
|
var fromBuffer = Math.Min(_readCount, buffer.Length);
|
|
_readBuffer.AsMemory(_readOffset, fromBuffer).CopyTo(buffer[..fromBuffer]);
|
|
_readOffset += fromBuffer;
|
|
_readCount -= fromBuffer;
|
|
if (_readCount == 0)
|
|
_readOffset = 0;
|
|
return fromBuffer;
|
|
}
|
|
|
|
// Receive the next WebSocket message into a temporary staging area.
|
|
var totalReceived = 0;
|
|
while (true)
|
|
{
|
|
EnsureReadBufferCapacity(totalReceived + 1024);
|
|
var result = await _ws.ReceiveAsync(
|
|
_readBuffer.AsMemory(totalReceived),
|
|
ct).ConfigureAwait(false);
|
|
|
|
if (result.MessageType == WebSocketMessageType.Close)
|
|
return 0;
|
|
|
|
totalReceived += result.Count;
|
|
|
|
if (result.EndOfMessage)
|
|
{
|
|
MessagesRead++;
|
|
BytesRead += totalReceived;
|
|
|
|
var toCopy = Math.Min(totalReceived, buffer.Length);
|
|
_readBuffer.AsMemory(0, toCopy).CopyTo(buffer[..toCopy]);
|
|
|
|
var remaining = totalReceived - toCopy;
|
|
if (remaining > 0)
|
|
{
|
|
_readOffset = toCopy;
|
|
_readCount = remaining;
|
|
}
|
|
else
|
|
{
|
|
_readOffset = 0;
|
|
_readCount = 0;
|
|
}
|
|
|
|
return toCopy;
|
|
}
|
|
|
|
EnsureReadBufferCapacity(totalReceived + 1024);
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ct)
|
|
{
|
|
ObjectDisposedException.ThrowIf(_disposed, this);
|
|
await _ws.SendAsync(
|
|
buffer.AsMemory(offset, count),
|
|
WebSocketMessageType.Binary,
|
|
endOfMessage: true,
|
|
ct).ConfigureAwait(false);
|
|
BytesWritten += count;
|
|
MessagesWritten++;
|
|
}
|
|
|
|
/// <inheritdoc/>
|
|
public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken ct = default)
|
|
{
|
|
ObjectDisposedException.ThrowIf(_disposed, this);
|
|
await _ws.SendAsync(
|
|
buffer,
|
|
WebSocketMessageType.Binary,
|
|
endOfMessage: true,
|
|
ct).ConfigureAwait(false);
|
|
BytesWritten += buffer.Length;
|
|
MessagesWritten++;
|
|
}
|
|
|
|
/// <inheritdoc/>
|
|
public override Task FlushAsync(CancellationToken ct) => Task.CompletedTask;
|
|
|
|
// Not-supported synchronous and seeking members
|
|
/// <inheritdoc />
|
|
public override long Length => throw new NotSupportedException();
|
|
/// <inheritdoc />
|
|
public override long Position
|
|
{
|
|
get => throw new NotSupportedException();
|
|
set => throw new NotSupportedException();
|
|
}
|
|
/// <inheritdoc />
|
|
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
|
|
/// <inheritdoc />
|
|
public override void SetLength(long value) => throw new NotSupportedException();
|
|
/// <inheritdoc />
|
|
public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use async methods");
|
|
/// <inheritdoc />
|
|
public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use async methods");
|
|
/// <inheritdoc />
|
|
public override void Flush() { }
|
|
|
|
/// <inheritdoc />
|
|
protected override void Dispose(bool disposing)
|
|
{
|
|
if (_disposed)
|
|
return;
|
|
_disposed = true;
|
|
if (disposing)
|
|
_ws.Dispose();
|
|
base.Dispose(disposing);
|
|
}
|
|
|
|
// -------------------------------------------------------------------------
|
|
// Helpers
|
|
// -------------------------------------------------------------------------
|
|
|
|
private void EnsureReadBufferCapacity(int required)
|
|
{
|
|
if (_readBuffer.Length >= required)
|
|
return;
|
|
|
|
var newSize = Math.Max(required, _readBuffer.Length * 2);
|
|
var next = new byte[newSize];
|
|
if (_readCount > 0)
|
|
_readBuffer.AsSpan(_readOffset, _readCount).CopyTo(next);
|
|
_readBuffer = next;
|
|
_readOffset = 0;
|
|
// _readCount unchanged
|
|
}
|
|
}
|