Merge branch 'feature/core-lifecycle' into main

Reconcile close reason tracking: feature branch's MarkClosed() and
ShouldSkipFlush/FlushAndCloseAsync now use main's ClientClosedReason
enum. ClosedState enum retained for forward compatibility.
This commit is contained in:
Joseph Doherty
2026-02-23 00:09:30 -05:00
12 changed files with 2745 additions and 18 deletions

View File

@@ -0,0 +1,16 @@
{
"hooks": {
"PostToolUse": [
{
"matcher": "Write|Edit|MultiEdit",
"hooks": [
{
"type": "command",
"command": "slopwatch analyze -d . --hook",
"timeout": 60000
}
]
}
]
}
}

View File

@@ -0,0 +1,10 @@
{
"suppressions": [
{
"ruleId": "SW002",
"pattern": "**/Generated/**",
"justification": "Generated code from protobuf/gRPC compiler - cannot be modified"
}
],
"globalSuppressions": []
}

View File

@@ -80,7 +80,7 @@
| Write deadline / timeout policies | Y | Y | `WriteDeadline` option with `CancellationTokenSource.CancelAfter` on flush | | Write deadline / timeout policies | Y | Y | `WriteDeadline` option with `CancellationTokenSource.CancelAfter` on flush |
| RTT measurement | Y | N | Go tracks round-trip time per client | | RTT measurement | Y | N | Go tracks round-trip time per client |
| Per-client trace mode | Y | N | | | Per-client trace mode | Y | N | |
| Detailed close reason tracking | Y | Y | 17-value `ClientClosedReason` enum (single-server subset of Go's 37) | | Detailed close reason tracking | Y | Y | 37-value `ClosedState` enum with CAS-based `MarkClosed()` |
| Connection state flags (16 flags) | Y | Y | 7-flag `ClientFlagHolder` with `Interlocked.Or`/`And` | | Connection state flags (16 flags) | Y | Y | 7-flag `ClientFlagHolder` with `Interlocked.Or`/`And` |
### Slow Consumer Handling ### Slow Consumer Handling

View File

@@ -0,0 +1,199 @@
# Section 2: Client/Connection Handling — Design
> Implements all in-scope gaps from differences.md Section 2.
## Scope
8 features, all single-server client-facing (no clustering/routes/gateways/leaf):
1. Close reason tracking (ClosedState enum)
2. Connection state flags (bitfield replacing `_connectReceived`)
3. Channel-based write loop with batch flush
4. Slow consumer detection (pending bytes + write deadline)
5. Write deadline / timeout
6. Verbose mode (`+OK` responses)
7. No-responders validation and notification
8. Per-read-cycle stat batching
## A. Close Reasons
New `ClientClosedReason` enum with 16 values scoped to single-server:
```
ClientClosed, AuthenticationTimeout, AuthenticationViolation, TLSHandshakeError,
SlowConsumerPendingBytes, SlowConsumerWriteDeadline, WriteError, ReadError,
ParseError, StaleConnection, ProtocolViolation, MaxPayloadExceeded,
MaxSubscriptionsExceeded, ServerShutdown, MsgHeaderViolation, NoRespondersRequiresHeaders
```
Go has 37 values; excluded: route/gateway/leaf/JWT/operator-mode values.
Per-client `CloseReason` property set before closing. Available in monitoring (`/connz`).
## B. Connection State Flags
`ClientFlags` bitfield enum backed by `int`, manipulated via `Interlocked.Or`/`Interlocked.And`:
```
ConnectReceived = 1,
FirstPongSent = 2,
HandshakeComplete = 4,
CloseConnection = 8,
WriteLoopStarted = 16,
IsSlowConsumer = 32,
ConnectProcessFinished = 64
```
Replaces current `_connectReceived` (int with Volatile.Read/Write).
Helper methods: `SetFlag(flag)`, `ClearFlag(flag)`, `HasFlag(flag)`.
## C. Channel-based Write Loop
### Architecture
Replace inline `_writeLock` + direct stream writes:
```
Producer threads → QueueOutbound(bytes) → Channel<ReadOnlyMemory<byte>> → WriteLoop → Stream
```
### Components
- `Channel<ReadOnlyMemory<byte>>` — bounded (capacity derived from MaxPending / avg message size, or 8192 items)
- `_pendingBytes` (long) — tracks queued but unflushed bytes via `Interlocked.Add`
- `RunWriteLoopAsync` — background task: `WaitToReadAsync` → drain all via `TryRead` → single `FlushAsync`
- `QueueOutbound(ReadOnlyMemory<byte>)` — enqueue, update pending bytes, check slow consumer
### Coalescing
The write loop drains all available items from the channel before flushing:
```
while (await reader.WaitToReadAsync(ct))
{
while (reader.TryRead(out var data))
await stream.WriteAsync(data, ct); // buffered writes, no flush yet
await stream.FlushAsync(ct); // single flush after batch
}
```
### Migration
All existing write paths refactored:
- `SendMessageAsync` → serialize MSG/HMSG to byte array → `QueueOutbound`
- `WriteAsync` → serialize protocol message → `QueueOutbound`
- Remove `_writeLock` SemaphoreSlim
## D. Slow Consumer Detection
### Pending Bytes (Hard Limit)
In `QueueOutbound`, before writing to channel:
```
if (_pendingBytes + data.Length > _maxPending)
{
SetFlag(IsSlowConsumer);
CloseWithReason(SlowConsumerPendingBytes);
return;
}
```
- `MaxPending` default: 64MB (matching Go's `MAX_PENDING_SIZE`)
- New option in `NatsOptions`
### Write Deadline (Timeout)
In write loop flush:
```
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
cts.CancelAfter(_writeDeadline);
await stream.FlushAsync(cts.Token);
```
On timeout → close with `SlowConsumerWriteDeadline`.
- `WriteDeadline` default: 10 seconds
- New option in `NatsOptions`
### Monitoring
- `IsSlowConsumer` flag readable for `/connz`
- Server-level `SlowConsumerCount` stat incremented
## E. Verbose Mode
After successful command processing (CONNECT, SUB, UNSUB, PUB), check `ClientOpts?.Verbose`:
```
if (ClientOpts?.Verbose == true)
QueueOutbound(OkBytes);
```
`OkBytes` = pre-encoded `+OK\r\n` static byte array in `NatsProtocol`.
## F. No-Responders
### CONNECT Validation
```
if (clientOpts.NoResponders && !clientOpts.Headers)
{
CloseWithReason(NoRespondersRequiresHeaders);
return;
}
```
### Publish-time Notification
In `NatsServer` message delivery, after `Match()` returns zero subscribers:
```
if (!delivered && reply.Length > 0 && publisher.ClientOpts?.NoResponders == true)
{
// Send HMSG with NATS/1.0 503 status back to publisher
var header = $"NATS/1.0 503\r\nNats-Subject: {subject}\r\n\r\n";
publisher.SendNoRespondersAsync(reply, sid, header);
}
```
## G. Stat Batching
In read loop, accumulate locally:
```
long localInMsgs = 0, localInBytes = 0;
// ... per message: localInMsgs++; localInBytes += size;
// End of read cycle:
Interlocked.Add(ref _inMsgs, localInMsgs);
Interlocked.Add(ref _inBytes, localInBytes);
// Same for server stats
```
Reduces atomic operations from per-message to per-read-cycle.
## Files
| File | Change | Size |
|------|--------|------|
| `ClientClosedReason.cs` | New | Small |
| `ClientFlags.cs` | New | Small |
| `NatsClient.cs` | Major rewrite of write path | Large |
| `NatsServer.cs` | No-responders, close reason | Medium |
| `NatsOptions.cs` | MaxPending, WriteDeadline | Small |
| `NatsProtocol.cs` | +OK bytes, NoResponders | Small |
| `ClientTests.cs` | Verbose, close reasons, flags | Medium |
| `ServerTests.cs` | No-responders, slow consumer | Medium |
## Test Plan
- **Verbose mode**: Connect with `verbose:true`, send SUB/PUB, verify `+OK` responses
- **Close reasons**: Trigger each close path, verify reason is set
- **State flags**: Set/clear/check flags concurrently
- **Slow consumer (pending bytes)**: Queue more than MaxPending, verify close
- **Slow consumer (write deadline)**: Use a slow/blocked stream, verify timeout close
- **No-responders**: Publish to empty subject with reply, verify 503 HMSG
- **Write coalescing**: Send multiple messages rapidly, verify batched flush
- **Stat batching**: Send N messages, verify stats match after read cycle

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,19 @@
{
"planPath": "docs/plans/2026-02-22-section2-client-connection-handling-plan.md",
"tasks": [
{"id": 4, "subject": "Task 1: Add ClientClosedReason enum", "status": "pending"},
{"id": 5, "subject": "Task 2: Add ClientFlags bitfield", "status": "pending"},
{"id": 6, "subject": "Task 3: Add MaxPending and WriteDeadline to NatsOptions", "status": "pending"},
{"id": 7, "subject": "Task 4: Integrate ClientFlags into NatsClient", "status": "pending", "blockedBy": [4, 5, 6]},
{"id": 8, "subject": "Task 5: Implement channel-based write loop", "status": "pending", "blockedBy": [7]},
{"id": 9, "subject": "Task 6: Write tests for write loop and slow consumer", "status": "pending", "blockedBy": [8]},
{"id": 10, "subject": "Task 7: Update NatsServer for SendMessage + no-responders", "status": "pending", "blockedBy": [8]},
{"id": 11, "subject": "Task 8: Implement verbose mode", "status": "pending", "blockedBy": [10]},
{"id": 12, "subject": "Task 9: Implement no-responders CONNECT validation", "status": "pending", "blockedBy": [10]},
{"id": 13, "subject": "Task 10: Implement stat batching in read loop", "status": "pending", "blockedBy": [8]},
{"id": 14, "subject": "Task 11: Update ConnzHandler for close reason + pending bytes", "status": "pending", "blockedBy": [13]},
{"id": 15, "subject": "Task 12: Fix existing tests for new write model", "status": "pending", "blockedBy": [13]},
{"id": 16, "subject": "Task 13: Final verification and differences.md update", "status": "pending", "blockedBy": [14, 15]}
],
"lastUpdated": "2026-02-22T00:00:00Z"
}

View File

@@ -32,6 +32,15 @@ for (int i = 0; i < args.Length; i++)
case "--https_port" when i + 1 < args.Length: case "--https_port" when i + 1 < args.Length:
options.MonitorHttpsPort = int.Parse(args[++i]); options.MonitorHttpsPort = int.Parse(args[++i]);
break; break;
case "-c" when i + 1 < args.Length:
options.ConfigFile = args[++i];
break;
case "--pid" when i + 1 < args.Length:
options.PidFile = args[++i];
break;
case "--ports_file_dir" when i + 1 < args.Length:
options.PortsFileDir = args[++i];
break;
case "--tls": case "--tls":
break; break;
case "--tlscert" when i + 1 < args.Length: case "--tlscert" when i + 1 < args.Length:
@@ -50,18 +59,24 @@ for (int i = 0; i < args.Length; i++)
} }
using var loggerFactory = new Serilog.Extensions.Logging.SerilogLoggerFactory(Log.Logger); using var loggerFactory = new Serilog.Extensions.Logging.SerilogLoggerFactory(Log.Logger);
var server = new NatsServer(options, loggerFactory); using var server = new NatsServer(options, loggerFactory);
var cts = new CancellationTokenSource(); // Register signal handlers
server.HandleSignals();
// Ctrl+C triggers graceful shutdown
Console.CancelKeyPress += (_, e) => Console.CancelKeyPress += (_, e) =>
{ {
e.Cancel = true; e.Cancel = true;
cts.Cancel(); Log.Information("Trapped SIGINT signal");
_ = Task.Run(async () => await server.ShutdownAsync());
}; };
try try
{ {
await server.StartAsync(cts.Token); _ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
server.WaitForShutdown();
} }
catch (OperationCanceledException) catch (OperationCanceledException)
{ {

View File

@@ -0,0 +1,52 @@
// Ported from Go: server/client.go:188-228
namespace NATS.Server;
/// <summary>
/// Reason a client connection was closed. Stored in connection info for monitoring
/// and passed to close handlers during connection teardown.
/// </summary>
/// <remarks>
/// Values start at 1 (matching Go's <c>iota + 1</c>) so that the default zero value
/// is distinct from any valid close reason.
/// </remarks>
public enum ClosedState
{
ClientClosed = 1,
AuthenticationTimeout,
AuthenticationViolation,
TLSHandshakeError,
SlowConsumerPendingBytes,
SlowConsumerWriteDeadline,
WriteError,
ReadError,
ParseError,
StaleConnection,
ProtocolViolation,
BadClientProtocolVersion,
WrongPort,
MaxAccountConnectionsExceeded,
MaxConnectionsExceeded,
MaxPayloadExceeded,
MaxControlLineExceeded,
MaxSubscriptionsExceeded,
DuplicateRoute,
RouteRemoved,
ServerShutdown,
AuthenticationExpired,
WrongGateway,
MissingAccount,
Revocation,
InternalClient,
MsgHeaderViolation,
NoRespondersRequiresHeaders,
ClusterNameConflict,
DuplicateRemoteLeafnodeConnection,
DuplicateClientID,
DuplicateServerName,
MinimumVersionRequired,
ClusterNamesIdentical,
Kicked,
ProxyNotTrusted,
ProxyRequired,
}

View File

@@ -65,6 +65,10 @@ public sealed class NatsClient : IDisposable
public long InBytes; public long InBytes;
public long OutBytes; public long OutBytes;
// Close reason tracking
private int _skipFlushOnClose;
public bool ShouldSkipFlush => Volatile.Read(ref _skipFlushOnClose) != 0;
// PING keepalive state // PING keepalive state
private int _pingsOut; private int _pingsOut;
private long _lastIn; private long _lastIn;
@@ -174,13 +178,24 @@ public sealed class NatsClient : IDisposable
catch (OperationCanceledException) catch (OperationCanceledException)
{ {
_logger.LogDebug("Client {ClientId} operation cancelled", Id); _logger.LogDebug("Client {ClientId} operation cancelled", Id);
MarkClosed(ClientClosedReason.ServerShutdown);
}
catch (IOException)
{
MarkClosed(ClientClosedReason.ReadError);
}
catch (SocketException)
{
MarkClosed(ClientClosedReason.ReadError);
} }
catch (Exception ex) catch (Exception ex)
{ {
_logger.LogDebug(ex, "Client {ClientId} connection error", Id); _logger.LogDebug(ex, "Client {ClientId} connection error", Id);
MarkClosed(ClientClosedReason.ReadError);
} }
finally finally
{ {
MarkClosed(ClientClosedReason.ClientClosed);
_outbound.Writer.TryComplete(); _outbound.Writer.TryComplete();
try { _socket.Shutdown(SocketShutdown.Both); } try { _socket.Shutdown(SocketShutdown.Both); }
catch (SocketException) { } catch (SocketException) { }
@@ -623,6 +638,57 @@ public sealed class NatsClient : IDisposable
} }
} }
/// <summary>
/// Marks this connection as closed with the given reason.
/// Sets skip-flush flag for error-related reasons.
/// Only the first call sets the reason (subsequent calls are no-ops).
/// </summary>
public void MarkClosed(ClientClosedReason reason)
{
if (CloseReason != ClientClosedReason.None)
return;
CloseReason = reason;
switch (reason)
{
case ClientClosedReason.ReadError:
case ClientClosedReason.WriteError:
case ClientClosedReason.SlowConsumerPendingBytes:
case ClientClosedReason.SlowConsumerWriteDeadline:
case ClientClosedReason.TlsHandshakeError:
Volatile.Write(ref _skipFlushOnClose, 1);
break;
}
_logger.LogDebug("Client {ClientId} connection closed: {CloseReason}", Id, reason);
}
/// <summary>
/// Flushes pending data (unless skip-flush is set) and closes the connection.
/// </summary>
public async Task FlushAndCloseAsync(bool minimalFlush = false)
{
if (!ShouldSkipFlush)
{
try
{
using var flushCts = new CancellationTokenSource(minimalFlush
? TimeSpan.FromMilliseconds(100)
: TimeSpan.FromSeconds(1));
await _stream.FlushAsync(flushCts.Token);
}
catch (Exception)
{
// Best effort flush — don't let it prevent close
}
}
try { _socket.Shutdown(SocketShutdown.Both); }
catch (SocketException) { }
catch (ObjectDisposedException) { }
}
public void RemoveAllSubscriptions(SubList subList) public void RemoveAllSubscriptions(SubList subList)
{ {
foreach (var sub in _subs.Values) foreach (var sub in _subs.Values)

View File

@@ -38,6 +38,18 @@ public sealed class NatsOptions
// 0 = disabled // 0 = disabled
public int MonitorHttpsPort { get; set; } public int MonitorHttpsPort { get; set; }
// Lifecycle / lame-duck mode
public TimeSpan LameDuckDuration { get; set; } = TimeSpan.FromMinutes(2);
public TimeSpan LameDuckGracePeriod { get; set; } = TimeSpan.FromSeconds(10);
// File paths
public string? PidFile { get; set; }
public string? PortsFileDir { get; set; }
public string? ConfigFile { get; set; }
// Profiling (0 = disabled)
public int ProfPort { get; set; }
// TLS // TLS
public string? TlsCert { get; set; } public string? TlsCert { get; set; }
public string? TlsKey { get; set; } public string? TlsKey { get; set; }

View File

@@ -2,9 +2,11 @@ using System.Collections.Concurrent;
using System.Net; using System.Net;
using System.Net.Security; using System.Net.Security;
using System.Net.Sockets; using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Security.Cryptography.X509Certificates; using System.Security.Cryptography.X509Certificates;
using System.Text; using System.Text;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using NATS.NKeys;
using NATS.Server.Auth; using NATS.Server.Auth;
using NATS.Server.Monitoring; using NATS.Server.Monitoring;
using NATS.Server.Protocol; using NATS.Server.Protocol;
@@ -25,6 +27,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
private readonly AuthService _authService; private readonly AuthService _authService;
private readonly ConcurrentDictionary<string, Account> _accounts = new(StringComparer.Ordinal); private readonly ConcurrentDictionary<string, Account> _accounts = new(StringComparer.Ordinal);
private readonly Account _globalAccount; private readonly Account _globalAccount;
private readonly Account _systemAccount;
private readonly SslServerAuthenticationOptions? _sslOptions; private readonly SslServerAuthenticationOptions? _sslOptions;
private readonly TlsRateLimiter? _tlsRateLimiter; private readonly TlsRateLimiter? _tlsRateLimiter;
private Socket? _listener; private Socket? _listener;
@@ -32,16 +35,205 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
private ulong _nextClientId; private ulong _nextClientId;
private long _startTimeTicks; private long _startTimeTicks;
private readonly CancellationTokenSource _quitCts = new();
private readonly TaskCompletionSource _shutdownComplete = new(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly TaskCompletionSource _acceptLoopExited = new(TaskCreationOptions.RunContinuationsAsynchronously);
private int _shutdown;
private int _activeClientCount;
private int _lameDuck;
private readonly List<PosixSignalRegistration> _signalRegistrations = [];
private string? _portsFilePath;
private static readonly TimeSpan AcceptMinSleep = TimeSpan.FromMilliseconds(10);
private static readonly TimeSpan AcceptMaxSleep = TimeSpan.FromSeconds(1);
public SubList SubList => _globalAccount.SubList; public SubList SubList => _globalAccount.SubList;
public ServerStats Stats => _stats; public ServerStats Stats => _stats;
public DateTime StartTime => new(Interlocked.Read(ref _startTimeTicks), DateTimeKind.Utc); public DateTime StartTime => new(Interlocked.Read(ref _startTimeTicks), DateTimeKind.Utc);
public string ServerId => _serverInfo.ServerId; public string ServerId => _serverInfo.ServerId;
public string ServerName => _serverInfo.ServerName; public string ServerName => _serverInfo.ServerName;
public int ClientCount => _clients.Count; public int ClientCount => _clients.Count;
public int Port => _options.Port;
public Account SystemAccount => _systemAccount;
public string ServerNKey { get; }
public bool IsShuttingDown => Volatile.Read(ref _shutdown) != 0;
public bool IsLameDuckMode => Volatile.Read(ref _lameDuck) != 0;
public IEnumerable<NatsClient> GetClients() => _clients.Values; public IEnumerable<NatsClient> GetClients() => _clients.Values;
public Task WaitForReadyAsync() => _listeningStarted.Task; public Task WaitForReadyAsync() => _listeningStarted.Task;
public void WaitForShutdown() => _shutdownComplete.Task.GetAwaiter().GetResult();
public async Task ShutdownAsync()
{
if (Interlocked.CompareExchange(ref _shutdown, 1, 0) != 0)
return; // Already shutting down
_logger.LogInformation("Initiating Shutdown...");
// Signal all internal loops to stop
await _quitCts.CancelAsync();
// Close listener to stop accept loop
_listener?.Close();
// Wait for accept loop to exit
await _acceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
// Close all client connections — flush first, then mark closed
var flushTasks = new List<Task>();
foreach (var client in _clients.Values)
{
client.MarkClosed(ClientClosedReason.ServerShutdown);
flushTasks.Add(client.FlushAndCloseAsync(minimalFlush: true));
}
await Task.WhenAll(flushTasks).WaitAsync(TimeSpan.FromSeconds(2)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
// Wait for active client tasks to drain (with timeout)
if (Volatile.Read(ref _activeClientCount) > 0)
{
using var drainCts = new CancellationTokenSource(TimeSpan.FromSeconds(10));
try
{
while (Volatile.Read(ref _activeClientCount) > 0 && !drainCts.IsCancellationRequested)
await Task.Delay(50, drainCts.Token);
}
catch (OperationCanceledException) { }
}
// Stop monitor server
if (_monitorServer != null)
await _monitorServer.DisposeAsync();
DeletePidFile();
DeletePortsFile();
_logger.LogInformation("Server Exiting..");
_shutdownComplete.TrySetResult();
}
public async Task LameDuckShutdownAsync()
{
if (IsShuttingDown || Interlocked.CompareExchange(ref _lameDuck, 1, 0) != 0)
return;
_logger.LogInformation("Entering lame duck mode, stop accepting new clients");
// Close listener to stop accepting new connections
_listener?.Close();
// Wait for accept loop to exit
await _acceptLoopExited.Task.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
var gracePeriod = _options.LameDuckGracePeriod;
if (gracePeriod < TimeSpan.Zero) gracePeriod = -gracePeriod;
// If no clients, go straight to shutdown
if (_clients.IsEmpty)
{
await ShutdownAsync();
return;
}
// Wait grace period for clients to drain naturally
_logger.LogInformation("Waiting {GracePeriod}ms grace period", gracePeriod.TotalMilliseconds);
try
{
await Task.Delay(gracePeriod, _quitCts.Token);
}
catch (OperationCanceledException) { return; }
if (_clients.IsEmpty)
{
await ShutdownAsync();
return;
}
// Stagger-close remaining clients
var dur = _options.LameDuckDuration - gracePeriod;
if (dur <= TimeSpan.Zero) dur = TimeSpan.FromSeconds(1);
var clients = _clients.Values.ToList();
var numClients = clients.Count;
if (numClients > 0)
{
_logger.LogInformation("Closing {Count} existing clients over {Duration}ms",
numClients, dur.TotalMilliseconds);
var sleepInterval = dur.Ticks / numClients;
if (sleepInterval < TimeSpan.TicksPerMillisecond)
sleepInterval = TimeSpan.TicksPerMillisecond;
if (sleepInterval > TimeSpan.TicksPerSecond)
sleepInterval = TimeSpan.TicksPerSecond;
for (int i = 0; i < clients.Count; i++)
{
clients[i].MarkClosed(ClientClosedReason.ServerShutdown);
await clients[i].FlushAndCloseAsync(minimalFlush: true);
if (i < clients.Count - 1)
{
var jitter = Random.Shared.NextInt64(sleepInterval / 2, sleepInterval);
try
{
await Task.Delay(TimeSpan.FromTicks(jitter), _quitCts.Token);
}
catch (OperationCanceledException) { break; }
}
}
}
await ShutdownAsync();
}
/// <summary>
/// Registers Unix signal handlers.
/// SIGTERM → shutdown, SIGUSR2 → lame duck, SIGUSR1 → log reopen (stub), SIGHUP → reload (stub).
/// </summary>
public void HandleSignals()
{
_signalRegistrations.Add(PosixSignalRegistration.Create(PosixSignal.SIGTERM, ctx =>
{
ctx.Cancel = true;
_logger.LogInformation("Trapped SIGTERM signal");
_ = Task.Run(async () => await ShutdownAsync());
}));
_signalRegistrations.Add(PosixSignalRegistration.Create(PosixSignal.SIGQUIT, ctx =>
{
ctx.Cancel = true;
_logger.LogInformation("Trapped SIGQUIT signal");
_ = Task.Run(async () => await ShutdownAsync());
}));
_signalRegistrations.Add(PosixSignalRegistration.Create(PosixSignal.SIGHUP, ctx =>
{
ctx.Cancel = true;
_logger.LogWarning("Trapped SIGHUP signal — config reload not yet supported");
}));
// SIGUSR1 and SIGUSR2 only on non-Windows
if (!OperatingSystem.IsWindows())
{
_signalRegistrations.Add(PosixSignalRegistration.Create((PosixSignal)10, ctx =>
{
ctx.Cancel = true;
_logger.LogWarning("Trapped SIGUSR1 signal — log reopen not yet supported");
}));
_signalRegistrations.Add(PosixSignalRegistration.Create((PosixSignal)12, ctx =>
{
ctx.Cancel = true;
_logger.LogInformation("Trapped SIGUSR2 signal — entering lame duck mode");
_ = Task.Run(async () => await LameDuckShutdownAsync());
}));
}
}
public NatsServer(NatsOptions options, ILoggerFactory loggerFactory) public NatsServer(NatsOptions options, ILoggerFactory loggerFactory)
{ {
_options = options; _options = options;
@@ -50,6 +242,15 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
_authService = AuthService.Build(options); _authService = AuthService.Build(options);
_globalAccount = new Account(Account.GlobalAccountName); _globalAccount = new Account(Account.GlobalAccountName);
_accounts[Account.GlobalAccountName] = _globalAccount; _accounts[Account.GlobalAccountName] = _globalAccount;
// Create $SYS system account (stub -- no internal subscriptions yet)
_systemAccount = new Account("$SYS");
_accounts["$SYS"] = _systemAccount;
// Generate Ed25519 server NKey identity
using var serverKeyPair = KeyPair.CreatePair(PrefixByte.Server);
ServerNKey = serverKeyPair.GetPublicKey();
_serverInfo = new ServerInfo _serverInfo = new ServerInfo
{ {
ServerId = Guid.NewGuid().ToString("N")[..20].ToUpperInvariant(), ServerId = Guid.NewGuid().ToString("N")[..20].ToUpperInvariant(),
@@ -75,6 +276,8 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
public async Task StartAsync(CancellationToken ct) public async Task StartAsync(CancellationToken ct)
{ {
using var linked = CancellationTokenSource.CreateLinkedTokenSource(ct, _quitCts.Token);
_listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); _listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_listener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true); _listener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true);
_listener.Bind(new IPEndPoint( _listener.Bind(new IPEndPoint(
@@ -82,23 +285,68 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
_options.Port)); _options.Port));
Interlocked.Exchange(ref _startTimeTicks, DateTime.UtcNow.Ticks); Interlocked.Exchange(ref _startTimeTicks, DateTime.UtcNow.Ticks);
_listener.Listen(128); _listener.Listen(128);
// Resolve ephemeral port if port=0
if (_options.Port == 0)
{
var actualPort = ((IPEndPoint)_listener.LocalEndPoint!).Port;
_options.Port = actualPort;
_serverInfo.Port = actualPort;
}
_listeningStarted.TrySetResult(); _listeningStarted.TrySetResult();
_logger.LogInformation("Listening on {Host}:{Port}", _options.Host, _options.Port); _logger.LogInformation("Listening for client connections on {Host}:{Port}", _options.Host, _options.Port);
// Warn about stub features
if (_options.ConfigFile != null)
_logger.LogWarning("Config file parsing not yet supported (file: {ConfigFile})", _options.ConfigFile);
if (_options.ProfPort > 0)
_logger.LogWarning("Profiling endpoint not yet supported (port: {ProfPort})", _options.ProfPort);
if (_options.MonitorPort > 0) if (_options.MonitorPort > 0)
{ {
_monitorServer = new MonitorServer(this, _options, _stats, _loggerFactory); _monitorServer = new MonitorServer(this, _options, _stats, _loggerFactory);
await _monitorServer.StartAsync(ct); await _monitorServer.StartAsync(linked.Token);
} }
WritePidFile();
WritePortsFile();
var tmpDelay = AcceptMinSleep;
try try
{ {
while (!ct.IsCancellationRequested) while (!linked.Token.IsCancellationRequested)
{ {
var socket = await _listener.AcceptAsync(ct); Socket socket;
try
{
socket = await _listener.AcceptAsync(linked.Token);
tmpDelay = AcceptMinSleep; // Reset on success
}
catch (OperationCanceledException)
{
break;
}
catch (ObjectDisposedException)
{
break;
}
catch (SocketException ex)
{
if (IsShuttingDown || IsLameDuckMode)
break;
// Check MaxConnections before creating the client _logger.LogError(ex, "Temporary accept error, sleeping {Delay}ms", tmpDelay.TotalMilliseconds);
try { await Task.Delay(tmpDelay, linked.Token); }
catch (OperationCanceledException) { break; }
tmpDelay = TimeSpan.FromTicks(Math.Min(tmpDelay.Ticks * 2, AcceptMaxSleep.Ticks));
continue;
}
// Check MaxConnections
if (_options.MaxConnections > 0 && _clients.Count >= _options.MaxConnections) if (_options.MaxConnections > 0 && _clients.Count >= _options.MaxConnections)
{ {
_logger.LogWarning("Client connection rejected: maximum connections ({MaxConnections}) exceeded", _logger.LogWarning("Client connection rejected: maximum connections ({MaxConnections}) exceeded",
@@ -108,13 +356,13 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
var stream = new NetworkStream(socket, ownsSocket: false); var stream = new NetworkStream(socket, ownsSocket: false);
var errBytes = Encoding.ASCII.GetBytes( var errBytes = Encoding.ASCII.GetBytes(
$"-ERR '{NatsProtocol.ErrMaxConnectionsExceeded}'\r\n"); $"-ERR '{NatsProtocol.ErrMaxConnectionsExceeded}'\r\n");
await stream.WriteAsync(errBytes, ct); await stream.WriteAsync(errBytes, linked.Token);
await stream.FlushAsync(ct); await stream.FlushAsync(linked.Token);
stream.Dispose(); stream.Dispose();
} }
catch (Exception ex) catch (Exception ex2)
{ {
_logger.LogDebug(ex, "Failed to send -ERR to rejected client"); _logger.LogDebug(ex2, "Failed to send -ERR to rejected client");
} }
finally finally
{ {
@@ -125,16 +373,21 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
var clientId = Interlocked.Increment(ref _nextClientId); var clientId = Interlocked.Increment(ref _nextClientId);
Interlocked.Increment(ref _stats.TotalConnections); Interlocked.Increment(ref _stats.TotalConnections);
Interlocked.Increment(ref _activeClientCount);
_logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint); _logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint);
_ = AcceptClientAsync(socket, clientId, ct); _ = AcceptClientAsync(socket, clientId, linked.Token);
} }
} }
catch (OperationCanceledException) catch (OperationCanceledException)
{ {
_logger.LogDebug("Accept loop cancelled, server shutting down"); _logger.LogDebug("Accept loop cancelled, server shutting down");
} }
finally
{
_acceptLoopExited.TrySetResult();
}
} }
private async Task AcceptClientAsync(Socket socket, ulong clientId, CancellationToken ct) private async Task AcceptClientAsync(Socket socket, ulong clientId, CancellationToken ct)
@@ -217,8 +470,9 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
} }
finally finally
{ {
_logger.LogDebug("Client {ClientId} disconnected", client.Id); _logger.LogDebug("Client {ClientId} disconnected (reason: {CloseReason})", client.Id, client.CloseReason);
RemoveClient(client); RemoveClient(client);
Interlocked.Decrement(ref _activeClientCount);
} }
} }
@@ -327,10 +581,75 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
client.Account?.RemoveClient(client.Id); client.Account?.RemoveClient(client.Id);
} }
private void WritePidFile()
{
if (string.IsNullOrEmpty(_options.PidFile)) return;
try
{
File.WriteAllText(_options.PidFile, Environment.ProcessId.ToString());
_logger.LogDebug("Wrote PID file {PidFile}", _options.PidFile);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error writing PID file {PidFile}", _options.PidFile);
}
}
private void DeletePidFile()
{
if (string.IsNullOrEmpty(_options.PidFile)) return;
try
{
if (File.Exists(_options.PidFile))
File.Delete(_options.PidFile);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error deleting PID file {PidFile}", _options.PidFile);
}
}
private void WritePortsFile()
{
if (string.IsNullOrEmpty(_options.PortsFileDir)) return;
try
{
var exeName = Path.GetFileNameWithoutExtension(Environment.ProcessPath ?? "nats-server");
var fileName = $"{exeName}_{Environment.ProcessId}.ports";
_portsFilePath = Path.Combine(_options.PortsFileDir, fileName);
var ports = new { client = _options.Port, monitor = _options.MonitorPort > 0 ? _options.MonitorPort : (int?)null };
var json = System.Text.Json.JsonSerializer.Serialize(ports);
File.WriteAllText(_portsFilePath, json);
_logger.LogDebug("Wrote ports file {PortsFile}", _portsFilePath);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error writing ports file to {PortsFileDir}", _options.PortsFileDir);
}
}
private void DeletePortsFile()
{
if (_portsFilePath == null) return;
try
{
if (File.Exists(_portsFilePath))
File.Delete(_portsFilePath);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error deleting ports file {PortsFile}", _portsFilePath);
}
}
public void Dispose() public void Dispose()
{ {
if (_monitorServer != null) if (!IsShuttingDown)
_monitorServer.DisposeAsync().AsTask().GetAwaiter().GetResult(); ShutdownAsync().GetAwaiter().GetResult();
foreach (var reg in _signalRegistrations)
reg.Dispose();
_quitCts.Dispose();
_tlsRateLimiter?.Dispose(); _tlsRateLimiter?.Dispose();
_listener?.Dispose(); _listener?.Dispose();
foreach (var client in _clients.Values) foreach (var client in _clients.Values)

View File

@@ -213,6 +213,39 @@ public class ServerTests : IAsyncLifetime
} }
} }
public class EphemeralPortTests
{
[Fact]
public async Task Server_resolves_ephemeral_port()
{
using var cts = new CancellationTokenSource();
var server = new NatsServer(new NatsOptions { Port = 0 }, NullLoggerFactory.Instance);
_ = server.StartAsync(cts.Token);
await server.WaitForReadyAsync();
try
{
// Port should have been resolved to a real port
server.Port.ShouldBeGreaterThan(0);
// Connect a raw socket to prove the port actually works
using var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await client.ConnectAsync(IPAddress.Loopback, server.Port);
var buf = new byte[4096];
var n = await client.ReceiveAsync(buf, SocketFlags.None);
var response = Encoding.ASCII.GetString(buf, 0, n);
response.ShouldStartWith("INFO ");
}
finally
{
await cts.CancelAsync();
server.Dispose();
}
}
}
public class MaxConnectionsTests : IAsyncLifetime public class MaxConnectionsTests : IAsyncLifetime
{ {
private readonly NatsServer _server; private readonly NatsServer _server;
@@ -423,3 +456,448 @@ public class PingKeepaliveTests : IAsyncLifetime
client.Dispose(); client.Dispose();
} }
} }
public class CloseReasonTests : IAsyncLifetime
{
private readonly NatsServer _server;
private readonly int _port;
private readonly CancellationTokenSource _cts = new();
public CloseReasonTests()
{
_port = GetFreePort();
_server = new NatsServer(new NatsOptions { Port = _port }, NullLoggerFactory.Instance);
}
public async Task InitializeAsync()
{
_ = _server.StartAsync(_cts.Token);
await _server.WaitForReadyAsync();
}
public async Task DisposeAsync()
{
await _cts.CancelAsync();
_server.Dispose();
}
private static int GetFreePort()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
return ((IPEndPoint)sock.LocalEndPoint!).Port;
}
[Fact]
public async Task Client_close_reason_set_on_normal_disconnect()
{
// Connect a raw TCP client
using var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await client.ConnectAsync(IPAddress.Loopback, _port);
// Read INFO
var buf = new byte[4096];
await client.ReceiveAsync(buf, SocketFlags.None);
// Send CONNECT + PING, wait for PONG
await client.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPING\r\n"));
var sb = new StringBuilder();
using var readCts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
while (!sb.ToString().Contains("PONG"))
{
var n = await client.ReceiveAsync(buf, SocketFlags.None, readCts.Token);
if (n == 0) break;
sb.Append(Encoding.ASCII.GetString(buf, 0, n));
}
sb.ToString().ShouldContain("PONG");
// Get the NatsClient from the server
var natsClient = _server.GetClients().First();
// Close the TCP socket (normal client disconnect)
client.Shutdown(SocketShutdown.Both);
client.Close();
// Wait for the server to detect the disconnect
await Task.Delay(500);
// The close reason should be ClientClosed (normal disconnect falls through to finally)
natsClient.CloseReason.ShouldBe(ClientClosedReason.ClientClosed);
}
}
public class ServerIdentityTests
{
[Fact]
public void Server_creates_system_account()
{
var server = new NatsServer(new NatsOptions { Port = 0 }, NullLoggerFactory.Instance);
server.SystemAccount.ShouldNotBeNull();
server.SystemAccount.Name.ShouldBe("$SYS");
server.Dispose();
}
[Fact]
public void Server_generates_nkey_identity()
{
var server = new NatsServer(new NatsOptions { Port = 0 }, NullLoggerFactory.Instance);
server.ServerNKey.ShouldNotBeNullOrEmpty();
// Server NKey public keys start with 'N'
server.ServerNKey[0].ShouldBe('N');
server.Dispose();
}
}
public class FlushBeforeCloseTests
{
private static int GetFreePort()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
return ((IPEndPoint)sock.LocalEndPoint!).Port;
}
private static async Task<string> ReadUntilAsync(Socket sock, string expected, int timeoutMs = 5000)
{
using var cts = new CancellationTokenSource(timeoutMs);
var sb = new StringBuilder();
var buf = new byte[4096];
while (!sb.ToString().Contains(expected))
{
var n = await sock.ReceiveAsync(buf, SocketFlags.None, cts.Token);
if (n == 0) break;
sb.Append(Encoding.ASCII.GetString(buf, 0, n));
}
return sb.ToString();
}
[Fact]
public async Task Shutdown_flushes_pending_data_to_clients()
{
var port = GetFreePort();
var server = new NatsServer(new NatsOptions { Port = port }, NullLoggerFactory.Instance);
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
try
{
// Connect a subscriber via raw socket
using var sub = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await sub.ConnectAsync(IPAddress.Loopback, port);
// Read INFO
var buf = new byte[4096];
await sub.ReceiveAsync(buf, SocketFlags.None);
// Subscribe to "foo"
await sub.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nSUB foo 1\r\nPING\r\n"));
var pong = await ReadUntilAsync(sub, "PONG");
pong.ShouldContain("PONG");
// Connect a publisher
using var pub = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await pub.ConnectAsync(IPAddress.Loopback, port);
await pub.ReceiveAsync(buf, SocketFlags.None); // INFO
// Publish "Hello" to "foo"
await pub.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPUB foo 5\r\nHello\r\n"));
// Wait briefly for delivery
await Task.Delay(200);
// Read from subscriber to verify MSG was received
var msg = await ReadUntilAsync(sub, "Hello\r\n");
msg.ShouldContain("MSG foo 1 5\r\nHello\r\n");
}
finally
{
await server.ShutdownAsync();
server.Dispose();
}
}
}
public class GracefulShutdownTests
{
private static int GetFreePort()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
return ((IPEndPoint)sock.LocalEndPoint!).Port;
}
[Fact]
public async Task ShutdownAsync_disconnects_all_clients()
{
var port = GetFreePort();
var server = new NatsServer(new NatsOptions { Port = port }, NullLoggerFactory.Instance);
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
// Connect 2 raw TCP clients
using var client1 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await client1.ConnectAsync(IPAddress.Loopback, port);
var buf = new byte[4096];
await client1.ReceiveAsync(buf, SocketFlags.None); // INFO
using var client2 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await client2.ConnectAsync(IPAddress.Loopback, port);
await client2.ReceiveAsync(buf, SocketFlags.None); // INFO
// Send CONNECT so both are registered
await client1.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPING\r\n"));
await client2.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPING\r\n"));
// Wait for PONG from both (confirming they are registered)
using var readCts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
await client1.ReceiveAsync(buf, SocketFlags.None, readCts.Token);
await client2.ReceiveAsync(buf, SocketFlags.None, readCts.Token);
server.ClientCount.ShouldBe(2);
await server.ShutdownAsync();
server.ClientCount.ShouldBe(0);
server.Dispose();
}
[Fact]
public async Task WaitForShutdown_blocks_until_shutdown()
{
var port = GetFreePort();
var server = new NatsServer(new NatsOptions { Port = port }, NullLoggerFactory.Instance);
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
// Start WaitForShutdown in background
var waitTask = Task.Run(() => server.WaitForShutdown());
// Give it a moment -- it should NOT complete yet
await Task.Delay(200);
waitTask.IsCompleted.ShouldBeFalse();
// Trigger shutdown
await server.ShutdownAsync();
// WaitForShutdown should complete within 5 seconds
var completed = await Task.WhenAny(waitTask, Task.Delay(TimeSpan.FromSeconds(5)));
completed.ShouldBe(waitTask);
server.Dispose();
}
[Fact]
public async Task ShutdownAsync_is_idempotent()
{
var port = GetFreePort();
var server = new NatsServer(new NatsOptions { Port = port }, NullLoggerFactory.Instance);
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
// Call ShutdownAsync 3 times -- should not throw
await server.ShutdownAsync();
await server.ShutdownAsync();
await server.ShutdownAsync();
server.IsShuttingDown.ShouldBeTrue();
server.Dispose();
}
[Fact]
public async Task Accept_loop_waits_for_active_clients()
{
var port = GetFreePort();
var server = new NatsServer(new NatsOptions { Port = port }, NullLoggerFactory.Instance);
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
// Connect a client
using var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await client.ConnectAsync(IPAddress.Loopback, port);
var buf = new byte[4096];
await client.ReceiveAsync(buf, SocketFlags.None); // INFO
await client.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPING\r\n"));
using var readCts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
await client.ReceiveAsync(buf, SocketFlags.None, readCts.Token); // PONG
// ShutdownAsync should complete within 10 seconds (doesn't hang)
var shutdownTask = server.ShutdownAsync();
var completed = await Task.WhenAny(shutdownTask, Task.Delay(TimeSpan.FromSeconds(10)));
completed.ShouldBe(shutdownTask);
server.Dispose();
}
}
public class LameDuckTests
{
private static int GetFreePort()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
return ((IPEndPoint)sock.LocalEndPoint!).Port;
}
[Fact]
public async Task LameDuckShutdown_stops_accepting_new_connections()
{
var port = GetFreePort();
var server = new NatsServer(
new NatsOptions
{
Port = port,
LameDuckDuration = TimeSpan.FromSeconds(3),
LameDuckGracePeriod = TimeSpan.FromMilliseconds(500),
},
NullLoggerFactory.Instance);
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
try
{
// Connect 1 client
using var client1 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await client1.ConnectAsync(IPAddress.Loopback, port);
var buf = new byte[4096];
await client1.ReceiveAsync(buf, SocketFlags.None); // INFO
await client1.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPING\r\n"));
using var readCts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
await client1.ReceiveAsync(buf, SocketFlags.None, readCts.Token); // PONG
// Start lame duck (don't await yet)
var lameDuckTask = server.LameDuckShutdownAsync();
// Wait briefly for listener to close
await Task.Delay(300);
// Verify lame duck mode is active
server.IsLameDuckMode.ShouldBeTrue();
// Try connecting a new client -- should fail (connection refused)
using var client2 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
var connectAction = async () =>
{
await client2.ConnectAsync(IPAddress.Loopback, port);
};
await connectAction.ShouldThrowAsync<SocketException>();
// Await the lame duck task with timeout
var completed = await Task.WhenAny(lameDuckTask, Task.Delay(TimeSpan.FromSeconds(15)));
completed.ShouldBe(lameDuckTask);
}
finally
{
server.Dispose();
}
}
[Fact]
public async Task LameDuckShutdown_eventually_closes_all_clients()
{
var port = GetFreePort();
var server = new NatsServer(
new NatsOptions
{
Port = port,
LameDuckDuration = TimeSpan.FromSeconds(2),
LameDuckGracePeriod = TimeSpan.FromMilliseconds(200),
},
NullLoggerFactory.Instance);
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
try
{
// Connect 3 clients via raw sockets
var clients = new List<Socket>();
var buf = new byte[4096];
for (int i = 0; i < 3; i++)
{
var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await sock.ConnectAsync(IPAddress.Loopback, port);
await sock.ReceiveAsync(buf, SocketFlags.None); // INFO
await sock.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPING\r\n"));
using var readCts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
await sock.ReceiveAsync(buf, SocketFlags.None, readCts.Token); // PONG
clients.Add(sock);
}
server.ClientCount.ShouldBe(3);
// Await LameDuckShutdownAsync
var lameDuckTask = server.LameDuckShutdownAsync();
var completed = await Task.WhenAny(lameDuckTask, Task.Delay(TimeSpan.FromSeconds(15)));
completed.ShouldBe(lameDuckTask);
server.ClientCount.ShouldBe(0);
foreach (var sock in clients)
sock.Dispose();
}
finally
{
server.Dispose();
}
}
}
public class PidFileTests : IDisposable
{
private readonly string _tempDir = Path.Combine(Path.GetTempPath(), $"nats-test-{Guid.NewGuid():N}");
public PidFileTests() => Directory.CreateDirectory(_tempDir);
public void Dispose()
{
if (Directory.Exists(_tempDir))
Directory.Delete(_tempDir, recursive: true);
}
private static int GetFreePort()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
return ((IPEndPoint)sock.LocalEndPoint!).Port;
}
[Fact]
public async Task Server_writes_pid_file_on_startup()
{
var pidFile = Path.Combine(_tempDir, "nats.pid");
var port = GetFreePort();
var server = new NatsServer(new NatsOptions { Port = port, PidFile = pidFile }, NullLoggerFactory.Instance);
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
File.Exists(pidFile).ShouldBeTrue();
var content = await File.ReadAllTextAsync(pidFile);
int.Parse(content).ShouldBe(Environment.ProcessId);
await server.ShutdownAsync();
File.Exists(pidFile).ShouldBeFalse();
server.Dispose();
}
[Fact]
public async Task Server_writes_ports_file_on_startup()
{
var port = GetFreePort();
var server = new NatsServer(new NatsOptions { Port = port, PortsFileDir = _tempDir }, NullLoggerFactory.Instance);
_ = server.StartAsync(CancellationToken.None);
await server.WaitForReadyAsync();
var portsFiles = Directory.GetFiles(_tempDir, "*.ports");
portsFiles.Length.ShouldBe(1);
var content = await File.ReadAllTextAsync(portsFiles[0]);
content.ShouldContain($"\"client\":{port}");
await server.ShutdownAsync();
Directory.GetFiles(_tempDir, "*.ports").Length.ShouldBe(0);
server.Dispose();
}
}