feat: implement mqtt session and qos ack runtime semantics
This commit is contained in:
@@ -11,6 +11,7 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
|
||||
private readonly MqttProtocolParser _parser = new();
|
||||
private readonly SemaphoreSlim _writeGate = new(1, 1);
|
||||
private string _clientId = string.Empty;
|
||||
private bool _cleanSession = true;
|
||||
|
||||
public async Task RunAsync(CancellationToken ct)
|
||||
{
|
||||
@@ -31,7 +32,11 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
|
||||
{
|
||||
case MqttPacketType.Connect:
|
||||
_clientId = packet.ClientId;
|
||||
_cleanSession = packet.CleanSession;
|
||||
var pending = _listener.OpenSession(_clientId, _cleanSession);
|
||||
await WriteLineAsync("CONNACK", ct);
|
||||
foreach (var redelivery in pending)
|
||||
await WriteLineAsync($"REDLIVER {redelivery.PacketId} {redelivery.Topic} {redelivery.Payload}", ct);
|
||||
break;
|
||||
case MqttPacketType.Subscribe:
|
||||
_listener.RegisterSubscription(this, packet.Topic);
|
||||
@@ -40,6 +45,14 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
|
||||
case MqttPacketType.Publish:
|
||||
await _listener.PublishAsync(packet.Topic, packet.Payload, this, ct);
|
||||
break;
|
||||
case MqttPacketType.PublishQos1:
|
||||
_listener.RecordPendingPublish(_clientId, packet.PacketId, packet.Topic, packet.Payload);
|
||||
await WriteLineAsync($"PUBACK {packet.PacketId}", ct);
|
||||
await _listener.PublishAsync(packet.Topic, packet.Payload, this, ct);
|
||||
break;
|
||||
case MqttPacketType.Ack:
|
||||
_listener.AckPendingPublish(_clientId, packet.PacketId);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable
|
||||
private int _port = port;
|
||||
private readonly ConcurrentDictionary<MqttConnection, byte> _connections = new();
|
||||
private readonly ConcurrentDictionary<string, ConcurrentDictionary<MqttConnection, byte>> _subscriptions = new(StringComparer.Ordinal);
|
||||
private readonly ConcurrentDictionary<string, MqttSessionState> _sessions = new(StringComparer.Ordinal);
|
||||
private TcpListener? _listener;
|
||||
private Task? _acceptLoop;
|
||||
private readonly CancellationTokenSource _cts = new();
|
||||
@@ -49,6 +50,41 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable
|
||||
}
|
||||
}
|
||||
|
||||
internal IReadOnlyList<MqttPendingPublish> OpenSession(string clientId, bool cleanSession)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(clientId))
|
||||
return [];
|
||||
|
||||
if (cleanSession)
|
||||
{
|
||||
_sessions.TryRemove(clientId, out _);
|
||||
return [];
|
||||
}
|
||||
|
||||
var session = _sessions.GetOrAdd(clientId, static _ => new MqttSessionState());
|
||||
return session.Pending.Values
|
||||
.OrderBy(static p => p.PacketId)
|
||||
.ToArray();
|
||||
}
|
||||
|
||||
internal void RecordPendingPublish(string clientId, int packetId, string topic, string payload)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(clientId) || packetId <= 0)
|
||||
return;
|
||||
|
||||
var session = _sessions.GetOrAdd(clientId, static _ => new MqttSessionState());
|
||||
session.Pending[packetId] = new MqttPendingPublish(packetId, topic, payload);
|
||||
}
|
||||
|
||||
internal void AckPendingPublish(string clientId, int packetId)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(clientId) || packetId <= 0)
|
||||
return;
|
||||
|
||||
if (_sessions.TryGetValue(clientId, out var session))
|
||||
session.Pending.TryRemove(packetId, out _);
|
||||
}
|
||||
|
||||
internal void Unregister(MqttConnection connection)
|
||||
{
|
||||
_connections.TryRemove(connection, out _);
|
||||
@@ -69,6 +105,7 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable
|
||||
|
||||
_connections.Clear();
|
||||
_subscriptions.Clear();
|
||||
_sessions.Clear();
|
||||
_cts.Dispose();
|
||||
}
|
||||
|
||||
@@ -101,4 +138,11 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable
|
||||
}, ct);
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class MqttSessionState
|
||||
{
|
||||
public ConcurrentDictionary<int, MqttPendingPublish> Pending { get; } = new();
|
||||
}
|
||||
}
|
||||
|
||||
internal sealed record MqttPendingPublish(int PacketId, string Topic, string Payload);
|
||||
|
||||
@@ -6,9 +6,17 @@ public enum MqttPacketType
|
||||
Connect,
|
||||
Subscribe,
|
||||
Publish,
|
||||
PublishQos1,
|
||||
Ack,
|
||||
}
|
||||
|
||||
public sealed record MqttPacket(MqttPacketType Type, string Topic, string Payload, string ClientId);
|
||||
public sealed record MqttPacket(
|
||||
MqttPacketType Type,
|
||||
string Topic,
|
||||
string Payload,
|
||||
string ClientId,
|
||||
int PacketId = 0,
|
||||
bool CleanSession = true);
|
||||
|
||||
public sealed class MqttProtocolParser
|
||||
{
|
||||
@@ -26,11 +34,26 @@ public sealed class MqttProtocolParser
|
||||
|
||||
if (trimmed.StartsWith("CONNECT ", StringComparison.Ordinal))
|
||||
{
|
||||
var parts = trimmed.Split(' ', StringSplitOptions.RemoveEmptyEntries);
|
||||
if (parts.Length < 2)
|
||||
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
|
||||
|
||||
var cleanSession = true;
|
||||
for (var i = 2; i < parts.Length; i++)
|
||||
{
|
||||
if (parts[i].StartsWith("clean=", StringComparison.OrdinalIgnoreCase)
|
||||
&& bool.TryParse(parts[i]["clean=".Length..], out var parsedClean))
|
||||
{
|
||||
cleanSession = parsedClean;
|
||||
}
|
||||
}
|
||||
|
||||
return new MqttPacket(
|
||||
MqttPacketType.Connect,
|
||||
string.Empty,
|
||||
string.Empty,
|
||||
trimmed["CONNECT ".Length..].Trim());
|
||||
parts[1],
|
||||
CleanSession: cleanSession);
|
||||
}
|
||||
|
||||
if (trimmed.StartsWith("SUB ", StringComparison.Ordinal))
|
||||
@@ -54,6 +77,43 @@ public sealed class MqttProtocolParser
|
||||
return new MqttPacket(MqttPacketType.Publish, topic, payload, string.Empty);
|
||||
}
|
||||
|
||||
if (trimmed.StartsWith("PUBQ1 ", StringComparison.Ordinal))
|
||||
{
|
||||
var rest = trimmed["PUBQ1 ".Length..];
|
||||
var firstSep = rest.IndexOf(' ');
|
||||
if (firstSep <= 0)
|
||||
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
|
||||
|
||||
var secondSep = rest.IndexOf(' ', firstSep + 1);
|
||||
if (secondSep <= firstSep + 1)
|
||||
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
|
||||
|
||||
if (!int.TryParse(rest[..firstSep], out var packetId) || packetId <= 0)
|
||||
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
|
||||
|
||||
var topic = rest[(firstSep + 1)..secondSep].Trim();
|
||||
var payload = rest[(secondSep + 1)..];
|
||||
return new MqttPacket(
|
||||
MqttPacketType.PublishQos1,
|
||||
topic,
|
||||
payload,
|
||||
string.Empty,
|
||||
PacketId: packetId);
|
||||
}
|
||||
|
||||
if (trimmed.StartsWith("ACK ", StringComparison.Ordinal))
|
||||
{
|
||||
if (!int.TryParse(trimmed["ACK ".Length..].Trim(), out var packetId) || packetId <= 0)
|
||||
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
|
||||
|
||||
return new MqttPacket(
|
||||
MqttPacketType.Ack,
|
||||
string.Empty,
|
||||
string.Empty,
|
||||
string.Empty,
|
||||
PacketId: packetId);
|
||||
}
|
||||
|
||||
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,6 +38,9 @@ public sealed class MqttOptions
|
||||
public TimeSpan AckWait { get; set; } = TimeSpan.FromSeconds(30);
|
||||
public ushort MaxAckPending { get; set; }
|
||||
public TimeSpan JsApiTimeout { get; set; } = TimeSpan.FromSeconds(5);
|
||||
public bool SessionPersistence { get; set; } = true;
|
||||
public TimeSpan SessionTtl { get; set; } = TimeSpan.FromHours(1);
|
||||
public bool Qos1PubAck { get; set; } = true;
|
||||
|
||||
public bool HasTls => TlsCert != null && TlsKey != null;
|
||||
}
|
||||
|
||||
26
tests/NATS.Server.Tests/Mqtt/MqttQosAckRuntimeTests.cs
Normal file
26
tests/NATS.Server.Tests/Mqtt/MqttQosAckRuntimeTests.cs
Normal file
@@ -0,0 +1,26 @@
|
||||
using System.Net;
|
||||
using System.Net.Sockets;
|
||||
using NATS.Server.Mqtt;
|
||||
|
||||
namespace NATS.Server.Tests.Mqtt;
|
||||
|
||||
public class MqttQosAckRuntimeTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task Qos1_publish_receives_puback_and_redelivery_on_session_reconnect_when_unacked()
|
||||
{
|
||||
await using var listener = new MqttListener("127.0.0.1", 0);
|
||||
using var cts = new CancellationTokenSource();
|
||||
await listener.StartAsync(cts.Token);
|
||||
|
||||
using var client = new TcpClient();
|
||||
await client.ConnectAsync(IPAddress.Loopback, listener.Port);
|
||||
var stream = client.GetStream();
|
||||
|
||||
await MqttRuntimeWire.WriteLineAsync(stream, "CONNECT qos-client clean=false");
|
||||
(await MqttRuntimeWire.ReadLineAsync(stream, 1000)).ShouldBe("CONNACK");
|
||||
|
||||
await MqttRuntimeWire.WriteLineAsync(stream, "PUBQ1 7 sensors.temp 42");
|
||||
(await MqttRuntimeWire.ReadLineAsync(stream, 1000)).ShouldBe("PUBACK 7");
|
||||
}
|
||||
}
|
||||
71
tests/NATS.Server.Tests/Mqtt/MqttSessionRuntimeTests.cs
Normal file
71
tests/NATS.Server.Tests/Mqtt/MqttSessionRuntimeTests.cs
Normal file
@@ -0,0 +1,71 @@
|
||||
using System.Net;
|
||||
using System.Net.Sockets;
|
||||
using System.Text;
|
||||
using NATS.Server.Mqtt;
|
||||
|
||||
namespace NATS.Server.Tests.Mqtt;
|
||||
|
||||
public class MqttSessionRuntimeTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task Qos1_publish_receives_puback_and_redelivery_on_session_reconnect_when_unacked()
|
||||
{
|
||||
await using var listener = new MqttListener("127.0.0.1", 0);
|
||||
using var cts = new CancellationTokenSource();
|
||||
await listener.StartAsync(cts.Token);
|
||||
|
||||
using (var first = new TcpClient())
|
||||
{
|
||||
await first.ConnectAsync(IPAddress.Loopback, listener.Port);
|
||||
var firstStream = first.GetStream();
|
||||
await MqttRuntimeWire.WriteLineAsync(firstStream, "CONNECT session-client clean=false");
|
||||
(await MqttRuntimeWire.ReadLineAsync(firstStream, 1000)).ShouldBe("CONNACK");
|
||||
|
||||
await MqttRuntimeWire.WriteLineAsync(firstStream, "PUBQ1 21 sensors.temp 99");
|
||||
(await MqttRuntimeWire.ReadLineAsync(firstStream, 1000)).ShouldBe("PUBACK 21");
|
||||
}
|
||||
|
||||
using var second = new TcpClient();
|
||||
await second.ConnectAsync(IPAddress.Loopback, listener.Port);
|
||||
var secondStream = second.GetStream();
|
||||
await MqttRuntimeWire.WriteLineAsync(secondStream, "CONNECT session-client clean=false");
|
||||
(await MqttRuntimeWire.ReadLineAsync(secondStream, 1000)).ShouldBe("CONNACK");
|
||||
(await MqttRuntimeWire.ReadLineAsync(secondStream, 1000)).ShouldBe("REDLIVER 21 sensors.temp 99");
|
||||
}
|
||||
}
|
||||
|
||||
internal static class MqttRuntimeWire
|
||||
{
|
||||
public static async Task WriteLineAsync(NetworkStream stream, string line)
|
||||
{
|
||||
var bytes = Encoding.UTF8.GetBytes(line + "\n");
|
||||
await stream.WriteAsync(bytes);
|
||||
await stream.FlushAsync();
|
||||
}
|
||||
|
||||
public static async Task<string?> ReadLineAsync(NetworkStream stream, int timeoutMs)
|
||||
{
|
||||
using var timeout = new CancellationTokenSource(timeoutMs);
|
||||
var bytes = new List<byte>();
|
||||
var one = new byte[1];
|
||||
try
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
var read = await stream.ReadAsync(one.AsMemory(0, 1), timeout.Token);
|
||||
if (read == 0)
|
||||
return null;
|
||||
if (one[0] == (byte)'\n')
|
||||
break;
|
||||
if (one[0] != (byte)'\r')
|
||||
bytes.Add(one[0]);
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
return Encoding.UTF8.GetString([.. bytes]);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user