diff --git a/src/NATS.Server/Mqtt/MqttSessionStore.cs b/src/NATS.Server/Mqtt/MqttSessionStore.cs index f2871de..9ea9531 100644 --- a/src/NATS.Server/Mqtt/MqttSessionStore.cs +++ b/src/NATS.Server/Mqtt/MqttSessionStore.cs @@ -9,6 +9,25 @@ using NATS.Server.JetStream.Storage; namespace NATS.Server.Mqtt; +/// +/// Per-client flapper detection state tracking connect/disconnect cycling. +/// Go reference: server/mqtt.go mqttCheckFlapper ~line 300. +/// +public sealed class FlapperState +{ + /// Total number of connect/disconnect events tracked in the current window. + public int ConnectDisconnectCount { get; set; } + + /// Start of the current detection window. + public DateTime WindowStart { get; set; } + + /// When the backoff expires, or null if not currently backing off. + public DateTime? BackoffUntil { get; set; } + + /// Current exponential backoff level (0 = 1s, 1 = 2s, 2 = 4s, …, capped at 60s). + public int BackoffLevel { get; set; } +} + /// /// Will message to be published on abnormal client disconnection. /// Go reference: server/mqtt.go mqttWill struct ~line 270. @@ -49,6 +68,7 @@ public sealed class MqttSessionStore { private readonly ConcurrentDictionary _sessions = new(StringComparer.Ordinal); private readonly ConcurrentDictionary> _connectHistory = new(StringComparer.Ordinal); + private readonly ConcurrentDictionary _flapperStates = new(StringComparer.Ordinal); private readonly ConcurrentDictionary _wills = new(StringComparer.Ordinal); private readonly ConcurrentDictionary _delayedWills = new(StringComparer.Ordinal); @@ -196,28 +216,133 @@ public sealed class MqttSessionStore _sessions.Values.ToList(); /// - /// Tracks a connect or disconnect event for flapper detection. + /// Backward-compatible overload: tracks a connect event (disconnect is ignored). + /// Delegates to when is true. /// Go reference: server/mqtt.go mqttCheckFlapper ~line 300. /// - /// The MQTT client identifier. - /// True for connect, false for disconnect. public void TrackConnectDisconnect(string clientId, bool connected) { + if (connected) + TrackConnectDisconnect(clientId); + + // Also maintain the legacy _connectHistory for ShouldApplyBackoff callers if (!connected) return; var now = _timeProvider.GetUtcNow().UtcDateTime; var history = _connectHistory.GetOrAdd(clientId, static _ => []); - lock (history) { - // Prune entries outside the flap window var cutoff = now - _flapWindow; history.RemoveAll(t => t < cutoff); history.Add(now); } } + /// + /// Tracks a connect or disconnect event for flapper detection. + /// Increments the count for the client within the detection window. If 3+ events + /// occur within 10 seconds the client is marked as a flapper and exponential + /// backoff is applied: min(2^backoffLevel * 1000, 60000) ms. + /// Go reference: server/mqtt.go mqttCheckFlapper ~line 300. + /// + /// The MQTT client identifier. + /// The updated for the client. + public FlapperState TrackConnectDisconnect(string clientId) + { + var now = _timeProvider.GetUtcNow().UtcDateTime; + + var state = _flapperStates.GetOrAdd(clientId, static _ => new FlapperState + { + WindowStart = DateTime.UtcNow, + }); + + lock (state) + { + // Reset window if we're past the flap window duration + if (now - state.WindowStart > _flapWindow) + { + state.ConnectDisconnectCount = 0; + state.WindowStart = now; + } + + state.ConnectDisconnectCount++; + + if (state.ConnectDisconnectCount >= _flapThreshold) + { + // Exponential backoff: min(2^backoffLevel * 1000 ms, 60000 ms) + var backoffMs = Math.Min((int)Math.Pow(2, state.BackoffLevel) * 1000, 60_000); + state.BackoffUntil = now + TimeSpan.FromMilliseconds(backoffMs); + state.BackoffLevel++; + } + } + + return state; + } + + /// + /// Returns true if the client is currently in a backoff period (is a flapper). + /// Go reference: server/mqtt.go mqttCheckFlapper ~line 320. + /// + public bool IsFlapper(string clientId) + { + if (!_flapperStates.TryGetValue(clientId, out var state)) + return false; + + var now = _timeProvider.GetUtcNow().UtcDateTime; + lock (state) + { + return state.BackoffUntil.HasValue && state.BackoffUntil.Value > now; + } + } + + /// + /// Returns the remaining backoff in milliseconds, or 0 if the client is not flapping. + /// Go reference: server/mqtt.go mqttCheckFlapper ~line 325. + /// + public long GetBackoffMs(string clientId) + { + if (!_flapperStates.TryGetValue(clientId, out var state)) + return 0; + + var now = _timeProvider.GetUtcNow().UtcDateTime; + lock (state) + { + if (!state.BackoffUntil.HasValue || state.BackoffUntil.Value <= now) + return 0; + + return (long)(state.BackoffUntil.Value - now).TotalMilliseconds; + } + } + + /// + /// Removes all flapper tracking state for the given client. + /// Called when stability is restored or the client is cleanly disconnected. + /// + public void ClearFlapperState(string clientId) => + _flapperStates.TryRemove(clientId, out _); + + /// + /// Clears flapper states for clients whose has + /// expired by at least ago, indicating the client + /// has been stable for the given duration. + /// + /// How long past the BackoffUntil expiry before clearing. + public void CheckAndClearStableClients(TimeSpan stableThreshold) + { + var now = _timeProvider.GetUtcNow().UtcDateTime; + foreach (var (clientId, state) in _flapperStates) + { + lock (state) + { + if (state.BackoffUntil.HasValue && now - state.BackoffUntil.Value >= stableThreshold) + { + _flapperStates.TryRemove(clientId, out _); + } + } + } + } + /// /// Returns the backoff delay if the client is flapping, otherwise . /// Go reference: server/mqtt.go mqttCheckFlapper ~line 320. diff --git a/tests/NATS.Server.Tests/Mqtt/MqttFlapperDetectionTests.cs b/tests/NATS.Server.Tests/Mqtt/MqttFlapperDetectionTests.cs new file mode 100644 index 0000000..4525be7 --- /dev/null +++ b/tests/NATS.Server.Tests/Mqtt/MqttFlapperDetectionTests.cs @@ -0,0 +1,248 @@ +// MQTT flapper detection tests — exponential backoff for rapid reconnectors. +// Go reference: golang/nats-server/server/mqtt.go mqttCheckFlapper ~lines 300–360. + +using NATS.Server.Mqtt; +using Shouldly; + +namespace NATS.Server.Tests.Mqtt; + +public class MqttFlapperDetectionTests +{ + // ------------------------------------------------------------------------- + // Helper: FakeTimeProvider from Microsoft.Extensions.TimeProvider.Testing + // ------------------------------------------------------------------------- + + private static MqttSessionStore CreateStore( + FakeTimeProvider? time = null, + int flapThreshold = 3, + TimeSpan? flapWindow = null) => + new( + flapWindow: flapWindow ?? TimeSpan.FromSeconds(10), + flapThreshold: flapThreshold, + timeProvider: time); + + // ------------------------------------------------------------------------- + // 1. TrackConnectDisconnect_counts_events + // ------------------------------------------------------------------------- + + [Fact] + public void TrackConnectDisconnect_counts_events() + { + // Go reference: server/mqtt.go mqttCheckFlapper — each connect increments the counter. + var store = CreateStore(); + + var s1 = store.TrackConnectDisconnect("client-a"); + s1.ConnectDisconnectCount.ShouldBe(1); + + var s2 = store.TrackConnectDisconnect("client-a"); + s2.ConnectDisconnectCount.ShouldBe(2); + } + + // ------------------------------------------------------------------------- + // 2. Not_flapper_below_threshold + // ------------------------------------------------------------------------- + + [Fact] + public void Not_flapper_below_threshold() + { + // Go reference: server/mqtt.go mqttCheckFlapper — threshold is 3; 2 events should not mark as flapper. + var store = CreateStore(); + + store.TrackConnectDisconnect("client-b"); + store.TrackConnectDisconnect("client-b"); + + store.IsFlapper("client-b").ShouldBeFalse(); + } + + // ------------------------------------------------------------------------- + // 3. Becomes_flapper_at_threshold + // ------------------------------------------------------------------------- + + [Fact] + public void Becomes_flapper_at_threshold() + { + // Go reference: server/mqtt.go mqttCheckFlapper — 3 events within window marks the client. + var time = new FakeTimeProvider(DateTimeOffset.UtcNow); + var store = CreateStore(time); + + store.TrackConnectDisconnect("client-c"); + time.Advance(TimeSpan.FromSeconds(1)); + store.TrackConnectDisconnect("client-c"); + time.Advance(TimeSpan.FromSeconds(1)); + store.TrackConnectDisconnect("client-c"); + + store.IsFlapper("client-c").ShouldBeTrue(); + } + + // ------------------------------------------------------------------------- + // 4. Backoff_increases_exponentially + // ------------------------------------------------------------------------- + + [Fact] + public void Backoff_increases_exponentially() + { + // Go reference: server/mqtt.go mqttCheckFlapper — backoff doubles on each new flap trigger. + // Level 0 → 1 s, Level 1 → 2 s, Level 2 → 4 s. + var time = new FakeTimeProvider(DateTimeOffset.UtcNow); + var store = CreateStore(time); + + // First flap at level 0 (1 s backoff) + store.TrackConnectDisconnect("client-d"); + store.TrackConnectDisconnect("client-d"); + var s1 = store.TrackConnectDisconnect("client-d"); + s1.BackoffLevel.ShouldBe(1); // incremented after applying level 0 + s1.BackoffUntil.ShouldNotBeNull(); + var backoff1 = s1.BackoffUntil!.Value - time.GetUtcNow().UtcDateTime; + backoff1.TotalMilliseconds.ShouldBeInRange(900, 1100); // ~1 000 ms + + // Advance past the backoff and trigger again — level 1 (2 s) + time.Advance(TimeSpan.FromSeconds(2)); + var s2 = store.TrackConnectDisconnect("client-d"); + s2.BackoffLevel.ShouldBe(2); + var backoff2 = s2.BackoffUntil!.Value - time.GetUtcNow().UtcDateTime; + backoff2.TotalMilliseconds.ShouldBeInRange(1900, 2100); // ~2 000 ms + + // Advance past and trigger once more — level 2 (4 s) + time.Advance(TimeSpan.FromSeconds(3)); + var s3 = store.TrackConnectDisconnect("client-d"); + s3.BackoffLevel.ShouldBe(3); + var backoff3 = s3.BackoffUntil!.Value - time.GetUtcNow().UtcDateTime; + backoff3.TotalMilliseconds.ShouldBeInRange(3900, 4100); // ~4 000 ms + } + + // ------------------------------------------------------------------------- + // 5. Backoff_capped_at_60_seconds + // ------------------------------------------------------------------------- + + [Fact] + public void Backoff_capped_at_60_seconds() + { + // Go reference: server/mqtt.go mqttCheckFlapper — cap the maximum backoff at 60 s. + var time = new FakeTimeProvider(DateTimeOffset.UtcNow); + var store = CreateStore(time); + + // Trigger enough flaps to overflow past 60 s (level 6 = 64 s, which should cap at 60 s) + for (var i = 0; i < 10; i++) + { + store.TrackConnectDisconnect("client-e"); + time.Advance(TimeSpan.FromMilliseconds(100)); + } + + var state = store.TrackConnectDisconnect("client-e"); + var remaining = state.BackoffUntil!.Value - time.GetUtcNow().UtcDateTime; + remaining.TotalMilliseconds.ShouldBeLessThanOrEqualTo(60_001); // max 60 s (±1 ms tolerance) + remaining.TotalMilliseconds.ShouldBeGreaterThan(0); + } + + // ------------------------------------------------------------------------- + // 6. GetBackoffMs_returns_remaining + // ------------------------------------------------------------------------- + + [Fact] + public void GetBackoffMs_returns_remaining() + { + // Go reference: server/mqtt.go mqttCheckFlapper — caller can query remaining backoff time. + var time = new FakeTimeProvider(DateTimeOffset.UtcNow); + var store = CreateStore(time); + + store.TrackConnectDisconnect("client-f"); + store.TrackConnectDisconnect("client-f"); + store.TrackConnectDisconnect("client-f"); // threshold hit + + var ms = store.GetBackoffMs("client-f"); + ms.ShouldBeGreaterThan(0); + ms.ShouldBeLessThanOrEqualTo(1000); + } + + // ------------------------------------------------------------------------- + // 7. GetBackoffMs_zero_when_not_flapping + // ------------------------------------------------------------------------- + + [Fact] + public void GetBackoffMs_zero_when_not_flapping() + { + // Not enough events to trigger backoff — remaining ms should be 0. + var store = CreateStore(); + + store.TrackConnectDisconnect("client-g"); + store.TrackConnectDisconnect("client-g"); + + store.GetBackoffMs("client-g").ShouldBe(0); + } + + // ------------------------------------------------------------------------- + // 8. ClearFlapperState_removes_tracking + // ------------------------------------------------------------------------- + + [Fact] + public void ClearFlapperState_removes_tracking() + { + // Go reference: server/mqtt.go — stable clients should have state purged. + var store = CreateStore(); + + store.TrackConnectDisconnect("client-h"); + store.TrackConnectDisconnect("client-h"); + store.TrackConnectDisconnect("client-h"); + store.IsFlapper("client-h").ShouldBeTrue(); + + store.ClearFlapperState("client-h"); + + store.IsFlapper("client-h").ShouldBeFalse(); + store.GetBackoffMs("client-h").ShouldBe(0); + } + + // ------------------------------------------------------------------------- + // 9. Window_resets_after_10_seconds + // ------------------------------------------------------------------------- + + [Fact] + public void Window_resets_after_10_seconds() + { + // Go reference: server/mqtt.go mqttCheckFlapper — window-based detection resets. + // Track 2 events, advance past the window, add 1 more — should NOT be a flapper. + var time = new FakeTimeProvider(DateTimeOffset.UtcNow); + var store = CreateStore(time); + + store.TrackConnectDisconnect("client-i"); + store.TrackConnectDisconnect("client-i"); + + // Advance past the 10 s flap window + time.Advance(TimeSpan.FromSeconds(11)); + + // Directly set the WindowStart via the returned state to simulate the old window + // being in the past. A single new event in a new window should not cross threshold. + store.TrackConnectDisconnect("client-i"); + + store.IsFlapper("client-i").ShouldBeFalse(); + } + + // ------------------------------------------------------------------------- + // 10. CheckAndClearStableClients_clears_old + // ------------------------------------------------------------------------- + + [Fact] + public void CheckAndClearStableClients_clears_old() + { + // Go reference: server/mqtt.go — periodic sweep clears long-stable flapper records. + var time = new FakeTimeProvider(DateTimeOffset.UtcNow); + var store = CreateStore(time); + + // Trigger flap + store.TrackConnectDisconnect("client-j"); + store.TrackConnectDisconnect("client-j"); + var state = store.TrackConnectDisconnect("client-j"); + store.IsFlapper("client-j").ShouldBeTrue(); + + // Manually backdate BackoffUntil so it's already expired + lock (state) + { + state.BackoffUntil = time.GetUtcNow().UtcDateTime - TimeSpan.FromSeconds(61); + } + + // A stable-threshold sweep of 60 s should evict the now-expired entry + store.CheckAndClearStableClients(TimeSpan.FromSeconds(60)); + + store.IsFlapper("client-j").ShouldBeFalse(); + store.GetBackoffMs("client-j").ShouldBe(0); + } +}