feat: add monitoring HTTP endpoints and TLS support

Monitoring HTTP:
- /varz, /connz, /healthz via Kestrel Minimal API
- Pagination, sorting, subscription details on /connz
- ServerStats atomic counters, CPU/memory sampling
- CLI flags: -m, --http_port, --http_base_path, --https_port

TLS Support:
- 4-mode negotiation: no TLS, required, TLS-first, mixed
- Certificate loading, pinning (SHA-256), client cert verification
- PeekableStream for non-destructive TLS detection
- Token-bucket rate limiter for TLS handshakes
- CLI flags: --tls, --tlscert, --tlskey, --tlscacert, --tlsverify

29 new tests (78 → 107 total), all passing.

# Conflicts:
#	src/NATS.Server.Host/Program.cs
#	src/NATS.Server/NATS.Server.csproj
#	src/NATS.Server/NatsClient.cs
#	src/NATS.Server/NatsOptions.cs
#	src/NATS.Server/NatsServer.cs
#	src/NATS.Server/Protocol/NatsProtocol.cs
#	tests/NATS.Server.Tests/ClientTests.cs
This commit is contained in:
Joseph Doherty
2026-02-22 23:13:22 -05:00
24 changed files with 2596 additions and 43 deletions

View File

@@ -32,6 +32,20 @@ for (int i = 0; i < args.Length; i++)
case "--https_port" when i + 1 < args.Length:
options.MonitorHttpsPort = int.Parse(args[++i]);
break;
case "--tls":
break;
case "--tlscert" when i + 1 < args.Length:
options.TlsCert = args[++i];
break;
case "--tlskey" when i + 1 < args.Length:
options.TlsKey = args[++i];
break;
case "--tlscacert" when i + 1 < args.Length:
options.TlsCaCert = args[++i];
break;
case "--tlsverify":
options.TlsVerify = true;
break;
}
}

View File

@@ -0,0 +1,207 @@
using System.Text.Json.Serialization;
namespace NATS.Server.Monitoring;
/// <summary>
/// Connection information response. Corresponds to Go server/monitor.go Connz struct.
/// </summary>
public sealed class Connz
{
[JsonPropertyName("server_id")]
public string Id { get; set; } = "";
[JsonPropertyName("now")]
public DateTime Now { get; set; }
[JsonPropertyName("num_connections")]
public int NumConns { get; set; }
[JsonPropertyName("total")]
public int Total { get; set; }
[JsonPropertyName("offset")]
public int Offset { get; set; }
[JsonPropertyName("limit")]
public int Limit { get; set; }
[JsonPropertyName("connections")]
public ConnInfo[] Conns { get; set; } = [];
}
/// <summary>
/// Detailed information on a per-connection basis.
/// Corresponds to Go server/monitor.go ConnInfo struct.
/// </summary>
public sealed class ConnInfo
{
[JsonPropertyName("cid")]
public ulong Cid { get; set; }
[JsonPropertyName("kind")]
public string Kind { get; set; } = "";
[JsonPropertyName("type")]
public string Type { get; set; } = "";
[JsonPropertyName("ip")]
public string Ip { get; set; } = "";
[JsonPropertyName("port")]
public int Port { get; set; }
[JsonPropertyName("start")]
public DateTime Start { get; set; }
[JsonPropertyName("last_activity")]
public DateTime LastActivity { get; set; }
[JsonPropertyName("stop")]
public DateTime? Stop { get; set; }
[JsonPropertyName("reason")]
public string Reason { get; set; } = "";
[JsonPropertyName("rtt")]
public string Rtt { get; set; } = "";
[JsonPropertyName("uptime")]
public string Uptime { get; set; } = "";
[JsonPropertyName("idle")]
public string Idle { get; set; } = "";
[JsonPropertyName("pending_bytes")]
public int Pending { get; set; }
[JsonPropertyName("in_msgs")]
public long InMsgs { get; set; }
[JsonPropertyName("out_msgs")]
public long OutMsgs { get; set; }
[JsonPropertyName("in_bytes")]
public long InBytes { get; set; }
[JsonPropertyName("out_bytes")]
public long OutBytes { get; set; }
[JsonPropertyName("subscriptions")]
public uint NumSubs { get; set; }
[JsonPropertyName("subscriptions_list")]
public string[] Subs { get; set; } = [];
[JsonPropertyName("subscriptions_list_detail")]
public SubDetail[] SubsDetail { get; set; } = [];
[JsonPropertyName("name")]
public string Name { get; set; } = "";
[JsonPropertyName("lang")]
public string Lang { get; set; } = "";
[JsonPropertyName("version")]
public string Version { get; set; } = "";
[JsonPropertyName("authorized_user")]
public string AuthorizedUser { get; set; } = "";
[JsonPropertyName("account")]
public string Account { get; set; } = "";
[JsonPropertyName("tls_version")]
public string TlsVersion { get; set; } = "";
[JsonPropertyName("tls_cipher_suite")]
public string TlsCipherSuite { get; set; } = "";
[JsonPropertyName("tls_first")]
public bool TlsFirst { get; set; }
[JsonPropertyName("mqtt_client")]
public string MqttClient { get; set; } = "";
}
/// <summary>
/// Subscription detail information.
/// Corresponds to Go server/monitor.go SubDetail struct.
/// </summary>
public sealed class SubDetail
{
[JsonPropertyName("account")]
public string Account { get; set; } = "";
[JsonPropertyName("subject")]
public string Subject { get; set; } = "";
[JsonPropertyName("qgroup")]
public string Queue { get; set; } = "";
[JsonPropertyName("sid")]
public string Sid { get; set; } = "";
[JsonPropertyName("msgs")]
public long Msgs { get; set; }
[JsonPropertyName("max")]
public long Max { get; set; }
[JsonPropertyName("cid")]
public ulong Cid { get; set; }
}
/// <summary>
/// Sort options for connection listing.
/// Corresponds to Go server/monitor_sort_opts.go SortOpt type.
/// </summary>
public enum SortOpt
{
ByCid,
ByStart,
BySubs,
ByPending,
ByMsgsTo,
ByMsgsFrom,
ByBytesTo,
ByBytesFrom,
ByLast,
ByIdle,
ByUptime,
}
/// <summary>
/// Connection state filter.
/// Corresponds to Go server/monitor.go ConnState type.
/// </summary>
public enum ConnState
{
Open,
Closed,
All,
}
/// <summary>
/// Options passed to Connz() for filtering and sorting.
/// Corresponds to Go server/monitor.go ConnzOptions struct.
/// </summary>
public sealed class ConnzOptions
{
public SortOpt Sort { get; set; } = SortOpt.ByCid;
public bool Subscriptions { get; set; }
public bool SubscriptionsDetail { get; set; }
public ConnState State { get; set; } = ConnState.Open;
public string User { get; set; } = "";
public string Account { get; set; } = "";
public string FilterSubject { get; set; } = "";
public int Offset { get; set; }
public int Limit { get; set; } = 1024;
}

View File

@@ -0,0 +1,148 @@
using Microsoft.AspNetCore.Http;
namespace NATS.Server.Monitoring;
/// <summary>
/// Handles /connz endpoint requests, returning detailed connection information.
/// Corresponds to Go server/monitor.go handleConnz function.
/// </summary>
public sealed class ConnzHandler(NatsServer server)
{
public Connz HandleConnz(HttpContext ctx)
{
var opts = ParseQueryParams(ctx);
var now = DateTime.UtcNow;
var clients = server.GetClients().ToArray();
var connInfos = clients.Select(c => BuildConnInfo(c, now, opts)).ToList();
// Sort
connInfos = opts.Sort switch
{
SortOpt.ByCid => connInfos.OrderBy(c => c.Cid).ToList(),
SortOpt.ByStart => connInfos.OrderBy(c => c.Start).ToList(),
SortOpt.BySubs => connInfos.OrderByDescending(c => c.NumSubs).ToList(),
SortOpt.ByPending => connInfos.OrderByDescending(c => c.Pending).ToList(),
SortOpt.ByMsgsTo => connInfos.OrderByDescending(c => c.OutMsgs).ToList(),
SortOpt.ByMsgsFrom => connInfos.OrderByDescending(c => c.InMsgs).ToList(),
SortOpt.ByBytesTo => connInfos.OrderByDescending(c => c.OutBytes).ToList(),
SortOpt.ByBytesFrom => connInfos.OrderByDescending(c => c.InBytes).ToList(),
SortOpt.ByLast => connInfos.OrderByDescending(c => c.LastActivity).ToList(),
SortOpt.ByIdle => connInfos.OrderByDescending(c => now - c.LastActivity).ToList(),
SortOpt.ByUptime => connInfos.OrderByDescending(c => now - c.Start).ToList(),
_ => connInfos.OrderBy(c => c.Cid).ToList(),
};
var total = connInfos.Count;
var paged = connInfos.Skip(opts.Offset).Take(opts.Limit).ToArray();
return new Connz
{
Id = server.ServerId,
Now = now,
NumConns = paged.Length,
Total = total,
Offset = opts.Offset,
Limit = opts.Limit,
Conns = paged,
};
}
private static ConnInfo BuildConnInfo(NatsClient client, DateTime now, ConnzOptions opts)
{
var info = new ConnInfo
{
Cid = client.Id,
Kind = "Client",
Type = "Client",
Ip = client.RemoteIp ?? "",
Port = client.RemotePort,
Start = client.StartTime,
LastActivity = client.LastActivity,
Uptime = FormatDuration(now - client.StartTime),
Idle = FormatDuration(now - client.LastActivity),
InMsgs = Interlocked.Read(ref client.InMsgs),
OutMsgs = Interlocked.Read(ref client.OutMsgs),
InBytes = Interlocked.Read(ref client.InBytes),
OutBytes = Interlocked.Read(ref client.OutBytes),
NumSubs = (uint)client.Subscriptions.Count,
Name = client.ClientOpts?.Name ?? "",
Lang = client.ClientOpts?.Lang ?? "",
Version = client.ClientOpts?.Version ?? "",
TlsVersion = client.TlsState?.TlsVersion ?? "",
TlsCipherSuite = client.TlsState?.CipherSuite ?? "",
};
if (opts.Subscriptions)
{
info.Subs = client.Subscriptions.Values.Select(s => s.Subject).ToArray();
}
if (opts.SubscriptionsDetail)
{
info.SubsDetail = client.Subscriptions.Values.Select(s => new SubDetail
{
Subject = s.Subject,
Queue = s.Queue ?? "",
Sid = s.Sid,
Msgs = Interlocked.Read(ref s.MessageCount),
Max = s.MaxMessages,
Cid = client.Id,
}).ToArray();
}
return info;
}
private static ConnzOptions ParseQueryParams(HttpContext ctx)
{
var q = ctx.Request.Query;
var opts = new ConnzOptions();
if (q.TryGetValue("sort", out var sort))
{
opts.Sort = sort.ToString().ToLowerInvariant() switch
{
"cid" => SortOpt.ByCid,
"start" => SortOpt.ByStart,
"subs" => SortOpt.BySubs,
"pending" => SortOpt.ByPending,
"msgs_to" => SortOpt.ByMsgsTo,
"msgs_from" => SortOpt.ByMsgsFrom,
"bytes_to" => SortOpt.ByBytesTo,
"bytes_from" => SortOpt.ByBytesFrom,
"last" => SortOpt.ByLast,
"idle" => SortOpt.ByIdle,
"uptime" => SortOpt.ByUptime,
_ => SortOpt.ByCid,
};
}
if (q.TryGetValue("subs", out var subs))
{
if (subs == "detail")
opts.SubscriptionsDetail = true;
else
opts.Subscriptions = true;
}
if (q.TryGetValue("offset", out var offset) && int.TryParse(offset, out var o))
opts.Offset = o;
if (q.TryGetValue("limit", out var limit) && int.TryParse(limit, out var l))
opts.Limit = l;
return opts;
}
private static string FormatDuration(TimeSpan ts)
{
if (ts.TotalDays >= 1)
return $"{(int)ts.TotalDays}d{ts.Hours}h{ts.Minutes}m{ts.Seconds}s";
if (ts.TotalHours >= 1)
return $"{(int)ts.TotalHours}h{ts.Minutes}m{ts.Seconds}s";
if (ts.TotalMinutes >= 1)
return $"{(int)ts.TotalMinutes}m{ts.Seconds}s";
return $"{(int)ts.TotalSeconds}s";
}
}

View File

@@ -0,0 +1,117 @@
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
namespace NATS.Server.Monitoring;
/// <summary>
/// HTTP monitoring server providing /healthz, /varz, and other monitoring endpoints.
/// Corresponds to Go server/monitor.go HTTP server setup.
/// </summary>
public sealed class MonitorServer : IAsyncDisposable
{
private readonly WebApplication _app;
private readonly ILogger<MonitorServer> _logger;
private readonly VarzHandler _varzHandler;
private readonly ConnzHandler _connzHandler;
public MonitorServer(NatsServer server, NatsOptions options, ServerStats stats, ILoggerFactory loggerFactory)
{
_logger = loggerFactory.CreateLogger<MonitorServer>();
var builder = WebApplication.CreateSlimBuilder();
builder.WebHost.UseUrls($"http://{options.MonitorHost}:{options.MonitorPort}");
builder.Logging.ClearProviders();
_app = builder.Build();
var basePath = options.MonitorBasePath ?? "";
_varzHandler = new VarzHandler(server, options);
_connzHandler = new ConnzHandler(server);
_app.MapGet(basePath + "/", () =>
{
stats.HttpReqStats.AddOrUpdate("/", 1, (_, v) => v + 1);
return Results.Ok(new
{
endpoints = new[]
{
"/varz", "/connz", "/healthz", "/routez",
"/gatewayz", "/leafz", "/subz", "/accountz", "/jsz",
},
});
});
_app.MapGet(basePath + "/healthz", () =>
{
stats.HttpReqStats.AddOrUpdate("/healthz", 1, (_, v) => v + 1);
return Results.Ok("ok");
});
_app.MapGet(basePath + "/varz", async (HttpContext ctx) =>
{
stats.HttpReqStats.AddOrUpdate("/varz", 1, (_, v) => v + 1);
return Results.Ok(await _varzHandler.HandleVarzAsync(ctx.RequestAborted));
});
_app.MapGet(basePath + "/connz", (HttpContext ctx) =>
{
stats.HttpReqStats.AddOrUpdate("/connz", 1, (_, v) => v + 1);
return Results.Ok(_connzHandler.HandleConnz(ctx));
});
// Stubs for unimplemented endpoints
_app.MapGet(basePath + "/routez", () =>
{
stats.HttpReqStats.AddOrUpdate("/routez", 1, (_, v) => v + 1);
return Results.Ok(new { });
});
_app.MapGet(basePath + "/gatewayz", () =>
{
stats.HttpReqStats.AddOrUpdate("/gatewayz", 1, (_, v) => v + 1);
return Results.Ok(new { });
});
_app.MapGet(basePath + "/leafz", () =>
{
stats.HttpReqStats.AddOrUpdate("/leafz", 1, (_, v) => v + 1);
return Results.Ok(new { });
});
_app.MapGet(basePath + "/subz", () =>
{
stats.HttpReqStats.AddOrUpdate("/subz", 1, (_, v) => v + 1);
return Results.Ok(new { });
});
_app.MapGet(basePath + "/subscriptionsz", () =>
{
stats.HttpReqStats.AddOrUpdate("/subscriptionsz", 1, (_, v) => v + 1);
return Results.Ok(new { });
});
_app.MapGet(basePath + "/accountz", () =>
{
stats.HttpReqStats.AddOrUpdate("/accountz", 1, (_, v) => v + 1);
return Results.Ok(new { });
});
_app.MapGet(basePath + "/accstatz", () =>
{
stats.HttpReqStats.AddOrUpdate("/accstatz", 1, (_, v) => v + 1);
return Results.Ok(new { });
});
_app.MapGet(basePath + "/jsz", () =>
{
stats.HttpReqStats.AddOrUpdate("/jsz", 1, (_, v) => v + 1);
return Results.Ok(new { });
});
}
public async Task StartAsync(CancellationToken ct)
{
await _app.StartAsync(ct);
_logger.LogInformation("Monitoring listening on {Urls}", string.Join(", ", _app.Urls));
}
public async ValueTask DisposeAsync()
{
await _app.StopAsync();
await _app.DisposeAsync();
_varzHandler.Dispose();
}
}

View File

@@ -0,0 +1,415 @@
using System.Text.Json.Serialization;
namespace NATS.Server.Monitoring;
/// <summary>
/// Server general information. Corresponds to Go server/monitor.go Varz struct.
/// </summary>
public sealed class Varz
{
// Identity
[JsonPropertyName("server_id")]
public string Id { get; set; } = "";
[JsonPropertyName("server_name")]
public string Name { get; set; } = "";
[JsonPropertyName("version")]
public string Version { get; set; } = "";
[JsonPropertyName("proto")]
public int Proto { get; set; }
[JsonPropertyName("git_commit")]
public string GitCommit { get; set; } = "";
[JsonPropertyName("go")]
public string GoVersion { get; set; } = "";
[JsonPropertyName("host")]
public string Host { get; set; } = "";
[JsonPropertyName("port")]
public int Port { get; set; }
// Network
[JsonPropertyName("ip")]
public string Ip { get; set; } = "";
[JsonPropertyName("connect_urls")]
public string[] ConnectUrls { get; set; } = [];
[JsonPropertyName("ws_connect_urls")]
public string[] WsConnectUrls { get; set; } = [];
[JsonPropertyName("http_host")]
public string HttpHost { get; set; } = "";
[JsonPropertyName("http_port")]
public int HttpPort { get; set; }
[JsonPropertyName("http_base_path")]
public string HttpBasePath { get; set; } = "";
[JsonPropertyName("https_port")]
public int HttpsPort { get; set; }
// Security
[JsonPropertyName("auth_required")]
public bool AuthRequired { get; set; }
[JsonPropertyName("tls_required")]
public bool TlsRequired { get; set; }
[JsonPropertyName("tls_verify")]
public bool TlsVerify { get; set; }
[JsonPropertyName("tls_ocsp_peer_verify")]
public bool TlsOcspPeerVerify { get; set; }
[JsonPropertyName("auth_timeout")]
public double AuthTimeout { get; set; }
[JsonPropertyName("tls_timeout")]
public double TlsTimeout { get; set; }
// Limits
[JsonPropertyName("max_connections")]
public int MaxConnections { get; set; }
[JsonPropertyName("max_subscriptions")]
public int MaxSubscriptions { get; set; }
[JsonPropertyName("max_payload")]
public int MaxPayload { get; set; }
[JsonPropertyName("max_pending")]
public long MaxPending { get; set; }
[JsonPropertyName("max_control_line")]
public int MaxControlLine { get; set; }
[JsonPropertyName("ping_max")]
public int MaxPingsOut { get; set; }
// Timing
[JsonPropertyName("ping_interval")]
public long PingInterval { get; set; }
[JsonPropertyName("write_deadline")]
public long WriteDeadline { get; set; }
[JsonPropertyName("start")]
public DateTime Start { get; set; }
[JsonPropertyName("now")]
public DateTime Now { get; set; }
[JsonPropertyName("uptime")]
public string Uptime { get; set; } = "";
// Runtime
[JsonPropertyName("mem")]
public long Mem { get; set; }
[JsonPropertyName("cpu")]
public double Cpu { get; set; }
[JsonPropertyName("cores")]
public int Cores { get; set; }
[JsonPropertyName("gomaxprocs")]
public int MaxProcs { get; set; }
// Connections
[JsonPropertyName("connections")]
public int Connections { get; set; }
[JsonPropertyName("total_connections")]
public ulong TotalConnections { get; set; }
[JsonPropertyName("routes")]
public int Routes { get; set; }
[JsonPropertyName("remotes")]
public int Remotes { get; set; }
[JsonPropertyName("leafnodes")]
public int Leafnodes { get; set; }
// Messages
[JsonPropertyName("in_msgs")]
public long InMsgs { get; set; }
[JsonPropertyName("out_msgs")]
public long OutMsgs { get; set; }
[JsonPropertyName("in_bytes")]
public long InBytes { get; set; }
[JsonPropertyName("out_bytes")]
public long OutBytes { get; set; }
// Health
[JsonPropertyName("slow_consumers")]
public long SlowConsumers { get; set; }
[JsonPropertyName("slow_consumer_stats")]
public SlowConsumersStats SlowConsumerStats { get; set; } = new();
[JsonPropertyName("subscriptions")]
public uint Subscriptions { get; set; }
// Config
[JsonPropertyName("config_load_time")]
public DateTime ConfigLoadTime { get; set; }
[JsonPropertyName("tags")]
public string[] Tags { get; set; } = [];
[JsonPropertyName("system_account")]
public string SystemAccount { get; set; } = "";
[JsonPropertyName("pinned_account_fails")]
public ulong PinnedAccountFail { get; set; }
[JsonPropertyName("tls_cert_not_after")]
public DateTime TlsCertNotAfter { get; set; }
// HTTP
[JsonPropertyName("http_req_stats")]
public Dictionary<string, ulong> HttpReqStats { get; set; } = new();
// Subsystems
[JsonPropertyName("cluster")]
public ClusterOptsVarz Cluster { get; set; } = new();
[JsonPropertyName("gateway")]
public GatewayOptsVarz Gateway { get; set; } = new();
[JsonPropertyName("leaf")]
public LeafNodeOptsVarz Leaf { get; set; } = new();
[JsonPropertyName("mqtt")]
public MqttOptsVarz Mqtt { get; set; } = new();
[JsonPropertyName("websocket")]
public WebsocketOptsVarz Websocket { get; set; } = new();
[JsonPropertyName("jetstream")]
public JetStreamVarz JetStream { get; set; } = new();
}
/// <summary>
/// Statistics about slow consumers by connection type.
/// Corresponds to Go server/monitor.go SlowConsumersStats struct.
/// </summary>
public sealed class SlowConsumersStats
{
[JsonPropertyName("clients")]
public ulong Clients { get; set; }
[JsonPropertyName("routes")]
public ulong Routes { get; set; }
[JsonPropertyName("gateways")]
public ulong Gateways { get; set; }
[JsonPropertyName("leafs")]
public ulong Leafs { get; set; }
}
/// <summary>
/// Cluster configuration monitoring information.
/// Corresponds to Go server/monitor.go ClusterOptsVarz struct.
/// </summary>
public sealed class ClusterOptsVarz
{
[JsonPropertyName("name")]
public string Name { get; set; } = "";
[JsonPropertyName("addr")]
public string Host { get; set; } = "";
[JsonPropertyName("cluster_port")]
public int Port { get; set; }
[JsonPropertyName("auth_timeout")]
public double AuthTimeout { get; set; }
[JsonPropertyName("tls_timeout")]
public double TlsTimeout { get; set; }
[JsonPropertyName("tls_required")]
public bool TlsRequired { get; set; }
[JsonPropertyName("tls_verify")]
public bool TlsVerify { get; set; }
[JsonPropertyName("pool_size")]
public int PoolSize { get; set; }
[JsonPropertyName("urls")]
public string[] Urls { get; set; } = [];
}
/// <summary>
/// Gateway configuration monitoring information.
/// Corresponds to Go server/monitor.go GatewayOptsVarz struct.
/// </summary>
public sealed class GatewayOptsVarz
{
[JsonPropertyName("name")]
public string Name { get; set; } = "";
[JsonPropertyName("host")]
public string Host { get; set; } = "";
[JsonPropertyName("port")]
public int Port { get; set; }
[JsonPropertyName("auth_timeout")]
public double AuthTimeout { get; set; }
[JsonPropertyName("tls_timeout")]
public double TlsTimeout { get; set; }
[JsonPropertyName("tls_required")]
public bool TlsRequired { get; set; }
[JsonPropertyName("tls_verify")]
public bool TlsVerify { get; set; }
[JsonPropertyName("advertise")]
public string Advertise { get; set; } = "";
[JsonPropertyName("connect_retries")]
public int ConnectRetries { get; set; }
[JsonPropertyName("reject_unknown")]
public bool RejectUnknown { get; set; }
}
/// <summary>
/// Leaf node configuration monitoring information.
/// Corresponds to Go server/monitor.go LeafNodeOptsVarz struct.
/// </summary>
public sealed class LeafNodeOptsVarz
{
[JsonPropertyName("host")]
public string Host { get; set; } = "";
[JsonPropertyName("port")]
public int Port { get; set; }
[JsonPropertyName("auth_timeout")]
public double AuthTimeout { get; set; }
[JsonPropertyName("tls_timeout")]
public double TlsTimeout { get; set; }
[JsonPropertyName("tls_required")]
public bool TlsRequired { get; set; }
[JsonPropertyName("tls_verify")]
public bool TlsVerify { get; set; }
[JsonPropertyName("tls_ocsp_peer_verify")]
public bool TlsOcspPeerVerify { get; set; }
}
/// <summary>
/// MQTT configuration monitoring information.
/// Corresponds to Go server/monitor.go MQTTOptsVarz struct.
/// </summary>
public sealed class MqttOptsVarz
{
[JsonPropertyName("host")]
public string Host { get; set; } = "";
[JsonPropertyName("port")]
public int Port { get; set; }
[JsonPropertyName("tls_timeout")]
public double TlsTimeout { get; set; }
}
/// <summary>
/// Websocket configuration monitoring information.
/// Corresponds to Go server/monitor.go WebsocketOptsVarz struct.
/// </summary>
public sealed class WebsocketOptsVarz
{
[JsonPropertyName("host")]
public string Host { get; set; } = "";
[JsonPropertyName("port")]
public int Port { get; set; }
[JsonPropertyName("tls_timeout")]
public double TlsTimeout { get; set; }
}
/// <summary>
/// JetStream runtime information.
/// Corresponds to Go server/monitor.go JetStreamVarz struct.
/// </summary>
public sealed class JetStreamVarz
{
[JsonPropertyName("config")]
public JetStreamConfig Config { get; set; } = new();
[JsonPropertyName("stats")]
public JetStreamStats Stats { get; set; } = new();
}
/// <summary>
/// JetStream configuration.
/// Corresponds to Go server/jetstream.go JetStreamConfig struct.
/// </summary>
public sealed class JetStreamConfig
{
[JsonPropertyName("max_memory")]
public long MaxMemory { get; set; }
[JsonPropertyName("max_storage")]
public long MaxStorage { get; set; }
[JsonPropertyName("store_dir")]
public string StoreDir { get; set; } = "";
}
/// <summary>
/// JetStream statistics.
/// Corresponds to Go server/jetstream.go JetStreamStats struct.
/// </summary>
public sealed class JetStreamStats
{
[JsonPropertyName("memory")]
public ulong Memory { get; set; }
[JsonPropertyName("storage")]
public ulong Storage { get; set; }
[JsonPropertyName("accounts")]
public int Accounts { get; set; }
[JsonPropertyName("ha_assets")]
public int HaAssets { get; set; }
[JsonPropertyName("api")]
public JetStreamApiStats Api { get; set; } = new();
}
/// <summary>
/// JetStream API statistics.
/// Corresponds to Go server/jetstream.go JetStreamAPIStats struct.
/// </summary>
public sealed class JetStreamApiStats
{
[JsonPropertyName("total")]
public ulong Total { get; set; }
[JsonPropertyName("errors")]
public ulong Errors { get; set; }
}

View File

@@ -0,0 +1,121 @@
using System.Diagnostics;
using System.Runtime.InteropServices;
using NATS.Server.Protocol;
namespace NATS.Server.Monitoring;
/// <summary>
/// Handles building the Varz response from server state and process metrics.
/// Corresponds to Go server/monitor.go handleVarz function.
/// </summary>
public sealed class VarzHandler : IDisposable
{
private readonly NatsServer _server;
private readonly NatsOptions _options;
private readonly SemaphoreSlim _varzMu = new(1, 1);
private DateTime _lastCpuSampleTime;
private TimeSpan _lastCpuUsage;
private double _cachedCpuPercent;
public VarzHandler(NatsServer server, NatsOptions options)
{
_server = server;
_options = options;
using var proc = Process.GetCurrentProcess();
_lastCpuSampleTime = DateTime.UtcNow;
_lastCpuUsage = proc.TotalProcessorTime;
}
public async Task<Varz> HandleVarzAsync(CancellationToken ct = default)
{
await _varzMu.WaitAsync(ct);
try
{
using var proc = Process.GetCurrentProcess();
var now = DateTime.UtcNow;
var uptime = now - _server.StartTime;
var stats = _server.Stats;
// CPU sampling with 1-second cache to avoid excessive sampling
if ((now - _lastCpuSampleTime).TotalSeconds >= 1.0)
{
var currentCpu = proc.TotalProcessorTime;
var elapsed = now - _lastCpuSampleTime;
_cachedCpuPercent = (currentCpu - _lastCpuUsage).TotalMilliseconds
/ elapsed.TotalMilliseconds / Environment.ProcessorCount * 100.0;
_lastCpuSampleTime = now;
_lastCpuUsage = currentCpu;
}
return new Varz
{
Id = _server.ServerId,
Name = _server.ServerName,
Version = NatsProtocol.Version,
Proto = NatsProtocol.ProtoVersion,
GoVersion = $"dotnet {RuntimeInformation.FrameworkDescription}",
Host = _options.Host,
Port = _options.Port,
HttpHost = _options.MonitorHost,
HttpPort = _options.MonitorPort,
HttpBasePath = _options.MonitorBasePath ?? "",
HttpsPort = _options.MonitorHttpsPort,
TlsRequired = _options.HasTls && !_options.AllowNonTls,
TlsVerify = _options.HasTls && _options.TlsVerify,
TlsTimeout = _options.HasTls ? _options.TlsTimeout.TotalSeconds : 0,
MaxConnections = _options.MaxConnections,
MaxPayload = _options.MaxPayload,
MaxControlLine = _options.MaxControlLine,
MaxPingsOut = _options.MaxPingsOut,
PingInterval = (long)_options.PingInterval.TotalNanoseconds,
Start = _server.StartTime,
Now = now,
Uptime = FormatUptime(uptime),
Mem = proc.WorkingSet64,
Cpu = Math.Round(_cachedCpuPercent, 2),
Cores = Environment.ProcessorCount,
MaxProcs = ThreadPool.ThreadCount,
Connections = _server.ClientCount,
TotalConnections = (ulong)Interlocked.Read(ref stats.TotalConnections),
InMsgs = Interlocked.Read(ref stats.InMsgs),
OutMsgs = Interlocked.Read(ref stats.OutMsgs),
InBytes = Interlocked.Read(ref stats.InBytes),
OutBytes = Interlocked.Read(ref stats.OutBytes),
SlowConsumers = Interlocked.Read(ref stats.SlowConsumers),
SlowConsumerStats = new SlowConsumersStats
{
Clients = (ulong)Interlocked.Read(ref stats.SlowConsumerClients),
Routes = (ulong)Interlocked.Read(ref stats.SlowConsumerRoutes),
Gateways = (ulong)Interlocked.Read(ref stats.SlowConsumerGateways),
Leafs = (ulong)Interlocked.Read(ref stats.SlowConsumerLeafs),
},
Subscriptions = _server.SubList.Count,
ConfigLoadTime = _server.StartTime,
HttpReqStats = stats.HttpReqStats.ToDictionary(kv => kv.Key, kv => (ulong)kv.Value),
};
}
finally
{
_varzMu.Release();
}
}
public void Dispose()
{
_varzMu.Dispose();
}
/// <summary>
/// Formats a TimeSpan as a human-readable uptime string matching Go server format.
/// </summary>
private static string FormatUptime(TimeSpan ts)
{
if (ts.TotalDays >= 1)
return $"{(int)ts.TotalDays}d{ts.Hours}h{ts.Minutes}m{ts.Seconds}s";
if (ts.TotalHours >= 1)
return $"{(int)ts.TotalHours}h{ts.Minutes}m{ts.Seconds}s";
if (ts.TotalMinutes >= 1)
return $"{(int)ts.TotalMinutes}m{ts.Seconds}s";
return $"{(int)ts.TotalSeconds}s";
}
}

View File

@@ -1,6 +1,6 @@
<Project Sdk="Microsoft.NET.Sdk">
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" />
<FrameworkReference Include="Microsoft.AspNetCore.App" />
<PackageReference Include="NATS.NKeys" />
<PackageReference Include="BCrypt.Net-Next" />
</ItemGroup>

View File

@@ -1,5 +1,6 @@
using System.Buffers;
using System.IO.Pipelines;
using System.Net;
using System.Net.Sockets;
using System.Security.Cryptography;
using System.Text;
@@ -8,6 +9,7 @@ using Microsoft.Extensions.Logging;
using NATS.Server.Auth;
using NATS.Server.Protocol;
using NATS.Server.Subscriptions;
using NATS.Server.Tls;
namespace NATS.Server;
@@ -26,7 +28,7 @@ public interface ISubListAccess
public sealed class NatsClient : IDisposable
{
private readonly Socket _socket;
private readonly NetworkStream _stream;
private readonly Stream _stream;
private readonly NatsOptions _options;
private readonly ServerInfo _serverInfo;
private readonly AuthService _authService;
@@ -37,6 +39,7 @@ public sealed class NatsClient : IDisposable
private readonly Dictionary<string, Subscription> _subs = new();
private readonly ILogger _logger;
private ClientPermissions? _permissions;
private readonly ServerStats _serverStats;
public ulong Id { get; }
public ClientOptions? ClientOpts { get; private set; }
@@ -47,6 +50,12 @@ public sealed class NatsClient : IDisposable
private int _connectReceived;
public bool ConnectReceived => Volatile.Read(ref _connectReceived) != 0;
public DateTime StartTime { get; }
private long _lastActivityTicks;
public DateTime LastActivity => new(Interlocked.Read(ref _lastActivityTicks), DateTimeKind.Utc);
public string? RemoteIp { get; }
public int RemotePort { get; }
// Stats
public long InMsgs;
public long OutMsgs;
@@ -57,20 +66,31 @@ public sealed class NatsClient : IDisposable
private int _pingsOut;
private long _lastIn;
public TlsConnectionState? TlsState { get; set; }
public bool InfoAlreadySent { get; set; }
public IReadOnlyDictionary<string, Subscription> Subscriptions => _subs;
public NatsClient(ulong id, Socket socket, NatsOptions options, ServerInfo serverInfo,
AuthService authService, byte[]? nonce, ILogger logger)
public NatsClient(ulong id, Stream stream, Socket socket, NatsOptions options, ServerInfo serverInfo,
AuthService authService, byte[]? nonce, ILogger logger, ServerStats serverStats)
{
Id = id;
_socket = socket;
_stream = new NetworkStream(socket, ownsSocket: false);
_stream = stream;
_options = options;
_serverInfo = serverInfo;
_authService = authService;
_nonce = nonce;
_logger = logger;
_serverStats = serverStats;
_parser = new NatsParser(options.MaxPayload);
StartTime = DateTime.UtcNow;
_lastActivityTicks = StartTime.Ticks;
if (socket.RemoteEndPoint is IPEndPoint ep)
{
RemoteIp = ep.Address.ToString();
RemotePort = ep.Port;
}
}
public async Task RunAsync(CancellationToken ct)
@@ -80,8 +100,9 @@ public sealed class NatsClient : IDisposable
var pipe = new Pipe();
try
{
// Send INFO
await SendInfoAsync(_clientCts.Token);
// Send INFO (skip if already sent during TLS negotiation)
if (!InfoAlreadySent)
await SendInfoAsync(_clientCts.Token);
// Start auth timeout if auth is required
Task? authTimeoutTask = null;
@@ -100,7 +121,7 @@ public sealed class NatsClient : IDisposable
}
catch (OperationCanceledException)
{
// Normal client connected or was cancelled
// Normal -- client connected or was cancelled
}
}, _clientCts.Token);
}
@@ -184,6 +205,8 @@ public sealed class NatsClient : IDisposable
private async ValueTask DispatchCommandAsync(ParsedCommand cmd, CancellationToken ct)
{
Interlocked.Exchange(ref _lastActivityTicks, DateTime.UtcNow.Ticks);
// If auth is required and CONNECT hasn't been received yet,
// only allow CONNECT and PING commands
if (_authService.IsAuthRequired && !ConnectReceived)
@@ -266,7 +289,7 @@ public sealed class NatsClient : IDisposable
_logger.LogDebug("Client {ClientId} authenticated as {Identity}", Id, result.Identity);
// Clear nonce after use defense-in-depth against memory dumps
// Clear nonce after use -- defense-in-depth against memory dumps
if (_nonce != null)
CryptographicOperations.ZeroMemory(_nonce);
}
@@ -330,6 +353,8 @@ public sealed class NatsClient : IDisposable
{
Interlocked.Increment(ref InMsgs);
Interlocked.Add(ref InBytes, cmd.Payload.Length);
Interlocked.Increment(ref _serverStats.InMsgs);
Interlocked.Add(ref _serverStats.InBytes, cmd.Payload.Length);
// Max payload validation (always, hard close)
if (cmd.Payload.Length > _options.MaxPayload)
@@ -380,6 +405,8 @@ public sealed class NatsClient : IDisposable
{
Interlocked.Increment(ref OutMsgs);
Interlocked.Add(ref OutBytes, payload.Length + headers.Length);
Interlocked.Increment(ref _serverStats.OutMsgs);
Interlocked.Add(ref _serverStats.OutBytes, payload.Length + headers.Length);
byte[] line;
if (headers.Length > 0)
@@ -470,7 +497,7 @@ public sealed class NatsClient : IDisposable
if (Volatile.Read(ref _pingsOut) + 1 > _options.MaxPingsOut)
{
_logger.LogDebug("Client {ClientId} stale connection closing", Id);
_logger.LogDebug("Client {ClientId} stale connection -- closing", Id);
await SendErrAndCloseAsync(NatsProtocol.ErrStaleConnection);
return;
}

View File

@@ -1,3 +1,4 @@
using System.Security.Authentication;
using NATS.Server.Auth;
namespace NATS.Server;
@@ -7,7 +8,7 @@ public sealed class NatsOptions
public string Host { get; set; } = "0.0.0.0";
public int Port { get; set; } = 4222;
public string? ServerName { get; set; }
public int MaxPayload { get; set; } = 1024 * 1024; // 1MB
public int MaxPayload { get; set; } = 1024 * 1024;
public int MaxControlLine { get; set; } = 4096;
public int MaxConnections { get; set; } = 65536;
public TimeSpan PingInterval { get; set; } = TimeSpan.FromMinutes(2);
@@ -27,4 +28,27 @@ public sealed class NatsOptions
// Auth timing
public TimeSpan AuthTimeout { get; set; } = TimeSpan.FromSeconds(2);
// Monitoring (0 = disabled; standard port is 8222)
public int MonitorPort { get; set; }
public string MonitorHost { get; set; } = "0.0.0.0";
public string? MonitorBasePath { get; set; }
// 0 = disabled
public int MonitorHttpsPort { get; set; }
// TLS
public string? TlsCert { get; set; }
public string? TlsKey { get; set; }
public string? TlsCaCert { get; set; }
public bool TlsVerify { get; set; }
public bool TlsMap { get; set; }
public TimeSpan TlsTimeout { get; set; } = TimeSpan.FromSeconds(2);
public bool TlsHandshakeFirst { get; set; }
public TimeSpan TlsHandshakeFirstFallback { get; set; } = TimeSpan.FromMilliseconds(50);
public bool AllowNonTls { get; set; }
public long TlsRateLimit { get; set; }
public HashSet<string>? TlsPinnedCerts { get; set; }
public SslProtocols TlsMinVersion { get; set; } = SslProtocols.Tls12;
public bool HasTls => TlsCert != null && TlsKey != null;
}

View File

@@ -1,11 +1,15 @@
using System.Collections.Concurrent;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using Microsoft.Extensions.Logging;
using NATS.Server.Auth;
using NATS.Server.Monitoring;
using NATS.Server.Protocol;
using NATS.Server.Subscriptions;
using NATS.Server.Tls;
namespace NATS.Server;
@@ -16,14 +20,25 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
private readonly ServerInfo _serverInfo;
private readonly ILogger<NatsServer> _logger;
private readonly ILoggerFactory _loggerFactory;
private readonly ServerStats _stats = new();
private readonly TaskCompletionSource _listeningStarted = new(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly AuthService _authService;
private readonly ConcurrentDictionary<string, Account> _accounts = new(StringComparer.Ordinal);
private readonly Account _globalAccount;
private readonly SslServerAuthenticationOptions? _sslOptions;
private readonly TlsRateLimiter? _tlsRateLimiter;
private Socket? _listener;
private MonitorServer? _monitorServer;
private ulong _nextClientId;
private long _startTimeTicks;
public SubList SubList => _globalAccount.SubList;
public ServerStats Stats => _stats;
public DateTime StartTime => new(Interlocked.Read(ref _startTimeTicks), DateTimeKind.Utc);
public string ServerId => _serverInfo.ServerId;
public string ServerName => _serverInfo.ServerName;
public int ClientCount => _clients.Count;
public IEnumerable<NatsClient> GetClients() => _clients.Values;
public Task WaitForReadyAsync() => _listeningStarted.Task;
@@ -45,6 +60,17 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
MaxPayload = options.MaxPayload,
AuthRequired = _authService.IsAuthRequired,
};
if (options.HasTls)
{
_sslOptions = TlsHelper.BuildServerAuthOptions(options);
_serverInfo.TlsRequired = !options.AllowNonTls;
_serverInfo.TlsAvailable = options.AllowNonTls;
_serverInfo.TlsVerify = options.TlsVerify;
if (options.TlsRateLimit > 0)
_tlsRateLimiter = new TlsRateLimiter(options.TlsRateLimit);
}
}
public async Task StartAsync(CancellationToken ct)
@@ -54,11 +80,18 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
_listener.Bind(new IPEndPoint(
_options.Host == "0.0.0.0" ? IPAddress.Any : IPAddress.Parse(_options.Host),
_options.Port));
Interlocked.Exchange(ref _startTimeTicks, DateTime.UtcNow.Ticks);
_listener.Listen(128);
_listeningStarted.TrySetResult();
_logger.LogInformation("Listening on {Host}:{Port}", _options.Host, _options.Port);
if (_options.MonitorPort > 0)
{
_monitorServer = new MonitorServer(this, _options, _stats, _loggerFactory);
await _monitorServer.StartAsync(ct);
}
try
{
while (!ct.IsCancellationRequested)
@@ -91,37 +124,11 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
}
var clientId = Interlocked.Increment(ref _nextClientId);
Interlocked.Increment(ref _stats.TotalConnections);
_logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint);
// Build per-client ServerInfo with nonce if NKey auth is configured
byte[]? nonce = null;
var clientInfo = _serverInfo;
if (_authService.NonceRequired)
{
var rawNonce = _authService.GenerateNonce();
var nonceStr = _authService.EncodeNonce(rawNonce);
// The client signs the nonce string (ASCII), not the raw bytes
nonce = Encoding.ASCII.GetBytes(nonceStr);
clientInfo = new ServerInfo
{
ServerId = _serverInfo.ServerId,
ServerName = _serverInfo.ServerName,
Version = _serverInfo.Version,
Host = _serverInfo.Host,
Port = _serverInfo.Port,
MaxPayload = _serverInfo.MaxPayload,
AuthRequired = _serverInfo.AuthRequired,
Nonce = nonceStr,
};
}
var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]");
var client = new NatsClient(clientId, socket, _options, clientInfo, _authService, nonce, clientLogger);
client.Router = this;
_clients[clientId] = client;
_ = RunClientAsync(client, ct);
_ = AcceptClientAsync(socket, clientId, ct);
}
}
catch (OperationCanceledException)
@@ -130,6 +137,74 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
}
}
private async Task AcceptClientAsync(Socket socket, ulong clientId, CancellationToken ct)
{
try
{
// Rate limit TLS handshakes
if (_tlsRateLimiter != null)
await _tlsRateLimiter.WaitAsync(ct);
var networkStream = new NetworkStream(socket, ownsSocket: false);
// TLS negotiation (no-op if not configured)
var (stream, infoAlreadySent) = await TlsConnectionWrapper.NegotiateAsync(
socket, networkStream, _options, _sslOptions, _serverInfo,
_loggerFactory.CreateLogger("NATS.Server.Tls"), ct);
// Extract TLS state
TlsConnectionState? tlsState = null;
if (stream is SslStream ssl)
{
tlsState = new TlsConnectionState(
ssl.SslProtocol.ToString(),
ssl.NegotiatedCipherSuite.ToString(),
ssl.RemoteCertificate as X509Certificate2);
}
// Build per-client ServerInfo with nonce if NKey auth is configured
byte[]? nonce = null;
var clientInfo = _serverInfo;
if (_authService.NonceRequired)
{
var rawNonce = _authService.GenerateNonce();
var nonceStr = _authService.EncodeNonce(rawNonce);
// The client signs the nonce string (ASCII), not the raw bytes
nonce = Encoding.ASCII.GetBytes(nonceStr);
clientInfo = new ServerInfo
{
ServerId = _serverInfo.ServerId,
ServerName = _serverInfo.ServerName,
Version = _serverInfo.Version,
Host = _serverInfo.Host,
Port = _serverInfo.Port,
MaxPayload = _serverInfo.MaxPayload,
AuthRequired = _serverInfo.AuthRequired,
TlsRequired = _serverInfo.TlsRequired,
TlsAvailable = _serverInfo.TlsAvailable,
TlsVerify = _serverInfo.TlsVerify,
Nonce = nonceStr,
};
}
var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]");
var client = new NatsClient(clientId, stream, socket, _options, clientInfo,
_authService, nonce, clientLogger, _stats);
client.Router = this;
client.TlsState = tlsState;
client.InfoAlreadySent = infoAlreadySent;
_clients[clientId] = client;
await RunClientAsync(client, ct);
}
catch (Exception ex)
{
_logger.LogDebug(ex, "Failed to accept client {ClientId}", clientId);
try { socket.Shutdown(SocketShutdown.Both); } catch { }
socket.Dispose();
}
}
private async Task RunClientAsync(NatsClient client, CancellationToken ct)
{
try
@@ -215,6 +290,9 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
public void Dispose()
{
if (_monitorServer != null)
_monitorServer.DisposeAsync().AsTask().GetAwaiter().GetResult();
_tlsRateLimiter?.Dispose();
_listener?.Dispose();
foreach (var client in _clients.Values)
client.Dispose();

View File

@@ -73,6 +73,18 @@ public sealed class ServerInfo
[JsonPropertyName("nonce")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Nonce { get; set; }
[JsonPropertyName("tls_required")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public bool TlsRequired { get; set; }
[JsonPropertyName("tls_verify")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public bool TlsVerify { get; set; }
[JsonPropertyName("tls_available")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public bool TlsAvailable { get; set; }
}
public sealed class ClientOptions

View File

@@ -0,0 +1,20 @@
using System.Collections.Concurrent;
namespace NATS.Server;
public sealed class ServerStats
{
public long InMsgs;
public long OutMsgs;
public long InBytes;
public long OutBytes;
public long TotalConnections;
public long SlowConsumers;
public long StaleConnections;
public long Stalls;
public long SlowConsumerClients;
public long SlowConsumerRoutes;
public long SlowConsumerLeafs;
public long SlowConsumerGateways;
public readonly ConcurrentDictionary<string, long> HttpReqStats = new();
}

View File

@@ -0,0 +1,71 @@
namespace NATS.Server.Tls;
public sealed class PeekableStream : Stream
{
private readonly Stream _inner;
private byte[]? _peekedBytes;
private int _peekedOffset;
private int _peekedCount;
public PeekableStream(Stream inner) => _inner = inner;
public async Task<byte[]> PeekAsync(int count, CancellationToken ct = default)
{
var buf = new byte[count];
int read = await _inner.ReadAsync(buf.AsMemory(0, count), ct);
if (read < count) Array.Resize(ref buf, read);
_peekedBytes = buf;
_peekedOffset = 0;
_peekedCount = read;
return buf;
}
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken ct = default)
{
if (_peekedBytes != null && _peekedOffset < _peekedCount)
{
int available = _peekedCount - _peekedOffset;
int toCopy = Math.Min(available, buffer.Length);
_peekedBytes.AsMemory(_peekedOffset, toCopy).CopyTo(buffer);
_peekedOffset += toCopy;
if (_peekedOffset >= _peekedCount) _peekedBytes = null;
return toCopy;
}
return await _inner.ReadAsync(buffer, ct);
}
public override int Read(byte[] buffer, int offset, int count)
{
if (_peekedBytes != null && _peekedOffset < _peekedCount)
{
int available = _peekedCount - _peekedOffset;
int toCopy = Math.Min(available, count);
Array.Copy(_peekedBytes, _peekedOffset, buffer, offset, toCopy);
_peekedOffset += toCopy;
if (_peekedOffset >= _peekedCount) _peekedBytes = null;
return toCopy;
}
return _inner.Read(buffer, offset, count);
}
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken ct)
=> ReadAsync(buffer.AsMemory(offset, count), ct).AsTask();
// Write passthrough
public override void Write(byte[] buffer, int offset, int count) => _inner.Write(buffer, offset, count);
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ct) => _inner.WriteAsync(buffer, offset, count, ct);
public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken ct = default) => _inner.WriteAsync(buffer, ct);
public override void Flush() => _inner.Flush();
public override Task FlushAsync(CancellationToken ct) => _inner.FlushAsync(ct);
// Required Stream overrides
public override bool CanRead => _inner.CanRead;
public override bool CanSeek => false;
public override bool CanWrite => _inner.CanWrite;
public override long Length => throw new NotSupportedException();
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
public override void SetLength(long value) => throw new NotSupportedException();
protected override void Dispose(bool disposing) { if (disposing) _inner.Dispose(); base.Dispose(disposing); }
}

View File

@@ -0,0 +1,9 @@
using System.Security.Cryptography.X509Certificates;
namespace NATS.Server.Tls;
public sealed record TlsConnectionState(
string? TlsVersion,
string? CipherSuite,
X509Certificate2? PeerCert
);

View File

@@ -0,0 +1,202 @@
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Text.Json;
using Microsoft.Extensions.Logging;
using NATS.Server.Protocol;
namespace NATS.Server.Tls;
public static class TlsConnectionWrapper
{
private const byte TlsRecordMarker = 0x16;
public static async Task<(Stream stream, bool infoAlreadySent)> NegotiateAsync(
Socket socket,
Stream networkStream,
NatsOptions options,
SslServerAuthenticationOptions? sslOptions,
ServerInfo serverInfo,
ILogger logger,
CancellationToken ct)
{
// Mode 1: No TLS
if (sslOptions == null || !options.HasTls)
return (networkStream, false);
// Clone to avoid mutating shared instance
serverInfo = new ServerInfo
{
ServerId = serverInfo.ServerId,
ServerName = serverInfo.ServerName,
Version = serverInfo.Version,
Proto = serverInfo.Proto,
Host = serverInfo.Host,
Port = serverInfo.Port,
Headers = serverInfo.Headers,
MaxPayload = serverInfo.MaxPayload,
ClientId = serverInfo.ClientId,
ClientIp = serverInfo.ClientIp,
};
// Mode 3: TLS First
if (options.TlsHandshakeFirst)
return await NegotiateTlsFirstAsync(socket, networkStream, options, sslOptions, serverInfo, logger, ct);
// Mode 2 & 4: Send INFO first, then decide
serverInfo.TlsRequired = !options.AllowNonTls;
serverInfo.TlsAvailable = options.AllowNonTls;
serverInfo.TlsVerify = options.TlsVerify;
await SendInfoAsync(networkStream, serverInfo, ct);
// Peek first byte to detect TLS
var peekable = new PeekableStream(networkStream);
var peeked = await PeekWithTimeoutAsync(peekable, 1, options.TlsTimeout, ct);
if (peeked.Length == 0)
{
// Client disconnected or timed out
return (peekable, true);
}
if (peeked[0] == TlsRecordMarker)
{
// Client is starting TLS
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
try
{
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
handshakeCts.CancelAfter(options.TlsTimeout);
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
logger.LogDebug("TLS handshake complete: {Protocol} {CipherSuite}",
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
// Validate pinned certs
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
{
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
{
logger.LogWarning("Certificate pinning check failed");
throw new InvalidOperationException("Certificate pinning check failed");
}
}
}
catch
{
sslStream.Dispose();
throw;
}
return (sslStream, true);
}
// Mode 4: Mixed — client chose plaintext
if (options.AllowNonTls)
{
logger.LogDebug("Client connected without TLS (mixed mode)");
return (peekable, true);
}
// TLS required but client sent plaintext
logger.LogWarning("TLS required but client sent plaintext data");
throw new InvalidOperationException("TLS required");
}
private static async Task<(Stream stream, bool infoAlreadySent)> NegotiateTlsFirstAsync(
Socket socket,
Stream networkStream,
NatsOptions options,
SslServerAuthenticationOptions sslOptions,
ServerInfo serverInfo,
ILogger logger,
CancellationToken ct)
{
// Wait for data with fallback timeout
var peekable = new PeekableStream(networkStream);
var peeked = await PeekWithTimeoutAsync(peekable, 1, options.TlsHandshakeFirstFallback, ct);
if (peeked.Length > 0 && peeked[0] == TlsRecordMarker)
{
// Client started TLS immediately — handshake first, then send INFO
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
try
{
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
handshakeCts.CancelAfter(options.TlsTimeout);
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
logger.LogDebug("TLS-first handshake complete: {Protocol} {CipherSuite}",
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
// Validate pinned certs
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
{
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
{
throw new InvalidOperationException("Certificate pinning check failed");
}
}
// Now send INFO over encrypted stream
serverInfo.TlsRequired = true;
serverInfo.TlsVerify = options.TlsVerify;
await SendInfoAsync(sslStream, serverInfo, ct);
}
catch
{
sslStream.Dispose();
throw;
}
return (sslStream, true);
}
// Fallback: timeout expired or non-TLS data — send INFO and negotiate normally
logger.LogDebug("TLS-first fallback: sending INFO");
serverInfo.TlsRequired = !options.AllowNonTls;
serverInfo.TlsAvailable = options.AllowNonTls;
serverInfo.TlsVerify = options.TlsVerify;
await SendInfoAsync(peekable, serverInfo, ct);
if (peeked.Length == 0)
{
// Timeout — INFO was sent, return stream for normal flow
return (peekable, true);
}
// Non-TLS data received during fallback window
if (options.AllowNonTls)
{
return (peekable, true);
}
// TLS required but got plaintext
throw new InvalidOperationException("TLS required but client sent plaintext");
}
private static async Task<byte[]> PeekWithTimeoutAsync(
PeekableStream stream, int count, TimeSpan timeout, CancellationToken ct)
{
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
cts.CancelAfter(timeout);
try
{
return await stream.PeekAsync(count, cts.Token);
}
catch (OperationCanceledException) when (!ct.IsCancellationRequested)
{
// Timeout — not a cancellation of the outer token
return [];
}
}
private static async Task SendInfoAsync(Stream stream, ServerInfo serverInfo, CancellationToken ct)
{
var infoJson = JsonSerializer.Serialize(serverInfo);
var infoLine = Encoding.ASCII.GetBytes($"INFO {infoJson}\r\n");
await stream.WriteAsync(infoLine, ct);
await stream.FlushAsync(ct);
}
}

View File

@@ -0,0 +1,65 @@
using System.Net.Security;
using System.Security.Authentication;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
namespace NATS.Server.Tls;
public static class TlsHelper
{
public static X509Certificate2 LoadCertificate(string certPath, string? keyPath)
{
if (keyPath != null)
return X509Certificate2.CreateFromPemFile(certPath, keyPath);
return X509CertificateLoader.LoadCertificateFromFile(certPath);
}
public static X509Certificate2Collection LoadCaCertificates(string caPath)
{
var collection = new X509Certificate2Collection();
collection.ImportFromPemFile(caPath);
return collection;
}
public static SslServerAuthenticationOptions BuildServerAuthOptions(NatsOptions opts)
{
var cert = LoadCertificate(opts.TlsCert!, opts.TlsKey);
var authOpts = new SslServerAuthenticationOptions
{
ServerCertificate = cert,
EnabledSslProtocols = opts.TlsMinVersion,
ClientCertificateRequired = opts.TlsVerify,
};
if (opts.TlsVerify && opts.TlsCaCert != null)
{
var caCerts = LoadCaCertificates(opts.TlsCaCert);
authOpts.RemoteCertificateValidationCallback = (_, cert, chain, errors) =>
{
if (cert == null) return false;
using var chain2 = new X509Chain();
chain2.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust;
foreach (var ca in caCerts)
chain2.ChainPolicy.CustomTrustStore.Add(ca);
chain2.ChainPolicy.RevocationMode = X509RevocationMode.NoCheck;
var cert2 = cert as X509Certificate2 ?? X509CertificateLoader.LoadCertificate(cert.GetRawCertData());
return chain2.Build(cert2);
};
}
return authOpts;
}
public static string GetCertificateHash(X509Certificate2 cert)
{
var spki = cert.PublicKey.ExportSubjectPublicKeyInfo();
var hash = SHA256.HashData(spki);
return Convert.ToHexStringLower(hash);
}
public static bool MatchesPinnedCert(X509Certificate2 cert, HashSet<string> pinned)
{
var hash = GetCertificateHash(cert);
return pinned.Contains(hash);
}
}

View File

@@ -0,0 +1,25 @@
namespace NATS.Server.Tls;
public sealed class TlsRateLimiter : IDisposable
{
private readonly SemaphoreSlim _semaphore;
private readonly Timer _refillTimer;
private readonly int _tokensPerSecond;
public TlsRateLimiter(long tokensPerSecond)
{
_tokensPerSecond = (int)Math.Max(1, tokensPerSecond);
_semaphore = new SemaphoreSlim(_tokensPerSecond, _tokensPerSecond);
_refillTimer = new Timer(Refill, null, TimeSpan.FromSeconds(1), TimeSpan.FromSeconds(1));
}
private void Refill(object? state)
{
int toRelease = _tokensPerSecond - _semaphore.CurrentCount;
if (toRelease > 0) _semaphore.Release(toRelease);
}
public Task WaitAsync(CancellationToken ct) => _semaphore.WaitAsync(ct);
public void Dispose() { _refillTimer.Dispose(); _semaphore.Dispose(); }
}

View File

@@ -41,7 +41,7 @@ public class ClientTests : IAsyncDisposable
};
var authService = AuthService.Build(new NatsOptions());
_natsClient = new NatsClient(1, _serverSocket, new NatsOptions(), serverInfo, authService, null, NullLogger.Instance);
_natsClient = new NatsClient(1, new NetworkStream(_serverSocket, ownsSocket: false), _serverSocket, new NatsOptions(), serverInfo, authService, null, NullLogger.Instance, new ServerStats());
}
public async ValueTask DisposeAsync()
@@ -56,7 +56,7 @@ public class ClientTests : IAsyncDisposable
{
var runTask = _natsClient.RunAsync(_cts.Token);
// Read from client socket should get INFO
// Read from client socket -- should get INFO
var buf = new byte[4096];
var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None);
var response = Encoding.ASCII.GetString(buf, 0, n);
@@ -80,7 +80,7 @@ public class ClientTests : IAsyncDisposable
// Send CONNECT then PING
await _clientSocket.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPING\r\n"));
// Read response should get PONG
// Read response -- should get PONG
var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None);
var response = Encoding.ASCII.GetString(buf, 0, n);
@@ -128,7 +128,7 @@ public class ClientTests : IAsyncDisposable
response.ShouldBe("-ERR 'maximum connections exceeded'\r\n");
// Connection should be closed next read returns 0
// Connection should be closed -- next read returns 0
n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token);
n.ShouldBe(0);
}

View File

@@ -0,0 +1,51 @@
using System.Text.Json;
using NATS.Server.Monitoring;
namespace NATS.Server.Tests;
public class MonitorModelTests
{
[Fact]
public void Varz_serializes_with_go_field_names()
{
var varz = new Varz
{
Id = "TESTID", Name = "test-server", Version = "0.1.0",
Host = "0.0.0.0", Port = 4222, InMsgs = 100, OutMsgs = 200,
};
var json = JsonSerializer.Serialize(varz);
json.ShouldContain("\"server_id\":");
json.ShouldContain("\"server_name\":");
json.ShouldContain("\"in_msgs\":");
json.ShouldContain("\"out_msgs\":");
json.ShouldNotContain("\"InMsgs\"");
}
[Fact]
public void Connz_serializes_with_go_field_names()
{
var connz = new Connz
{
Id = "TESTID", Now = DateTime.UtcNow, NumConns = 1, Total = 1, Limit = 1024,
Conns = [new ConnInfo { Cid = 1, Ip = "127.0.0.1", Port = 5555,
InMsgs = 10, Uptime = "1s", Idle = "0s",
Start = DateTime.UtcNow, LastActivity = DateTime.UtcNow }],
};
var json = JsonSerializer.Serialize(connz);
json.ShouldContain("\"server_id\":");
json.ShouldContain("\"num_connections\":");
json.ShouldContain("\"in_msgs\":");
json.ShouldContain("\"pending_bytes\":");
}
[Fact]
public void Varz_includes_nested_config_stubs()
{
var varz = new Varz { Id = "X", Name = "X", Version = "X", Host = "X" };
var json = JsonSerializer.Serialize(varz);
json.ShouldContain("\"cluster\":");
json.ShouldContain("\"gateway\":");
json.ShouldContain("\"leaf\":");
json.ShouldContain("\"jetstream\":");
}
}

View File

@@ -0,0 +1,274 @@
using System.Net;
using System.Net.Http.Json;
using System.Net.Security;
using System.Net.Sockets;
using System.Text;
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Server.Monitoring;
namespace NATS.Server.Tests;
public class MonitorTests : IAsyncLifetime
{
private readonly NatsServer _server;
private readonly int _natsPort;
private readonly int _monitorPort;
private readonly CancellationTokenSource _cts = new();
private readonly HttpClient _http = new();
public MonitorTests()
{
_natsPort = GetFreePort();
_monitorPort = GetFreePort();
_server = new NatsServer(
new NatsOptions { Port = _natsPort, MonitorPort = _monitorPort },
NullLoggerFactory.Instance);
}
public async Task InitializeAsync()
{
_ = _server.StartAsync(_cts.Token);
await _server.WaitForReadyAsync();
// Wait for monitoring HTTP server to be ready
for (int i = 0; i < 50; i++)
{
try
{
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/healthz");
if (response.IsSuccessStatusCode) break;
}
catch (HttpRequestException) { }
await Task.Delay(50);
}
}
public async Task DisposeAsync()
{
_http.Dispose();
await _cts.CancelAsync();
_server.Dispose();
}
[Fact]
public async Task Healthz_returns_ok()
{
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/healthz");
response.StatusCode.ShouldBe(HttpStatusCode.OK);
}
[Fact]
public async Task Varz_returns_server_identity()
{
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/varz");
response.StatusCode.ShouldBe(HttpStatusCode.OK);
var varz = await response.Content.ReadFromJsonAsync<Varz>();
varz.ShouldNotBeNull();
varz.Id.ShouldNotBeNullOrEmpty();
varz.Name.ShouldNotBeNullOrEmpty();
varz.Version.ShouldBe("0.1.0");
varz.Host.ShouldBe("0.0.0.0");
varz.Port.ShouldBe(_natsPort);
varz.MaxPayload.ShouldBe(1024 * 1024);
varz.Uptime.ShouldNotBeNullOrEmpty();
varz.Now.ShouldBeGreaterThan(DateTime.MinValue);
varz.Mem.ShouldBeGreaterThan(0);
varz.Cores.ShouldBeGreaterThan(0);
}
[Fact]
public async Task Varz_tracks_connections_and_messages()
{
// Connect a client and send a message
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _natsPort));
var buf = new byte[4096];
_ = await sock.ReceiveAsync(buf, SocketFlags.None); // Read INFO
var cmd = "CONNECT {}\r\nSUB test 1\r\nPUB test 5\r\nhello\r\n"u8.ToArray();
await sock.SendAsync(cmd, SocketFlags.None);
await Task.Delay(200);
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/varz");
var varz = await response.Content.ReadFromJsonAsync<Varz>();
varz.ShouldNotBeNull();
varz.Connections.ShouldBeGreaterThanOrEqualTo(1);
varz.TotalConnections.ShouldBeGreaterThanOrEqualTo(1UL);
varz.InMsgs.ShouldBeGreaterThanOrEqualTo(1L);
varz.InBytes.ShouldBeGreaterThanOrEqualTo(5L);
}
[Fact]
public async Task Connz_returns_connections()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _natsPort));
using var stream = new NetworkStream(sock);
var buf = new byte[4096];
_ = await stream.ReadAsync(buf);
await stream.WriteAsync("CONNECT {\"name\":\"test-client\",\"lang\":\"csharp\",\"version\":\"1.0\"}\r\n"u8.ToArray());
await Task.Delay(200);
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/connz");
response.StatusCode.ShouldBe(HttpStatusCode.OK);
var connz = await response.Content.ReadFromJsonAsync<Connz>();
connz.ShouldNotBeNull();
connz.NumConns.ShouldBeGreaterThanOrEqualTo(1);
connz.Conns.Length.ShouldBeGreaterThanOrEqualTo(1);
var conn = connz.Conns.First(c => c.Name == "test-client");
conn.Ip.ShouldNotBeNullOrEmpty();
conn.Port.ShouldBeGreaterThan(0);
conn.Lang.ShouldBe("csharp");
conn.Version.ShouldBe("1.0");
conn.Uptime.ShouldNotBeNullOrEmpty();
}
[Fact]
public async Task Connz_pagination()
{
var sockets = new List<Socket>();
try
{
for (int i = 0; i < 3; i++)
{
var s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await s.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _natsPort));
var ns = new NetworkStream(s);
var buf = new byte[4096];
_ = await ns.ReadAsync(buf);
await ns.WriteAsync("CONNECT {}\r\n"u8.ToArray());
sockets.Add(s);
}
await Task.Delay(200);
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/connz?limit=2&offset=0");
var connz = await response.Content.ReadFromJsonAsync<Connz>();
connz!.Conns.Length.ShouldBe(2);
connz.Total.ShouldBeGreaterThanOrEqualTo(3);
connz.Limit.ShouldBe(2);
connz.Offset.ShouldBe(0);
}
finally
{
foreach (var s in sockets) s.Dispose();
}
}
[Fact]
public async Task Connz_with_subscriptions()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _natsPort));
using var stream = new NetworkStream(sock);
var buf = new byte[4096];
_ = await stream.ReadAsync(buf);
await stream.WriteAsync("CONNECT {}\r\nSUB foo 1\r\nSUB bar 2\r\n"u8.ToArray());
await Task.Delay(200);
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/connz?subs=true");
var connz = await response.Content.ReadFromJsonAsync<Connz>();
var conn = connz!.Conns.First(c => c.NumSubs >= 2);
conn.Subs.ShouldNotBeNull();
conn.Subs.ShouldContain("foo");
conn.Subs.ShouldContain("bar");
}
private static int GetFreePort()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
return ((IPEndPoint)sock.LocalEndPoint!).Port;
}
}
public class MonitorTlsTests : IAsyncLifetime
{
private readonly NatsServer _server;
private readonly int _natsPort;
private readonly int _monitorPort;
private readonly CancellationTokenSource _cts = new();
private readonly HttpClient _http = new();
private readonly string _certPath;
private readonly string _keyPath;
public MonitorTlsTests()
{
_natsPort = GetFreePort();
_monitorPort = GetFreePort();
(_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles();
_server = new NatsServer(
new NatsOptions
{
Port = _natsPort,
MonitorPort = _monitorPort,
TlsCert = _certPath,
TlsKey = _keyPath,
},
NullLoggerFactory.Instance);
}
public async Task InitializeAsync()
{
_ = _server.StartAsync(_cts.Token);
await _server.WaitForReadyAsync();
// Wait for monitoring HTTP server to be ready
for (int i = 0; i < 50; i++)
{
try
{
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/healthz");
if (response.IsSuccessStatusCode) break;
}
catch (HttpRequestException) { }
await Task.Delay(50);
}
}
public async Task DisposeAsync()
{
_http.Dispose();
await _cts.CancelAsync();
_server.Dispose();
File.Delete(_certPath);
File.Delete(_keyPath);
}
[Fact]
public async Task Connz_shows_tls_info_for_tls_client()
{
// Connect and upgrade to TLS
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, _natsPort);
using var netStream = tcp.GetStream();
var buf = new byte[4096];
_ = await netStream.ReadAsync(buf); // Read INFO
using var ssl = new SslStream(netStream, false, (_, _, _, _) => true);
await ssl.AuthenticateAsClientAsync("localhost");
await ssl.WriteAsync("CONNECT {}\r\n"u8.ToArray());
await ssl.FlushAsync();
await Task.Delay(200);
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/connz");
var connz = await response.Content.ReadFromJsonAsync<Connz>();
connz!.Conns.Length.ShouldBeGreaterThanOrEqualTo(1);
var conn = connz.Conns[0];
conn.TlsVersion.ShouldNotBeNullOrEmpty();
conn.TlsCipherSuite.ShouldNotBeNullOrEmpty();
}
private static int GetFreePort()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
return ((IPEndPoint)sock.LocalEndPoint!).Port;
}
}

View File

@@ -0,0 +1,84 @@
using System.Net;
using System.Net.Sockets;
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Server;
namespace NATS.Server.Tests;
public class ServerStatsTests : IAsyncLifetime
{
private readonly NatsServer _server;
private readonly int _port;
private readonly CancellationTokenSource _cts = new();
public ServerStatsTests()
{
_port = GetFreePort();
_server = new NatsServer(new NatsOptions { Port = _port }, NullLoggerFactory.Instance);
}
public async Task InitializeAsync()
{
_ = _server.StartAsync(_cts.Token);
await _server.WaitForReadyAsync();
}
public Task DisposeAsync()
{
_cts.Cancel();
_server.Dispose();
return Task.CompletedTask;
}
[Fact]
public void Server_has_start_time()
{
_server.StartTime.ShouldNotBe(default);
_server.StartTime.ShouldBeLessThanOrEqualTo(DateTime.UtcNow);
}
[Fact]
public async Task Server_tracks_total_connections()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port));
await Task.Delay(100);
_server.Stats.TotalConnections.ShouldBeGreaterThanOrEqualTo(1);
}
[Fact]
public async Task Server_stats_track_messages()
{
using var pub = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await pub.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port));
var buf = new byte[4096];
await pub.ReceiveAsync(buf, SocketFlags.None); // INFO
await pub.SendAsync("CONNECT {}\r\nSUB test 1\r\nPUB test 5\r\nhello\r\n"u8.ToArray());
await Task.Delay(200);
_server.Stats.InMsgs.ShouldBeGreaterThanOrEqualTo(1);
_server.Stats.InBytes.ShouldBeGreaterThanOrEqualTo(5);
}
[Fact]
public async Task Client_has_metadata()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port));
await Task.Delay(100);
var client = _server.GetClients().First();
client.RemoteIp.ShouldNotBeNullOrEmpty();
client.RemotePort.ShouldBeGreaterThan(0);
client.StartTime.ShouldNotBe(default);
}
private static int GetFreePort()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
return ((IPEndPoint)sock.LocalEndPoint!).Port;
}
}

View File

@@ -0,0 +1,254 @@
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Server;
using NATS.Server.Protocol;
using NATS.Server.Tls;
namespace NATS.Server.Tests;
public class TlsConnectionWrapperTests
{
[Fact]
public async Task NoTls_returns_plain_stream()
{
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
using var serverStream = new NetworkStream(serverSocket, ownsSocket: true);
using var clientStream = new NetworkStream(clientSocket, ownsSocket: true);
var opts = new NatsOptions(); // No TLS configured
var serverInfo = CreateServerInfo();
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
serverSocket, serverStream, opts, null, serverInfo, NullLogger.Instance, CancellationToken.None);
stream.ShouldBe(serverStream); // Same stream, no wrapping
infoSent.ShouldBeFalse();
}
[Fact]
public async Task TlsRequired_upgrades_to_ssl()
{
var (cert, _) = TlsHelperTests.GenerateTestCert();
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy" };
var sslOpts = new SslServerAuthenticationOptions
{
ServerCertificate = cert,
};
var serverInfo = CreateServerInfo();
// Client side: read INFO then start TLS
var clientTask = Task.Run(async () =>
{
// Read INFO line
var buf = new byte[4096];
var read = await clientNetStream.ReadAsync(buf);
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
info.ShouldStartWith("INFO ");
// Upgrade to TLS
var sslClient = new SslStream(clientNetStream, true,
(_, _, _, _) => true); // Trust all for testing
await sslClient.AuthenticateAsClientAsync("localhost");
return sslClient;
});
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
stream.ShouldBeOfType<SslStream>();
infoSent.ShouldBeTrue();
var clientSsl = await clientTask;
// Verify encrypted communication works
await stream.WriteAsync("PING\r\n"u8.ToArray());
await stream.FlushAsync();
var readBuf = new byte[64];
var bytesRead = await clientSsl.ReadAsync(readBuf);
var msg = System.Text.Encoding.ASCII.GetString(readBuf, 0, bytesRead);
msg.ShouldBe("PING\r\n");
stream.Dispose();
clientSsl.Dispose();
}
[Fact]
public async Task MixedMode_allows_plaintext_when_AllowNonTls()
{
var (cert, _) = TlsHelperTests.GenerateTestCert();
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
var opts = new NatsOptions
{
TlsCert = "dummy",
TlsKey = "dummy",
AllowNonTls = true,
TlsTimeout = TimeSpan.FromSeconds(2),
};
var sslOpts = new SslServerAuthenticationOptions
{
ServerCertificate = cert,
};
var serverInfo = CreateServerInfo();
// Client side: read INFO then send plaintext (not TLS)
var clientTask = Task.Run(async () =>
{
var buf = new byte[4096];
var read = await clientNetStream.ReadAsync(buf);
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
info.ShouldStartWith("INFO ");
// Send plaintext CONNECT (not a TLS handshake)
var connectLine = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n");
await clientNetStream.WriteAsync(connectLine);
await clientNetStream.FlushAsync();
});
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
await clientTask;
// In mixed mode with plaintext client, we get a PeekableStream, not SslStream
stream.ShouldBeOfType<PeekableStream>();
infoSent.ShouldBeTrue();
stream.Dispose();
}
[Fact]
public async Task TlsRequired_rejects_plaintext()
{
var (cert, _) = TlsHelperTests.GenerateTestCert();
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
var opts = new NatsOptions
{
TlsCert = "dummy",
TlsKey = "dummy",
AllowNonTls = false,
TlsTimeout = TimeSpan.FromSeconds(2),
};
var sslOpts = new SslServerAuthenticationOptions
{
ServerCertificate = cert,
};
var serverInfo = CreateServerInfo();
// Client side: read INFO then send plaintext
var clientTask = Task.Run(async () =>
{
var buf = new byte[4096];
var read = await clientNetStream.ReadAsync(buf);
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
info.ShouldStartWith("INFO ");
// Send plaintext data (first byte is 'C', not 0x16 TLS marker)
var connectLine = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n");
await clientNetStream.WriteAsync(connectLine);
await clientNetStream.FlushAsync();
});
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
await Should.ThrowAsync<InvalidOperationException>(async () =>
{
await TlsConnectionWrapper.NegotiateAsync(
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
});
await clientTask;
serverNetStream.Dispose();
}
[Fact]
public async Task TlsFirst_handshakes_before_sending_info()
{
var (cert, _) = TlsHelperTests.GenerateTestCert();
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy", TlsHandshakeFirst = true };
var sslOpts = new SslServerAuthenticationOptions
{
ServerCertificate = cert,
};
var serverInfo = CreateServerInfo();
// Client side: immediately start TLS (no INFO first)
var clientTask = Task.Run(async () =>
{
var sslClient = new SslStream(clientNetStream, true, (_, _, _, _) => true);
await sslClient.AuthenticateAsClientAsync("localhost");
// After TLS, read INFO over encrypted stream
var buf = new byte[4096];
var read = await sslClient.ReadAsync(buf);
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
info.ShouldStartWith("INFO ");
return sslClient;
});
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
stream.ShouldBeOfType<SslStream>();
infoSent.ShouldBeTrue();
var clientSsl = await clientTask;
// Verify encrypted communication works
await stream.WriteAsync("PING\r\n"u8.ToArray());
await stream.FlushAsync();
var readBuf = new byte[64];
var bytesRead = await clientSsl.ReadAsync(readBuf);
var msg = System.Text.Encoding.ASCII.GetString(readBuf, 0, bytesRead);
msg.ShouldBe("PING\r\n");
stream.Dispose();
clientSsl.Dispose();
}
private static ServerInfo CreateServerInfo() => new()
{
ServerId = "TEST",
ServerName = "test",
Version = NatsProtocol.Version,
Host = "127.0.0.1",
Port = 4222,
};
private static async Task<(Socket server, Socket client)> CreateSocketPairAsync()
{
using var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
listener.Listen(1);
var port = ((IPEndPoint)listener.LocalEndPoint!).Port;
var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await client.ConnectAsync(new IPEndPoint(IPAddress.Loopback, port));
var server = await listener.AcceptAsync();
return (server, client);
}
}

View File

@@ -0,0 +1,110 @@
using System.Net;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using NATS.Server;
using NATS.Server.Tls;
namespace NATS.Server.Tests;
public class TlsHelperTests
{
[Fact]
public void LoadCertificate_loads_pem_cert_and_key()
{
var (certPath, keyPath) = GenerateTestCertFiles();
try
{
var cert = TlsHelper.LoadCertificate(certPath, keyPath);
cert.ShouldNotBeNull();
cert.HasPrivateKey.ShouldBeTrue();
}
finally { File.Delete(certPath); File.Delete(keyPath); }
}
[Fact]
public void BuildServerAuthOptions_creates_valid_options()
{
var (certPath, keyPath) = GenerateTestCertFiles();
try
{
var opts = new NatsOptions { TlsCert = certPath, TlsKey = keyPath };
var authOpts = TlsHelper.BuildServerAuthOptions(opts);
authOpts.ShouldNotBeNull();
authOpts.ServerCertificate.ShouldNotBeNull();
}
finally { File.Delete(certPath); File.Delete(keyPath); }
}
[Fact]
public void MatchesPinnedCert_matches_correct_hash()
{
var (cert, _) = GenerateTestCert();
var hash = TlsHelper.GetCertificateHash(cert);
var pinned = new HashSet<string> { hash };
TlsHelper.MatchesPinnedCert(cert, pinned).ShouldBeTrue();
}
[Fact]
public void MatchesPinnedCert_rejects_wrong_hash()
{
var (cert, _) = GenerateTestCert();
var pinned = new HashSet<string> { "0000000000000000000000000000000000000000000000000000000000000000" };
TlsHelper.MatchesPinnedCert(cert, pinned).ShouldBeFalse();
}
[Fact]
public async Task PeekableStream_peeks_and_replays()
{
var data = "Hello, World!"u8.ToArray();
using var ms = new MemoryStream(data);
using var peekable = new PeekableStream(ms);
var peeked = await peekable.PeekAsync(1);
peeked.Length.ShouldBe(1);
peeked[0].ShouldBe((byte)'H');
var buf = new byte[data.Length];
int total = 0;
while (total < data.Length)
{
var read = await peekable.ReadAsync(buf.AsMemory(total));
if (read == 0) break;
total += read;
}
total.ShouldBe(data.Length);
buf.ShouldBe(data);
}
[Fact]
public async Task TlsRateLimiter_allows_within_limit()
{
using var limiter = new TlsRateLimiter(10);
var cts = new CancellationTokenSource(TimeSpan.FromSeconds(2));
for (int i = 0; i < 5; i++)
await limiter.WaitAsync(cts.Token);
}
// Public helper methods used by other test classes
public static (string certPath, string keyPath) GenerateTestCertFiles()
{
var (cert, key) = GenerateTestCert();
var certPath = Path.GetTempFileName();
var keyPath = Path.GetTempFileName();
File.WriteAllText(certPath, cert.ExportCertificatePem());
File.WriteAllText(keyPath, key.ExportPkcs8PrivateKeyPem());
return (certPath, keyPath);
}
public static (X509Certificate2 cert, RSA key) GenerateTestCert()
{
var key = RSA.Create(2048);
var req = new CertificateRequest("CN=localhost", key, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
req.CertificateExtensions.Add(new X509BasicConstraintsExtension(false, false, 0, false));
var sanBuilder = new SubjectAlternativeNameBuilder();
sanBuilder.AddIpAddress(IPAddress.Loopback);
sanBuilder.AddDnsName("localhost");
req.CertificateExtensions.Add(sanBuilder.Build());
var cert = req.CreateSelfSigned(DateTimeOffset.UtcNow, DateTimeOffset.UtcNow.AddYears(1));
return (cert, key);
}
}

View File

@@ -0,0 +1,225 @@
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Text;
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Server;
namespace NATS.Server.Tests;
public class TlsServerTests : IAsyncLifetime
{
private readonly NatsServer _server;
private readonly int _port;
private readonly CancellationTokenSource _cts = new();
private readonly string _certPath;
private readonly string _keyPath;
public TlsServerTests()
{
_port = GetFreePort();
(_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles();
_server = new NatsServer(
new NatsOptions
{
Port = _port,
TlsCert = _certPath,
TlsKey = _keyPath,
},
NullLoggerFactory.Instance);
}
public async Task InitializeAsync()
{
_ = _server.StartAsync(_cts.Token);
await _server.WaitForReadyAsync();
}
public async Task DisposeAsync()
{
await _cts.CancelAsync();
_server.Dispose();
File.Delete(_certPath);
File.Delete(_keyPath);
}
[Fact]
public async Task Tls_client_connects_and_receives_info()
{
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, _port);
using var netStream = tcp.GetStream();
// Read INFO (sent before TLS upgrade in Mode 2)
var buf = new byte[4096];
var read = await netStream.ReadAsync(buf);
var info = Encoding.ASCII.GetString(buf, 0, read);
info.ShouldStartWith("INFO ");
info.ShouldContain("\"tls_required\":true");
// Upgrade to TLS
using var sslStream = new SslStream(netStream, false, (_, _, _, _) => true);
await sslStream.AuthenticateAsClientAsync("localhost");
// Send CONNECT + PING over TLS
await sslStream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
await sslStream.FlushAsync();
// Read PONG
var pongBuf = new byte[256];
read = await sslStream.ReadAsync(pongBuf);
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
pong.ShouldContain("PONG");
}
[Fact]
public async Task Tls_pubsub_works_over_encrypted_connection()
{
using var tcp1 = new TcpClient();
await tcp1.ConnectAsync(IPAddress.Loopback, _port);
using var ssl1 = await UpgradeToTlsAsync(tcp1);
using var tcp2 = new TcpClient();
await tcp2.ConnectAsync(IPAddress.Loopback, _port);
using var ssl2 = await UpgradeToTlsAsync(tcp2);
// Sub on client 1
await ssl1.WriteAsync("CONNECT {}\r\nSUB test 1\r\nPING\r\n"u8.ToArray());
await ssl1.FlushAsync();
// Wait for PONG to confirm subscription is registered
var pongBuf = new byte[256];
var pongRead = await ssl1.ReadAsync(pongBuf);
var pongStr = Encoding.ASCII.GetString(pongBuf, 0, pongRead);
pongStr.ShouldContain("PONG");
// Pub on client 2
await ssl2.WriteAsync("CONNECT {}\r\nPUB test 5\r\nhello\r\nPING\r\n"u8.ToArray());
await ssl2.FlushAsync();
// Client 1 should receive MSG (may arrive across multiple TLS records)
var msg = await ReadUntilAsync(ssl1, "hello");
msg.ShouldContain("MSG test 1 5");
msg.ShouldContain("hello");
}
private static async Task<string> ReadUntilAsync(Stream stream, string expected, int timeoutMs = 5000)
{
using var cts = new CancellationTokenSource(timeoutMs);
var sb = new StringBuilder();
var buf = new byte[4096];
while (!sb.ToString().Contains(expected))
{
var n = await stream.ReadAsync(buf, cts.Token);
if (n == 0) break;
sb.Append(Encoding.ASCII.GetString(buf, 0, n));
}
return sb.ToString();
}
private static async Task<SslStream> UpgradeToTlsAsync(TcpClient tcp)
{
var netStream = tcp.GetStream();
var buf = new byte[4096];
_ = await netStream.ReadAsync(buf); // Read INFO (discard)
var ssl = new SslStream(netStream, false, (_, _, _, _) => true);
await ssl.AuthenticateAsClientAsync("localhost");
return ssl;
}
private static int GetFreePort()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
return ((IPEndPoint)sock.LocalEndPoint!).Port;
}
}
public class TlsMixedModeTests : IAsyncLifetime
{
private readonly NatsServer _server;
private readonly int _port;
private readonly CancellationTokenSource _cts = new();
private readonly string _certPath;
private readonly string _keyPath;
public TlsMixedModeTests()
{
_port = GetFreePort();
(_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles();
_server = new NatsServer(
new NatsOptions
{
Port = _port,
TlsCert = _certPath,
TlsKey = _keyPath,
AllowNonTls = true,
},
NullLoggerFactory.Instance);
}
public async Task InitializeAsync()
{
_ = _server.StartAsync(_cts.Token);
await _server.WaitForReadyAsync();
}
public async Task DisposeAsync()
{
await _cts.CancelAsync();
_server.Dispose();
File.Delete(_certPath);
File.Delete(_keyPath);
}
[Fact]
public async Task Mixed_mode_accepts_plain_client()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port));
using var stream = new NetworkStream(sock);
var buf = new byte[4096];
var read = await stream.ReadAsync(buf);
var info = Encoding.ASCII.GetString(buf, 0, read);
info.ShouldContain("\"tls_available\":true");
await stream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
await stream.FlushAsync();
var pongBuf = new byte[64];
read = await stream.ReadAsync(pongBuf);
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
pong.ShouldContain("PONG");
}
[Fact]
public async Task Mixed_mode_accepts_tls_client()
{
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, _port);
using var netStream = tcp.GetStream();
var buf = new byte[4096];
_ = await netStream.ReadAsync(buf); // Read INFO
using var ssl = new SslStream(netStream, false, (_, _, _, _) => true);
await ssl.AuthenticateAsClientAsync("localhost");
await ssl.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
await ssl.FlushAsync();
var pongBuf = new byte[64];
var read = await ssl.ReadAsync(pongBuf);
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
pong.ShouldContain("PONG");
}
private static int GetFreePort()
{
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
return ((IPEndPoint)sock.LocalEndPoint!).Port;
}
}