Compare commits

..

16 Commits

Author SHA1 Message Date
Joseph Doherty
9b784024db docs: update differences.md to reflect SYSTEM/ACCOUNT types and imports/exports implemented 2026-02-23 06:04:29 -05:00
Joseph Doherty
86283a7f97 feat: add latency tracking for service import request-reply 2026-02-23 06:03:37 -05:00
Joseph Doherty
4450c27381 feat: add response routing for service import request-reply patterns 2026-02-23 06:01:53 -05:00
Joseph Doherty
c9066e526d feat: wire service import forwarding into message delivery path
Add ProcessServiceImport method to NatsServer that transforms subjects
from importer to exporter namespace and delivers to destination account
subscribers. Wire service import checking into ProcessMessage so that
publishes matching a service import "From" pattern are automatically
forwarded to the destination account. Includes MapImportSubject for
wildcard-aware subject mapping and WireServiceImports for import setup.
2026-02-23 05:59:36 -05:00
Joseph Doherty
4c2b7fa3de feat: add import/export support to Account with ACCOUNT client lazy creation 2026-02-23 05:54:31 -05:00
Joseph Doherty
591833adbb feat: add import/export model types (ServiceImport, StreamImport, exports, auth) 2026-02-23 05:51:30 -05:00
Joseph Doherty
5bae9cc289 feat: add system request-reply monitoring services ($SYS.REQ.SERVER.*)
Register VARZ, HEALTHZ, SUBSZ, STATSZ, and IDZ request-reply handlers
on $SYS.REQ.SERVER.{id}.* subjects and $SYS.REQ.SERVER.PING.* wildcard
subjects via InitEventTracking. Also excludes the $SYS system account
from the /subz monitoring endpoint by default since its subscriptions
are internal infrastructure.
2026-02-23 05:48:32 -05:00
Joseph Doherty
0b34f8cec4 feat: add periodic server stats and account connection heartbeat publishing 2026-02-23 05:44:09 -05:00
Joseph Doherty
125b71b3b0 feat: wire system event publishing for connect, disconnect, and shutdown 2026-02-23 05:41:44 -05:00
Joseph Doherty
89465450a1 fix: use per-SID callback dictionary in SysSubscribe to support multiple subscriptions 2026-02-23 05:38:10 -05:00
Joseph Doherty
8e790445f4 feat: add InternalEventSystem with Channel-based send/receive loops 2026-02-23 05:34:57 -05:00
Joseph Doherty
fc96b6eb43 feat: add system event DTOs and JSON source generator context 2026-02-23 05:29:40 -05:00
Joseph Doherty
b0c5b4acd8 feat: add system event subject constants and SystemMessageHandler delegate 2026-02-23 05:26:25 -05:00
Joseph Doherty
0c4bca9073 feat: add InternalClient class for socketless internal messaging 2026-02-23 05:22:58 -05:00
Joseph Doherty
0e7db5615e feat: add INatsClient interface and implement on NatsClient
Extract INatsClient interface from NatsClient to enable internal clients
(SYSTEM, ACCOUNT) to participate in the subscription system without
requiring a socket connection. Change Subscription.Client from concrete
NatsClient to INatsClient, keeping IMessageRouter and RemoveClient using
the concrete type since only socket clients need those paths.
2026-02-23 05:18:59 -05:00
Joseph Doherty
5e11785bdf feat: add ClientKind enum with IsInternal extension 2026-02-23 05:15:06 -05:00
50 changed files with 2490 additions and 2615 deletions

View File

@@ -11,7 +11,7 @@
| Feature | Go | .NET | Notes |
|---------|:--:|:----:|-------|
| NKey generation (server identity) | Y | Y | Ed25519 key pair via NATS.NKeys at startup |
| System account setup | Y | Y | `$SYS` account created; no event publishing yet (stub) |
| System account setup | Y | Y | `$SYS` account with InternalEventSystem, event publishing, request-reply services |
| Config file validation on startup | Y | Y | Full config parsing with error collection via `ConfigProcessor` |
| PID file writing | Y | Y | Written on startup, deleted on shutdown |
| Profiling HTTP endpoint (`/debug/pprof`) | Y | Stub | `ProfPort` option exists but endpoint not implemented |
@@ -64,10 +64,10 @@
| ROUTER | Y | N | Excluded per scope |
| GATEWAY | Y | N | Excluded per scope |
| LEAF | Y | N | Excluded per scope |
| SYSTEM (internal) | Y | N | |
| SYSTEM (internal) | Y | Y | InternalClient + InternalEventSystem with Channel-based send/receive loops |
| JETSTREAM (internal) | Y | N | |
| ACCOUNT (internal) | Y | N | |
| WebSocket clients | Y | Y | Custom frame parser, permessage-deflate compression, origin checking, cookie auth |
| ACCOUNT (internal) | Y | Y | Lazy per-account InternalClient with import/export subscription support |
| WebSocket clients | Y | N | |
| MQTT clients | Y | N | |
### Client Features
@@ -218,7 +218,7 @@ Go implements a sophisticated slow consumer detection system:
|---------|:--:|:----:|-------|
| Per-account SubList isolation | Y | Y | |
| Multi-account user resolution | Y | Y | `AccountConfig` per account in `NatsOptions.Accounts`; `GetOrCreateAccount` wires limits |
| Account exports/imports | Y | N | |
| Account exports/imports | Y | Y | ServiceImport/StreamImport with ExportAuth, subject transforms, response routing |
| Per-account connection limits | Y | Y | `Account.AddClient()` returns false when `MaxConnections` exceeded |
| Per-account subscription limits | Y | Y | `Account.IncrementSubscriptions()` enforced in `ProcessSub()` |
| Account JetStream limits | Y | N | Excluded per scope |
@@ -267,8 +267,7 @@ Go implements a sophisticated slow consumer detection system:
- ~~Advanced limits (MaxSubs, MaxSubTokens, MaxPending, WriteDeadline)~~ — `MaxSubs`, `MaxSubTokens` implemented; MaxPending/WriteDeadline already existed
- ~~Tags/metadata~~ — `Tags` dictionary implemented in `NatsOptions`
- ~~OCSP configuration~~ — `OcspConfig` with 4 modes (Auto/Always/Must/Never), peer verification, and stapling
- ~~WebSocket options~~ — `WebSocketOptions` with port, compression, origin checking, cookie auth, custom headers
- MQTT options
- WebSocket/MQTT options
- ~~Operator mode / account resolver~~ — `JwtAuthenticator` + `IAccountResolver` + `MemAccountResolver` with trusted keys
---
@@ -407,6 +406,11 @@ The following items from the original gap list have been implemented:
- **User revocation** — per-account tracking with wildcard (`*`) revocation
- **Config file parsing** — custom lexer/parser ported from Go; supports includes, variables, nested blocks, size suffixes
- **Hot reload (SIGHUP)** — re-parses config, diffs changes, validates reloadable set, applies with CLI precedence
- **SYSTEM client type** — InternalClient with InternalEventSystem, Channel-based send/receive loops, event publishing
- **ACCOUNT client type** — lazy per-account InternalClient with import/export subscription support
- **System event publishing** — connect/disconnect advisories, server stats, shutdown/lame-duck events, auth errors
- **System request-reply services** — $SYS.REQ.SERVER.*.VARZ/CONNZ/SUBSZ/HEALTHZ/IDZ/STATSZ with ping wildcards
- **Account exports/imports** — service and stream imports with ExportAuth, subject transforms, response routing, latency tracking
### Remaining Lower Priority
1. **Dynamic buffer sizing** — delegated to Pipe, less optimized for long-lived connections

View File

@@ -1,141 +0,0 @@
# Full JetStream and Cluster Prerequisite Parity Design
**Date:** 2026-02-23
**Status:** Approved
**Scope:** Port JetStream from Go with all prerequisite subsystems required for full Go JetStream test parity, including cluster route/gateway/leaf behaviors and RAFT/meta-cluster semantics.
**Verification Gate:** Go JetStream-focused test suites in `golang/nats-server/server/` plus new/updated .NET tests.
**Cutover Model:** Single end-to-end cutover (no interim acceptance gates).
## 1. Architecture
The implementation uses a full in-process .NET parity architecture that mirrors Go subsystem boundaries while keeping strict internal contracts.
1. Core Server Layer (`NatsServer`/`NatsClient`)
- Extend existing server/client runtime to support full client kinds and inter-server protocol paths.
- Preserve responsibility for socket lifecycle, parser integration, auth entry, and local dispatch.
2. Cluster Fabric Layer
- Add route mesh, gateway links, leafnode links, interest propagation, and remote subscription accounting.
- Provide transport-neutral contracts consumed by JetStream and RAFT replication services.
3. JetStream Control Plane
- Add account-scoped JetStream managers, API subject handlers (`$JS.API.*`), stream/consumer metadata lifecycle, advisories, and limit enforcement.
- Integrate with RAFT/meta services for replicated decisions.
4. JetStream Data Plane
- Add stream ingest path, retention/eviction logic, consumer delivery/ack/redelivery, mirror/source orchestration, and flow-control behavior.
- Use pluggable storage abstractions with parity-focused behavior.
5. RAFT and Replication Layer
- Implement meta-group plus per-asset replication groups, election/term logic, log replication, snapshots, and catchup.
- Expose deterministic commit/applied hooks to JetStream runtime layers.
6. Storage Layer
- Implement memstore and filestore with sequence indexing, subject indexing, compaction/snapshot support, and recovery semantics.
7. Observability Layer
- Upgrade `/jsz` and `/varz` JetStream blocks from placeholders to live runtime reporting with Go-compatible response shape.
## 2. Components and Contracts
### 2.1 New component families
1. Cluster and interserver subsystem
- Add route/gateway/leaf and interserver protocol operations under `src/NATS.Server/`.
- Extend parser/dispatcher with route/leaf/account operations currently excluded.
- Expand client-kind model and command routing constraints.
2. JetStream API and domain model
- Add `src/NATS.Server/JetStream/` subtree for API payload models, stream/consumer models, and error templates/codes.
3. JetStream runtime
- Add stream manager, consumer manager, ack processor, delivery scheduler, mirror/source orchestration, and flow control handlers.
- Integrate publish path with stream capture/store/ack behavior.
4. RAFT subsystem
- Add `src/NATS.Server/Raft/` for replicated logs, elections, snapshots, and membership operations.
5. Storage subsystem
- Add `src/NATS.Server/JetStream/Storage/` for `MemStore` and `FileStore`, sequence/subject indexes, and restart recovery.
### 2.2 Existing components to upgrade
1. `src/NATS.Server/NatsOptions.cs`
- Add full config surface for clustering, JetStream, storage, placement, and parity-required limits.
2. `src/NATS.Server/Configuration/ConfigProcessor.cs`
- Replace silent ignore behavior for cluster/jetstream keys with parsing, mapping, and validation.
3. `src/NATS.Server/Protocol/NatsParser.cs` and `src/NATS.Server/NatsClient.cs`
- Add missing interserver operations and kind-aware dispatch paths needed for clustered JetStream behavior.
4. Monitoring components
- Upgrade `src/NATS.Server/Monitoring/MonitorServer.cs` and `src/NATS.Server/Monitoring/Varz.cs`.
- Add/extend JS monitoring handlers and models for `/jsz` and JetStream runtime fields.
## 3. Data Flow and Behavioral Semantics
1. Inbound publish path
- Parse client publish commands, apply auth/permission checks, route to local subscribers and JetStream candidates.
- For JetStream subjects: apply preconditions, append to store, replicate via RAFT (as required), apply committed state, return Go-compatible pub ack.
2. Consumer delivery path
- Use shared push/pull state model for pending, ack floor, redelivery timers, flow control, and max ack pending.
- Enforce retention policy semantics (limits/interest/workqueue), filter subject behavior, replay policy, and eviction behavior.
3. Replication and control flow
- Meta RAFT governs replicated metadata decisions.
- Per-stream/per-consumer groups replicate state and snapshots.
- Leader changes preserve at-least-once delivery and consumer state invariants.
4. Recovery flow
- Reconstruct stream/consumer/store state on startup.
- In clustered mode, rejoin replication groups and catch up before serving full API/delivery workload.
- Preserve sequence continuity, subject indexes, delete markers, and pending/redelivery state.
5. Monitoring flow
- `/varz` JetStream fields and `/jsz` return live runtime state.
- Advisory and metric surfaces update from control-plane and data-plane events.
## 4. Error Handling and Operational Constraints
1. API error parity
- Match canonical JetStream codes/messages for validation failures, state conflicts, limits, leadership/quorum issues, and storage failures.
2. Protocol behavior
- Preserve normal client compatibility while adding interserver protocol and internal client-kind restrictions.
3. Storage and consistency failures
- Classify corruption/truncation/checksum/snapshot failures as recoverable vs non-recoverable.
- Avoid silent data loss and emit monitoring/advisory signals where parity requires.
4. Cluster and RAFT fault handling
- Explicitly handle no-quorum, stale leader, delayed apply, peer removal, catchup lag, and stepdown transitions.
- Return leadership-aware API errors.
5. Config/reload behavior
- Treat JetStream and cluster config as first-class with strict validation.
- Mirror Go-like reloadable vs restart-required change boundaries.
## 5. Testing and Verification Strategy
1. .NET unit tests
- Add focused tests for JetStream API validation, stream and consumer state, RAFT primitives, mem/file store invariants, and config parsing/validation.
2. .NET integration tests
- Add end-to-end tests for publish/store/consume/ack behavior, retention policies, restart recovery, and clustered prerequisites used by JetStream.
3. Parity harness
- Maintain mapping of Go JetStream test categories to .NET feature areas.
- Execute JetStream-focused Go tests from `golang/nats-server/server/` as acceptance benchmark.
4. `differences.md` policy
- Update only after verification gate passes.
- Remove opening JetStream exclusion scope statement and replace with updated parity scope.
## 6. Scope Decisions Captured
- Include all prerequisite non-JetStream subsystems required to satisfy full Go JetStream tests.
- Verification target is full Go JetStream-focused parity, not a narrowed subset.
- Delivery model is single end-to-end cutover.
- `differences.md` top-level scope statement will be updated to include JetStream and clustering parity coverage once verified.

View File

@@ -1,21 +1,18 @@
# MQTT Connection Type Port Design
## Goal
Port MQTT-related connection type parity from Go into the .NET server for three scoped areas:
Port MQTT-related connection type parity from Go into the .NET server for two scoped areas:
1. JWT `allowed_connection_types` behavior for `MQTT` / `MQTT_WS` (plus existing known types).
2. `/connz` filtering by `mqtt_client`.
3. Full MQTT configuration parsing from `mqtt {}` config blocks (all Go `MQTTOpts` fields).
## Scope
- In scope:
- JWT allowed connection type normalization and enforcement semantics.
- `/connz?mqtt_client=` option parsing and filtering.
- MQTT configuration model and config file parsing (all Go `MQTTOpts` fields).
- Expanded `MqttOptsVarz` monitoring output.
- Unit/integration tests for new and updated behavior.
- `differences.md` updates after implementation is verified.
- Out of scope:
- Full MQTT transport implementation (listener, protocol parser, sessions).
- Full MQTT transport implementation.
- WebSocket transport implementation.
- Leaf/route/gateway transport plumbing.
@@ -30,8 +27,6 @@ Port MQTT-related connection type parity from Go into the .NET server for three
- Extend connz monitoring options to parse `mqtt_client` and apply exact-match filtering before sort/pagination.
## Components
### JWT Connection-Type Enforcement
- `src/NATS.Server/Auth/IAuthenticator.cs`
- Extend `ClientAuthContext` with a connection-type value.
- `src/NATS.Server/Auth/Jwt/JwtConnectionTypes.cs` (new)
@@ -43,8 +38,6 @@ Port MQTT-related connection type parity from Go into the .NET server for three
- Enforce against current `ClientAuthContext.ConnectionType`.
- `src/NATS.Server/NatsClient.cs`
- Populate auth context connection type (currently `STANDARD`).
### Connz MQTT Client Filtering
- `src/NATS.Server/Monitoring/Connz.cs`
- Add `MqttClient` to `ConnzOptions` with JSON field `mqtt_client`.
- `src/NATS.Server/Monitoring/ConnzHandler.cs`
@@ -55,30 +48,6 @@ Port MQTT-related connection type parity from Go into the .NET server for three
- `src/NATS.Server/NatsServer.cs`
- Persist `MqttClient` into `ClosedClient` snapshot (empty for now).
### MQTT Configuration Parsing
- `src/NATS.Server/MqttOptions.cs` (new)
- Full model matching Go `MQTTOpts` struct (opts.go:613-707):
- Network: `Host`, `Port`
- Auth override: `NoAuthUser`, `Username`, `Password`, `Token`, `AuthTimeout`
- TLS: `TlsCert`, `TlsKey`, `TlsCaCert`, `TlsVerify`, `TlsTimeout`, `TlsMap`, `TlsPinnedCerts`
- JetStream: `JsDomain`, `StreamReplicas`, `ConsumerReplicas`, `ConsumerMemoryStorage`, `ConsumerInactiveThreshold`
- QoS: `AckWait`, `MaxAckPending`, `JsApiTimeout`
- `src/NATS.Server/NatsOptions.cs`
- Add `Mqtt` property of type `MqttOptions?`.
- `src/NATS.Server/Configuration/ConfigProcessor.cs`
- Add `ParseMqtt()` for `mqtt {}` config block with Go-compatible key aliases:
- `host`/`net` → Host, `listen` → Host+Port
- `ack_wait`/`ackwait` → AckWait
- `max_ack_pending`/`max_pending`/`max_inflight` → MaxAckPending
- `js_domain` → JsDomain
- `js_api_timeout`/`api_timeout` → JsApiTimeout
- `consumer_inactive_threshold`/`consumer_auto_cleanup` → ConsumerInactiveThreshold
- Nested `tls {}` and `authorization {}`/`authentication {}` blocks
- `src/NATS.Server/Monitoring/Varz.cs`
- Expand `MqttOptsVarz` from 3 fields to full monitoring-visible set.
- `src/NATS.Server/Monitoring/VarzHandler.cs`
- Populate expanded `MqttOptsVarz` from `NatsOptions.Mqtt`.
## Data Flow
1. Client sends `CONNECT`.
2. `NatsClient.ProcessConnectAsync` builds `ClientAuthContext` with `ConnectionType=STANDARD`.
@@ -104,7 +73,6 @@ Port MQTT-related connection type parity from Go into the .NET server for three
- MQTT transport is not implemented yet in this repository.
- Runtime connection type currently resolves to `STANDARD` in auth context.
- `mqtt_client` values remain empty until MQTT path populates them.
- MQTT config is parsed and stored but no listener is started.
## Testing Strategy
- `tests/NATS.Server.Tests/JwtAuthenticatorTests.cs`
@@ -117,16 +85,9 @@ Port MQTT-related connection type parity from Go into the .NET server for three
- `/connz?mqtt_client=<id>` returns matching connections only.
- `/connz?state=closed&mqtt_client=<id>` filters closed snapshots.
- non-existing ID yields empty connection set.
- `tests/NATS.Server.Tests/ConfigProcessorTests.cs` (or similar)
- Parse valid `mqtt {}` block with all fields.
- Parse config with aliases (ackwait vs ack_wait, host vs net, etc.).
- Parse nested `tls {}` and `authorization {}` blocks within mqtt.
- Varz MQTT section populated from config.
## Success Criteria
- JWT `allowed_connection_types` behavior matches Go semantics for known/unknown mixing and unknown-only rejection.
- `/connz` supports exact `mqtt_client` filtering for open and closed sets.
- `mqtt {}` config block parses all Go `MQTTOpts` fields with aliases.
- `MqttOptsVarz` includes full monitoring output.
- Added tests pass.
- `differences.md` accurately reflects implemented parity.

View File

@@ -1,4 +1,5 @@
using System.Collections.Concurrent;
using NATS.Server.Imports;
using NATS.Server.Subscriptions;
namespace NATS.Server.Auth;
@@ -12,6 +13,8 @@ public sealed class Account : IDisposable
public Permissions? DefaultPermissions { get; set; }
public int MaxConnections { get; set; } // 0 = unlimited
public int MaxSubscriptions { get; set; } // 0 = unlimited
public ExportMap Exports { get; } = new();
public ImportMap Imports { get; } = new();
// JWT fields
public string? Nkey { get; set; }
@@ -89,5 +92,77 @@ public sealed class Account : IDisposable
Interlocked.Add(ref _outBytes, bytes);
}
// Internal (ACCOUNT) client for import/export message routing
private InternalClient? _internalClient;
public InternalClient GetOrCreateInternalClient(ulong clientId)
{
if (_internalClient != null) return _internalClient;
_internalClient = new InternalClient(clientId, ClientKind.Account, this);
return _internalClient;
}
public void AddServiceExport(string subject, ServiceResponseType responseType, IEnumerable<Account>? approved)
{
var auth = new ExportAuth
{
ApprovedAccounts = approved != null ? new HashSet<string>(approved.Select(a => a.Name)) : null,
};
Exports.Services[subject] = new ServiceExport
{
Auth = auth,
Account = this,
ResponseType = responseType,
};
}
public void AddStreamExport(string subject, IEnumerable<Account>? approved)
{
var auth = new ExportAuth
{
ApprovedAccounts = approved != null ? new HashSet<string>(approved.Select(a => a.Name)) : null,
};
Exports.Streams[subject] = new StreamExport { Auth = auth };
}
public ServiceImport AddServiceImport(Account destination, string from, string to)
{
if (!destination.Exports.Services.TryGetValue(to, out var export))
throw new InvalidOperationException($"No service export found for '{to}' on account '{destination.Name}'");
if (!export.Auth.IsAuthorized(this))
throw new UnauthorizedAccessException($"Account '{Name}' not authorized to import '{to}' from '{destination.Name}'");
var si = new ServiceImport
{
DestinationAccount = destination,
From = from,
To = to,
Export = export,
ResponseType = export.ResponseType,
};
Imports.AddServiceImport(si);
return si;
}
public void AddStreamImport(Account source, string from, string to)
{
if (!source.Exports.Streams.TryGetValue(from, out var export))
throw new InvalidOperationException($"No stream export found for '{from}' on account '{source.Name}'");
if (!export.Auth.IsAuthorized(this))
throw new UnauthorizedAccessException($"Account '{Name}' not authorized to import '{from}' from '{source.Name}'");
var si = new StreamImport
{
SourceAccount = source,
From = from,
To = to,
};
Imports.Streams.Add(si);
}
public void Dispose() => SubList.Dispose();
}

View File

@@ -0,0 +1,22 @@
namespace NATS.Server;
/// <summary>
/// Identifies the type of a client connection.
/// Maps to Go's client kind constants in client.go:45-65.
/// </summary>
public enum ClientKind
{
Client,
Router,
Gateway,
Leaf,
System,
JetStream,
Account,
}
public static class ClientKindExtensions
{
public static bool IsInternal(this ClientKind kind) =>
kind is ClientKind.System or ClientKind.JetStream or ClientKind.Account;
}

View File

@@ -0,0 +1,12 @@
using System.Text.Json.Serialization;
namespace NATS.Server.Events;
[JsonSerializable(typeof(ConnectEventMsg))]
[JsonSerializable(typeof(DisconnectEventMsg))]
[JsonSerializable(typeof(AccountNumConns))]
[JsonSerializable(typeof(ServerStatsMsg))]
[JsonSerializable(typeof(ShutdownEventMsg))]
[JsonSerializable(typeof(LameDuckEventMsg))]
[JsonSerializable(typeof(AuthErrorEventMsg))]
internal partial class EventJsonContext : JsonSerializerContext;

View File

@@ -0,0 +1,49 @@
using NATS.Server.Auth;
using NATS.Server.Subscriptions;
namespace NATS.Server.Events;
/// <summary>
/// System event subject patterns.
/// Maps to Go events.go:41-97 subject constants.
/// </summary>
public static class EventSubjects
{
// Account-scoped events
public const string ConnectEvent = "$SYS.ACCOUNT.{0}.CONNECT";
public const string DisconnectEvent = "$SYS.ACCOUNT.{0}.DISCONNECT";
public const string AccountConnsNew = "$SYS.ACCOUNT.{0}.SERVER.CONNS";
public const string AccountConnsOld = "$SYS.SERVER.ACCOUNT.{0}.CONNS";
// Server-scoped events
public const string ServerStats = "$SYS.SERVER.{0}.STATSZ";
public const string ServerShutdown = "$SYS.SERVER.{0}.SHUTDOWN";
public const string ServerLameDuck = "$SYS.SERVER.{0}.LAMEDUCK";
public const string AuthError = "$SYS.SERVER.{0}.CLIENT.AUTH.ERR";
public const string AuthErrorAccount = "$SYS.ACCOUNT.CLIENT.AUTH.ERR";
// Request-reply subjects (server-specific)
public const string ServerReq = "$SYS.REQ.SERVER.{0}.{1}";
// Wildcard ping subjects (all servers respond)
public const string ServerPing = "$SYS.REQ.SERVER.PING.{0}";
// Account-scoped request subjects
public const string AccountReq = "$SYS.REQ.ACCOUNT.{0}.{1}";
// Inbox for responses
public const string InboxResponse = "$SYS._INBOX_.{0}";
}
/// <summary>
/// Callback signature for system message handlers.
/// Maps to Go's sysMsgHandler type in events.go:109.
/// </summary>
public delegate void SystemMessageHandler(
Subscription? sub,
INatsClient? client,
Account? account,
string subject,
string? reply,
ReadOnlyMemory<byte> headers,
ReadOnlyMemory<byte> message);

View File

@@ -0,0 +1,270 @@
using System.Text.Json.Serialization;
namespace NATS.Server.Events;
/// <summary>
/// Server identity block embedded in all system events.
/// </summary>
public sealed class EventServerInfo
{
[JsonPropertyName("name")]
public string Name { get; set; } = string.Empty;
[JsonPropertyName("host")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Host { get; set; }
[JsonPropertyName("id")]
public string Id { get; set; } = string.Empty;
[JsonPropertyName("cluster")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Cluster { get; set; }
[JsonPropertyName("domain")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Domain { get; set; }
[JsonPropertyName("ver")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Version { get; set; }
[JsonPropertyName("seq")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public ulong Seq { get; set; }
[JsonPropertyName("tags")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public Dictionary<string, string>? Tags { get; set; }
}
/// <summary>
/// Client identity block for connect/disconnect events.
/// </summary>
public sealed class EventClientInfo
{
[JsonPropertyName("start")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public DateTime Start { get; set; }
[JsonPropertyName("stop")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public DateTime Stop { get; set; }
[JsonPropertyName("host")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Host { get; set; }
[JsonPropertyName("id")]
public ulong Id { get; set; }
[JsonPropertyName("acc")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Account { get; set; }
[JsonPropertyName("name")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Name { get; set; }
[JsonPropertyName("lang")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Lang { get; set; }
[JsonPropertyName("ver")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Version { get; set; }
[JsonPropertyName("rtt")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public long RttNanos { get; set; }
}
public sealed class DataStats
{
[JsonPropertyName("msgs")]
public long Msgs { get; set; }
[JsonPropertyName("bytes")]
public long Bytes { get; set; }
}
/// <summary>Client connect advisory. Go events.go:155-160.</summary>
public sealed class ConnectEventMsg
{
public const string EventType = "io.nats.server.advisory.v1.client_connect";
[JsonPropertyName("type")]
public string Type { get; set; } = EventType;
[JsonPropertyName("id")]
public string Id { get; set; } = string.Empty;
[JsonPropertyName("timestamp")]
public DateTime Time { get; set; }
[JsonPropertyName("server")]
public EventServerInfo Server { get; set; } = new();
[JsonPropertyName("client")]
public EventClientInfo Client { get; set; } = new();
}
/// <summary>Client disconnect advisory. Go events.go:167-174.</summary>
public sealed class DisconnectEventMsg
{
public const string EventType = "io.nats.server.advisory.v1.client_disconnect";
[JsonPropertyName("type")]
public string Type { get; set; } = EventType;
[JsonPropertyName("id")]
public string Id { get; set; } = string.Empty;
[JsonPropertyName("timestamp")]
public DateTime Time { get; set; }
[JsonPropertyName("server")]
public EventServerInfo Server { get; set; } = new();
[JsonPropertyName("client")]
public EventClientInfo Client { get; set; } = new();
[JsonPropertyName("sent")]
public DataStats Sent { get; set; } = new();
[JsonPropertyName("received")]
public DataStats Received { get; set; } = new();
[JsonPropertyName("reason")]
public string Reason { get; set; } = string.Empty;
}
/// <summary>Account connection count heartbeat. Go events.go:210-214.</summary>
public sealed class AccountNumConns
{
public const string EventType = "io.nats.server.advisory.v1.account_connections";
[JsonPropertyName("type")]
public string Type { get; set; } = EventType;
[JsonPropertyName("id")]
public string Id { get; set; } = string.Empty;
[JsonPropertyName("timestamp")]
public DateTime Time { get; set; }
[JsonPropertyName("server")]
public EventServerInfo Server { get; set; } = new();
[JsonPropertyName("acc")]
public string AccountName { get; set; } = string.Empty;
[JsonPropertyName("conns")]
public int Connections { get; set; }
[JsonPropertyName("total_conns")]
public long TotalConnections { get; set; }
[JsonPropertyName("subs")]
public int Subscriptions { get; set; }
[JsonPropertyName("sent")]
public DataStats Sent { get; set; } = new();
[JsonPropertyName("received")]
public DataStats Received { get; set; } = new();
}
/// <summary>Server stats broadcast. Go events.go:150-153.</summary>
public sealed class ServerStatsMsg
{
[JsonPropertyName("server")]
public EventServerInfo Server { get; set; } = new();
[JsonPropertyName("statsz")]
public ServerStatsData Stats { get; set; } = new();
}
public sealed class ServerStatsData
{
[JsonPropertyName("start")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public DateTime Start { get; set; }
[JsonPropertyName("mem")]
public long Mem { get; set; }
[JsonPropertyName("cores")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public int Cores { get; set; }
[JsonPropertyName("connections")]
public int Connections { get; set; }
[JsonPropertyName("total_connections")]
public long TotalConnections { get; set; }
[JsonPropertyName("active_accounts")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public int ActiveAccounts { get; set; }
[JsonPropertyName("subscriptions")]
public long Subscriptions { get; set; }
[JsonPropertyName("in_msgs")]
public long InMsgs { get; set; }
[JsonPropertyName("out_msgs")]
public long OutMsgs { get; set; }
[JsonPropertyName("in_bytes")]
public long InBytes { get; set; }
[JsonPropertyName("out_bytes")]
public long OutBytes { get; set; }
[JsonPropertyName("slow_consumers")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public long SlowConsumers { get; set; }
}
/// <summary>Server shutdown notification.</summary>
public sealed class ShutdownEventMsg
{
[JsonPropertyName("server")]
public EventServerInfo Server { get; set; } = new();
[JsonPropertyName("reason")]
public string Reason { get; set; } = string.Empty;
}
/// <summary>Lame duck mode notification.</summary>
public sealed class LameDuckEventMsg
{
[JsonPropertyName("server")]
public EventServerInfo Server { get; set; } = new();
}
/// <summary>Auth error advisory.</summary>
public sealed class AuthErrorEventMsg
{
public const string EventType = "io.nats.server.advisory.v1.client_auth";
[JsonPropertyName("type")]
public string Type { get; set; } = EventType;
[JsonPropertyName("id")]
public string Id { get; set; } = string.Empty;
[JsonPropertyName("timestamp")]
public DateTime Time { get; set; }
[JsonPropertyName("server")]
public EventServerInfo Server { get; set; } = new();
[JsonPropertyName("client")]
public EventClientInfo Client { get; set; } = new();
[JsonPropertyName("reason")]
public string Reason { get; set; } = string.Empty;
}

View File

@@ -0,0 +1,333 @@
using System.Collections.Concurrent;
using System.Security.Cryptography;
using System.Text;
using System.Text.Json;
using System.Threading.Channels;
using Microsoft.Extensions.Logging;
using NATS.Server.Auth;
using NATS.Server.Subscriptions;
namespace NATS.Server.Events;
/// <summary>
/// Internal publish message queued for the send loop.
/// </summary>
public sealed class PublishMessage
{
public InternalClient? Client { get; init; }
public required string Subject { get; init; }
public string? Reply { get; init; }
public byte[]? Headers { get; init; }
public object? Body { get; init; }
public bool Echo { get; init; }
public bool IsLast { get; init; }
}
/// <summary>
/// Internal received message queued for the receive loop.
/// </summary>
public sealed class InternalSystemMessage
{
public required Subscription? Sub { get; init; }
public required INatsClient? Client { get; init; }
public required Account? Account { get; init; }
public required string Subject { get; init; }
public required string? Reply { get; init; }
public required ReadOnlyMemory<byte> Headers { get; init; }
public required ReadOnlyMemory<byte> Message { get; init; }
public required SystemMessageHandler Callback { get; init; }
}
/// <summary>
/// Manages the server's internal event system with Channel-based send/receive loops.
/// Maps to Go's internal struct in events.go:124-147 and the goroutines
/// internalSendLoop (events.go:495) and internalReceiveLoop (events.go:476).
/// </summary>
public sealed class InternalEventSystem : IAsyncDisposable
{
private readonly ILogger _logger;
private readonly Channel<PublishMessage> _sendQueue;
private readonly Channel<InternalSystemMessage> _receiveQueue;
private readonly Channel<InternalSystemMessage> _receiveQueuePings;
private readonly CancellationTokenSource _cts = new();
private Task? _sendLoop;
private Task? _receiveLoop;
private Task? _receiveLoopPings;
private NatsServer? _server;
private ulong _sequence;
private int _subscriptionId;
private readonly ConcurrentDictionary<string, SystemMessageHandler> _callbacks = new();
public Account SystemAccount { get; }
public InternalClient SystemClient { get; }
public string ServerHash { get; }
public InternalEventSystem(Account systemAccount, InternalClient systemClient, string serverName, ILogger logger)
{
_logger = logger;
SystemAccount = systemAccount;
SystemClient = systemClient;
// Hash server name for inbox routing (matches Go's shash)
ServerHash = Convert.ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(serverName)))[..8].ToLowerInvariant();
_sendQueue = Channel.CreateUnbounded<PublishMessage>(new UnboundedChannelOptions { SingleReader = true });
_receiveQueue = Channel.CreateUnbounded<InternalSystemMessage>(new UnboundedChannelOptions { SingleReader = true });
_receiveQueuePings = Channel.CreateUnbounded<InternalSystemMessage>(new UnboundedChannelOptions { SingleReader = true });
}
public void Start(NatsServer server)
{
_server = server;
var ct = _cts.Token;
_sendLoop = Task.Run(() => InternalSendLoopAsync(ct), ct);
_receiveLoop = Task.Run(() => InternalReceiveLoopAsync(_receiveQueue, ct), ct);
_receiveLoopPings = Task.Run(() => InternalReceiveLoopAsync(_receiveQueuePings, ct), ct);
// Periodic stats publish every 10 seconds
_ = Task.Run(async () =>
{
using var timer = new PeriodicTimer(TimeSpan.FromSeconds(10));
while (await timer.WaitForNextTickAsync(ct))
{
PublishServerStats();
}
}, ct);
}
/// <summary>
/// Registers system request-reply monitoring services for this server.
/// Maps to Go's initEventTracking in events.go.
/// Sets up handlers for $SYS.REQ.SERVER.{id}.VARZ, HEALTHZ, SUBSZ, STATSZ, IDZ
/// and wildcard $SYS.REQ.SERVER.PING.* subjects.
/// </summary>
public void InitEventTracking(NatsServer server)
{
_server = server;
var serverId = server.ServerId;
// Server-specific monitoring services
RegisterService(serverId, "VARZ", server.HandleVarzRequest);
RegisterService(serverId, "HEALTHZ", server.HandleHealthzRequest);
RegisterService(serverId, "SUBSZ", server.HandleSubszRequest);
RegisterService(serverId, "STATSZ", server.HandleStatszRequest);
RegisterService(serverId, "IDZ", server.HandleIdzRequest);
// Wildcard ping services (all servers respond)
SysSubscribe(string.Format(EventSubjects.ServerPing, "VARZ"), WrapRequestHandler(server.HandleVarzRequest));
SysSubscribe(string.Format(EventSubjects.ServerPing, "HEALTHZ"), WrapRequestHandler(server.HandleHealthzRequest));
SysSubscribe(string.Format(EventSubjects.ServerPing, "IDZ"), WrapRequestHandler(server.HandleIdzRequest));
SysSubscribe(string.Format(EventSubjects.ServerPing, "STATSZ"), WrapRequestHandler(server.HandleStatszRequest));
}
private void RegisterService(string serverId, string name, Action<string, string?> handler)
{
var subject = string.Format(EventSubjects.ServerReq, serverId, name);
SysSubscribe(subject, WrapRequestHandler(handler));
}
private SystemMessageHandler WrapRequestHandler(Action<string, string?> handler)
{
return (sub, client, acc, subject, reply, hdr, msg) =>
{
handler(subject, reply);
};
}
/// <summary>
/// Publishes a $SYS.SERVER.{id}.STATSZ message with current server statistics.
/// Maps to Go's sendStatsz in events.go.
/// Can be called manually for testing or is invoked periodically by the stats timer.
/// </summary>
public void PublishServerStats()
{
if (_server == null) return;
var subject = string.Format(EventSubjects.ServerStats, _server.ServerId);
var process = System.Diagnostics.Process.GetCurrentProcess();
var statsMsg = new ServerStatsMsg
{
Server = _server.BuildEventServerInfo(),
Stats = new ServerStatsData
{
Start = _server.StartTime,
Mem = process.WorkingSet64,
Cores = Environment.ProcessorCount,
Connections = _server.ClientCount,
TotalConnections = Interlocked.Read(ref _server.Stats.TotalConnections),
Subscriptions = SystemAccount.SubList.Count,
InMsgs = Interlocked.Read(ref _server.Stats.InMsgs),
OutMsgs = Interlocked.Read(ref _server.Stats.OutMsgs),
InBytes = Interlocked.Read(ref _server.Stats.InBytes),
OutBytes = Interlocked.Read(ref _server.Stats.OutBytes),
SlowConsumers = Interlocked.Read(ref _server.Stats.SlowConsumers),
},
};
Enqueue(new PublishMessage { Subject = subject, Body = statsMsg });
}
/// <summary>
/// Creates a system subscription in the system account's SubList.
/// Maps to Go's sysSubscribe in events.go:2796.
/// </summary>
public Subscription SysSubscribe(string subject, SystemMessageHandler callback)
{
var sid = Interlocked.Increment(ref _subscriptionId).ToString();
var sub = new Subscription
{
Subject = subject,
Sid = sid,
Client = SystemClient,
};
// Store callback keyed by SID so multiple subscriptions work
_callbacks[sid] = callback;
// Set a single routing callback on the system client that dispatches by SID
SystemClient.MessageCallback = (subj, s, reply, hdr, msg) =>
{
if (_callbacks.TryGetValue(s, out var cb))
{
_receiveQueue.Writer.TryWrite(new InternalSystemMessage
{
Sub = sub,
Client = SystemClient,
Account = SystemAccount,
Subject = subj,
Reply = reply,
Headers = hdr,
Message = msg,
Callback = cb,
});
}
};
SystemAccount.SubList.Insert(sub);
return sub;
}
/// <summary>
/// Returns the next monotonically increasing sequence number for event ordering.
/// </summary>
public ulong NextSequence() => Interlocked.Increment(ref _sequence);
/// <summary>
/// Enqueue an internal message for publishing through the send loop.
/// </summary>
public void Enqueue(PublishMessage message)
{
_sendQueue.Writer.TryWrite(message);
}
/// <summary>
/// The send loop: serializes messages and delivers them via the server's routing.
/// Maps to Go's internalSendLoop in events.go:495-668.
/// </summary>
private async Task InternalSendLoopAsync(CancellationToken ct)
{
try
{
await foreach (var pm in _sendQueue.Reader.ReadAllAsync(ct))
{
try
{
var seq = Interlocked.Increment(ref _sequence);
// Serialize body to JSON
byte[] payload;
if (pm.Body is byte[] raw)
{
payload = raw;
}
else if (pm.Body != null)
{
// Try source-generated context first, fall back to reflection-based for unknown types
var bodyType = pm.Body.GetType();
var typeInfo = EventJsonContext.Default.GetTypeInfo(bodyType);
payload = typeInfo != null
? JsonSerializer.SerializeToUtf8Bytes(pm.Body, typeInfo)
: JsonSerializer.SerializeToUtf8Bytes(pm.Body, bodyType);
}
else
{
payload = [];
}
// Deliver via the system account's SubList matching
var result = SystemAccount.SubList.Match(pm.Subject);
foreach (var sub in result.PlainSubs)
{
sub.Client?.SendMessage(pm.Subject, sub.Sid, pm.Reply,
pm.Headers ?? ReadOnlyMemory<byte>.Empty,
payload);
}
foreach (var queueGroup in result.QueueSubs)
{
if (queueGroup.Length == 0) continue;
var sub = queueGroup[0]; // Simple pick for internal
sub.Client?.SendMessage(pm.Subject, sub.Sid, pm.Reply,
pm.Headers ?? ReadOnlyMemory<byte>.Empty,
payload);
}
if (pm.IsLast)
break;
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Error in internal send loop processing message on {Subject}", pm.Subject);
}
}
}
catch (OperationCanceledException)
{
// Normal shutdown
}
}
/// <summary>
/// The receive loop: dispatches callbacks for internally-received messages.
/// Maps to Go's internalReceiveLoop in events.go:476-491.
/// </summary>
private async Task InternalReceiveLoopAsync(Channel<InternalSystemMessage> queue, CancellationToken ct)
{
try
{
await foreach (var msg in queue.Reader.ReadAllAsync(ct))
{
try
{
msg.Callback(msg.Sub, msg.Client, msg.Account, msg.Subject, msg.Reply, msg.Headers, msg.Message);
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Error in internal receive loop processing {Subject}", msg.Subject);
}
}
}
catch (OperationCanceledException)
{
// Normal shutdown
}
}
public async ValueTask DisposeAsync()
{
await _cts.CancelAsync();
_sendQueue.Writer.TryComplete();
_receiveQueue.Writer.TryComplete();
_receiveQueuePings.Writer.TryComplete();
if (_sendLoop != null) await _sendLoop.WaitAsync(TimeSpan.FromSeconds(2)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
if (_receiveLoop != null) await _receiveLoop.WaitAsync(TimeSpan.FromSeconds(2)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
if (_receiveLoopPings != null) await _receiveLoopPings.WaitAsync(TimeSpan.FromSeconds(2)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
_cts.Dispose();
}
}

View File

@@ -0,0 +1,19 @@
using NATS.Server.Auth;
using NATS.Server.Protocol;
namespace NATS.Server;
public interface INatsClient
{
ulong Id { get; }
ClientKind Kind { get; }
bool IsInternal => Kind.IsInternal();
Account? Account { get; }
ClientOptions? ClientOpts { get; }
ClientPermissions? Permissions { get; }
void SendMessage(string subject, string sid, string? replyTo,
ReadOnlyMemory<byte> headers, ReadOnlyMemory<byte> payload);
bool QueueOutbound(ReadOnlyMemory<byte> data);
void RemoveSubscription(string sid);
}

View File

@@ -0,0 +1,25 @@
using NATS.Server.Auth;
namespace NATS.Server.Imports;
public sealed class ExportAuth
{
public bool TokenRequired { get; init; }
public uint AccountPosition { get; init; }
public HashSet<string>? ApprovedAccounts { get; init; }
public Dictionary<string, long>? RevokedAccounts { get; init; }
public bool IsAuthorized(Account account)
{
if (RevokedAccounts != null && RevokedAccounts.ContainsKey(account.Name))
return false;
if (ApprovedAccounts == null && !TokenRequired && AccountPosition == 0)
return true;
if (ApprovedAccounts != null)
return ApprovedAccounts.Contains(account.Name);
return false;
}
}

View File

@@ -0,0 +1,8 @@
namespace NATS.Server.Imports;
public sealed class ExportMap
{
public Dictionary<string, StreamExport> Streams { get; } = new(StringComparer.Ordinal);
public Dictionary<string, ServiceExport> Services { get; } = new(StringComparer.Ordinal);
public Dictionary<string, ServiceImport> Responses { get; } = new(StringComparer.Ordinal);
}

View File

@@ -0,0 +1,18 @@
namespace NATS.Server.Imports;
public sealed class ImportMap
{
public List<StreamImport> Streams { get; } = [];
public Dictionary<string, List<ServiceImport>> Services { get; } = new(StringComparer.Ordinal);
public void AddServiceImport(ServiceImport si)
{
if (!Services.TryGetValue(si.From, out var list))
{
list = [];
Services[si.From] = list;
}
list.Add(si);
}
}

View File

@@ -0,0 +1,47 @@
using System.Text.Json.Serialization;
namespace NATS.Server.Imports;
public sealed class ServiceLatencyMsg
{
[JsonPropertyName("type")]
public string Type { get; set; } = "io.nats.server.metric.v1.service_latency";
[JsonPropertyName("requestor")]
public string Requestor { get; set; } = string.Empty;
[JsonPropertyName("responder")]
public string Responder { get; set; } = string.Empty;
[JsonPropertyName("status")]
public int Status { get; set; } = 200;
[JsonPropertyName("svc_latency")]
public long ServiceLatencyNanos { get; set; }
[JsonPropertyName("total_latency")]
public long TotalLatencyNanos { get; set; }
}
public static class LatencyTracker
{
public static bool ShouldSample(ServiceLatency latency)
{
if (latency.SamplingPercentage <= 0) return false;
if (latency.SamplingPercentage >= 100) return true;
return Random.Shared.Next(100) < latency.SamplingPercentage;
}
public static ServiceLatencyMsg BuildLatencyMsg(
string requestor, string responder,
TimeSpan serviceLatency, TimeSpan totalLatency)
{
return new ServiceLatencyMsg
{
Requestor = requestor,
Responder = responder,
ServiceLatencyNanos = serviceLatency.Ticks * 100,
TotalLatencyNanos = totalLatency.Ticks * 100,
};
}
}

View File

@@ -0,0 +1,64 @@
using System.Security.Cryptography;
using NATS.Server.Auth;
namespace NATS.Server.Imports;
/// <summary>
/// Handles response routing for service imports.
/// Maps to Go's service reply prefix generation and response cleanup.
/// Reference: golang/nats-server/server/accounts.go — addRespServiceImport, removeRespServiceImport
/// </summary>
public static class ResponseRouter
{
private static readonly char[] Base62 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789".ToCharArray();
/// <summary>
/// Generates a unique reply prefix for response routing.
/// Format: "_R_.{10 random base62 chars}."
/// </summary>
public static string GenerateReplyPrefix()
{
Span<byte> bytes = stackalloc byte[10];
RandomNumberGenerator.Fill(bytes);
var chars = new char[10];
for (int i = 0; i < 10; i++)
chars[i] = Base62[bytes[i] % 62];
return $"_R_.{new string(chars)}.";
}
/// <summary>
/// Creates a response service import that maps the generated reply prefix
/// back to the original reply subject on the requesting account.
/// </summary>
public static ServiceImport CreateResponseImport(
Account exporterAccount,
ServiceImport originalImport,
string originalReply)
{
var replyPrefix = GenerateReplyPrefix();
var responseSi = new ServiceImport
{
DestinationAccount = exporterAccount,
From = replyPrefix + ">",
To = originalReply,
IsResponse = true,
ResponseType = originalImport.ResponseType,
Export = originalImport.Export,
TimestampTicks = DateTime.UtcNow.Ticks,
};
exporterAccount.Exports.Responses[replyPrefix] = responseSi;
return responseSi;
}
/// <summary>
/// Removes a response import from the account's export map.
/// For Singleton responses, this is called after the first reply is delivered.
/// For Streamed/Chunked, it is called when the response stream ends.
/// </summary>
public static void CleanupResponse(Account account, string replyPrefix, ServiceImport responseSi)
{
account.Exports.Responses.Remove(replyPrefix);
}
}

View File

@@ -0,0 +1,13 @@
using NATS.Server.Auth;
namespace NATS.Server.Imports;
public sealed class ServiceExport
{
public ExportAuth Auth { get; init; } = new();
public Account? Account { get; init; }
public ServiceResponseType ResponseType { get; init; } = ServiceResponseType.Singleton;
public TimeSpan ResponseThreshold { get; init; } = TimeSpan.FromMinutes(2);
public ServiceLatency? Latency { get; init; }
public bool AllowTrace { get; init; }
}

View File

@@ -0,0 +1,21 @@
using NATS.Server.Auth;
using NATS.Server.Subscriptions;
namespace NATS.Server.Imports;
public sealed class ServiceImport
{
public required Account DestinationAccount { get; init; }
public required string From { get; init; }
public required string To { get; init; }
public SubjectTransform? Transform { get; init; }
public ServiceExport? Export { get; init; }
public ServiceResponseType ResponseType { get; init; } = ServiceResponseType.Singleton;
public byte[]? Sid { get; set; }
public bool IsResponse { get; init; }
public bool UsePub { get; init; }
public bool Invalid { get; set; }
public bool Share { get; init; }
public bool Tracking { get; init; }
public long TimestampTicks { get; set; }
}

View File

@@ -0,0 +1,7 @@
namespace NATS.Server.Imports;
public sealed class ServiceLatency
{
public int SamplingPercentage { get; init; } = 100;
public string Subject { get; init; } = string.Empty;
}

View File

@@ -0,0 +1,8 @@
namespace NATS.Server.Imports;
public enum ServiceResponseType
{
Singleton,
Streamed,
Chunked,
}

View File

@@ -0,0 +1,6 @@
namespace NATS.Server.Imports;
public sealed class StreamExport
{
public ExportAuth Auth { get; init; } = new();
}

View File

@@ -0,0 +1,14 @@
using NATS.Server.Auth;
using NATS.Server.Subscriptions;
namespace NATS.Server.Imports;
public sealed class StreamImport
{
public required Account SourceAccount { get; init; }
public required string From { get; init; }
public required string To { get; init; }
public SubjectTransform? Transform { get; init; }
public bool UsePub { get; init; }
public bool Invalid { get; set; }
}

View File

@@ -0,0 +1,59 @@
using NATS.Server.Auth;
using NATS.Server.Protocol;
using NATS.Server.Subscriptions;
namespace NATS.Server;
/// <summary>
/// Lightweight socketless client for internal messaging (SYSTEM, ACCOUNT, JETSTREAM).
/// Maps to Go's internal client created by createInternalClient() in server.go:1910-1936.
/// No network I/O — messages are delivered via callback.
/// </summary>
public sealed class InternalClient : INatsClient
{
public ulong Id { get; }
public ClientKind Kind { get; }
public bool IsInternal => Kind.IsInternal();
public Account? Account { get; }
public ClientOptions? ClientOpts => null;
public ClientPermissions? Permissions => null;
/// <summary>
/// Callback invoked when a message is delivered to this internal client.
/// Set by the event system or account import infrastructure.
/// </summary>
public Action<string, string, string?, ReadOnlyMemory<byte>, ReadOnlyMemory<byte>>? MessageCallback { get; set; }
private readonly Dictionary<string, Subscription> _subs = new(StringComparer.Ordinal);
public InternalClient(ulong id, ClientKind kind, Account account)
{
if (!kind.IsInternal())
throw new ArgumentException($"InternalClient requires an internal ClientKind, got {kind}", nameof(kind));
Id = id;
Kind = kind;
Account = account;
}
public void SendMessage(string subject, string sid, string? replyTo,
ReadOnlyMemory<byte> headers, ReadOnlyMemory<byte> payload)
{
MessageCallback?.Invoke(subject, sid, replyTo, headers, payload);
}
public bool QueueOutbound(ReadOnlyMemory<byte> data) => true; // no-op for internal clients
public void RemoveSubscription(string sid)
{
if (_subs.Remove(sid))
Account?.DecrementSubscriptions();
}
public void AddSubscription(Subscription sub)
{
_subs[sub.Sid] = sub;
}
public IReadOnlyDictionary<string, Subscription> Subscriptions => _subs;
}

View File

@@ -14,12 +14,16 @@ public sealed class SubszHandler(NatsServer server)
var opts = ParseQueryParams(ctx);
var now = DateTime.UtcNow;
// Collect subscriptions from all accounts (or filtered)
// Collect subscriptions from all accounts (or filtered).
// Exclude the $SYS system account unless explicitly requested — its internal
// subscriptions are infrastructure and not user-facing.
var allSubs = new List<Subscription>();
foreach (var account in server.GetAccounts())
{
if (!string.IsNullOrEmpty(opts.Account) && account.Name != opts.Account)
continue;
if (string.IsNullOrEmpty(opts.Account) && account.Name == "$SYS")
continue;
allSubs.AddRange(account.SubList.GetAllSubscriptions());
}
@@ -31,10 +35,10 @@ public sealed class SubszHandler(NatsServer server)
var total = allSubs.Count;
var numSubs = server.GetAccounts()
.Where(a => string.IsNullOrEmpty(opts.Account) || a.Name == opts.Account)
.Where(a => (string.IsNullOrEmpty(opts.Account) && a.Name != "$SYS") || a.Name == opts.Account)
.Aggregate(0u, (sum, a) => sum + a.SubList.Count);
var numCache = server.GetAccounts()
.Where(a => string.IsNullOrEmpty(opts.Account) || a.Name == opts.Account)
.Where(a => (string.IsNullOrEmpty(opts.Account) && a.Name != "$SYS") || a.Name == opts.Account)
.Sum(a => a.SubList.CacheCount);
SubDetail[] details = [];

View File

@@ -1,4 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">
<ItemGroup>
<InternalsVisibleTo Include="NATS.Server.Tests" />
</ItemGroup>
<ItemGroup>
<FrameworkReference Include="Microsoft.AspNetCore.App" />
<PackageReference Include="NATS.NKeys" />

View File

@@ -11,7 +11,6 @@ using NATS.Server.Auth;
using NATS.Server.Protocol;
using NATS.Server.Subscriptions;
using NATS.Server.Tls;
using NATS.Server.WebSocket;
namespace NATS.Server;
@@ -20,6 +19,8 @@ public interface IMessageRouter
void ProcessMessage(string subject, string? replyTo, ReadOnlyMemory<byte> headers,
ReadOnlyMemory<byte> payload, NatsClient sender);
void RemoveClient(NatsClient client);
void PublishConnectEvent(NatsClient client);
void PublishDisconnectEvent(NatsClient client);
}
public interface ISubListAccess
@@ -27,7 +28,7 @@ public interface ISubListAccess
SubList SubList { get; }
}
public sealed class NatsClient : IDisposable
public sealed class NatsClient : INatsClient, IDisposable
{
private readonly Socket _socket;
private readonly Stream _stream;
@@ -46,6 +47,7 @@ public sealed class NatsClient : IDisposable
private readonly ServerStats _serverStats;
public ulong Id { get; }
public ClientKind Kind => ClientKind.Client;
public ClientOptions? ClientOpts { get; private set; }
public IMessageRouter? Router { get; set; }
public Account? Account { get; private set; }
@@ -94,9 +96,6 @@ public sealed class NatsClient : IDisposable
private long _rtt;
public TimeSpan Rtt => new(Interlocked.Read(ref _rtt));
public bool IsWebSocket { get; set; }
public WsUpgradeResult? WsInfo { get; set; }
public TlsConnectionState? TlsState { get; set; }
public bool InfoAlreadySent { get; set; }
@@ -448,6 +447,9 @@ public sealed class NatsClient : IDisposable
_flags.SetFlag(ClientFlags.ConnectProcessFinished);
_logger.LogDebug("CONNECT received from client {ClientId}, name={ClientName}", Id, ClientOpts?.Name);
// Publish connect advisory to the system event bus
Router?.PublishConnectEvent(this);
// Start auth expiry timer if needed
if (_authService.IsAuthRequired && authResult?.Expiry is { } expiry)
{

View File

@@ -116,32 +116,4 @@ public sealed class NatsOptions
public Dictionary<string, string>? SubjectMappings { get; set; }
public bool HasTls => TlsCert != null && TlsKey != null;
// WebSocket
public WebSocketOptions WebSocket { get; set; } = new();
}
public sealed class WebSocketOptions
{
public string Host { get; set; } = "0.0.0.0";
public int Port { get; set; } = -1;
public string? Advertise { get; set; }
public string? NoAuthUser { get; set; }
public string? JwtCookie { get; set; }
public string? UsernameCookie { get; set; }
public string? PasswordCookie { get; set; }
public string? TokenCookie { get; set; }
public string? Username { get; set; }
public string? Password { get; set; }
public string? Token { get; set; }
public TimeSpan AuthTimeout { get; set; } = TimeSpan.FromSeconds(2);
public bool NoTls { get; set; }
public string? TlsCert { get; set; }
public string? TlsKey { get; set; }
public bool SameOrigin { get; set; }
public List<string>? AllowedOrigins { get; set; }
public bool Compression { get; set; }
public TimeSpan HandshakeTimeout { get; set; } = TimeSpan.FromSeconds(2);
public TimeSpan? PingInterval { get; set; }
public Dictionary<string, string>? Headers { get; set; }
}

View File

@@ -9,11 +9,12 @@ using Microsoft.Extensions.Logging;
using NATS.NKeys;
using NATS.Server.Auth;
using NATS.Server.Configuration;
using NATS.Server.Events;
using NATS.Server.Imports;
using NATS.Server.Monitoring;
using NATS.Server.Protocol;
using NATS.Server.Subscriptions;
using NATS.Server.Tls;
using NATS.Server.WebSocket;
namespace NATS.Server;
@@ -36,12 +37,11 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
private string? _configDigest;
private readonly Account _globalAccount;
private readonly Account _systemAccount;
private InternalEventSystem? _eventSystem;
private readonly SslServerAuthenticationOptions? _sslOptions;
private readonly TlsRateLimiter? _tlsRateLimiter;
private readonly SubjectTransform[] _subjectTransforms;
private Socket? _listener;
private Socket? _wsListener;
private readonly TaskCompletionSource _wsAcceptLoopExited = new(TaskCreationOptions.RunContinuationsAsynchronously);
private MonitorServer? _monitorServer;
private ulong _nextClientId;
private long _startTimeTicks;
@@ -73,6 +73,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
public int Port => _options.Port;
public Account SystemAccount => _systemAccount;
public string ServerNKey { get; }
public InternalEventSystem? EventSystem => _eventSystem;
public bool IsShuttingDown => Volatile.Read(ref _shutdown) != 0;
public bool IsLameDuckMode => Volatile.Read(ref _lameDuck) != 0;
public Action? ReOpenLogFile { get; set; }
@@ -93,16 +94,29 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
_logger.LogInformation("Initiating Shutdown...");
// Publish shutdown advisory before tearing down the event system
if (_eventSystem != null)
{
var shutdownSubject = string.Format(EventSubjects.ServerShutdown, _serverInfo.ServerId);
_eventSystem.Enqueue(new PublishMessage
{
Subject = shutdownSubject,
Body = new ShutdownEventMsg { Server = BuildEventServerInfo(), Reason = "Server Shutdown" },
IsLast = true,
});
// Give the send loop time to process the shutdown event
await Task.Delay(100);
await _eventSystem.DisposeAsync();
}
// Signal all internal loops to stop
await _quitCts.CancelAsync();
// Close listeners to stop accept loops
// Close listener to stop accept loop
_listener?.Close();
_wsListener?.Close();
// Wait for accept loops to exit
// Wait for accept loop to exit
await _acceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
await _wsAcceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
// Close all client connections — flush first, then mark closed
var flushTasks = new List<Task>();
@@ -143,13 +157,11 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
_logger.LogInformation("Entering lame duck mode, stop accepting new clients");
// Close listeners to stop accepting new connections
// Close listener to stop accepting new connections
_listener?.Close();
_wsListener?.Close();
// Wait for accept loops to exit
// Wait for accept loop to exit
await _acceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
await _wsAcceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
var gracePeriod = _options.LameDuckGracePeriod;
if (gracePeriod < TimeSpan.Zero) gracePeriod = -gracePeriod;
@@ -272,6 +284,14 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
_systemAccount = new Account("$SYS");
_accounts["$SYS"] = _systemAccount;
// Create system internal client and event system
var sysClientId = Interlocked.Increment(ref _nextClientId);
var sysClient = new InternalClient(sysClientId, ClientKind.System, _systemAccount);
_eventSystem = new InternalEventSystem(
_systemAccount, sysClient,
options.ServerName ?? $"nats-dotnet-{Environment.MachineName}",
_loggerFactory.CreateLogger<InternalEventSystem>());
// Generate Ed25519 server NKey identity
using var serverKeyPair = KeyPair.CreatePair(PrefixByte.Server);
ServerNKey = serverKeyPair.GetPublicKey();
@@ -376,6 +396,11 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
BuildCachedInfo();
}
_listeningStarted.TrySetResult();
_eventSystem?.Start(this);
_eventSystem?.InitEventTracking(this);
_logger.LogInformation("Listening for client connections on {Host}:{Port}", _options.Host, _options.Port);
// Warn about stub features
@@ -391,31 +416,6 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
WritePidFile();
WritePortsFile();
if (_options.WebSocket.Port >= 0)
{
_wsListener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_wsListener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true);
_wsListener.Bind(new IPEndPoint(
_options.WebSocket.Host == "0.0.0.0" ? IPAddress.Any : IPAddress.Parse(_options.WebSocket.Host),
_options.WebSocket.Port));
_wsListener.Listen(128);
if (_options.WebSocket.Port == 0)
{
_options.WebSocket.Port = ((IPEndPoint)_wsListener.LocalEndPoint!).Port;
}
_logger.LogInformation("Listening for WebSocket clients on {Host}:{Port}",
_options.WebSocket.Host, _options.WebSocket.Port);
if (_options.WebSocket.NoTls)
_logger.LogWarning("WebSocket not configured with TLS. DO NOT USE IN PRODUCTION!");
_ = RunWebSocketAcceptLoopAsync(linked.Token);
}
_listeningStarted.TrySetResult();
var tmpDelay = AcceptMinSleep;
try
@@ -561,102 +561,6 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
}
}
private async Task RunWebSocketAcceptLoopAsync(CancellationToken ct)
{
var tmpDelay = AcceptMinSleep;
try
{
while (!ct.IsCancellationRequested)
{
Socket socket;
try
{
socket = await _wsListener!.AcceptAsync(ct);
tmpDelay = AcceptMinSleep;
}
catch (OperationCanceledException) { break; }
catch (ObjectDisposedException) { break; }
catch (SocketException ex)
{
if (IsShuttingDown || IsLameDuckMode) break;
_logger.LogError(ex, "Temporary WebSocket accept error, sleeping {Delay}ms", tmpDelay.TotalMilliseconds);
try { await Task.Delay(tmpDelay, ct); } catch (OperationCanceledException) { break; }
tmpDelay = TimeSpan.FromTicks(Math.Min(tmpDelay.Ticks * 2, AcceptMaxSleep.Ticks));
continue;
}
if (_options.MaxConnections > 0 && _clients.Count >= _options.MaxConnections)
{
socket.Dispose();
continue;
}
var clientId = Interlocked.Increment(ref _nextClientId);
Interlocked.Increment(ref _stats.TotalConnections);
Interlocked.Increment(ref _activeClientCount);
_ = AcceptWebSocketClientAsync(socket, clientId, ct);
}
}
finally
{
_wsAcceptLoopExited.TrySetResult();
}
}
private async Task AcceptWebSocketClientAsync(Socket socket, ulong clientId, CancellationToken ct)
{
try
{
var networkStream = new NetworkStream(socket, ownsSocket: false);
Stream stream = networkStream;
// TLS negotiation if configured
if (_sslOptions != null && !_options.WebSocket.NoTls)
{
var (tlsStream, _) = await TlsConnectionWrapper.NegotiateAsync(
socket, networkStream, _options, _sslOptions, _serverInfo,
_loggerFactory.CreateLogger("NATS.Server.Tls"), ct);
stream = tlsStream;
}
// HTTP upgrade handshake
var upgradeResult = await WsUpgrade.TryUpgradeAsync(stream, stream, _options.WebSocket, ct);
if (!upgradeResult.Success)
{
_logger.LogDebug("WebSocket upgrade failed for client {ClientId}", clientId);
socket.Dispose();
Interlocked.Decrement(ref _activeClientCount);
return;
}
// Create WsConnection wrapper
var wsConn = new WsConnection(stream,
compress: upgradeResult.Compress,
maskRead: upgradeResult.MaskRead,
maskWrite: upgradeResult.MaskWrite,
browser: upgradeResult.Browser,
noCompFrag: upgradeResult.NoCompFrag);
var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]");
var client = new NatsClient(clientId, wsConn, socket, _options, _serverInfo,
_authService, null, clientLogger, _stats);
client.Router = this;
client.IsWebSocket = true;
client.WsInfo = upgradeResult;
_clients[clientId] = client;
await RunClientAsync(client, ct);
}
catch (Exception ex)
{
_logger.LogDebug(ex, "Failed to accept WebSocket client {ClientId}", clientId);
try { socket.Shutdown(SocketShutdown.Both); } catch { }
socket.Dispose();
Interlocked.Decrement(ref _activeClientCount);
}
}
private async Task RunClientAsync(NatsClient client, CancellationToken ct)
{
try
@@ -728,6 +632,27 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
}
}
// Check for service imports that match this subject.
// When a client in the importer account publishes to a subject
// that matches a service import "From" pattern, we forward the
// message to the destination (exporter) account's subscribers
// using the mapped "To" subject.
if (sender.Account != null)
{
foreach (var kvp in sender.Account.Imports.Services)
{
foreach (var si in kvp.Value)
{
if (si.Invalid) continue;
if (SubjectMatch.MatchLiteral(subject, si.From))
{
ProcessServiceImport(si, subject, replyTo, headers, payload);
delivered = true;
}
}
}
}
// No-responders: if nobody received the message and the publisher
// opted in, send back a 503 status HMSG on the reply subject.
if (!delivered && replyTo != null && sender.ClientOpts?.NoResponders == true)
@@ -767,6 +692,153 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
}
}
/// <summary>
/// Processes a service import by transforming the subject from the importer's
/// subject space to the exporter's subject space, then delivering to matching
/// subscribers in the destination account.
/// Reference: Go server/accounts.go addServiceImport / processServiceImport.
/// </summary>
public void ProcessServiceImport(ServiceImport si, string subject, string? replyTo,
ReadOnlyMemory<byte> headers, ReadOnlyMemory<byte> payload)
{
if (si.Invalid) return;
// Transform subject: map from importer subject space to exporter subject space
string targetSubject;
if (si.Transform != null)
{
var transformed = si.Transform.Apply(subject);
targetSubject = transformed ?? si.To;
}
else if (si.UsePub)
{
targetSubject = subject;
}
else
{
// Default: use the "To" subject from the import definition.
// For wildcard imports (e.g. "requests.>" -> "api.>"), we need
// to map the specific subject tokens from the source pattern to
// the destination pattern.
targetSubject = MapImportSubject(subject, si.From, si.To);
}
// Match against destination account's SubList
var destSubList = si.DestinationAccount.SubList;
var result = destSubList.Match(targetSubject);
// Deliver to plain subscribers in the destination account
foreach (var sub in result.PlainSubs)
{
if (sub.Client == null) continue;
DeliverMessage(sub, targetSubject, replyTo, headers, payload);
}
// Deliver to one member of each queue group
foreach (var queueGroup in result.QueueSubs)
{
if (queueGroup.Length == 0) continue;
var sub = queueGroup[0]; // Simple selection: first available
if (sub.Client != null)
DeliverMessage(sub, targetSubject, replyTo, headers, payload);
}
}
/// <summary>
/// Maps a published subject from the import "From" pattern to the "To" pattern.
/// For example, if From="requests.>" and To="api.>" and subject="requests.test",
/// this returns "api.test".
/// </summary>
private static string MapImportSubject(string subject, string fromPattern, string toPattern)
{
// If "To" doesn't contain wildcards, use it directly
if (SubjectMatch.IsLiteral(toPattern))
return toPattern;
// For wildcard patterns, replace matching wildcard segments.
// Split into tokens and map from source to destination.
var subTokens = subject.Split('.');
var fromTokens = fromPattern.Split('.');
var toTokens = toPattern.Split('.');
var result = new string[toTokens.Length];
int subIdx = 0;
// Build a mapping: for each wildcard position in "from",
// capture the corresponding subject token(s)
var wildcardValues = new List<string>();
string? fwcValue = null;
for (int i = 0; i < fromTokens.Length && subIdx < subTokens.Length; i++)
{
if (fromTokens[i] == "*")
{
wildcardValues.Add(subTokens[subIdx]);
subIdx++;
}
else if (fromTokens[i] == ">")
{
// Capture all remaining tokens
fwcValue = string.Join(".", subTokens[subIdx..]);
subIdx = subTokens.Length;
}
else
{
subIdx++; // Skip literal match
}
}
// Now build the output using the "to" pattern
int wcIdx = 0;
var sb = new StringBuilder();
for (int i = 0; i < toTokens.Length; i++)
{
if (i > 0) sb.Append('.');
if (toTokens[i] == "*")
{
sb.Append(wcIdx < wildcardValues.Count ? wildcardValues[wcIdx] : "*");
wcIdx++;
}
else if (toTokens[i] == ">")
{
sb.Append(fwcValue ?? ">");
}
else
{
sb.Append(toTokens[i]);
}
}
return sb.ToString();
}
/// <summary>
/// Wires service import subscriptions for an account. Creates marker
/// subscriptions in the account's SubList so that the import paths
/// are tracked. The actual forwarding happens in ProcessMessage when
/// it checks the account's Imports.Services.
/// Reference: Go server/accounts.go addServiceImportSub.
/// </summary>
public void WireServiceImports(Account account)
{
foreach (var kvp in account.Imports.Services)
{
foreach (var si in kvp.Value)
{
if (si.Invalid) continue;
// Create a marker subscription in the importer account.
// This subscription doesn't directly deliver messages;
// the ProcessMessage method checks service imports after
// the regular SubList match.
_logger.LogDebug(
"Wired service import for account {Account}: {From} -> {To} (dest: {DestAccount})",
account.Name, si.From, si.To, si.DestinationAccount.Name);
}
}
}
private static void SendNoResponders(NatsClient sender, string replyTo)
{
// Find the sid for a subscription matching the reply subject
@@ -812,8 +884,194 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
});
}
public void SendInternalMsg(string subject, string? reply, object? msg)
{
_eventSystem?.Enqueue(new PublishMessage { Subject = subject, Reply = reply, Body = msg });
}
public void SendInternalAccountMsg(Account account, string subject, object? msg)
{
_eventSystem?.Enqueue(new PublishMessage { Subject = subject, Body = msg });
}
/// <summary>
/// Handles $SYS.REQ.SERVER.{id}.VARZ requests.
/// Returns core server information including stats counters.
/// </summary>
public void HandleVarzRequest(string subject, string? reply)
{
if (reply == null) return;
var varz = new
{
server_id = _serverInfo.ServerId,
server_name = _serverInfo.ServerName,
version = NatsProtocol.Version,
host = _options.Host,
port = _options.Port,
max_payload = _options.MaxPayload,
connections = ClientCount,
total_connections = Interlocked.Read(ref _stats.TotalConnections),
in_msgs = Interlocked.Read(ref _stats.InMsgs),
out_msgs = Interlocked.Read(ref _stats.OutMsgs),
in_bytes = Interlocked.Read(ref _stats.InBytes),
out_bytes = Interlocked.Read(ref _stats.OutBytes),
};
SendInternalMsg(reply, null, varz);
}
/// <summary>
/// Handles $SYS.REQ.SERVER.{id}.HEALTHZ requests.
/// Returns a simple health status response.
/// </summary>
public void HandleHealthzRequest(string subject, string? reply)
{
if (reply == null) return;
SendInternalMsg(reply, null, new { status = "ok" });
}
/// <summary>
/// Handles $SYS.REQ.SERVER.{id}.SUBSZ requests.
/// Returns the current subscription count.
/// </summary>
public void HandleSubszRequest(string subject, string? reply)
{
if (reply == null) return;
SendInternalMsg(reply, null, new { num_subscriptions = SubList.Count });
}
/// <summary>
/// Handles $SYS.REQ.SERVER.{id}.STATSZ requests.
/// Publishes current server statistics through the event system.
/// </summary>
public void HandleStatszRequest(string subject, string? reply)
{
if (reply == null) return;
var process = System.Diagnostics.Process.GetCurrentProcess();
var statsMsg = new Events.ServerStatsMsg
{
Server = BuildEventServerInfo(),
Stats = new Events.ServerStatsData
{
Start = StartTime,
Mem = process.WorkingSet64,
Cores = Environment.ProcessorCount,
Connections = ClientCount,
TotalConnections = Interlocked.Read(ref _stats.TotalConnections),
Subscriptions = SubList.Count,
InMsgs = Interlocked.Read(ref _stats.InMsgs),
OutMsgs = Interlocked.Read(ref _stats.OutMsgs),
InBytes = Interlocked.Read(ref _stats.InBytes),
OutBytes = Interlocked.Read(ref _stats.OutBytes),
SlowConsumers = Interlocked.Read(ref _stats.SlowConsumers),
},
};
SendInternalMsg(reply, null, statsMsg);
}
/// <summary>
/// Handles $SYS.REQ.SERVER.{id}.IDZ requests.
/// Returns basic server identity information.
/// </summary>
public void HandleIdzRequest(string subject, string? reply)
{
if (reply == null) return;
var idz = new
{
server_id = _serverInfo.ServerId,
server_name = _serverInfo.ServerName,
version = NatsProtocol.Version,
host = _options.Host,
port = _options.Port,
};
SendInternalMsg(reply, null, idz);
}
/// <summary>
/// Builds an EventServerInfo block for embedding in system event messages.
/// Maps to Go's serverInfo() helper used in events.go advisory publishing.
/// </summary>
public EventServerInfo BuildEventServerInfo()
{
var seq = _eventSystem?.NextSequence() ?? 0;
return new EventServerInfo
{
Name = _serverInfo.ServerName,
Host = _options.Host,
Id = _serverInfo.ServerId,
Version = NatsProtocol.Version,
Seq = seq,
};
}
private static EventClientInfo BuildEventClientInfo(NatsClient client)
{
return new EventClientInfo
{
Id = client.Id,
Host = client.RemoteIp,
Account = client.Account?.Name,
Name = client.ClientOpts?.Name,
Lang = client.ClientOpts?.Lang,
Version = client.ClientOpts?.Version,
Start = client.StartTime,
};
}
/// <summary>
/// Publishes a $SYS.ACCOUNT.{account}.CONNECT advisory when a client
/// completes authentication. Maps to Go's sendConnectEvent in events.go.
/// </summary>
public void PublishConnectEvent(NatsClient client)
{
if (_eventSystem == null) return;
var accountName = client.Account?.Name ?? Account.GlobalAccountName;
var subject = string.Format(EventSubjects.ConnectEvent, accountName);
var evt = new ConnectEventMsg
{
Id = Guid.NewGuid().ToString("N"),
Time = DateTime.UtcNow,
Server = BuildEventServerInfo(),
Client = BuildEventClientInfo(client),
};
SendInternalMsg(subject, null, evt);
}
/// <summary>
/// Publishes a $SYS.ACCOUNT.{account}.DISCONNECT advisory when a client
/// disconnects. Maps to Go's sendDisconnectEvent in events.go.
/// </summary>
public void PublishDisconnectEvent(NatsClient client)
{
if (_eventSystem == null) return;
var accountName = client.Account?.Name ?? Account.GlobalAccountName;
var subject = string.Format(EventSubjects.DisconnectEvent, accountName);
var evt = new DisconnectEventMsg
{
Id = Guid.NewGuid().ToString("N"),
Time = DateTime.UtcNow,
Server = BuildEventServerInfo(),
Client = BuildEventClientInfo(client),
Sent = new DataStats
{
Msgs = Interlocked.Read(ref client.OutMsgs),
Bytes = Interlocked.Read(ref client.OutBytes),
},
Received = new DataStats
{
Msgs = Interlocked.Read(ref client.InMsgs),
Bytes = Interlocked.Read(ref client.InBytes),
},
Reason = client.CloseReason.ToReasonString(),
};
SendInternalMsg(subject, null, evt);
}
public void RemoveClient(NatsClient client)
{
// Publish disconnect advisory before removing client state
if (client.ConnectReceived)
PublishDisconnectEvent(client);
_clients.TryRemove(client.Id, out _);
_logger.LogDebug("Removed client {ClientId}", client.Id);
@@ -1068,7 +1326,6 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
_quitCts.Dispose();
_tlsRateLimiter?.Dispose();
_listener?.Dispose();
_wsListener?.Dispose();
foreach (var client in _clients.Values)
client.Dispose();
foreach (var account in _accounts.Values)

View File

@@ -1,4 +1,5 @@
using NATS.Server;
using NATS.Server.Imports;
namespace NATS.Server.Subscriptions;
@@ -9,5 +10,7 @@ public sealed class Subscription
public required string Sid { get; init; }
public long MessageCount; // Interlocked
public long MaxMessages; // 0 = unlimited
public NatsClient? Client { get; set; }
public INatsClient? Client { get; set; }
public ServiceImport? ServiceImport { get; set; }
public StreamImport? StreamImport { get; set; }
}

View File

@@ -1,94 +0,0 @@
using System.IO.Compression;
namespace NATS.Server.WebSocket;
/// <summary>
/// permessage-deflate compression/decompression for WebSocket frames (RFC 7692).
/// Ported from golang/nats-server/server/websocket.go lines 403-440 and 1391-1466.
/// </summary>
public static class WsCompression
{
/// <summary>
/// Compresses data using deflate. Removes trailing 4 bytes (sync marker)
/// per RFC 7692 Section 7.2.1.
/// </summary>
/// <remarks>
/// We call Flush() but intentionally do not Dispose() the DeflateStream before
/// reading output, because Dispose writes a final deflate block (0x03 0x00) that
/// would be corrupted by the 4-byte tail strip. Flush() alone writes a sync flush
/// ending with 0x00 0x00 0xff 0xff, matching Go's flate.Writer.Flush() behavior.
/// </remarks>
public static byte[] Compress(ReadOnlySpan<byte> data)
{
var output = new MemoryStream();
var deflate = new DeflateStream(output, CompressionLevel.Fastest, leaveOpen: true);
try
{
deflate.Write(data);
deflate.Flush();
var compressed = output.ToArray();
// Remove trailing 4-byte sync marker (0x00 0x00 0xff 0xff) per RFC 7692
if (compressed.Length >= 4)
return compressed[..^4];
return compressed;
}
finally
{
deflate.Dispose();
output.Dispose();
}
}
/// <summary>
/// Decompresses collected compressed buffers.
/// Appends trailer bytes before decompressing per RFC 7692 Section 7.2.2.
/// Ported from golang/nats-server/server/websocket.go lines 403-440.
/// The Go code appends compressLastBlock (9 bytes) which includes the sync
/// marker plus a final empty stored block to signal end-of-stream to the
/// flate reader.
/// </summary>
public static byte[] Decompress(List<byte[]> compressedBuffers, int maxPayload)
{
if (maxPayload <= 0)
maxPayload = 1024 * 1024; // Default 1MB
// Concatenate all compressed buffers + trailer.
// Per RFC 7692 Section 7.2.2, append the sync flush marker (0x00 0x00 0xff 0xff)
// that was stripped during compression. The Go reference appends compressLastBlock
// (9 bytes) for Go's flate reader; .NET's DeflateStream only needs the 4-byte trailer.
int totalLen = 0;
foreach (var buf in compressedBuffers)
totalLen += buf.Length;
totalLen += WsConstants.DecompressTrailer.Length;
var combined = new byte[totalLen];
int offset = 0;
foreach (var buf in compressedBuffers)
{
buf.CopyTo(combined, offset);
offset += buf.Length;
}
WsConstants.DecompressTrailer.CopyTo(combined, offset);
using var input = new MemoryStream(combined);
using var deflate = new DeflateStream(input, CompressionMode.Decompress);
using var output = new MemoryStream();
var readBuf = new byte[4096];
int totalRead = 0;
int n;
while ((n = deflate.Read(readBuf, 0, readBuf.Length)) > 0)
{
totalRead += n;
if (totalRead > maxPayload)
throw new InvalidOperationException("decompressed data exceeds maximum payload size");
output.Write(readBuf, 0, n);
}
return output.ToArray();
}
}

View File

@@ -1,202 +0,0 @@
namespace NATS.Server.WebSocket;
/// <summary>
/// Stream wrapper that transparently frames/deframes WebSocket around raw TCP I/O.
/// NatsClient uses this as its _stream -- FillPipeAsync and RunWriteLoopAsync work unchanged.
/// Ported from golang/nats-server/server/websocket.go wsUpgrade/wrapWebsocket pattern.
/// </summary>
public sealed class WsConnection : Stream
{
private readonly Stream _inner;
private readonly bool _compress;
private readonly bool _maskRead;
private readonly bool _maskWrite;
private readonly bool _browser;
private readonly bool _noCompFrag;
private WsReadInfo _readInfo;
// Read-side state: accessed only from the single FillPipeAsync reader task (no synchronization needed)
private readonly Queue<byte[]> _readQueue = new();
private int _readOffset;
private readonly object _writeLock = new();
private readonly List<ControlFrameAction> _pendingControlWrites = [];
public bool CloseReceived => _readInfo.CloseReceived;
public int CloseStatus => _readInfo.CloseStatus;
public WsConnection(Stream inner, bool compress, bool maskRead, bool maskWrite, bool browser, bool noCompFrag)
{
_inner = inner;
_compress = compress;
_maskRead = maskRead;
_maskWrite = maskWrite;
_browser = browser;
_noCompFrag = noCompFrag;
_readInfo = new WsReadInfo(expectMask: maskRead);
}
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken ct = default)
{
// Drain any buffered decoded payloads first
if (_readQueue.Count > 0)
return DrainReadQueue(buffer.Span);
while (true)
{
// Read raw bytes from inner stream
var rawBuf = new byte[Math.Max(buffer.Length, 4096)];
int bytesRead = await _inner.ReadAsync(rawBuf.AsMemory(), ct);
if (bytesRead == 0) return 0;
// Decode frames
var payloads = WsReadInfo.ReadFrames(_readInfo, new MemoryStream(rawBuf, 0, bytesRead), bytesRead, maxPayload: 1024 * 1024);
// Collect control frame responses
if (_readInfo.PendingControlFrames.Count > 0)
{
lock (_writeLock)
_pendingControlWrites.AddRange(_readInfo.PendingControlFrames);
_readInfo.PendingControlFrames.Clear();
// Write pending control frames
await FlushControlFramesAsync(ct);
}
if (_readInfo.CloseReceived)
return 0;
foreach (var payload in payloads)
_readQueue.Enqueue(payload);
// If no payloads were decoded (e.g. only frame headers were read),
// continue reading instead of returning 0 which signals end-of-stream
if (_readQueue.Count > 0)
return DrainReadQueue(buffer.Span);
}
}
public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken ct = default)
{
var data = buffer.Span;
if (_compress && data.Length > WsConstants.CompressThreshold)
{
var compressed = WsCompression.Compress(data);
await WriteFramedAsync(compressed, compressed: true, ct);
}
else
{
await WriteFramedAsync(data.ToArray(), compressed: false, ct);
}
}
private async ValueTask WriteFramedAsync(byte[] payload, bool compressed, CancellationToken ct)
{
if (_browser && payload.Length > WsConstants.FrameSizeForBrowsers && !(_noCompFrag && compressed))
{
// Fragment for browsers
int offset = 0;
bool first = true;
while (offset < payload.Length)
{
int chunkLen = Math.Min(WsConstants.FrameSizeForBrowsers, payload.Length - offset);
bool final = offset + chunkLen >= payload.Length;
var fh = new byte[WsConstants.MaxFrameHeaderSize];
var (n, key) = WsFrameWriter.FillFrameHeader(fh, _maskWrite,
first: first, final: final, compressed: first && compressed,
opcode: WsConstants.BinaryMessage, payloadLength: chunkLen);
var chunk = payload.AsSpan(offset, chunkLen).ToArray();
if (_maskWrite && key != null)
WsFrameWriter.MaskBuf(key, chunk);
await _inner.WriteAsync(fh.AsMemory(0, n), ct);
await _inner.WriteAsync(chunk.AsMemory(), ct);
offset += chunkLen;
first = false;
}
}
else
{
var (header, key) = WsFrameWriter.CreateFrameHeader(_maskWrite, compressed, WsConstants.BinaryMessage, payload.Length);
if (_maskWrite && key != null)
WsFrameWriter.MaskBuf(key, payload);
await _inner.WriteAsync(header.AsMemory(), ct);
await _inner.WriteAsync(payload.AsMemory(), ct);
}
}
private async Task FlushControlFramesAsync(CancellationToken ct)
{
List<ControlFrameAction> toWrite;
lock (_writeLock)
{
if (_pendingControlWrites.Count == 0) return;
toWrite = [.. _pendingControlWrites];
_pendingControlWrites.Clear();
}
foreach (var action in toWrite)
{
var frame = WsFrameWriter.BuildControlFrame(action.Opcode, action.Payload, _maskWrite);
await _inner.WriteAsync(frame, ct);
}
await _inner.FlushAsync(ct);
}
/// <summary>
/// Sends a WebSocket close frame.
/// </summary>
public async Task SendCloseAsync(ClientClosedReason reason, CancellationToken ct = default)
{
var status = WsFrameWriter.MapCloseStatus(reason);
var closePayload = WsFrameWriter.CreateCloseMessage(status, reason.ToReasonString());
var frame = WsFrameWriter.BuildControlFrame(WsConstants.CloseMessage, closePayload, _maskWrite);
await _inner.WriteAsync(frame, ct);
await _inner.FlushAsync(ct);
}
private int DrainReadQueue(Span<byte> buffer)
{
int written = 0;
while (_readQueue.Count > 0 && written < buffer.Length)
{
var current = _readQueue.Peek();
int available = current.Length - _readOffset;
int toCopy = Math.Min(available, buffer.Length - written);
current.AsSpan(_readOffset, toCopy).CopyTo(buffer[written..]);
written += toCopy;
_readOffset += toCopy;
if (_readOffset >= current.Length)
{
_readQueue.Dequeue();
_readOffset = 0;
}
}
return written;
}
// Stream abstract members
public override bool CanRead => true;
public override bool CanWrite => true;
public override bool CanSeek => false;
public override long Length => throw new NotSupportedException();
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
public override void Flush() => _inner.Flush();
public override Task FlushAsync(CancellationToken ct) => _inner.FlushAsync(ct);
public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use ReadAsync");
public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException("Use WriteAsync");
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
public override void SetLength(long value) => throw new NotSupportedException();
protected override void Dispose(bool disposing)
{
if (disposing)
_inner.Dispose();
base.Dispose(disposing);
}
public override async ValueTask DisposeAsync()
{
await _inner.DisposeAsync();
GC.SuppressFinalize(this);
}
}

View File

@@ -1,72 +0,0 @@
namespace NATS.Server.WebSocket;
/// <summary>
/// WebSocket protocol constants (RFC 6455).
/// Ported from golang/nats-server/server/websocket.go lines 41-106.
/// </summary>
public static class WsConstants
{
// Opcodes (RFC 6455 Section 5.2)
public const int TextMessage = 1;
public const int BinaryMessage = 2;
public const int CloseMessage = 8;
public const int PingMessage = 9;
public const int PongMessage = 10;
public const int ContinuationFrame = 0;
// Frame header bits
public const byte FinalBit = 0x80; // 1 << 7
public const byte Rsv1Bit = 0x40; // 1 << 6 (compression, RFC 7692)
public const byte Rsv2Bit = 0x20; // 1 << 5
public const byte Rsv3Bit = 0x10; // 1 << 4
public const byte MaskBit = 0x80; // 1 << 7 (in second byte)
// Frame size limits
public const int MaxFrameHeaderSize = 14;
public const int MaxControlPayloadSize = 125;
public const int FrameSizeForBrowsers = 4096;
public const int CompressThreshold = 64;
public const int CloseStatusSize = 2;
// Close status codes (RFC 6455 Section 11.7)
public const int CloseStatusNormalClosure = 1000;
public const int CloseStatusGoingAway = 1001;
public const int CloseStatusProtocolError = 1002;
public const int CloseStatusUnsupportedData = 1003;
public const int CloseStatusNoStatusReceived = 1005;
public const int CloseStatusInvalidPayloadData = 1007;
public const int CloseStatusPolicyViolation = 1008;
public const int CloseStatusMessageTooBig = 1009;
public const int CloseStatusInternalSrvError = 1011;
public const int CloseStatusTlsHandshake = 1015;
// Compression constants (RFC 7692)
public const string PmcExtension = "permessage-deflate";
public const string PmcSrvNoCtx = "server_no_context_takeover";
public const string PmcCliNoCtx = "client_no_context_takeover";
public static readonly string PmcReqHeaderValue = $"{PmcExtension}; {PmcSrvNoCtx}; {PmcCliNoCtx}";
public static readonly string PmcFullResponse = $"Sec-WebSocket-Extensions: {PmcExtension}; {PmcSrvNoCtx}; {PmcCliNoCtx}\r\n";
// Header names
public const string NoMaskingHeader = "Nats-No-Masking";
public const string NoMaskingValue = "true";
public static readonly string NoMaskingFullResponse = $"{NoMaskingHeader}: {NoMaskingValue}\r\n";
public const string XForwardedForHeader = "X-Forwarded-For";
// Path routing
public const string ClientPath = "/";
public const string LeafNodePath = "/leafnode";
public const string MqttPath = "/mqtt";
// Decompression trailer appended before decompressing (RFC 7692 Section 7.2.2)
public static readonly byte[] DecompressTrailer = [0x00, 0x00, 0xff, 0xff];
public static bool IsControlFrame(int opcode) => opcode >= CloseMessage;
}
public enum WsClientKind
{
Client,
Leaf,
Mqtt,
}

View File

@@ -1,171 +0,0 @@
using System.Buffers.Binary;
using System.Security.Cryptography;
using System.Text;
namespace NATS.Server.WebSocket;
/// <summary>
/// WebSocket frame construction, masking, and control message creation.
/// Ported from golang/nats-server/server/websocket.go lines 543-726.
/// </summary>
public static class WsFrameWriter
{
/// <summary>
/// Creates a complete frame header for a single-frame message (first=true, final=true).
/// Returns (header bytes, mask key or null).
/// </summary>
public static (byte[] header, byte[]? key) CreateFrameHeader(
bool useMasking, bool compressed, int opcode, int payloadLength)
{
var fh = new byte[WsConstants.MaxFrameHeaderSize];
var (n, key) = FillFrameHeader(fh, useMasking,
first: true, final: true, compressed: compressed, opcode: opcode, payloadLength: payloadLength);
return (fh[..n], key);
}
/// <summary>
/// Fills a pre-allocated frame header buffer.
/// Returns (bytes written, mask key or null).
/// </summary>
public static (int written, byte[]? key) FillFrameHeader(
Span<byte> fh, bool useMasking, bool first, bool final, bool compressed, int opcode, int payloadLength)
{
byte b0 = first ? (byte)opcode : (byte)0;
if (final) b0 |= WsConstants.FinalBit;
if (compressed) b0 |= WsConstants.Rsv1Bit;
byte b1 = 0;
if (useMasking) b1 |= WsConstants.MaskBit;
int n;
switch (payloadLength)
{
case <= 125:
n = 2;
fh[0] = b0;
fh[1] = (byte)(b1 | (byte)payloadLength);
break;
case < 65536:
n = 4;
fh[0] = b0;
fh[1] = (byte)(b1 | 126);
BinaryPrimitives.WriteUInt16BigEndian(fh[2..], (ushort)payloadLength);
break;
default:
n = 10;
fh[0] = b0;
fh[1] = (byte)(b1 | 127);
BinaryPrimitives.WriteUInt64BigEndian(fh[2..], (ulong)payloadLength);
break;
}
byte[]? key = null;
if (useMasking)
{
key = new byte[4];
RandomNumberGenerator.Fill(key);
key.CopyTo(fh[n..]);
n += 4;
}
return (n, key);
}
/// <summary>
/// XOR masks a buffer with a 4-byte key. Applies in-place.
/// </summary>
public static void MaskBuf(ReadOnlySpan<byte> key, Span<byte> buf)
{
for (int i = 0; i < buf.Length; i++)
buf[i] ^= key[i & 3];
}
/// <summary>
/// XOR masks multiple contiguous buffers as if they were one.
/// </summary>
public static void MaskBufs(ReadOnlySpan<byte> key, List<byte[]> bufs)
{
int pos = 0;
foreach (var buf in bufs)
{
for (int j = 0; j < buf.Length; j++)
{
buf[j] ^= key[pos & 3];
pos++;
}
}
}
/// <summary>
/// Creates a close message payload: 2-byte status code + optional UTF-8 body.
/// Body truncated to fit MaxControlPayloadSize with "..." suffix.
/// </summary>
public static byte[] CreateCloseMessage(int status, string body)
{
var bodyBytes = Encoding.UTF8.GetBytes(body);
int maxBody = WsConstants.MaxControlPayloadSize - WsConstants.CloseStatusSize;
if (bodyBytes.Length > maxBody)
{
var suffix = "..."u8;
int truncLen = maxBody - suffix.Length;
// Find a valid UTF-8 boundary by walking back from truncation point
while (truncLen > 0 && (bodyBytes[truncLen] & 0xC0) == 0x80)
truncLen--;
var buf = new byte[WsConstants.CloseStatusSize + truncLen + suffix.Length];
BinaryPrimitives.WriteUInt16BigEndian(buf, (ushort)status);
bodyBytes.AsSpan(0, truncLen).CopyTo(buf.AsSpan(WsConstants.CloseStatusSize));
suffix.CopyTo(buf.AsSpan(WsConstants.CloseStatusSize + truncLen));
return buf;
}
var result = new byte[WsConstants.CloseStatusSize + bodyBytes.Length];
BinaryPrimitives.WriteUInt16BigEndian(result, (ushort)status);
bodyBytes.CopyTo(result.AsSpan(WsConstants.CloseStatusSize));
return result;
}
/// <summary>
/// Builds a complete control frame (header + payload, optional masking).
/// </summary>
public static byte[] BuildControlFrame(int opcode, ReadOnlySpan<byte> payload, bool useMasking)
{
int headerSize = 2 + (useMasking ? 4 : 0);
var frame = new byte[headerSize + payload.Length];
var span = frame.AsSpan();
var (n, key) = FillFrameHeader(span, useMasking,
first: true, final: true, compressed: false, opcode: opcode, payloadLength: payload.Length);
if (payload.Length > 0)
{
payload.CopyTo(span[n..]);
if (useMasking && key != null)
MaskBuf(key, span[n..]);
}
return frame;
}
/// <summary>
/// Maps a ClientClosedReason to a WebSocket close status code.
/// Matches Go wsEnqueueCloseMessage in websocket.go lines 668-694.
/// </summary>
public static int MapCloseStatus(ClientClosedReason reason) => reason switch
{
ClientClosedReason.ClientClosed => WsConstants.CloseStatusNormalClosure,
ClientClosedReason.AuthenticationTimeout or
ClientClosedReason.AuthenticationViolation or
ClientClosedReason.SlowConsumerPendingBytes or
ClientClosedReason.SlowConsumerWriteDeadline or
ClientClosedReason.MaxSubscriptionsExceeded or
ClientClosedReason.AuthenticationExpired => WsConstants.CloseStatusPolicyViolation,
ClientClosedReason.TlsHandshakeError => WsConstants.CloseStatusTlsHandshake,
ClientClosedReason.ParseError or
ClientClosedReason.ProtocolViolation => WsConstants.CloseStatusProtocolError,
ClientClosedReason.MaxPayloadExceeded => WsConstants.CloseStatusMessageTooBig,
ClientClosedReason.WriteError or
ClientClosedReason.ReadError or
ClientClosedReason.StaleConnection or
ClientClosedReason.ServerShutdown => WsConstants.CloseStatusGoingAway,
_ => WsConstants.CloseStatusInternalSrvError,
};
}

View File

@@ -1,81 +0,0 @@
namespace NATS.Server.WebSocket;
/// <summary>
/// Validates WebSocket Origin headers per RFC 6455 Section 10.2.
/// Ported from golang/nats-server/server/websocket.go lines 933-1000.
/// </summary>
public sealed class WsOriginChecker
{
private readonly bool _sameOrigin;
private readonly Dictionary<string, AllowedOrigin>? _allowedOrigins;
public WsOriginChecker(bool sameOrigin, List<string>? allowedOrigins)
{
_sameOrigin = sameOrigin;
if (allowedOrigins is { Count: > 0 })
{
_allowedOrigins = new Dictionary<string, AllowedOrigin>(StringComparer.OrdinalIgnoreCase);
foreach (var ao in allowedOrigins)
{
if (Uri.TryCreate(ao, UriKind.Absolute, out var uri))
{
var (host, port) = GetHostAndPort(uri.Scheme == "https", uri.Host, uri.Port);
_allowedOrigins[host] = new AllowedOrigin(uri.Scheme, port);
}
}
}
}
/// <summary>
/// Returns null if origin is allowed, or an error message if rejected.
/// </summary>
public string? CheckOrigin(string? origin, string requestHost, bool isTls)
{
if (!_sameOrigin && _allowedOrigins == null)
return null;
if (string.IsNullOrEmpty(origin))
return null;
if (!Uri.TryCreate(origin, UriKind.Absolute, out var originUri))
return $"invalid origin: {origin}";
var (oh, op) = GetHostAndPort(originUri.Scheme == "https", originUri.Host, originUri.Port);
if (_sameOrigin)
{
var (rh, rp) = ParseHostPort(requestHost, isTls);
if (!string.Equals(oh, rh, StringComparison.OrdinalIgnoreCase) || op != rp)
return "not same origin";
}
if (_allowedOrigins != null)
{
if (!_allowedOrigins.TryGetValue(oh, out var allowed) ||
!string.Equals(originUri.Scheme, allowed.Scheme, StringComparison.OrdinalIgnoreCase) ||
op != allowed.Port)
{
return "not in the allowed list";
}
}
return null;
}
private static (string host, int port) GetHostAndPort(bool tls, string host, int port)
{
if (port <= 0)
port = tls ? 443 : 80;
return (host.ToLowerInvariant(), port);
}
private static (string host, int port) ParseHostPort(string hostPort, bool isTls)
{
var colonIdx = hostPort.LastIndexOf(':');
if (colonIdx > 0 && int.TryParse(hostPort.AsSpan(colonIdx + 1), out var port))
return (hostPort[..colonIdx].ToLowerInvariant(), port);
return (hostPort.ToLowerInvariant(), isTls ? 443 : 80);
}
private readonly record struct AllowedOrigin(string Scheme, int Port);
}

View File

@@ -1,322 +0,0 @@
using System.Buffers.Binary;
using System.Text;
namespace NATS.Server.WebSocket;
/// <summary>
/// Per-connection WebSocket frame reading state machine.
/// Ported from golang/nats-server/server/websocket.go lines 156-506.
/// </summary>
public class WsReadInfo
{
public int Remaining;
public bool FrameStart;
public bool FirstFrame;
public bool FrameCompressed;
public bool ExpectMask;
public byte MaskKeyPos;
public byte[] MaskKey;
public List<byte[]>? CompressedBuffers;
public int CompressedOffset;
// Control frame outputs
public List<ControlFrameAction> PendingControlFrames;
public bool CloseReceived;
public int CloseStatus;
public string? CloseBody;
public WsReadInfo(bool expectMask)
{
Remaining = 0;
FrameStart = true;
FirstFrame = true;
FrameCompressed = false;
ExpectMask = expectMask;
MaskKeyPos = 0;
MaskKey = new byte[4];
CompressedBuffers = null;
CompressedOffset = 0;
PendingControlFrames = [];
CloseReceived = false;
CloseStatus = 0;
CloseBody = null;
}
public void SetMaskKey(ReadOnlySpan<byte> key)
{
key[..4].CopyTo(MaskKey);
MaskKeyPos = 0;
}
/// <summary>
/// Unmask buffer in-place using current mask key and position.
/// Optimized for 8-byte chunks when buffer is large enough.
/// Ported from websocket.go lines 509-536.
/// </summary>
public void Unmask(Span<byte> buf)
{
int p = MaskKeyPos;
if (buf.Length < 16)
{
for (int i = 0; i < buf.Length; i++)
{
buf[i] ^= MaskKey[p & 3];
p++;
}
MaskKeyPos = (byte)(p & 3);
return;
}
// Build 8-byte key for bulk XOR
Span<byte> k = stackalloc byte[8];
for (int i = 0; i < 8; i++)
k[i] = MaskKey[(p + i) & 3];
ulong km = BinaryPrimitives.ReadUInt64BigEndian(k);
int n = (buf.Length / 8) * 8;
for (int i = 0; i < n; i += 8)
{
ulong tmp = BinaryPrimitives.ReadUInt64BigEndian(buf[i..]);
tmp ^= km;
BinaryPrimitives.WriteUInt64BigEndian(buf[i..], tmp);
}
// Handle remaining bytes
p += n;
var tail = buf[n..];
for (int i = 0; i < tail.Length; i++)
{
tail[i] ^= MaskKey[p & 3];
p++;
}
MaskKeyPos = (byte)(p & 3);
}
/// <summary>
/// Read and decode WebSocket frames from a buffer.
/// Returns list of decoded payload byte arrays.
/// Ported from websocket.go lines 208-351.
/// </summary>
public static List<byte[]> ReadFrames(WsReadInfo r, Stream stream, int available, int maxPayload)
{
var bufs = new List<byte[]>();
var buf = new byte[available];
int bytesRead = 0;
// Fill the buffer from the stream
while (bytesRead < available)
{
int n = stream.Read(buf, bytesRead, available - bytesRead);
if (n == 0) break;
bytesRead += n;
}
int pos = 0;
int max = bytesRead;
while (pos < max)
{
if (r.FrameStart)
{
if (pos >= max) break;
byte b0 = buf[pos];
int frameType = b0 & 0x0F;
bool final = (b0 & WsConstants.FinalBit) != 0;
bool compressed = (b0 & WsConstants.Rsv1Bit) != 0;
pos++;
// Read second byte
var (b1Buf, newPos) = WsGet(stream, buf, pos, max, 1);
pos = newPos;
byte b1 = b1Buf[0];
// Check mask bit
if (r.ExpectMask && (b1 & WsConstants.MaskBit) == 0)
throw new InvalidOperationException("mask bit missing");
r.Remaining = b1 & 0x7F;
// Validate frame types
if (WsConstants.IsControlFrame(frameType))
{
if (r.Remaining > WsConstants.MaxControlPayloadSize)
throw new InvalidOperationException("control frame length too large");
if (!final)
throw new InvalidOperationException("control frame does not have final bit set");
}
else if (frameType == WsConstants.TextMessage || frameType == WsConstants.BinaryMessage)
{
if (!r.FirstFrame)
throw new InvalidOperationException("new message before previous finished");
r.FirstFrame = final;
r.FrameCompressed = compressed;
}
else if (frameType == WsConstants.ContinuationFrame)
{
if (r.FirstFrame || compressed)
throw new InvalidOperationException("invalid continuation frame");
r.FirstFrame = final;
}
else
{
throw new InvalidOperationException($"unknown opcode {frameType}");
}
// Extended payload length
switch (r.Remaining)
{
case 126:
{
var (lenBuf, p2) = WsGet(stream, buf, pos, max, 2);
pos = p2;
r.Remaining = BinaryPrimitives.ReadUInt16BigEndian(lenBuf);
break;
}
case 127:
{
var (lenBuf, p2) = WsGet(stream, buf, pos, max, 8);
pos = p2;
var len64 = BinaryPrimitives.ReadUInt64BigEndian(lenBuf);
if (len64 > (ulong)maxPayload)
throw new InvalidOperationException($"frame payload length {len64} exceeds max payload {maxPayload}");
r.Remaining = (int)len64;
break;
}
}
// Read mask key (mask bit already validated at line 134)
if (r.ExpectMask)
{
var (keyBuf, p2) = WsGet(stream, buf, pos, max, 4);
pos = p2;
keyBuf.AsSpan(0, 4).CopyTo(r.MaskKey);
r.MaskKeyPos = 0;
}
// Handle control frames
if (WsConstants.IsControlFrame(frameType))
{
pos = HandleControlFrame(r, frameType, stream, buf, pos, max);
continue;
}
r.FrameStart = false;
}
if (pos < max)
{
int n = r.Remaining;
if (pos + n > max) n = max - pos;
var payloadSlice = buf.AsSpan(pos, n).ToArray();
pos += n;
r.Remaining -= n;
if (r.ExpectMask)
r.Unmask(payloadSlice);
bool addToBufs = true;
if (r.FrameCompressed)
{
addToBufs = false;
r.CompressedBuffers ??= [];
r.CompressedBuffers.Add(payloadSlice);
if (r.FirstFrame && r.Remaining == 0)
{
var decompressed = WsCompression.Decompress(r.CompressedBuffers, maxPayload);
r.CompressedBuffers = null;
r.FrameCompressed = false;
addToBufs = true;
payloadSlice = decompressed;
}
}
if (addToBufs && payloadSlice.Length > 0)
bufs.Add(payloadSlice);
if (r.Remaining == 0)
r.FrameStart = true;
}
}
return bufs;
}
private static int HandleControlFrame(WsReadInfo r, int frameType, Stream stream, byte[] buf, int pos, int max)
{
byte[]? payload = null;
if (r.Remaining > 0)
{
var (payloadBuf, newPos) = WsGet(stream, buf, pos, max, r.Remaining);
pos = newPos;
payload = payloadBuf;
if (r.ExpectMask)
r.Unmask(payload);
r.Remaining = 0;
}
switch (frameType)
{
case WsConstants.CloseMessage:
r.CloseReceived = true;
r.CloseStatus = WsConstants.CloseStatusNoStatusReceived;
if (payload != null && payload.Length >= WsConstants.CloseStatusSize)
{
r.CloseStatus = BinaryPrimitives.ReadUInt16BigEndian(payload);
if (payload.Length > WsConstants.CloseStatusSize)
r.CloseBody = Encoding.UTF8.GetString(payload.AsSpan(WsConstants.CloseStatusSize));
}
// Per RFC 6455 Section 5.5.1, always send a close response
if (r.CloseStatus != WsConstants.CloseStatusNoStatusReceived)
{
var closeMsg = WsFrameWriter.CreateCloseMessage(r.CloseStatus, r.CloseBody ?? "");
r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.CloseMessage, closeMsg));
}
else
{
// Empty close frame — respond with empty close
r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.CloseMessage, []));
}
break;
case WsConstants.PingMessage:
r.PendingControlFrames.Add(new ControlFrameAction(WsConstants.PongMessage, payload ?? []));
break;
case WsConstants.PongMessage:
// Nothing to do
break;
}
return pos;
}
/// <summary>
/// Gets needed bytes from buffer or reads from stream.
/// Ported from websocket.go lines 178-193.
/// </summary>
private static (byte[] data, int newPos) WsGet(Stream stream, byte[] buf, int pos, int max, int needed)
{
int avail = max - pos;
if (avail >= needed)
return (buf[pos..(pos + needed)], pos + needed);
var b = new byte[needed];
int start = 0;
if (avail > 0)
{
Buffer.BlockCopy(buf, pos, b, 0, avail);
start = avail;
}
while (start < needed)
{
int n = stream.Read(b, start, needed - start);
if (n == 0) throw new IOException("unexpected end of stream");
start += n;
}
return (b, pos + avail);
}
}
public readonly record struct ControlFrameAction(int Opcode, byte[] Payload);

View File

@@ -1,268 +0,0 @@
using System.Net;
using System.Security.Cryptography;
using System.Text;
namespace NATS.Server.WebSocket;
/// <summary>
/// WebSocket HTTP upgrade handshake handler.
/// Ported from golang/nats-server/server/websocket.go lines 731-917.
/// </summary>
public static class WsUpgrade
{
public static async Task<WsUpgradeResult> TryUpgradeAsync(
Stream inputStream, Stream outputStream, WebSocketOptions options,
CancellationToken ct = default)
{
try
{
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
cts.CancelAfter(options.HandshakeTimeout);
var (method, path, headers) = await ReadHttpRequestAsync(inputStream, cts.Token);
if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase))
return await FailAsync(outputStream, 405, "request method must be GET");
if (!headers.ContainsKey("Host"))
return await FailAsync(outputStream, 400, "'Host' missing in request");
if (!HeaderContains(headers, "Upgrade", "websocket"))
return await FailAsync(outputStream, 400, "invalid value for header 'Upgrade'");
if (!HeaderContains(headers, "Connection", "Upgrade"))
return await FailAsync(outputStream, 400, "invalid value for header 'Connection'");
if (!headers.TryGetValue("Sec-WebSocket-Key", out var key) || string.IsNullOrEmpty(key))
return await FailAsync(outputStream, 400, "key missing");
if (!HeaderContains(headers, "Sec-WebSocket-Version", "13"))
return await FailAsync(outputStream, 400, "invalid version");
var kind = path switch
{
_ when path.EndsWith("/leafnode") => WsClientKind.Leaf,
_ when path.EndsWith("/mqtt") => WsClientKind.Mqtt,
_ => WsClientKind.Client,
};
// Origin checking
if (options.SameOrigin || options.AllowedOrigins is { Count: > 0 })
{
var checker = new WsOriginChecker(options.SameOrigin, options.AllowedOrigins);
headers.TryGetValue("Origin", out var origin);
if (string.IsNullOrEmpty(origin))
headers.TryGetValue("Sec-WebSocket-Origin", out origin);
var originErr = checker.CheckOrigin(origin, headers.GetValueOrDefault("Host", ""), isTls: false);
if (originErr != null)
return await FailAsync(outputStream, 403, $"origin not allowed: {originErr}");
}
// Compression negotiation
bool compress = options.Compression;
if (compress)
{
compress = headers.TryGetValue("Sec-WebSocket-Extensions", out var ext) &&
ext.Contains(WsConstants.PmcExtension, StringComparison.OrdinalIgnoreCase);
}
// No-masking support (leaf nodes only — browser clients must always mask)
bool noMasking = kind == WsClientKind.Leaf &&
headers.TryGetValue(WsConstants.NoMaskingHeader, out var nmVal) &&
string.Equals(nmVal.Trim(), WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase);
// Browser detection
bool browser = false;
bool noCompFrag = false;
if (kind is WsClientKind.Client or WsClientKind.Mqtt &&
headers.TryGetValue("User-Agent", out var ua) && ua.StartsWith("Mozilla/"))
{
browser = true;
// Disable fragmentation of compressed frames for Safari browsers.
// Safari has both "Version/" and "Safari/" in the user agent string,
// while Chrome on macOS has "Safari/" but not "Version/".
noCompFrag = compress && ua.Contains("Version/") && ua.Contains("Safari/");
}
// Cookie extraction
string? cookieJwt = null, cookieUsername = null, cookiePassword = null, cookieToken = null;
if ((kind is WsClientKind.Client or WsClientKind.Mqtt) &&
headers.TryGetValue("Cookie", out var cookieHeader))
{
var cookies = ParseCookies(cookieHeader);
if (options.JwtCookie != null) cookies.TryGetValue(options.JwtCookie, out cookieJwt);
if (options.UsernameCookie != null) cookies.TryGetValue(options.UsernameCookie, out cookieUsername);
if (options.PasswordCookie != null) cookies.TryGetValue(options.PasswordCookie, out cookiePassword);
if (options.TokenCookie != null) cookies.TryGetValue(options.TokenCookie, out cookieToken);
}
// X-Forwarded-For client IP extraction
string? clientIp = null;
if (headers.TryGetValue(WsConstants.XForwardedForHeader, out var xff))
{
var ip = xff.Split(',')[0].Trim();
if (IPAddress.TryParse(ip, out _))
clientIp = ip;
}
// Build the 101 Switching Protocols response
var response = new StringBuilder();
response.Append("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ");
response.Append(ComputeAcceptKey(key));
response.Append("\r\n");
if (compress)
response.Append(WsConstants.PmcFullResponse);
if (noMasking)
response.Append(WsConstants.NoMaskingFullResponse);
if (options.Headers != null)
{
foreach (var (k, v) in options.Headers)
{
response.Append(k);
response.Append(": ");
response.Append(v);
response.Append("\r\n");
}
}
response.Append("\r\n");
var responseBytes = Encoding.ASCII.GetBytes(response.ToString());
await outputStream.WriteAsync(responseBytes);
await outputStream.FlushAsync();
return new WsUpgradeResult(
Success: true, Compress: compress, Browser: browser, NoCompFrag: noCompFrag,
MaskRead: !noMasking, MaskWrite: false,
CookieJwt: cookieJwt, CookieUsername: cookieUsername,
CookiePassword: cookiePassword, CookieToken: cookieToken,
ClientIp: clientIp, Kind: kind);
}
catch (Exception)
{
return WsUpgradeResult.Failed;
}
}
/// <summary>
/// Computes the Sec-WebSocket-Accept value per RFC 6455 Section 4.2.2.
/// </summary>
public static string ComputeAcceptKey(string clientKey)
{
var combined = Encoding.ASCII.GetBytes(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
var hash = SHA1.HashData(combined);
return Convert.ToBase64String(hash);
}
private static async Task<WsUpgradeResult> FailAsync(Stream output, int statusCode, string reason)
{
var statusText = statusCode switch
{
400 => "Bad Request",
403 => "Forbidden",
405 => "Method Not Allowed",
_ => "Internal Server Error",
};
var response = $"HTTP/1.1 {statusCode} {statusText}\r\nSec-WebSocket-Version: 13\r\nContent-Type: text/plain\r\nContent-Length: {reason.Length}\r\n\r\n{reason}";
await output.WriteAsync(Encoding.ASCII.GetBytes(response));
await output.FlushAsync();
return WsUpgradeResult.Failed;
}
private static async Task<(string method, string path, Dictionary<string, string> headers)> ReadHttpRequestAsync(
Stream stream, CancellationToken ct)
{
var headerBytes = new List<byte>(4096);
var buf = new byte[512];
while (true)
{
int n = await stream.ReadAsync(buf, ct);
if (n == 0) throw new IOException("connection closed during handshake");
for (int i = 0; i < n; i++)
{
headerBytes.Add(buf[i]);
if (headerBytes.Count >= 4 &&
headerBytes[^4] == '\r' && headerBytes[^3] == '\n' &&
headerBytes[^2] == '\r' && headerBytes[^1] == '\n')
goto done;
if (headerBytes.Count > 8192)
throw new InvalidOperationException("HTTP header too large");
}
}
done:;
var text = Encoding.ASCII.GetString(headerBytes.ToArray());
var lines = text.Split("\r\n", StringSplitOptions.None);
if (lines.Length < 1) throw new InvalidOperationException("invalid HTTP request");
var parts = lines[0].Split(' ');
if (parts.Length < 3) throw new InvalidOperationException("invalid HTTP request line");
var method = parts[0];
var path = parts[1];
var headers = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
for (int i = 1; i < lines.Length; i++)
{
var line = lines[i];
if (string.IsNullOrEmpty(line)) break;
var colonIdx = line.IndexOf(':');
if (colonIdx > 0)
{
var name = line[..colonIdx].Trim();
var value = line[(colonIdx + 1)..].Trim();
headers[name] = value;
}
}
return (method, path, headers);
}
private static bool HeaderContains(Dictionary<string, string> headers, string name, string value)
{
if (!headers.TryGetValue(name, out var headerValue))
return false;
foreach (var token in headerValue.Split(','))
{
if (string.Equals(token.Trim(), value, StringComparison.OrdinalIgnoreCase))
return true;
}
return false;
}
private static Dictionary<string, string> ParseCookies(string cookieHeader)
{
var cookies = new Dictionary<string, string>(StringComparer.Ordinal);
foreach (var pair in cookieHeader.Split(';'))
{
var trimmed = pair.Trim();
var eqIdx = trimmed.IndexOf('=');
if (eqIdx > 0)
cookies[trimmed[..eqIdx].Trim()] = trimmed[(eqIdx + 1)..].Trim();
}
return cookies;
}
}
/// <summary>
/// Result of a WebSocket upgrade handshake attempt.
/// </summary>
public readonly record struct WsUpgradeResult(
bool Success,
bool Compress,
bool Browser,
bool NoCompFrag,
bool MaskRead,
bool MaskWrite,
string? CookieJwt,
string? CookieUsername,
string? CookiePassword,
string? CookieToken,
string? ClientIp,
WsClientKind Kind)
{
public static readonly WsUpgradeResult Failed = new(
Success: false, Compress: false, Browser: false, NoCompFrag: false,
MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null,
CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client);
}

View File

@@ -0,0 +1,121 @@
using System.Text.Json;
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Server.Events;
namespace NATS.Server.Tests;
public class EventSystemTests
{
[Fact]
public void ConnectEventMsg_serializes_with_correct_type()
{
var evt = new ConnectEventMsg
{
Type = ConnectEventMsg.EventType,
Id = "test123",
Time = new DateTime(2026, 1, 1, 0, 0, 0, DateTimeKind.Utc),
Server = new EventServerInfo { Name = "test-server", Id = "SRV1" },
Client = new EventClientInfo { Id = 1, Account = "$G" },
};
var json = JsonSerializer.Serialize(evt, EventJsonContext.Default.ConnectEventMsg);
json.ShouldContain("\"type\":\"io.nats.server.advisory.v1.client_connect\"");
json.ShouldContain("\"server\":");
json.ShouldContain("\"client\":");
}
[Fact]
public void DisconnectEventMsg_serializes_with_reason()
{
var evt = new DisconnectEventMsg
{
Type = DisconnectEventMsg.EventType,
Id = "test456",
Time = DateTime.UtcNow,
Server = new EventServerInfo { Name = "test-server", Id = "SRV1" },
Client = new EventClientInfo { Id = 2, Account = "myacc" },
Reason = "Client Closed",
Sent = new DataStats { Msgs = 10, Bytes = 1024 },
Received = new DataStats { Msgs = 5, Bytes = 512 },
};
var json = JsonSerializer.Serialize(evt, EventJsonContext.Default.DisconnectEventMsg);
json.ShouldContain("\"reason\":\"Client Closed\"");
}
[Fact]
public void ServerStatsMsg_serializes()
{
var evt = new ServerStatsMsg
{
Server = new EventServerInfo { Name = "srv1", Id = "ABC" },
Stats = new ServerStatsData
{
Connections = 10,
TotalConnections = 100,
InMsgs = 5000,
OutMsgs = 4500,
InBytes = 1_000_000,
OutBytes = 900_000,
Mem = 50 * 1024 * 1024,
Subscriptions = 42,
},
};
var json = JsonSerializer.Serialize(evt, EventJsonContext.Default.ServerStatsMsg);
json.ShouldContain("\"connections\":10");
json.ShouldContain("\"in_msgs\":5000");
}
[Fact]
public async Task InternalEventSystem_start_and_stop_lifecycle()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var eventSystem = server.EventSystem;
eventSystem.ShouldNotBeNull();
eventSystem.SystemClient.ShouldNotBeNull();
eventSystem.SystemClient.Kind.ShouldBe(ClientKind.System);
await server.ShutdownAsync();
}
[Fact]
public async Task SendInternalMsg_delivers_to_system_subscriber()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<string>();
server.EventSystem!.SysSubscribe("test.subject", (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(subject);
});
server.SendInternalMsg("test.subject", null, new { Value = "hello" });
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
result.ShouldBe("test.subject");
await server.ShutdownAsync();
}
private static NatsServer CreateTestServer()
{
var port = GetFreePort();
return new NatsServer(new NatsOptions { Port = port }, NullLoggerFactory.Instance);
}
private static int GetFreePort()
{
using var sock = new System.Net.Sockets.Socket(
System.Net.Sockets.AddressFamily.InterNetwork,
System.Net.Sockets.SocketType.Stream,
System.Net.Sockets.ProtocolType.Tcp);
sock.Bind(new System.Net.IPEndPoint(System.Net.IPAddress.Loopback, 0));
return ((System.Net.IPEndPoint)sock.LocalEndPoint!).Port;
}
}

View File

@@ -0,0 +1,338 @@
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Server;
using NATS.Server.Auth;
using NATS.Server.Imports;
using NATS.Server.Subscriptions;
namespace NATS.Server.Tests;
public class ImportExportTests
{
[Fact]
public void ExportAuth_public_export_authorizes_any_account()
{
var auth = new ExportAuth();
var account = new Account("test");
auth.IsAuthorized(account).ShouldBeTrue();
}
[Fact]
public void ExportAuth_approved_accounts_restricts_access()
{
var auth = new ExportAuth { ApprovedAccounts = ["allowed"] };
var allowed = new Account("allowed");
var denied = new Account("denied");
auth.IsAuthorized(allowed).ShouldBeTrue();
auth.IsAuthorized(denied).ShouldBeFalse();
}
[Fact]
public void ExportAuth_revoked_account_denied()
{
var auth = new ExportAuth
{
ApprovedAccounts = ["test"],
RevokedAccounts = new() { ["test"] = DateTimeOffset.UtcNow.ToUnixTimeSeconds() },
};
var account = new Account("test");
auth.IsAuthorized(account).ShouldBeFalse();
}
[Fact]
public void ServiceResponseType_defaults_to_singleton()
{
var import = new ServiceImport
{
DestinationAccount = new Account("dest"),
From = "requests.>",
To = "api.>",
};
import.ResponseType.ShouldBe(ServiceResponseType.Singleton);
}
[Fact]
public void ExportMap_stores_and_retrieves_exports()
{
var map = new ExportMap();
map.Services["api.>"] = new ServiceExport { Account = new Account("svc") };
map.Streams["events.>"] = new StreamExport();
map.Services.ShouldContainKey("api.>");
map.Streams.ShouldContainKey("events.>");
}
[Fact]
public void ImportMap_stores_service_imports()
{
var map = new ImportMap();
var si = new ServiceImport
{
DestinationAccount = new Account("dest"),
From = "requests.>",
To = "api.>",
};
map.AddServiceImport(si);
map.Services.ShouldContainKey("requests.>");
map.Services["requests.>"].Count.ShouldBe(1);
}
[Fact]
public void Account_add_service_export_and_import()
{
var exporter = new Account("exporter");
var importer = new Account("importer");
exporter.AddServiceExport("api.>", ServiceResponseType.Singleton, null);
exporter.Exports.Services.ShouldContainKey("api.>");
var si = importer.AddServiceImport(exporter, "requests.>", "api.>");
si.ShouldNotBeNull();
si.From.ShouldBe("requests.>");
si.To.ShouldBe("api.>");
si.DestinationAccount.ShouldBe(exporter);
importer.Imports.Services.ShouldContainKey("requests.>");
}
[Fact]
public void Account_add_stream_export_and_import()
{
var exporter = new Account("exporter");
var importer = new Account("importer");
exporter.AddStreamExport("events.>", null);
exporter.Exports.Streams.ShouldContainKey("events.>");
importer.AddStreamImport(exporter, "events.>", "imported.events.>");
importer.Imports.Streams.Count.ShouldBe(1);
importer.Imports.Streams[0].From.ShouldBe("events.>");
importer.Imports.Streams[0].To.ShouldBe("imported.events.>");
}
[Fact]
public void Account_service_import_auth_rejected()
{
var exporter = new Account("exporter");
var importer = new Account("importer");
exporter.AddServiceExport("api.>", ServiceResponseType.Singleton, [new Account("other")]);
Should.Throw<UnauthorizedAccessException>(() =>
importer.AddServiceImport(exporter, "requests.>", "api.>"));
}
[Fact]
public void Account_lazy_creates_internal_client()
{
var account = new Account("test");
var client = account.GetOrCreateInternalClient(99);
client.ShouldNotBeNull();
client.Kind.ShouldBe(ClientKind.Account);
client.Account.ShouldBe(account);
// Second call returns same instance
var client2 = account.GetOrCreateInternalClient(100);
client2.ShouldBeSameAs(client);
}
[Fact]
public async Task Service_import_forwards_message_to_export_account()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
// Set up exporter and importer accounts
var exporter = server.GetOrCreateAccount("exporter");
var importer = server.GetOrCreateAccount("importer");
exporter.AddServiceExport("api.>", ServiceResponseType.Singleton, null);
importer.AddServiceImport(exporter, "requests.>", "api.>");
// Wire the import subscriptions into the importer account
server.WireServiceImports(importer);
// Subscribe in exporter account to receive forwarded message
var exportSub = new Subscription { Subject = "api.test", Sid = "export-1", Client = null };
exporter.SubList.Insert(exportSub);
// Verify import infrastructure is wired: the importer should have service import entries
importer.Imports.Services.ShouldContainKey("requests.>");
importer.Imports.Services["requests.>"].Count.ShouldBe(1);
importer.Imports.Services["requests.>"][0].DestinationAccount.ShouldBe(exporter);
await server.ShutdownAsync();
}
[Fact]
public void ProcessServiceImport_delivers_to_destination_account_subscribers()
{
using var server = CreateTestServer();
var exporter = server.GetOrCreateAccount("exporter");
var importer = server.GetOrCreateAccount("importer");
exporter.AddServiceExport("api.>", ServiceResponseType.Singleton, null);
importer.AddServiceImport(exporter, "requests.>", "api.>");
// Add a subscriber in the exporter account's SubList
var received = new List<(string Subject, string Sid)>();
var mockClient = new TestNatsClient(1, exporter);
mockClient.OnMessage = (subject, sid, _, _, _) =>
received.Add((subject, sid));
var exportSub = new Subscription { Subject = "api.test", Sid = "s1", Client = mockClient };
exporter.SubList.Insert(exportSub);
// Process a service import directly
var si = importer.Imports.Services["requests.>"][0];
server.ProcessServiceImport(si, "requests.test", null,
ReadOnlyMemory<byte>.Empty, ReadOnlyMemory<byte>.Empty);
received.Count.ShouldBe(1);
received[0].Subject.ShouldBe("api.test");
received[0].Sid.ShouldBe("s1");
}
[Fact]
public void ProcessServiceImport_with_transform_applies_subject_mapping()
{
using var server = CreateTestServer();
var exporter = server.GetOrCreateAccount("exporter");
var importer = server.GetOrCreateAccount("importer");
exporter.AddServiceExport("api.>", ServiceResponseType.Singleton, null);
var si = importer.AddServiceImport(exporter, "requests.>", "api.>");
// Create a transform from requests.> to api.>
var transform = SubjectTransform.Create("requests.>", "api.>");
transform.ShouldNotBeNull();
// Create a new import with the transform set
var siWithTransform = new ServiceImport
{
DestinationAccount = exporter,
From = "requests.>",
To = "api.>",
Transform = transform,
};
var received = new List<string>();
var mockClient = new TestNatsClient(1, exporter);
mockClient.OnMessage = (subject, _, _, _, _) =>
received.Add(subject);
var exportSub = new Subscription { Subject = "api.hello", Sid = "s1", Client = mockClient };
exporter.SubList.Insert(exportSub);
server.ProcessServiceImport(siWithTransform, "requests.hello", null,
ReadOnlyMemory<byte>.Empty, ReadOnlyMemory<byte>.Empty);
received.Count.ShouldBe(1);
received[0].ShouldBe("api.hello");
}
[Fact]
public void ProcessServiceImport_skips_invalid_imports()
{
using var server = CreateTestServer();
var exporter = server.GetOrCreateAccount("exporter");
var importer = server.GetOrCreateAccount("importer");
exporter.AddServiceExport("api.>", ServiceResponseType.Singleton, null);
importer.AddServiceImport(exporter, "requests.>", "api.>");
// Mark the import as invalid
var si = importer.Imports.Services["requests.>"][0];
si.Invalid = true;
// Add a subscriber in the exporter account
var received = new List<string>();
var mockClient = new TestNatsClient(1, exporter);
mockClient.OnMessage = (subject, _, _, _, _) =>
received.Add(subject);
var exportSub = new Subscription { Subject = "api.test", Sid = "s1", Client = mockClient };
exporter.SubList.Insert(exportSub);
// ProcessServiceImport should be a no-op for invalid imports
server.ProcessServiceImport(si, "requests.test", null,
ReadOnlyMemory<byte>.Empty, ReadOnlyMemory<byte>.Empty);
received.Count.ShouldBe(0);
}
[Fact]
public void ProcessServiceImport_delivers_to_queue_groups()
{
using var server = CreateTestServer();
var exporter = server.GetOrCreateAccount("exporter");
var importer = server.GetOrCreateAccount("importer");
exporter.AddServiceExport("api.>", ServiceResponseType.Singleton, null);
importer.AddServiceImport(exporter, "requests.>", "api.>");
// Add queue group subscribers in the exporter account
var received = new List<(string Subject, string Sid)>();
var mockClient1 = new TestNatsClient(1, exporter);
mockClient1.OnMessage = (subject, sid, _, _, _) =>
received.Add((subject, sid));
var mockClient2 = new TestNatsClient(2, exporter);
mockClient2.OnMessage = (subject, sid, _, _, _) =>
received.Add((subject, sid));
var qSub1 = new Subscription { Subject = "api.test", Sid = "q1", Queue = "workers", Client = mockClient1 };
var qSub2 = new Subscription { Subject = "api.test", Sid = "q2", Queue = "workers", Client = mockClient2 };
exporter.SubList.Insert(qSub1);
exporter.SubList.Insert(qSub2);
var si = importer.Imports.Services["requests.>"][0];
server.ProcessServiceImport(si, "requests.test", null,
ReadOnlyMemory<byte>.Empty, ReadOnlyMemory<byte>.Empty);
// One member of the queue group should receive the message
received.Count.ShouldBe(1);
}
private static NatsServer CreateTestServer()
{
var port = GetFreePort();
return new NatsServer(new NatsOptions { Port = port }, NullLoggerFactory.Instance);
}
private static int GetFreePort()
{
using var sock = new System.Net.Sockets.Socket(
System.Net.Sockets.AddressFamily.InterNetwork,
System.Net.Sockets.SocketType.Stream,
System.Net.Sockets.ProtocolType.Tcp);
sock.Bind(new System.Net.IPEndPoint(System.Net.IPAddress.Loopback, 0));
return ((System.Net.IPEndPoint)sock.LocalEndPoint!).Port;
}
/// <summary>
/// Minimal test double for INatsClient used in import/export tests.
/// </summary>
private sealed class TestNatsClient(ulong id, Account account) : INatsClient
{
public ulong Id => id;
public ClientKind Kind => ClientKind.Client;
public Account? Account => account;
public Protocol.ClientOptions? ClientOpts => null;
public ClientPermissions? Permissions => null;
public Action<string, string, string?, ReadOnlyMemory<byte>, ReadOnlyMemory<byte>>? OnMessage { get; set; }
public void SendMessage(string subject, string sid, string? replyTo,
ReadOnlyMemory<byte> headers, ReadOnlyMemory<byte> payload)
{
OnMessage?.Invoke(subject, sid, replyTo, headers, payload);
}
public bool QueueOutbound(ReadOnlyMemory<byte> data) => true;
public void RemoveSubscription(string sid) { }
}
}

View File

@@ -0,0 +1,85 @@
using NATS.Server.Auth;
namespace NATS.Server.Tests;
public class InternalClientTests
{
[Theory]
[InlineData(ClientKind.Client, false)]
[InlineData(ClientKind.Router, false)]
[InlineData(ClientKind.Gateway, false)]
[InlineData(ClientKind.Leaf, false)]
[InlineData(ClientKind.System, true)]
[InlineData(ClientKind.JetStream, true)]
[InlineData(ClientKind.Account, true)]
public void IsInternal_returns_correct_value(ClientKind kind, bool expected)
{
kind.IsInternal().ShouldBe(expected);
}
[Fact]
public void NatsClient_implements_INatsClient()
{
typeof(NatsClient).GetInterfaces().ShouldContain(typeof(INatsClient));
}
[Fact]
public void NatsClient_kind_is_Client()
{
typeof(NatsClient).GetProperty("Kind")!.PropertyType.ShouldBe(typeof(ClientKind));
}
[Fact]
public void InternalClient_system_kind()
{
var account = new Account("$SYS");
var client = new InternalClient(1, ClientKind.System, account);
client.Kind.ShouldBe(ClientKind.System);
client.IsInternal.ShouldBeTrue();
client.Id.ShouldBe(1UL);
client.Account.ShouldBe(account);
}
[Fact]
public void InternalClient_account_kind()
{
var account = new Account("myaccount");
var client = new InternalClient(2, ClientKind.Account, account);
client.Kind.ShouldBe(ClientKind.Account);
client.IsInternal.ShouldBeTrue();
}
[Fact]
public void InternalClient_rejects_non_internal_kind()
{
var account = new Account("test");
Should.Throw<ArgumentException>(() => new InternalClient(1, ClientKind.Client, account));
}
[Fact]
public void InternalClient_SendMessage_invokes_callback()
{
var account = new Account("$SYS");
var client = new InternalClient(1, ClientKind.System, account);
string? capturedSubject = null;
string? capturedSid = null;
client.MessageCallback = (subject, sid, replyTo, headers, payload) =>
{
capturedSubject = subject;
capturedSid = sid;
};
client.SendMessage("test.subject", "1", null, ReadOnlyMemory<byte>.Empty, ReadOnlyMemory<byte>.Empty);
capturedSubject.ShouldBe("test.subject");
capturedSid.ShouldBe("1");
}
[Fact]
public void InternalClient_QueueOutbound_returns_true_noop()
{
var account = new Account("$SYS");
var client = new InternalClient(1, ClientKind.System, account);
client.QueueOutbound(ReadOnlyMemory<byte>.Empty).ShouldBeTrue();
}
}

View File

@@ -0,0 +1,149 @@
using NATS.Server.Auth;
using NATS.Server.Imports;
namespace NATS.Server.Tests;
public class ResponseRoutingTests
{
[Fact]
public void GenerateReplyPrefix_creates_unique_prefix()
{
var prefix1 = ResponseRouter.GenerateReplyPrefix();
var prefix2 = ResponseRouter.GenerateReplyPrefix();
prefix1.ShouldStartWith("_R_.");
prefix2.ShouldStartWith("_R_.");
prefix1.ShouldNotBe(prefix2);
prefix1.Length.ShouldBeGreaterThan(4);
}
[Fact]
public void GenerateReplyPrefix_ends_with_dot()
{
var prefix = ResponseRouter.GenerateReplyPrefix();
prefix.ShouldEndWith(".");
// Format: "_R_." + 10 chars + "." = 15 chars
prefix.Length.ShouldBe(15);
}
[Fact]
public void Singleton_response_import_removed_after_delivery()
{
var exporter = new Account("exporter");
exporter.AddServiceExport("api.test", ServiceResponseType.Singleton, null);
var replyPrefix = ResponseRouter.GenerateReplyPrefix();
var responseSi = new ServiceImport
{
DestinationAccount = exporter,
From = replyPrefix + ">",
To = "_INBOX.original.reply",
IsResponse = true,
ResponseType = ServiceResponseType.Singleton,
};
exporter.Exports.Responses[replyPrefix] = responseSi;
exporter.Exports.Responses.ShouldContainKey(replyPrefix);
// Simulate singleton delivery cleanup
ResponseRouter.CleanupResponse(exporter, replyPrefix, responseSi);
exporter.Exports.Responses.ShouldNotContainKey(replyPrefix);
}
[Fact]
public void CreateResponseImport_registers_in_exporter_responses()
{
var exporter = new Account("exporter");
var importer = new Account("importer");
exporter.AddServiceExport("api.test", ServiceResponseType.Singleton, null);
var originalSi = new ServiceImport
{
DestinationAccount = exporter,
From = "api.test",
To = "api.test",
Export = exporter.Exports.Services["api.test"],
ResponseType = ServiceResponseType.Singleton,
};
var responseSi = ResponseRouter.CreateResponseImport(exporter, originalSi, "_INBOX.abc123");
responseSi.IsResponse.ShouldBeTrue();
responseSi.ResponseType.ShouldBe(ServiceResponseType.Singleton);
responseSi.To.ShouldBe("_INBOX.abc123");
responseSi.DestinationAccount.ShouldBe(exporter);
responseSi.From.ShouldEndWith(">");
responseSi.Export.ShouldBe(originalSi.Export);
// Should be registered in the exporter's response map
exporter.Exports.Responses.Count.ShouldBe(1);
}
[Fact]
public void CreateResponseImport_preserves_streamed_response_type()
{
var exporter = new Account("exporter");
exporter.AddServiceExport("api.stream", ServiceResponseType.Streamed, null);
var originalSi = new ServiceImport
{
DestinationAccount = exporter,
From = "api.stream",
To = "api.stream",
Export = exporter.Exports.Services["api.stream"],
ResponseType = ServiceResponseType.Streamed,
};
var responseSi = ResponseRouter.CreateResponseImport(exporter, originalSi, "_INBOX.xyz789");
responseSi.ResponseType.ShouldBe(ServiceResponseType.Streamed);
}
[Fact]
public void Multiple_response_imports_each_get_unique_prefix()
{
var exporter = new Account("exporter");
exporter.AddServiceExport("api.test", ServiceResponseType.Singleton, null);
var originalSi = new ServiceImport
{
DestinationAccount = exporter,
From = "api.test",
To = "api.test",
Export = exporter.Exports.Services["api.test"],
ResponseType = ServiceResponseType.Singleton,
};
var resp1 = ResponseRouter.CreateResponseImport(exporter, originalSi, "_INBOX.reply1");
var resp2 = ResponseRouter.CreateResponseImport(exporter, originalSi, "_INBOX.reply2");
exporter.Exports.Responses.Count.ShouldBe(2);
resp1.To.ShouldBe("_INBOX.reply1");
resp2.To.ShouldBe("_INBOX.reply2");
resp1.From.ShouldNotBe(resp2.From);
}
[Fact]
public void LatencyTracker_should_sample_respects_percentage()
{
var latency = new ServiceLatency { SamplingPercentage = 0, Subject = "latency.test" };
LatencyTracker.ShouldSample(latency).ShouldBeFalse();
var latency100 = new ServiceLatency { SamplingPercentage = 100, Subject = "latency.test" };
LatencyTracker.ShouldSample(latency100).ShouldBeTrue();
}
[Fact]
public void LatencyTracker_builds_latency_message()
{
var msg = LatencyTracker.BuildLatencyMsg("requester", "responder",
TimeSpan.FromMilliseconds(5), TimeSpan.FromMilliseconds(10));
msg.Requestor.ShouldBe("requester");
msg.Responder.ShouldBe("responder");
msg.ServiceLatencyNanos.ShouldBeGreaterThan(0);
msg.TotalLatencyNanos.ShouldBeGreaterThan(0);
}
}

View File

@@ -0,0 +1,133 @@
using System.Text.Json;
using NATS.Server;
using NATS.Server.Events;
using Microsoft.Extensions.Logging.Abstractions;
namespace NATS.Server.Tests;
public class SystemEventsTests
{
[Fact]
public async Task Server_publishes_connect_event_on_client_auth()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<string>();
server.EventSystem!.SysSubscribe("$SYS.ACCOUNT.*.CONNECT", (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(subject);
});
// Connect a real client
using var sock = new System.Net.Sockets.Socket(
System.Net.Sockets.AddressFamily.InterNetwork,
System.Net.Sockets.SocketType.Stream,
System.Net.Sockets.ProtocolType.Tcp);
await sock.ConnectAsync(System.Net.IPAddress.Loopback, server.Port);
// Read INFO
var buf = new byte[4096];
await sock.ReceiveAsync(buf);
// Send CONNECT
var connect = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n");
await sock.SendAsync(connect);
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
result.ShouldStartWith("$SYS.ACCOUNT.");
result.ShouldEndWith(".CONNECT");
await server.ShutdownAsync();
}
[Fact]
public async Task Server_publishes_disconnect_event_on_client_close()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<string>();
server.EventSystem!.SysSubscribe("$SYS.ACCOUNT.*.DISCONNECT", (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(subject);
});
// Connect and then disconnect
using var sock = new System.Net.Sockets.Socket(
System.Net.Sockets.AddressFamily.InterNetwork,
System.Net.Sockets.SocketType.Stream,
System.Net.Sockets.ProtocolType.Tcp);
await sock.ConnectAsync(System.Net.IPAddress.Loopback, server.Port);
var buf = new byte[4096];
await sock.ReceiveAsync(buf);
await sock.SendAsync(System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n"));
await Task.Delay(100);
sock.Shutdown(System.Net.Sockets.SocketShutdown.Both);
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
result.ShouldStartWith("$SYS.ACCOUNT.");
result.ShouldEndWith(".DISCONNECT");
await server.ShutdownAsync();
}
[Fact]
public async Task Server_publishes_statsz_periodically()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<string>();
server.EventSystem!.SysSubscribe("$SYS.SERVER.*.STATSZ", (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(subject);
});
// Trigger a manual stats publish (don't wait 10s)
server.EventSystem!.PublishServerStats();
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
result.ShouldContain(".STATSZ");
await server.ShutdownAsync();
}
[Fact]
public async Task Server_publishes_shutdown_event()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<string>();
server.EventSystem!.SysSubscribe("$SYS.SERVER.*.SHUTDOWN", (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(subject);
});
await server.ShutdownAsync();
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
result.ShouldContain(".SHUTDOWN");
}
private static NatsServer CreateTestServer()
{
var port = GetFreePort();
return new NatsServer(new NatsOptions { Port = port }, NullLoggerFactory.Instance);
}
private static int GetFreePort()
{
using var sock = new System.Net.Sockets.Socket(
System.Net.Sockets.AddressFamily.InterNetwork,
System.Net.Sockets.SocketType.Stream,
System.Net.Sockets.ProtocolType.Tcp);
sock.Bind(new System.Net.IPEndPoint(System.Net.IPAddress.Loopback, 0));
return ((System.Net.IPEndPoint)sock.LocalEndPoint!).Port;
}
}

View File

@@ -0,0 +1,170 @@
using System.Text;
using System.Text.Json;
using NATS.Server;
using NATS.Server.Events;
using Microsoft.Extensions.Logging.Abstractions;
namespace NATS.Server.Tests;
public class SystemRequestReplyTests
{
[Fact]
public async Task Varz_request_reply_returns_server_info()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<byte[]>();
var replySubject = $"_INBOX.test.{Guid.NewGuid():N}";
server.EventSystem!.SysSubscribe(replySubject, (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(msg.ToArray());
});
var reqSubject = string.Format(EventSubjects.ServerReq, server.ServerId, "VARZ");
server.SendInternalMsg(reqSubject, replySubject, null);
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
var json = Encoding.UTF8.GetString(result);
json.ShouldContain("\"server_id\"");
json.ShouldContain("\"version\"");
json.ShouldContain("\"host\"");
json.ShouldContain("\"port\"");
await server.ShutdownAsync();
}
[Fact]
public async Task Healthz_request_reply_returns_ok()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<byte[]>();
var replySubject = $"_INBOX.test.{Guid.NewGuid():N}";
server.EventSystem!.SysSubscribe(replySubject, (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(msg.ToArray());
});
var reqSubject = string.Format(EventSubjects.ServerReq, server.ServerId, "HEALTHZ");
server.SendInternalMsg(reqSubject, replySubject, null);
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
var json = Encoding.UTF8.GetString(result);
json.ShouldContain("ok");
await server.ShutdownAsync();
}
[Fact]
public async Task Subsz_request_reply_returns_subscription_count()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<byte[]>();
var replySubject = $"_INBOX.test.{Guid.NewGuid():N}";
server.EventSystem!.SysSubscribe(replySubject, (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(msg.ToArray());
});
var reqSubject = string.Format(EventSubjects.ServerReq, server.ServerId, "SUBSZ");
server.SendInternalMsg(reqSubject, replySubject, null);
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
var json = Encoding.UTF8.GetString(result);
json.ShouldContain("\"num_subscriptions\"");
await server.ShutdownAsync();
}
[Fact]
public async Task Idz_request_reply_returns_server_identity()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<byte[]>();
var replySubject = $"_INBOX.test.{Guid.NewGuid():N}";
server.EventSystem!.SysSubscribe(replySubject, (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(msg.ToArray());
});
var reqSubject = string.Format(EventSubjects.ServerReq, server.ServerId, "IDZ");
server.SendInternalMsg(reqSubject, replySubject, null);
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
var json = Encoding.UTF8.GetString(result);
json.ShouldContain("\"server_id\"");
json.ShouldContain("\"server_name\"");
await server.ShutdownAsync();
}
[Fact]
public async Task Ping_varz_responds_via_wildcard_subject()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var received = new TaskCompletionSource<byte[]>();
var replySubject = $"_INBOX.test.{Guid.NewGuid():N}";
server.EventSystem!.SysSubscribe(replySubject, (sub, client, acc, subject, reply, hdr, msg) =>
{
received.TrySetResult(msg.ToArray());
});
var pingSubject = string.Format(EventSubjects.ServerPing, "VARZ");
server.SendInternalMsg(pingSubject, replySubject, null);
var result = await received.Task.WaitAsync(TimeSpan.FromSeconds(5));
var json = Encoding.UTF8.GetString(result);
json.ShouldContain("\"server_id\"");
await server.ShutdownAsync();
}
[Fact]
public async Task Request_without_reply_is_ignored()
{
using var server = CreateTestServer();
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
// Send a request with no reply subject -- should not crash
var reqSubject = string.Format(EventSubjects.ServerReq, server.ServerId, "VARZ");
server.SendInternalMsg(reqSubject, null, null);
// Give it a moment to process without error
await Task.Delay(200);
// Server should still be running
server.IsShuttingDown.ShouldBeFalse();
await server.ShutdownAsync();
}
private static NatsServer CreateTestServer()
{
var port = GetFreePort();
return new NatsServer(new NatsOptions { Port = port }, NullLoggerFactory.Instance);
}
private static int GetFreePort()
{
using var sock = new System.Net.Sockets.Socket(
System.Net.Sockets.AddressFamily.InterNetwork,
System.Net.Sockets.SocketType.Stream,
System.Net.Sockets.ProtocolType.Tcp);
sock.Bind(new System.Net.IPEndPoint(System.Net.IPAddress.Loopback, 0));
return ((System.Net.IPEndPoint)sock.LocalEndPoint!).Port;
}
}

View File

@@ -1,26 +0,0 @@
using Shouldly;
namespace NATS.Server.Tests.WebSocket;
public class WebSocketOptionsTests
{
[Fact]
public void DefaultOptions_PortIsNegativeOne_Disabled()
{
var opts = new WebSocketOptions();
opts.Port.ShouldBe(-1);
opts.Host.ShouldBe("0.0.0.0");
opts.Compression.ShouldBeFalse();
opts.NoTls.ShouldBeFalse();
opts.HandshakeTimeout.ShouldBe(TimeSpan.FromSeconds(2));
opts.AuthTimeout.ShouldBe(TimeSpan.FromSeconds(2));
}
[Fact]
public void NatsOptions_HasWebSocketProperty()
{
var opts = new NatsOptions();
opts.WebSocket.ShouldNotBeNull();
opts.WebSocket.Port.ShouldBe(-1);
}
}

View File

@@ -1,58 +0,0 @@
using NATS.Server.WebSocket;
using Shouldly;
namespace NATS.Server.Tests.WebSocket;
public class WsCompressionTests
{
[Fact]
public void CompressDecompress_RoundTrip()
{
var original = "Hello, WebSocket compression test! This is long enough to compress."u8.ToArray();
var compressed = WsCompression.Compress(original);
compressed.ShouldNotBeNull();
compressed.Length.ShouldBeGreaterThan(0);
var decompressed = WsCompression.Decompress([compressed], maxPayload: 4096);
decompressed.ShouldBe(original);
}
[Fact]
public void Decompress_ExceedsMaxPayload_Throws()
{
var original = new byte[1000];
Random.Shared.NextBytes(original);
var compressed = WsCompression.Compress(original);
Should.Throw<InvalidOperationException>(() =>
WsCompression.Decompress([compressed], maxPayload: 100));
}
[Fact]
public void Compress_RemovesTrailing4Bytes()
{
var data = new byte[200];
Random.Shared.NextBytes(data);
var compressed = WsCompression.Compress(data);
// The compressed data should be valid for decompression when we add the trailer back
var decompressed = WsCompression.Decompress([compressed], maxPayload: 4096);
decompressed.ShouldBe(data);
}
[Fact]
public void Decompress_MultipleBuffers()
{
var original = new byte[500];
Random.Shared.NextBytes(original);
var compressed = WsCompression.Compress(original);
// Split compressed data into multiple chunks
int mid = compressed.Length / 2;
var chunk1 = compressed[..mid];
var chunk2 = compressed[mid..];
var decompressed = WsCompression.Decompress([chunk1, chunk2], maxPayload: 4096);
decompressed.ShouldBe(original);
}
}

View File

@@ -1,124 +0,0 @@
using System.Buffers.Binary;
using NATS.Server.WebSocket;
namespace NATS.Server.Tests.WebSocket;
public class WsConnectionTests
{
[Fact]
public async Task ReadAsync_DecodesFrameAndReturnsPayload()
{
var payload = "SUB test 1\r\n"u8.ToArray();
var frame = BuildUnmaskedFrame(payload);
var inner = new MemoryStream(frame);
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
var buf = new byte[256];
int n = await ws.ReadAsync(buf);
n.ShouldBe(payload.Length);
buf[..n].ShouldBe(payload);
}
[Fact]
public async Task WriteAsync_FramesPayload()
{
var inner = new MemoryStream();
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
var payload = "MSG test 1 5\r\nHello\r\n"u8.ToArray();
await ws.WriteAsync(payload);
await ws.FlushAsync();
inner.Position = 0;
var written = inner.ToArray();
// First 2 bytes should be WS frame header
(written[0] & WsConstants.FinalBit).ShouldNotBe(0);
(written[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage);
int len = written[1] & 0x7F;
len.ShouldBe(payload.Length);
written[2..].ShouldBe(payload);
}
[Fact]
public async Task WriteAsync_WithCompression_CompressesLargePayload()
{
var inner = new MemoryStream();
var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
var payload = new byte[200];
Array.Fill<byte>(payload, 0x41); // 'A' repeated - very compressible
await ws.WriteAsync(payload);
await ws.FlushAsync();
inner.Position = 0;
var written = inner.ToArray();
// RSV1 bit should be set for compressed frame
(written[0] & WsConstants.Rsv1Bit).ShouldNotBe(0);
// Compressed size should be less than original
written.Length.ShouldBeLessThan(payload.Length + 10);
}
[Fact]
public async Task WriteAsync_SmallPayload_NotCompressedEvenWhenEnabled()
{
var inner = new MemoryStream();
var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
var payload = "Hi"u8.ToArray(); // Below CompressThreshold
await ws.WriteAsync(payload);
await ws.FlushAsync();
inner.Position = 0;
var written = inner.ToArray();
// RSV1 bit should NOT be set for small payloads
(written[0] & WsConstants.Rsv1Bit).ShouldBe(0);
}
[Fact]
public async Task ReadAsync_DecodesMaskedFrame()
{
var payload = "CONNECT {}\r\n"u8.ToArray();
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: true, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: payload.Length);
var maskKey = header[^4..];
WsFrameWriter.MaskBuf(maskKey, payload);
var frame = new byte[header.Length + payload.Length];
header.CopyTo(frame, 0);
payload.CopyTo(frame, header.Length);
var inner = new MemoryStream(frame);
var ws = new WsConnection(inner, compress: false, maskRead: true, maskWrite: false, browser: false, noCompFrag: false);
var buf = new byte[256];
int n = await ws.ReadAsync(buf);
n.ShouldBe("CONNECT {}\r\n".Length);
System.Text.Encoding.ASCII.GetString(buf, 0, n).ShouldBe("CONNECT {}\r\n");
}
[Fact]
public async Task ReadAsync_ReturnsZero_OnEndOfStream()
{
// Empty stream should return 0 (true end of stream)
var inner = new MemoryStream([]);
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
var buf = new byte[256];
int n = await ws.ReadAsync(buf);
n.ShouldBe(0);
}
private static byte[] BuildUnmaskedFrame(byte[] payload)
{
var header = new byte[2];
header[0] = (byte)(WsConstants.FinalBit | WsConstants.BinaryMessage);
header[1] = (byte)payload.Length;
var frame = new byte[2 + payload.Length];
header.CopyTo(frame, 0);
payload.CopyTo(frame, 2);
return frame;
}
}

View File

@@ -1,53 +0,0 @@
using NATS.Server.WebSocket;
using Shouldly;
namespace NATS.Server.Tests.WebSocket;
public class WsConstantsTests
{
[Fact]
public void OpCodes_MatchRfc6455()
{
WsConstants.TextMessage.ShouldBe(1);
WsConstants.BinaryMessage.ShouldBe(2);
WsConstants.CloseMessage.ShouldBe(8);
WsConstants.PingMessage.ShouldBe(9);
WsConstants.PongMessage.ShouldBe(10);
}
[Fact]
public void FrameBits_MatchRfc6455()
{
WsConstants.FinalBit.ShouldBe((byte)0x80);
WsConstants.Rsv1Bit.ShouldBe((byte)0x40);
WsConstants.MaskBit.ShouldBe((byte)0x80);
}
[Fact]
public void CloseStatusCodes_MatchRfc6455()
{
WsConstants.CloseStatusNormalClosure.ShouldBe(1000);
WsConstants.CloseStatusGoingAway.ShouldBe(1001);
WsConstants.CloseStatusProtocolError.ShouldBe(1002);
WsConstants.CloseStatusPolicyViolation.ShouldBe(1008);
WsConstants.CloseStatusMessageTooBig.ShouldBe(1009);
}
[Theory]
[InlineData(WsConstants.CloseMessage)]
[InlineData(WsConstants.PingMessage)]
[InlineData(WsConstants.PongMessage)]
public void IsControlFrame_True(int opcode)
{
WsConstants.IsControlFrame(opcode).ShouldBeTrue();
}
[Theory]
[InlineData(WsConstants.TextMessage)]
[InlineData(WsConstants.BinaryMessage)]
[InlineData(0)]
public void IsControlFrame_False(int opcode)
{
WsConstants.IsControlFrame(opcode).ShouldBeFalse();
}
}

View File

@@ -1,163 +0,0 @@
using System.Buffers.Binary;
using NATS.Server.WebSocket;
using Shouldly;
namespace NATS.Server.Tests.WebSocket;
public class WsFrameReadTests
{
/// <summary>Helper: build a single unmasked binary frame.</summary>
private static byte[] BuildFrame(byte[] payload, bool fin = true, bool compressed = false, int opcode = WsConstants.BinaryMessage, bool mask = false, byte[]? maskKey = null)
{
int payloadLen = payload.Length;
byte b0 = (byte)opcode;
if (fin) b0 |= WsConstants.FinalBit;
if (compressed) b0 |= WsConstants.Rsv1Bit;
byte b1 = 0;
if (mask) b1 |= WsConstants.MaskBit;
byte[] lenBytes;
if (payloadLen <= 125)
{
lenBytes = [(byte)(b1 | (byte)payloadLen)];
}
else if (payloadLen < 65536)
{
lenBytes = new byte[3];
lenBytes[0] = (byte)(b1 | 126);
BinaryPrimitives.WriteUInt16BigEndian(lenBytes.AsSpan(1), (ushort)payloadLen);
}
else
{
lenBytes = new byte[9];
lenBytes[0] = (byte)(b1 | 127);
BinaryPrimitives.WriteUInt64BigEndian(lenBytes.AsSpan(1), (ulong)payloadLen);
}
int totalLen = 1 + lenBytes.Length + (mask ? 4 : 0) + payloadLen;
var frame = new byte[totalLen];
frame[0] = b0;
lenBytes.CopyTo(frame.AsSpan(1));
int pos = 1 + lenBytes.Length;
if (mask && maskKey != null)
{
maskKey.CopyTo(frame.AsSpan(pos));
pos += 4;
var maskedPayload = payload.ToArray();
WsFrameWriter.MaskBuf(maskKey, maskedPayload);
maskedPayload.CopyTo(frame.AsSpan(pos));
}
else
{
payload.CopyTo(frame.AsSpan(pos));
}
return frame;
}
[Fact]
public void ReadSingleUnmaskedFrame()
{
var payload = "Hello"u8.ToArray();
var frame = BuildFrame(payload);
var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(1);
result[0].ShouldBe(payload);
}
[Fact]
public void ReadMaskedFrame()
{
var payload = "Hello"u8.ToArray();
byte[] key = [0x37, 0xFA, 0x21, 0x3D];
var frame = BuildFrame(payload, mask: true, maskKey: key);
var readInfo = new WsReadInfo(expectMask: true);
var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(1);
result[0].ShouldBe(payload);
}
[Fact]
public void Read16BitLengthFrame()
{
var payload = new byte[200];
Random.Shared.NextBytes(payload);
var frame = BuildFrame(payload);
var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(1);
result[0].ShouldBe(payload);
}
[Fact]
public void ReadPingFrame_ReturnsPongAction()
{
var frame = BuildFrame([], opcode: WsConstants.PingMessage);
var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(0); // control frames don't produce payload
readInfo.PendingControlFrames.Count.ShouldBe(1);
readInfo.PendingControlFrames[0].Opcode.ShouldBe(WsConstants.PongMessage);
}
[Fact]
public void ReadCloseFrame_ReturnsCloseAction()
{
var closePayload = new byte[2];
BinaryPrimitives.WriteUInt16BigEndian(closePayload, 1000);
var frame = BuildFrame(closePayload, opcode: WsConstants.CloseMessage);
var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(0);
readInfo.CloseReceived.ShouldBeTrue();
readInfo.CloseStatus.ShouldBe(1000);
}
[Fact]
public void ReadPongFrame_NoAction()
{
var frame = BuildFrame([], opcode: WsConstants.PongMessage);
var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(0);
readInfo.PendingControlFrames.Count.ShouldBe(0);
}
[Fact]
public void Unmask_Optimized_8ByteChunks()
{
byte[] key = [0xAA, 0xBB, 0xCC, 0xDD];
var original = new byte[32];
Random.Shared.NextBytes(original);
var masked = original.ToArray();
// Mask it
for (int i = 0; i < masked.Length; i++)
masked[i] ^= key[i & 3];
// Unmask using the state machine
var info = new WsReadInfo(expectMask: true);
info.SetMaskKey(key);
info.Unmask(masked);
masked.ShouldBe(original);
}
}

View File

@@ -1,152 +0,0 @@
using System.Buffers.Binary;
using NATS.Server.WebSocket;
using Shouldly;
namespace NATS.Server.Tests.WebSocket;
public class WsFrameWriterTests
{
[Fact]
public void CreateFrameHeader_SmallPayload_7BitLength()
{
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: false, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: 100);
header.Length.ShouldBe(2);
(header[0] & WsConstants.FinalBit).ShouldNotBe(0); // FIN set
(header[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage);
(header[1] & 0x7F).ShouldBe(100);
}
[Fact]
public void CreateFrameHeader_MediumPayload_16BitLength()
{
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: false, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: 1000);
header.Length.ShouldBe(4);
(header[1] & 0x7F).ShouldBe(126);
BinaryPrimitives.ReadUInt16BigEndian(header.AsSpan(2)).ShouldBe((ushort)1000);
}
[Fact]
public void CreateFrameHeader_LargePayload_64BitLength()
{
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: false, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: 70000);
header.Length.ShouldBe(10);
(header[1] & 0x7F).ShouldBe(127);
BinaryPrimitives.ReadUInt64BigEndian(header.AsSpan(2)).ShouldBe(70000UL);
}
[Fact]
public void CreateFrameHeader_WithMasking_Adds4ByteKey()
{
var (header, key) = WsFrameWriter.CreateFrameHeader(
useMasking: true, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: 10);
header.Length.ShouldBe(6); // 2 header + 4 mask key
(header[1] & WsConstants.MaskBit).ShouldNotBe(0);
key.ShouldNotBeNull();
key.Length.ShouldBe(4);
}
[Fact]
public void CreateFrameHeader_Compressed_SetsRsv1Bit()
{
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: false, compressed: true,
opcode: WsConstants.BinaryMessage, payloadLength: 10);
(header[0] & WsConstants.Rsv1Bit).ShouldNotBe(0);
}
[Fact]
public void MaskBuf_XorsCorrectly()
{
byte[] key = [0xAA, 0xBB, 0xCC, 0xDD];
byte[] data = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06];
byte[] expected = new byte[data.Length];
for (int i = 0; i < data.Length; i++)
expected[i] = (byte)(data[i] ^ key[i & 3]);
WsFrameWriter.MaskBuf(key, data);
data.ShouldBe(expected);
}
[Fact]
public void MaskBuf_RoundTrip()
{
byte[] key = [0x12, 0x34, 0x56, 0x78];
byte[] original = "Hello, WebSocket!"u8.ToArray();
var data = original.ToArray();
WsFrameWriter.MaskBuf(key, data);
data.ShouldNotBe(original);
WsFrameWriter.MaskBuf(key, data);
data.ShouldBe(original);
}
[Fact]
public void CreateCloseMessage_WithStatusAndBody()
{
var msg = WsFrameWriter.CreateCloseMessage(1000, "normal closure");
msg.Length.ShouldBe(2 + "normal closure".Length);
BinaryPrimitives.ReadUInt16BigEndian(msg).ShouldBe((ushort)1000);
}
[Fact]
public void CreateCloseMessage_LongBody_Truncated()
{
var longBody = new string('x', 200);
var msg = WsFrameWriter.CreateCloseMessage(1000, longBody);
msg.Length.ShouldBeLessThanOrEqualTo(WsConstants.MaxControlPayloadSize);
}
[Fact]
public void MapCloseStatus_ClientClosed_NormalClosure()
{
WsFrameWriter.MapCloseStatus(ClientClosedReason.ClientClosed)
.ShouldBe(WsConstants.CloseStatusNormalClosure);
}
[Fact]
public void MapCloseStatus_AuthTimeout_PolicyViolation()
{
WsFrameWriter.MapCloseStatus(ClientClosedReason.AuthenticationTimeout)
.ShouldBe(WsConstants.CloseStatusPolicyViolation);
}
[Fact]
public void MapCloseStatus_ParseError_ProtocolError()
{
WsFrameWriter.MapCloseStatus(ClientClosedReason.ParseError)
.ShouldBe(WsConstants.CloseStatusProtocolError);
}
[Fact]
public void MapCloseStatus_MaxPayload_MessageTooBig()
{
WsFrameWriter.MapCloseStatus(ClientClosedReason.MaxPayloadExceeded)
.ShouldBe(WsConstants.CloseStatusMessageTooBig);
}
[Fact]
public void BuildControlFrame_PingNomask()
{
var frame = WsFrameWriter.BuildControlFrame(WsConstants.PingMessage, [], useMasking: false);
frame.Length.ShouldBe(2);
(frame[0] & WsConstants.FinalBit).ShouldNotBe(0);
(frame[0] & 0x0F).ShouldBe(WsConstants.PingMessage);
(frame[1] & 0x7F).ShouldBe(0);
}
[Fact]
public void BuildControlFrame_PongWithPayload()
{
byte[] payload = [1, 2, 3, 4];
var frame = WsFrameWriter.BuildControlFrame(WsConstants.PongMessage, payload, useMasking: false);
frame.Length.ShouldBe(2 + 4);
frame[2..].ShouldBe(payload);
}
}

View File

@@ -1,162 +0,0 @@
using System.Buffers.Binary;
using System.Net;
using System.Net.Sockets;
using System.Security.Cryptography;
using System.Text;
using NATS.Server.WebSocket;
namespace NATS.Server.Tests.WebSocket;
public class WsIntegrationTests : IAsyncLifetime
{
private NatsServer _server = null!;
private NatsOptions _options = null!;
public async Task InitializeAsync()
{
_options = new NatsOptions
{
Port = 0,
WebSocket = new WebSocketOptions { Port = 0, NoTls = true },
};
var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(b => { });
_server = new NatsServer(_options, loggerFactory);
_ = _server.StartAsync(CancellationToken.None);
await _server.WaitForReadyAsync();
}
public async Task DisposeAsync()
{
await _server.ShutdownAsync();
_server.Dispose();
}
[Fact]
public async Task WebSocket_ConnectAndReceiveInfo()
{
using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port));
using var stream = new NetworkStream(socket, ownsSocket: false);
await SendUpgradeRequest(stream);
var response = await ReadHttpResponse(stream);
response.ShouldContain("101");
var wsFrame = await ReadWsFrame(stream);
var info = Encoding.ASCII.GetString(wsFrame);
info.ShouldStartWith("INFO ");
}
[Fact]
public async Task WebSocket_ConnectAndPing()
{
using var client = await ConnectWsClient();
// Send CONNECT and PING together
await SendWsText(client, "CONNECT {}\r\nPING\r\n");
// Read PONG WS frame
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
var pong = await ReadWsFrameAsync(client, cts.Token);
Encoding.ASCII.GetString(pong).ShouldContain("PONG");
}
[Fact]
public async Task WebSocket_PubSub()
{
using var sub = await ConnectWsClient();
using var pub = await ConnectWsClient();
await SendWsText(sub, "CONNECT {}\r\nSUB test.ws 1\r\n");
await Task.Delay(200);
await SendWsText(pub, "CONNECT {}\r\nPUB test.ws 5\r\nHello\r\n");
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
var msg = await ReadWsFrameAsync(sub, cts.Token);
Encoding.ASCII.GetString(msg).ShouldContain("MSG test.ws 1 5");
}
private async Task<NetworkStream> ConnectWsClient()
{
var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port));
var stream = new NetworkStream(socket, ownsSocket: true);
await SendUpgradeRequest(stream);
var response = await ReadHttpResponse(stream);
response.ShouldContain("101");
await ReadWsFrame(stream); // Read INFO frame
return stream;
}
private static async Task SendUpgradeRequest(NetworkStream stream)
{
var keyBytes = new byte[16];
RandomNumberGenerator.Fill(keyBytes);
var key = Convert.ToBase64String(keyBytes);
var request = $"GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {key}\r\nSec-WebSocket-Version: 13\r\n\r\n";
await stream.WriteAsync(Encoding.ASCII.GetBytes(request));
await stream.FlushAsync();
}
private static async Task<string> ReadHttpResponse(NetworkStream stream)
{
// Read one byte at a time to avoid consuming WS frame bytes that follow the HTTP response
var sb = new StringBuilder();
var buf = new byte[1];
while (true)
{
int n = await stream.ReadAsync(buf);
if (n == 0) break;
sb.Append((char)buf[0]);
if (sb.Length >= 4 &&
sb[^4] == '\r' && sb[^3] == '\n' &&
sb[^2] == '\r' && sb[^1] == '\n')
break;
}
return sb.ToString();
}
private static Task<byte[]> ReadWsFrame(NetworkStream stream)
=> ReadWsFrameAsync(stream, CancellationToken.None);
private static async Task<byte[]> ReadWsFrameAsync(NetworkStream stream, CancellationToken ct)
{
var header = new byte[2];
await stream.ReadExactlyAsync(header, ct);
int len = header[1] & 0x7F;
if (len == 126)
{
var extLen = new byte[2];
await stream.ReadExactlyAsync(extLen, ct);
len = BinaryPrimitives.ReadUInt16BigEndian(extLen);
}
else if (len == 127)
{
var extLen = new byte[8];
await stream.ReadExactlyAsync(extLen, ct);
len = (int)BinaryPrimitives.ReadUInt64BigEndian(extLen);
}
var payload = new byte[len];
if (len > 0) await stream.ReadExactlyAsync(payload, ct);
return payload;
}
private static async Task SendWsText(NetworkStream stream, string text)
{
var payload = Encoding.ASCII.GetBytes(text);
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: true, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: payload.Length);
var maskKey = header[^4..];
WsFrameWriter.MaskBuf(maskKey, payload);
await stream.WriteAsync(header);
await stream.WriteAsync(payload);
await stream.FlushAsync();
}
}

View File

@@ -1,82 +0,0 @@
using NATS.Server.WebSocket;
using Shouldly;
namespace NATS.Server.Tests.WebSocket;
public class WsOriginCheckerTests
{
[Fact]
public void NoOriginHeader_Accepted()
{
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
checker.CheckOrigin(origin: null, requestHost: "localhost:4222", isTls: false)
.ShouldBeNull();
}
[Fact]
public void NeitherSameNorList_AlwaysAccepted()
{
var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: null);
checker.CheckOrigin("https://evil.com", "localhost:4222", false)
.ShouldBeNull();
}
[Fact]
public void SameOrigin_Match()
{
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
checker.CheckOrigin("http://localhost:4222", "localhost:4222", false)
.ShouldBeNull();
}
[Fact]
public void SameOrigin_Mismatch()
{
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
checker.CheckOrigin("http://other:4222", "localhost:4222", false)
.ShouldNotBeNull();
}
[Fact]
public void SameOrigin_DefaultPort_Http()
{
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
checker.CheckOrigin("http://localhost", "localhost:80", false)
.ShouldBeNull();
}
[Fact]
public void SameOrigin_DefaultPort_Https()
{
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
checker.CheckOrigin("https://localhost", "localhost:443", true)
.ShouldBeNull();
}
[Fact]
public void AllowedOrigins_Match()
{
var checker = new WsOriginChecker(sameOrigin: false,
allowedOrigins: ["https://app.example.com"]);
checker.CheckOrigin("https://app.example.com", "localhost:4222", false)
.ShouldBeNull();
}
[Fact]
public void AllowedOrigins_Mismatch()
{
var checker = new WsOriginChecker(sameOrigin: false,
allowedOrigins: ["https://app.example.com"]);
checker.CheckOrigin("https://evil.example.com", "localhost:4222", false)
.ShouldNotBeNull();
}
[Fact]
public void AllowedOrigins_SchemeMismatch()
{
var checker = new WsOriginChecker(sameOrigin: false,
allowedOrigins: ["https://app.example.com"]);
checker.CheckOrigin("http://app.example.com", "localhost:4222", false)
.ShouldNotBeNull();
}
}

View File

@@ -1,226 +0,0 @@
using System.Text;
using NATS.Server.WebSocket;
namespace NATS.Server.Tests.WebSocket;
public class WsUpgradeTests
{
private static string BuildValidRequest(string path = "/", string? extraHeaders = null)
{
var sb = new StringBuilder();
sb.Append($"GET {path} HTTP/1.1\r\n");
sb.Append("Host: localhost:4222\r\n");
sb.Append("Upgrade: websocket\r\n");
sb.Append("Connection: Upgrade\r\n");
sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
sb.Append("Sec-WebSocket-Version: 13\r\n");
if (extraHeaders != null)
sb.Append(extraHeaders);
sb.Append("\r\n");
return sb.ToString();
}
[Fact]
public async Task ValidUpgrade_Returns101()
{
var request = BuildValidRequest();
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Kind.ShouldBe(WsClientKind.Client);
var response = ReadResponse(outputStream);
response.ShouldContain("HTTP/1.1 101");
response.ShouldContain("Upgrade: websocket");
response.ShouldContain("Sec-WebSocket-Accept:");
}
[Fact]
public async Task MissingUpgradeHeader_Returns400()
{
var request = "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeFalse();
ReadResponse(outputStream).ShouldContain("400");
}
[Fact]
public async Task MissingHost_Returns400()
{
var request = "GET / HTTP/1.1\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeFalse();
}
[Fact]
public async Task WrongVersion_Returns400()
{
var request = "GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 12\r\n\r\n";
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeFalse();
}
[Fact]
public async Task LeafNodePath_ReturnsLeafKind()
{
var request = BuildValidRequest("/leafnode");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Kind.ShouldBe(WsClientKind.Leaf);
}
[Fact]
public async Task MqttPath_ReturnsMqttKind()
{
var request = BuildValidRequest("/mqtt");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Kind.ShouldBe(WsClientKind.Mqtt);
}
[Fact]
public async Task CompressionNegotiation_WhenEnabled()
{
var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}; {WsConstants.PmcSrvNoCtx}; {WsConstants.PmcCliNoCtx}\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true });
result.Success.ShouldBeTrue();
result.Compress.ShouldBeTrue();
ReadResponse(outputStream).ShouldContain("permessage-deflate");
}
[Fact]
public async Task CompressionNegotiation_WhenDisabled()
{
var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = false });
result.Success.ShouldBeTrue();
result.Compress.ShouldBeFalse();
}
[Fact]
public async Task NoMaskingHeader_ForLeaf()
{
var request = BuildValidRequest("/leafnode", "Nats-No-Masking: true\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.MaskRead.ShouldBeFalse();
}
[Fact]
public async Task BrowserDetection_Mozilla()
{
var request = BuildValidRequest(extraHeaders: "User-Agent: Mozilla/5.0 (Windows)\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Browser.ShouldBeTrue();
}
[Fact]
public async Task SafariDetection_NoCompFrag()
{
var request = BuildValidRequest(extraHeaders:
"User-Agent: Mozilla/5.0 (Macintosh) Version/15.0 Safari/605.1.15\r\n" +
$"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true });
result.Success.ShouldBeTrue();
result.NoCompFrag.ShouldBeTrue();
}
[Fact]
public void AcceptKey_MatchesRfc6455Example()
{
// RFC 6455 Section 4.2.2 example
var key = WsUpgrade.ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ==");
key.ShouldBe("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
}
[Fact]
public async Task CookieExtraction()
{
var request = BuildValidRequest(extraHeaders:
"Cookie: jwt_token=my-jwt; nats_user=admin; nats_pass=secret\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var opts = new WebSocketOptions
{
NoTls = true,
JwtCookie = "jwt_token",
UsernameCookie = "nats_user",
PasswordCookie = "nats_pass",
};
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
result.Success.ShouldBeTrue();
result.CookieJwt.ShouldBe("my-jwt");
result.CookieUsername.ShouldBe("admin");
result.CookiePassword.ShouldBe("secret");
}
[Fact]
public async Task XForwardedFor_ExtractsClientIp()
{
var request = BuildValidRequest(extraHeaders: "X-Forwarded-For: 192.168.1.100\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.ClientIp.ShouldBe("192.168.1.100");
}
[Fact]
public async Task PostMethod_Returns405()
{
var request = "POST / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeFalse();
ReadResponse(outputStream).ShouldContain("405");
}
// Helper: create a readable input stream and writable output stream
private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest)
{
var inputBytes = Encoding.ASCII.GetBytes(httpRequest);
return (new MemoryStream(inputBytes), new MemoryStream());
}
private static string ReadResponse(MemoryStream output)
{
output.Position = 0;
return Encoding.ASCII.GetString(output.ToArray());
}
}