feat: implement mqtt session and qos ack runtime semantics

This commit is contained in:
Joseph Doherty
2026-02-23 14:43:08 -05:00
parent 7faf42c588
commit 7dcf5776b3
6 changed files with 219 additions and 2 deletions

View File

@@ -11,6 +11,7 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
private readonly MqttProtocolParser _parser = new();
private readonly SemaphoreSlim _writeGate = new(1, 1);
private string _clientId = string.Empty;
private bool _cleanSession = true;
public async Task RunAsync(CancellationToken ct)
{
@@ -31,7 +32,11 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
{
case MqttPacketType.Connect:
_clientId = packet.ClientId;
_cleanSession = packet.CleanSession;
var pending = _listener.OpenSession(_clientId, _cleanSession);
await WriteLineAsync("CONNACK", ct);
foreach (var redelivery in pending)
await WriteLineAsync($"REDLIVER {redelivery.PacketId} {redelivery.Topic} {redelivery.Payload}", ct);
break;
case MqttPacketType.Subscribe:
_listener.RegisterSubscription(this, packet.Topic);
@@ -40,6 +45,14 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
case MqttPacketType.Publish:
await _listener.PublishAsync(packet.Topic, packet.Payload, this, ct);
break;
case MqttPacketType.PublishQos1:
_listener.RecordPendingPublish(_clientId, packet.PacketId, packet.Topic, packet.Payload);
await WriteLineAsync($"PUBACK {packet.PacketId}", ct);
await _listener.PublishAsync(packet.Topic, packet.Payload, this, ct);
break;
case MqttPacketType.Ack:
_listener.AckPendingPublish(_clientId, packet.PacketId);
break;
}
}
}

View File

@@ -10,6 +10,7 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable
private int _port = port;
private readonly ConcurrentDictionary<MqttConnection, byte> _connections = new();
private readonly ConcurrentDictionary<string, ConcurrentDictionary<MqttConnection, byte>> _subscriptions = new(StringComparer.Ordinal);
private readonly ConcurrentDictionary<string, MqttSessionState> _sessions = new(StringComparer.Ordinal);
private TcpListener? _listener;
private Task? _acceptLoop;
private readonly CancellationTokenSource _cts = new();
@@ -49,6 +50,41 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable
}
}
internal IReadOnlyList<MqttPendingPublish> OpenSession(string clientId, bool cleanSession)
{
if (string.IsNullOrWhiteSpace(clientId))
return [];
if (cleanSession)
{
_sessions.TryRemove(clientId, out _);
return [];
}
var session = _sessions.GetOrAdd(clientId, static _ => new MqttSessionState());
return session.Pending.Values
.OrderBy(static p => p.PacketId)
.ToArray();
}
internal void RecordPendingPublish(string clientId, int packetId, string topic, string payload)
{
if (string.IsNullOrWhiteSpace(clientId) || packetId <= 0)
return;
var session = _sessions.GetOrAdd(clientId, static _ => new MqttSessionState());
session.Pending[packetId] = new MqttPendingPublish(packetId, topic, payload);
}
internal void AckPendingPublish(string clientId, int packetId)
{
if (string.IsNullOrWhiteSpace(clientId) || packetId <= 0)
return;
if (_sessions.TryGetValue(clientId, out var session))
session.Pending.TryRemove(packetId, out _);
}
internal void Unregister(MqttConnection connection)
{
_connections.TryRemove(connection, out _);
@@ -69,6 +105,7 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable
_connections.Clear();
_subscriptions.Clear();
_sessions.Clear();
_cts.Dispose();
}
@@ -101,4 +138,11 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable
}, ct);
}
}
private sealed class MqttSessionState
{
public ConcurrentDictionary<int, MqttPendingPublish> Pending { get; } = new();
}
}
internal sealed record MqttPendingPublish(int PacketId, string Topic, string Payload);

View File

@@ -6,9 +6,17 @@ public enum MqttPacketType
Connect,
Subscribe,
Publish,
PublishQos1,
Ack,
}
public sealed record MqttPacket(MqttPacketType Type, string Topic, string Payload, string ClientId);
public sealed record MqttPacket(
MqttPacketType Type,
string Topic,
string Payload,
string ClientId,
int PacketId = 0,
bool CleanSession = true);
public sealed class MqttProtocolParser
{
@@ -26,11 +34,26 @@ public sealed class MqttProtocolParser
if (trimmed.StartsWith("CONNECT ", StringComparison.Ordinal))
{
var parts = trimmed.Split(' ', StringSplitOptions.RemoveEmptyEntries);
if (parts.Length < 2)
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
var cleanSession = true;
for (var i = 2; i < parts.Length; i++)
{
if (parts[i].StartsWith("clean=", StringComparison.OrdinalIgnoreCase)
&& bool.TryParse(parts[i]["clean=".Length..], out var parsedClean))
{
cleanSession = parsedClean;
}
}
return new MqttPacket(
MqttPacketType.Connect,
string.Empty,
string.Empty,
trimmed["CONNECT ".Length..].Trim());
parts[1],
CleanSession: cleanSession);
}
if (trimmed.StartsWith("SUB ", StringComparison.Ordinal))
@@ -54,6 +77,43 @@ public sealed class MqttProtocolParser
return new MqttPacket(MqttPacketType.Publish, topic, payload, string.Empty);
}
if (trimmed.StartsWith("PUBQ1 ", StringComparison.Ordinal))
{
var rest = trimmed["PUBQ1 ".Length..];
var firstSep = rest.IndexOf(' ');
if (firstSep <= 0)
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
var secondSep = rest.IndexOf(' ', firstSep + 1);
if (secondSep <= firstSep + 1)
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
if (!int.TryParse(rest[..firstSep], out var packetId) || packetId <= 0)
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
var topic = rest[(firstSep + 1)..secondSep].Trim();
var payload = rest[(secondSep + 1)..];
return new MqttPacket(
MqttPacketType.PublishQos1,
topic,
payload,
string.Empty,
PacketId: packetId);
}
if (trimmed.StartsWith("ACK ", StringComparison.Ordinal))
{
if (!int.TryParse(trimmed["ACK ".Length..].Trim(), out var packetId) || packetId <= 0)
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
return new MqttPacket(
MqttPacketType.Ack,
string.Empty,
string.Empty,
string.Empty,
PacketId: packetId);
}
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
}
}

View File

@@ -38,6 +38,9 @@ public sealed class MqttOptions
public TimeSpan AckWait { get; set; } = TimeSpan.FromSeconds(30);
public ushort MaxAckPending { get; set; }
public TimeSpan JsApiTimeout { get; set; } = TimeSpan.FromSeconds(5);
public bool SessionPersistence { get; set; } = true;
public TimeSpan SessionTtl { get; set; } = TimeSpan.FromHours(1);
public bool Qos1PubAck { get; set; } = true;
public bool HasTls => TlsCert != null && TlsKey != null;
}

View File

@@ -0,0 +1,26 @@
using System.Net;
using System.Net.Sockets;
using NATS.Server.Mqtt;
namespace NATS.Server.Tests.Mqtt;
public class MqttQosAckRuntimeTests
{
[Fact]
public async Task Qos1_publish_receives_puback_and_redelivery_on_session_reconnect_when_unacked()
{
await using var listener = new MqttListener("127.0.0.1", 0);
using var cts = new CancellationTokenSource();
await listener.StartAsync(cts.Token);
using var client = new TcpClient();
await client.ConnectAsync(IPAddress.Loopback, listener.Port);
var stream = client.GetStream();
await MqttRuntimeWire.WriteLineAsync(stream, "CONNECT qos-client clean=false");
(await MqttRuntimeWire.ReadLineAsync(stream, 1000)).ShouldBe("CONNACK");
await MqttRuntimeWire.WriteLineAsync(stream, "PUBQ1 7 sensors.temp 42");
(await MqttRuntimeWire.ReadLineAsync(stream, 1000)).ShouldBe("PUBACK 7");
}
}

View File

@@ -0,0 +1,71 @@
using System.Net;
using System.Net.Sockets;
using System.Text;
using NATS.Server.Mqtt;
namespace NATS.Server.Tests.Mqtt;
public class MqttSessionRuntimeTests
{
[Fact]
public async Task Qos1_publish_receives_puback_and_redelivery_on_session_reconnect_when_unacked()
{
await using var listener = new MqttListener("127.0.0.1", 0);
using var cts = new CancellationTokenSource();
await listener.StartAsync(cts.Token);
using (var first = new TcpClient())
{
await first.ConnectAsync(IPAddress.Loopback, listener.Port);
var firstStream = first.GetStream();
await MqttRuntimeWire.WriteLineAsync(firstStream, "CONNECT session-client clean=false");
(await MqttRuntimeWire.ReadLineAsync(firstStream, 1000)).ShouldBe("CONNACK");
await MqttRuntimeWire.WriteLineAsync(firstStream, "PUBQ1 21 sensors.temp 99");
(await MqttRuntimeWire.ReadLineAsync(firstStream, 1000)).ShouldBe("PUBACK 21");
}
using var second = new TcpClient();
await second.ConnectAsync(IPAddress.Loopback, listener.Port);
var secondStream = second.GetStream();
await MqttRuntimeWire.WriteLineAsync(secondStream, "CONNECT session-client clean=false");
(await MqttRuntimeWire.ReadLineAsync(secondStream, 1000)).ShouldBe("CONNACK");
(await MqttRuntimeWire.ReadLineAsync(secondStream, 1000)).ShouldBe("REDLIVER 21 sensors.temp 99");
}
}
internal static class MqttRuntimeWire
{
public static async Task WriteLineAsync(NetworkStream stream, string line)
{
var bytes = Encoding.UTF8.GetBytes(line + "\n");
await stream.WriteAsync(bytes);
await stream.FlushAsync();
}
public static async Task<string?> ReadLineAsync(NetworkStream stream, int timeoutMs)
{
using var timeout = new CancellationTokenSource(timeoutMs);
var bytes = new List<byte>();
var one = new byte[1];
try
{
while (true)
{
var read = await stream.ReadAsync(one.AsMemory(0, 1), timeout.Token);
if (read == 0)
return null;
if (one[0] == (byte)'\n')
break;
if (one[0] != (byte)'\r')
bytes.Add(one[0]);
}
}
catch (OperationCanceledException)
{
return null;
}
return Encoding.UTF8.GetString([.. bytes]);
}
}