diff --git a/src/NATS.Server/NatsServer.cs b/src/NATS.Server/NatsServer.cs index 079b443..b69fc0c 100644 --- a/src/NATS.Server/NatsServer.cs +++ b/src/NATS.Server/NatsServer.cs @@ -568,7 +568,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable } // HTTP upgrade handshake - var upgradeResult = await WsUpgrade.TryUpgradeAsync(stream, stream, _options.WebSocket); + var upgradeResult = await WsUpgrade.TryUpgradeAsync(stream, stream, _options.WebSocket, ct); if (!upgradeResult.Success) { _logger.LogDebug("WebSocket upgrade failed for client {ClientId}", clientId); diff --git a/src/NATS.Server/WebSocket/WsConnection.cs b/src/NATS.Server/WebSocket/WsConnection.cs index eb2b13b..498a1de 100644 --- a/src/NATS.Server/WebSocket/WsConnection.cs +++ b/src/NATS.Server/WebSocket/WsConnection.cs @@ -14,6 +14,7 @@ public sealed class WsConnection : Stream private readonly bool _browser; private readonly bool _noCompFrag; private WsReadInfo _readInfo; + // Read-side state: accessed only from the single FillPipeAsync reader task (no synchronization needed) private readonly Queue _readQueue = new(); private int _readOffset; private readonly object _writeLock = new(); @@ -47,7 +48,7 @@ public sealed class WsConnection : Stream if (bytesRead == 0) return 0; // Decode frames - var payloads = WsReadInfo.ReadFrames(ref _readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024); + var payloads = WsReadInfo.ReadFrames(_readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024); // Collect control frame responses if (_readInfo.PendingControlFrames.Count > 0) @@ -192,4 +193,10 @@ public sealed class WsConnection : Stream _inner.Dispose(); base.Dispose(disposing); } + + public override async ValueTask DisposeAsync() + { + await _inner.DisposeAsync(); + GC.SuppressFinalize(this); + } } diff --git a/src/NATS.Server/WebSocket/WsConstants.cs b/src/NATS.Server/WebSocket/WsConstants.cs index f0d392d..8a3a9d3 100644 --- a/src/NATS.Server/WebSocket/WsConstants.cs +++ b/src/NATS.Server/WebSocket/WsConstants.cs @@ -58,13 +58,7 @@ public static class WsConstants public const string LeafNodePath = "/leafnode"; public const string MqttPath = "/mqtt"; - // WebSocket GUID (RFC 6455 Section 1.3) - public static readonly byte[] WsGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"u8.ToArray(); - - // Compression trailer (RFC 7692 Section 7.2.2) - public static readonly byte[] CompressLastBlock = [0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff]; - - // Decompression trailer appended before decompressing + // Decompression trailer appended before decompressing (RFC 7692 Section 7.2.2) public static readonly byte[] DecompressTrailer = [0x00, 0x00, 0xff, 0xff]; public static bool IsControlFrame(int opcode) => opcode >= CloseMessage; diff --git a/src/NATS.Server/WebSocket/WsReadInfo.cs b/src/NATS.Server/WebSocket/WsReadInfo.cs index 9dbbc29..f6930c2 100644 --- a/src/NATS.Server/WebSocket/WsReadInfo.cs +++ b/src/NATS.Server/WebSocket/WsReadInfo.cs @@ -7,7 +7,7 @@ namespace NATS.Server.WebSocket; /// Per-connection WebSocket frame reading state machine. /// Ported from golang/nats-server/server/websocket.go lines 156-506. /// -public struct WsReadInfo +public class WsReadInfo { public int Remaining; public bool FrameStart; @@ -97,7 +97,7 @@ public struct WsReadInfo /// Returns list of decoded payload byte arrays. /// Ported from websocket.go lines 208-351. /// - public static List ReadFrames(ref WsReadInfo r, Stream stream, int available, int maxPayload) + public static List ReadFrames(WsReadInfo r, Stream stream, int available, int maxPayload) { var bufs = new List(); var buf = new byte[available]; @@ -184,8 +184,8 @@ public struct WsReadInfo } } - // Read mask key - if (r.ExpectMask && (b1 & WsConstants.MaskBit) != 0) + // Read mask key (mask bit already validated at line 134) + if (r.ExpectMask) { var (keyBuf, p2) = WsGet(stream, buf, pos, max, 4); pos = p2; @@ -196,7 +196,7 @@ public struct WsReadInfo // Handle control frames if (WsConstants.IsControlFrame(frameType)) { - pos = HandleControlFrame(ref r, frameType, stream, buf, pos, max); + pos = HandleControlFrame(r, frameType, stream, buf, pos, max); continue; } @@ -243,7 +243,7 @@ public struct WsReadInfo return bufs; } - private static int HandleControlFrame(ref WsReadInfo r, int frameType, Stream stream, byte[] buf, int pos, int max) + private static int HandleControlFrame(WsReadInfo r, int frameType, Stream stream, byte[] buf, int pos, int max) { byte[]? payload = null; if (r.Remaining > 0) diff --git a/src/NATS.Server/WebSocket/WsUpgrade.cs b/src/NATS.Server/WebSocket/WsUpgrade.cs index 662065f..d2fddbc 100644 --- a/src/NATS.Server/WebSocket/WsUpgrade.cs +++ b/src/NATS.Server/WebSocket/WsUpgrade.cs @@ -11,11 +11,14 @@ namespace NATS.Server.WebSocket; public static class WsUpgrade { public static async Task TryUpgradeAsync( - Stream inputStream, Stream outputStream, WebSocketOptions options) + Stream inputStream, Stream outputStream, WebSocketOptions options, + CancellationToken ct = default) { try { - var (method, path, headers) = await ReadHttpRequestAsync(inputStream); + using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); + cts.CancelAfter(options.HandshakeTimeout); + var (method, path, headers) = await ReadHttpRequestAsync(inputStream, cts.Token); if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase)) return await FailAsync(outputStream, 405, "request method must be GET"); @@ -165,22 +168,27 @@ public static class WsUpgrade return WsUpgradeResult.Failed; } - private static async Task<(string method, string path, Dictionary headers)> ReadHttpRequestAsync(Stream stream) + private static async Task<(string method, string path, Dictionary headers)> ReadHttpRequestAsync( + Stream stream, CancellationToken ct) { var headerBytes = new List(4096); - var buf = new byte[1]; + var buf = new byte[512]; while (true) { - int n = await stream.ReadAsync(buf); + int n = await stream.ReadAsync(buf, ct); if (n == 0) throw new IOException("connection closed during handshake"); - headerBytes.Add(buf[0]); - if (headerBytes.Count >= 4 && - headerBytes[^4] == '\r' && headerBytes[^3] == '\n' && - headerBytes[^2] == '\r' && headerBytes[^1] == '\n') - break; - if (headerBytes.Count > 8192) - throw new InvalidOperationException("HTTP header too large"); + for (int i = 0; i < n; i++) + { + headerBytes.Add(buf[i]); + if (headerBytes.Count >= 4 && + headerBytes[^4] == '\r' && headerBytes[^3] == '\n' && + headerBytes[^2] == '\r' && headerBytes[^1] == '\n') + goto done; + if (headerBytes.Count > 8192) + throw new InvalidOperationException("HTTP header too large"); + } } + done:; var text = Encoding.ASCII.GetString(headerBytes.ToArray()); var lines = text.Split("\r\n", StringSplitOptions.None); diff --git a/tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs b/tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs index 9a21b53..7e0e9df 100644 --- a/tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs +++ b/tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs @@ -62,7 +62,7 @@ public class WsFrameReadTests var readInfo = new WsReadInfo(expectMask: false); var stream = new MemoryStream(frame); - var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); + var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024); result.Count.ShouldBe(1); result[0].ShouldBe(payload); @@ -77,7 +77,7 @@ public class WsFrameReadTests var readInfo = new WsReadInfo(expectMask: true); var stream = new MemoryStream(frame); - var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); + var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024); result.Count.ShouldBe(1); result[0].ShouldBe(payload); @@ -92,7 +92,7 @@ public class WsFrameReadTests var readInfo = new WsReadInfo(expectMask: false); var stream = new MemoryStream(frame); - var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); + var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024); result.Count.ShouldBe(1); result[0].ShouldBe(payload); @@ -105,7 +105,7 @@ public class WsFrameReadTests var readInfo = new WsReadInfo(expectMask: false); var stream = new MemoryStream(frame); - var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); + var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024); result.Count.ShouldBe(0); // control frames don't produce payload readInfo.PendingControlFrames.Count.ShouldBe(1); @@ -121,7 +121,7 @@ public class WsFrameReadTests var readInfo = new WsReadInfo(expectMask: false); var stream = new MemoryStream(frame); - var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); + var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024); result.Count.ShouldBe(0); readInfo.CloseReceived.ShouldBeTrue(); @@ -135,7 +135,7 @@ public class WsFrameReadTests var readInfo = new WsReadInfo(expectMask: false); var stream = new MemoryStream(frame); - var result = WsReadInfo.ReadFrames(ref readInfo, stream, frame.Length, maxPayload: 1024); + var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024); result.Count.ShouldBe(0); readInfo.PendingControlFrames.Count.ShouldBe(0);