diff --git a/src/NATS.Server/Auth/IAuthenticator.cs b/src/NATS.Server/Auth/IAuthenticator.cs index fb28f0c..abb8db3 100644 --- a/src/NATS.Server/Auth/IAuthenticator.cs +++ b/src/NATS.Server/Auth/IAuthenticator.cs @@ -1,4 +1,5 @@ using System.Security.Cryptography.X509Certificates; +using NATS.Server.Auth.Jwt; using NATS.Server.Protocol; namespace NATS.Server.Auth; @@ -19,5 +20,5 @@ public sealed class ClientAuthContext /// Used by JWT authenticator to enforce allowed_connection_types claims. /// Defaults to "STANDARD" for regular NATS client connections. /// - public string ConnectionType { get; init; } = "STANDARD"; + public string ConnectionType { get; init; } = JwtConnectionTypes.Standard; } diff --git a/src/NATS.Server/Auth/Jwt/JwtConnectionTypes.cs b/src/NATS.Server/Auth/Jwt/JwtConnectionTypes.cs new file mode 100644 index 0000000..59d2418 --- /dev/null +++ b/src/NATS.Server/Auth/Jwt/JwtConnectionTypes.cs @@ -0,0 +1,34 @@ +namespace NATS.Server.Auth.Jwt; + +internal static class JwtConnectionTypes +{ + public const string Standard = "STANDARD"; + public const string Websocket = "WEBSOCKET"; + public const string Leafnode = "LEAFNODE"; + public const string LeafnodeWs = "LEAFNODE_WS"; + public const string Mqtt = "MQTT"; + public const string MqttWs = "MQTT_WS"; + public const string InProcess = "INPROCESS"; + + private static readonly HashSet Known = + [ + Standard, Websocket, Leafnode, LeafnodeWs, Mqtt, MqttWs, InProcess, + ]; + + public static (HashSet Valid, bool HasUnknown) Convert(IEnumerable? values) + { + var valid = new HashSet(StringComparer.Ordinal); + var hasUnknown = false; + if (values is null) return (valid, false); + + foreach (var raw in values) + { + var up = (raw ?? string.Empty).Trim().ToUpperInvariant(); + if (up.Length == 0) continue; + if (Known.Contains(up)) valid.Add(up); + else hasUnknown = true; + } + + return (valid, hasUnknown); + } +} diff --git a/src/NATS.Server/Auth/JwtAuthenticator.cs b/src/NATS.Server/Auth/JwtAuthenticator.cs index f28a155..126fb83 100644 --- a/src/NATS.Server/Auth/JwtAuthenticator.cs +++ b/src/NATS.Server/Auth/JwtAuthenticator.cs @@ -95,6 +95,24 @@ public sealed class JwtAuthenticator : IAuthenticator } } + // 7b. Check allowed connection types + var (allowedTypes, hasUnknown) = JwtConnectionTypes.Convert(userClaims.Nats?.AllowedConnectionTypes); + + if (allowedTypes.Count == 0) + { + if (hasUnknown) + return null; // unknown-only list should reject + } + else + { + var connType = string.IsNullOrWhiteSpace(context.ConnectionType) + ? JwtConnectionTypes.Standard + : context.ConnectionType.ToUpperInvariant(); + + if (!allowedTypes.Contains(connType)) + return null; + } + // 8. Build permissions from JWT claims Permissions? permissions = null; var nats = userClaims.Nats; diff --git a/src/NATS.Server/NatsClient.cs b/src/NATS.Server/NatsClient.cs index 620f275..5c35cf5 100644 --- a/src/NATS.Server/NatsClient.cs +++ b/src/NATS.Server/NatsClient.cs @@ -8,6 +8,7 @@ using System.Text.Json; using System.Threading.Channels; using Microsoft.Extensions.Logging; using NATS.Server.Auth; +using NATS.Server.Auth.Jwt; using NATS.Server.Protocol; using NATS.Server.Subscriptions; using NATS.Server.Tls; @@ -388,6 +389,7 @@ public sealed class NatsClient : IDisposable Opts = ClientOpts, Nonce = _nonce ?? [], ClientCertificate = TlsState?.PeerCert, + ConnectionType = JwtConnectionTypes.Standard, }; authResult = _authService.Authenticate(context);