feat: implement SubList trie with wildcard matching and cache

This commit is contained in:
Joseph Doherty
2026-02-22 20:07:35 -05:00
parent 9e36b7c0fc
commit afc419ce3f
2 changed files with 674 additions and 0 deletions

View File

@@ -0,0 +1,515 @@
namespace NATS.Server.Subscriptions;
/// <summary>
/// SubList is a routing mechanism to handle subject distribution and
/// provides a facility to match subjects from published messages to
/// interested subscribers. Subscribers can have wildcard subjects to
/// match multiple published subjects.
/// </summary>
public sealed class SubList
{
private const int CacheMax = 1024;
private const int CacheSweep = 256;
private readonly ReaderWriterLockSlim _lock = new();
private readonly TrieLevel _root = new();
private Dictionary<string, SubListResult>? _cache = new();
private uint _count;
private ulong _genId;
public uint Count
{
get
{
_lock.EnterReadLock();
try { return _count; }
finally { _lock.ExitReadLock(); }
}
}
public void Insert(Subscription sub)
{
var subject = sub.Subject;
_lock.EnterWriteLock();
try
{
var level = _root;
TrieNode? node = null;
bool sawFwc = false;
foreach (var token in new TokenEnumerator(subject))
{
if (token.Length == 0 || sawFwc)
throw new ArgumentException("Invalid subject", nameof(sub));
if (token.Length == 1 && token[0] == SubjectMatch.Pwc)
{
node = level.Pwc ??= new TrieNode();
}
else if (token.Length == 1 && token[0] == SubjectMatch.Fwc)
{
node = level.Fwc ??= new TrieNode();
sawFwc = true;
}
else
{
var key = token.ToString();
if (!level.Nodes.TryGetValue(key, out node))
{
node = new TrieNode();
level.Nodes[key] = node;
}
}
node.Next ??= new TrieLevel();
level = node.Next;
}
if (node == null)
throw new ArgumentException("Invalid subject", nameof(sub));
if (sub.Queue == null)
{
node.PlainSubs.Add(sub);
}
else
{
if (!node.QueueSubs.TryGetValue(sub.Queue, out var qset))
{
qset = [];
node.QueueSubs[sub.Queue] = qset;
}
qset.Add(sub);
}
_count++;
_genId++;
AddToCache(subject, sub);
}
finally
{
_lock.ExitWriteLock();
}
}
public void Remove(Subscription sub)
{
_lock.EnterWriteLock();
try
{
var level = _root;
TrieNode? node = null;
bool sawFwc = false;
var pathList = new List<(TrieLevel level, TrieNode node, string token, bool isPwc, bool isFwc)>();
foreach (var token in new TokenEnumerator(sub.Subject))
{
if (token.Length == 0 || sawFwc)
return;
bool isPwc = token.Length == 1 && token[0] == SubjectMatch.Pwc;
bool isFwc = token.Length == 1 && token[0] == SubjectMatch.Fwc;
if (isPwc)
{
node = level.Pwc;
}
else if (isFwc)
{
node = level.Fwc;
sawFwc = true;
}
else
{
level.Nodes.TryGetValue(token.ToString(), out node);
}
if (node == null)
return; // not found
var tokenStr = token.ToString();
pathList.Add((level, node, tokenStr, isPwc, isFwc));
level = node.Next ?? new TrieLevel();
}
if (node == null) return;
// Remove from node
bool removed;
if (sub.Queue == null)
{
removed = node.PlainSubs.Remove(sub);
}
else
{
removed = false;
if (node.QueueSubs.TryGetValue(sub.Queue, out var qset))
{
removed = qset.Remove(sub);
if (qset.Count == 0)
node.QueueSubs.Remove(sub.Queue);
}
}
if (!removed) return;
_count--;
_genId++;
RemoveFromCache(sub.Subject);
// Prune empty nodes (walk backwards)
for (int i = pathList.Count - 1; i >= 0; i--)
{
var (l, n, t, isPwc, isFwc) = pathList[i];
if (n.IsEmpty)
{
if (isPwc) l.Pwc = null;
else if (isFwc) l.Fwc = null;
else l.Nodes.Remove(t);
}
}
}
finally
{
_lock.ExitWriteLock();
}
}
public SubListResult Match(string subject)
{
// Check cache under read lock first.
_lock.EnterReadLock();
try
{
if (_cache != null && _cache.TryGetValue(subject, out var cached))
return cached;
}
finally
{
_lock.ExitReadLock();
}
// Cache miss -- tokenize and match under write lock (needed for cache update).
// Tokenize the subject.
var tokens = Tokenize(subject);
if (tokens == null)
return SubListResult.Empty;
_lock.EnterWriteLock();
try
{
// Re-check cache after acquiring write lock.
if (_cache != null && _cache.TryGetValue(subject, out var cached))
return cached;
var plainSubs = new List<Subscription>();
var queueSubs = new List<List<Subscription>>();
MatchLevel(_root, tokens, 0, plainSubs, queueSubs);
SubListResult result;
if (plainSubs.Count == 0 && queueSubs.Count == 0)
{
result = SubListResult.Empty;
}
else
{
result = new SubListResult(
plainSubs.ToArray(),
queueSubs.Select(q => q.ToArray()).ToArray());
}
if (_cache != null)
{
_cache[subject] = result;
if (_cache.Count > CacheMax)
{
// Sweep: remove entries until at CacheSweep count.
var keys = _cache.Keys.Take(_cache.Count - CacheSweep).ToList();
foreach (var key in keys)
_cache.Remove(key);
}
}
return result;
}
finally
{
_lock.ExitWriteLock();
}
}
/// <summary>
/// Tokenize the subject into an array of token strings.
/// Returns null if the subject is invalid (empty tokens).
/// </summary>
private static string[]? Tokenize(string subject)
{
if (string.IsNullOrEmpty(subject))
return null;
var tokens = new List<string>();
int start = 0;
for (int i = 0; i <= subject.Length; i++)
{
if (i == subject.Length || subject[i] == SubjectMatch.Sep)
{
if (i - start == 0)
return null; // empty token
tokens.Add(subject[start..i]);
start = i + 1;
}
}
return tokens.Count > 0 ? tokens.ToArray() : null;
}
/// <summary>
/// Recursively descend into the trie matching tokens.
/// This follows the Go matchLevel() algorithm closely.
/// </summary>
private static void MatchLevel(TrieLevel? level, string[] tokens, int tokenIndex,
List<Subscription> plainSubs, List<List<Subscription>> queueSubs)
{
TrieNode? pwc = null;
TrieNode? node = null;
for (int i = tokenIndex; i < tokens.Length; i++)
{
if (level == null)
return;
// Full wildcard (>) at this level matches all remaining tokens.
if (level.Fwc != null)
AddNodeToResults(level.Fwc, plainSubs, queueSubs);
// Partial wildcard (*) -- recurse with remaining tokens.
pwc = level.Pwc;
if (pwc != null)
MatchLevel(pwc.Next, tokens, i + 1, plainSubs, queueSubs);
// Literal match
node = null;
if (level.Nodes.TryGetValue(tokens[i], out var found))
{
node = found;
level = node.Next;
}
else
{
level = null;
}
}
// After processing all tokens, add results from the final literal node.
if (node != null)
AddNodeToResults(node, plainSubs, queueSubs);
// Also add results from the partial wildcard at the last level,
// which handles the case where * matches the final token.
if (pwc != null)
AddNodeToResults(pwc, plainSubs, queueSubs);
}
private static void AddNodeToResults(TrieNode node,
List<Subscription> plainSubs, List<List<Subscription>> queueSubs)
{
// Add plain subscriptions
foreach (var sub in node.PlainSubs)
plainSubs.Add(sub);
// Add queue subscriptions grouped by queue name
foreach (var (queueName, subs) in node.QueueSubs)
{
if (subs.Count == 0) continue;
// Find existing queue group or create new one
List<Subscription>? existing = null;
foreach (var qs in queueSubs)
{
if (qs.Count > 0 && qs[0].Queue == queueName)
{
existing = qs;
break;
}
}
if (existing == null)
{
existing = new List<Subscription>();
queueSubs.Add(existing);
}
existing.AddRange(subs);
}
}
/// <summary>
/// Adds a subscription to matching cache entries.
/// Assumes write lock is held.
/// </summary>
private void AddToCache(string subject, Subscription sub)
{
if (_cache == null)
return;
// If literal subject, we can do a direct lookup.
if (SubjectMatch.IsLiteral(subject))
{
if (_cache.TryGetValue(subject, out var r))
{
_cache[subject] = AddSubToResult(r, sub);
}
return;
}
// Wildcard subscription -- check all cached keys.
var keysToUpdate = new List<(string key, SubListResult result)>();
foreach (var (key, r) in _cache)
{
if (SubjectMatch.MatchLiteral(key, subject))
{
keysToUpdate.Add((key, r));
}
}
foreach (var (key, r) in keysToUpdate)
{
_cache[key] = AddSubToResult(r, sub);
}
}
/// <summary>
/// Removes cache entries that match the given subject.
/// Assumes write lock is held.
/// </summary>
private void RemoveFromCache(string subject)
{
if (_cache == null)
return;
// If literal subject, we can do a direct removal.
if (SubjectMatch.IsLiteral(subject))
{
_cache.Remove(subject);
return;
}
// Wildcard subscription -- remove all matching cached keys.
var keysToRemove = new List<string>();
foreach (var key in _cache.Keys)
{
if (SubjectMatch.MatchLiteral(key, subject))
{
keysToRemove.Add(key);
}
}
foreach (var key in keysToRemove)
{
_cache.Remove(key);
}
}
/// <summary>
/// Creates a new result with the given subscription added.
/// </summary>
private static SubListResult AddSubToResult(SubListResult result, Subscription sub)
{
if (sub.Queue == null)
{
var newPlain = new Subscription[result.PlainSubs.Length + 1];
result.PlainSubs.CopyTo(newPlain, 0);
newPlain[^1] = sub;
return new SubListResult(newPlain, result.QueueSubs);
}
else
{
// Find existing queue group
var queueSubs = result.QueueSubs;
int slot = -1;
for (int i = 0; i < queueSubs.Length; i++)
{
if (queueSubs[i].Length > 0 && queueSubs[i][0].Queue == sub.Queue)
{
slot = i;
break;
}
}
// Deep copy queue subs
var newQueueSubs = new Subscription[queueSubs.Length + (slot < 0 ? 1 : 0)][];
for (int i = 0; i < queueSubs.Length; i++)
{
if (i == slot)
{
var newGroup = new Subscription[queueSubs[i].Length + 1];
queueSubs[i].CopyTo(newGroup, 0);
newGroup[^1] = sub;
newQueueSubs[i] = newGroup;
}
else
{
newQueueSubs[i] = (Subscription[])queueSubs[i].Clone();
}
}
if (slot < 0)
{
newQueueSubs[^1] = [sub];
}
return new SubListResult(result.PlainSubs, newQueueSubs);
}
}
/// <summary>Enumerates '.' separated tokens in a subject without allocating.</summary>
private ref struct TokenEnumerator
{
private ReadOnlySpan<char> _remaining;
public TokenEnumerator(string subject)
{
_remaining = subject.AsSpan();
Current = default;
}
public ReadOnlySpan<char> Current { get; private set; }
public TokenEnumerator GetEnumerator() => this;
public bool MoveNext()
{
if (_remaining.IsEmpty)
return false;
int sep = _remaining.IndexOf(SubjectMatch.Sep);
if (sep < 0)
{
Current = _remaining;
_remaining = default;
}
else
{
Current = _remaining[..sep];
_remaining = _remaining[(sep + 1)..];
}
return true;
}
}
private sealed class TrieLevel
{
public readonly Dictionary<string, TrieNode> Nodes = new();
public TrieNode? Pwc; // partial wildcard (*)
public TrieNode? Fwc; // full wildcard (>)
}
private sealed class TrieNode
{
public TrieLevel? Next;
public readonly HashSet<Subscription> PlainSubs = [];
public readonly Dictionary<string, HashSet<Subscription>> QueueSubs = new();
public bool IsEmpty => PlainSubs.Count == 0 && QueueSubs.Count == 0 &&
(Next == null || (Next.Nodes.Count == 0 && Next.Pwc == null && Next.Fwc == null));
}
}

View File

@@ -0,0 +1,159 @@
using NATS.Server.Subscriptions;
namespace NATS.Server.Tests;
public class SubListTests
{
private static Subscription MakeSub(string subject, string? queue = null, string sid = "1")
=> new() { Subject = subject, Queue = queue, Sid = sid };
[Fact]
public void Insert_and_match_literal_subject()
{
var sl = new SubList();
var sub = MakeSub("foo.bar");
sl.Insert(sub);
var r = sl.Match("foo.bar");
Assert.Single(r.PlainSubs);
Assert.Same(sub, r.PlainSubs[0]);
Assert.Empty(r.QueueSubs);
}
[Fact]
public void Match_returns_empty_for_no_match()
{
var sl = new SubList();
sl.Insert(MakeSub("foo.bar"));
var r = sl.Match("foo.baz");
Assert.Empty(r.PlainSubs);
}
[Fact]
public void Match_partial_wildcard()
{
var sl = new SubList();
var sub = MakeSub("foo.*");
sl.Insert(sub);
Assert.Single(sl.Match("foo.bar").PlainSubs);
Assert.Single(sl.Match("foo.baz").PlainSubs);
Assert.Empty(sl.Match("foo.bar.baz").PlainSubs);
}
[Fact]
public void Match_full_wildcard()
{
var sl = new SubList();
var sub = MakeSub("foo.>");
sl.Insert(sub);
Assert.Single(sl.Match("foo.bar").PlainSubs);
Assert.Single(sl.Match("foo.bar.baz").PlainSubs);
Assert.Empty(sl.Match("foo").PlainSubs);
}
[Fact]
public void Match_root_full_wildcard()
{
var sl = new SubList();
sl.Insert(MakeSub(">"));
Assert.Single(sl.Match("foo").PlainSubs);
Assert.Single(sl.Match("foo.bar").PlainSubs);
Assert.Single(sl.Match("foo.bar.baz").PlainSubs);
}
[Fact]
public void Match_collects_multiple_subs()
{
var sl = new SubList();
sl.Insert(MakeSub("foo.bar", sid: "1"));
sl.Insert(MakeSub("foo.*", sid: "2"));
sl.Insert(MakeSub("foo.>", sid: "3"));
sl.Insert(MakeSub(">", sid: "4"));
var r = sl.Match("foo.bar");
Assert.Equal(4, r.PlainSubs.Length);
}
[Fact]
public void Remove_subscription()
{
var sl = new SubList();
var sub = MakeSub("foo.bar");
sl.Insert(sub);
Assert.Single(sl.Match("foo.bar").PlainSubs);
sl.Remove(sub);
Assert.Empty(sl.Match("foo.bar").PlainSubs);
}
[Fact]
public void Queue_group_subscriptions()
{
var sl = new SubList();
sl.Insert(MakeSub("foo.bar", queue: "workers", sid: "1"));
sl.Insert(MakeSub("foo.bar", queue: "workers", sid: "2"));
sl.Insert(MakeSub("foo.bar", queue: "loggers", sid: "3"));
var r = sl.Match("foo.bar");
Assert.Empty(r.PlainSubs);
Assert.Equal(2, r.QueueSubs.Length); // 2 queue groups
}
[Fact]
public void Count_tracks_subscriptions()
{
var sl = new SubList();
Assert.Equal(0u, sl.Count);
sl.Insert(MakeSub("foo", sid: "1"));
sl.Insert(MakeSub("bar", sid: "2"));
Assert.Equal(2u, sl.Count);
sl.Remove(MakeSub("foo", sid: "1"));
// Remove by reference won't work — we need the same instance
}
[Fact]
public void Count_tracks_with_same_instance()
{
var sl = new SubList();
var sub = MakeSub("foo");
sl.Insert(sub);
Assert.Equal(1u, sl.Count);
sl.Remove(sub);
Assert.Equal(0u, sl.Count);
}
[Fact]
public void Cache_invalidation_on_insert()
{
var sl = new SubList();
sl.Insert(MakeSub("foo.bar", sid: "1"));
// Prime the cache
var r1 = sl.Match("foo.bar");
Assert.Single(r1.PlainSubs);
// Insert a wildcard that matches — cache should be invalidated
sl.Insert(MakeSub("foo.*", sid: "2"));
var r2 = sl.Match("foo.bar");
Assert.Equal(2, r2.PlainSubs.Length);
}
[Fact]
public void Match_partial_wildcard_at_different_levels()
{
var sl = new SubList();
sl.Insert(MakeSub("*.bar.baz", sid: "1"));
sl.Insert(MakeSub("foo.*.baz", sid: "2"));
sl.Insert(MakeSub("foo.bar.*", sid: "3"));
var r = sl.Match("foo.bar.baz");
Assert.Equal(3, r.PlainSubs.Length);
}
}