fix: address code review findings for WebSocket implementation
- Convert WsReadInfo from mutable struct to class (prevents silent copy bugs) - Add handshake timeout enforcement via CancellationToken in WsUpgrade - Use buffered reading (512 bytes) in ReadHttpRequestAsync instead of byte-at-a-time - Add IAsyncDisposable to WsConnection for proper async cleanup - Simplify redundant mask bit check in WsReadInfo - Remove unused WsGuid and CompressLastBlock dead code from WsConstants - Document single-reader assumption on WsConnection read-side state
This commit is contained in:
@@ -568,7 +568,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
}
|
||||
|
||||
// HTTP upgrade handshake
|
||||
var upgradeResult = await WsUpgrade.TryUpgradeAsync(stream, stream, _options.WebSocket);
|
||||
var upgradeResult = await WsUpgrade.TryUpgradeAsync(stream, stream, _options.WebSocket, ct);
|
||||
if (!upgradeResult.Success)
|
||||
{
|
||||
_logger.LogDebug("WebSocket upgrade failed for client {ClientId}", clientId);
|
||||
|
||||
@@ -14,6 +14,7 @@ public sealed class WsConnection : Stream
|
||||
private readonly bool _browser;
|
||||
private readonly bool _noCompFrag;
|
||||
private WsReadInfo _readInfo;
|
||||
// Read-side state: accessed only from the single FillPipeAsync reader task (no synchronization needed)
|
||||
private readonly Queue<byte[]> _readQueue = new();
|
||||
private int _readOffset;
|
||||
private readonly object _writeLock = new();
|
||||
@@ -47,7 +48,7 @@ public sealed class WsConnection : Stream
|
||||
if (bytesRead == 0) return 0;
|
||||
|
||||
// Decode frames
|
||||
var payloads = WsReadInfo.ReadFrames(ref _readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024);
|
||||
var payloads = WsReadInfo.ReadFrames(_readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024);
|
||||
|
||||
// Collect control frame responses
|
||||
if (_readInfo.PendingControlFrames.Count > 0)
|
||||
@@ -192,4 +193,10 @@ public sealed class WsConnection : Stream
|
||||
_inner.Dispose();
|
||||
base.Dispose(disposing);
|
||||
}
|
||||
|
||||
public override async ValueTask DisposeAsync()
|
||||
{
|
||||
await _inner.DisposeAsync();
|
||||
GC.SuppressFinalize(this);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,13 +58,7 @@ public static class WsConstants
|
||||
public const string LeafNodePath = "/leafnode";
|
||||
public const string MqttPath = "/mqtt";
|
||||
|
||||
// WebSocket GUID (RFC 6455 Section 1.3)
|
||||
public static readonly byte[] WsGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"u8.ToArray();
|
||||
|
||||
// Compression trailer (RFC 7692 Section 7.2.2)
|
||||
public static readonly byte[] CompressLastBlock = [0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff];
|
||||
|
||||
// Decompression trailer appended before decompressing
|
||||
// Decompression trailer appended before decompressing (RFC 7692 Section 7.2.2)
|
||||
public static readonly byte[] DecompressTrailer = [0x00, 0x00, 0xff, 0xff];
|
||||
|
||||
public static bool IsControlFrame(int opcode) => opcode >= CloseMessage;
|
||||
|
||||
@@ -7,7 +7,7 @@ namespace NATS.Server.WebSocket;
|
||||
/// Per-connection WebSocket frame reading state machine.
|
||||
/// Ported from golang/nats-server/server/websocket.go lines 156-506.
|
||||
/// </summary>
|
||||
public struct WsReadInfo
|
||||
public class WsReadInfo
|
||||
{
|
||||
public int Remaining;
|
||||
public bool FrameStart;
|
||||
@@ -97,7 +97,7 @@ public struct WsReadInfo
|
||||
/// Returns list of decoded payload byte arrays.
|
||||
/// Ported from websocket.go lines 208-351.
|
||||
/// </summary>
|
||||
public static List<byte[]> ReadFrames(ref WsReadInfo r, Stream stream, int available, int maxPayload)
|
||||
public static List<byte[]> ReadFrames(WsReadInfo r, Stream stream, int available, int maxPayload)
|
||||
{
|
||||
var bufs = new List<byte[]>();
|
||||
var buf = new byte[available];
|
||||
@@ -184,8 +184,8 @@ public struct WsReadInfo
|
||||
}
|
||||
}
|
||||
|
||||
// Read mask key
|
||||
if (r.ExpectMask && (b1 & WsConstants.MaskBit) != 0)
|
||||
// Read mask key (mask bit already validated at line 134)
|
||||
if (r.ExpectMask)
|
||||
{
|
||||
var (keyBuf, p2) = WsGet(stream, buf, pos, max, 4);
|
||||
pos = p2;
|
||||
@@ -196,7 +196,7 @@ public struct WsReadInfo
|
||||
// Handle control frames
|
||||
if (WsConstants.IsControlFrame(frameType))
|
||||
{
|
||||
pos = HandleControlFrame(ref r, frameType, stream, buf, pos, max);
|
||||
pos = HandleControlFrame(r, frameType, stream, buf, pos, max);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -243,7 +243,7 @@ public struct WsReadInfo
|
||||
return bufs;
|
||||
}
|
||||
|
||||
private static int HandleControlFrame(ref WsReadInfo r, int frameType, Stream stream, byte[] buf, int pos, int max)
|
||||
private static int HandleControlFrame(WsReadInfo r, int frameType, Stream stream, byte[] buf, int pos, int max)
|
||||
{
|
||||
byte[]? payload = null;
|
||||
if (r.Remaining > 0)
|
||||
|
||||
@@ -11,11 +11,14 @@ namespace NATS.Server.WebSocket;
|
||||
public static class WsUpgrade
|
||||
{
|
||||
public static async Task<WsUpgradeResult> TryUpgradeAsync(
|
||||
Stream inputStream, Stream outputStream, WebSocketOptions options)
|
||||
Stream inputStream, Stream outputStream, WebSocketOptions options,
|
||||
CancellationToken ct = default)
|
||||
{
|
||||
try
|
||||
{
|
||||
var (method, path, headers) = await ReadHttpRequestAsync(inputStream);
|
||||
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
||||
cts.CancelAfter(options.HandshakeTimeout);
|
||||
var (method, path, headers) = await ReadHttpRequestAsync(inputStream, cts.Token);
|
||||
|
||||
if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase))
|
||||
return await FailAsync(outputStream, 405, "request method must be GET");
|
||||
@@ -165,22 +168,27 @@ public static class WsUpgrade
|
||||
return WsUpgradeResult.Failed;
|
||||
}
|
||||
|
||||
private static async Task<(string method, string path, Dictionary<string, string> headers)> ReadHttpRequestAsync(Stream stream)
|
||||
private static async Task<(string method, string path, Dictionary<string, string> headers)> ReadHttpRequestAsync(
|
||||
Stream stream, CancellationToken ct)
|
||||
{
|
||||
var headerBytes = new List<byte>(4096);
|
||||
var buf = new byte[1];
|
||||
var buf = new byte[512];
|
||||
while (true)
|
||||
{
|
||||
int n = await stream.ReadAsync(buf);
|
||||
int n = await stream.ReadAsync(buf, ct);
|
||||
if (n == 0) throw new IOException("connection closed during handshake");
|
||||
headerBytes.Add(buf[0]);
|
||||
if (headerBytes.Count >= 4 &&
|
||||
headerBytes[^4] == '\r' && headerBytes[^3] == '\n' &&
|
||||
headerBytes[^2] == '\r' && headerBytes[^1] == '\n')
|
||||
break;
|
||||
if (headerBytes.Count > 8192)
|
||||
throw new InvalidOperationException("HTTP header too large");
|
||||
for (int i = 0; i < n; i++)
|
||||
{
|
||||
headerBytes.Add(buf[i]);
|
||||
if (headerBytes.Count >= 4 &&
|
||||
headerBytes[^4] == '\r' && headerBytes[^3] == '\n' &&
|
||||
headerBytes[^2] == '\r' && headerBytes[^1] == '\n')
|
||||
goto done;
|
||||
if (headerBytes.Count > 8192)
|
||||
throw new InvalidOperationException("HTTP header too large");
|
||||
}
|
||||
}
|
||||
done:;
|
||||
|
||||
var text = Encoding.ASCII.GetString(headerBytes.ToArray());
|
||||
var lines = text.Split("\r\n", StringSplitOptions.None);
|
||||
|
||||
@@ -62,7 +62,7 @@ public class WsFrameReadTests
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(1);
|
||||
result[0].ShouldBe(payload);
|
||||
@@ -77,7 +77,7 @@ public class WsFrameReadTests
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: true);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(1);
|
||||
result[0].ShouldBe(payload);
|
||||
@@ -92,7 +92,7 @@ public class WsFrameReadTests
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(1);
|
||||
result[0].ShouldBe(payload);
|
||||
@@ -105,7 +105,7 @@ public class WsFrameReadTests
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(0); // control frames don't produce payload
|
||||
readInfo.PendingControlFrames.Count.ShouldBe(1);
|
||||
@@ -121,7 +121,7 @@ public class WsFrameReadTests
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(0);
|
||||
readInfo.CloseReceived.ShouldBeTrue();
|
||||
@@ -135,7 +135,7 @@ public class WsFrameReadTests
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(0);
|
||||
readInfo.PendingControlFrames.Count.ShouldBe(0);
|
||||
|
||||
Reference in New Issue
Block a user