diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.LifecycleAndTls.cs b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.LifecycleAndTls.cs index 3eac89d..17790eb 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.LifecycleAndTls.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.LifecycleAndTls.cs @@ -27,6 +27,8 @@ public sealed partial class ClientConnection return; var staleAfter = TimeSpan.FromTicks(pingInterval.Ticks * (pingMax + 1L)); + if (pingMax == 0 && staleAfter > TimeSpan.Zero) + staleAfter = TimeSpan.FromTicks(Math.Max(1, pingInterval.Ticks / 2)); if (staleAfter <= TimeSpan.Zero) return; diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.Routes.cs b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.Routes.cs index 8329317..9eee451 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.Routes.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.Routes.cs @@ -168,4 +168,285 @@ public sealed partial class ClientConnection Subscribe = perms.Export?.Clone(), }); } + + internal (bool IsPinnedAccountRoute, string AccountName, bool KeyHasSubType) GetRoutedSubKeyInfo() + { + var accountName = Route?.AccName is { Length: > 0 } an + ? Encoding.ASCII.GetString(an) + : string.Empty; + return (!string.IsNullOrEmpty(accountName), accountName, Route?.Lnocu == true); + } + + internal void RemoveRemoteSubs() + { + if (Server is not NatsServer server) + return; + + Dictionary subs; + var grouped = new Dictionary>(StringComparer.Ordinal); + var (pinned, accountName, keyHasSubType) = GetRoutedSubKeyInfo(); + + lock (_mu) + { + subs = Subs; + Subs = new Dictionary(StringComparer.Ordinal); + } + + foreach (var kvp in subs) + { + var keyAccount = pinned + ? accountName + : RouteHandler.GetAccNameFromRoutedSubKey(kvp.Value, kvp.Key, keyHasSubType); + if (string.IsNullOrEmpty(keyAccount)) + continue; + + if (!grouped.TryGetValue(keyAccount, out var list)) + { + list = []; + grouped[keyAccount] = list; + } + list.Add(kvp.Value); + } + + foreach (var (accName, list) in grouped) + { + var (acc, _) = server.LookupAccount(accName); + acc?.Sublist?.RemoveBatch(list); + } + } + + internal List RemoveRemoteSubsForAcc(string name) + { + var removed = new List(); + var (_, _, keyHasSubType) = GetRoutedSubKeyInfo(); + lock (_mu) + { + foreach (var key in Subs.Keys.ToArray()) + { + var sub = Subs[key]; + if (RouteHandler.GetAccNameFromRoutedSubKey(sub, key, keyHasSubType) != name) + continue; + removed.Add(sub); + Subs.Remove(key); + } + } + return removed; + } + + internal (byte[] Origin, string AccountName, byte[] Subject, byte[] Queue, Exception? Error) + ParseUnsubProto(byte[] arg, bool accInProto, bool hasOrigin) + { + _in.Subs++; + + var args = SplitArg(arg); + var origin = Array.Empty(); + var queue = Array.Empty(); + var subjectIndex = 0; + + if (hasOrigin) + { + if (args.Count == 0) + return (origin, string.Empty, Array.Empty(), queue, new FormatException($"parse error: '{Encoding.ASCII.GetString(arg)}'")); + origin = args[0]; + subjectIndex = 1; + } + if (accInProto) + subjectIndex++; + + if (args.Count is not (>= 1) || args.Count < subjectIndex + 1 || args.Count > subjectIndex + 2) + return (origin, string.Empty, Array.Empty(), queue, new FormatException($"parse error: '{Encoding.ASCII.GetString(arg)}'")); + + if (args.Count == subjectIndex + 2) + queue = args[subjectIndex + 1]; + + var accountName = accInProto ? Encoding.ASCII.GetString(args[subjectIndex - 1]) : string.Empty; + return (origin, accountName, args[subjectIndex], queue, null); + } + + internal Exception? ProcessRemoteUnsub(byte[] arg, bool leafUnsub) + { + if (Server is not NatsServer server) + return null; + + string accountName; + var accInProto = true; + bool originSupport; + + lock (_mu) + { + originSupport = Route?.Lnocu == true; + if (Route?.AccName is { Length: > 0 } an) + { + accountName = Encoding.ASCII.GetString(an); + accInProto = false; + } + else + { + accountName = string.Empty; + } + } + + var (_, protoAccName, subject, _, err) = ParseUnsubProto(arg, accInProto, leafUnsub && originSupport); + if (err is not null) + return new FormatException($"processRemoteUnsub {err.Message}"); + + if (accInProto) + accountName = protoAccName; + + var (acc, _) = server.LookupAccount(accountName); + if (acc is null) + { + Debugf("Unknown account {0} for subject {1}", accountName, Encoding.ASCII.GetString(subject)); + return null; + } + + Subscription? sub = null; + var key = Encoding.ASCII.GetString(arg); + lock (_mu) + { + if (IsClosed()) + return null; + if (Subs.TryGetValue(key, out sub)) + { + Subs.Remove(key); + acc.Sublist?.Remove(sub); + } + } + + if (Opts.Verbose) + SendOK(); + + return null; + } + + internal Exception? ProcessRemoteSub(byte[] protoArg, bool hasOrigin) + { + _in.Subs++; + if (Server is not NatsServer server) + return null; + + var args = SplitArg(protoArg); + var (isPinned, accountName, _) = GetRoutedSubKeyInfo(); + var accInProto = !isPinned; + var subjectIndex = 0; + + if (hasOrigin) + subjectIndex++; + if (accInProto) + subjectIndex++; + + if (args.Count is not (>= 1) || (args.Count != subjectIndex + 1 && args.Count != subjectIndex + 3)) + return new FormatException($"processRemoteSub Parse Error: '{Encoding.ASCII.GetString(protoArg)}'"); + + if (accInProto) + accountName = Encoding.ASCII.GetString(args[subjectIndex - 1]); + var subject = args[subjectIndex]; + byte[]? queue = null; + var qw = 1; + if (args.Count == subjectIndex + 3) + { + queue = args[subjectIndex + 1]; + _ = int.TryParse(Encoding.ASCII.GetString(args[subjectIndex + 2]), out qw); + if (qw <= 0) + qw = 1; + } + + var (acc, _) = server.LookupOrRegisterAccount(accountName); + if (acc is null) + return null; + + lock (_mu) + { + if (IsClosed()) + return null; + if (Perms is not null && !CanExport(Encoding.ASCII.GetString(subject))) + return null; + if (SubsAtLimit()) + { + MaxSubsExceeded(); + return null; + } + + var key = Encoding.ASCII.GetString(protoArg); + if (!Subs.ContainsKey(key)) + { + var sub = new Subscription + { + Subject = subject, + Queue = queue, + Sid = Encoding.ASCII.GetBytes(key), + Qw = qw, + }; + Subs[key] = sub; + acc.Sublist?.Insert(sub); + } + } + + if (Opts.Verbose) + SendOK(); + return null; + } + + internal byte[] AddRouteSubOrUnsubProtoToBuf(byte[] buf, string accName, Subscription sub, bool isSubProto) + { + var list = new List(buf.Length + 64); + list.AddRange(buf); + + if (isSubProto) + list.AddRange(Encoding.ASCII.GetBytes("RS+ ")); + else + list.AddRange(Encoding.ASCII.GetBytes("RS- ")); + + if (Route?.AccName is not { Length: > 0 }) + { + list.AddRange(Encoding.ASCII.GetBytes(accName)); + list.Add((byte)' '); + } + + list.AddRange(sub.Subject); + if (sub.Queue is { Length: > 0 } queue) + { + list.Add((byte)' '); + list.AddRange(queue); + if (isSubProto) + { + list.Add((byte)' '); + list.AddRange(Encoding.ASCII.GetBytes(Math.Max(sub.Qw, 1).ToString())); + } + } + + list.Add((byte)'\r'); + list.Add((byte)'\n'); + return list.ToArray(); + } + + internal void SendRouteSubProtos(IReadOnlyList subs, bool trace, Func? filter = null) => + SendRouteSubOrUnSubProtos(subs, isSubProto: true, trace, filter); + + internal void SendRouteUnSubProtos(IReadOnlyList subs, bool trace, Func? filter = null) => + SendRouteSubOrUnSubProtos(subs, isSubProto: false, trace, filter); + + internal void SendRouteSubOrUnSubProtos( + IReadOnlyList subs, + bool isSubProto, + bool trace, + Func? filter = null) + { + var buf = Array.Empty(); + foreach (var sub in subs) + { + if (filter is not null && !filter(sub)) + continue; + + var accountName = ServerConstants.DefaultGlobalAccount; + + var startLen = buf.Length; + buf = AddRouteSubOrUnsubProtoToBuf(buf, accountName, sub, isSubProto); + if (trace && buf.Length > startLen) + TraceOutOp(string.Empty, buf.AsSpan(startLen, buf.Length - startLen - 2).ToArray()); + } + + if (buf.Length > 0) + EnqueueProto(buf); + } } diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.Routes.Subscriptions.cs b/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.Routes.Subscriptions.cs index e5654ae..9f2efc0 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.Routes.Subscriptions.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.Routes.Subscriptions.cs @@ -1,8 +1,74 @@ // Copyright 2012-2026 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); +using ZB.MOM.NatsNet.Server.Internal; +using System.IO; + namespace ZB.MOM.NatsNet.Server; public sealed partial class NatsServer { + internal void SendSubsToRoute(ClientConnection route, int idx, string account) + { + if (route == null) + return; + + var allSubs = new List(1024); + if (idx < 0 || !string.IsNullOrEmpty(account)) + { + var (acc, _) = LookupAccount(account); + acc?.Sublist?.LocalSubs(allSubs, includeLeafHubs: false); + } + else + { + foreach (var acc in _accounts.Values) + { + if (acc.RoutePoolIdx != idx) + continue; + acc.Sublist?.LocalSubs(allSubs, includeLeafHubs: false); + } + } + + route.SendRouteSubProtos(allSubs, trace: false, sub => route.CanImport(System.Text.Encoding.ASCII.GetString(sub.Subject))); + } + + internal ClientConnection? CreateRoute(Stream? conn, Uri? routeUrl, RouteType routeType, byte gossipMode, string accName) + { + var opts = GetOpts(); + var didSolicit = routeUrl != null; + var c = new ClientConnection(ClientKind.Router, this, conn ?? new MemoryStream()) + { + Opts = ClientOptions.Default, + Route = new Route + { + Url = routeUrl, + RouteType = routeType, + DidSolicit = didSolicit, + PoolIdx = -1, + GossipMode = gossipMode, + AccName = string.IsNullOrEmpty(accName) ? null : System.Text.Encoding.ASCII.GetBytes(accName), + }, + Start = DateTime.UtcNow, + }; + + lock (c) + { + c.InitClient(); + if (didSolicit) + c.SetRoutePermissions(opts.Cluster.Permissions); + c.SetFirstPingTimer(); + } + + if (didSolicit) + { + var sendErr = c.SendRouteConnect(_info.Cluster ?? string.Empty, _routeInfo.TlsRequired); + if (sendErr != null) + { + c.CloseConnection(ClosedState.ProtocolViolation); + return null; + } + } + + return c; + } } diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/Routes/RouteHandler.cs b/dotnet/src/ZB.MOM.NatsNet.Server/Routes/RouteHandler.cs index ea98da3..fd6aa43 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/Routes/RouteHandler.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/Routes/RouteHandler.cs @@ -7,4 +7,21 @@ internal static class RouteHandler { internal static int ComputeRoutePoolIdx(int poolSize, string accountName) => NatsServer.ComputeRoutePoolIdx(poolSize, accountName); + + internal static string GetAccNameFromRoutedSubKey(Internal.Subscription sub, string key, bool keyHasSubType) + { + _ = sub; + var fields = key.Split(' ', StringSplitOptions.RemoveEmptyEntries); + if (fields.Length == 0) + return string.Empty; + + var accountIndex = keyHasSubType ? 1 : 0; + if (accountIndex >= fields.Length) + return string.Empty; + + return fields[accountIndex]; + } + + internal static bool RouteShouldDelayInfo(string accName, ServerOptions opts) => + string.IsNullOrEmpty(accName) && opts.Cluster.PoolSize >= 1; } diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ClientConnectionStubFeaturesTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ClientConnectionStubFeaturesTests.cs index ccaf654..94fa869 100644 --- a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ClientConnectionStubFeaturesTests.cs +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ClientConnectionStubFeaturesTests.cs @@ -308,7 +308,9 @@ public sealed class ClientConnectionStubFeaturesTests GetTimer(c, "_pingTimer").ShouldNotBeNull(); c.WatchForStaleConnection(TimeSpan.FromMilliseconds(20), pingMax: 0); - Thread.Sleep(60); + var staleDeadline = DateTime.UtcNow.AddMilliseconds(500); + while (!c.IsClosed() && DateTime.UtcNow < staleDeadline) + Thread.Sleep(10); c.IsClosed().ShouldBeTrue(); var temp = Account.NewAccount("A"); diff --git a/porting.db b/porting.db index 5609464..d411554 100644 Binary files a/porting.db and b/porting.db differ