diff --git a/src/NATS.Server/Subscriptions/SubList.cs b/src/NATS.Server/Subscriptions/SubList.cs new file mode 100644 index 0000000..76bc149 --- /dev/null +++ b/src/NATS.Server/Subscriptions/SubList.cs @@ -0,0 +1,515 @@ +namespace NATS.Server.Subscriptions; + +/// +/// SubList is a routing mechanism to handle subject distribution and +/// provides a facility to match subjects from published messages to +/// interested subscribers. Subscribers can have wildcard subjects to +/// match multiple published subjects. +/// +public sealed class SubList +{ + private const int CacheMax = 1024; + private const int CacheSweep = 256; + + private readonly ReaderWriterLockSlim _lock = new(); + private readonly TrieLevel _root = new(); + private Dictionary? _cache = new(); + private uint _count; + private ulong _genId; + + public uint Count + { + get + { + _lock.EnterReadLock(); + try { return _count; } + finally { _lock.ExitReadLock(); } + } + } + + public void Insert(Subscription sub) + { + var subject = sub.Subject; + + _lock.EnterWriteLock(); + try + { + var level = _root; + TrieNode? node = null; + bool sawFwc = false; + + foreach (var token in new TokenEnumerator(subject)) + { + if (token.Length == 0 || sawFwc) + throw new ArgumentException("Invalid subject", nameof(sub)); + + if (token.Length == 1 && token[0] == SubjectMatch.Pwc) + { + node = level.Pwc ??= new TrieNode(); + } + else if (token.Length == 1 && token[0] == SubjectMatch.Fwc) + { + node = level.Fwc ??= new TrieNode(); + sawFwc = true; + } + else + { + var key = token.ToString(); + if (!level.Nodes.TryGetValue(key, out node)) + { + node = new TrieNode(); + level.Nodes[key] = node; + } + } + + node.Next ??= new TrieLevel(); + level = node.Next; + } + + if (node == null) + throw new ArgumentException("Invalid subject", nameof(sub)); + + if (sub.Queue == null) + { + node.PlainSubs.Add(sub); + } + else + { + if (!node.QueueSubs.TryGetValue(sub.Queue, out var qset)) + { + qset = []; + node.QueueSubs[sub.Queue] = qset; + } + qset.Add(sub); + } + + _count++; + _genId++; + AddToCache(subject, sub); + } + finally + { + _lock.ExitWriteLock(); + } + } + + public void Remove(Subscription sub) + { + _lock.EnterWriteLock(); + try + { + var level = _root; + TrieNode? node = null; + bool sawFwc = false; + + var pathList = new List<(TrieLevel level, TrieNode node, string token, bool isPwc, bool isFwc)>(); + + foreach (var token in new TokenEnumerator(sub.Subject)) + { + if (token.Length == 0 || sawFwc) + return; + + bool isPwc = token.Length == 1 && token[0] == SubjectMatch.Pwc; + bool isFwc = token.Length == 1 && token[0] == SubjectMatch.Fwc; + + if (isPwc) + { + node = level.Pwc; + } + else if (isFwc) + { + node = level.Fwc; + sawFwc = true; + } + else + { + level.Nodes.TryGetValue(token.ToString(), out node); + } + + if (node == null) + return; // not found + + var tokenStr = token.ToString(); + pathList.Add((level, node, tokenStr, isPwc, isFwc)); + level = node.Next ?? new TrieLevel(); + } + + if (node == null) return; + + // Remove from node + bool removed; + if (sub.Queue == null) + { + removed = node.PlainSubs.Remove(sub); + } + else + { + removed = false; + if (node.QueueSubs.TryGetValue(sub.Queue, out var qset)) + { + removed = qset.Remove(sub); + if (qset.Count == 0) + node.QueueSubs.Remove(sub.Queue); + } + } + + if (!removed) return; + + _count--; + _genId++; + RemoveFromCache(sub.Subject); + + // Prune empty nodes (walk backwards) + for (int i = pathList.Count - 1; i >= 0; i--) + { + var (l, n, t, isPwc, isFwc) = pathList[i]; + if (n.IsEmpty) + { + if (isPwc) l.Pwc = null; + else if (isFwc) l.Fwc = null; + else l.Nodes.Remove(t); + } + } + } + finally + { + _lock.ExitWriteLock(); + } + } + + public SubListResult Match(string subject) + { + // Check cache under read lock first. + _lock.EnterReadLock(); + try + { + if (_cache != null && _cache.TryGetValue(subject, out var cached)) + return cached; + } + finally + { + _lock.ExitReadLock(); + } + + // Cache miss -- tokenize and match under write lock (needed for cache update). + // Tokenize the subject. + var tokens = Tokenize(subject); + if (tokens == null) + return SubListResult.Empty; + + _lock.EnterWriteLock(); + try + { + // Re-check cache after acquiring write lock. + if (_cache != null && _cache.TryGetValue(subject, out var cached)) + return cached; + + var plainSubs = new List(); + var queueSubs = new List>(); + + MatchLevel(_root, tokens, 0, plainSubs, queueSubs); + + SubListResult result; + if (plainSubs.Count == 0 && queueSubs.Count == 0) + { + result = SubListResult.Empty; + } + else + { + result = new SubListResult( + plainSubs.ToArray(), + queueSubs.Select(q => q.ToArray()).ToArray()); + } + + if (_cache != null) + { + _cache[subject] = result; + + if (_cache.Count > CacheMax) + { + // Sweep: remove entries until at CacheSweep count. + var keys = _cache.Keys.Take(_cache.Count - CacheSweep).ToList(); + foreach (var key in keys) + _cache.Remove(key); + } + } + + return result; + } + finally + { + _lock.ExitWriteLock(); + } + } + + /// + /// Tokenize the subject into an array of token strings. + /// Returns null if the subject is invalid (empty tokens). + /// + private static string[]? Tokenize(string subject) + { + if (string.IsNullOrEmpty(subject)) + return null; + + var tokens = new List(); + int start = 0; + for (int i = 0; i <= subject.Length; i++) + { + if (i == subject.Length || subject[i] == SubjectMatch.Sep) + { + if (i - start == 0) + return null; // empty token + tokens.Add(subject[start..i]); + start = i + 1; + } + } + + return tokens.Count > 0 ? tokens.ToArray() : null; + } + + /// + /// Recursively descend into the trie matching tokens. + /// This follows the Go matchLevel() algorithm closely. + /// + private static void MatchLevel(TrieLevel? level, string[] tokens, int tokenIndex, + List plainSubs, List> queueSubs) + { + TrieNode? pwc = null; + TrieNode? node = null; + + for (int i = tokenIndex; i < tokens.Length; i++) + { + if (level == null) + return; + + // Full wildcard (>) at this level matches all remaining tokens. + if (level.Fwc != null) + AddNodeToResults(level.Fwc, plainSubs, queueSubs); + + // Partial wildcard (*) -- recurse with remaining tokens. + pwc = level.Pwc; + if (pwc != null) + MatchLevel(pwc.Next, tokens, i + 1, plainSubs, queueSubs); + + // Literal match + node = null; + if (level.Nodes.TryGetValue(tokens[i], out var found)) + { + node = found; + level = node.Next; + } + else + { + level = null; + } + } + + // After processing all tokens, add results from the final literal node. + if (node != null) + AddNodeToResults(node, plainSubs, queueSubs); + + // Also add results from the partial wildcard at the last level, + // which handles the case where * matches the final token. + if (pwc != null) + AddNodeToResults(pwc, plainSubs, queueSubs); + } + + private static void AddNodeToResults(TrieNode node, + List plainSubs, List> queueSubs) + { + // Add plain subscriptions + foreach (var sub in node.PlainSubs) + plainSubs.Add(sub); + + // Add queue subscriptions grouped by queue name + foreach (var (queueName, subs) in node.QueueSubs) + { + if (subs.Count == 0) continue; + + // Find existing queue group or create new one + List? existing = null; + foreach (var qs in queueSubs) + { + if (qs.Count > 0 && qs[0].Queue == queueName) + { + existing = qs; + break; + } + } + if (existing == null) + { + existing = new List(); + queueSubs.Add(existing); + } + existing.AddRange(subs); + } + } + + /// + /// Adds a subscription to matching cache entries. + /// Assumes write lock is held. + /// + private void AddToCache(string subject, Subscription sub) + { + if (_cache == null) + return; + + // If literal subject, we can do a direct lookup. + if (SubjectMatch.IsLiteral(subject)) + { + if (_cache.TryGetValue(subject, out var r)) + { + _cache[subject] = AddSubToResult(r, sub); + } + return; + } + + // Wildcard subscription -- check all cached keys. + var keysToUpdate = new List<(string key, SubListResult result)>(); + foreach (var (key, r) in _cache) + { + if (SubjectMatch.MatchLiteral(key, subject)) + { + keysToUpdate.Add((key, r)); + } + } + foreach (var (key, r) in keysToUpdate) + { + _cache[key] = AddSubToResult(r, sub); + } + } + + /// + /// Removes cache entries that match the given subject. + /// Assumes write lock is held. + /// + private void RemoveFromCache(string subject) + { + if (_cache == null) + return; + + // If literal subject, we can do a direct removal. + if (SubjectMatch.IsLiteral(subject)) + { + _cache.Remove(subject); + return; + } + + // Wildcard subscription -- remove all matching cached keys. + var keysToRemove = new List(); + foreach (var key in _cache.Keys) + { + if (SubjectMatch.MatchLiteral(key, subject)) + { + keysToRemove.Add(key); + } + } + foreach (var key in keysToRemove) + { + _cache.Remove(key); + } + } + + /// + /// Creates a new result with the given subscription added. + /// + private static SubListResult AddSubToResult(SubListResult result, Subscription sub) + { + if (sub.Queue == null) + { + var newPlain = new Subscription[result.PlainSubs.Length + 1]; + result.PlainSubs.CopyTo(newPlain, 0); + newPlain[^1] = sub; + return new SubListResult(newPlain, result.QueueSubs); + } + else + { + // Find existing queue group + var queueSubs = result.QueueSubs; + int slot = -1; + for (int i = 0; i < queueSubs.Length; i++) + { + if (queueSubs[i].Length > 0 && queueSubs[i][0].Queue == sub.Queue) + { + slot = i; + break; + } + } + + // Deep copy queue subs + var newQueueSubs = new Subscription[queueSubs.Length + (slot < 0 ? 1 : 0)][]; + for (int i = 0; i < queueSubs.Length; i++) + { + if (i == slot) + { + var newGroup = new Subscription[queueSubs[i].Length + 1]; + queueSubs[i].CopyTo(newGroup, 0); + newGroup[^1] = sub; + newQueueSubs[i] = newGroup; + } + else + { + newQueueSubs[i] = (Subscription[])queueSubs[i].Clone(); + } + } + if (slot < 0) + { + newQueueSubs[^1] = [sub]; + } + + return new SubListResult(result.PlainSubs, newQueueSubs); + } + } + + /// Enumerates '.' separated tokens in a subject without allocating. + private ref struct TokenEnumerator + { + private ReadOnlySpan _remaining; + + public TokenEnumerator(string subject) + { + _remaining = subject.AsSpan(); + Current = default; + } + + public ReadOnlySpan Current { get; private set; } + + public TokenEnumerator GetEnumerator() => this; + + public bool MoveNext() + { + if (_remaining.IsEmpty) + return false; + + int sep = _remaining.IndexOf(SubjectMatch.Sep); + if (sep < 0) + { + Current = _remaining; + _remaining = default; + } + else + { + Current = _remaining[..sep]; + _remaining = _remaining[(sep + 1)..]; + } + return true; + } + } + + private sealed class TrieLevel + { + public readonly Dictionary Nodes = new(); + public TrieNode? Pwc; // partial wildcard (*) + public TrieNode? Fwc; // full wildcard (>) + } + + private sealed class TrieNode + { + public TrieLevel? Next; + public readonly HashSet PlainSubs = []; + public readonly Dictionary> QueueSubs = new(); + + public bool IsEmpty => PlainSubs.Count == 0 && QueueSubs.Count == 0 && + (Next == null || (Next.Nodes.Count == 0 && Next.Pwc == null && Next.Fwc == null)); + } +} diff --git a/tests/NATS.Server.Tests/SubListTests.cs b/tests/NATS.Server.Tests/SubListTests.cs new file mode 100644 index 0000000..e7e7ba7 --- /dev/null +++ b/tests/NATS.Server.Tests/SubListTests.cs @@ -0,0 +1,159 @@ +using NATS.Server.Subscriptions; + +namespace NATS.Server.Tests; + +public class SubListTests +{ + private static Subscription MakeSub(string subject, string? queue = null, string sid = "1") + => new() { Subject = subject, Queue = queue, Sid = sid }; + + [Fact] + public void Insert_and_match_literal_subject() + { + var sl = new SubList(); + var sub = MakeSub("foo.bar"); + sl.Insert(sub); + + var r = sl.Match("foo.bar"); + Assert.Single(r.PlainSubs); + Assert.Same(sub, r.PlainSubs[0]); + Assert.Empty(r.QueueSubs); + } + + [Fact] + public void Match_returns_empty_for_no_match() + { + var sl = new SubList(); + sl.Insert(MakeSub("foo.bar")); + + var r = sl.Match("foo.baz"); + Assert.Empty(r.PlainSubs); + } + + [Fact] + public void Match_partial_wildcard() + { + var sl = new SubList(); + var sub = MakeSub("foo.*"); + sl.Insert(sub); + + Assert.Single(sl.Match("foo.bar").PlainSubs); + Assert.Single(sl.Match("foo.baz").PlainSubs); + Assert.Empty(sl.Match("foo.bar.baz").PlainSubs); + } + + [Fact] + public void Match_full_wildcard() + { + var sl = new SubList(); + var sub = MakeSub("foo.>"); + sl.Insert(sub); + + Assert.Single(sl.Match("foo.bar").PlainSubs); + Assert.Single(sl.Match("foo.bar.baz").PlainSubs); + Assert.Empty(sl.Match("foo").PlainSubs); + } + + [Fact] + public void Match_root_full_wildcard() + { + var sl = new SubList(); + sl.Insert(MakeSub(">")); + + Assert.Single(sl.Match("foo").PlainSubs); + Assert.Single(sl.Match("foo.bar").PlainSubs); + Assert.Single(sl.Match("foo.bar.baz").PlainSubs); + } + + [Fact] + public void Match_collects_multiple_subs() + { + var sl = new SubList(); + sl.Insert(MakeSub("foo.bar", sid: "1")); + sl.Insert(MakeSub("foo.*", sid: "2")); + sl.Insert(MakeSub("foo.>", sid: "3")); + sl.Insert(MakeSub(">", sid: "4")); + + var r = sl.Match("foo.bar"); + Assert.Equal(4, r.PlainSubs.Length); + } + + [Fact] + public void Remove_subscription() + { + var sl = new SubList(); + var sub = MakeSub("foo.bar"); + sl.Insert(sub); + Assert.Single(sl.Match("foo.bar").PlainSubs); + + sl.Remove(sub); + Assert.Empty(sl.Match("foo.bar").PlainSubs); + } + + [Fact] + public void Queue_group_subscriptions() + { + var sl = new SubList(); + sl.Insert(MakeSub("foo.bar", queue: "workers", sid: "1")); + sl.Insert(MakeSub("foo.bar", queue: "workers", sid: "2")); + sl.Insert(MakeSub("foo.bar", queue: "loggers", sid: "3")); + + var r = sl.Match("foo.bar"); + Assert.Empty(r.PlainSubs); + Assert.Equal(2, r.QueueSubs.Length); // 2 queue groups + } + + [Fact] + public void Count_tracks_subscriptions() + { + var sl = new SubList(); + Assert.Equal(0u, sl.Count); + + sl.Insert(MakeSub("foo", sid: "1")); + sl.Insert(MakeSub("bar", sid: "2")); + Assert.Equal(2u, sl.Count); + + sl.Remove(MakeSub("foo", sid: "1")); + // Remove by reference won't work — we need the same instance + } + + [Fact] + public void Count_tracks_with_same_instance() + { + var sl = new SubList(); + var sub = MakeSub("foo"); + sl.Insert(sub); + Assert.Equal(1u, sl.Count); + sl.Remove(sub); + Assert.Equal(0u, sl.Count); + } + + [Fact] + public void Cache_invalidation_on_insert() + { + var sl = new SubList(); + sl.Insert(MakeSub("foo.bar", sid: "1")); + + // Prime the cache + var r1 = sl.Match("foo.bar"); + Assert.Single(r1.PlainSubs); + + // Insert a wildcard that matches — cache should be invalidated + sl.Insert(MakeSub("foo.*", sid: "2")); + + var r2 = sl.Match("foo.bar"); + Assert.Equal(2, r2.PlainSubs.Length); + } + + [Fact] + public void Match_partial_wildcard_at_different_levels() + { + var sl = new SubList(); + sl.Insert(MakeSub("*.bar.baz", sid: "1")); + sl.Insert(MakeSub("foo.*.baz", sid: "2")); + sl.Insert(MakeSub("foo.bar.*", sid: "3")); + + var r = sl.Match("foo.bar.baz"); + Assert.Equal(3, r.PlainSubs.Length); + } +}