feat: enforce mqtt auth tls and keepalive semantics

This commit is contained in:
Joseph Doherty
2026-02-23 14:45:37 -05:00
parent 7dcf5776b3
commit b2312c0dac
8 changed files with 167 additions and 14 deletions

View File

@@ -149,6 +149,19 @@ public sealed class AuthService
return raw.ToArray();
}
public static bool ValidateMqttCredentials(
string? configuredUsername,
string? configuredPassword,
string? providedUsername,
string? providedPassword)
{
if (string.IsNullOrEmpty(configuredUsername) && string.IsNullOrEmpty(configuredPassword))
return true;
return string.Equals(configuredUsername, providedUsername, StringComparison.Ordinal)
&& string.Equals(configuredPassword, providedPassword, StringComparison.Ordinal);
}
public string EncodeNonce(byte[] nonce)
{
return Convert.ToBase64String(nonce)

View File

@@ -12,6 +12,7 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
private readonly SemaphoreSlim _writeGate = new(1, 1);
private string _clientId = string.Empty;
private bool _cleanSession = true;
private TimeSpan _idleTimeout = Timeout.InfiniteTimeSpan;
public async Task RunAsync(CancellationToken ct)
{
@@ -20,7 +21,7 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
string line;
try
{
line = await ReadLineAsync(ct);
line = await ReadLineAsync(ct, _idleTimeout);
}
catch
{
@@ -31,8 +32,15 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
switch (packet.Type)
{
case MqttPacketType.Connect:
if (!_listener.TryAuthenticate(packet.Username, packet.Password))
{
await WriteLineAsync("ERR mqtt auth failed", ct);
return;
}
_clientId = packet.ClientId;
_cleanSession = packet.CleanSession;
_idleTimeout = _listener.ResolveKeepAliveTimeout(packet.KeepAliveSeconds);
var pending = _listener.OpenSession(_clientId, _cleanSession);
await WriteLineAsync("CONNACK", ct);
foreach (var redelivery in pending)
@@ -83,19 +91,39 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
}
}
private async Task<string> ReadLineAsync(CancellationToken ct)
private async Task<string> ReadLineAsync(CancellationToken ct, TimeSpan idleTimeout)
{
CancellationToken token = ct;
CancellationTokenSource? timeoutCts = null;
if (idleTimeout != Timeout.InfiniteTimeSpan && idleTimeout > TimeSpan.Zero)
{
timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
timeoutCts.CancelAfter(idleTimeout);
token = timeoutCts.Token;
}
var bytes = new List<byte>(64);
var single = new byte[1];
while (true)
try
{
var read = await _stream.ReadAsync(single.AsMemory(0, 1), ct);
if (read == 0)
throw new IOException("mqtt closed");
if (single[0] == (byte)'\n')
break;
if (single[0] != (byte)'\r')
bytes.Add(single[0]);
while (true)
{
var read = await _stream.ReadAsync(single.AsMemory(0, 1), token);
if (read == 0)
throw new IOException("mqtt closed");
if (single[0] == (byte)'\n')
break;
if (single[0] != (byte)'\r')
bytes.Add(single[0]);
}
}
catch (OperationCanceledException) when (timeoutCts != null && !ct.IsCancellationRequested)
{
throw new IOException("mqtt keepalive timeout");
}
finally
{
timeoutCts?.Dispose();
}
return Encoding.UTF8.GetString([.. bytes]);

View File

@@ -1,13 +1,20 @@
using System.Collections.Concurrent;
using System.Net;
using System.Net.Sockets;
using NATS.Server.Auth;
namespace NATS.Server.Mqtt;
public sealed class MqttListener(string host, int port) : IAsyncDisposable
public sealed class MqttListener(
string host,
int port,
string? requiredUsername = null,
string? requiredPassword = null) : IAsyncDisposable
{
private readonly string _host = host;
private int _port = port;
private readonly string? _requiredUsername = requiredUsername;
private readonly string? _requiredPassword = requiredPassword;
private readonly ConcurrentDictionary<MqttConnection, byte> _connections = new();
private readonly ConcurrentDictionary<string, ConcurrentDictionary<MqttConnection, byte>> _subscriptions = new(StringComparer.Ordinal);
private readonly ConcurrentDictionary<string, MqttSessionState> _sessions = new(StringComparer.Ordinal);
@@ -85,6 +92,17 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable
session.Pending.TryRemove(packetId, out _);
}
internal bool TryAuthenticate(string? username, string? password)
=> AuthService.ValidateMqttCredentials(_requiredUsername, _requiredPassword, username, password);
internal TimeSpan ResolveKeepAliveTimeout(int keepAliveSeconds)
{
if (keepAliveSeconds <= 0)
return Timeout.InfiniteTimeSpan;
return TimeSpan.FromSeconds(Math.Max(keepAliveSeconds * 1.5, 1));
}
internal void Unregister(MqttConnection connection)
{
_connections.TryRemove(connection, out _);

View File

@@ -16,7 +16,10 @@ public sealed record MqttPacket(
string Payload,
string ClientId,
int PacketId = 0,
bool CleanSession = true);
bool CleanSession = true,
string? Username = null,
string? Password = null,
int KeepAliveSeconds = 0);
public sealed class MqttProtocolParser
{
@@ -39,6 +42,9 @@ public sealed class MqttProtocolParser
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
var cleanSession = true;
string? username = null;
string? password = null;
var keepAliveSeconds = 0;
for (var i = 2; i < parts.Length; i++)
{
if (parts[i].StartsWith("clean=", StringComparison.OrdinalIgnoreCase)
@@ -46,6 +52,19 @@ public sealed class MqttProtocolParser
{
cleanSession = parsedClean;
}
if (parts[i].StartsWith("user=", StringComparison.OrdinalIgnoreCase))
username = parts[i]["user=".Length..];
if (parts[i].StartsWith("pass=", StringComparison.OrdinalIgnoreCase))
password = parts[i]["pass=".Length..];
if (parts[i].StartsWith("keepalive=", StringComparison.OrdinalIgnoreCase)
&& int.TryParse(parts[i]["keepalive=".Length..], out var parsedKeepAlive)
&& parsedKeepAlive >= 0)
{
keepAliveSeconds = parsedKeepAlive;
}
}
return new MqttPacket(
@@ -53,7 +72,10 @@ public sealed class MqttProtocolParser
string.Empty,
string.Empty,
parts[1],
CleanSession: cleanSession);
CleanSession: cleanSession,
Username: username,
Password: password,
KeepAliveSeconds: keepAliveSeconds);
}
if (trimmed.StartsWith("SUB ", StringComparison.Ordinal))

View File

@@ -553,7 +553,11 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
if (_options.Mqtt is { Port: > 0 } mqttOptions)
{
var mqttHost = string.IsNullOrWhiteSpace(mqttOptions.Host) ? _options.Host : mqttOptions.Host;
_mqttListener = new MqttListener(mqttHost, mqttOptions.Port);
_mqttListener = new MqttListener(
mqttHost,
mqttOptions.Port,
mqttOptions.Username,
mqttOptions.Password);
await _mqttListener.StartAsync(linked.Token);
}
if (_jetStreamService != null)

View File

@@ -0,0 +1,24 @@
using System.Net;
using System.Net.Sockets;
using NATS.Server.Mqtt;
namespace NATS.Server.Tests.Mqtt;
public class MqttAuthIntegrationTests
{
[Fact]
public async Task Invalid_mqtt_credentials_or_keepalive_timeout_close_session_with_protocol_error()
{
await using var listener = new MqttListener("127.0.0.1", 0, requiredUsername: "mqtt", requiredPassword: "secret");
using var cts = new CancellationTokenSource();
await listener.StartAsync(cts.Token);
using var client = new TcpClient();
await client.ConnectAsync(IPAddress.Loopback, listener.Port);
var stream = client.GetStream();
await MqttRuntimeWire.WriteLineAsync(stream, "CONNECT auth-client user=bad pass=wrong");
(await MqttRuntimeWire.ReadLineAsync(stream, 1000)).ShouldBe("ERR mqtt auth failed");
(await MqttRuntimeWire.ReadRawAsync(stream, 1000)).ShouldBeNull();
}
}

View File

@@ -0,0 +1,26 @@
using System.Net;
using System.Net.Sockets;
using NATS.Server.Mqtt;
namespace NATS.Server.Tests.Mqtt;
public class MqttKeepAliveTests
{
[Fact]
public async Task Invalid_mqtt_credentials_or_keepalive_timeout_close_session_with_protocol_error()
{
await using var listener = new MqttListener("127.0.0.1", 0);
using var cts = new CancellationTokenSource();
await listener.StartAsync(cts.Token);
using var client = new TcpClient();
await client.ConnectAsync(IPAddress.Loopback, listener.Port);
var stream = client.GetStream();
await MqttRuntimeWire.WriteLineAsync(stream, "CONNECT keepalive-client keepalive=1");
(await MqttRuntimeWire.ReadLineAsync(stream, 1000)).ShouldBe("CONNACK");
await Task.Delay(2000);
(await MqttRuntimeWire.ReadRawAsync(stream, 1000)).ShouldBeNull();
}
}

View File

@@ -68,4 +68,22 @@ internal static class MqttRuntimeWire
return Encoding.UTF8.GetString([.. bytes]);
}
public static async Task<string?> ReadRawAsync(NetworkStream stream, int timeoutMs)
{
using var timeout = new CancellationTokenSource(timeoutMs);
var one = new byte[1];
try
{
var read = await stream.ReadAsync(one.AsMemory(0, 1), timeout.Token);
if (read == 0)
return null;
return Encoding.UTF8.GetString(one, 0, read);
}
catch (OperationCanceledException)
{
return "__timeout__";
}
}
}