Compare commits

...

17 Commits

Author SHA1 Message Date
Joseph Doherty
4d89661e79 feat: add monitoring HTTP endpoints and TLS support
Monitoring HTTP:
- /varz, /connz, /healthz via Kestrel Minimal API
- Pagination, sorting, subscription details on /connz
- ServerStats atomic counters, CPU/memory sampling
- CLI flags: -m, --http_port, --http_base_path, --https_port

TLS Support:
- 4-mode negotiation: no TLS, required, TLS-first, mixed
- Certificate loading, pinning (SHA-256), client cert verification
- PeekableStream for non-destructive TLS detection
- Token-bucket rate limiter for TLS handshakes
- CLI flags: --tls, --tlscert, --tlskey, --tlscacert, --tlsverify

29 new tests (78 → 107 total), all passing.

# Conflicts:
#	src/NATS.Server.Host/Program.cs
#	src/NATS.Server/NATS.Server.csproj
#	src/NATS.Server/NatsClient.cs
#	src/NATS.Server/NatsOptions.cs
#	src/NATS.Server/NatsServer.cs
#	src/NATS.Server/Protocol/NatsProtocol.cs
#	tests/NATS.Server.Tests/ClientTests.cs
2026-02-22 23:13:22 -05:00
Joseph Doherty
3b6bd08248 feat: add TLS mixed mode tests and monitoring TLS field verification
Add TlsMixedModeTests verifying that a server with AllowNonTls=true
accepts both plaintext and TLS clients on the same port. Add
MonitorTlsTests verifying that /connz reports TlsVersion and
TlsCipherSuite for TLS-connected clients.
2026-02-22 22:40:03 -05:00
Joseph Doherty
19f35e6463 feat: add --tls, --tlscert, --tlskey, --tlscacert, --tlsverify CLI flags 2026-02-22 22:36:57 -05:00
Joseph Doherty
9eb108b1df feat: add /connz endpoint with pagination, sorting, and subscription details 2026-02-22 22:36:28 -05:00
Joseph Doherty
87746168ba feat: wire TLS negotiation into NatsServer accept loop
Integrate TLS support into the server's connection accept path:
- Add SslServerAuthenticationOptions and TlsRateLimiter fields to NatsServer
- Extract AcceptClientAsync method for TLS negotiation, rate limiting, and
  TLS state extraction (protocol version, cipher suite, peer certificate)
- Add InfoAlreadySent flag to NatsClient to skip redundant INFO when
  TlsConnectionWrapper already sent it during negotiation
- Add TlsServerTests verifying TLS connect+INFO and TLS pub/sub
2026-02-22 22:35:42 -05:00
Joseph Doherty
818bc0ba1f fix: address MonitorServer review — dispose resources, add cancellation, improve test reliability 2026-02-22 22:30:14 -05:00
Joseph Doherty
63198ef83b fix: address TlsConnectionWrapper review — clone ServerInfo, fix SslStream leak, add TLS-first test 2026-02-22 22:28:19 -05:00
Joseph Doherty
a52db677e2 fix: track HTTP request stats for all monitoring endpoints 2026-02-22 22:25:00 -05:00
Joseph Doherty
0409acc745 feat: add TlsConnectionWrapper with 4-mode TLS negotiation 2026-02-22 22:21:11 -05:00
Joseph Doherty
f2badc3488 feat: add MonitorServer with /healthz and /varz endpoints 2026-02-22 22:20:44 -05:00
Joseph Doherty
f6b38df291 feat: add TlsHelper, PeekableStream, and TlsRateLimiter
Add TLS utility classes for certificate loading, peekable stream for TLS
detection, token-bucket rate limiter for handshake throttling, and
TlsConnectionState for post-handshake info. Add TlsState property to
NatsClient. Fix X509Certificate2 constructor usage for .NET 10 compat.
2026-02-22 22:13:53 -05:00
Joseph Doherty
045c12cce7 feat: add Varz and Connz monitoring JSON models with Go field name parity 2026-02-22 22:13:50 -05:00
Joseph Doherty
b2f7b1b2a0 feat: add -m/--http_port CLI flag for monitoring 2026-02-22 22:10:07 -05:00
Joseph Doherty
a26c1359de refactor: NatsClient accepts Stream parameter for TLS support 2026-02-22 22:09:48 -05:00
Joseph Doherty
1a777e09c9 feat: add ServerStats counters and NatsClient metadata for monitoring 2026-02-22 22:08:30 -05:00
Joseph Doherty
ceaafc48d4 feat: add project setup for monitoring and TLS — csproj, config options, ServerInfo TLS fields
Add FrameworkReference to Microsoft.AspNetCore.App to enable Kestrel
Minimal APIs for the monitoring HTTP server. Remove the now-redundant
Microsoft.Extensions.Logging.Abstractions PackageReference (it is
included transitively via the framework reference).

Add monitoring config properties (MonitorPort, MonitorHost,
MonitorBasePath, MonitorHttpsPort) and TLS config properties (TlsCert,
TlsKey, TlsCaCert, TlsVerify, TlsHandshakeFirst, etc.) to NatsOptions.

Add TlsRequired, TlsVerify, and TlsAvailable fields to ServerInfo so
the server can advertise TLS capability in the INFO protocol message.
2026-02-22 22:00:42 -05:00
Joseph Doherty
d08ce7f6fb chore: add .worktrees/ to .gitignore for isolated development 2026-02-22 21:54:26 -05:00
24 changed files with 2596 additions and 43 deletions

View File

@@ -32,6 +32,20 @@ for (int i = 0; i < args.Length; i++)
case "--https_port" when i + 1 < args.Length:
options.MonitorHttpsPort = int.Parse(args[++i]);
break;
case "--tls":
break;
case "--tlscert" when i + 1 < args.Length:
options.TlsCert = args[++i];
break;
case "--tlskey" when i + 1 < args.Length:
options.TlsKey = args[++i];
break;
case "--tlscacert" when i + 1 < args.Length:
options.TlsCaCert = args[++i];
break;
case "--tlsverify":
options.TlsVerify = true;
break;
}
}

View File

@@ -0,0 +1,207 @@
using System.Text.Json.Serialization;
namespace NATS.Server.Monitoring;
/// <summary>
/// Connection information response. Corresponds to Go server/monitor.go Connz struct.
/// </summary>
public sealed class Connz
{
[JsonPropertyName("server_id")]
public string Id { get; set; } = "";
[JsonPropertyName("now")]
public DateTime Now { get; set; }
[JsonPropertyName("num_connections")]
public int NumConns { get; set; }
[JsonPropertyName("total")]
public int Total { get; set; }
[JsonPropertyName("offset")]
public int Offset { get; set; }
[JsonPropertyName("limit")]
public int Limit { get; set; }
[JsonPropertyName("connections")]
public ConnInfo[] Conns { get; set; } = [];
}
/// <summary>
/// Detailed information on a per-connection basis.
/// Corresponds to Go server/monitor.go ConnInfo struct.
/// </summary>
public sealed class ConnInfo
{
[JsonPropertyName("cid")]
public ulong Cid { get; set; }
[JsonPropertyName("kind")]
public string Kind { get; set; } = "";
[JsonPropertyName("type")]
public string Type { get; set; } = "";
[JsonPropertyName("ip")]
public string Ip { get; set; } = "";
[JsonPropertyName("port")]
public int Port { get; set; }
[JsonPropertyName("start")]
public DateTime Start { get; set; }
[JsonPropertyName("last_activity")]
public DateTime LastActivity { get; set; }
[JsonPropertyName("stop")]
public DateTime? Stop { get; set; }
[JsonPropertyName("reason")]
public string Reason { get; set; } = "";
[JsonPropertyName("rtt")]
public string Rtt { get; set; } = "";
[JsonPropertyName("uptime")]
public string Uptime { get; set; } = "";
[JsonPropertyName("idle")]
public string Idle { get; set; } = "";
[JsonPropertyName("pending_bytes")]
public int Pending { get; set; }
[JsonPropertyName("in_msgs")]
public long InMsgs { get; set; }
[JsonPropertyName("out_msgs")]
public long OutMsgs { get; set; }
[JsonPropertyName("in_bytes")]
public long InBytes { get; set; }
[JsonPropertyName("out_bytes")]
public long OutBytes { get; set; }
[JsonPropertyName("subscriptions")]
public uint NumSubs { get; set; }
[JsonPropertyName("subscriptions_list")]
public string[] Subs { get; set; } = [];
[JsonPropertyName("subscriptions_list_detail")]
public SubDetail[] SubsDetail { get; set; } = [];
[JsonPropertyName("name")]
public string Name { get; set; } = "";
[JsonPropertyName("lang")]
public string Lang { get; set; } = "";
[JsonPropertyName("version")]
public string Version { get; set; } = "";
[JsonPropertyName("authorized_user")]
public string AuthorizedUser { get; set; } = "";
[JsonPropertyName("account")]
public string Account { get; set; } = "";
[JsonPropertyName("tls_version")]
public string TlsVersion { get; set; } = "";
[JsonPropertyName("tls_cipher_suite")]
public string TlsCipherSuite { get; set; } = "";
[JsonPropertyName("tls_first")]
public bool TlsFirst { get; set; }
[JsonPropertyName("mqtt_client")]
public string MqttClient { get; set; } = "";
}
/// <summary>
/// Subscription detail information.
/// Corresponds to Go server/monitor.go SubDetail struct.
/// </summary>
public sealed class SubDetail
{
[JsonPropertyName("account")]
public string Account { get; set; } = "";
[JsonPropertyName("subject")]
public string Subject { get; set; } = "";
[JsonPropertyName("qgroup")]
public string Queue { get; set; } = "";
[JsonPropertyName("sid")]
public string Sid { get; set; } = "";
[JsonPropertyName("msgs")]
public long Msgs { get; set; }
[JsonPropertyName("max")]
public long Max { get; set; }
[JsonPropertyName("cid")]
public ulong Cid { get; set; }
}
/// <summary>
/// Sort options for connection listing.
/// Corresponds to Go server/monitor_sort_opts.go SortOpt type.
/// </summary>
public enum SortOpt
{
ByCid,
ByStart,
BySubs,
ByPending,
ByMsgsTo,
ByMsgsFrom,
ByBytesTo,
ByBytesFrom,
ByLast,
ByIdle,
ByUptime,
}
/// <summary>
/// Connection state filter.
/// Corresponds to Go server/monitor.go ConnState type.
/// </summary>
public enum ConnState
{
Open,
Closed,
All,
}
/// <summary>
/// Options passed to Connz() for filtering and sorting.
/// Corresponds to Go server/monitor.go ConnzOptions struct.
/// </summary>
public sealed class ConnzOptions
{
public SortOpt Sort { get; set; } = SortOpt.ByCid;
public bool Subscriptions { get; set; }
public bool SubscriptionsDetail { get; set; }
public ConnState State { get; set; } = ConnState.Open;
public string User { get; set; } = "";
public string Account { get; set; } = "";
public string FilterSubject { get; set; } = "";
public int Offset { get; set; }
public int Limit { get; set; } = 1024;
}

View File

@@ -0,0 +1,148 @@
using Microsoft.AspNetCore.Http;
namespace NATS.Server.Monitoring;
/// <summary>
/// Handles /connz endpoint requests, returning detailed connection information.
/// Corresponds to Go server/monitor.go handleConnz function.
/// </summary>
public sealed class ConnzHandler(NatsServer server)
{
public Connz HandleConnz(HttpContext ctx)
{
var opts = ParseQueryParams(ctx);
var now = DateTime.UtcNow;
var clients = server.GetClients().ToArray();
var connInfos = clients.Select(c => BuildConnInfo(c, now, opts)).ToList();
// Sort
connInfos = opts.Sort switch
{
SortOpt.ByCid => connInfos.OrderBy(c => c.Cid).ToList(),
SortOpt.ByStart => connInfos.OrderBy(c => c.Start).ToList(),
SortOpt.BySubs => connInfos.OrderByDescending(c => c.NumSubs).ToList(),
SortOpt.ByPending => connInfos.OrderByDescending(c => c.Pending).ToList(),
SortOpt.ByMsgsTo => connInfos.OrderByDescending(c => c.OutMsgs).ToList(),
SortOpt.ByMsgsFrom => connInfos.OrderByDescending(c => c.InMsgs).ToList(),
SortOpt.ByBytesTo => connInfos.OrderByDescending(c => c.OutBytes).ToList(),
SortOpt.ByBytesFrom => connInfos.OrderByDescending(c => c.InBytes).ToList(),
SortOpt.ByLast => connInfos.OrderByDescending(c => c.LastActivity).ToList(),
SortOpt.ByIdle => connInfos.OrderByDescending(c => now - c.LastActivity).ToList(),
SortOpt.ByUptime => connInfos.OrderByDescending(c => now - c.Start).ToList(),
_ => connInfos.OrderBy(c => c.Cid).ToList(),
};
var total = connInfos.Count;
var paged = connInfos.Skip(opts.Offset).Take(opts.Limit).ToArray();
return new Connz
{
Id = server.ServerId,
Now = now,
NumConns = paged.Length,
Total = total,
Offset = opts.Offset,
Limit = opts.Limit,
Conns = paged,
};
}
private static ConnInfo BuildConnInfo(NatsClient client, DateTime now, ConnzOptions opts)
{
var info = new ConnInfo
{
Cid = client.Id,
Kind = "Client",
Type = "Client",
Ip = client.RemoteIp ?? "",
Port = client.RemotePort,
Start = client.StartTime,
LastActivity = client.LastActivity,
Uptime = FormatDuration(now - client.StartTime),
Idle = FormatDuration(now - client.LastActivity),
InMsgs = Interlocked.Read(ref client.InMsgs),
OutMsgs = Interlocked.Read(ref client.OutMsgs),
InBytes = Interlocked.Read(ref client.InBytes),
OutBytes = Interlocked.Read(ref client.OutBytes),
NumSubs = (uint)client.Subscriptions.Count,
Name = client.ClientOpts?.Name ?? "",
Lang = client.ClientOpts?.Lang ?? "",
Version = client.ClientOpts?.Version ?? "",
TlsVersion = client.TlsState?.TlsVersion ?? "",
TlsCipherSuite = client.TlsState?.CipherSuite ?? "",
};
if (opts.Subscriptions)
{
info.Subs = client.Subscriptions.Values.Select(s => s.Subject).ToArray();
}
if (opts.SubscriptionsDetail)
{
info.SubsDetail = client.Subscriptions.Values.Select(s => new SubDetail
{
Subject = s.Subject,
Queue = s.Queue ?? "",
Sid = s.Sid,
Msgs = Interlocked.Read(ref s.MessageCount),
Max = s.MaxMessages,
Cid = client.Id,
}).ToArray();
}
return info;
}
private static ConnzOptions ParseQueryParams(HttpContext ctx)
{
var q = ctx.Request.Query;
var opts = new ConnzOptions();
if (q.TryGetValue("sort", out var sort))
{
opts.Sort = sort.ToString().ToLowerInvariant() switch
{
"cid" => SortOpt.ByCid,
"start" => SortOpt.ByStart,
"subs" => SortOpt.BySubs,
"pending" => SortOpt.ByPending,
"msgs_to" => SortOpt.ByMsgsTo,
"msgs_from" => SortOpt.ByMsgsFrom,
"bytes_to" => SortOpt.ByBytesTo,
"bytes_from" => SortOpt.ByBytesFrom,
"last" => SortOpt.ByLast,
"idle" => SortOpt.ByIdle,
"uptime" => SortOpt.ByUptime,
_ => SortOpt.ByCid,
};
}
if (q.TryGetValue("subs", out var subs))
{
if (subs == "detail")
opts.SubscriptionsDetail = true;
else
opts.Subscriptions = true;
}
if (q.TryGetValue("offset", out var offset) && int.TryParse(offset, out var o))
opts.Offset = o;
if (q.TryGetValue("limit", out var limit) && int.TryParse(limit, out var l))
opts.Limit = l;
return opts;
}
private static string FormatDuration(TimeSpan ts)
{
if (ts.TotalDays >= 1)
return $"{(int)ts.TotalDays}d{ts.Hours}h{ts.Minutes}m{ts.Seconds}s";
if (ts.TotalHours >= 1)
return $"{(int)ts.TotalHours}h{ts.Minutes}m{ts.Seconds}s";
if (ts.TotalMinutes >= 1)
return $"{(int)ts.TotalMinutes}m{ts.Seconds}s";
return $"{(int)ts.TotalSeconds}s";
}
}

View File

@@ -0,0 +1,117 @@
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
namespace NATS.Server.Monitoring;
/// <summary>
/// HTTP monitoring server providing /healthz, /varz, and other monitoring endpoints.
/// Corresponds to Go server/monitor.go HTTP server setup.
/// </summary>
public sealed class MonitorServer : IAsyncDisposable
{
private readonly WebApplication _app;
private readonly ILogger<MonitorServer> _logger;
private readonly VarzHandler _varzHandler;
private readonly ConnzHandler _connzHandler;
public MonitorServer(NatsServer server, NatsOptions options, ServerStats stats, ILoggerFactory loggerFactory)
{
_logger = loggerFactory.CreateLogger<MonitorServer>();
var builder = WebApplication.CreateSlimBuilder();
builder.WebHost.UseUrls($"http://{options.MonitorHost}:{options.MonitorPort}");
builder.Logging.ClearProviders();
_app = builder.Build();
var basePath = options.MonitorBasePath ?? "";
_varzHandler = new VarzHandler(server, options);
_connzHandler = new ConnzHandler(server);
_app.MapGet(basePath + "/", () =>
{
stats.HttpReqStats.AddOrUpdate("/", 1, (_, v) => v + 1);
return Results.Ok(new
{
endpoints = new[]
{
"/varz", "/connz", "/healthz", "/routez",
"/gatewayz", "/leafz", "/subz", "/accountz", "/jsz",
},
});
});
_app.MapGet(basePath + "/healthz", () =>
{
stats.HttpReqStats.AddOrUpdate("/healthz", 1, (_, v) => v + 1);
return Results.Ok("ok");
});
_app.MapGet(basePath + "/varz", async (HttpContext ctx) =>
{
stats.HttpReqStats.AddOrUpdate("/varz", 1, (_, v) => v + 1);
return Results.Ok(await _varzHandler.HandleVarzAsync(ctx.RequestAborted));
});
_app.MapGet(basePath + "/connz", (HttpContext ctx) =>
{
stats.HttpReqStats.AddOrUpdate("/connz", 1, (_, v) => v + 1);
return Results.Ok(_connzHandler.HandleConnz(ctx));
});
// Stubs for unimplemented endpoints
_app.MapGet(basePath + "/routez", () =>
{
stats.HttpReqStats.AddOrUpdate("/routez", 1, (_, v) => v + 1);
return Results.Ok(new { });
});
_app.MapGet(basePath + "/gatewayz", () =>
{
stats.HttpReqStats.AddOrUpdate("/gatewayz", 1, (_, v) => v + 1);
return Results.Ok(new { });
});
_app.MapGet(basePath + "/leafz", () =>
{
stats.HttpReqStats.AddOrUpdate("/leafz", 1, (_, v) => v + 1);
return Results.Ok(new { });
});
_app.MapGet(basePath + "/subz", () =>
{
stats.HttpReqStats.AddOrUpdate("/subz", 1, (_, v) => v + 1);
return Results.Ok(new { });
});
_app.MapGet(basePath + "/subscriptionsz", () =>
{
stats.HttpReqStats.AddOrUpdate("/subscriptionsz", 1, (_, v) => v + 1);
return Results.Ok(new { });
});
_app.MapGet(basePath + "/accountz", () =>
{
stats.HttpReqStats.AddOrUpdate("/accountz", 1, (_, v) => v + 1);
return Results.Ok(new { });
});
_app.MapGet(basePath + "/accstatz", () =>
{
stats.HttpReqStats.AddOrUpdate("/accstatz", 1, (_, v) => v + 1);
return Results.Ok(new { });
});
_app.MapGet(basePath + "/jsz", () =>
{
stats.HttpReqStats.AddOrUpdate("/jsz", 1, (_, v) => v + 1);
return Results.Ok(new { });
});
}
public async Task StartAsync(CancellationToken ct)
{
await _app.StartAsync(ct);
_logger.LogInformation("Monitoring listening on {Urls}", string.Join(", ", _app.Urls));
}
public async ValueTask DisposeAsync()
{
await _app.StopAsync();
await _app.DisposeAsync();
_varzHandler.Dispose();
}
}

View File

@@ -0,0 +1,415 @@
using System.Text.Json.Serialization;
namespace NATS.Server.Monitoring;
/// <summary>
/// Server general information. Corresponds to Go server/monitor.go Varz struct.
/// </summary>
public sealed class Varz
{
// Identity
[JsonPropertyName("server_id")]
public string Id { get; set; } = "";
[JsonPropertyName("server_name")]
public string Name { get; set; } = "";
[JsonPropertyName("version")]
public string Version { get; set; } = "";
[JsonPropertyName("proto")]
public int Proto { get; set; }
[JsonPropertyName("git_commit")]
public string GitCommit { get; set; } = "";
[JsonPropertyName("go")]
public string GoVersion { get; set; } = "";
[JsonPropertyName("host")]
public string Host { get; set; } = "";
[JsonPropertyName("port")]
public int Port { get; set; }
// Network
[JsonPropertyName("ip")]
public string Ip { get; set; } = "";
[JsonPropertyName("connect_urls")]
public string[] ConnectUrls { get; set; } = [];
[JsonPropertyName("ws_connect_urls")]
public string[] WsConnectUrls { get; set; } = [];
[JsonPropertyName("http_host")]
public string HttpHost { get; set; } = "";
[JsonPropertyName("http_port")]
public int HttpPort { get; set; }
[JsonPropertyName("http_base_path")]
public string HttpBasePath { get; set; } = "";
[JsonPropertyName("https_port")]
public int HttpsPort { get; set; }
// Security
[JsonPropertyName("auth_required")]
public bool AuthRequired { get; set; }
[JsonPropertyName("tls_required")]
public bool TlsRequired { get; set; }
[JsonPropertyName("tls_verify")]
public bool TlsVerify { get; set; }
[JsonPropertyName("tls_ocsp_peer_verify")]
public bool TlsOcspPeerVerify { get; set; }
[JsonPropertyName("auth_timeout")]
public double AuthTimeout { get; set; }
[JsonPropertyName("tls_timeout")]
public double TlsTimeout { get; set; }
// Limits
[JsonPropertyName("max_connections")]
public int MaxConnections { get; set; }
[JsonPropertyName("max_subscriptions")]
public int MaxSubscriptions { get; set; }
[JsonPropertyName("max_payload")]
public int MaxPayload { get; set; }
[JsonPropertyName("max_pending")]
public long MaxPending { get; set; }
[JsonPropertyName("max_control_line")]
public int MaxControlLine { get; set; }
[JsonPropertyName("ping_max")]
public int MaxPingsOut { get; set; }
// Timing
[JsonPropertyName("ping_interval")]
public long PingInterval { get; set; }
[JsonPropertyName("write_deadline")]
public long WriteDeadline { get; set; }
[JsonPropertyName("start")]
public DateTime Start { get; set; }
[JsonPropertyName("now")]
public DateTime Now { get; set; }
[JsonPropertyName("uptime")]
public string Uptime { get; set; } = "";
// Runtime
[JsonPropertyName("mem")]
public long Mem { get; set; }
[JsonPropertyName("cpu")]
public double Cpu { get; set; }
[JsonPropertyName("cores")]
public int Cores { get; set; }
[JsonPropertyName("gomaxprocs")]
public int MaxProcs { get; set; }
// Connections
[JsonPropertyName("connections")]
public int Connections { get; set; }
[JsonPropertyName("total_connections")]
public ulong TotalConnections { get; set; }
[JsonPropertyName("routes")]
public int Routes { get; set; }
[JsonPropertyName("remotes")]
public int Remotes { get; set; }
[JsonPropertyName("leafnodes")]
public int Leafnodes { get; set; }
// Messages
[JsonPropertyName("in_msgs")]
public long InMsgs { get; set; }
[JsonPropertyName("out_msgs")]
public long OutMsgs { get; set; }
[JsonPropertyName("in_bytes")]
public long InBytes { get; set; }
[JsonPropertyName("out_bytes")]
public long OutBytes { get; set; }
// Health
[JsonPropertyName("slow_consumers")]
public long SlowConsumers { get; set; }
[JsonPropertyName("slow_consumer_stats")]
public SlowConsumersStats SlowConsumerStats { get; set; } = new();
[JsonPropertyName("subscriptions")]
public uint Subscriptions { get; set; }
// Config
[JsonPropertyName("config_load_time")]
public DateTime ConfigLoadTime { get; set; }
[JsonPropertyName("tags")]
public string[] Tags { get; set; } = [];
[JsonPropertyName("system_account")]
public string SystemAccount { get; set; } = "";
[JsonPropertyName("pinned_account_fails")]
public ulong PinnedAccountFail { get; set; }
[JsonPropertyName("tls_cert_not_after")]
public DateTime TlsCertNotAfter { get; set; }
// HTTP
[JsonPropertyName("http_req_stats")]
public Dictionary<string, ulong> HttpReqStats { get; set; } = new();
// Subsystems
[JsonPropertyName("cluster")]
public ClusterOptsVarz Cluster { get; set; } = new();
[JsonPropertyName("gateway")]
public GatewayOptsVarz Gateway { get; set; } = new();
[JsonPropertyName("leaf")]
public LeafNodeOptsVarz Leaf { get; set; } = new();
[JsonPropertyName("mqtt")]
public MqttOptsVarz Mqtt { get; set; } = new();
[JsonPropertyName("websocket")]
public WebsocketOptsVarz Websocket { get; set; } = new();
[JsonPropertyName("jetstream")]
public JetStreamVarz JetStream { get; set; } = new();
}
/// <summary>
/// Statistics about slow consumers by connection type.
/// Corresponds to Go server/monitor.go SlowConsumersStats struct.
/// </summary>
public sealed class SlowConsumersStats
{
[JsonPropertyName("clients")]
public ulong Clients { get; set; }
[JsonPropertyName("routes")]
public ulong Routes { get; set; }
[JsonPropertyName("gateways")]
public ulong Gateways { get; set; }
[JsonPropertyName("leafs")]
public ulong Leafs { get; set; }
}
/// <summary>
/// Cluster configuration monitoring information.
/// Corresponds to Go server/monitor.go ClusterOptsVarz struct.
/// </summary>
public sealed class ClusterOptsVarz
{
[JsonPropertyName("name")]
public string Name { get; set; } = "";
[JsonPropertyName("addr")]
public string Host { get; set; } = "";
[JsonPropertyName("cluster_port")]
public int Port { get; set; }
[JsonPropertyName("auth_timeout")]
public double AuthTimeout { get; set; }
[JsonPropertyName("tls_timeout")]
public double TlsTimeout { get; set; }
[JsonPropertyName("tls_required")]
public bool TlsRequired { get; set; }
[JsonPropertyName("tls_verify")]
public bool TlsVerify { get; set; }
[JsonPropertyName("pool_size")]
public int PoolSize { get; set; }
[JsonPropertyName("urls")]
public string[] Urls { get; set; } = [];
}
/// <summary>
/// Gateway configuration monitoring information.
/// Corresponds to Go server/monitor.go GatewayOptsVarz struct.
/// </summary>
public sealed class GatewayOptsVarz
{
[JsonPropertyName("name")]
public string Name { get; set; } = "";
[JsonPropertyName("host")]
public string Host { get; set; } = "";
[JsonPropertyName("port")]
public int Port { get; set; }
[JsonPropertyName("auth_timeout")]
public double AuthTimeout { get; set; }
[JsonPropertyName("tls_timeout")]
public double TlsTimeout { get; set; }
[JsonPropertyName("tls_required")]
public bool TlsRequired { get; set; }
[JsonPropertyName("tls_verify")]
public bool TlsVerify { get; set; }
[JsonPropertyName("advertise")]
public string Advertise { get; set; } = "";
[JsonPropertyName("connect_retries")]
public int ConnectRetries { get; set; }
[JsonPropertyName("reject_unknown")]
public bool RejectUnknown { get; set; }
}
/// <summary>
/// Leaf node configuration monitoring information.
/// Corresponds to Go server/monitor.go LeafNodeOptsVarz struct.
/// </summary>
public sealed class LeafNodeOptsVarz
{
[JsonPropertyName("host")]
public string Host { get; set; } = "";
[JsonPropertyName("port")]
public int Port { get; set; }
[JsonPropertyName("auth_timeout")]
public double AuthTimeout { get; set; }
[JsonPropertyName("tls_timeout")]
public double TlsTimeout { get; set; }
[JsonPropertyName("tls_required")]
public bool TlsRequired { get; set; }
[JsonPropertyName("tls_verify")]
public bool TlsVerify { get; set; }
[JsonPropertyName("tls_ocsp_peer_verify")]
public bool TlsOcspPeerVerify { get; set; }
}
/// <summary>
/// MQTT configuration monitoring information.
/// Corresponds to Go server/monitor.go MQTTOptsVarz struct.
/// </summary>
public sealed class MqttOptsVarz
{
[JsonPropertyName("host")]
public string Host { get; set; } = "";
[JsonPropertyName("port")]
public int Port { get; set; }
[JsonPropertyName("tls_timeout")]
public double TlsTimeout { get; set; }
}
/// <summary>
/// Websocket configuration monitoring information.
/// Corresponds to Go server/monitor.go WebsocketOptsVarz struct.
/// </summary>
public sealed class WebsocketOptsVarz
{
[JsonPropertyName("host")]
public string Host { get; set; } = "";
[JsonPropertyName("port")]
public int Port { get; set; }
[JsonPropertyName("tls_timeout")]
public double TlsTimeout { get; set; }
}
/// <summary>
/// JetStream runtime information.
/// Corresponds to Go server/monitor.go JetStreamVarz struct.
/// </summary>
public sealed class JetStreamVarz
{
[JsonPropertyName("config")]
public JetStreamConfig Config { get; set; } = new();
[JsonPropertyName("stats")]
public JetStreamStats Stats { get; set; } = new();
}
/// <summary>
/// JetStream configuration.
/// Corresponds to Go server/jetstream.go JetStreamConfig struct.
/// </summary>
public sealed class JetStreamConfig
{
[JsonPropertyName("max_memory")]
public long MaxMemory { get; set; }
[JsonPropertyName("max_storage")]
public long MaxStorage { get; set; }
[JsonPropertyName("store_dir")]
public string StoreDir { get; set; } = "";
}
/// <summary>
/// JetStream statistics.
/// Corresponds to Go server/jetstream.go JetStreamStats struct.
/// </summary>
public sealed class JetStreamStats
{
[JsonPropertyName("memory")]
public ulong Memory { get; set; }
[JsonPropertyName("storage")]
public ulong Storage { get; set; }
[JsonPropertyName("accounts")]
public int Accounts { get; set; }
[JsonPropertyName("ha_assets")]
public int HaAssets { get; set; }
[JsonPropertyName("api")]
public JetStreamApiStats Api { get; set; } = new();
}
/// <summary>
/// JetStream API statistics.
/// Corresponds to Go server/jetstream.go JetStreamAPIStats struct.
/// </summary>
public sealed class JetStreamApiStats
{
[JsonPropertyName("total")]
public ulong Total { get; set; }
[JsonPropertyName("errors")]
public ulong Errors { get; set; }
}

View File

@@ -0,0 +1,121 @@
using System.Diagnostics;
using System.Runtime.InteropServices;
using NATS.Server.Protocol;
namespace NATS.Server.Monitoring;
/// <summary>
/// Handles building the Varz response from server state and process metrics.
/// Corresponds to Go server/monitor.go handleVarz function.
/// </summary>
public sealed class VarzHandler : IDisposable
{
private readonly NatsServer _server;
private readonly NatsOptions _options;
private readonly SemaphoreSlim _varzMu = new(1, 1);
private DateTime _lastCpuSampleTime;
private TimeSpan _lastCpuUsage;
private double _cachedCpuPercent;
public VarzHandler(NatsServer server, NatsOptions options)
{
_server = server;
_options = options;
using var proc = Process.GetCurrentProcess();
_lastCpuSampleTime = DateTime.UtcNow;
_lastCpuUsage = proc.TotalProcessorTime;
}
public async Task<Varz> HandleVarzAsync(CancellationToken ct = default)
{
await _varzMu.WaitAsync(ct);
try
{
using var proc = Process.GetCurrentProcess();
var now = DateTime.UtcNow;
var uptime = now - _server.StartTime;
var stats = _server.Stats;
// CPU sampling with 1-second cache to avoid excessive sampling
if ((now - _lastCpuSampleTime).TotalSeconds >= 1.0)
{
var currentCpu = proc.TotalProcessorTime;
var elapsed = now - _lastCpuSampleTime;
_cachedCpuPercent = (currentCpu - _lastCpuUsage).TotalMilliseconds
/ elapsed.TotalMilliseconds / Environment.ProcessorCount * 100.0;
_lastCpuSampleTime = now;
_lastCpuUsage = currentCpu;
}
return new Varz
{
Id = _server.ServerId,
Name = _server.ServerName,
Version = NatsProtocol.Version,
Proto = NatsProtocol.ProtoVersion,
GoVersion = $"dotnet {RuntimeInformation.FrameworkDescription}",
Host = _options.Host,
Port = _options.Port,
HttpHost = _options.MonitorHost,
HttpPort = _options.MonitorPort,
HttpBasePath = _options.MonitorBasePath ?? "",
HttpsPort = _options.MonitorHttpsPort,
TlsRequired = _options.HasTls && !_options.AllowNonTls,
TlsVerify = _options.HasTls && _options.TlsVerify,
TlsTimeout = _options.HasTls ? _options.TlsTimeout.TotalSeconds : 0,
MaxConnections = _options.MaxConnections,
MaxPayload = _options.MaxPayload,
MaxControlLine = _options.MaxControlLine,
MaxPingsOut = _options.MaxPingsOut,
PingInterval = (long)_options.PingInterval.TotalNanoseconds,
Start = _server.StartTime,
Now = now,
Uptime = FormatUptime(uptime),
Mem = proc.WorkingSet64,
Cpu = Math.Round(_cachedCpuPercent, 2),
Cores = Environment.ProcessorCount,
MaxProcs = ThreadPool.ThreadCount,
Connections = _server.ClientCount,
TotalConnections = (ulong)Interlocked.Read(ref stats.TotalConnections),
InMsgs = Interlocked.Read(ref stats.InMsgs),
OutMsgs = Interlocked.Read(ref stats.OutMsgs),
InBytes = Interlocked.Read(ref stats.InBytes),
OutBytes = Interlocked.Read(ref stats.OutBytes),
SlowConsumers = Interlocked.Read(ref stats.SlowConsumers),
SlowConsumerStats = new SlowConsumersStats
{
Clients = (ulong)Interlocked.Read(ref stats.SlowConsumerClients),
Routes = (ulong)Interlocked.Read(ref stats.SlowConsumerRoutes),
Gateways = (ulong)Interlocked.Read(ref stats.SlowConsumerGateways),
Leafs = (ulong)Interlocked.Read(ref stats.SlowConsumerLeafs),
},
Subscriptions = _server.SubList.Count,
ConfigLoadTime = _server.StartTime,
HttpReqStats = stats.HttpReqStats.ToDictionary(kv => kv.Key, kv => (ulong)kv.Value),
};
}
finally
{
_varzMu.Release();
}
}
public void Dispose()
{
_varzMu.Dispose();
}
/// <summary>
/// Formats a TimeSpan as a human-readable uptime string matching Go server format.
/// </summary>
private static string FormatUptime(TimeSpan ts)
{
if (ts.TotalDays >= 1)
return $"{(int)ts.TotalDays}d{ts.Hours}h{ts.Minutes}m{ts.Seconds}s";
if (ts.TotalHours >= 1)
return $"{(int)ts.TotalHours}h{ts.Minutes}m{ts.Seconds}s";
if (ts.TotalMinutes >= 1)
return $"{(int)ts.TotalMinutes}m{ts.Seconds}s";
return $"{(int)ts.TotalSeconds}s";
}
}

View File

@@ -1,6 +1,6 @@
<Project Sdk="Microsoft.NET.Sdk">
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" />
<FrameworkReference Include="Microsoft.AspNetCore.App" />
<PackageReference Include="NATS.NKeys" />
<PackageReference Include="BCrypt.Net-Next" />
</ItemGroup>

View File

@@ -1,5 +1,6 @@
using System.Buffers;
using System.IO.Pipelines;
using System.Net;
using System.Net.Sockets;
using System.Security.Cryptography;
using System.Text;
@@ -8,6 +9,7 @@ using Microsoft.Extensions.Logging;
using NATS.Server.Auth;
using NATS.Server.Protocol;
using NATS.Server.Subscriptions;
using NATS.Server.Tls;
namespace NATS.Server;
@@ -26,7 +28,7 @@ public interface ISubListAccess
public sealed class NatsClient : IDisposable
{
private readonly Socket _socket;
private readonly NetworkStream _stream;
private readonly Stream _stream;
private readonly NatsOptions _options;
private readonly ServerInfo _serverInfo;
private readonly AuthService _authService;
@@ -37,6 +39,7 @@ public sealed class NatsClient : IDisposable
private readonly Dictionary<string, Subscription> _subs = new();
private readonly ILogger _logger;
private ClientPermissions? _permissions;
private readonly ServerStats _serverStats;
public ulong Id { get; }
public ClientOptions? ClientOpts { get; private set; }
@@ -47,6 +50,12 @@ public sealed class NatsClient : IDisposable
private int _connectReceived;
public bool ConnectReceived => Volatile.Read(ref _connectReceived) != 0;
public DateTime StartTime { get; }
private long _lastActivityTicks;
public DateTime LastActivity => new(Interlocked.Read(ref _lastActivityTicks), DateTimeKind.Utc);
public string? RemoteIp { get; }
public int RemotePort { get; }
// Stats
public long InMsgs;
public long OutMsgs;
@@ -57,20 +66,31 @@ public sealed class NatsClient : IDisposable
private int _pingsOut;
private long _lastIn;
public TlsConnectionState? TlsState { get; set; }
public bool InfoAlreadySent { get; set; }
public IReadOnlyDictionary<string, Subscription> Subscriptions => _subs;
public NatsClient(ulong id, Socket socket, NatsOptions options, ServerInfo serverInfo,
AuthService authService, byte[]? nonce, ILogger logger)
public NatsClient(ulong id, Stream stream, Socket socket, NatsOptions options, ServerInfo serverInfo,
AuthService authService, byte[]? nonce, ILogger logger, ServerStats serverStats)
{
Id = id;
_socket = socket;
_stream = new NetworkStream(socket, ownsSocket: false);
_stream = stream;
_options = options;
_serverInfo = serverInfo;
_authService = authService;
_nonce = nonce;
_logger = logger;
_serverStats = serverStats;
_parser = new NatsParser(options.MaxPayload);
StartTime = DateTime.UtcNow;
_lastActivityTicks = StartTime.Ticks;
if (socket.RemoteEndPoint is IPEndPoint ep)
{
RemoteIp = ep.Address.ToString();
RemotePort = ep.Port;
}
}
public async Task RunAsync(CancellationToken ct)
@@ -80,8 +100,9 @@ public sealed class NatsClient : IDisposable
var pipe = new Pipe();
try
{
// Send INFO
await SendInfoAsync(_clientCts.Token);
// Send INFO (skip if already sent during TLS negotiation)
if (!InfoAlreadySent)
await SendInfoAsync(_clientCts.Token);
// Start auth timeout if auth is required
Task? authTimeoutTask = null;
@@ -100,7 +121,7 @@ public sealed class NatsClient : IDisposable
}
catch (OperationCanceledException)
{
// Normal client connected or was cancelled
// Normal -- client connected or was cancelled
}
}, _clientCts.Token);
}
@@ -184,6 +205,8 @@ public sealed class NatsClient : IDisposable
private async ValueTask DispatchCommandAsync(ParsedCommand cmd, CancellationToken ct)
{
Interlocked.Exchange(ref _lastActivityTicks, DateTime.UtcNow.Ticks);
// If auth is required and CONNECT hasn't been received yet,
// only allow CONNECT and PING commands
if (_authService.IsAuthRequired && !ConnectReceived)
@@ -266,7 +289,7 @@ public sealed class NatsClient : IDisposable
_logger.LogDebug("Client {ClientId} authenticated as {Identity}", Id, result.Identity);
// Clear nonce after use defense-in-depth against memory dumps
// Clear nonce after use -- defense-in-depth against memory dumps
if (_nonce != null)
CryptographicOperations.ZeroMemory(_nonce);
}
@@ -330,6 +353,8 @@ public sealed class NatsClient : IDisposable
{
Interlocked.Increment(ref InMsgs);
Interlocked.Add(ref InBytes, cmd.Payload.Length);
Interlocked.Increment(ref _serverStats.InMsgs);
Interlocked.Add(ref _serverStats.InBytes, cmd.Payload.Length);
// Max payload validation (always, hard close)
if (cmd.Payload.Length > _options.MaxPayload)
@@ -380,6 +405,8 @@ public sealed class NatsClient : IDisposable
{
Interlocked.Increment(ref OutMsgs);
Interlocked.Add(ref OutBytes, payload.Length + headers.Length);
Interlocked.Increment(ref _serverStats.OutMsgs);
Interlocked.Add(ref _serverStats.OutBytes, payload.Length + headers.Length);
byte[] line;
if (headers.Length > 0)
@@ -470,7 +497,7 @@ public sealed class NatsClient : IDisposable
if (Volatile.Read(ref _pingsOut) + 1 > _options.MaxPingsOut)
{
_logger.LogDebug("Client {ClientId} stale connection closing", Id);
_logger.LogDebug("Client {ClientId} stale connection -- closing", Id);
await SendErrAndCloseAsync(NatsProtocol.ErrStaleConnection);
return;
}

View File

@@ -1,3 +1,4 @@
using System.Security.Authentication;
using NATS.Server.Auth;
namespace NATS.Server;
@@ -7,7 +8,7 @@ public sealed class NatsOptions
public string Host { get; set; } = "0.0.0.0";
public int Port { get; set; } = 4222;
public string? ServerName { get; set; }
public int MaxPayload { get; set; } = 1024 * 1024; // 1MB
public int MaxPayload { get; set; } = 1024 * 1024;
public int MaxControlLine { get; set; } = 4096;
public int MaxConnections { get; set; } = 65536;
public TimeSpan PingInterval { get; set; } = TimeSpan.FromMinutes(2);
@@ -27,4 +28,27 @@ public sealed class NatsOptions
// Auth timing
public TimeSpan AuthTimeout { get; set; } = TimeSpan.FromSeconds(2);
// Monitoring (0 = disabled; standard port is 8222)
public int MonitorPort { get; set; }
public string MonitorHost { get; set; } = "0.0.0.0";
public string? MonitorBasePath { get; set; }
// 0 = disabled
public int MonitorHttpsPort { get; set; }
// TLS
public string? TlsCert { get; set; }
public string? TlsKey { get; set; }
public string? TlsCaCert { get; set; }
public bool TlsVerify { get; set; }
public bool TlsMap { get; set; }
public TimeSpan TlsTimeout { get; set; } = TimeSpan.FromSeconds(2);
public bool TlsHandshakeFirst { get; set; }
public TimeSpan TlsHandshakeFirstFallback { get; set; } = TimeSpan.FromMilliseconds(50);
public bool AllowNonTls { get; set; }
public long TlsRateLimit { get; set; }
public HashSet<string>? TlsPinnedCerts { get; set; }
public SslProtocols TlsMinVersion { get; set; } = SslProtocols.Tls12;
public bool HasTls => TlsCert != null && TlsKey != null;
}

View File

@@ -1,11 +1,15 @@
using System.Collections.Concurrent;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using Microsoft.Extensions.Logging;
using NATS.Server.Auth;
using NATS.Server.Monitoring;
using NATS.Server.Protocol;
using NATS.Server.Subscriptions;
using NATS.Server.Tls;
namespace NATS.Server;
@@ -16,14 +20,25 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
private readonly ServerInfo _serverInfo;
private readonly ILogger<NatsServer> _logger;
private readonly ILoggerFactory _loggerFactory;
private readonly ServerStats _stats = new();
private readonly TaskCompletionSource _listeningStarted = new(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly AuthService _authService;
private readonly ConcurrentDictionary<string, Account> _accounts = new(StringComparer.Ordinal);
private readonly Account _globalAccount;
private readonly SslServerAuthenticationOptions? _sslOptions;
private readonly TlsRateLimiter? _tlsRateLimiter;
private Socket? _listener;
private MonitorServer? _monitorServer;
private ulong _nextClientId;
private long _startTimeTicks;
public SubList SubList => _globalAccount.SubList;
public ServerStats Stats => _stats;
public DateTime StartTime => new(Interlocked.Read(ref _startTimeTicks), DateTimeKind.Utc);
public string ServerId => _serverInfo.ServerId;
public string ServerName => _serverInfo.ServerName;
public int ClientCount => _clients.Count;
public IEnumerable<NatsClient> GetClients() => _clients.Values;
public Task WaitForReadyAsync() => _listeningStarted.Task;
@@ -45,6 +60,17 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
MaxPayload = options.MaxPayload,
AuthRequired = _authService.IsAuthRequired,
};
if (options.HasTls)
{
_sslOptions = TlsHelper.BuildServerAuthOptions(options);
_serverInfo.TlsRequired = !options.AllowNonTls;
_serverInfo.TlsAvailable = options.AllowNonTls;
_serverInfo.TlsVerify = options.TlsVerify;
if (options.TlsRateLimit > 0)
_tlsRateLimiter = new TlsRateLimiter(options.TlsRateLimit);
}
}
public async Task StartAsync(CancellationToken ct)
@@ -54,11 +80,18 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
_listener.Bind(new IPEndPoint(
_options.Host == "0.0.0.0" ? IPAddress.Any : IPAddress.Parse(_options.Host),
_options.Port));
Interlocked.Exchange(ref _startTimeTicks, DateTime.UtcNow.Ticks);
_listener.Listen(128);
_listeningStarted.TrySetResult();
_logger.LogInformation("Listening on {Host}:{Port}", _options.Host, _options.Port);
if (_options.MonitorPort > 0)
{
_monitorServer = new MonitorServer(this, _options, _stats, _loggerFactory);
await _monitorServer.StartAsync(ct);
}
try
{
while (!ct.IsCancellationRequested)
@@ -91,37 +124,11 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
}
var clientId = Interlocked.Increment(ref _nextClientId);
Interlocked.Increment(ref _stats.TotalConnections);
_logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint);
// Build per-client ServerInfo with nonce if NKey auth is configured
byte[]? nonce = null;
var clientInfo = _serverInfo;
if (_authService.NonceRequired)
{
var rawNonce = _authService.GenerateNonce();
var nonceStr = _authService.EncodeNonce(rawNonce);
// The client signs the nonce string (ASCII), not the raw bytes
nonce = Encoding.ASCII.GetBytes(nonceStr);
clientInfo = new ServerInfo
{
ServerId = _serverInfo.ServerId,
ServerName = _serverInfo.ServerName,
Version = _serverInfo.Version,
Host = _serverInfo.Host,
Port = _serverInfo.Port,
MaxPayload = _serverInfo.MaxPayload,
AuthRequired = _serverInfo.AuthRequired,
Nonce = nonceStr,
};
}
var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]");
var client = new NatsClient(clientId, socket, _options, clientInfo, _authService, nonce, clientLogger);
client.Router = this;
_clients[clientId] = client;
_ = RunClientAsync(client, ct);
_ = AcceptClientAsync(socket, clientId, ct);
}
}
catch (OperationCanceledException)
@@ -130,6 +137,74 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
}
}
private async Task AcceptClientAsync(Socket socket, ulong clientId, CancellationToken ct)
{
try
{
// Rate limit TLS handshakes
if (_tlsRateLimiter != null)
await _tlsRateLimiter.WaitAsync(ct);
var networkStream = new NetworkStream(socket, ownsSocket: false);
// TLS negotiation (no-op if not configured)
var (stream, infoAlreadySent) = await TlsConnectionWrapper.NegotiateAsync(
socket, networkStream, _options, _sslOptions, _serverInfo,
_loggerFactory.CreateLogger("NATS.Server.Tls"), ct);
// Extract TLS state
TlsConnectionState? tlsState = null;
if (stream is SslStream ssl)
{
tlsState = new TlsConnectionState(
ssl.SslProtocol.ToString(),
ssl.NegotiatedCipherSuite.ToString(),
ssl.RemoteCertificate as X509Certificate2);
}
// Build per-client ServerInfo with nonce if NKey auth is configured
byte[]? nonce = null;
var clientInfo = _serverInfo;
if (_authService.NonceRequired)
{
var rawNonce = _authService.GenerateNonce();
var nonceStr = _authService.EncodeNonce(rawNonce);
// The client signs the nonce string (ASCII), not the raw bytes
nonce = Encoding.ASCII.GetBytes(nonceStr);
clientInfo = new ServerInfo
{
ServerId = _serverInfo.ServerId,
ServerName = _serverInfo.ServerName,
Version = _serverInfo.Version,
Host = _serverInfo.Host,
Port = _serverInfo.Port,
MaxPayload = _serverInfo.MaxPayload,
AuthRequired = _serverInfo.AuthRequired,
TlsRequired = _serverInfo.TlsRequired,
TlsAvailable = _serverInfo.TlsAvailable,
TlsVerify = _serverInfo.TlsVerify,
Nonce = nonceStr,
};
}
var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]");
var client = new NatsClient(clientId, stream, socket, _options, clientInfo,
_authService, nonce, clientLogger, _stats);
client.Router = this;
client.TlsState = tlsState;
client.InfoAlreadySent = infoAlreadySent;
_clients[clientId] = client;
await RunClientAsync(client, ct);
}
catch (Exception ex)
{
_logger.LogDebug(ex, "Failed to accept client {ClientId}", clientId);
try { socket.Shutdown(SocketShutdown.Both); } catch { }
socket.Dispose();
}
}
private async Task RunClientAsync(NatsClient client, CancellationToken ct)
{
try
@@ -215,6 +290,9 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
public void Dispose()
{
if (_monitorServer != null)
_monitorServer.DisposeAsync().AsTask().GetAwaiter().GetResult();
_tlsRateLimiter?.Dispose();
_listener?.Dispose();
foreach (var client in _clients.Values)
client.Dispose();

View File

@@ -73,6 +73,18 @@ public sealed class ServerInfo
[JsonPropertyName("nonce")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Nonce { get; set; }
[JsonPropertyName("tls_required")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public bool TlsRequired { get; set; }
[JsonPropertyName("tls_verify")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public bool TlsVerify { get; set; }
[JsonPropertyName("tls_available")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public bool TlsAvailable { get; set; }
}
public sealed class ClientOptions

View File

@@ -0,0 +1,20 @@
using System.Collections.Concurrent;
namespace NATS.Server;
public sealed class ServerStats
{
public long InMsgs;
public long OutMsgs;
public long InBytes;
public long OutBytes;
public long TotalConnections;
public long SlowConsumers;
public long StaleConnections;
public long Stalls;
public long SlowConsumerClients;
public long SlowConsumerRoutes;
public long SlowConsumerLeafs;
public long SlowConsumerGateways;
public readonly ConcurrentDictionary<string, long> HttpReqStats = new();
}

View File

@@ -0,0 +1,71 @@
namespace NATS.Server.Tls;
public sealed class PeekableStream : Stream
{
private readonly Stream _inner;
private byte[]? _peekedBytes;
private int _peekedOffset;
private int _peekedCount;
public PeekableStream(Stream inner) => _inner = inner;
public async Task<byte[]> PeekAsync(int count, CancellationToken ct = default)
{
var buf = new byte[count];
int read = await _inner.ReadAsync(buf.AsMemory(0, count), ct);
if (read < count) Array.Resize(ref buf, read);
_peekedBytes = buf;
_peekedOffset = 0;
_peekedCount = read;
return buf;
}
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken ct = default)
{
if (_peekedBytes != null && _peekedOffset < _peekedCount)
{
int available = _peekedCount - _peekedOffset;
int toCopy = Math.Min(available, buffer.Length);
_peekedBytes.AsMemory(_peekedOffset, toCopy).CopyTo(buffer);
_peekedOffset += toCopy;
if (_peekedOffset >= _peekedCount) _peekedBytes = null;
return toCopy;
}
return await _inner.ReadAsync(buffer, ct);
}
public override int Read(byte[] buffer, int offset, int count)
{
if (_peekedBytes != null && _peekedOffset < _peekedCount)
{
int available = _peekedCount - _peekedOffset;
int toCopy = Math.Min(available, count);
Array.Copy(_peekedBytes, _peekedOffset, buffer, offset, toCopy);
_peekedOffset += toCopy;
if (_peekedOffset >= _peekedCount) _peekedBytes = null;
return toCopy;
}
return _inner.Read(buffer, offset, count);
}
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken ct)
=> ReadAsync(buffer.AsMemory(offset, count), ct).AsTask();
// Write passthrough
public override void Write(byte[] buffer, int offset, int count) => _inner.Write(buffer, offset, count);
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ct) => _inner.WriteAsync(buffer, offset, count, ct);
public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken ct = default) => _inner.WriteAsync(buffer, ct);
public override void Flush() => _inner.Flush();
public override Task FlushAsync(CancellationToken ct) => _inner.FlushAsync(ct);
// Required Stream overrides
public override bool CanRead => _inner.CanRead;
public override bool CanSeek => false;
public override bool CanWrite => _inner.CanWrite;
public override long Length => throw new NotSupportedException();
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
public override void SetLength(long value) => throw new NotSupportedException();
protected override void Dispose(bool disposing) { if (disposing) _inner.Dispose(); base.Dispose(disposing); }
}

View File

@@ -0,0 +1,9 @@
using System.Security.Cryptography.X509Certificates;
namespace NATS.Server.Tls;
public sealed record TlsConnectionState(
string? TlsVersion,
string? CipherSuite,
X509Certificate2? PeerCert
);

View File

@@ -0,0 +1,202 @@
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Text.Json;
using Microsoft.Extensions.Logging;
using NATS.Server.Protocol;
namespace NATS.Server.Tls;
public static class TlsConnectionWrapper
{
private const byte TlsRecordMarker = 0x16;
public static async Task<(Stream stream, bool infoAlreadySent)> NegotiateAsync(
Socket socket,
Stream networkStream,
NatsOptions options,
SslServerAuthenticationOptions? sslOptions,
ServerInfo serverInfo,
ILogger logger,
CancellationToken ct)
{
// Mode 1: No TLS
if (sslOptions == null || !options.HasTls)
return (networkStream, false);
// Clone to avoid mutating shared instance
serverInfo = new ServerInfo
{
ServerId = serverInfo.ServerId,
ServerName = serverInfo.ServerName,
Version = serverInfo.Version,
Proto = serverInfo.Proto,
Host = serverInfo.Host,
Port = serverInfo.Port,
Headers = serverInfo.Headers,
MaxPayload = serverInfo.MaxPayload,
ClientId = serverInfo.ClientId,
ClientIp = serverInfo.ClientIp,
};
// Mode 3: TLS First
if (options.TlsHandshakeFirst)
return await NegotiateTlsFirstAsync(socket, networkStream, options, sslOptions, serverInfo, logger, ct);
// Mode 2 & 4: Send INFO first, then decide
serverInfo.TlsRequired = !options.AllowNonTls;
serverInfo.TlsAvailable = options.AllowNonTls;
serverInfo.TlsVerify = options.TlsVerify;
await SendInfoAsync(networkStream, serverInfo, ct);
// Peek first byte to detect TLS
var peekable = new PeekableStream(networkStream);
var peeked = await PeekWithTimeoutAsync(peekable, 1, options.TlsTimeout, ct);
if (peeked.Length == 0)
{
// Client disconnected or timed out
return (peekable, true);
}
if (peeked[0] == TlsRecordMarker)
{
// Client is starting TLS
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
try
{
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
handshakeCts.CancelAfter(options.TlsTimeout);
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
logger.LogDebug("TLS handshake complete: {Protocol} {CipherSuite}",
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
// Validate pinned certs
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
{
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
{
logger.LogWarning("Certificate pinning check failed");
throw new InvalidOperationException("Certificate pinning check failed");
}
}
}
catch
{
sslStream.Dispose();
throw;
}
return (sslStream, true);
}
// Mode 4: Mixed — client chose plaintext
if (options.AllowNonTls)
{
logger.LogDebug("Client connected without TLS (mixed mode)");
return (peekable, true);
}
// TLS required but client sent plaintext
logger.LogWarning("TLS required but client sent plaintext data");
throw new InvalidOperationException("TLS required");
}
private static async Task<(Stream stream, bool infoAlreadySent)> NegotiateTlsFirstAsync(
Socket socket,
Stream networkStream,
NatsOptions options,
SslServerAuthenticationOptions sslOptions,
ServerInfo serverInfo,
ILogger logger,
CancellationToken ct)
{
// Wait for data with fallback timeout
var peekable = new PeekableStream(networkStream);
var peeked = await PeekWithTimeoutAsync(peekable, 1, options.TlsHandshakeFirstFallback, ct);
if (peeked.Length > 0 && peeked[0] == TlsRecordMarker)
{
// Client started TLS immediately — handshake first, then send INFO
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
try
{
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
handshakeCts.CancelAfter(options.TlsTimeout);
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
logger.LogDebug("TLS-first handshake complete: {Protocol} {CipherSuite}",
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
// Validate pinned certs
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
{
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
{
throw new InvalidOperationException("Certificate pinning check failed");
}
}
// Now send INFO over encrypted stream
serverInfo.TlsRequired = true;
serverInfo.TlsVerify = options.TlsVerify;
await SendInfoAsync(sslStream, serverInfo, ct);
}
catch
{
sslStream.Dispose();
throw;
}
return (sslStream, true);
}
// Fallback: timeout expired or non-TLS data — send INFO and negotiate normally
logger.LogDebug("TLS-first fallback: sending INFO");
serverInfo.TlsRequired = !options.AllowNonTls;
serverInfo.TlsAvailable = options.AllowNonTls;
serverInfo.TlsVerify = options.TlsVerify;
await SendInfoAsync(peekable, serverInfo, ct);
if (peeked.Length == 0)
{
// Timeout — INFO was sent, return stream for normal flow
return (peekable, true);
}
// Non-TLS data received during fallback window
if (options.AllowNonTls)
{
return (peekable, true);
}
// TLS required but got plaintext
throw new InvalidOperationException("TLS required but client sent plaintext");
}
private static async Task<byte[]> PeekWithTimeoutAsync(
PeekableStream stream, int count, TimeSpan timeout, CancellationToken ct)
{
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
cts.CancelAfter(timeout);
try
{
return await stream.PeekAsync(count, cts.Token);
}
catch (OperationCanceledException) when (!ct.IsCancellationRequested)
{
// Timeout — not a cancellation of the outer token
return [];
}
}
private static async Task SendInfoAsync(Stream stream, ServerInfo serverInfo, CancellationToken ct)
{
var infoJson = JsonSerializer.Serialize(serverInfo);
var infoLine = Encoding.ASCII.GetBytes($"INFO {infoJson}\r\n");
await stream.WriteAsync(infoLine, ct);
await stream.FlushAsync(ct);
}
}

View File

@@ -0,0 +1,65 @@
using System.Net.Security;
using System.Security.Authentication;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
namespace NATS.Server.Tls;
public static class TlsHelper
{
public static X509Certificate2 LoadCertificate(string certPath, string? keyPath)
{
if (keyPath != null)
return X509Certificate2.CreateFromPemFile(certPath, keyPath);
return X509CertificateLoader.LoadCertificateFromFile(certPath);
}
public static X509Certificate2Collection LoadCaCertificates(string caPath)
{
var collection = new X509Certificate2Collection();
collection.ImportFromPemFile(caPath);
return collection;
}
public static SslServerAuthenticationOptions BuildServerAuthOptions(NatsOptions opts)
{
var cert = LoadCertificate(opts.TlsCert!, opts.TlsKey);
var authOpts = new SslServerAuthenticationOptions
{
ServerCertificate = cert,
EnabledSslProtocols = opts.TlsMinVersion,
ClientCertificateRequired = opts.TlsVerify,
};
if (opts.TlsVerify && opts.TlsCaCert != null)
{
var caCerts = LoadCaCertificates(opts.TlsCaCert);
authOpts.RemoteCertificateValidationCallback = (_, cert, chain, errors) =>
{
if (cert == null) return false;
using var chain2 = new X509Chain();
chain2.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust;
foreach (var ca in caCerts)
chain2.ChainPolicy.CustomTrustStore.Add(ca);
chain2.ChainPolicy.RevocationMode = X509RevocationMode.NoCheck;
var cert2 = cert as X509Certificate2 ?? X509CertificateLoader.LoadCertificate(cert.GetRawCertData());
return chain2.Build(cert2);
};
}
return authOpts;
}
public static string GetCertificateHash(X509Certificate2 cert)
{
var spki = cert.PublicKey.ExportSubjectPublicKeyInfo();
var hash = SHA256.HashData(spki);
return Convert.ToHexStringLower(hash);
}
public static bool MatchesPinnedCert(X509Certificate2 cert, HashSet<string> pinned)
{
var hash = GetCertificateHash(cert);
return pinned.Contains(hash);
}
}

View File

@@ -0,0 +1,25 @@
namespace NATS.Server.Tls;
public sealed class TlsRateLimiter : IDisposable
{
private readonly SemaphoreSlim _semaphore;
private readonly Timer _refillTimer;
private readonly int _tokensPerSecond;
public TlsRateLimiter(long tokensPerSecond)
{
_tokensPerSecond = (int)Math.Max(1, tokensPerSecond);
_semaphore = new SemaphoreSlim(_tokensPerSecond, _tokensPerSecond);
_refillTimer = new Timer(Refill, null, TimeSpan.FromSeconds(1), TimeSpan.FromSeconds(1));
}
private void Refill(object? state)
{
int toRelease = _tokensPerSecond - _semaphore.CurrentCount;
if (toRelease > 0) _semaphore.Release(toRelease);
}
public Task WaitAsync(CancellationToken ct) => _semaphore.WaitAsync(ct);
public void Dispose() { _refillTimer.Dispose(); _semaphore.Dispose(); }
}

View File

@@ -41,7 +41,7 @@ public class ClientTests : IAsyncDisposable
};
var authService = AuthService.Build(new NatsOptions());
_natsClient = new NatsClient(1, _serverSocket, new NatsOptions(), serverInfo, authService, null, NullLogger.Instance);
_natsClient = new NatsClient(1, new NetworkStream(_serverSocket, ownsSocket: false), _serverSocket, new NatsOptions(), serverInfo, authService, null, NullLogger.Instance, new ServerStats());
}
public async ValueTask DisposeAsync()
@@ -56,7 +56,7 @@ public class ClientTests : IAsyncDisposable
{
var runTask = _natsClient.RunAsync(_cts.Token);
// Read from client socket should get INFO
// Read from client socket -- should get INFO
var buf = new byte[4096];
var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None);
var response = Encoding.ASCII.GetString(buf, 0, n);
@@ -80,7 +80,7 @@ public class ClientTests : IAsyncDisposable
// Send CONNECT then PING
await _clientSocket.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPING\r\n"));
// Read response should get PONG
// Read response -- should get PONG
var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None);
var response = Encoding.ASCII.GetString(buf, 0, n);
@@ -128,7 +128,7 @@ public class ClientTests : IAsyncDisposable
response.ShouldBe("-ERR 'maximum connections exceeded'\r\n");
// Connection should be closed next read returns 0
// Connection should be closed -- next read returns 0
n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token);
n.ShouldBe(0);
}

View File

@@ -0,0 +1,51 @@
using System.Text.Json;
using NATS.Server.Monitoring;
namespace NATS.Server.Tests;
public class MonitorModelTests
{
[Fact]
public void Varz_serializes_with_go_field_names()
{
var varz = new Varz
{
Id = "TESTID", Name = "test-server", Version = "0.1.0",
Host = "0.0.0.0", Port = 4222, InMsgs = 100, OutMsgs = 200,
};
var json = JsonSerializer.Serialize(varz);
json.ShouldContain("\"server_id\":");
json.ShouldContain("\"server_name\":");
json.ShouldContain("\"in_msgs\":");
json.ShouldContain("\"out_msgs\":");
json.ShouldNotContain("\"InMsgs\"");
}
[Fact]
public void Connz_serializes_with_go_field_names()
{
var connz = new Connz
{
Id = "TESTID", Now = DateTime.UtcNow, NumConns = 1, Total = 1, Limit = 1024,
Conns = [new ConnInfo { Cid = 1, Ip = "127.0.0.1", Port = 5555,
InMsgs = 10, Uptime = "1s", Idle = "0s",
Start = DateTime.UtcNow, LastActivity = DateTime.UtcNow }],
};
var json = JsonSerializer.Serialize(connz);
json.ShouldContain("\"server_id\":");
json.ShouldContain("\"num_connections\":");
json.ShouldContain("\"in_msgs\":");
json.ShouldContain("\"pending_bytes\":");
}
[Fact]
public void Varz_includes_nested_config_stubs()
{
var varz = new Varz { Id = "X", Name = "X", Version = "X", Host = "X" };
var json = JsonSerializer.Serialize(varz);
json.ShouldContain("\"cluster\":");
json.ShouldContain("\"gateway\":");
json.ShouldContain("\"leaf\":");
json.ShouldContain("\"jetstream\":");
}
}

View File

@@ -0,0 +1,274 @@
using System.Net;
using System.Net.Http.Json;
using System.Net.Security;
using System.Net.Sockets;
using System.Text;
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Server.Monitoring;
namespace NATS.Server.Tests;
public class MonitorTests : IAsyncLifetime
{
private readonly NatsServer _server;
private readonly int _natsPort;
private readonly int _monitorPort;
private readonly CancellationTokenSource _cts = new();
private readonly HttpClient _http = new();
public MonitorTests()
{
_natsPort = GetFreePort();
_monitorPort = GetFreePort();
_server = new NatsServer(
new NatsOptions { Port = _natsPort, MonitorPort = _monitorPort },
NullLoggerFactory.Instance);
}
public async Task InitializeAsync()
{
_ = _server.StartAsync(_cts.Token);
await _server.WaitForReadyAsync();
// Wait for monitoring HTTP server to be ready
for (int i = 0; i < 50; i++)
{
try
{
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/healthz");
if (response.IsSuccessStatusCode) break;
}
catch (HttpRequestException) { }
await Task.Delay(50);
}
}
public async Task DisposeAsync()
{
_http.Dispose();
await _cts.CancelAsync();
_server.Dispose();
}
[Fact]
public async Task Healthz_returns_ok()
{
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/healthz");
response.StatusCode.ShouldBe(HttpStatusCode.OK);
}
[Fact]
public async Task Varz_returns_server_identity()
{
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/varz");
response.StatusCode.ShouldBe(HttpStatusCode.OK);
var varz = await response.Content.ReadFromJsonAsync<Varz>();
varz.ShouldNotBeNull();
varz.Id.ShouldNotBeNullOrEmpty();
varz.Name.ShouldNotBeNullOrEmpty();
varz.Version.ShouldBe("0.1.0");
varz.Host.ShouldBe("0.0.0.0");
varz.Port.ShouldBe(_natsPort);
varz.MaxPayload.ShouldBe(1024 * 1024);
varz.Uptime.ShouldNotBeNullOrEmpty();
varz.Now.ShouldBeGreaterThan(DateTime.MinValue);
varz.Mem.ShouldBeGreaterThan(0);
varz.Cores.ShouldBeGreaterThan(0);
}
[Fact]
public async Task Varz_tracks_connections_and_messages()
{
// Connect a client and send a message
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _natsPort));
var buf = new byte[4096];
_ = await sock.ReceiveAsync(buf, SocketFlags.None); // Read INFO
var cmd = "CONNECT {}\r\nSUB test 1\r\nPUB test 5\r\nhello\r\n"u8.ToArray();
await sock.SendAsync(cmd, SocketFlags.None);
await Task.Delay(200);
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/varz");
var varz = await response.Content.ReadFromJsonAsync<Varz>();
varz.ShouldNotBeNull();
varz.Connections.ShouldBeGreaterThanOrEqualTo(1);
varz.TotalConnections.ShouldBeGreaterThanOrEqualTo(1UL);
varz.InMsgs.ShouldBeGreaterThanOrEqualTo(1L);
varz.InBytes.ShouldBeGreaterThanOrEqualTo(5L);
}
[Fact]
public async Task Connz_returns_connections()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _natsPort));
using var stream = new NetworkStream(sock);
var buf = new byte[4096];
_ = await stream.ReadAsync(buf);
await stream.WriteAsync("CONNECT {\"name\":\"test-client\",\"lang\":\"csharp\",\"version\":\"1.0\"}\r\n"u8.ToArray());
await Task.Delay(200);
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/connz");
response.StatusCode.ShouldBe(HttpStatusCode.OK);
var connz = await response.Content.ReadFromJsonAsync<Connz>();
connz.ShouldNotBeNull();
connz.NumConns.ShouldBeGreaterThanOrEqualTo(1);
connz.Conns.Length.ShouldBeGreaterThanOrEqualTo(1);
var conn = connz.Conns.First(c => c.Name == "test-client");
conn.Ip.ShouldNotBeNullOrEmpty();
conn.Port.ShouldBeGreaterThan(0);
conn.Lang.ShouldBe("csharp");
conn.Version.ShouldBe("1.0");
conn.Uptime.ShouldNotBeNullOrEmpty();
}
[Fact]
public async Task Connz_pagination()
{
var sockets = new List<Socket>();
try
{
for (int i = 0; i < 3; i++)
{
var s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await s.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _natsPort));
var ns = new NetworkStream(s);
var buf = new byte[4096];
_ = await ns.ReadAsync(buf);
await ns.WriteAsync("CONNECT {}\r\n"u8.ToArray());
sockets.Add(s);
}
await Task.Delay(200);
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/connz?limit=2&offset=0");
var connz = await response.Content.ReadFromJsonAsync<Connz>();
connz!.Conns.Length.ShouldBe(2);
connz.Total.ShouldBeGreaterThanOrEqualTo(3);
connz.Limit.ShouldBe(2);
connz.Offset.ShouldBe(0);
}
finally
{
foreach (var s in sockets) s.Dispose();
}
}
[Fact]
public async Task Connz_with_subscriptions()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _natsPort));
using var stream = new NetworkStream(sock);
var buf = new byte[4096];
_ = await stream.ReadAsync(buf);
await stream.WriteAsync("CONNECT {}\r\nSUB foo 1\r\nSUB bar 2\r\n"u8.ToArray());
await Task.Delay(200);
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/connz?subs=true");
var connz = await response.Content.ReadFromJsonAsync<Connz>();
var conn = connz!.Conns.First(c => c.NumSubs >= 2);
conn.Subs.ShouldNotBeNull();
conn.Subs.ShouldContain("foo");
conn.Subs.ShouldContain("bar");
}
private static int GetFreePort()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
return ((IPEndPoint)sock.LocalEndPoint!).Port;
}
}
public class MonitorTlsTests : IAsyncLifetime
{
private readonly NatsServer _server;
private readonly int _natsPort;
private readonly int _monitorPort;
private readonly CancellationTokenSource _cts = new();
private readonly HttpClient _http = new();
private readonly string _certPath;
private readonly string _keyPath;
public MonitorTlsTests()
{
_natsPort = GetFreePort();
_monitorPort = GetFreePort();
(_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles();
_server = new NatsServer(
new NatsOptions
{
Port = _natsPort,
MonitorPort = _monitorPort,
TlsCert = _certPath,
TlsKey = _keyPath,
},
NullLoggerFactory.Instance);
}
public async Task InitializeAsync()
{
_ = _server.StartAsync(_cts.Token);
await _server.WaitForReadyAsync();
// Wait for monitoring HTTP server to be ready
for (int i = 0; i < 50; i++)
{
try
{
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/healthz");
if (response.IsSuccessStatusCode) break;
}
catch (HttpRequestException) { }
await Task.Delay(50);
}
}
public async Task DisposeAsync()
{
_http.Dispose();
await _cts.CancelAsync();
_server.Dispose();
File.Delete(_certPath);
File.Delete(_keyPath);
}
[Fact]
public async Task Connz_shows_tls_info_for_tls_client()
{
// Connect and upgrade to TLS
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, _natsPort);
using var netStream = tcp.GetStream();
var buf = new byte[4096];
_ = await netStream.ReadAsync(buf); // Read INFO
using var ssl = new SslStream(netStream, false, (_, _, _, _) => true);
await ssl.AuthenticateAsClientAsync("localhost");
await ssl.WriteAsync("CONNECT {}\r\n"u8.ToArray());
await ssl.FlushAsync();
await Task.Delay(200);
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/connz");
var connz = await response.Content.ReadFromJsonAsync<Connz>();
connz!.Conns.Length.ShouldBeGreaterThanOrEqualTo(1);
var conn = connz.Conns[0];
conn.TlsVersion.ShouldNotBeNullOrEmpty();
conn.TlsCipherSuite.ShouldNotBeNullOrEmpty();
}
private static int GetFreePort()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
return ((IPEndPoint)sock.LocalEndPoint!).Port;
}
}

View File

@@ -0,0 +1,84 @@
using System.Net;
using System.Net.Sockets;
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Server;
namespace NATS.Server.Tests;
public class ServerStatsTests : IAsyncLifetime
{
private readonly NatsServer _server;
private readonly int _port;
private readonly CancellationTokenSource _cts = new();
public ServerStatsTests()
{
_port = GetFreePort();
_server = new NatsServer(new NatsOptions { Port = _port }, NullLoggerFactory.Instance);
}
public async Task InitializeAsync()
{
_ = _server.StartAsync(_cts.Token);
await _server.WaitForReadyAsync();
}
public Task DisposeAsync()
{
_cts.Cancel();
_server.Dispose();
return Task.CompletedTask;
}
[Fact]
public void Server_has_start_time()
{
_server.StartTime.ShouldNotBe(default);
_server.StartTime.ShouldBeLessThanOrEqualTo(DateTime.UtcNow);
}
[Fact]
public async Task Server_tracks_total_connections()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port));
await Task.Delay(100);
_server.Stats.TotalConnections.ShouldBeGreaterThanOrEqualTo(1);
}
[Fact]
public async Task Server_stats_track_messages()
{
using var pub = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await pub.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port));
var buf = new byte[4096];
await pub.ReceiveAsync(buf, SocketFlags.None); // INFO
await pub.SendAsync("CONNECT {}\r\nSUB test 1\r\nPUB test 5\r\nhello\r\n"u8.ToArray());
await Task.Delay(200);
_server.Stats.InMsgs.ShouldBeGreaterThanOrEqualTo(1);
_server.Stats.InBytes.ShouldBeGreaterThanOrEqualTo(5);
}
[Fact]
public async Task Client_has_metadata()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port));
await Task.Delay(100);
var client = _server.GetClients().First();
client.RemoteIp.ShouldNotBeNullOrEmpty();
client.RemotePort.ShouldBeGreaterThan(0);
client.StartTime.ShouldNotBe(default);
}
private static int GetFreePort()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
return ((IPEndPoint)sock.LocalEndPoint!).Port;
}
}

View File

@@ -0,0 +1,254 @@
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Server;
using NATS.Server.Protocol;
using NATS.Server.Tls;
namespace NATS.Server.Tests;
public class TlsConnectionWrapperTests
{
[Fact]
public async Task NoTls_returns_plain_stream()
{
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
using var serverStream = new NetworkStream(serverSocket, ownsSocket: true);
using var clientStream = new NetworkStream(clientSocket, ownsSocket: true);
var opts = new NatsOptions(); // No TLS configured
var serverInfo = CreateServerInfo();
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
serverSocket, serverStream, opts, null, serverInfo, NullLogger.Instance, CancellationToken.None);
stream.ShouldBe(serverStream); // Same stream, no wrapping
infoSent.ShouldBeFalse();
}
[Fact]
public async Task TlsRequired_upgrades_to_ssl()
{
var (cert, _) = TlsHelperTests.GenerateTestCert();
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy" };
var sslOpts = new SslServerAuthenticationOptions
{
ServerCertificate = cert,
};
var serverInfo = CreateServerInfo();
// Client side: read INFO then start TLS
var clientTask = Task.Run(async () =>
{
// Read INFO line
var buf = new byte[4096];
var read = await clientNetStream.ReadAsync(buf);
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
info.ShouldStartWith("INFO ");
// Upgrade to TLS
var sslClient = new SslStream(clientNetStream, true,
(_, _, _, _) => true); // Trust all for testing
await sslClient.AuthenticateAsClientAsync("localhost");
return sslClient;
});
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
stream.ShouldBeOfType<SslStream>();
infoSent.ShouldBeTrue();
var clientSsl = await clientTask;
// Verify encrypted communication works
await stream.WriteAsync("PING\r\n"u8.ToArray());
await stream.FlushAsync();
var readBuf = new byte[64];
var bytesRead = await clientSsl.ReadAsync(readBuf);
var msg = System.Text.Encoding.ASCII.GetString(readBuf, 0, bytesRead);
msg.ShouldBe("PING\r\n");
stream.Dispose();
clientSsl.Dispose();
}
[Fact]
public async Task MixedMode_allows_plaintext_when_AllowNonTls()
{
var (cert, _) = TlsHelperTests.GenerateTestCert();
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
var opts = new NatsOptions
{
TlsCert = "dummy",
TlsKey = "dummy",
AllowNonTls = true,
TlsTimeout = TimeSpan.FromSeconds(2),
};
var sslOpts = new SslServerAuthenticationOptions
{
ServerCertificate = cert,
};
var serverInfo = CreateServerInfo();
// Client side: read INFO then send plaintext (not TLS)
var clientTask = Task.Run(async () =>
{
var buf = new byte[4096];
var read = await clientNetStream.ReadAsync(buf);
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
info.ShouldStartWith("INFO ");
// Send plaintext CONNECT (not a TLS handshake)
var connectLine = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n");
await clientNetStream.WriteAsync(connectLine);
await clientNetStream.FlushAsync();
});
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
await clientTask;
// In mixed mode with plaintext client, we get a PeekableStream, not SslStream
stream.ShouldBeOfType<PeekableStream>();
infoSent.ShouldBeTrue();
stream.Dispose();
}
[Fact]
public async Task TlsRequired_rejects_plaintext()
{
var (cert, _) = TlsHelperTests.GenerateTestCert();
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
var opts = new NatsOptions
{
TlsCert = "dummy",
TlsKey = "dummy",
AllowNonTls = false,
TlsTimeout = TimeSpan.FromSeconds(2),
};
var sslOpts = new SslServerAuthenticationOptions
{
ServerCertificate = cert,
};
var serverInfo = CreateServerInfo();
// Client side: read INFO then send plaintext
var clientTask = Task.Run(async () =>
{
var buf = new byte[4096];
var read = await clientNetStream.ReadAsync(buf);
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
info.ShouldStartWith("INFO ");
// Send plaintext data (first byte is 'C', not 0x16 TLS marker)
var connectLine = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n");
await clientNetStream.WriteAsync(connectLine);
await clientNetStream.FlushAsync();
});
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
await Should.ThrowAsync<InvalidOperationException>(async () =>
{
await TlsConnectionWrapper.NegotiateAsync(
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
});
await clientTask;
serverNetStream.Dispose();
}
[Fact]
public async Task TlsFirst_handshakes_before_sending_info()
{
var (cert, _) = TlsHelperTests.GenerateTestCert();
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy", TlsHandshakeFirst = true };
var sslOpts = new SslServerAuthenticationOptions
{
ServerCertificate = cert,
};
var serverInfo = CreateServerInfo();
// Client side: immediately start TLS (no INFO first)
var clientTask = Task.Run(async () =>
{
var sslClient = new SslStream(clientNetStream, true, (_, _, _, _) => true);
await sslClient.AuthenticateAsClientAsync("localhost");
// After TLS, read INFO over encrypted stream
var buf = new byte[4096];
var read = await sslClient.ReadAsync(buf);
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
info.ShouldStartWith("INFO ");
return sslClient;
});
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
stream.ShouldBeOfType<SslStream>();
infoSent.ShouldBeTrue();
var clientSsl = await clientTask;
// Verify encrypted communication works
await stream.WriteAsync("PING\r\n"u8.ToArray());
await stream.FlushAsync();
var readBuf = new byte[64];
var bytesRead = await clientSsl.ReadAsync(readBuf);
var msg = System.Text.Encoding.ASCII.GetString(readBuf, 0, bytesRead);
msg.ShouldBe("PING\r\n");
stream.Dispose();
clientSsl.Dispose();
}
private static ServerInfo CreateServerInfo() => new()
{
ServerId = "TEST",
ServerName = "test",
Version = NatsProtocol.Version,
Host = "127.0.0.1",
Port = 4222,
};
private static async Task<(Socket server, Socket client)> CreateSocketPairAsync()
{
using var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
listener.Listen(1);
var port = ((IPEndPoint)listener.LocalEndPoint!).Port;
var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await client.ConnectAsync(new IPEndPoint(IPAddress.Loopback, port));
var server = await listener.AcceptAsync();
return (server, client);
}
}

View File

@@ -0,0 +1,110 @@
using System.Net;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using NATS.Server;
using NATS.Server.Tls;
namespace NATS.Server.Tests;
public class TlsHelperTests
{
[Fact]
public void LoadCertificate_loads_pem_cert_and_key()
{
var (certPath, keyPath) = GenerateTestCertFiles();
try
{
var cert = TlsHelper.LoadCertificate(certPath, keyPath);
cert.ShouldNotBeNull();
cert.HasPrivateKey.ShouldBeTrue();
}
finally { File.Delete(certPath); File.Delete(keyPath); }
}
[Fact]
public void BuildServerAuthOptions_creates_valid_options()
{
var (certPath, keyPath) = GenerateTestCertFiles();
try
{
var opts = new NatsOptions { TlsCert = certPath, TlsKey = keyPath };
var authOpts = TlsHelper.BuildServerAuthOptions(opts);
authOpts.ShouldNotBeNull();
authOpts.ServerCertificate.ShouldNotBeNull();
}
finally { File.Delete(certPath); File.Delete(keyPath); }
}
[Fact]
public void MatchesPinnedCert_matches_correct_hash()
{
var (cert, _) = GenerateTestCert();
var hash = TlsHelper.GetCertificateHash(cert);
var pinned = new HashSet<string> { hash };
TlsHelper.MatchesPinnedCert(cert, pinned).ShouldBeTrue();
}
[Fact]
public void MatchesPinnedCert_rejects_wrong_hash()
{
var (cert, _) = GenerateTestCert();
var pinned = new HashSet<string> { "0000000000000000000000000000000000000000000000000000000000000000" };
TlsHelper.MatchesPinnedCert(cert, pinned).ShouldBeFalse();
}
[Fact]
public async Task PeekableStream_peeks_and_replays()
{
var data = "Hello, World!"u8.ToArray();
using var ms = new MemoryStream(data);
using var peekable = new PeekableStream(ms);
var peeked = await peekable.PeekAsync(1);
peeked.Length.ShouldBe(1);
peeked[0].ShouldBe((byte)'H');
var buf = new byte[data.Length];
int total = 0;
while (total < data.Length)
{
var read = await peekable.ReadAsync(buf.AsMemory(total));
if (read == 0) break;
total += read;
}
total.ShouldBe(data.Length);
buf.ShouldBe(data);
}
[Fact]
public async Task TlsRateLimiter_allows_within_limit()
{
using var limiter = new TlsRateLimiter(10);
var cts = new CancellationTokenSource(TimeSpan.FromSeconds(2));
for (int i = 0; i < 5; i++)
await limiter.WaitAsync(cts.Token);
}
// Public helper methods used by other test classes
public static (string certPath, string keyPath) GenerateTestCertFiles()
{
var (cert, key) = GenerateTestCert();
var certPath = Path.GetTempFileName();
var keyPath = Path.GetTempFileName();
File.WriteAllText(certPath, cert.ExportCertificatePem());
File.WriteAllText(keyPath, key.ExportPkcs8PrivateKeyPem());
return (certPath, keyPath);
}
public static (X509Certificate2 cert, RSA key) GenerateTestCert()
{
var key = RSA.Create(2048);
var req = new CertificateRequest("CN=localhost", key, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
req.CertificateExtensions.Add(new X509BasicConstraintsExtension(false, false, 0, false));
var sanBuilder = new SubjectAlternativeNameBuilder();
sanBuilder.AddIpAddress(IPAddress.Loopback);
sanBuilder.AddDnsName("localhost");
req.CertificateExtensions.Add(sanBuilder.Build());
var cert = req.CreateSelfSigned(DateTimeOffset.UtcNow, DateTimeOffset.UtcNow.AddYears(1));
return (cert, key);
}
}

View File

@@ -0,0 +1,225 @@
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Text;
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Server;
namespace NATS.Server.Tests;
public class TlsServerTests : IAsyncLifetime
{
private readonly NatsServer _server;
private readonly int _port;
private readonly CancellationTokenSource _cts = new();
private readonly string _certPath;
private readonly string _keyPath;
public TlsServerTests()
{
_port = GetFreePort();
(_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles();
_server = new NatsServer(
new NatsOptions
{
Port = _port,
TlsCert = _certPath,
TlsKey = _keyPath,
},
NullLoggerFactory.Instance);
}
public async Task InitializeAsync()
{
_ = _server.StartAsync(_cts.Token);
await _server.WaitForReadyAsync();
}
public async Task DisposeAsync()
{
await _cts.CancelAsync();
_server.Dispose();
File.Delete(_certPath);
File.Delete(_keyPath);
}
[Fact]
public async Task Tls_client_connects_and_receives_info()
{
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, _port);
using var netStream = tcp.GetStream();
// Read INFO (sent before TLS upgrade in Mode 2)
var buf = new byte[4096];
var read = await netStream.ReadAsync(buf);
var info = Encoding.ASCII.GetString(buf, 0, read);
info.ShouldStartWith("INFO ");
info.ShouldContain("\"tls_required\":true");
// Upgrade to TLS
using var sslStream = new SslStream(netStream, false, (_, _, _, _) => true);
await sslStream.AuthenticateAsClientAsync("localhost");
// Send CONNECT + PING over TLS
await sslStream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
await sslStream.FlushAsync();
// Read PONG
var pongBuf = new byte[256];
read = await sslStream.ReadAsync(pongBuf);
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
pong.ShouldContain("PONG");
}
[Fact]
public async Task Tls_pubsub_works_over_encrypted_connection()
{
using var tcp1 = new TcpClient();
await tcp1.ConnectAsync(IPAddress.Loopback, _port);
using var ssl1 = await UpgradeToTlsAsync(tcp1);
using var tcp2 = new TcpClient();
await tcp2.ConnectAsync(IPAddress.Loopback, _port);
using var ssl2 = await UpgradeToTlsAsync(tcp2);
// Sub on client 1
await ssl1.WriteAsync("CONNECT {}\r\nSUB test 1\r\nPING\r\n"u8.ToArray());
await ssl1.FlushAsync();
// Wait for PONG to confirm subscription is registered
var pongBuf = new byte[256];
var pongRead = await ssl1.ReadAsync(pongBuf);
var pongStr = Encoding.ASCII.GetString(pongBuf, 0, pongRead);
pongStr.ShouldContain("PONG");
// Pub on client 2
await ssl2.WriteAsync("CONNECT {}\r\nPUB test 5\r\nhello\r\nPING\r\n"u8.ToArray());
await ssl2.FlushAsync();
// Client 1 should receive MSG (may arrive across multiple TLS records)
var msg = await ReadUntilAsync(ssl1, "hello");
msg.ShouldContain("MSG test 1 5");
msg.ShouldContain("hello");
}
private static async Task<string> ReadUntilAsync(Stream stream, string expected, int timeoutMs = 5000)
{
using var cts = new CancellationTokenSource(timeoutMs);
var sb = new StringBuilder();
var buf = new byte[4096];
while (!sb.ToString().Contains(expected))
{
var n = await stream.ReadAsync(buf, cts.Token);
if (n == 0) break;
sb.Append(Encoding.ASCII.GetString(buf, 0, n));
}
return sb.ToString();
}
private static async Task<SslStream> UpgradeToTlsAsync(TcpClient tcp)
{
var netStream = tcp.GetStream();
var buf = new byte[4096];
_ = await netStream.ReadAsync(buf); // Read INFO (discard)
var ssl = new SslStream(netStream, false, (_, _, _, _) => true);
await ssl.AuthenticateAsClientAsync("localhost");
return ssl;
}
private static int GetFreePort()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
return ((IPEndPoint)sock.LocalEndPoint!).Port;
}
}
public class TlsMixedModeTests : IAsyncLifetime
{
private readonly NatsServer _server;
private readonly int _port;
private readonly CancellationTokenSource _cts = new();
private readonly string _certPath;
private readonly string _keyPath;
public TlsMixedModeTests()
{
_port = GetFreePort();
(_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles();
_server = new NatsServer(
new NatsOptions
{
Port = _port,
TlsCert = _certPath,
TlsKey = _keyPath,
AllowNonTls = true,
},
NullLoggerFactory.Instance);
}
public async Task InitializeAsync()
{
_ = _server.StartAsync(_cts.Token);
await _server.WaitForReadyAsync();
}
public async Task DisposeAsync()
{
await _cts.CancelAsync();
_server.Dispose();
File.Delete(_certPath);
File.Delete(_keyPath);
}
[Fact]
public async Task Mixed_mode_accepts_plain_client()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port));
using var stream = new NetworkStream(sock);
var buf = new byte[4096];
var read = await stream.ReadAsync(buf);
var info = Encoding.ASCII.GetString(buf, 0, read);
info.ShouldContain("\"tls_available\":true");
await stream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
await stream.FlushAsync();
var pongBuf = new byte[64];
read = await stream.ReadAsync(pongBuf);
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
pong.ShouldContain("PONG");
}
[Fact]
public async Task Mixed_mode_accepts_tls_client()
{
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, _port);
using var netStream = tcp.GetStream();
var buf = new byte[4096];
_ = await netStream.ReadAsync(buf); // Read INFO
using var ssl = new SslStream(netStream, false, (_, _, _, _) => true);
await ssl.AuthenticateAsClientAsync("localhost");
await ssl.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
await ssl.FlushAsync();
var pongBuf = new byte[64];
var read = await ssl.ReadAsync(pongBuf);
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
pong.ShouldContain("PONG");
}
private static int GetFreePort()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
return ((IPEndPoint)sock.LocalEndPoint!).Port;
}
}