Compare commits
16 Commits
feature/sy
...
1ebf283a8c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ebf283a8c | ||
|
|
18a6d0f478 | ||
|
|
02a474a91e | ||
|
|
c8a89c9de2 | ||
|
|
5fd2cf040d | ||
|
|
ca88036126 | ||
|
|
6d0a4d259e | ||
|
|
fe304dfe01 | ||
|
|
1c948b5b0f | ||
|
|
bd29c529a8 | ||
|
|
1a1aa9d642 | ||
|
|
d49bc5b0d7 | ||
|
|
8ded10d49b | ||
|
|
6981a38b72 | ||
|
|
72f60054ed | ||
|
|
708e1b4168 |
@@ -67,7 +67,7 @@
|
||||
| SYSTEM (internal) | Y | N | |
|
||||
| JETSTREAM (internal) | Y | N | |
|
||||
| ACCOUNT (internal) | Y | N | |
|
||||
| WebSocket clients | Y | N | |
|
||||
| WebSocket clients | Y | Y | Custom frame parser, permessage-deflate compression, origin checking, cookie auth |
|
||||
| MQTT clients | Y | N | |
|
||||
|
||||
### Client Features
|
||||
@@ -267,7 +267,8 @@ Go implements a sophisticated slow consumer detection system:
|
||||
- ~~Advanced limits (MaxSubs, MaxSubTokens, MaxPending, WriteDeadline)~~ — `MaxSubs`, `MaxSubTokens` implemented; MaxPending/WriteDeadline already existed
|
||||
- ~~Tags/metadata~~ — `Tags` dictionary implemented in `NatsOptions`
|
||||
- ~~OCSP configuration~~ — `OcspConfig` with 4 modes (Auto/Always/Must/Never), peer verification, and stapling
|
||||
- WebSocket/MQTT options
|
||||
- ~~WebSocket options~~ — `WebSocketOptions` with port, compression, origin checking, cookie auth, custom headers
|
||||
- MQTT options
|
||||
- ~~Operator mode / account resolver~~ — `JwtAuthenticator` + `IAccountResolver` + `MemAccountResolver` with trusted keys
|
||||
|
||||
---
|
||||
|
||||
141
docs/plans/2026-02-23-jetstream-full-parity-design.md
Normal file
141
docs/plans/2026-02-23-jetstream-full-parity-design.md
Normal file
@@ -0,0 +1,141 @@
|
||||
# Full JetStream and Cluster Prerequisite Parity Design
|
||||
|
||||
**Date:** 2026-02-23
|
||||
**Status:** Approved
|
||||
**Scope:** Port JetStream from Go with all prerequisite subsystems required for full Go JetStream test parity, including cluster route/gateway/leaf behaviors and RAFT/meta-cluster semantics.
|
||||
**Verification Gate:** Go JetStream-focused test suites in `golang/nats-server/server/` plus new/updated .NET tests.
|
||||
**Cutover Model:** Single end-to-end cutover (no interim acceptance gates).
|
||||
|
||||
## 1. Architecture
|
||||
|
||||
The implementation uses a full in-process .NET parity architecture that mirrors Go subsystem boundaries while keeping strict internal contracts.
|
||||
|
||||
1. Core Server Layer (`NatsServer`/`NatsClient`)
|
||||
- Extend existing server/client runtime to support full client kinds and inter-server protocol paths.
|
||||
- Preserve responsibility for socket lifecycle, parser integration, auth entry, and local dispatch.
|
||||
|
||||
2. Cluster Fabric Layer
|
||||
- Add route mesh, gateway links, leafnode links, interest propagation, and remote subscription accounting.
|
||||
- Provide transport-neutral contracts consumed by JetStream and RAFT replication services.
|
||||
|
||||
3. JetStream Control Plane
|
||||
- Add account-scoped JetStream managers, API subject handlers (`$JS.API.*`), stream/consumer metadata lifecycle, advisories, and limit enforcement.
|
||||
- Integrate with RAFT/meta services for replicated decisions.
|
||||
|
||||
4. JetStream Data Plane
|
||||
- Add stream ingest path, retention/eviction logic, consumer delivery/ack/redelivery, mirror/source orchestration, and flow-control behavior.
|
||||
- Use pluggable storage abstractions with parity-focused behavior.
|
||||
|
||||
5. RAFT and Replication Layer
|
||||
- Implement meta-group plus per-asset replication groups, election/term logic, log replication, snapshots, and catchup.
|
||||
- Expose deterministic commit/applied hooks to JetStream runtime layers.
|
||||
|
||||
6. Storage Layer
|
||||
- Implement memstore and filestore with sequence indexing, subject indexing, compaction/snapshot support, and recovery semantics.
|
||||
|
||||
7. Observability Layer
|
||||
- Upgrade `/jsz` and `/varz` JetStream blocks from placeholders to live runtime reporting with Go-compatible response shape.
|
||||
|
||||
## 2. Components and Contracts
|
||||
|
||||
### 2.1 New component families
|
||||
|
||||
1. Cluster and interserver subsystem
|
||||
- Add route/gateway/leaf and interserver protocol operations under `src/NATS.Server/`.
|
||||
- Extend parser/dispatcher with route/leaf/account operations currently excluded.
|
||||
- Expand client-kind model and command routing constraints.
|
||||
|
||||
2. JetStream API and domain model
|
||||
- Add `src/NATS.Server/JetStream/` subtree for API payload models, stream/consumer models, and error templates/codes.
|
||||
|
||||
3. JetStream runtime
|
||||
- Add stream manager, consumer manager, ack processor, delivery scheduler, mirror/source orchestration, and flow control handlers.
|
||||
- Integrate publish path with stream capture/store/ack behavior.
|
||||
|
||||
4. RAFT subsystem
|
||||
- Add `src/NATS.Server/Raft/` for replicated logs, elections, snapshots, and membership operations.
|
||||
|
||||
5. Storage subsystem
|
||||
- Add `src/NATS.Server/JetStream/Storage/` for `MemStore` and `FileStore`, sequence/subject indexes, and restart recovery.
|
||||
|
||||
### 2.2 Existing components to upgrade
|
||||
|
||||
1. `src/NATS.Server/NatsOptions.cs`
|
||||
- Add full config surface for clustering, JetStream, storage, placement, and parity-required limits.
|
||||
|
||||
2. `src/NATS.Server/Configuration/ConfigProcessor.cs`
|
||||
- Replace silent ignore behavior for cluster/jetstream keys with parsing, mapping, and validation.
|
||||
|
||||
3. `src/NATS.Server/Protocol/NatsParser.cs` and `src/NATS.Server/NatsClient.cs`
|
||||
- Add missing interserver operations and kind-aware dispatch paths needed for clustered JetStream behavior.
|
||||
|
||||
4. Monitoring components
|
||||
- Upgrade `src/NATS.Server/Monitoring/MonitorServer.cs` and `src/NATS.Server/Monitoring/Varz.cs`.
|
||||
- Add/extend JS monitoring handlers and models for `/jsz` and JetStream runtime fields.
|
||||
|
||||
## 3. Data Flow and Behavioral Semantics
|
||||
|
||||
1. Inbound publish path
|
||||
- Parse client publish commands, apply auth/permission checks, route to local subscribers and JetStream candidates.
|
||||
- For JetStream subjects: apply preconditions, append to store, replicate via RAFT (as required), apply committed state, return Go-compatible pub ack.
|
||||
|
||||
2. Consumer delivery path
|
||||
- Use shared push/pull state model for pending, ack floor, redelivery timers, flow control, and max ack pending.
|
||||
- Enforce retention policy semantics (limits/interest/workqueue), filter subject behavior, replay policy, and eviction behavior.
|
||||
|
||||
3. Replication and control flow
|
||||
- Meta RAFT governs replicated metadata decisions.
|
||||
- Per-stream/per-consumer groups replicate state and snapshots.
|
||||
- Leader changes preserve at-least-once delivery and consumer state invariants.
|
||||
|
||||
4. Recovery flow
|
||||
- Reconstruct stream/consumer/store state on startup.
|
||||
- In clustered mode, rejoin replication groups and catch up before serving full API/delivery workload.
|
||||
- Preserve sequence continuity, subject indexes, delete markers, and pending/redelivery state.
|
||||
|
||||
5. Monitoring flow
|
||||
- `/varz` JetStream fields and `/jsz` return live runtime state.
|
||||
- Advisory and metric surfaces update from control-plane and data-plane events.
|
||||
|
||||
## 4. Error Handling and Operational Constraints
|
||||
|
||||
1. API error parity
|
||||
- Match canonical JetStream codes/messages for validation failures, state conflicts, limits, leadership/quorum issues, and storage failures.
|
||||
|
||||
2. Protocol behavior
|
||||
- Preserve normal client compatibility while adding interserver protocol and internal client-kind restrictions.
|
||||
|
||||
3. Storage and consistency failures
|
||||
- Classify corruption/truncation/checksum/snapshot failures as recoverable vs non-recoverable.
|
||||
- Avoid silent data loss and emit monitoring/advisory signals where parity requires.
|
||||
|
||||
4. Cluster and RAFT fault handling
|
||||
- Explicitly handle no-quorum, stale leader, delayed apply, peer removal, catchup lag, and stepdown transitions.
|
||||
- Return leadership-aware API errors.
|
||||
|
||||
5. Config/reload behavior
|
||||
- Treat JetStream and cluster config as first-class with strict validation.
|
||||
- Mirror Go-like reloadable vs restart-required change boundaries.
|
||||
|
||||
## 5. Testing and Verification Strategy
|
||||
|
||||
1. .NET unit tests
|
||||
- Add focused tests for JetStream API validation, stream and consumer state, RAFT primitives, mem/file store invariants, and config parsing/validation.
|
||||
|
||||
2. .NET integration tests
|
||||
- Add end-to-end tests for publish/store/consume/ack behavior, retention policies, restart recovery, and clustered prerequisites used by JetStream.
|
||||
|
||||
3. Parity harness
|
||||
- Maintain mapping of Go JetStream test categories to .NET feature areas.
|
||||
- Execute JetStream-focused Go tests from `golang/nats-server/server/` as acceptance benchmark.
|
||||
|
||||
4. `differences.md` policy
|
||||
- Update only after verification gate passes.
|
||||
- Remove opening JetStream exclusion scope statement and replace with updated parity scope.
|
||||
|
||||
## 6. Scope Decisions Captured
|
||||
|
||||
- Include all prerequisite non-JetStream subsystems required to satisfy full Go JetStream tests.
|
||||
- Verification target is full Go JetStream-focused parity, not a narrowed subset.
|
||||
- Delivery model is single end-to-end cutover.
|
||||
- `differences.md` top-level scope statement will be updated to include JetStream and clustering parity coverage once verified.
|
||||
@@ -1,18 +1,21 @@
|
||||
# MQTT Connection Type Port Design
|
||||
|
||||
## Goal
|
||||
Port MQTT-related connection type parity from Go into the .NET server for two scoped areas:
|
||||
Port MQTT-related connection type parity from Go into the .NET server for three scoped areas:
|
||||
1. JWT `allowed_connection_types` behavior for `MQTT` / `MQTT_WS` (plus existing known types).
|
||||
2. `/connz` filtering by `mqtt_client`.
|
||||
3. Full MQTT configuration parsing from `mqtt {}` config blocks (all Go `MQTTOpts` fields).
|
||||
|
||||
## Scope
|
||||
- In scope:
|
||||
- JWT allowed connection type normalization and enforcement semantics.
|
||||
- `/connz?mqtt_client=` option parsing and filtering.
|
||||
- MQTT configuration model and config file parsing (all Go `MQTTOpts` fields).
|
||||
- Expanded `MqttOptsVarz` monitoring output.
|
||||
- Unit/integration tests for new and updated behavior.
|
||||
- `differences.md` updates after implementation is verified.
|
||||
- Out of scope:
|
||||
- Full MQTT transport implementation.
|
||||
- Full MQTT transport implementation (listener, protocol parser, sessions).
|
||||
- WebSocket transport implementation.
|
||||
- Leaf/route/gateway transport plumbing.
|
||||
|
||||
@@ -27,6 +30,8 @@ Port MQTT-related connection type parity from Go into the .NET server for two sc
|
||||
- Extend connz monitoring options to parse `mqtt_client` and apply exact-match filtering before sort/pagination.
|
||||
|
||||
## Components
|
||||
|
||||
### JWT Connection-Type Enforcement
|
||||
- `src/NATS.Server/Auth/IAuthenticator.cs`
|
||||
- Extend `ClientAuthContext` with a connection-type value.
|
||||
- `src/NATS.Server/Auth/Jwt/JwtConnectionTypes.cs` (new)
|
||||
@@ -38,6 +43,8 @@ Port MQTT-related connection type parity from Go into the .NET server for two sc
|
||||
- Enforce against current `ClientAuthContext.ConnectionType`.
|
||||
- `src/NATS.Server/NatsClient.cs`
|
||||
- Populate auth context connection type (currently `STANDARD`).
|
||||
|
||||
### Connz MQTT Client Filtering
|
||||
- `src/NATS.Server/Monitoring/Connz.cs`
|
||||
- Add `MqttClient` to `ConnzOptions` with JSON field `mqtt_client`.
|
||||
- `src/NATS.Server/Monitoring/ConnzHandler.cs`
|
||||
@@ -48,6 +55,30 @@ Port MQTT-related connection type parity from Go into the .NET server for two sc
|
||||
- `src/NATS.Server/NatsServer.cs`
|
||||
- Persist `MqttClient` into `ClosedClient` snapshot (empty for now).
|
||||
|
||||
### MQTT Configuration Parsing
|
||||
- `src/NATS.Server/MqttOptions.cs` (new)
|
||||
- Full model matching Go `MQTTOpts` struct (opts.go:613-707):
|
||||
- Network: `Host`, `Port`
|
||||
- Auth override: `NoAuthUser`, `Username`, `Password`, `Token`, `AuthTimeout`
|
||||
- TLS: `TlsCert`, `TlsKey`, `TlsCaCert`, `TlsVerify`, `TlsTimeout`, `TlsMap`, `TlsPinnedCerts`
|
||||
- JetStream: `JsDomain`, `StreamReplicas`, `ConsumerReplicas`, `ConsumerMemoryStorage`, `ConsumerInactiveThreshold`
|
||||
- QoS: `AckWait`, `MaxAckPending`, `JsApiTimeout`
|
||||
- `src/NATS.Server/NatsOptions.cs`
|
||||
- Add `Mqtt` property of type `MqttOptions?`.
|
||||
- `src/NATS.Server/Configuration/ConfigProcessor.cs`
|
||||
- Add `ParseMqtt()` for `mqtt {}` config block with Go-compatible key aliases:
|
||||
- `host`/`net` → Host, `listen` → Host+Port
|
||||
- `ack_wait`/`ackwait` → AckWait
|
||||
- `max_ack_pending`/`max_pending`/`max_inflight` → MaxAckPending
|
||||
- `js_domain` → JsDomain
|
||||
- `js_api_timeout`/`api_timeout` → JsApiTimeout
|
||||
- `consumer_inactive_threshold`/`consumer_auto_cleanup` → ConsumerInactiveThreshold
|
||||
- Nested `tls {}` and `authorization {}`/`authentication {}` blocks
|
||||
- `src/NATS.Server/Monitoring/Varz.cs`
|
||||
- Expand `MqttOptsVarz` from 3 fields to full monitoring-visible set.
|
||||
- `src/NATS.Server/Monitoring/VarzHandler.cs`
|
||||
- Populate expanded `MqttOptsVarz` from `NatsOptions.Mqtt`.
|
||||
|
||||
## Data Flow
|
||||
1. Client sends `CONNECT`.
|
||||
2. `NatsClient.ProcessConnectAsync` builds `ClientAuthContext` with `ConnectionType=STANDARD`.
|
||||
@@ -73,6 +104,7 @@ Port MQTT-related connection type parity from Go into the .NET server for two sc
|
||||
- MQTT transport is not implemented yet in this repository.
|
||||
- Runtime connection type currently resolves to `STANDARD` in auth context.
|
||||
- `mqtt_client` values remain empty until MQTT path populates them.
|
||||
- MQTT config is parsed and stored but no listener is started.
|
||||
|
||||
## Testing Strategy
|
||||
- `tests/NATS.Server.Tests/JwtAuthenticatorTests.cs`
|
||||
@@ -85,9 +117,16 @@ Port MQTT-related connection type parity from Go into the .NET server for two sc
|
||||
- `/connz?mqtt_client=<id>` returns matching connections only.
|
||||
- `/connz?state=closed&mqtt_client=<id>` filters closed snapshots.
|
||||
- non-existing ID yields empty connection set.
|
||||
- `tests/NATS.Server.Tests/ConfigProcessorTests.cs` (or similar)
|
||||
- Parse valid `mqtt {}` block with all fields.
|
||||
- Parse config with aliases (ackwait vs ack_wait, host vs net, etc.).
|
||||
- Parse nested `tls {}` and `authorization {}` blocks within mqtt.
|
||||
- Varz MQTT section populated from config.
|
||||
|
||||
## Success Criteria
|
||||
- JWT `allowed_connection_types` behavior matches Go semantics for known/unknown mixing and unknown-only rejection.
|
||||
- `/connz` supports exact `mqtt_client` filtering for open and closed sets.
|
||||
- `mqtt {}` config block parses all Go `MQTTOpts` fields with aliases.
|
||||
- `MqttOptsVarz` includes full monitoring output.
|
||||
- Added tests pass.
|
||||
- `differences.md` accurately reflects implemented parity.
|
||||
|
||||
@@ -11,6 +11,7 @@ using NATS.Server.Auth;
|
||||
using NATS.Server.Protocol;
|
||||
using NATS.Server.Subscriptions;
|
||||
using NATS.Server.Tls;
|
||||
using NATS.Server.WebSocket;
|
||||
|
||||
namespace NATS.Server;
|
||||
|
||||
@@ -93,6 +94,9 @@ public sealed class NatsClient : IDisposable
|
||||
private long _rtt;
|
||||
public TimeSpan Rtt => new(Interlocked.Read(ref _rtt));
|
||||
|
||||
public bool IsWebSocket { get; set; }
|
||||
public WsUpgradeResult? WsInfo { get; set; }
|
||||
|
||||
public TlsConnectionState? TlsState { get; set; }
|
||||
public bool InfoAlreadySent { get; set; }
|
||||
|
||||
|
||||
@@ -116,4 +116,32 @@ public sealed class NatsOptions
|
||||
public Dictionary<string, string>? SubjectMappings { get; set; }
|
||||
|
||||
public bool HasTls => TlsCert != null && TlsKey != null;
|
||||
|
||||
// WebSocket
|
||||
public WebSocketOptions WebSocket { get; set; } = new();
|
||||
}
|
||||
|
||||
public sealed class WebSocketOptions
|
||||
{
|
||||
public string Host { get; set; } = "0.0.0.0";
|
||||
public int Port { get; set; } = -1;
|
||||
public string? Advertise { get; set; }
|
||||
public string? NoAuthUser { get; set; }
|
||||
public string? JwtCookie { get; set; }
|
||||
public string? UsernameCookie { get; set; }
|
||||
public string? PasswordCookie { get; set; }
|
||||
public string? TokenCookie { get; set; }
|
||||
public string? Username { get; set; }
|
||||
public string? Password { get; set; }
|
||||
public string? Token { get; set; }
|
||||
public TimeSpan AuthTimeout { get; set; } = TimeSpan.FromSeconds(2);
|
||||
public bool NoTls { get; set; }
|
||||
public string? TlsCert { get; set; }
|
||||
public string? TlsKey { get; set; }
|
||||
public bool SameOrigin { get; set; }
|
||||
public List<string>? AllowedOrigins { get; set; }
|
||||
public bool Compression { get; set; }
|
||||
public TimeSpan HandshakeTimeout { get; set; } = TimeSpan.FromSeconds(2);
|
||||
public TimeSpan? PingInterval { get; set; }
|
||||
public Dictionary<string, string>? Headers { get; set; }
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ using NATS.Server.Monitoring;
|
||||
using NATS.Server.Protocol;
|
||||
using NATS.Server.Subscriptions;
|
||||
using NATS.Server.Tls;
|
||||
using NATS.Server.WebSocket;
|
||||
|
||||
namespace NATS.Server;
|
||||
|
||||
@@ -39,6 +40,8 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
private readonly TlsRateLimiter? _tlsRateLimiter;
|
||||
private readonly SubjectTransform[] _subjectTransforms;
|
||||
private Socket? _listener;
|
||||
private Socket? _wsListener;
|
||||
private readonly TaskCompletionSource _wsAcceptLoopExited = new(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
private MonitorServer? _monitorServer;
|
||||
private ulong _nextClientId;
|
||||
private long _startTimeTicks;
|
||||
@@ -93,11 +96,13 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
// Signal all internal loops to stop
|
||||
await _quitCts.CancelAsync();
|
||||
|
||||
// Close listener to stop accept loop
|
||||
// Close listeners to stop accept loops
|
||||
_listener?.Close();
|
||||
_wsListener?.Close();
|
||||
|
||||
// Wait for accept loop to exit
|
||||
// Wait for accept loops to exit
|
||||
await _acceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
|
||||
await _wsAcceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
|
||||
|
||||
// Close all client connections — flush first, then mark closed
|
||||
var flushTasks = new List<Task>();
|
||||
@@ -138,11 +143,13 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
|
||||
_logger.LogInformation("Entering lame duck mode, stop accepting new clients");
|
||||
|
||||
// Close listener to stop accepting new connections
|
||||
// Close listeners to stop accepting new connections
|
||||
_listener?.Close();
|
||||
_wsListener?.Close();
|
||||
|
||||
// Wait for accept loop to exit
|
||||
// Wait for accept loops to exit
|
||||
await _acceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
|
||||
await _wsAcceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
|
||||
|
||||
var gracePeriod = _options.LameDuckGracePeriod;
|
||||
if (gracePeriod < TimeSpan.Zero) gracePeriod = -gracePeriod;
|
||||
@@ -369,8 +376,6 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
BuildCachedInfo();
|
||||
}
|
||||
|
||||
_listeningStarted.TrySetResult();
|
||||
|
||||
_logger.LogInformation("Listening for client connections on {Host}:{Port}", _options.Host, _options.Port);
|
||||
|
||||
// Warn about stub features
|
||||
@@ -386,6 +391,31 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
WritePidFile();
|
||||
WritePortsFile();
|
||||
|
||||
if (_options.WebSocket.Port >= 0)
|
||||
{
|
||||
_wsListener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
_wsListener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true);
|
||||
_wsListener.Bind(new IPEndPoint(
|
||||
_options.WebSocket.Host == "0.0.0.0" ? IPAddress.Any : IPAddress.Parse(_options.WebSocket.Host),
|
||||
_options.WebSocket.Port));
|
||||
_wsListener.Listen(128);
|
||||
|
||||
if (_options.WebSocket.Port == 0)
|
||||
{
|
||||
_options.WebSocket.Port = ((IPEndPoint)_wsListener.LocalEndPoint!).Port;
|
||||
}
|
||||
|
||||
_logger.LogInformation("Listening for WebSocket clients on {Host}:{Port}",
|
||||
_options.WebSocket.Host, _options.WebSocket.Port);
|
||||
|
||||
if (_options.WebSocket.NoTls)
|
||||
_logger.LogWarning("WebSocket not configured with TLS. DO NOT USE IN PRODUCTION!");
|
||||
|
||||
_ = RunWebSocketAcceptLoopAsync(linked.Token);
|
||||
}
|
||||
|
||||
_listeningStarted.TrySetResult();
|
||||
|
||||
var tmpDelay = AcceptMinSleep;
|
||||
|
||||
try
|
||||
@@ -531,6 +561,102 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
}
|
||||
}
|
||||
|
||||
private async Task RunWebSocketAcceptLoopAsync(CancellationToken ct)
|
||||
{
|
||||
var tmpDelay = AcceptMinSleep;
|
||||
try
|
||||
{
|
||||
while (!ct.IsCancellationRequested)
|
||||
{
|
||||
Socket socket;
|
||||
try
|
||||
{
|
||||
socket = await _wsListener!.AcceptAsync(ct);
|
||||
tmpDelay = AcceptMinSleep;
|
||||
}
|
||||
catch (OperationCanceledException) { break; }
|
||||
catch (ObjectDisposedException) { break; }
|
||||
catch (SocketException ex)
|
||||
{
|
||||
if (IsShuttingDown || IsLameDuckMode) break;
|
||||
_logger.LogError(ex, "Temporary WebSocket accept error, sleeping {Delay}ms", tmpDelay.TotalMilliseconds);
|
||||
try { await Task.Delay(tmpDelay, ct); } catch (OperationCanceledException) { break; }
|
||||
tmpDelay = TimeSpan.FromTicks(Math.Min(tmpDelay.Ticks * 2, AcceptMaxSleep.Ticks));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (_options.MaxConnections > 0 && _clients.Count >= _options.MaxConnections)
|
||||
{
|
||||
socket.Dispose();
|
||||
continue;
|
||||
}
|
||||
|
||||
var clientId = Interlocked.Increment(ref _nextClientId);
|
||||
Interlocked.Increment(ref _stats.TotalConnections);
|
||||
Interlocked.Increment(ref _activeClientCount);
|
||||
|
||||
_ = AcceptWebSocketClientAsync(socket, clientId, ct);
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
_wsAcceptLoopExited.TrySetResult();
|
||||
}
|
||||
}
|
||||
|
||||
private async Task AcceptWebSocketClientAsync(Socket socket, ulong clientId, CancellationToken ct)
|
||||
{
|
||||
try
|
||||
{
|
||||
var networkStream = new NetworkStream(socket, ownsSocket: false);
|
||||
Stream stream = networkStream;
|
||||
|
||||
// TLS negotiation if configured
|
||||
if (_sslOptions != null && !_options.WebSocket.NoTls)
|
||||
{
|
||||
var (tlsStream, _) = await TlsConnectionWrapper.NegotiateAsync(
|
||||
socket, networkStream, _options, _sslOptions, _serverInfo,
|
||||
_loggerFactory.CreateLogger("NATS.Server.Tls"), ct);
|
||||
stream = tlsStream;
|
||||
}
|
||||
|
||||
// HTTP upgrade handshake
|
||||
var upgradeResult = await WsUpgrade.TryUpgradeAsync(stream, stream, _options.WebSocket, ct);
|
||||
if (!upgradeResult.Success)
|
||||
{
|
||||
_logger.LogDebug("WebSocket upgrade failed for client {ClientId}", clientId);
|
||||
socket.Dispose();
|
||||
Interlocked.Decrement(ref _activeClientCount);
|
||||
return;
|
||||
}
|
||||
|
||||
// Create WsConnection wrapper
|
||||
var wsConn = new WsConnection(stream,
|
||||
compress: upgradeResult.Compress,
|
||||
maskRead: upgradeResult.MaskRead,
|
||||
maskWrite: upgradeResult.MaskWrite,
|
||||
browser: upgradeResult.Browser,
|
||||
noCompFrag: upgradeResult.NoCompFrag);
|
||||
|
||||
var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]");
|
||||
var client = new NatsClient(clientId, wsConn, socket, _options, _serverInfo,
|
||||
_authService, null, clientLogger, _stats);
|
||||
client.Router = this;
|
||||
client.IsWebSocket = true;
|
||||
client.WsInfo = upgradeResult;
|
||||
_clients[clientId] = client;
|
||||
|
||||
await RunClientAsync(client, ct);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogDebug(ex, "Failed to accept WebSocket client {ClientId}", clientId);
|
||||
try { socket.Shutdown(SocketShutdown.Both); } catch { }
|
||||
socket.Dispose();
|
||||
Interlocked.Decrement(ref _activeClientCount);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task RunClientAsync(NatsClient client, CancellationToken ct)
|
||||
{
|
||||
try
|
||||
@@ -942,6 +1068,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
_quitCts.Dispose();
|
||||
_tlsRateLimiter?.Dispose();
|
||||
_listener?.Dispose();
|
||||
_wsListener?.Dispose();
|
||||
foreach (var client in _clients.Values)
|
||||
client.Dispose();
|
||||
foreach (var account in _accounts.Values)
|
||||
|
||||
94
src/NATS.Server/WebSocket/WsCompression.cs
Normal file
94
src/NATS.Server/WebSocket/WsCompression.cs
Normal file
@@ -0,0 +1,94 @@
|
||||
using System.IO.Compression;
|
||||
|
||||
namespace NATS.Server.WebSocket;
|
||||
|
||||
/// <summary>
|
||||
/// permessage-deflate compression/decompression for WebSocket frames (RFC 7692).
|
||||
/// Ported from golang/nats-server/server/websocket.go lines 403-440 and 1391-1466.
|
||||
/// </summary>
|
||||
public static class WsCompression
|
||||
{
|
||||
/// <summary>
|
||||
/// Compresses data using deflate. Removes trailing 4 bytes (sync marker)
|
||||
/// per RFC 7692 Section 7.2.1.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// We call Flush() but intentionally do not Dispose() the DeflateStream before
|
||||
/// reading output, because Dispose writes a final deflate block (0x03 0x00) that
|
||||
/// would be corrupted by the 4-byte tail strip. Flush() alone writes a sync flush
|
||||
/// ending with 0x00 0x00 0xff 0xff, matching Go's flate.Writer.Flush() behavior.
|
||||
/// </remarks>
|
||||
public static byte[] Compress(ReadOnlySpan<byte> data)
|
||||
{
|
||||
var output = new MemoryStream();
|
||||
var deflate = new DeflateStream(output, CompressionLevel.Fastest, leaveOpen: true);
|
||||
try
|
||||
{
|
||||
deflate.Write(data);
|
||||
deflate.Flush();
|
||||
|
||||
var compressed = output.ToArray();
|
||||
|
||||
// Remove trailing 4-byte sync marker (0x00 0x00 0xff 0xff) per RFC 7692
|
||||
if (compressed.Length >= 4)
|
||||
return compressed[..^4];
|
||||
|
||||
return compressed;
|
||||
}
|
||||
finally
|
||||
{
|
||||
deflate.Dispose();
|
||||
output.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Decompresses collected compressed buffers.
|
||||
/// Appends trailer bytes before decompressing per RFC 7692 Section 7.2.2.
|
||||
/// Ported from golang/nats-server/server/websocket.go lines 403-440.
|
||||
/// The Go code appends compressLastBlock (9 bytes) which includes the sync
|
||||
/// marker plus a final empty stored block to signal end-of-stream to the
|
||||
/// flate reader.
|
||||
/// </summary>
|
||||
public static byte[] Decompress(List<byte[]> compressedBuffers, int maxPayload)
|
||||
{
|
||||
if (maxPayload <= 0)
|
||||
maxPayload = 1024 * 1024; // Default 1MB
|
||||
|
||||
// Concatenate all compressed buffers + trailer.
|
||||
// Per RFC 7692 Section 7.2.2, append the sync flush marker (0x00 0x00 0xff 0xff)
|
||||
// that was stripped during compression. The Go reference appends compressLastBlock
|
||||
// (9 bytes) for Go's flate reader; .NET's DeflateStream only needs the 4-byte trailer.
|
||||
int totalLen = 0;
|
||||
foreach (var buf in compressedBuffers)
|
||||
totalLen += buf.Length;
|
||||
totalLen += WsConstants.DecompressTrailer.Length;
|
||||
|
||||
var combined = new byte[totalLen];
|
||||
int offset = 0;
|
||||
foreach (var buf in compressedBuffers)
|
||||
{
|
||||
buf.CopyTo(combined, offset);
|
||||
offset += buf.Length;
|
||||
}
|
||||
|
||||
WsConstants.DecompressTrailer.CopyTo(combined, offset);
|
||||
|
||||
using var input = new MemoryStream(combined);
|
||||
using var deflate = new DeflateStream(input, CompressionMode.Decompress);
|
||||
using var output = new MemoryStream();
|
||||
|
||||
var readBuf = new byte[4096];
|
||||
int totalRead = 0;
|
||||
int n;
|
||||
while ((n = deflate.Read(readBuf, 0, readBuf.Length)) > 0)
|
||||
{
|
||||
totalRead += n;
|
||||
if (totalRead > maxPayload)
|
||||
throw new InvalidOperationException("decompressed data exceeds maximum payload size");
|
||||
output.Write(readBuf, 0, n);
|
||||
}
|
||||
|
||||
return output.ToArray();
|
||||
}
|
||||
}
|
||||
202
src/NATS.Server/WebSocket/WsConnection.cs
Normal file
202
src/NATS.Server/WebSocket/WsConnection.cs
Normal file
@@ -0,0 +1,202 @@
|
||||
namespace NATS.Server.WebSocket;
|
||||
|
||||
/// <summary>
|
||||
/// Stream wrapper that transparently frames/deframes WebSocket around raw TCP I/O.
|
||||
/// NatsClient uses this as its _stream -- FillPipeAsync and RunWriteLoopAsync work unchanged.
|
||||
/// Ported from golang/nats-server/server/websocket.go wsUpgrade/wrapWebsocket pattern.
|
||||
/// </summary>
|
||||
public sealed class WsConnection : Stream
|
||||
{
|
||||
private readonly Stream _inner;
|
||||
private readonly bool _compress;
|
||||
private readonly bool _maskRead;
|
||||
private readonly bool _maskWrite;
|
||||
private readonly bool _browser;
|
||||
private readonly bool _noCompFrag;
|
||||
private WsReadInfo _readInfo;
|
||||
// Read-side state: accessed only from the single FillPipeAsync reader task (no synchronization needed)
|
||||
private readonly Queue<byte[]> _readQueue = new();
|
||||
private int _readOffset;
|
||||
private readonly object _writeLock = new();
|
||||
private readonly List<ControlFrameAction> _pendingControlWrites = [];
|
||||
|
||||
public bool CloseReceived => _readInfo.CloseReceived;
|
||||
public int CloseStatus => _readInfo.CloseStatus;
|
||||
|
||||
public WsConnection(Stream inner, bool compress, bool maskRead, bool maskWrite, bool browser, bool noCompFrag)
|
||||
{
|
||||
_inner = inner;
|
||||
_compress = compress;
|
||||
_maskRead = maskRead;
|
||||
_maskWrite = maskWrite;
|
||||
_browser = browser;
|
||||
_noCompFrag = noCompFrag;
|
||||
_readInfo = new WsReadInfo(expectMask: maskRead);
|
||||
}
|
||||
|
||||
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken ct = default)
|
||||
{
|
||||
// Drain any buffered decoded payloads first
|
||||
if (_readQueue.Count > 0)
|
||||
return DrainReadQueue(buffer.Span);
|
||||
|
||||
while (true)
|
||||
{
|
||||
// Read raw bytes from inner stream
|
||||
var rawBuf = new byte[Math.Max(buffer.Length, 4096)];
|
||||
int bytesRead = await _inner.ReadAsync(rawBuf.AsMemory(), ct);
|
||||
if (bytesRead == 0) return 0;
|
||||
|
||||
// Decode frames
|
||||
var payloads = WsReadInfo.ReadFrames(_readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024);
|
||||
|
||||
// Collect control frame responses
|
||||
if (_readInfo.PendingControlFrames.Count > 0)
|
||||
{
|
||||
lock (_writeLock)
|
||||
_pendingControlWrites.AddRange(_readInfo.PendingControlFrames);
|
||||
_readInfo.PendingControlFrames.Clear();
|
||||
// Write pending control frames
|
||||
await FlushControlFramesAsync(ct);
|
||||
}
|
||||
|
||||
if (_readInfo.CloseReceived)
|
||||
return 0;
|
||||
|
||||
foreach (var payload in payloads)
|
||||
_readQueue.Enqueue(payload);
|
||||
|
||||
// If no payloads were decoded (e.g. only frame headers were read),
|
||||
// continue reading instead of returning 0 which signals end-of-stream
|
||||
if (_readQueue.Count > 0)
|
||||
return DrainReadQueue(buffer.Span);
|
||||
}
|
||||
}
|
||||
|
||||
public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken ct = default)
|
||||
{
|
||||
var data = buffer.Span;
|
||||
|
||||
if (_compress && data.Length > WsConstants.CompressThreshold)
|
||||
{
|
||||
var compressed = WsCompression.Compress(data);
|
||||
await WriteFramedAsync(compressed, compressed: true, ct);
|
||||
}
|
||||
else
|
||||
{
|
||||
await WriteFramedAsync(data.ToArray(), compressed: false, ct);
|
||||
}
|
||||
}
|
||||
|
||||
private async ValueTask WriteFramedAsync(byte[] payload, bool compressed, CancellationToken ct)
|
||||
{
|
||||
if (_browser && payload.Length > WsConstants.FrameSizeForBrowsers && !(_noCompFrag && compressed))
|
||||
{
|
||||
// Fragment for browsers
|
||||
int offset = 0;
|
||||
bool first = true;
|
||||
while (offset < payload.Length)
|
||||
{
|
||||
int chunkLen = Math.Min(WsConstants.FrameSizeForBrowsers, payload.Length - offset);
|
||||
bool final = offset + chunkLen >= payload.Length;
|
||||
var fh = new byte[WsConstants.MaxFrameHeaderSize];
|
||||
var (n, key) = WsFrameWriter.FillFrameHeader(fh, _maskWrite,
|
||||
first: first, final: final, compressed: first && compressed,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: chunkLen);
|
||||
|
||||
var chunk = payload.AsSpan(offset, chunkLen).ToArray();
|
||||
if (_maskWrite && key != null)
|
||||
WsFrameWriter.MaskBuf(key, chunk);
|
||||
|
||||
await _inner.WriteAsync(fh.AsMemory(0, n), ct);
|
||||
await _inner.WriteAsync(chunk.AsMemory(), ct);
|
||||
offset += chunkLen;
|
||||
first = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
var (header, key) = WsFrameWriter.CreateFrameHeader(_maskWrite, compressed, WsConstants.BinaryMessage, payload.Length);
|
||||
if (_maskWrite && key != null)
|
||||
WsFrameWriter.MaskBuf(key, payload);
|
||||
await _inner.WriteAsync(header.AsMemory(), ct);
|
||||
await _inner.WriteAsync(payload.AsMemory(), ct);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task FlushControlFramesAsync(CancellationToken ct)
|
||||
{
|
||||
List<ControlFrameAction> toWrite;
|
||||
lock (_writeLock)
|
||||
{
|
||||
if (_pendingControlWrites.Count == 0) return;
|
||||
toWrite = [.. _pendingControlWrites];
|
||||
_pendingControlWrites.Clear();
|
||||
}
|
||||
|
||||
foreach (var action in toWrite)
|
||||
{
|
||||
var frame = WsFrameWriter.BuildControlFrame(action.Opcode, action.Payload, _maskWrite);
|
||||
await _inner.WriteAsync(frame, ct);
|
||||
}
|
||||
await _inner.FlushAsync(ct);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sends a WebSocket close frame.
|
||||
/// </summary>
|
||||
public async Task SendCloseAsync(ClientClosedReason reason, CancellationToken ct = default)
|
||||
{
|
||||
var status = WsFrameWriter.MapCloseStatus(reason);
|
||||
var closePayload = WsFrameWriter.CreateCloseMessage(status, reason.ToReasonString());
|
||||
var frame = WsFrameWriter.BuildControlFrame(WsConstants.CloseMessage, closePayload, _maskWrite);
|
||||
await _inner.WriteAsync(frame, ct);
|
||||
await _inner.FlushAsync(ct);
|
||||
}
|
||||
|
||||
private int DrainReadQueue(Span<byte> buffer)
|
||||
{
|
||||
int written = 0;
|
||||
while (_readQueue.Count > 0 && written < buffer.Length)
|
||||
{
|
||||
var current = _readQueue.Peek();
|
||||
int available = current.Length - _readOffset;
|
||||
int toCopy = Math.Min(available, buffer.Length - written);
|
||||
current.AsSpan(_readOffset, toCopy).CopyTo(buffer[written..]);
|
||||
written += toCopy;
|
||||
_readOffset += toCopy;
|
||||
if (_readOffset >= current.Length)
|
||||
{
|
||||
_readQueue.Dequeue();
|
||||
_readOffset = 0;
|
||||
}
|
||||
}
|
||||
return written;
|
||||
}
|
||||
|
||||
// Stream abstract members
|
||||
public override bool CanRead => true;
|
||||
public override bool CanWrite => true;
|
||||
public override bool CanSeek => false;
|
||||
public override long Length => throw new NotSupportedException();
|
||||
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
|
||||
public override void Flush() => _inner.Flush();
|
||||
public override Task FlushAsync(CancellationToken ct) => _inner.FlushAsync(ct);
|
||||
public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use ReadAsync");
|
||||
public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use WriteAsync");
|
||||
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
|
||||
public override void SetLength(long value) => throw new NotSupportedException();
|
||||
|
||||
protected override void Dispose(bool disposing)
|
||||
{
|
||||
if (disposing)
|
||||
_inner.Dispose();
|
||||
base.Dispose(disposing);
|
||||
}
|
||||
|
||||
public override async ValueTask DisposeAsync()
|
||||
{
|
||||
await _inner.DisposeAsync();
|
||||
GC.SuppressFinalize(this);
|
||||
}
|
||||
}
|
||||
72
src/NATS.Server/WebSocket/WsConstants.cs
Normal file
72
src/NATS.Server/WebSocket/WsConstants.cs
Normal file
@@ -0,0 +1,72 @@
|
||||
namespace NATS.Server.WebSocket;
|
||||
|
||||
/// <summary>
|
||||
/// WebSocket protocol constants (RFC 6455).
|
||||
/// Ported from golang/nats-server/server/websocket.go lines 41-106.
|
||||
/// </summary>
|
||||
public static class WsConstants
|
||||
{
|
||||
// Opcodes (RFC 6455 Section 5.2)
|
||||
public const int TextMessage = 1;
|
||||
public const int BinaryMessage = 2;
|
||||
public const int CloseMessage = 8;
|
||||
public const int PingMessage = 9;
|
||||
public const int PongMessage = 10;
|
||||
public const int ContinuationFrame = 0;
|
||||
|
||||
// Frame header bits
|
||||
public const byte FinalBit = 0x80; // 1 << 7
|
||||
public const byte Rsv1Bit = 0x40; // 1 << 6 (compression, RFC 7692)
|
||||
public const byte Rsv2Bit = 0x20; // 1 << 5
|
||||
public const byte Rsv3Bit = 0x10; // 1 << 4
|
||||
public const byte MaskBit = 0x80; // 1 << 7 (in second byte)
|
||||
|
||||
// Frame size limits
|
||||
public const int MaxFrameHeaderSize = 14;
|
||||
public const int MaxControlPayloadSize = 125;
|
||||
public const int FrameSizeForBrowsers = 4096;
|
||||
public const int CompressThreshold = 64;
|
||||
public const int CloseStatusSize = 2;
|
||||
|
||||
// Close status codes (RFC 6455 Section 11.7)
|
||||
public const int CloseStatusNormalClosure = 1000;
|
||||
public const int CloseStatusGoingAway = 1001;
|
||||
public const int CloseStatusProtocolError = 1002;
|
||||
public const int CloseStatusUnsupportedData = 1003;
|
||||
public const int CloseStatusNoStatusReceived = 1005;
|
||||
public const int CloseStatusInvalidPayloadData = 1007;
|
||||
public const int CloseStatusPolicyViolation = 1008;
|
||||
public const int CloseStatusMessageTooBig = 1009;
|
||||
public const int CloseStatusInternalSrvError = 1011;
|
||||
public const int CloseStatusTlsHandshake = 1015;
|
||||
|
||||
// Compression constants (RFC 7692)
|
||||
public const string PmcExtension = "permessage-deflate";
|
||||
public const string PmcSrvNoCtx = "server_no_context_takeover";
|
||||
public const string PmcCliNoCtx = "client_no_context_takeover";
|
||||
public static readonly string PmcReqHeaderValue = $"{PmcExtension}; {PmcSrvNoCtx}; {PmcCliNoCtx}";
|
||||
public static readonly string PmcFullResponse = $"Sec-WebSocket-Extensions: {PmcExtension}; {PmcSrvNoCtx}; {PmcCliNoCtx}\r\n";
|
||||
|
||||
// Header names
|
||||
public const string NoMaskingHeader = "Nats-No-Masking";
|
||||
public const string NoMaskingValue = "true";
|
||||
public static readonly string NoMaskingFullResponse = $"{NoMaskingHeader}: {NoMaskingValue}\r\n";
|
||||
public const string XForwardedForHeader = "X-Forwarded-For";
|
||||
|
||||
// Path routing
|
||||
public const string ClientPath = "/";
|
||||
public const string LeafNodePath = "/leafnode";
|
||||
public const string MqttPath = "/mqtt";
|
||||
|
||||
// Decompression trailer appended before decompressing (RFC 7692 Section 7.2.2)
|
||||
public static readonly byte[] DecompressTrailer = [0x00, 0x00, 0xff, 0xff];
|
||||
|
||||
public static bool IsControlFrame(int opcode) => opcode >= CloseMessage;
|
||||
}
|
||||
|
||||
public enum WsClientKind
|
||||
{
|
||||
Client,
|
||||
Leaf,
|
||||
Mqtt,
|
||||
}
|
||||
171
src/NATS.Server/WebSocket/WsFrameWriter.cs
Normal file
171
src/NATS.Server/WebSocket/WsFrameWriter.cs
Normal file
@@ -0,0 +1,171 @@
|
||||
using System.Buffers.Binary;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
|
||||
namespace NATS.Server.WebSocket;
|
||||
|
||||
/// <summary>
|
||||
/// WebSocket frame construction, masking, and control message creation.
|
||||
/// Ported from golang/nats-server/server/websocket.go lines 543-726.
|
||||
/// </summary>
|
||||
public static class WsFrameWriter
|
||||
{
|
||||
/// <summary>
|
||||
/// Creates a complete frame header for a single-frame message (first=true, final=true).
|
||||
/// Returns (header bytes, mask key or null).
|
||||
/// </summary>
|
||||
public static (byte[] header, byte[]? key) CreateFrameHeader(
|
||||
bool useMasking, bool compressed, int opcode, int payloadLength)
|
||||
{
|
||||
var fh = new byte[WsConstants.MaxFrameHeaderSize];
|
||||
var (n, key) = FillFrameHeader(fh, useMasking,
|
||||
first: true, final: true, compressed: compressed, opcode: opcode, payloadLength: payloadLength);
|
||||
return (fh[..n], key);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Fills a pre-allocated frame header buffer.
|
||||
/// Returns (bytes written, mask key or null).
|
||||
/// </summary>
|
||||
public static (int written, byte[]? key) FillFrameHeader(
|
||||
Span<byte> fh, bool useMasking, bool first, bool final, bool compressed, int opcode, int payloadLength)
|
||||
{
|
||||
byte b0 = first ? (byte)opcode : (byte)0;
|
||||
if (final) b0 |= WsConstants.FinalBit;
|
||||
if (compressed) b0 |= WsConstants.Rsv1Bit;
|
||||
|
||||
byte b1 = 0;
|
||||
if (useMasking) b1 |= WsConstants.MaskBit;
|
||||
|
||||
int n;
|
||||
switch (payloadLength)
|
||||
{
|
||||
case <= 125:
|
||||
n = 2;
|
||||
fh[0] = b0;
|
||||
fh[1] = (byte)(b1 | (byte)payloadLength);
|
||||
break;
|
||||
case < 65536:
|
||||
n = 4;
|
||||
fh[0] = b0;
|
||||
fh[1] = (byte)(b1 | 126);
|
||||
BinaryPrimitives.WriteUInt16BigEndian(fh[2..], (ushort)payloadLength);
|
||||
break;
|
||||
default:
|
||||
n = 10;
|
||||
fh[0] = b0;
|
||||
fh[1] = (byte)(b1 | 127);
|
||||
BinaryPrimitives.WriteUInt64BigEndian(fh[2..], (ulong)payloadLength);
|
||||
break;
|
||||
}
|
||||
|
||||
byte[]? key = null;
|
||||
if (useMasking)
|
||||
{
|
||||
key = new byte[4];
|
||||
RandomNumberGenerator.Fill(key);
|
||||
key.CopyTo(fh[n..]);
|
||||
n += 4;
|
||||
}
|
||||
|
||||
return (n, key);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// XOR masks a buffer with a 4-byte key. Applies in-place.
|
||||
/// </summary>
|
||||
public static void MaskBuf(ReadOnlySpan<byte> key, Span<byte> buf)
|
||||
{
|
||||
for (int i = 0; i < buf.Length; i++)
|
||||
buf[i] ^= key[i & 3];
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// XOR masks multiple contiguous buffers as if they were one.
|
||||
/// </summary>
|
||||
public static void MaskBufs(ReadOnlySpan<byte> key, List<byte[]> bufs)
|
||||
{
|
||||
int pos = 0;
|
||||
foreach (var buf in bufs)
|
||||
{
|
||||
for (int j = 0; j < buf.Length; j++)
|
||||
{
|
||||
buf[j] ^= key[pos & 3];
|
||||
pos++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates a close message payload: 2-byte status code + optional UTF-8 body.
|
||||
/// Body truncated to fit MaxControlPayloadSize with "..." suffix.
|
||||
/// </summary>
|
||||
public static byte[] CreateCloseMessage(int status, string body)
|
||||
{
|
||||
var bodyBytes = Encoding.UTF8.GetBytes(body);
|
||||
int maxBody = WsConstants.MaxControlPayloadSize - WsConstants.CloseStatusSize;
|
||||
|
||||
if (bodyBytes.Length > maxBody)
|
||||
{
|
||||
var suffix = "..."u8;
|
||||
int truncLen = maxBody - suffix.Length;
|
||||
// Find a valid UTF-8 boundary by walking back from truncation point
|
||||
while (truncLen > 0 && (bodyBytes[truncLen] & 0xC0) == 0x80)
|
||||
truncLen--;
|
||||
var buf = new byte[WsConstants.CloseStatusSize + truncLen + suffix.Length];
|
||||
BinaryPrimitives.WriteUInt16BigEndian(buf, (ushort)status);
|
||||
bodyBytes.AsSpan(0, truncLen).CopyTo(buf.AsSpan(WsConstants.CloseStatusSize));
|
||||
suffix.CopyTo(buf.AsSpan(WsConstants.CloseStatusSize + truncLen));
|
||||
return buf;
|
||||
}
|
||||
|
||||
var result = new byte[WsConstants.CloseStatusSize + bodyBytes.Length];
|
||||
BinaryPrimitives.WriteUInt16BigEndian(result, (ushort)status);
|
||||
bodyBytes.CopyTo(result.AsSpan(WsConstants.CloseStatusSize));
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Builds a complete control frame (header + payload, optional masking).
|
||||
/// </summary>
|
||||
public static byte[] BuildControlFrame(int opcode, ReadOnlySpan<byte> payload, bool useMasking)
|
||||
{
|
||||
int headerSize = 2 + (useMasking ? 4 : 0);
|
||||
var frame = new byte[headerSize + payload.Length];
|
||||
var span = frame.AsSpan();
|
||||
var (n, key) = FillFrameHeader(span, useMasking,
|
||||
first: true, final: true, compressed: false, opcode: opcode, payloadLength: payload.Length);
|
||||
if (payload.Length > 0)
|
||||
{
|
||||
payload.CopyTo(span[n..]);
|
||||
if (useMasking && key != null)
|
||||
MaskBuf(key, span[n..]);
|
||||
}
|
||||
|
||||
return frame;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Maps a ClientClosedReason to a WebSocket close status code.
|
||||
/// Matches Go wsEnqueueCloseMessage in websocket.go lines 668-694.
|
||||
/// </summary>
|
||||
public static int MapCloseStatus(ClientClosedReason reason) => reason switch
|
||||
{
|
||||
ClientClosedReason.ClientClosed => WsConstants.CloseStatusNormalClosure,
|
||||
ClientClosedReason.AuthenticationTimeout or
|
||||
ClientClosedReason.AuthenticationViolation or
|
||||
ClientClosedReason.SlowConsumerPendingBytes or
|
||||
ClientClosedReason.SlowConsumerWriteDeadline or
|
||||
ClientClosedReason.MaxSubscriptionsExceeded or
|
||||
ClientClosedReason.AuthenticationExpired => WsConstants.CloseStatusPolicyViolation,
|
||||
ClientClosedReason.TlsHandshakeError => WsConstants.CloseStatusTlsHandshake,
|
||||
ClientClosedReason.ParseError or
|
||||
ClientClosedReason.ProtocolViolation => WsConstants.CloseStatusProtocolError,
|
||||
ClientClosedReason.MaxPayloadExceeded => WsConstants.CloseStatusMessageTooBig,
|
||||
ClientClosedReason.WriteError or
|
||||
ClientClosedReason.ReadError or
|
||||
ClientClosedReason.StaleConnection or
|
||||
ClientClosedReason.ServerShutdown => WsConstants.CloseStatusGoingAway,
|
||||
_ => WsConstants.CloseStatusInternalSrvError,
|
||||
};
|
||||
}
|
||||
81
src/NATS.Server/WebSocket/WsOriginChecker.cs
Normal file
81
src/NATS.Server/WebSocket/WsOriginChecker.cs
Normal file
@@ -0,0 +1,81 @@
|
||||
namespace NATS.Server.WebSocket;
|
||||
|
||||
/// <summary>
|
||||
/// Validates WebSocket Origin headers per RFC 6455 Section 10.2.
|
||||
/// Ported from golang/nats-server/server/websocket.go lines 933-1000.
|
||||
/// </summary>
|
||||
public sealed class WsOriginChecker
|
||||
{
|
||||
private readonly bool _sameOrigin;
|
||||
private readonly Dictionary<string, AllowedOrigin>? _allowedOrigins;
|
||||
|
||||
public WsOriginChecker(bool sameOrigin, List<string>? allowedOrigins)
|
||||
{
|
||||
_sameOrigin = sameOrigin;
|
||||
if (allowedOrigins is { Count: > 0 })
|
||||
{
|
||||
_allowedOrigins = new Dictionary<string, AllowedOrigin>(StringComparer.OrdinalIgnoreCase);
|
||||
foreach (var ao in allowedOrigins)
|
||||
{
|
||||
if (Uri.TryCreate(ao, UriKind.Absolute, out var uri))
|
||||
{
|
||||
var (host, port) = GetHostAndPort(uri.Scheme == "https", uri.Host, uri.Port);
|
||||
_allowedOrigins[host] = new AllowedOrigin(uri.Scheme, port);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns null if origin is allowed, or an error message if rejected.
|
||||
/// </summary>
|
||||
public string? CheckOrigin(string? origin, string requestHost, bool isTls)
|
||||
{
|
||||
if (!_sameOrigin && _allowedOrigins == null)
|
||||
return null;
|
||||
|
||||
if (string.IsNullOrEmpty(origin))
|
||||
return null;
|
||||
|
||||
if (!Uri.TryCreate(origin, UriKind.Absolute, out var originUri))
|
||||
return $"invalid origin: {origin}";
|
||||
|
||||
var (oh, op) = GetHostAndPort(originUri.Scheme == "https", originUri.Host, originUri.Port);
|
||||
|
||||
if (_sameOrigin)
|
||||
{
|
||||
var (rh, rp) = ParseHostPort(requestHost, isTls);
|
||||
if (!string.Equals(oh, rh, StringComparison.OrdinalIgnoreCase) || op != rp)
|
||||
return "not same origin";
|
||||
}
|
||||
|
||||
if (_allowedOrigins != null)
|
||||
{
|
||||
if (!_allowedOrigins.TryGetValue(oh, out var allowed) ||
|
||||
!string.Equals(originUri.Scheme, allowed.Scheme, StringComparison.OrdinalIgnoreCase) ||
|
||||
op != allowed.Port)
|
||||
{
|
||||
return "not in the allowed list";
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private static (string host, int port) GetHostAndPort(bool tls, string host, int port)
|
||||
{
|
||||
if (port <= 0)
|
||||
port = tls ? 443 : 80;
|
||||
return (host.ToLowerInvariant(), port);
|
||||
}
|
||||
|
||||
private static (string host, int port) ParseHostPort(string hostPort, bool isTls)
|
||||
{
|
||||
var colonIdx = hostPort.LastIndexOf(':');
|
||||
if (colonIdx > 0 && int.TryParse(hostPort.AsSpan(colonIdx + 1), out var port))
|
||||
return (hostPort[..colonIdx].ToLowerInvariant(), port);
|
||||
return (hostPort.ToLowerInvariant(), isTls ? 443 : 80);
|
||||
}
|
||||
|
||||
private readonly record struct AllowedOrigin(string Scheme, int Port);
|
||||
}
|
||||
322
src/NATS.Server/WebSocket/WsReadInfo.cs
Normal file
322
src/NATS.Server/WebSocket/WsReadInfo.cs
Normal file
@@ -0,0 +1,322 @@
|
||||
using System.Buffers.Binary;
|
||||
using System.Text;
|
||||
|
||||
namespace NATS.Server.WebSocket;
|
||||
|
||||
/// <summary>
|
||||
/// Per-connection WebSocket frame reading state machine.
|
||||
/// Ported from golang/nats-server/server/websocket.go lines 156-506.
|
||||
/// </summary>
|
||||
public class WsReadInfo
|
||||
{
|
||||
public int Remaining;
|
||||
public bool FrameStart;
|
||||
public bool FirstFrame;
|
||||
public bool FrameCompressed;
|
||||
public bool ExpectMask;
|
||||
public byte MaskKeyPos;
|
||||
public byte[] MaskKey;
|
||||
public List<byte[]>? CompressedBuffers;
|
||||
public int CompressedOffset;
|
||||
|
||||
// Control frame outputs
|
||||
public List<ControlFrameAction> PendingControlFrames;
|
||||
public bool CloseReceived;
|
||||
public int CloseStatus;
|
||||
public string? CloseBody;
|
||||
|
||||
public WsReadInfo(bool expectMask)
|
||||
{
|
||||
Remaining = 0;
|
||||
FrameStart = true;
|
||||
FirstFrame = true;
|
||||
FrameCompressed = false;
|
||||
ExpectMask = expectMask;
|
||||
MaskKeyPos = 0;
|
||||
MaskKey = new byte[4];
|
||||
CompressedBuffers = null;
|
||||
CompressedOffset = 0;
|
||||
PendingControlFrames = [];
|
||||
CloseReceived = false;
|
||||
CloseStatus = 0;
|
||||
CloseBody = null;
|
||||
}
|
||||
|
||||
public void SetMaskKey(ReadOnlySpan<byte> key)
|
||||
{
|
||||
key[..4].CopyTo(MaskKey);
|
||||
MaskKeyPos = 0;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Unmask buffer in-place using current mask key and position.
|
||||
/// Optimized for 8-byte chunks when buffer is large enough.
|
||||
/// Ported from websocket.go lines 509-536.
|
||||
/// </summary>
|
||||
public void Unmask(Span<byte> buf)
|
||||
{
|
||||
int p = MaskKeyPos;
|
||||
if (buf.Length < 16)
|
||||
{
|
||||
for (int i = 0; i < buf.Length; i++)
|
||||
{
|
||||
buf[i] ^= MaskKey[p & 3];
|
||||
p++;
|
||||
}
|
||||
MaskKeyPos = (byte)(p & 3);
|
||||
return;
|
||||
}
|
||||
|
||||
// Build 8-byte key for bulk XOR
|
||||
Span<byte> k = stackalloc byte[8];
|
||||
for (int i = 0; i < 8; i++)
|
||||
k[i] = MaskKey[(p + i) & 3];
|
||||
ulong km = BinaryPrimitives.ReadUInt64BigEndian(k);
|
||||
|
||||
int n = (buf.Length / 8) * 8;
|
||||
for (int i = 0; i < n; i += 8)
|
||||
{
|
||||
ulong tmp = BinaryPrimitives.ReadUInt64BigEndian(buf[i..]);
|
||||
tmp ^= km;
|
||||
BinaryPrimitives.WriteUInt64BigEndian(buf[i..], tmp);
|
||||
}
|
||||
|
||||
// Handle remaining bytes
|
||||
p += n;
|
||||
var tail = buf[n..];
|
||||
for (int i = 0; i < tail.Length; i++)
|
||||
{
|
||||
tail[i] ^= MaskKey[p & 3];
|
||||
p++;
|
||||
}
|
||||
MaskKeyPos = (byte)(p & 3);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Read and decode WebSocket frames from a buffer.
|
||||
/// Returns list of decoded payload byte arrays.
|
||||
/// Ported from websocket.go lines 208-351.
|
||||
/// </summary>
|
||||
public static List<byte[]> ReadFrames(WsReadInfo r, Stream stream, int available, int maxPayload)
|
||||
{
|
||||
var bufs = new List<byte[]>();
|
||||
var buf = new byte[available];
|
||||
int bytesRead = 0;
|
||||
|
||||
// Fill the buffer from the stream
|
||||
while (bytesRead < available)
|
||||
{
|
||||
int n = stream.Read(buf, bytesRead, available - bytesRead);
|
||||
if (n == 0) break;
|
||||
bytesRead += n;
|
||||
}
|
||||
|
||||
int pos = 0;
|
||||
int max = bytesRead;
|
||||
|
||||
while (pos < max)
|
||||
{
|
||||
if (r.FrameStart)
|
||||
{
|
||||
if (pos >= max) break;
|
||||
byte b0 = buf[pos];
|
||||
int frameType = b0 & 0x0F;
|
||||
bool final = (b0 & WsConstants.FinalBit) != 0;
|
||||
bool compressed = (b0 & WsConstants.Rsv1Bit) != 0;
|
||||
pos++;
|
||||
|
||||
// Read second byte
|
||||
var (b1Buf, newPos) = WsGet(stream, buf, pos, max, 1);
|
||||
pos = newPos;
|
||||
byte b1 = b1Buf[0];
|
||||
|
||||
// Check mask bit
|
||||
if (r.ExpectMask && (b1 & WsConstants.MaskBit) == 0)
|
||||
throw new InvalidOperationException("mask bit missing");
|
||||
|
||||
r.Remaining = b1 & 0x7F;
|
||||
|
||||
// Validate frame types
|
||||
if (WsConstants.IsControlFrame(frameType))
|
||||
{
|
||||
if (r.Remaining > WsConstants.MaxControlPayloadSize)
|
||||
throw new InvalidOperationException("control frame length too large");
|
||||
if (!final)
|
||||
throw new InvalidOperationException("control frame does not have final bit set");
|
||||
}
|
||||
else if (frameType == WsConstants.TextMessage || frameType == WsConstants.BinaryMessage)
|
||||
{
|
||||
if (!r.FirstFrame)
|
||||
throw new InvalidOperationException("new message before previous finished");
|
||||
r.FirstFrame = final;
|
||||
r.FrameCompressed = compressed;
|
||||
}
|
||||
else if (frameType == WsConstants.ContinuationFrame)
|
||||
{
|
||||
if (r.FirstFrame || compressed)
|
||||
throw new InvalidOperationException("invalid continuation frame");
|
||||
r.FirstFrame = final;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException($"unknown opcode {frameType}");
|
||||
}
|
||||
|
||||
// Extended payload length
|
||||
switch (r.Remaining)
|
||||
{
|
||||
case 126:
|
||||
{
|
||||
var (lenBuf, p2) = WsGet(stream, buf, pos, max, 2);
|
||||
pos = p2;
|
||||
r.Remaining = BinaryPrimitives.ReadUInt16BigEndian(lenBuf);
|
||||
break;
|
||||
}
|
||||
case 127:
|
||||
{
|
||||
var (lenBuf, p2) = WsGet(stream, buf, pos, max, 8);
|
||||
pos = p2;
|
||||
var len64 = BinaryPrimitives.ReadUInt64BigEndian(lenBuf);
|
||||
if (len64 > (ulong)maxPayload)
|
||||
throw new InvalidOperationException($"frame payload length {len64} exceeds max payload {maxPayload}");
|
||||
r.Remaining = (int)len64;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Read mask key (mask bit already validated at line 134)
|
||||
if (r.ExpectMask)
|
||||
{
|
||||
var (keyBuf, p2) = WsGet(stream, buf, pos, max, 4);
|
||||
pos = p2;
|
||||
keyBuf.AsSpan(0, 4).CopyTo(r.MaskKey);
|
||||
r.MaskKeyPos = 0;
|
||||
}
|
||||
|
||||
// Handle control frames
|
||||
if (WsConstants.IsControlFrame(frameType))
|
||||
{
|
||||
pos = HandleControlFrame(r, frameType, stream, buf, pos, max);
|
||||
continue;
|
||||
}
|
||||
|
||||
r.FrameStart = false;
|
||||
}
|
||||
|
||||
if (pos < max)
|
||||
{
|
||||
int n = r.Remaining;
|
||||
if (pos + n > max) n = max - pos;
|
||||
|
||||
var payloadSlice = buf.AsSpan(pos, n).ToArray();
|
||||
pos += n;
|
||||
r.Remaining -= n;
|
||||
|
||||
if (r.ExpectMask)
|
||||
r.Unmask(payloadSlice);
|
||||
|
||||
bool addToBufs = true;
|
||||
if (r.FrameCompressed)
|
||||
{
|
||||
addToBufs = false;
|
||||
r.CompressedBuffers ??= [];
|
||||
r.CompressedBuffers.Add(payloadSlice);
|
||||
|
||||
if (r.FirstFrame && r.Remaining == 0)
|
||||
{
|
||||
var decompressed = WsCompression.Decompress(r.CompressedBuffers, maxPayload);
|
||||
r.CompressedBuffers = null;
|
||||
r.FrameCompressed = false;
|
||||
addToBufs = true;
|
||||
payloadSlice = decompressed;
|
||||
}
|
||||
}
|
||||
|
||||
if (addToBufs && payloadSlice.Length > 0)
|
||||
bufs.Add(payloadSlice);
|
||||
|
||||
if (r.Remaining == 0)
|
||||
r.FrameStart = true;
|
||||
}
|
||||
}
|
||||
|
||||
return bufs;
|
||||
}
|
||||
|
||||
private static int HandleControlFrame(WsReadInfo r, int frameType, Stream stream, byte[] buf, int pos, int max)
|
||||
{
|
||||
byte[]? payload = null;
|
||||
if (r.Remaining > 0)
|
||||
{
|
||||
var (payloadBuf, newPos) = WsGet(stream, buf, pos, max, r.Remaining);
|
||||
pos = newPos;
|
||||
payload = payloadBuf;
|
||||
if (r.ExpectMask)
|
||||
r.Unmask(payload);
|
||||
r.Remaining = 0;
|
||||
}
|
||||
|
||||
switch (frameType)
|
||||
{
|
||||
case WsConstants.CloseMessage:
|
||||
r.CloseReceived = true;
|
||||
r.CloseStatus = WsConstants.CloseStatusNoStatusReceived;
|
||||
if (payload != null && payload.Length >= WsConstants.CloseStatusSize)
|
||||
{
|
||||
r.CloseStatus = BinaryPrimitives.ReadUInt16BigEndian(payload);
|
||||
if (payload.Length > WsConstants.CloseStatusSize)
|
||||
r.CloseBody = Encoding.UTF8.GetString(payload.AsSpan(WsConstants.CloseStatusSize));
|
||||
}
|
||||
// Per RFC 6455 Section 5.5.1, always send a close response
|
||||
if (r.CloseStatus != WsConstants.CloseStatusNoStatusReceived)
|
||||
{
|
||||
var closeMsg = WsFrameWriter.CreateCloseMessage(r.CloseStatus, r.CloseBody ?? "");
|
||||
r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.CloseMessage, closeMsg));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Empty close frame — respond with empty close
|
||||
r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.CloseMessage, []));
|
||||
}
|
||||
break;
|
||||
|
||||
case WsConstants.PingMessage:
|
||||
r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.PongMessage, payload ?? []));
|
||||
break;
|
||||
|
||||
case WsConstants.PongMessage:
|
||||
// Nothing to do
|
||||
break;
|
||||
}
|
||||
|
||||
return pos;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets needed bytes from buffer or reads from stream.
|
||||
/// Ported from websocket.go lines 178-193.
|
||||
/// </summary>
|
||||
private static (byte[] data, int newPos) WsGet(Stream stream, byte[] buf, int pos, int max, int needed)
|
||||
{
|
||||
int avail = max - pos;
|
||||
if (avail >= needed)
|
||||
return (buf[pos..(pos + needed)], pos + needed);
|
||||
|
||||
var b = new byte[needed];
|
||||
int start = 0;
|
||||
if (avail > 0)
|
||||
{
|
||||
Buffer.BlockCopy(buf, pos, b, 0, avail);
|
||||
start = avail;
|
||||
}
|
||||
while (start < needed)
|
||||
{
|
||||
int n = stream.Read(b, start, needed - start);
|
||||
if (n == 0) throw new IOException("unexpected end of stream");
|
||||
start += n;
|
||||
}
|
||||
return (b, pos + avail);
|
||||
}
|
||||
}
|
||||
|
||||
public readonly record struct ControlFrameAction(int Opcode, byte[] Payload);
|
||||
268
src/NATS.Server/WebSocket/WsUpgrade.cs
Normal file
268
src/NATS.Server/WebSocket/WsUpgrade.cs
Normal file
@@ -0,0 +1,268 @@
|
||||
using System.Net;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
|
||||
namespace NATS.Server.WebSocket;
|
||||
|
||||
/// <summary>
|
||||
/// WebSocket HTTP upgrade handshake handler.
|
||||
/// Ported from golang/nats-server/server/websocket.go lines 731-917.
|
||||
/// </summary>
|
||||
public static class WsUpgrade
|
||||
{
|
||||
public static async Task<WsUpgradeResult> TryUpgradeAsync(
|
||||
Stream inputStream, Stream outputStream, WebSocketOptions options,
|
||||
CancellationToken ct = default)
|
||||
{
|
||||
try
|
||||
{
|
||||
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
||||
cts.CancelAfter(options.HandshakeTimeout);
|
||||
var (method, path, headers) = await ReadHttpRequestAsync(inputStream, cts.Token);
|
||||
|
||||
if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase))
|
||||
return await FailAsync(outputStream, 405, "request method must be GET");
|
||||
|
||||
if (!headers.ContainsKey("Host"))
|
||||
return await FailAsync(outputStream, 400, "'Host' missing in request");
|
||||
|
||||
if (!HeaderContains(headers, "Upgrade", "websocket"))
|
||||
return await FailAsync(outputStream, 400, "invalid value for header 'Upgrade'");
|
||||
|
||||
if (!HeaderContains(headers, "Connection", "Upgrade"))
|
||||
return await FailAsync(outputStream, 400, "invalid value for header 'Connection'");
|
||||
|
||||
if (!headers.TryGetValue("Sec-WebSocket-Key", out var key) || string.IsNullOrEmpty(key))
|
||||
return await FailAsync(outputStream, 400, "key missing");
|
||||
|
||||
if (!HeaderContains(headers, "Sec-WebSocket-Version", "13"))
|
||||
return await FailAsync(outputStream, 400, "invalid version");
|
||||
|
||||
var kind = path switch
|
||||
{
|
||||
_ when path.EndsWith("/leafnode") => WsClientKind.Leaf,
|
||||
_ when path.EndsWith("/mqtt") => WsClientKind.Mqtt,
|
||||
_ => WsClientKind.Client,
|
||||
};
|
||||
|
||||
// Origin checking
|
||||
if (options.SameOrigin || options.AllowedOrigins is { Count: > 0 })
|
||||
{
|
||||
var checker = new WsOriginChecker(options.SameOrigin, options.AllowedOrigins);
|
||||
headers.TryGetValue("Origin", out var origin);
|
||||
if (string.IsNullOrEmpty(origin))
|
||||
headers.TryGetValue("Sec-WebSocket-Origin", out origin);
|
||||
var originErr = checker.CheckOrigin(origin, headers.GetValueOrDefault("Host", ""), isTls: false);
|
||||
if (originErr != null)
|
||||
return await FailAsync(outputStream, 403, $"origin not allowed: {originErr}");
|
||||
}
|
||||
|
||||
// Compression negotiation
|
||||
bool compress = options.Compression;
|
||||
if (compress)
|
||||
{
|
||||
compress = headers.TryGetValue("Sec-WebSocket-Extensions", out var ext) &&
|
||||
ext.Contains(WsConstants.PmcExtension, StringComparison.OrdinalIgnoreCase);
|
||||
}
|
||||
|
||||
// No-masking support (leaf nodes only — browser clients must always mask)
|
||||
bool noMasking = kind == WsClientKind.Leaf &&
|
||||
headers.TryGetValue(WsConstants.NoMaskingHeader, out var nmVal) &&
|
||||
string.Equals(nmVal.Trim(), WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase);
|
||||
|
||||
// Browser detection
|
||||
bool browser = false;
|
||||
bool noCompFrag = false;
|
||||
if (kind is WsClientKind.Client or WsClientKind.Mqtt &&
|
||||
headers.TryGetValue("User-Agent", out var ua) && ua.StartsWith("Mozilla/"))
|
||||
{
|
||||
browser = true;
|
||||
// Disable fragmentation of compressed frames for Safari browsers.
|
||||
// Safari has both "Version/" and "Safari/" in the user agent string,
|
||||
// while Chrome on macOS has "Safari/" but not "Version/".
|
||||
noCompFrag = compress && ua.Contains("Version/") && ua.Contains("Safari/");
|
||||
}
|
||||
|
||||
// Cookie extraction
|
||||
string? cookieJwt = null, cookieUsername = null, cookiePassword = null, cookieToken = null;
|
||||
if ((kind is WsClientKind.Client or WsClientKind.Mqtt) &&
|
||||
headers.TryGetValue("Cookie", out var cookieHeader))
|
||||
{
|
||||
var cookies = ParseCookies(cookieHeader);
|
||||
if (options.JwtCookie != null) cookies.TryGetValue(options.JwtCookie, out cookieJwt);
|
||||
if (options.UsernameCookie != null) cookies.TryGetValue(options.UsernameCookie, out cookieUsername);
|
||||
if (options.PasswordCookie != null) cookies.TryGetValue(options.PasswordCookie, out cookiePassword);
|
||||
if (options.TokenCookie != null) cookies.TryGetValue(options.TokenCookie, out cookieToken);
|
||||
}
|
||||
|
||||
// X-Forwarded-For client IP extraction
|
||||
string? clientIp = null;
|
||||
if (headers.TryGetValue(WsConstants.XForwardedForHeader, out var xff))
|
||||
{
|
||||
var ip = xff.Split(',')[0].Trim();
|
||||
if (IPAddress.TryParse(ip, out _))
|
||||
clientIp = ip;
|
||||
}
|
||||
|
||||
// Build the 101 Switching Protocols response
|
||||
var response = new StringBuilder();
|
||||
response.Append("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ");
|
||||
response.Append(ComputeAcceptKey(key));
|
||||
response.Append("\r\n");
|
||||
if (compress)
|
||||
response.Append(WsConstants.PmcFullResponse);
|
||||
if (noMasking)
|
||||
response.Append(WsConstants.NoMaskingFullResponse);
|
||||
if (options.Headers != null)
|
||||
{
|
||||
foreach (var (k, v) in options.Headers)
|
||||
{
|
||||
response.Append(k);
|
||||
response.Append(": ");
|
||||
response.Append(v);
|
||||
response.Append("\r\n");
|
||||
}
|
||||
}
|
||||
|
||||
response.Append("\r\n");
|
||||
|
||||
var responseBytes = Encoding.ASCII.GetBytes(response.ToString());
|
||||
await outputStream.WriteAsync(responseBytes);
|
||||
await outputStream.FlushAsync();
|
||||
|
||||
return new WsUpgradeResult(
|
||||
Success: true, Compress: compress, Browser: browser, NoCompFrag: noCompFrag,
|
||||
MaskRead: !noMasking, MaskWrite: false,
|
||||
CookieJwt: cookieJwt, CookieUsername: cookieUsername,
|
||||
CookiePassword: cookiePassword, CookieToken: cookieToken,
|
||||
ClientIp: clientIp, Kind: kind);
|
||||
}
|
||||
catch (Exception)
|
||||
{
|
||||
return WsUpgradeResult.Failed;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Computes the Sec-WebSocket-Accept value per RFC 6455 Section 4.2.2.
|
||||
/// </summary>
|
||||
public static string ComputeAcceptKey(string clientKey)
|
||||
{
|
||||
var combined = Encoding.ASCII.GetBytes(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
|
||||
var hash = SHA1.HashData(combined);
|
||||
return Convert.ToBase64String(hash);
|
||||
}
|
||||
|
||||
private static async Task<WsUpgradeResult> FailAsync(Stream output, int statusCode, string reason)
|
||||
{
|
||||
var statusText = statusCode switch
|
||||
{
|
||||
400 => "Bad Request",
|
||||
403 => "Forbidden",
|
||||
405 => "Method Not Allowed",
|
||||
_ => "Internal Server Error",
|
||||
};
|
||||
var response = $"HTTP/1.1 {statusCode} {statusText}\r\nSec-WebSocket-Version: 13\r\nContent-Type: text/plain\r\nContent-Length: {reason.Length}\r\n\r\n{reason}";
|
||||
await output.WriteAsync(Encoding.ASCII.GetBytes(response));
|
||||
await output.FlushAsync();
|
||||
return WsUpgradeResult.Failed;
|
||||
}
|
||||
|
||||
private static async Task<(string method, string path, Dictionary<string, string> headers)> ReadHttpRequestAsync(
|
||||
Stream stream, CancellationToken ct)
|
||||
{
|
||||
var headerBytes = new List<byte>(4096);
|
||||
var buf = new byte[512];
|
||||
while (true)
|
||||
{
|
||||
int n = await stream.ReadAsync(buf, ct);
|
||||
if (n == 0) throw new IOException("connection closed during handshake");
|
||||
for (int i = 0; i < n; i++)
|
||||
{
|
||||
headerBytes.Add(buf[i]);
|
||||
if (headerBytes.Count >= 4 &&
|
||||
headerBytes[^4] == '\r' && headerBytes[^3] == '\n' &&
|
||||
headerBytes[^2] == '\r' && headerBytes[^1] == '\n')
|
||||
goto done;
|
||||
if (headerBytes.Count > 8192)
|
||||
throw new InvalidOperationException("HTTP header too large");
|
||||
}
|
||||
}
|
||||
done:;
|
||||
|
||||
var text = Encoding.ASCII.GetString(headerBytes.ToArray());
|
||||
var lines = text.Split("\r\n", StringSplitOptions.None);
|
||||
if (lines.Length < 1) throw new InvalidOperationException("invalid HTTP request");
|
||||
|
||||
var parts = lines[0].Split(' ');
|
||||
if (parts.Length < 3) throw new InvalidOperationException("invalid HTTP request line");
|
||||
var method = parts[0];
|
||||
var path = parts[1];
|
||||
|
||||
var headers = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
|
||||
for (int i = 1; i < lines.Length; i++)
|
||||
{
|
||||
var line = lines[i];
|
||||
if (string.IsNullOrEmpty(line)) break;
|
||||
var colonIdx = line.IndexOf(':');
|
||||
if (colonIdx > 0)
|
||||
{
|
||||
var name = line[..colonIdx].Trim();
|
||||
var value = line[(colonIdx + 1)..].Trim();
|
||||
headers[name] = value;
|
||||
}
|
||||
}
|
||||
|
||||
return (method, path, headers);
|
||||
}
|
||||
|
||||
private static bool HeaderContains(Dictionary<string, string> headers, string name, string value)
|
||||
{
|
||||
if (!headers.TryGetValue(name, out var headerValue))
|
||||
return false;
|
||||
foreach (var token in headerValue.Split(','))
|
||||
{
|
||||
if (string.Equals(token.Trim(), value, StringComparison.OrdinalIgnoreCase))
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
private static Dictionary<string, string> ParseCookies(string cookieHeader)
|
||||
{
|
||||
var cookies = new Dictionary<string, string>(StringComparer.Ordinal);
|
||||
foreach (var pair in cookieHeader.Split(';'))
|
||||
{
|
||||
var trimmed = pair.Trim();
|
||||
var eqIdx = trimmed.IndexOf('=');
|
||||
if (eqIdx > 0)
|
||||
cookies[trimmed[..eqIdx].Trim()] = trimmed[(eqIdx + 1)..].Trim();
|
||||
}
|
||||
|
||||
return cookies;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Result of a WebSocket upgrade handshake attempt.
|
||||
/// </summary>
|
||||
public readonly record struct WsUpgradeResult(
|
||||
bool Success,
|
||||
bool Compress,
|
||||
bool Browser,
|
||||
bool NoCompFrag,
|
||||
bool MaskRead,
|
||||
bool MaskWrite,
|
||||
string? CookieJwt,
|
||||
string? CookieUsername,
|
||||
string? CookiePassword,
|
||||
string? CookieToken,
|
||||
string? ClientIp,
|
||||
WsClientKind Kind)
|
||||
{
|
||||
public static readonly WsUpgradeResult Failed = new(
|
||||
Success: false, Compress: false, Browser: false, NoCompFrag: false,
|
||||
MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null,
|
||||
CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client);
|
||||
}
|
||||
26
tests/NATS.Server.Tests/WebSocket/WebSocketOptionsTests.cs
Normal file
26
tests/NATS.Server.Tests/WebSocket/WebSocketOptionsTests.cs
Normal file
@@ -0,0 +1,26 @@
|
||||
using Shouldly;
|
||||
|
||||
namespace NATS.Server.Tests.WebSocket;
|
||||
|
||||
public class WebSocketOptionsTests
|
||||
{
|
||||
[Fact]
|
||||
public void DefaultOptions_PortIsNegativeOne_Disabled()
|
||||
{
|
||||
var opts = new WebSocketOptions();
|
||||
opts.Port.ShouldBe(-1);
|
||||
opts.Host.ShouldBe("0.0.0.0");
|
||||
opts.Compression.ShouldBeFalse();
|
||||
opts.NoTls.ShouldBeFalse();
|
||||
opts.HandshakeTimeout.ShouldBe(TimeSpan.FromSeconds(2));
|
||||
opts.AuthTimeout.ShouldBe(TimeSpan.FromSeconds(2));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void NatsOptions_HasWebSocketProperty()
|
||||
{
|
||||
var opts = new NatsOptions();
|
||||
opts.WebSocket.ShouldNotBeNull();
|
||||
opts.WebSocket.Port.ShouldBe(-1);
|
||||
}
|
||||
}
|
||||
58
tests/NATS.Server.Tests/WebSocket/WsCompressionTests.cs
Normal file
58
tests/NATS.Server.Tests/WebSocket/WsCompressionTests.cs
Normal file
@@ -0,0 +1,58 @@
|
||||
using NATS.Server.WebSocket;
|
||||
using Shouldly;
|
||||
|
||||
namespace NATS.Server.Tests.WebSocket;
|
||||
|
||||
public class WsCompressionTests
|
||||
{
|
||||
[Fact]
|
||||
public void CompressDecompress_RoundTrip()
|
||||
{
|
||||
var original = "Hello, WebSocket compression test! This is long enough to compress."u8.ToArray();
|
||||
var compressed = WsCompression.Compress(original);
|
||||
compressed.ShouldNotBeNull();
|
||||
compressed.Length.ShouldBeGreaterThan(0);
|
||||
|
||||
var decompressed = WsCompression.Decompress([compressed], maxPayload: 4096);
|
||||
decompressed.ShouldBe(original);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Decompress_ExceedsMaxPayload_Throws()
|
||||
{
|
||||
var original = new byte[1000];
|
||||
Random.Shared.NextBytes(original);
|
||||
var compressed = WsCompression.Compress(original);
|
||||
|
||||
Should.Throw<InvalidOperationException>(() =>
|
||||
WsCompression.Decompress([compressed], maxPayload: 100));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Compress_RemovesTrailing4Bytes()
|
||||
{
|
||||
var data = new byte[200];
|
||||
Random.Shared.NextBytes(data);
|
||||
var compressed = WsCompression.Compress(data);
|
||||
|
||||
// The compressed data should be valid for decompression when we add the trailer back
|
||||
var decompressed = WsCompression.Decompress([compressed], maxPayload: 4096);
|
||||
decompressed.ShouldBe(data);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Decompress_MultipleBuffers()
|
||||
{
|
||||
var original = new byte[500];
|
||||
Random.Shared.NextBytes(original);
|
||||
var compressed = WsCompression.Compress(original);
|
||||
|
||||
// Split compressed data into multiple chunks
|
||||
int mid = compressed.Length / 2;
|
||||
var chunk1 = compressed[..mid];
|
||||
var chunk2 = compressed[mid..];
|
||||
|
||||
var decompressed = WsCompression.Decompress([chunk1, chunk2], maxPayload: 4096);
|
||||
decompressed.ShouldBe(original);
|
||||
}
|
||||
}
|
||||
124
tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs
Normal file
124
tests/NATS.Server.Tests/WebSocket/WsConnectionTests.cs
Normal file
@@ -0,0 +1,124 @@
|
||||
using System.Buffers.Binary;
|
||||
using NATS.Server.WebSocket;
|
||||
|
||||
namespace NATS.Server.Tests.WebSocket;
|
||||
|
||||
public class WsConnectionTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task ReadAsync_DecodesFrameAndReturnsPayload()
|
||||
{
|
||||
var payload = "SUB test 1\r\n"u8.ToArray();
|
||||
var frame = BuildUnmaskedFrame(payload);
|
||||
var inner = new MemoryStream(frame);
|
||||
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
|
||||
|
||||
var buf = new byte[256];
|
||||
int n = await ws.ReadAsync(buf);
|
||||
|
||||
n.ShouldBe(payload.Length);
|
||||
buf[..n].ShouldBe(payload);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WriteAsync_FramesPayload()
|
||||
{
|
||||
var inner = new MemoryStream();
|
||||
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
|
||||
|
||||
var payload = "MSG test 1 5\r\nHello\r\n"u8.ToArray();
|
||||
await ws.WriteAsync(payload);
|
||||
await ws.FlushAsync();
|
||||
|
||||
inner.Position = 0;
|
||||
var written = inner.ToArray();
|
||||
// First 2 bytes should be WS frame header
|
||||
(written[0] & WsConstants.FinalBit).ShouldNotBe(0);
|
||||
(written[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage);
|
||||
int len = written[1] & 0x7F;
|
||||
len.ShouldBe(payload.Length);
|
||||
written[2..].ShouldBe(payload);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WriteAsync_WithCompression_CompressesLargePayload()
|
||||
{
|
||||
var inner = new MemoryStream();
|
||||
var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
|
||||
|
||||
var payload = new byte[200];
|
||||
Array.Fill<byte>(payload, 0x41); // 'A' repeated - very compressible
|
||||
await ws.WriteAsync(payload);
|
||||
await ws.FlushAsync();
|
||||
|
||||
inner.Position = 0;
|
||||
var written = inner.ToArray();
|
||||
// RSV1 bit should be set for compressed frame
|
||||
(written[0] & WsConstants.Rsv1Bit).ShouldNotBe(0);
|
||||
// Compressed size should be less than original
|
||||
written.Length.ShouldBeLessThan(payload.Length + 10);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WriteAsync_SmallPayload_NotCompressedEvenWhenEnabled()
|
||||
{
|
||||
var inner = new MemoryStream();
|
||||
var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
|
||||
|
||||
var payload = "Hi"u8.ToArray(); // Below CompressThreshold
|
||||
await ws.WriteAsync(payload);
|
||||
await ws.FlushAsync();
|
||||
|
||||
inner.Position = 0;
|
||||
var written = inner.ToArray();
|
||||
// RSV1 bit should NOT be set for small payloads
|
||||
(written[0] & WsConstants.Rsv1Bit).ShouldBe(0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ReadAsync_DecodesMaskedFrame()
|
||||
{
|
||||
var payload = "CONNECT {}\r\n"u8.ToArray();
|
||||
var (header, _) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: true, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: payload.Length);
|
||||
var maskKey = header[^4..];
|
||||
WsFrameWriter.MaskBuf(maskKey, payload);
|
||||
|
||||
var frame = new byte[header.Length + payload.Length];
|
||||
header.CopyTo(frame, 0);
|
||||
payload.CopyTo(frame, header.Length);
|
||||
|
||||
var inner = new MemoryStream(frame);
|
||||
var ws = new WsConnection(inner, compress: false, maskRead: true, maskWrite: false, browser: false, noCompFrag: false);
|
||||
|
||||
var buf = new byte[256];
|
||||
int n = await ws.ReadAsync(buf);
|
||||
|
||||
n.ShouldBe("CONNECT {}\r\n".Length);
|
||||
System.Text.Encoding.ASCII.GetString(buf, 0, n).ShouldBe("CONNECT {}\r\n");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ReadAsync_ReturnsZero_OnEndOfStream()
|
||||
{
|
||||
// Empty stream should return 0 (true end of stream)
|
||||
var inner = new MemoryStream([]);
|
||||
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
|
||||
|
||||
var buf = new byte[256];
|
||||
int n = await ws.ReadAsync(buf);
|
||||
n.ShouldBe(0);
|
||||
}
|
||||
|
||||
private static byte[] BuildUnmaskedFrame(byte[] payload)
|
||||
{
|
||||
var header = new byte[2];
|
||||
header[0] = (byte)(WsConstants.FinalBit | WsConstants.BinaryMessage);
|
||||
header[1] = (byte)payload.Length;
|
||||
var frame = new byte[2 + payload.Length];
|
||||
header.CopyTo(frame, 0);
|
||||
payload.CopyTo(frame, 2);
|
||||
return frame;
|
||||
}
|
||||
}
|
||||
53
tests/NATS.Server.Tests/WebSocket/WsConstantsTests.cs
Normal file
53
tests/NATS.Server.Tests/WebSocket/WsConstantsTests.cs
Normal file
@@ -0,0 +1,53 @@
|
||||
using NATS.Server.WebSocket;
|
||||
using Shouldly;
|
||||
|
||||
namespace NATS.Server.Tests.WebSocket;
|
||||
|
||||
public class WsConstantsTests
|
||||
{
|
||||
[Fact]
|
||||
public void OpCodes_MatchRfc6455()
|
||||
{
|
||||
WsConstants.TextMessage.ShouldBe(1);
|
||||
WsConstants.BinaryMessage.ShouldBe(2);
|
||||
WsConstants.CloseMessage.ShouldBe(8);
|
||||
WsConstants.PingMessage.ShouldBe(9);
|
||||
WsConstants.PongMessage.ShouldBe(10);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void FrameBits_MatchRfc6455()
|
||||
{
|
||||
WsConstants.FinalBit.ShouldBe((byte)0x80);
|
||||
WsConstants.Rsv1Bit.ShouldBe((byte)0x40);
|
||||
WsConstants.MaskBit.ShouldBe((byte)0x80);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CloseStatusCodes_MatchRfc6455()
|
||||
{
|
||||
WsConstants.CloseStatusNormalClosure.ShouldBe(1000);
|
||||
WsConstants.CloseStatusGoingAway.ShouldBe(1001);
|
||||
WsConstants.CloseStatusProtocolError.ShouldBe(1002);
|
||||
WsConstants.CloseStatusPolicyViolation.ShouldBe(1008);
|
||||
WsConstants.CloseStatusMessageTooBig.ShouldBe(1009);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(WsConstants.CloseMessage)]
|
||||
[InlineData(WsConstants.PingMessage)]
|
||||
[InlineData(WsConstants.PongMessage)]
|
||||
public void IsControlFrame_True(int opcode)
|
||||
{
|
||||
WsConstants.IsControlFrame(opcode).ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(WsConstants.TextMessage)]
|
||||
[InlineData(WsConstants.BinaryMessage)]
|
||||
[InlineData(0)]
|
||||
public void IsControlFrame_False(int opcode)
|
||||
{
|
||||
WsConstants.IsControlFrame(opcode).ShouldBeFalse();
|
||||
}
|
||||
}
|
||||
163
tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs
Normal file
163
tests/NATS.Server.Tests/WebSocket/WsFrameReadTests.cs
Normal file
@@ -0,0 +1,163 @@
|
||||
using System.Buffers.Binary;
|
||||
using NATS.Server.WebSocket;
|
||||
using Shouldly;
|
||||
|
||||
namespace NATS.Server.Tests.WebSocket;
|
||||
|
||||
public class WsFrameReadTests
|
||||
{
|
||||
/// <summary>Helper: build a single unmasked binary frame.</summary>
|
||||
private static byte[] BuildFrame(byte[] payload, bool fin = true, bool compressed = false, int opcode = WsConstants.BinaryMessage, bool mask = false, byte[]? maskKey = null)
|
||||
{
|
||||
int payloadLen = payload.Length;
|
||||
byte b0 = (byte)opcode;
|
||||
if (fin) b0 |= WsConstants.FinalBit;
|
||||
if (compressed) b0 |= WsConstants.Rsv1Bit;
|
||||
byte b1 = 0;
|
||||
if (mask) b1 |= WsConstants.MaskBit;
|
||||
|
||||
byte[] lenBytes;
|
||||
if (payloadLen <= 125)
|
||||
{
|
||||
lenBytes = [(byte)(b1 | (byte)payloadLen)];
|
||||
}
|
||||
else if (payloadLen < 65536)
|
||||
{
|
||||
lenBytes = new byte[3];
|
||||
lenBytes[0] = (byte)(b1 | 126);
|
||||
BinaryPrimitives.WriteUInt16BigEndian(lenBytes.AsSpan(1), (ushort)payloadLen);
|
||||
}
|
||||
else
|
||||
{
|
||||
lenBytes = new byte[9];
|
||||
lenBytes[0] = (byte)(b1 | 127);
|
||||
BinaryPrimitives.WriteUInt64BigEndian(lenBytes.AsSpan(1), (ulong)payloadLen);
|
||||
}
|
||||
|
||||
int totalLen = 1 + lenBytes.Length + (mask ? 4 : 0) + payloadLen;
|
||||
var frame = new byte[totalLen];
|
||||
frame[0] = b0;
|
||||
lenBytes.CopyTo(frame.AsSpan(1));
|
||||
int pos = 1 + lenBytes.Length;
|
||||
if (mask && maskKey != null)
|
||||
{
|
||||
maskKey.CopyTo(frame.AsSpan(pos));
|
||||
pos += 4;
|
||||
var maskedPayload = payload.ToArray();
|
||||
WsFrameWriter.MaskBuf(maskKey, maskedPayload);
|
||||
maskedPayload.CopyTo(frame.AsSpan(pos));
|
||||
}
|
||||
else
|
||||
{
|
||||
payload.CopyTo(frame.AsSpan(pos));
|
||||
}
|
||||
return frame;
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadSingleUnmaskedFrame()
|
||||
{
|
||||
var payload = "Hello"u8.ToArray();
|
||||
var frame = BuildFrame(payload);
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(1);
|
||||
result[0].ShouldBe(payload);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadMaskedFrame()
|
||||
{
|
||||
var payload = "Hello"u8.ToArray();
|
||||
byte[] key = [0x37, 0xFA, 0x21, 0x3D];
|
||||
var frame = BuildFrame(payload, mask: true, maskKey: key);
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: true);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(1);
|
||||
result[0].ShouldBe(payload);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Read16BitLengthFrame()
|
||||
{
|
||||
var payload = new byte[200];
|
||||
Random.Shared.NextBytes(payload);
|
||||
var frame = BuildFrame(payload);
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(1);
|
||||
result[0].ShouldBe(payload);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadPingFrame_ReturnsPongAction()
|
||||
{
|
||||
var frame = BuildFrame([], opcode: WsConstants.PingMessage);
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(0); // control frames don't produce payload
|
||||
readInfo.PendingControlFrames.Count.ShouldBe(1);
|
||||
readInfo.PendingControlFrames[0].Opcode.ShouldBe(WsConstants.PongMessage);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadCloseFrame_ReturnsCloseAction()
|
||||
{
|
||||
var closePayload = new byte[2];
|
||||
BinaryPrimitives.WriteUInt16BigEndian(closePayload, 1000);
|
||||
var frame = BuildFrame(closePayload, opcode: WsConstants.CloseMessage);
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(0);
|
||||
readInfo.CloseReceived.ShouldBeTrue();
|
||||
readInfo.CloseStatus.ShouldBe(1000);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadPongFrame_NoAction()
|
||||
{
|
||||
var frame = BuildFrame([], opcode: WsConstants.PongMessage);
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(0);
|
||||
readInfo.PendingControlFrames.Count.ShouldBe(0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Unmask_Optimized_8ByteChunks()
|
||||
{
|
||||
byte[] key = [0xAA, 0xBB, 0xCC, 0xDD];
|
||||
var original = new byte[32];
|
||||
Random.Shared.NextBytes(original);
|
||||
var masked = original.ToArray();
|
||||
|
||||
// Mask it
|
||||
for (int i = 0; i < masked.Length; i++)
|
||||
masked[i] ^= key[i & 3];
|
||||
|
||||
// Unmask using the state machine
|
||||
var info = new WsReadInfo(expectMask: true);
|
||||
info.SetMaskKey(key);
|
||||
info.Unmask(masked);
|
||||
|
||||
masked.ShouldBe(original);
|
||||
}
|
||||
}
|
||||
152
tests/NATS.Server.Tests/WebSocket/WsFrameWriterTests.cs
Normal file
152
tests/NATS.Server.Tests/WebSocket/WsFrameWriterTests.cs
Normal file
@@ -0,0 +1,152 @@
|
||||
using System.Buffers.Binary;
|
||||
using NATS.Server.WebSocket;
|
||||
using Shouldly;
|
||||
|
||||
namespace NATS.Server.Tests.WebSocket;
|
||||
|
||||
public class WsFrameWriterTests
|
||||
{
|
||||
[Fact]
|
||||
public void CreateFrameHeader_SmallPayload_7BitLength()
|
||||
{
|
||||
var (header, _) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: false, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 100);
|
||||
header.Length.ShouldBe(2);
|
||||
(header[0] & WsConstants.FinalBit).ShouldNotBe(0); // FIN set
|
||||
(header[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage);
|
||||
(header[1] & 0x7F).ShouldBe(100);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateFrameHeader_MediumPayload_16BitLength()
|
||||
{
|
||||
var (header, _) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: false, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 1000);
|
||||
header.Length.ShouldBe(4);
|
||||
(header[1] & 0x7F).ShouldBe(126);
|
||||
BinaryPrimitives.ReadUInt16BigEndian(header.AsSpan(2)).ShouldBe((ushort)1000);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateFrameHeader_LargePayload_64BitLength()
|
||||
{
|
||||
var (header, _) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: false, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 70000);
|
||||
header.Length.ShouldBe(10);
|
||||
(header[1] & 0x7F).ShouldBe(127);
|
||||
BinaryPrimitives.ReadUInt64BigEndian(header.AsSpan(2)).ShouldBe(70000UL);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateFrameHeader_WithMasking_Adds4ByteKey()
|
||||
{
|
||||
var (header, key) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: true, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 10);
|
||||
header.Length.ShouldBe(6); // 2 header + 4 mask key
|
||||
(header[1] & WsConstants.MaskBit).ShouldNotBe(0);
|
||||
key.ShouldNotBeNull();
|
||||
key.Length.ShouldBe(4);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateFrameHeader_Compressed_SetsRsv1Bit()
|
||||
{
|
||||
var (header, _) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: false, compressed: true,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 10);
|
||||
(header[0] & WsConstants.Rsv1Bit).ShouldNotBe(0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MaskBuf_XorsCorrectly()
|
||||
{
|
||||
byte[] key = [0xAA, 0xBB, 0xCC, 0xDD];
|
||||
byte[] data = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06];
|
||||
byte[] expected = new byte[data.Length];
|
||||
for (int i = 0; i < data.Length; i++)
|
||||
expected[i] = (byte)(data[i] ^ key[i & 3]);
|
||||
|
||||
WsFrameWriter.MaskBuf(key, data);
|
||||
data.ShouldBe(expected);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MaskBuf_RoundTrip()
|
||||
{
|
||||
byte[] key = [0x12, 0x34, 0x56, 0x78];
|
||||
byte[] original = "Hello, WebSocket!"u8.ToArray();
|
||||
var data = original.ToArray();
|
||||
|
||||
WsFrameWriter.MaskBuf(key, data);
|
||||
data.ShouldNotBe(original);
|
||||
WsFrameWriter.MaskBuf(key, data);
|
||||
data.ShouldBe(original);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateCloseMessage_WithStatusAndBody()
|
||||
{
|
||||
var msg = WsFrameWriter.CreateCloseMessage(1000, "normal closure");
|
||||
msg.Length.ShouldBe(2 + "normal closure".Length);
|
||||
BinaryPrimitives.ReadUInt16BigEndian(msg).ShouldBe((ushort)1000);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateCloseMessage_LongBody_Truncated()
|
||||
{
|
||||
var longBody = new string('x', 200);
|
||||
var msg = WsFrameWriter.CreateCloseMessage(1000, longBody);
|
||||
msg.Length.ShouldBeLessThanOrEqualTo(WsConstants.MaxControlPayloadSize);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MapCloseStatus_ClientClosed_NormalClosure()
|
||||
{
|
||||
WsFrameWriter.MapCloseStatus(ClientClosedReason.ClientClosed)
|
||||
.ShouldBe(WsConstants.CloseStatusNormalClosure);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MapCloseStatus_AuthTimeout_PolicyViolation()
|
||||
{
|
||||
WsFrameWriter.MapCloseStatus(ClientClosedReason.AuthenticationTimeout)
|
||||
.ShouldBe(WsConstants.CloseStatusPolicyViolation);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MapCloseStatus_ParseError_ProtocolError()
|
||||
{
|
||||
WsFrameWriter.MapCloseStatus(ClientClosedReason.ParseError)
|
||||
.ShouldBe(WsConstants.CloseStatusProtocolError);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MapCloseStatus_MaxPayload_MessageTooBig()
|
||||
{
|
||||
WsFrameWriter.MapCloseStatus(ClientClosedReason.MaxPayloadExceeded)
|
||||
.ShouldBe(WsConstants.CloseStatusMessageTooBig);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BuildControlFrame_PingNomask()
|
||||
{
|
||||
var frame = WsFrameWriter.BuildControlFrame(WsConstants.PingMessage, [], useMasking: false);
|
||||
frame.Length.ShouldBe(2);
|
||||
(frame[0] & WsConstants.FinalBit).ShouldNotBe(0);
|
||||
(frame[0] & 0x0F).ShouldBe(WsConstants.PingMessage);
|
||||
(frame[1] & 0x7F).ShouldBe(0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BuildControlFrame_PongWithPayload()
|
||||
{
|
||||
byte[] payload = [1, 2, 3, 4];
|
||||
var frame = WsFrameWriter.BuildControlFrame(WsConstants.PongMessage, payload, useMasking: false);
|
||||
frame.Length.ShouldBe(2 + 4);
|
||||
frame[2..].ShouldBe(payload);
|
||||
}
|
||||
}
|
||||
162
tests/NATS.Server.Tests/WebSocket/WsIntegrationTests.cs
Normal file
162
tests/NATS.Server.Tests/WebSocket/WsIntegrationTests.cs
Normal file
@@ -0,0 +1,162 @@
|
||||
using System.Buffers.Binary;
|
||||
using System.Net;
|
||||
using System.Net.Sockets;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
using NATS.Server.WebSocket;
|
||||
|
||||
namespace NATS.Server.Tests.WebSocket;
|
||||
|
||||
public class WsIntegrationTests : IAsyncLifetime
|
||||
{
|
||||
private NatsServer _server = null!;
|
||||
private NatsOptions _options = null!;
|
||||
|
||||
public async Task InitializeAsync()
|
||||
{
|
||||
_options = new NatsOptions
|
||||
{
|
||||
Port = 0,
|
||||
WebSocket = new WebSocketOptions { Port = 0, NoTls = true },
|
||||
};
|
||||
var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(b => { });
|
||||
_server = new NatsServer(_options, loggerFactory);
|
||||
_ = _server.StartAsync(CancellationToken.None);
|
||||
await _server.WaitForReadyAsync();
|
||||
}
|
||||
|
||||
public async Task DisposeAsync()
|
||||
{
|
||||
await _server.ShutdownAsync();
|
||||
_server.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WebSocket_ConnectAndReceiveInfo()
|
||||
{
|
||||
using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port));
|
||||
using var stream = new NetworkStream(socket, ownsSocket: false);
|
||||
|
||||
await SendUpgradeRequest(stream);
|
||||
var response = await ReadHttpResponse(stream);
|
||||
response.ShouldContain("101");
|
||||
|
||||
var wsFrame = await ReadWsFrame(stream);
|
||||
var info = Encoding.ASCII.GetString(wsFrame);
|
||||
info.ShouldStartWith("INFO ");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WebSocket_ConnectAndPing()
|
||||
{
|
||||
using var client = await ConnectWsClient();
|
||||
|
||||
// Send CONNECT and PING together
|
||||
await SendWsText(client, "CONNECT {}\r\nPING\r\n");
|
||||
|
||||
// Read PONG WS frame
|
||||
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
|
||||
var pong = await ReadWsFrameAsync(client, cts.Token);
|
||||
Encoding.ASCII.GetString(pong).ShouldContain("PONG");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WebSocket_PubSub()
|
||||
{
|
||||
using var sub = await ConnectWsClient();
|
||||
using var pub = await ConnectWsClient();
|
||||
|
||||
await SendWsText(sub, "CONNECT {}\r\nSUB test.ws 1\r\n");
|
||||
await Task.Delay(200);
|
||||
|
||||
await SendWsText(pub, "CONNECT {}\r\nPUB test.ws 5\r\nHello\r\n");
|
||||
|
||||
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
|
||||
var msg = await ReadWsFrameAsync(sub, cts.Token);
|
||||
Encoding.ASCII.GetString(msg).ShouldContain("MSG test.ws 1 5");
|
||||
}
|
||||
|
||||
private async Task<NetworkStream> ConnectWsClient()
|
||||
{
|
||||
var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port));
|
||||
var stream = new NetworkStream(socket, ownsSocket: true);
|
||||
|
||||
await SendUpgradeRequest(stream);
|
||||
var response = await ReadHttpResponse(stream);
|
||||
response.ShouldContain("101");
|
||||
|
||||
await ReadWsFrame(stream); // Read INFO frame
|
||||
return stream;
|
||||
}
|
||||
|
||||
private static async Task SendUpgradeRequest(NetworkStream stream)
|
||||
{
|
||||
var keyBytes = new byte[16];
|
||||
RandomNumberGenerator.Fill(keyBytes);
|
||||
var key = Convert.ToBase64String(keyBytes);
|
||||
|
||||
var request = $"GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {key}\r\nSec-WebSocket-Version: 13\r\n\r\n";
|
||||
await stream.WriteAsync(Encoding.ASCII.GetBytes(request));
|
||||
await stream.FlushAsync();
|
||||
}
|
||||
|
||||
private static async Task<string> ReadHttpResponse(NetworkStream stream)
|
||||
{
|
||||
// Read one byte at a time to avoid consuming WS frame bytes that follow the HTTP response
|
||||
var sb = new StringBuilder();
|
||||
var buf = new byte[1];
|
||||
while (true)
|
||||
{
|
||||
int n = await stream.ReadAsync(buf);
|
||||
if (n == 0) break;
|
||||
sb.Append((char)buf[0]);
|
||||
if (sb.Length >= 4 &&
|
||||
sb[^4] == '\r' && sb[^3] == '\n' &&
|
||||
sb[^2] == '\r' && sb[^1] == '\n')
|
||||
break;
|
||||
}
|
||||
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
private static Task<byte[]> ReadWsFrame(NetworkStream stream)
|
||||
=> ReadWsFrameAsync(stream, CancellationToken.None);
|
||||
|
||||
private static async Task<byte[]> ReadWsFrameAsync(NetworkStream stream, CancellationToken ct)
|
||||
{
|
||||
var header = new byte[2];
|
||||
await stream.ReadExactlyAsync(header, ct);
|
||||
int len = header[1] & 0x7F;
|
||||
if (len == 126)
|
||||
{
|
||||
var extLen = new byte[2];
|
||||
await stream.ReadExactlyAsync(extLen, ct);
|
||||
len = BinaryPrimitives.ReadUInt16BigEndian(extLen);
|
||||
}
|
||||
else if (len == 127)
|
||||
{
|
||||
var extLen = new byte[8];
|
||||
await stream.ReadExactlyAsync(extLen, ct);
|
||||
len = (int)BinaryPrimitives.ReadUInt64BigEndian(extLen);
|
||||
}
|
||||
|
||||
var payload = new byte[len];
|
||||
if (len > 0) await stream.ReadExactlyAsync(payload, ct);
|
||||
return payload;
|
||||
}
|
||||
|
||||
private static async Task SendWsText(NetworkStream stream, string text)
|
||||
{
|
||||
var payload = Encoding.ASCII.GetBytes(text);
|
||||
var (header, _) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: true, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: payload.Length);
|
||||
var maskKey = header[^4..];
|
||||
WsFrameWriter.MaskBuf(maskKey, payload);
|
||||
await stream.WriteAsync(header);
|
||||
await stream.WriteAsync(payload);
|
||||
await stream.FlushAsync();
|
||||
}
|
||||
}
|
||||
82
tests/NATS.Server.Tests/WebSocket/WsOriginCheckerTests.cs
Normal file
82
tests/NATS.Server.Tests/WebSocket/WsOriginCheckerTests.cs
Normal file
@@ -0,0 +1,82 @@
|
||||
using NATS.Server.WebSocket;
|
||||
using Shouldly;
|
||||
|
||||
namespace NATS.Server.Tests.WebSocket;
|
||||
|
||||
public class WsOriginCheckerTests
|
||||
{
|
||||
[Fact]
|
||||
public void NoOriginHeader_Accepted()
|
||||
{
|
||||
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
|
||||
checker.CheckOrigin(origin: null, requestHost: "localhost:4222", isTls: false)
|
||||
.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void NeitherSameNorList_AlwaysAccepted()
|
||||
{
|
||||
var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: null);
|
||||
checker.CheckOrigin("https://evil.com", "localhost:4222", false)
|
||||
.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SameOrigin_Match()
|
||||
{
|
||||
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
|
||||
checker.CheckOrigin("http://localhost:4222", "localhost:4222", false)
|
||||
.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SameOrigin_Mismatch()
|
||||
{
|
||||
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
|
||||
checker.CheckOrigin("http://other:4222", "localhost:4222", false)
|
||||
.ShouldNotBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SameOrigin_DefaultPort_Http()
|
||||
{
|
||||
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
|
||||
checker.CheckOrigin("http://localhost", "localhost:80", false)
|
||||
.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SameOrigin_DefaultPort_Https()
|
||||
{
|
||||
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
|
||||
checker.CheckOrigin("https://localhost", "localhost:443", true)
|
||||
.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void AllowedOrigins_Match()
|
||||
{
|
||||
var checker = new WsOriginChecker(sameOrigin: false,
|
||||
allowedOrigins: ["https://app.example.com"]);
|
||||
checker.CheckOrigin("https://app.example.com", "localhost:4222", false)
|
||||
.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void AllowedOrigins_Mismatch()
|
||||
{
|
||||
var checker = new WsOriginChecker(sameOrigin: false,
|
||||
allowedOrigins: ["https://app.example.com"]);
|
||||
checker.CheckOrigin("https://evil.example.com", "localhost:4222", false)
|
||||
.ShouldNotBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void AllowedOrigins_SchemeMismatch()
|
||||
{
|
||||
var checker = new WsOriginChecker(sameOrigin: false,
|
||||
allowedOrigins: ["https://app.example.com"]);
|
||||
checker.CheckOrigin("http://app.example.com", "localhost:4222", false)
|
||||
.ShouldNotBeNull();
|
||||
}
|
||||
}
|
||||
226
tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs
Normal file
226
tests/NATS.Server.Tests/WebSocket/WsUpgradeTests.cs
Normal file
@@ -0,0 +1,226 @@
|
||||
using System.Text;
|
||||
using NATS.Server.WebSocket;
|
||||
|
||||
namespace NATS.Server.Tests.WebSocket;
|
||||
|
||||
public class WsUpgradeTests
|
||||
{
|
||||
private static string BuildValidRequest(string path = "/", string? extraHeaders = null)
|
||||
{
|
||||
var sb = new StringBuilder();
|
||||
sb.Append($"GET {path} HTTP/1.1\r\n");
|
||||
sb.Append("Host: localhost:4222\r\n");
|
||||
sb.Append("Upgrade: websocket\r\n");
|
||||
sb.Append("Connection: Upgrade\r\n");
|
||||
sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
|
||||
sb.Append("Sec-WebSocket-Version: 13\r\n");
|
||||
if (extraHeaders != null)
|
||||
sb.Append(extraHeaders);
|
||||
sb.Append("\r\n");
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ValidUpgrade_Returns101()
|
||||
{
|
||||
var request = BuildValidRequest();
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Kind.ShouldBe(WsClientKind.Client);
|
||||
var response = ReadResponse(outputStream);
|
||||
response.ShouldContain("HTTP/1.1 101");
|
||||
response.ShouldContain("Upgrade: websocket");
|
||||
response.ShouldContain("Sec-WebSocket-Accept:");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task MissingUpgradeHeader_Returns400()
|
||||
{
|
||||
var request = "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeFalse();
|
||||
ReadResponse(outputStream).ShouldContain("400");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task MissingHost_Returns400()
|
||||
{
|
||||
var request = "GET / HTTP/1.1\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WrongVersion_Returns400()
|
||||
{
|
||||
var request = "GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 12\r\n\r\n";
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task LeafNodePath_ReturnsLeafKind()
|
||||
{
|
||||
var request = BuildValidRequest("/leafnode");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Kind.ShouldBe(WsClientKind.Leaf);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task MqttPath_ReturnsMqttKind()
|
||||
{
|
||||
var request = BuildValidRequest("/mqtt");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Kind.ShouldBe(WsClientKind.Mqtt);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompressionNegotiation_WhenEnabled()
|
||||
{
|
||||
var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}; {WsConstants.PmcSrvNoCtx}; {WsConstants.PmcCliNoCtx}\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Compress.ShouldBeTrue();
|
||||
ReadResponse(outputStream).ShouldContain("permessage-deflate");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompressionNegotiation_WhenDisabled()
|
||||
{
|
||||
var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = false });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Compress.ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task NoMaskingHeader_ForLeaf()
|
||||
{
|
||||
var request = BuildValidRequest("/leafnode", "Nats-No-Masking: true\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.MaskRead.ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task BrowserDetection_Mozilla()
|
||||
{
|
||||
var request = BuildValidRequest(extraHeaders: "User-Agent: Mozilla/5.0 (Windows)\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Browser.ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task SafariDetection_NoCompFrag()
|
||||
{
|
||||
var request = BuildValidRequest(extraHeaders:
|
||||
"User-Agent: Mozilla/5.0 (Macintosh) Version/15.0 Safari/605.1.15\r\n" +
|
||||
$"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.NoCompFrag.ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void AcceptKey_MatchesRfc6455Example()
|
||||
{
|
||||
// RFC 6455 Section 4.2.2 example
|
||||
var key = WsUpgrade.ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ==");
|
||||
key.ShouldBe("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CookieExtraction()
|
||||
{
|
||||
var request = BuildValidRequest(extraHeaders:
|
||||
"Cookie: jwt_token=my-jwt; nats_user=admin; nats_pass=secret\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions
|
||||
{
|
||||
NoTls = true,
|
||||
JwtCookie = "jwt_token",
|
||||
UsernameCookie = "nats_user",
|
||||
PasswordCookie = "nats_pass",
|
||||
};
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.CookieJwt.ShouldBe("my-jwt");
|
||||
result.CookieUsername.ShouldBe("admin");
|
||||
result.CookiePassword.ShouldBe("secret");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task XForwardedFor_ExtractsClientIp()
|
||||
{
|
||||
var request = BuildValidRequest(extraHeaders: "X-Forwarded-For: 192.168.1.100\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.ClientIp.ShouldBe("192.168.1.100");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task PostMethod_Returns405()
|
||||
{
|
||||
var request = "POST / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeFalse();
|
||||
ReadResponse(outputStream).ShouldContain("405");
|
||||
}
|
||||
|
||||
// Helper: create a readable input stream and writable output stream
|
||||
private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest)
|
||||
{
|
||||
var inputBytes = Encoding.ASCII.GetBytes(httpRequest);
|
||||
return (new MemoryStream(inputBytes), new MemoryStream());
|
||||
}
|
||||
|
||||
private static string ReadResponse(MemoryStream output)
|
||||
{
|
||||
output.Position = 0;
|
||||
return Encoding.ASCII.GetString(output.ToArray());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user