feat: add monitoring HTTP endpoints and TLS support
Monitoring HTTP: - /varz, /connz, /healthz via Kestrel Minimal API - Pagination, sorting, subscription details on /connz - ServerStats atomic counters, CPU/memory sampling - CLI flags: -m, --http_port, --http_base_path, --https_port TLS Support: - 4-mode negotiation: no TLS, required, TLS-first, mixed - Certificate loading, pinning (SHA-256), client cert verification - PeekableStream for non-destructive TLS detection - Token-bucket rate limiter for TLS handshakes - CLI flags: --tls, --tlscert, --tlskey, --tlscacert, --tlsverify 29 new tests (78 → 107 total), all passing. # Conflicts: # src/NATS.Server.Host/Program.cs # src/NATS.Server/NATS.Server.csproj # src/NATS.Server/NatsClient.cs # src/NATS.Server/NatsOptions.cs # src/NATS.Server/NatsServer.cs # src/NATS.Server/Protocol/NatsProtocol.cs # tests/NATS.Server.Tests/ClientTests.cs
This commit is contained in:
@@ -32,6 +32,20 @@ for (int i = 0; i < args.Length; i++)
|
||||
case "--https_port" when i + 1 < args.Length:
|
||||
options.MonitorHttpsPort = int.Parse(args[++i]);
|
||||
break;
|
||||
case "--tls":
|
||||
break;
|
||||
case "--tlscert" when i + 1 < args.Length:
|
||||
options.TlsCert = args[++i];
|
||||
break;
|
||||
case "--tlskey" when i + 1 < args.Length:
|
||||
options.TlsKey = args[++i];
|
||||
break;
|
||||
case "--tlscacert" when i + 1 < args.Length:
|
||||
options.TlsCaCert = args[++i];
|
||||
break;
|
||||
case "--tlsverify":
|
||||
options.TlsVerify = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
207
src/NATS.Server/Monitoring/Connz.cs
Normal file
207
src/NATS.Server/Monitoring/Connz.cs
Normal file
@@ -0,0 +1,207 @@
|
||||
using System.Text.Json.Serialization;
|
||||
|
||||
namespace NATS.Server.Monitoring;
|
||||
|
||||
/// <summary>
|
||||
/// Connection information response. Corresponds to Go server/monitor.go Connz struct.
|
||||
/// </summary>
|
||||
public sealed class Connz
|
||||
{
|
||||
[JsonPropertyName("server_id")]
|
||||
public string Id { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("now")]
|
||||
public DateTime Now { get; set; }
|
||||
|
||||
[JsonPropertyName("num_connections")]
|
||||
public int NumConns { get; set; }
|
||||
|
||||
[JsonPropertyName("total")]
|
||||
public int Total { get; set; }
|
||||
|
||||
[JsonPropertyName("offset")]
|
||||
public int Offset { get; set; }
|
||||
|
||||
[JsonPropertyName("limit")]
|
||||
public int Limit { get; set; }
|
||||
|
||||
[JsonPropertyName("connections")]
|
||||
public ConnInfo[] Conns { get; set; } = [];
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Detailed information on a per-connection basis.
|
||||
/// Corresponds to Go server/monitor.go ConnInfo struct.
|
||||
/// </summary>
|
||||
public sealed class ConnInfo
|
||||
{
|
||||
[JsonPropertyName("cid")]
|
||||
public ulong Cid { get; set; }
|
||||
|
||||
[JsonPropertyName("kind")]
|
||||
public string Kind { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("type")]
|
||||
public string Type { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("ip")]
|
||||
public string Ip { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("port")]
|
||||
public int Port { get; set; }
|
||||
|
||||
[JsonPropertyName("start")]
|
||||
public DateTime Start { get; set; }
|
||||
|
||||
[JsonPropertyName("last_activity")]
|
||||
public DateTime LastActivity { get; set; }
|
||||
|
||||
[JsonPropertyName("stop")]
|
||||
public DateTime? Stop { get; set; }
|
||||
|
||||
[JsonPropertyName("reason")]
|
||||
public string Reason { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("rtt")]
|
||||
public string Rtt { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("uptime")]
|
||||
public string Uptime { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("idle")]
|
||||
public string Idle { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("pending_bytes")]
|
||||
public int Pending { get; set; }
|
||||
|
||||
[JsonPropertyName("in_msgs")]
|
||||
public long InMsgs { get; set; }
|
||||
|
||||
[JsonPropertyName("out_msgs")]
|
||||
public long OutMsgs { get; set; }
|
||||
|
||||
[JsonPropertyName("in_bytes")]
|
||||
public long InBytes { get; set; }
|
||||
|
||||
[JsonPropertyName("out_bytes")]
|
||||
public long OutBytes { get; set; }
|
||||
|
||||
[JsonPropertyName("subscriptions")]
|
||||
public uint NumSubs { get; set; }
|
||||
|
||||
[JsonPropertyName("subscriptions_list")]
|
||||
public string[] Subs { get; set; } = [];
|
||||
|
||||
[JsonPropertyName("subscriptions_list_detail")]
|
||||
public SubDetail[] SubsDetail { get; set; } = [];
|
||||
|
||||
[JsonPropertyName("name")]
|
||||
public string Name { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("lang")]
|
||||
public string Lang { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("version")]
|
||||
public string Version { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("authorized_user")]
|
||||
public string AuthorizedUser { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("account")]
|
||||
public string Account { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("tls_version")]
|
||||
public string TlsVersion { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("tls_cipher_suite")]
|
||||
public string TlsCipherSuite { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("tls_first")]
|
||||
public bool TlsFirst { get; set; }
|
||||
|
||||
[JsonPropertyName("mqtt_client")]
|
||||
public string MqttClient { get; set; } = "";
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Subscription detail information.
|
||||
/// Corresponds to Go server/monitor.go SubDetail struct.
|
||||
/// </summary>
|
||||
public sealed class SubDetail
|
||||
{
|
||||
[JsonPropertyName("account")]
|
||||
public string Account { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("subject")]
|
||||
public string Subject { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("qgroup")]
|
||||
public string Queue { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("sid")]
|
||||
public string Sid { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("msgs")]
|
||||
public long Msgs { get; set; }
|
||||
|
||||
[JsonPropertyName("max")]
|
||||
public long Max { get; set; }
|
||||
|
||||
[JsonPropertyName("cid")]
|
||||
public ulong Cid { get; set; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sort options for connection listing.
|
||||
/// Corresponds to Go server/monitor_sort_opts.go SortOpt type.
|
||||
/// </summary>
|
||||
public enum SortOpt
|
||||
{
|
||||
ByCid,
|
||||
ByStart,
|
||||
BySubs,
|
||||
ByPending,
|
||||
ByMsgsTo,
|
||||
ByMsgsFrom,
|
||||
ByBytesTo,
|
||||
ByBytesFrom,
|
||||
ByLast,
|
||||
ByIdle,
|
||||
ByUptime,
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Connection state filter.
|
||||
/// Corresponds to Go server/monitor.go ConnState type.
|
||||
/// </summary>
|
||||
public enum ConnState
|
||||
{
|
||||
Open,
|
||||
Closed,
|
||||
All,
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Options passed to Connz() for filtering and sorting.
|
||||
/// Corresponds to Go server/monitor.go ConnzOptions struct.
|
||||
/// </summary>
|
||||
public sealed class ConnzOptions
|
||||
{
|
||||
public SortOpt Sort { get; set; } = SortOpt.ByCid;
|
||||
|
||||
public bool Subscriptions { get; set; }
|
||||
|
||||
public bool SubscriptionsDetail { get; set; }
|
||||
|
||||
public ConnState State { get; set; } = ConnState.Open;
|
||||
|
||||
public string User { get; set; } = "";
|
||||
|
||||
public string Account { get; set; } = "";
|
||||
|
||||
public string FilterSubject { get; set; } = "";
|
||||
|
||||
public int Offset { get; set; }
|
||||
|
||||
public int Limit { get; set; } = 1024;
|
||||
}
|
||||
148
src/NATS.Server/Monitoring/ConnzHandler.cs
Normal file
148
src/NATS.Server/Monitoring/ConnzHandler.cs
Normal file
@@ -0,0 +1,148 @@
|
||||
using Microsoft.AspNetCore.Http;
|
||||
|
||||
namespace NATS.Server.Monitoring;
|
||||
|
||||
/// <summary>
|
||||
/// Handles /connz endpoint requests, returning detailed connection information.
|
||||
/// Corresponds to Go server/monitor.go handleConnz function.
|
||||
/// </summary>
|
||||
public sealed class ConnzHandler(NatsServer server)
|
||||
{
|
||||
public Connz HandleConnz(HttpContext ctx)
|
||||
{
|
||||
var opts = ParseQueryParams(ctx);
|
||||
var now = DateTime.UtcNow;
|
||||
var clients = server.GetClients().ToArray();
|
||||
|
||||
var connInfos = clients.Select(c => BuildConnInfo(c, now, opts)).ToList();
|
||||
|
||||
// Sort
|
||||
connInfos = opts.Sort switch
|
||||
{
|
||||
SortOpt.ByCid => connInfos.OrderBy(c => c.Cid).ToList(),
|
||||
SortOpt.ByStart => connInfos.OrderBy(c => c.Start).ToList(),
|
||||
SortOpt.BySubs => connInfos.OrderByDescending(c => c.NumSubs).ToList(),
|
||||
SortOpt.ByPending => connInfos.OrderByDescending(c => c.Pending).ToList(),
|
||||
SortOpt.ByMsgsTo => connInfos.OrderByDescending(c => c.OutMsgs).ToList(),
|
||||
SortOpt.ByMsgsFrom => connInfos.OrderByDescending(c => c.InMsgs).ToList(),
|
||||
SortOpt.ByBytesTo => connInfos.OrderByDescending(c => c.OutBytes).ToList(),
|
||||
SortOpt.ByBytesFrom => connInfos.OrderByDescending(c => c.InBytes).ToList(),
|
||||
SortOpt.ByLast => connInfos.OrderByDescending(c => c.LastActivity).ToList(),
|
||||
SortOpt.ByIdle => connInfos.OrderByDescending(c => now - c.LastActivity).ToList(),
|
||||
SortOpt.ByUptime => connInfos.OrderByDescending(c => now - c.Start).ToList(),
|
||||
_ => connInfos.OrderBy(c => c.Cid).ToList(),
|
||||
};
|
||||
|
||||
var total = connInfos.Count;
|
||||
var paged = connInfos.Skip(opts.Offset).Take(opts.Limit).ToArray();
|
||||
|
||||
return new Connz
|
||||
{
|
||||
Id = server.ServerId,
|
||||
Now = now,
|
||||
NumConns = paged.Length,
|
||||
Total = total,
|
||||
Offset = opts.Offset,
|
||||
Limit = opts.Limit,
|
||||
Conns = paged,
|
||||
};
|
||||
}
|
||||
|
||||
private static ConnInfo BuildConnInfo(NatsClient client, DateTime now, ConnzOptions opts)
|
||||
{
|
||||
var info = new ConnInfo
|
||||
{
|
||||
Cid = client.Id,
|
||||
Kind = "Client",
|
||||
Type = "Client",
|
||||
Ip = client.RemoteIp ?? "",
|
||||
Port = client.RemotePort,
|
||||
Start = client.StartTime,
|
||||
LastActivity = client.LastActivity,
|
||||
Uptime = FormatDuration(now - client.StartTime),
|
||||
Idle = FormatDuration(now - client.LastActivity),
|
||||
InMsgs = Interlocked.Read(ref client.InMsgs),
|
||||
OutMsgs = Interlocked.Read(ref client.OutMsgs),
|
||||
InBytes = Interlocked.Read(ref client.InBytes),
|
||||
OutBytes = Interlocked.Read(ref client.OutBytes),
|
||||
NumSubs = (uint)client.Subscriptions.Count,
|
||||
Name = client.ClientOpts?.Name ?? "",
|
||||
Lang = client.ClientOpts?.Lang ?? "",
|
||||
Version = client.ClientOpts?.Version ?? "",
|
||||
TlsVersion = client.TlsState?.TlsVersion ?? "",
|
||||
TlsCipherSuite = client.TlsState?.CipherSuite ?? "",
|
||||
};
|
||||
|
||||
if (opts.Subscriptions)
|
||||
{
|
||||
info.Subs = client.Subscriptions.Values.Select(s => s.Subject).ToArray();
|
||||
}
|
||||
|
||||
if (opts.SubscriptionsDetail)
|
||||
{
|
||||
info.SubsDetail = client.Subscriptions.Values.Select(s => new SubDetail
|
||||
{
|
||||
Subject = s.Subject,
|
||||
Queue = s.Queue ?? "",
|
||||
Sid = s.Sid,
|
||||
Msgs = Interlocked.Read(ref s.MessageCount),
|
||||
Max = s.MaxMessages,
|
||||
Cid = client.Id,
|
||||
}).ToArray();
|
||||
}
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
private static ConnzOptions ParseQueryParams(HttpContext ctx)
|
||||
{
|
||||
var q = ctx.Request.Query;
|
||||
var opts = new ConnzOptions();
|
||||
|
||||
if (q.TryGetValue("sort", out var sort))
|
||||
{
|
||||
opts.Sort = sort.ToString().ToLowerInvariant() switch
|
||||
{
|
||||
"cid" => SortOpt.ByCid,
|
||||
"start" => SortOpt.ByStart,
|
||||
"subs" => SortOpt.BySubs,
|
||||
"pending" => SortOpt.ByPending,
|
||||
"msgs_to" => SortOpt.ByMsgsTo,
|
||||
"msgs_from" => SortOpt.ByMsgsFrom,
|
||||
"bytes_to" => SortOpt.ByBytesTo,
|
||||
"bytes_from" => SortOpt.ByBytesFrom,
|
||||
"last" => SortOpt.ByLast,
|
||||
"idle" => SortOpt.ByIdle,
|
||||
"uptime" => SortOpt.ByUptime,
|
||||
_ => SortOpt.ByCid,
|
||||
};
|
||||
}
|
||||
|
||||
if (q.TryGetValue("subs", out var subs))
|
||||
{
|
||||
if (subs == "detail")
|
||||
opts.SubscriptionsDetail = true;
|
||||
else
|
||||
opts.Subscriptions = true;
|
||||
}
|
||||
|
||||
if (q.TryGetValue("offset", out var offset) && int.TryParse(offset, out var o))
|
||||
opts.Offset = o;
|
||||
|
||||
if (q.TryGetValue("limit", out var limit) && int.TryParse(limit, out var l))
|
||||
opts.Limit = l;
|
||||
|
||||
return opts;
|
||||
}
|
||||
|
||||
private static string FormatDuration(TimeSpan ts)
|
||||
{
|
||||
if (ts.TotalDays >= 1)
|
||||
return $"{(int)ts.TotalDays}d{ts.Hours}h{ts.Minutes}m{ts.Seconds}s";
|
||||
if (ts.TotalHours >= 1)
|
||||
return $"{(int)ts.TotalHours}h{ts.Minutes}m{ts.Seconds}s";
|
||||
if (ts.TotalMinutes >= 1)
|
||||
return $"{(int)ts.TotalMinutes}m{ts.Seconds}s";
|
||||
return $"{(int)ts.TotalSeconds}s";
|
||||
}
|
||||
}
|
||||
117
src/NATS.Server/Monitoring/MonitorServer.cs
Normal file
117
src/NATS.Server/Monitoring/MonitorServer.cs
Normal file
@@ -0,0 +1,117 @@
|
||||
using Microsoft.AspNetCore.Builder;
|
||||
using Microsoft.AspNetCore.Hosting;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace NATS.Server.Monitoring;
|
||||
|
||||
/// <summary>
|
||||
/// HTTP monitoring server providing /healthz, /varz, and other monitoring endpoints.
|
||||
/// Corresponds to Go server/monitor.go HTTP server setup.
|
||||
/// </summary>
|
||||
public sealed class MonitorServer : IAsyncDisposable
|
||||
{
|
||||
private readonly WebApplication _app;
|
||||
private readonly ILogger<MonitorServer> _logger;
|
||||
private readonly VarzHandler _varzHandler;
|
||||
private readonly ConnzHandler _connzHandler;
|
||||
|
||||
public MonitorServer(NatsServer server, NatsOptions options, ServerStats stats, ILoggerFactory loggerFactory)
|
||||
{
|
||||
_logger = loggerFactory.CreateLogger<MonitorServer>();
|
||||
|
||||
var builder = WebApplication.CreateSlimBuilder();
|
||||
builder.WebHost.UseUrls($"http://{options.MonitorHost}:{options.MonitorPort}");
|
||||
builder.Logging.ClearProviders();
|
||||
|
||||
_app = builder.Build();
|
||||
var basePath = options.MonitorBasePath ?? "";
|
||||
|
||||
_varzHandler = new VarzHandler(server, options);
|
||||
_connzHandler = new ConnzHandler(server);
|
||||
|
||||
_app.MapGet(basePath + "/", () =>
|
||||
{
|
||||
stats.HttpReqStats.AddOrUpdate("/", 1, (_, v) => v + 1);
|
||||
return Results.Ok(new
|
||||
{
|
||||
endpoints = new[]
|
||||
{
|
||||
"/varz", "/connz", "/healthz", "/routez",
|
||||
"/gatewayz", "/leafz", "/subz", "/accountz", "/jsz",
|
||||
},
|
||||
});
|
||||
});
|
||||
_app.MapGet(basePath + "/healthz", () =>
|
||||
{
|
||||
stats.HttpReqStats.AddOrUpdate("/healthz", 1, (_, v) => v + 1);
|
||||
return Results.Ok("ok");
|
||||
});
|
||||
_app.MapGet(basePath + "/varz", async (HttpContext ctx) =>
|
||||
{
|
||||
stats.HttpReqStats.AddOrUpdate("/varz", 1, (_, v) => v + 1);
|
||||
return Results.Ok(await _varzHandler.HandleVarzAsync(ctx.RequestAborted));
|
||||
});
|
||||
|
||||
_app.MapGet(basePath + "/connz", (HttpContext ctx) =>
|
||||
{
|
||||
stats.HttpReqStats.AddOrUpdate("/connz", 1, (_, v) => v + 1);
|
||||
return Results.Ok(_connzHandler.HandleConnz(ctx));
|
||||
});
|
||||
|
||||
// Stubs for unimplemented endpoints
|
||||
_app.MapGet(basePath + "/routez", () =>
|
||||
{
|
||||
stats.HttpReqStats.AddOrUpdate("/routez", 1, (_, v) => v + 1);
|
||||
return Results.Ok(new { });
|
||||
});
|
||||
_app.MapGet(basePath + "/gatewayz", () =>
|
||||
{
|
||||
stats.HttpReqStats.AddOrUpdate("/gatewayz", 1, (_, v) => v + 1);
|
||||
return Results.Ok(new { });
|
||||
});
|
||||
_app.MapGet(basePath + "/leafz", () =>
|
||||
{
|
||||
stats.HttpReqStats.AddOrUpdate("/leafz", 1, (_, v) => v + 1);
|
||||
return Results.Ok(new { });
|
||||
});
|
||||
_app.MapGet(basePath + "/subz", () =>
|
||||
{
|
||||
stats.HttpReqStats.AddOrUpdate("/subz", 1, (_, v) => v + 1);
|
||||
return Results.Ok(new { });
|
||||
});
|
||||
_app.MapGet(basePath + "/subscriptionsz", () =>
|
||||
{
|
||||
stats.HttpReqStats.AddOrUpdate("/subscriptionsz", 1, (_, v) => v + 1);
|
||||
return Results.Ok(new { });
|
||||
});
|
||||
_app.MapGet(basePath + "/accountz", () =>
|
||||
{
|
||||
stats.HttpReqStats.AddOrUpdate("/accountz", 1, (_, v) => v + 1);
|
||||
return Results.Ok(new { });
|
||||
});
|
||||
_app.MapGet(basePath + "/accstatz", () =>
|
||||
{
|
||||
stats.HttpReqStats.AddOrUpdate("/accstatz", 1, (_, v) => v + 1);
|
||||
return Results.Ok(new { });
|
||||
});
|
||||
_app.MapGet(basePath + "/jsz", () =>
|
||||
{
|
||||
stats.HttpReqStats.AddOrUpdate("/jsz", 1, (_, v) => v + 1);
|
||||
return Results.Ok(new { });
|
||||
});
|
||||
}
|
||||
|
||||
public async Task StartAsync(CancellationToken ct)
|
||||
{
|
||||
await _app.StartAsync(ct);
|
||||
_logger.LogInformation("Monitoring listening on {Urls}", string.Join(", ", _app.Urls));
|
||||
}
|
||||
|
||||
public async ValueTask DisposeAsync()
|
||||
{
|
||||
await _app.StopAsync();
|
||||
await _app.DisposeAsync();
|
||||
_varzHandler.Dispose();
|
||||
}
|
||||
}
|
||||
415
src/NATS.Server/Monitoring/Varz.cs
Normal file
415
src/NATS.Server/Monitoring/Varz.cs
Normal file
@@ -0,0 +1,415 @@
|
||||
using System.Text.Json.Serialization;
|
||||
|
||||
namespace NATS.Server.Monitoring;
|
||||
|
||||
/// <summary>
|
||||
/// Server general information. Corresponds to Go server/monitor.go Varz struct.
|
||||
/// </summary>
|
||||
public sealed class Varz
|
||||
{
|
||||
// Identity
|
||||
[JsonPropertyName("server_id")]
|
||||
public string Id { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("server_name")]
|
||||
public string Name { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("version")]
|
||||
public string Version { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("proto")]
|
||||
public int Proto { get; set; }
|
||||
|
||||
[JsonPropertyName("git_commit")]
|
||||
public string GitCommit { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("go")]
|
||||
public string GoVersion { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("host")]
|
||||
public string Host { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("port")]
|
||||
public int Port { get; set; }
|
||||
|
||||
// Network
|
||||
[JsonPropertyName("ip")]
|
||||
public string Ip { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("connect_urls")]
|
||||
public string[] ConnectUrls { get; set; } = [];
|
||||
|
||||
[JsonPropertyName("ws_connect_urls")]
|
||||
public string[] WsConnectUrls { get; set; } = [];
|
||||
|
||||
[JsonPropertyName("http_host")]
|
||||
public string HttpHost { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("http_port")]
|
||||
public int HttpPort { get; set; }
|
||||
|
||||
[JsonPropertyName("http_base_path")]
|
||||
public string HttpBasePath { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("https_port")]
|
||||
public int HttpsPort { get; set; }
|
||||
|
||||
// Security
|
||||
[JsonPropertyName("auth_required")]
|
||||
public bool AuthRequired { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_required")]
|
||||
public bool TlsRequired { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_verify")]
|
||||
public bool TlsVerify { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_ocsp_peer_verify")]
|
||||
public bool TlsOcspPeerVerify { get; set; }
|
||||
|
||||
[JsonPropertyName("auth_timeout")]
|
||||
public double AuthTimeout { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_timeout")]
|
||||
public double TlsTimeout { get; set; }
|
||||
|
||||
// Limits
|
||||
[JsonPropertyName("max_connections")]
|
||||
public int MaxConnections { get; set; }
|
||||
|
||||
[JsonPropertyName("max_subscriptions")]
|
||||
public int MaxSubscriptions { get; set; }
|
||||
|
||||
[JsonPropertyName("max_payload")]
|
||||
public int MaxPayload { get; set; }
|
||||
|
||||
[JsonPropertyName("max_pending")]
|
||||
public long MaxPending { get; set; }
|
||||
|
||||
[JsonPropertyName("max_control_line")]
|
||||
public int MaxControlLine { get; set; }
|
||||
|
||||
[JsonPropertyName("ping_max")]
|
||||
public int MaxPingsOut { get; set; }
|
||||
|
||||
// Timing
|
||||
[JsonPropertyName("ping_interval")]
|
||||
public long PingInterval { get; set; }
|
||||
|
||||
[JsonPropertyName("write_deadline")]
|
||||
public long WriteDeadline { get; set; }
|
||||
|
||||
[JsonPropertyName("start")]
|
||||
public DateTime Start { get; set; }
|
||||
|
||||
[JsonPropertyName("now")]
|
||||
public DateTime Now { get; set; }
|
||||
|
||||
[JsonPropertyName("uptime")]
|
||||
public string Uptime { get; set; } = "";
|
||||
|
||||
// Runtime
|
||||
[JsonPropertyName("mem")]
|
||||
public long Mem { get; set; }
|
||||
|
||||
[JsonPropertyName("cpu")]
|
||||
public double Cpu { get; set; }
|
||||
|
||||
[JsonPropertyName("cores")]
|
||||
public int Cores { get; set; }
|
||||
|
||||
[JsonPropertyName("gomaxprocs")]
|
||||
public int MaxProcs { get; set; }
|
||||
|
||||
// Connections
|
||||
[JsonPropertyName("connections")]
|
||||
public int Connections { get; set; }
|
||||
|
||||
[JsonPropertyName("total_connections")]
|
||||
public ulong TotalConnections { get; set; }
|
||||
|
||||
[JsonPropertyName("routes")]
|
||||
public int Routes { get; set; }
|
||||
|
||||
[JsonPropertyName("remotes")]
|
||||
public int Remotes { get; set; }
|
||||
|
||||
[JsonPropertyName("leafnodes")]
|
||||
public int Leafnodes { get; set; }
|
||||
|
||||
// Messages
|
||||
[JsonPropertyName("in_msgs")]
|
||||
public long InMsgs { get; set; }
|
||||
|
||||
[JsonPropertyName("out_msgs")]
|
||||
public long OutMsgs { get; set; }
|
||||
|
||||
[JsonPropertyName("in_bytes")]
|
||||
public long InBytes { get; set; }
|
||||
|
||||
[JsonPropertyName("out_bytes")]
|
||||
public long OutBytes { get; set; }
|
||||
|
||||
// Health
|
||||
[JsonPropertyName("slow_consumers")]
|
||||
public long SlowConsumers { get; set; }
|
||||
|
||||
[JsonPropertyName("slow_consumer_stats")]
|
||||
public SlowConsumersStats SlowConsumerStats { get; set; } = new();
|
||||
|
||||
[JsonPropertyName("subscriptions")]
|
||||
public uint Subscriptions { get; set; }
|
||||
|
||||
// Config
|
||||
[JsonPropertyName("config_load_time")]
|
||||
public DateTime ConfigLoadTime { get; set; }
|
||||
|
||||
[JsonPropertyName("tags")]
|
||||
public string[] Tags { get; set; } = [];
|
||||
|
||||
[JsonPropertyName("system_account")]
|
||||
public string SystemAccount { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("pinned_account_fails")]
|
||||
public ulong PinnedAccountFail { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_cert_not_after")]
|
||||
public DateTime TlsCertNotAfter { get; set; }
|
||||
|
||||
// HTTP
|
||||
[JsonPropertyName("http_req_stats")]
|
||||
public Dictionary<string, ulong> HttpReqStats { get; set; } = new();
|
||||
|
||||
// Subsystems
|
||||
[JsonPropertyName("cluster")]
|
||||
public ClusterOptsVarz Cluster { get; set; } = new();
|
||||
|
||||
[JsonPropertyName("gateway")]
|
||||
public GatewayOptsVarz Gateway { get; set; } = new();
|
||||
|
||||
[JsonPropertyName("leaf")]
|
||||
public LeafNodeOptsVarz Leaf { get; set; } = new();
|
||||
|
||||
[JsonPropertyName("mqtt")]
|
||||
public MqttOptsVarz Mqtt { get; set; } = new();
|
||||
|
||||
[JsonPropertyName("websocket")]
|
||||
public WebsocketOptsVarz Websocket { get; set; } = new();
|
||||
|
||||
[JsonPropertyName("jetstream")]
|
||||
public JetStreamVarz JetStream { get; set; } = new();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Statistics about slow consumers by connection type.
|
||||
/// Corresponds to Go server/monitor.go SlowConsumersStats struct.
|
||||
/// </summary>
|
||||
public sealed class SlowConsumersStats
|
||||
{
|
||||
[JsonPropertyName("clients")]
|
||||
public ulong Clients { get; set; }
|
||||
|
||||
[JsonPropertyName("routes")]
|
||||
public ulong Routes { get; set; }
|
||||
|
||||
[JsonPropertyName("gateways")]
|
||||
public ulong Gateways { get; set; }
|
||||
|
||||
[JsonPropertyName("leafs")]
|
||||
public ulong Leafs { get; set; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Cluster configuration monitoring information.
|
||||
/// Corresponds to Go server/monitor.go ClusterOptsVarz struct.
|
||||
/// </summary>
|
||||
public sealed class ClusterOptsVarz
|
||||
{
|
||||
[JsonPropertyName("name")]
|
||||
public string Name { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("addr")]
|
||||
public string Host { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("cluster_port")]
|
||||
public int Port { get; set; }
|
||||
|
||||
[JsonPropertyName("auth_timeout")]
|
||||
public double AuthTimeout { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_timeout")]
|
||||
public double TlsTimeout { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_required")]
|
||||
public bool TlsRequired { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_verify")]
|
||||
public bool TlsVerify { get; set; }
|
||||
|
||||
[JsonPropertyName("pool_size")]
|
||||
public int PoolSize { get; set; }
|
||||
|
||||
[JsonPropertyName("urls")]
|
||||
public string[] Urls { get; set; } = [];
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gateway configuration monitoring information.
|
||||
/// Corresponds to Go server/monitor.go GatewayOptsVarz struct.
|
||||
/// </summary>
|
||||
public sealed class GatewayOptsVarz
|
||||
{
|
||||
[JsonPropertyName("name")]
|
||||
public string Name { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("host")]
|
||||
public string Host { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("port")]
|
||||
public int Port { get; set; }
|
||||
|
||||
[JsonPropertyName("auth_timeout")]
|
||||
public double AuthTimeout { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_timeout")]
|
||||
public double TlsTimeout { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_required")]
|
||||
public bool TlsRequired { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_verify")]
|
||||
public bool TlsVerify { get; set; }
|
||||
|
||||
[JsonPropertyName("advertise")]
|
||||
public string Advertise { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("connect_retries")]
|
||||
public int ConnectRetries { get; set; }
|
||||
|
||||
[JsonPropertyName("reject_unknown")]
|
||||
public bool RejectUnknown { get; set; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Leaf node configuration monitoring information.
|
||||
/// Corresponds to Go server/monitor.go LeafNodeOptsVarz struct.
|
||||
/// </summary>
|
||||
public sealed class LeafNodeOptsVarz
|
||||
{
|
||||
[JsonPropertyName("host")]
|
||||
public string Host { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("port")]
|
||||
public int Port { get; set; }
|
||||
|
||||
[JsonPropertyName("auth_timeout")]
|
||||
public double AuthTimeout { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_timeout")]
|
||||
public double TlsTimeout { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_required")]
|
||||
public bool TlsRequired { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_verify")]
|
||||
public bool TlsVerify { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_ocsp_peer_verify")]
|
||||
public bool TlsOcspPeerVerify { get; set; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// MQTT configuration monitoring information.
|
||||
/// Corresponds to Go server/monitor.go MQTTOptsVarz struct.
|
||||
/// </summary>
|
||||
public sealed class MqttOptsVarz
|
||||
{
|
||||
[JsonPropertyName("host")]
|
||||
public string Host { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("port")]
|
||||
public int Port { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_timeout")]
|
||||
public double TlsTimeout { get; set; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Websocket configuration monitoring information.
|
||||
/// Corresponds to Go server/monitor.go WebsocketOptsVarz struct.
|
||||
/// </summary>
|
||||
public sealed class WebsocketOptsVarz
|
||||
{
|
||||
[JsonPropertyName("host")]
|
||||
public string Host { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("port")]
|
||||
public int Port { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_timeout")]
|
||||
public double TlsTimeout { get; set; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// JetStream runtime information.
|
||||
/// Corresponds to Go server/monitor.go JetStreamVarz struct.
|
||||
/// </summary>
|
||||
public sealed class JetStreamVarz
|
||||
{
|
||||
[JsonPropertyName("config")]
|
||||
public JetStreamConfig Config { get; set; } = new();
|
||||
|
||||
[JsonPropertyName("stats")]
|
||||
public JetStreamStats Stats { get; set; } = new();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// JetStream configuration.
|
||||
/// Corresponds to Go server/jetstream.go JetStreamConfig struct.
|
||||
/// </summary>
|
||||
public sealed class JetStreamConfig
|
||||
{
|
||||
[JsonPropertyName("max_memory")]
|
||||
public long MaxMemory { get; set; }
|
||||
|
||||
[JsonPropertyName("max_storage")]
|
||||
public long MaxStorage { get; set; }
|
||||
|
||||
[JsonPropertyName("store_dir")]
|
||||
public string StoreDir { get; set; } = "";
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// JetStream statistics.
|
||||
/// Corresponds to Go server/jetstream.go JetStreamStats struct.
|
||||
/// </summary>
|
||||
public sealed class JetStreamStats
|
||||
{
|
||||
[JsonPropertyName("memory")]
|
||||
public ulong Memory { get; set; }
|
||||
|
||||
[JsonPropertyName("storage")]
|
||||
public ulong Storage { get; set; }
|
||||
|
||||
[JsonPropertyName("accounts")]
|
||||
public int Accounts { get; set; }
|
||||
|
||||
[JsonPropertyName("ha_assets")]
|
||||
public int HaAssets { get; set; }
|
||||
|
||||
[JsonPropertyName("api")]
|
||||
public JetStreamApiStats Api { get; set; } = new();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// JetStream API statistics.
|
||||
/// Corresponds to Go server/jetstream.go JetStreamAPIStats struct.
|
||||
/// </summary>
|
||||
public sealed class JetStreamApiStats
|
||||
{
|
||||
[JsonPropertyName("total")]
|
||||
public ulong Total { get; set; }
|
||||
|
||||
[JsonPropertyName("errors")]
|
||||
public ulong Errors { get; set; }
|
||||
}
|
||||
121
src/NATS.Server/Monitoring/VarzHandler.cs
Normal file
121
src/NATS.Server/Monitoring/VarzHandler.cs
Normal file
@@ -0,0 +1,121 @@
|
||||
using System.Diagnostics;
|
||||
using System.Runtime.InteropServices;
|
||||
using NATS.Server.Protocol;
|
||||
|
||||
namespace NATS.Server.Monitoring;
|
||||
|
||||
/// <summary>
|
||||
/// Handles building the Varz response from server state and process metrics.
|
||||
/// Corresponds to Go server/monitor.go handleVarz function.
|
||||
/// </summary>
|
||||
public sealed class VarzHandler : IDisposable
|
||||
{
|
||||
private readonly NatsServer _server;
|
||||
private readonly NatsOptions _options;
|
||||
private readonly SemaphoreSlim _varzMu = new(1, 1);
|
||||
private DateTime _lastCpuSampleTime;
|
||||
private TimeSpan _lastCpuUsage;
|
||||
private double _cachedCpuPercent;
|
||||
|
||||
public VarzHandler(NatsServer server, NatsOptions options)
|
||||
{
|
||||
_server = server;
|
||||
_options = options;
|
||||
using var proc = Process.GetCurrentProcess();
|
||||
_lastCpuSampleTime = DateTime.UtcNow;
|
||||
_lastCpuUsage = proc.TotalProcessorTime;
|
||||
}
|
||||
|
||||
public async Task<Varz> HandleVarzAsync(CancellationToken ct = default)
|
||||
{
|
||||
await _varzMu.WaitAsync(ct);
|
||||
try
|
||||
{
|
||||
using var proc = Process.GetCurrentProcess();
|
||||
var now = DateTime.UtcNow;
|
||||
var uptime = now - _server.StartTime;
|
||||
var stats = _server.Stats;
|
||||
|
||||
// CPU sampling with 1-second cache to avoid excessive sampling
|
||||
if ((now - _lastCpuSampleTime).TotalSeconds >= 1.0)
|
||||
{
|
||||
var currentCpu = proc.TotalProcessorTime;
|
||||
var elapsed = now - _lastCpuSampleTime;
|
||||
_cachedCpuPercent = (currentCpu - _lastCpuUsage).TotalMilliseconds
|
||||
/ elapsed.TotalMilliseconds / Environment.ProcessorCount * 100.0;
|
||||
_lastCpuSampleTime = now;
|
||||
_lastCpuUsage = currentCpu;
|
||||
}
|
||||
|
||||
return new Varz
|
||||
{
|
||||
Id = _server.ServerId,
|
||||
Name = _server.ServerName,
|
||||
Version = NatsProtocol.Version,
|
||||
Proto = NatsProtocol.ProtoVersion,
|
||||
GoVersion = $"dotnet {RuntimeInformation.FrameworkDescription}",
|
||||
Host = _options.Host,
|
||||
Port = _options.Port,
|
||||
HttpHost = _options.MonitorHost,
|
||||
HttpPort = _options.MonitorPort,
|
||||
HttpBasePath = _options.MonitorBasePath ?? "",
|
||||
HttpsPort = _options.MonitorHttpsPort,
|
||||
TlsRequired = _options.HasTls && !_options.AllowNonTls,
|
||||
TlsVerify = _options.HasTls && _options.TlsVerify,
|
||||
TlsTimeout = _options.HasTls ? _options.TlsTimeout.TotalSeconds : 0,
|
||||
MaxConnections = _options.MaxConnections,
|
||||
MaxPayload = _options.MaxPayload,
|
||||
MaxControlLine = _options.MaxControlLine,
|
||||
MaxPingsOut = _options.MaxPingsOut,
|
||||
PingInterval = (long)_options.PingInterval.TotalNanoseconds,
|
||||
Start = _server.StartTime,
|
||||
Now = now,
|
||||
Uptime = FormatUptime(uptime),
|
||||
Mem = proc.WorkingSet64,
|
||||
Cpu = Math.Round(_cachedCpuPercent, 2),
|
||||
Cores = Environment.ProcessorCount,
|
||||
MaxProcs = ThreadPool.ThreadCount,
|
||||
Connections = _server.ClientCount,
|
||||
TotalConnections = (ulong)Interlocked.Read(ref stats.TotalConnections),
|
||||
InMsgs = Interlocked.Read(ref stats.InMsgs),
|
||||
OutMsgs = Interlocked.Read(ref stats.OutMsgs),
|
||||
InBytes = Interlocked.Read(ref stats.InBytes),
|
||||
OutBytes = Interlocked.Read(ref stats.OutBytes),
|
||||
SlowConsumers = Interlocked.Read(ref stats.SlowConsumers),
|
||||
SlowConsumerStats = new SlowConsumersStats
|
||||
{
|
||||
Clients = (ulong)Interlocked.Read(ref stats.SlowConsumerClients),
|
||||
Routes = (ulong)Interlocked.Read(ref stats.SlowConsumerRoutes),
|
||||
Gateways = (ulong)Interlocked.Read(ref stats.SlowConsumerGateways),
|
||||
Leafs = (ulong)Interlocked.Read(ref stats.SlowConsumerLeafs),
|
||||
},
|
||||
Subscriptions = _server.SubList.Count,
|
||||
ConfigLoadTime = _server.StartTime,
|
||||
HttpReqStats = stats.HttpReqStats.ToDictionary(kv => kv.Key, kv => (ulong)kv.Value),
|
||||
};
|
||||
}
|
||||
finally
|
||||
{
|
||||
_varzMu.Release();
|
||||
}
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
_varzMu.Dispose();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Formats a TimeSpan as a human-readable uptime string matching Go server format.
|
||||
/// </summary>
|
||||
private static string FormatUptime(TimeSpan ts)
|
||||
{
|
||||
if (ts.TotalDays >= 1)
|
||||
return $"{(int)ts.TotalDays}d{ts.Hours}h{ts.Minutes}m{ts.Seconds}s";
|
||||
if (ts.TotalHours >= 1)
|
||||
return $"{(int)ts.TotalHours}h{ts.Minutes}m{ts.Seconds}s";
|
||||
if (ts.TotalMinutes >= 1)
|
||||
return $"{(int)ts.TotalMinutes}m{ts.Seconds}s";
|
||||
return $"{(int)ts.TotalSeconds}s";
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" />
|
||||
<FrameworkReference Include="Microsoft.AspNetCore.App" />
|
||||
<PackageReference Include="NATS.NKeys" />
|
||||
<PackageReference Include="BCrypt.Net-Next" />
|
||||
</ItemGroup>
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
using System.Buffers;
|
||||
using System.IO.Pipelines;
|
||||
using System.Net;
|
||||
using System.Net.Sockets;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
@@ -8,6 +9,7 @@ using Microsoft.Extensions.Logging;
|
||||
using NATS.Server.Auth;
|
||||
using NATS.Server.Protocol;
|
||||
using NATS.Server.Subscriptions;
|
||||
using NATS.Server.Tls;
|
||||
|
||||
namespace NATS.Server;
|
||||
|
||||
@@ -26,7 +28,7 @@ public interface ISubListAccess
|
||||
public sealed class NatsClient : IDisposable
|
||||
{
|
||||
private readonly Socket _socket;
|
||||
private readonly NetworkStream _stream;
|
||||
private readonly Stream _stream;
|
||||
private readonly NatsOptions _options;
|
||||
private readonly ServerInfo _serverInfo;
|
||||
private readonly AuthService _authService;
|
||||
@@ -37,6 +39,7 @@ public sealed class NatsClient : IDisposable
|
||||
private readonly Dictionary<string, Subscription> _subs = new();
|
||||
private readonly ILogger _logger;
|
||||
private ClientPermissions? _permissions;
|
||||
private readonly ServerStats _serverStats;
|
||||
|
||||
public ulong Id { get; }
|
||||
public ClientOptions? ClientOpts { get; private set; }
|
||||
@@ -47,6 +50,12 @@ public sealed class NatsClient : IDisposable
|
||||
private int _connectReceived;
|
||||
public bool ConnectReceived => Volatile.Read(ref _connectReceived) != 0;
|
||||
|
||||
public DateTime StartTime { get; }
|
||||
private long _lastActivityTicks;
|
||||
public DateTime LastActivity => new(Interlocked.Read(ref _lastActivityTicks), DateTimeKind.Utc);
|
||||
public string? RemoteIp { get; }
|
||||
public int RemotePort { get; }
|
||||
|
||||
// Stats
|
||||
public long InMsgs;
|
||||
public long OutMsgs;
|
||||
@@ -57,20 +66,31 @@ public sealed class NatsClient : IDisposable
|
||||
private int _pingsOut;
|
||||
private long _lastIn;
|
||||
|
||||
public TlsConnectionState? TlsState { get; set; }
|
||||
public bool InfoAlreadySent { get; set; }
|
||||
|
||||
public IReadOnlyDictionary<string, Subscription> Subscriptions => _subs;
|
||||
|
||||
public NatsClient(ulong id, Socket socket, NatsOptions options, ServerInfo serverInfo,
|
||||
AuthService authService, byte[]? nonce, ILogger logger)
|
||||
public NatsClient(ulong id, Stream stream, Socket socket, NatsOptions options, ServerInfo serverInfo,
|
||||
AuthService authService, byte[]? nonce, ILogger logger, ServerStats serverStats)
|
||||
{
|
||||
Id = id;
|
||||
_socket = socket;
|
||||
_stream = new NetworkStream(socket, ownsSocket: false);
|
||||
_stream = stream;
|
||||
_options = options;
|
||||
_serverInfo = serverInfo;
|
||||
_authService = authService;
|
||||
_nonce = nonce;
|
||||
_logger = logger;
|
||||
_serverStats = serverStats;
|
||||
_parser = new NatsParser(options.MaxPayload);
|
||||
StartTime = DateTime.UtcNow;
|
||||
_lastActivityTicks = StartTime.Ticks;
|
||||
if (socket.RemoteEndPoint is IPEndPoint ep)
|
||||
{
|
||||
RemoteIp = ep.Address.ToString();
|
||||
RemotePort = ep.Port;
|
||||
}
|
||||
}
|
||||
|
||||
public async Task RunAsync(CancellationToken ct)
|
||||
@@ -80,8 +100,9 @@ public sealed class NatsClient : IDisposable
|
||||
var pipe = new Pipe();
|
||||
try
|
||||
{
|
||||
// Send INFO
|
||||
await SendInfoAsync(_clientCts.Token);
|
||||
// Send INFO (skip if already sent during TLS negotiation)
|
||||
if (!InfoAlreadySent)
|
||||
await SendInfoAsync(_clientCts.Token);
|
||||
|
||||
// Start auth timeout if auth is required
|
||||
Task? authTimeoutTask = null;
|
||||
@@ -100,7 +121,7 @@ public sealed class NatsClient : IDisposable
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Normal — client connected or was cancelled
|
||||
// Normal -- client connected or was cancelled
|
||||
}
|
||||
}, _clientCts.Token);
|
||||
}
|
||||
@@ -184,6 +205,8 @@ public sealed class NatsClient : IDisposable
|
||||
|
||||
private async ValueTask DispatchCommandAsync(ParsedCommand cmd, CancellationToken ct)
|
||||
{
|
||||
Interlocked.Exchange(ref _lastActivityTicks, DateTime.UtcNow.Ticks);
|
||||
|
||||
// If auth is required and CONNECT hasn't been received yet,
|
||||
// only allow CONNECT and PING commands
|
||||
if (_authService.IsAuthRequired && !ConnectReceived)
|
||||
@@ -266,7 +289,7 @@ public sealed class NatsClient : IDisposable
|
||||
|
||||
_logger.LogDebug("Client {ClientId} authenticated as {Identity}", Id, result.Identity);
|
||||
|
||||
// Clear nonce after use — defense-in-depth against memory dumps
|
||||
// Clear nonce after use -- defense-in-depth against memory dumps
|
||||
if (_nonce != null)
|
||||
CryptographicOperations.ZeroMemory(_nonce);
|
||||
}
|
||||
@@ -330,6 +353,8 @@ public sealed class NatsClient : IDisposable
|
||||
{
|
||||
Interlocked.Increment(ref InMsgs);
|
||||
Interlocked.Add(ref InBytes, cmd.Payload.Length);
|
||||
Interlocked.Increment(ref _serverStats.InMsgs);
|
||||
Interlocked.Add(ref _serverStats.InBytes, cmd.Payload.Length);
|
||||
|
||||
// Max payload validation (always, hard close)
|
||||
if (cmd.Payload.Length > _options.MaxPayload)
|
||||
@@ -380,6 +405,8 @@ public sealed class NatsClient : IDisposable
|
||||
{
|
||||
Interlocked.Increment(ref OutMsgs);
|
||||
Interlocked.Add(ref OutBytes, payload.Length + headers.Length);
|
||||
Interlocked.Increment(ref _serverStats.OutMsgs);
|
||||
Interlocked.Add(ref _serverStats.OutBytes, payload.Length + headers.Length);
|
||||
|
||||
byte[] line;
|
||||
if (headers.Length > 0)
|
||||
@@ -470,7 +497,7 @@ public sealed class NatsClient : IDisposable
|
||||
|
||||
if (Volatile.Read(ref _pingsOut) + 1 > _options.MaxPingsOut)
|
||||
{
|
||||
_logger.LogDebug("Client {ClientId} stale connection — closing", Id);
|
||||
_logger.LogDebug("Client {ClientId} stale connection -- closing", Id);
|
||||
await SendErrAndCloseAsync(NatsProtocol.ErrStaleConnection);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
using System.Security.Authentication;
|
||||
using NATS.Server.Auth;
|
||||
|
||||
namespace NATS.Server;
|
||||
@@ -7,7 +8,7 @@ public sealed class NatsOptions
|
||||
public string Host { get; set; } = "0.0.0.0";
|
||||
public int Port { get; set; } = 4222;
|
||||
public string? ServerName { get; set; }
|
||||
public int MaxPayload { get; set; } = 1024 * 1024; // 1MB
|
||||
public int MaxPayload { get; set; } = 1024 * 1024;
|
||||
public int MaxControlLine { get; set; } = 4096;
|
||||
public int MaxConnections { get; set; } = 65536;
|
||||
public TimeSpan PingInterval { get; set; } = TimeSpan.FromMinutes(2);
|
||||
@@ -27,4 +28,27 @@ public sealed class NatsOptions
|
||||
|
||||
// Auth timing
|
||||
public TimeSpan AuthTimeout { get; set; } = TimeSpan.FromSeconds(2);
|
||||
|
||||
// Monitoring (0 = disabled; standard port is 8222)
|
||||
public int MonitorPort { get; set; }
|
||||
public string MonitorHost { get; set; } = "0.0.0.0";
|
||||
public string? MonitorBasePath { get; set; }
|
||||
// 0 = disabled
|
||||
public int MonitorHttpsPort { get; set; }
|
||||
|
||||
// TLS
|
||||
public string? TlsCert { get; set; }
|
||||
public string? TlsKey { get; set; }
|
||||
public string? TlsCaCert { get; set; }
|
||||
public bool TlsVerify { get; set; }
|
||||
public bool TlsMap { get; set; }
|
||||
public TimeSpan TlsTimeout { get; set; } = TimeSpan.FromSeconds(2);
|
||||
public bool TlsHandshakeFirst { get; set; }
|
||||
public TimeSpan TlsHandshakeFirstFallback { get; set; } = TimeSpan.FromMilliseconds(50);
|
||||
public bool AllowNonTls { get; set; }
|
||||
public long TlsRateLimit { get; set; }
|
||||
public HashSet<string>? TlsPinnedCerts { get; set; }
|
||||
public SslProtocols TlsMinVersion { get; set; } = SslProtocols.Tls12;
|
||||
|
||||
public bool HasTls => TlsCert != null && TlsKey != null;
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
using System.Collections.Concurrent;
|
||||
using System.Net;
|
||||
using System.Net.Security;
|
||||
using System.Net.Sockets;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using System.Text;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using NATS.Server.Auth;
|
||||
using NATS.Server.Monitoring;
|
||||
using NATS.Server.Protocol;
|
||||
using NATS.Server.Subscriptions;
|
||||
using NATS.Server.Tls;
|
||||
|
||||
namespace NATS.Server;
|
||||
|
||||
@@ -16,14 +20,25 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
private readonly ServerInfo _serverInfo;
|
||||
private readonly ILogger<NatsServer> _logger;
|
||||
private readonly ILoggerFactory _loggerFactory;
|
||||
private readonly ServerStats _stats = new();
|
||||
private readonly TaskCompletionSource _listeningStarted = new(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
private readonly AuthService _authService;
|
||||
private readonly ConcurrentDictionary<string, Account> _accounts = new(StringComparer.Ordinal);
|
||||
private readonly Account _globalAccount;
|
||||
private readonly SslServerAuthenticationOptions? _sslOptions;
|
||||
private readonly TlsRateLimiter? _tlsRateLimiter;
|
||||
private Socket? _listener;
|
||||
private MonitorServer? _monitorServer;
|
||||
private ulong _nextClientId;
|
||||
private long _startTimeTicks;
|
||||
|
||||
public SubList SubList => _globalAccount.SubList;
|
||||
public ServerStats Stats => _stats;
|
||||
public DateTime StartTime => new(Interlocked.Read(ref _startTimeTicks), DateTimeKind.Utc);
|
||||
public string ServerId => _serverInfo.ServerId;
|
||||
public string ServerName => _serverInfo.ServerName;
|
||||
public int ClientCount => _clients.Count;
|
||||
public IEnumerable<NatsClient> GetClients() => _clients.Values;
|
||||
|
||||
public Task WaitForReadyAsync() => _listeningStarted.Task;
|
||||
|
||||
@@ -45,6 +60,17 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
MaxPayload = options.MaxPayload,
|
||||
AuthRequired = _authService.IsAuthRequired,
|
||||
};
|
||||
|
||||
if (options.HasTls)
|
||||
{
|
||||
_sslOptions = TlsHelper.BuildServerAuthOptions(options);
|
||||
_serverInfo.TlsRequired = !options.AllowNonTls;
|
||||
_serverInfo.TlsAvailable = options.AllowNonTls;
|
||||
_serverInfo.TlsVerify = options.TlsVerify;
|
||||
|
||||
if (options.TlsRateLimit > 0)
|
||||
_tlsRateLimiter = new TlsRateLimiter(options.TlsRateLimit);
|
||||
}
|
||||
}
|
||||
|
||||
public async Task StartAsync(CancellationToken ct)
|
||||
@@ -54,11 +80,18 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
_listener.Bind(new IPEndPoint(
|
||||
_options.Host == "0.0.0.0" ? IPAddress.Any : IPAddress.Parse(_options.Host),
|
||||
_options.Port));
|
||||
Interlocked.Exchange(ref _startTimeTicks, DateTime.UtcNow.Ticks);
|
||||
_listener.Listen(128);
|
||||
_listeningStarted.TrySetResult();
|
||||
|
||||
_logger.LogInformation("Listening on {Host}:{Port}", _options.Host, _options.Port);
|
||||
|
||||
if (_options.MonitorPort > 0)
|
||||
{
|
||||
_monitorServer = new MonitorServer(this, _options, _stats, _loggerFactory);
|
||||
await _monitorServer.StartAsync(ct);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
while (!ct.IsCancellationRequested)
|
||||
@@ -91,37 +124,11 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
}
|
||||
|
||||
var clientId = Interlocked.Increment(ref _nextClientId);
|
||||
Interlocked.Increment(ref _stats.TotalConnections);
|
||||
|
||||
_logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint);
|
||||
|
||||
// Build per-client ServerInfo with nonce if NKey auth is configured
|
||||
byte[]? nonce = null;
|
||||
var clientInfo = _serverInfo;
|
||||
if (_authService.NonceRequired)
|
||||
{
|
||||
var rawNonce = _authService.GenerateNonce();
|
||||
var nonceStr = _authService.EncodeNonce(rawNonce);
|
||||
// The client signs the nonce string (ASCII), not the raw bytes
|
||||
nonce = Encoding.ASCII.GetBytes(nonceStr);
|
||||
clientInfo = new ServerInfo
|
||||
{
|
||||
ServerId = _serverInfo.ServerId,
|
||||
ServerName = _serverInfo.ServerName,
|
||||
Version = _serverInfo.Version,
|
||||
Host = _serverInfo.Host,
|
||||
Port = _serverInfo.Port,
|
||||
MaxPayload = _serverInfo.MaxPayload,
|
||||
AuthRequired = _serverInfo.AuthRequired,
|
||||
Nonce = nonceStr,
|
||||
};
|
||||
}
|
||||
|
||||
var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]");
|
||||
var client = new NatsClient(clientId, socket, _options, clientInfo, _authService, nonce, clientLogger);
|
||||
client.Router = this;
|
||||
_clients[clientId] = client;
|
||||
|
||||
_ = RunClientAsync(client, ct);
|
||||
_ = AcceptClientAsync(socket, clientId, ct);
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
@@ -130,6 +137,74 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
}
|
||||
}
|
||||
|
||||
private async Task AcceptClientAsync(Socket socket, ulong clientId, CancellationToken ct)
|
||||
{
|
||||
try
|
||||
{
|
||||
// Rate limit TLS handshakes
|
||||
if (_tlsRateLimiter != null)
|
||||
await _tlsRateLimiter.WaitAsync(ct);
|
||||
|
||||
var networkStream = new NetworkStream(socket, ownsSocket: false);
|
||||
|
||||
// TLS negotiation (no-op if not configured)
|
||||
var (stream, infoAlreadySent) = await TlsConnectionWrapper.NegotiateAsync(
|
||||
socket, networkStream, _options, _sslOptions, _serverInfo,
|
||||
_loggerFactory.CreateLogger("NATS.Server.Tls"), ct);
|
||||
|
||||
// Extract TLS state
|
||||
TlsConnectionState? tlsState = null;
|
||||
if (stream is SslStream ssl)
|
||||
{
|
||||
tlsState = new TlsConnectionState(
|
||||
ssl.SslProtocol.ToString(),
|
||||
ssl.NegotiatedCipherSuite.ToString(),
|
||||
ssl.RemoteCertificate as X509Certificate2);
|
||||
}
|
||||
|
||||
// Build per-client ServerInfo with nonce if NKey auth is configured
|
||||
byte[]? nonce = null;
|
||||
var clientInfo = _serverInfo;
|
||||
if (_authService.NonceRequired)
|
||||
{
|
||||
var rawNonce = _authService.GenerateNonce();
|
||||
var nonceStr = _authService.EncodeNonce(rawNonce);
|
||||
// The client signs the nonce string (ASCII), not the raw bytes
|
||||
nonce = Encoding.ASCII.GetBytes(nonceStr);
|
||||
clientInfo = new ServerInfo
|
||||
{
|
||||
ServerId = _serverInfo.ServerId,
|
||||
ServerName = _serverInfo.ServerName,
|
||||
Version = _serverInfo.Version,
|
||||
Host = _serverInfo.Host,
|
||||
Port = _serverInfo.Port,
|
||||
MaxPayload = _serverInfo.MaxPayload,
|
||||
AuthRequired = _serverInfo.AuthRequired,
|
||||
TlsRequired = _serverInfo.TlsRequired,
|
||||
TlsAvailable = _serverInfo.TlsAvailable,
|
||||
TlsVerify = _serverInfo.TlsVerify,
|
||||
Nonce = nonceStr,
|
||||
};
|
||||
}
|
||||
|
||||
var clientLogger = _loggerFactory.CreateLogger($"NATS.Server.NatsClient[{clientId}]");
|
||||
var client = new NatsClient(clientId, stream, socket, _options, clientInfo,
|
||||
_authService, nonce, clientLogger, _stats);
|
||||
client.Router = this;
|
||||
client.TlsState = tlsState;
|
||||
client.InfoAlreadySent = infoAlreadySent;
|
||||
_clients[clientId] = client;
|
||||
|
||||
await RunClientAsync(client, ct);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogDebug(ex, "Failed to accept client {ClientId}", clientId);
|
||||
try { socket.Shutdown(SocketShutdown.Both); } catch { }
|
||||
socket.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
private async Task RunClientAsync(NatsClient client, CancellationToken ct)
|
||||
{
|
||||
try
|
||||
@@ -215,6 +290,9 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
if (_monitorServer != null)
|
||||
_monitorServer.DisposeAsync().AsTask().GetAwaiter().GetResult();
|
||||
_tlsRateLimiter?.Dispose();
|
||||
_listener?.Dispose();
|
||||
foreach (var client in _clients.Values)
|
||||
client.Dispose();
|
||||
|
||||
@@ -73,6 +73,18 @@ public sealed class ServerInfo
|
||||
[JsonPropertyName("nonce")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
|
||||
public string? Nonce { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_required")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
|
||||
public bool TlsRequired { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_verify")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
|
||||
public bool TlsVerify { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_available")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
|
||||
public bool TlsAvailable { get; set; }
|
||||
}
|
||||
|
||||
public sealed class ClientOptions
|
||||
|
||||
20
src/NATS.Server/ServerStats.cs
Normal file
20
src/NATS.Server/ServerStats.cs
Normal file
@@ -0,0 +1,20 @@
|
||||
using System.Collections.Concurrent;
|
||||
|
||||
namespace NATS.Server;
|
||||
|
||||
public sealed class ServerStats
|
||||
{
|
||||
public long InMsgs;
|
||||
public long OutMsgs;
|
||||
public long InBytes;
|
||||
public long OutBytes;
|
||||
public long TotalConnections;
|
||||
public long SlowConsumers;
|
||||
public long StaleConnections;
|
||||
public long Stalls;
|
||||
public long SlowConsumerClients;
|
||||
public long SlowConsumerRoutes;
|
||||
public long SlowConsumerLeafs;
|
||||
public long SlowConsumerGateways;
|
||||
public readonly ConcurrentDictionary<string, long> HttpReqStats = new();
|
||||
}
|
||||
71
src/NATS.Server/Tls/PeekableStream.cs
Normal file
71
src/NATS.Server/Tls/PeekableStream.cs
Normal file
@@ -0,0 +1,71 @@
|
||||
namespace NATS.Server.Tls;
|
||||
|
||||
public sealed class PeekableStream : Stream
|
||||
{
|
||||
private readonly Stream _inner;
|
||||
private byte[]? _peekedBytes;
|
||||
private int _peekedOffset;
|
||||
private int _peekedCount;
|
||||
|
||||
public PeekableStream(Stream inner) => _inner = inner;
|
||||
|
||||
public async Task<byte[]> PeekAsync(int count, CancellationToken ct = default)
|
||||
{
|
||||
var buf = new byte[count];
|
||||
int read = await _inner.ReadAsync(buf.AsMemory(0, count), ct);
|
||||
if (read < count) Array.Resize(ref buf, read);
|
||||
_peekedBytes = buf;
|
||||
_peekedOffset = 0;
|
||||
_peekedCount = read;
|
||||
return buf;
|
||||
}
|
||||
|
||||
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken ct = default)
|
||||
{
|
||||
if (_peekedBytes != null && _peekedOffset < _peekedCount)
|
||||
{
|
||||
int available = _peekedCount - _peekedOffset;
|
||||
int toCopy = Math.Min(available, buffer.Length);
|
||||
_peekedBytes.AsMemory(_peekedOffset, toCopy).CopyTo(buffer);
|
||||
_peekedOffset += toCopy;
|
||||
if (_peekedOffset >= _peekedCount) _peekedBytes = null;
|
||||
return toCopy;
|
||||
}
|
||||
return await _inner.ReadAsync(buffer, ct);
|
||||
}
|
||||
|
||||
public override int Read(byte[] buffer, int offset, int count)
|
||||
{
|
||||
if (_peekedBytes != null && _peekedOffset < _peekedCount)
|
||||
{
|
||||
int available = _peekedCount - _peekedOffset;
|
||||
int toCopy = Math.Min(available, count);
|
||||
Array.Copy(_peekedBytes, _peekedOffset, buffer, offset, toCopy);
|
||||
_peekedOffset += toCopy;
|
||||
if (_peekedOffset >= _peekedCount) _peekedBytes = null;
|
||||
return toCopy;
|
||||
}
|
||||
return _inner.Read(buffer, offset, count);
|
||||
}
|
||||
|
||||
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken ct)
|
||||
=> ReadAsync(buffer.AsMemory(offset, count), ct).AsTask();
|
||||
|
||||
// Write passthrough
|
||||
public override void Write(byte[] buffer, int offset, int count) => _inner.Write(buffer, offset, count);
|
||||
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ct) => _inner.WriteAsync(buffer, offset, count, ct);
|
||||
public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken ct = default) => _inner.WriteAsync(buffer, ct);
|
||||
public override void Flush() => _inner.Flush();
|
||||
public override Task FlushAsync(CancellationToken ct) => _inner.FlushAsync(ct);
|
||||
|
||||
// Required Stream overrides
|
||||
public override bool CanRead => _inner.CanRead;
|
||||
public override bool CanSeek => false;
|
||||
public override bool CanWrite => _inner.CanWrite;
|
||||
public override long Length => throw new NotSupportedException();
|
||||
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
|
||||
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
|
||||
public override void SetLength(long value) => throw new NotSupportedException();
|
||||
|
||||
protected override void Dispose(bool disposing) { if (disposing) _inner.Dispose(); base.Dispose(disposing); }
|
||||
}
|
||||
9
src/NATS.Server/Tls/TlsConnectionState.cs
Normal file
9
src/NATS.Server/Tls/TlsConnectionState.cs
Normal file
@@ -0,0 +1,9 @@
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
|
||||
namespace NATS.Server.Tls;
|
||||
|
||||
public sealed record TlsConnectionState(
|
||||
string? TlsVersion,
|
||||
string? CipherSuite,
|
||||
X509Certificate2? PeerCert
|
||||
);
|
||||
202
src/NATS.Server/Tls/TlsConnectionWrapper.cs
Normal file
202
src/NATS.Server/Tls/TlsConnectionWrapper.cs
Normal file
@@ -0,0 +1,202 @@
|
||||
using System.Net.Security;
|
||||
using System.Net.Sockets;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using NATS.Server.Protocol;
|
||||
|
||||
namespace NATS.Server.Tls;
|
||||
|
||||
public static class TlsConnectionWrapper
|
||||
{
|
||||
private const byte TlsRecordMarker = 0x16;
|
||||
|
||||
public static async Task<(Stream stream, bool infoAlreadySent)> NegotiateAsync(
|
||||
Socket socket,
|
||||
Stream networkStream,
|
||||
NatsOptions options,
|
||||
SslServerAuthenticationOptions? sslOptions,
|
||||
ServerInfo serverInfo,
|
||||
ILogger logger,
|
||||
CancellationToken ct)
|
||||
{
|
||||
// Mode 1: No TLS
|
||||
if (sslOptions == null || !options.HasTls)
|
||||
return (networkStream, false);
|
||||
|
||||
// Clone to avoid mutating shared instance
|
||||
serverInfo = new ServerInfo
|
||||
{
|
||||
ServerId = serverInfo.ServerId,
|
||||
ServerName = serverInfo.ServerName,
|
||||
Version = serverInfo.Version,
|
||||
Proto = serverInfo.Proto,
|
||||
Host = serverInfo.Host,
|
||||
Port = serverInfo.Port,
|
||||
Headers = serverInfo.Headers,
|
||||
MaxPayload = serverInfo.MaxPayload,
|
||||
ClientId = serverInfo.ClientId,
|
||||
ClientIp = serverInfo.ClientIp,
|
||||
};
|
||||
|
||||
// Mode 3: TLS First
|
||||
if (options.TlsHandshakeFirst)
|
||||
return await NegotiateTlsFirstAsync(socket, networkStream, options, sslOptions, serverInfo, logger, ct);
|
||||
|
||||
// Mode 2 & 4: Send INFO first, then decide
|
||||
serverInfo.TlsRequired = !options.AllowNonTls;
|
||||
serverInfo.TlsAvailable = options.AllowNonTls;
|
||||
serverInfo.TlsVerify = options.TlsVerify;
|
||||
await SendInfoAsync(networkStream, serverInfo, ct);
|
||||
|
||||
// Peek first byte to detect TLS
|
||||
var peekable = new PeekableStream(networkStream);
|
||||
var peeked = await PeekWithTimeoutAsync(peekable, 1, options.TlsTimeout, ct);
|
||||
|
||||
if (peeked.Length == 0)
|
||||
{
|
||||
// Client disconnected or timed out
|
||||
return (peekable, true);
|
||||
}
|
||||
|
||||
if (peeked[0] == TlsRecordMarker)
|
||||
{
|
||||
// Client is starting TLS
|
||||
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
|
||||
try
|
||||
{
|
||||
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
||||
handshakeCts.CancelAfter(options.TlsTimeout);
|
||||
|
||||
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
|
||||
logger.LogDebug("TLS handshake complete: {Protocol} {CipherSuite}",
|
||||
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
|
||||
|
||||
// Validate pinned certs
|
||||
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
|
||||
{
|
||||
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
|
||||
{
|
||||
logger.LogWarning("Certificate pinning check failed");
|
||||
throw new InvalidOperationException("Certificate pinning check failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
catch
|
||||
{
|
||||
sslStream.Dispose();
|
||||
throw;
|
||||
}
|
||||
|
||||
return (sslStream, true);
|
||||
}
|
||||
|
||||
// Mode 4: Mixed — client chose plaintext
|
||||
if (options.AllowNonTls)
|
||||
{
|
||||
logger.LogDebug("Client connected without TLS (mixed mode)");
|
||||
return (peekable, true);
|
||||
}
|
||||
|
||||
// TLS required but client sent plaintext
|
||||
logger.LogWarning("TLS required but client sent plaintext data");
|
||||
throw new InvalidOperationException("TLS required");
|
||||
}
|
||||
|
||||
private static async Task<(Stream stream, bool infoAlreadySent)> NegotiateTlsFirstAsync(
|
||||
Socket socket,
|
||||
Stream networkStream,
|
||||
NatsOptions options,
|
||||
SslServerAuthenticationOptions sslOptions,
|
||||
ServerInfo serverInfo,
|
||||
ILogger logger,
|
||||
CancellationToken ct)
|
||||
{
|
||||
// Wait for data with fallback timeout
|
||||
var peekable = new PeekableStream(networkStream);
|
||||
var peeked = await PeekWithTimeoutAsync(peekable, 1, options.TlsHandshakeFirstFallback, ct);
|
||||
|
||||
if (peeked.Length > 0 && peeked[0] == TlsRecordMarker)
|
||||
{
|
||||
// Client started TLS immediately — handshake first, then send INFO
|
||||
var sslStream = new SslStream(peekable, leaveInnerStreamOpen: false);
|
||||
try
|
||||
{
|
||||
using var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
||||
handshakeCts.CancelAfter(options.TlsTimeout);
|
||||
|
||||
await sslStream.AuthenticateAsServerAsync(sslOptions, handshakeCts.Token);
|
||||
logger.LogDebug("TLS-first handshake complete: {Protocol} {CipherSuite}",
|
||||
sslStream.SslProtocol, sslStream.NegotiatedCipherSuite);
|
||||
|
||||
// Validate pinned certs
|
||||
if (options.TlsPinnedCerts != null && sslStream.RemoteCertificate is X509Certificate2 remoteCert)
|
||||
{
|
||||
if (!TlsHelper.MatchesPinnedCert(remoteCert, options.TlsPinnedCerts))
|
||||
{
|
||||
throw new InvalidOperationException("Certificate pinning check failed");
|
||||
}
|
||||
}
|
||||
|
||||
// Now send INFO over encrypted stream
|
||||
serverInfo.TlsRequired = true;
|
||||
serverInfo.TlsVerify = options.TlsVerify;
|
||||
await SendInfoAsync(sslStream, serverInfo, ct);
|
||||
}
|
||||
catch
|
||||
{
|
||||
sslStream.Dispose();
|
||||
throw;
|
||||
}
|
||||
|
||||
return (sslStream, true);
|
||||
}
|
||||
|
||||
// Fallback: timeout expired or non-TLS data — send INFO and negotiate normally
|
||||
logger.LogDebug("TLS-first fallback: sending INFO");
|
||||
serverInfo.TlsRequired = !options.AllowNonTls;
|
||||
serverInfo.TlsAvailable = options.AllowNonTls;
|
||||
serverInfo.TlsVerify = options.TlsVerify;
|
||||
await SendInfoAsync(peekable, serverInfo, ct);
|
||||
|
||||
if (peeked.Length == 0)
|
||||
{
|
||||
// Timeout — INFO was sent, return stream for normal flow
|
||||
return (peekable, true);
|
||||
}
|
||||
|
||||
// Non-TLS data received during fallback window
|
||||
if (options.AllowNonTls)
|
||||
{
|
||||
return (peekable, true);
|
||||
}
|
||||
|
||||
// TLS required but got plaintext
|
||||
throw new InvalidOperationException("TLS required but client sent plaintext");
|
||||
}
|
||||
|
||||
private static async Task<byte[]> PeekWithTimeoutAsync(
|
||||
PeekableStream stream, int count, TimeSpan timeout, CancellationToken ct)
|
||||
{
|
||||
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
||||
cts.CancelAfter(timeout);
|
||||
try
|
||||
{
|
||||
return await stream.PeekAsync(count, cts.Token);
|
||||
}
|
||||
catch (OperationCanceledException) when (!ct.IsCancellationRequested)
|
||||
{
|
||||
// Timeout — not a cancellation of the outer token
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
private static async Task SendInfoAsync(Stream stream, ServerInfo serverInfo, CancellationToken ct)
|
||||
{
|
||||
var infoJson = JsonSerializer.Serialize(serverInfo);
|
||||
var infoLine = Encoding.ASCII.GetBytes($"INFO {infoJson}\r\n");
|
||||
await stream.WriteAsync(infoLine, ct);
|
||||
await stream.FlushAsync(ct);
|
||||
}
|
||||
}
|
||||
65
src/NATS.Server/Tls/TlsHelper.cs
Normal file
65
src/NATS.Server/Tls/TlsHelper.cs
Normal file
@@ -0,0 +1,65 @@
|
||||
using System.Net.Security;
|
||||
using System.Security.Authentication;
|
||||
using System.Security.Cryptography;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
|
||||
namespace NATS.Server.Tls;
|
||||
|
||||
public static class TlsHelper
|
||||
{
|
||||
public static X509Certificate2 LoadCertificate(string certPath, string? keyPath)
|
||||
{
|
||||
if (keyPath != null)
|
||||
return X509Certificate2.CreateFromPemFile(certPath, keyPath);
|
||||
return X509CertificateLoader.LoadCertificateFromFile(certPath);
|
||||
}
|
||||
|
||||
public static X509Certificate2Collection LoadCaCertificates(string caPath)
|
||||
{
|
||||
var collection = new X509Certificate2Collection();
|
||||
collection.ImportFromPemFile(caPath);
|
||||
return collection;
|
||||
}
|
||||
|
||||
public static SslServerAuthenticationOptions BuildServerAuthOptions(NatsOptions opts)
|
||||
{
|
||||
var cert = LoadCertificate(opts.TlsCert!, opts.TlsKey);
|
||||
var authOpts = new SslServerAuthenticationOptions
|
||||
{
|
||||
ServerCertificate = cert,
|
||||
EnabledSslProtocols = opts.TlsMinVersion,
|
||||
ClientCertificateRequired = opts.TlsVerify,
|
||||
};
|
||||
|
||||
if (opts.TlsVerify && opts.TlsCaCert != null)
|
||||
{
|
||||
var caCerts = LoadCaCertificates(opts.TlsCaCert);
|
||||
authOpts.RemoteCertificateValidationCallback = (_, cert, chain, errors) =>
|
||||
{
|
||||
if (cert == null) return false;
|
||||
using var chain2 = new X509Chain();
|
||||
chain2.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust;
|
||||
foreach (var ca in caCerts)
|
||||
chain2.ChainPolicy.CustomTrustStore.Add(ca);
|
||||
chain2.ChainPolicy.RevocationMode = X509RevocationMode.NoCheck;
|
||||
var cert2 = cert as X509Certificate2 ?? X509CertificateLoader.LoadCertificate(cert.GetRawCertData());
|
||||
return chain2.Build(cert2);
|
||||
};
|
||||
}
|
||||
|
||||
return authOpts;
|
||||
}
|
||||
|
||||
public static string GetCertificateHash(X509Certificate2 cert)
|
||||
{
|
||||
var spki = cert.PublicKey.ExportSubjectPublicKeyInfo();
|
||||
var hash = SHA256.HashData(spki);
|
||||
return Convert.ToHexStringLower(hash);
|
||||
}
|
||||
|
||||
public static bool MatchesPinnedCert(X509Certificate2 cert, HashSet<string> pinned)
|
||||
{
|
||||
var hash = GetCertificateHash(cert);
|
||||
return pinned.Contains(hash);
|
||||
}
|
||||
}
|
||||
25
src/NATS.Server/Tls/TlsRateLimiter.cs
Normal file
25
src/NATS.Server/Tls/TlsRateLimiter.cs
Normal file
@@ -0,0 +1,25 @@
|
||||
namespace NATS.Server.Tls;
|
||||
|
||||
public sealed class TlsRateLimiter : IDisposable
|
||||
{
|
||||
private readonly SemaphoreSlim _semaphore;
|
||||
private readonly Timer _refillTimer;
|
||||
private readonly int _tokensPerSecond;
|
||||
|
||||
public TlsRateLimiter(long tokensPerSecond)
|
||||
{
|
||||
_tokensPerSecond = (int)Math.Max(1, tokensPerSecond);
|
||||
_semaphore = new SemaphoreSlim(_tokensPerSecond, _tokensPerSecond);
|
||||
_refillTimer = new Timer(Refill, null, TimeSpan.FromSeconds(1), TimeSpan.FromSeconds(1));
|
||||
}
|
||||
|
||||
private void Refill(object? state)
|
||||
{
|
||||
int toRelease = _tokensPerSecond - _semaphore.CurrentCount;
|
||||
if (toRelease > 0) _semaphore.Release(toRelease);
|
||||
}
|
||||
|
||||
public Task WaitAsync(CancellationToken ct) => _semaphore.WaitAsync(ct);
|
||||
|
||||
public void Dispose() { _refillTimer.Dispose(); _semaphore.Dispose(); }
|
||||
}
|
||||
@@ -41,7 +41,7 @@ public class ClientTests : IAsyncDisposable
|
||||
};
|
||||
|
||||
var authService = AuthService.Build(new NatsOptions());
|
||||
_natsClient = new NatsClient(1, _serverSocket, new NatsOptions(), serverInfo, authService, null, NullLogger.Instance);
|
||||
_natsClient = new NatsClient(1, new NetworkStream(_serverSocket, ownsSocket: false), _serverSocket, new NatsOptions(), serverInfo, authService, null, NullLogger.Instance, new ServerStats());
|
||||
}
|
||||
|
||||
public async ValueTask DisposeAsync()
|
||||
@@ -56,7 +56,7 @@ public class ClientTests : IAsyncDisposable
|
||||
{
|
||||
var runTask = _natsClient.RunAsync(_cts.Token);
|
||||
|
||||
// Read from client socket — should get INFO
|
||||
// Read from client socket -- should get INFO
|
||||
var buf = new byte[4096];
|
||||
var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None);
|
||||
var response = Encoding.ASCII.GetString(buf, 0, n);
|
||||
@@ -80,7 +80,7 @@ public class ClientTests : IAsyncDisposable
|
||||
// Send CONNECT then PING
|
||||
await _clientSocket.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPING\r\n"));
|
||||
|
||||
// Read response — should get PONG
|
||||
// Read response -- should get PONG
|
||||
var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None);
|
||||
var response = Encoding.ASCII.GetString(buf, 0, n);
|
||||
|
||||
@@ -128,7 +128,7 @@ public class ClientTests : IAsyncDisposable
|
||||
|
||||
response.ShouldBe("-ERR 'maximum connections exceeded'\r\n");
|
||||
|
||||
// Connection should be closed — next read returns 0
|
||||
// Connection should be closed -- next read returns 0
|
||||
n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token);
|
||||
n.ShouldBe(0);
|
||||
}
|
||||
|
||||
51
tests/NATS.Server.Tests/MonitorModelTests.cs
Normal file
51
tests/NATS.Server.Tests/MonitorModelTests.cs
Normal file
@@ -0,0 +1,51 @@
|
||||
using System.Text.Json;
|
||||
using NATS.Server.Monitoring;
|
||||
|
||||
namespace NATS.Server.Tests;
|
||||
|
||||
public class MonitorModelTests
|
||||
{
|
||||
[Fact]
|
||||
public void Varz_serializes_with_go_field_names()
|
||||
{
|
||||
var varz = new Varz
|
||||
{
|
||||
Id = "TESTID", Name = "test-server", Version = "0.1.0",
|
||||
Host = "0.0.0.0", Port = 4222, InMsgs = 100, OutMsgs = 200,
|
||||
};
|
||||
var json = JsonSerializer.Serialize(varz);
|
||||
json.ShouldContain("\"server_id\":");
|
||||
json.ShouldContain("\"server_name\":");
|
||||
json.ShouldContain("\"in_msgs\":");
|
||||
json.ShouldContain("\"out_msgs\":");
|
||||
json.ShouldNotContain("\"InMsgs\"");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Connz_serializes_with_go_field_names()
|
||||
{
|
||||
var connz = new Connz
|
||||
{
|
||||
Id = "TESTID", Now = DateTime.UtcNow, NumConns = 1, Total = 1, Limit = 1024,
|
||||
Conns = [new ConnInfo { Cid = 1, Ip = "127.0.0.1", Port = 5555,
|
||||
InMsgs = 10, Uptime = "1s", Idle = "0s",
|
||||
Start = DateTime.UtcNow, LastActivity = DateTime.UtcNow }],
|
||||
};
|
||||
var json = JsonSerializer.Serialize(connz);
|
||||
json.ShouldContain("\"server_id\":");
|
||||
json.ShouldContain("\"num_connections\":");
|
||||
json.ShouldContain("\"in_msgs\":");
|
||||
json.ShouldContain("\"pending_bytes\":");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Varz_includes_nested_config_stubs()
|
||||
{
|
||||
var varz = new Varz { Id = "X", Name = "X", Version = "X", Host = "X" };
|
||||
var json = JsonSerializer.Serialize(varz);
|
||||
json.ShouldContain("\"cluster\":");
|
||||
json.ShouldContain("\"gateway\":");
|
||||
json.ShouldContain("\"leaf\":");
|
||||
json.ShouldContain("\"jetstream\":");
|
||||
}
|
||||
}
|
||||
274
tests/NATS.Server.Tests/MonitorTests.cs
Normal file
274
tests/NATS.Server.Tests/MonitorTests.cs
Normal file
@@ -0,0 +1,274 @@
|
||||
using System.Net;
|
||||
using System.Net.Http.Json;
|
||||
using System.Net.Security;
|
||||
using System.Net.Sockets;
|
||||
using System.Text;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using NATS.Server.Monitoring;
|
||||
|
||||
namespace NATS.Server.Tests;
|
||||
|
||||
public class MonitorTests : IAsyncLifetime
|
||||
{
|
||||
private readonly NatsServer _server;
|
||||
private readonly int _natsPort;
|
||||
private readonly int _monitorPort;
|
||||
private readonly CancellationTokenSource _cts = new();
|
||||
private readonly HttpClient _http = new();
|
||||
|
||||
public MonitorTests()
|
||||
{
|
||||
_natsPort = GetFreePort();
|
||||
_monitorPort = GetFreePort();
|
||||
_server = new NatsServer(
|
||||
new NatsOptions { Port = _natsPort, MonitorPort = _monitorPort },
|
||||
NullLoggerFactory.Instance);
|
||||
}
|
||||
|
||||
public async Task InitializeAsync()
|
||||
{
|
||||
_ = _server.StartAsync(_cts.Token);
|
||||
await _server.WaitForReadyAsync();
|
||||
// Wait for monitoring HTTP server to be ready
|
||||
for (int i = 0; i < 50; i++)
|
||||
{
|
||||
try
|
||||
{
|
||||
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/healthz");
|
||||
if (response.IsSuccessStatusCode) break;
|
||||
}
|
||||
catch (HttpRequestException) { }
|
||||
await Task.Delay(50);
|
||||
}
|
||||
}
|
||||
|
||||
public async Task DisposeAsync()
|
||||
{
|
||||
_http.Dispose();
|
||||
await _cts.CancelAsync();
|
||||
_server.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Healthz_returns_ok()
|
||||
{
|
||||
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/healthz");
|
||||
response.StatusCode.ShouldBe(HttpStatusCode.OK);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Varz_returns_server_identity()
|
||||
{
|
||||
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/varz");
|
||||
response.StatusCode.ShouldBe(HttpStatusCode.OK);
|
||||
|
||||
var varz = await response.Content.ReadFromJsonAsync<Varz>();
|
||||
varz.ShouldNotBeNull();
|
||||
varz.Id.ShouldNotBeNullOrEmpty();
|
||||
varz.Name.ShouldNotBeNullOrEmpty();
|
||||
varz.Version.ShouldBe("0.1.0");
|
||||
varz.Host.ShouldBe("0.0.0.0");
|
||||
varz.Port.ShouldBe(_natsPort);
|
||||
varz.MaxPayload.ShouldBe(1024 * 1024);
|
||||
varz.Uptime.ShouldNotBeNullOrEmpty();
|
||||
varz.Now.ShouldBeGreaterThan(DateTime.MinValue);
|
||||
varz.Mem.ShouldBeGreaterThan(0);
|
||||
varz.Cores.ShouldBeGreaterThan(0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Varz_tracks_connections_and_messages()
|
||||
{
|
||||
// Connect a client and send a message
|
||||
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _natsPort));
|
||||
|
||||
var buf = new byte[4096];
|
||||
_ = await sock.ReceiveAsync(buf, SocketFlags.None); // Read INFO
|
||||
|
||||
var cmd = "CONNECT {}\r\nSUB test 1\r\nPUB test 5\r\nhello\r\n"u8.ToArray();
|
||||
await sock.SendAsync(cmd, SocketFlags.None);
|
||||
await Task.Delay(200);
|
||||
|
||||
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/varz");
|
||||
var varz = await response.Content.ReadFromJsonAsync<Varz>();
|
||||
|
||||
varz.ShouldNotBeNull();
|
||||
varz.Connections.ShouldBeGreaterThanOrEqualTo(1);
|
||||
varz.TotalConnections.ShouldBeGreaterThanOrEqualTo(1UL);
|
||||
varz.InMsgs.ShouldBeGreaterThanOrEqualTo(1L);
|
||||
varz.InBytes.ShouldBeGreaterThanOrEqualTo(5L);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Connz_returns_connections()
|
||||
{
|
||||
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _natsPort));
|
||||
using var stream = new NetworkStream(sock);
|
||||
var buf = new byte[4096];
|
||||
_ = await stream.ReadAsync(buf);
|
||||
await stream.WriteAsync("CONNECT {\"name\":\"test-client\",\"lang\":\"csharp\",\"version\":\"1.0\"}\r\n"u8.ToArray());
|
||||
await Task.Delay(200);
|
||||
|
||||
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/connz");
|
||||
response.StatusCode.ShouldBe(HttpStatusCode.OK);
|
||||
|
||||
var connz = await response.Content.ReadFromJsonAsync<Connz>();
|
||||
connz.ShouldNotBeNull();
|
||||
connz.NumConns.ShouldBeGreaterThanOrEqualTo(1);
|
||||
connz.Conns.Length.ShouldBeGreaterThanOrEqualTo(1);
|
||||
|
||||
var conn = connz.Conns.First(c => c.Name == "test-client");
|
||||
conn.Ip.ShouldNotBeNullOrEmpty();
|
||||
conn.Port.ShouldBeGreaterThan(0);
|
||||
conn.Lang.ShouldBe("csharp");
|
||||
conn.Version.ShouldBe("1.0");
|
||||
conn.Uptime.ShouldNotBeNullOrEmpty();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Connz_pagination()
|
||||
{
|
||||
var sockets = new List<Socket>();
|
||||
try
|
||||
{
|
||||
for (int i = 0; i < 3; i++)
|
||||
{
|
||||
var s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
await s.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _natsPort));
|
||||
var ns = new NetworkStream(s);
|
||||
var buf = new byte[4096];
|
||||
_ = await ns.ReadAsync(buf);
|
||||
await ns.WriteAsync("CONNECT {}\r\n"u8.ToArray());
|
||||
sockets.Add(s);
|
||||
}
|
||||
await Task.Delay(200);
|
||||
|
||||
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/connz?limit=2&offset=0");
|
||||
var connz = await response.Content.ReadFromJsonAsync<Connz>();
|
||||
|
||||
connz!.Conns.Length.ShouldBe(2);
|
||||
connz.Total.ShouldBeGreaterThanOrEqualTo(3);
|
||||
connz.Limit.ShouldBe(2);
|
||||
connz.Offset.ShouldBe(0);
|
||||
}
|
||||
finally
|
||||
{
|
||||
foreach (var s in sockets) s.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Connz_with_subscriptions()
|
||||
{
|
||||
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _natsPort));
|
||||
using var stream = new NetworkStream(sock);
|
||||
var buf = new byte[4096];
|
||||
_ = await stream.ReadAsync(buf);
|
||||
await stream.WriteAsync("CONNECT {}\r\nSUB foo 1\r\nSUB bar 2\r\n"u8.ToArray());
|
||||
await Task.Delay(200);
|
||||
|
||||
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/connz?subs=true");
|
||||
var connz = await response.Content.ReadFromJsonAsync<Connz>();
|
||||
|
||||
var conn = connz!.Conns.First(c => c.NumSubs >= 2);
|
||||
conn.Subs.ShouldNotBeNull();
|
||||
conn.Subs.ShouldContain("foo");
|
||||
conn.Subs.ShouldContain("bar");
|
||||
}
|
||||
|
||||
private static int GetFreePort()
|
||||
{
|
||||
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
|
||||
return ((IPEndPoint)sock.LocalEndPoint!).Port;
|
||||
}
|
||||
}
|
||||
|
||||
public class MonitorTlsTests : IAsyncLifetime
|
||||
{
|
||||
private readonly NatsServer _server;
|
||||
private readonly int _natsPort;
|
||||
private readonly int _monitorPort;
|
||||
private readonly CancellationTokenSource _cts = new();
|
||||
private readonly HttpClient _http = new();
|
||||
private readonly string _certPath;
|
||||
private readonly string _keyPath;
|
||||
|
||||
public MonitorTlsTests()
|
||||
{
|
||||
_natsPort = GetFreePort();
|
||||
_monitorPort = GetFreePort();
|
||||
(_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles();
|
||||
_server = new NatsServer(
|
||||
new NatsOptions
|
||||
{
|
||||
Port = _natsPort,
|
||||
MonitorPort = _monitorPort,
|
||||
TlsCert = _certPath,
|
||||
TlsKey = _keyPath,
|
||||
},
|
||||
NullLoggerFactory.Instance);
|
||||
}
|
||||
|
||||
public async Task InitializeAsync()
|
||||
{
|
||||
_ = _server.StartAsync(_cts.Token);
|
||||
await _server.WaitForReadyAsync();
|
||||
// Wait for monitoring HTTP server to be ready
|
||||
for (int i = 0; i < 50; i++)
|
||||
{
|
||||
try
|
||||
{
|
||||
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/healthz");
|
||||
if (response.IsSuccessStatusCode) break;
|
||||
}
|
||||
catch (HttpRequestException) { }
|
||||
await Task.Delay(50);
|
||||
}
|
||||
}
|
||||
|
||||
public async Task DisposeAsync()
|
||||
{
|
||||
_http.Dispose();
|
||||
await _cts.CancelAsync();
|
||||
_server.Dispose();
|
||||
File.Delete(_certPath);
|
||||
File.Delete(_keyPath);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Connz_shows_tls_info_for_tls_client()
|
||||
{
|
||||
// Connect and upgrade to TLS
|
||||
using var tcp = new TcpClient();
|
||||
await tcp.ConnectAsync(IPAddress.Loopback, _natsPort);
|
||||
using var netStream = tcp.GetStream();
|
||||
var buf = new byte[4096];
|
||||
_ = await netStream.ReadAsync(buf); // Read INFO
|
||||
|
||||
using var ssl = new SslStream(netStream, false, (_, _, _, _) => true);
|
||||
await ssl.AuthenticateAsClientAsync("localhost");
|
||||
|
||||
await ssl.WriteAsync("CONNECT {}\r\n"u8.ToArray());
|
||||
await ssl.FlushAsync();
|
||||
await Task.Delay(200);
|
||||
|
||||
var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/connz");
|
||||
var connz = await response.Content.ReadFromJsonAsync<Connz>();
|
||||
|
||||
connz!.Conns.Length.ShouldBeGreaterThanOrEqualTo(1);
|
||||
var conn = connz.Conns[0];
|
||||
conn.TlsVersion.ShouldNotBeNullOrEmpty();
|
||||
conn.TlsCipherSuite.ShouldNotBeNullOrEmpty();
|
||||
}
|
||||
|
||||
private static int GetFreePort()
|
||||
{
|
||||
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
|
||||
return ((IPEndPoint)sock.LocalEndPoint!).Port;
|
||||
}
|
||||
}
|
||||
84
tests/NATS.Server.Tests/ServerStatsTests.cs
Normal file
84
tests/NATS.Server.Tests/ServerStatsTests.cs
Normal file
@@ -0,0 +1,84 @@
|
||||
using System.Net;
|
||||
using System.Net.Sockets;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using NATS.Server;
|
||||
|
||||
namespace NATS.Server.Tests;
|
||||
|
||||
public class ServerStatsTests : IAsyncLifetime
|
||||
{
|
||||
private readonly NatsServer _server;
|
||||
private readonly int _port;
|
||||
private readonly CancellationTokenSource _cts = new();
|
||||
|
||||
public ServerStatsTests()
|
||||
{
|
||||
_port = GetFreePort();
|
||||
_server = new NatsServer(new NatsOptions { Port = _port }, NullLoggerFactory.Instance);
|
||||
}
|
||||
|
||||
public async Task InitializeAsync()
|
||||
{
|
||||
_ = _server.StartAsync(_cts.Token);
|
||||
await _server.WaitForReadyAsync();
|
||||
}
|
||||
|
||||
public Task DisposeAsync()
|
||||
{
|
||||
_cts.Cancel();
|
||||
_server.Dispose();
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Server_has_start_time()
|
||||
{
|
||||
_server.StartTime.ShouldNotBe(default);
|
||||
_server.StartTime.ShouldBeLessThanOrEqualTo(DateTime.UtcNow);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Server_tracks_total_connections()
|
||||
{
|
||||
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port));
|
||||
await Task.Delay(100);
|
||||
_server.Stats.TotalConnections.ShouldBeGreaterThanOrEqualTo(1);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Server_stats_track_messages()
|
||||
{
|
||||
using var pub = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
await pub.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port));
|
||||
|
||||
var buf = new byte[4096];
|
||||
await pub.ReceiveAsync(buf, SocketFlags.None); // INFO
|
||||
|
||||
await pub.SendAsync("CONNECT {}\r\nSUB test 1\r\nPUB test 5\r\nhello\r\n"u8.ToArray());
|
||||
await Task.Delay(200);
|
||||
|
||||
_server.Stats.InMsgs.ShouldBeGreaterThanOrEqualTo(1);
|
||||
_server.Stats.InBytes.ShouldBeGreaterThanOrEqualTo(5);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Client_has_metadata()
|
||||
{
|
||||
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port));
|
||||
await Task.Delay(100);
|
||||
|
||||
var client = _server.GetClients().First();
|
||||
client.RemoteIp.ShouldNotBeNullOrEmpty();
|
||||
client.RemotePort.ShouldBeGreaterThan(0);
|
||||
client.StartTime.ShouldNotBe(default);
|
||||
}
|
||||
|
||||
private static int GetFreePort()
|
||||
{
|
||||
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
|
||||
return ((IPEndPoint)sock.LocalEndPoint!).Port;
|
||||
}
|
||||
}
|
||||
254
tests/NATS.Server.Tests/TlsConnectionWrapperTests.cs
Normal file
254
tests/NATS.Server.Tests/TlsConnectionWrapperTests.cs
Normal file
@@ -0,0 +1,254 @@
|
||||
using System.Net;
|
||||
using System.Net.Security;
|
||||
using System.Net.Sockets;
|
||||
using System.Security.Cryptography;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using NATS.Server;
|
||||
using NATS.Server.Protocol;
|
||||
using NATS.Server.Tls;
|
||||
|
||||
namespace NATS.Server.Tests;
|
||||
|
||||
public class TlsConnectionWrapperTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task NoTls_returns_plain_stream()
|
||||
{
|
||||
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
||||
using var serverStream = new NetworkStream(serverSocket, ownsSocket: true);
|
||||
using var clientStream = new NetworkStream(clientSocket, ownsSocket: true);
|
||||
|
||||
var opts = new NatsOptions(); // No TLS configured
|
||||
var serverInfo = CreateServerInfo();
|
||||
|
||||
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
|
||||
serverSocket, serverStream, opts, null, serverInfo, NullLogger.Instance, CancellationToken.None);
|
||||
|
||||
stream.ShouldBe(serverStream); // Same stream, no wrapping
|
||||
infoSent.ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TlsRequired_upgrades_to_ssl()
|
||||
{
|
||||
var (cert, _) = TlsHelperTests.GenerateTestCert();
|
||||
|
||||
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
||||
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
|
||||
|
||||
var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy" };
|
||||
var sslOpts = new SslServerAuthenticationOptions
|
||||
{
|
||||
ServerCertificate = cert,
|
||||
};
|
||||
var serverInfo = CreateServerInfo();
|
||||
|
||||
// Client side: read INFO then start TLS
|
||||
var clientTask = Task.Run(async () =>
|
||||
{
|
||||
// Read INFO line
|
||||
var buf = new byte[4096];
|
||||
var read = await clientNetStream.ReadAsync(buf);
|
||||
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
|
||||
info.ShouldStartWith("INFO ");
|
||||
|
||||
// Upgrade to TLS
|
||||
var sslClient = new SslStream(clientNetStream, true,
|
||||
(_, _, _, _) => true); // Trust all for testing
|
||||
await sslClient.AuthenticateAsClientAsync("localhost");
|
||||
return sslClient;
|
||||
});
|
||||
|
||||
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
|
||||
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
|
||||
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
|
||||
|
||||
stream.ShouldBeOfType<SslStream>();
|
||||
infoSent.ShouldBeTrue();
|
||||
|
||||
var clientSsl = await clientTask;
|
||||
|
||||
// Verify encrypted communication works
|
||||
await stream.WriteAsync("PING\r\n"u8.ToArray());
|
||||
await stream.FlushAsync();
|
||||
|
||||
var readBuf = new byte[64];
|
||||
var bytesRead = await clientSsl.ReadAsync(readBuf);
|
||||
var msg = System.Text.Encoding.ASCII.GetString(readBuf, 0, bytesRead);
|
||||
msg.ShouldBe("PING\r\n");
|
||||
|
||||
stream.Dispose();
|
||||
clientSsl.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task MixedMode_allows_plaintext_when_AllowNonTls()
|
||||
{
|
||||
var (cert, _) = TlsHelperTests.GenerateTestCert();
|
||||
|
||||
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
||||
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
|
||||
|
||||
var opts = new NatsOptions
|
||||
{
|
||||
TlsCert = "dummy",
|
||||
TlsKey = "dummy",
|
||||
AllowNonTls = true,
|
||||
TlsTimeout = TimeSpan.FromSeconds(2),
|
||||
};
|
||||
var sslOpts = new SslServerAuthenticationOptions
|
||||
{
|
||||
ServerCertificate = cert,
|
||||
};
|
||||
var serverInfo = CreateServerInfo();
|
||||
|
||||
// Client side: read INFO then send plaintext (not TLS)
|
||||
var clientTask = Task.Run(async () =>
|
||||
{
|
||||
var buf = new byte[4096];
|
||||
var read = await clientNetStream.ReadAsync(buf);
|
||||
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
|
||||
info.ShouldStartWith("INFO ");
|
||||
|
||||
// Send plaintext CONNECT (not a TLS handshake)
|
||||
var connectLine = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n");
|
||||
await clientNetStream.WriteAsync(connectLine);
|
||||
await clientNetStream.FlushAsync();
|
||||
});
|
||||
|
||||
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
|
||||
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
|
||||
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
|
||||
|
||||
await clientTask;
|
||||
|
||||
// In mixed mode with plaintext client, we get a PeekableStream, not SslStream
|
||||
stream.ShouldBeOfType<PeekableStream>();
|
||||
infoSent.ShouldBeTrue();
|
||||
|
||||
stream.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TlsRequired_rejects_plaintext()
|
||||
{
|
||||
var (cert, _) = TlsHelperTests.GenerateTestCert();
|
||||
|
||||
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
||||
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
|
||||
|
||||
var opts = new NatsOptions
|
||||
{
|
||||
TlsCert = "dummy",
|
||||
TlsKey = "dummy",
|
||||
AllowNonTls = false,
|
||||
TlsTimeout = TimeSpan.FromSeconds(2),
|
||||
};
|
||||
var sslOpts = new SslServerAuthenticationOptions
|
||||
{
|
||||
ServerCertificate = cert,
|
||||
};
|
||||
var serverInfo = CreateServerInfo();
|
||||
|
||||
// Client side: read INFO then send plaintext
|
||||
var clientTask = Task.Run(async () =>
|
||||
{
|
||||
var buf = new byte[4096];
|
||||
var read = await clientNetStream.ReadAsync(buf);
|
||||
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
|
||||
info.ShouldStartWith("INFO ");
|
||||
|
||||
// Send plaintext data (first byte is 'C', not 0x16 TLS marker)
|
||||
var connectLine = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n");
|
||||
await clientNetStream.WriteAsync(connectLine);
|
||||
await clientNetStream.FlushAsync();
|
||||
});
|
||||
|
||||
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
|
||||
|
||||
await Should.ThrowAsync<InvalidOperationException>(async () =>
|
||||
{
|
||||
await TlsConnectionWrapper.NegotiateAsync(
|
||||
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
|
||||
});
|
||||
|
||||
await clientTask;
|
||||
serverNetStream.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TlsFirst_handshakes_before_sending_info()
|
||||
{
|
||||
var (cert, _) = TlsHelperTests.GenerateTestCert();
|
||||
|
||||
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
||||
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
|
||||
|
||||
var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy", TlsHandshakeFirst = true };
|
||||
var sslOpts = new SslServerAuthenticationOptions
|
||||
{
|
||||
ServerCertificate = cert,
|
||||
};
|
||||
var serverInfo = CreateServerInfo();
|
||||
|
||||
// Client side: immediately start TLS (no INFO first)
|
||||
var clientTask = Task.Run(async () =>
|
||||
{
|
||||
var sslClient = new SslStream(clientNetStream, true, (_, _, _, _) => true);
|
||||
await sslClient.AuthenticateAsClientAsync("localhost");
|
||||
|
||||
// After TLS, read INFO over encrypted stream
|
||||
var buf = new byte[4096];
|
||||
var read = await sslClient.ReadAsync(buf);
|
||||
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
|
||||
info.ShouldStartWith("INFO ");
|
||||
|
||||
return sslClient;
|
||||
});
|
||||
|
||||
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
|
||||
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
|
||||
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
|
||||
|
||||
stream.ShouldBeOfType<SslStream>();
|
||||
infoSent.ShouldBeTrue();
|
||||
|
||||
var clientSsl = await clientTask;
|
||||
|
||||
// Verify encrypted communication works
|
||||
await stream.WriteAsync("PING\r\n"u8.ToArray());
|
||||
await stream.FlushAsync();
|
||||
|
||||
var readBuf = new byte[64];
|
||||
var bytesRead = await clientSsl.ReadAsync(readBuf);
|
||||
var msg = System.Text.Encoding.ASCII.GetString(readBuf, 0, bytesRead);
|
||||
msg.ShouldBe("PING\r\n");
|
||||
|
||||
stream.Dispose();
|
||||
clientSsl.Dispose();
|
||||
}
|
||||
|
||||
private static ServerInfo CreateServerInfo() => new()
|
||||
{
|
||||
ServerId = "TEST",
|
||||
ServerName = "test",
|
||||
Version = NatsProtocol.Version,
|
||||
Host = "127.0.0.1",
|
||||
Port = 4222,
|
||||
};
|
||||
|
||||
private static async Task<(Socket server, Socket client)> CreateSocketPairAsync()
|
||||
{
|
||||
using var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
|
||||
listener.Listen(1);
|
||||
var port = ((IPEndPoint)listener.LocalEndPoint!).Port;
|
||||
|
||||
var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
await client.ConnectAsync(new IPEndPoint(IPAddress.Loopback, port));
|
||||
var server = await listener.AcceptAsync();
|
||||
|
||||
return (server, client);
|
||||
}
|
||||
}
|
||||
110
tests/NATS.Server.Tests/TlsHelperTests.cs
Normal file
110
tests/NATS.Server.Tests/TlsHelperTests.cs
Normal file
@@ -0,0 +1,110 @@
|
||||
using System.Net;
|
||||
using System.Security.Cryptography;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using NATS.Server;
|
||||
using NATS.Server.Tls;
|
||||
|
||||
namespace NATS.Server.Tests;
|
||||
|
||||
public class TlsHelperTests
|
||||
{
|
||||
[Fact]
|
||||
public void LoadCertificate_loads_pem_cert_and_key()
|
||||
{
|
||||
var (certPath, keyPath) = GenerateTestCertFiles();
|
||||
try
|
||||
{
|
||||
var cert = TlsHelper.LoadCertificate(certPath, keyPath);
|
||||
cert.ShouldNotBeNull();
|
||||
cert.HasPrivateKey.ShouldBeTrue();
|
||||
}
|
||||
finally { File.Delete(certPath); File.Delete(keyPath); }
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BuildServerAuthOptions_creates_valid_options()
|
||||
{
|
||||
var (certPath, keyPath) = GenerateTestCertFiles();
|
||||
try
|
||||
{
|
||||
var opts = new NatsOptions { TlsCert = certPath, TlsKey = keyPath };
|
||||
var authOpts = TlsHelper.BuildServerAuthOptions(opts);
|
||||
authOpts.ShouldNotBeNull();
|
||||
authOpts.ServerCertificate.ShouldNotBeNull();
|
||||
}
|
||||
finally { File.Delete(certPath); File.Delete(keyPath); }
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MatchesPinnedCert_matches_correct_hash()
|
||||
{
|
||||
var (cert, _) = GenerateTestCert();
|
||||
var hash = TlsHelper.GetCertificateHash(cert);
|
||||
var pinned = new HashSet<string> { hash };
|
||||
TlsHelper.MatchesPinnedCert(cert, pinned).ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MatchesPinnedCert_rejects_wrong_hash()
|
||||
{
|
||||
var (cert, _) = GenerateTestCert();
|
||||
var pinned = new HashSet<string> { "0000000000000000000000000000000000000000000000000000000000000000" };
|
||||
TlsHelper.MatchesPinnedCert(cert, pinned).ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task PeekableStream_peeks_and_replays()
|
||||
{
|
||||
var data = "Hello, World!"u8.ToArray();
|
||||
using var ms = new MemoryStream(data);
|
||||
using var peekable = new PeekableStream(ms);
|
||||
|
||||
var peeked = await peekable.PeekAsync(1);
|
||||
peeked.Length.ShouldBe(1);
|
||||
peeked[0].ShouldBe((byte)'H');
|
||||
|
||||
var buf = new byte[data.Length];
|
||||
int total = 0;
|
||||
while (total < data.Length)
|
||||
{
|
||||
var read = await peekable.ReadAsync(buf.AsMemory(total));
|
||||
if (read == 0) break;
|
||||
total += read;
|
||||
}
|
||||
total.ShouldBe(data.Length);
|
||||
buf.ShouldBe(data);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TlsRateLimiter_allows_within_limit()
|
||||
{
|
||||
using var limiter = new TlsRateLimiter(10);
|
||||
var cts = new CancellationTokenSource(TimeSpan.FromSeconds(2));
|
||||
for (int i = 0; i < 5; i++)
|
||||
await limiter.WaitAsync(cts.Token);
|
||||
}
|
||||
|
||||
// Public helper methods used by other test classes
|
||||
public static (string certPath, string keyPath) GenerateTestCertFiles()
|
||||
{
|
||||
var (cert, key) = GenerateTestCert();
|
||||
var certPath = Path.GetTempFileName();
|
||||
var keyPath = Path.GetTempFileName();
|
||||
File.WriteAllText(certPath, cert.ExportCertificatePem());
|
||||
File.WriteAllText(keyPath, key.ExportPkcs8PrivateKeyPem());
|
||||
return (certPath, keyPath);
|
||||
}
|
||||
|
||||
public static (X509Certificate2 cert, RSA key) GenerateTestCert()
|
||||
{
|
||||
var key = RSA.Create(2048);
|
||||
var req = new CertificateRequest("CN=localhost", key, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
|
||||
req.CertificateExtensions.Add(new X509BasicConstraintsExtension(false, false, 0, false));
|
||||
var sanBuilder = new SubjectAlternativeNameBuilder();
|
||||
sanBuilder.AddIpAddress(IPAddress.Loopback);
|
||||
sanBuilder.AddDnsName("localhost");
|
||||
req.CertificateExtensions.Add(sanBuilder.Build());
|
||||
var cert = req.CreateSelfSigned(DateTimeOffset.UtcNow, DateTimeOffset.UtcNow.AddYears(1));
|
||||
return (cert, key);
|
||||
}
|
||||
}
|
||||
225
tests/NATS.Server.Tests/TlsServerTests.cs
Normal file
225
tests/NATS.Server.Tests/TlsServerTests.cs
Normal file
@@ -0,0 +1,225 @@
|
||||
using System.Net;
|
||||
using System.Net.Security;
|
||||
using System.Net.Sockets;
|
||||
using System.Text;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using NATS.Server;
|
||||
|
||||
namespace NATS.Server.Tests;
|
||||
|
||||
public class TlsServerTests : IAsyncLifetime
|
||||
{
|
||||
private readonly NatsServer _server;
|
||||
private readonly int _port;
|
||||
private readonly CancellationTokenSource _cts = new();
|
||||
private readonly string _certPath;
|
||||
private readonly string _keyPath;
|
||||
|
||||
public TlsServerTests()
|
||||
{
|
||||
_port = GetFreePort();
|
||||
(_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles();
|
||||
_server = new NatsServer(
|
||||
new NatsOptions
|
||||
{
|
||||
Port = _port,
|
||||
TlsCert = _certPath,
|
||||
TlsKey = _keyPath,
|
||||
},
|
||||
NullLoggerFactory.Instance);
|
||||
}
|
||||
|
||||
public async Task InitializeAsync()
|
||||
{
|
||||
_ = _server.StartAsync(_cts.Token);
|
||||
await _server.WaitForReadyAsync();
|
||||
}
|
||||
|
||||
public async Task DisposeAsync()
|
||||
{
|
||||
await _cts.CancelAsync();
|
||||
_server.Dispose();
|
||||
File.Delete(_certPath);
|
||||
File.Delete(_keyPath);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Tls_client_connects_and_receives_info()
|
||||
{
|
||||
using var tcp = new TcpClient();
|
||||
await tcp.ConnectAsync(IPAddress.Loopback, _port);
|
||||
using var netStream = tcp.GetStream();
|
||||
|
||||
// Read INFO (sent before TLS upgrade in Mode 2)
|
||||
var buf = new byte[4096];
|
||||
var read = await netStream.ReadAsync(buf);
|
||||
var info = Encoding.ASCII.GetString(buf, 0, read);
|
||||
info.ShouldStartWith("INFO ");
|
||||
info.ShouldContain("\"tls_required\":true");
|
||||
|
||||
// Upgrade to TLS
|
||||
using var sslStream = new SslStream(netStream, false, (_, _, _, _) => true);
|
||||
await sslStream.AuthenticateAsClientAsync("localhost");
|
||||
|
||||
// Send CONNECT + PING over TLS
|
||||
await sslStream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
|
||||
await sslStream.FlushAsync();
|
||||
|
||||
// Read PONG
|
||||
var pongBuf = new byte[256];
|
||||
read = await sslStream.ReadAsync(pongBuf);
|
||||
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
|
||||
pong.ShouldContain("PONG");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Tls_pubsub_works_over_encrypted_connection()
|
||||
{
|
||||
using var tcp1 = new TcpClient();
|
||||
await tcp1.ConnectAsync(IPAddress.Loopback, _port);
|
||||
using var ssl1 = await UpgradeToTlsAsync(tcp1);
|
||||
|
||||
using var tcp2 = new TcpClient();
|
||||
await tcp2.ConnectAsync(IPAddress.Loopback, _port);
|
||||
using var ssl2 = await UpgradeToTlsAsync(tcp2);
|
||||
|
||||
// Sub on client 1
|
||||
await ssl1.WriteAsync("CONNECT {}\r\nSUB test 1\r\nPING\r\n"u8.ToArray());
|
||||
await ssl1.FlushAsync();
|
||||
|
||||
// Wait for PONG to confirm subscription is registered
|
||||
var pongBuf = new byte[256];
|
||||
var pongRead = await ssl1.ReadAsync(pongBuf);
|
||||
var pongStr = Encoding.ASCII.GetString(pongBuf, 0, pongRead);
|
||||
pongStr.ShouldContain("PONG");
|
||||
|
||||
// Pub on client 2
|
||||
await ssl2.WriteAsync("CONNECT {}\r\nPUB test 5\r\nhello\r\nPING\r\n"u8.ToArray());
|
||||
await ssl2.FlushAsync();
|
||||
|
||||
// Client 1 should receive MSG (may arrive across multiple TLS records)
|
||||
var msg = await ReadUntilAsync(ssl1, "hello");
|
||||
msg.ShouldContain("MSG test 1 5");
|
||||
msg.ShouldContain("hello");
|
||||
}
|
||||
|
||||
private static async Task<string> ReadUntilAsync(Stream stream, string expected, int timeoutMs = 5000)
|
||||
{
|
||||
using var cts = new CancellationTokenSource(timeoutMs);
|
||||
var sb = new StringBuilder();
|
||||
var buf = new byte[4096];
|
||||
while (!sb.ToString().Contains(expected))
|
||||
{
|
||||
var n = await stream.ReadAsync(buf, cts.Token);
|
||||
if (n == 0) break;
|
||||
sb.Append(Encoding.ASCII.GetString(buf, 0, n));
|
||||
}
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
private static async Task<SslStream> UpgradeToTlsAsync(TcpClient tcp)
|
||||
{
|
||||
var netStream = tcp.GetStream();
|
||||
var buf = new byte[4096];
|
||||
_ = await netStream.ReadAsync(buf); // Read INFO (discard)
|
||||
|
||||
var ssl = new SslStream(netStream, false, (_, _, _, _) => true);
|
||||
await ssl.AuthenticateAsClientAsync("localhost");
|
||||
return ssl;
|
||||
}
|
||||
|
||||
private static int GetFreePort()
|
||||
{
|
||||
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
|
||||
return ((IPEndPoint)sock.LocalEndPoint!).Port;
|
||||
}
|
||||
}
|
||||
|
||||
public class TlsMixedModeTests : IAsyncLifetime
|
||||
{
|
||||
private readonly NatsServer _server;
|
||||
private readonly int _port;
|
||||
private readonly CancellationTokenSource _cts = new();
|
||||
private readonly string _certPath;
|
||||
private readonly string _keyPath;
|
||||
|
||||
public TlsMixedModeTests()
|
||||
{
|
||||
_port = GetFreePort();
|
||||
(_certPath, _keyPath) = TlsHelperTests.GenerateTestCertFiles();
|
||||
_server = new NatsServer(
|
||||
new NatsOptions
|
||||
{
|
||||
Port = _port,
|
||||
TlsCert = _certPath,
|
||||
TlsKey = _keyPath,
|
||||
AllowNonTls = true,
|
||||
},
|
||||
NullLoggerFactory.Instance);
|
||||
}
|
||||
|
||||
public async Task InitializeAsync()
|
||||
{
|
||||
_ = _server.StartAsync(_cts.Token);
|
||||
await _server.WaitForReadyAsync();
|
||||
}
|
||||
|
||||
public async Task DisposeAsync()
|
||||
{
|
||||
await _cts.CancelAsync();
|
||||
_server.Dispose();
|
||||
File.Delete(_certPath);
|
||||
File.Delete(_keyPath);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Mixed_mode_accepts_plain_client()
|
||||
{
|
||||
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port));
|
||||
using var stream = new NetworkStream(sock);
|
||||
|
||||
var buf = new byte[4096];
|
||||
var read = await stream.ReadAsync(buf);
|
||||
var info = Encoding.ASCII.GetString(buf, 0, read);
|
||||
info.ShouldContain("\"tls_available\":true");
|
||||
|
||||
await stream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
|
||||
await stream.FlushAsync();
|
||||
|
||||
var pongBuf = new byte[64];
|
||||
read = await stream.ReadAsync(pongBuf);
|
||||
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
|
||||
pong.ShouldContain("PONG");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Mixed_mode_accepts_tls_client()
|
||||
{
|
||||
using var tcp = new TcpClient();
|
||||
await tcp.ConnectAsync(IPAddress.Loopback, _port);
|
||||
using var netStream = tcp.GetStream();
|
||||
|
||||
var buf = new byte[4096];
|
||||
_ = await netStream.ReadAsync(buf); // Read INFO
|
||||
|
||||
using var ssl = new SslStream(netStream, false, (_, _, _, _) => true);
|
||||
await ssl.AuthenticateAsClientAsync("localhost");
|
||||
|
||||
await ssl.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
|
||||
await ssl.FlushAsync();
|
||||
|
||||
var pongBuf = new byte[64];
|
||||
var read = await ssl.ReadAsync(pongBuf);
|
||||
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
|
||||
pong.ShouldContain("PONG");
|
||||
}
|
||||
|
||||
private static int GetFreePort()
|
||||
{
|
||||
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
|
||||
return ((IPEndPoint)sock.LocalEndPoint!).Port;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user