diff --git a/differences.md b/differences.md index 1a921ba..33a6453 100644 --- a/differences.md +++ b/differences.md @@ -67,7 +67,7 @@ | SYSTEM (internal) | Y | N | | | JETSTREAM (internal) | Y | N | | | ACCOUNT (internal) | Y | N | | -| WebSocket clients | Y | N | | +| WebSocket clients | Y | Y | Custom frame parser, permessage-deflate compression, origin checking, cookie auth | | MQTT clients | Y | N | | ### Client Features @@ -267,7 +267,8 @@ Go implements a sophisticated slow consumer detection system: - ~~Advanced limits (MaxSubs, MaxSubTokens, MaxPending, WriteDeadline)~~ — `MaxSubs`, `MaxSubTokens` implemented; MaxPending/WriteDeadline already existed - ~~Tags/metadata~~ — `Tags` dictionary implemented in `NatsOptions` - ~~OCSP configuration~~ — `OcspConfig` with 4 modes (Auto/Always/Must/Never), peer verification, and stapling -- WebSocket/MQTT options +- ~~WebSocket options~~ — `WebSocketOptions` with port, compression, origin checking, cookie auth, custom headers +- MQTT options - ~~Operator mode / account resolver~~ — `JwtAuthenticator` + `IAccountResolver` + `MemAccountResolver` with trusted keys --- diff --git a/src/NATS.Server/NatsClient.cs b/src/NATS.Server/NatsClient.cs index 5a7e2c8..620f275 100644 --- a/src/NATS.Server/NatsClient.cs +++ b/src/NATS.Server/NatsClient.cs @@ -11,6 +11,7 @@ using NATS.Server.Auth; using NATS.Server.Protocol; using NATS.Server.Subscriptions; using NATS.Server.Tls; +using NATS.Server.WebSocket; namespace NATS.Server; @@ -93,6 +94,9 @@ public sealed class NatsClient : IDisposable private long _rtt; public TimeSpan Rtt => new(Interlocked.Read(ref _rtt)); + public bool IsWebSocket { get; set; } + public WsUpgradeResult? WsInfo { get; set; } + public TlsConnectionState? TlsState { get; set; } public bool InfoAlreadySent { get; set; } diff --git a/src/NATS.Server/NatsOptions.cs b/src/NATS.Server/NatsOptions.cs index fb5d831..1e3820e 100644 --- a/src/NATS.Server/NatsOptions.cs +++ b/src/NATS.Server/NatsOptions.cs @@ -116,4 +116,32 @@ public sealed class NatsOptions public Dictionary? SubjectMappings { get; set; } public bool HasTls => TlsCert != null && TlsKey != null; + + // WebSocket + public WebSocketOptions WebSocket { get; set; } = new(); +} + +public sealed class WebSocketOptions +{ + public string Host { get; set; } = "0.0.0.0"; + public int Port { get; set; } = -1; + public string? Advertise { get; set; } + public string? NoAuthUser { get; set; } + public string? JwtCookie { get; set; } + public string? UsernameCookie { get; set; } + public string? PasswordCookie { get; set; } + public string? TokenCookie { get; set; } + public string? Username { get; set; } + public string? Password { get; set; } + public string? Token { get; set; } + public TimeSpan AuthTimeout { get; set; } = TimeSpan.FromSeconds(2); + public bool NoTls { get; set; } + public string? TlsCert { get; set; } + public string? TlsKey { get; set; } + public bool SameOrigin { get; set; } + public List? AllowedOrigins { get; set; } + public bool Compression { get; set; } + public TimeSpan HandshakeTimeout { get; set; } = TimeSpan.FromSeconds(2); + public TimeSpan? PingInterval { get; set; } + public Dictionary? Headers { get; set; } } diff --git a/src/NATS.Server/NatsServer.cs b/src/NATS.Server/NatsServer.cs index 89c3bd4..9a1f717 100644 --- a/src/NATS.Server/NatsServer.cs +++ b/src/NATS.Server/NatsServer.cs @@ -13,6 +13,7 @@ using NATS.Server.Monitoring; using NATS.Server.Protocol; using NATS.Server.Subscriptions; using NATS.Server.Tls; +using NATS.Server.WebSocket; namespace NATS.Server; @@ -39,6 +40,8 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable private readonly TlsRateLimiter? _tlsRateLimiter; private readonly SubjectTransform[] _subjectTransforms; private Socket? _listener; + private Socket? _wsListener; + private readonly TaskCompletionSource _wsAcceptLoopExited = new(TaskCreationOptions.RunContinuationsAsynchronously); private MonitorServer? _monitorServer; private ulong _nextClientId; private long _startTimeTicks; @@ -93,11 +96,13 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable // Signal all internal loops to stop await _quitCts.CancelAsync(); - // Close listener to stop accept loop + // Close listeners to stop accept loops _listener?.Close(); + _wsListener?.Close(); - // Wait for accept loop to exit + // Wait for accept loops to exit await _acceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + await _wsAcceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); // Close all client connections — flush first, then mark closed var flushTasks = new List(); @@ -138,11 +143,13 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable _logger.LogInformation("Entering lame duck mode, stop accepting new clients"); - // Close listener to stop accepting new connections + // Close listeners to stop accepting new connections _listener?.Close(); + _wsListener?.Close(); - // Wait for accept loop to exit + // Wait for accept loops to exit await _acceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + await _wsAcceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); var gracePeriod = _options.LameDuckGracePeriod; if (gracePeriod < TimeSpan.Zero) gracePeriod = -gracePeriod; @@ -369,8 +376,6 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable BuildCachedInfo(); } - _listeningStarted.TrySetResult(); - _logger.LogInformation("Listening for client connections on {Host}:{Port}", _options.Host, _options.Port); // Warn about stub features @@ -386,6 +391,31 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable WritePidFile(); WritePortsFile(); + if (_options.WebSocket.Port >= 0) + { + _wsListener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + _wsListener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true); + _wsListener.Bind(new IPEndPoint( + _options.WebSocket.Host == "0.0.0.0" ? IPAddress.Any : IPAddress.Parse(_options.WebSocket.Host), + _options.WebSocket.Port)); + _wsListener.Listen(128); + + if (_options.WebSocket.Port == 0) + { + _options.WebSocket.Port = ((IPEndPoint)_wsListener.LocalEndPoint!).Port; + } + + _logger.LogInformation("Listening for WebSocket clients on {Host}:{Port}", + _options.WebSocket.Host, _options.WebSocket.Port); + + if (_options.WebSocket.NoTls) + _logger.LogWarning("WebSocket not configured with TLS. DO NOT USE IN PRODUCTION!"); + + _ = RunWebSocketAcceptLoopAsync(linked.Token); + } + + _listeningStarted.TrySetResult(); + var tmpDelay = AcceptMinSleep; try @@ -531,6 +561,102 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable } } + private async Task RunWebSocketAcceptLoopAsync(CancellationToken ct) + { + var tmpDelay = AcceptMinSleep; + try + { + while (!ct.IsCancellationRequested) + { + Socket socket; + try + { + socket = await _wsListener!.AcceptAsync(ct); + tmpDelay = AcceptMinSleep; + } + catch (OperationCanceledException) { break; } + catch (ObjectDisposedException) { break; } + catch (SocketException ex) + { + if (IsShuttingDown || IsLameDuckMode) break; + _logger.LogError(ex, "Temporary WebSocket accept error, sleeping {Delay}ms", tmpDelay.TotalMilliseconds); + try { await Task.Delay(tmpDelay, ct); } catch (OperationCanceledException) { break; } + tmpDelay = TimeSpan.FromTicks(Math.Min(tmpDelay.Ticks * 2, AcceptMaxSleep.Ticks)); + continue; + } + + if (_options.MaxConnections > 0 && _clients.Count >= _options.MaxConnections) + { + socket.Dispose(); + continue; + } + + var clientId = Interlocked.Increment(ref _nextClientId); + Interlocked.Increment(ref _stats.TotalConnections); + Interlocked.Increment(ref _activeClientCount); + + _ = AcceptWebSocketClientAsync(socket, clientId, ct); + } + } + finally + { + _wsAcceptLoopExited.TrySetResult(); + } + } + + private async Task AcceptWebSocketClientAsync(Socket socket, ulong clientId, CancellationToken ct) + { + try + { + var networkStream = new NetworkStream(socket, ownsSocket: false); + Stream stream = networkStream; + + // TLS negotiation if configured + if (_sslOptions != null && !_options.WebSocket.NoTls) + { + var (tlsStream, _) = await TlsConnectionWrapper.NegotiateAsync( + socket, networkStream, _options, _sslOptions, _serverInfo, + _loggerFactory.CreateLogger("NATS.Server.Tls"), ct); + stream = tlsStream; + } + + // HTTP upgrade handshake + var upgradeResult = await WsUpgrade.TryUpgradeAsync(stream, stream, _options.WebSocket, ct); + if (!upgradeResult.Success) + { + _logger.LogDebug("WebSocket upgrade failed for client {ClientId}", clientId); + socket.Dispose(); + Interlocked.Decrement(ref _activeClientCount); + return; + } + + // Create WsConnection wrapper + var wsConn = new WsConnection(stream, + compress: upgradeResult.Compress, + maskRead: upgradeResult.MaskRead, + maskWrite: upgradeResult.MaskWrite, + browser: upgradeResult.Browser, + noCompFrag: upgradeResult.NoCompFrag); + + var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]"); + var client = new NatsClient(clientId, wsConn, socket, _options, _serverInfo, + _authService, null, clientLogger, _stats); + client.Router = this; + client.IsWebSocket = true; + client.WsInfo = upgradeResult; + _clients[clientId] = client; + + await RunClientAsync(client, ct); + } + catch (Exception ex) + { + _logger.LogDebug(ex, "Failed to accept WebSocket client {ClientId}", clientId); + try { socket.Shutdown(SocketShutdown.Both); } catch { } + socket.Dispose(); + Interlocked.Decrement(ref _activeClientCount); + } + } + private async Task RunClientAsync(NatsClient client, CancellationToken ct) { try @@ -942,6 +1068,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable _quitCts.Dispose(); _tlsRateLimiter?.Dispose(); _listener?.Dispose(); + _wsListener?.Dispose(); foreach (var client in _clients.Values) client.Dispose(); foreach (var account in _accounts.Values) diff --git a/src/NATS.Server/WebSocket/WsCompression.cs b/src/NATS.Server/WebSocket/WsCompression.cs new file mode 100644 index 0000000..92f0184 --- /dev/null +++ b/src/NATS.Server/WebSocket/WsCompression.cs @@ -0,0 +1,94 @@ +using System.IO.Compression; + +namespace NATS.Server.WebSocket; + +/// +/// permessage-deflate compression/decompression for WebSocket frames (RFC 7692). +/// Ported from golang/nats-server/server/websocket.go lines 403-440 and 1391-1466. +/// +public static class WsCompression +{ + /// + /// Compresses data using deflate. Removes trailing 4 bytes (sync marker) + /// per RFC 7692 Section 7.2.1. + /// + /// + /// We call Flush() but intentionally do not Dispose() the DeflateStream before + /// reading output, because Dispose writes a final deflate block (0x03 0x00) that + /// would be corrupted by the 4-byte tail strip. Flush() alone writes a sync flush + /// ending with 0x00 0x00 0xff 0xff, matching Go's flate.Writer.Flush() behavior. + /// + public static byte[] Compress(ReadOnlySpan data) + { + var output = new MemoryStream(); + var deflate = new DeflateStream(output, CompressionLevel.Fastest, leaveOpen: true); + try + { + deflate.Write(data); + deflate.Flush(); + + var compressed = output.ToArray(); + + // Remove trailing 4-byte sync marker (0x00 0x00 0xff 0xff) per RFC 7692 + if (compressed.Length >= 4) + return compressed[..^4]; + + return compressed; + } + finally + { + deflate.Dispose(); + output.Dispose(); + } + } + + /// + /// Decompresses collected compressed buffers. + /// Appends trailer bytes before decompressing per RFC 7692 Section 7.2.2. + /// Ported from golang/nats-server/server/websocket.go lines 403-440. + /// The Go code appends compressLastBlock (9 bytes) which includes the sync + /// marker plus a final empty stored block to signal end-of-stream to the + /// flate reader. + /// + public static byte[] Decompress(List compressedBuffers, int maxPayload) + { + if (maxPayload <= 0) + maxPayload = 1024 * 1024; // Default 1MB + + // Concatenate all compressed buffers + trailer. + // Per RFC 7692 Section 7.2.2, append the sync flush marker (0x00 0x00 0xff 0xff) + // that was stripped during compression. The Go reference appends compressLastBlock + // (9 bytes) for Go's flate reader; .NET's DeflateStream only needs the 4-byte trailer. + int totalLen = 0; + foreach (var buf in compressedBuffers) + totalLen += buf.Length; + totalLen += WsConstants.DecompressTrailer.Length; + + var combined = new byte[totalLen]; + int offset = 0; + foreach (var buf in compressedBuffers) + { + buf.CopyTo(combined, offset); + offset += buf.Length; + } + + WsConstants.DecompressTrailer.CopyTo(combined, offset); + + using var input = new MemoryStream(combined); + using var deflate = new DeflateStream(input, CompressionMode.Decompress); + using var output = new MemoryStream(); + + var readBuf = new byte[4096]; + int totalRead = 0; + int n; + while ((n = deflate.Read(readBuf, 0, readBuf.Length)) > 0) + { + totalRead += n; + if (totalRead > maxPayload) + throw new InvalidOperationException("decompressed data exceeds maximum payload size"); + output.Write(readBuf, 0, n); + } + + return output.ToArray(); + } +} diff --git a/src/NATS.Server/WebSocket/WsConnection.cs b/src/NATS.Server/WebSocket/WsConnection.cs new file mode 100644 index 0000000..498a1de --- /dev/null +++ b/src/NATS.Server/WebSocket/WsConnection.cs @@ -0,0 +1,202 @@ +namespace NATS.Server.WebSocket; + +/// +/// Stream wrapper that transparently frames/deframes WebSocket around raw TCP I/O. +/// NatsClient uses this as its _stream -- FillPipeAsync and RunWriteLoopAsync work unchanged. +/// Ported from golang/nats-server/server/websocket.go wsUpgrade/wrapWebsocket pattern. +/// +public sealed class WsConnection : Stream +{ + private readonly Stream _inner; + private readonly bool _compress; + private readonly bool _maskRead; + private readonly bool _maskWrite; + private readonly bool _browser; + private readonly bool _noCompFrag; + private WsReadInfo _readInfo; + // Read-side state: accessed only from the single FillPipeAsync reader task (no synchronization needed) + private readonly Queue _readQueue = new(); + private int _readOffset; + private readonly object _writeLock = new(); + private readonly List _pendingControlWrites = []; + + public bool CloseReceived => _readInfo.CloseReceived; + public int CloseStatus => _readInfo.CloseStatus; + + public WsConnection(Stream inner, bool compress, bool maskRead, bool maskWrite, bool browser, bool noCompFrag) + { + _inner = inner; + _compress = compress; + _maskRead = maskRead; + _maskWrite = maskWrite; + _browser = browser; + _noCompFrag = noCompFrag; + _readInfo = new WsReadInfo(expectMask: maskRead); + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken ct = default) + { + // Drain any buffered decoded payloads first + if (_readQueue.Count > 0) + return DrainReadQueue(buffer.Span); + + while (true) + { + // Read raw bytes from inner stream + var rawBuf = new byte[Math.Max(buffer.Length, 4096)]; + int bytesRead = await _inner.ReadAsync(rawBuf.AsMemory(), ct); + if (bytesRead == 0) return 0; + + // Decode frames + var payloads = WsReadInfo.ReadFrames(_readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024); + + // Collect control frame responses + if (_readInfo.PendingControlFrames.Count > 0) + { + lock (_writeLock) + _pendingControlWrites.AddRange(_readInfo.PendingControlFrames); + _readInfo.PendingControlFrames.Clear(); + // Write pending control frames + await FlushControlFramesAsync(ct); + } + + if (_readInfo.CloseReceived) + return 0; + + foreach (var payload in payloads) + _readQueue.Enqueue(payload); + + // If no payloads were decoded (e.g. only frame headers were read), + // continue reading instead of returning 0 which signals end-of-stream + if (_readQueue.Count > 0) + return DrainReadQueue(buffer.Span); + } + } + + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken ct = default) + { + var data = buffer.Span; + + if (_compress && data.Length > WsConstants.CompressThreshold) + { + var compressed = WsCompression.Compress(data); + await WriteFramedAsync(compressed, compressed: true, ct); + } + else + { + await WriteFramedAsync(data.ToArray(), compressed: false, ct); + } + } + + private async ValueTask WriteFramedAsync(byte[] payload, bool compressed, CancellationToken ct) + { + if (_browser && payload.Length > WsConstants.FrameSizeForBrowsers && !(_noCompFrag && compressed)) + { + // Fragment for browsers + int offset = 0; + bool first = true; + while (offset < payload.Length) + { + int chunkLen = Math.Min(WsConstants.FrameSizeForBrowsers, payload.Length - offset); + bool final = offset + chunkLen >= payload.Length; + var fh = new byte[WsConstants.MaxFrameHeaderSize]; + var (n, key) = WsFrameWriter.FillFrameHeader(fh, _maskWrite, + first: first, final: final, compressed: first && compressed, + opcode: WsConstants.BinaryMessage, payloadLength: chunkLen); + + var chunk = payload.AsSpan(offset, chunkLen).ToArray(); + if (_maskWrite && key != null) + WsFrameWriter.MaskBuf(key, chunk); + + await _inner.WriteAsync(fh.AsMemory(0, n), ct); + await _inner.WriteAsync(chunk.AsMemory(), ct); + offset += chunkLen; + first = false; + } + } + else + { + var (header, key) = WsFrameWriter.CreateFrameHeader(_maskWrite, compressed, WsConstants.BinaryMessage, payload.Length); + if (_maskWrite && key != null) + WsFrameWriter.MaskBuf(key, payload); + await _inner.WriteAsync(header.AsMemory(), ct); + await _inner.WriteAsync(payload.AsMemory(), ct); + } + } + + private async Task FlushControlFramesAsync(CancellationToken ct) + { + List toWrite; + lock (_writeLock) + { + if (_pendingControlWrites.Count == 0) return; + toWrite = [.. _pendingControlWrites]; + _pendingControlWrites.Clear(); + } + + foreach (var action in toWrite) + { + var frame = WsFrameWriter.BuildControlFrame(action.Opcode, action.Payload, _maskWrite); + await _inner.WriteAsync(frame, ct); + } + await _inner.FlushAsync(ct); + } + + /// + /// Sends a WebSocket close frame. + /// + public async Task SendCloseAsync(ClientClosedReason reason, CancellationToken ct = default) + { + var status = WsFrameWriter.MapCloseStatus(reason); + var closePayload = WsFrameWriter.CreateCloseMessage(status, reason.ToReasonString()); + var frame = WsFrameWriter.BuildControlFrame(WsConstants.CloseMessage, closePayload, _maskWrite); + await _inner.WriteAsync(frame, ct); + await _inner.FlushAsync(ct); + } + + private int DrainReadQueue(Span buffer) + { + int written = 0; + while (_readQueue.Count > 0 && written < buffer.Length) + { + var current = _readQueue.Peek(); + int available = current.Length - _readOffset; + int toCopy = Math.Min(available, buffer.Length - written); + current.AsSpan(_readOffset, toCopy).CopyTo(buffer[written..]); + written += toCopy; + _readOffset += toCopy; + if (_readOffset >= current.Length) + { + _readQueue.Dequeue(); + _readOffset = 0; + } + } + return written; + } + + // Stream abstract members + public override bool CanRead => true; + public override bool CanWrite => true; + public override bool CanSeek => false; + public override long Length => throw new NotSupportedException(); + public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } + public override void Flush() => _inner.Flush(); + public override Task FlushAsync(CancellationToken ct) => _inner.FlushAsync(ct); + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use ReadAsync"); + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use WriteAsync"); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + + protected override void Dispose(bool disposing) + { + if (disposing) + _inner.Dispose(); + base.Dispose(disposing); + } + + public override async ValueTask DisposeAsync() + { + await _inner.DisposeAsync(); + GC.SuppressFinalize(this); + } +} diff --git a/src/NATS.Server/WebSocket/WsConstants.cs b/src/NATS.Server/WebSocket/WsConstants.cs new file mode 100644 index 0000000..8a3a9d3 --- /dev/null +++ b/src/NATS.Server/WebSocket/WsConstants.cs @@ -0,0 +1,72 @@ +namespace NATS.Server.WebSocket; + +/// +/// WebSocket protocol constants (RFC 6455). +/// Ported from golang/nats-server/server/websocket.go lines 41-106. +/// +public static class WsConstants +{ + // Opcodes (RFC 6455 Section 5.2) + public const int TextMessage = 1; + public const int BinaryMessage = 2; + public const int CloseMessage = 8; + public const int PingMessage = 9; + public const int PongMessage = 10; + public const int ContinuationFrame = 0; + + // Frame header bits + public const byte FinalBit = 0x80; // 1 << 7 + public const byte Rsv1Bit = 0x40; // 1 << 6 (compression, RFC 7692) + public const byte Rsv2Bit = 0x20; // 1 << 5 + public const byte Rsv3Bit = 0x10; // 1 << 4 + public const byte MaskBit = 0x80; // 1 << 7 (in second byte) + + // Frame size limits + public const int MaxFrameHeaderSize = 14; + public const int MaxControlPayloadSize = 125; + public const int FrameSizeForBrowsers = 4096; + public const int CompressThreshold = 64; + public const int CloseStatusSize = 2; + + // Close status codes (RFC 6455 Section 11.7) + public const int CloseStatusNormalClosure = 1000; + public const int CloseStatusGoingAway = 1001; + public const int CloseStatusProtocolError = 1002; + public const int CloseStatusUnsupportedData = 1003; + public const int CloseStatusNoStatusReceived = 1005; + public const int CloseStatusInvalidPayloadData = 1007; + public const int CloseStatusPolicyViolation = 1008; + public const int CloseStatusMessageTooBig = 1009; + public const int CloseStatusInternalSrvError = 1011; + public const int CloseStatusTlsHandshake = 1015; + + // Compression constants (RFC 7692) + public const string PmcExtension = "permessage-deflate"; + public const string PmcSrvNoCtx = "server_no_context_takeover"; + public const string PmcCliNoCtx = "client_no_context_takeover"; + public static readonly string PmcReqHeaderValue = $"{PmcExtension}; {PmcSrvNoCtx}; {PmcCliNoCtx}"; + public static readonly string PmcFullResponse = $"Sec-WebSocket-Extensions: {PmcExtension}; {PmcSrvNoCtx}; {PmcCliNoCtx}\r\n"; + + // Header names + public const string NoMaskingHeader = "Nats-No-Masking"; + public const string NoMaskingValue = "true"; + public static readonly string NoMaskingFullResponse = $"{NoMaskingHeader}: {NoMaskingValue}\r\n"; + public const string XForwardedForHeader = "X-Forwarded-For"; + + // Path routing + public const string ClientPath = "/"; + public const string LeafNodePath = "/leafnode"; + public const string MqttPath = "/mqtt"; + + // Decompression trailer appended before decompressing (RFC 7692 Section 7.2.2) + public static readonly byte[] DecompressTrailer = [0x00, 0x00, 0xff, 0xff]; + + public static bool IsControlFrame(int opcode) => opcode >= CloseMessage; +} + +public enum WsClientKind +{ + Client, + Leaf, + Mqtt, +} diff --git a/src/NATS.Server/WebSocket/WsFrameWriter.cs b/src/NATS.Server/WebSocket/WsFrameWriter.cs new file mode 100644 index 0000000..59ba4f8 --- /dev/null +++ b/src/NATS.Server/WebSocket/WsFrameWriter.cs @@ -0,0 +1,171 @@ +using System.Buffers.Binary; +using System.Security.Cryptography; +using System.Text; + +namespace NATS.Server.WebSocket; + +/// +/// WebSocket frame construction, masking, and control message creation. +/// Ported from golang/nats-server/server/websocket.go lines 543-726. +/// +public static class WsFrameWriter +{ + /// + /// Creates a complete frame header for a single-frame message (first=true, final=true). + /// Returns (header bytes, mask key or null). + /// + public static (byte[] header, byte[]? key) CreateFrameHeader( + bool useMasking, bool compressed, int opcode, int payloadLength) + { + var fh = new byte[WsConstants.MaxFrameHeaderSize]; + var (n, key) = FillFrameHeader(fh, useMasking, + first: true, final: true, compressed: compressed, opcode: opcode, payloadLength: payloadLength); + return (fh[..n], key); + } + + /// + /// Fills a pre-allocated frame header buffer. + /// Returns (bytes written, mask key or null). + /// + public static (int written, byte[]? key) FillFrameHeader( + Span fh, bool useMasking, bool first, bool final, bool compressed, int opcode, int payloadLength) + { + byte b0 = first ? (byte)opcode : (byte)0; + if (final) b0 |= WsConstants.FinalBit; + if (compressed) b0 |= WsConstants.Rsv1Bit; + + byte b1 = 0; + if (useMasking) b1 |= WsConstants.MaskBit; + + int n; + switch (payloadLength) + { + case <= 125: + n = 2; + fh[0] = b0; + fh[1] = (byte)(b1 | (byte)payloadLength); + break; + case < 65536: + n = 4; + fh[0] = b0; + fh[1] = (byte)(b1 | 126); + BinaryPrimitives.WriteUInt16BigEndian(fh[2..], (ushort)payloadLength); + break; + default: + n = 10; + fh[0] = b0; + fh[1] = (byte)(b1 | 127); + BinaryPrimitives.WriteUInt64BigEndian(fh[2..], (ulong)payloadLength); + break; + } + + byte[]? key = null; + if (useMasking) + { + key = new byte[4]; + RandomNumberGenerator.Fill(key); + key.CopyTo(fh[n..]); + n += 4; + } + + return (n, key); + } + + /// + /// XOR masks a buffer with a 4-byte key. Applies in-place. + /// + public static void MaskBuf(ReadOnlySpan key, Span buf) + { + for (int i = 0; i < buf.Length; i++) + buf[i] ^= key[i & 3]; + } + + /// + /// XOR masks multiple contiguous buffers as if they were one. + /// + public static void MaskBufs(ReadOnlySpan key, List bufs) + { + int pos = 0; + foreach (var buf in bufs) + { + for (int j = 0; j < buf.Length; j++) + { + buf[j] ^= key[pos & 3]; + pos++; + } + } + } + + /// + /// Creates a close message payload: 2-byte status code + optional UTF-8 body. + /// Body truncated to fit MaxControlPayloadSize with "..." suffix. + /// + public static byte[] CreateCloseMessage(int status, string body) + { + var bodyBytes = Encoding.UTF8.GetBytes(body); + int maxBody = WsConstants.MaxControlPayloadSize - WsConstants.CloseStatusSize; + + if (bodyBytes.Length > maxBody) + { + var suffix = "..."u8; + int truncLen = maxBody - suffix.Length; + // Find a valid UTF-8 boundary by walking back from truncation point + while (truncLen > 0 && (bodyBytes[truncLen] & 0xC0) == 0x80) + truncLen--; + var buf = new byte[WsConstants.CloseStatusSize + truncLen + suffix.Length]; + BinaryPrimitives.WriteUInt16BigEndian(buf, (ushort)status); + bodyBytes.AsSpan(0, truncLen).CopyTo(buf.AsSpan(WsConstants.CloseStatusSize)); + suffix.CopyTo(buf.AsSpan(WsConstants.CloseStatusSize + truncLen)); + return buf; + } + + var result = new byte[WsConstants.CloseStatusSize + bodyBytes.Length]; + BinaryPrimitives.WriteUInt16BigEndian(result, (ushort)status); + bodyBytes.CopyTo(result.AsSpan(WsConstants.CloseStatusSize)); + return result; + } + + /// + /// Builds a complete control frame (header + payload, optional masking). + /// + public static byte[] BuildControlFrame(int opcode, ReadOnlySpan payload, bool useMasking) + { + int headerSize = 2 + (useMasking ? 4 : 0); + var frame = new byte[headerSize + payload.Length]; + var span = frame.AsSpan(); + var (n, key) = FillFrameHeader(span, useMasking, + first: true, final: true, compressed: false, opcode: opcode, payloadLength: payload.Length); + if (payload.Length > 0) + { + payload.CopyTo(span[n..]); + if (useMasking && key != null) + MaskBuf(key, span[n..]); + } + + return frame; + } + + /// + /// Maps a ClientClosedReason to a WebSocket close status code. + /// Matches Go wsEnqueueCloseMessage in websocket.go lines 668-694. + /// + public static int MapCloseStatus(ClientClosedReason reason) => reason switch + { + ClientClosedReason.ClientClosed => WsConstants.CloseStatusNormalClosure, + ClientClosedReason.AuthenticationTimeout or + ClientClosedReason.AuthenticationViolation or + ClientClosedReason.SlowConsumerPendingBytes or + ClientClosedReason.SlowConsumerWriteDeadline or + ClientClosedReason.MaxSubscriptionsExceeded or + ClientClosedReason.AuthenticationExpired => WsConstants.CloseStatusPolicyViolation, + ClientClosedReason.TlsHandshakeError => WsConstants.CloseStatusTlsHandshake, + ClientClosedReason.ParseError or + ClientClosedReason.ProtocolViolation => WsConstants.CloseStatusProtocolError, + ClientClosedReason.MaxPayloadExceeded => WsConstants.CloseStatusMessageTooBig, + ClientClosedReason.WriteError or + ClientClosedReason.ReadError or + ClientClosedReason.StaleConnection or + ClientClosedReason.ServerShutdown => WsConstants.CloseStatusGoingAway, + _ => WsConstants.CloseStatusInternalSrvError, + }; +} diff --git a/src/NATS.Server/WebSocket/WsOriginChecker.cs b/src/NATS.Server/WebSocket/WsOriginChecker.cs new file mode 100644 index 0000000..c11d1ce --- /dev/null +++ b/src/NATS.Server/WebSocket/WsOriginChecker.cs @@ -0,0 +1,81 @@ +namespace NATS.Server.WebSocket; + +/// +/// Validates WebSocket Origin headers per RFC 6455 Section 10.2. +/// Ported from golang/nats-server/server/websocket.go lines 933-1000. +/// +public sealed class WsOriginChecker +{ + private readonly bool _sameOrigin; + private readonly Dictionary? _allowedOrigins; + + public WsOriginChecker(bool sameOrigin, List? allowedOrigins) + { + _sameOrigin = sameOrigin; + if (allowedOrigins is { Count: > 0 }) + { + _allowedOrigins = new Dictionary(StringComparer.OrdinalIgnoreCase); + foreach (var ao in allowedOrigins) + { + if (Uri.TryCreate(ao, UriKind.Absolute, out var uri)) + { + var (host, port) = GetHostAndPort(uri.Scheme == "https", uri.Host, uri.Port); + _allowedOrigins[host] = new AllowedOrigin(uri.Scheme, port); + } + } + } + } + + /// + /// Returns null if origin is allowed, or an error message if rejected. + /// + public string? CheckOrigin(string? origin, string requestHost, bool isTls) + { + if (!_sameOrigin && _allowedOrigins == null) + return null; + + if (string.IsNullOrEmpty(origin)) + return null; + + if (!Uri.TryCreate(origin, UriKind.Absolute, out var originUri)) + return $"invalid origin: {origin}"; + + var (oh, op) = GetHostAndPort(originUri.Scheme == "https", originUri.Host, originUri.Port); + + if (_sameOrigin) + { + var (rh, rp) = ParseHostPort(requestHost, isTls); + if (!string.Equals(oh, rh, StringComparison.OrdinalIgnoreCase) || op != rp) + return "not same origin"; + } + + if (_allowedOrigins != null) + { + if (!_allowedOrigins.TryGetValue(oh, out var allowed) || + !string.Equals(originUri.Scheme, allowed.Scheme, StringComparison.OrdinalIgnoreCase) || + op != allowed.Port) + { + return "not in the allowed list"; + } + } + + return null; + } + + private static (string host, int port) GetHostAndPort(bool tls, string host, int port) + { + if (port <= 0) + port = tls ? 443 : 80; + return (host.ToLowerInvariant(), port); + } + + private static (string host, int port) ParseHostPort(string hostPort, bool isTls) + { + var colonIdx = hostPort.LastIndexOf(':'); + if (colonIdx > 0 && int.TryParse(hostPort.AsSpan(colonIdx + 1), out var port)) + return (hostPort[..colonIdx].ToLowerInvariant(), port); + return (hostPort.ToLowerInvariant(), isTls ? 443 : 80); + } + + private readonly record struct AllowedOrigin(string Scheme, int Port); +} diff --git a/src/NATS.Server/WebSocket/WsReadInfo.cs b/src/NATS.Server/WebSocket/WsReadInfo.cs new file mode 100644 index 0000000..f6930c2 --- /dev/null +++ b/src/NATS.Server/WebSocket/WsReadInfo.cs @@ -0,0 +1,322 @@ +using System.Buffers.Binary; +using System.Text; + +namespace NATS.Server.WebSocket; + +/// +/// Per-connection WebSocket frame reading state machine. +/// Ported from golang/nats-server/server/websocket.go lines 156-506. +/// +public class WsReadInfo +{ + public int Remaining; + public bool FrameStart; + public bool FirstFrame; + public bool FrameCompressed; + public bool ExpectMask; + public byte MaskKeyPos; + public byte[] MaskKey; + public List? CompressedBuffers; + public int CompressedOffset; + + // Control frame outputs + public List PendingControlFrames; + public bool CloseReceived; + public int CloseStatus; + public string? CloseBody; + + public WsReadInfo(bool expectMask) + { + Remaining = 0; + FrameStart = true; + FirstFrame = true; + FrameCompressed = false; + ExpectMask = expectMask; + MaskKeyPos = 0; + MaskKey = new byte[4]; + CompressedBuffers = null; + CompressedOffset = 0; + PendingControlFrames = []; + CloseReceived = false; + CloseStatus = 0; + CloseBody = null; + } + + public void SetMaskKey(ReadOnlySpan key) + { + key[..4].CopyTo(MaskKey); + MaskKeyPos = 0; + } + + /// + /// Unmask buffer in-place using current mask key and position. + /// Optimized for 8-byte chunks when buffer is large enough. + /// Ported from websocket.go lines 509-536. + /// + public void Unmask(Span buf) + { + int p = MaskKeyPos; + if (buf.Length < 16) + { + for (int i = 0; i < buf.Length; i++) + { + buf[i] ^= MaskKey[p & 3]; + p++; + } + MaskKeyPos = (byte)(p & 3); + return; + } + + // Build 8-byte key for bulk XOR + Span k = stackalloc byte[8]; + for (int i = 0; i < 8; i++) + k[i] = MaskKey[(p + i) & 3]; + ulong km = BinaryPrimitives.ReadUInt64BigEndian(k); + + int n = (buf.Length / 8) * 8; + for (int i = 0; i < n; i += 8) + { + ulong tmp = BinaryPrimitives.ReadUInt64BigEndian(buf[i..]); + tmp ^= km; + BinaryPrimitives.WriteUInt64BigEndian(buf[i..], tmp); + } + + // Handle remaining bytes + p += n; + var tail = buf[n..]; + for (int i = 0; i < tail.Length; i++) + { + tail[i] ^= MaskKey[p & 3]; + p++; + } + MaskKeyPos = (byte)(p & 3); + } + + /// + /// Read and decode WebSocket frames from a buffer. + /// Returns list of decoded payload byte arrays. + /// Ported from websocket.go lines 208-351. + /// + public static List ReadFrames(WsReadInfo r, Stream stream, int available, int maxPayload) + { + var bufs = new List(); + var buf = new byte[available]; + int bytesRead = 0; + + // Fill the buffer from the stream + while (bytesRead < available) + { + int n = stream.Read(buf, bytesRead, available - bytesRead); + if (n == 0) break; + bytesRead += n; + } + + int pos = 0; + int max = bytesRead; + + while (pos < max) + { + if (r.FrameStart) + { + if (pos >= max) break; + byte b0 = buf[pos]; + int frameType = b0 & 0x0F; + bool final = (b0 & WsConstants.FinalBit) != 0; + bool compressed = (b0 & WsConstants.Rsv1Bit) != 0; + pos++; + + // Read second byte + var (b1Buf, newPos) = WsGet(stream, buf, pos, max, 1); + pos = newPos; + byte b1 = b1Buf[0]; + + // Check mask bit + if (r.ExpectMask && (b1 & WsConstants.MaskBit) == 0) + throw new InvalidOperationException("mask bit missing"); + + r.Remaining = b1 & 0x7F; + + // Validate frame types + if (WsConstants.IsControlFrame(frameType)) + { + if (r.Remaining > WsConstants.MaxControlPayloadSize) + throw new InvalidOperationException("control frame length too large"); + if (!final) + throw new InvalidOperationException("control frame does not have final bit set"); + } + else if (frameType == WsConstants.TextMessage || frameType == WsConstants.BinaryMessage) + { + if (!r.FirstFrame) + throw new InvalidOperationException("new message before previous finished"); + r.FirstFrame = final; + r.FrameCompressed = compressed; + } + else if (frameType == WsConstants.ContinuationFrame) + { + if (r.FirstFrame || compressed) + throw new InvalidOperationException("invalid continuation frame"); + r.FirstFrame = final; + } + else + { + throw new InvalidOperationException($"unknown opcode {frameType}"); + } + + // Extended payload length + switch (r.Remaining) + { + case 126: + { + var (lenBuf, p2) = WsGet(stream, buf, pos, max, 2); + pos = p2; + r.Remaining = BinaryPrimitives.ReadUInt16BigEndian(lenBuf); + break; + } + case 127: + { + var (lenBuf, p2) = WsGet(stream, buf, pos, max, 8); + pos = p2; + var len64 = BinaryPrimitives.ReadUInt64BigEndian(lenBuf); + if (len64 > (ulong)maxPayload) + throw new InvalidOperationException($"frame payload length {len64} exceeds max payload {maxPayload}"); + r.Remaining = (int)len64; + break; + } + } + + // Read mask key (mask bit already validated at line 134) + if (r.ExpectMask) + { + var (keyBuf, p2) = WsGet(stream, buf, pos, max, 4); + pos = p2; + keyBuf.AsSpan(0, 4).CopyTo(r.MaskKey); + r.MaskKeyPos = 0; + } + + // Handle control frames + if (WsConstants.IsControlFrame(frameType)) + { + pos = HandleControlFrame(r, frameType, stream, buf, pos, max); + continue; + } + + r.FrameStart = false; + } + + if (pos < max) + { + int n = r.Remaining; + if (pos + n > max) n = max - pos; + + var payloadSlice = buf.AsSpan(pos, n).ToArray(); + pos += n; + r.Remaining -= n; + + if (r.ExpectMask) + r.Unmask(payloadSlice); + + bool addToBufs = true; + if (r.FrameCompressed) + { + addToBufs = false; + r.CompressedBuffers ??= []; + r.CompressedBuffers.Add(payloadSlice); + + if (r.FirstFrame && r.Remaining == 0) + { + var decompressed = WsCompression.Decompress(r.CompressedBuffers, maxPayload); + r.CompressedBuffers = null; + r.FrameCompressed = false; + addToBufs = true; + payloadSlice = decompressed; + } + } + + if (addToBufs && payloadSlice.Length > 0) + bufs.Add(payloadSlice); + + if (r.Remaining == 0) + r.FrameStart = true; + } + } + + return bufs; + } + + private static int HandleControlFrame(WsReadInfo r, int frameType, Stream stream, byte[] buf, int pos, int max) + { + byte[]? payload = null; + if (r.Remaining > 0) + { + var (payloadBuf, newPos) = WsGet(stream, buf, pos, max, r.Remaining); + pos = newPos; + payload = payloadBuf; + if (r.ExpectMask) + r.Unmask(payload); + r.Remaining = 0; + } + + switch (frameType) + { + case WsConstants.CloseMessage: + r.CloseReceived = true; + r.CloseStatus = WsConstants.CloseStatusNoStatusReceived; + if (payload != null && payload.Length >= WsConstants.CloseStatusSize) + { + r.CloseStatus = BinaryPrimitives.ReadUInt16BigEndian(payload); + if (payload.Length > WsConstants.CloseStatusSize) + r.CloseBody = Encoding.UTF8.GetString(payload.AsSpan(WsConstants.CloseStatusSize)); + } + // Per RFC 6455 Section 5.5.1, always send a close response + if (r.CloseStatus != WsConstants.CloseStatusNoStatusReceived) + { + var closeMsg = WsFrameWriter.CreateCloseMessage(r.CloseStatus, r.CloseBody ?? ""); + r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.CloseMessage, closeMsg)); + } + else + { + // Empty close frame — respond with empty close + r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.CloseMessage, [])); + } + break; + + case WsConstants.PingMessage: + r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.PongMessage, payload ?? [])); + break; + + case WsConstants.PongMessage: + // Nothing to do + break; + } + + return pos; + } + + /// + /// Gets needed bytes from buffer or reads from stream. + /// Ported from websocket.go lines 178-193. + /// + private static (byte[] data, int newPos) WsGet(Stream stream, byte[] buf, int pos, int max, int needed) + { + int avail = max - pos; + if (avail >= needed) + return (buf[pos..(pos + needed)], pos + needed); + + var b = new byte[needed]; + int start = 0; + if (avail > 0) + { + Buffer.BlockCopy(buf, pos, b, 0, avail); + start = avail; + } + while (start < needed) + { + int n = stream.Read(b, start, needed - start); + if (n == 0) throw new IOException("unexpected end of stream"); + start += n; + } + return (b, pos + avail); + } +} + +public readonly record struct ControlFrameAction(int Opcode, byte[] Payload); diff --git a/src/NATS.Server/WebSocket/WsUpgrade.cs b/src/NATS.Server/WebSocket/WsUpgrade.cs new file mode 100644 index 0000000..d2fddbc --- /dev/null +++ b/src/NATS.Server/WebSocket/WsUpgrade.cs @@ -0,0 +1,268 @@ +using System.Net; +using System.Security.Cryptography; +using System.Text; + +namespace NATS.Server.WebSocket; + +/// +/// WebSocket HTTP upgrade handshake handler. +/// Ported from golang/nats-server/server/websocket.go lines 731-917. +/// +public static class WsUpgrade +{ + public static async Task TryUpgradeAsync( + Stream inputStream, Stream outputStream, WebSocketOptions options, + CancellationToken ct = default) + { + try + { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); + cts.CancelAfter(options.HandshakeTimeout); + var (method, path, headers) = await ReadHttpRequestAsync(inputStream, cts.Token); + + if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase)) + return await FailAsync(outputStream, 405, "request method must be GET"); + + if (!headers.ContainsKey("Host")) + return await FailAsync(outputStream, 400, "'Host' missing in request"); + + if (!HeaderContains(headers, "Upgrade", "websocket")) + return await FailAsync(outputStream, 400, "invalid value for header 'Upgrade'"); + + if (!HeaderContains(headers, "Connection", "Upgrade")) + return await FailAsync(outputStream, 400, "invalid value for header 'Connection'"); + + if (!headers.TryGetValue("Sec-WebSocket-Key", out var key) || string.IsNullOrEmpty(key)) + return await FailAsync(outputStream, 400, "key missing"); + + if (!HeaderContains(headers, "Sec-WebSocket-Version", "13")) + return await FailAsync(outputStream, 400, "invalid version"); + + var kind = path switch + { + _ when path.EndsWith("/leafnode") => WsClientKind.Leaf, + _ when path.EndsWith("/mqtt") => WsClientKind.Mqtt, + _ => WsClientKind.Client, + }; + + // Origin checking + if (options.SameOrigin || options.AllowedOrigins is { Count: > 0 }) + { + var checker = new WsOriginChecker(options.SameOrigin, options.AllowedOrigins); + headers.TryGetValue("Origin", out var origin); + if (string.IsNullOrEmpty(origin)) + headers.TryGetValue("Sec-WebSocket-Origin", out origin); + var originErr = checker.CheckOrigin(origin, headers.GetValueOrDefault("Host", ""), isTls: false); + if (originErr != null) + return await FailAsync(outputStream, 403, $"origin not allowed: {originErr}"); + } + + // Compression negotiation + bool compress = options.Compression; + if (compress) + { + compress = headers.TryGetValue("Sec-WebSocket-Extensions", out var ext) && + ext.Contains(WsConstants.PmcExtension, StringComparison.OrdinalIgnoreCase); + } + + // No-masking support (leaf nodes only — browser clients must always mask) + bool noMasking = kind == WsClientKind.Leaf && + headers.TryGetValue(WsConstants.NoMaskingHeader, out var nmVal) && + string.Equals(nmVal.Trim(), WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase); + + // Browser detection + bool browser = false; + bool noCompFrag = false; + if (kind is WsClientKind.Client or WsClientKind.Mqtt && + headers.TryGetValue("User-Agent", out var ua) && ua.StartsWith("Mozilla/")) + { + browser = true; + // Disable fragmentation of compressed frames for Safari browsers. + // Safari has both "Version/" and "Safari/" in the user agent string, + // while Chrome on macOS has "Safari/" but not "Version/". + noCompFrag = compress && ua.Contains("Version/") && ua.Contains("Safari/"); + } + + // Cookie extraction + string? cookieJwt = null, cookieUsername = null, cookiePassword = null, cookieToken = null; + if ((kind is WsClientKind.Client or WsClientKind.Mqtt) && + headers.TryGetValue("Cookie", out var cookieHeader)) + { + var cookies = ParseCookies(cookieHeader); + if (options.JwtCookie != null) cookies.TryGetValue(options.JwtCookie, out cookieJwt); + if (options.UsernameCookie != null) cookies.TryGetValue(options.UsernameCookie, out cookieUsername); + if (options.PasswordCookie != null) cookies.TryGetValue(options.PasswordCookie, out cookiePassword); + if (options.TokenCookie != null) cookies.TryGetValue(options.TokenCookie, out cookieToken); + } + + // X-Forwarded-For client IP extraction + string? clientIp = null; + if (headers.TryGetValue(WsConstants.XForwardedForHeader, out var xff)) + { + var ip = xff.Split(',')[0].Trim(); + if (IPAddress.TryParse(ip, out _)) + clientIp = ip; + } + + // Build the 101 Switching Protocols response + var response = new StringBuilder(); + response.Append("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "); + response.Append(ComputeAcceptKey(key)); + response.Append("\r\n"); + if (compress) + response.Append(WsConstants.PmcFullResponse); + if (noMasking) + response.Append(WsConstants.NoMaskingFullResponse); + if (options.Headers != null) + { + foreach (var (k, v) in options.Headers) + { + response.Append(k); + response.Append(": "); + response.Append(v); + response.Append("\r\n"); + } + } + + response.Append("\r\n"); + + var responseBytes = Encoding.ASCII.GetBytes(response.ToString()); + await outputStream.WriteAsync(responseBytes); + await outputStream.FlushAsync(); + + return new WsUpgradeResult( + Success: true, Compress: compress, Browser: browser, NoCompFrag: noCompFrag, + MaskRead: !noMasking, MaskWrite: false, + CookieJwt: cookieJwt, CookieUsername: cookieUsername, + CookiePassword: cookiePassword, CookieToken: cookieToken, + ClientIp: clientIp, Kind: kind); + } + catch (Exception) + { + return WsUpgradeResult.Failed; + } + } + + /// + /// Computes the Sec-WebSocket-Accept value per RFC 6455 Section 4.2.2. + /// + public static string ComputeAcceptKey(string clientKey) + { + var combined = Encoding.ASCII.GetBytes(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + var hash = SHA1.HashData(combined); + return Convert.ToBase64String(hash); + } + + private static async Task FailAsync(Stream output, int statusCode, string reason) + { + var statusText = statusCode switch + { + 400 => "Bad Request", + 403 => "Forbidden", + 405 => "Method Not Allowed", + _ => "Internal Server Error", + }; + var response = $"HTTP/1.1 {statusCode} {statusText}\r\nSec-WebSocket-Version: 13\r\nContent-Type: text/plain\r\nContent-Length: {reason.Length}\r\n\r\n{reason}"; + await output.WriteAsync(Encoding.ASCII.GetBytes(response)); + await output.FlushAsync(); + return WsUpgradeResult.Failed; + } + + private static async Task<(string method, string path, Dictionary headers)> ReadHttpRequestAsync( + Stream stream, CancellationToken ct) + { + var headerBytes = new List(4096); + var buf = new byte[512]; + while (true) + { + int n = await stream.ReadAsync(buf, ct); + if (n == 0) throw new IOException("connection closed during handshake"); + for (int i = 0; i < n; i++) + { + headerBytes.Add(buf[i]); + if (headerBytes.Count >= 4 && + headerBytes[^4] == '\r' && headerBytes[^3] == '\n' && + headerBytes[^2] == '\r' && headerBytes[^1] == '\n') + goto done; + if (headerBytes.Count > 8192) + throw new InvalidOperationException("HTTP header too large"); + } + } + done:; + + var text = Encoding.ASCII.GetString(headerBytes.ToArray()); + var lines = text.Split("\r\n", StringSplitOptions.None); + if (lines.Length < 1) throw new InvalidOperationException("invalid HTTP request"); + + var parts = lines[0].Split(' '); + if (parts.Length < 3) throw new InvalidOperationException("invalid HTTP request line"); + var method = parts[0]; + var path = parts[1]; + + var headers = new Dictionary(StringComparer.OrdinalIgnoreCase); + for (int i = 1; i < lines.Length; i++) + { + var line = lines[i]; + if (string.IsNullOrEmpty(line)) break; + var colonIdx = line.IndexOf(':'); + if (colonIdx > 0) + { + var name = line[..colonIdx].Trim(); + var value = line[(colonIdx + 1)..].Trim(); + headers[name] = value; + } + } + + return (method, path, headers); + } + + private static bool HeaderContains(Dictionary headers, string name, string value) + { + if (!headers.TryGetValue(name, out var headerValue)) + return false; + foreach (var token in headerValue.Split(',')) + { + if (string.Equals(token.Trim(), value, StringComparison.OrdinalIgnoreCase)) + return true; + } + + return false; + } + + private static Dictionary ParseCookies(string cookieHeader) + { + var cookies = new Dictionary(StringComparer.Ordinal); + foreach (var pair in cookieHeader.Split(';')) + { + var trimmed = pair.Trim(); + var eqIdx = trimmed.IndexOf('='); + if (eqIdx > 0) + cookies[trimmed[..eqIdx].Trim()] = trimmed[(eqIdx + 1)..].Trim(); + } + + return cookies; + } +} + +/// +/// Result of a WebSocket upgrade handshake attempt. +/// +public readonly record struct WsUpgradeResult( + bool Success, + bool Compress, + bool Browser, + bool NoCompFrag, + bool MaskRead, + bool MaskWrite, + string? CookieJwt, + string? CookieUsername, + string? CookiePassword, + string? CookieToken, + string? ClientIp, + WsClientKind Kind) +{ + public static readonly WsUpgradeResult Failed = new( + Success: false, Compress: false, Browser: false, NoCompFrag: false, + MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null, + CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client); +} diff --git a/tests/NATS.Server.Tests/WebSocket/WebSocketOptionsTests.cs b/tests/NATS.Server.Tests/WebSocket/WebSocketOptionsTests.cs new file mode 100644 index 0000000..50c43b7 --- /dev/null +++ b/tests/NATS.Server.Tests/WebSocket/WebSocketOptionsTests.cs @@ -0,0 +1,26 @@ +using Shouldly; + +namespace NATS.Server.Tests.WebSocket; + +public class WebSocketOptionsTests +{ + [Fact] + public void DefaultOptions_PortIsNegativeOne_Disabled() + { + var opts = new WebSocketOptions(); + opts.Port.ShouldBe(-1); + opts.Host.ShouldBe("0.0.0.0"); + opts.Compression.ShouldBeFalse(); + opts.NoTls.ShouldBeFalse(); + opts.HandshakeTimeout.ShouldBe(TimeSpan.FromSeconds(2)); + opts.AuthTimeout.ShouldBe(TimeSpan.FromSeconds(2)); + } + + [Fact] + public void NatsOptions_HasWebSocketProperty() + { + var opts = new NatsOptions(); + opts.WebSocket.ShouldNotBeNull(); + opts.WebSocket.Port.ShouldBe(-1); + } +} diff --git a/tests/NATS.Server.Tests/WebSocket/WsCompressionTests.cs b/tests/NATS.Server.Tests/WebSocket/WsCompressionTests.cs new file mode 100644 index 0000000..425534c --- /dev/null +++ b/tests/NATS.Server.Tests/WebSocket/WsCompressionTests.cs @@ -0,0 +1,58 @@ +using NATS.Server.WebSocket; +using Shouldly; + +namespace NATS.Server.Tests.WebSocket; + +public class WsCompressionTests +{ + [Fact] + public void CompressDecompress_RoundTrip() + { + var original = "Hello, WebSocket compression test! This is long enough to compress."u8.ToArray(); + var compressed = WsCompression.Compress(original); + compressed.ShouldNotBeNull(); + compressed.Length.ShouldBeGreaterThan(0); + + var decompressed = WsCompression.Decompress([compressed], maxPayload: 4096); + decompressed.ShouldBe(original); + } + + [Fact] + public void Decompress_ExceedsMaxPayload_Throws() + { + var original = new byte[1000]; + Random.Shared.NextBytes(original); + var compressed = WsCompression.Compress(original); + + Should.Throw(() => + WsCompression.Decompress([compressed], maxPayload: 100)); + } + + [Fact] + public void Compress_RemovesTrailing4Bytes() + { + var data = new byte[200]; + Random.Shared.NextBytes(data); + var compressed = WsCompression.Compress(data); + + // The compressed data should be valid for decompression when we add the trailer back + var decompressed = WsCompression.Decompress([compressed], maxPayload: 4096); + decompressed.ShouldBe(data); + } + + [Fact] + public void Decompress_MultipleBuffers() + { + var original = new byte[500]; + Random.Shared.NextBytes(original); + var compressed = WsCompression.Compress(original); + + // Split compressed data into multiple chunks + int mid = compressed.Length / 2; + var chunk1 = compressed[..mid]; + var chunk2 = compressed[mid..]; + + var decompressed = WsCompression.Decompress([chunk1, chunk2], maxPayload: 4096); + decompressed.ShouldBe(original); + } +} diff --git a/tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs b/tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs new file mode 100644 index 0000000..8f30768 --- /dev/null +++ b/tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs @@ -0,0 +1,124 @@ +using System.Buffers.Binary; +using NATS.Server.WebSocket; + +namespace NATS.Server.Tests.WebSocket; + +public class WsConnectionTests +{ + [Fact] + public async Task ReadAsync_DecodesFrameAndReturnsPayload() + { + var payload = "SUB test 1\r\n"u8.ToArray(); + var frame = BuildUnmaskedFrame(payload); + var inner = new MemoryStream(frame); + var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false); + + var buf = new byte[256]; + int n = await ws.ReadAsync(buf); + + n.ShouldBe(payload.Length); + buf[..n].ShouldBe(payload); + } + + [Fact] + public async Task WriteAsync_FramesPayload() + { + var inner = new MemoryStream(); + var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false); + + var payload = "MSG test 1 5\r\nHello\r\n"u8.ToArray(); + await ws.WriteAsync(payload); + await ws.FlushAsync(); + + inner.Position = 0; + var written = inner.ToArray(); + // First 2 bytes should be WS frame header + (written[0] & WsConstants.FinalBit).ShouldNotBe(0); + (written[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage); + int len = written[1] & 0x7F; + len.ShouldBe(payload.Length); + written[2..].ShouldBe(payload); + } + + [Fact] + public async Task WriteAsync_WithCompression_CompressesLargePayload() + { + var inner = new MemoryStream(); + var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false); + + var payload = new byte[200]; + Array.Fill(payload, 0x41); // 'A' repeated - very compressible + await ws.WriteAsync(payload); + await ws.FlushAsync(); + + inner.Position = 0; + var written = inner.ToArray(); + // RSV1 bit should be set for compressed frame + (written[0] & WsConstants.Rsv1Bit).ShouldNotBe(0); + // Compressed size should be less than original + written.Length.ShouldBeLessThan(payload.Length + 10); + } + + [Fact] + public async Task WriteAsync_SmallPayload_NotCompressedEvenWhenEnabled() + { + var inner = new MemoryStream(); + var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false); + + var payload = "Hi"u8.ToArray(); // Below CompressThreshold + await ws.WriteAsync(payload); + await ws.FlushAsync(); + + inner.Position = 0; + var written = inner.ToArray(); + // RSV1 bit should NOT be set for small payloads + (written[0] & WsConstants.Rsv1Bit).ShouldBe(0); + } + + [Fact] + public async Task ReadAsync_DecodesMaskedFrame() + { + var payload = "CONNECT {}\r\n"u8.ToArray(); + var (header, _) = WsFrameWriter.CreateFrameHeader( + useMasking: true, compressed: false, + opcode: WsConstants.BinaryMessage, payloadLength: payload.Length); + var maskKey = header[^4..]; + WsFrameWriter.MaskBuf(maskKey, payload); + + var frame = new byte[header.Length + payload.Length]; + header.CopyTo(frame, 0); + payload.CopyTo(frame, header.Length); + + var inner = new MemoryStream(frame); + var ws = new WsConnection(inner, compress: false, maskRead: true, maskWrite: false, browser: false, noCompFrag: false); + + var buf = new byte[256]; + int n = await ws.ReadAsync(buf); + + n.ShouldBe("CONNECT {}\r\n".Length); + System.Text.Encoding.ASCII.GetString(buf, 0, n).ShouldBe("CONNECT {}\r\n"); + } + + [Fact] + public async Task ReadAsync_ReturnsZero_OnEndOfStream() + { + // Empty stream should return 0 (true end of stream) + var inner = new MemoryStream([]); + var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false); + + var buf = new byte[256]; + int n = await ws.ReadAsync(buf); + n.ShouldBe(0); + } + + private static byte[] BuildUnmaskedFrame(byte[] payload) + { + var header = new byte[2]; + header[0] = (byte)(WsConstants.FinalBit | WsConstants.BinaryMessage); + header[1] = (byte)payload.Length; + var frame = new byte[2 + payload.Length]; + header.CopyTo(frame, 0); + payload.CopyTo(frame, 2); + return frame; + } +} diff --git a/tests/NATS.Server.Tests/WebSocket/WsConstantsTests.cs b/tests/NATS.Server.Tests/WebSocket/WsConstantsTests.cs new file mode 100644 index 0000000..3dd0b33 --- /dev/null +++ b/tests/NATS.Server.Tests/WebSocket/WsConstantsTests.cs @@ -0,0 +1,53 @@ +using NATS.Server.WebSocket; +using Shouldly; + +namespace NATS.Server.Tests.WebSocket; + +public class WsConstantsTests +{ + [Fact] + public void OpCodes_MatchRfc6455() + { + WsConstants.TextMessage.ShouldBe(1); + WsConstants.BinaryMessage.ShouldBe(2); + WsConstants.CloseMessage.ShouldBe(8); + WsConstants.PingMessage.ShouldBe(9); + WsConstants.PongMessage.ShouldBe(10); + } + + [Fact] + public void FrameBits_MatchRfc6455() + { + WsConstants.FinalBit.ShouldBe((byte)0x80); + WsConstants.Rsv1Bit.ShouldBe((byte)0x40); + WsConstants.MaskBit.ShouldBe((byte)0x80); + } + + [Fact] + public void CloseStatusCodes_MatchRfc6455() + { + WsConstants.CloseStatusNormalClosure.ShouldBe(1000); + WsConstants.CloseStatusGoingAway.ShouldBe(1001); + WsConstants.CloseStatusProtocolError.ShouldBe(1002); + WsConstants.CloseStatusPolicyViolation.ShouldBe(1008); + WsConstants.CloseStatusMessageTooBig.ShouldBe(1009); + } + + [Theory] + [InlineData(WsConstants.CloseMessage)] + [InlineData(WsConstants.PingMessage)] + [InlineData(WsConstants.PongMessage)] + public void IsControlFrame_True(int opcode) + { + WsConstants.IsControlFrame(opcode).ShouldBeTrue(); + } + + [Theory] + [InlineData(WsConstants.TextMessage)] + [InlineData(WsConstants.BinaryMessage)] + [InlineData(0)] + public void IsControlFrame_False(int opcode) + { + WsConstants.IsControlFrame(opcode).ShouldBeFalse(); + } +} diff --git a/tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs b/tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs new file mode 100644 index 0000000..7e0e9df --- /dev/null +++ b/tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs @@ -0,0 +1,163 @@ +using System.Buffers.Binary; +using NATS.Server.WebSocket; +using Shouldly; + +namespace NATS.Server.Tests.WebSocket; + +public class WsFrameReadTests +{ + /// Helper: build a single unmasked binary frame. + private static byte[] BuildFrame(byte[] payload, bool fin = true, bool compressed = false, int opcode = WsConstants.BinaryMessage, bool mask = false, byte[]? maskKey = null) + { + int payloadLen = payload.Length; + byte b0 = (byte)opcode; + if (fin) b0 |= WsConstants.FinalBit; + if (compressed) b0 |= WsConstants.Rsv1Bit; + byte b1 = 0; + if (mask) b1 |= WsConstants.MaskBit; + + byte[] lenBytes; + if (payloadLen <= 125) + { + lenBytes = [(byte)(b1 | (byte)payloadLen)]; + } + else if (payloadLen < 65536) + { + lenBytes = new byte[3]; + lenBytes[0] = (byte)(b1 | 126); + BinaryPrimitives.WriteUInt16BigEndian(lenBytes.AsSpan(1), (ushort)payloadLen); + } + else + { + lenBytes = new byte[9]; + lenBytes[0] = (byte)(b1 | 127); + BinaryPrimitives.WriteUInt64BigEndian(lenBytes.AsSpan(1), (ulong)payloadLen); + } + + int totalLen = 1 + lenBytes.Length + (mask ? 4 : 0) + payloadLen; + var frame = new byte[totalLen]; + frame[0] = b0; + lenBytes.CopyTo(frame.AsSpan(1)); + int pos = 1 + lenBytes.Length; + if (mask && maskKey != null) + { + maskKey.CopyTo(frame.AsSpan(pos)); + pos += 4; + var maskedPayload = payload.ToArray(); + WsFrameWriter.MaskBuf(maskKey, maskedPayload); + maskedPayload.CopyTo(frame.AsSpan(pos)); + } + else + { + payload.CopyTo(frame.AsSpan(pos)); + } + return frame; + } + + [Fact] + public void ReadSingleUnmaskedFrame() + { + var payload = "Hello"u8.ToArray(); + var frame = BuildFrame(payload); + + var readInfo = new WsReadInfo(expectMask: false); + var stream = new MemoryStream(frame); + var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024); + + result.Count.ShouldBe(1); + result[0].ShouldBe(payload); + } + + [Fact] + public void ReadMaskedFrame() + { + var payload = "Hello"u8.ToArray(); + byte[] key = [0x37, 0xFA, 0x21, 0x3D]; + var frame = BuildFrame(payload, mask: true, maskKey: key); + + var readInfo = new WsReadInfo(expectMask: true); + var stream = new MemoryStream(frame); + var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024); + + result.Count.ShouldBe(1); + result[0].ShouldBe(payload); + } + + [Fact] + public void Read16BitLengthFrame() + { + var payload = new byte[200]; + Random.Shared.NextBytes(payload); + var frame = BuildFrame(payload); + + var readInfo = new WsReadInfo(expectMask: false); + var stream = new MemoryStream(frame); + var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024); + + result.Count.ShouldBe(1); + result[0].ShouldBe(payload); + } + + [Fact] + public void ReadPingFrame_ReturnsPongAction() + { + var frame = BuildFrame([], opcode: WsConstants.PingMessage); + + var readInfo = new WsReadInfo(expectMask: false); + var stream = new MemoryStream(frame); + var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024); + + result.Count.ShouldBe(0); // control frames don't produce payload + readInfo.PendingControlFrames.Count.ShouldBe(1); + readInfo.PendingControlFrames[0].Opcode.ShouldBe(WsConstants.PongMessage); + } + + [Fact] + public void ReadCloseFrame_ReturnsCloseAction() + { + var closePayload = new byte[2]; + BinaryPrimitives.WriteUInt16BigEndian(closePayload, 1000); + var frame = BuildFrame(closePayload, opcode: WsConstants.CloseMessage); + + var readInfo = new WsReadInfo(expectMask: false); + var stream = new MemoryStream(frame); + var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024); + + result.Count.ShouldBe(0); + readInfo.CloseReceived.ShouldBeTrue(); + readInfo.CloseStatus.ShouldBe(1000); + } + + [Fact] + public void ReadPongFrame_NoAction() + { + var frame = BuildFrame([], opcode: WsConstants.PongMessage); + + var readInfo = new WsReadInfo(expectMask: false); + var stream = new MemoryStream(frame); + var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024); + + result.Count.ShouldBe(0); + readInfo.PendingControlFrames.Count.ShouldBe(0); + } + + [Fact] + public void Unmask_Optimized_8ByteChunks() + { + byte[] key = [0xAA, 0xBB, 0xCC, 0xDD]; + var original = new byte[32]; + Random.Shared.NextBytes(original); + var masked = original.ToArray(); + + // Mask it + for (int i = 0; i < masked.Length; i++) + masked[i] ^= key[i & 3]; + + // Unmask using the state machine + var info = new WsReadInfo(expectMask: true); + info.SetMaskKey(key); + info.Unmask(masked); + + masked.ShouldBe(original); + } +} diff --git a/tests/NATS.Server.Tests/WebSocket/WsFrameWriterTests.cs b/tests/NATS.Server.Tests/WebSocket/WsFrameWriterTests.cs new file mode 100644 index 0000000..153b120 --- /dev/null +++ b/tests/NATS.Server.Tests/WebSocket/WsFrameWriterTests.cs @@ -0,0 +1,152 @@ +using System.Buffers.Binary; +using NATS.Server.WebSocket; +using Shouldly; + +namespace NATS.Server.Tests.WebSocket; + +public class WsFrameWriterTests +{ + [Fact] + public void CreateFrameHeader_SmallPayload_7BitLength() + { + var (header, _) = WsFrameWriter.CreateFrameHeader( + useMasking: false, compressed: false, + opcode: WsConstants.BinaryMessage, payloadLength: 100); + header.Length.ShouldBe(2); + (header[0] & WsConstants.FinalBit).ShouldNotBe(0); // FIN set + (header[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage); + (header[1] & 0x7F).ShouldBe(100); + } + + [Fact] + public void CreateFrameHeader_MediumPayload_16BitLength() + { + var (header, _) = WsFrameWriter.CreateFrameHeader( + useMasking: false, compressed: false, + opcode: WsConstants.BinaryMessage, payloadLength: 1000); + header.Length.ShouldBe(4); + (header[1] & 0x7F).ShouldBe(126); + BinaryPrimitives.ReadUInt16BigEndian(header.AsSpan(2)).ShouldBe((ushort)1000); + } + + [Fact] + public void CreateFrameHeader_LargePayload_64BitLength() + { + var (header, _) = WsFrameWriter.CreateFrameHeader( + useMasking: false, compressed: false, + opcode: WsConstants.BinaryMessage, payloadLength: 70000); + header.Length.ShouldBe(10); + (header[1] & 0x7F).ShouldBe(127); + BinaryPrimitives.ReadUInt64BigEndian(header.AsSpan(2)).ShouldBe(70000UL); + } + + [Fact] + public void CreateFrameHeader_WithMasking_Adds4ByteKey() + { + var (header, key) = WsFrameWriter.CreateFrameHeader( + useMasking: true, compressed: false, + opcode: WsConstants.BinaryMessage, payloadLength: 10); + header.Length.ShouldBe(6); // 2 header + 4 mask key + (header[1] & WsConstants.MaskBit).ShouldNotBe(0); + key.ShouldNotBeNull(); + key.Length.ShouldBe(4); + } + + [Fact] + public void CreateFrameHeader_Compressed_SetsRsv1Bit() + { + var (header, _) = WsFrameWriter.CreateFrameHeader( + useMasking: false, compressed: true, + opcode: WsConstants.BinaryMessage, payloadLength: 10); + (header[0] & WsConstants.Rsv1Bit).ShouldNotBe(0); + } + + [Fact] + public void MaskBuf_XorsCorrectly() + { + byte[] key = [0xAA, 0xBB, 0xCC, 0xDD]; + byte[] data = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06]; + byte[] expected = new byte[data.Length]; + for (int i = 0; i < data.Length; i++) + expected[i] = (byte)(data[i] ^ key[i & 3]); + + WsFrameWriter.MaskBuf(key, data); + data.ShouldBe(expected); + } + + [Fact] + public void MaskBuf_RoundTrip() + { + byte[] key = [0x12, 0x34, 0x56, 0x78]; + byte[] original = "Hello, WebSocket!"u8.ToArray(); + var data = original.ToArray(); + + WsFrameWriter.MaskBuf(key, data); + data.ShouldNotBe(original); + WsFrameWriter.MaskBuf(key, data); + data.ShouldBe(original); + } + + [Fact] + public void CreateCloseMessage_WithStatusAndBody() + { + var msg = WsFrameWriter.CreateCloseMessage(1000, "normal closure"); + msg.Length.ShouldBe(2 + "normal closure".Length); + BinaryPrimitives.ReadUInt16BigEndian(msg).ShouldBe((ushort)1000); + } + + [Fact] + public void CreateCloseMessage_LongBody_Truncated() + { + var longBody = new string('x', 200); + var msg = WsFrameWriter.CreateCloseMessage(1000, longBody); + msg.Length.ShouldBeLessThanOrEqualTo(WsConstants.MaxControlPayloadSize); + } + + [Fact] + public void MapCloseStatus_ClientClosed_NormalClosure() + { + WsFrameWriter.MapCloseStatus(ClientClosedReason.ClientClosed) + .ShouldBe(WsConstants.CloseStatusNormalClosure); + } + + [Fact] + public void MapCloseStatus_AuthTimeout_PolicyViolation() + { + WsFrameWriter.MapCloseStatus(ClientClosedReason.AuthenticationTimeout) + .ShouldBe(WsConstants.CloseStatusPolicyViolation); + } + + [Fact] + public void MapCloseStatus_ParseError_ProtocolError() + { + WsFrameWriter.MapCloseStatus(ClientClosedReason.ParseError) + .ShouldBe(WsConstants.CloseStatusProtocolError); + } + + [Fact] + public void MapCloseStatus_MaxPayload_MessageTooBig() + { + WsFrameWriter.MapCloseStatus(ClientClosedReason.MaxPayloadExceeded) + .ShouldBe(WsConstants.CloseStatusMessageTooBig); + } + + [Fact] + public void BuildControlFrame_PingNomask() + { + var frame = WsFrameWriter.BuildControlFrame(WsConstants.PingMessage, [], useMasking: false); + frame.Length.ShouldBe(2); + (frame[0] & WsConstants.FinalBit).ShouldNotBe(0); + (frame[0] & 0x0F).ShouldBe(WsConstants.PingMessage); + (frame[1] & 0x7F).ShouldBe(0); + } + + [Fact] + public void BuildControlFrame_PongWithPayload() + { + byte[] payload = [1, 2, 3, 4]; + var frame = WsFrameWriter.BuildControlFrame(WsConstants.PongMessage, payload, useMasking: false); + frame.Length.ShouldBe(2 + 4); + frame[2..].ShouldBe(payload); + } +} diff --git a/tests/NATS.Server.Tests/WebSocket/WsIntegrationTests.cs b/tests/NATS.Server.Tests/WebSocket/WsIntegrationTests.cs new file mode 100644 index 0000000..c20edee --- /dev/null +++ b/tests/NATS.Server.Tests/WebSocket/WsIntegrationTests.cs @@ -0,0 +1,162 @@ +using System.Buffers.Binary; +using System.Net; +using System.Net.Sockets; +using System.Security.Cryptography; +using System.Text; +using NATS.Server.WebSocket; + +namespace NATS.Server.Tests.WebSocket; + +public class WsIntegrationTests : IAsyncLifetime +{ + private NatsServer _server = null!; + private NatsOptions _options = null!; + + public async Task InitializeAsync() + { + _options = new NatsOptions + { + Port = 0, + WebSocket = new WebSocketOptions { Port = 0, NoTls = true }, + }; + var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(b => { }); + _server = new NatsServer(_options, loggerFactory); + _ = _server.StartAsync(CancellationToken.None); + await _server.WaitForReadyAsync(); + } + + public async Task DisposeAsync() + { + await _server.ShutdownAsync(); + _server.Dispose(); + } + + [Fact] + public async Task WebSocket_ConnectAndReceiveInfo() + { + using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port)); + using var stream = new NetworkStream(socket, ownsSocket: false); + + await SendUpgradeRequest(stream); + var response = await ReadHttpResponse(stream); + response.ShouldContain("101"); + + var wsFrame = await ReadWsFrame(stream); + var info = Encoding.ASCII.GetString(wsFrame); + info.ShouldStartWith("INFO "); + } + + [Fact] + public async Task WebSocket_ConnectAndPing() + { + using var client = await ConnectWsClient(); + + // Send CONNECT and PING together + await SendWsText(client, "CONNECT {}\r\nPING\r\n"); + + // Read PONG WS frame + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + var pong = await ReadWsFrameAsync(client, cts.Token); + Encoding.ASCII.GetString(pong).ShouldContain("PONG"); + } + + [Fact] + public async Task WebSocket_PubSub() + { + using var sub = await ConnectWsClient(); + using var pub = await ConnectWsClient(); + + await SendWsText(sub, "CONNECT {}\r\nSUB test.ws 1\r\n"); + await Task.Delay(200); + + await SendWsText(pub, "CONNECT {}\r\nPUB test.ws 5\r\nHello\r\n"); + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + var msg = await ReadWsFrameAsync(sub, cts.Token); + Encoding.ASCII.GetString(msg).ShouldContain("MSG test.ws 1 5"); + } + + private async Task ConnectWsClient() + { + var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port)); + var stream = new NetworkStream(socket, ownsSocket: true); + + await SendUpgradeRequest(stream); + var response = await ReadHttpResponse(stream); + response.ShouldContain("101"); + + await ReadWsFrame(stream); // Read INFO frame + return stream; + } + + private static async Task SendUpgradeRequest(NetworkStream stream) + { + var keyBytes = new byte[16]; + RandomNumberGenerator.Fill(keyBytes); + var key = Convert.ToBase64String(keyBytes); + + var request = $"GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {key}\r\nSec-WebSocket-Version: 13\r\n\r\n"; + await stream.WriteAsync(Encoding.ASCII.GetBytes(request)); + await stream.FlushAsync(); + } + + private static async Task ReadHttpResponse(NetworkStream stream) + { + // Read one byte at a time to avoid consuming WS frame bytes that follow the HTTP response + var sb = new StringBuilder(); + var buf = new byte[1]; + while (true) + { + int n = await stream.ReadAsync(buf); + if (n == 0) break; + sb.Append((char)buf[0]); + if (sb.Length >= 4 && + sb[^4] == '\r' && sb[^3] == '\n' && + sb[^2] == '\r' && sb[^1] == '\n') + break; + } + + return sb.ToString(); + } + + private static Task ReadWsFrame(NetworkStream stream) + => ReadWsFrameAsync(stream, CancellationToken.None); + + private static async Task ReadWsFrameAsync(NetworkStream stream, CancellationToken ct) + { + var header = new byte[2]; + await stream.ReadExactlyAsync(header, ct); + int len = header[1] & 0x7F; + if (len == 126) + { + var extLen = new byte[2]; + await stream.ReadExactlyAsync(extLen, ct); + len = BinaryPrimitives.ReadUInt16BigEndian(extLen); + } + else if (len == 127) + { + var extLen = new byte[8]; + await stream.ReadExactlyAsync(extLen, ct); + len = (int)BinaryPrimitives.ReadUInt64BigEndian(extLen); + } + + var payload = new byte[len]; + if (len > 0) await stream.ReadExactlyAsync(payload, ct); + return payload; + } + + private static async Task SendWsText(NetworkStream stream, string text) + { + var payload = Encoding.ASCII.GetBytes(text); + var (header, _) = WsFrameWriter.CreateFrameHeader( + useMasking: true, compressed: false, + opcode: WsConstants.BinaryMessage, payloadLength: payload.Length); + var maskKey = header[^4..]; + WsFrameWriter.MaskBuf(maskKey, payload); + await stream.WriteAsync(header); + await stream.WriteAsync(payload); + await stream.FlushAsync(); + } +} diff --git a/tests/NATS.Server.Tests/WebSocket/WsOriginCheckerTests.cs b/tests/NATS.Server.Tests/WebSocket/WsOriginCheckerTests.cs new file mode 100644 index 0000000..ebd3531 --- /dev/null +++ b/tests/NATS.Server.Tests/WebSocket/WsOriginCheckerTests.cs @@ -0,0 +1,82 @@ +using NATS.Server.WebSocket; +using Shouldly; + +namespace NATS.Server.Tests.WebSocket; + +public class WsOriginCheckerTests +{ + [Fact] + public void NoOriginHeader_Accepted() + { + var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); + checker.CheckOrigin(origin: null, requestHost: "localhost:4222", isTls: false) + .ShouldBeNull(); + } + + [Fact] + public void NeitherSameNorList_AlwaysAccepted() + { + var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: null); + checker.CheckOrigin("https://evil.com", "localhost:4222", false) + .ShouldBeNull(); + } + + [Fact] + public void SameOrigin_Match() + { + var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); + checker.CheckOrigin("http://localhost:4222", "localhost:4222", false) + .ShouldBeNull(); + } + + [Fact] + public void SameOrigin_Mismatch() + { + var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); + checker.CheckOrigin("http://other:4222", "localhost:4222", false) + .ShouldNotBeNull(); + } + + [Fact] + public void SameOrigin_DefaultPort_Http() + { + var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); + checker.CheckOrigin("http://localhost", "localhost:80", false) + .ShouldBeNull(); + } + + [Fact] + public void SameOrigin_DefaultPort_Https() + { + var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); + checker.CheckOrigin("https://localhost", "localhost:443", true) + .ShouldBeNull(); + } + + [Fact] + public void AllowedOrigins_Match() + { + var checker = new WsOriginChecker(sameOrigin: false, + allowedOrigins: ["https://app.example.com"]); + checker.CheckOrigin("https://app.example.com", "localhost:4222", false) + .ShouldBeNull(); + } + + [Fact] + public void AllowedOrigins_Mismatch() + { + var checker = new WsOriginChecker(sameOrigin: false, + allowedOrigins: ["https://app.example.com"]); + checker.CheckOrigin("https://evil.example.com", "localhost:4222", false) + .ShouldNotBeNull(); + } + + [Fact] + public void AllowedOrigins_SchemeMismatch() + { + var checker = new WsOriginChecker(sameOrigin: false, + allowedOrigins: ["https://app.example.com"]); + checker.CheckOrigin("http://app.example.com", "localhost:4222", false) + .ShouldNotBeNull(); + } +} diff --git a/tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs b/tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs new file mode 100644 index 0000000..a5e1168 --- /dev/null +++ b/tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs @@ -0,0 +1,226 @@ +using System.Text; +using NATS.Server.WebSocket; + +namespace NATS.Server.Tests.WebSocket; + +public class WsUpgradeTests +{ + private static string BuildValidRequest(string path = "/", string? extraHeaders = null) + { + var sb = new StringBuilder(); + sb.Append($"GET {path} HTTP/1.1\r\n"); + sb.Append("Host: localhost:4222\r\n"); + sb.Append("Upgrade: websocket\r\n"); + sb.Append("Connection: Upgrade\r\n"); + sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"); + sb.Append("Sec-WebSocket-Version: 13\r\n"); + if (extraHeaders != null) + sb.Append(extraHeaders); + sb.Append("\r\n"); + return sb.ToString(); + } + + [Fact] + public async Task ValidUpgrade_Returns101() + { + var request = BuildValidRequest(); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Kind.ShouldBe(WsClientKind.Client); + var response = ReadResponse(outputStream); + response.ShouldContain("HTTP/1.1 101"); + response.ShouldContain("Upgrade: websocket"); + response.ShouldContain("Sec-WebSocket-Accept:"); + } + + [Fact] + public async Task MissingUpgradeHeader_Returns400() + { + var request = "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n"; + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeFalse(); + ReadResponse(outputStream).ShouldContain("400"); + } + + [Fact] + public async Task MissingHost_Returns400() + { + var request = "GET / HTTP/1.1\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n"; + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeFalse(); + } + + [Fact] + public async Task WrongVersion_Returns400() + { + var request = "GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 12\r\n\r\n"; + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeFalse(); + } + + [Fact] + public async Task LeafNodePath_ReturnsLeafKind() + { + var request = BuildValidRequest("/leafnode"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Kind.ShouldBe(WsClientKind.Leaf); + } + + [Fact] + public async Task MqttPath_ReturnsMqttKind() + { + var request = BuildValidRequest("/mqtt"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Kind.ShouldBe(WsClientKind.Mqtt); + } + + [Fact] + public async Task CompressionNegotiation_WhenEnabled() + { + var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}; {WsConstants.PmcSrvNoCtx}; {WsConstants.PmcCliNoCtx}\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true }); + + result.Success.ShouldBeTrue(); + result.Compress.ShouldBeTrue(); + ReadResponse(outputStream).ShouldContain("permessage-deflate"); + } + + [Fact] + public async Task CompressionNegotiation_WhenDisabled() + { + var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = false }); + + result.Success.ShouldBeTrue(); + result.Compress.ShouldBeFalse(); + } + + [Fact] + public async Task NoMaskingHeader_ForLeaf() + { + var request = BuildValidRequest("/leafnode", "Nats-No-Masking: true\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.MaskRead.ShouldBeFalse(); + } + + [Fact] + public async Task BrowserDetection_Mozilla() + { + var request = BuildValidRequest(extraHeaders: "User-Agent: Mozilla/5.0 (Windows)\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.Browser.ShouldBeTrue(); + } + + [Fact] + public async Task SafariDetection_NoCompFrag() + { + var request = BuildValidRequest(extraHeaders: + "User-Agent: Mozilla/5.0 (Macintosh) Version/15.0 Safari/605.1.15\r\n" + + $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true }); + + result.Success.ShouldBeTrue(); + result.NoCompFrag.ShouldBeTrue(); + } + + [Fact] + public void AcceptKey_MatchesRfc6455Example() + { + // RFC 6455 Section 4.2.2 example + var key = WsUpgrade.ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ=="); + key.ShouldBe("s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); + } + + [Fact] + public async Task CookieExtraction() + { + var request = BuildValidRequest(extraHeaders: + "Cookie: jwt_token=my-jwt; nats_user=admin; nats_pass=secret\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var opts = new WebSocketOptions + { + NoTls = true, + JwtCookie = "jwt_token", + UsernameCookie = "nats_user", + PasswordCookie = "nats_pass", + }; + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts); + + result.Success.ShouldBeTrue(); + result.CookieJwt.ShouldBe("my-jwt"); + result.CookieUsername.ShouldBe("admin"); + result.CookiePassword.ShouldBe("secret"); + } + + [Fact] + public async Task XForwardedFor_ExtractsClientIp() + { + var request = BuildValidRequest(extraHeaders: "X-Forwarded-For: 192.168.1.100\r\n"); + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeTrue(); + result.ClientIp.ShouldBe("192.168.1.100"); + } + + [Fact] + public async Task PostMethod_Returns405() + { + var request = "POST / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n"; + var (inputStream, outputStream) = CreateStreamPair(request); + + var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); + + result.Success.ShouldBeFalse(); + ReadResponse(outputStream).ShouldContain("405"); + } + + // Helper: create a readable input stream and writable output stream + private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest) + { + var inputBytes = Encoding.ASCII.GetBytes(httpRequest); + return (new MemoryStream(inputBytes), new MemoryStream()); + } + + private static string ReadResponse(MemoryStream output) + { + output.Position = 0; + return Encoding.ASCII.GetString(output.ToArray()); + } +}