feat: enforce mqtt auth tls and keepalive semantics

This commit is contained in:
Joseph Doherty
2026-02-23 14:45:37 -05:00
parent 7dcf5776b3
commit b2312c0dac
8 changed files with 167 additions and 14 deletions

View File

@@ -12,6 +12,7 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
private readonly SemaphoreSlim _writeGate = new(1, 1);
private string _clientId = string.Empty;
private bool _cleanSession = true;
private TimeSpan _idleTimeout = Timeout.InfiniteTimeSpan;
public async Task RunAsync(CancellationToken ct)
{
@@ -20,7 +21,7 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
string line;
try
{
line = await ReadLineAsync(ct);
line = await ReadLineAsync(ct, _idleTimeout);
}
catch
{
@@ -31,8 +32,15 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
switch (packet.Type)
{
case MqttPacketType.Connect:
if (!_listener.TryAuthenticate(packet.Username, packet.Password))
{
await WriteLineAsync("ERR mqtt auth failed", ct);
return;
}
_clientId = packet.ClientId;
_cleanSession = packet.CleanSession;
_idleTimeout = _listener.ResolveKeepAliveTimeout(packet.KeepAliveSeconds);
var pending = _listener.OpenSession(_clientId, _cleanSession);
await WriteLineAsync("CONNACK", ct);
foreach (var redelivery in pending)
@@ -83,19 +91,39 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
}
}
private async Task<string> ReadLineAsync(CancellationToken ct)
private async Task<string> ReadLineAsync(CancellationToken ct, TimeSpan idleTimeout)
{
CancellationToken token = ct;
CancellationTokenSource? timeoutCts = null;
if (idleTimeout != Timeout.InfiniteTimeSpan && idleTimeout > TimeSpan.Zero)
{
timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
timeoutCts.CancelAfter(idleTimeout);
token = timeoutCts.Token;
}
var bytes = new List<byte>(64);
var single = new byte[1];
while (true)
try
{
var read = await _stream.ReadAsync(single.AsMemory(0, 1), ct);
if (read == 0)
throw new IOException("mqtt closed");
if (single[0] == (byte)'\n')
break;
if (single[0] != (byte)'\r')
bytes.Add(single[0]);
while (true)
{
var read = await _stream.ReadAsync(single.AsMemory(0, 1), token);
if (read == 0)
throw new IOException("mqtt closed");
if (single[0] == (byte)'\n')
break;
if (single[0] != (byte)'\r')
bytes.Add(single[0]);
}
}
catch (OperationCanceledException) when (timeoutCts != null && !ct.IsCancellationRequested)
{
throw new IOException("mqtt keepalive timeout");
}
finally
{
timeoutCts?.Dispose();
}
return Encoding.UTF8.GetString([.. bytes]);

View File

@@ -1,13 +1,20 @@
using System.Collections.Concurrent;
using System.Net;
using System.Net.Sockets;
using NATS.Server.Auth;
namespace NATS.Server.Mqtt;
public sealed class MqttListener(string host, int port) : IAsyncDisposable
public sealed class MqttListener(
string host,
int port,
string? requiredUsername = null,
string? requiredPassword = null) : IAsyncDisposable
{
private readonly string _host = host;
private int _port = port;
private readonly string? _requiredUsername = requiredUsername;
private readonly string? _requiredPassword = requiredPassword;
private readonly ConcurrentDictionary<MqttConnection, byte> _connections = new();
private readonly ConcurrentDictionary<string, ConcurrentDictionary<MqttConnection, byte>> _subscriptions = new(StringComparer.Ordinal);
private readonly ConcurrentDictionary<string, MqttSessionState> _sessions = new(StringComparer.Ordinal);
@@ -85,6 +92,17 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable
session.Pending.TryRemove(packetId, out _);
}
internal bool TryAuthenticate(string? username, string? password)
=> AuthService.ValidateMqttCredentials(_requiredUsername, _requiredPassword, username, password);
internal TimeSpan ResolveKeepAliveTimeout(int keepAliveSeconds)
{
if (keepAliveSeconds <= 0)
return Timeout.InfiniteTimeSpan;
return TimeSpan.FromSeconds(Math.Max(keepAliveSeconds * 1.5, 1));
}
internal void Unregister(MqttConnection connection)
{
_connections.TryRemove(connection, out _);

View File

@@ -16,7 +16,10 @@ public sealed record MqttPacket(
string Payload,
string ClientId,
int PacketId = 0,
bool CleanSession = true);
bool CleanSession = true,
string? Username = null,
string? Password = null,
int KeepAliveSeconds = 0);
public sealed class MqttProtocolParser
{
@@ -39,6 +42,9 @@ public sealed class MqttProtocolParser
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
var cleanSession = true;
string? username = null;
string? password = null;
var keepAliveSeconds = 0;
for (var i = 2; i < parts.Length; i++)
{
if (parts[i].StartsWith("clean=", StringComparison.OrdinalIgnoreCase)
@@ -46,6 +52,19 @@ public sealed class MqttProtocolParser
{
cleanSession = parsedClean;
}
if (parts[i].StartsWith("user=", StringComparison.OrdinalIgnoreCase))
username = parts[i]["user=".Length..];
if (parts[i].StartsWith("pass=", StringComparison.OrdinalIgnoreCase))
password = parts[i]["pass=".Length..];
if (parts[i].StartsWith("keepalive=", StringComparison.OrdinalIgnoreCase)
&& int.TryParse(parts[i]["keepalive=".Length..], out var parsedKeepAlive)
&& parsedKeepAlive >= 0)
{
keepAliveSeconds = parsedKeepAlive;
}
}
return new MqttPacket(
@@ -53,7 +72,10 @@ public sealed class MqttProtocolParser
string.Empty,
string.Empty,
parts[1],
CleanSession: cleanSession);
CleanSession: cleanSession,
Username: username,
Password: password,
KeepAliveSeconds: keepAliveSeconds);
}
if (trimmed.StartsWith("SUB ", StringComparison.Ordinal))