fix: address code review findings for WebSocket implementation

- Convert WsReadInfo from mutable struct to class (prevents silent copy bugs)
- Add handshake timeout enforcement via CancellationToken in WsUpgrade
- Use buffered reading (512 bytes) in ReadHttpRequestAsync instead of byte-at-a-time
- Add IAsyncDisposable to WsConnection for proper async cleanup
- Simplify redundant mask bit check in WsReadInfo
- Remove unused WsGuid and CompressLastBlock dead code from WsConstants
- Document single-reader assumption on WsConnection read-side state
This commit is contained in:
Joseph Doherty
2026-02-23 05:27:36 -05:00
parent 5fd2cf040d
commit 18a6d0f478
6 changed files with 42 additions and 33 deletions

View File

@@ -14,6 +14,7 @@ public sealed class WsConnection : Stream
private readonly bool _browser;
private readonly bool _noCompFrag;
private WsReadInfo _readInfo;
// Read-side state: accessed only from the single FillPipeAsync reader task (no synchronization needed)
private readonly Queue<byte[]> _readQueue = new();
private int _readOffset;
private readonly object _writeLock = new();
@@ -47,7 +48,7 @@ public sealed class WsConnection : Stream
if (bytesRead == 0) return 0;
// Decode frames
var payloads = WsReadInfo.ReadFrames(ref _readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024);
var payloads = WsReadInfo.ReadFrames(_readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024);
// Collect control frame responses
if (_readInfo.PendingControlFrames.Count > 0)
@@ -192,4 +193,10 @@ public sealed class WsConnection : Stream
_inner.Dispose();
base.Dispose(disposing);
}
public override async ValueTask DisposeAsync()
{
await _inner.DisposeAsync();
GC.SuppressFinalize(this);
}
}

View File

@@ -58,13 +58,7 @@ public static class WsConstants
public const string LeafNodePath = "/leafnode";
public const string MqttPath = "/mqtt";
// WebSocket GUID (RFC 6455 Section 1.3)
public static readonly byte[] WsGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"u8.ToArray();
// Compression trailer (RFC 7692 Section 7.2.2)
public static readonly byte[] CompressLastBlock = [0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff];
// Decompression trailer appended before decompressing
// Decompression trailer appended before decompressing (RFC 7692 Section 7.2.2)
public static readonly byte[] DecompressTrailer = [0x00, 0x00, 0xff, 0xff];
public static bool IsControlFrame(int opcode) => opcode >= CloseMessage;

View File

@@ -7,7 +7,7 @@ namespace NATS.Server.WebSocket;
/// Per-connection WebSocket frame reading state machine.
/// Ported from golang/nats-server/server/websocket.go lines 156-506.
/// </summary>
public struct WsReadInfo
public class WsReadInfo
{
public int Remaining;
public bool FrameStart;
@@ -97,7 +97,7 @@ public struct WsReadInfo
/// Returns list of decoded payload byte arrays.
/// Ported from websocket.go lines 208-351.
/// </summary>
public static List<byte[]> ReadFrames(ref WsReadInfo r, Stream stream, int available, int maxPayload)
public static List<byte[]> ReadFrames(WsReadInfo r, Stream stream, int available, int maxPayload)
{
var bufs = new List<byte[]>();
var buf = new byte[available];
@@ -184,8 +184,8 @@ public struct WsReadInfo
}
}
// Read mask key
if (r.ExpectMask && (b1 & WsConstants.MaskBit) != 0)
// Read mask key (mask bit already validated at line 134)
if (r.ExpectMask)
{
var (keyBuf, p2) = WsGet(stream, buf, pos, max, 4);
pos = p2;
@@ -196,7 +196,7 @@ public struct WsReadInfo
// Handle control frames
if (WsConstants.IsControlFrame(frameType))
{
pos = HandleControlFrame(ref r, frameType, stream, buf, pos, max);
pos = HandleControlFrame(r, frameType, stream, buf, pos, max);
continue;
}
@@ -243,7 +243,7 @@ public struct WsReadInfo
return bufs;
}
private static int HandleControlFrame(ref WsReadInfo r, int frameType, Stream stream, byte[] buf, int pos, int max)
private static int HandleControlFrame(WsReadInfo r, int frameType, Stream stream, byte[] buf, int pos, int max)
{
byte[]? payload = null;
if (r.Remaining > 0)

View File

@@ -11,11 +11,14 @@ namespace NATS.Server.WebSocket;
public static class WsUpgrade
{
public static async Task<WsUpgradeResult> TryUpgradeAsync(
Stream inputStream, Stream outputStream, WebSocketOptions options)
Stream inputStream, Stream outputStream, WebSocketOptions options,
CancellationToken ct = default)
{
try
{
var (method, path, headers) = await ReadHttpRequestAsync(inputStream);
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
cts.CancelAfter(options.HandshakeTimeout);
var (method, path, headers) = await ReadHttpRequestAsync(inputStream, cts.Token);
if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase))
return await FailAsync(outputStream, 405, "request method must be GET");
@@ -165,22 +168,27 @@ public static class WsUpgrade
return WsUpgradeResult.Failed;
}
private static async Task<(string method, string path, Dictionary<string, string> headers)> ReadHttpRequestAsync(Stream stream)
private static async Task<(string method, string path, Dictionary<string, string> headers)> ReadHttpRequestAsync(
Stream stream, CancellationToken ct)
{
var headerBytes = new List<byte>(4096);
var buf = new byte[1];
var buf = new byte[512];
while (true)
{
int n = await stream.ReadAsync(buf);
int n = await stream.ReadAsync(buf, ct);
if (n == 0) throw new IOException("connection closed during handshake");
headerBytes.Add(buf[0]);
if (headerBytes.Count >= 4 &&
headerBytes[^4] == '\r' && headerBytes[^3] == '\n' &&
headerBytes[^2] == '\r' && headerBytes[^1] == '\n')
break;
if (headerBytes.Count > 8192)
throw new InvalidOperationException("HTTP header too large");
for (int i = 0; i < n; i++)
{
headerBytes.Add(buf[i]);
if (headerBytes.Count >= 4 &&
headerBytes[^4] == '\r' && headerBytes[^3] == '\n' &&
headerBytes[^2] == '\r' && headerBytes[^1] == '\n')
goto done;
if (headerBytes.Count > 8192)
throw new InvalidOperationException("HTTP header too large");
}
}
done:;
var text = Encoding.ASCII.GetString(headerBytes.ToArray());
var lines = text.Split("\r\n", StringSplitOptions.None);