From ca88036126cc3e036fa88fc832ead929ffa0126c Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Mon, 23 Feb 2026 05:16:57 -0500 Subject: [PATCH] feat: integrate WebSocket accept loop into NatsServer and NatsClient Add WebSocket listener support to NatsServer alongside the existing TCP listener. When WebSocketOptions.Port >= 0, the server binds a second socket, performs HTTP upgrade via WsUpgrade.TryUpgradeAsync, wraps the connection in WsConnection for transparent frame/deframe, and hands it to the standard NatsClient pipeline. Changes: - NatsClient: add IsWebSocket and WsInfo properties - NatsServer: add RunWebSocketAcceptLoopAsync and AcceptWebSocketClientAsync, WS listener lifecycle in StartAsync/ShutdownAsync/Dispose - NatsOptions: change WebSocketOptions.Port default from 0 to -1 (disabled) - WsConnection.ReadAsync: fix premature end-of-stream when ReadFrames returns no payloads by looping until data is available - Add WsIntegration tests (connect, ping, pub/sub over WebSocket) - Add WsConnection masked frame and end-of-stream unit tests --- src/NATS.Server/NatsClient.cs | 4 + src/NATS.Server/NatsOptions.cs | 2 +- src/NATS.Server/NatsServer.cs | 139 ++++++++++++++- src/NATS.Server/WebSocket/WsConnection.cs | 55 +++--- .../WebSocket/WebSocketOptionsTests.cs | 6 +- .../WebSocket/WsConnectionTests.cs | 36 ++++ .../WebSocket/WsIntegrationTests.cs | 162 ++++++++++++++++++ 7 files changed, 368 insertions(+), 36 deletions(-) create mode 100644 tests/NATS.Server.Tests/WebSocket/WsIntegrationTests.cs diff --git a/src/NATS.Server/NatsClient.cs b/src/NATS.Server/NatsClient.cs index 1ccfc71..22c406e 100644 --- a/src/NATS.Server/NatsClient.cs +++ b/src/NATS.Server/NatsClient.cs @@ -11,6 +11,7 @@ using NATS.Server.Auth; using NATS.Server.Protocol; using NATS.Server.Subscriptions; using NATS.Server.Tls; +using NATS.Server.WebSocket; namespace NATS.Server; @@ -79,6 +80,9 @@ public sealed class NatsClient : IDisposable private long _rtt; public TimeSpan Rtt => new(Interlocked.Read(ref _rtt)); + public bool IsWebSocket { get; set; } + public WsUpgradeResult? WsInfo { get; set; } + public TlsConnectionState? TlsState { get; set; } public bool InfoAlreadySent { get; set; } diff --git a/src/NATS.Server/NatsOptions.cs b/src/NATS.Server/NatsOptions.cs index 97ec196..dc84854 100644 --- a/src/NATS.Server/NatsOptions.cs +++ b/src/NATS.Server/NatsOptions.cs @@ -94,7 +94,7 @@ public sealed class NatsOptions public sealed class WebSocketOptions { public string Host { get; set; } = "0.0.0.0"; - public int Port { get; set; } + public int Port { get; set; } = -1; public string? Advertise { get; set; } public string? NoAuthUser { get; set; } public string? JwtCookie { get; set; } diff --git a/src/NATS.Server/NatsServer.cs b/src/NATS.Server/NatsServer.cs index 02a0734..079b443 100644 --- a/src/NATS.Server/NatsServer.cs +++ b/src/NATS.Server/NatsServer.cs @@ -12,6 +12,7 @@ using NATS.Server.Monitoring; using NATS.Server.Protocol; using NATS.Server.Subscriptions; using NATS.Server.Tls; +using NATS.Server.WebSocket; namespace NATS.Server; @@ -33,6 +34,8 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable private readonly SslServerAuthenticationOptions? _sslOptions; private readonly TlsRateLimiter? _tlsRateLimiter; private Socket? _listener; + private Socket? _wsListener; + private readonly TaskCompletionSource _wsAcceptLoopExited = new(TaskCreationOptions.RunContinuationsAsynchronously); private MonitorServer? _monitorServer; private ulong _nextClientId; private long _startTimeTicks; @@ -87,11 +90,13 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable // Signal all internal loops to stop await _quitCts.CancelAsync(); - // Close listener to stop accept loop + // Close listeners to stop accept loops _listener?.Close(); + _wsListener?.Close(); - // Wait for accept loop to exit + // Wait for accept loops to exit await _acceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + await _wsAcceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); // Close all client connections — flush first, then mark closed var flushTasks = new List(); @@ -132,11 +137,13 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable _logger.LogInformation("Entering lame duck mode, stop accepting new clients"); - // Close listener to stop accepting new connections + // Close listeners to stop accepting new connections _listener?.Close(); + _wsListener?.Close(); - // Wait for accept loop to exit + // Wait for accept loops to exit await _acceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + await _wsAcceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); var gracePeriod = _options.LameDuckGracePeriod; if (gracePeriod < TimeSpan.Zero) gracePeriod = -gracePeriod; @@ -314,8 +321,6 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable BuildCachedInfo(); } - _listeningStarted.TrySetResult(); - _logger.LogInformation("Listening for client connections on {Host}:{Port}", _options.Host, _options.Port); // Warn about stub features @@ -333,6 +338,31 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable WritePidFile(); WritePortsFile(); + if (_options.WebSocket.Port >= 0) + { + _wsListener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + _wsListener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true); + _wsListener.Bind(new IPEndPoint( + _options.WebSocket.Host == "0.0.0.0" ? IPAddress.Any : IPAddress.Parse(_options.WebSocket.Host), + _options.WebSocket.Port)); + _wsListener.Listen(128); + + if (_options.WebSocket.Port == 0) + { + _options.WebSocket.Port = ((IPEndPoint)_wsListener.LocalEndPoint!).Port; + } + + _logger.LogInformation("Listening for WebSocket clients on {Host}:{Port}", + _options.WebSocket.Host, _options.WebSocket.Port); + + if (_options.WebSocket.NoTls) + _logger.LogWarning("WebSocket not configured with TLS. DO NOT USE IN PRODUCTION!"); + + _ = RunWebSocketAcceptLoopAsync(linked.Token); + } + + _listeningStarted.TrySetResult(); + var tmpDelay = AcceptMinSleep; try @@ -478,6 +508,102 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable } } + private async Task RunWebSocketAcceptLoopAsync(CancellationToken ct) + { + var tmpDelay = AcceptMinSleep; + try + { + while (!ct.IsCancellationRequested) + { + Socket socket; + try + { + socket = await _wsListener!.AcceptAsync(ct); + tmpDelay = AcceptMinSleep; + } + catch (OperationCanceledException) { break; } + catch (ObjectDisposedException) { break; } + catch (SocketException ex) + { + if (IsShuttingDown || IsLameDuckMode) break; + _logger.LogError(ex, "Temporary WebSocket accept error, sleeping {Delay}ms", tmpDelay.TotalMilliseconds); + try { await Task.Delay(tmpDelay, ct); } catch (OperationCanceledException) { break; } + tmpDelay = TimeSpan.FromTicks(Math.Min(tmpDelay.Ticks * 2, AcceptMaxSleep.Ticks)); + continue; + } + + if (_options.MaxConnections > 0 && _clients.Count >= _options.MaxConnections) + { + socket.Dispose(); + continue; + } + + var clientId = Interlocked.Increment(ref _nextClientId); + Interlocked.Increment(ref _stats.TotalConnections); + Interlocked.Increment(ref _activeClientCount); + + _ = AcceptWebSocketClientAsync(socket, clientId, ct); + } + } + finally + { + _wsAcceptLoopExited.TrySetResult(); + } + } + + private async Task AcceptWebSocketClientAsync(Socket socket, ulong clientId, CancellationToken ct) + { + try + { + var networkStream = new NetworkStream(socket, ownsSocket: false); + Stream stream = networkStream; + + // TLS negotiation if configured + if (_sslOptions != null && !_options.WebSocket.NoTls) + { + var (tlsStream, _) = await TlsConnectionWrapper.NegotiateAsync( + socket, networkStream, _options, _sslOptions, _serverInfo, + _loggerFactory.CreateLogger("NATS.Server.Tls"), ct); + stream = tlsStream; + } + + // HTTP upgrade handshake + var upgradeResult = await WsUpgrade.TryUpgradeAsync(stream, stream, _options.WebSocket); + if (!upgradeResult.Success) + { + _logger.LogDebug("WebSocket upgrade failed for client {ClientId}", clientId); + socket.Dispose(); + Interlocked.Decrement(ref _activeClientCount); + return; + } + + // Create WsConnection wrapper + var wsConn = new WsConnection(stream, + compress: upgradeResult.Compress, + maskRead: upgradeResult.MaskRead, + maskWrite: upgradeResult.MaskWrite, + browser: upgradeResult.Browser, + noCompFrag: upgradeResult.NoCompFrag); + + var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]"); + var client = new NatsClient(clientId, wsConn, socket, _options, _serverInfo, + _authService, null, clientLogger, _stats); + client.Router = this; + client.IsWebSocket = true; + client.WsInfo = upgradeResult; + _clients[clientId] = client; + + await RunClientAsync(client, ct); + } + catch (Exception ex) + { + _logger.LogDebug(ex, "Failed to accept WebSocket client {ClientId}", clientId); + try { socket.Shutdown(SocketShutdown.Both); } catch { } + socket.Dispose(); + Interlocked.Decrement(ref _activeClientCount); + } + } + private async Task RunClientAsync(NatsClient client, CancellationToken ct) { try @@ -726,6 +852,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable _quitCts.Dispose(); _tlsRateLimiter?.Dispose(); _listener?.Dispose(); + _wsListener?.Dispose(); foreach (var client in _clients.Values) client.Dispose(); foreach (var account in _accounts.Values) diff --git a/src/NATS.Server/WebSocket/WsConnection.cs b/src/NATS.Server/WebSocket/WsConnection.cs index 61e51e7..eb2b13b 100644 --- a/src/NATS.Server/WebSocket/WsConnection.cs +++ b/src/NATS.Server/WebSocket/WsConnection.cs @@ -39,34 +39,37 @@ public sealed class WsConnection : Stream if (_readQueue.Count > 0) return DrainReadQueue(buffer.Span); - // Read raw bytes from inner stream - var rawBuf = new byte[Math.Max(buffer.Length, 4096)]; - int bytesRead = await _inner.ReadAsync(rawBuf.AsMemory(), ct); - if (bytesRead == 0) return 0; - - // Decode frames - var payloads = WsReadInfo.ReadFrames(ref _readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024); - - // Collect control frame responses - if (_readInfo.PendingControlFrames.Count > 0) + while (true) { - lock (_writeLock) - _pendingControlWrites.AddRange(_readInfo.PendingControlFrames); - _readInfo.PendingControlFrames.Clear(); - // Write pending control frames - await FlushControlFramesAsync(ct); + // Read raw bytes from inner stream + var rawBuf = new byte[Math.Max(buffer.Length, 4096)]; + int bytesRead = await _inner.ReadAsync(rawBuf.AsMemory(), ct); + if (bytesRead == 0) return 0; + + // Decode frames + var payloads = WsReadInfo.ReadFrames(ref _readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024); + + // Collect control frame responses + if (_readInfo.PendingControlFrames.Count > 0) + { + lock (_writeLock) + _pendingControlWrites.AddRange(_readInfo.PendingControlFrames); + _readInfo.PendingControlFrames.Clear(); + // Write pending control frames + await FlushControlFramesAsync(ct); + } + + if (_readInfo.CloseReceived) + return 0; + + foreach (var payload in payloads) + _readQueue.Enqueue(payload); + + // If no payloads were decoded (e.g. only frame headers were read), + // continue reading instead of returning 0 which signals end-of-stream + if (_readQueue.Count > 0) + return DrainReadQueue(buffer.Span); } - - if (_readInfo.CloseReceived) - return 0; - - foreach (var payload in payloads) - _readQueue.Enqueue(payload); - - if (_readQueue.Count == 0) - return 0; - - return DrainReadQueue(buffer.Span); } public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken ct = default) diff --git a/tests/NATS.Server.Tests/WebSocket/WebSocketOptionsTests.cs b/tests/NATS.Server.Tests/WebSocket/WebSocketOptionsTests.cs index dba1d0e..50c43b7 100644 --- a/tests/NATS.Server.Tests/WebSocket/WebSocketOptionsTests.cs +++ b/tests/NATS.Server.Tests/WebSocket/WebSocketOptionsTests.cs @@ -5,10 +5,10 @@ namespace NATS.Server.Tests.WebSocket; public class WebSocketOptionsTests { [Fact] - public void DefaultOptions_PortIsZero_Disabled() + public void DefaultOptions_PortIsNegativeOne_Disabled() { var opts = new WebSocketOptions(); - opts.Port.ShouldBe(0); + opts.Port.ShouldBe(-1); opts.Host.ShouldBe("0.0.0.0"); opts.Compression.ShouldBeFalse(); opts.NoTls.ShouldBeFalse(); @@ -21,6 +21,6 @@ public class WebSocketOptionsTests { var opts = new NatsOptions(); opts.WebSocket.ShouldNotBeNull(); - opts.WebSocket.Port.ShouldBe(0); + opts.WebSocket.Port.ShouldBe(-1); } } diff --git a/tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs b/tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs index 2955b1d..8f30768 100644 --- a/tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs +++ b/tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs @@ -75,6 +75,42 @@ public class WsConnectionTests (written[0] & WsConstants.Rsv1Bit).ShouldBe(0); } + [Fact] + public async Task ReadAsync_DecodesMaskedFrame() + { + var payload = "CONNECT {}\r\n"u8.ToArray(); + var (header, _) = WsFrameWriter.CreateFrameHeader( + useMasking: true, compressed: false, + opcode: WsConstants.BinaryMessage, payloadLength: payload.Length); + var maskKey = header[^4..]; + WsFrameWriter.MaskBuf(maskKey, payload); + + var frame = new byte[header.Length + payload.Length]; + header.CopyTo(frame, 0); + payload.CopyTo(frame, header.Length); + + var inner = new MemoryStream(frame); + var ws = new WsConnection(inner, compress: false, maskRead: true, maskWrite: false, browser: false, noCompFrag: false); + + var buf = new byte[256]; + int n = await ws.ReadAsync(buf); + + n.ShouldBe("CONNECT {}\r\n".Length); + System.Text.Encoding.ASCII.GetString(buf, 0, n).ShouldBe("CONNECT {}\r\n"); + } + + [Fact] + public async Task ReadAsync_ReturnsZero_OnEndOfStream() + { + // Empty stream should return 0 (true end of stream) + var inner = new MemoryStream([]); + var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false); + + var buf = new byte[256]; + int n = await ws.ReadAsync(buf); + n.ShouldBe(0); + } + private static byte[] BuildUnmaskedFrame(byte[] payload) { var header = new byte[2]; diff --git a/tests/NATS.Server.Tests/WebSocket/WsIntegrationTests.cs b/tests/NATS.Server.Tests/WebSocket/WsIntegrationTests.cs new file mode 100644 index 0000000..c20edee --- /dev/null +++ b/tests/NATS.Server.Tests/WebSocket/WsIntegrationTests.cs @@ -0,0 +1,162 @@ +using System.Buffers.Binary; +using System.Net; +using System.Net.Sockets; +using System.Security.Cryptography; +using System.Text; +using NATS.Server.WebSocket; + +namespace NATS.Server.Tests.WebSocket; + +public class WsIntegrationTests : IAsyncLifetime +{ + private NatsServer _server = null!; + private NatsOptions _options = null!; + + public async Task InitializeAsync() + { + _options = new NatsOptions + { + Port = 0, + WebSocket = new WebSocketOptions { Port = 0, NoTls = true }, + }; + var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(b => { }); + _server = new NatsServer(_options, loggerFactory); + _ = _server.StartAsync(CancellationToken.None); + await _server.WaitForReadyAsync(); + } + + public async Task DisposeAsync() + { + await _server.ShutdownAsync(); + _server.Dispose(); + } + + [Fact] + public async Task WebSocket_ConnectAndReceiveInfo() + { + using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port)); + using var stream = new NetworkStream(socket, ownsSocket: false); + + await SendUpgradeRequest(stream); + var response = await ReadHttpResponse(stream); + response.ShouldContain("101"); + + var wsFrame = await ReadWsFrame(stream); + var info = Encoding.ASCII.GetString(wsFrame); + info.ShouldStartWith("INFO "); + } + + [Fact] + public async Task WebSocket_ConnectAndPing() + { + using var client = await ConnectWsClient(); + + // Send CONNECT and PING together + await SendWsText(client, "CONNECT {}\r\nPING\r\n"); + + // Read PONG WS frame + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + var pong = await ReadWsFrameAsync(client, cts.Token); + Encoding.ASCII.GetString(pong).ShouldContain("PONG"); + } + + [Fact] + public async Task WebSocket_PubSub() + { + using var sub = await ConnectWsClient(); + using var pub = await ConnectWsClient(); + + await SendWsText(sub, "CONNECT {}\r\nSUB test.ws 1\r\n"); + await Task.Delay(200); + + await SendWsText(pub, "CONNECT {}\r\nPUB test.ws 5\r\nHello\r\n"); + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + var msg = await ReadWsFrameAsync(sub, cts.Token); + Encoding.ASCII.GetString(msg).ShouldContain("MSG test.ws 1 5"); + } + + private async Task ConnectWsClient() + { + var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port)); + var stream = new NetworkStream(socket, ownsSocket: true); + + await SendUpgradeRequest(stream); + var response = await ReadHttpResponse(stream); + response.ShouldContain("101"); + + await ReadWsFrame(stream); // Read INFO frame + return stream; + } + + private static async Task SendUpgradeRequest(NetworkStream stream) + { + var keyBytes = new byte[16]; + RandomNumberGenerator.Fill(keyBytes); + var key = Convert.ToBase64String(keyBytes); + + var request = $"GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {key}\r\nSec-WebSocket-Version: 13\r\n\r\n"; + await stream.WriteAsync(Encoding.ASCII.GetBytes(request)); + await stream.FlushAsync(); + } + + private static async Task ReadHttpResponse(NetworkStream stream) + { + // Read one byte at a time to avoid consuming WS frame bytes that follow the HTTP response + var sb = new StringBuilder(); + var buf = new byte[1]; + while (true) + { + int n = await stream.ReadAsync(buf); + if (n == 0) break; + sb.Append((char)buf[0]); + if (sb.Length >= 4 && + sb[^4] == '\r' && sb[^3] == '\n' && + sb[^2] == '\r' && sb[^1] == '\n') + break; + } + + return sb.ToString(); + } + + private static Task ReadWsFrame(NetworkStream stream) + => ReadWsFrameAsync(stream, CancellationToken.None); + + private static async Task ReadWsFrameAsync(NetworkStream stream, CancellationToken ct) + { + var header = new byte[2]; + await stream.ReadExactlyAsync(header, ct); + int len = header[1] & 0x7F; + if (len == 126) + { + var extLen = new byte[2]; + await stream.ReadExactlyAsync(extLen, ct); + len = BinaryPrimitives.ReadUInt16BigEndian(extLen); + } + else if (len == 127) + { + var extLen = new byte[8]; + await stream.ReadExactlyAsync(extLen, ct); + len = (int)BinaryPrimitives.ReadUInt64BigEndian(extLen); + } + + var payload = new byte[len]; + if (len > 0) await stream.ReadExactlyAsync(payload, ct); + return payload; + } + + private static async Task SendWsText(NetworkStream stream, string text) + { + var payload = Encoding.ASCII.GetBytes(text); + var (header, _) = WsFrameWriter.CreateFrameHeader( + useMasking: true, compressed: false, + opcode: WsConstants.BinaryMessage, payloadLength: payload.Length); + var maskKey = header[^4..]; + WsFrameWriter.MaskBuf(maskKey, payload); + await stream.WriteAsync(header); + await stream.WriteAsync(payload); + await stream.FlushAsync(); + } +}