diff --git a/src/NATS.Server.Host/Program.cs b/src/NATS.Server.Host/Program.cs index 90aadea..bdb6827 100644 --- a/src/NATS.Server.Host/Program.cs +++ b/src/NATS.Server.Host/Program.cs @@ -32,6 +32,20 @@ for (int i = 0; i < args.Length; i++) case "--https_port" when i + 1 < args.Length: options.MonitorHttpsPort = int.Parse(args[++i]); break; + case "--tls": + break; + case "--tlscert" when i + 1 < args.Length: + options.TlsCert = args[++i]; + break; + case "--tlskey" when i + 1 < args.Length: + options.TlsKey = args[++i]; + break; + case "--tlscacert" when i + 1 < args.Length: + options.TlsCaCert = args[++i]; + break; + case "--tlsverify": + options.TlsVerify = true; + break; } } diff --git a/src/NATS.Server/Monitoring/Connz.cs b/src/NATS.Server/Monitoring/Connz.cs new file mode 100644 index 0000000..d2a6f49 --- /dev/null +++ b/src/NATS.Server/Monitoring/Connz.cs @@ -0,0 +1,207 @@ +using System.Text.Json.Serialization; + +namespace NATS.Server.Monitoring; + +/// +/// Connection information response. Corresponds to Go server/monitor.go Connz struct. +/// +public sealed class Connz +{ + [JsonPropertyName("server_id")] + public string Id { get; set; } = ""; + + [JsonPropertyName("now")] + public DateTime Now { get; set; } + + [JsonPropertyName("num_connections")] + public int NumConns { get; set; } + + [JsonPropertyName("total")] + public int Total { get; set; } + + [JsonPropertyName("offset")] + public int Offset { get; set; } + + [JsonPropertyName("limit")] + public int Limit { get; set; } + + [JsonPropertyName("connections")] + public ConnInfo[] Conns { get; set; } = []; +} + +/// +/// Detailed information on a per-connection basis. +/// Corresponds to Go server/monitor.go ConnInfo struct. +/// +public sealed class ConnInfo +{ + [JsonPropertyName("cid")] + public ulong Cid { get; set; } + + [JsonPropertyName("kind")] + public string Kind { get; set; } = ""; + + [JsonPropertyName("type")] + public string Type { get; set; } = ""; + + [JsonPropertyName("ip")] + public string Ip { get; set; } = ""; + + [JsonPropertyName("port")] + public int Port { get; set; } + + [JsonPropertyName("start")] + public DateTime Start { get; set; } + + [JsonPropertyName("last_activity")] + public DateTime LastActivity { get; set; } + + [JsonPropertyName("stop")] + public DateTime? Stop { get; set; } + + [JsonPropertyName("reason")] + public string Reason { get; set; } = ""; + + [JsonPropertyName("rtt")] + public string Rtt { get; set; } = ""; + + [JsonPropertyName("uptime")] + public string Uptime { get; set; } = ""; + + [JsonPropertyName("idle")] + public string Idle { get; set; } = ""; + + [JsonPropertyName("pending_bytes")] + public int Pending { get; set; } + + [JsonPropertyName("in_msgs")] + public long InMsgs { get; set; } + + [JsonPropertyName("out_msgs")] + public long OutMsgs { get; set; } + + [JsonPropertyName("in_bytes")] + public long InBytes { get; set; } + + [JsonPropertyName("out_bytes")] + public long OutBytes { get; set; } + + [JsonPropertyName("subscriptions")] + public uint NumSubs { get; set; } + + [JsonPropertyName("subscriptions_list")] + public string[] Subs { get; set; } = []; + + [JsonPropertyName("subscriptions_list_detail")] + public SubDetail[] SubsDetail { get; set; } = []; + + [JsonPropertyName("name")] + public string Name { get; set; } = ""; + + [JsonPropertyName("lang")] + public string Lang { get; set; } = ""; + + [JsonPropertyName("version")] + public string Version { get; set; } = ""; + + [JsonPropertyName("authorized_user")] + public string AuthorizedUser { get; set; } = ""; + + [JsonPropertyName("account")] + public string Account { get; set; } = ""; + + [JsonPropertyName("tls_version")] + public string TlsVersion { get; set; } = ""; + + [JsonPropertyName("tls_cipher_suite")] + public string TlsCipherSuite { get; set; } = ""; + + [JsonPropertyName("tls_first")] + public bool TlsFirst { get; set; } + + [JsonPropertyName("mqtt_client")] + public string MqttClient { get; set; } = ""; +} + +/// +/// Subscription detail information. +/// Corresponds to Go server/monitor.go SubDetail struct. +/// +public sealed class SubDetail +{ + [JsonPropertyName("account")] + public string Account { get; set; } = ""; + + [JsonPropertyName("subject")] + public string Subject { get; set; } = ""; + + [JsonPropertyName("qgroup")] + public string Queue { get; set; } = ""; + + [JsonPropertyName("sid")] + public string Sid { get; set; } = ""; + + [JsonPropertyName("msgs")] + public long Msgs { get; set; } + + [JsonPropertyName("max")] + public long Max { get; set; } + + [JsonPropertyName("cid")] + public ulong Cid { get; set; } +} + +/// +/// Sort options for connection listing. +/// Corresponds to Go server/monitor_sort_opts.go SortOpt type. +/// +public enum SortOpt +{ + ByCid, + ByStart, + BySubs, + ByPending, + ByMsgsTo, + ByMsgsFrom, + ByBytesTo, + ByBytesFrom, + ByLast, + ByIdle, + ByUptime, +} + +/// +/// Connection state filter. +/// Corresponds to Go server/monitor.go ConnState type. +/// +public enum ConnState +{ + Open, + Closed, + All, +} + +/// +/// Options passed to Connz() for filtering and sorting. +/// Corresponds to Go server/monitor.go ConnzOptions struct. +/// +public sealed class ConnzOptions +{ + public SortOpt Sort { get; set; } = SortOpt.ByCid; + + public bool Subscriptions { get; set; } + + public bool SubscriptionsDetail { get; set; } + + public ConnState State { get; set; } = ConnState.Open; + + public string User { get; set; } = ""; + + public string Account { get; set; } = ""; + + public string FilterSubject { get; set; } = ""; + + public int Offset { get; set; } + + public int Limit { get; set; } = 1024; +} diff --git a/src/NATS.Server/Monitoring/ConnzHandler.cs b/src/NATS.Server/Monitoring/ConnzHandler.cs new file mode 100644 index 0000000..e5cd7b0 --- /dev/null +++ b/src/NATS.Server/Monitoring/ConnzHandler.cs @@ -0,0 +1,148 @@ +using Microsoft.AspNetCore.Http; + +namespace NATS.Server.Monitoring; + +/// +/// Handles /connz endpoint requests, returning detailed connection information. +/// Corresponds to Go server/monitor.go handleConnz function. +/// +public sealed class ConnzHandler(NatsServer server) +{ + public Connz HandleConnz(HttpContext ctx) + { + var opts = ParseQueryParams(ctx); + var now = DateTime.UtcNow; + var clients = server.GetClients().ToArray(); + + var connInfos = clients.Select(c => BuildConnInfo(c, now, opts)).ToList(); + + // Sort + connInfos = opts.Sort switch + { + SortOpt.ByCid => connInfos.OrderBy(c => c.Cid).ToList(), + SortOpt.ByStart => connInfos.OrderBy(c => c.Start).ToList(), + SortOpt.BySubs => connInfos.OrderByDescending(c => c.NumSubs).ToList(), + SortOpt.ByPending => connInfos.OrderByDescending(c => c.Pending).ToList(), + SortOpt.ByMsgsTo => connInfos.OrderByDescending(c => c.OutMsgs).ToList(), + SortOpt.ByMsgsFrom => connInfos.OrderByDescending(c => c.InMsgs).ToList(), + SortOpt.ByBytesTo => connInfos.OrderByDescending(c => c.OutBytes).ToList(), + SortOpt.ByBytesFrom => connInfos.OrderByDescending(c => c.InBytes).ToList(), + SortOpt.ByLast => connInfos.OrderByDescending(c => c.LastActivity).ToList(), + SortOpt.ByIdle => connInfos.OrderByDescending(c => now - c.LastActivity).ToList(), + SortOpt.ByUptime => connInfos.OrderByDescending(c => now - c.Start).ToList(), + _ => connInfos.OrderBy(c => c.Cid).ToList(), + }; + + var total = connInfos.Count; + var paged = connInfos.Skip(opts.Offset).Take(opts.Limit).ToArray(); + + return new Connz + { + Id = server.ServerId, + Now = now, + NumConns = paged.Length, + Total = total, + Offset = opts.Offset, + Limit = opts.Limit, + Conns = paged, + }; + } + + private static ConnInfo BuildConnInfo(NatsClient client, DateTime now, ConnzOptions opts) + { + var info = new ConnInfo + { + Cid = client.Id, + Kind = "Client", + Type = "Client", + Ip = client.RemoteIp ?? "", + Port = client.RemotePort, + Start = client.StartTime, + LastActivity = client.LastActivity, + Uptime = FormatDuration(now - client.StartTime), + Idle = FormatDuration(now - client.LastActivity), + InMsgs = Interlocked.Read(ref client.InMsgs), + OutMsgs = Interlocked.Read(ref client.OutMsgs), + InBytes = Interlocked.Read(ref client.InBytes), + OutBytes = Interlocked.Read(ref client.OutBytes), + NumSubs = (uint)client.Subscriptions.Count, + Name = client.ClientOpts?.Name ?? "", + Lang = client.ClientOpts?.Lang ?? "", + Version = client.ClientOpts?.Version ?? "", + TlsVersion = client.TlsState?.TlsVersion ?? "", + TlsCipherSuite = client.TlsState?.CipherSuite ?? "", + }; + + if (opts.Subscriptions) + { + info.Subs = client.Subscriptions.Values.Select(s => s.Subject).ToArray(); + } + + if (opts.SubscriptionsDetail) + { + info.SubsDetail = client.Subscriptions.Values.Select(s => new SubDetail + { + Subject = s.Subject, + Queue = s.Queue ?? "", + Sid = s.Sid, + Msgs = Interlocked.Read(ref s.MessageCount), + Max = s.MaxMessages, + Cid = client.Id, + }).ToArray(); + } + + return info; + } + + private static ConnzOptions ParseQueryParams(HttpContext ctx) + { + var q = ctx.Request.Query; + var opts = new ConnzOptions(); + + if (q.TryGetValue("sort", out var sort)) + { + opts.Sort = sort.ToString().ToLowerInvariant() switch + { + "cid" => SortOpt.ByCid, + "start" => SortOpt.ByStart, + "subs" => SortOpt.BySubs, + "pending" => SortOpt.ByPending, + "msgs_to" => SortOpt.ByMsgsTo, + "msgs_from" => SortOpt.ByMsgsFrom, + "bytes_to" => SortOpt.ByBytesTo, + "bytes_from" => SortOpt.ByBytesFrom, + "last" => SortOpt.ByLast, + "idle" => SortOpt.ByIdle, + "uptime" => SortOpt.ByUptime, + _ => SortOpt.ByCid, + }; + } + + if (q.TryGetValue("subs", out var subs)) + { + if (subs == "detail") + opts.SubscriptionsDetail = true; + else + opts.Subscriptions = true; + } + + if (q.TryGetValue("offset", out var offset) && int.TryParse(offset, out var o)) + opts.Offset = o; + + if (q.TryGetValue("limit", out var limit) && int.TryParse(limit, out var l)) + opts.Limit = l; + + return opts; + } + + private static string FormatDuration(TimeSpan ts) + { + if (ts.TotalDays >= 1) + return $"{(int)ts.TotalDays}d{ts.Hours}h{ts.Minutes}m{ts.Seconds}s"; + if (ts.TotalHours >= 1) + return $"{(int)ts.TotalHours}h{ts.Minutes}m{ts.Seconds}s"; + if (ts.TotalMinutes >= 1) + return $"{(int)ts.TotalMinutes}m{ts.Seconds}s"; + return $"{(int)ts.TotalSeconds}s"; + } +} diff --git a/src/NATS.Server/Monitoring/MonitorServer.cs b/src/NATS.Server/Monitoring/MonitorServer.cs new file mode 100644 index 0000000..af23506 --- /dev/null +++ b/src/NATS.Server/Monitoring/MonitorServer.cs @@ -0,0 +1,117 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; + +namespace NATS.Server.Monitoring; + +/// +/// HTTP monitoring server providing /healthz, /varz, and other monitoring endpoints. +/// Corresponds to Go server/monitor.go HTTP server setup. +/// +public sealed class MonitorServer : IAsyncDisposable +{ + private readonly WebApplication _app; + private readonly ILogger _logger; + private readonly VarzHandler _varzHandler; + private readonly ConnzHandler _connzHandler; + + public MonitorServer(NatsServer server, NatsOptions options, ServerStats stats, ILoggerFactory loggerFactory) + { + _logger = loggerFactory.CreateLogger(); + + var builder = WebApplication.CreateSlimBuilder(); + builder.WebHost.UseUrls($"http://{options.MonitorHost}:{options.MonitorPort}"); + builder.Logging.ClearProviders(); + + _app = builder.Build(); + var basePath = options.MonitorBasePath ?? ""; + + _varzHandler = new VarzHandler(server, options); + _connzHandler = new ConnzHandler(server); + + _app.MapGet(basePath + "/", () => + { + stats.HttpReqStats.AddOrUpdate("/", 1, (_, v) => v + 1); + return Results.Ok(new + { + endpoints = new[] + { + "/varz", "/connz", "/healthz", "/routez", + "/gatewayz", "/leafz", "/subz", "/accountz", "/jsz", + }, + }); + }); + _app.MapGet(basePath + "/healthz", () => + { + stats.HttpReqStats.AddOrUpdate("/healthz", 1, (_, v) => v + 1); + return Results.Ok("ok"); + }); + _app.MapGet(basePath + "/varz", async (HttpContext ctx) => + { + stats.HttpReqStats.AddOrUpdate("/varz", 1, (_, v) => v + 1); + return Results.Ok(await _varzHandler.HandleVarzAsync(ctx.RequestAborted)); + }); + + _app.MapGet(basePath + "/connz", (HttpContext ctx) => + { + stats.HttpReqStats.AddOrUpdate("/connz", 1, (_, v) => v + 1); + return Results.Ok(_connzHandler.HandleConnz(ctx)); + }); + + // Stubs for unimplemented endpoints + _app.MapGet(basePath + "/routez", () => + { + stats.HttpReqStats.AddOrUpdate("/routez", 1, (_, v) => v + 1); + return Results.Ok(new { }); + }); + _app.MapGet(basePath + "/gatewayz", () => + { + stats.HttpReqStats.AddOrUpdate("/gatewayz", 1, (_, v) => v + 1); + return Results.Ok(new { }); + }); + _app.MapGet(basePath + "/leafz", () => + { + stats.HttpReqStats.AddOrUpdate("/leafz", 1, (_, v) => v + 1); + return Results.Ok(new { }); + }); + _app.MapGet(basePath + "/subz", () => + { + stats.HttpReqStats.AddOrUpdate("/subz", 1, (_, v) => v + 1); + return Results.Ok(new { }); + }); + _app.MapGet(basePath + "/subscriptionsz", () => + { + stats.HttpReqStats.AddOrUpdate("/subscriptionsz", 1, (_, v) => v + 1); + return Results.Ok(new { }); + }); + _app.MapGet(basePath + "/accountz", () => + { + stats.HttpReqStats.AddOrUpdate("/accountz", 1, (_, v) => v + 1); + return Results.Ok(new { }); + }); + _app.MapGet(basePath + "/accstatz", () => + { + stats.HttpReqStats.AddOrUpdate("/accstatz", 1, (_, v) => v + 1); + return Results.Ok(new { }); + }); + _app.MapGet(basePath + "/jsz", () => + { + stats.HttpReqStats.AddOrUpdate("/jsz", 1, (_, v) => v + 1); + return Results.Ok(new { }); + }); + } + + public async Task StartAsync(CancellationToken ct) + { + await _app.StartAsync(ct); + _logger.LogInformation("Monitoring listening on {Urls}", string.Join(", ", _app.Urls)); + } + + public async ValueTask DisposeAsync() + { + await _app.StopAsync(); + await _app.DisposeAsync(); + _varzHandler.Dispose(); + } +} diff --git a/src/NATS.Server/Monitoring/Varz.cs b/src/NATS.Server/Monitoring/Varz.cs new file mode 100644 index 0000000..847bdc2 --- /dev/null +++ b/src/NATS.Server/Monitoring/Varz.cs @@ -0,0 +1,415 @@ +using System.Text.Json.Serialization; + +namespace NATS.Server.Monitoring; + +/// +/// Server general information. Corresponds to Go server/monitor.go Varz struct. +/// +public sealed class Varz +{ + // Identity + [JsonPropertyName("server_id")] + public string Id { get; set; } = ""; + + [JsonPropertyName("server_name")] + public string Name { get; set; } = ""; + + [JsonPropertyName("version")] + public string Version { get; set; } = ""; + + [JsonPropertyName("proto")] + public int Proto { get; set; } + + [JsonPropertyName("git_commit")] + public string GitCommit { get; set; } = ""; + + [JsonPropertyName("go")] + public string GoVersion { get; set; } = ""; + + [JsonPropertyName("host")] + public string Host { get; set; } = ""; + + [JsonPropertyName("port")] + public int Port { get; set; } + + // Network + [JsonPropertyName("ip")] + public string Ip { get; set; } = ""; + + [JsonPropertyName("connect_urls")] + public string[] ConnectUrls { get; set; } = []; + + [JsonPropertyName("ws_connect_urls")] + public string[] WsConnectUrls { get; set; } = []; + + [JsonPropertyName("http_host")] + public string HttpHost { get; set; } = ""; + + [JsonPropertyName("http_port")] + public int HttpPort { get; set; } + + [JsonPropertyName("http_base_path")] + public string HttpBasePath { get; set; } = ""; + + [JsonPropertyName("https_port")] + public int HttpsPort { get; set; } + + // Security + [JsonPropertyName("auth_required")] + public bool AuthRequired { get; set; } + + [JsonPropertyName("tls_required")] + public bool TlsRequired { get; set; } + + [JsonPropertyName("tls_verify")] + public bool TlsVerify { get; set; } + + [JsonPropertyName("tls_ocsp_peer_verify")] + public bool TlsOcspPeerVerify { get; set; } + + [JsonPropertyName("auth_timeout")] + public double AuthTimeout { get; set; } + + [JsonPropertyName("tls_timeout")] + public double TlsTimeout { get; set; } + + // Limits + [JsonPropertyName("max_connections")] + public int MaxConnections { get; set; } + + [JsonPropertyName("max_subscriptions")] + public int MaxSubscriptions { get; set; } + + [JsonPropertyName("max_payload")] + public int MaxPayload { get; set; } + + [JsonPropertyName("max_pending")] + public long MaxPending { get; set; } + + [JsonPropertyName("max_control_line")] + public int MaxControlLine { get; set; } + + [JsonPropertyName("ping_max")] + public int MaxPingsOut { get; set; } + + // Timing + [JsonPropertyName("ping_interval")] + public long PingInterval { get; set; } + + [JsonPropertyName("write_deadline")] + public long WriteDeadline { get; set; } + + [JsonPropertyName("start")] + public DateTime Start { get; set; } + + [JsonPropertyName("now")] + public DateTime Now { get; set; } + + [JsonPropertyName("uptime")] + public string Uptime { get; set; } = ""; + + // Runtime + [JsonPropertyName("mem")] + public long Mem { get; set; } + + [JsonPropertyName("cpu")] + public double Cpu { get; set; } + + [JsonPropertyName("cores")] + public int Cores { get; set; } + + [JsonPropertyName("gomaxprocs")] + public int MaxProcs { get; set; } + + // Connections + [JsonPropertyName("connections")] + public int Connections { get; set; } + + [JsonPropertyName("total_connections")] + public ulong TotalConnections { get; set; } + + [JsonPropertyName("routes")] + public int Routes { get; set; } + + [JsonPropertyName("remotes")] + public int Remotes { get; set; } + + [JsonPropertyName("leafnodes")] + public int Leafnodes { get; set; } + + // Messages + [JsonPropertyName("in_msgs")] + public long InMsgs { get; set; } + + [JsonPropertyName("out_msgs")] + public long OutMsgs { get; set; } + + [JsonPropertyName("in_bytes")] + public long InBytes { get; set; } + + [JsonPropertyName("out_bytes")] + public long OutBytes { get; set; } + + // Health + [JsonPropertyName("slow_consumers")] + public long SlowConsumers { get; set; } + + [JsonPropertyName("slow_consumer_stats")] + public SlowConsumersStats SlowConsumerStats { get; set; } = new(); + + [JsonPropertyName("subscriptions")] + public uint Subscriptions { get; set; } + + // Config + [JsonPropertyName("config_load_time")] + public DateTime ConfigLoadTime { get; set; } + + [JsonPropertyName("tags")] + public string[] Tags { get; set; } = []; + + [JsonPropertyName("system_account")] + public string SystemAccount { get; set; } = ""; + + [JsonPropertyName("pinned_account_fails")] + public ulong PinnedAccountFail { get; set; } + + [JsonPropertyName("tls_cert_not_after")] + public DateTime TlsCertNotAfter { get; set; } + + // HTTP + [JsonPropertyName("http_req_stats")] + public Dictionary HttpReqStats { get; set; } = new(); + + // Subsystems + [JsonPropertyName("cluster")] + public ClusterOptsVarz Cluster { get; set; } = new(); + + [JsonPropertyName("gateway")] + public GatewayOptsVarz Gateway { get; set; } = new(); + + [JsonPropertyName("leaf")] + public LeafNodeOptsVarz Leaf { get; set; } = new(); + + [JsonPropertyName("mqtt")] + public MqttOptsVarz Mqtt { get; set; } = new(); + + [JsonPropertyName("websocket")] + public WebsocketOptsVarz Websocket { get; set; } = new(); + + [JsonPropertyName("jetstream")] + public JetStreamVarz JetStream { get; set; } = new(); +} + +/// +/// Statistics about slow consumers by connection type. +/// Corresponds to Go server/monitor.go SlowConsumersStats struct. +/// +public sealed class SlowConsumersStats +{ + [JsonPropertyName("clients")] + public ulong Clients { get; set; } + + [JsonPropertyName("routes")] + public ulong Routes { get; set; } + + [JsonPropertyName("gateways")] + public ulong Gateways { get; set; } + + [JsonPropertyName("leafs")] + public ulong Leafs { get; set; } +} + +/// +/// Cluster configuration monitoring information. +/// Corresponds to Go server/monitor.go ClusterOptsVarz struct. +/// +public sealed class ClusterOptsVarz +{ + [JsonPropertyName("name")] + public string Name { get; set; } = ""; + + [JsonPropertyName("addr")] + public string Host { get; set; } = ""; + + [JsonPropertyName("cluster_port")] + public int Port { get; set; } + + [JsonPropertyName("auth_timeout")] + public double AuthTimeout { get; set; } + + [JsonPropertyName("tls_timeout")] + public double TlsTimeout { get; set; } + + [JsonPropertyName("tls_required")] + public bool TlsRequired { get; set; } + + [JsonPropertyName("tls_verify")] + public bool TlsVerify { get; set; } + + [JsonPropertyName("pool_size")] + public int PoolSize { get; set; } + + [JsonPropertyName("urls")] + public string[] Urls { get; set; } = []; +} + +/// +/// Gateway configuration monitoring information. +/// Corresponds to Go server/monitor.go GatewayOptsVarz struct. +/// +public sealed class GatewayOptsVarz +{ + [JsonPropertyName("name")] + public string Name { get; set; } = ""; + + [JsonPropertyName("host")] + public string Host { get; set; } = ""; + + [JsonPropertyName("port")] + public int Port { get; set; } + + [JsonPropertyName("auth_timeout")] + public double AuthTimeout { get; set; } + + [JsonPropertyName("tls_timeout")] + public double TlsTimeout { get; set; } + + [JsonPropertyName("tls_required")] + public bool TlsRequired { get; set; } + + [JsonPropertyName("tls_verify")] + public bool TlsVerify { get; set; } + + [JsonPropertyName("advertise")] + public string Advertise { get; set; } = ""; + + [JsonPropertyName("connect_retries")] + public int ConnectRetries { get; set; } + + [JsonPropertyName("reject_unknown")] + public bool RejectUnknown { get; set; } +} + +/// +/// Leaf node configuration monitoring information. +/// Corresponds to Go server/monitor.go LeafNodeOptsVarz struct. +/// +public sealed class LeafNodeOptsVarz +{ + [JsonPropertyName("host")] + public string Host { get; set; } = ""; + + [JsonPropertyName("port")] + public int Port { get; set; } + + [JsonPropertyName("auth_timeout")] + public double AuthTimeout { get; set; } + + [JsonPropertyName("tls_timeout")] + public double TlsTimeout { get; set; } + + [JsonPropertyName("tls_required")] + public bool TlsRequired { get; set; } + + [JsonPropertyName("tls_verify")] + public bool TlsVerify { get; set; } + + [JsonPropertyName("tls_ocsp_peer_verify")] + public bool TlsOcspPeerVerify { get; set; } +} + +/// +/// MQTT configuration monitoring information. +/// Corresponds to Go server/monitor.go MQTTOptsVarz struct. +/// +public sealed class MqttOptsVarz +{ + [JsonPropertyName("host")] + public string Host { get; set; } = ""; + + [JsonPropertyName("port")] + public int Port { get; set; } + + [JsonPropertyName("tls_timeout")] + public double TlsTimeout { get; set; } +} + +/// +/// Websocket configuration monitoring information. +/// Corresponds to Go server/monitor.go WebsocketOptsVarz struct. +/// +public sealed class WebsocketOptsVarz +{ + [JsonPropertyName("host")] + public string Host { get; set; } = ""; + + [JsonPropertyName("port")] + public int Port { get; set; } + + [JsonPropertyName("tls_timeout")] + public double TlsTimeout { get; set; } +} + +/// +/// JetStream runtime information. +/// Corresponds to Go server/monitor.go JetStreamVarz struct. +/// +public sealed class JetStreamVarz +{ + [JsonPropertyName("config")] + public JetStreamConfig Config { get; set; } = new(); + + [JsonPropertyName("stats")] + public JetStreamStats Stats { get; set; } = new(); +} + +/// +/// JetStream configuration. +/// Corresponds to Go server/jetstream.go JetStreamConfig struct. +/// +public sealed class JetStreamConfig +{ + [JsonPropertyName("max_memory")] + public long MaxMemory { get; set; } + + [JsonPropertyName("max_storage")] + public long MaxStorage { get; set; } + + [JsonPropertyName("store_dir")] + public string StoreDir { get; set; } = ""; +} + +/// +/// JetStream statistics. +/// Corresponds to Go server/jetstream.go JetStreamStats struct. +/// +public sealed class JetStreamStats +{ + [JsonPropertyName("memory")] + public ulong Memory { get; set; } + + [JsonPropertyName("storage")] + public ulong Storage { get; set; } + + [JsonPropertyName("accounts")] + public int Accounts { get; set; } + + [JsonPropertyName("ha_assets")] + public int HaAssets { get; set; } + + [JsonPropertyName("api")] + public JetStreamApiStats Api { get; set; } = new(); +} + +/// +/// JetStream API statistics. +/// Corresponds to Go server/jetstream.go JetStreamAPIStats struct. +/// +public sealed class JetStreamApiStats +{ + [JsonPropertyName("total")] + public ulong Total { get; set; } + + [JsonPropertyName("errors")] + public ulong Errors { get; set; } +} diff --git a/src/NATS.Server/Monitoring/VarzHandler.cs b/src/NATS.Server/Monitoring/VarzHandler.cs new file mode 100644 index 0000000..036fb92 --- /dev/null +++ b/src/NATS.Server/Monitoring/VarzHandler.cs @@ -0,0 +1,121 @@ +using System.Diagnostics; +using System.Runtime.InteropServices; +using NATS.Server.Protocol; + +namespace NATS.Server.Monitoring; + +/// +/// Handles building the Varz response from server state and process metrics. +/// Corresponds to Go server/monitor.go handleVarz function. +/// +public sealed class VarzHandler : IDisposable +{ + private readonly NatsServer _server; + private readonly NatsOptions _options; + private readonly SemaphoreSlim _varzMu = new(1, 1); + private DateTime _lastCpuSampleTime; + private TimeSpan _lastCpuUsage; + private double _cachedCpuPercent; + + public VarzHandler(NatsServer server, NatsOptions options) + { + _server = server; + _options = options; + using var proc = Process.GetCurrentProcess(); + _lastCpuSampleTime = DateTime.UtcNow; + _lastCpuUsage = proc.TotalProcessorTime; + } + + public async Task HandleVarzAsync(CancellationToken ct = default) + { + await _varzMu.WaitAsync(ct); + try + { + using var proc = Process.GetCurrentProcess(); + var now = DateTime.UtcNow; + var uptime = now - _server.StartTime; + var stats = _server.Stats; + + // CPU sampling with 1-second cache to avoid excessive sampling + if ((now - _lastCpuSampleTime).TotalSeconds >= 1.0) + { + var currentCpu = proc.TotalProcessorTime; + var elapsed = now - _lastCpuSampleTime; + _cachedCpuPercent = (currentCpu - _lastCpuUsage).TotalMilliseconds + / elapsed.TotalMilliseconds / Environment.ProcessorCount * 100.0; + _lastCpuSampleTime = now; + _lastCpuUsage = currentCpu; + } + + return new Varz + { + Id = _server.ServerId, + Name = _server.ServerName, + Version = NatsProtocol.Version, + Proto = NatsProtocol.ProtoVersion, + GoVersion = $"dotnet {RuntimeInformation.FrameworkDescription}", + Host = _options.Host, + Port = _options.Port, + HttpHost = _options.MonitorHost, + HttpPort = _options.MonitorPort, + HttpBasePath = _options.MonitorBasePath ?? "", + HttpsPort = _options.MonitorHttpsPort, + TlsRequired = _options.HasTls && !_options.AllowNonTls, + TlsVerify = _options.HasTls && _options.TlsVerify, + TlsTimeout = _options.HasTls ? _options.TlsTimeout.TotalSeconds : 0, + MaxConnections = _options.MaxConnections, + MaxPayload = _options.MaxPayload, + MaxControlLine = _options.MaxControlLine, + MaxPingsOut = _options.MaxPingsOut, + PingInterval = (long)_options.PingInterval.TotalNanoseconds, + Start = _server.StartTime, + Now = now, + Uptime = FormatUptime(uptime), + Mem = proc.WorkingSet64, + Cpu = Math.Round(_cachedCpuPercent, 2), + Cores = Environment.ProcessorCount, + MaxProcs = ThreadPool.ThreadCount, + Connections = _server.ClientCount, + TotalConnections = (ulong)Interlocked.Read(ref stats.TotalConnections), + InMsgs = Interlocked.Read(ref stats.InMsgs), + OutMsgs = Interlocked.Read(ref stats.OutMsgs), + InBytes = Interlocked.Read(ref stats.InBytes), + OutBytes = Interlocked.Read(ref stats.OutBytes), + SlowConsumers = Interlocked.Read(ref stats.SlowConsumers), + SlowConsumerStats = new SlowConsumersStats + { + Clients = (ulong)Interlocked.Read(ref stats.SlowConsumerClients), + Routes = (ulong)Interlocked.Read(ref stats.SlowConsumerRoutes), + Gateways = (ulong)Interlocked.Read(ref stats.SlowConsumerGateways), + Leafs = (ulong)Interlocked.Read(ref stats.SlowConsumerLeafs), + }, + Subscriptions = _server.SubList.Count, + ConfigLoadTime = _server.StartTime, + HttpReqStats = stats.HttpReqStats.ToDictionary(kv => kv.Key, kv => (ulong)kv.Value), + }; + } + finally + { + _varzMu.Release(); + } + } + + public void Dispose() + { + _varzMu.Dispose(); + } + + /// + /// Formats a TimeSpan as a human-readable uptime string matching Go server format. + /// + private static string FormatUptime(TimeSpan ts) + { + if (ts.TotalDays >= 1) + return $"{(int)ts.TotalDays}d{ts.Hours}h{ts.Minutes}m{ts.Seconds}s"; + if (ts.TotalHours >= 1) + return $"{(int)ts.TotalHours}h{ts.Minutes}m{ts.Seconds}s"; + if (ts.TotalMinutes >= 1) + return $"{(int)ts.TotalMinutes}m{ts.Seconds}s"; + return $"{(int)ts.TotalSeconds}s"; + } +} diff --git a/src/NATS.Server/NATS.Server.csproj b/src/NATS.Server/NATS.Server.csproj index 4a9060d..d85e688 100644 --- a/src/NATS.Server/NATS.Server.csproj +++ b/src/NATS.Server/NATS.Server.csproj @@ -1,6 +1,6 @@ - + diff --git a/src/NATS.Server/NatsClient.cs b/src/NATS.Server/NatsClient.cs index 473c52c..69cde4c 100644 --- a/src/NATS.Server/NatsClient.cs +++ b/src/NATS.Server/NatsClient.cs @@ -1,5 +1,6 @@ using System.Buffers; using System.IO.Pipelines; +using System.Net; using System.Net.Sockets; using System.Security.Cryptography; using System.Text; @@ -8,6 +9,7 @@ using Microsoft.Extensions.Logging; using NATS.Server.Auth; using NATS.Server.Protocol; using NATS.Server.Subscriptions; +using NATS.Server.Tls; namespace NATS.Server; @@ -26,7 +28,7 @@ public interface ISubListAccess public sealed class NatsClient : IDisposable { private readonly Socket _socket; - private readonly NetworkStream _stream; + private readonly Stream _stream; private readonly NatsOptions _options; private readonly ServerInfo _serverInfo; private readonly AuthService _authService; @@ -37,6 +39,7 @@ public sealed class NatsClient : IDisposable private readonly Dictionary _subs = new(); private readonly ILogger _logger; private ClientPermissions? _permissions; + private readonly ServerStats _serverStats; public ulong Id { get; } public ClientOptions? ClientOpts { get; private set; } @@ -47,6 +50,12 @@ public sealed class NatsClient : IDisposable private int _connectReceived; public bool ConnectReceived => Volatile.Read(ref _connectReceived) != 0; + public DateTime StartTime { get; } + private long _lastActivityTicks; + public DateTime LastActivity => new(Interlocked.Read(ref _lastActivityTicks), DateTimeKind.Utc); + public string? RemoteIp { get; } + public int RemotePort { get; } + // Stats public long InMsgs; public long OutMsgs; @@ -57,20 +66,31 @@ public sealed class NatsClient : IDisposable private int _pingsOut; private long _lastIn; + public TlsConnectionState? TlsState { get; set; } + public bool InfoAlreadySent { get; set; } + public IReadOnlyDictionary Subscriptions => _subs; - public NatsClient(ulong id, Socket socket, NatsOptions options, ServerInfo serverInfo, - AuthService authService, byte[]? nonce, ILogger logger) + public NatsClient(ulong id, Stream stream, Socket socket, NatsOptions options, ServerInfo serverInfo, + AuthService authService, byte[]? nonce, ILogger logger, ServerStats serverStats) { Id = id; _socket = socket; - _stream = new NetworkStream(socket, ownsSocket: false); + _stream = stream; _options = options; _serverInfo = serverInfo; _authService = authService; _nonce = nonce; _logger = logger; + _serverStats = serverStats; _parser = new NatsParser(options.MaxPayload); + StartTime = DateTime.UtcNow; + _lastActivityTicks = StartTime.Ticks; + if (socket.RemoteEndPoint is IPEndPoint ep) + { + RemoteIp = ep.Address.ToString(); + RemotePort = ep.Port; + } } public async Task RunAsync(CancellationToken ct) @@ -80,8 +100,9 @@ public sealed class NatsClient : IDisposable var pipe = new Pipe(); try { - // Send INFO - await SendInfoAsync(_clientCts.Token); + // Send INFO (skip if already sent during TLS negotiation) + if (!InfoAlreadySent) + await SendInfoAsync(_clientCts.Token); // Start auth timeout if auth is required Task? authTimeoutTask = null; @@ -100,7 +121,7 @@ public sealed class NatsClient : IDisposable } catch (OperationCanceledException) { - // Normal — client connected or was cancelled + // Normal -- client connected or was cancelled } }, _clientCts.Token); } @@ -184,6 +205,8 @@ public sealed class NatsClient : IDisposable private async ValueTask DispatchCommandAsync(ParsedCommand cmd, CancellationToken ct) { + Interlocked.Exchange(ref _lastActivityTicks, DateTime.UtcNow.Ticks); + // If auth is required and CONNECT hasn't been received yet, // only allow CONNECT and PING commands if (_authService.IsAuthRequired && !ConnectReceived) @@ -266,7 +289,7 @@ public sealed class NatsClient : IDisposable _logger.LogDebug("Client {ClientId} authenticated as {Identity}", Id, result.Identity); - // Clear nonce after use — defense-in-depth against memory dumps + // Clear nonce after use -- defense-in-depth against memory dumps if (_nonce != null) CryptographicOperations.ZeroMemory(_nonce); } @@ -330,6 +353,8 @@ public sealed class NatsClient : IDisposable { Interlocked.Increment(ref InMsgs); Interlocked.Add(ref InBytes, cmd.Payload.Length); + Interlocked.Increment(ref _serverStats.InMsgs); + Interlocked.Add(ref _serverStats.InBytes, cmd.Payload.Length); // Max payload validation (always, hard close) if (cmd.Payload.Length > _options.MaxPayload) @@ -380,6 +405,8 @@ public sealed class NatsClient : IDisposable { Interlocked.Increment(ref OutMsgs); Interlocked.Add(ref OutBytes, payload.Length + headers.Length); + Interlocked.Increment(ref _serverStats.OutMsgs); + Interlocked.Add(ref _serverStats.OutBytes, payload.Length + headers.Length); byte[] line; if (headers.Length > 0) @@ -470,7 +497,7 @@ public sealed class NatsClient : IDisposable if (Volatile.Read(ref _pingsOut) + 1 > _options.MaxPingsOut) { - _logger.LogDebug("Client {ClientId} stale connection — closing", Id); + _logger.LogDebug("Client {ClientId} stale connection -- closing", Id); await SendErrAndCloseAsync(NatsProtocol.ErrStaleConnection); return; } diff --git a/src/NATS.Server/NatsOptions.cs b/src/NATS.Server/NatsOptions.cs index c641fe3..1aced93 100644 --- a/src/NATS.Server/NatsOptions.cs +++ b/src/NATS.Server/NatsOptions.cs @@ -1,3 +1,4 @@ +using System.Security.Authentication; using NATS.Server.Auth; namespace NATS.Server; @@ -7,7 +8,7 @@ public sealed class NatsOptions public string Host { get; set; } = "0.0.0.0"; public int Port { get; set; } = 4222; public string? ServerName { get; set; } - public int MaxPayload { get; set; } = 1024 * 1024; // 1MB + public int MaxPayload { get; set; } = 1024 * 1024; public int MaxControlLine { get; set; } = 4096; public int MaxConnections { get; set; } = 65536; public TimeSpan PingInterval { get; set; } = TimeSpan.FromMinutes(2); @@ -27,4 +28,27 @@ public sealed class NatsOptions // Auth timing public TimeSpan AuthTimeout { get; set; } = TimeSpan.FromSeconds(2); + + // Monitoring (0 = disabled; standard port is 8222) + public int MonitorPort { get; set; } + public string MonitorHost { get; set; } = "0.0.0.0"; + public string? MonitorBasePath { get; set; } + // 0 = disabled + public int MonitorHttpsPort { get; set; } + + // TLS + public string? TlsCert { get; set; } + public string? TlsKey { get; set; } + public string? TlsCaCert { get; set; } + public bool TlsVerify { get; set; } + public bool TlsMap { get; set; } + public TimeSpan TlsTimeout { get; set; } = TimeSpan.FromSeconds(2); + public bool TlsHandshakeFirst { get; set; } + public TimeSpan TlsHandshakeFirstFallback { get; set; } = TimeSpan.FromMilliseconds(50); + public bool AllowNonTls { get; set; } + public long TlsRateLimit { get; set; } + public HashSet? TlsPinnedCerts { get; set; } + public SslProtocols TlsMinVersion { get; set; } = SslProtocols.Tls12; + + public bool HasTls => TlsCert != null && TlsKey != null; } diff --git a/src/NATS.Server/NatsServer.cs b/src/NATS.Server/NatsServer.cs index fbf540a..87f6b97 100644 --- a/src/NATS.Server/NatsServer.cs +++ b/src/NATS.Server/NatsServer.cs @@ -1,11 +1,15 @@ using System.Collections.Concurrent; using System.Net; +using System.Net.Security; using System.Net.Sockets; +using System.Security.Cryptography.X509Certificates; using System.Text; using Microsoft.Extensions.Logging; using NATS.Server.Auth; +using NATS.Server.Monitoring; using NATS.Server.Protocol; using NATS.Server.Subscriptions; +using NATS.Server.Tls; namespace NATS.Server; @@ -16,14 +20,25 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable private readonly ServerInfo _serverInfo; private readonly ILogger _logger; private readonly ILoggerFactory _loggerFactory; + private readonly ServerStats _stats = new(); private readonly TaskCompletionSource _listeningStarted = new(TaskCreationOptions.RunContinuationsAsynchronously); private readonly AuthService _authService; private readonly ConcurrentDictionary _accounts = new(StringComparer.Ordinal); private readonly Account _globalAccount; + private readonly SslServerAuthenticationOptions? _sslOptions; + private readonly TlsRateLimiter? _tlsRateLimiter; private Socket? _listener; + private MonitorServer? _monitorServer; private ulong _nextClientId; + private long _startTimeTicks; public SubList SubList => _globalAccount.SubList; + public ServerStats Stats => _stats; + public DateTime StartTime => new(Interlocked.Read(ref _startTimeTicks), DateTimeKind.Utc); + public string ServerId => _serverInfo.ServerId; + public string ServerName => _serverInfo.ServerName; + public int ClientCount => _clients.Count; + public IEnumerable GetClients() => _clients.Values; public Task WaitForReadyAsync() => _listeningStarted.Task; @@ -45,6 +60,17 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable MaxPayload = options.MaxPayload, AuthRequired = _authService.IsAuthRequired, }; + + if (options.HasTls) + { + _sslOptions = TlsHelper.BuildServerAuthOptions(options); + _serverInfo.TlsRequired = !options.AllowNonTls; + _serverInfo.TlsAvailable = options.AllowNonTls; + _serverInfo.TlsVerify = options.TlsVerify; + + if (options.TlsRateLimit > 0) + _tlsRateLimiter = new TlsRateLimiter(options.TlsRateLimit); + } } public async Task StartAsync(CancellationToken ct) @@ -54,11 +80,18 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable _listener.Bind(new IPEndPoint( _options.Host == "0.0.0.0" ? IPAddress.Any : IPAddress.Parse(_options.Host), _options.Port)); + Interlocked.Exchange(ref _startTimeTicks, DateTime.UtcNow.Ticks); _listener.Listen(128); _listeningStarted.TrySetResult(); _logger.LogInformation("Listening on {Host}:{Port}", _options.Host, _options.Port); + if (_options.MonitorPort > 0) + { + _monitorServer = new MonitorServer(this, _options, _stats, _loggerFactory); + await _monitorServer.StartAsync(ct); + } + try { while (!ct.IsCancellationRequested) @@ -91,37 +124,11 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable } var clientId = Interlocked.Increment(ref _nextClientId); + Interlocked.Increment(ref _stats.TotalConnections); _logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint); - // Build per-client ServerInfo with nonce if NKey auth is configured - byte[]? nonce = null; - var clientInfo = _serverInfo; - if (_authService.NonceRequired) - { - var rawNonce = _authService.GenerateNonce(); - var nonceStr = _authService.EncodeNonce(rawNonce); - // The client signs the nonce string (ASCII), not the raw bytes - nonce = Encoding.ASCII.GetBytes(nonceStr); - clientInfo = new ServerInfo - { - ServerId = _serverInfo.ServerId, - ServerName = _serverInfo.ServerName, - Version = _serverInfo.Version, - Host = _serverInfo.Host, - Port = _serverInfo.Port, - MaxPayload = _serverInfo.MaxPayload, - AuthRequired = _serverInfo.AuthRequired, - Nonce = nonceStr, - }; - } - - var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]"); - var client = new NatsClient(clientId, socket, _options, clientInfo, _authService, nonce, clientLogger); - client.Router = this; - _clients[clientId] = client; - - _ = RunClientAsync(client, ct); + _ = AcceptClientAsync(socket, clientId, ct); } } catch (OperationCanceledException) @@ -130,6 +137,74 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable } } + private async Task AcceptClientAsync(Socket socket, ulong clientId, CancellationToken ct) + { + try + { + // Rate limit TLS handshakes + if (_tlsRateLimiter != null) + await _tlsRateLimiter.WaitAsync(ct); + + var networkStream = new NetworkStream(socket, ownsSocket: false); + + // TLS negotiation (no-op if not configured) + var (stream, infoAlreadySent) = await TlsConnectionWrapper.NegotiateAsync( + socket, networkStream, _options, _sslOptions, _serverInfo, + _loggerFactory.CreateLogger("NATS.Server.Tls"), ct); + + // Extract TLS state + TlsConnectionState? tlsState = null; + if (stream is SslStream ssl) + { + tlsState = new TlsConnectionState( + ssl.SslProtocol.ToString(), + ssl.NegotiatedCipherSuite.ToString(), + ssl.RemoteCertificate as X509Certificate2); + } + + // Build per-client ServerInfo with nonce if NKey auth is configured + byte[]? nonce = null; + var clientInfo = _serverInfo; + if (_authService.NonceRequired) + { + var rawNonce = _authService.GenerateNonce(); + var nonceStr = _authService.EncodeNonce(rawNonce); + // The client signs the nonce string (ASCII), not the raw bytes + nonce = Encoding.ASCII.GetBytes(nonceStr); + clientInfo = new ServerInfo + { + ServerId = _serverInfo.ServerId, + ServerName = _serverInfo.ServerName, + Version = _serverInfo.Version, + Host = _serverInfo.Host, + Port = _serverInfo.Port, + MaxPayload = _serverInfo.MaxPayload, + AuthRequired = _serverInfo.AuthRequired, + TlsRequired = _serverInfo.TlsRequired, + TlsAvailable = _serverInfo.TlsAvailable, + TlsVerify = _serverInfo.TlsVerify, + Nonce = nonceStr, + }; + } + + var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]"); + var client = new NatsClient(clientId, stream, socket, _options, clientInfo, + _authService, nonce, clientLogger, _stats); + client.Router = this; + client.TlsState = tlsState; + client.InfoAlreadySent = infoAlreadySent; + _clients[clientId] = client; + + await RunClientAsync(client, ct); + } + catch (Exception ex) + { + _logger.LogDebug(ex, "Failed to accept client {ClientId}", clientId); + try { socket.Shutdown(SocketShutdown.Both); } catch { } + socket.Dispose(); + } + } + private async Task RunClientAsync(NatsClient client, CancellationToken ct) { try @@ -215,6 +290,9 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable public void Dispose() { + if (_monitorServer != null) + _monitorServer.DisposeAsync().AsTask().GetAwaiter().GetResult(); + _tlsRateLimiter?.Dispose(); _listener?.Dispose(); foreach (var client in _clients.Values) client.Dispose(); diff --git a/src/NATS.Server/Protocol/NatsProtocol.cs b/src/NATS.Server/Protocol/NatsProtocol.cs index fbfa1bb..44a603a 100644 --- a/src/NATS.Server/Protocol/NatsProtocol.cs +++ b/src/NATS.Server/Protocol/NatsProtocol.cs @@ -73,6 +73,18 @@ public sealed class ServerInfo [JsonPropertyName("nonce")] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public string? Nonce { get; set; } + + [JsonPropertyName("tls_required")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public bool TlsRequired { get; set; } + + [JsonPropertyName("tls_verify")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public bool TlsVerify { get; set; } + + [JsonPropertyName("tls_available")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public bool TlsAvailable { get; set; } } public sealed class ClientOptions diff --git a/src/NATS.Server/ServerStats.cs b/src/NATS.Server/ServerStats.cs new file mode 100644 index 0000000..b737dee --- /dev/null +++ b/src/NATS.Server/ServerStats.cs @@ -0,0 +1,20 @@ +using System.Collections.Concurrent; + +namespace NATS.Server; + +public sealed class ServerStats +{ + public long InMsgs; + public long OutMsgs; + public long InBytes; + public long OutBytes; + public long TotalConnections; + public long SlowConsumers; + public long StaleConnections; + public long Stalls; + public long SlowConsumerClients; + public long SlowConsumerRoutes; + public long SlowConsumerLeafs; + public long SlowConsumerGateways; + public readonly ConcurrentDictionary HttpReqStats = new(); +} diff --git a/src/NATS.Server/Tls/PeekableStream.cs b/src/NATS.Server/Tls/PeekableStream.cs new file mode 100644 index 0000000..29abf07 --- /dev/null +++ b/src/NATS.Server/Tls/PeekableStream.cs @@ -0,0 +1,71 @@ +namespace NATS.Server.Tls; + +public sealed class PeekableStream : Stream +{ + private readonly Stream _inner; + private byte[]? _peekedBytes; + private int _peekedOffset; + private int _peekedCount; + + public PeekableStream(Stream inner) => _inner = inner; + + public async Task PeekAsync(int count, CancellationToken ct = default) + { + var buf = new byte[count]; + int read = await _inner.ReadAsync(buf.AsMemory(0, count), ct); + if (read < count) Array.Resize(ref buf, read); + _peekedBytes = buf; + _peekedOffset = 0; + _peekedCount = read; + return buf; + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken ct = default) + { + if (_peekedBytes != null && _peekedOffset < _peekedCount) + { + int available = _peekedCount - _peekedOffset; + int toCopy = Math.Min(available, buffer.Length); + _peekedBytes.AsMemory(_peekedOffset, toCopy).CopyTo(buffer); + _peekedOffset += toCopy; + if (_peekedOffset >= _peekedCount) _peekedBytes = null; + return toCopy; + } + return await _inner.ReadAsync(buffer, ct); + } + + public override int Read(byte[] buffer, int offset, int count) + { + if (_peekedBytes != null && _peekedOffset < _peekedCount) + { + int available = _peekedCount - _peekedOffset; + int toCopy = Math.Min(available, count); + Array.Copy(_peekedBytes, _peekedOffset, buffer, offset, toCopy); + _peekedOffset += toCopy; + if (_peekedOffset >= _peekedCount) _peekedBytes = null; + return toCopy; + } + return _inner.Read(buffer, offset, count); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken ct) + => ReadAsync(buffer.AsMemory(offset, count), ct).AsTask(); + + // Write passthrough + public override void Write(byte[] buffer, int offset, int count) => _inner.Write(buffer, offset, count); + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ct) => _inner.WriteAsync(buffer, offset, count, ct); + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken ct = default) => _inner.WriteAsync(buffer, ct); + public override void Flush() => _inner.Flush(); + public override Task FlushAsync(CancellationToken ct) => _inner.FlushAsync(ct); + + // Required Stream overrides + public override bool CanRead => _inner.CanRead; + public override bool CanSeek => false; + public override bool CanWrite => _inner.CanWrite; + public override long Length => throw new NotSupportedException(); + public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + + protected override void Dispose(bool disposing) { if (disposing) _inner.Dispose(); base.Dispose(disposing); } +} diff --git a/src/NATS.Server/Tls/TlsConnectionState.cs b/src/NATS.Server/Tls/TlsConnectionState.cs new file mode 100644 index 0000000..0fe788a --- /dev/null +++ b/src/NATS.Server/Tls/TlsConnectionState.cs @@ -0,0 +1,9 @@ +using System.Security.Cryptography.X509Certificates; + +namespace NATS.Server.Tls; + +public sealed record TlsConnectionState( + string? TlsVersion, + string? CipherSuite, + X509Certificate2? PeerCert +); diff --git a/src/NATS.Server/Tls/TlsConnectionWrapper.cs b/src/NATS.Server/Tls/TlsConnectionWrapper.cs new file mode 100644 index 0000000..0ca0961 --- /dev/null +++ b/src/NATS.Server/Tls/TlsConnectionWrapper.cs @@ -0,0 +1,202 @@ +using System.Net.Security; +using System.Net.Sockets; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Text.Json; +using Microsoft.Extensions.Logging; +using NATS.Server.Protocol; + +namespace NATS.Server.Tls; + +public static class TlsConnectionWrapper +{ + private const byte TlsRecordMarker = 0x16; + + public static async Task<(Stream stream, bool infoAlreadySent)> NegotiateAsync( + Socket socket, + Stream networkStream, + NatsOptions options, + SslServerAuthenticationOptions? sslOptions, + ServerInfo serverInfo, + ILogger logger, + CancellationToken ct) + { + // Mode 1: No TLS + if (sslOptions == null || !options.HasTls) + return (networkStream, false); + + // Clone to avoid mutating shared instance + serverInfo = new ServerInfo + { + ServerId = serverInfo.ServerId, + ServerName = serverInfo.ServerName, + Version = serverInfo.Version, + Proto = serverInfo.Proto, + Host = serverInfo.Host, + Port = serverInfo.Port, + Headers = serverInfo.Headers, + MaxPayload = serverInfo.MaxPayload, + ClientId = serverInfo.ClientId, + ClientIp = serverInfo.ClientIp, + }; + + // Mode 3: TLS First + if (options.TlsHandshakeFirst) + return await NegotiateTlsFirstAsync(socket, networkStream, options, sslOptions, serverInfo, logger, ct); + + // Mode 2 & 4: Send INFO first, then decide + serverInfo.TlsRequired = !options.AllowNonTls; + serverInfo.TlsAvailable = options.AllowNonTls; + serverInfo.TlsVerify = options.TlsVerify; + await SendInfoAsync(networkStream, serverInfo, ct); + + // Peek first byte to detect TLS + var peekable = new PeekableStream(networkStream); + var peeked = await PeekWithTimeoutAsync(peekable, 1, options.TlsTimeout, ct); + + if (peeked.Length == 0) + { + // Client disconnected or timed out + return (peekable, true); + } + + if (peeked[0] == TlsRecordMarker) + { + // Client is starting TLS + var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false); + try + { + using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + handshakeCts.CancelAfter(options.TlsTimeout); + + await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token); + logger.LogDebug("TLS handshake complete: {Protocol} {CipherSuite}", + sslStream.SslProtocol, sslStream.NegotiatedCipherSuite); + + // Validate pinned certs + if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert) + { + if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts)) + { + logger.LogWarning("Certificate pinning check failed"); + throw new InvalidOperationException("Certificate pinning check failed"); + } + } + } + catch + { + sslStream.Dispose(); + throw; + } + + return (sslStream, true); + } + + // Mode 4: Mixed — client chose plaintext + if (options.AllowNonTls) + { + logger.LogDebug("Client connected without TLS (mixed mode)"); + return (peekable, true); + } + + // TLS required but client sent plaintext + logger.LogWarning("TLS required but client sent plaintext data"); + throw new InvalidOperationException("TLS required"); + } + + private static async Task<(Stream stream, bool infoAlreadySent)> NegotiateTlsFirstAsync( + Socket socket, + Stream networkStream, + NatsOptions options, + SslServerAuthenticationOptions sslOptions, + ServerInfo serverInfo, + ILogger logger, + CancellationToken ct) + { + // Wait for data with fallback timeout + var peekable = new PeekableStream(networkStream); + var peeked = await PeekWithTimeoutAsync(peekable, 1, options.TlsHandshakeFirstFallback, ct); + + if (peeked.Length > 0 && peeked[0] == TlsRecordMarker) + { + // Client started TLS immediately — handshake first, then send INFO + var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false); + try + { + using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + handshakeCts.CancelAfter(options.TlsTimeout); + + await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token); + logger.LogDebug("TLS-first handshake complete: {Protocol} {CipherSuite}", + sslStream.SslProtocol, sslStream.NegotiatedCipherSuite); + + // Validate pinned certs + if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert) + { + if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts)) + { + throw new InvalidOperationException("Certificate pinning check failed"); + } + } + + // Now send INFO over encrypted stream + serverInfo.TlsRequired = true; + serverInfo.TlsVerify = options.TlsVerify; + await SendInfoAsync(sslStream, serverInfo, ct); + } + catch + { + sslStream.Dispose(); + throw; + } + + return (sslStream, true); + } + + // Fallback: timeout expired or non-TLS data — send INFO and negotiate normally + logger.LogDebug("TLS-first fallback: sending INFO"); + serverInfo.TlsRequired = !options.AllowNonTls; + serverInfo.TlsAvailable = options.AllowNonTls; + serverInfo.TlsVerify = options.TlsVerify; + await SendInfoAsync(peekable, serverInfo, ct); + + if (peeked.Length == 0) + { + // Timeout — INFO was sent, return stream for normal flow + return (peekable, true); + } + + // Non-TLS data received during fallback window + if (options.AllowNonTls) + { + return (peekable, true); + } + + // TLS required but got plaintext + throw new InvalidOperationException("TLS required but client sent plaintext"); + } + + private static async Task PeekWithTimeoutAsync( + PeekableStream stream, int count, TimeSpan timeout, CancellationToken ct) + { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); + cts.CancelAfter(timeout); + try + { + return await stream.PeekAsync(count, cts.Token); + } + catch (OperationCanceledException) when (!ct.IsCancellationRequested) + { + // Timeout — not a cancellation of the outer token + return []; + } + } + + private static async Task SendInfoAsync(Stream stream, ServerInfo serverInfo, CancellationToken ct) + { + var infoJson = JsonSerializer.Serialize(serverInfo); + var infoLine = Encoding.ASCII.GetBytes($"INFO {infoJson}\r\n"); + await stream.WriteAsync(infoLine, ct); + await stream.FlushAsync(ct); + } +} diff --git a/src/NATS.Server/Tls/TlsHelper.cs b/src/NATS.Server/Tls/TlsHelper.cs new file mode 100644 index 0000000..cdc5ef6 --- /dev/null +++ b/src/NATS.Server/Tls/TlsHelper.cs @@ -0,0 +1,65 @@ +using System.Net.Security; +using System.Security.Authentication; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; + +namespace NATS.Server.Tls; + +public static class TlsHelper +{ + public static X509Certificate2 LoadCertificate(string certPath, string? keyPath) + { + if (keyPath != null) + return X509Certificate2.CreateFromPemFile(certPath, keyPath); + return X509CertificateLoader.LoadCertificateFromFile(certPath); + } + + public static X509Certificate2Collection LoadCaCertificates(string caPath) + { + var collection = new X509Certificate2Collection(); + collection.ImportFromPemFile(caPath); + return collection; + } + + public static SslServerAuthenticationOptions BuildServerAuthOptions(NatsOptions opts) + { + var cert = LoadCertificate(opts.TlsCert!, opts.TlsKey); + var authOpts = new SslServerAuthenticationOptions + { + ServerCertificate = cert, + EnabledSslProtocols = opts.TlsMinVersion, + ClientCertificateRequired = opts.TlsVerify, + }; + + if (opts.TlsVerify && opts.TlsCaCert != null) + { + var caCerts = LoadCaCertificates(opts.TlsCaCert); + authOpts.RemoteCertificateValidationCallback = (_, cert, chain, errors) => + { + if (cert == null) return false; + using var chain2 = new X509Chain(); + chain2.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust; + foreach (var ca in caCerts) + chain2.ChainPolicy.CustomTrustStore.Add(ca); + chain2.ChainPolicy.RevocationMode = X509RevocationMode.NoCheck; + var cert2 = cert as X509Certificate2 ?? X509CertificateLoader.LoadCertificate(cert.GetRawCertData()); + return chain2.Build(cert2); + }; + } + + return authOpts; + } + + public static string GetCertificateHash(X509Certificate2 cert) + { + var spki = cert.PublicKey.ExportSubjectPublicKeyInfo(); + var hash = SHA256.HashData(spki); + return Convert.ToHexStringLower(hash); + } + + public static bool MatchesPinnedCert(X509Certificate2 cert, HashSet pinned) + { + var hash = GetCertificateHash(cert); + return pinned.Contains(hash); + } +} diff --git a/src/NATS.Server/Tls/TlsRateLimiter.cs b/src/NATS.Server/Tls/TlsRateLimiter.cs new file mode 100644 index 0000000..75741a5 --- /dev/null +++ b/src/NATS.Server/Tls/TlsRateLimiter.cs @@ -0,0 +1,25 @@ +namespace NATS.Server.Tls; + +public sealed class TlsRateLimiter : IDisposable +{ + private readonly SemaphoreSlim _semaphore; + private readonly Timer _refillTimer; + private readonly int _tokensPerSecond; + + public TlsRateLimiter(long tokensPerSecond) + { + _tokensPerSecond = (int)Math.Max(1, tokensPerSecond); + _semaphore = new SemaphoreSlim(_tokensPerSecond, _tokensPerSecond); + _refillTimer = new Timer(Refill, null, TimeSpan.FromSeconds(1), TimeSpan.FromSeconds(1)); + } + + private void Refill(object? state) + { + int toRelease = _tokensPerSecond - _semaphore.CurrentCount; + if (toRelease > 0) _semaphore.Release(toRelease); + } + + public Task WaitAsync(CancellationToken ct) => _semaphore.WaitAsync(ct); + + public void Dispose() { _refillTimer.Dispose(); _semaphore.Dispose(); } +} diff --git a/tests/NATS.Server.Tests/ClientTests.cs b/tests/NATS.Server.Tests/ClientTests.cs index 096877a..92bcdd8 100644 --- a/tests/NATS.Server.Tests/ClientTests.cs +++ b/tests/NATS.Server.Tests/ClientTests.cs @@ -41,7 +41,7 @@ public class ClientTests : IAsyncDisposable }; var authService = AuthService.Build(new NatsOptions()); - _natsClient = new NatsClient(1, _serverSocket, new NatsOptions(), serverInfo, authService, null, NullLogger.Instance); + _natsClient = new NatsClient(1, new NetworkStream(_serverSocket, ownsSocket: false), _serverSocket, new NatsOptions(), serverInfo, authService, null, NullLogger.Instance, new ServerStats()); } public async ValueTask DisposeAsync() @@ -56,7 +56,7 @@ public class ClientTests : IAsyncDisposable { var runTask = _natsClient.RunAsync(_cts.Token); - // Read from client socket — should get INFO + // Read from client socket -- should get INFO var buf = new byte[4096]; var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None); var response = Encoding.ASCII.GetString(buf, 0, n); @@ -80,7 +80,7 @@ public class ClientTests : IAsyncDisposable // Send CONNECT then PING await _clientSocket.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPING\r\n")); - // Read response — should get PONG + // Read response -- should get PONG var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None); var response = Encoding.ASCII.GetString(buf, 0, n); @@ -128,7 +128,7 @@ public class ClientTests : IAsyncDisposable response.ShouldBe("-ERR 'maximum connections exceeded'\r\n"); - // Connection should be closed — next read returns 0 + // Connection should be closed -- next read returns 0 n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token); n.ShouldBe(0); } diff --git a/tests/NATS.Server.Tests/MonitorModelTests.cs b/tests/NATS.Server.Tests/MonitorModelTests.cs new file mode 100644 index 0000000..690afd5 --- /dev/null +++ b/tests/NATS.Server.Tests/MonitorModelTests.cs @@ -0,0 +1,51 @@ +using System.Text.Json; +using NATS.Server.Monitoring; + +namespace NATS.Server.Tests; + +public class MonitorModelTests +{ + [Fact] + public void Varz_serializes_with_go_field_names() + { + var varz = new Varz + { + Id = "TESTID", Name = "test-server", Version = "0.1.0", + Host = "0.0.0.0", Port = 4222, InMsgs = 100, OutMsgs = 200, + }; + var json = JsonSerializer.Serialize(varz); + json.ShouldContain("\"server_id\":"); + json.ShouldContain("\"server_name\":"); + json.ShouldContain("\"in_msgs\":"); + json.ShouldContain("\"out_msgs\":"); + json.ShouldNotContain("\"InMsgs\""); + } + + [Fact] + public void Connz_serializes_with_go_field_names() + { + var connz = new Connz + { + Id = "TESTID", Now = DateTime.UtcNow, NumConns = 1, Total = 1, Limit = 1024, + Conns = [new ConnInfo { Cid = 1, Ip = "127.0.0.1", Port = 5555, + InMsgs = 10, Uptime = "1s", Idle = "0s", + Start = DateTime.UtcNow, LastActivity = DateTime.UtcNow }], + }; + var json = JsonSerializer.Serialize(connz); + json.ShouldContain("\"server_id\":"); + json.ShouldContain("\"num_connections\":"); + json.ShouldContain("\"in_msgs\":"); + json.ShouldContain("\"pending_bytes\":"); + } + + [Fact] + public void Varz_includes_nested_config_stubs() + { + var varz = new Varz { Id = "X", Name = "X", Version = "X", Host = "X" }; + var json = JsonSerializer.Serialize(varz); + json.ShouldContain("\"cluster\":"); + json.ShouldContain("\"gateway\":"); + json.ShouldContain("\"leaf\":"); + json.ShouldContain("\"jetstream\":"); + } +} diff --git a/tests/NATS.Server.Tests/MonitorTests.cs b/tests/NATS.Server.Tests/MonitorTests.cs new file mode 100644 index 0000000..65a1399 --- /dev/null +++ b/tests/NATS.Server.Tests/MonitorTests.cs @@ -0,0 +1,274 @@ +using System.Net; +using System.Net.Http.Json; +using System.Net.Security; +using System.Net.Sockets; +using System.Text; +using Microsoft.Extensions.Logging.Abstractions; +using NATS.Server.Monitoring; + +namespace NATS.Server.Tests; + +public class MonitorTests : IAsyncLifetime +{ + private readonly NatsServer _server; + private readonly int _natsPort; + private readonly int _monitorPort; + private readonly CancellationTokenSource _cts = new(); + private readonly HttpClient _http = new(); + + public MonitorTests() + { + _natsPort = GetFreePort(); + _monitorPort = GetFreePort(); + _server = new NatsServer( + new NatsOptions { Port = _natsPort, MonitorPort = _monitorPort }, + NullLoggerFactory.Instance); + } + + public async Task InitializeAsync() + { + _ = _server.StartAsync(_cts.Token); + await _server.WaitForReadyAsync(); + // Wait for monitoring HTTP server to be ready + for (int i = 0; i < 50; i++) + { + try + { + var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/healthz"); + if (response.IsSuccessStatusCode) break; + } + catch (HttpRequestException) { } + await Task.Delay(50); + } + } + + public async Task DisposeAsync() + { + _http.Dispose(); + await _cts.CancelAsync(); + _server.Dispose(); + } + + [Fact] + public async Task Healthz_returns_ok() + { + var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/healthz"); + response.StatusCode.ShouldBe(HttpStatusCode.OK); + } + + [Fact] + public async Task Varz_returns_server_identity() + { + var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/varz"); + response.StatusCode.ShouldBe(HttpStatusCode.OK); + + var varz = await response.Content.ReadFromJsonAsync(); + varz.ShouldNotBeNull(); + varz.Id.ShouldNotBeNullOrEmpty(); + varz.Name.ShouldNotBeNullOrEmpty(); + varz.Version.ShouldBe("0.1.0"); + varz.Host.ShouldBe("0.0.0.0"); + varz.Port.ShouldBe(_natsPort); + varz.MaxPayload.ShouldBe(1024 * 1024); + varz.Uptime.ShouldNotBeNullOrEmpty(); + varz.Now.ShouldBeGreaterThan(DateTime.MinValue); + varz.Mem.ShouldBeGreaterThan(0); + varz.Cores.ShouldBeGreaterThan(0); + } + + [Fact] + public async Task Varz_tracks_connections_and_messages() + { + // Connect a client and send a message + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _natsPort)); + + var buf = new byte[4096]; + _ = await sock.ReceiveAsync(buf, SocketFlags.None); // Read INFO + + var cmd = "CONNECT {}\r\nSUB test 1\r\nPUB test 5\r\nhello\r\n"u8.ToArray(); + await sock.SendAsync(cmd, SocketFlags.None); + await Task.Delay(200); + + var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/varz"); + var varz = await response.Content.ReadFromJsonAsync(); + + varz.ShouldNotBeNull(); + varz.Connections.ShouldBeGreaterThanOrEqualTo(1); + varz.TotalConnections.ShouldBeGreaterThanOrEqualTo(1UL); + varz.InMsgs.ShouldBeGreaterThanOrEqualTo(1L); + varz.InBytes.ShouldBeGreaterThanOrEqualTo(5L); + } + + [Fact] + public async Task Connz_returns_connections() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _natsPort)); + using var stream = new NetworkStream(sock); + var buf = new byte[4096]; + _ = await stream.ReadAsync(buf); + await stream.WriteAsync("CONNECT {\"name\":\"test-client\",\"lang\":\"csharp\",\"version\":\"1.0\"}\r\n"u8.ToArray()); + await Task.Delay(200); + + var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/connz"); + response.StatusCode.ShouldBe(HttpStatusCode.OK); + + var connz = await response.Content.ReadFromJsonAsync(); + connz.ShouldNotBeNull(); + connz.NumConns.ShouldBeGreaterThanOrEqualTo(1); + connz.Conns.Length.ShouldBeGreaterThanOrEqualTo(1); + + var conn = connz.Conns.First(c => c.Name == "test-client"); + conn.Ip.ShouldNotBeNullOrEmpty(); + conn.Port.ShouldBeGreaterThan(0); + conn.Lang.ShouldBe("csharp"); + conn.Version.ShouldBe("1.0"); + conn.Uptime.ShouldNotBeNullOrEmpty(); + } + + [Fact] + public async Task Connz_pagination() + { + var sockets = new List(); + try + { + for (int i = 0; i < 3; i++) + { + var s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await s.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _natsPort)); + var ns = new NetworkStream(s); + var buf = new byte[4096]; + _ = await ns.ReadAsync(buf); + await ns.WriteAsync("CONNECT {}\r\n"u8.ToArray()); + sockets.Add(s); + } + await Task.Delay(200); + + var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/connz?limit=2&offset=0"); + var connz = await response.Content.ReadFromJsonAsync(); + + connz!.Conns.Length.ShouldBe(2); + connz.Total.ShouldBeGreaterThanOrEqualTo(3); + connz.Limit.ShouldBe(2); + connz.Offset.ShouldBe(0); + } + finally + { + foreach (var s in sockets) s.Dispose(); + } + } + + [Fact] + public async Task Connz_with_subscriptions() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _natsPort)); + using var stream = new NetworkStream(sock); + var buf = new byte[4096]; + _ = await stream.ReadAsync(buf); + await stream.WriteAsync("CONNECT {}\r\nSUB foo 1\r\nSUB bar 2\r\n"u8.ToArray()); + await Task.Delay(200); + + var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/connz?subs=true"); + var connz = await response.Content.ReadFromJsonAsync(); + + var conn = connz!.Conns.First(c => c.NumSubs >= 2); + conn.Subs.ShouldNotBeNull(); + conn.Subs.ShouldContain("foo"); + conn.Subs.ShouldContain("bar"); + } + + private static int GetFreePort() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + sock.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + return ((IPEndPoint)sock.LocalEndPoint!).Port; + } +} + +public class MonitorTlsTests : IAsyncLifetime +{ + private readonly NatsServer _server; + private readonly int _natsPort; + private readonly int _monitorPort; + private readonly CancellationTokenSource _cts = new(); + private readonly HttpClient _http = new(); + private readonly string _certPath; + private readonly string _keyPath; + + public MonitorTlsTests() + { + _natsPort = GetFreePort(); + _monitorPort = GetFreePort(); + (_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles(); + _server = new NatsServer( + new NatsOptions + { + Port = _natsPort, + MonitorPort = _monitorPort, + TlsCert = _certPath, + TlsKey = _keyPath, + }, + NullLoggerFactory.Instance); + } + + public async Task InitializeAsync() + { + _ = _server.StartAsync(_cts.Token); + await _server.WaitForReadyAsync(); + // Wait for monitoring HTTP server to be ready + for (int i = 0; i < 50; i++) + { + try + { + var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/healthz"); + if (response.IsSuccessStatusCode) break; + } + catch (HttpRequestException) { } + await Task.Delay(50); + } + } + + public async Task DisposeAsync() + { + _http.Dispose(); + await _cts.CancelAsync(); + _server.Dispose(); + File.Delete(_certPath); + File.Delete(_keyPath); + } + + [Fact] + public async Task Connz_shows_tls_info_for_tls_client() + { + // Connect and upgrade to TLS + using var tcp = new TcpClient(); + await tcp.ConnectAsync(IPAddress.Loopback, _natsPort); + using var netStream = tcp.GetStream(); + var buf = new byte[4096]; + _ = await netStream.ReadAsync(buf); // Read INFO + + using var ssl = new SslStream(netStream, false, (_, _, _, _) => true); + await ssl.AuthenticateAsClientAsync("localhost"); + + await ssl.WriteAsync("CONNECT {}\r\n"u8.ToArray()); + await ssl.FlushAsync(); + await Task.Delay(200); + + var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/connz"); + var connz = await response.Content.ReadFromJsonAsync(); + + connz!.Conns.Length.ShouldBeGreaterThanOrEqualTo(1); + var conn = connz.Conns[0]; + conn.TlsVersion.ShouldNotBeNullOrEmpty(); + conn.TlsCipherSuite.ShouldNotBeNullOrEmpty(); + } + + private static int GetFreePort() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + sock.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + return ((IPEndPoint)sock.LocalEndPoint!).Port; + } +} diff --git a/tests/NATS.Server.Tests/ServerStatsTests.cs b/tests/NATS.Server.Tests/ServerStatsTests.cs new file mode 100644 index 0000000..6b45bd3 --- /dev/null +++ b/tests/NATS.Server.Tests/ServerStatsTests.cs @@ -0,0 +1,84 @@ +using System.Net; +using System.Net.Sockets; +using Microsoft.Extensions.Logging.Abstractions; +using NATS.Server; + +namespace NATS.Server.Tests; + +public class ServerStatsTests : IAsyncLifetime +{ + private readonly NatsServer _server; + private readonly int _port; + private readonly CancellationTokenSource _cts = new(); + + public ServerStatsTests() + { + _port = GetFreePort(); + _server = new NatsServer(new NatsOptions { Port = _port }, NullLoggerFactory.Instance); + } + + public async Task InitializeAsync() + { + _ = _server.StartAsync(_cts.Token); + await _server.WaitForReadyAsync(); + } + + public Task DisposeAsync() + { + _cts.Cancel(); + _server.Dispose(); + return Task.CompletedTask; + } + + [Fact] + public void Server_has_start_time() + { + _server.StartTime.ShouldNotBe(default); + _server.StartTime.ShouldBeLessThanOrEqualTo(DateTime.UtcNow); + } + + [Fact] + public async Task Server_tracks_total_connections() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port)); + await Task.Delay(100); + _server.Stats.TotalConnections.ShouldBeGreaterThanOrEqualTo(1); + } + + [Fact] + public async Task Server_stats_track_messages() + { + using var pub = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await pub.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port)); + + var buf = new byte[4096]; + await pub.ReceiveAsync(buf, SocketFlags.None); // INFO + + await pub.SendAsync("CONNECT {}\r\nSUB test 1\r\nPUB test 5\r\nhello\r\n"u8.ToArray()); + await Task.Delay(200); + + _server.Stats.InMsgs.ShouldBeGreaterThanOrEqualTo(1); + _server.Stats.InBytes.ShouldBeGreaterThanOrEqualTo(5); + } + + [Fact] + public async Task Client_has_metadata() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port)); + await Task.Delay(100); + + var client = _server.GetClients().First(); + client.RemoteIp.ShouldNotBeNullOrEmpty(); + client.RemotePort.ShouldBeGreaterThan(0); + client.StartTime.ShouldNotBe(default); + } + + private static int GetFreePort() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + sock.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + return ((IPEndPoint)sock.LocalEndPoint!).Port; + } +} diff --git a/tests/NATS.Server.Tests/TlsConnectionWrapperTests.cs b/tests/NATS.Server.Tests/TlsConnectionWrapperTests.cs new file mode 100644 index 0000000..55df6cc --- /dev/null +++ b/tests/NATS.Server.Tests/TlsConnectionWrapperTests.cs @@ -0,0 +1,254 @@ +using System.Net; +using System.Net.Security; +using System.Net.Sockets; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using Microsoft.Extensions.Logging.Abstractions; +using NATS.Server; +using NATS.Server.Protocol; +using NATS.Server.Tls; + +namespace NATS.Server.Tests; + +public class TlsConnectionWrapperTests +{ + [Fact] + public async Task NoTls_returns_plain_stream() + { + var (serverSocket, clientSocket) = await CreateSocketPairAsync(); + using var serverStream = new NetworkStream(serverSocket, ownsSocket: true); + using var clientStream = new NetworkStream(clientSocket, ownsSocket: true); + + var opts = new NatsOptions(); // No TLS configured + var serverInfo = CreateServerInfo(); + + var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync( + serverSocket, serverStream, opts, null, serverInfo, NullLogger.Instance, CancellationToken.None); + + stream.ShouldBe(serverStream); // Same stream, no wrapping + infoSent.ShouldBeFalse(); + } + + [Fact] + public async Task TlsRequired_upgrades_to_ssl() + { + var (cert, _) = TlsHelperTests.GenerateTestCert(); + + var (serverSocket, clientSocket) = await CreateSocketPairAsync(); + using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true); + + var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy" }; + var sslOpts = new SslServerAuthenticationOptions + { + ServerCertificate = cert, + }; + var serverInfo = CreateServerInfo(); + + // Client side: read INFO then start TLS + var clientTask = Task.Run(async () => + { + // Read INFO line + var buf = new byte[4096]; + var read = await clientNetStream.ReadAsync(buf); + var info = System.Text.Encoding.ASCII.GetString(buf, 0, read); + info.ShouldStartWith("INFO "); + + // Upgrade to TLS + var sslClient = new SslStream(clientNetStream, true, + (_, _, _, _) => true); // Trust all for testing + await sslClient.AuthenticateAsClientAsync("localhost"); + return sslClient; + }); + + var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true); + var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync( + serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None); + + stream.ShouldBeOfType(); + infoSent.ShouldBeTrue(); + + var clientSsl = await clientTask; + + // Verify encrypted communication works + await stream.WriteAsync("PING\r\n"u8.ToArray()); + await stream.FlushAsync(); + + var readBuf = new byte[64]; + var bytesRead = await clientSsl.ReadAsync(readBuf); + var msg = System.Text.Encoding.ASCII.GetString(readBuf, 0, bytesRead); + msg.ShouldBe("PING\r\n"); + + stream.Dispose(); + clientSsl.Dispose(); + } + + [Fact] + public async Task MixedMode_allows_plaintext_when_AllowNonTls() + { + var (cert, _) = TlsHelperTests.GenerateTestCert(); + + var (serverSocket, clientSocket) = await CreateSocketPairAsync(); + using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true); + + var opts = new NatsOptions + { + TlsCert = "dummy", + TlsKey = "dummy", + AllowNonTls = true, + TlsTimeout = TimeSpan.FromSeconds(2), + }; + var sslOpts = new SslServerAuthenticationOptions + { + ServerCertificate = cert, + }; + var serverInfo = CreateServerInfo(); + + // Client side: read INFO then send plaintext (not TLS) + var clientTask = Task.Run(async () => + { + var buf = new byte[4096]; + var read = await clientNetStream.ReadAsync(buf); + var info = System.Text.Encoding.ASCII.GetString(buf, 0, read); + info.ShouldStartWith("INFO "); + + // Send plaintext CONNECT (not a TLS handshake) + var connectLine = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n"); + await clientNetStream.WriteAsync(connectLine); + await clientNetStream.FlushAsync(); + }); + + var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true); + var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync( + serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None); + + await clientTask; + + // In mixed mode with plaintext client, we get a PeekableStream, not SslStream + stream.ShouldBeOfType(); + infoSent.ShouldBeTrue(); + + stream.Dispose(); + } + + [Fact] + public async Task TlsRequired_rejects_plaintext() + { + var (cert, _) = TlsHelperTests.GenerateTestCert(); + + var (serverSocket, clientSocket) = await CreateSocketPairAsync(); + using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true); + + var opts = new NatsOptions + { + TlsCert = "dummy", + TlsKey = "dummy", + AllowNonTls = false, + TlsTimeout = TimeSpan.FromSeconds(2), + }; + var sslOpts = new SslServerAuthenticationOptions + { + ServerCertificate = cert, + }; + var serverInfo = CreateServerInfo(); + + // Client side: read INFO then send plaintext + var clientTask = Task.Run(async () => + { + var buf = new byte[4096]; + var read = await clientNetStream.ReadAsync(buf); + var info = System.Text.Encoding.ASCII.GetString(buf, 0, read); + info.ShouldStartWith("INFO "); + + // Send plaintext data (first byte is 'C', not 0x16 TLS marker) + var connectLine = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n"); + await clientNetStream.WriteAsync(connectLine); + await clientNetStream.FlushAsync(); + }); + + var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true); + + await Should.ThrowAsync(async () => + { + await TlsConnectionWrapper.NegotiateAsync( + serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None); + }); + + await clientTask; + serverNetStream.Dispose(); + } + + [Fact] + public async Task TlsFirst_handshakes_before_sending_info() + { + var (cert, _) = TlsHelperTests.GenerateTestCert(); + + var (serverSocket, clientSocket) = await CreateSocketPairAsync(); + using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true); + + var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy", TlsHandshakeFirst = true }; + var sslOpts = new SslServerAuthenticationOptions + { + ServerCertificate = cert, + }; + var serverInfo = CreateServerInfo(); + + // Client side: immediately start TLS (no INFO first) + var clientTask = Task.Run(async () => + { + var sslClient = new SslStream(clientNetStream, true, (_, _, _, _) => true); + await sslClient.AuthenticateAsClientAsync("localhost"); + + // After TLS, read INFO over encrypted stream + var buf = new byte[4096]; + var read = await sslClient.ReadAsync(buf); + var info = System.Text.Encoding.ASCII.GetString(buf, 0, read); + info.ShouldStartWith("INFO "); + + return sslClient; + }); + + var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true); + var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync( + serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None); + + stream.ShouldBeOfType(); + infoSent.ShouldBeTrue(); + + var clientSsl = await clientTask; + + // Verify encrypted communication works + await stream.WriteAsync("PING\r\n"u8.ToArray()); + await stream.FlushAsync(); + + var readBuf = new byte[64]; + var bytesRead = await clientSsl.ReadAsync(readBuf); + var msg = System.Text.Encoding.ASCII.GetString(readBuf, 0, bytesRead); + msg.ShouldBe("PING\r\n"); + + stream.Dispose(); + clientSsl.Dispose(); + } + + private static ServerInfo CreateServerInfo() => new() + { + ServerId = "TEST", + ServerName = "test", + Version = NatsProtocol.Version, + Host = "127.0.0.1", + Port = 4222, + }; + + private static async Task<(Socket server, Socket client)> CreateSocketPairAsync() + { + using var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + listener.Listen(1); + var port = ((IPEndPoint)listener.LocalEndPoint!).Port; + + var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await client.ConnectAsync(new IPEndPoint(IPAddress.Loopback, port)); + var server = await listener.AcceptAsync(); + + return (server, client); + } +} diff --git a/tests/NATS.Server.Tests/TlsHelperTests.cs b/tests/NATS.Server.Tests/TlsHelperTests.cs new file mode 100644 index 0000000..c8d1cfa --- /dev/null +++ b/tests/NATS.Server.Tests/TlsHelperTests.cs @@ -0,0 +1,110 @@ +using System.Net; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using NATS.Server; +using NATS.Server.Tls; + +namespace NATS.Server.Tests; + +public class TlsHelperTests +{ + [Fact] + public void LoadCertificate_loads_pem_cert_and_key() + { + var (certPath, keyPath) = GenerateTestCertFiles(); + try + { + var cert = TlsHelper.LoadCertificate(certPath, keyPath); + cert.ShouldNotBeNull(); + cert.HasPrivateKey.ShouldBeTrue(); + } + finally { File.Delete(certPath); File.Delete(keyPath); } + } + + [Fact] + public void BuildServerAuthOptions_creates_valid_options() + { + var (certPath, keyPath) = GenerateTestCertFiles(); + try + { + var opts = new NatsOptions { TlsCert = certPath, TlsKey = keyPath }; + var authOpts = TlsHelper.BuildServerAuthOptions(opts); + authOpts.ShouldNotBeNull(); + authOpts.ServerCertificate.ShouldNotBeNull(); + } + finally { File.Delete(certPath); File.Delete(keyPath); } + } + + [Fact] + public void MatchesPinnedCert_matches_correct_hash() + { + var (cert, _) = GenerateTestCert(); + var hash = TlsHelper.GetCertificateHash(cert); + var pinned = new HashSet { hash }; + TlsHelper.MatchesPinnedCert(cert, pinned).ShouldBeTrue(); + } + + [Fact] + public void MatchesPinnedCert_rejects_wrong_hash() + { + var (cert, _) = GenerateTestCert(); + var pinned = new HashSet { "0000000000000000000000000000000000000000000000000000000000000000" }; + TlsHelper.MatchesPinnedCert(cert, pinned).ShouldBeFalse(); + } + + [Fact] + public async Task PeekableStream_peeks_and_replays() + { + var data = "Hello, World!"u8.ToArray(); + using var ms = new MemoryStream(data); + using var peekable = new PeekableStream(ms); + + var peeked = await peekable.PeekAsync(1); + peeked.Length.ShouldBe(1); + peeked[0].ShouldBe((byte)'H'); + + var buf = new byte[data.Length]; + int total = 0; + while (total < data.Length) + { + var read = await peekable.ReadAsync(buf.AsMemory(total)); + if (read == 0) break; + total += read; + } + total.ShouldBe(data.Length); + buf.ShouldBe(data); + } + + [Fact] + public async Task TlsRateLimiter_allows_within_limit() + { + using var limiter = new TlsRateLimiter(10); + var cts = new CancellationTokenSource(TimeSpan.FromSeconds(2)); + for (int i = 0; i < 5; i++) + await limiter.WaitAsync(cts.Token); + } + + // Public helper methods used by other test classes + public static (string certPath, string keyPath) GenerateTestCertFiles() + { + var (cert, key) = GenerateTestCert(); + var certPath = Path.GetTempFileName(); + var keyPath = Path.GetTempFileName(); + File.WriteAllText(certPath, cert.ExportCertificatePem()); + File.WriteAllText(keyPath, key.ExportPkcs8PrivateKeyPem()); + return (certPath, keyPath); + } + + public static (X509Certificate2 cert, RSA key) GenerateTestCert() + { + var key = RSA.Create(2048); + var req = new CertificateRequest("CN=localhost", key, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + req.CertificateExtensions.Add(new X509BasicConstraintsExtension(false, false, 0, false)); + var sanBuilder = new SubjectAlternativeNameBuilder(); + sanBuilder.AddIpAddress(IPAddress.Loopback); + sanBuilder.AddDnsName("localhost"); + req.CertificateExtensions.Add(sanBuilder.Build()); + var cert = req.CreateSelfSigned(DateTimeOffset.UtcNow, DateTimeOffset.UtcNow.AddYears(1)); + return (cert, key); + } +} diff --git a/tests/NATS.Server.Tests/TlsServerTests.cs b/tests/NATS.Server.Tests/TlsServerTests.cs new file mode 100644 index 0000000..703b1bf --- /dev/null +++ b/tests/NATS.Server.Tests/TlsServerTests.cs @@ -0,0 +1,225 @@ +using System.Net; +using System.Net.Security; +using System.Net.Sockets; +using System.Text; +using Microsoft.Extensions.Logging.Abstractions; +using NATS.Server; + +namespace NATS.Server.Tests; + +public class TlsServerTests : IAsyncLifetime +{ + private readonly NatsServer _server; + private readonly int _port; + private readonly CancellationTokenSource _cts = new(); + private readonly string _certPath; + private readonly string _keyPath; + + public TlsServerTests() + { + _port = GetFreePort(); + (_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles(); + _server = new NatsServer( + new NatsOptions + { + Port = _port, + TlsCert = _certPath, + TlsKey = _keyPath, + }, + NullLoggerFactory.Instance); + } + + public async Task InitializeAsync() + { + _ = _server.StartAsync(_cts.Token); + await _server.WaitForReadyAsync(); + } + + public async Task DisposeAsync() + { + await _cts.CancelAsync(); + _server.Dispose(); + File.Delete(_certPath); + File.Delete(_keyPath); + } + + [Fact] + public async Task Tls_client_connects_and_receives_info() + { + using var tcp = new TcpClient(); + await tcp.ConnectAsync(IPAddress.Loopback, _port); + using var netStream = tcp.GetStream(); + + // Read INFO (sent before TLS upgrade in Mode 2) + var buf = new byte[4096]; + var read = await netStream.ReadAsync(buf); + var info = Encoding.ASCII.GetString(buf, 0, read); + info.ShouldStartWith("INFO "); + info.ShouldContain("\"tls_required\":true"); + + // Upgrade to TLS + using var sslStream = new SslStream(netStream, false, (_, _, _, _) => true); + await sslStream.AuthenticateAsClientAsync("localhost"); + + // Send CONNECT + PING over TLS + await sslStream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray()); + await sslStream.FlushAsync(); + + // Read PONG + var pongBuf = new byte[256]; + read = await sslStream.ReadAsync(pongBuf); + var pong = Encoding.ASCII.GetString(pongBuf, 0, read); + pong.ShouldContain("PONG"); + } + + [Fact] + public async Task Tls_pubsub_works_over_encrypted_connection() + { + using var tcp1 = new TcpClient(); + await tcp1.ConnectAsync(IPAddress.Loopback, _port); + using var ssl1 = await UpgradeToTlsAsync(tcp1); + + using var tcp2 = new TcpClient(); + await tcp2.ConnectAsync(IPAddress.Loopback, _port); + using var ssl2 = await UpgradeToTlsAsync(tcp2); + + // Sub on client 1 + await ssl1.WriteAsync("CONNECT {}\r\nSUB test 1\r\nPING\r\n"u8.ToArray()); + await ssl1.FlushAsync(); + + // Wait for PONG to confirm subscription is registered + var pongBuf = new byte[256]; + var pongRead = await ssl1.ReadAsync(pongBuf); + var pongStr = Encoding.ASCII.GetString(pongBuf, 0, pongRead); + pongStr.ShouldContain("PONG"); + + // Pub on client 2 + await ssl2.WriteAsync("CONNECT {}\r\nPUB test 5\r\nhello\r\nPING\r\n"u8.ToArray()); + await ssl2.FlushAsync(); + + // Client 1 should receive MSG (may arrive across multiple TLS records) + var msg = await ReadUntilAsync(ssl1, "hello"); + msg.ShouldContain("MSG test 1 5"); + msg.ShouldContain("hello"); + } + + private static async Task ReadUntilAsync(Stream stream, string expected, int timeoutMs = 5000) + { + using var cts = new CancellationTokenSource(timeoutMs); + var sb = new StringBuilder(); + var buf = new byte[4096]; + while (!sb.ToString().Contains(expected)) + { + var n = await stream.ReadAsync(buf, cts.Token); + if (n == 0) break; + sb.Append(Encoding.ASCII.GetString(buf, 0, n)); + } + return sb.ToString(); + } + + private static async Task UpgradeToTlsAsync(TcpClient tcp) + { + var netStream = tcp.GetStream(); + var buf = new byte[4096]; + _ = await netStream.ReadAsync(buf); // Read INFO (discard) + + var ssl = new SslStream(netStream, false, (_, _, _, _) => true); + await ssl.AuthenticateAsClientAsync("localhost"); + return ssl; + } + + private static int GetFreePort() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + sock.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + return ((IPEndPoint)sock.LocalEndPoint!).Port; + } +} + +public class TlsMixedModeTests : IAsyncLifetime +{ + private readonly NatsServer _server; + private readonly int _port; + private readonly CancellationTokenSource _cts = new(); + private readonly string _certPath; + private readonly string _keyPath; + + public TlsMixedModeTests() + { + _port = GetFreePort(); + (_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles(); + _server = new NatsServer( + new NatsOptions + { + Port = _port, + TlsCert = _certPath, + TlsKey = _keyPath, + AllowNonTls = true, + }, + NullLoggerFactory.Instance); + } + + public async Task InitializeAsync() + { + _ = _server.StartAsync(_cts.Token); + await _server.WaitForReadyAsync(); + } + + public async Task DisposeAsync() + { + await _cts.CancelAsync(); + _server.Dispose(); + File.Delete(_certPath); + File.Delete(_keyPath); + } + + [Fact] + public async Task Mixed_mode_accepts_plain_client() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port)); + using var stream = new NetworkStream(sock); + + var buf = new byte[4096]; + var read = await stream.ReadAsync(buf); + var info = Encoding.ASCII.GetString(buf, 0, read); + info.ShouldContain("\"tls_available\":true"); + + await stream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray()); + await stream.FlushAsync(); + + var pongBuf = new byte[64]; + read = await stream.ReadAsync(pongBuf); + var pong = Encoding.ASCII.GetString(pongBuf, 0, read); + pong.ShouldContain("PONG"); + } + + [Fact] + public async Task Mixed_mode_accepts_tls_client() + { + using var tcp = new TcpClient(); + await tcp.ConnectAsync(IPAddress.Loopback, _port); + using var netStream = tcp.GetStream(); + + var buf = new byte[4096]; + _ = await netStream.ReadAsync(buf); // Read INFO + + using var ssl = new SslStream(netStream, false, (_, _, _, _) => true); + await ssl.AuthenticateAsClientAsync("localhost"); + + await ssl.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray()); + await ssl.FlushAsync(); + + var pongBuf = new byte[64]; + var read = await ssl.ReadAsync(pongBuf); + var pong = Encoding.ASCII.GetString(pongBuf, 0, read); + pong.ShouldContain("PONG"); + } + + private static int GetFreePort() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + sock.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + return ((IPEndPoint)sock.LocalEndPoint!).Port; + } +}