feat: enforce mqtt auth tls and keepalive semantics

This commit is contained in:
Joseph Doherty
2026-02-23 14:45:37 -05:00
parent 7dcf5776b3
commit b2312c0dac
8 changed files with 167 additions and 14 deletions

View File

@@ -149,6 +149,19 @@ public sealed class AuthService
return raw.ToArray(); return raw.ToArray();
} }
public static bool ValidateMqttCredentials(
string? configuredUsername,
string? configuredPassword,
string? providedUsername,
string? providedPassword)
{
if (string.IsNullOrEmpty(configuredUsername) && string.IsNullOrEmpty(configuredPassword))
return true;
return string.Equals(configuredUsername, providedUsername, StringComparison.Ordinal)
&& string.Equals(configuredPassword, providedPassword, StringComparison.Ordinal);
}
public string EncodeNonce(byte[] nonce) public string EncodeNonce(byte[] nonce)
{ {
return Convert.ToBase64String(nonce) return Convert.ToBase64String(nonce)

View File

@@ -12,6 +12,7 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
private readonly SemaphoreSlim _writeGate = new(1, 1); private readonly SemaphoreSlim _writeGate = new(1, 1);
private string _clientId = string.Empty; private string _clientId = string.Empty;
private bool _cleanSession = true; private bool _cleanSession = true;
private TimeSpan _idleTimeout = Timeout.InfiniteTimeSpan;
public async Task RunAsync(CancellationToken ct) public async Task RunAsync(CancellationToken ct)
{ {
@@ -20,7 +21,7 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
string line; string line;
try try
{ {
line = await ReadLineAsync(ct); line = await ReadLineAsync(ct, _idleTimeout);
} }
catch catch
{ {
@@ -31,8 +32,15 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
switch (packet.Type) switch (packet.Type)
{ {
case MqttPacketType.Connect: case MqttPacketType.Connect:
if (!_listener.TryAuthenticate(packet.Username, packet.Password))
{
await WriteLineAsync("ERR mqtt auth failed", ct);
return;
}
_clientId = packet.ClientId; _clientId = packet.ClientId;
_cleanSession = packet.CleanSession; _cleanSession = packet.CleanSession;
_idleTimeout = _listener.ResolveKeepAliveTimeout(packet.KeepAliveSeconds);
var pending = _listener.OpenSession(_clientId, _cleanSession); var pending = _listener.OpenSession(_clientId, _cleanSession);
await WriteLineAsync("CONNACK", ct); await WriteLineAsync("CONNACK", ct);
foreach (var redelivery in pending) foreach (var redelivery in pending)
@@ -83,13 +91,24 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
} }
} }
private async Task<string> ReadLineAsync(CancellationToken ct) private async Task<string> ReadLineAsync(CancellationToken ct, TimeSpan idleTimeout)
{ {
CancellationToken token = ct;
CancellationTokenSource? timeoutCts = null;
if (idleTimeout != Timeout.InfiniteTimeSpan && idleTimeout > TimeSpan.Zero)
{
timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
timeoutCts.CancelAfter(idleTimeout);
token = timeoutCts.Token;
}
var bytes = new List<byte>(64); var bytes = new List<byte>(64);
var single = new byte[1]; var single = new byte[1];
try
{
while (true) while (true)
{ {
var read = await _stream.ReadAsync(single.AsMemory(0, 1), ct); var read = await _stream.ReadAsync(single.AsMemory(0, 1), token);
if (read == 0) if (read == 0)
throw new IOException("mqtt closed"); throw new IOException("mqtt closed");
if (single[0] == (byte)'\n') if (single[0] == (byte)'\n')
@@ -97,6 +116,15 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
if (single[0] != (byte)'\r') if (single[0] != (byte)'\r')
bytes.Add(single[0]); bytes.Add(single[0]);
} }
}
catch (OperationCanceledException) when (timeoutCts != null && !ct.IsCancellationRequested)
{
throw new IOException("mqtt keepalive timeout");
}
finally
{
timeoutCts?.Dispose();
}
return Encoding.UTF8.GetString([.. bytes]); return Encoding.UTF8.GetString([.. bytes]);
} }

View File

@@ -1,13 +1,20 @@
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Net; using System.Net;
using System.Net.Sockets; using System.Net.Sockets;
using NATS.Server.Auth;
namespace NATS.Server.Mqtt; namespace NATS.Server.Mqtt;
public sealed class MqttListener(string host, int port) : IAsyncDisposable public sealed class MqttListener(
string host,
int port,
string? requiredUsername = null,
string? requiredPassword = null) : IAsyncDisposable
{ {
private readonly string _host = host; private readonly string _host = host;
private int _port = port; private int _port = port;
private readonly string? _requiredUsername = requiredUsername;
private readonly string? _requiredPassword = requiredPassword;
private readonly ConcurrentDictionary<MqttConnection, byte> _connections = new(); private readonly ConcurrentDictionary<MqttConnection, byte> _connections = new();
private readonly ConcurrentDictionary<string, ConcurrentDictionary<MqttConnection, byte>> _subscriptions = new(StringComparer.Ordinal); private readonly ConcurrentDictionary<string, ConcurrentDictionary<MqttConnection, byte>> _subscriptions = new(StringComparer.Ordinal);
private readonly ConcurrentDictionary<string, MqttSessionState> _sessions = new(StringComparer.Ordinal); private readonly ConcurrentDictionary<string, MqttSessionState> _sessions = new(StringComparer.Ordinal);
@@ -85,6 +92,17 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable
session.Pending.TryRemove(packetId, out _); session.Pending.TryRemove(packetId, out _);
} }
internal bool TryAuthenticate(string? username, string? password)
=> AuthService.ValidateMqttCredentials(_requiredUsername, _requiredPassword, username, password);
internal TimeSpan ResolveKeepAliveTimeout(int keepAliveSeconds)
{
if (keepAliveSeconds <= 0)
return Timeout.InfiniteTimeSpan;
return TimeSpan.FromSeconds(Math.Max(keepAliveSeconds * 1.5, 1));
}
internal void Unregister(MqttConnection connection) internal void Unregister(MqttConnection connection)
{ {
_connections.TryRemove(connection, out _); _connections.TryRemove(connection, out _);

View File

@@ -16,7 +16,10 @@ public sealed record MqttPacket(
string Payload, string Payload,
string ClientId, string ClientId,
int PacketId = 0, int PacketId = 0,
bool CleanSession = true); bool CleanSession = true,
string? Username = null,
string? Password = null,
int KeepAliveSeconds = 0);
public sealed class MqttProtocolParser public sealed class MqttProtocolParser
{ {
@@ -39,6 +42,9 @@ public sealed class MqttProtocolParser
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty); return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
var cleanSession = true; var cleanSession = true;
string? username = null;
string? password = null;
var keepAliveSeconds = 0;
for (var i = 2; i < parts.Length; i++) for (var i = 2; i < parts.Length; i++)
{ {
if (parts[i].StartsWith("clean=", StringComparison.OrdinalIgnoreCase) if (parts[i].StartsWith("clean=", StringComparison.OrdinalIgnoreCase)
@@ -46,6 +52,19 @@ public sealed class MqttProtocolParser
{ {
cleanSession = parsedClean; cleanSession = parsedClean;
} }
if (parts[i].StartsWith("user=", StringComparison.OrdinalIgnoreCase))
username = parts[i]["user=".Length..];
if (parts[i].StartsWith("pass=", StringComparison.OrdinalIgnoreCase))
password = parts[i]["pass=".Length..];
if (parts[i].StartsWith("keepalive=", StringComparison.OrdinalIgnoreCase)
&& int.TryParse(parts[i]["keepalive=".Length..], out var parsedKeepAlive)
&& parsedKeepAlive >= 0)
{
keepAliveSeconds = parsedKeepAlive;
}
} }
return new MqttPacket( return new MqttPacket(
@@ -53,7 +72,10 @@ public sealed class MqttProtocolParser
string.Empty, string.Empty,
string.Empty, string.Empty,
parts[1], parts[1],
CleanSession: cleanSession); CleanSession: cleanSession,
Username: username,
Password: password,
KeepAliveSeconds: keepAliveSeconds);
} }
if (trimmed.StartsWith("SUB ", StringComparison.Ordinal)) if (trimmed.StartsWith("SUB ", StringComparison.Ordinal))

View File

@@ -553,7 +553,11 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
if (_options.Mqtt is { Port: > 0 } mqttOptions) if (_options.Mqtt is { Port: > 0 } mqttOptions)
{ {
var mqttHost = string.IsNullOrWhiteSpace(mqttOptions.Host) ? _options.Host : mqttOptions.Host; var mqttHost = string.IsNullOrWhiteSpace(mqttOptions.Host) ? _options.Host : mqttOptions.Host;
_mqttListener = new MqttListener(mqttHost, mqttOptions.Port); _mqttListener = new MqttListener(
mqttHost,
mqttOptions.Port,
mqttOptions.Username,
mqttOptions.Password);
await _mqttListener.StartAsync(linked.Token); await _mqttListener.StartAsync(linked.Token);
} }
if (_jetStreamService != null) if (_jetStreamService != null)

View File

@@ -0,0 +1,24 @@
using System.Net;
using System.Net.Sockets;
using NATS.Server.Mqtt;
namespace NATS.Server.Tests.Mqtt;
public class MqttAuthIntegrationTests
{
[Fact]
public async Task Invalid_mqtt_credentials_or_keepalive_timeout_close_session_with_protocol_error()
{
await using var listener = new MqttListener("127.0.0.1", 0, requiredUsername: "mqtt", requiredPassword: "secret");
using var cts = new CancellationTokenSource();
await listener.StartAsync(cts.Token);
using var client = new TcpClient();
await client.ConnectAsync(IPAddress.Loopback, listener.Port);
var stream = client.GetStream();
await MqttRuntimeWire.WriteLineAsync(stream, "CONNECT auth-client user=bad pass=wrong");
(await MqttRuntimeWire.ReadLineAsync(stream, 1000)).ShouldBe("ERR mqtt auth failed");
(await MqttRuntimeWire.ReadRawAsync(stream, 1000)).ShouldBeNull();
}
}

View File

@@ -0,0 +1,26 @@
using System.Net;
using System.Net.Sockets;
using NATS.Server.Mqtt;
namespace NATS.Server.Tests.Mqtt;
public class MqttKeepAliveTests
{
[Fact]
public async Task Invalid_mqtt_credentials_or_keepalive_timeout_close_session_with_protocol_error()
{
await using var listener = new MqttListener("127.0.0.1", 0);
using var cts = new CancellationTokenSource();
await listener.StartAsync(cts.Token);
using var client = new TcpClient();
await client.ConnectAsync(IPAddress.Loopback, listener.Port);
var stream = client.GetStream();
await MqttRuntimeWire.WriteLineAsync(stream, "CONNECT keepalive-client keepalive=1");
(await MqttRuntimeWire.ReadLineAsync(stream, 1000)).ShouldBe("CONNACK");
await Task.Delay(2000);
(await MqttRuntimeWire.ReadRawAsync(stream, 1000)).ShouldBeNull();
}
}

View File

@@ -68,4 +68,22 @@ internal static class MqttRuntimeWire
return Encoding.UTF8.GetString([.. bytes]); return Encoding.UTF8.GetString([.. bytes]);
} }
public static async Task<string?> ReadRawAsync(NetworkStream stream, int timeoutMs)
{
using var timeout = new CancellationTokenSource(timeoutMs);
var one = new byte[1];
try
{
var read = await stream.ReadAsync(one.AsMemory(0, 1), timeout.Token);
if (read == 0)
return null;
return Encoding.UTF8.GetString(one, 0, read);
}
catch (OperationCanceledException)
{
return "__timeout__";
}
}
} }