feat(mqtt): add session persistence, QoS 2 state machine, and retained store (E2+E3)
Add MqttSessionStore with save/load/delete/list operations, flapper detection (backoff on rapid reconnects), and TimeProvider-based testing. Add MqttRetainedStore for per-topic retained messages with MQTT wildcard matching (+/# filters). Add MqttQos2StateMachine tracking the full PUBREC/PUBREL/PUBCOMP flow with duplicate rejection and timeout detection. 19 new tests: 9 session persistence, 10 QoS/retained message tests.
This commit is contained in:
241
src/NATS.Server/Mqtt/MqttRetainedStore.cs
Normal file
241
src/NATS.Server/Mqtt/MqttRetainedStore.cs
Normal file
@@ -0,0 +1,241 @@
|
||||
// MQTT retained message store and QoS 2 state machine.
|
||||
// Go reference: golang/nats-server/server/mqtt.go
|
||||
// Retained messages — mqttHandleRetainedMsg / mqttGetRetainedMessages (~lines 1600–1700)
|
||||
// QoS 2 flow — mqttProcessPubRec / mqttProcessPubRel / mqttProcessPubComp (~lines 1300–1400)
|
||||
|
||||
using System.Collections.Concurrent;
|
||||
|
||||
namespace NATS.Server.Mqtt;
|
||||
|
||||
/// <summary>
|
||||
/// A retained message stored for a topic.
|
||||
/// </summary>
|
||||
public sealed record MqttRetainedMessage(string Topic, ReadOnlyMemory<byte> Payload);
|
||||
|
||||
/// <summary>
|
||||
/// In-memory store for MQTT retained messages.
|
||||
/// Go reference: server/mqtt.go mqttHandleRetainedMsg ~line 1600.
|
||||
/// </summary>
|
||||
public sealed class MqttRetainedStore
|
||||
{
|
||||
private readonly ConcurrentDictionary<string, ReadOnlyMemory<byte>> _retained = new(StringComparer.Ordinal);
|
||||
|
||||
/// <summary>
|
||||
/// Sets (or clears) the retained message for a topic.
|
||||
/// An empty payload clears the retained message.
|
||||
/// Go reference: server/mqtt.go mqttHandleRetainedMsg.
|
||||
/// </summary>
|
||||
public void SetRetained(string topic, ReadOnlyMemory<byte> payload)
|
||||
{
|
||||
if (payload.IsEmpty)
|
||||
{
|
||||
_retained.TryRemove(topic, out _);
|
||||
return;
|
||||
}
|
||||
|
||||
_retained[topic] = payload;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the retained message payload for a topic, or null if none.
|
||||
/// </summary>
|
||||
public ReadOnlyMemory<byte>? GetRetained(string topic)
|
||||
{
|
||||
if (_retained.TryGetValue(topic, out var payload))
|
||||
return payload;
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns all retained messages matching an MQTT topic filter pattern.
|
||||
/// Supports '+' (single-level) and '#' (multi-level) wildcards.
|
||||
/// Go reference: server/mqtt.go mqttGetRetainedMessages ~line 1650.
|
||||
/// </summary>
|
||||
public IReadOnlyList<MqttRetainedMessage> GetMatchingRetained(string filter)
|
||||
{
|
||||
var results = new List<MqttRetainedMessage>();
|
||||
foreach (var kvp in _retained)
|
||||
{
|
||||
if (MqttTopicMatch(kvp.Key, filter))
|
||||
results.Add(new MqttRetainedMessage(kvp.Key, kvp.Value));
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Matches an MQTT topic against a filter pattern.
|
||||
/// '+' matches exactly one level, '#' matches zero or more levels (must be last).
|
||||
/// </summary>
|
||||
internal static bool MqttTopicMatch(string topic, string filter)
|
||||
{
|
||||
var topicLevels = topic.Split('/');
|
||||
var filterLevels = filter.Split('/');
|
||||
|
||||
for (var i = 0; i < filterLevels.Length; i++)
|
||||
{
|
||||
if (filterLevels[i] == "#")
|
||||
return true; // '#' matches everything from here
|
||||
|
||||
if (i >= topicLevels.Length)
|
||||
return false; // filter has more levels than topic
|
||||
|
||||
if (filterLevels[i] != "+" && filterLevels[i] != topicLevels[i])
|
||||
return false;
|
||||
}
|
||||
|
||||
// Topic must not have more levels than filter (unless filter ended with '#')
|
||||
return topicLevels.Length == filterLevels.Length;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// QoS 2 state machine states.
|
||||
/// Go reference: server/mqtt.go ~line 1300.
|
||||
/// </summary>
|
||||
public enum MqttQos2State
|
||||
{
|
||||
/// <summary>Publish received, awaiting PUBREC from peer.</summary>
|
||||
AwaitingPubRec,
|
||||
|
||||
/// <summary>PUBREC received, awaiting PUBREL from originator.</summary>
|
||||
AwaitingPubRel,
|
||||
|
||||
/// <summary>PUBREL received, awaiting PUBCOMP from peer.</summary>
|
||||
AwaitingPubComp,
|
||||
|
||||
/// <summary>Flow complete.</summary>
|
||||
Complete,
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tracks QoS 2 flow state for a single packet ID.
|
||||
/// </summary>
|
||||
internal sealed class MqttQos2Flow
|
||||
{
|
||||
public MqttQos2State State { get; set; }
|
||||
public DateTime StartedAtUtc { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Manages the QoS 2 exactly-once delivery state machine for a connection.
|
||||
/// Tracks per-packet-id state transitions: PUBLISH -> PUBREC -> PUBREL -> PUBCOMP.
|
||||
/// Go reference: server/mqtt.go mqttProcessPubRec / mqttProcessPubRel / mqttProcessPubComp.
|
||||
/// </summary>
|
||||
public sealed class MqttQos2StateMachine
|
||||
{
|
||||
private readonly ConcurrentDictionary<ushort, MqttQos2Flow> _flows = new();
|
||||
private readonly TimeSpan _timeout;
|
||||
private readonly TimeProvider _timeProvider;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new QoS 2 state machine.
|
||||
/// </summary>
|
||||
/// <param name="timeout">Timeout for incomplete flows. Default 30 seconds.</param>
|
||||
/// <param name="timeProvider">Optional time provider for testing.</param>
|
||||
public MqttQos2StateMachine(TimeSpan? timeout = null, TimeProvider? timeProvider = null)
|
||||
{
|
||||
_timeout = timeout ?? TimeSpan.FromSeconds(30);
|
||||
_timeProvider = timeProvider ?? TimeProvider.System;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Begins a new QoS 2 flow for the given packet ID.
|
||||
/// Returns false if a flow for this packet ID already exists (duplicate publish).
|
||||
/// </summary>
|
||||
public bool BeginPublish(ushort packetId)
|
||||
{
|
||||
var flow = new MqttQos2Flow
|
||||
{
|
||||
State = MqttQos2State.AwaitingPubRec,
|
||||
StartedAtUtc = _timeProvider.GetUtcNow().UtcDateTime,
|
||||
};
|
||||
|
||||
return _flows.TryAdd(packetId, flow);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Processes a PUBREC for the given packet ID.
|
||||
/// Returns false if the flow is not in the expected state.
|
||||
/// </summary>
|
||||
public bool ProcessPubRec(ushort packetId)
|
||||
{
|
||||
if (!_flows.TryGetValue(packetId, out var flow))
|
||||
return false;
|
||||
|
||||
if (flow.State != MqttQos2State.AwaitingPubRec)
|
||||
return false;
|
||||
|
||||
flow.State = MqttQos2State.AwaitingPubRel;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Processes a PUBREL for the given packet ID.
|
||||
/// Returns false if the flow is not in the expected state.
|
||||
/// </summary>
|
||||
public bool ProcessPubRel(ushort packetId)
|
||||
{
|
||||
if (!_flows.TryGetValue(packetId, out var flow))
|
||||
return false;
|
||||
|
||||
if (flow.State != MqttQos2State.AwaitingPubRel)
|
||||
return false;
|
||||
|
||||
flow.State = MqttQos2State.AwaitingPubComp;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Processes a PUBCOMP for the given packet ID.
|
||||
/// Returns false if the flow is not in the expected state.
|
||||
/// Removes the flow on completion.
|
||||
/// </summary>
|
||||
public bool ProcessPubComp(ushort packetId)
|
||||
{
|
||||
if (!_flows.TryGetValue(packetId, out var flow))
|
||||
return false;
|
||||
|
||||
if (flow.State != MqttQos2State.AwaitingPubComp)
|
||||
return false;
|
||||
|
||||
flow.State = MqttQos2State.Complete;
|
||||
_flows.TryRemove(packetId, out _);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the current state for a packet ID, or null if no flow exists.
|
||||
/// </summary>
|
||||
public MqttQos2State? GetState(ushort packetId)
|
||||
{
|
||||
if (_flows.TryGetValue(packetId, out var flow))
|
||||
return flow.State;
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns packet IDs for flows that have exceeded the timeout.
|
||||
/// </summary>
|
||||
public IReadOnlyList<ushort> GetTimedOutFlows()
|
||||
{
|
||||
var now = _timeProvider.GetUtcNow().UtcDateTime;
|
||||
var timedOut = new List<ushort>();
|
||||
|
||||
foreach (var kvp in _flows)
|
||||
{
|
||||
if (now - kvp.Value.StartedAtUtc > _timeout)
|
||||
timedOut.Add(kvp.Key);
|
||||
}
|
||||
|
||||
return timedOut;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Removes a flow (e.g., after timeout cleanup).
|
||||
/// </summary>
|
||||
public void RemoveFlow(ushort packetId) =>
|
||||
_flows.TryRemove(packetId, out _);
|
||||
}
|
||||
Reference in New Issue
Block a user