diff --git a/src/NATS.Server/Mqtt/MqttListener.cs b/src/NATS.Server/Mqtt/MqttListener.cs index 19dacfa..0fd7d71 100644 --- a/src/NATS.Server/Mqtt/MqttListener.cs +++ b/src/NATS.Server/Mqtt/MqttListener.cs @@ -163,4 +163,4 @@ public sealed class MqttListener( } } -internal sealed record MqttPendingPublish(int PacketId, string Topic, string Payload); +public sealed record MqttPendingPublish(int PacketId, string Topic, string Payload); diff --git a/src/NATS.Server/Mqtt/MqttRetainedStore.cs b/src/NATS.Server/Mqtt/MqttRetainedStore.cs new file mode 100644 index 0000000..e43e1da --- /dev/null +++ b/src/NATS.Server/Mqtt/MqttRetainedStore.cs @@ -0,0 +1,241 @@ +// MQTT retained message store and QoS 2 state machine. +// Go reference: golang/nats-server/server/mqtt.go +// Retained messages — mqttHandleRetainedMsg / mqttGetRetainedMessages (~lines 1600–1700) +// QoS 2 flow — mqttProcessPubRec / mqttProcessPubRel / mqttProcessPubComp (~lines 1300–1400) + +using System.Collections.Concurrent; + +namespace NATS.Server.Mqtt; + +/// +/// A retained message stored for a topic. +/// +public sealed record MqttRetainedMessage(string Topic, ReadOnlyMemory Payload); + +/// +/// In-memory store for MQTT retained messages. +/// Go reference: server/mqtt.go mqttHandleRetainedMsg ~line 1600. +/// +public sealed class MqttRetainedStore +{ + private readonly ConcurrentDictionary> _retained = new(StringComparer.Ordinal); + + /// + /// Sets (or clears) the retained message for a topic. + /// An empty payload clears the retained message. + /// Go reference: server/mqtt.go mqttHandleRetainedMsg. + /// + public void SetRetained(string topic, ReadOnlyMemory payload) + { + if (payload.IsEmpty) + { + _retained.TryRemove(topic, out _); + return; + } + + _retained[topic] = payload; + } + + /// + /// Gets the retained message payload for a topic, or null if none. + /// + public ReadOnlyMemory? GetRetained(string topic) + { + if (_retained.TryGetValue(topic, out var payload)) + return payload; + + return null; + } + + /// + /// Returns all retained messages matching an MQTT topic filter pattern. + /// Supports '+' (single-level) and '#' (multi-level) wildcards. + /// Go reference: server/mqtt.go mqttGetRetainedMessages ~line 1650. + /// + public IReadOnlyList GetMatchingRetained(string filter) + { + var results = new List(); + foreach (var kvp in _retained) + { + if (MqttTopicMatch(kvp.Key, filter)) + results.Add(new MqttRetainedMessage(kvp.Key, kvp.Value)); + } + + return results; + } + + /// + /// Matches an MQTT topic against a filter pattern. + /// '+' matches exactly one level, '#' matches zero or more levels (must be last). + /// + internal static bool MqttTopicMatch(string topic, string filter) + { + var topicLevels = topic.Split('/'); + var filterLevels = filter.Split('/'); + + for (var i = 0; i < filterLevels.Length; i++) + { + if (filterLevels[i] == "#") + return true; // '#' matches everything from here + + if (i >= topicLevels.Length) + return false; // filter has more levels than topic + + if (filterLevels[i] != "+" && filterLevels[i] != topicLevels[i]) + return false; + } + + // Topic must not have more levels than filter (unless filter ended with '#') + return topicLevels.Length == filterLevels.Length; + } +} + +/// +/// QoS 2 state machine states. +/// Go reference: server/mqtt.go ~line 1300. +/// +public enum MqttQos2State +{ + /// Publish received, awaiting PUBREC from peer. + AwaitingPubRec, + + /// PUBREC received, awaiting PUBREL from originator. + AwaitingPubRel, + + /// PUBREL received, awaiting PUBCOMP from peer. + AwaitingPubComp, + + /// Flow complete. + Complete, +} + +/// +/// Tracks QoS 2 flow state for a single packet ID. +/// +internal sealed class MqttQos2Flow +{ + public MqttQos2State State { get; set; } + public DateTime StartedAtUtc { get; init; } +} + +/// +/// Manages the QoS 2 exactly-once delivery state machine for a connection. +/// Tracks per-packet-id state transitions: PUBLISH -> PUBREC -> PUBREL -> PUBCOMP. +/// Go reference: server/mqtt.go mqttProcessPubRec / mqttProcessPubRel / mqttProcessPubComp. +/// +public sealed class MqttQos2StateMachine +{ + private readonly ConcurrentDictionary _flows = new(); + private readonly TimeSpan _timeout; + private readonly TimeProvider _timeProvider; + + /// + /// Initializes a new QoS 2 state machine. + /// + /// Timeout for incomplete flows. Default 30 seconds. + /// Optional time provider for testing. + public MqttQos2StateMachine(TimeSpan? timeout = null, TimeProvider? timeProvider = null) + { + _timeout = timeout ?? TimeSpan.FromSeconds(30); + _timeProvider = timeProvider ?? TimeProvider.System; + } + + /// + /// Begins a new QoS 2 flow for the given packet ID. + /// Returns false if a flow for this packet ID already exists (duplicate publish). + /// + public bool BeginPublish(ushort packetId) + { + var flow = new MqttQos2Flow + { + State = MqttQos2State.AwaitingPubRec, + StartedAtUtc = _timeProvider.GetUtcNow().UtcDateTime, + }; + + return _flows.TryAdd(packetId, flow); + } + + /// + /// Processes a PUBREC for the given packet ID. + /// Returns false if the flow is not in the expected state. + /// + public bool ProcessPubRec(ushort packetId) + { + if (!_flows.TryGetValue(packetId, out var flow)) + return false; + + if (flow.State != MqttQos2State.AwaitingPubRec) + return false; + + flow.State = MqttQos2State.AwaitingPubRel; + return true; + } + + /// + /// Processes a PUBREL for the given packet ID. + /// Returns false if the flow is not in the expected state. + /// + public bool ProcessPubRel(ushort packetId) + { + if (!_flows.TryGetValue(packetId, out var flow)) + return false; + + if (flow.State != MqttQos2State.AwaitingPubRel) + return false; + + flow.State = MqttQos2State.AwaitingPubComp; + return true; + } + + /// + /// Processes a PUBCOMP for the given packet ID. + /// Returns false if the flow is not in the expected state. + /// Removes the flow on completion. + /// + public bool ProcessPubComp(ushort packetId) + { + if (!_flows.TryGetValue(packetId, out var flow)) + return false; + + if (flow.State != MqttQos2State.AwaitingPubComp) + return false; + + flow.State = MqttQos2State.Complete; + _flows.TryRemove(packetId, out _); + return true; + } + + /// + /// Gets the current state for a packet ID, or null if no flow exists. + /// + public MqttQos2State? GetState(ushort packetId) + { + if (_flows.TryGetValue(packetId, out var flow)) + return flow.State; + + return null; + } + + /// + /// Returns packet IDs for flows that have exceeded the timeout. + /// + public IReadOnlyList GetTimedOutFlows() + { + var now = _timeProvider.GetUtcNow().UtcDateTime; + var timedOut = new List(); + + foreach (var kvp in _flows) + { + if (now - kvp.Value.StartedAtUtc > _timeout) + timedOut.Add(kvp.Key); + } + + return timedOut; + } + + /// + /// Removes a flow (e.g., after timeout cleanup). + /// + public void RemoveFlow(ushort packetId) => + _flows.TryRemove(packetId, out _); +} diff --git a/src/NATS.Server/Mqtt/MqttSessionStore.cs b/src/NATS.Server/Mqtt/MqttSessionStore.cs new file mode 100644 index 0000000..5dcb387 --- /dev/null +++ b/src/NATS.Server/Mqtt/MqttSessionStore.cs @@ -0,0 +1,133 @@ +// MQTT session persistence store. +// Go reference: golang/nats-server/server/mqtt.go:253-300 +// Session state management — mqttInitSessionStore / mqttStoreSession +// Flapper detection — mqttCheckFlapper (lines ~300–360) + +using System.Collections.Concurrent; + +namespace NATS.Server.Mqtt; + +/// +/// Serializable session data for an MQTT client. +/// Go reference: server/mqtt.go mqttSession struct ~line 253. +/// +public sealed record MqttSessionData +{ + public required string ClientId { get; init; } + public Dictionary Subscriptions { get; init; } = []; + public List PendingPublishes { get; init; } = []; + public string? WillTopic { get; init; } + public byte[]? WillPayload { get; init; } + public int WillQoS { get; init; } + public bool WillRetain { get; init; } + public bool CleanSession { get; init; } + public DateTime ConnectedAtUtc { get; init; } = DateTime.UtcNow; + public DateTime LastActivityUtc { get; set; } = DateTime.UtcNow; +} + +/// +/// In-memory MQTT session store with flapper detection. +/// The abstraction allows future JetStream backing. +/// Go reference: server/mqtt.go mqttInitSessionStore ~line 260. +/// +public sealed class MqttSessionStore +{ + private readonly ConcurrentDictionary _sessions = new(StringComparer.Ordinal); + private readonly ConcurrentDictionary> _connectHistory = new(StringComparer.Ordinal); + + private readonly TimeSpan _flapWindow; + private readonly int _flapThreshold; + private readonly TimeSpan _flapBackoff; + private readonly TimeProvider _timeProvider; + + /// + /// Initializes a new session store. + /// + /// Window in which repeated connects trigger flap detection. Default 10 seconds. + /// Number of connects within the window to trigger backoff. Default 3. + /// Backoff delay to apply when flapping. Default 1 second. + /// Optional time provider for testing. Default uses system clock. + public MqttSessionStore( + TimeSpan? flapWindow = null, + int flapThreshold = 3, + TimeSpan? flapBackoff = null, + TimeProvider? timeProvider = null) + { + _flapWindow = flapWindow ?? TimeSpan.FromSeconds(10); + _flapThreshold = flapThreshold; + _flapBackoff = flapBackoff ?? TimeSpan.FromSeconds(1); + _timeProvider = timeProvider ?? TimeProvider.System; + } + + /// + /// Saves (or overwrites) session data for the given client. + /// Go reference: server/mqtt.go mqttStoreSession. + /// + public void SaveSession(MqttSessionData session) + { + ArgumentNullException.ThrowIfNull(session); + _sessions[session.ClientId] = session; + } + + /// + /// Loads session data for the given client, or null if not found. + /// Go reference: server/mqtt.go mqttLoadSession. + /// + public MqttSessionData? LoadSession(string clientId) => + _sessions.TryGetValue(clientId, out var session) ? session : null; + + /// + /// Deletes the session for the given client. No-op if not found. + /// Go reference: server/mqtt.go mqttDeleteSession. + /// + public void DeleteSession(string clientId) => + _sessions.TryRemove(clientId, out _); + + /// + /// Returns all active sessions. + /// + public IReadOnlyList ListSessions() => + _sessions.Values.ToList(); + + /// + /// Tracks a connect or disconnect event for flapper detection. + /// Go reference: server/mqtt.go mqttCheckFlapper ~line 300. + /// + /// The MQTT client identifier. + /// True for connect, false for disconnect. + public void TrackConnectDisconnect(string clientId, bool connected) + { + if (!connected) + return; + + var now = _timeProvider.GetUtcNow().UtcDateTime; + var history = _connectHistory.GetOrAdd(clientId, static _ => []); + + lock (history) + { + // Prune entries outside the flap window + var cutoff = now - _flapWindow; + history.RemoveAll(t => t < cutoff); + history.Add(now); + } + } + + /// + /// Returns the backoff delay if the client is flapping, otherwise . + /// Go reference: server/mqtt.go mqttCheckFlapper ~line 320. + /// + public TimeSpan ShouldApplyBackoff(string clientId) + { + if (!_connectHistory.TryGetValue(clientId, out var history)) + return TimeSpan.Zero; + + var now = _timeProvider.GetUtcNow().UtcDateTime; + + lock (history) + { + var cutoff = now - _flapWindow; + history.RemoveAll(t => t < cutoff); + return history.Count >= _flapThreshold ? _flapBackoff : TimeSpan.Zero; + } + } +} diff --git a/tests/NATS.Server.Tests/Mqtt/MqttQosTests.cs b/tests/NATS.Server.Tests/Mqtt/MqttQosTests.cs new file mode 100644 index 0000000..9d893bf --- /dev/null +++ b/tests/NATS.Server.Tests/Mqtt/MqttQosTests.cs @@ -0,0 +1,190 @@ +// MQTT QoS and retained message tests. +// Go reference: golang/nats-server/server/mqtt.go +// Retained messages — mqttHandleRetainedMsg / mqttGetRetainedMessages (~lines 1600–1700) +// QoS 2 flow — mqttProcessPubRec / mqttProcessPubRel / mqttProcessPubComp (~lines 1300–1400) + +using System.Text; +using NATS.Server.Mqtt; + +namespace NATS.Server.Tests.Mqtt; + +public class MqttQosTests +{ + [Fact] + public void RetainedStore_SetAndGet_RoundTrips() + { + // Go reference: server/mqtt.go mqttHandleRetainedMsg — store and retrieve + var store = new MqttRetainedStore(); + var payload = Encoding.UTF8.GetBytes("temperature=72.5"); + + store.SetRetained("sensors/temp", payload); + + var result = store.GetRetained("sensors/temp"); + result.ShouldNotBeNull(); + Encoding.UTF8.GetString(result.Value.Span).ShouldBe("temperature=72.5"); + } + + [Fact] + public void RetainedStore_EmptyPayload_ClearsRetained() + { + // Go reference: server/mqtt.go mqttHandleRetainedMsg — empty payload clears + var store = new MqttRetainedStore(); + store.SetRetained("sensors/temp", Encoding.UTF8.GetBytes("old-value")); + + store.SetRetained("sensors/temp", ReadOnlyMemory.Empty); + + store.GetRetained("sensors/temp").ShouldBeNull(); + } + + [Fact] + public void RetainedStore_Overwrite_ReplacesOld() + { + // Go reference: server/mqtt.go mqttHandleRetainedMsg — overwrite replaces + var store = new MqttRetainedStore(); + store.SetRetained("sensors/temp", Encoding.UTF8.GetBytes("first")); + + store.SetRetained("sensors/temp", Encoding.UTF8.GetBytes("second")); + + var result = store.GetRetained("sensors/temp"); + result.ShouldNotBeNull(); + Encoding.UTF8.GetString(result.Value.Span).ShouldBe("second"); + } + + [Fact] + public void RetainedStore_GetMatching_WildcardPlus() + { + // Go reference: server/mqtt.go mqttGetRetainedMessages — '+' single-level wildcard + var store = new MqttRetainedStore(); + store.SetRetained("sensors/temp", Encoding.UTF8.GetBytes("72.5")); + store.SetRetained("sensors/humidity", Encoding.UTF8.GetBytes("45%")); + store.SetRetained("alerts/fire", Encoding.UTF8.GetBytes("!")); + + var matches = store.GetMatchingRetained("sensors/+"); + + matches.Count.ShouldBe(2); + matches.Select(m => m.Topic).ShouldBe( + new[] { "sensors/temp", "sensors/humidity" }, + ignoreOrder: true); + } + + [Fact] + public void RetainedStore_GetMatching_WildcardHash() + { + // Go reference: server/mqtt.go mqttGetRetainedMessages — '#' multi-level wildcard + var store = new MqttRetainedStore(); + store.SetRetained("home/living/temp", Encoding.UTF8.GetBytes("22")); + store.SetRetained("home/living/light", Encoding.UTF8.GetBytes("on")); + store.SetRetained("home/kitchen/temp", Encoding.UTF8.GetBytes("24")); + store.SetRetained("office/desk/light", Encoding.UTF8.GetBytes("off")); + + var matches = store.GetMatchingRetained("home/#"); + + matches.Count.ShouldBe(3); + matches.Select(m => m.Topic).ShouldBe( + new[] { "home/living/temp", "home/living/light", "home/kitchen/temp" }, + ignoreOrder: true); + } + + [Fact] + public void Qos2_FullFlow_PubRecPubRelPubComp() + { + // Go reference: server/mqtt.go mqttProcessPubRec / mqttProcessPubRel / mqttProcessPubComp + var sm = new MqttQos2StateMachine(); + + // Begin publish + sm.BeginPublish(100).ShouldBeTrue(); + sm.GetState(100).ShouldBe(MqttQos2State.AwaitingPubRec); + + // PUBREC + sm.ProcessPubRec(100).ShouldBeTrue(); + sm.GetState(100).ShouldBe(MqttQos2State.AwaitingPubRel); + + // PUBREL + sm.ProcessPubRel(100).ShouldBeTrue(); + sm.GetState(100).ShouldBe(MqttQos2State.AwaitingPubComp); + + // PUBCOMP — completes and removes flow + sm.ProcessPubComp(100).ShouldBeTrue(); + sm.GetState(100).ShouldBeNull(); + } + + [Fact] + public void Qos2_DuplicatePublish_Rejected() + { + // Go reference: server/mqtt.go — duplicate packet ID rejected during active flow + var sm = new MqttQos2StateMachine(); + + sm.BeginPublish(200).ShouldBeTrue(); + + // Same packet ID while flow is active — should be rejected + sm.BeginPublish(200).ShouldBeFalse(); + } + + [Fact] + public void Qos2_IncompleteFlow_TimesOut() + { + // Go reference: server/mqtt.go — incomplete QoS 2 flows time out + var fakeTime = new FakeTimeProvider(new DateTimeOffset(2026, 1, 15, 12, 0, 0, TimeSpan.Zero)); + var sm = new MqttQos2StateMachine(timeout: TimeSpan.FromSeconds(5), timeProvider: fakeTime); + + sm.BeginPublish(300).ShouldBeTrue(); + + // Not timed out yet + fakeTime.Advance(TimeSpan.FromSeconds(3)); + sm.GetTimedOutFlows().ShouldBeEmpty(); + + // Advance past timeout + fakeTime.Advance(TimeSpan.FromSeconds(3)); + var timedOut = sm.GetTimedOutFlows(); + timedOut.Count.ShouldBe(1); + timedOut[0].ShouldBe((ushort)300); + + // Clean up + sm.RemoveFlow(300); + sm.GetState(300).ShouldBeNull(); + } + + [Fact] + public void Qos1_Puback_RemovesPending() + { + // Go reference: server/mqtt.go — QoS 1 PUBACK removes from pending + // This tests the existing MqttListener pending publish / ack mechanism + // in the context of the session store. + var store = new MqttSessionStore(); + var session = new MqttSessionData + { + ClientId = "qos1-client", + PendingPublishes = + [ + new MqttPendingPublish(1, "topic/a", "payload-a"), + new MqttPendingPublish(2, "topic/b", "payload-b"), + ], + }; + + store.SaveSession(session); + + // Simulate PUBACK for packet 1: remove it from pending + var loaded = store.LoadSession("qos1-client"); + loaded.ShouldNotBeNull(); + loaded.PendingPublishes.RemoveAll(p => p.PacketId == 1); + store.SaveSession(loaded); + + // Verify only packet 2 remains + var updated = store.LoadSession("qos1-client"); + updated.ShouldNotBeNull(); + updated.PendingPublishes.Count.ShouldBe(1); + updated.PendingPublishes[0].PacketId.ShouldBe(2); + } + + [Fact] + public void RetainedStore_GetMatching_NoMatch_ReturnsEmpty() + { + // Go reference: server/mqtt.go mqttGetRetainedMessages — no match returns empty + var store = new MqttRetainedStore(); + store.SetRetained("sensors/temp", Encoding.UTF8.GetBytes("72")); + + var matches = store.GetMatchingRetained("alerts/+"); + + matches.ShouldBeEmpty(); + } +} diff --git a/tests/NATS.Server.Tests/Mqtt/MqttSessionPersistenceTests.cs b/tests/NATS.Server.Tests/Mqtt/MqttSessionPersistenceTests.cs new file mode 100644 index 0000000..711d991 --- /dev/null +++ b/tests/NATS.Server.Tests/Mqtt/MqttSessionPersistenceTests.cs @@ -0,0 +1,209 @@ +// MQTT session persistence tests. +// Go reference: golang/nats-server/server/mqtt.go:253-360 +// Session store — mqttInitSessionStore / mqttStoreSession / mqttLoadSession +// Flapper detection — mqttCheckFlapper (~lines 300–360) + +using NATS.Server.Mqtt; + +namespace NATS.Server.Tests.Mqtt; + +public class MqttSessionPersistenceTests +{ + [Fact] + public void SaveSession_ThenLoad_RoundTrips() + { + // Go reference: server/mqtt.go mqttStoreSession / mqttLoadSession + var store = new MqttSessionStore(); + var session = new MqttSessionData + { + ClientId = "client-1", + Subscriptions = new Dictionary { ["sensors/temp"] = 1, ["alerts/#"] = 0 }, + PendingPublishes = [new MqttPendingPublish(42, "sensors/temp", "72.5")], + WillTopic = "clients/offline", + WillPayload = [0x01, 0x02], + WillQoS = 1, + WillRetain = true, + CleanSession = false, + ConnectedAtUtc = new DateTime(2026, 1, 15, 10, 30, 0, DateTimeKind.Utc), + LastActivityUtc = new DateTime(2026, 1, 15, 10, 35, 0, DateTimeKind.Utc), + }; + + store.SaveSession(session); + var loaded = store.LoadSession("client-1"); + + loaded.ShouldNotBeNull(); + loaded.ClientId.ShouldBe("client-1"); + loaded.Subscriptions.Count.ShouldBe(2); + loaded.Subscriptions["sensors/temp"].ShouldBe(1); + loaded.Subscriptions["alerts/#"].ShouldBe(0); + loaded.PendingPublishes.Count.ShouldBe(1); + loaded.PendingPublishes[0].PacketId.ShouldBe(42); + loaded.PendingPublishes[0].Topic.ShouldBe("sensors/temp"); + loaded.PendingPublishes[0].Payload.ShouldBe("72.5"); + loaded.WillTopic.ShouldBe("clients/offline"); + loaded.WillPayload.ShouldBe(new byte[] { 0x01, 0x02 }); + loaded.WillQoS.ShouldBe(1); + loaded.WillRetain.ShouldBeTrue(); + loaded.CleanSession.ShouldBeFalse(); + loaded.ConnectedAtUtc.ShouldBe(new DateTime(2026, 1, 15, 10, 30, 0, DateTimeKind.Utc)); + loaded.LastActivityUtc.ShouldBe(new DateTime(2026, 1, 15, 10, 35, 0, DateTimeKind.Utc)); + } + + [Fact] + public void SaveSession_Update_OverwritesPrevious() + { + // Go reference: server/mqtt.go mqttStoreSession — overwrites existing + var store = new MqttSessionStore(); + + store.SaveSession(new MqttSessionData + { + ClientId = "client-x", + Subscriptions = new Dictionary { ["old/topic"] = 0 }, + }); + + store.SaveSession(new MqttSessionData + { + ClientId = "client-x", + Subscriptions = new Dictionary { ["new/topic"] = 1 }, + }); + + var loaded = store.LoadSession("client-x"); + loaded.ShouldNotBeNull(); + loaded.Subscriptions.ShouldContainKey("new/topic"); + loaded.Subscriptions.ShouldNotContainKey("old/topic"); + } + + [Fact] + public void LoadSession_NonExistent_ReturnsNull() + { + // Go reference: server/mqtt.go mqttLoadSession — returns nil for missing + var store = new MqttSessionStore(); + + var loaded = store.LoadSession("does-not-exist"); + + loaded.ShouldBeNull(); + } + + [Fact] + public void DeleteSession_RemovesFromStore() + { + // Go reference: server/mqtt.go mqttDeleteSession + var store = new MqttSessionStore(); + store.SaveSession(new MqttSessionData { ClientId = "to-delete" }); + + store.DeleteSession("to-delete"); + + store.LoadSession("to-delete").ShouldBeNull(); + } + + [Fact] + public void DeleteSession_NonExistent_NoError() + { + // Go reference: server/mqtt.go mqttDeleteSession — no-op on missing + var store = new MqttSessionStore(); + + // Should not throw + store.DeleteSession("phantom"); + + store.LoadSession("phantom").ShouldBeNull(); + } + + [Fact] + public void ListSessions_ReturnsAllActive() + { + // Go reference: server/mqtt.go session enumeration + var store = new MqttSessionStore(); + store.SaveSession(new MqttSessionData { ClientId = "alpha" }); + store.SaveSession(new MqttSessionData { ClientId = "beta" }); + store.SaveSession(new MqttSessionData { ClientId = "gamma" }); + + var sessions = store.ListSessions(); + + sessions.Count.ShouldBe(3); + sessions.Select(s => s.ClientId).ShouldBe( + new[] { "alpha", "beta", "gamma" }, + ignoreOrder: true); + } + + [Fact] + public void FlapperDetection_ThreeConnectsInTenSeconds_BackoffApplied() + { + // Go reference: server/mqtt.go mqttCheckFlapper ~line 300 + // Three connects within the flap window triggers backoff. + var fakeTime = new FakeTimeProvider(new DateTimeOffset(2026, 1, 15, 12, 0, 0, TimeSpan.Zero)); + var store = new MqttSessionStore( + flapWindow: TimeSpan.FromSeconds(10), + flapThreshold: 3, + flapBackoff: TimeSpan.FromSeconds(1), + timeProvider: fakeTime); + + // Three rapid connects + store.TrackConnectDisconnect("flapper", connected: true); + fakeTime.Advance(TimeSpan.FromSeconds(1)); + store.TrackConnectDisconnect("flapper", connected: true); + fakeTime.Advance(TimeSpan.FromSeconds(1)); + store.TrackConnectDisconnect("flapper", connected: true); + + var backoff = store.ShouldApplyBackoff("flapper"); + backoff.ShouldBeGreaterThan(TimeSpan.Zero); + backoff.ShouldBe(TimeSpan.FromSeconds(1)); + } + + [Fact] + public void FlapperDetection_SlowConnects_NoBackoff() + { + // Go reference: server/mqtt.go mqttCheckFlapper — slow connects should not trigger + var fakeTime = new FakeTimeProvider(new DateTimeOffset(2026, 1, 15, 12, 0, 0, TimeSpan.Zero)); + var store = new MqttSessionStore( + flapWindow: TimeSpan.FromSeconds(10), + flapThreshold: 3, + flapBackoff: TimeSpan.FromSeconds(1), + timeProvider: fakeTime); + + // Three connects, but spread out beyond the window + store.TrackConnectDisconnect("slow-client", connected: true); + fakeTime.Advance(TimeSpan.FromSeconds(5)); + store.TrackConnectDisconnect("slow-client", connected: true); + fakeTime.Advance(TimeSpan.FromSeconds(6)); // first connect now outside window + store.TrackConnectDisconnect("slow-client", connected: true); + + var backoff = store.ShouldApplyBackoff("slow-client"); + backoff.ShouldBe(TimeSpan.Zero); + } + + [Fact] + public void CleanSession_DeletesOnConnect() + { + // Go reference: server/mqtt.go — clean session flag clears stored state + var store = new MqttSessionStore(); + + // Pre-populate a session + store.SaveSession(new MqttSessionData + { + ClientId = "ephemeral", + Subscriptions = new Dictionary { ["topic/a"] = 1 }, + CleanSession = false, + }); + + store.LoadSession("ephemeral").ShouldNotBeNull(); + + // Simulate clean session connect: delete the old session + store.DeleteSession("ephemeral"); + + store.LoadSession("ephemeral").ShouldBeNull(); + } +} + +/// +/// Fake for deterministic time control in tests. +/// +internal sealed class FakeTimeProvider(DateTimeOffset startTime) : TimeProvider +{ + private DateTimeOffset _current = startTime; + + public override DateTimeOffset GetUtcNow() => _current; + + public void Advance(TimeSpan duration) => _current += duration; + + public void SetUtcNow(DateTimeOffset value) => _current = value; +}