feat(mqtt): add session persistence, QoS 2 state machine, and retained store (E2+E3)
Add MqttSessionStore with save/load/delete/list operations, flapper detection (backoff on rapid reconnects), and TimeProvider-based testing. Add MqttRetainedStore for per-topic retained messages with MQTT wildcard matching (+/# filters). Add MqttQos2StateMachine tracking the full PUBREC/PUBREL/PUBCOMP flow with duplicate rejection and timeout detection. 19 new tests: 9 session persistence, 10 QoS/retained message tests.
This commit is contained in:
@@ -163,4 +163,4 @@ public sealed class MqttListener(
|
||||
}
|
||||
}
|
||||
|
||||
internal sealed record MqttPendingPublish(int PacketId, string Topic, string Payload);
|
||||
public sealed record MqttPendingPublish(int PacketId, string Topic, string Payload);
|
||||
|
||||
241
src/NATS.Server/Mqtt/MqttRetainedStore.cs
Normal file
241
src/NATS.Server/Mqtt/MqttRetainedStore.cs
Normal file
@@ -0,0 +1,241 @@
|
||||
// MQTT retained message store and QoS 2 state machine.
|
||||
// Go reference: golang/nats-server/server/mqtt.go
|
||||
// Retained messages — mqttHandleRetainedMsg / mqttGetRetainedMessages (~lines 1600–1700)
|
||||
// QoS 2 flow — mqttProcessPubRec / mqttProcessPubRel / mqttProcessPubComp (~lines 1300–1400)
|
||||
|
||||
using System.Collections.Concurrent;
|
||||
|
||||
namespace NATS.Server.Mqtt;
|
||||
|
||||
/// <summary>
|
||||
/// A retained message stored for a topic.
|
||||
/// </summary>
|
||||
public sealed record MqttRetainedMessage(string Topic, ReadOnlyMemory<byte> Payload);
|
||||
|
||||
/// <summary>
|
||||
/// In-memory store for MQTT retained messages.
|
||||
/// Go reference: server/mqtt.go mqttHandleRetainedMsg ~line 1600.
|
||||
/// </summary>
|
||||
public sealed class MqttRetainedStore
|
||||
{
|
||||
private readonly ConcurrentDictionary<string, ReadOnlyMemory<byte>> _retained = new(StringComparer.Ordinal);
|
||||
|
||||
/// <summary>
|
||||
/// Sets (or clears) the retained message for a topic.
|
||||
/// An empty payload clears the retained message.
|
||||
/// Go reference: server/mqtt.go mqttHandleRetainedMsg.
|
||||
/// </summary>
|
||||
public void SetRetained(string topic, ReadOnlyMemory<byte> payload)
|
||||
{
|
||||
if (payload.IsEmpty)
|
||||
{
|
||||
_retained.TryRemove(topic, out _);
|
||||
return;
|
||||
}
|
||||
|
||||
_retained[topic] = payload;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the retained message payload for a topic, or null if none.
|
||||
/// </summary>
|
||||
public ReadOnlyMemory<byte>? GetRetained(string topic)
|
||||
{
|
||||
if (_retained.TryGetValue(topic, out var payload))
|
||||
return payload;
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns all retained messages matching an MQTT topic filter pattern.
|
||||
/// Supports '+' (single-level) and '#' (multi-level) wildcards.
|
||||
/// Go reference: server/mqtt.go mqttGetRetainedMessages ~line 1650.
|
||||
/// </summary>
|
||||
public IReadOnlyList<MqttRetainedMessage> GetMatchingRetained(string filter)
|
||||
{
|
||||
var results = new List<MqttRetainedMessage>();
|
||||
foreach (var kvp in _retained)
|
||||
{
|
||||
if (MqttTopicMatch(kvp.Key, filter))
|
||||
results.Add(new MqttRetainedMessage(kvp.Key, kvp.Value));
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Matches an MQTT topic against a filter pattern.
|
||||
/// '+' matches exactly one level, '#' matches zero or more levels (must be last).
|
||||
/// </summary>
|
||||
internal static bool MqttTopicMatch(string topic, string filter)
|
||||
{
|
||||
var topicLevels = topic.Split('/');
|
||||
var filterLevels = filter.Split('/');
|
||||
|
||||
for (var i = 0; i < filterLevels.Length; i++)
|
||||
{
|
||||
if (filterLevels[i] == "#")
|
||||
return true; // '#' matches everything from here
|
||||
|
||||
if (i >= topicLevels.Length)
|
||||
return false; // filter has more levels than topic
|
||||
|
||||
if (filterLevels[i] != "+" && filterLevels[i] != topicLevels[i])
|
||||
return false;
|
||||
}
|
||||
|
||||
// Topic must not have more levels than filter (unless filter ended with '#')
|
||||
return topicLevels.Length == filterLevels.Length;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// QoS 2 state machine states.
|
||||
/// Go reference: server/mqtt.go ~line 1300.
|
||||
/// </summary>
|
||||
public enum MqttQos2State
|
||||
{
|
||||
/// <summary>Publish received, awaiting PUBREC from peer.</summary>
|
||||
AwaitingPubRec,
|
||||
|
||||
/// <summary>PUBREC received, awaiting PUBREL from originator.</summary>
|
||||
AwaitingPubRel,
|
||||
|
||||
/// <summary>PUBREL received, awaiting PUBCOMP from peer.</summary>
|
||||
AwaitingPubComp,
|
||||
|
||||
/// <summary>Flow complete.</summary>
|
||||
Complete,
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tracks QoS 2 flow state for a single packet ID.
|
||||
/// </summary>
|
||||
internal sealed class MqttQos2Flow
|
||||
{
|
||||
public MqttQos2State State { get; set; }
|
||||
public DateTime StartedAtUtc { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Manages the QoS 2 exactly-once delivery state machine for a connection.
|
||||
/// Tracks per-packet-id state transitions: PUBLISH -> PUBREC -> PUBREL -> PUBCOMP.
|
||||
/// Go reference: server/mqtt.go mqttProcessPubRec / mqttProcessPubRel / mqttProcessPubComp.
|
||||
/// </summary>
|
||||
public sealed class MqttQos2StateMachine
|
||||
{
|
||||
private readonly ConcurrentDictionary<ushort, MqttQos2Flow> _flows = new();
|
||||
private readonly TimeSpan _timeout;
|
||||
private readonly TimeProvider _timeProvider;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new QoS 2 state machine.
|
||||
/// </summary>
|
||||
/// <param name="timeout">Timeout for incomplete flows. Default 30 seconds.</param>
|
||||
/// <param name="timeProvider">Optional time provider for testing.</param>
|
||||
public MqttQos2StateMachine(TimeSpan? timeout = null, TimeProvider? timeProvider = null)
|
||||
{
|
||||
_timeout = timeout ?? TimeSpan.FromSeconds(30);
|
||||
_timeProvider = timeProvider ?? TimeProvider.System;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Begins a new QoS 2 flow for the given packet ID.
|
||||
/// Returns false if a flow for this packet ID already exists (duplicate publish).
|
||||
/// </summary>
|
||||
public bool BeginPublish(ushort packetId)
|
||||
{
|
||||
var flow = new MqttQos2Flow
|
||||
{
|
||||
State = MqttQos2State.AwaitingPubRec,
|
||||
StartedAtUtc = _timeProvider.GetUtcNow().UtcDateTime,
|
||||
};
|
||||
|
||||
return _flows.TryAdd(packetId, flow);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Processes a PUBREC for the given packet ID.
|
||||
/// Returns false if the flow is not in the expected state.
|
||||
/// </summary>
|
||||
public bool ProcessPubRec(ushort packetId)
|
||||
{
|
||||
if (!_flows.TryGetValue(packetId, out var flow))
|
||||
return false;
|
||||
|
||||
if (flow.State != MqttQos2State.AwaitingPubRec)
|
||||
return false;
|
||||
|
||||
flow.State = MqttQos2State.AwaitingPubRel;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Processes a PUBREL for the given packet ID.
|
||||
/// Returns false if the flow is not in the expected state.
|
||||
/// </summary>
|
||||
public bool ProcessPubRel(ushort packetId)
|
||||
{
|
||||
if (!_flows.TryGetValue(packetId, out var flow))
|
||||
return false;
|
||||
|
||||
if (flow.State != MqttQos2State.AwaitingPubRel)
|
||||
return false;
|
||||
|
||||
flow.State = MqttQos2State.AwaitingPubComp;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Processes a PUBCOMP for the given packet ID.
|
||||
/// Returns false if the flow is not in the expected state.
|
||||
/// Removes the flow on completion.
|
||||
/// </summary>
|
||||
public bool ProcessPubComp(ushort packetId)
|
||||
{
|
||||
if (!_flows.TryGetValue(packetId, out var flow))
|
||||
return false;
|
||||
|
||||
if (flow.State != MqttQos2State.AwaitingPubComp)
|
||||
return false;
|
||||
|
||||
flow.State = MqttQos2State.Complete;
|
||||
_flows.TryRemove(packetId, out _);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the current state for a packet ID, or null if no flow exists.
|
||||
/// </summary>
|
||||
public MqttQos2State? GetState(ushort packetId)
|
||||
{
|
||||
if (_flows.TryGetValue(packetId, out var flow))
|
||||
return flow.State;
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns packet IDs for flows that have exceeded the timeout.
|
||||
/// </summary>
|
||||
public IReadOnlyList<ushort> GetTimedOutFlows()
|
||||
{
|
||||
var now = _timeProvider.GetUtcNow().UtcDateTime;
|
||||
var timedOut = new List<ushort>();
|
||||
|
||||
foreach (var kvp in _flows)
|
||||
{
|
||||
if (now - kvp.Value.StartedAtUtc > _timeout)
|
||||
timedOut.Add(kvp.Key);
|
||||
}
|
||||
|
||||
return timedOut;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Removes a flow (e.g., after timeout cleanup).
|
||||
/// </summary>
|
||||
public void RemoveFlow(ushort packetId) =>
|
||||
_flows.TryRemove(packetId, out _);
|
||||
}
|
||||
133
src/NATS.Server/Mqtt/MqttSessionStore.cs
Normal file
133
src/NATS.Server/Mqtt/MqttSessionStore.cs
Normal file
@@ -0,0 +1,133 @@
|
||||
// MQTT session persistence store.
|
||||
// Go reference: golang/nats-server/server/mqtt.go:253-300
|
||||
// Session state management — mqttInitSessionStore / mqttStoreSession
|
||||
// Flapper detection — mqttCheckFlapper (lines ~300–360)
|
||||
|
||||
using System.Collections.Concurrent;
|
||||
|
||||
namespace NATS.Server.Mqtt;
|
||||
|
||||
/// <summary>
|
||||
/// Serializable session data for an MQTT client.
|
||||
/// Go reference: server/mqtt.go mqttSession struct ~line 253.
|
||||
/// </summary>
|
||||
public sealed record MqttSessionData
|
||||
{
|
||||
public required string ClientId { get; init; }
|
||||
public Dictionary<string, int> Subscriptions { get; init; } = [];
|
||||
public List<MqttPendingPublish> PendingPublishes { get; init; } = [];
|
||||
public string? WillTopic { get; init; }
|
||||
public byte[]? WillPayload { get; init; }
|
||||
public int WillQoS { get; init; }
|
||||
public bool WillRetain { get; init; }
|
||||
public bool CleanSession { get; init; }
|
||||
public DateTime ConnectedAtUtc { get; init; } = DateTime.UtcNow;
|
||||
public DateTime LastActivityUtc { get; set; } = DateTime.UtcNow;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// In-memory MQTT session store with flapper detection.
|
||||
/// The abstraction allows future JetStream backing.
|
||||
/// Go reference: server/mqtt.go mqttInitSessionStore ~line 260.
|
||||
/// </summary>
|
||||
public sealed class MqttSessionStore
|
||||
{
|
||||
private readonly ConcurrentDictionary<string, MqttSessionData> _sessions = new(StringComparer.Ordinal);
|
||||
private readonly ConcurrentDictionary<string, List<DateTime>> _connectHistory = new(StringComparer.Ordinal);
|
||||
|
||||
private readonly TimeSpan _flapWindow;
|
||||
private readonly int _flapThreshold;
|
||||
private readonly TimeSpan _flapBackoff;
|
||||
private readonly TimeProvider _timeProvider;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new session store.
|
||||
/// </summary>
|
||||
/// <param name="flapWindow">Window in which repeated connects trigger flap detection. Default 10 seconds.</param>
|
||||
/// <param name="flapThreshold">Number of connects within the window to trigger backoff. Default 3.</param>
|
||||
/// <param name="flapBackoff">Backoff delay to apply when flapping. Default 1 second.</param>
|
||||
/// <param name="timeProvider">Optional time provider for testing. Default uses system clock.</param>
|
||||
public MqttSessionStore(
|
||||
TimeSpan? flapWindow = null,
|
||||
int flapThreshold = 3,
|
||||
TimeSpan? flapBackoff = null,
|
||||
TimeProvider? timeProvider = null)
|
||||
{
|
||||
_flapWindow = flapWindow ?? TimeSpan.FromSeconds(10);
|
||||
_flapThreshold = flapThreshold;
|
||||
_flapBackoff = flapBackoff ?? TimeSpan.FromSeconds(1);
|
||||
_timeProvider = timeProvider ?? TimeProvider.System;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Saves (or overwrites) session data for the given client.
|
||||
/// Go reference: server/mqtt.go mqttStoreSession.
|
||||
/// </summary>
|
||||
public void SaveSession(MqttSessionData session)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(session);
|
||||
_sessions[session.ClientId] = session;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Loads session data for the given client, or null if not found.
|
||||
/// Go reference: server/mqtt.go mqttLoadSession.
|
||||
/// </summary>
|
||||
public MqttSessionData? LoadSession(string clientId) =>
|
||||
_sessions.TryGetValue(clientId, out var session) ? session : null;
|
||||
|
||||
/// <summary>
|
||||
/// Deletes the session for the given client. No-op if not found.
|
||||
/// Go reference: server/mqtt.go mqttDeleteSession.
|
||||
/// </summary>
|
||||
public void DeleteSession(string clientId) =>
|
||||
_sessions.TryRemove(clientId, out _);
|
||||
|
||||
/// <summary>
|
||||
/// Returns all active sessions.
|
||||
/// </summary>
|
||||
public IReadOnlyList<MqttSessionData> ListSessions() =>
|
||||
_sessions.Values.ToList();
|
||||
|
||||
/// <summary>
|
||||
/// Tracks a connect or disconnect event for flapper detection.
|
||||
/// Go reference: server/mqtt.go mqttCheckFlapper ~line 300.
|
||||
/// </summary>
|
||||
/// <param name="clientId">The MQTT client identifier.</param>
|
||||
/// <param name="connected">True for connect, false for disconnect.</param>
|
||||
public void TrackConnectDisconnect(string clientId, bool connected)
|
||||
{
|
||||
if (!connected)
|
||||
return;
|
||||
|
||||
var now = _timeProvider.GetUtcNow().UtcDateTime;
|
||||
var history = _connectHistory.GetOrAdd(clientId, static _ => []);
|
||||
|
||||
lock (history)
|
||||
{
|
||||
// Prune entries outside the flap window
|
||||
var cutoff = now - _flapWindow;
|
||||
history.RemoveAll(t => t < cutoff);
|
||||
history.Add(now);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the backoff delay if the client is flapping, otherwise <see cref="TimeSpan.Zero"/>.
|
||||
/// Go reference: server/mqtt.go mqttCheckFlapper ~line 320.
|
||||
/// </summary>
|
||||
public TimeSpan ShouldApplyBackoff(string clientId)
|
||||
{
|
||||
if (!_connectHistory.TryGetValue(clientId, out var history))
|
||||
return TimeSpan.Zero;
|
||||
|
||||
var now = _timeProvider.GetUtcNow().UtcDateTime;
|
||||
|
||||
lock (history)
|
||||
{
|
||||
var cutoff = now - _flapWindow;
|
||||
history.RemoveAll(t => t < cutoff);
|
||||
return history.Count >= _flapThreshold ? _flapBackoff : TimeSpan.Zero;
|
||||
}
|
||||
}
|
||||
}
|
||||
190
tests/NATS.Server.Tests/Mqtt/MqttQosTests.cs
Normal file
190
tests/NATS.Server.Tests/Mqtt/MqttQosTests.cs
Normal file
@@ -0,0 +1,190 @@
|
||||
// MQTT QoS and retained message tests.
|
||||
// Go reference: golang/nats-server/server/mqtt.go
|
||||
// Retained messages — mqttHandleRetainedMsg / mqttGetRetainedMessages (~lines 1600–1700)
|
||||
// QoS 2 flow — mqttProcessPubRec / mqttProcessPubRel / mqttProcessPubComp (~lines 1300–1400)
|
||||
|
||||
using System.Text;
|
||||
using NATS.Server.Mqtt;
|
||||
|
||||
namespace NATS.Server.Tests.Mqtt;
|
||||
|
||||
public class MqttQosTests
|
||||
{
|
||||
[Fact]
|
||||
public void RetainedStore_SetAndGet_RoundTrips()
|
||||
{
|
||||
// Go reference: server/mqtt.go mqttHandleRetainedMsg — store and retrieve
|
||||
var store = new MqttRetainedStore();
|
||||
var payload = Encoding.UTF8.GetBytes("temperature=72.5");
|
||||
|
||||
store.SetRetained("sensors/temp", payload);
|
||||
|
||||
var result = store.GetRetained("sensors/temp");
|
||||
result.ShouldNotBeNull();
|
||||
Encoding.UTF8.GetString(result.Value.Span).ShouldBe("temperature=72.5");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void RetainedStore_EmptyPayload_ClearsRetained()
|
||||
{
|
||||
// Go reference: server/mqtt.go mqttHandleRetainedMsg — empty payload clears
|
||||
var store = new MqttRetainedStore();
|
||||
store.SetRetained("sensors/temp", Encoding.UTF8.GetBytes("old-value"));
|
||||
|
||||
store.SetRetained("sensors/temp", ReadOnlyMemory<byte>.Empty);
|
||||
|
||||
store.GetRetained("sensors/temp").ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void RetainedStore_Overwrite_ReplacesOld()
|
||||
{
|
||||
// Go reference: server/mqtt.go mqttHandleRetainedMsg — overwrite replaces
|
||||
var store = new MqttRetainedStore();
|
||||
store.SetRetained("sensors/temp", Encoding.UTF8.GetBytes("first"));
|
||||
|
||||
store.SetRetained("sensors/temp", Encoding.UTF8.GetBytes("second"));
|
||||
|
||||
var result = store.GetRetained("sensors/temp");
|
||||
result.ShouldNotBeNull();
|
||||
Encoding.UTF8.GetString(result.Value.Span).ShouldBe("second");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void RetainedStore_GetMatching_WildcardPlus()
|
||||
{
|
||||
// Go reference: server/mqtt.go mqttGetRetainedMessages — '+' single-level wildcard
|
||||
var store = new MqttRetainedStore();
|
||||
store.SetRetained("sensors/temp", Encoding.UTF8.GetBytes("72.5"));
|
||||
store.SetRetained("sensors/humidity", Encoding.UTF8.GetBytes("45%"));
|
||||
store.SetRetained("alerts/fire", Encoding.UTF8.GetBytes("!"));
|
||||
|
||||
var matches = store.GetMatchingRetained("sensors/+");
|
||||
|
||||
matches.Count.ShouldBe(2);
|
||||
matches.Select(m => m.Topic).ShouldBe(
|
||||
new[] { "sensors/temp", "sensors/humidity" },
|
||||
ignoreOrder: true);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void RetainedStore_GetMatching_WildcardHash()
|
||||
{
|
||||
// Go reference: server/mqtt.go mqttGetRetainedMessages — '#' multi-level wildcard
|
||||
var store = new MqttRetainedStore();
|
||||
store.SetRetained("home/living/temp", Encoding.UTF8.GetBytes("22"));
|
||||
store.SetRetained("home/living/light", Encoding.UTF8.GetBytes("on"));
|
||||
store.SetRetained("home/kitchen/temp", Encoding.UTF8.GetBytes("24"));
|
||||
store.SetRetained("office/desk/light", Encoding.UTF8.GetBytes("off"));
|
||||
|
||||
var matches = store.GetMatchingRetained("home/#");
|
||||
|
||||
matches.Count.ShouldBe(3);
|
||||
matches.Select(m => m.Topic).ShouldBe(
|
||||
new[] { "home/living/temp", "home/living/light", "home/kitchen/temp" },
|
||||
ignoreOrder: true);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Qos2_FullFlow_PubRecPubRelPubComp()
|
||||
{
|
||||
// Go reference: server/mqtt.go mqttProcessPubRec / mqttProcessPubRel / mqttProcessPubComp
|
||||
var sm = new MqttQos2StateMachine();
|
||||
|
||||
// Begin publish
|
||||
sm.BeginPublish(100).ShouldBeTrue();
|
||||
sm.GetState(100).ShouldBe(MqttQos2State.AwaitingPubRec);
|
||||
|
||||
// PUBREC
|
||||
sm.ProcessPubRec(100).ShouldBeTrue();
|
||||
sm.GetState(100).ShouldBe(MqttQos2State.AwaitingPubRel);
|
||||
|
||||
// PUBREL
|
||||
sm.ProcessPubRel(100).ShouldBeTrue();
|
||||
sm.GetState(100).ShouldBe(MqttQos2State.AwaitingPubComp);
|
||||
|
||||
// PUBCOMP — completes and removes flow
|
||||
sm.ProcessPubComp(100).ShouldBeTrue();
|
||||
sm.GetState(100).ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Qos2_DuplicatePublish_Rejected()
|
||||
{
|
||||
// Go reference: server/mqtt.go — duplicate packet ID rejected during active flow
|
||||
var sm = new MqttQos2StateMachine();
|
||||
|
||||
sm.BeginPublish(200).ShouldBeTrue();
|
||||
|
||||
// Same packet ID while flow is active — should be rejected
|
||||
sm.BeginPublish(200).ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Qos2_IncompleteFlow_TimesOut()
|
||||
{
|
||||
// Go reference: server/mqtt.go — incomplete QoS 2 flows time out
|
||||
var fakeTime = new FakeTimeProvider(new DateTimeOffset(2026, 1, 15, 12, 0, 0, TimeSpan.Zero));
|
||||
var sm = new MqttQos2StateMachine(timeout: TimeSpan.FromSeconds(5), timeProvider: fakeTime);
|
||||
|
||||
sm.BeginPublish(300).ShouldBeTrue();
|
||||
|
||||
// Not timed out yet
|
||||
fakeTime.Advance(TimeSpan.FromSeconds(3));
|
||||
sm.GetTimedOutFlows().ShouldBeEmpty();
|
||||
|
||||
// Advance past timeout
|
||||
fakeTime.Advance(TimeSpan.FromSeconds(3));
|
||||
var timedOut = sm.GetTimedOutFlows();
|
||||
timedOut.Count.ShouldBe(1);
|
||||
timedOut[0].ShouldBe((ushort)300);
|
||||
|
||||
// Clean up
|
||||
sm.RemoveFlow(300);
|
||||
sm.GetState(300).ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Qos1_Puback_RemovesPending()
|
||||
{
|
||||
// Go reference: server/mqtt.go — QoS 1 PUBACK removes from pending
|
||||
// This tests the existing MqttListener pending publish / ack mechanism
|
||||
// in the context of the session store.
|
||||
var store = new MqttSessionStore();
|
||||
var session = new MqttSessionData
|
||||
{
|
||||
ClientId = "qos1-client",
|
||||
PendingPublishes =
|
||||
[
|
||||
new MqttPendingPublish(1, "topic/a", "payload-a"),
|
||||
new MqttPendingPublish(2, "topic/b", "payload-b"),
|
||||
],
|
||||
};
|
||||
|
||||
store.SaveSession(session);
|
||||
|
||||
// Simulate PUBACK for packet 1: remove it from pending
|
||||
var loaded = store.LoadSession("qos1-client");
|
||||
loaded.ShouldNotBeNull();
|
||||
loaded.PendingPublishes.RemoveAll(p => p.PacketId == 1);
|
||||
store.SaveSession(loaded);
|
||||
|
||||
// Verify only packet 2 remains
|
||||
var updated = store.LoadSession("qos1-client");
|
||||
updated.ShouldNotBeNull();
|
||||
updated.PendingPublishes.Count.ShouldBe(1);
|
||||
updated.PendingPublishes[0].PacketId.ShouldBe(2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void RetainedStore_GetMatching_NoMatch_ReturnsEmpty()
|
||||
{
|
||||
// Go reference: server/mqtt.go mqttGetRetainedMessages — no match returns empty
|
||||
var store = new MqttRetainedStore();
|
||||
store.SetRetained("sensors/temp", Encoding.UTF8.GetBytes("72"));
|
||||
|
||||
var matches = store.GetMatchingRetained("alerts/+");
|
||||
|
||||
matches.ShouldBeEmpty();
|
||||
}
|
||||
}
|
||||
209
tests/NATS.Server.Tests/Mqtt/MqttSessionPersistenceTests.cs
Normal file
209
tests/NATS.Server.Tests/Mqtt/MqttSessionPersistenceTests.cs
Normal file
@@ -0,0 +1,209 @@
|
||||
// MQTT session persistence tests.
|
||||
// Go reference: golang/nats-server/server/mqtt.go:253-360
|
||||
// Session store — mqttInitSessionStore / mqttStoreSession / mqttLoadSession
|
||||
// Flapper detection — mqttCheckFlapper (~lines 300–360)
|
||||
|
||||
using NATS.Server.Mqtt;
|
||||
|
||||
namespace NATS.Server.Tests.Mqtt;
|
||||
|
||||
public class MqttSessionPersistenceTests
|
||||
{
|
||||
[Fact]
|
||||
public void SaveSession_ThenLoad_RoundTrips()
|
||||
{
|
||||
// Go reference: server/mqtt.go mqttStoreSession / mqttLoadSession
|
||||
var store = new MqttSessionStore();
|
||||
var session = new MqttSessionData
|
||||
{
|
||||
ClientId = "client-1",
|
||||
Subscriptions = new Dictionary<string, int> { ["sensors/temp"] = 1, ["alerts/#"] = 0 },
|
||||
PendingPublishes = [new MqttPendingPublish(42, "sensors/temp", "72.5")],
|
||||
WillTopic = "clients/offline",
|
||||
WillPayload = [0x01, 0x02],
|
||||
WillQoS = 1,
|
||||
WillRetain = true,
|
||||
CleanSession = false,
|
||||
ConnectedAtUtc = new DateTime(2026, 1, 15, 10, 30, 0, DateTimeKind.Utc),
|
||||
LastActivityUtc = new DateTime(2026, 1, 15, 10, 35, 0, DateTimeKind.Utc),
|
||||
};
|
||||
|
||||
store.SaveSession(session);
|
||||
var loaded = store.LoadSession("client-1");
|
||||
|
||||
loaded.ShouldNotBeNull();
|
||||
loaded.ClientId.ShouldBe("client-1");
|
||||
loaded.Subscriptions.Count.ShouldBe(2);
|
||||
loaded.Subscriptions["sensors/temp"].ShouldBe(1);
|
||||
loaded.Subscriptions["alerts/#"].ShouldBe(0);
|
||||
loaded.PendingPublishes.Count.ShouldBe(1);
|
||||
loaded.PendingPublishes[0].PacketId.ShouldBe(42);
|
||||
loaded.PendingPublishes[0].Topic.ShouldBe("sensors/temp");
|
||||
loaded.PendingPublishes[0].Payload.ShouldBe("72.5");
|
||||
loaded.WillTopic.ShouldBe("clients/offline");
|
||||
loaded.WillPayload.ShouldBe(new byte[] { 0x01, 0x02 });
|
||||
loaded.WillQoS.ShouldBe(1);
|
||||
loaded.WillRetain.ShouldBeTrue();
|
||||
loaded.CleanSession.ShouldBeFalse();
|
||||
loaded.ConnectedAtUtc.ShouldBe(new DateTime(2026, 1, 15, 10, 30, 0, DateTimeKind.Utc));
|
||||
loaded.LastActivityUtc.ShouldBe(new DateTime(2026, 1, 15, 10, 35, 0, DateTimeKind.Utc));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SaveSession_Update_OverwritesPrevious()
|
||||
{
|
||||
// Go reference: server/mqtt.go mqttStoreSession — overwrites existing
|
||||
var store = new MqttSessionStore();
|
||||
|
||||
store.SaveSession(new MqttSessionData
|
||||
{
|
||||
ClientId = "client-x",
|
||||
Subscriptions = new Dictionary<string, int> { ["old/topic"] = 0 },
|
||||
});
|
||||
|
||||
store.SaveSession(new MqttSessionData
|
||||
{
|
||||
ClientId = "client-x",
|
||||
Subscriptions = new Dictionary<string, int> { ["new/topic"] = 1 },
|
||||
});
|
||||
|
||||
var loaded = store.LoadSession("client-x");
|
||||
loaded.ShouldNotBeNull();
|
||||
loaded.Subscriptions.ShouldContainKey("new/topic");
|
||||
loaded.Subscriptions.ShouldNotContainKey("old/topic");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void LoadSession_NonExistent_ReturnsNull()
|
||||
{
|
||||
// Go reference: server/mqtt.go mqttLoadSession — returns nil for missing
|
||||
var store = new MqttSessionStore();
|
||||
|
||||
var loaded = store.LoadSession("does-not-exist");
|
||||
|
||||
loaded.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void DeleteSession_RemovesFromStore()
|
||||
{
|
||||
// Go reference: server/mqtt.go mqttDeleteSession
|
||||
var store = new MqttSessionStore();
|
||||
store.SaveSession(new MqttSessionData { ClientId = "to-delete" });
|
||||
|
||||
store.DeleteSession("to-delete");
|
||||
|
||||
store.LoadSession("to-delete").ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void DeleteSession_NonExistent_NoError()
|
||||
{
|
||||
// Go reference: server/mqtt.go mqttDeleteSession — no-op on missing
|
||||
var store = new MqttSessionStore();
|
||||
|
||||
// Should not throw
|
||||
store.DeleteSession("phantom");
|
||||
|
||||
store.LoadSession("phantom").ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ListSessions_ReturnsAllActive()
|
||||
{
|
||||
// Go reference: server/mqtt.go session enumeration
|
||||
var store = new MqttSessionStore();
|
||||
store.SaveSession(new MqttSessionData { ClientId = "alpha" });
|
||||
store.SaveSession(new MqttSessionData { ClientId = "beta" });
|
||||
store.SaveSession(new MqttSessionData { ClientId = "gamma" });
|
||||
|
||||
var sessions = store.ListSessions();
|
||||
|
||||
sessions.Count.ShouldBe(3);
|
||||
sessions.Select(s => s.ClientId).ShouldBe(
|
||||
new[] { "alpha", "beta", "gamma" },
|
||||
ignoreOrder: true);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void FlapperDetection_ThreeConnectsInTenSeconds_BackoffApplied()
|
||||
{
|
||||
// Go reference: server/mqtt.go mqttCheckFlapper ~line 300
|
||||
// Three connects within the flap window triggers backoff.
|
||||
var fakeTime = new FakeTimeProvider(new DateTimeOffset(2026, 1, 15, 12, 0, 0, TimeSpan.Zero));
|
||||
var store = new MqttSessionStore(
|
||||
flapWindow: TimeSpan.FromSeconds(10),
|
||||
flapThreshold: 3,
|
||||
flapBackoff: TimeSpan.FromSeconds(1),
|
||||
timeProvider: fakeTime);
|
||||
|
||||
// Three rapid connects
|
||||
store.TrackConnectDisconnect("flapper", connected: true);
|
||||
fakeTime.Advance(TimeSpan.FromSeconds(1));
|
||||
store.TrackConnectDisconnect("flapper", connected: true);
|
||||
fakeTime.Advance(TimeSpan.FromSeconds(1));
|
||||
store.TrackConnectDisconnect("flapper", connected: true);
|
||||
|
||||
var backoff = store.ShouldApplyBackoff("flapper");
|
||||
backoff.ShouldBeGreaterThan(TimeSpan.Zero);
|
||||
backoff.ShouldBe(TimeSpan.FromSeconds(1));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void FlapperDetection_SlowConnects_NoBackoff()
|
||||
{
|
||||
// Go reference: server/mqtt.go mqttCheckFlapper — slow connects should not trigger
|
||||
var fakeTime = new FakeTimeProvider(new DateTimeOffset(2026, 1, 15, 12, 0, 0, TimeSpan.Zero));
|
||||
var store = new MqttSessionStore(
|
||||
flapWindow: TimeSpan.FromSeconds(10),
|
||||
flapThreshold: 3,
|
||||
flapBackoff: TimeSpan.FromSeconds(1),
|
||||
timeProvider: fakeTime);
|
||||
|
||||
// Three connects, but spread out beyond the window
|
||||
store.TrackConnectDisconnect("slow-client", connected: true);
|
||||
fakeTime.Advance(TimeSpan.FromSeconds(5));
|
||||
store.TrackConnectDisconnect("slow-client", connected: true);
|
||||
fakeTime.Advance(TimeSpan.FromSeconds(6)); // first connect now outside window
|
||||
store.TrackConnectDisconnect("slow-client", connected: true);
|
||||
|
||||
var backoff = store.ShouldApplyBackoff("slow-client");
|
||||
backoff.ShouldBe(TimeSpan.Zero);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CleanSession_DeletesOnConnect()
|
||||
{
|
||||
// Go reference: server/mqtt.go — clean session flag clears stored state
|
||||
var store = new MqttSessionStore();
|
||||
|
||||
// Pre-populate a session
|
||||
store.SaveSession(new MqttSessionData
|
||||
{
|
||||
ClientId = "ephemeral",
|
||||
Subscriptions = new Dictionary<string, int> { ["topic/a"] = 1 },
|
||||
CleanSession = false,
|
||||
});
|
||||
|
||||
store.LoadSession("ephemeral").ShouldNotBeNull();
|
||||
|
||||
// Simulate clean session connect: delete the old session
|
||||
store.DeleteSession("ephemeral");
|
||||
|
||||
store.LoadSession("ephemeral").ShouldBeNull();
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Fake <see cref="TimeProvider"/> for deterministic time control in tests.
|
||||
/// </summary>
|
||||
internal sealed class FakeTimeProvider(DateTimeOffset startTime) : TimeProvider
|
||||
{
|
||||
private DateTimeOffset _current = startTime;
|
||||
|
||||
public override DateTimeOffset GetUtcNow() => _current;
|
||||
|
||||
public void Advance(TimeSpan duration) => _current += duration;
|
||||
|
||||
public void SetUtcNow(DateTimeOffset value) => _current = value;
|
||||
}
|
||||
Reference in New Issue
Block a user