fix: address code review findings for WebSocket implementation

- Convert WsReadInfo from mutable struct to class (prevents silent copy bugs)
- Add handshake timeout enforcement via CancellationToken in WsUpgrade
- Use buffered reading (512 bytes) in ReadHttpRequestAsync instead of byte-at-a-time
- Add IAsyncDisposable to WsConnection for proper async cleanup
- Simplify redundant mask bit check in WsReadInfo
- Remove unused WsGuid and CompressLastBlock dead code from WsConstants
- Document single-reader assumption on WsConnection read-side state
This commit is contained in:
Joseph Doherty
2026-02-23 05:27:36 -05:00
parent 5fd2cf040d
commit 18a6d0f478
6 changed files with 42 additions and 33 deletions

View File

@@ -568,7 +568,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
} }
// HTTP upgrade handshake // HTTP upgrade handshake
var upgradeResult = await WsUpgrade.TryUpgradeAsync(stream, stream, _options.WebSocket); var upgradeResult = await WsUpgrade.TryUpgradeAsync(stream, stream, _options.WebSocket, ct);
if (!upgradeResult.Success) if (!upgradeResult.Success)
{ {
_logger.LogDebug("WebSocket upgrade failed for client {ClientId}", clientId); _logger.LogDebug("WebSocket upgrade failed for client {ClientId}", clientId);

View File

@@ -14,6 +14,7 @@ public sealed class WsConnection : Stream
private readonly bool _browser; private readonly bool _browser;
private readonly bool _noCompFrag; private readonly bool _noCompFrag;
private WsReadInfo _readInfo; private WsReadInfo _readInfo;
// Read-side state: accessed only from the single FillPipeAsync reader task (no synchronization needed)
private readonly Queue<byte[]> _readQueue = new(); private readonly Queue<byte[]> _readQueue = new();
private int _readOffset; private int _readOffset;
private readonly object _writeLock = new(); private readonly object _writeLock = new();
@@ -47,7 +48,7 @@ public sealed class WsConnection : Stream
if (bytesRead == 0) return 0; if (bytesRead == 0) return 0;
// Decode frames // Decode frames
var payloads = WsReadInfo.ReadFrames(ref _readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024); var payloads = WsReadInfo.ReadFrames(_readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024);
// Collect control frame responses // Collect control frame responses
if (_readInfo.PendingControlFrames.Count > 0) if (_readInfo.PendingControlFrames.Count > 0)
@@ -192,4 +193,10 @@ public sealed class WsConnection : Stream
_inner.Dispose(); _inner.Dispose();
base.Dispose(disposing); base.Dispose(disposing);
} }
public override async ValueTask DisposeAsync()
{
await _inner.DisposeAsync();
GC.SuppressFinalize(this);
}
} }

View File

@@ -58,13 +58,7 @@ public static class WsConstants
public const string LeafNodePath = "/leafnode"; public const string LeafNodePath = "/leafnode";
public const string MqttPath = "/mqtt"; public const string MqttPath = "/mqtt";
// WebSocket GUID (RFC 6455 Section 1.3) // Decompression trailer appended before decompressing (RFC 7692 Section 7.2.2)
public static readonly byte[] WsGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"u8.ToArray();
// Compression trailer (RFC 7692 Section 7.2.2)
public static readonly byte[] CompressLastBlock = [0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff];
// Decompression trailer appended before decompressing
public static readonly byte[] DecompressTrailer = [0x00, 0x00, 0xff, 0xff]; public static readonly byte[] DecompressTrailer = [0x00, 0x00, 0xff, 0xff];
public static bool IsControlFrame(int opcode) => opcode >= CloseMessage; public static bool IsControlFrame(int opcode) => opcode >= CloseMessage;

View File

@@ -7,7 +7,7 @@ namespace NATS.Server.WebSocket;
/// Per-connection WebSocket frame reading state machine. /// Per-connection WebSocket frame reading state machine.
/// Ported from golang/nats-server/server/websocket.go lines 156-506. /// Ported from golang/nats-server/server/websocket.go lines 156-506.
/// </summary> /// </summary>
public struct WsReadInfo public class WsReadInfo
{ {
public int Remaining; public int Remaining;
public bool FrameStart; public bool FrameStart;
@@ -97,7 +97,7 @@ public struct WsReadInfo
/// Returns list of decoded payload byte arrays. /// Returns list of decoded payload byte arrays.
/// Ported from websocket.go lines 208-351. /// Ported from websocket.go lines 208-351.
/// </summary> /// </summary>
public static List<byte[]> ReadFrames(ref WsReadInfo r, Stream stream, int available, int maxPayload) public static List<byte[]> ReadFrames(WsReadInfo r, Stream stream, int available, int maxPayload)
{ {
var bufs = new List<byte[]>(); var bufs = new List<byte[]>();
var buf = new byte[available]; var buf = new byte[available];
@@ -184,8 +184,8 @@ public struct WsReadInfo
} }
} }
// Read mask key // Read mask key (mask bit already validated at line 134)
if (r.ExpectMask && (b1 & WsConstants.MaskBit) != 0) if (r.ExpectMask)
{ {
var (keyBuf, p2) = WsGet(stream, buf, pos, max, 4); var (keyBuf, p2) = WsGet(stream, buf, pos, max, 4);
pos = p2; pos = p2;
@@ -196,7 +196,7 @@ public struct WsReadInfo
// Handle control frames // Handle control frames
if (WsConstants.IsControlFrame(frameType)) if (WsConstants.IsControlFrame(frameType))
{ {
pos = HandleControlFrame(ref r, frameType, stream, buf, pos, max); pos = HandleControlFrame(r, frameType, stream, buf, pos, max);
continue; continue;
} }
@@ -243,7 +243,7 @@ public struct WsReadInfo
return bufs; return bufs;
} }
private static int HandleControlFrame(ref WsReadInfo r, int frameType, Stream stream, byte[] buf, int pos, int max) private static int HandleControlFrame(WsReadInfo r, int frameType, Stream stream, byte[] buf, int pos, int max)
{ {
byte[]? payload = null; byte[]? payload = null;
if (r.Remaining > 0) if (r.Remaining > 0)

View File

@@ -11,11 +11,14 @@ namespace NATS.Server.WebSocket;
public static class WsUpgrade public static class WsUpgrade
{ {
public static async Task<WsUpgradeResult> TryUpgradeAsync( public static async Task<WsUpgradeResult> TryUpgradeAsync(
Stream inputStream, Stream outputStream, WebSocketOptions options) Stream inputStream, Stream outputStream, WebSocketOptions options,
CancellationToken ct = default)
{ {
try try
{ {
var (method, path, headers) = await ReadHttpRequestAsync(inputStream); using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
cts.CancelAfter(options.HandshakeTimeout);
var (method, path, headers) = await ReadHttpRequestAsync(inputStream, cts.Token);
if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase)) if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase))
return await FailAsync(outputStream, 405, "request method must be GET"); return await FailAsync(outputStream, 405, "request method must be GET");
@@ -165,22 +168,27 @@ public static class WsUpgrade
return WsUpgradeResult.Failed; return WsUpgradeResult.Failed;
} }
private static async Task<(string method, string path, Dictionary<string, string> headers)> ReadHttpRequestAsync(Stream stream) private static async Task<(string method, string path, Dictionary<string, string> headers)> ReadHttpRequestAsync(
Stream stream, CancellationToken ct)
{ {
var headerBytes = new List<byte>(4096); var headerBytes = new List<byte>(4096);
var buf = new byte[1]; var buf = new byte[512];
while (true) while (true)
{ {
int n = await stream.ReadAsync(buf); int n = await stream.ReadAsync(buf, ct);
if (n == 0) throw new IOException("connection closed during handshake"); if (n == 0) throw new IOException("connection closed during handshake");
headerBytes.Add(buf[0]); for (int i = 0; i < n; i++)
if (headerBytes.Count >= 4 && {
headerBytes[^4] == '\r' && headerBytes[^3] == '\n' && headerBytes.Add(buf[i]);
headerBytes[^2] == '\r' && headerBytes[^1] == '\n') if (headerBytes.Count >= 4 &&
break; headerBytes[^4] == '\r' && headerBytes[^3] == '\n' &&
if (headerBytes.Count > 8192) headerBytes[^2] == '\r' && headerBytes[^1] == '\n')
throw new InvalidOperationException("HTTP header too large"); goto done;
if (headerBytes.Count > 8192)
throw new InvalidOperationException("HTTP header too large");
}
} }
done:;
var text = Encoding.ASCII.GetString(headerBytes.ToArray()); var text = Encoding.ASCII.GetString(headerBytes.ToArray());
var lines = text.Split("\r\n", StringSplitOptions.None); var lines = text.Split("\r\n", StringSplitOptions.None);

View File

@@ -62,7 +62,7 @@ public class WsFrameReadTests
var readInfo = new WsReadInfo(expectMask: false); var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame); var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(1); result.Count.ShouldBe(1);
result[0].ShouldBe(payload); result[0].ShouldBe(payload);
@@ -77,7 +77,7 @@ public class WsFrameReadTests
var readInfo = new WsReadInfo(expectMask: true); var readInfo = new WsReadInfo(expectMask: true);
var stream = new MemoryStream(frame); var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(1); result.Count.ShouldBe(1);
result[0].ShouldBe(payload); result[0].ShouldBe(payload);
@@ -92,7 +92,7 @@ public class WsFrameReadTests
var readInfo = new WsReadInfo(expectMask: false); var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame); var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(1); result.Count.ShouldBe(1);
result[0].ShouldBe(payload); result[0].ShouldBe(payload);
@@ -105,7 +105,7 @@ public class WsFrameReadTests
var readInfo = new WsReadInfo(expectMask: false); var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame); var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(0); // control frames don't produce payload result.Count.ShouldBe(0); // control frames don't produce payload
readInfo.PendingControlFrames.Count.ShouldBe(1); readInfo.PendingControlFrames.Count.ShouldBe(1);
@@ -121,7 +121,7 @@ public class WsFrameReadTests
var readInfo = new WsReadInfo(expectMask: false); var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame); var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(0); result.Count.ShouldBe(0);
readInfo.CloseReceived.ShouldBeTrue(); readInfo.CloseReceived.ShouldBeTrue();
@@ -135,7 +135,7 @@ public class WsFrameReadTests
var readInfo = new WsReadInfo(expectMask: false); var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame); var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(0); result.Count.ShouldBe(0);
readInfo.PendingControlFrames.Count.ShouldBe(0); readInfo.PendingControlFrames.Count.ShouldBe(0);