Compare commits

..

16 Commits

Author SHA1 Message Date
Joseph Doherty
1ebf283a8c Merge branch 'feature/websocket'
# Conflicts:
#	differences.md
2026-02-23 05:28:34 -05:00
Joseph Doherty
18a6d0f478 fix: address code review findings for WebSocket implementation
- Convert WsReadInfo from mutable struct to class (prevents silent copy bugs)
- Add handshake timeout enforcement via CancellationToken in WsUpgrade
- Use buffered reading (512 bytes) in ReadHttpRequestAsync instead of byte-at-a-time
- Add IAsyncDisposable to WsConnection for proper async cleanup
- Simplify redundant mask bit check in WsReadInfo
- Remove unused WsGuid and CompressLastBlock dead code from WsConstants
- Document single-reader assumption on WsConnection read-side state
2026-02-23 05:27:36 -05:00
Joseph Doherty
02a474a91e docs: add JetStream full parity design 2026-02-23 05:25:09 -05:00
Joseph Doherty
c8a89c9de2 docs: update mqtt connection type design with config parsing scope 2026-02-23 05:18:47 -05:00
Joseph Doherty
5fd2cf040d docs: update differences.md to reflect WebSocket implementation 2026-02-23 05:18:03 -05:00
Joseph Doherty
ca88036126 feat: integrate WebSocket accept loop into NatsServer and NatsClient
Add WebSocket listener support to NatsServer alongside the existing TCP
listener. When WebSocketOptions.Port >= 0, the server binds a second
socket, performs HTTP upgrade via WsUpgrade.TryUpgradeAsync, wraps the
connection in WsConnection for transparent frame/deframe, and hands it
to the standard NatsClient pipeline.

Changes:
- NatsClient: add IsWebSocket and WsInfo properties
- NatsServer: add RunWebSocketAcceptLoopAsync and AcceptWebSocketClientAsync,
  WS listener lifecycle in StartAsync/ShutdownAsync/Dispose
- NatsOptions: change WebSocketOptions.Port default from 0 to -1 (disabled)
- WsConnection.ReadAsync: fix premature end-of-stream when ReadFrames
  returns no payloads by looping until data is available
- Add WsIntegration tests (connect, ping, pub/sub over WebSocket)
- Add WsConnection masked frame and end-of-stream unit tests
2026-02-23 05:16:57 -05:00
Joseph Doherty
6d0a4d259e feat: add WsConnection Stream wrapper for transparent framing 2026-02-23 04:58:56 -05:00
Joseph Doherty
fe304dfe01 fix: review fixes for WsReadInfo and WsUpgrade
- WsReadInfo: validate 64-bit frame payload length against maxPayload
  before casting to int (prevents overflow/memory exhaustion)
- WsReadInfo: always send close response per RFC 6455 Section 5.5.1,
  including for empty close frames
- WsUpgrade: restrict no-masking to leaf node connections only (browser
  clients must always mask frames)
2026-02-23 04:55:53 -05:00
Joseph Doherty
1c948b5b0f feat: add WebSocket HTTP upgrade handshake 2026-02-23 04:53:21 -05:00
Joseph Doherty
bd29c529a8 feat: add WebSocket frame reader state machine 2026-02-23 04:51:54 -05:00
Joseph Doherty
1a1aa9d642 fix: use byte-length for close message truncation, add exception-safe disposal
- CreateCloseMessage now operates on UTF-8 byte length (matching Go's
  len(body) behavior) instead of character length, with proper UTF-8
  boundary detection during truncation
- WsCompression.Compress now uses try/finally for exception-safe disposal
  of DeflateStream and MemoryStream
2026-02-23 04:47:57 -05:00
Joseph Doherty
d49bc5b0d7 feat: add WebSocket permessage-deflate compression
Implement WsCompression with Compress/Decompress methods per RFC 7692.
Key .NET adaptation: Flush() without Dispose() on DeflateStream to produce
the correct sync flush marker that can be stripped and re-appended.
2026-02-23 04:42:31 -05:00
Joseph Doherty
8ded10d49b feat: add WebSocket frame writer with masking and close status mapping 2026-02-23 04:40:44 -05:00
Joseph Doherty
6981a38b72 feat: add WebSocket origin checker 2026-02-23 04:35:06 -05:00
Joseph Doherty
72f60054ed feat: add WebSocket protocol constants (RFC 6455)
Port WsConstants from golang/nats-server/server/websocket.go lines 41-106.
Includes opcodes, frame header bits, close status codes, compression
constants, header names, path routing, and the WsClientKind enum.
2026-02-23 04:33:04 -05:00
Joseph Doherty
708e1b4168 feat: add WebSocketOptions configuration class 2026-02-23 04:29:45 -05:00
50 changed files with 2615 additions and 2490 deletions

View File

@@ -11,7 +11,7 @@
| Feature | Go | .NET | Notes |
|---------|:--:|:----:|-------|
| NKey generation (server identity) | Y | Y | Ed25519 key pair via NATS.NKeys at startup |
| System account setup | Y | Y | `$SYS` account with InternalEventSystem, event publishing, request-reply services |
| System account setup | Y | Y | `$SYS` account created; no event publishing yet (stub) |
| Config file validation on startup | Y | Y | Full config parsing with error collection via `ConfigProcessor` |
| PID file writing | Y | Y | Written on startup, deleted on shutdown |
| Profiling HTTP endpoint (`/debug/pprof`) | Y | Stub | `ProfPort` option exists but endpoint not implemented |
@@ -64,10 +64,10 @@
| ROUTER | Y | N | Excluded per scope |
| GATEWAY | Y | N | Excluded per scope |
| LEAF | Y | N | Excluded per scope |
| SYSTEM (internal) | Y | Y | InternalClient + InternalEventSystem with Channel-based send/receive loops |
| SYSTEM (internal) | Y | N | |
| JETSTREAM (internal) | Y | N | |
| ACCOUNT (internal) | Y | Y | Lazy per-account InternalClient with import/export subscription support |
| WebSocket clients | Y | N | |
| ACCOUNT (internal) | Y | N | |
| WebSocket clients | Y | Y | Custom frame parser, permessage-deflate compression, origin checking, cookie auth |
| MQTT clients | Y | N | |
### Client Features
@@ -218,7 +218,7 @@ Go implements a sophisticated slow consumer detection system:
|---------|:--:|:----:|-------|
| Per-account SubList isolation | Y | Y | |
| Multi-account user resolution | Y | Y | `AccountConfig` per account in `NatsOptions.Accounts`; `GetOrCreateAccount` wires limits |
| Account exports/imports | Y | Y | ServiceImport/StreamImport with ExportAuth, subject transforms, response routing |
| Account exports/imports | Y | N | |
| Per-account connection limits | Y | Y | `Account.AddClient()` returns false when `MaxConnections` exceeded |
| Per-account subscription limits | Y | Y | `Account.IncrementSubscriptions()` enforced in `ProcessSub()` |
| Account JetStream limits | Y | N | Excluded per scope |
@@ -267,7 +267,8 @@ Go implements a sophisticated slow consumer detection system:
- ~~Advanced limits (MaxSubs, MaxSubTokens, MaxPending, WriteDeadline)~~ — `MaxSubs`, `MaxSubTokens` implemented; MaxPending/WriteDeadline already existed
- ~~Tags/metadata~~ — `Tags` dictionary implemented in `NatsOptions`
- ~~OCSP configuration~~ — `OcspConfig` with 4 modes (Auto/Always/Must/Never), peer verification, and stapling
- WebSocket/MQTT options
- ~~WebSocket options~~ — `WebSocketOptions` with port, compression, origin checking, cookie auth, custom headers
- MQTT options
- ~~Operator mode / account resolver~~ — `JwtAuthenticator` + `IAccountResolver` + `MemAccountResolver` with trusted keys
---
@@ -406,11 +407,6 @@ The following items from the original gap list have been implemented:
- **User revocation** — per-account tracking with wildcard (`*`) revocation
- **Config file parsing** — custom lexer/parser ported from Go; supports includes, variables, nested blocks, size suffixes
- **Hot reload (SIGHUP)** — re-parses config, diffs changes, validates reloadable set, applies with CLI precedence
- **SYSTEM client type** — InternalClient with InternalEventSystem, Channel-based send/receive loops, event publishing
- **ACCOUNT client type** — lazy per-account InternalClient with import/export subscription support
- **System event publishing** — connect/disconnect advisories, server stats, shutdown/lame-duck events, auth errors
- **System request-reply services** — $SYS.REQ.SERVER.*.VARZ/CONNZ/SUBSZ/HEALTHZ/IDZ/STATSZ with ping wildcards
- **Account exports/imports** — service and stream imports with ExportAuth, subject transforms, response routing, latency tracking
### Remaining Lower Priority
1. **Dynamic buffer sizing** — delegated to Pipe, less optimized for long-lived connections

View File

@@ -0,0 +1,141 @@
# Full JetStream and Cluster Prerequisite Parity Design
**Date:** 2026-02-23
**Status:** Approved
**Scope:** Port JetStream from Go with all prerequisite subsystems required for full Go JetStream test parity, including cluster route/gateway/leaf behaviors and RAFT/meta-cluster semantics.
**Verification Gate:** Go JetStream-focused test suites in `golang/nats-server/server/` plus new/updated .NET tests.
**Cutover Model:** Single end-to-end cutover (no interim acceptance gates).
## 1. Architecture
The implementation uses a full in-process .NET parity architecture that mirrors Go subsystem boundaries while keeping strict internal contracts.
1. Core Server Layer (`NatsServer`/`NatsClient`)
- Extend existing server/client runtime to support full client kinds and inter-server protocol paths.
- Preserve responsibility for socket lifecycle, parser integration, auth entry, and local dispatch.
2. Cluster Fabric Layer
- Add route mesh, gateway links, leafnode links, interest propagation, and remote subscription accounting.
- Provide transport-neutral contracts consumed by JetStream and RAFT replication services.
3. JetStream Control Plane
- Add account-scoped JetStream managers, API subject handlers (`$JS.API.*`), stream/consumer metadata lifecycle, advisories, and limit enforcement.
- Integrate with RAFT/meta services for replicated decisions.
4. JetStream Data Plane
- Add stream ingest path, retention/eviction logic, consumer delivery/ack/redelivery, mirror/source orchestration, and flow-control behavior.
- Use pluggable storage abstractions with parity-focused behavior.
5. RAFT and Replication Layer
- Implement meta-group plus per-asset replication groups, election/term logic, log replication, snapshots, and catchup.
- Expose deterministic commit/applied hooks to JetStream runtime layers.
6. Storage Layer
- Implement memstore and filestore with sequence indexing, subject indexing, compaction/snapshot support, and recovery semantics.
7. Observability Layer
- Upgrade `/jsz` and `/varz` JetStream blocks from placeholders to live runtime reporting with Go-compatible response shape.
## 2. Components and Contracts
### 2.1 New component families
1. Cluster and interserver subsystem
- Add route/gateway/leaf and interserver protocol operations under `src/NATS.Server/`.
- Extend parser/dispatcher with route/leaf/account operations currently excluded.
- Expand client-kind model and command routing constraints.
2. JetStream API and domain model
- Add `src/NATS.Server/JetStream/` subtree for API payload models, stream/consumer models, and error templates/codes.
3. JetStream runtime
- Add stream manager, consumer manager, ack processor, delivery scheduler, mirror/source orchestration, and flow control handlers.
- Integrate publish path with stream capture/store/ack behavior.
4. RAFT subsystem
- Add `src/NATS.Server/Raft/` for replicated logs, elections, snapshots, and membership operations.
5. Storage subsystem
- Add `src/NATS.Server/JetStream/Storage/` for `MemStore` and `FileStore`, sequence/subject indexes, and restart recovery.
### 2.2 Existing components to upgrade
1. `src/NATS.Server/NatsOptions.cs`
- Add full config surface for clustering, JetStream, storage, placement, and parity-required limits.
2. `src/NATS.Server/Configuration/ConfigProcessor.cs`
- Replace silent ignore behavior for cluster/jetstream keys with parsing, mapping, and validation.
3. `src/NATS.Server/Protocol/NatsParser.cs` and `src/NATS.Server/NatsClient.cs`
- Add missing interserver operations and kind-aware dispatch paths needed for clustered JetStream behavior.
4. Monitoring components
- Upgrade `src/NATS.Server/Monitoring/MonitorServer.cs` and `src/NATS.Server/Monitoring/Varz.cs`.
- Add/extend JS monitoring handlers and models for `/jsz` and JetStream runtime fields.
## 3. Data Flow and Behavioral Semantics
1. Inbound publish path
- Parse client publish commands, apply auth/permission checks, route to local subscribers and JetStream candidates.
- For JetStream subjects: apply preconditions, append to store, replicate via RAFT (as required), apply committed state, return Go-compatible pub ack.
2. Consumer delivery path
- Use shared push/pull state model for pending, ack floor, redelivery timers, flow control, and max ack pending.
- Enforce retention policy semantics (limits/interest/workqueue), filter subject behavior, replay policy, and eviction behavior.
3. Replication and control flow
- Meta RAFT governs replicated metadata decisions.
- Per-stream/per-consumer groups replicate state and snapshots.
- Leader changes preserve at-least-once delivery and consumer state invariants.
4. Recovery flow
- Reconstruct stream/consumer/store state on startup.
- In clustered mode, rejoin replication groups and catch up before serving full API/delivery workload.
- Preserve sequence continuity, subject indexes, delete markers, and pending/redelivery state.
5. Monitoring flow
- `/varz` JetStream fields and `/jsz` return live runtime state.
- Advisory and metric surfaces update from control-plane and data-plane events.
## 4. Error Handling and Operational Constraints
1. API error parity
- Match canonical JetStream codes/messages for validation failures, state conflicts, limits, leadership/quorum issues, and storage failures.
2. Protocol behavior
- Preserve normal client compatibility while adding interserver protocol and internal client-kind restrictions.
3. Storage and consistency failures
- Classify corruption/truncation/checksum/snapshot failures as recoverable vs non-recoverable.
- Avoid silent data loss and emit monitoring/advisory signals where parity requires.
4. Cluster and RAFT fault handling
- Explicitly handle no-quorum, stale leader, delayed apply, peer removal, catchup lag, and stepdown transitions.
- Return leadership-aware API errors.
5. Config/reload behavior
- Treat JetStream and cluster config as first-class with strict validation.
- Mirror Go-like reloadable vs restart-required change boundaries.
## 5. Testing and Verification Strategy
1. .NET unit tests
- Add focused tests for JetStream API validation, stream and consumer state, RAFT primitives, mem/file store invariants, and config parsing/validation.
2. .NET integration tests
- Add end-to-end tests for publish/store/consume/ack behavior, retention policies, restart recovery, and clustered prerequisites used by JetStream.
3. Parity harness
- Maintain mapping of Go JetStream test categories to .NET feature areas.
- Execute JetStream-focused Go tests from `golang/nats-server/server/` as acceptance benchmark.
4. `differences.md` policy
- Update only after verification gate passes.
- Remove opening JetStream exclusion scope statement and replace with updated parity scope.
## 6. Scope Decisions Captured
- Include all prerequisite non-JetStream subsystems required to satisfy full Go JetStream tests.
- Verification target is full Go JetStream-focused parity, not a narrowed subset.
- Delivery model is single end-to-end cutover.
- `differences.md` top-level scope statement will be updated to include JetStream and clustering parity coverage once verified.

View File

@@ -1,18 +1,21 @@
# MQTT Connection Type Port Design
## Goal
Port MQTT-related connection type parity from Go into the .NET server for two scoped areas:
Port MQTT-related connection type parity from Go into the .NET server for three scoped areas:
1. JWT `allowed_connection_types` behavior for `MQTT` / `MQTT_WS` (plus existing known types).
2. `/connz` filtering by `mqtt_client`.
3. Full MQTT configuration parsing from `mqtt {}` config blocks (all Go `MQTTOpts` fields).
## Scope
- In scope:
- JWT allowed connection type normalization and enforcement semantics.
- `/connz?mqtt_client=` option parsing and filtering.
- MQTT configuration model and config file parsing (all Go `MQTTOpts` fields).
- Expanded `MqttOptsVarz` monitoring output.
- Unit/integration tests for new and updated behavior.
- `differences.md` updates after implementation is verified.
- Out of scope:
- Full MQTT transport implementation.
- Full MQTT transport implementation (listener, protocol parser, sessions).
- WebSocket transport implementation.
- Leaf/route/gateway transport plumbing.
@@ -27,6 +30,8 @@ Port MQTT-related connection type parity from Go into the .NET server for two sc
- Extend connz monitoring options to parse `mqtt_client` and apply exact-match filtering before sort/pagination.
## Components
### JWT Connection-Type Enforcement
- `src/NATS.Server/Auth/IAuthenticator.cs`
- Extend `ClientAuthContext` with a connection-type value.
- `src/NATS.Server/Auth/Jwt/JwtConnectionTypes.cs` (new)
@@ -38,6 +43,8 @@ Port MQTT-related connection type parity from Go into the .NET server for two sc
- Enforce against current `ClientAuthContext.ConnectionType`.
- `src/NATS.Server/NatsClient.cs`
- Populate auth context connection type (currently `STANDARD`).
### Connz MQTT Client Filtering
- `src/NATS.Server/Monitoring/Connz.cs`
- Add `MqttClient` to `ConnzOptions` with JSON field `mqtt_client`.
- `src/NATS.Server/Monitoring/ConnzHandler.cs`
@@ -48,6 +55,30 @@ Port MQTT-related connection type parity from Go into the .NET server for two sc
- `src/NATS.Server/NatsServer.cs`
- Persist `MqttClient` into `ClosedClient` snapshot (empty for now).
### MQTT Configuration Parsing
- `src/NATS.Server/MqttOptions.cs` (new)
- Full model matching Go `MQTTOpts` struct (opts.go:613-707):
- Network: `Host`, `Port`
- Auth override: `NoAuthUser`, `Username`, `Password`, `Token`, `AuthTimeout`
- TLS: `TlsCert`, `TlsKey`, `TlsCaCert`, `TlsVerify`, `TlsTimeout`, `TlsMap`, `TlsPinnedCerts`
- JetStream: `JsDomain`, `StreamReplicas`, `ConsumerReplicas`, `ConsumerMemoryStorage`, `ConsumerInactiveThreshold`
- QoS: `AckWait`, `MaxAckPending`, `JsApiTimeout`
- `src/NATS.Server/NatsOptions.cs`
- Add `Mqtt` property of type `MqttOptions?`.
- `src/NATS.Server/Configuration/ConfigProcessor.cs`
- Add `ParseMqtt()` for `mqtt {}` config block with Go-compatible key aliases:
- `host`/`net` → Host, `listen` → Host+Port
- `ack_wait`/`ackwait` → AckWait
- `max_ack_pending`/`max_pending`/`max_inflight` → MaxAckPending
- `js_domain` → JsDomain
- `js_api_timeout`/`api_timeout` → JsApiTimeout
- `consumer_inactive_threshold`/`consumer_auto_cleanup` → ConsumerInactiveThreshold
- Nested `tls {}` and `authorization {}`/`authentication {}` blocks
- `src/NATS.Server/Monitoring/Varz.cs`
- Expand `MqttOptsVarz` from 3 fields to full monitoring-visible set.
- `src/NATS.Server/Monitoring/VarzHandler.cs`
- Populate expanded `MqttOptsVarz` from `NatsOptions.Mqtt`.
## Data Flow
1. Client sends `CONNECT`.
2. `NatsClient.ProcessConnectAsync` builds `ClientAuthContext` with `ConnectionType=STANDARD`.
@@ -73,6 +104,7 @@ Port MQTT-related connection type parity from Go into the .NET server for two sc
- MQTT transport is not implemented yet in this repository.
- Runtime connection type currently resolves to `STANDARD` in auth context.
- `mqtt_client` values remain empty until MQTT path populates them.
- MQTT config is parsed and stored but no listener is started.
## Testing Strategy
- `tests/NATS.Server.Tests/JwtAuthenticatorTests.cs`
@@ -85,9 +117,16 @@ Port MQTT-related connection type parity from Go into the .NET server for two sc
- `/connz?mqtt_client=<id>` returns matching connections only.
- `/connz?state=closed&mqtt_client=<id>` filters closed snapshots.
- non-existing ID yields empty connection set.
- `tests/NATS.Server.Tests/ConfigProcessorTests.cs` (or similar)
- Parse valid `mqtt {}` block with all fields.
- Parse config with aliases (ackwait vs ack_wait, host vs net, etc.).
- Parse nested `tls {}` and `authorization {}` blocks within mqtt.
- Varz MQTT section populated from config.
## Success Criteria
- JWT `allowed_connection_types` behavior matches Go semantics for known/unknown mixing and unknown-only rejection.
- `/connz` supports exact `mqtt_client` filtering for open and closed sets.
- `mqtt {}` config block parses all Go `MQTTOpts` fields with aliases.
- `MqttOptsVarz` includes full monitoring output.
- Added tests pass.
- `differences.md` accurately reflects implemented parity.

View File

@@ -1,5 +1,4 @@
using System.Collections.Concurrent;
using NATS.Server.Imports;
using NATS.Server.Subscriptions;
namespace NATS.Server.Auth;
@@ -13,8 +12,6 @@ public sealed class Account : IDisposable
public Permissions? DefaultPermissions { get; set; }
public int MaxConnections { get; set; } // 0 = unlimited
public int MaxSubscriptions { get; set; } // 0 = unlimited
public ExportMap Exports { get; } = new();
public ImportMap Imports { get; } = new();
// JWT fields
public string? Nkey { get; set; }
@@ -92,77 +89,5 @@ public sealed class Account : IDisposable
Interlocked.Add(ref _outBytes, bytes);
}
// Internal (ACCOUNT) client for import/export message routing
private InternalClient? _internalClient;
public InternalClient GetOrCreateInternalClient(ulong clientId)
{
if (_internalClient != null) return _internalClient;
_internalClient = new InternalClient(clientId, ClientKind.Account, this);
return _internalClient;
}
public void AddServiceExport(string subject, ServiceResponseType responseType, IEnumerable<Account>? approved)
{
var auth = new ExportAuth
{
ApprovedAccounts = approved != null ? new HashSet<string>(approved.Select(a => a.Name)) : null,
};
Exports.Services[subject] = new ServiceExport
{
Auth = auth,
Account = this,
ResponseType = responseType,
};
}
public void AddStreamExport(string subject, IEnumerable<Account>? approved)
{
var auth = new ExportAuth
{
ApprovedAccounts = approved != null ? new HashSet<string>(approved.Select(a => a.Name)) : null,
};
Exports.Streams[subject] = new StreamExport { Auth = auth };
}
public ServiceImport AddServiceImport(Account destination, string from, string to)
{
if (!destination.Exports.Services.TryGetValue(to, out var export))
throw new InvalidOperationException($"No service export found for '{to}' on account '{destination.Name}'");
if (!export.Auth.IsAuthorized(this))
throw new UnauthorizedAccessException($"Account '{Name}' not authorized to import '{to}' from '{destination.Name}'");
var si = new ServiceImport
{
DestinationAccount = destination,
From = from,
To = to,
Export = export,
ResponseType = export.ResponseType,
};
Imports.AddServiceImport(si);
return si;
}
public void AddStreamImport(Account source, string from, string to)
{
if (!source.Exports.Streams.TryGetValue(from, out var export))
throw new InvalidOperationException($"No stream export found for '{from}' on account '{source.Name}'");
if (!export.Auth.IsAuthorized(this))
throw new UnauthorizedAccessException($"Account '{Name}' not authorized to import '{from}' from '{source.Name}'");
var si = new StreamImport
{
SourceAccount = source,
From = from,
To = to,
};
Imports.Streams.Add(si);
}
public void Dispose() => SubList.Dispose();
}

View File

@@ -1,22 +0,0 @@
namespace NATS.Server;
/// <summary>
/// Identifies the type of a client connection.
/// Maps to Go's client kind constants in client.go:45-65.
/// </summary>
public enum ClientKind
{
Client,
Router,
Gateway,
Leaf,
System,
JetStream,
Account,
}
public static class ClientKindExtensions
{
public static bool IsInternal(this ClientKind kind) =>
kind is ClientKind.System or ClientKind.JetStream or ClientKind.Account;
}

View File

@@ -1,12 +0,0 @@
using System.Text.Json.Serialization;
namespace NATS.Server.Events;
[JsonSerializable(typeof(ConnectEventMsg))]
[JsonSerializable(typeof(DisconnectEventMsg))]
[JsonSerializable(typeof(AccountNumConns))]
[JsonSerializable(typeof(ServerStatsMsg))]
[JsonSerializable(typeof(ShutdownEventMsg))]
[JsonSerializable(typeof(LameDuckEventMsg))]
[JsonSerializable(typeof(AuthErrorEventMsg))]
internal partial class EventJsonContext : JsonSerializerContext;

View File

@@ -1,49 +0,0 @@
using NATS.Server.Auth;
using NATS.Server.Subscriptions;
namespace NATS.Server.Events;
/// <summary>
/// System event subject patterns.
/// Maps to Go events.go:41-97 subject constants.
/// </summary>
public static class EventSubjects
{
// Account-scoped events
public const string ConnectEvent = "$SYS.ACCOUNT.{0}.CONNECT";
public const string DisconnectEvent = "$SYS.ACCOUNT.{0}.DISCONNECT";
public const string AccountConnsNew = "$SYS.ACCOUNT.{0}.SERVER.CONNS";
public const string AccountConnsOld = "$SYS.SERVER.ACCOUNT.{0}.CONNS";
// Server-scoped events
public const string ServerStats = "$SYS.SERVER.{0}.STATSZ";
public const string ServerShutdown = "$SYS.SERVER.{0}.SHUTDOWN";
public const string ServerLameDuck = "$SYS.SERVER.{0}.LAMEDUCK";
public const string AuthError = "$SYS.SERVER.{0}.CLIENT.AUTH.ERR";
public const string AuthErrorAccount = "$SYS.ACCOUNT.CLIENT.AUTH.ERR";
// Request-reply subjects (server-specific)
public const string ServerReq = "$SYS.REQ.SERVER.{0}.{1}";
// Wildcard ping subjects (all servers respond)
public const string ServerPing = "$SYS.REQ.SERVER.PING.{0}";
// Account-scoped request subjects
public const string AccountReq = "$SYS.REQ.ACCOUNT.{0}.{1}";
// Inbox for responses
public const string InboxResponse = "$SYS._INBOX_.{0}";
}
/// <summary>
/// Callback signature for system message handlers.
/// Maps to Go's sysMsgHandler type in events.go:109.
/// </summary>
public delegate void SystemMessageHandler(
Subscription? sub,
INatsClient? client,
Account? account,
string subject,
string? reply,
ReadOnlyMemory<byte> headers,
ReadOnlyMemory<byte> message);

View File

@@ -1,270 +0,0 @@
using System.Text.Json.Serialization;
namespace NATS.Server.Events;
/// <summary>
/// Server identity block embedded in all system events.
/// </summary>
public sealed class EventServerInfo
{
[JsonPropertyName("name")]
public string Name { get; set; } = string.Empty;
[JsonPropertyName("host")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Host { get; set; }
[JsonPropertyName("id")]
public string Id { get; set; } = string.Empty;
[JsonPropertyName("cluster")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Cluster { get; set; }
[JsonPropertyName("domain")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Domain { get; set; }
[JsonPropertyName("ver")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Version { get; set; }
[JsonPropertyName("seq")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public ulong Seq { get; set; }
[JsonPropertyName("tags")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public Dictionary<string, string>? Tags { get; set; }
}
/// <summary>
/// Client identity block for connect/disconnect events.
/// </summary>
public sealed class EventClientInfo
{
[JsonPropertyName("start")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public DateTime Start { get; set; }
[JsonPropertyName("stop")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public DateTime Stop { get; set; }
[JsonPropertyName("host")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Host { get; set; }
[JsonPropertyName("id")]
public ulong Id { get; set; }
[JsonPropertyName("acc")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Account { get; set; }
[JsonPropertyName("name")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Name { get; set; }
[JsonPropertyName("lang")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Lang { get; set; }
[JsonPropertyName("ver")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Version { get; set; }
[JsonPropertyName("rtt")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public long RttNanos { get; set; }
}
public sealed class DataStats
{
[JsonPropertyName("msgs")]
public long Msgs { get; set; }
[JsonPropertyName("bytes")]
public long Bytes { get; set; }
}
/// <summary>Client connect advisory. Go events.go:155-160.</summary>
public sealed class ConnectEventMsg
{
public const string EventType = "io.nats.server.advisory.v1.client_connect";
[JsonPropertyName("type")]
public string Type { get; set; } = EventType;
[JsonPropertyName("id")]
public string Id { get; set; } = string.Empty;
[JsonPropertyName("timestamp")]
public DateTime Time { get; set; }
[JsonPropertyName("server")]
public EventServerInfo Server { get; set; } = new();
[JsonPropertyName("client")]
public EventClientInfo Client { get; set; } = new();
}
/// <summary>Client disconnect advisory. Go events.go:167-174.</summary>
public sealed class DisconnectEventMsg
{
public const string EventType = "io.nats.server.advisory.v1.client_disconnect";
[JsonPropertyName("type")]
public string Type { get; set; } = EventType;
[JsonPropertyName("id")]
public string Id { get; set; } = string.Empty;
[JsonPropertyName("timestamp")]
public DateTime Time { get; set; }
[JsonPropertyName("server")]
public EventServerInfo Server { get; set; } = new();
[JsonPropertyName("client")]
public EventClientInfo Client { get; set; } = new();
[JsonPropertyName("sent")]
public DataStats Sent { get; set; } = new();
[JsonPropertyName("received")]
public DataStats Received { get; set; } = new();
[JsonPropertyName("reason")]
public string Reason { get; set; } = string.Empty;
}
/// <summary>Account connection count heartbeat. Go events.go:210-214.</summary>
public sealed class AccountNumConns
{
public const string EventType = "io.nats.server.advisory.v1.account_connections";
[JsonPropertyName("type")]
public string Type { get; set; } = EventType;
[JsonPropertyName("id")]
public string Id { get; set; } = string.Empty;
[JsonPropertyName("timestamp")]
public DateTime Time { get; set; }
[JsonPropertyName("server")]
public EventServerInfo Server { get; set; } = new();
[JsonPropertyName("acc")]
public string AccountName { get; set; } = string.Empty;
[JsonPropertyName("conns")]
public int Connections { get; set; }
[JsonPropertyName("total_conns")]
public long TotalConnections { get; set; }
[JsonPropertyName("subs")]
public int Subscriptions { get; set; }
[JsonPropertyName("sent")]
public DataStats Sent { get; set; } = new();
[JsonPropertyName("received")]
public DataStats Received { get; set; } = new();
}
/// <summary>Server stats broadcast. Go events.go:150-153.</summary>
public sealed class ServerStatsMsg
{
[JsonPropertyName("server")]
public EventServerInfo Server { get; set; } = new();
[JsonPropertyName("statsz")]
public ServerStatsData Stats { get; set; } = new();
}
public sealed class ServerStatsData
{
[JsonPropertyName("start")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public DateTime Start { get; set; }
[JsonPropertyName("mem")]
public long Mem { get; set; }
[JsonPropertyName("cores")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public int Cores { get; set; }
[JsonPropertyName("connections")]
public int Connections { get; set; }
[JsonPropertyName("total_connections")]
public long TotalConnections { get; set; }
[JsonPropertyName("active_accounts")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public int ActiveAccounts { get; set; }
[JsonPropertyName("subscriptions")]
public long Subscriptions { get; set; }
[JsonPropertyName("in_msgs")]
public long InMsgs { get; set; }
[JsonPropertyName("out_msgs")]
public long OutMsgs { get; set; }
[JsonPropertyName("in_bytes")]
public long InBytes { get; set; }
[JsonPropertyName("out_bytes")]
public long OutBytes { get; set; }
[JsonPropertyName("slow_consumers")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public long SlowConsumers { get; set; }
}
/// <summary>Server shutdown notification.</summary>
public sealed class ShutdownEventMsg
{
[JsonPropertyName("server")]
public EventServerInfo Server { get; set; } = new();
[JsonPropertyName("reason")]
public string Reason { get; set; } = string.Empty;
}
/// <summary>Lame duck mode notification.</summary>
public sealed class LameDuckEventMsg
{
[JsonPropertyName("server")]
public EventServerInfo Server { get; set; } = new();
}
/// <summary>Auth error advisory.</summary>
public sealed class AuthErrorEventMsg
{
public const string EventType = "io.nats.server.advisory.v1.client_auth";
[JsonPropertyName("type")]
public string Type { get; set; } = EventType;
[JsonPropertyName("id")]
public string Id { get; set; } = string.Empty;
[JsonPropertyName("timestamp")]
public DateTime Time { get; set; }
[JsonPropertyName("server")]
public EventServerInfo Server { get; set; } = new();
[JsonPropertyName("client")]
public EventClientInfo Client { get; set; } = new();
[JsonPropertyName("reason")]
public string Reason { get; set; } = string.Empty;
}

View File

@@ -1,333 +0,0 @@
using System.Collections.Concurrent;
using System.Security.Cryptography;
using System.Text;
using System.Text.Json;
using System.Threading.Channels;
using Microsoft.Extensions.Logging;
using NATS.Server.Auth;
using NATS.Server.Subscriptions;
namespace NATS.Server.Events;
/// <summary>
/// Internal publish message queued for the send loop.
/// </summary>
public sealed class PublishMessage
{
public InternalClient? Client { get; init; }
public required string Subject { get; init; }
public string? Reply { get; init; }
public byte[]? Headers { get; init; }
public object? Body { get; init; }
public bool Echo { get; init; }
public bool IsLast { get; init; }
}
/// <summary>
/// Internal received message queued for the receive loop.
/// </summary>
public sealed class InternalSystemMessage
{
public required Subscription? Sub { get; init; }
public required INatsClient? Client { get; init; }
public required Account? Account { get; init; }
public required string Subject { get; init; }
public required string? Reply { get; init; }
public required ReadOnlyMemory<byte> Headers { get; init; }
public required ReadOnlyMemory<byte> Message { get; init; }
public required SystemMessageHandler Callback { get; init; }
}
/// <summary>
/// Manages the server's internal event system with Channel-based send/receive loops.
/// Maps to Go's internal struct in events.go:124-147 and the goroutines
/// internalSendLoop (events.go:495) and internalReceiveLoop (events.go:476).
/// </summary>
public sealed class InternalEventSystem : IAsyncDisposable
{
private readonly ILogger _logger;
private readonly Channel<PublishMessage> _sendQueue;
private readonly Channel<InternalSystemMessage> _receiveQueue;
private readonly Channel<InternalSystemMessage> _receiveQueuePings;
private readonly CancellationTokenSource _cts = new();
private Task? _sendLoop;
private Task? _receiveLoop;
private Task? _receiveLoopPings;
private NatsServer? _server;
private ulong _sequence;
private int _subscriptionId;
private readonly ConcurrentDictionary<string, SystemMessageHandler> _callbacks = new();
public Account SystemAccount { get; }
public InternalClient SystemClient { get; }
public string ServerHash { get; }
public InternalEventSystem(Account systemAccount, InternalClient systemClient, string serverName, ILogger logger)
{
_logger = logger;
SystemAccount = systemAccount;
SystemClient = systemClient;
// Hash server name for inbox routing (matches Go's shash)
ServerHash = Convert.ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(serverName)))[..8].ToLowerInvariant();
_sendQueue = Channel.CreateUnbounded<PublishMessage>(new UnboundedChannelOptions { SingleReader = true });
_receiveQueue = Channel.CreateUnbounded<InternalSystemMessage>(new UnboundedChannelOptions { SingleReader = true });
_receiveQueuePings = Channel.CreateUnbounded<InternalSystemMessage>(new UnboundedChannelOptions { SingleReader = true });
}
public void Start(NatsServer server)
{
_server = server;
var ct = _cts.Token;
_sendLoop = Task.Run(() => InternalSendLoopAsync(ct), ct);
_receiveLoop = Task.Run(() => InternalReceiveLoopAsync(_receiveQueue, ct), ct);
_receiveLoopPings = Task.Run(() => InternalReceiveLoopAsync(_receiveQueuePings, ct), ct);
// Periodic stats publish every 10 seconds
_ = Task.Run(async () =>
{
using var timer = new PeriodicTimer(TimeSpan.FromSeconds(10));
while (await timer.WaitForNextTickAsync(ct))
{
PublishServerStats();
}
}, ct);
}
/// <summary>
/// Registers system request-reply monitoring services for this server.
/// Maps to Go's initEventTracking in events.go.
/// Sets up handlers for $SYS.REQ.SERVER.{id}.VARZ, HEALTHZ, SUBSZ, STATSZ, IDZ
/// and wildcard $SYS.REQ.SERVER.PING.* subjects.
/// </summary>
public void InitEventTracking(NatsServer server)
{
_server = server;
var serverId = server.ServerId;
// Server-specific monitoring services
RegisterService(serverId, "VARZ", server.HandleVarzRequest);
RegisterService(serverId, "HEALTHZ", server.HandleHealthzRequest);
RegisterService(serverId, "SUBSZ", server.HandleSubszRequest);
RegisterService(serverId, "STATSZ", server.HandleStatszRequest);
RegisterService(serverId, "IDZ", server.HandleIdzRequest);
// Wildcard ping services (all servers respond)
SysSubscribe(string.Format(EventSubjects.ServerPing, "VARZ"), WrapRequestHandler(server.HandleVarzRequest));
SysSubscribe(string.Format(EventSubjects.ServerPing, "HEALTHZ"), WrapRequestHandler(server.HandleHealthzRequest));
SysSubscribe(string.Format(EventSubjects.ServerPing, "IDZ"), WrapRequestHandler(server.HandleIdzRequest));
SysSubscribe(string.Format(EventSubjects.ServerPing, "STATSZ"), WrapRequestHandler(server.HandleStatszRequest));
}
private void RegisterService(string serverId, string name, Action<string, string?> handler)
{
var subject = string.Format(EventSubjects.ServerReq, serverId, name);
SysSubscribe(subject, WrapRequestHandler(handler));
}
private SystemMessageHandler WrapRequestHandler(Action<string, string?> handler)
{
return (sub, client, acc, subject, reply, hdr, msg) =>
{
handler(subject, reply);
};
}
/// <summary>
/// Publishes a $SYS.SERVER.{id}.STATSZ message with current server statistics.
/// Maps to Go's sendStatsz in events.go.
/// Can be called manually for testing or is invoked periodically by the stats timer.
/// </summary>
public void PublishServerStats()
{
if (_server == null) return;
var subject = string.Format(EventSubjects.ServerStats, _server.ServerId);
var process = System.Diagnostics.Process.GetCurrentProcess();
var statsMsg = new ServerStatsMsg
{
Server = _server.BuildEventServerInfo(),
Stats = new ServerStatsData
{
Start = _server.StartTime,
Mem = process.WorkingSet64,
Cores = Environment.ProcessorCount,
Connections = _server.ClientCount,
TotalConnections = Interlocked.Read(ref _server.Stats.TotalConnections),
Subscriptions = SystemAccount.SubList.Count,
InMsgs = Interlocked.Read(ref _server.Stats.InMsgs),
OutMsgs = Interlocked.Read(ref _server.Stats.OutMsgs),
InBytes = Interlocked.Read(ref _server.Stats.InBytes),
OutBytes = Interlocked.Read(ref _server.Stats.OutBytes),
SlowConsumers = Interlocked.Read(ref _server.Stats.SlowConsumers),
},
};
Enqueue(new PublishMessage { Subject = subject, Body = statsMsg });
}
/// <summary>
/// Creates a system subscription in the system account's SubList.
/// Maps to Go's sysSubscribe in events.go:2796.
/// </summary>
public Subscription SysSubscribe(string subject, SystemMessageHandler callback)
{
var sid = Interlocked.Increment(ref _subscriptionId).ToString();
var sub = new Subscription
{
Subject = subject,
Sid = sid,
Client = SystemClient,
};
// Store callback keyed by SID so multiple subscriptions work
_callbacks[sid] = callback;
// Set a single routing callback on the system client that dispatches by SID
SystemClient.MessageCallback = (subj, s, reply, hdr, msg) =>
{
if (_callbacks.TryGetValue(s, out var cb))
{
_receiveQueue.Writer.TryWrite(new InternalSystemMessage
{
Sub = sub,
Client = SystemClient,
Account = SystemAccount,
Subject = subj,
Reply = reply,
Headers = hdr,
Message = msg,
Callback = cb,
});
}
};
SystemAccount.SubList.Insert(sub);
return sub;
}
/// <summary>
/// Returns the next monotonically increasing sequence number for event ordering.
/// </summary>
public ulong NextSequence() => Interlocked.Increment(ref _sequence);
/// <summary>
/// Enqueue an internal message for publishing through the send loop.
/// </summary>
public void Enqueue(PublishMessage message)
{
_sendQueue.Writer.TryWrite(message);
}
/// <summary>
/// The send loop: serializes messages and delivers them via the server's routing.
/// Maps to Go's internalSendLoop in events.go:495-668.
/// </summary>
private async Task InternalSendLoopAsync(CancellationToken ct)
{
try
{
await foreach (var pm in _sendQueue.Reader.ReadAllAsync(ct))
{
try
{
var seq = Interlocked.Increment(ref _sequence);
// Serialize body to JSON
byte[] payload;
if (pm.Body is byte[] raw)
{
payload = raw;
}
else if (pm.Body != null)
{
// Try source-generated context first, fall back to reflection-based for unknown types
var bodyType = pm.Body.GetType();
var typeInfo = EventJsonContext.Default.GetTypeInfo(bodyType);
payload = typeInfo != null
? JsonSerializer.SerializeToUtf8Bytes(pm.Body, typeInfo)
: JsonSerializer.SerializeToUtf8Bytes(pm.Body, bodyType);
}
else
{
payload = [];
}
// Deliver via the system account's SubList matching
var result = SystemAccount.SubList.Match(pm.Subject);
foreach (var sub in result.PlainSubs)
{
sub.Client?.SendMessage(pm.Subject, sub.Sid, pm.Reply,
pm.Headers ?? ReadOnlyMemory<byte>.Empty,
payload);
}
foreach (var queueGroup in result.QueueSubs)
{
if (queueGroup.Length == 0) continue;
var sub = queueGroup[0]; // Simple pick for internal
sub.Client?.SendMessage(pm.Subject, sub.Sid, pm.Reply,
pm.Headers ?? ReadOnlyMemory<byte>.Empty,
payload);
}
if (pm.IsLast)
break;
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Error in internal send loop processing message on {Subject}", pm.Subject);
}
}
}
catch (OperationCanceledException)
{
// Normal shutdown
}
}
/// <summary>
/// The receive loop: dispatches callbacks for internally-received messages.
/// Maps to Go's internalReceiveLoop in events.go:476-491.
/// </summary>
private async Task InternalReceiveLoopAsync(Channel<InternalSystemMessage> queue, CancellationToken ct)
{
try
{
await foreach (var msg in queue.Reader.ReadAllAsync(ct))
{
try
{
msg.Callback(msg.Sub, msg.Client, msg.Account, msg.Subject, msg.Reply, msg.Headers, msg.Message);
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Error in internal receive loop processing {Subject}", msg.Subject);
}
}
}
catch (OperationCanceledException)
{
// Normal shutdown
}
}
public async ValueTask DisposeAsync()
{
await _cts.CancelAsync();
_sendQueue.Writer.TryComplete();
_receiveQueue.Writer.TryComplete();
_receiveQueuePings.Writer.TryComplete();
if (_sendLoop != null) await _sendLoop.WaitAsync(TimeSpan.FromSeconds(2)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
if (_receiveLoop != null) await _receiveLoop.WaitAsync(TimeSpan.FromSeconds(2)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
if (_receiveLoopPings != null) await _receiveLoopPings.WaitAsync(TimeSpan.FromSeconds(2)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
_cts.Dispose();
}
}

View File

@@ -1,19 +0,0 @@
using NATS.Server.Auth;
using NATS.Server.Protocol;
namespace NATS.Server;
public interface INatsClient
{
ulong Id { get; }
ClientKind Kind { get; }
bool IsInternal => Kind.IsInternal();
Account? Account { get; }
ClientOptions? ClientOpts { get; }
ClientPermissions? Permissions { get; }
void SendMessage(string subject, string sid, string? replyTo,
ReadOnlyMemory<byte> headers, ReadOnlyMemory<byte> payload);
bool QueueOutbound(ReadOnlyMemory<byte> data);
void RemoveSubscription(string sid);
}

View File

@@ -1,25 +0,0 @@
using NATS.Server.Auth;
namespace NATS.Server.Imports;
public sealed class ExportAuth
{
public bool TokenRequired { get; init; }
public uint AccountPosition { get; init; }
public HashSet<string>? ApprovedAccounts { get; init; }
public Dictionary<string, long>? RevokedAccounts { get; init; }
public bool IsAuthorized(Account account)
{
if (RevokedAccounts != null && RevokedAccounts.ContainsKey(account.Name))
return false;
if (ApprovedAccounts == null && !TokenRequired && AccountPosition == 0)
return true;
if (ApprovedAccounts != null)
return ApprovedAccounts.Contains(account.Name);
return false;
}
}

View File

@@ -1,8 +0,0 @@
namespace NATS.Server.Imports;
public sealed class ExportMap
{
public Dictionary<string, StreamExport> Streams { get; } = new(StringComparer.Ordinal);
public Dictionary<string, ServiceExport> Services { get; } = new(StringComparer.Ordinal);
public Dictionary<string, ServiceImport> Responses { get; } = new(StringComparer.Ordinal);
}

View File

@@ -1,18 +0,0 @@
namespace NATS.Server.Imports;
public sealed class ImportMap
{
public List<StreamImport> Streams { get; } = [];
public Dictionary<string, List<ServiceImport>> Services { get; } = new(StringComparer.Ordinal);
public void AddServiceImport(ServiceImport si)
{
if (!Services.TryGetValue(si.From, out var list))
{
list = [];
Services[si.From] = list;
}
list.Add(si);
}
}

View File

@@ -1,47 +0,0 @@
using System.Text.Json.Serialization;
namespace NATS.Server.Imports;
public sealed class ServiceLatencyMsg
{
[JsonPropertyName("type")]
public string Type { get; set; } = "io.nats.server.metric.v1.service_latency";
[JsonPropertyName("requestor")]
public string Requestor { get; set; } = string.Empty;
[JsonPropertyName("responder")]
public string Responder { get; set; } = string.Empty;
[JsonPropertyName("status")]
public int Status { get; set; } = 200;
[JsonPropertyName("svc_latency")]
public long ServiceLatencyNanos { get; set; }
[JsonPropertyName("total_latency")]
public long TotalLatencyNanos { get; set; }
}
public static class LatencyTracker
{
public static bool ShouldSample(ServiceLatency latency)
{
if (latency.SamplingPercentage <= 0) return false;
if (latency.SamplingPercentage >= 100) return true;
return Random.Shared.Next(100) < latency.SamplingPercentage;
}
public static ServiceLatencyMsg BuildLatencyMsg(
string requestor, string responder,
TimeSpan serviceLatency, TimeSpan totalLatency)
{
return new ServiceLatencyMsg
{
Requestor = requestor,
Responder = responder,
ServiceLatencyNanos = serviceLatency.Ticks * 100,
TotalLatencyNanos = totalLatency.Ticks * 100,
};
}
}

View File

@@ -1,64 +0,0 @@
using System.Security.Cryptography;
using NATS.Server.Auth;
namespace NATS.Server.Imports;
/// <summary>
/// Handles response routing for service imports.
/// Maps to Go's service reply prefix generation and response cleanup.
/// Reference: golang/nats-server/server/accounts.go — addRespServiceImport, removeRespServiceImport
/// </summary>
public static class ResponseRouter
{
private static readonly char[] Base62 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789".ToCharArray();
/// <summary>
/// Generates a unique reply prefix for response routing.
/// Format: "_R_.{10 random base62 chars}."
/// </summary>
public static string GenerateReplyPrefix()
{
Span<byte> bytes = stackalloc byte[10];
RandomNumberGenerator.Fill(bytes);
var chars = new char[10];
for (int i = 0; i < 10; i++)
chars[i] = Base62[bytes[i] % 62];
return $"_R_.{new string(chars)}.";
}
/// <summary>
/// Creates a response service import that maps the generated reply prefix
/// back to the original reply subject on the requesting account.
/// </summary>
public static ServiceImport CreateResponseImport(
Account exporterAccount,
ServiceImport originalImport,
string originalReply)
{
var replyPrefix = GenerateReplyPrefix();
var responseSi = new ServiceImport
{
DestinationAccount = exporterAccount,
From = replyPrefix + ">",
To = originalReply,
IsResponse = true,
ResponseType = originalImport.ResponseType,
Export = originalImport.Export,
TimestampTicks = DateTime.UtcNow.Ticks,
};
exporterAccount.Exports.Responses[replyPrefix] = responseSi;
return responseSi;
}
/// <summary>
/// Removes a response import from the account's export map.
/// For Singleton responses, this is called after the first reply is delivered.
/// For Streamed/Chunked, it is called when the response stream ends.
/// </summary>
public static void CleanupResponse(Account account, string replyPrefix, ServiceImport responseSi)
{
account.Exports.Responses.Remove(replyPrefix);
}
}

View File

@@ -1,13 +0,0 @@
using NATS.Server.Auth;
namespace NATS.Server.Imports;
public sealed class ServiceExport
{
public ExportAuth Auth { get; init; } = new();
public Account? Account { get; init; }
public ServiceResponseType ResponseType { get; init; } = ServiceResponseType.Singleton;
public TimeSpan ResponseThreshold { get; init; } = TimeSpan.FromMinutes(2);
public ServiceLatency? Latency { get; init; }
public bool AllowTrace { get; init; }
}

View File

@@ -1,21 +0,0 @@
using NATS.Server.Auth;
using NATS.Server.Subscriptions;
namespace NATS.Server.Imports;
public sealed class ServiceImport
{
public required Account DestinationAccount { get; init; }
public required string From { get; init; }
public required string To { get; init; }
public SubjectTransform? Transform { get; init; }
public ServiceExport? Export { get; init; }
public ServiceResponseType ResponseType { get; init; } = ServiceResponseType.Singleton;
public byte[]? Sid { get; set; }
public bool IsResponse { get; init; }
public bool UsePub { get; init; }
public bool Invalid { get; set; }
public bool Share { get; init; }
public bool Tracking { get; init; }
public long TimestampTicks { get; set; }
}

View File

@@ -1,7 +0,0 @@
namespace NATS.Server.Imports;
public sealed class ServiceLatency
{
public int SamplingPercentage { get; init; } = 100;
public string Subject { get; init; } = string.Empty;
}

View File

@@ -1,8 +0,0 @@
namespace NATS.Server.Imports;
public enum ServiceResponseType
{
Singleton,
Streamed,
Chunked,
}

View File

@@ -1,6 +0,0 @@
namespace NATS.Server.Imports;
public sealed class StreamExport
{
public ExportAuth Auth { get; init; } = new();
}

View File

@@ -1,14 +0,0 @@
using NATS.Server.Auth;
using NATS.Server.Subscriptions;
namespace NATS.Server.Imports;
public sealed class StreamImport
{
public required Account SourceAccount { get; init; }
public required string From { get; init; }
public required string To { get; init; }
public SubjectTransform? Transform { get; init; }
public bool UsePub { get; init; }
public bool Invalid { get; set; }
}

View File

@@ -1,59 +0,0 @@
using NATS.Server.Auth;
using NATS.Server.Protocol;
using NATS.Server.Subscriptions;
namespace NATS.Server;
/// <summary>
/// Lightweight socketless client for internal messaging (SYSTEM, ACCOUNT, JETSTREAM).
/// Maps to Go's internal client created by createInternalClient() in server.go:1910-1936.
/// No network I/O — messages are delivered via callback.
/// </summary>
public sealed class InternalClient : INatsClient
{
public ulong Id { get; }
public ClientKind Kind { get; }
public bool IsInternal => Kind.IsInternal();
public Account? Account { get; }
public ClientOptions? ClientOpts => null;
public ClientPermissions? Permissions => null;
/// <summary>
/// Callback invoked when a message is delivered to this internal client.
/// Set by the event system or account import infrastructure.
/// </summary>
public Action<string, string, string?, ReadOnlyMemory<byte>, ReadOnlyMemory<byte>>? MessageCallback { get; set; }
private readonly Dictionary<string, Subscription> _subs = new(StringComparer.Ordinal);
public InternalClient(ulong id, ClientKind kind, Account account)
{
if (!kind.IsInternal())
throw new ArgumentException($"InternalClient requires an internal ClientKind, got {kind}", nameof(kind));
Id = id;
Kind = kind;
Account = account;
}
public void SendMessage(string subject, string sid, string? replyTo,
ReadOnlyMemory<byte> headers, ReadOnlyMemory<byte> payload)
{
MessageCallback?.Invoke(subject, sid, replyTo, headers, payload);
}
public bool QueueOutbound(ReadOnlyMemory<byte> data) => true; // no-op for internal clients
public void RemoveSubscription(string sid)
{
if (_subs.Remove(sid))
Account?.DecrementSubscriptions();
}
public void AddSubscription(Subscription sub)
{
_subs[sub.Sid] = sub;
}
public IReadOnlyDictionary<string, Subscription> Subscriptions => _subs;
}

View File

@@ -14,16 +14,12 @@ public sealed class SubszHandler(NatsServer server)
var opts = ParseQueryParams(ctx);
var now = DateTime.UtcNow;
// Collect subscriptions from all accounts (or filtered).
// Exclude the $SYS system account unless explicitly requested — its internal
// subscriptions are infrastructure and not user-facing.
// Collect subscriptions from all accounts (or filtered)
var allSubs = new List<Subscription>();
foreach (var account in server.GetAccounts())
{
if (!string.IsNullOrEmpty(opts.Account) && account.Name != opts.Account)
continue;
if (string.IsNullOrEmpty(opts.Account) && account.Name == "$SYS")
continue;
allSubs.AddRange(account.SubList.GetAllSubscriptions());
}
@@ -35,10 +31,10 @@ public sealed class SubszHandler(NatsServer server)
var total = allSubs.Count;
var numSubs = server.GetAccounts()
.Where(a => (string.IsNullOrEmpty(opts.Account) && a.Name != "$SYS") || a.Name == opts.Account)
.Where(a => string.IsNullOrEmpty(opts.Account) || a.Name == opts.Account)
.Aggregate(0u, (sum, a) => sum + a.SubList.Count);
var numCache = server.GetAccounts()
.Where(a => (string.IsNullOrEmpty(opts.Account) && a.Name != "$SYS") || a.Name == opts.Account)
.Where(a => string.IsNullOrEmpty(opts.Account) || a.Name == opts.Account)
.Sum(a => a.SubList.CacheCount);
SubDetail[] details = [];

View File

@@ -1,7 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<ItemGroup>
<InternalsVisibleTo Include="NATS.Server.Tests" />
</ItemGroup>
<ItemGroup>
<FrameworkReference Include="Microsoft.AspNetCore.App" />
<PackageReference Include="NATS.NKeys" />

View File

@@ -11,6 +11,7 @@ using NATS.Server.Auth;
using NATS.Server.Protocol;
using NATS.Server.Subscriptions;
using NATS.Server.Tls;
using NATS.Server.WebSocket;
namespace NATS.Server;
@@ -19,8 +20,6 @@ public interface IMessageRouter
void ProcessMessage(string subject, string? replyTo, ReadOnlyMemory<byte> headers,
ReadOnlyMemory<byte> payload, NatsClient sender);
void RemoveClient(NatsClient client);
void PublishConnectEvent(NatsClient client);
void PublishDisconnectEvent(NatsClient client);
}
public interface ISubListAccess
@@ -28,7 +27,7 @@ public interface ISubListAccess
SubList SubList { get; }
}
public sealed class NatsClient : INatsClient, IDisposable
public sealed class NatsClient : IDisposable
{
private readonly Socket _socket;
private readonly Stream _stream;
@@ -47,7 +46,6 @@ public sealed class NatsClient : INatsClient, IDisposable
private readonly ServerStats _serverStats;
public ulong Id { get; }
public ClientKind Kind => ClientKind.Client;
public ClientOptions? ClientOpts { get; private set; }
public IMessageRouter? Router { get; set; }
public Account? Account { get; private set; }
@@ -96,6 +94,9 @@ public sealed class NatsClient : INatsClient, IDisposable
private long _rtt;
public TimeSpan Rtt => new(Interlocked.Read(ref _rtt));
public bool IsWebSocket { get; set; }
public WsUpgradeResult? WsInfo { get; set; }
public TlsConnectionState? TlsState { get; set; }
public bool InfoAlreadySent { get; set; }
@@ -447,9 +448,6 @@ public sealed class NatsClient : INatsClient, IDisposable
_flags.SetFlag(ClientFlags.ConnectProcessFinished);
_logger.LogDebug("CONNECT received from client {ClientId}, name={ClientName}", Id, ClientOpts?.Name);
// Publish connect advisory to the system event bus
Router?.PublishConnectEvent(this);
// Start auth expiry timer if needed
if (_authService.IsAuthRequired && authResult?.Expiry is { } expiry)
{

View File

@@ -116,4 +116,32 @@ public sealed class NatsOptions
public Dictionary<string, string>? SubjectMappings { get; set; }
public bool HasTls => TlsCert != null && TlsKey != null;
// WebSocket
public WebSocketOptions WebSocket { get; set; } = new();
}
public sealed class WebSocketOptions
{
public string Host { get; set; } = "0.0.0.0";
public int Port { get; set; } = -1;
public string? Advertise { get; set; }
public string? NoAuthUser { get; set; }
public string? JwtCookie { get; set; }
public string? UsernameCookie { get; set; }
public string? PasswordCookie { get; set; }
public string? TokenCookie { get; set; }
public string? Username { get; set; }
public string? Password { get; set; }
public string? Token { get; set; }
public TimeSpan AuthTimeout { get; set; } = TimeSpan.FromSeconds(2);
public bool NoTls { get; set; }
public string? TlsCert { get; set; }
public string? TlsKey { get; set; }
public bool SameOrigin { get; set; }
public List<string>? AllowedOrigins { get; set; }
public bool Compression { get; set; }
public TimeSpan HandshakeTimeout { get; set; } = TimeSpan.FromSeconds(2);
public TimeSpan? PingInterval { get; set; }
public Dictionary<string, string>? Headers { get; set; }
}

View File

@@ -9,12 +9,11 @@ using Microsoft.Extensions.Logging;
using NATS.NKeys;
using NATS.Server.Auth;
using NATS.Server.Configuration;
using NATS.Server.Events;
using NATS.Server.Imports;
using NATS.Server.Monitoring;
using NATS.Server.Protocol;
using NATS.Server.Subscriptions;
using NATS.Server.Tls;
using NATS.Server.WebSocket;
namespace NATS.Server;
@@ -37,11 +36,12 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
private string? _configDigest;
private readonly Account _globalAccount;
private readonly Account _systemAccount;
private InternalEventSystem? _eventSystem;
private readonly SslServerAuthenticationOptions? _sslOptions;
private readonly TlsRateLimiter? _tlsRateLimiter;
private readonly SubjectTransform[] _subjectTransforms;
private Socket? _listener;
private Socket? _wsListener;
private readonly TaskCompletionSource _wsAcceptLoopExited = new(TaskCreationOptions.RunContinuationsAsynchronously);
private MonitorServer? _monitorServer;
private ulong _nextClientId;
private long _startTimeTicks;
@@ -73,7 +73,6 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
public int Port => _options.Port;
public Account SystemAccount => _systemAccount;
public string ServerNKey { get; }
public InternalEventSystem? EventSystem => _eventSystem;
public bool IsShuttingDown => Volatile.Read(ref _shutdown) != 0;
public bool IsLameDuckMode => Volatile.Read(ref _lameDuck) != 0;
public Action? ReOpenLogFile { get; set; }
@@ -94,29 +93,16 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
_logger.LogInformation("Initiating Shutdown...");
// Publish shutdown advisory before tearing down the event system
if (_eventSystem != null)
{
var shutdownSubject = string.Format(EventSubjects.ServerShutdown, _serverInfo.ServerId);
_eventSystem.Enqueue(new PublishMessage
{
Subject = shutdownSubject,
Body = new ShutdownEventMsg { Server = BuildEventServerInfo(), Reason = "Server Shutdown" },
IsLast = true,
});
// Give the send loop time to process the shutdown event
await Task.Delay(100);
await _eventSystem.DisposeAsync();
}
// Signal all internal loops to stop
await _quitCts.CancelAsync();
// Close listener to stop accept loop
// Close listeners to stop accept loops
_listener?.Close();
_wsListener?.Close();
// Wait for accept loop to exit
// Wait for accept loops to exit
await _acceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
await _wsAcceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
// Close all client connections — flush first, then mark closed
var flushTasks = new List<Task>();
@@ -157,11 +143,13 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
_logger.LogInformation("Entering lame duck mode, stop accepting new clients");
// Close listener to stop accepting new connections
// Close listeners to stop accepting new connections
_listener?.Close();
_wsListener?.Close();
// Wait for accept loop to exit
// Wait for accept loops to exit
await _acceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
await _wsAcceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
var gracePeriod = _options.LameDuckGracePeriod;
if (gracePeriod < TimeSpan.Zero) gracePeriod = -gracePeriod;
@@ -284,14 +272,6 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
_systemAccount = new Account("$SYS");
_accounts["$SYS"] = _systemAccount;
// Create system internal client and event system
var sysClientId = Interlocked.Increment(ref _nextClientId);
var sysClient = new InternalClient(sysClientId, ClientKind.System, _systemAccount);
_eventSystem = new InternalEventSystem(
_systemAccount, sysClient,
options.ServerName ?? $"nats-dotnet-{Environment.MachineName}",
_loggerFactory.CreateLogger<InternalEventSystem>());
// Generate Ed25519 server NKey identity
using var serverKeyPair = KeyPair.CreatePair(PrefixByte.Server);
ServerNKey = serverKeyPair.GetPublicKey();
@@ -396,11 +376,6 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
BuildCachedInfo();
}
_listeningStarted.TrySetResult();
_eventSystem?.Start(this);
_eventSystem?.InitEventTracking(this);
_logger.LogInformation("Listening for client connections on {Host}:{Port}", _options.Host, _options.Port);
// Warn about stub features
@@ -416,6 +391,31 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
WritePidFile();
WritePortsFile();
if (_options.WebSocket.Port >= 0)
{
_wsListener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_wsListener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true);
_wsListener.Bind(new IPEndPoint(
_options.WebSocket.Host == "0.0.0.0" ? IPAddress.Any : IPAddress.Parse(_options.WebSocket.Host),
_options.WebSocket.Port));
_wsListener.Listen(128);
if (_options.WebSocket.Port == 0)
{
_options.WebSocket.Port = ((IPEndPoint)_wsListener.LocalEndPoint!).Port;
}
_logger.LogInformation("Listening for WebSocket clients on {Host}:{Port}",
_options.WebSocket.Host, _options.WebSocket.Port);
if (_options.WebSocket.NoTls)
_logger.LogWarning("WebSocket not configured with TLS. DO NOT USE IN PRODUCTION!");
_ = RunWebSocketAcceptLoopAsync(linked.Token);
}
_listeningStarted.TrySetResult();
var tmpDelay = AcceptMinSleep;
try
@@ -561,6 +561,102 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
}
}
private async Task RunWebSocketAcceptLoopAsync(CancellationToken ct)
{
var tmpDelay = AcceptMinSleep;
try
{
while (!ct.IsCancellationRequested)
{
Socket socket;
try
{
socket = await _wsListener!.AcceptAsync(ct);
tmpDelay = AcceptMinSleep;
}
catch (OperationCanceledException) { break; }
catch (ObjectDisposedException) { break; }
catch (SocketException ex)
{
if (IsShuttingDown || IsLameDuckMode) break;
_logger.LogError(ex, "Temporary WebSocket accept error, sleeping {Delay}ms", tmpDelay.TotalMilliseconds);
try { await Task.Delay(tmpDelay, ct); } catch (OperationCanceledException) { break; }
tmpDelay = TimeSpan.FromTicks(Math.Min(tmpDelay.Ticks * 2, AcceptMaxSleep.Ticks));
continue;
}
if (_options.MaxConnections > 0 && _clients.Count >= _options.MaxConnections)
{
socket.Dispose();
continue;
}
var clientId = Interlocked.Increment(ref _nextClientId);
Interlocked.Increment(ref _stats.TotalConnections);
Interlocked.Increment(ref _activeClientCount);
_ = AcceptWebSocketClientAsync(socket, clientId, ct);
}
}
finally
{
_wsAcceptLoopExited.TrySetResult();
}
}
private async Task AcceptWebSocketClientAsync(Socket socket, ulong clientId, CancellationToken ct)
{
try
{
var networkStream = new NetworkStream(socket, ownsSocket: false);
Stream stream = networkStream;
// TLS negotiation if configured
if (_sslOptions != null && !_options.WebSocket.NoTls)
{
var (tlsStream, _) = await TlsConnectionWrapper.NegotiateAsync(
socket, networkStream, _options, _sslOptions, _serverInfo,
_loggerFactory.CreateLogger("NATS.Server.Tls"), ct);
stream = tlsStream;
}
// HTTP upgrade handshake
var upgradeResult = await WsUpgrade.TryUpgradeAsync(stream, stream, _options.WebSocket, ct);
if (!upgradeResult.Success)
{
_logger.LogDebug("WebSocket upgrade failed for client {ClientId}", clientId);
socket.Dispose();
Interlocked.Decrement(ref _activeClientCount);
return;
}
// Create WsConnection wrapper
var wsConn = new WsConnection(stream,
compress: upgradeResult.Compress,
maskRead: upgradeResult.MaskRead,
maskWrite: upgradeResult.MaskWrite,
browser: upgradeResult.Browser,
noCompFrag: upgradeResult.NoCompFrag);
var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]");
var client = new NatsClient(clientId, wsConn, socket, _options, _serverInfo,
_authService, null, clientLogger, _stats);
client.Router = this;
client.IsWebSocket = true;
client.WsInfo = upgradeResult;
_clients[clientId] = client;
await RunClientAsync(client, ct);
}
catch (Exception ex)
{
_logger.LogDebug(ex, "Failed to accept WebSocket client {ClientId}", clientId);
try { socket.Shutdown(SocketShutdown.Both); } catch { }
socket.Dispose();
Interlocked.Decrement(ref _activeClientCount);
}
}
private async Task RunClientAsync(NatsClient client, CancellationToken ct)
{
try
@@ -632,27 +728,6 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
}
}
// Check for service imports that match this subject.
// When a client in the importer account publishes to a subject
// that matches a service import "From" pattern, we forward the
// message to the destination (exporter) account's subscribers
// using the mapped "To" subject.
if (sender.Account != null)
{
foreach (var kvp in sender.Account.Imports.Services)
{
foreach (var si in kvp.Value)
{
if (si.Invalid) continue;
if (SubjectMatch.MatchLiteral(subject, si.From))
{
ProcessServiceImport(si, subject, replyTo, headers, payload);
delivered = true;
}
}
}
}
// No-responders: if nobody received the message and the publisher
// opted in, send back a 503 status HMSG on the reply subject.
if (!delivered && replyTo != null && sender.ClientOpts?.NoResponders == true)
@@ -692,153 +767,6 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
}
}
/// <summary>
/// Processes a service import by transforming the subject from the importer's
/// subject space to the exporter's subject space, then delivering to matching
/// subscribers in the destination account.
/// Reference: Go server/accounts.go addServiceImport / processServiceImport.
/// </summary>
public void ProcessServiceImport(ServiceImport si, string subject, string? replyTo,
ReadOnlyMemory<byte> headers, ReadOnlyMemory<byte> payload)
{
if (si.Invalid) return;
// Transform subject: map from importer subject space to exporter subject space
string targetSubject;
if (si.Transform != null)
{
var transformed = si.Transform.Apply(subject);
targetSubject = transformed ?? si.To;
}
else if (si.UsePub)
{
targetSubject = subject;
}
else
{
// Default: use the "To" subject from the import definition.
// For wildcard imports (e.g. "requests.>" -> "api.>"), we need
// to map the specific subject tokens from the source pattern to
// the destination pattern.
targetSubject = MapImportSubject(subject, si.From, si.To);
}
// Match against destination account's SubList
var destSubList = si.DestinationAccount.SubList;
var result = destSubList.Match(targetSubject);
// Deliver to plain subscribers in the destination account
foreach (var sub in result.PlainSubs)
{
if (sub.Client == null) continue;
DeliverMessage(sub, targetSubject, replyTo, headers, payload);
}
// Deliver to one member of each queue group
foreach (var queueGroup in result.QueueSubs)
{
if (queueGroup.Length == 0) continue;
var sub = queueGroup[0]; // Simple selection: first available
if (sub.Client != null)
DeliverMessage(sub, targetSubject, replyTo, headers, payload);
}
}
/// <summary>
/// Maps a published subject from the import "From" pattern to the "To" pattern.
/// For example, if From="requests.>" and To="api.>" and subject="requests.test",
/// this returns "api.test".
/// </summary>
private static string MapImportSubject(string subject, string fromPattern, string toPattern)
{
// If "To" doesn't contain wildcards, use it directly
if (SubjectMatch.IsLiteral(toPattern))
return toPattern;
// For wildcard patterns, replace matching wildcard segments.
// Split into tokens and map from source to destination.
var subTokens = subject.Split('.');
var fromTokens = fromPattern.Split('.');
var toTokens = toPattern.Split('.');
var result = new string[toTokens.Length];
int subIdx = 0;
// Build a mapping: for each wildcard position in "from",
// capture the corresponding subject token(s)
var wildcardValues = new List<string>();
string? fwcValue = null;
for (int i = 0; i < fromTokens.Length && subIdx < subTokens.Length; i++)
{
if (fromTokens[i] == "*")
{
wildcardValues.Add(subTokens[subIdx]);
subIdx++;
}
else if (fromTokens[i] == ">")
{
// Capture all remaining tokens
fwcValue = string.Join(".", subTokens[subIdx..]);
subIdx = subTokens.Length;
}
else
{
subIdx++; // Skip literal match
}
}
// Now build the output using the "to" pattern
int wcIdx = 0;
var sb = new StringBuilder();
for (int i = 0; i < toTokens.Length; i++)
{
if (i > 0) sb.Append('.');
if (toTokens[i] == "*")
{
sb.Append(wcIdx < wildcardValues.Count ? wildcardValues[wcIdx] : "*");
wcIdx++;
}
else if (toTokens[i] == ">")
{
sb.Append(fwcValue ?? ">");
}
else
{
sb.Append(toTokens[i]);
}
}
return sb.ToString();
}
/// <summary>
/// Wires service import subscriptions for an account. Creates marker
/// subscriptions in the account's SubList so that the import paths
/// are tracked. The actual forwarding happens in ProcessMessage when
/// it checks the account's Imports.Services.
/// Reference: Go server/accounts.go addServiceImportSub.
/// </summary>
public void WireServiceImports(Account account)
{
foreach (var kvp in account.Imports.Services)
{
foreach (var si in kvp.Value)
{
if (si.Invalid) continue;
// Create a marker subscription in the importer account.
// This subscription doesn't directly deliver messages;
// the ProcessMessage method checks service imports after
// the regular SubList match.
_logger.LogDebug(
"Wired service import for account {Account}: {From} -> {To} (dest: {DestAccount})",
account.Name, si.From, si.To, si.DestinationAccount.Name);
}
}
}
private static void SendNoResponders(NatsClient sender, string replyTo)
{
// Find the sid for a subscription matching the reply subject
@@ -884,194 +812,8 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
});
}
public void SendInternalMsg(string subject, string? reply, object? msg)
{
_eventSystem?.Enqueue(new PublishMessage { Subject = subject, Reply = reply, Body = msg });
}
public void SendInternalAccountMsg(Account account, string subject, object? msg)
{
_eventSystem?.Enqueue(new PublishMessage { Subject = subject, Body = msg });
}
/// <summary>
/// Handles $SYS.REQ.SERVER.{id}.VARZ requests.
/// Returns core server information including stats counters.
/// </summary>
public void HandleVarzRequest(string subject, string? reply)
{
if (reply == null) return;
var varz = new
{
server_id = _serverInfo.ServerId,
server_name = _serverInfo.ServerName,
version = NatsProtocol.Version,
host = _options.Host,
port = _options.Port,
max_payload = _options.MaxPayload,
connections = ClientCount,
total_connections = Interlocked.Read(ref _stats.TotalConnections),
in_msgs = Interlocked.Read(ref _stats.InMsgs),
out_msgs = Interlocked.Read(ref _stats.OutMsgs),
in_bytes = Interlocked.Read(ref _stats.InBytes),
out_bytes = Interlocked.Read(ref _stats.OutBytes),
};
SendInternalMsg(reply, null, varz);
}
/// <summary>
/// Handles $SYS.REQ.SERVER.{id}.HEALTHZ requests.
/// Returns a simple health status response.
/// </summary>
public void HandleHealthzRequest(string subject, string? reply)
{
if (reply == null) return;
SendInternalMsg(reply, null, new { status = "ok" });
}
/// <summary>
/// Handles $SYS.REQ.SERVER.{id}.SUBSZ requests.
/// Returns the current subscription count.
/// </summary>
public void HandleSubszRequest(string subject, string? reply)
{
if (reply == null) return;
SendInternalMsg(reply, null, new { num_subscriptions = SubList.Count });
}
/// <summary>
/// Handles $SYS.REQ.SERVER.{id}.STATSZ requests.
/// Publishes current server statistics through the event system.
/// </summary>
public void HandleStatszRequest(string subject, string? reply)
{
if (reply == null) return;
var process = System.Diagnostics.Process.GetCurrentProcess();
var statsMsg = new Events.ServerStatsMsg
{
Server = BuildEventServerInfo(),
Stats = new Events.ServerStatsData
{
Start = StartTime,
Mem = process.WorkingSet64,
Cores = Environment.ProcessorCount,
Connections = ClientCount,
TotalConnections = Interlocked.Read(ref _stats.TotalConnections),
Subscriptions = SubList.Count,
InMsgs = Interlocked.Read(ref _stats.InMsgs),
OutMsgs = Interlocked.Read(ref _stats.OutMsgs),
InBytes = Interlocked.Read(ref _stats.InBytes),
OutBytes = Interlocked.Read(ref _stats.OutBytes),
SlowConsumers = Interlocked.Read(ref _stats.SlowConsumers),
},
};
SendInternalMsg(reply, null, statsMsg);
}
/// <summary>
/// Handles $SYS.REQ.SERVER.{id}.IDZ requests.
/// Returns basic server identity information.
/// </summary>
public void HandleIdzRequest(string subject, string? reply)
{
if (reply == null) return;
var idz = new
{
server_id = _serverInfo.ServerId,
server_name = _serverInfo.ServerName,
version = NatsProtocol.Version,
host = _options.Host,
port = _options.Port,
};
SendInternalMsg(reply, null, idz);
}
/// <summary>
/// Builds an EventServerInfo block for embedding in system event messages.
/// Maps to Go's serverInfo() helper used in events.go advisory publishing.
/// </summary>
public EventServerInfo BuildEventServerInfo()
{
var seq = _eventSystem?.NextSequence() ?? 0;
return new EventServerInfo
{
Name = _serverInfo.ServerName,
Host = _options.Host,
Id = _serverInfo.ServerId,
Version = NatsProtocol.Version,
Seq = seq,
};
}
private static EventClientInfo BuildEventClientInfo(NatsClient client)
{
return new EventClientInfo
{
Id = client.Id,
Host = client.RemoteIp,
Account = client.Account?.Name,
Name = client.ClientOpts?.Name,
Lang = client.ClientOpts?.Lang,
Version = client.ClientOpts?.Version,
Start = client.StartTime,
};
}
/// <summary>
/// Publishes a $SYS.ACCOUNT.{account}.CONNECT advisory when a client
/// completes authentication. Maps to Go's sendConnectEvent in events.go.
/// </summary>
public void PublishConnectEvent(NatsClient client)
{
if (_eventSystem == null) return;
var accountName = client.Account?.Name ?? Account.GlobalAccountName;
var subject = string.Format(EventSubjects.ConnectEvent, accountName);
var evt = new ConnectEventMsg
{
Id = Guid.NewGuid().ToString("N"),
Time = DateTime.UtcNow,
Server = BuildEventServerInfo(),
Client = BuildEventClientInfo(client),
};
SendInternalMsg(subject, null, evt);
}
/// <summary>
/// Publishes a $SYS.ACCOUNT.{account}.DISCONNECT advisory when a client
/// disconnects. Maps to Go's sendDisconnectEvent in events.go.
/// </summary>
public void PublishDisconnectEvent(NatsClient client)
{
if (_eventSystem == null) return;
var accountName = client.Account?.Name ?? Account.GlobalAccountName;
var subject = string.Format(EventSubjects.DisconnectEvent, accountName);
var evt = new DisconnectEventMsg
{
Id = Guid.NewGuid().ToString("N"),
Time = DateTime.UtcNow,
Server = BuildEventServerInfo(),
Client = BuildEventClientInfo(client),
Sent = new DataStats
{
Msgs = Interlocked.Read(ref client.OutMsgs),
Bytes = Interlocked.Read(ref client.OutBytes),
},
Received = new DataStats
{
Msgs = Interlocked.Read(ref client.InMsgs),
Bytes = Interlocked.Read(ref client.InBytes),
},
Reason = client.CloseReason.ToReasonString(),
};
SendInternalMsg(subject, null, evt);
}
public void RemoveClient(NatsClient client)
{
// Publish disconnect advisory before removing client state
if (client.ConnectReceived)
PublishDisconnectEvent(client);
_clients.TryRemove(client.Id, out _);
_logger.LogDebug("Removed client {ClientId}", client.Id);
@@ -1326,6 +1068,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
_quitCts.Dispose();
_tlsRateLimiter?.Dispose();
_listener?.Dispose();
_wsListener?.Dispose();
foreach (var client in _clients.Values)
client.Dispose();
foreach (var account in _accounts.Values)

View File

@@ -1,5 +1,4 @@
using NATS.Server;
using NATS.Server.Imports;
namespace NATS.Server.Subscriptions;
@@ -10,7 +9,5 @@ public sealed class Subscription
public required string Sid { get; init; }
public long MessageCount; // Interlocked
public long MaxMessages; // 0 = unlimited
public INatsClient? Client { get; set; }
public ServiceImport? ServiceImport { get; set; }
public StreamImport? StreamImport { get; set; }
public NatsClient? Client { get; set; }
}

View File

@@ -0,0 +1,94 @@
using System.IO.Compression;
namespace NATS.Server.WebSocket;
/// <summary>
/// permessage-deflate compression/decompression for WebSocket frames (RFC 7692).
/// Ported from golang/nats-server/server/websocket.go lines 403-440 and 1391-1466.
/// </summary>
public static class WsCompression
{
/// <summary>
/// Compresses data using deflate. Removes trailing 4 bytes (sync marker)
/// per RFC 7692 Section 7.2.1.
/// </summary>
/// <remarks>
/// We call Flush() but intentionally do not Dispose() the DeflateStream before
/// reading output, because Dispose writes a final deflate block (0x03 0x00) that
/// would be corrupted by the 4-byte tail strip. Flush() alone writes a sync flush
/// ending with 0x00 0x00 0xff 0xff, matching Go's flate.Writer.Flush() behavior.
/// </remarks>
public static byte[] Compress(ReadOnlySpan<byte> data)
{
var output = new MemoryStream();
var deflate = new DeflateStream(output, CompressionLevel.Fastest, leaveOpen: true);
try
{
deflate.Write(data);
deflate.Flush();
var compressed = output.ToArray();
// Remove trailing 4-byte sync marker (0x00 0x00 0xff 0xff) per RFC 7692
if (compressed.Length >= 4)
return compressed[..^4];
return compressed;
}
finally
{
deflate.Dispose();
output.Dispose();
}
}
/// <summary>
/// Decompresses collected compressed buffers.
/// Appends trailer bytes before decompressing per RFC 7692 Section 7.2.2.
/// Ported from golang/nats-server/server/websocket.go lines 403-440.
/// The Go code appends compressLastBlock (9 bytes) which includes the sync
/// marker plus a final empty stored block to signal end-of-stream to the
/// flate reader.
/// </summary>
public static byte[] Decompress(List<byte[]> compressedBuffers, int maxPayload)
{
if (maxPayload <= 0)
maxPayload = 1024 * 1024; // Default 1MB
// Concatenate all compressed buffers + trailer.
// Per RFC 7692 Section 7.2.2, append the sync flush marker (0x00 0x00 0xff 0xff)
// that was stripped during compression. The Go reference appends compressLastBlock
// (9 bytes) for Go's flate reader; .NET's DeflateStream only needs the 4-byte trailer.
int totalLen = 0;
foreach (var buf in compressedBuffers)
totalLen += buf.Length;
totalLen += WsConstants.DecompressTrailer.Length;
var combined = new byte[totalLen];
int offset = 0;
foreach (var buf in compressedBuffers)
{
buf.CopyTo(combined, offset);
offset += buf.Length;
}
WsConstants.DecompressTrailer.CopyTo(combined, offset);
using var input = new MemoryStream(combined);
using var deflate = new DeflateStream(input, CompressionMode.Decompress);
using var output = new MemoryStream();
var readBuf = new byte[4096];
int totalRead = 0;
int n;
while ((n = deflate.Read(readBuf, 0, readBuf.Length)) > 0)
{
totalRead += n;
if (totalRead > maxPayload)
throw new InvalidOperationException("decompressed data exceeds maximum payload size");
output.Write(readBuf, 0, n);
}
return output.ToArray();
}
}

View File

@@ -0,0 +1,202 @@
namespace NATS.Server.WebSocket;
/// <summary>
/// Stream wrapper that transparently frames/deframes WebSocket around raw TCP I/O.
/// NatsClient uses this as its _stream -- FillPipeAsync and RunWriteLoopAsync work unchanged.
/// Ported from golang/nats-server/server/websocket.go wsUpgrade/wrapWebsocket pattern.
/// </summary>
public sealed class WsConnection : Stream
{
private readonly Stream _inner;
private readonly bool _compress;
private readonly bool _maskRead;
private readonly bool _maskWrite;
private readonly bool _browser;
private readonly bool _noCompFrag;
private WsReadInfo _readInfo;
// Read-side state: accessed only from the single FillPipeAsync reader task (no synchronization needed)
private readonly Queue<byte[]> _readQueue = new();
private int _readOffset;
private readonly object _writeLock = new();
private readonly List<ControlFrameAction> _pendingControlWrites = [];
public bool CloseReceived => _readInfo.CloseReceived;
public int CloseStatus => _readInfo.CloseStatus;
public WsConnection(Stream inner, bool compress, bool maskRead, bool maskWrite, bool browser, bool noCompFrag)
{
_inner = inner;
_compress = compress;
_maskRead = maskRead;
_maskWrite = maskWrite;
_browser = browser;
_noCompFrag = noCompFrag;
_readInfo = new WsReadInfo(expectMask: maskRead);
}
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken ct = default)
{
// Drain any buffered decoded payloads first
if (_readQueue.Count > 0)
return DrainReadQueue(buffer.Span);
while (true)
{
// Read raw bytes from inner stream
var rawBuf = new byte[Math.Max(buffer.Length, 4096)];
int bytesRead = await _inner.ReadAsync(rawBuf.AsMemory(), ct);
if (bytesRead == 0) return 0;
// Decode frames
var payloads = WsReadInfo.ReadFrames(_readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024);
// Collect control frame responses
if (_readInfo.PendingControlFrames.Count > 0)
{
lock (_writeLock)
_pendingControlWrites.AddRange(_readInfo.PendingControlFrames);
_readInfo.PendingControlFrames.Clear();
// Write pending control frames
await FlushControlFramesAsync(ct);
}
if (_readInfo.CloseReceived)
return 0;
foreach (var payload in payloads)
_readQueue.Enqueue(payload);
// If no payloads were decoded (e.g. only frame headers were read),
// continue reading instead of returning 0 which signals end-of-stream
if (_readQueue.Count > 0)
return DrainReadQueue(buffer.Span);
}
}
public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken ct = default)
{
var data = buffer.Span;
if (_compress && data.Length > WsConstants.CompressThreshold)
{
var compressed = WsCompression.Compress(data);
await WriteFramedAsync(compressed, compressed: true, ct);
}
else
{
await WriteFramedAsync(data.ToArray(), compressed: false, ct);
}
}
private async ValueTask WriteFramedAsync(byte[] payload, bool compressed, CancellationToken ct)
{
if (_browser && payload.Length > WsConstants.FrameSizeForBrowsers && !(_noCompFrag && compressed))
{
// Fragment for browsers
int offset = 0;
bool first = true;
while (offset < payload.Length)
{
int chunkLen = Math.Min(WsConstants.FrameSizeForBrowsers, payload.Length - offset);
bool final = offset + chunkLen >= payload.Length;
var fh = new byte[WsConstants.MaxFrameHeaderSize];
var (n, key) = WsFrameWriter.FillFrameHeader(fh, _maskWrite,
first: first, final: final, compressed: first && compressed,
opcode: WsConstants.BinaryMessage, payloadLength: chunkLen);
var chunk = payload.AsSpan(offset, chunkLen).ToArray();
if (_maskWrite && key != null)
WsFrameWriter.MaskBuf(key, chunk);
await _inner.WriteAsync(fh.AsMemory(0, n), ct);
await _inner.WriteAsync(chunk.AsMemory(), ct);
offset += chunkLen;
first = false;
}
}
else
{
var (header, key) = WsFrameWriter.CreateFrameHeader(_maskWrite, compressed, WsConstants.BinaryMessage, payload.Length);
if (_maskWrite && key != null)
WsFrameWriter.MaskBuf(key, payload);
await _inner.WriteAsync(header.AsMemory(), ct);
await _inner.WriteAsync(payload.AsMemory(), ct);
}
}
private async Task FlushControlFramesAsync(CancellationToken ct)
{
List<ControlFrameAction> toWrite;
lock (_writeLock)
{
if (_pendingControlWrites.Count == 0) return;
toWrite = [.. _pendingControlWrites];
_pendingControlWrites.Clear();
}
foreach (var action in toWrite)
{
var frame = WsFrameWriter.BuildControlFrame(action.Opcode, action.Payload, _maskWrite);
await _inner.WriteAsync(frame, ct);
}
await _inner.FlushAsync(ct);
}
/// <summary>
/// Sends a WebSocket close frame.
/// </summary>
public async Task SendCloseAsync(ClientClosedReason reason, CancellationToken ct = default)
{
var status = WsFrameWriter.MapCloseStatus(reason);
var closePayload = WsFrameWriter.CreateCloseMessage(status, reason.ToReasonString());
var frame = WsFrameWriter.BuildControlFrame(WsConstants.CloseMessage, closePayload, _maskWrite);
await _inner.WriteAsync(frame, ct);
await _inner.FlushAsync(ct);
}
private int DrainReadQueue(Span<byte> buffer)
{
int written = 0;
while (_readQueue.Count > 0 && written < buffer.Length)
{
var current = _readQueue.Peek();
int available = current.Length - _readOffset;
int toCopy = Math.Min(available, buffer.Length - written);
current.AsSpan(_readOffset, toCopy).CopyTo(buffer[written..]);
written += toCopy;
_readOffset += toCopy;
if (_readOffset >= current.Length)
{
_readQueue.Dequeue();
_readOffset = 0;
}
}
return written;
}
// Stream abstract members
public override bool CanRead => true;
public override bool CanWrite => true;
public override bool CanSeek => false;
public override long Length => throw new NotSupportedException();
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
public override void Flush() => _inner.Flush();
public override Task FlushAsync(CancellationToken ct) => _inner.FlushAsync(ct);
public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use ReadAsync");
public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use WriteAsync");
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
public override void SetLength(long value) => throw new NotSupportedException();
protected override void Dispose(bool disposing)
{
if (disposing)
_inner.Dispose();
base.Dispose(disposing);
}
public override async ValueTask DisposeAsync()
{
await _inner.DisposeAsync();
GC.SuppressFinalize(this);
}
}

View File

@@ -0,0 +1,72 @@
namespace NATS.Server.WebSocket;
/// <summary>
/// WebSocket protocol constants (RFC 6455).
/// Ported from golang/nats-server/server/websocket.go lines 41-106.
/// </summary>
public static class WsConstants
{
// Opcodes (RFC 6455 Section 5.2)
public const int TextMessage = 1;
public const int BinaryMessage = 2;
public const int CloseMessage = 8;
public const int PingMessage = 9;
public const int PongMessage = 10;
public const int ContinuationFrame = 0;
// Frame header bits
public const byte FinalBit = 0x80; // 1 << 7
public const byte Rsv1Bit = 0x40; // 1 << 6 (compression, RFC 7692)
public const byte Rsv2Bit = 0x20; // 1 << 5
public const byte Rsv3Bit = 0x10; // 1 << 4
public const byte MaskBit = 0x80; // 1 << 7 (in second byte)
// Frame size limits
public const int MaxFrameHeaderSize = 14;
public const int MaxControlPayloadSize = 125;
public const int FrameSizeForBrowsers = 4096;
public const int CompressThreshold = 64;
public const int CloseStatusSize = 2;
// Close status codes (RFC 6455 Section 11.7)
public const int CloseStatusNormalClosure = 1000;
public const int CloseStatusGoingAway = 1001;
public const int CloseStatusProtocolError = 1002;
public const int CloseStatusUnsupportedData = 1003;
public const int CloseStatusNoStatusReceived = 1005;
public const int CloseStatusInvalidPayloadData = 1007;
public const int CloseStatusPolicyViolation = 1008;
public const int CloseStatusMessageTooBig = 1009;
public const int CloseStatusInternalSrvError = 1011;
public const int CloseStatusTlsHandshake = 1015;
// Compression constants (RFC 7692)
public const string PmcExtension = "permessage-deflate";
public const string PmcSrvNoCtx = "server_no_context_takeover";
public const string PmcCliNoCtx = "client_no_context_takeover";
public static readonly string PmcReqHeaderValue = $"{PmcExtension}; {PmcSrvNoCtx}; {PmcCliNoCtx}";
public static readonly string PmcFullResponse = $"Sec-WebSocket-Extensions: {PmcExtension}; {PmcSrvNoCtx}; {PmcCliNoCtx}\r\n";
// Header names
public const string NoMaskingHeader = "Nats-No-Masking";
public const string NoMaskingValue = "true";
public static readonly string NoMaskingFullResponse = $"{NoMaskingHeader}: {NoMaskingValue}\r\n";
public const string XForwardedForHeader = "X-Forwarded-For";
// Path routing
public const string ClientPath = "/";
public const string LeafNodePath = "/leafnode";
public const string MqttPath = "/mqtt";
// Decompression trailer appended before decompressing (RFC 7692 Section 7.2.2)
public static readonly byte[] DecompressTrailer = [0x00, 0x00, 0xff, 0xff];
public static bool IsControlFrame(int opcode) => opcode >= CloseMessage;
}
public enum WsClientKind
{
Client,
Leaf,
Mqtt,
}

View File

@@ -0,0 +1,171 @@
using System.Buffers.Binary;
using System.Security.Cryptography;
using System.Text;
namespace NATS.Server.WebSocket;
/// <summary>
/// WebSocket frame construction, masking, and control message creation.
/// Ported from golang/nats-server/server/websocket.go lines 543-726.
/// </summary>
public static class WsFrameWriter
{
/// <summary>
/// Creates a complete frame header for a single-frame message (first=true, final=true).
/// Returns (header bytes, mask key or null).
/// </summary>
public static (byte[] header, byte[]? key) CreateFrameHeader(
bool useMasking, bool compressed, int opcode, int payloadLength)
{
var fh = new byte[WsConstants.MaxFrameHeaderSize];
var (n, key) = FillFrameHeader(fh, useMasking,
first: true, final: true, compressed: compressed, opcode: opcode, payloadLength: payloadLength);
return (fh[..n], key);
}
/// <summary>
/// Fills a pre-allocated frame header buffer.
/// Returns (bytes written, mask key or null).
/// </summary>
public static (int written, byte[]? key) FillFrameHeader(
Span<byte> fh, bool useMasking, bool first, bool final, bool compressed, int opcode, int payloadLength)
{
byte b0 = first ? (byte)opcode : (byte)0;
if (final) b0 |= WsConstants.FinalBit;
if (compressed) b0 |= WsConstants.Rsv1Bit;
byte b1 = 0;
if (useMasking) b1 |= WsConstants.MaskBit;
int n;
switch (payloadLength)
{
case <= 125:
n = 2;
fh[0] = b0;
fh[1] = (byte)(b1 | (byte)payloadLength);
break;
case < 65536:
n = 4;
fh[0] = b0;
fh[1] = (byte)(b1 | 126);
BinaryPrimitives.WriteUInt16BigEndian(fh[2..], (ushort)payloadLength);
break;
default:
n = 10;
fh[0] = b0;
fh[1] = (byte)(b1 | 127);
BinaryPrimitives.WriteUInt64BigEndian(fh[2..], (ulong)payloadLength);
break;
}
byte[]? key = null;
if (useMasking)
{
key = new byte[4];
RandomNumberGenerator.Fill(key);
key.CopyTo(fh[n..]);
n += 4;
}
return (n, key);
}
/// <summary>
/// XOR masks a buffer with a 4-byte key. Applies in-place.
/// </summary>
public static void MaskBuf(ReadOnlySpan<byte> key, Span<byte> buf)
{
for (int i = 0; i < buf.Length; i++)
buf[i] ^= key[i & 3];
}
/// <summary>
/// XOR masks multiple contiguous buffers as if they were one.
/// </summary>
public static void MaskBufs(ReadOnlySpan<byte> key, List<byte[]> bufs)
{
int pos = 0;
foreach (var buf in bufs)
{
for (int j = 0; j < buf.Length; j++)
{
buf[j] ^= key[pos & 3];
pos++;
}
}
}
/// <summary>
/// Creates a close message payload: 2-byte status code + optional UTF-8 body.
/// Body truncated to fit MaxControlPayloadSize with "..." suffix.
/// </summary>
public static byte[] CreateCloseMessage(int status, string body)
{
var bodyBytes = Encoding.UTF8.GetBytes(body);
int maxBody = WsConstants.MaxControlPayloadSize - WsConstants.CloseStatusSize;
if (bodyBytes.Length > maxBody)
{
var suffix = "..."u8;
int truncLen = maxBody - suffix.Length;
// Find a valid UTF-8 boundary by walking back from truncation point
while (truncLen > 0 && (bodyBytes[truncLen] & 0xC0) == 0x80)
truncLen--;
var buf = new byte[WsConstants.CloseStatusSize + truncLen + suffix.Length];
BinaryPrimitives.WriteUInt16BigEndian(buf, (ushort)status);
bodyBytes.AsSpan(0, truncLen).CopyTo(buf.AsSpan(WsConstants.CloseStatusSize));
suffix.CopyTo(buf.AsSpan(WsConstants.CloseStatusSize + truncLen));
return buf;
}
var result = new byte[WsConstants.CloseStatusSize + bodyBytes.Length];
BinaryPrimitives.WriteUInt16BigEndian(result, (ushort)status);
bodyBytes.CopyTo(result.AsSpan(WsConstants.CloseStatusSize));
return result;
}
/// <summary>
/// Builds a complete control frame (header + payload, optional masking).
/// </summary>
public static byte[] BuildControlFrame(int opcode, ReadOnlySpan<byte> payload, bool useMasking)
{
int headerSize = 2 + (useMasking ? 4 : 0);
var frame = new byte[headerSize + payload.Length];
var span = frame.AsSpan();
var (n, key) = FillFrameHeader(span, useMasking,
first: true, final: true, compressed: false, opcode: opcode, payloadLength: payload.Length);
if (payload.Length > 0)
{
payload.CopyTo(span[n..]);
if (useMasking && key != null)
MaskBuf(key, span[n..]);
}
return frame;
}
/// <summary>
/// Maps a ClientClosedReason to a WebSocket close status code.
/// Matches Go wsEnqueueCloseMessage in websocket.go lines 668-694.
/// </summary>
public static int MapCloseStatus(ClientClosedReason reason) => reason switch
{
ClientClosedReason.ClientClosed => WsConstants.CloseStatusNormalClosure,
ClientClosedReason.AuthenticationTimeout or
ClientClosedReason.AuthenticationViolation or
ClientClosedReason.SlowConsumerPendingBytes or
ClientClosedReason.SlowConsumerWriteDeadline or
ClientClosedReason.MaxSubscriptionsExceeded or
ClientClosedReason.AuthenticationExpired => WsConstants.CloseStatusPolicyViolation,
ClientClosedReason.TlsHandshakeError => WsConstants.CloseStatusTlsHandshake,
ClientClosedReason.ParseError or
ClientClosedReason.ProtocolViolation => WsConstants.CloseStatusProtocolError,
ClientClosedReason.MaxPayloadExceeded => WsConstants.CloseStatusMessageTooBig,
ClientClosedReason.WriteError or
ClientClosedReason.ReadError or
ClientClosedReason.StaleConnection or
ClientClosedReason.ServerShutdown => WsConstants.CloseStatusGoingAway,
_ => WsConstants.CloseStatusInternalSrvError,
};
}

View File

@@ -0,0 +1,81 @@
namespace NATS.Server.WebSocket;
/// <summary>
/// Validates WebSocket Origin headers per RFC 6455 Section 10.2.
/// Ported from golang/nats-server/server/websocket.go lines 933-1000.
/// </summary>
public sealed class WsOriginChecker
{
private readonly bool _sameOrigin;
private readonly Dictionary<string, AllowedOrigin>? _allowedOrigins;
public WsOriginChecker(bool sameOrigin, List<string>? allowedOrigins)
{
_sameOrigin = sameOrigin;
if (allowedOrigins is { Count: > 0 })
{
_allowedOrigins = new Dictionary<string, AllowedOrigin>(StringComparer.OrdinalIgnoreCase);
foreach (var ao in allowedOrigins)
{
if (Uri.TryCreate(ao, UriKind.Absolute, out var uri))
{
var (host, port) = GetHostAndPort(uri.Scheme == "https", uri.Host, uri.Port);
_allowedOrigins[host] = new AllowedOrigin(uri.Scheme, port);
}
}
}
}
/// <summary>
/// Returns null if origin is allowed, or an error message if rejected.
/// </summary>
public string? CheckOrigin(string? origin, string requestHost, bool isTls)
{
if (!_sameOrigin && _allowedOrigins == null)
return null;
if (string.IsNullOrEmpty(origin))
return null;
if (!Uri.TryCreate(origin, UriKind.Absolute, out var originUri))
return $"invalid origin: {origin}";
var (oh, op) = GetHostAndPort(originUri.Scheme == "https", originUri.Host, originUri.Port);
if (_sameOrigin)
{
var (rh, rp) = ParseHostPort(requestHost, isTls);
if (!string.Equals(oh, rh, StringComparison.OrdinalIgnoreCase) || op != rp)
return "not same origin";
}
if (_allowedOrigins != null)
{
if (!_allowedOrigins.TryGetValue(oh, out var allowed) ||
!string.Equals(originUri.Scheme, allowed.Scheme, StringComparison.OrdinalIgnoreCase) ||
op != allowed.Port)
{
return "not in the allowed list";
}
}
return null;
}
private static (string host, int port) GetHostAndPort(bool tls, string host, int port)
{
if (port <= 0)
port = tls ? 443 : 80;
return (host.ToLowerInvariant(), port);
}
private static (string host, int port) ParseHostPort(string hostPort, bool isTls)
{
var colonIdx = hostPort.LastIndexOf(':');
if (colonIdx > 0 && int.TryParse(hostPort.AsSpan(colonIdx + 1), out var port))
return (hostPort[..colonIdx].ToLowerInvariant(), port);
return (hostPort.ToLowerInvariant(), isTls ? 443 : 80);
}
private readonly record struct AllowedOrigin(string Scheme, int Port);
}

View File

@@ -0,0 +1,322 @@
using System.Buffers.Binary;
using System.Text;
namespace NATS.Server.WebSocket;
/// <summary>
/// Per-connection WebSocket frame reading state machine.
/// Ported from golang/nats-server/server/websocket.go lines 156-506.
/// </summary>
public class WsReadInfo
{
public int Remaining;
public bool FrameStart;
public bool FirstFrame;
public bool FrameCompressed;
public bool ExpectMask;
public byte MaskKeyPos;
public byte[] MaskKey;
public List<byte[]>? CompressedBuffers;
public int CompressedOffset;
// Control frame outputs
public List<ControlFrameAction> PendingControlFrames;
public bool CloseReceived;
public int CloseStatus;
public string? CloseBody;
public WsReadInfo(bool expectMask)
{
Remaining = 0;
FrameStart = true;
FirstFrame = true;
FrameCompressed = false;
ExpectMask = expectMask;
MaskKeyPos = 0;
MaskKey = new byte[4];
CompressedBuffers = null;
CompressedOffset = 0;
PendingControlFrames = [];
CloseReceived = false;
CloseStatus = 0;
CloseBody = null;
}
public void SetMaskKey(ReadOnlySpan<byte> key)
{
key[..4].CopyTo(MaskKey);
MaskKeyPos = 0;
}
/// <summary>
/// Unmask buffer in-place using current mask key and position.
/// Optimized for 8-byte chunks when buffer is large enough.
/// Ported from websocket.go lines 509-536.
/// </summary>
public void Unmask(Span<byte> buf)
{
int p = MaskKeyPos;
if (buf.Length < 16)
{
for (int i = 0; i < buf.Length; i++)
{
buf[i] ^= MaskKey[p & 3];
p++;
}
MaskKeyPos = (byte)(p & 3);
return;
}
// Build 8-byte key for bulk XOR
Span<byte> k = stackalloc byte[8];
for (int i = 0; i < 8; i++)
k[i] = MaskKey[(p + i) & 3];
ulong km = BinaryPrimitives.ReadUInt64BigEndian(k);
int n = (buf.Length / 8) * 8;
for (int i = 0; i < n; i += 8)
{
ulong tmp = BinaryPrimitives.ReadUInt64BigEndian(buf[i..]);
tmp ^= km;
BinaryPrimitives.WriteUInt64BigEndian(buf[i..], tmp);
}
// Handle remaining bytes
p += n;
var tail = buf[n..];
for (int i = 0; i < tail.Length; i++)
{
tail[i] ^= MaskKey[p & 3];
p++;
}
MaskKeyPos = (byte)(p & 3);
}
/// <summary>
/// Read and decode WebSocket frames from a buffer.
/// Returns list of decoded payload byte arrays.
/// Ported from websocket.go lines 208-351.
/// </summary>
public static List<byte[]> ReadFrames(WsReadInfo r, Stream stream, int available, int maxPayload)
{
var bufs = new List<byte[]>();
var buf = new byte[available];
int bytesRead = 0;
// Fill the buffer from the stream
while (bytesRead < available)
{
int n = stream.Read(buf, bytesRead, available - bytesRead);
if (n == 0) break;
bytesRead += n;
}
int pos = 0;
int max = bytesRead;
while (pos < max)
{
if (r.FrameStart)
{
if (pos >= max) break;
byte b0 = buf[pos];
int frameType = b0 & 0x0F;
bool final = (b0 & WsConstants.FinalBit) != 0;
bool compressed = (b0 & WsConstants.Rsv1Bit) != 0;
pos++;
// Read second byte
var (b1Buf, newPos) = WsGet(stream, buf, pos, max, 1);
pos = newPos;
byte b1 = b1Buf[0];
// Check mask bit
if (r.ExpectMask && (b1 & WsConstants.MaskBit) == 0)
throw new InvalidOperationException("mask bit missing");
r.Remaining = b1 & 0x7F;
// Validate frame types
if (WsConstants.IsControlFrame(frameType))
{
if (r.Remaining > WsConstants.MaxControlPayloadSize)
throw new InvalidOperationException("control frame length too large");
if (!final)
throw new InvalidOperationException("control frame does not have final bit set");
}
else if (frameType == WsConstants.TextMessage || frameType == WsConstants.BinaryMessage)
{
if (!r.FirstFrame)
throw new InvalidOperationException("new message before previous finished");
r.FirstFrame = final;
r.FrameCompressed = compressed;
}
else if (frameType == WsConstants.ContinuationFrame)
{
if (r.FirstFrame || compressed)
throw new InvalidOperationException("invalid continuation frame");
r.FirstFrame = final;
}
else
{
throw new InvalidOperationException($"unknown opcode {frameType}");
}
// Extended payload length
switch (r.Remaining)
{
case 126:
{
var (lenBuf, p2) = WsGet(stream, buf, pos, max, 2);
pos = p2;
r.Remaining = BinaryPrimitives.ReadUInt16BigEndian(lenBuf);
break;
}
case 127:
{
var (lenBuf, p2) = WsGet(stream, buf, pos, max, 8);
pos = p2;
var len64 = BinaryPrimitives.ReadUInt64BigEndian(lenBuf);
if (len64 > (ulong)maxPayload)
throw new InvalidOperationException($"frame payload length {len64} exceeds max payload {maxPayload}");
r.Remaining = (int)len64;
break;
}
}
// Read mask key (mask bit already validated at line 134)
if (r.ExpectMask)
{
var (keyBuf, p2) = WsGet(stream, buf, pos, max, 4);
pos = p2;
keyBuf.AsSpan(0, 4).CopyTo(r.MaskKey);
r.MaskKeyPos = 0;
}
// Handle control frames
if (WsConstants.IsControlFrame(frameType))
{
pos = HandleControlFrame(r, frameType, stream, buf, pos, max);
continue;
}
r.FrameStart = false;
}
if (pos < max)
{
int n = r.Remaining;
if (pos + n > max) n = max - pos;
var payloadSlice = buf.AsSpan(pos, n).ToArray();
pos += n;
r.Remaining -= n;
if (r.ExpectMask)
r.Unmask(payloadSlice);
bool addToBufs = true;
if (r.FrameCompressed)
{
addToBufs = false;
r.CompressedBuffers ??= [];
r.CompressedBuffers.Add(payloadSlice);
if (r.FirstFrame && r.Remaining == 0)
{
var decompressed = WsCompression.Decompress(r.CompressedBuffers, maxPayload);
r.CompressedBuffers = null;
r.FrameCompressed = false;
addToBufs = true;
payloadSlice = decompressed;
}
}
if (addToBufs && payloadSlice.Length > 0)
bufs.Add(payloadSlice);
if (r.Remaining == 0)
r.FrameStart = true;
}
}
return bufs;
}
private static int HandleControlFrame(WsReadInfo r, int frameType, Stream stream, byte[] buf, int pos, int max)
{
byte[]? payload = null;
if (r.Remaining > 0)
{
var (payloadBuf, newPos) = WsGet(stream, buf, pos, max, r.Remaining);
pos = newPos;
payload = payloadBuf;
if (r.ExpectMask)
r.Unmask(payload);
r.Remaining = 0;
}
switch (frameType)
{
case WsConstants.CloseMessage:
r.CloseReceived = true;
r.CloseStatus = WsConstants.CloseStatusNoStatusReceived;
if (payload != null && payload.Length >= WsConstants.CloseStatusSize)
{
r.CloseStatus = BinaryPrimitives.ReadUInt16BigEndian(payload);
if (payload.Length > WsConstants.CloseStatusSize)
r.CloseBody = Encoding.UTF8.GetString(payload.AsSpan(WsConstants.CloseStatusSize));
}
// Per RFC 6455 Section 5.5.1, always send a close response
if (r.CloseStatus != WsConstants.CloseStatusNoStatusReceived)
{
var closeMsg = WsFrameWriter.CreateCloseMessage(r.CloseStatus, r.CloseBody ?? "");
r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.CloseMessage, closeMsg));
}
else
{
// Empty close frame — respond with empty close
r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.CloseMessage, []));
}
break;
case WsConstants.PingMessage:
r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.PongMessage, payload ?? []));
break;
case WsConstants.PongMessage:
// Nothing to do
break;
}
return pos;
}
/// <summary>
/// Gets needed bytes from buffer or reads from stream.
/// Ported from websocket.go lines 178-193.
/// </summary>
private static (byte[] data, int newPos) WsGet(Stream stream, byte[] buf, int pos, int max, int needed)
{
int avail = max - pos;
if (avail >= needed)
return (buf[pos..(pos + needed)], pos + needed);
var b = new byte[needed];
int start = 0;
if (avail > 0)
{
Buffer.BlockCopy(buf, pos, b, 0, avail);
start = avail;
}
while (start < needed)
{
int n = stream.Read(b, start, needed - start);
if (n == 0) throw new IOException("unexpected end of stream");
start += n;
}
return (b, pos + avail);
}
}
public readonly record struct ControlFrameAction(int Opcode, byte[] Payload);

View File

@@ -0,0 +1,268 @@
using System.Net;
using System.Security.Cryptography;
using System.Text;
namespace NATS.Server.WebSocket;
/// <summary>
/// WebSocket HTTP upgrade handshake handler.
/// Ported from golang/nats-server/server/websocket.go lines 731-917.
/// </summary>
public static class WsUpgrade
{
public static async Task<WsUpgradeResult> TryUpgradeAsync(
Stream inputStream, Stream outputStream, WebSocketOptions options,
CancellationToken ct = default)
{
try
{
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
cts.CancelAfter(options.HandshakeTimeout);
var (method, path, headers) = await ReadHttpRequestAsync(inputStream, cts.Token);
if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase))
return await FailAsync(outputStream, 405, "request method must be GET");
if (!headers.ContainsKey("Host"))
return await FailAsync(outputStream, 400, "'Host' missing in request");
if (!HeaderContains(headers, "Upgrade", "websocket"))
return await FailAsync(outputStream, 400, "invalid value for header 'Upgrade'");
if (!HeaderContains(headers, "Connection", "Upgrade"))
return await FailAsync(outputStream, 400, "invalid value for header 'Connection'");
if (!headers.TryGetValue("Sec-WebSocket-Key", out var key) || string.IsNullOrEmpty(key))
return await FailAsync(outputStream, 400, "key missing");
if (!HeaderContains(headers, "Sec-WebSocket-Version", "13"))
return await FailAsync(outputStream, 400, "invalid version");
var kind = path switch
{
_ when path.EndsWith("/leafnode") => WsClientKind.Leaf,
_ when path.EndsWith("/mqtt") => WsClientKind.Mqtt,
_ => WsClientKind.Client,
};
// Origin checking
if (options.SameOrigin || options.AllowedOrigins is { Count: > 0 })
{
var checker = new WsOriginChecker(options.SameOrigin, options.AllowedOrigins);
headers.TryGetValue("Origin", out var origin);
if (string.IsNullOrEmpty(origin))
headers.TryGetValue("Sec-WebSocket-Origin", out origin);
var originErr = checker.CheckOrigin(origin, headers.GetValueOrDefault("Host", ""), isTls: false);
if (originErr != null)
return await FailAsync(outputStream, 403, $"origin not allowed: {originErr}");
}
// Compression negotiation
bool compress = options.Compression;
if (compress)
{
compress = headers.TryGetValue("Sec-WebSocket-Extensions", out var ext) &&
ext.Contains(WsConstants.PmcExtension, StringComparison.OrdinalIgnoreCase);
}
// No-masking support (leaf nodes only — browser clients must always mask)
bool noMasking = kind == WsClientKind.Leaf &&
headers.TryGetValue(WsConstants.NoMaskingHeader, out var nmVal) &&
string.Equals(nmVal.Trim(), WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase);
// Browser detection
bool browser = false;
bool noCompFrag = false;
if (kind is WsClientKind.Client or WsClientKind.Mqtt &&
headers.TryGetValue("User-Agent", out var ua) && ua.StartsWith("Mozilla/"))
{
browser = true;
// Disable fragmentation of compressed frames for Safari browsers.
// Safari has both "Version/" and "Safari/" in the user agent string,
// while Chrome on macOS has "Safari/" but not "Version/".
noCompFrag = compress && ua.Contains("Version/") && ua.Contains("Safari/");
}
// Cookie extraction
string? cookieJwt = null, cookieUsername = null, cookiePassword = null, cookieToken = null;
if ((kind is WsClientKind.Client or WsClientKind.Mqtt) &&
headers.TryGetValue("Cookie", out var cookieHeader))
{
var cookies = ParseCookies(cookieHeader);
if (options.JwtCookie != null) cookies.TryGetValue(options.JwtCookie, out cookieJwt);
if (options.UsernameCookie != null) cookies.TryGetValue(options.UsernameCookie, out cookieUsername);
if (options.PasswordCookie != null) cookies.TryGetValue(options.PasswordCookie, out cookiePassword);
if (options.TokenCookie != null) cookies.TryGetValue(options.TokenCookie, out cookieToken);
}
// X-Forwarded-For client IP extraction
string? clientIp = null;
if (headers.TryGetValue(WsConstants.XForwardedForHeader, out var xff))
{
var ip = xff.Split(',')[0].Trim();
if (IPAddress.TryParse(ip, out _))
clientIp = ip;
}
// Build the 101 Switching Protocols response
var response = new StringBuilder();
response.Append("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ");
response.Append(ComputeAcceptKey(key));
response.Append("\r\n");
if (compress)
response.Append(WsConstants.PmcFullResponse);
if (noMasking)
response.Append(WsConstants.NoMaskingFullResponse);
if (options.Headers != null)
{
foreach (var (k, v) in options.Headers)
{
response.Append(k);
response.Append(": ");
response.Append(v);
response.Append("\r\n");
}
}
response.Append("\r\n");
var responseBytes = Encoding.ASCII.GetBytes(response.ToString());
await outputStream.WriteAsync(responseBytes);
await outputStream.FlushAsync();
return new WsUpgradeResult(
Success: true, Compress: compress, Browser: browser, NoCompFrag: noCompFrag,
MaskRead: !noMasking, MaskWrite: false,
CookieJwt: cookieJwt, CookieUsername: cookieUsername,
CookiePassword: cookiePassword, CookieToken: cookieToken,
ClientIp: clientIp, Kind: kind);
}
catch (Exception)
{
return WsUpgradeResult.Failed;
}
}
/// <summary>
/// Computes the Sec-WebSocket-Accept value per RFC 6455 Section 4.2.2.
/// </summary>
public static string ComputeAcceptKey(string clientKey)
{
var combined = Encoding.ASCII.GetBytes(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
var hash = SHA1.HashData(combined);
return Convert.ToBase64String(hash);
}
private static async Task<WsUpgradeResult> FailAsync(Stream output, int statusCode, string reason)
{
var statusText = statusCode switch
{
400 => "Bad Request",
403 => "Forbidden",
405 => "Method Not Allowed",
_ => "Internal Server Error",
};
var response = $"HTTP/1.1 {statusCode} {statusText}\r\nSec-WebSocket-Version: 13\r\nContent-Type: text/plain\r\nContent-Length: {reason.Length}\r\n\r\n{reason}";
await output.WriteAsync(Encoding.ASCII.GetBytes(response));
await output.FlushAsync();
return WsUpgradeResult.Failed;
}
private static async Task<(string method, string path, Dictionary<string, string> headers)> ReadHttpRequestAsync(
Stream stream, CancellationToken ct)
{
var headerBytes = new List<byte>(4096);
var buf = new byte[512];
while (true)
{
int n = await stream.ReadAsync(buf, ct);
if (n == 0) throw new IOException("connection closed during handshake");
for (int i = 0; i < n; i++)
{
headerBytes.Add(buf[i]);
if (headerBytes.Count >= 4 &&
headerBytes[^4] == '\r' && headerBytes[^3] == '\n' &&
headerBytes[^2] == '\r' && headerBytes[^1] == '\n')
goto done;
if (headerBytes.Count > 8192)
throw new InvalidOperationException("HTTP header too large");
}
}
done:;
var text = Encoding.ASCII.GetString(headerBytes.ToArray());
var lines = text.Split("\r\n", StringSplitOptions.None);
if (lines.Length < 1) throw new InvalidOperationException("invalid HTTP request");
var parts = lines[0].Split(' ');
if (parts.Length < 3) throw new InvalidOperationException("invalid HTTP request line");
var method = parts[0];
var path = parts[1];
var headers = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
for (int i = 1; i < lines.Length; i++)
{
var line = lines[i];
if (string.IsNullOrEmpty(line)) break;
var colonIdx = line.IndexOf(':');
if (colonIdx > 0)
{
var name = line[..colonIdx].Trim();
var value = line[(colonIdx + 1)..].Trim();
headers[name] = value;
}
}
return (method, path, headers);
}
private static bool HeaderContains(Dictionary<string, string> headers, string name, string value)
{
if (!headers.TryGetValue(name, out var headerValue))
return false;
foreach (var token in headerValue.Split(','))
{
if (string.Equals(token.Trim(), value, StringComparison.OrdinalIgnoreCase))
return true;
}
return false;
}
private static Dictionary<string, string> ParseCookies(string cookieHeader)
{
var cookies = new Dictionary<string, string>(StringComparer.Ordinal);
foreach (var pair in cookieHeader.Split(';'))
{
var trimmed = pair.Trim();
var eqIdx = trimmed.IndexOf('=');
if (eqIdx > 0)
cookies[trimmed[..eqIdx].Trim()] = trimmed[(eqIdx + 1)..].Trim();
}
return cookies;
}
}
/// <summary>
/// Result of a WebSocket upgrade handshake attempt.
/// </summary>
public readonly record struct WsUpgradeResult(
bool Success,
bool Compress,
bool Browser,
bool NoCompFrag,
bool MaskRead,
bool MaskWrite,
string? CookieJwt,
string? CookieUsername,
string? CookiePassword,
string? CookieToken,
string? ClientIp,
WsClientKind Kind)
{
public static readonly WsUpgradeResult Failed = new(
Success: false, Compress: false, Browser: false, NoCompFrag: false,
MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null,
CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client);
}

View File

@@ -1,121 +0,0 @@
using System.Text.Json;
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Server.Events;
namespace NATS.Server.Tests;
public class EventSystemTests
{
[Fact]
public void ConnectEventMsg_serializes_with_correct_type()
{
var evt = new ConnectEventMsg
{
Type = ConnectEventMsg.EventType,
Id = "test123",
Time = new DateTime(2026, 1, 1, 0, 0, 0, DateTimeKind.Utc),
Server = new EventServerInfo { Name = "test-server", Id = "SRV1" },
Client = new EventClientInfo { Id = 1, Account = "$G" },
};
var json = JsonSerializer.Serialize(evt, EventJsonContext.Default.ConnectEventMsg);
json.ShouldContain("\"type\":\"io.nats.server.advisory.v1.client_connect\"");
json.ShouldContain("\"server\":");
json.ShouldContain("\"client\":");
}
[Fact]
public void DisconnectEventMsg_serializes_with_reason()
{
var evt = new DisconnectEventMsg
{
Type = DisconnectEventMsg.EventType,
Id = "test456",
Time = DateTime.UtcNow,
Server = new EventServerInfo { Name = "test-server", Id = "SRV1" },
Client = new EventClientInfo { Id = 2, Account = "myacc" },
Reason = "Client Closed",
Sent = new DataStats { Msgs = 10, Bytes = 1024 },
Received = new DataStats { Msgs = 5, Bytes = 512 },
};
var json = JsonSerializer.Serialize(evt, EventJsonContext.Default.DisconnectEventMsg);
json.ShouldContain("\"reason\":\"Client Closed\"");
}
[Fact]
public void ServerStatsMsg_serializes()
{
var evt = new ServerStatsMsg
{
Server = new EventServerInfo { Name = "srv1", Id = "ABC" },
Stats = new ServerStatsData
{
Connections = 10,
TotalConnections = 100,
InMsgs = 5000,
OutMsgs = 4500,
InBytes = 1_000_000,
OutBytes = 900_000,
Mem = 50 * 1024 * 1024,
Subscriptions = 42,
},
};
var json = JsonSerializer.Serialize(evt, EventJsonContext.Default.ServerStatsMsg);
json.ShouldContain("\"connections\":10");
json.ShouldContain("\"in_msgs\":5000");
}
[Fact]
public async Task InternalEventSystem_start_and_stop_lifecycle()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var eventSystem = server.EventSystem;
eventSystem.ShouldNotBeNull();
eventSystem.SystemClient.ShouldNotBeNull();
eventSystem.SystemClient.Kind.ShouldBe(ClientKind.System);
await server.ShutdownAsync();
}
[Fact]
public async Task SendInternalMsg_delivers_to_system_subscriber()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<string>();
server.EventSystem!.SysSubscribe("test.subject", (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(subject);
});
server.SendInternalMsg("test.subject", null, new { Value = "hello" });
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
result.ShouldBe("test.subject");
await server.ShutdownAsync();
}
private static NatsServer CreateTestServer()
{
var port = GetFreePort();
return new NatsServer(new NatsOptions { Port = port }, NullLoggerFactory.Instance);
}
private static int GetFreePort()
{
using var sock = new System.Net.Sockets.Socket(
System.Net.Sockets.AddressFamily.InterNetwork,
System.Net.Sockets.SocketType.Stream,
System.Net.Sockets.ProtocolType.Tcp);
sock.Bind(new System.Net.IPEndPoint(System.Net.IPAddress.Loopback, 0));
return ((System.Net.IPEndPoint)sock.LocalEndPoint!).Port;
}
}

View File

@@ -1,338 +0,0 @@
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Server;
using NATS.Server.Auth;
using NATS.Server.Imports;
using NATS.Server.Subscriptions;
namespace NATS.Server.Tests;
public class ImportExportTests
{
[Fact]
public void ExportAuth_public_export_authorizes_any_account()
{
var auth = new ExportAuth();
var account = new Account("test");
auth.IsAuthorized(account).ShouldBeTrue();
}
[Fact]
public void ExportAuth_approved_accounts_restricts_access()
{
var auth = new ExportAuth { ApprovedAccounts = ["allowed"] };
var allowed = new Account("allowed");
var denied = new Account("denied");
auth.IsAuthorized(allowed).ShouldBeTrue();
auth.IsAuthorized(denied).ShouldBeFalse();
}
[Fact]
public void ExportAuth_revoked_account_denied()
{
var auth = new ExportAuth
{
ApprovedAccounts = ["test"],
RevokedAccounts = new() { ["test"] = DateTimeOffset.UtcNow.ToUnixTimeSeconds() },
};
var account = new Account("test");
auth.IsAuthorized(account).ShouldBeFalse();
}
[Fact]
public void ServiceResponseType_defaults_to_singleton()
{
var import = new ServiceImport
{
DestinationAccount = new Account("dest"),
From = "requests.>",
To = "api.>",
};
import.ResponseType.ShouldBe(ServiceResponseType.Singleton);
}
[Fact]
public void ExportMap_stores_and_retrieves_exports()
{
var map = new ExportMap();
map.Services["api.>"] = new ServiceExport { Account = new Account("svc") };
map.Streams["events.>"] = new StreamExport();
map.Services.ShouldContainKey("api.>");
map.Streams.ShouldContainKey("events.>");
}
[Fact]
public void ImportMap_stores_service_imports()
{
var map = new ImportMap();
var si = new ServiceImport
{
DestinationAccount = new Account("dest"),
From = "requests.>",
To = "api.>",
};
map.AddServiceImport(si);
map.Services.ShouldContainKey("requests.>");
map.Services["requests.>"].Count.ShouldBe(1);
}
[Fact]
public void Account_add_service_export_and_import()
{
var exporter = new Account("exporter");
var importer = new Account("importer");
exporter.AddServiceExport("api.>", ServiceResponseType.Singleton, null);
exporter.Exports.Services.ShouldContainKey("api.>");
var si = importer.AddServiceImport(exporter, "requests.>", "api.>");
si.ShouldNotBeNull();
si.From.ShouldBe("requests.>");
si.To.ShouldBe("api.>");
si.DestinationAccount.ShouldBe(exporter);
importer.Imports.Services.ShouldContainKey("requests.>");
}
[Fact]
public void Account_add_stream_export_and_import()
{
var exporter = new Account("exporter");
var importer = new Account("importer");
exporter.AddStreamExport("events.>", null);
exporter.Exports.Streams.ShouldContainKey("events.>");
importer.AddStreamImport(exporter, "events.>", "imported.events.>");
importer.Imports.Streams.Count.ShouldBe(1);
importer.Imports.Streams[0].From.ShouldBe("events.>");
importer.Imports.Streams[0].To.ShouldBe("imported.events.>");
}
[Fact]
public void Account_service_import_auth_rejected()
{
var exporter = new Account("exporter");
var importer = new Account("importer");
exporter.AddServiceExport("api.>", ServiceResponseType.Singleton, [new Account("other")]);
Should.Throw<UnauthorizedAccessException>(() =>
importer.AddServiceImport(exporter, "requests.>", "api.>"));
}
[Fact]
public void Account_lazy_creates_internal_client()
{
var account = new Account("test");
var client = account.GetOrCreateInternalClient(99);
client.ShouldNotBeNull();
client.Kind.ShouldBe(ClientKind.Account);
client.Account.ShouldBe(account);
// Second call returns same instance
var client2 = account.GetOrCreateInternalClient(100);
client2.ShouldBeSameAs(client);
}
[Fact]
public async Task Service_import_forwards_message_to_export_account()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
// Set up exporter and importer accounts
var exporter = server.GetOrCreateAccount("exporter");
var importer = server.GetOrCreateAccount("importer");
exporter.AddServiceExport("api.>", ServiceResponseType.Singleton, null);
importer.AddServiceImport(exporter, "requests.>", "api.>");
// Wire the import subscriptions into the importer account
server.WireServiceImports(importer);
// Subscribe in exporter account to receive forwarded message
var exportSub = new Subscription { Subject = "api.test", Sid = "export-1", Client = null };
exporter.SubList.Insert(exportSub);
// Verify import infrastructure is wired: the importer should have service import entries
importer.Imports.Services.ShouldContainKey("requests.>");
importer.Imports.Services["requests.>"].Count.ShouldBe(1);
importer.Imports.Services["requests.>"][0].DestinationAccount.ShouldBe(exporter);
await server.ShutdownAsync();
}
[Fact]
public void ProcessServiceImport_delivers_to_destination_account_subscribers()
{
using var server = CreateTestServer();
var exporter = server.GetOrCreateAccount("exporter");
var importer = server.GetOrCreateAccount("importer");
exporter.AddServiceExport("api.>", ServiceResponseType.Singleton, null);
importer.AddServiceImport(exporter, "requests.>", "api.>");
// Add a subscriber in the exporter account's SubList
var received = new List<(string Subject, string Sid)>();
var mockClient = new TestNatsClient(1, exporter);
mockClient.OnMessage = (subject, sid, _, _, _) =>
received.Add((subject, sid));
var exportSub = new Subscription { Subject = "api.test", Sid = "s1", Client = mockClient };
exporter.SubList.Insert(exportSub);
// Process a service import directly
var si = importer.Imports.Services["requests.>"][0];
server.ProcessServiceImport(si, "requests.test", null,
ReadOnlyMemory<byte>.Empty, ReadOnlyMemory<byte>.Empty);
received.Count.ShouldBe(1);
received[0].Subject.ShouldBe("api.test");
received[0].Sid.ShouldBe("s1");
}
[Fact]
public void ProcessServiceImport_with_transform_applies_subject_mapping()
{
using var server = CreateTestServer();
var exporter = server.GetOrCreateAccount("exporter");
var importer = server.GetOrCreateAccount("importer");
exporter.AddServiceExport("api.>", ServiceResponseType.Singleton, null);
var si = importer.AddServiceImport(exporter, "requests.>", "api.>");
// Create a transform from requests.> to api.>
var transform = SubjectTransform.Create("requests.>", "api.>");
transform.ShouldNotBeNull();
// Create a new import with the transform set
var siWithTransform = new ServiceImport
{
DestinationAccount = exporter,
From = "requests.>",
To = "api.>",
Transform = transform,
};
var received = new List<string>();
var mockClient = new TestNatsClient(1, exporter);
mockClient.OnMessage = (subject, _, _, _, _) =>
received.Add(subject);
var exportSub = new Subscription { Subject = "api.hello", Sid = "s1", Client = mockClient };
exporter.SubList.Insert(exportSub);
server.ProcessServiceImport(siWithTransform, "requests.hello", null,
ReadOnlyMemory<byte>.Empty, ReadOnlyMemory<byte>.Empty);
received.Count.ShouldBe(1);
received[0].ShouldBe("api.hello");
}
[Fact]
public void ProcessServiceImport_skips_invalid_imports()
{
using var server = CreateTestServer();
var exporter = server.GetOrCreateAccount("exporter");
var importer = server.GetOrCreateAccount("importer");
exporter.AddServiceExport("api.>", ServiceResponseType.Singleton, null);
importer.AddServiceImport(exporter, "requests.>", "api.>");
// Mark the import as invalid
var si = importer.Imports.Services["requests.>"][0];
si.Invalid = true;
// Add a subscriber in the exporter account
var received = new List<string>();
var mockClient = new TestNatsClient(1, exporter);
mockClient.OnMessage = (subject, _, _, _, _) =>
received.Add(subject);
var exportSub = new Subscription { Subject = "api.test", Sid = "s1", Client = mockClient };
exporter.SubList.Insert(exportSub);
// ProcessServiceImport should be a no-op for invalid imports
server.ProcessServiceImport(si, "requests.test", null,
ReadOnlyMemory<byte>.Empty, ReadOnlyMemory<byte>.Empty);
received.Count.ShouldBe(0);
}
[Fact]
public void ProcessServiceImport_delivers_to_queue_groups()
{
using var server = CreateTestServer();
var exporter = server.GetOrCreateAccount("exporter");
var importer = server.GetOrCreateAccount("importer");
exporter.AddServiceExport("api.>", ServiceResponseType.Singleton, null);
importer.AddServiceImport(exporter, "requests.>", "api.>");
// Add queue group subscribers in the exporter account
var received = new List<(string Subject, string Sid)>();
var mockClient1 = new TestNatsClient(1, exporter);
mockClient1.OnMessage = (subject, sid, _, _, _) =>
received.Add((subject, sid));
var mockClient2 = new TestNatsClient(2, exporter);
mockClient2.OnMessage = (subject, sid, _, _, _) =>
received.Add((subject, sid));
var qSub1 = new Subscription { Subject = "api.test", Sid = "q1", Queue = "workers", Client = mockClient1 };
var qSub2 = new Subscription { Subject = "api.test", Sid = "q2", Queue = "workers", Client = mockClient2 };
exporter.SubList.Insert(qSub1);
exporter.SubList.Insert(qSub2);
var si = importer.Imports.Services["requests.>"][0];
server.ProcessServiceImport(si, "requests.test", null,
ReadOnlyMemory<byte>.Empty, ReadOnlyMemory<byte>.Empty);
// One member of the queue group should receive the message
received.Count.ShouldBe(1);
}
private static NatsServer CreateTestServer()
{
var port = GetFreePort();
return new NatsServer(new NatsOptions { Port = port }, NullLoggerFactory.Instance);
}
private static int GetFreePort()
{
using var sock = new System.Net.Sockets.Socket(
System.Net.Sockets.AddressFamily.InterNetwork,
System.Net.Sockets.SocketType.Stream,
System.Net.Sockets.ProtocolType.Tcp);
sock.Bind(new System.Net.IPEndPoint(System.Net.IPAddress.Loopback, 0));
return ((System.Net.IPEndPoint)sock.LocalEndPoint!).Port;
}
/// <summary>
/// Minimal test double for INatsClient used in import/export tests.
/// </summary>
private sealed class TestNatsClient(ulong id, Account account) : INatsClient
{
public ulong Id => id;
public ClientKind Kind => ClientKind.Client;
public Account? Account => account;
public Protocol.ClientOptions? ClientOpts => null;
public ClientPermissions? Permissions => null;
public Action<string, string, string?, ReadOnlyMemory<byte>, ReadOnlyMemory<byte>>? OnMessage { get; set; }
public void SendMessage(string subject, string sid, string? replyTo,
ReadOnlyMemory<byte> headers, ReadOnlyMemory<byte> payload)
{
OnMessage?.Invoke(subject, sid, replyTo, headers, payload);
}
public bool QueueOutbound(ReadOnlyMemory<byte> data) => true;
public void RemoveSubscription(string sid) { }
}
}

View File

@@ -1,85 +0,0 @@
using NATS.Server.Auth;
namespace NATS.Server.Tests;
public class InternalClientTests
{
[Theory]
[InlineData(ClientKind.Client, false)]
[InlineData(ClientKind.Router, false)]
[InlineData(ClientKind.Gateway, false)]
[InlineData(ClientKind.Leaf, false)]
[InlineData(ClientKind.System, true)]
[InlineData(ClientKind.JetStream, true)]
[InlineData(ClientKind.Account, true)]
public void IsInternal_returns_correct_value(ClientKind kind, bool expected)
{
kind.IsInternal().ShouldBe(expected);
}
[Fact]
public void NatsClient_implements_INatsClient()
{
typeof(NatsClient).GetInterfaces().ShouldContain(typeof(INatsClient));
}
[Fact]
public void NatsClient_kind_is_Client()
{
typeof(NatsClient).GetProperty("Kind")!.PropertyType.ShouldBe(typeof(ClientKind));
}
[Fact]
public void InternalClient_system_kind()
{
var account = new Account("$SYS");
var client = new InternalClient(1, ClientKind.System, account);
client.Kind.ShouldBe(ClientKind.System);
client.IsInternal.ShouldBeTrue();
client.Id.ShouldBe(1UL);
client.Account.ShouldBe(account);
}
[Fact]
public void InternalClient_account_kind()
{
var account = new Account("myaccount");
var client = new InternalClient(2, ClientKind.Account, account);
client.Kind.ShouldBe(ClientKind.Account);
client.IsInternal.ShouldBeTrue();
}
[Fact]
public void InternalClient_rejects_non_internal_kind()
{
var account = new Account("test");
Should.Throw<ArgumentException>(() => new InternalClient(1, ClientKind.Client, account));
}
[Fact]
public void InternalClient_SendMessage_invokes_callback()
{
var account = new Account("$SYS");
var client = new InternalClient(1, ClientKind.System, account);
string? capturedSubject = null;
string? capturedSid = null;
client.MessageCallback = (subject, sid, replyTo, headers, payload) =>
{
capturedSubject = subject;
capturedSid = sid;
};
client.SendMessage("test.subject", "1", null, ReadOnlyMemory<byte>.Empty, ReadOnlyMemory<byte>.Empty);
capturedSubject.ShouldBe("test.subject");
capturedSid.ShouldBe("1");
}
[Fact]
public void InternalClient_QueueOutbound_returns_true_noop()
{
var account = new Account("$SYS");
var client = new InternalClient(1, ClientKind.System, account);
client.QueueOutbound(ReadOnlyMemory<byte>.Empty).ShouldBeTrue();
}
}

View File

@@ -1,149 +0,0 @@
using NATS.Server.Auth;
using NATS.Server.Imports;
namespace NATS.Server.Tests;
public class ResponseRoutingTests
{
[Fact]
public void GenerateReplyPrefix_creates_unique_prefix()
{
var prefix1 = ResponseRouter.GenerateReplyPrefix();
var prefix2 = ResponseRouter.GenerateReplyPrefix();
prefix1.ShouldStartWith("_R_.");
prefix2.ShouldStartWith("_R_.");
prefix1.ShouldNotBe(prefix2);
prefix1.Length.ShouldBeGreaterThan(4);
}
[Fact]
public void GenerateReplyPrefix_ends_with_dot()
{
var prefix = ResponseRouter.GenerateReplyPrefix();
prefix.ShouldEndWith(".");
// Format: "_R_." + 10 chars + "." = 15 chars
prefix.Length.ShouldBe(15);
}
[Fact]
public void Singleton_response_import_removed_after_delivery()
{
var exporter = new Account("exporter");
exporter.AddServiceExport("api.test", ServiceResponseType.Singleton, null);
var replyPrefix = ResponseRouter.GenerateReplyPrefix();
var responseSi = new ServiceImport
{
DestinationAccount = exporter,
From = replyPrefix + ">",
To = "_INBOX.original.reply",
IsResponse = true,
ResponseType = ServiceResponseType.Singleton,
};
exporter.Exports.Responses[replyPrefix] = responseSi;
exporter.Exports.Responses.ShouldContainKey(replyPrefix);
// Simulate singleton delivery cleanup
ResponseRouter.CleanupResponse(exporter, replyPrefix, responseSi);
exporter.Exports.Responses.ShouldNotContainKey(replyPrefix);
}
[Fact]
public void CreateResponseImport_registers_in_exporter_responses()
{
var exporter = new Account("exporter");
var importer = new Account("importer");
exporter.AddServiceExport("api.test", ServiceResponseType.Singleton, null);
var originalSi = new ServiceImport
{
DestinationAccount = exporter,
From = "api.test",
To = "api.test",
Export = exporter.Exports.Services["api.test"],
ResponseType = ServiceResponseType.Singleton,
};
var responseSi = ResponseRouter.CreateResponseImport(exporter, originalSi, "_INBOX.abc123");
responseSi.IsResponse.ShouldBeTrue();
responseSi.ResponseType.ShouldBe(ServiceResponseType.Singleton);
responseSi.To.ShouldBe("_INBOX.abc123");
responseSi.DestinationAccount.ShouldBe(exporter);
responseSi.From.ShouldEndWith(">");
responseSi.Export.ShouldBe(originalSi.Export);
// Should be registered in the exporter's response map
exporter.Exports.Responses.Count.ShouldBe(1);
}
[Fact]
public void CreateResponseImport_preserves_streamed_response_type()
{
var exporter = new Account("exporter");
exporter.AddServiceExport("api.stream", ServiceResponseType.Streamed, null);
var originalSi = new ServiceImport
{
DestinationAccount = exporter,
From = "api.stream",
To = "api.stream",
Export = exporter.Exports.Services["api.stream"],
ResponseType = ServiceResponseType.Streamed,
};
var responseSi = ResponseRouter.CreateResponseImport(exporter, originalSi, "_INBOX.xyz789");
responseSi.ResponseType.ShouldBe(ServiceResponseType.Streamed);
}
[Fact]
public void Multiple_response_imports_each_get_unique_prefix()
{
var exporter = new Account("exporter");
exporter.AddServiceExport("api.test", ServiceResponseType.Singleton, null);
var originalSi = new ServiceImport
{
DestinationAccount = exporter,
From = "api.test",
To = "api.test",
Export = exporter.Exports.Services["api.test"],
ResponseType = ServiceResponseType.Singleton,
};
var resp1 = ResponseRouter.CreateResponseImport(exporter, originalSi, "_INBOX.reply1");
var resp2 = ResponseRouter.CreateResponseImport(exporter, originalSi, "_INBOX.reply2");
exporter.Exports.Responses.Count.ShouldBe(2);
resp1.To.ShouldBe("_INBOX.reply1");
resp2.To.ShouldBe("_INBOX.reply2");
resp1.From.ShouldNotBe(resp2.From);
}
[Fact]
public void LatencyTracker_should_sample_respects_percentage()
{
var latency = new ServiceLatency { SamplingPercentage = 0, Subject = "latency.test" };
LatencyTracker.ShouldSample(latency).ShouldBeFalse();
var latency100 = new ServiceLatency { SamplingPercentage = 100, Subject = "latency.test" };
LatencyTracker.ShouldSample(latency100).ShouldBeTrue();
}
[Fact]
public void LatencyTracker_builds_latency_message()
{
var msg = LatencyTracker.BuildLatencyMsg("requester", "responder",
TimeSpan.FromMilliseconds(5), TimeSpan.FromMilliseconds(10));
msg.Requestor.ShouldBe("requester");
msg.Responder.ShouldBe("responder");
msg.ServiceLatencyNanos.ShouldBeGreaterThan(0);
msg.TotalLatencyNanos.ShouldBeGreaterThan(0);
}
}

View File

@@ -1,133 +0,0 @@
using System.Text.Json;
using NATS.Server;
using NATS.Server.Events;
using Microsoft.Extensions.Logging.Abstractions;
namespace NATS.Server.Tests;
public class SystemEventsTests
{
[Fact]
public async Task Server_publishes_connect_event_on_client_auth()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<string>();
server.EventSystem!.SysSubscribe("$SYS.ACCOUNT.*.CONNECT", (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(subject);
});
// Connect a real client
using var sock = new System.Net.Sockets.Socket(
System.Net.Sockets.AddressFamily.InterNetwork,
System.Net.Sockets.SocketType.Stream,
System.Net.Sockets.ProtocolType.Tcp);
await sock.ConnectAsync(System.Net.IPAddress.Loopback, server.Port);
// Read INFO
var buf = new byte[4096];
await sock.ReceiveAsync(buf);
// Send CONNECT
var connect = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n");
await sock.SendAsync(connect);
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
result.ShouldStartWith("$SYS.ACCOUNT.");
result.ShouldEndWith(".CONNECT");
await server.ShutdownAsync();
}
[Fact]
public async Task Server_publishes_disconnect_event_on_client_close()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<string>();
server.EventSystem!.SysSubscribe("$SYS.ACCOUNT.*.DISCONNECT", (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(subject);
});
// Connect and then disconnect
using var sock = new System.Net.Sockets.Socket(
System.Net.Sockets.AddressFamily.InterNetwork,
System.Net.Sockets.SocketType.Stream,
System.Net.Sockets.ProtocolType.Tcp);
await sock.ConnectAsync(System.Net.IPAddress.Loopback, server.Port);
var buf = new byte[4096];
await sock.ReceiveAsync(buf);
await sock.SendAsync(System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n"));
await Task.Delay(100);
sock.Shutdown(System.Net.Sockets.SocketShutdown.Both);
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
result.ShouldStartWith("$SYS.ACCOUNT.");
result.ShouldEndWith(".DISCONNECT");
await server.ShutdownAsync();
}
[Fact]
public async Task Server_publishes_statsz_periodically()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<string>();
server.EventSystem!.SysSubscribe("$SYS.SERVER.*.STATSZ", (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(subject);
});
// Trigger a manual stats publish (don't wait 10s)
server.EventSystem!.PublishServerStats();
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
result.ShouldContain(".STATSZ");
await server.ShutdownAsync();
}
[Fact]
public async Task Server_publishes_shutdown_event()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<string>();
server.EventSystem!.SysSubscribe("$SYS.SERVER.*.SHUTDOWN", (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(subject);
});
await server.ShutdownAsync();
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
result.ShouldContain(".SHUTDOWN");
}
private static NatsServer CreateTestServer()
{
var port = GetFreePort();
return new NatsServer(new NatsOptions { Port = port }, NullLoggerFactory.Instance);
}
private static int GetFreePort()
{
using var sock = new System.Net.Sockets.Socket(
System.Net.Sockets.AddressFamily.InterNetwork,
System.Net.Sockets.SocketType.Stream,
System.Net.Sockets.ProtocolType.Tcp);
sock.Bind(new System.Net.IPEndPoint(System.Net.IPAddress.Loopback, 0));
return ((System.Net.IPEndPoint)sock.LocalEndPoint!).Port;
}
}

View File

@@ -1,170 +0,0 @@
using System.Text;
using System.Text.Json;
using NATS.Server;
using NATS.Server.Events;
using Microsoft.Extensions.Logging.Abstractions;
namespace NATS.Server.Tests;
public class SystemRequestReplyTests
{
[Fact]
public async Task Varz_request_reply_returns_server_info()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<byte[]>();
var replySubject = $"_INBOX.test.{Guid.NewGuid():N}";
server.EventSystem!.SysSubscribe(replySubject, (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(msg.ToArray());
});
var reqSubject = string.Format(EventSubjects.ServerReq, server.ServerId, "VARZ");
server.SendInternalMsg(reqSubject, replySubject, null);
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
var json = Encoding.UTF8.GetString(result);
json.ShouldContain("\"server_id\"");
json.ShouldContain("\"version\"");
json.ShouldContain("\"host\"");
json.ShouldContain("\"port\"");
await server.ShutdownAsync();
}
[Fact]
public async Task Healthz_request_reply_returns_ok()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<byte[]>();
var replySubject = $"_INBOX.test.{Guid.NewGuid():N}";
server.EventSystem!.SysSubscribe(replySubject, (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(msg.ToArray());
});
var reqSubject = string.Format(EventSubjects.ServerReq, server.ServerId, "HEALTHZ");
server.SendInternalMsg(reqSubject, replySubject, null);
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
var json = Encoding.UTF8.GetString(result);
json.ShouldContain("ok");
await server.ShutdownAsync();
}
[Fact]
public async Task Subsz_request_reply_returns_subscription_count()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<byte[]>();
var replySubject = $"_INBOX.test.{Guid.NewGuid():N}";
server.EventSystem!.SysSubscribe(replySubject, (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(msg.ToArray());
});
var reqSubject = string.Format(EventSubjects.ServerReq, server.ServerId, "SUBSZ");
server.SendInternalMsg(reqSubject, replySubject, null);
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
var json = Encoding.UTF8.GetString(result);
json.ShouldContain("\"num_subscriptions\"");
await server.ShutdownAsync();
}
[Fact]
public async Task Idz_request_reply_returns_server_identity()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<byte[]>();
var replySubject = $"_INBOX.test.{Guid.NewGuid():N}";
server.EventSystem!.SysSubscribe(replySubject, (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(msg.ToArray());
});
var reqSubject = string.Format(EventSubjects.ServerReq, server.ServerId, "IDZ");
server.SendInternalMsg(reqSubject, replySubject, null);
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
var json = Encoding.UTF8.GetString(result);
json.ShouldContain("\"server_id\"");
json.ShouldContain("\"server_name\"");
await server.ShutdownAsync();
}
[Fact]
public async Task Ping_varz_responds_via_wildcard_subject()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<byte[]>();
var replySubject = $"_INBOX.test.{Guid.NewGuid():N}";
server.EventSystem!.SysSubscribe(replySubject, (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(msg.ToArray());
});
var pingSubject = string.Format(EventSubjects.ServerPing, "VARZ");
server.SendInternalMsg(pingSubject, replySubject, null);
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
var json = Encoding.UTF8.GetString(result);
json.ShouldContain("\"server_id\"");
await server.ShutdownAsync();
}
[Fact]
public async Task Request_without_reply_is_ignored()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
// Send a request with no reply subject -- should not crash
var reqSubject = string.Format(EventSubjects.ServerReq, server.ServerId, "VARZ");
server.SendInternalMsg(reqSubject, null, null);
// Give it a moment to process without error
await Task.Delay(200);
// Server should still be running
server.IsShuttingDown.ShouldBeFalse();
await server.ShutdownAsync();
}
private static NatsServer CreateTestServer()
{
var port = GetFreePort();
return new NatsServer(new NatsOptions { Port = port }, NullLoggerFactory.Instance);
}
private static int GetFreePort()
{
using var sock = new System.Net.Sockets.Socket(
System.Net.Sockets.AddressFamily.InterNetwork,
System.Net.Sockets.SocketType.Stream,
System.Net.Sockets.ProtocolType.Tcp);
sock.Bind(new System.Net.IPEndPoint(System.Net.IPAddress.Loopback, 0));
return ((System.Net.IPEndPoint)sock.LocalEndPoint!).Port;
}
}

View File

@@ -0,0 +1,26 @@
using Shouldly;
namespace NATS.Server.Tests.WebSocket;
public class WebSocketOptionsTests
{
[Fact]
public void DefaultOptions_PortIsNegativeOne_Disabled()
{
var opts = new WebSocketOptions();
opts.Port.ShouldBe(-1);
opts.Host.ShouldBe("0.0.0.0");
opts.Compression.ShouldBeFalse();
opts.NoTls.ShouldBeFalse();
opts.HandshakeTimeout.ShouldBe(TimeSpan.FromSeconds(2));
opts.AuthTimeout.ShouldBe(TimeSpan.FromSeconds(2));
}
[Fact]
public void NatsOptions_HasWebSocketProperty()
{
var opts = new NatsOptions();
opts.WebSocket.ShouldNotBeNull();
opts.WebSocket.Port.ShouldBe(-1);
}
}

View File

@@ -0,0 +1,58 @@
using NATS.Server.WebSocket;
using Shouldly;
namespace NATS.Server.Tests.WebSocket;
public class WsCompressionTests
{
[Fact]
public void CompressDecompress_RoundTrip()
{
var original = "Hello, WebSocket compression test! This is long enough to compress."u8.ToArray();
var compressed = WsCompression.Compress(original);
compressed.ShouldNotBeNull();
compressed.Length.ShouldBeGreaterThan(0);
var decompressed = WsCompression.Decompress([compressed], maxPayload: 4096);
decompressed.ShouldBe(original);
}
[Fact]
public void Decompress_ExceedsMaxPayload_Throws()
{
var original = new byte[1000];
Random.Shared.NextBytes(original);
var compressed = WsCompression.Compress(original);
Should.Throw<InvalidOperationException>(() =>
WsCompression.Decompress([compressed], maxPayload: 100));
}
[Fact]
public void Compress_RemovesTrailing4Bytes()
{
var data = new byte[200];
Random.Shared.NextBytes(data);
var compressed = WsCompression.Compress(data);
// The compressed data should be valid for decompression when we add the trailer back
var decompressed = WsCompression.Decompress([compressed], maxPayload: 4096);
decompressed.ShouldBe(data);
}
[Fact]
public void Decompress_MultipleBuffers()
{
var original = new byte[500];
Random.Shared.NextBytes(original);
var compressed = WsCompression.Compress(original);
// Split compressed data into multiple chunks
int mid = compressed.Length / 2;
var chunk1 = compressed[..mid];
var chunk2 = compressed[mid..];
var decompressed = WsCompression.Decompress([chunk1, chunk2], maxPayload: 4096);
decompressed.ShouldBe(original);
}
}

View File

@@ -0,0 +1,124 @@
using System.Buffers.Binary;
using NATS.Server.WebSocket;
namespace NATS.Server.Tests.WebSocket;
public class WsConnectionTests
{
[Fact]
public async Task ReadAsync_DecodesFrameAndReturnsPayload()
{
var payload = "SUB test 1\r\n"u8.ToArray();
var frame = BuildUnmaskedFrame(payload);
var inner = new MemoryStream(frame);
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
var buf = new byte[256];
int n = await ws.ReadAsync(buf);
n.ShouldBe(payload.Length);
buf[..n].ShouldBe(payload);
}
[Fact]
public async Task WriteAsync_FramesPayload()
{
var inner = new MemoryStream();
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
var payload = "MSG test 1 5\r\nHello\r\n"u8.ToArray();
await ws.WriteAsync(payload);
await ws.FlushAsync();
inner.Position = 0;
var written = inner.ToArray();
// First 2 bytes should be WS frame header
(written[0] & WsConstants.FinalBit).ShouldNotBe(0);
(written[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage);
int len = written[1] & 0x7F;
len.ShouldBe(payload.Length);
written[2..].ShouldBe(payload);
}
[Fact]
public async Task WriteAsync_WithCompression_CompressesLargePayload()
{
var inner = new MemoryStream();
var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
var payload = new byte[200];
Array.Fill<byte>(payload, 0x41); // 'A' repeated - very compressible
await ws.WriteAsync(payload);
await ws.FlushAsync();
inner.Position = 0;
var written = inner.ToArray();
// RSV1 bit should be set for compressed frame
(written[0] & WsConstants.Rsv1Bit).ShouldNotBe(0);
// Compressed size should be less than original
written.Length.ShouldBeLessThan(payload.Length + 10);
}
[Fact]
public async Task WriteAsync_SmallPayload_NotCompressedEvenWhenEnabled()
{
var inner = new MemoryStream();
var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
var payload = "Hi"u8.ToArray(); // Below CompressThreshold
await ws.WriteAsync(payload);
await ws.FlushAsync();
inner.Position = 0;
var written = inner.ToArray();
// RSV1 bit should NOT be set for small payloads
(written[0] & WsConstants.Rsv1Bit).ShouldBe(0);
}
[Fact]
public async Task ReadAsync_DecodesMaskedFrame()
{
var payload = "CONNECT {}\r\n"u8.ToArray();
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: true, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: payload.Length);
var maskKey = header[^4..];
WsFrameWriter.MaskBuf(maskKey, payload);
var frame = new byte[header.Length + payload.Length];
header.CopyTo(frame, 0);
payload.CopyTo(frame, header.Length);
var inner = new MemoryStream(frame);
var ws = new WsConnection(inner, compress: false, maskRead: true, maskWrite: false, browser: false, noCompFrag: false);
var buf = new byte[256];
int n = await ws.ReadAsync(buf);
n.ShouldBe("CONNECT {}\r\n".Length);
System.Text.Encoding.ASCII.GetString(buf, 0, n).ShouldBe("CONNECT {}\r\n");
}
[Fact]
public async Task ReadAsync_ReturnsZero_OnEndOfStream()
{
// Empty stream should return 0 (true end of stream)
var inner = new MemoryStream([]);
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
var buf = new byte[256];
int n = await ws.ReadAsync(buf);
n.ShouldBe(0);
}
private static byte[] BuildUnmaskedFrame(byte[] payload)
{
var header = new byte[2];
header[0] = (byte)(WsConstants.FinalBit | WsConstants.BinaryMessage);
header[1] = (byte)payload.Length;
var frame = new byte[2 + payload.Length];
header.CopyTo(frame, 0);
payload.CopyTo(frame, 2);
return frame;
}
}

View File

@@ -0,0 +1,53 @@
using NATS.Server.WebSocket;
using Shouldly;
namespace NATS.Server.Tests.WebSocket;
public class WsConstantsTests
{
[Fact]
public void OpCodes_MatchRfc6455()
{
WsConstants.TextMessage.ShouldBe(1);
WsConstants.BinaryMessage.ShouldBe(2);
WsConstants.CloseMessage.ShouldBe(8);
WsConstants.PingMessage.ShouldBe(9);
WsConstants.PongMessage.ShouldBe(10);
}
[Fact]
public void FrameBits_MatchRfc6455()
{
WsConstants.FinalBit.ShouldBe((byte)0x80);
WsConstants.Rsv1Bit.ShouldBe((byte)0x40);
WsConstants.MaskBit.ShouldBe((byte)0x80);
}
[Fact]
public void CloseStatusCodes_MatchRfc6455()
{
WsConstants.CloseStatusNormalClosure.ShouldBe(1000);
WsConstants.CloseStatusGoingAway.ShouldBe(1001);
WsConstants.CloseStatusProtocolError.ShouldBe(1002);
WsConstants.CloseStatusPolicyViolation.ShouldBe(1008);
WsConstants.CloseStatusMessageTooBig.ShouldBe(1009);
}
[Theory]
[InlineData(WsConstants.CloseMessage)]
[InlineData(WsConstants.PingMessage)]
[InlineData(WsConstants.PongMessage)]
public void IsControlFrame_True(int opcode)
{
WsConstants.IsControlFrame(opcode).ShouldBeTrue();
}
[Theory]
[InlineData(WsConstants.TextMessage)]
[InlineData(WsConstants.BinaryMessage)]
[InlineData(0)]
public void IsControlFrame_False(int opcode)
{
WsConstants.IsControlFrame(opcode).ShouldBeFalse();
}
}

View File

@@ -0,0 +1,163 @@
using System.Buffers.Binary;
using NATS.Server.WebSocket;
using Shouldly;
namespace NATS.Server.Tests.WebSocket;
public class WsFrameReadTests
{
/// <summary>Helper: build a single unmasked binary frame.</summary>
private static byte[] BuildFrame(byte[] payload, bool fin = true, bool compressed = false, int opcode = WsConstants.BinaryMessage, bool mask = false, byte[]? maskKey = null)
{
int payloadLen = payload.Length;
byte b0 = (byte)opcode;
if (fin) b0 |= WsConstants.FinalBit;
if (compressed) b0 |= WsConstants.Rsv1Bit;
byte b1 = 0;
if (mask) b1 |= WsConstants.MaskBit;
byte[] lenBytes;
if (payloadLen <= 125)
{
lenBytes = [(byte)(b1 | (byte)payloadLen)];
}
else if (payloadLen < 65536)
{
lenBytes = new byte[3];
lenBytes[0] = (byte)(b1 | 126);
BinaryPrimitives.WriteUInt16BigEndian(lenBytes.AsSpan(1), (ushort)payloadLen);
}
else
{
lenBytes = new byte[9];
lenBytes[0] = (byte)(b1 | 127);
BinaryPrimitives.WriteUInt64BigEndian(lenBytes.AsSpan(1), (ulong)payloadLen);
}
int totalLen = 1 + lenBytes.Length + (mask ? 4 : 0) + payloadLen;
var frame = new byte[totalLen];
frame[0] = b0;
lenBytes.CopyTo(frame.AsSpan(1));
int pos = 1 + lenBytes.Length;
if (mask && maskKey != null)
{
maskKey.CopyTo(frame.AsSpan(pos));
pos += 4;
var maskedPayload = payload.ToArray();
WsFrameWriter.MaskBuf(maskKey, maskedPayload);
maskedPayload.CopyTo(frame.AsSpan(pos));
}
else
{
payload.CopyTo(frame.AsSpan(pos));
}
return frame;
}
[Fact]
public void ReadSingleUnmaskedFrame()
{
var payload = "Hello"u8.ToArray();
var frame = BuildFrame(payload);
var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(1);
result[0].ShouldBe(payload);
}
[Fact]
public void ReadMaskedFrame()
{
var payload = "Hello"u8.ToArray();
byte[] key = [0x37, 0xFA, 0x21, 0x3D];
var frame = BuildFrame(payload, mask: true, maskKey: key);
var readInfo = new WsReadInfo(expectMask: true);
var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(1);
result[0].ShouldBe(payload);
}
[Fact]
public void Read16BitLengthFrame()
{
var payload = new byte[200];
Random.Shared.NextBytes(payload);
var frame = BuildFrame(payload);
var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(1);
result[0].ShouldBe(payload);
}
[Fact]
public void ReadPingFrame_ReturnsPongAction()
{
var frame = BuildFrame([], opcode: WsConstants.PingMessage);
var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(0); // control frames don't produce payload
readInfo.PendingControlFrames.Count.ShouldBe(1);
readInfo.PendingControlFrames[0].Opcode.ShouldBe(WsConstants.PongMessage);
}
[Fact]
public void ReadCloseFrame_ReturnsCloseAction()
{
var closePayload = new byte[2];
BinaryPrimitives.WriteUInt16BigEndian(closePayload, 1000);
var frame = BuildFrame(closePayload, opcode: WsConstants.CloseMessage);
var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(0);
readInfo.CloseReceived.ShouldBeTrue();
readInfo.CloseStatus.ShouldBe(1000);
}
[Fact]
public void ReadPongFrame_NoAction()
{
var frame = BuildFrame([], opcode: WsConstants.PongMessage);
var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(0);
readInfo.PendingControlFrames.Count.ShouldBe(0);
}
[Fact]
public void Unmask_Optimized_8ByteChunks()
{
byte[] key = [0xAA, 0xBB, 0xCC, 0xDD];
var original = new byte[32];
Random.Shared.NextBytes(original);
var masked = original.ToArray();
// Mask it
for (int i = 0; i < masked.Length; i++)
masked[i] ^= key[i & 3];
// Unmask using the state machine
var info = new WsReadInfo(expectMask: true);
info.SetMaskKey(key);
info.Unmask(masked);
masked.ShouldBe(original);
}
}

View File

@@ -0,0 +1,152 @@
using System.Buffers.Binary;
using NATS.Server.WebSocket;
using Shouldly;
namespace NATS.Server.Tests.WebSocket;
public class WsFrameWriterTests
{
[Fact]
public void CreateFrameHeader_SmallPayload_7BitLength()
{
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: false, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: 100);
header.Length.ShouldBe(2);
(header[0] & WsConstants.FinalBit).ShouldNotBe(0); // FIN set
(header[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage);
(header[1] & 0x7F).ShouldBe(100);
}
[Fact]
public void CreateFrameHeader_MediumPayload_16BitLength()
{
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: false, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: 1000);
header.Length.ShouldBe(4);
(header[1] & 0x7F).ShouldBe(126);
BinaryPrimitives.ReadUInt16BigEndian(header.AsSpan(2)).ShouldBe((ushort)1000);
}
[Fact]
public void CreateFrameHeader_LargePayload_64BitLength()
{
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: false, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: 70000);
header.Length.ShouldBe(10);
(header[1] & 0x7F).ShouldBe(127);
BinaryPrimitives.ReadUInt64BigEndian(header.AsSpan(2)).ShouldBe(70000UL);
}
[Fact]
public void CreateFrameHeader_WithMasking_Adds4ByteKey()
{
var (header, key) = WsFrameWriter.CreateFrameHeader(
useMasking: true, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: 10);
header.Length.ShouldBe(6); // 2 header + 4 mask key
(header[1] & WsConstants.MaskBit).ShouldNotBe(0);
key.ShouldNotBeNull();
key.Length.ShouldBe(4);
}
[Fact]
public void CreateFrameHeader_Compressed_SetsRsv1Bit()
{
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: false, compressed: true,
opcode: WsConstants.BinaryMessage, payloadLength: 10);
(header[0] & WsConstants.Rsv1Bit).ShouldNotBe(0);
}
[Fact]
public void MaskBuf_XorsCorrectly()
{
byte[] key = [0xAA, 0xBB, 0xCC, 0xDD];
byte[] data = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06];
byte[] expected = new byte[data.Length];
for (int i = 0; i < data.Length; i++)
expected[i] = (byte)(data[i] ^ key[i & 3]);
WsFrameWriter.MaskBuf(key, data);
data.ShouldBe(expected);
}
[Fact]
public void MaskBuf_RoundTrip()
{
byte[] key = [0x12, 0x34, 0x56, 0x78];
byte[] original = "Hello, WebSocket!"u8.ToArray();
var data = original.ToArray();
WsFrameWriter.MaskBuf(key, data);
data.ShouldNotBe(original);
WsFrameWriter.MaskBuf(key, data);
data.ShouldBe(original);
}
[Fact]
public void CreateCloseMessage_WithStatusAndBody()
{
var msg = WsFrameWriter.CreateCloseMessage(1000, "normal closure");
msg.Length.ShouldBe(2 + "normal closure".Length);
BinaryPrimitives.ReadUInt16BigEndian(msg).ShouldBe((ushort)1000);
}
[Fact]
public void CreateCloseMessage_LongBody_Truncated()
{
var longBody = new string('x', 200);
var msg = WsFrameWriter.CreateCloseMessage(1000, longBody);
msg.Length.ShouldBeLessThanOrEqualTo(WsConstants.MaxControlPayloadSize);
}
[Fact]
public void MapCloseStatus_ClientClosed_NormalClosure()
{
WsFrameWriter.MapCloseStatus(ClientClosedReason.ClientClosed)
.ShouldBe(WsConstants.CloseStatusNormalClosure);
}
[Fact]
public void MapCloseStatus_AuthTimeout_PolicyViolation()
{
WsFrameWriter.MapCloseStatus(ClientClosedReason.AuthenticationTimeout)
.ShouldBe(WsConstants.CloseStatusPolicyViolation);
}
[Fact]
public void MapCloseStatus_ParseError_ProtocolError()
{
WsFrameWriter.MapCloseStatus(ClientClosedReason.ParseError)
.ShouldBe(WsConstants.CloseStatusProtocolError);
}
[Fact]
public void MapCloseStatus_MaxPayload_MessageTooBig()
{
WsFrameWriter.MapCloseStatus(ClientClosedReason.MaxPayloadExceeded)
.ShouldBe(WsConstants.CloseStatusMessageTooBig);
}
[Fact]
public void BuildControlFrame_PingNomask()
{
var frame = WsFrameWriter.BuildControlFrame(WsConstants.PingMessage, [], useMasking: false);
frame.Length.ShouldBe(2);
(frame[0] & WsConstants.FinalBit).ShouldNotBe(0);
(frame[0] & 0x0F).ShouldBe(WsConstants.PingMessage);
(frame[1] & 0x7F).ShouldBe(0);
}
[Fact]
public void BuildControlFrame_PongWithPayload()
{
byte[] payload = [1, 2, 3, 4];
var frame = WsFrameWriter.BuildControlFrame(WsConstants.PongMessage, payload, useMasking: false);
frame.Length.ShouldBe(2 + 4);
frame[2..].ShouldBe(payload);
}
}

View File

@@ -0,0 +1,162 @@
using System.Buffers.Binary;
using System.Net;
using System.Net.Sockets;
using System.Security.Cryptography;
using System.Text;
using NATS.Server.WebSocket;
namespace NATS.Server.Tests.WebSocket;
public class WsIntegrationTests : IAsyncLifetime
{
private NatsServer _server = null!;
private NatsOptions _options = null!;
public async Task InitializeAsync()
{
_options = new NatsOptions
{
Port = 0,
WebSocket = new WebSocketOptions { Port = 0, NoTls = true },
};
var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(b => { });
_server = new NatsServer(_options, loggerFactory);
_ = _server.StartAsync(CancellationToken.None);
await _server.WaitForReadyAsync();
}
public async Task DisposeAsync()
{
await _server.ShutdownAsync();
_server.Dispose();
}
[Fact]
public async Task WebSocket_ConnectAndReceiveInfo()
{
using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port));
using var stream = new NetworkStream(socket, ownsSocket: false);
await SendUpgradeRequest(stream);
var response = await ReadHttpResponse(stream);
response.ShouldContain("101");
var wsFrame = await ReadWsFrame(stream);
var info = Encoding.ASCII.GetString(wsFrame);
info.ShouldStartWith("INFO ");
}
[Fact]
public async Task WebSocket_ConnectAndPing()
{
using var client = await ConnectWsClient();
// Send CONNECT and PING together
await SendWsText(client, "CONNECT {}\r\nPING\r\n");
// Read PONG WS frame
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
var pong = await ReadWsFrameAsync(client, cts.Token);
Encoding.ASCII.GetString(pong).ShouldContain("PONG");
}
[Fact]
public async Task WebSocket_PubSub()
{
using var sub = await ConnectWsClient();
using var pub = await ConnectWsClient();
await SendWsText(sub, "CONNECT {}\r\nSUB test.ws 1\r\n");
await Task.Delay(200);
await SendWsText(pub, "CONNECT {}\r\nPUB test.ws 5\r\nHello\r\n");
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
var msg = await ReadWsFrameAsync(sub, cts.Token);
Encoding.ASCII.GetString(msg).ShouldContain("MSG test.ws 1 5");
}
private async Task<NetworkStream> ConnectWsClient()
{
var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port));
var stream = new NetworkStream(socket, ownsSocket: true);
await SendUpgradeRequest(stream);
var response = await ReadHttpResponse(stream);
response.ShouldContain("101");
await ReadWsFrame(stream); // Read INFO frame
return stream;
}
private static async Task SendUpgradeRequest(NetworkStream stream)
{
var keyBytes = new byte[16];
RandomNumberGenerator.Fill(keyBytes);
var key = Convert.ToBase64String(keyBytes);
var request = $"GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {key}\r\nSec-WebSocket-Version: 13\r\n\r\n";
await stream.WriteAsync(Encoding.ASCII.GetBytes(request));
await stream.FlushAsync();
}
private static async Task<string> ReadHttpResponse(NetworkStream stream)
{
// Read one byte at a time to avoid consuming WS frame bytes that follow the HTTP response
var sb = new StringBuilder();
var buf = new byte[1];
while (true)
{
int n = await stream.ReadAsync(buf);
if (n == 0) break;
sb.Append((char)buf[0]);
if (sb.Length >= 4 &&
sb[^4] == '\r' && sb[^3] == '\n' &&
sb[^2] == '\r' && sb[^1] == '\n')
break;
}
return sb.ToString();
}
private static Task<byte[]> ReadWsFrame(NetworkStream stream)
=> ReadWsFrameAsync(stream, CancellationToken.None);
private static async Task<byte[]> ReadWsFrameAsync(NetworkStream stream, CancellationToken ct)
{
var header = new byte[2];
await stream.ReadExactlyAsync(header, ct);
int len = header[1] & 0x7F;
if (len == 126)
{
var extLen = new byte[2];
await stream.ReadExactlyAsync(extLen, ct);
len = BinaryPrimitives.ReadUInt16BigEndian(extLen);
}
else if (len == 127)
{
var extLen = new byte[8];
await stream.ReadExactlyAsync(extLen, ct);
len = (int)BinaryPrimitives.ReadUInt64BigEndian(extLen);
}
var payload = new byte[len];
if (len > 0) await stream.ReadExactlyAsync(payload, ct);
return payload;
}
private static async Task SendWsText(NetworkStream stream, string text)
{
var payload = Encoding.ASCII.GetBytes(text);
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: true, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: payload.Length);
var maskKey = header[^4..];
WsFrameWriter.MaskBuf(maskKey, payload);
await stream.WriteAsync(header);
await stream.WriteAsync(payload);
await stream.FlushAsync();
}
}

View File

@@ -0,0 +1,82 @@
using NATS.Server.WebSocket;
using Shouldly;
namespace NATS.Server.Tests.WebSocket;
public class WsOriginCheckerTests
{
[Fact]
public void NoOriginHeader_Accepted()
{
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
checker.CheckOrigin(origin: null, requestHost: "localhost:4222", isTls: false)
.ShouldBeNull();
}
[Fact]
public void NeitherSameNorList_AlwaysAccepted()
{
var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: null);
checker.CheckOrigin("https://evil.com", "localhost:4222", false)
.ShouldBeNull();
}
[Fact]
public void SameOrigin_Match()
{
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
checker.CheckOrigin("http://localhost:4222", "localhost:4222", false)
.ShouldBeNull();
}
[Fact]
public void SameOrigin_Mismatch()
{
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
checker.CheckOrigin("http://other:4222", "localhost:4222", false)
.ShouldNotBeNull();
}
[Fact]
public void SameOrigin_DefaultPort_Http()
{
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
checker.CheckOrigin("http://localhost", "localhost:80", false)
.ShouldBeNull();
}
[Fact]
public void SameOrigin_DefaultPort_Https()
{
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
checker.CheckOrigin("https://localhost", "localhost:443", true)
.ShouldBeNull();
}
[Fact]
public void AllowedOrigins_Match()
{
var checker = new WsOriginChecker(sameOrigin: false,
allowedOrigins: ["https://app.example.com"]);
checker.CheckOrigin("https://app.example.com", "localhost:4222", false)
.ShouldBeNull();
}
[Fact]
public void AllowedOrigins_Mismatch()
{
var checker = new WsOriginChecker(sameOrigin: false,
allowedOrigins: ["https://app.example.com"]);
checker.CheckOrigin("https://evil.example.com", "localhost:4222", false)
.ShouldNotBeNull();
}
[Fact]
public void AllowedOrigins_SchemeMismatch()
{
var checker = new WsOriginChecker(sameOrigin: false,
allowedOrigins: ["https://app.example.com"]);
checker.CheckOrigin("http://app.example.com", "localhost:4222", false)
.ShouldNotBeNull();
}
}

View File

@@ -0,0 +1,226 @@
using System.Text;
using NATS.Server.WebSocket;
namespace NATS.Server.Tests.WebSocket;
public class WsUpgradeTests
{
private static string BuildValidRequest(string path = "/", string? extraHeaders = null)
{
var sb = new StringBuilder();
sb.Append($"GET {path} HTTP/1.1\r\n");
sb.Append("Host: localhost:4222\r\n");
sb.Append("Upgrade: websocket\r\n");
sb.Append("Connection: Upgrade\r\n");
sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
sb.Append("Sec-WebSocket-Version: 13\r\n");
if (extraHeaders != null)
sb.Append(extraHeaders);
sb.Append("\r\n");
return sb.ToString();
}
[Fact]
public async Task ValidUpgrade_Returns101()
{
var request = BuildValidRequest();
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Kind.ShouldBe(WsClientKind.Client);
var response = ReadResponse(outputStream);
response.ShouldContain("HTTP/1.1 101");
response.ShouldContain("Upgrade: websocket");
response.ShouldContain("Sec-WebSocket-Accept:");
}
[Fact]
public async Task MissingUpgradeHeader_Returns400()
{
var request = "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeFalse();
ReadResponse(outputStream).ShouldContain("400");
}
[Fact]
public async Task MissingHost_Returns400()
{
var request = "GET / HTTP/1.1\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeFalse();
}
[Fact]
public async Task WrongVersion_Returns400()
{
var request = "GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 12\r\n\r\n";
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeFalse();
}
[Fact]
public async Task LeafNodePath_ReturnsLeafKind()
{
var request = BuildValidRequest("/leafnode");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Kind.ShouldBe(WsClientKind.Leaf);
}
[Fact]
public async Task MqttPath_ReturnsMqttKind()
{
var request = BuildValidRequest("/mqtt");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Kind.ShouldBe(WsClientKind.Mqtt);
}
[Fact]
public async Task CompressionNegotiation_WhenEnabled()
{
var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}; {WsConstants.PmcSrvNoCtx}; {WsConstants.PmcCliNoCtx}\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true });
result.Success.ShouldBeTrue();
result.Compress.ShouldBeTrue();
ReadResponse(outputStream).ShouldContain("permessage-deflate");
}
[Fact]
public async Task CompressionNegotiation_WhenDisabled()
{
var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = false });
result.Success.ShouldBeTrue();
result.Compress.ShouldBeFalse();
}
[Fact]
public async Task NoMaskingHeader_ForLeaf()
{
var request = BuildValidRequest("/leafnode", "Nats-No-Masking: true\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.MaskRead.ShouldBeFalse();
}
[Fact]
public async Task BrowserDetection_Mozilla()
{
var request = BuildValidRequest(extraHeaders: "User-Agent: Mozilla/5.0 (Windows)\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Browser.ShouldBeTrue();
}
[Fact]
public async Task SafariDetection_NoCompFrag()
{
var request = BuildValidRequest(extraHeaders:
"User-Agent: Mozilla/5.0 (Macintosh) Version/15.0 Safari/605.1.15\r\n" +
$"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true });
result.Success.ShouldBeTrue();
result.NoCompFrag.ShouldBeTrue();
}
[Fact]
public void AcceptKey_MatchesRfc6455Example()
{
// RFC 6455 Section 4.2.2 example
var key = WsUpgrade.ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ==");
key.ShouldBe("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
}
[Fact]
public async Task CookieExtraction()
{
var request = BuildValidRequest(extraHeaders:
"Cookie: jwt_token=my-jwt; nats_user=admin; nats_pass=secret\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var opts = new WebSocketOptions
{
NoTls = true,
JwtCookie = "jwt_token",
UsernameCookie = "nats_user",
PasswordCookie = "nats_pass",
};
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
result.Success.ShouldBeTrue();
result.CookieJwt.ShouldBe("my-jwt");
result.CookieUsername.ShouldBe("admin");
result.CookiePassword.ShouldBe("secret");
}
[Fact]
public async Task XForwardedFor_ExtractsClientIp()
{
var request = BuildValidRequest(extraHeaders: "X-Forwarded-For: 192.168.1.100\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.ClientIp.ShouldBe("192.168.1.100");
}
[Fact]
public async Task PostMethod_Returns405()
{
var request = "POST / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeFalse();
ReadResponse(outputStream).ShouldContain("405");
}
// Helper: create a readable input stream and writable output stream
private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest)
{
var inputBytes = Encoding.ASCII.GetBytes(httpRequest);
return (new MemoryStream(inputBytes), new MemoryStream());
}
private static string ReadResponse(MemoryStream output)
{
output.Position = 0;
return Encoding.ASCII.GetString(output.ToArray());
}
}