// MQTT retained message store and QoS 2 state machine.
// Go reference: golang/nats-server/server/mqtt.go
// Retained messages — mqttHandleRetainedMsg / mqttGetRetainedMessages (~lines 1600–1700)
// QoS 2 flow — mqttProcessPubRec / mqttProcessPubRel / mqttProcessPubComp (~lines 1300–1400)
using System.Collections.Concurrent;
using NATS.Server.JetStream.Storage;
namespace NATS.Server.Mqtt;
///
/// A retained message stored for a topic.
///
public sealed record MqttRetainedMessage(string Topic, ReadOnlyMemory Payload);
///
/// In-memory store for MQTT retained messages.
/// Go reference: server/mqtt.go mqttHandleRetainedMsg ~line 1600.
///
public sealed class MqttRetainedStore
{
private readonly ConcurrentDictionary> _retained = new(StringComparer.Ordinal);
// Topics explicitly cleared in this session — prevents falling back to backing store for cleared topics.
private readonly ConcurrentDictionary _cleared = new(StringComparer.Ordinal);
private readonly IStreamStore? _backingStore;
/// Backing store for JetStream persistence.
public IStreamStore? BackingStore => _backingStore;
///
/// Initializes a new in-memory retained message store with no backing store.
///
public MqttRetainedStore() : this(null) { }
///
/// Initializes a new retained message store with an optional JetStream backing store.
///
/// Optional JetStream stream store for persistence.
public MqttRetainedStore(IStreamStore? backingStore)
{
_backingStore = backingStore;
}
///
/// Sets (or clears) the retained message for a topic.
/// An empty payload clears the retained message.
/// Go reference: server/mqtt.go mqttHandleRetainedMsg.
///
public void SetRetained(string topic, ReadOnlyMemory payload)
{
if (payload.IsEmpty)
{
_retained.TryRemove(topic, out _);
_cleared[topic] = true;
return;
}
_cleared.TryRemove(topic, out _);
_retained[topic] = payload;
}
///
/// Gets the retained message payload for a topic, or null if none.
///
public ReadOnlyMemory? GetRetained(string topic)
{
if (_retained.TryGetValue(topic, out var payload))
return payload;
return null;
}
///
/// Returns all retained messages matching an MQTT topic filter pattern.
/// Supports '+' (single-level) and '#' (multi-level) wildcards.
/// Go reference: server/mqtt.go mqttGetRetainedMessages ~line 1650.
///
public IReadOnlyList GetMatchingRetained(string filter)
{
var results = new List();
foreach (var kvp in _retained)
{
if (MqttTopicMatch(kvp.Key, filter))
results.Add(new MqttRetainedMessage(kvp.Key, kvp.Value));
}
return results;
}
///
/// Sets (or clears) the retained message and persists to backing store.
/// Go reference: server/mqtt.go mqttHandleRetainedMsg with JetStream.
///
public async Task SetRetainedAsync(string topic, ReadOnlyMemory payload, CancellationToken ct = default)
{
SetRetained(topic, payload);
if (_backingStore is not null)
{
if (payload.IsEmpty)
{
// Clear — the in-memory clear above is sufficient for this implementation.
// A full implementation would publish a tombstone to JetStream.
return;
}
await _backingStore.AppendAsync($"$MQTT.rmsgs.{topic}", payload, ct);
}
}
///
/// Gets the retained message, checking backing store if not in memory.
/// Returns null if the topic was explicitly cleared in this session.
///
public async Task GetRetainedAsync(string topic, CancellationToken ct = default)
{
var mem = GetRetained(topic);
if (mem.HasValue)
return mem.Value.ToArray();
// Don't consult the backing store if this topic was explicitly cleared in this session.
if (_cleared.ContainsKey(topic))
return null;
if (_backingStore is not null)
{
var messages = await _backingStore.ListAsync(ct);
foreach (var msg in messages)
{
if (msg.Subject == $"$MQTT.rmsgs.{topic}")
return msg.Payload.ToArray();
}
}
return null;
}
///
/// Matches an MQTT topic against a filter pattern.
/// '+' matches exactly one level, '#' matches zero or more levels (must be last).
///
internal static bool MqttTopicMatch(string topic, string filter)
{
var topicLevels = topic.Split('/');
var filterLevels = filter.Split('/');
for (var i = 0; i < filterLevels.Length; i++)
{
if (filterLevels[i] == "#")
return true; // '#' matches everything from here
if (i >= topicLevels.Length)
return false; // filter has more levels than topic
if (filterLevels[i] != "+" && filterLevels[i] != topicLevels[i])
return false;
}
// Topic must not have more levels than filter (unless filter ended with '#')
return topicLevels.Length == filterLevels.Length;
}
}
///
/// QoS 2 state machine states.
/// Go reference: server/mqtt.go ~line 1300.
///
public enum MqttQos2State
{
/// Publish received, awaiting PUBREC from peer.
AwaitingPubRec,
/// PUBREC received, awaiting PUBREL from originator.
AwaitingPubRel,
/// PUBREL received, awaiting PUBCOMP from peer.
AwaitingPubComp,
/// Flow complete.
Complete,
}
///
/// Tracks QoS 2 flow state for a single packet ID.
///
internal sealed class MqttQos2Flow
{
public MqttQos2State State { get; set; }
public DateTime StartedAtUtc { get; init; }
}
///
/// Manages the QoS 2 exactly-once delivery state machine for a connection.
/// Tracks per-packet-id state transitions: PUBLISH -> PUBREC -> PUBREL -> PUBCOMP.
/// Go reference: server/mqtt.go mqttProcessPubRec / mqttProcessPubRel / mqttProcessPubComp.
///
public sealed class MqttQos2StateMachine
{
private readonly ConcurrentDictionary _flows = new();
private readonly TimeSpan _timeout;
private readonly TimeProvider _timeProvider;
///
/// Initializes a new QoS 2 state machine.
///
/// Timeout for incomplete flows. Default 30 seconds.
/// Optional time provider for testing.
public MqttQos2StateMachine(TimeSpan? timeout = null, TimeProvider? timeProvider = null)
{
_timeout = timeout ?? TimeSpan.FromSeconds(30);
_timeProvider = timeProvider ?? TimeProvider.System;
}
///
/// Begins a new QoS 2 flow for the given packet ID.
/// Returns false if a flow for this packet ID already exists (duplicate publish).
///
public bool BeginPublish(ushort packetId)
{
var flow = new MqttQos2Flow
{
State = MqttQos2State.AwaitingPubRec,
StartedAtUtc = _timeProvider.GetUtcNow().UtcDateTime,
};
return _flows.TryAdd(packetId, flow);
}
///
/// Processes a PUBREC for the given packet ID.
/// Returns false if the flow is not in the expected state.
///
public bool ProcessPubRec(ushort packetId)
{
if (!_flows.TryGetValue(packetId, out var flow))
return false;
if (flow.State != MqttQos2State.AwaitingPubRec)
return false;
flow.State = MqttQos2State.AwaitingPubRel;
return true;
}
///
/// Processes a PUBREL for the given packet ID.
/// Returns false if the flow is not in the expected state.
///
public bool ProcessPubRel(ushort packetId)
{
if (!_flows.TryGetValue(packetId, out var flow))
return false;
if (flow.State != MqttQos2State.AwaitingPubRel)
return false;
flow.State = MqttQos2State.AwaitingPubComp;
return true;
}
///
/// Processes a PUBCOMP for the given packet ID.
/// Returns false if the flow is not in the expected state.
/// Removes the flow on completion.
///
public bool ProcessPubComp(ushort packetId)
{
if (!_flows.TryGetValue(packetId, out var flow))
return false;
if (flow.State != MqttQos2State.AwaitingPubComp)
return false;
flow.State = MqttQos2State.Complete;
_flows.TryRemove(packetId, out _);
return true;
}
///
/// Gets the current state for a packet ID, or null if no flow exists.
///
public MqttQos2State? GetState(ushort packetId)
{
if (_flows.TryGetValue(packetId, out var flow))
return flow.State;
return null;
}
///
/// Returns packet IDs for flows that have exceeded the timeout.
///
public IReadOnlyList GetTimedOutFlows()
{
var now = _timeProvider.GetUtcNow().UtcDateTime;
var timedOut = new List();
foreach (var kvp in _flows)
{
if (now - kvp.Value.StartedAtUtc > _timeout)
timedOut.Add(kvp.Key);
}
return timedOut;
}
///
/// Removes a flow (e.g., after timeout cleanup).
///
public void RemoveFlow(ushort packetId) =>
_flows.TryRemove(packetId, out _);
}