feat: implement mqtt session and qos ack runtime semantics
This commit is contained in:
@@ -11,6 +11,7 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
|
|||||||
private readonly MqttProtocolParser _parser = new();
|
private readonly MqttProtocolParser _parser = new();
|
||||||
private readonly SemaphoreSlim _writeGate = new(1, 1);
|
private readonly SemaphoreSlim _writeGate = new(1, 1);
|
||||||
private string _clientId = string.Empty;
|
private string _clientId = string.Empty;
|
||||||
|
private bool _cleanSession = true;
|
||||||
|
|
||||||
public async Task RunAsync(CancellationToken ct)
|
public async Task RunAsync(CancellationToken ct)
|
||||||
{
|
{
|
||||||
@@ -31,7 +32,11 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
|
|||||||
{
|
{
|
||||||
case MqttPacketType.Connect:
|
case MqttPacketType.Connect:
|
||||||
_clientId = packet.ClientId;
|
_clientId = packet.ClientId;
|
||||||
|
_cleanSession = packet.CleanSession;
|
||||||
|
var pending = _listener.OpenSession(_clientId, _cleanSession);
|
||||||
await WriteLineAsync("CONNACK", ct);
|
await WriteLineAsync("CONNACK", ct);
|
||||||
|
foreach (var redelivery in pending)
|
||||||
|
await WriteLineAsync($"REDLIVER {redelivery.PacketId} {redelivery.Topic} {redelivery.Payload}", ct);
|
||||||
break;
|
break;
|
||||||
case MqttPacketType.Subscribe:
|
case MqttPacketType.Subscribe:
|
||||||
_listener.RegisterSubscription(this, packet.Topic);
|
_listener.RegisterSubscription(this, packet.Topic);
|
||||||
@@ -40,6 +45,14 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
|
|||||||
case MqttPacketType.Publish:
|
case MqttPacketType.Publish:
|
||||||
await _listener.PublishAsync(packet.Topic, packet.Payload, this, ct);
|
await _listener.PublishAsync(packet.Topic, packet.Payload, this, ct);
|
||||||
break;
|
break;
|
||||||
|
case MqttPacketType.PublishQos1:
|
||||||
|
_listener.RecordPendingPublish(_clientId, packet.PacketId, packet.Topic, packet.Payload);
|
||||||
|
await WriteLineAsync($"PUBACK {packet.PacketId}", ct);
|
||||||
|
await _listener.PublishAsync(packet.Topic, packet.Payload, this, ct);
|
||||||
|
break;
|
||||||
|
case MqttPacketType.Ack:
|
||||||
|
_listener.AckPendingPublish(_clientId, packet.PacketId);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable
|
|||||||
private int _port = port;
|
private int _port = port;
|
||||||
private readonly ConcurrentDictionary<MqttConnection, byte> _connections = new();
|
private readonly ConcurrentDictionary<MqttConnection, byte> _connections = new();
|
||||||
private readonly ConcurrentDictionary<string, ConcurrentDictionary<MqttConnection, byte>> _subscriptions = new(StringComparer.Ordinal);
|
private readonly ConcurrentDictionary<string, ConcurrentDictionary<MqttConnection, byte>> _subscriptions = new(StringComparer.Ordinal);
|
||||||
|
private readonly ConcurrentDictionary<string, MqttSessionState> _sessions = new(StringComparer.Ordinal);
|
||||||
private TcpListener? _listener;
|
private TcpListener? _listener;
|
||||||
private Task? _acceptLoop;
|
private Task? _acceptLoop;
|
||||||
private readonly CancellationTokenSource _cts = new();
|
private readonly CancellationTokenSource _cts = new();
|
||||||
@@ -49,6 +50,41 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal IReadOnlyList<MqttPendingPublish> OpenSession(string clientId, bool cleanSession)
|
||||||
|
{
|
||||||
|
if (string.IsNullOrWhiteSpace(clientId))
|
||||||
|
return [];
|
||||||
|
|
||||||
|
if (cleanSession)
|
||||||
|
{
|
||||||
|
_sessions.TryRemove(clientId, out _);
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
var session = _sessions.GetOrAdd(clientId, static _ => new MqttSessionState());
|
||||||
|
return session.Pending.Values
|
||||||
|
.OrderBy(static p => p.PacketId)
|
||||||
|
.ToArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
internal void RecordPendingPublish(string clientId, int packetId, string topic, string payload)
|
||||||
|
{
|
||||||
|
if (string.IsNullOrWhiteSpace(clientId) || packetId <= 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
var session = _sessions.GetOrAdd(clientId, static _ => new MqttSessionState());
|
||||||
|
session.Pending[packetId] = new MqttPendingPublish(packetId, topic, payload);
|
||||||
|
}
|
||||||
|
|
||||||
|
internal void AckPendingPublish(string clientId, int packetId)
|
||||||
|
{
|
||||||
|
if (string.IsNullOrWhiteSpace(clientId) || packetId <= 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
if (_sessions.TryGetValue(clientId, out var session))
|
||||||
|
session.Pending.TryRemove(packetId, out _);
|
||||||
|
}
|
||||||
|
|
||||||
internal void Unregister(MqttConnection connection)
|
internal void Unregister(MqttConnection connection)
|
||||||
{
|
{
|
||||||
_connections.TryRemove(connection, out _);
|
_connections.TryRemove(connection, out _);
|
||||||
@@ -69,6 +105,7 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable
|
|||||||
|
|
||||||
_connections.Clear();
|
_connections.Clear();
|
||||||
_subscriptions.Clear();
|
_subscriptions.Clear();
|
||||||
|
_sessions.Clear();
|
||||||
_cts.Dispose();
|
_cts.Dispose();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,4 +138,11 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable
|
|||||||
}, ct);
|
}, ct);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private sealed class MqttSessionState
|
||||||
|
{
|
||||||
|
public ConcurrentDictionary<int, MqttPendingPublish> Pending { get; } = new();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal sealed record MqttPendingPublish(int PacketId, string Topic, string Payload);
|
||||||
|
|||||||
@@ -6,9 +6,17 @@ public enum MqttPacketType
|
|||||||
Connect,
|
Connect,
|
||||||
Subscribe,
|
Subscribe,
|
||||||
Publish,
|
Publish,
|
||||||
|
PublishQos1,
|
||||||
|
Ack,
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed record MqttPacket(MqttPacketType Type, string Topic, string Payload, string ClientId);
|
public sealed record MqttPacket(
|
||||||
|
MqttPacketType Type,
|
||||||
|
string Topic,
|
||||||
|
string Payload,
|
||||||
|
string ClientId,
|
||||||
|
int PacketId = 0,
|
||||||
|
bool CleanSession = true);
|
||||||
|
|
||||||
public sealed class MqttProtocolParser
|
public sealed class MqttProtocolParser
|
||||||
{
|
{
|
||||||
@@ -26,11 +34,26 @@ public sealed class MqttProtocolParser
|
|||||||
|
|
||||||
if (trimmed.StartsWith("CONNECT ", StringComparison.Ordinal))
|
if (trimmed.StartsWith("CONNECT ", StringComparison.Ordinal))
|
||||||
{
|
{
|
||||||
|
var parts = trimmed.Split(' ', StringSplitOptions.RemoveEmptyEntries);
|
||||||
|
if (parts.Length < 2)
|
||||||
|
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
|
||||||
|
|
||||||
|
var cleanSession = true;
|
||||||
|
for (var i = 2; i < parts.Length; i++)
|
||||||
|
{
|
||||||
|
if (parts[i].StartsWith("clean=", StringComparison.OrdinalIgnoreCase)
|
||||||
|
&& bool.TryParse(parts[i]["clean=".Length..], out var parsedClean))
|
||||||
|
{
|
||||||
|
cleanSession = parsedClean;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return new MqttPacket(
|
return new MqttPacket(
|
||||||
MqttPacketType.Connect,
|
MqttPacketType.Connect,
|
||||||
string.Empty,
|
string.Empty,
|
||||||
string.Empty,
|
string.Empty,
|
||||||
trimmed["CONNECT ".Length..].Trim());
|
parts[1],
|
||||||
|
CleanSession: cleanSession);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (trimmed.StartsWith("SUB ", StringComparison.Ordinal))
|
if (trimmed.StartsWith("SUB ", StringComparison.Ordinal))
|
||||||
@@ -54,6 +77,43 @@ public sealed class MqttProtocolParser
|
|||||||
return new MqttPacket(MqttPacketType.Publish, topic, payload, string.Empty);
|
return new MqttPacket(MqttPacketType.Publish, topic, payload, string.Empty);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (trimmed.StartsWith("PUBQ1 ", StringComparison.Ordinal))
|
||||||
|
{
|
||||||
|
var rest = trimmed["PUBQ1 ".Length..];
|
||||||
|
var firstSep = rest.IndexOf(' ');
|
||||||
|
if (firstSep <= 0)
|
||||||
|
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
|
||||||
|
|
||||||
|
var secondSep = rest.IndexOf(' ', firstSep + 1);
|
||||||
|
if (secondSep <= firstSep + 1)
|
||||||
|
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
|
||||||
|
|
||||||
|
if (!int.TryParse(rest[..firstSep], out var packetId) || packetId <= 0)
|
||||||
|
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
|
||||||
|
|
||||||
|
var topic = rest[(firstSep + 1)..secondSep].Trim();
|
||||||
|
var payload = rest[(secondSep + 1)..];
|
||||||
|
return new MqttPacket(
|
||||||
|
MqttPacketType.PublishQos1,
|
||||||
|
topic,
|
||||||
|
payload,
|
||||||
|
string.Empty,
|
||||||
|
PacketId: packetId);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (trimmed.StartsWith("ACK ", StringComparison.Ordinal))
|
||||||
|
{
|
||||||
|
if (!int.TryParse(trimmed["ACK ".Length..].Trim(), out var packetId) || packetId <= 0)
|
||||||
|
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
|
||||||
|
|
||||||
|
return new MqttPacket(
|
||||||
|
MqttPacketType.Ack,
|
||||||
|
string.Empty,
|
||||||
|
string.Empty,
|
||||||
|
string.Empty,
|
||||||
|
PacketId: packetId);
|
||||||
|
}
|
||||||
|
|
||||||
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
|
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,6 +38,9 @@ public sealed class MqttOptions
|
|||||||
public TimeSpan AckWait { get; set; } = TimeSpan.FromSeconds(30);
|
public TimeSpan AckWait { get; set; } = TimeSpan.FromSeconds(30);
|
||||||
public ushort MaxAckPending { get; set; }
|
public ushort MaxAckPending { get; set; }
|
||||||
public TimeSpan JsApiTimeout { get; set; } = TimeSpan.FromSeconds(5);
|
public TimeSpan JsApiTimeout { get; set; } = TimeSpan.FromSeconds(5);
|
||||||
|
public bool SessionPersistence { get; set; } = true;
|
||||||
|
public TimeSpan SessionTtl { get; set; } = TimeSpan.FromHours(1);
|
||||||
|
public bool Qos1PubAck { get; set; } = true;
|
||||||
|
|
||||||
public bool HasTls => TlsCert != null && TlsKey != null;
|
public bool HasTls => TlsCert != null && TlsKey != null;
|
||||||
}
|
}
|
||||||
|
|||||||
26
tests/NATS.Server.Tests/Mqtt/MqttQosAckRuntimeTests.cs
Normal file
26
tests/NATS.Server.Tests/Mqtt/MqttQosAckRuntimeTests.cs
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
using System.Net;
|
||||||
|
using System.Net.Sockets;
|
||||||
|
using NATS.Server.Mqtt;
|
||||||
|
|
||||||
|
namespace NATS.Server.Tests.Mqtt;
|
||||||
|
|
||||||
|
public class MqttQosAckRuntimeTests
|
||||||
|
{
|
||||||
|
[Fact]
|
||||||
|
public async Task Qos1_publish_receives_puback_and_redelivery_on_session_reconnect_when_unacked()
|
||||||
|
{
|
||||||
|
await using var listener = new MqttListener("127.0.0.1", 0);
|
||||||
|
using var cts = new CancellationTokenSource();
|
||||||
|
await listener.StartAsync(cts.Token);
|
||||||
|
|
||||||
|
using var client = new TcpClient();
|
||||||
|
await client.ConnectAsync(IPAddress.Loopback, listener.Port);
|
||||||
|
var stream = client.GetStream();
|
||||||
|
|
||||||
|
await MqttRuntimeWire.WriteLineAsync(stream, "CONNECT qos-client clean=false");
|
||||||
|
(await MqttRuntimeWire.ReadLineAsync(stream, 1000)).ShouldBe("CONNACK");
|
||||||
|
|
||||||
|
await MqttRuntimeWire.WriteLineAsync(stream, "PUBQ1 7 sensors.temp 42");
|
||||||
|
(await MqttRuntimeWire.ReadLineAsync(stream, 1000)).ShouldBe("PUBACK 7");
|
||||||
|
}
|
||||||
|
}
|
||||||
71
tests/NATS.Server.Tests/Mqtt/MqttSessionRuntimeTests.cs
Normal file
71
tests/NATS.Server.Tests/Mqtt/MqttSessionRuntimeTests.cs
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
using System.Net;
|
||||||
|
using System.Net.Sockets;
|
||||||
|
using System.Text;
|
||||||
|
using NATS.Server.Mqtt;
|
||||||
|
|
||||||
|
namespace NATS.Server.Tests.Mqtt;
|
||||||
|
|
||||||
|
public class MqttSessionRuntimeTests
|
||||||
|
{
|
||||||
|
[Fact]
|
||||||
|
public async Task Qos1_publish_receives_puback_and_redelivery_on_session_reconnect_when_unacked()
|
||||||
|
{
|
||||||
|
await using var listener = new MqttListener("127.0.0.1", 0);
|
||||||
|
using var cts = new CancellationTokenSource();
|
||||||
|
await listener.StartAsync(cts.Token);
|
||||||
|
|
||||||
|
using (var first = new TcpClient())
|
||||||
|
{
|
||||||
|
await first.ConnectAsync(IPAddress.Loopback, listener.Port);
|
||||||
|
var firstStream = first.GetStream();
|
||||||
|
await MqttRuntimeWire.WriteLineAsync(firstStream, "CONNECT session-client clean=false");
|
||||||
|
(await MqttRuntimeWire.ReadLineAsync(firstStream, 1000)).ShouldBe("CONNACK");
|
||||||
|
|
||||||
|
await MqttRuntimeWire.WriteLineAsync(firstStream, "PUBQ1 21 sensors.temp 99");
|
||||||
|
(await MqttRuntimeWire.ReadLineAsync(firstStream, 1000)).ShouldBe("PUBACK 21");
|
||||||
|
}
|
||||||
|
|
||||||
|
using var second = new TcpClient();
|
||||||
|
await second.ConnectAsync(IPAddress.Loopback, listener.Port);
|
||||||
|
var secondStream = second.GetStream();
|
||||||
|
await MqttRuntimeWire.WriteLineAsync(secondStream, "CONNECT session-client clean=false");
|
||||||
|
(await MqttRuntimeWire.ReadLineAsync(secondStream, 1000)).ShouldBe("CONNACK");
|
||||||
|
(await MqttRuntimeWire.ReadLineAsync(secondStream, 1000)).ShouldBe("REDLIVER 21 sensors.temp 99");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal static class MqttRuntimeWire
|
||||||
|
{
|
||||||
|
public static async Task WriteLineAsync(NetworkStream stream, string line)
|
||||||
|
{
|
||||||
|
var bytes = Encoding.UTF8.GetBytes(line + "\n");
|
||||||
|
await stream.WriteAsync(bytes);
|
||||||
|
await stream.FlushAsync();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static async Task<string?> ReadLineAsync(NetworkStream stream, int timeoutMs)
|
||||||
|
{
|
||||||
|
using var timeout = new CancellationTokenSource(timeoutMs);
|
||||||
|
var bytes = new List<byte>();
|
||||||
|
var one = new byte[1];
|
||||||
|
try
|
||||||
|
{
|
||||||
|
while (true)
|
||||||
|
{
|
||||||
|
var read = await stream.ReadAsync(one.AsMemory(0, 1), timeout.Token);
|
||||||
|
if (read == 0)
|
||||||
|
return null;
|
||||||
|
if (one[0] == (byte)'\n')
|
||||||
|
break;
|
||||||
|
if (one[0] != (byte)'\r')
|
||||||
|
bytes.Add(one[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (OperationCanceledException)
|
||||||
|
{
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Encoding.UTF8.GetString([.. bytes]);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user