Files
natsdotnet/src/NATS.Server/Mqtt/MqttRetainedStore.cs
Joseph Doherty b7bac8e68e feat(mqtt): add JetStream-backed session and retained message persistence
Add optional IStreamStore backing to MqttSessionStore and MqttRetainedStore,
enabling session and retained message state to survive process restarts via
JetStream persistence. Includes ConnectAsync/SaveSessionAsync for session
lifecycle, SetRetainedAsync/GetRetainedAsync with cleared-topic tombstone
tracking, and 4 new parity tests covering persist/restart/clear semantics.
2026-02-25 02:42:02 -05:00

314 lines
10 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// MQTT retained message store and QoS 2 state machine.
// Go reference: golang/nats-server/server/mqtt.go
// Retained messages — mqttHandleRetainedMsg / mqttGetRetainedMessages (~lines 16001700)
// QoS 2 flow — mqttProcessPubRec / mqttProcessPubRel / mqttProcessPubComp (~lines 13001400)
using System.Collections.Concurrent;
using NATS.Server.JetStream.Storage;
namespace NATS.Server.Mqtt;
/// <summary>
/// A retained message stored for a topic.
/// </summary>
public sealed record MqttRetainedMessage(string Topic, ReadOnlyMemory<byte> Payload);
/// <summary>
/// In-memory store for MQTT retained messages.
/// Go reference: server/mqtt.go mqttHandleRetainedMsg ~line 1600.
/// </summary>
public sealed class MqttRetainedStore
{
private readonly ConcurrentDictionary<string, ReadOnlyMemory<byte>> _retained = new(StringComparer.Ordinal);
// Topics explicitly cleared in this session — prevents falling back to backing store for cleared topics.
private readonly ConcurrentDictionary<string, bool> _cleared = new(StringComparer.Ordinal);
private readonly IStreamStore? _backingStore;
/// <summary>Backing store for JetStream persistence.</summary>
public IStreamStore? BackingStore => _backingStore;
/// <summary>
/// Initializes a new in-memory retained message store with no backing store.
/// </summary>
public MqttRetainedStore() : this(null) { }
/// <summary>
/// Initializes a new retained message store with an optional JetStream backing store.
/// </summary>
/// <param name="backingStore">Optional JetStream stream store for persistence.</param>
public MqttRetainedStore(IStreamStore? backingStore)
{
_backingStore = backingStore;
}
/// <summary>
/// Sets (or clears) the retained message for a topic.
/// An empty payload clears the retained message.
/// Go reference: server/mqtt.go mqttHandleRetainedMsg.
/// </summary>
public void SetRetained(string topic, ReadOnlyMemory<byte> payload)
{
if (payload.IsEmpty)
{
_retained.TryRemove(topic, out _);
_cleared[topic] = true;
return;
}
_cleared.TryRemove(topic, out _);
_retained[topic] = payload;
}
/// <summary>
/// Gets the retained message payload for a topic, or null if none.
/// </summary>
public ReadOnlyMemory<byte>? GetRetained(string topic)
{
if (_retained.TryGetValue(topic, out var payload))
return payload;
return null;
}
/// <summary>
/// Returns all retained messages matching an MQTT topic filter pattern.
/// Supports '+' (single-level) and '#' (multi-level) wildcards.
/// Go reference: server/mqtt.go mqttGetRetainedMessages ~line 1650.
/// </summary>
public IReadOnlyList<MqttRetainedMessage> GetMatchingRetained(string filter)
{
var results = new List<MqttRetainedMessage>();
foreach (var kvp in _retained)
{
if (MqttTopicMatch(kvp.Key, filter))
results.Add(new MqttRetainedMessage(kvp.Key, kvp.Value));
}
return results;
}
/// <summary>
/// Sets (or clears) the retained message and persists to backing store.
/// Go reference: server/mqtt.go mqttHandleRetainedMsg with JetStream.
/// </summary>
public async Task SetRetainedAsync(string topic, ReadOnlyMemory<byte> payload, CancellationToken ct = default)
{
SetRetained(topic, payload);
if (_backingStore is not null)
{
if (payload.IsEmpty)
{
// Clear — the in-memory clear above is sufficient for this implementation.
// A full implementation would publish a tombstone to JetStream.
return;
}
await _backingStore.AppendAsync($"$MQTT.rmsgs.{topic}", payload, ct);
}
}
/// <summary>
/// Gets the retained message, checking backing store if not in memory.
/// Returns null if the topic was explicitly cleared in this session.
/// </summary>
public async Task<byte[]?> GetRetainedAsync(string topic, CancellationToken ct = default)
{
var mem = GetRetained(topic);
if (mem.HasValue)
return mem.Value.ToArray();
// Don't consult the backing store if this topic was explicitly cleared in this session.
if (_cleared.ContainsKey(topic))
return null;
if (_backingStore is not null)
{
var messages = await _backingStore.ListAsync(ct);
foreach (var msg in messages)
{
if (msg.Subject == $"$MQTT.rmsgs.{topic}")
return msg.Payload.ToArray();
}
}
return null;
}
/// <summary>
/// Matches an MQTT topic against a filter pattern.
/// '+' matches exactly one level, '#' matches zero or more levels (must be last).
/// </summary>
internal static bool MqttTopicMatch(string topic, string filter)
{
var topicLevels = topic.Split('/');
var filterLevels = filter.Split('/');
for (var i = 0; i < filterLevels.Length; i++)
{
if (filterLevels[i] == "#")
return true; // '#' matches everything from here
if (i >= topicLevels.Length)
return false; // filter has more levels than topic
if (filterLevels[i] != "+" && filterLevels[i] != topicLevels[i])
return false;
}
// Topic must not have more levels than filter (unless filter ended with '#')
return topicLevels.Length == filterLevels.Length;
}
}
/// <summary>
/// QoS 2 state machine states.
/// Go reference: server/mqtt.go ~line 1300.
/// </summary>
public enum MqttQos2State
{
/// <summary>Publish received, awaiting PUBREC from peer.</summary>
AwaitingPubRec,
/// <summary>PUBREC received, awaiting PUBREL from originator.</summary>
AwaitingPubRel,
/// <summary>PUBREL received, awaiting PUBCOMP from peer.</summary>
AwaitingPubComp,
/// <summary>Flow complete.</summary>
Complete,
}
/// <summary>
/// Tracks QoS 2 flow state for a single packet ID.
/// </summary>
internal sealed class MqttQos2Flow
{
public MqttQos2State State { get; set; }
public DateTime StartedAtUtc { get; init; }
}
/// <summary>
/// Manages the QoS 2 exactly-once delivery state machine for a connection.
/// Tracks per-packet-id state transitions: PUBLISH -> PUBREC -> PUBREL -> PUBCOMP.
/// Go reference: server/mqtt.go mqttProcessPubRec / mqttProcessPubRel / mqttProcessPubComp.
/// </summary>
public sealed class MqttQos2StateMachine
{
private readonly ConcurrentDictionary<ushort, MqttQos2Flow> _flows = new();
private readonly TimeSpan _timeout;
private readonly TimeProvider _timeProvider;
/// <summary>
/// Initializes a new QoS 2 state machine.
/// </summary>
/// <param name="timeout">Timeout for incomplete flows. Default 30 seconds.</param>
/// <param name="timeProvider">Optional time provider for testing.</param>
public MqttQos2StateMachine(TimeSpan? timeout = null, TimeProvider? timeProvider = null)
{
_timeout = timeout ?? TimeSpan.FromSeconds(30);
_timeProvider = timeProvider ?? TimeProvider.System;
}
/// <summary>
/// Begins a new QoS 2 flow for the given packet ID.
/// Returns false if a flow for this packet ID already exists (duplicate publish).
/// </summary>
public bool BeginPublish(ushort packetId)
{
var flow = new MqttQos2Flow
{
State = MqttQos2State.AwaitingPubRec,
StartedAtUtc = _timeProvider.GetUtcNow().UtcDateTime,
};
return _flows.TryAdd(packetId, flow);
}
/// <summary>
/// Processes a PUBREC for the given packet ID.
/// Returns false if the flow is not in the expected state.
/// </summary>
public bool ProcessPubRec(ushort packetId)
{
if (!_flows.TryGetValue(packetId, out var flow))
return false;
if (flow.State != MqttQos2State.AwaitingPubRec)
return false;
flow.State = MqttQos2State.AwaitingPubRel;
return true;
}
/// <summary>
/// Processes a PUBREL for the given packet ID.
/// Returns false if the flow is not in the expected state.
/// </summary>
public bool ProcessPubRel(ushort packetId)
{
if (!_flows.TryGetValue(packetId, out var flow))
return false;
if (flow.State != MqttQos2State.AwaitingPubRel)
return false;
flow.State = MqttQos2State.AwaitingPubComp;
return true;
}
/// <summary>
/// Processes a PUBCOMP for the given packet ID.
/// Returns false if the flow is not in the expected state.
/// Removes the flow on completion.
/// </summary>
public bool ProcessPubComp(ushort packetId)
{
if (!_flows.TryGetValue(packetId, out var flow))
return false;
if (flow.State != MqttQos2State.AwaitingPubComp)
return false;
flow.State = MqttQos2State.Complete;
_flows.TryRemove(packetId, out _);
return true;
}
/// <summary>
/// Gets the current state for a packet ID, or null if no flow exists.
/// </summary>
public MqttQos2State? GetState(ushort packetId)
{
if (_flows.TryGetValue(packetId, out var flow))
return flow.State;
return null;
}
/// <summary>
/// Returns packet IDs for flows that have exceeded the timeout.
/// </summary>
public IReadOnlyList<ushort> GetTimedOutFlows()
{
var now = _timeProvider.GetUtcNow().UtcDateTime;
var timedOut = new List<ushort>();
foreach (var kvp in _flows)
{
if (now - kvp.Value.StartedAtUtc > _timeout)
timedOut.Add(kvp.Key);
}
return timedOut;
}
/// <summary>
/// Removes a flow (e.g., after timeout cleanup).
/// </summary>
public void RemoveFlow(ushort packetId) =>
_flows.TryRemove(packetId, out _);
}