diff --git a/differences.md b/differences.md index 08eeef1..a0c5fa0 100644 --- a/differences.md +++ b/differences.md @@ -68,7 +68,7 @@ | JETSTREAM (internal) | Y | N | | | ACCOUNT (internal) | Y | Y | Lazy per-account InternalClient with import/export subscription support | | WebSocket clients | Y | Y | Custom frame parser, permessage-deflate compression, origin checking, cookie auth | -| MQTT clients | Y | N | | +| MQTT clients | Y | Partial | JWT connection-type constants + config parsing; no MQTT transport yet | ### Client Features | Feature | Go | .NET | Notes | @@ -204,7 +204,7 @@ Go implements a sophisticated slow consumer detection system: | Username/password | Y | Y | | | Token | Y | Y | | | NKeys (Ed25519) | Y | Y | .NET has framework but integration is basic | -| JWT validation | Y | Y | `NatsJwt` decode/verify, `JwtAuthenticator` with account resolution + revocation | +| JWT validation | Y | Y | `NatsJwt` decode/verify, `JwtAuthenticator` with account resolution + revocation + `allowed_connection_types` enforcement | | Bcrypt password hashing | Y | Y | .NET supports bcrypt (`$2*` prefix) with constant-time fallback | | TLS certificate mapping | Y | Y | X500DistinguishedName with full DN match and CN fallback | | Custom auth interface | Y | N | | @@ -268,7 +268,7 @@ Go implements a sophisticated slow consumer detection system: - ~~Tags/metadata~~ — `Tags` dictionary implemented in `NatsOptions` - ~~OCSP configuration~~ — `OcspConfig` with 4 modes (Auto/Always/Must/Never), peer verification, and stapling - ~~WebSocket options~~ — `WebSocketOptions` with port, compression, origin checking, cookie auth, custom headers -- MQTT options +- ~~MQTT options~~ — `mqtt {}` config block parsed with all Go `MQTTOpts` fields; no listener yet - ~~Operator mode / account resolver~~ — `JwtAuthenticator` + `IAccountResolver` + `MemAccountResolver` with trusted keys --- @@ -317,7 +317,7 @@ Go implements a sophisticated slow consumer detection system: | Subscription detail mode | Y | N | | | TLS peer certificate info | Y | N | | | JWT/IssuerKey/Tags fields | Y | N | | -| MQTT client ID filtering | Y | N | | +| MQTT client ID filtering | Y | Y | `mqtt_client` query param filters open and closed connections | | Proxy info | Y | N | | --- diff --git a/src/NATS.Server/Auth/IAuthenticator.cs b/src/NATS.Server/Auth/IAuthenticator.cs index 3783c88..abb8db3 100644 --- a/src/NATS.Server/Auth/IAuthenticator.cs +++ b/src/NATS.Server/Auth/IAuthenticator.cs @@ -1,4 +1,5 @@ using System.Security.Cryptography.X509Certificates; +using NATS.Server.Auth.Jwt; using NATS.Server.Protocol; namespace NATS.Server.Auth; @@ -13,4 +14,11 @@ public sealed class ClientAuthContext public required ClientOptions Opts { get; init; } public required byte[] Nonce { get; init; } public X509Certificate2? ClientCertificate { get; init; } + + /// + /// The type of connection (e.g., "STANDARD", "WEBSOCKET", "MQTT", "LEAFNODE"). + /// Used by JWT authenticator to enforce allowed_connection_types claims. + /// Defaults to "STANDARD" for regular NATS client connections. + /// + public string ConnectionType { get; init; } = JwtConnectionTypes.Standard; } diff --git a/src/NATS.Server/Auth/Jwt/JwtConnectionTypes.cs b/src/NATS.Server/Auth/Jwt/JwtConnectionTypes.cs new file mode 100644 index 0000000..59d2418 --- /dev/null +++ b/src/NATS.Server/Auth/Jwt/JwtConnectionTypes.cs @@ -0,0 +1,34 @@ +namespace NATS.Server.Auth.Jwt; + +internal static class JwtConnectionTypes +{ + public const string Standard = "STANDARD"; + public const string Websocket = "WEBSOCKET"; + public const string Leafnode = "LEAFNODE"; + public const string LeafnodeWs = "LEAFNODE_WS"; + public const string Mqtt = "MQTT"; + public const string MqttWs = "MQTT_WS"; + public const string InProcess = "INPROCESS"; + + private static readonly HashSet Known = + [ + Standard, Websocket, Leafnode, LeafnodeWs, Mqtt, MqttWs, InProcess, + ]; + + public static (HashSet Valid, bool HasUnknown) Convert(IEnumerable? values) + { + var valid = new HashSet(StringComparer.Ordinal); + var hasUnknown = false; + if (values is null) return (valid, false); + + foreach (var raw in values) + { + var up = (raw ?? string.Empty).Trim().ToUpperInvariant(); + if (up.Length == 0) continue; + if (Known.Contains(up)) valid.Add(up); + else hasUnknown = true; + } + + return (valid, hasUnknown); + } +} diff --git a/src/NATS.Server/Auth/JwtAuthenticator.cs b/src/NATS.Server/Auth/JwtAuthenticator.cs index f28a155..126fb83 100644 --- a/src/NATS.Server/Auth/JwtAuthenticator.cs +++ b/src/NATS.Server/Auth/JwtAuthenticator.cs @@ -95,6 +95,24 @@ public sealed class JwtAuthenticator : IAuthenticator } } + // 7b. Check allowed connection types + var (allowedTypes, hasUnknown) = JwtConnectionTypes.Convert(userClaims.Nats?.AllowedConnectionTypes); + + if (allowedTypes.Count == 0) + { + if (hasUnknown) + return null; // unknown-only list should reject + } + else + { + var connType = string.IsNullOrWhiteSpace(context.ConnectionType) + ? JwtConnectionTypes.Standard + : context.ConnectionType.ToUpperInvariant(); + + if (!allowedTypes.Contains(connType)) + return null; + } + // 8. Build permissions from JWT claims Permissions? permissions = null; var nats = userClaims.Nats; diff --git a/src/NATS.Server/Configuration/ConfigProcessor.cs b/src/NATS.Server/Configuration/ConfigProcessor.cs index 88b36ae..ae593b1 100644 --- a/src/NATS.Server/Configuration/ConfigProcessor.cs +++ b/src/NATS.Server/Configuration/ConfigProcessor.cs @@ -245,6 +245,12 @@ public static class ConfigProcessor opts.ReconnectErrorReports = ToInt(value); break; + // MQTT + case "mqtt": + if (value is Dictionary mqttDict) + ParseMqtt(mqttDict, opts, errors); + break; + // Unknown keys silently ignored (cluster, jetstream, gateway, leafnode, etc.) default: break; @@ -620,6 +626,145 @@ public static class ConfigProcessor opts.Tags = tags; } + // ─── MQTT parsing ──────────────────────────────────────────────── + // Reference: Go server/opts.go parseMQTT (lines ~5443-5541) + + private static void ParseMqtt(Dictionary dict, NatsOptions opts, List errors) + { + var mqtt = opts.Mqtt ?? new MqttOptions(); + + foreach (var (key, value) in dict) + { + switch (key.ToLowerInvariant()) + { + case "listen": + var (host, port) = ParseHostPort(value); + if (host is not null) mqtt.Host = host; + if (port is not null) mqtt.Port = port.Value; + break; + case "port": + mqtt.Port = ToInt(value); + break; + case "host" or "net": + mqtt.Host = ToString(value); + break; + case "no_auth_user": + mqtt.NoAuthUser = ToString(value); + break; + case "tls": + if (value is Dictionary tlsDict) + ParseMqttTls(tlsDict, mqtt, errors); + break; + case "authorization" or "authentication": + if (value is Dictionary authDict) + ParseMqttAuth(authDict, mqtt, errors); + break; + case "ack_wait" or "ackwait": + mqtt.AckWait = ParseDuration(value); + break; + case "js_api_timeout" or "api_timeout": + mqtt.JsApiTimeout = ParseDuration(value); + break; + case "max_ack_pending" or "max_pending" or "max_inflight": + var pending = ToInt(value); + if (pending < 0 || pending > 0xFFFF) + errors.Add($"mqtt max_ack_pending invalid value {pending}, should be in [0..{0xFFFF}] range"); + else + mqtt.MaxAckPending = (ushort)pending; + break; + case "js_domain": + mqtt.JsDomain = ToString(value); + break; + case "stream_replicas": + mqtt.StreamReplicas = ToInt(value); + break; + case "consumer_replicas": + mqtt.ConsumerReplicas = ToInt(value); + break; + case "consumer_memory_storage": + mqtt.ConsumerMemoryStorage = ToBool(value); + break; + case "consumer_inactive_threshold" or "consumer_auto_cleanup": + mqtt.ConsumerInactiveThreshold = ParseDuration(value); + break; + default: + break; + } + } + + opts.Mqtt = mqtt; + } + + private static void ParseMqttAuth(Dictionary dict, MqttOptions mqtt, List errors) + { + foreach (var (key, value) in dict) + { + switch (key.ToLowerInvariant()) + { + case "user" or "username": + mqtt.Username = ToString(value); + break; + case "pass" or "password": + mqtt.Password = ToString(value); + break; + case "token": + mqtt.Token = ToString(value); + break; + case "timeout": + mqtt.AuthTimeout = ToDouble(value); + break; + default: + break; + } + } + } + + private static void ParseMqttTls(Dictionary dict, MqttOptions mqtt, List errors) + { + foreach (var (key, value) in dict) + { + switch (key.ToLowerInvariant()) + { + case "cert_file": + mqtt.TlsCert = ToString(value); + break; + case "key_file": + mqtt.TlsKey = ToString(value); + break; + case "ca_file": + mqtt.TlsCaCert = ToString(value); + break; + case "verify": + mqtt.TlsVerify = ToBool(value); + break; + case "verify_and_map": + var map = ToBool(value); + mqtt.TlsMap = map; + if (map) mqtt.TlsVerify = true; + break; + case "timeout": + mqtt.TlsTimeout = ToDouble(value); + break; + case "pinned_certs": + if (value is List pinnedList) + { + var certs = new HashSet(StringComparer.OrdinalIgnoreCase); + foreach (var item in pinnedList) + { + if (item is string s) + certs.Add(s.ToLowerInvariant()); + } + + mqtt.TlsPinnedCerts = certs; + } + + break; + default: + break; + } + } + } + // ─── Type conversion helpers ─────────────────────────────────── private static int ToInt(object? value) => value switch @@ -653,6 +798,15 @@ public static class ConfigProcessor _ => throw new FormatException($"Cannot convert {value?.GetType().Name ?? "null"} to string"), }; + private static double ToDouble(object? value) => value switch + { + double d => d, + long l => l, + int i => i, + string s when double.TryParse(s, NumberStyles.Float, CultureInfo.InvariantCulture, out var d) => d, + _ => throw new FormatException($"Cannot convert {value?.GetType().Name ?? "null"} to double"), + }; + private static IReadOnlyList ToStringList(object? value) { if (value is List list) diff --git a/src/NATS.Server/Monitoring/ClosedClient.cs b/src/NATS.Server/Monitoring/ClosedClient.cs index 0710d19..277fba8 100644 --- a/src/NATS.Server/Monitoring/ClosedClient.cs +++ b/src/NATS.Server/Monitoring/ClosedClient.cs @@ -22,4 +22,5 @@ public sealed record ClosedClient public TimeSpan Rtt { get; init; } public string TlsVersion { get; init; } = ""; public string TlsCipherSuite { get; init; } = ""; + public string MqttClient { get; init; } = ""; } diff --git a/src/NATS.Server/Monitoring/Connz.cs b/src/NATS.Server/Monitoring/Connz.cs index aae62ed..cea93ca 100644 --- a/src/NATS.Server/Monitoring/Connz.cs +++ b/src/NATS.Server/Monitoring/Connz.cs @@ -204,6 +204,8 @@ public sealed class ConnzOptions public string FilterSubject { get; set; } = ""; + public string MqttClient { get; set; } = ""; + public int Offset { get; set; } public int Limit { get; set; } = 1024; diff --git a/src/NATS.Server/Monitoring/ConnzHandler.cs b/src/NATS.Server/Monitoring/ConnzHandler.cs index 8ecf512..96c96de 100644 --- a/src/NATS.Server/Monitoring/ConnzHandler.cs +++ b/src/NATS.Server/Monitoring/ConnzHandler.cs @@ -28,6 +28,10 @@ public sealed class ConnzHandler(NatsServer server) connInfos.AddRange(server.GetClosedClients().Select(c => BuildClosedConnInfo(c, now, opts))); } + // Filter by MQTT client ID + if (!string.IsNullOrEmpty(opts.MqttClient)) + connInfos = connInfos.Where(c => c.MqttClient == opts.MqttClient).ToList(); + // Validate sort options that require closed state if (opts.Sort is SortOpt.ByStop or SortOpt.ByReason && opts.State == ConnState.Open) opts.Sort = SortOpt.ByCid; // Fallback @@ -142,6 +146,7 @@ public sealed class ConnzHandler(NatsServer server) Rtt = FormatRtt(closed.Rtt), TlsVersion = closed.TlsVersion, TlsCipherSuite = closed.TlsCipherSuite, + MqttClient = closed.MqttClient, }; } @@ -197,6 +202,9 @@ public sealed class ConnzHandler(NatsServer server) if (q.TryGetValue("limit", out var limit) && int.TryParse(limit, out var l)) opts.Limit = l; + if (q.TryGetValue("mqtt_client", out var mqttClient)) + opts.MqttClient = mqttClient.ToString(); + return opts; } diff --git a/src/NATS.Server/Monitoring/Varz.cs b/src/NATS.Server/Monitoring/Varz.cs index 3e85374..dda7234 100644 --- a/src/NATS.Server/Monitoring/Varz.cs +++ b/src/NATS.Server/Monitoring/Varz.cs @@ -355,8 +355,29 @@ public sealed class MqttOptsVarz [JsonPropertyName("port")] public int Port { get; set; } + [JsonPropertyName("no_auth_user")] + public string NoAuthUser { get; set; } = ""; + + [JsonPropertyName("auth_timeout")] + public double AuthTimeout { get; set; } + + [JsonPropertyName("tls_map")] + public bool TlsMap { get; set; } + [JsonPropertyName("tls_timeout")] public double TlsTimeout { get; set; } + + [JsonPropertyName("tls_pinned_certs")] + public string[] TlsPinnedCerts { get; set; } = []; + + [JsonPropertyName("js_domain")] + public string JsDomain { get; set; } = ""; + + [JsonPropertyName("ack_wait")] + public long AckWait { get; set; } + + [JsonPropertyName("max_ack_pending")] + public ushort MaxAckPending { get; set; } } /// diff --git a/src/NATS.Server/Monitoring/VarzHandler.cs b/src/NATS.Server/Monitoring/VarzHandler.cs index 3bdbe6d..290139c 100644 --- a/src/NATS.Server/Monitoring/VarzHandler.cs +++ b/src/NATS.Server/Monitoring/VarzHandler.cs @@ -121,6 +121,7 @@ public sealed class VarzHandler : IDisposable Subscriptions = _server.SubList.Count, ConfigLoadTime = _server.StartTime, HttpReqStats = stats.HttpReqStats.ToDictionary(kv => kv.Key, kv => (ulong)kv.Value), + Mqtt = BuildMqttVarz(), }; } finally @@ -134,6 +135,27 @@ public sealed class VarzHandler : IDisposable _varzMu.Dispose(); } + private MqttOptsVarz BuildMqttVarz() + { + var mqtt = _options.Mqtt; + if (mqtt is null) + return new MqttOptsVarz(); + + return new MqttOptsVarz + { + Host = mqtt.Host, + Port = mqtt.Port, + NoAuthUser = mqtt.NoAuthUser ?? "", + AuthTimeout = mqtt.AuthTimeout, + TlsMap = mqtt.TlsMap, + TlsTimeout = mqtt.TlsTimeout, + TlsPinnedCerts = mqtt.TlsPinnedCerts?.ToArray() ?? [], + JsDomain = mqtt.JsDomain ?? "", + AckWait = (long)mqtt.AckWait.TotalNanoseconds, + MaxAckPending = mqtt.MaxAckPending, + }; + } + /// /// Formats a TimeSpan as a human-readable uptime string matching Go server format. /// diff --git a/src/NATS.Server/MqttOptions.cs b/src/NATS.Server/MqttOptions.cs new file mode 100644 index 0000000..c47e15e --- /dev/null +++ b/src/NATS.Server/MqttOptions.cs @@ -0,0 +1,43 @@ +namespace NATS.Server; + +/// +/// MQTT protocol configuration options. +/// Corresponds to Go server/opts.go MQTTOpts struct. +/// Config is parsed and stored but no MQTT listener is started yet. +/// +public sealed class MqttOptions +{ + // Network + public string Host { get; set; } = ""; + public int Port { get; set; } + + // Auth override (MQTT-specific, separate from global auth) + public string? NoAuthUser { get; set; } + public string? Username { get; set; } + public string? Password { get; set; } + public string? Token { get; set; } + public double AuthTimeout { get; set; } + + // TLS + public string? TlsCert { get; set; } + public string? TlsKey { get; set; } + public string? TlsCaCert { get; set; } + public bool TlsVerify { get; set; } + public double TlsTimeout { get; set; } = 2.0; + public bool TlsMap { get; set; } + public HashSet? TlsPinnedCerts { get; set; } + + // JetStream integration + public string? JsDomain { get; set; } + public int StreamReplicas { get; set; } + public int ConsumerReplicas { get; set; } + public bool ConsumerMemoryStorage { get; set; } + public TimeSpan ConsumerInactiveThreshold { get; set; } + + // QoS + public TimeSpan AckWait { get; set; } = TimeSpan.FromSeconds(30); + public ushort MaxAckPending { get; set; } + public TimeSpan JsApiTimeout { get; set; } = TimeSpan.FromSeconds(5); + + public bool HasTls => TlsCert != null && TlsKey != null; +} diff --git a/src/NATS.Server/NatsClient.cs b/src/NATS.Server/NatsClient.cs index 0f6a47a..b1a7c52 100644 --- a/src/NATS.Server/NatsClient.cs +++ b/src/NATS.Server/NatsClient.cs @@ -8,6 +8,7 @@ using System.Text.Json; using System.Threading.Channels; using Microsoft.Extensions.Logging; using NATS.Server.Auth; +using NATS.Server.Auth.Jwt; using NATS.Server.Protocol; using NATS.Server.Subscriptions; using NATS.Server.Tls; @@ -391,6 +392,7 @@ public sealed class NatsClient : INatsClient, IDisposable Opts = ClientOpts, Nonce = _nonce ?? [], ClientCertificate = TlsState?.PeerCert, + ConnectionType = JwtConnectionTypes.Standard, }; authResult = _authService.Authenticate(context); diff --git a/src/NATS.Server/NatsOptions.cs b/src/NATS.Server/NatsOptions.cs index 1e3820e..981f6f7 100644 --- a/src/NATS.Server/NatsOptions.cs +++ b/src/NATS.Server/NatsOptions.cs @@ -115,6 +115,9 @@ public sealed class NatsOptions // Subject mapping / transforms (source pattern -> destination template) public Dictionary? SubjectMappings { get; set; } + // MQTT configuration (parsed from config, no listener yet) + public MqttOptions? Mqtt { get; set; } + public bool HasTls => TlsCert != null && TlsKey != null; // WebSocket diff --git a/src/NATS.Server/NatsServer.cs b/src/NATS.Server/NatsServer.cs index 01d1a16..4a773c7 100644 --- a/src/NATS.Server/NatsServer.cs +++ b/src/NATS.Server/NatsServer.cs @@ -1221,6 +1221,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable Rtt = client.Rtt, TlsVersion = client.TlsState?.TlsVersion ?? "", TlsCipherSuite = client.TlsState?.CipherSuite ?? "", + MqttClient = "", // populated when MQTT transport is implemented }); // Cap closed clients list diff --git a/tests/NATS.Server.Tests/ConfigProcessorTests.cs b/tests/NATS.Server.Tests/ConfigProcessorTests.cs index 0ee2f39..6f7a7e5 100644 --- a/tests/NATS.Server.Tests/ConfigProcessorTests.cs +++ b/tests/NATS.Server.Tests/ConfigProcessorTests.cs @@ -501,4 +501,112 @@ public class ConfigProcessorTests var opts = ConfigProcessor.ProcessConfigFile(TestDataPath("tls.conf")); opts.HasTls.ShouldBeTrue(); } + + // ─── MQTT config ──────────────────────────────────────────── + + [Fact] + public void MqttConf_ListenHostAndPort() + { + var opts = ConfigProcessor.ProcessConfigFile(TestDataPath("mqtt.conf")); + opts.Mqtt.ShouldNotBeNull(); + opts.Mqtt!.Host.ShouldBe("10.0.0.1"); + opts.Mqtt.Port.ShouldBe(1883); + } + + [Fact] + public void MqttConf_NoAuthUser() + { + var opts = ConfigProcessor.ProcessConfigFile(TestDataPath("mqtt.conf")); + opts.Mqtt.ShouldNotBeNull(); + opts.Mqtt!.NoAuthUser.ShouldBe("mqtt_default"); + } + + [Fact] + public void MqttConf_Authorization() + { + var opts = ConfigProcessor.ProcessConfigFile(TestDataPath("mqtt.conf")); + opts.Mqtt.ShouldNotBeNull(); + opts.Mqtt!.Username.ShouldBe("mqtt_user"); + opts.Mqtt.Password.ShouldBe("mqtt_pass"); + opts.Mqtt.Token.ShouldBe("mqtt_token"); + opts.Mqtt.AuthTimeout.ShouldBe(3.0); + } + + [Fact] + public void MqttConf_Tls() + { + var opts = ConfigProcessor.ProcessConfigFile(TestDataPath("mqtt.conf")); + opts.Mqtt.ShouldNotBeNull(); + opts.Mqtt!.TlsCert.ShouldBe("/path/to/mqtt-cert.pem"); + opts.Mqtt.TlsKey.ShouldBe("/path/to/mqtt-key.pem"); + opts.Mqtt.TlsCaCert.ShouldBe("/path/to/mqtt-ca.pem"); + opts.Mqtt.TlsVerify.ShouldBeTrue(); + opts.Mqtt.TlsTimeout.ShouldBe(5.0); + opts.Mqtt.HasTls.ShouldBeTrue(); + } + + [Fact] + public void MqttConf_QosSettings() + { + var opts = ConfigProcessor.ProcessConfigFile(TestDataPath("mqtt.conf")); + opts.Mqtt.ShouldNotBeNull(); + opts.Mqtt!.AckWait.ShouldBe(TimeSpan.FromSeconds(60)); + opts.Mqtt.MaxAckPending.ShouldBe((ushort)2048); + opts.Mqtt.JsApiTimeout.ShouldBe(TimeSpan.FromSeconds(10)); + } + + [Fact] + public void MqttConf_JetStreamSettings() + { + var opts = ConfigProcessor.ProcessConfigFile(TestDataPath("mqtt.conf")); + opts.Mqtt.ShouldNotBeNull(); + opts.Mqtt!.JsDomain.ShouldBe("mqtt-domain"); + opts.Mqtt.StreamReplicas.ShouldBe(3); + opts.Mqtt.ConsumerReplicas.ShouldBe(1); + opts.Mqtt.ConsumerMemoryStorage.ShouldBeTrue(); + opts.Mqtt.ConsumerInactiveThreshold.ShouldBe(TimeSpan.FromMinutes(5)); + } + + [Fact] + public void MqttConf_MaxAckPendingValidation_ReportsError() + { + var ex = Should.Throw(() => + ConfigProcessor.ProcessConfig(""" + mqtt { + max_ack_pending: 70000 + } + """)); + ex.Errors.ShouldContain(e => e.Contains("max_ack_pending")); + } + + [Fact] + public void MqttConf_Aliases() + { + // Test alias keys: "ackwait" (alias for "ack_wait"), "net" (alias for "host"), + // "max_inflight" (alias for "max_ack_pending"), "consumer_auto_cleanup" (alias) + var opts = ConfigProcessor.ProcessConfig(""" + mqtt { + net: "127.0.0.1" + port: 1884 + ackwait: "45s" + max_inflight: 500 + api_timeout: "8s" + consumer_auto_cleanup: "10m" + } + """); + opts.Mqtt.ShouldNotBeNull(); + opts.Mqtt!.Host.ShouldBe("127.0.0.1"); + opts.Mqtt.Port.ShouldBe(1884); + opts.Mqtt.AckWait.ShouldBe(TimeSpan.FromSeconds(45)); + opts.Mqtt.MaxAckPending.ShouldBe((ushort)500); + opts.Mqtt.JsApiTimeout.ShouldBe(TimeSpan.FromSeconds(8)); + opts.Mqtt.ConsumerInactiveThreshold.ShouldBe(TimeSpan.FromMinutes(10)); + } + + [Fact] + public void MqttConf_Absent_ReturnsNull() + { + var opts = ConfigProcessor.ProcessConfig("port: 4222"); + opts.Mqtt.ShouldBeNull(); + } } diff --git a/tests/NATS.Server.Tests/JwtAuthenticatorTests.cs b/tests/NATS.Server.Tests/JwtAuthenticatorTests.cs index 7cb0eaf..7e60f76 100644 --- a/tests/NATS.Server.Tests/JwtAuthenticatorTests.cs +++ b/tests/NATS.Server.Tests/JwtAuthenticatorTests.cs @@ -588,4 +588,279 @@ public class JwtAuthenticatorTests auth.Authenticate(ctx).ShouldBeNull(); } + + // ========================================================================= + // allowed_connection_types tests + // ========================================================================= + + [Fact] + public async Task Allowed_connection_types_allows_standard_context() + { + var operatorKp = KeyPair.CreatePair(PrefixByte.Operator); + var accountKp = KeyPair.CreatePair(PrefixByte.Account); + var userKp = KeyPair.CreatePair(PrefixByte.User); + + var operatorPub = operatorKp.GetPublicKey(); + var accountPub = accountKp.GetPublicKey(); + var userPub = userKp.GetPublicKey(); + + var accountPayload = $$""" + { + "sub":"{{accountPub}}", + "iss":"{{operatorPub}}", + "iat":1700000000, + "nats":{"type":"account","version":2} + } + """; + var accountJwt = BuildSignedToken(accountPayload, operatorKp); + + var userPayload = $$""" + { + "sub":"{{userPub}}", + "iss":"{{accountPub}}", + "iat":1700000000, + "nats":{ + "type":"user","version":2, + "bearer_token":true, + "issuer_account":"{{accountPub}}", + "allowed_connection_types":["STANDARD"] + } + } + """; + var userJwt = BuildSignedToken(userPayload, accountKp); + + var resolver = new MemAccountResolver(); + await resolver.StoreAsync(accountPub, accountJwt); + + var auth = new JwtAuthenticator([operatorPub], resolver); + + var ctx = new ClientAuthContext + { + Opts = new ClientOptions { JWT = userJwt }, + Nonce = "nonce"u8.ToArray(), + ConnectionType = "STANDARD", + }; + + var result = auth.Authenticate(ctx); + + result.ShouldNotBeNull(); + result.Identity.ShouldBe(userPub); + } + + [Fact] + public async Task Allowed_connection_types_rejects_mqtt_only_for_standard_context() + { + var operatorKp = KeyPair.CreatePair(PrefixByte.Operator); + var accountKp = KeyPair.CreatePair(PrefixByte.Account); + var userKp = KeyPair.CreatePair(PrefixByte.User); + + var operatorPub = operatorKp.GetPublicKey(); + var accountPub = accountKp.GetPublicKey(); + var userPub = userKp.GetPublicKey(); + + var accountPayload = $$""" + { + "sub":"{{accountPub}}", + "iss":"{{operatorPub}}", + "iat":1700000000, + "nats":{"type":"account","version":2} + } + """; + var accountJwt = BuildSignedToken(accountPayload, operatorKp); + + // User JWT only allows MQTT connections + var userPayload = $$""" + { + "sub":"{{userPub}}", + "iss":"{{accountPub}}", + "iat":1700000000, + "nats":{ + "type":"user","version":2, + "bearer_token":true, + "issuer_account":"{{accountPub}}", + "allowed_connection_types":["MQTT"] + } + } + """; + var userJwt = BuildSignedToken(userPayload, accountKp); + + var resolver = new MemAccountResolver(); + await resolver.StoreAsync(accountPub, accountJwt); + + var auth = new JwtAuthenticator([operatorPub], resolver); + + var ctx = new ClientAuthContext + { + Opts = new ClientOptions { JWT = userJwt }, + Nonce = "nonce"u8.ToArray(), + ConnectionType = "STANDARD", + }; + + // Should reject: STANDARD is not in allowed_connection_types + auth.Authenticate(ctx).ShouldBeNull(); + } + + [Fact] + public async Task Allowed_connection_types_allows_known_even_with_unknown_values() + { + var operatorKp = KeyPair.CreatePair(PrefixByte.Operator); + var accountKp = KeyPair.CreatePair(PrefixByte.Account); + var userKp = KeyPair.CreatePair(PrefixByte.User); + + var operatorPub = operatorKp.GetPublicKey(); + var accountPub = accountKp.GetPublicKey(); + var userPub = userKp.GetPublicKey(); + + var accountPayload = $$""" + { + "sub":"{{accountPub}}", + "iss":"{{operatorPub}}", + "iat":1700000000, + "nats":{"type":"account","version":2} + } + """; + var accountJwt = BuildSignedToken(accountPayload, operatorKp); + + // User JWT allows STANDARD and an unknown type + var userPayload = $$""" + { + "sub":"{{userPub}}", + "iss":"{{accountPub}}", + "iat":1700000000, + "nats":{ + "type":"user","version":2, + "bearer_token":true, + "issuer_account":"{{accountPub}}", + "allowed_connection_types":["STANDARD","SOME_NEW_TYPE"] + } + } + """; + var userJwt = BuildSignedToken(userPayload, accountKp); + + var resolver = new MemAccountResolver(); + await resolver.StoreAsync(accountPub, accountJwt); + + var auth = new JwtAuthenticator([operatorPub], resolver); + + var ctx = new ClientAuthContext + { + Opts = new ClientOptions { JWT = userJwt }, + Nonce = "nonce"u8.ToArray(), + ConnectionType = "STANDARD", + }; + + var result = auth.Authenticate(ctx); + + result.ShouldNotBeNull(); + result.Identity.ShouldBe(userPub); + } + + [Fact] + public async Task Allowed_connection_types_rejects_when_only_unknown_values_present() + { + var operatorKp = KeyPair.CreatePair(PrefixByte.Operator); + var accountKp = KeyPair.CreatePair(PrefixByte.Account); + var userKp = KeyPair.CreatePair(PrefixByte.User); + + var operatorPub = operatorKp.GetPublicKey(); + var accountPub = accountKp.GetPublicKey(); + var userPub = userKp.GetPublicKey(); + + var accountPayload = $$""" + { + "sub":"{{accountPub}}", + "iss":"{{operatorPub}}", + "iat":1700000000, + "nats":{"type":"account","version":2} + } + """; + var accountJwt = BuildSignedToken(accountPayload, operatorKp); + + // User JWT only allows an unknown connection type + var userPayload = $$""" + { + "sub":"{{userPub}}", + "iss":"{{accountPub}}", + "iat":1700000000, + "nats":{ + "type":"user","version":2, + "bearer_token":true, + "issuer_account":"{{accountPub}}", + "allowed_connection_types":["SOME_NEW_TYPE"] + } + } + """; + var userJwt = BuildSignedToken(userPayload, accountKp); + + var resolver = new MemAccountResolver(); + await resolver.StoreAsync(accountPub, accountJwt); + + var auth = new JwtAuthenticator([operatorPub], resolver); + + var ctx = new ClientAuthContext + { + Opts = new ClientOptions { JWT = userJwt }, + Nonce = "nonce"u8.ToArray(), + ConnectionType = "STANDARD", + }; + + // Should reject: STANDARD is not in allowed_connection_types + auth.Authenticate(ctx).ShouldBeNull(); + } + + [Fact] + public async Task Allowed_connection_types_is_case_insensitive_for_input_values() + { + var operatorKp = KeyPair.CreatePair(PrefixByte.Operator); + var accountKp = KeyPair.CreatePair(PrefixByte.Account); + var userKp = KeyPair.CreatePair(PrefixByte.User); + + var operatorPub = operatorKp.GetPublicKey(); + var accountPub = accountKp.GetPublicKey(); + var userPub = userKp.GetPublicKey(); + + var accountPayload = $$""" + { + "sub":"{{accountPub}}", + "iss":"{{operatorPub}}", + "iat":1700000000, + "nats":{"type":"account","version":2} + } + """; + var accountJwt = BuildSignedToken(accountPayload, operatorKp); + + // User JWT allows "standard" (lowercase) + var userPayload = $$""" + { + "sub":"{{userPub}}", + "iss":"{{accountPub}}", + "iat":1700000000, + "nats":{ + "type":"user","version":2, + "bearer_token":true, + "issuer_account":"{{accountPub}}", + "allowed_connection_types":["standard"] + } + } + """; + var userJwt = BuildSignedToken(userPayload, accountKp); + + var resolver = new MemAccountResolver(); + await resolver.StoreAsync(accountPub, accountJwt); + + var auth = new JwtAuthenticator([operatorPub], resolver); + + var ctx = new ClientAuthContext + { + Opts = new ClientOptions { JWT = userJwt }, + Nonce = "nonce"u8.ToArray(), + ConnectionType = "STANDARD", + }; + + // Should allow: case-insensitive match of "standard" == "STANDARD" + var result = auth.Authenticate(ctx); + + result.ShouldNotBeNull(); + result.Identity.ShouldBe(userPub); + } } diff --git a/tests/NATS.Server.Tests/MonitorTests.cs b/tests/NATS.Server.Tests/MonitorTests.cs index e89a0db..bb1356a 100644 --- a/tests/NATS.Server.Tests/MonitorTests.cs +++ b/tests/NATS.Server.Tests/MonitorTests.cs @@ -203,6 +203,51 @@ public class MonitorTests : IAsyncLifetime closed.Reason.ShouldNotBeNullOrEmpty(); } + [Fact] + public async Task Connz_filters_by_mqtt_client_for_open_connections() + { + // Connect a regular NATS client (no MQTT ID) + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _natsPort)); + using var stream = new NetworkStream(sock); + var buf = new byte[4096]; + _ = await stream.ReadAsync(buf); + await stream.WriteAsync("CONNECT {}\r\n"u8.ToArray()); + await Task.Delay(200); + + // Query for an MQTT client ID that no connection has + var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/connz?mqtt_client=some-id"); + response.StatusCode.ShouldBe(HttpStatusCode.OK); + + var connz = await response.Content.ReadFromJsonAsync(); + connz.ShouldNotBeNull(); + connz.NumConns.ShouldBe(0); + } + + [Fact] + public async Task Connz_filters_by_mqtt_client_for_closed_connections() + { + // Connect then disconnect a client so it appears in closed list + var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _natsPort)); + using var stream = new NetworkStream(sock); + var buf = new byte[4096]; + _ = await stream.ReadAsync(buf); + await stream.WriteAsync("CONNECT {}\r\n"u8.ToArray()); + await Task.Delay(200); + sock.Shutdown(SocketShutdown.Both); + sock.Dispose(); + await Task.Delay(500); + + // Query closed connections with an MQTT client ID that no connection has + var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/connz?state=closed&mqtt_client=missing-id"); + response.StatusCode.ShouldBe(HttpStatusCode.OK); + + var connz = await response.Content.ReadFromJsonAsync(); + connz.ShouldNotBeNull(); + connz.NumConns.ShouldBe(0); + } + [Fact] public async Task Connz_sort_by_stop_requires_closed_state() { @@ -226,6 +271,23 @@ public class MonitorTests : IAsyncLifetime response.StatusCode.ShouldBe(HttpStatusCode.OK); } + [Fact] + public async Task Varz_includes_mqtt_section() + { + var response = await _http.GetAsync($"http://127.0.0.1:{_monitorPort}/varz"); + response.StatusCode.ShouldBe(HttpStatusCode.OK); + + var varz = await response.Content.ReadFromJsonAsync(); + varz.ShouldNotBeNull(); + varz.Mqtt.ShouldNotBeNull(); + varz.Mqtt.Host.ShouldBe(""); + varz.Mqtt.Port.ShouldBe(0); + varz.Mqtt.NoAuthUser.ShouldBe(""); + varz.Mqtt.JsDomain.ShouldBe(""); + varz.Mqtt.AckWait.ShouldBe(0L); + varz.Mqtt.MaxAckPending.ShouldBe((ushort)0); + } + private static int GetFreePort() { using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); diff --git a/tests/NATS.Server.Tests/TestData/mqtt.conf b/tests/NATS.Server.Tests/TestData/mqtt.conf new file mode 100644 index 0000000..e0692fa --- /dev/null +++ b/tests/NATS.Server.Tests/TestData/mqtt.conf @@ -0,0 +1,28 @@ +mqtt { + listen: "10.0.0.1:1883" + no_auth_user: "mqtt_default" + + authorization { + user: "mqtt_user" + pass: "mqtt_pass" + token: "mqtt_token" + timeout: 3.0 + } + + tls { + cert_file: "/path/to/mqtt-cert.pem" + key_file: "/path/to/mqtt-key.pem" + ca_file: "/path/to/mqtt-ca.pem" + verify: true + timeout: 5.0 + } + + ack_wait: "60s" + max_ack_pending: 2048 + js_domain: "mqtt-domain" + js_api_timeout: "10s" + stream_replicas: 3 + consumer_replicas: 1 + consumer_memory_storage: true + consumer_inactive_threshold: "5m" +}