Merge branch 'feature/mqtt-connection-type'
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using NATS.Server.Auth.Jwt;
|
||||
using NATS.Server.Protocol;
|
||||
|
||||
namespace NATS.Server.Auth;
|
||||
@@ -13,4 +14,11 @@ public sealed class ClientAuthContext
|
||||
public required ClientOptions Opts { get; init; }
|
||||
public required byte[] Nonce { get; init; }
|
||||
public X509Certificate2? ClientCertificate { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// The type of connection (e.g., "STANDARD", "WEBSOCKET", "MQTT", "LEAFNODE").
|
||||
/// Used by JWT authenticator to enforce allowed_connection_types claims.
|
||||
/// Defaults to "STANDARD" for regular NATS client connections.
|
||||
/// </summary>
|
||||
public string ConnectionType { get; init; } = JwtConnectionTypes.Standard;
|
||||
}
|
||||
|
||||
34
src/NATS.Server/Auth/Jwt/JwtConnectionTypes.cs
Normal file
34
src/NATS.Server/Auth/Jwt/JwtConnectionTypes.cs
Normal file
@@ -0,0 +1,34 @@
|
||||
namespace NATS.Server.Auth.Jwt;
|
||||
|
||||
internal static class JwtConnectionTypes
|
||||
{
|
||||
public const string Standard = "STANDARD";
|
||||
public const string Websocket = "WEBSOCKET";
|
||||
public const string Leafnode = "LEAFNODE";
|
||||
public const string LeafnodeWs = "LEAFNODE_WS";
|
||||
public const string Mqtt = "MQTT";
|
||||
public const string MqttWs = "MQTT_WS";
|
||||
public const string InProcess = "INPROCESS";
|
||||
|
||||
private static readonly HashSet<string> Known =
|
||||
[
|
||||
Standard, Websocket, Leafnode, LeafnodeWs, Mqtt, MqttWs, InProcess,
|
||||
];
|
||||
|
||||
public static (HashSet<string> Valid, bool HasUnknown) Convert(IEnumerable<string>? values)
|
||||
{
|
||||
var valid = new HashSet<string>(StringComparer.Ordinal);
|
||||
var hasUnknown = false;
|
||||
if (values is null) return (valid, false);
|
||||
|
||||
foreach (var raw in values)
|
||||
{
|
||||
var up = (raw ?? string.Empty).Trim().ToUpperInvariant();
|
||||
if (up.Length == 0) continue;
|
||||
if (Known.Contains(up)) valid.Add(up);
|
||||
else hasUnknown = true;
|
||||
}
|
||||
|
||||
return (valid, hasUnknown);
|
||||
}
|
||||
}
|
||||
@@ -95,6 +95,24 @@ public sealed class JwtAuthenticator : IAuthenticator
|
||||
}
|
||||
}
|
||||
|
||||
// 7b. Check allowed connection types
|
||||
var (allowedTypes, hasUnknown) = JwtConnectionTypes.Convert(userClaims.Nats?.AllowedConnectionTypes);
|
||||
|
||||
if (allowedTypes.Count == 0)
|
||||
{
|
||||
if (hasUnknown)
|
||||
return null; // unknown-only list should reject
|
||||
}
|
||||
else
|
||||
{
|
||||
var connType = string.IsNullOrWhiteSpace(context.ConnectionType)
|
||||
? JwtConnectionTypes.Standard
|
||||
: context.ConnectionType.ToUpperInvariant();
|
||||
|
||||
if (!allowedTypes.Contains(connType))
|
||||
return null;
|
||||
}
|
||||
|
||||
// 8. Build permissions from JWT claims
|
||||
Permissions? permissions = null;
|
||||
var nats = userClaims.Nats;
|
||||
|
||||
@@ -245,6 +245,12 @@ public static class ConfigProcessor
|
||||
opts.ReconnectErrorReports = ToInt(value);
|
||||
break;
|
||||
|
||||
// MQTT
|
||||
case "mqtt":
|
||||
if (value is Dictionary<string, object?> mqttDict)
|
||||
ParseMqtt(mqttDict, opts, errors);
|
||||
break;
|
||||
|
||||
// Unknown keys silently ignored (cluster, jetstream, gateway, leafnode, etc.)
|
||||
default:
|
||||
break;
|
||||
@@ -620,6 +626,145 @@ public static class ConfigProcessor
|
||||
opts.Tags = tags;
|
||||
}
|
||||
|
||||
// ─── MQTT parsing ────────────────────────────────────────────────
|
||||
// Reference: Go server/opts.go parseMQTT (lines ~5443-5541)
|
||||
|
||||
private static void ParseMqtt(Dictionary<string, object?> dict, NatsOptions opts, List<string> errors)
|
||||
{
|
||||
var mqtt = opts.Mqtt ?? new MqttOptions();
|
||||
|
||||
foreach (var (key, value) in dict)
|
||||
{
|
||||
switch (key.ToLowerInvariant())
|
||||
{
|
||||
case "listen":
|
||||
var (host, port) = ParseHostPort(value);
|
||||
if (host is not null) mqtt.Host = host;
|
||||
if (port is not null) mqtt.Port = port.Value;
|
||||
break;
|
||||
case "port":
|
||||
mqtt.Port = ToInt(value);
|
||||
break;
|
||||
case "host" or "net":
|
||||
mqtt.Host = ToString(value);
|
||||
break;
|
||||
case "no_auth_user":
|
||||
mqtt.NoAuthUser = ToString(value);
|
||||
break;
|
||||
case "tls":
|
||||
if (value is Dictionary<string, object?> tlsDict)
|
||||
ParseMqttTls(tlsDict, mqtt, errors);
|
||||
break;
|
||||
case "authorization" or "authentication":
|
||||
if (value is Dictionary<string, object?> authDict)
|
||||
ParseMqttAuth(authDict, mqtt, errors);
|
||||
break;
|
||||
case "ack_wait" or "ackwait":
|
||||
mqtt.AckWait = ParseDuration(value);
|
||||
break;
|
||||
case "js_api_timeout" or "api_timeout":
|
||||
mqtt.JsApiTimeout = ParseDuration(value);
|
||||
break;
|
||||
case "max_ack_pending" or "max_pending" or "max_inflight":
|
||||
var pending = ToInt(value);
|
||||
if (pending < 0 || pending > 0xFFFF)
|
||||
errors.Add($"mqtt max_ack_pending invalid value {pending}, should be in [0..{0xFFFF}] range");
|
||||
else
|
||||
mqtt.MaxAckPending = (ushort)pending;
|
||||
break;
|
||||
case "js_domain":
|
||||
mqtt.JsDomain = ToString(value);
|
||||
break;
|
||||
case "stream_replicas":
|
||||
mqtt.StreamReplicas = ToInt(value);
|
||||
break;
|
||||
case "consumer_replicas":
|
||||
mqtt.ConsumerReplicas = ToInt(value);
|
||||
break;
|
||||
case "consumer_memory_storage":
|
||||
mqtt.ConsumerMemoryStorage = ToBool(value);
|
||||
break;
|
||||
case "consumer_inactive_threshold" or "consumer_auto_cleanup":
|
||||
mqtt.ConsumerInactiveThreshold = ParseDuration(value);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
opts.Mqtt = mqtt;
|
||||
}
|
||||
|
||||
private static void ParseMqttAuth(Dictionary<string, object?> dict, MqttOptions mqtt, List<string> errors)
|
||||
{
|
||||
foreach (var (key, value) in dict)
|
||||
{
|
||||
switch (key.ToLowerInvariant())
|
||||
{
|
||||
case "user" or "username":
|
||||
mqtt.Username = ToString(value);
|
||||
break;
|
||||
case "pass" or "password":
|
||||
mqtt.Password = ToString(value);
|
||||
break;
|
||||
case "token":
|
||||
mqtt.Token = ToString(value);
|
||||
break;
|
||||
case "timeout":
|
||||
mqtt.AuthTimeout = ToDouble(value);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void ParseMqttTls(Dictionary<string, object?> dict, MqttOptions mqtt, List<string> errors)
|
||||
{
|
||||
foreach (var (key, value) in dict)
|
||||
{
|
||||
switch (key.ToLowerInvariant())
|
||||
{
|
||||
case "cert_file":
|
||||
mqtt.TlsCert = ToString(value);
|
||||
break;
|
||||
case "key_file":
|
||||
mqtt.TlsKey = ToString(value);
|
||||
break;
|
||||
case "ca_file":
|
||||
mqtt.TlsCaCert = ToString(value);
|
||||
break;
|
||||
case "verify":
|
||||
mqtt.TlsVerify = ToBool(value);
|
||||
break;
|
||||
case "verify_and_map":
|
||||
var map = ToBool(value);
|
||||
mqtt.TlsMap = map;
|
||||
if (map) mqtt.TlsVerify = true;
|
||||
break;
|
||||
case "timeout":
|
||||
mqtt.TlsTimeout = ToDouble(value);
|
||||
break;
|
||||
case "pinned_certs":
|
||||
if (value is List<object?> pinnedList)
|
||||
{
|
||||
var certs = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
|
||||
foreach (var item in pinnedList)
|
||||
{
|
||||
if (item is string s)
|
||||
certs.Add(s.ToLowerInvariant());
|
||||
}
|
||||
|
||||
mqtt.TlsPinnedCerts = certs;
|
||||
}
|
||||
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Type conversion helpers ───────────────────────────────────
|
||||
|
||||
private static int ToInt(object? value) => value switch
|
||||
@@ -653,6 +798,15 @@ public static class ConfigProcessor
|
||||
_ => throw new FormatException($"Cannot convert {value?.GetType().Name ?? "null"} to string"),
|
||||
};
|
||||
|
||||
private static double ToDouble(object? value) => value switch
|
||||
{
|
||||
double d => d,
|
||||
long l => l,
|
||||
int i => i,
|
||||
string s when double.TryParse(s, NumberStyles.Float, CultureInfo.InvariantCulture, out var d) => d,
|
||||
_ => throw new FormatException($"Cannot convert {value?.GetType().Name ?? "null"} to double"),
|
||||
};
|
||||
|
||||
private static IReadOnlyList<string> ToStringList(object? value)
|
||||
{
|
||||
if (value is List<object?> list)
|
||||
|
||||
@@ -22,4 +22,5 @@ public sealed record ClosedClient
|
||||
public TimeSpan Rtt { get; init; }
|
||||
public string TlsVersion { get; init; } = "";
|
||||
public string TlsCipherSuite { get; init; } = "";
|
||||
public string MqttClient { get; init; } = "";
|
||||
}
|
||||
|
||||
@@ -204,6 +204,8 @@ public sealed class ConnzOptions
|
||||
|
||||
public string FilterSubject { get; set; } = "";
|
||||
|
||||
public string MqttClient { get; set; } = "";
|
||||
|
||||
public int Offset { get; set; }
|
||||
|
||||
public int Limit { get; set; } = 1024;
|
||||
|
||||
@@ -28,6 +28,10 @@ public sealed class ConnzHandler(NatsServer server)
|
||||
connInfos.AddRange(server.GetClosedClients().Select(c => BuildClosedConnInfo(c, now, opts)));
|
||||
}
|
||||
|
||||
// Filter by MQTT client ID
|
||||
if (!string.IsNullOrEmpty(opts.MqttClient))
|
||||
connInfos = connInfos.Where(c => c.MqttClient == opts.MqttClient).ToList();
|
||||
|
||||
// Validate sort options that require closed state
|
||||
if (opts.Sort is SortOpt.ByStop or SortOpt.ByReason && opts.State == ConnState.Open)
|
||||
opts.Sort = SortOpt.ByCid; // Fallback
|
||||
@@ -142,6 +146,7 @@ public sealed class ConnzHandler(NatsServer server)
|
||||
Rtt = FormatRtt(closed.Rtt),
|
||||
TlsVersion = closed.TlsVersion,
|
||||
TlsCipherSuite = closed.TlsCipherSuite,
|
||||
MqttClient = closed.MqttClient,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -197,6 +202,9 @@ public sealed class ConnzHandler(NatsServer server)
|
||||
if (q.TryGetValue("limit", out var limit) && int.TryParse(limit, out var l))
|
||||
opts.Limit = l;
|
||||
|
||||
if (q.TryGetValue("mqtt_client", out var mqttClient))
|
||||
opts.MqttClient = mqttClient.ToString();
|
||||
|
||||
return opts;
|
||||
}
|
||||
|
||||
|
||||
@@ -355,8 +355,29 @@ public sealed class MqttOptsVarz
|
||||
[JsonPropertyName("port")]
|
||||
public int Port { get; set; }
|
||||
|
||||
[JsonPropertyName("no_auth_user")]
|
||||
public string NoAuthUser { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("auth_timeout")]
|
||||
public double AuthTimeout { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_map")]
|
||||
public bool TlsMap { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_timeout")]
|
||||
public double TlsTimeout { get; set; }
|
||||
|
||||
[JsonPropertyName("tls_pinned_certs")]
|
||||
public string[] TlsPinnedCerts { get; set; } = [];
|
||||
|
||||
[JsonPropertyName("js_domain")]
|
||||
public string JsDomain { get; set; } = "";
|
||||
|
||||
[JsonPropertyName("ack_wait")]
|
||||
public long AckWait { get; set; }
|
||||
|
||||
[JsonPropertyName("max_ack_pending")]
|
||||
public ushort MaxAckPending { get; set; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
||||
@@ -121,6 +121,7 @@ public sealed class VarzHandler : IDisposable
|
||||
Subscriptions = _server.SubList.Count,
|
||||
ConfigLoadTime = _server.StartTime,
|
||||
HttpReqStats = stats.HttpReqStats.ToDictionary(kv => kv.Key, kv => (ulong)kv.Value),
|
||||
Mqtt = BuildMqttVarz(),
|
||||
};
|
||||
}
|
||||
finally
|
||||
@@ -134,6 +135,27 @@ public sealed class VarzHandler : IDisposable
|
||||
_varzMu.Dispose();
|
||||
}
|
||||
|
||||
private MqttOptsVarz BuildMqttVarz()
|
||||
{
|
||||
var mqtt = _options.Mqtt;
|
||||
if (mqtt is null)
|
||||
return new MqttOptsVarz();
|
||||
|
||||
return new MqttOptsVarz
|
||||
{
|
||||
Host = mqtt.Host,
|
||||
Port = mqtt.Port,
|
||||
NoAuthUser = mqtt.NoAuthUser ?? "",
|
||||
AuthTimeout = mqtt.AuthTimeout,
|
||||
TlsMap = mqtt.TlsMap,
|
||||
TlsTimeout = mqtt.TlsTimeout,
|
||||
TlsPinnedCerts = mqtt.TlsPinnedCerts?.ToArray() ?? [],
|
||||
JsDomain = mqtt.JsDomain ?? "",
|
||||
AckWait = (long)mqtt.AckWait.TotalNanoseconds,
|
||||
MaxAckPending = mqtt.MaxAckPending,
|
||||
};
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Formats a TimeSpan as a human-readable uptime string matching Go server format.
|
||||
/// </summary>
|
||||
|
||||
43
src/NATS.Server/MqttOptions.cs
Normal file
43
src/NATS.Server/MqttOptions.cs
Normal file
@@ -0,0 +1,43 @@
|
||||
namespace NATS.Server;
|
||||
|
||||
/// <summary>
|
||||
/// MQTT protocol configuration options.
|
||||
/// Corresponds to Go server/opts.go MQTTOpts struct.
|
||||
/// Config is parsed and stored but no MQTT listener is started yet.
|
||||
/// </summary>
|
||||
public sealed class MqttOptions
|
||||
{
|
||||
// Network
|
||||
public string Host { get; set; } = "";
|
||||
public int Port { get; set; }
|
||||
|
||||
// Auth override (MQTT-specific, separate from global auth)
|
||||
public string? NoAuthUser { get; set; }
|
||||
public string? Username { get; set; }
|
||||
public string? Password { get; set; }
|
||||
public string? Token { get; set; }
|
||||
public double AuthTimeout { get; set; }
|
||||
|
||||
// TLS
|
||||
public string? TlsCert { get; set; }
|
||||
public string? TlsKey { get; set; }
|
||||
public string? TlsCaCert { get; set; }
|
||||
public bool TlsVerify { get; set; }
|
||||
public double TlsTimeout { get; set; } = 2.0;
|
||||
public bool TlsMap { get; set; }
|
||||
public HashSet<string>? TlsPinnedCerts { get; set; }
|
||||
|
||||
// JetStream integration
|
||||
public string? JsDomain { get; set; }
|
||||
public int StreamReplicas { get; set; }
|
||||
public int ConsumerReplicas { get; set; }
|
||||
public bool ConsumerMemoryStorage { get; set; }
|
||||
public TimeSpan ConsumerInactiveThreshold { get; set; }
|
||||
|
||||
// QoS
|
||||
public TimeSpan AckWait { get; set; } = TimeSpan.FromSeconds(30);
|
||||
public ushort MaxAckPending { get; set; }
|
||||
public TimeSpan JsApiTimeout { get; set; } = TimeSpan.FromSeconds(5);
|
||||
|
||||
public bool HasTls => TlsCert != null && TlsKey != null;
|
||||
}
|
||||
@@ -8,6 +8,7 @@ using System.Text.Json;
|
||||
using System.Threading.Channels;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using NATS.Server.Auth;
|
||||
using NATS.Server.Auth.Jwt;
|
||||
using NATS.Server.Protocol;
|
||||
using NATS.Server.Subscriptions;
|
||||
using NATS.Server.Tls;
|
||||
@@ -391,6 +392,7 @@ public sealed class NatsClient : INatsClient, IDisposable
|
||||
Opts = ClientOpts,
|
||||
Nonce = _nonce ?? [],
|
||||
ClientCertificate = TlsState?.PeerCert,
|
||||
ConnectionType = JwtConnectionTypes.Standard,
|
||||
};
|
||||
|
||||
authResult = _authService.Authenticate(context);
|
||||
|
||||
@@ -115,6 +115,9 @@ public sealed class NatsOptions
|
||||
// Subject mapping / transforms (source pattern -> destination template)
|
||||
public Dictionary<string, string>? SubjectMappings { get; set; }
|
||||
|
||||
// MQTT configuration (parsed from config, no listener yet)
|
||||
public MqttOptions? Mqtt { get; set; }
|
||||
|
||||
public bool HasTls => TlsCert != null && TlsKey != null;
|
||||
|
||||
// WebSocket
|
||||
|
||||
@@ -1221,6 +1221,7 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
Rtt = client.Rtt,
|
||||
TlsVersion = client.TlsState?.TlsVersion ?? "",
|
||||
TlsCipherSuite = client.TlsState?.CipherSuite ?? "",
|
||||
MqttClient = "", // populated when MQTT transport is implemented
|
||||
});
|
||||
|
||||
// Cap closed clients list
|
||||
|
||||
Reference in New Issue
Block a user