From 7dcf5776b3cdaef61dbca1fa66f14865b55be9af Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Mon, 23 Feb 2026 14:43:08 -0500 Subject: [PATCH] feat: implement mqtt session and qos ack runtime semantics --- src/NATS.Server/Mqtt/MqttConnection.cs | 13 ++++ src/NATS.Server/Mqtt/MqttListener.cs | 44 ++++++++++++ src/NATS.Server/Mqtt/MqttProtocolParser.cs | 64 ++++++++++++++++- src/NATS.Server/MqttOptions.cs | 3 + .../Mqtt/MqttQosAckRuntimeTests.cs | 26 +++++++ .../Mqtt/MqttSessionRuntimeTests.cs | 71 +++++++++++++++++++ 6 files changed, 219 insertions(+), 2 deletions(-) create mode 100644 tests/NATS.Server.Tests/Mqtt/MqttQosAckRuntimeTests.cs create mode 100644 tests/NATS.Server.Tests/Mqtt/MqttSessionRuntimeTests.cs diff --git a/src/NATS.Server/Mqtt/MqttConnection.cs b/src/NATS.Server/Mqtt/MqttConnection.cs index 071ab7a..29da2af 100644 --- a/src/NATS.Server/Mqtt/MqttConnection.cs +++ b/src/NATS.Server/Mqtt/MqttConnection.cs @@ -11,6 +11,7 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA private readonly MqttProtocolParser _parser = new(); private readonly SemaphoreSlim _writeGate = new(1, 1); private string _clientId = string.Empty; + private bool _cleanSession = true; public async Task RunAsync(CancellationToken ct) { @@ -31,7 +32,11 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA { case MqttPacketType.Connect: _clientId = packet.ClientId; + _cleanSession = packet.CleanSession; + var pending = _listener.OpenSession(_clientId, _cleanSession); await WriteLineAsync("CONNACK", ct); + foreach (var redelivery in pending) + await WriteLineAsync($"REDLIVER {redelivery.PacketId} {redelivery.Topic} {redelivery.Payload}", ct); break; case MqttPacketType.Subscribe: _listener.RegisterSubscription(this, packet.Topic); @@ -40,6 +45,14 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA case MqttPacketType.Publish: await _listener.PublishAsync(packet.Topic, packet.Payload, this, ct); break; + case MqttPacketType.PublishQos1: + _listener.RecordPendingPublish(_clientId, packet.PacketId, packet.Topic, packet.Payload); + await WriteLineAsync($"PUBACK {packet.PacketId}", ct); + await _listener.PublishAsync(packet.Topic, packet.Payload, this, ct); + break; + case MqttPacketType.Ack: + _listener.AckPendingPublish(_clientId, packet.PacketId); + break; } } } diff --git a/src/NATS.Server/Mqtt/MqttListener.cs b/src/NATS.Server/Mqtt/MqttListener.cs index e6a9ce3..1ad743f 100644 --- a/src/NATS.Server/Mqtt/MqttListener.cs +++ b/src/NATS.Server/Mqtt/MqttListener.cs @@ -10,6 +10,7 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable private int _port = port; private readonly ConcurrentDictionary _connections = new(); private readonly ConcurrentDictionary> _subscriptions = new(StringComparer.Ordinal); + private readonly ConcurrentDictionary _sessions = new(StringComparer.Ordinal); private TcpListener? _listener; private Task? _acceptLoop; private readonly CancellationTokenSource _cts = new(); @@ -49,6 +50,41 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable } } + internal IReadOnlyList OpenSession(string clientId, bool cleanSession) + { + if (string.IsNullOrWhiteSpace(clientId)) + return []; + + if (cleanSession) + { + _sessions.TryRemove(clientId, out _); + return []; + } + + var session = _sessions.GetOrAdd(clientId, static _ => new MqttSessionState()); + return session.Pending.Values + .OrderBy(static p => p.PacketId) + .ToArray(); + } + + internal void RecordPendingPublish(string clientId, int packetId, string topic, string payload) + { + if (string.IsNullOrWhiteSpace(clientId) || packetId <= 0) + return; + + var session = _sessions.GetOrAdd(clientId, static _ => new MqttSessionState()); + session.Pending[packetId] = new MqttPendingPublish(packetId, topic, payload); + } + + internal void AckPendingPublish(string clientId, int packetId) + { + if (string.IsNullOrWhiteSpace(clientId) || packetId <= 0) + return; + + if (_sessions.TryGetValue(clientId, out var session)) + session.Pending.TryRemove(packetId, out _); + } + internal void Unregister(MqttConnection connection) { _connections.TryRemove(connection, out _); @@ -69,6 +105,7 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable _connections.Clear(); _subscriptions.Clear(); + _sessions.Clear(); _cts.Dispose(); } @@ -101,4 +138,11 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable }, ct); } } + + private sealed class MqttSessionState + { + public ConcurrentDictionary Pending { get; } = new(); + } } + +internal sealed record MqttPendingPublish(int PacketId, string Topic, string Payload); diff --git a/src/NATS.Server/Mqtt/MqttProtocolParser.cs b/src/NATS.Server/Mqtt/MqttProtocolParser.cs index db8e7bb..698365b 100644 --- a/src/NATS.Server/Mqtt/MqttProtocolParser.cs +++ b/src/NATS.Server/Mqtt/MqttProtocolParser.cs @@ -6,9 +6,17 @@ public enum MqttPacketType Connect, Subscribe, Publish, + PublishQos1, + Ack, } -public sealed record MqttPacket(MqttPacketType Type, string Topic, string Payload, string ClientId); +public sealed record MqttPacket( + MqttPacketType Type, + string Topic, + string Payload, + string ClientId, + int PacketId = 0, + bool CleanSession = true); public sealed class MqttProtocolParser { @@ -26,11 +34,26 @@ public sealed class MqttProtocolParser if (trimmed.StartsWith("CONNECT ", StringComparison.Ordinal)) { + var parts = trimmed.Split(' ', StringSplitOptions.RemoveEmptyEntries); + if (parts.Length < 2) + return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty); + + var cleanSession = true; + for (var i = 2; i < parts.Length; i++) + { + if (parts[i].StartsWith("clean=", StringComparison.OrdinalIgnoreCase) + && bool.TryParse(parts[i]["clean=".Length..], out var parsedClean)) + { + cleanSession = parsedClean; + } + } + return new MqttPacket( MqttPacketType.Connect, string.Empty, string.Empty, - trimmed["CONNECT ".Length..].Trim()); + parts[1], + CleanSession: cleanSession); } if (trimmed.StartsWith("SUB ", StringComparison.Ordinal)) @@ -54,6 +77,43 @@ public sealed class MqttProtocolParser return new MqttPacket(MqttPacketType.Publish, topic, payload, string.Empty); } + if (trimmed.StartsWith("PUBQ1 ", StringComparison.Ordinal)) + { + var rest = trimmed["PUBQ1 ".Length..]; + var firstSep = rest.IndexOf(' '); + if (firstSep <= 0) + return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty); + + var secondSep = rest.IndexOf(' ', firstSep + 1); + if (secondSep <= firstSep + 1) + return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty); + + if (!int.TryParse(rest[..firstSep], out var packetId) || packetId <= 0) + return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty); + + var topic = rest[(firstSep + 1)..secondSep].Trim(); + var payload = rest[(secondSep + 1)..]; + return new MqttPacket( + MqttPacketType.PublishQos1, + topic, + payload, + string.Empty, + PacketId: packetId); + } + + if (trimmed.StartsWith("ACK ", StringComparison.Ordinal)) + { + if (!int.TryParse(trimmed["ACK ".Length..].Trim(), out var packetId) || packetId <= 0) + return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty); + + return new MqttPacket( + MqttPacketType.Ack, + string.Empty, + string.Empty, + string.Empty, + PacketId: packetId); + } + return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty); } } diff --git a/src/NATS.Server/MqttOptions.cs b/src/NATS.Server/MqttOptions.cs index c47e15e..37c0252 100644 --- a/src/NATS.Server/MqttOptions.cs +++ b/src/NATS.Server/MqttOptions.cs @@ -38,6 +38,9 @@ public sealed class MqttOptions public TimeSpan AckWait { get; set; } = TimeSpan.FromSeconds(30); public ushort MaxAckPending { get; set; } public TimeSpan JsApiTimeout { get; set; } = TimeSpan.FromSeconds(5); + public bool SessionPersistence { get; set; } = true; + public TimeSpan SessionTtl { get; set; } = TimeSpan.FromHours(1); + public bool Qos1PubAck { get; set; } = true; public bool HasTls => TlsCert != null && TlsKey != null; } diff --git a/tests/NATS.Server.Tests/Mqtt/MqttQosAckRuntimeTests.cs b/tests/NATS.Server.Tests/Mqtt/MqttQosAckRuntimeTests.cs new file mode 100644 index 0000000..b792dfa --- /dev/null +++ b/tests/NATS.Server.Tests/Mqtt/MqttQosAckRuntimeTests.cs @@ -0,0 +1,26 @@ +using System.Net; +using System.Net.Sockets; +using NATS.Server.Mqtt; + +namespace NATS.Server.Tests.Mqtt; + +public class MqttQosAckRuntimeTests +{ + [Fact] + public async Task Qos1_publish_receives_puback_and_redelivery_on_session_reconnect_when_unacked() + { + await using var listener = new MqttListener("127.0.0.1", 0); + using var cts = new CancellationTokenSource(); + await listener.StartAsync(cts.Token); + + using var client = new TcpClient(); + await client.ConnectAsync(IPAddress.Loopback, listener.Port); + var stream = client.GetStream(); + + await MqttRuntimeWire.WriteLineAsync(stream, "CONNECT qos-client clean=false"); + (await MqttRuntimeWire.ReadLineAsync(stream, 1000)).ShouldBe("CONNACK"); + + await MqttRuntimeWire.WriteLineAsync(stream, "PUBQ1 7 sensors.temp 42"); + (await MqttRuntimeWire.ReadLineAsync(stream, 1000)).ShouldBe("PUBACK 7"); + } +} diff --git a/tests/NATS.Server.Tests/Mqtt/MqttSessionRuntimeTests.cs b/tests/NATS.Server.Tests/Mqtt/MqttSessionRuntimeTests.cs new file mode 100644 index 0000000..7c6f905 --- /dev/null +++ b/tests/NATS.Server.Tests/Mqtt/MqttSessionRuntimeTests.cs @@ -0,0 +1,71 @@ +using System.Net; +using System.Net.Sockets; +using System.Text; +using NATS.Server.Mqtt; + +namespace NATS.Server.Tests.Mqtt; + +public class MqttSessionRuntimeTests +{ + [Fact] + public async Task Qos1_publish_receives_puback_and_redelivery_on_session_reconnect_when_unacked() + { + await using var listener = new MqttListener("127.0.0.1", 0); + using var cts = new CancellationTokenSource(); + await listener.StartAsync(cts.Token); + + using (var first = new TcpClient()) + { + await first.ConnectAsync(IPAddress.Loopback, listener.Port); + var firstStream = first.GetStream(); + await MqttRuntimeWire.WriteLineAsync(firstStream, "CONNECT session-client clean=false"); + (await MqttRuntimeWire.ReadLineAsync(firstStream, 1000)).ShouldBe("CONNACK"); + + await MqttRuntimeWire.WriteLineAsync(firstStream, "PUBQ1 21 sensors.temp 99"); + (await MqttRuntimeWire.ReadLineAsync(firstStream, 1000)).ShouldBe("PUBACK 21"); + } + + using var second = new TcpClient(); + await second.ConnectAsync(IPAddress.Loopback, listener.Port); + var secondStream = second.GetStream(); + await MqttRuntimeWire.WriteLineAsync(secondStream, "CONNECT session-client clean=false"); + (await MqttRuntimeWire.ReadLineAsync(secondStream, 1000)).ShouldBe("CONNACK"); + (await MqttRuntimeWire.ReadLineAsync(secondStream, 1000)).ShouldBe("REDLIVER 21 sensors.temp 99"); + } +} + +internal static class MqttRuntimeWire +{ + public static async Task WriteLineAsync(NetworkStream stream, string line) + { + var bytes = Encoding.UTF8.GetBytes(line + "\n"); + await stream.WriteAsync(bytes); + await stream.FlushAsync(); + } + + public static async Task ReadLineAsync(NetworkStream stream, int timeoutMs) + { + using var timeout = new CancellationTokenSource(timeoutMs); + var bytes = new List(); + var one = new byte[1]; + try + { + while (true) + { + var read = await stream.ReadAsync(one.AsMemory(0, 1), timeout.Token); + if (read == 0) + return null; + if (one[0] == (byte)'\n') + break; + if (one[0] != (byte)'\r') + bytes.Add(one[0]); + } + } + catch (OperationCanceledException) + { + return null; + } + + return Encoding.UTF8.GetString([.. bytes]); + } +}