diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.LifecycleAndTls.cs b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.LifecycleAndTls.cs new file mode 100644 index 0000000..3eac89d --- /dev/null +++ b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.LifecycleAndTls.cs @@ -0,0 +1,516 @@ +// Copyright 2012-2026 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); + +using System.Collections.Concurrent; +using System.Net.Security; +using System.Runtime.CompilerServices; +using System.Security.Authentication; +using System.Text; +using System.Linq; +using ZB.MOM.NatsNet.Server.Auth; +using ZB.MOM.NatsNet.Server.Internal; +using ZB.MOM.NatsNet.Server.Internal.DataStructures; + +namespace ZB.MOM.NatsNet.Server; + +public sealed partial class ClientConnection +{ + private static readonly TimeSpan FirstPingInterval = TimeSpan.FromSeconds(15); + private static readonly TimeSpan FirstClientPingInterval = TimeSpan.FromSeconds(2); + private const int MaxPerAccountCacheSize = 8192; + private const string StaleErrProtoFormat = "-ERR '{0}'\r\n"; + private static readonly ConditionalWeakTable> RateLimitCacheByServer = new(); + + internal void WatchForStaleConnection(TimeSpan pingInterval, int pingMax) + { + if (pingInterval <= TimeSpan.Zero || pingMax < 0) + return; + + var staleAfter = TimeSpan.FromTicks(pingInterval.Ticks * (pingMax + 1L)); + if (staleAfter <= TimeSpan.Zero) + return; + + ClearPingTimer(); + _pingTimer = new Timer(_ => + { + lock (_mu) + { + if (IsClosed()) + return; + + Debugf("Stale Client Connection - Closing"); + EnqueueProto(Encoding.ASCII.GetBytes(string.Format(StaleErrProtoFormat, "Stale Connection"))); + } + CloseConnection(ClosedState.StaleConnection); + }, null, staleAfter, Timeout.InfiniteTimeSpan); + } + + internal void SwapAccountAfterReload() + { + string accountName; + lock (_mu) + { + if (_account is null || Server is null) + return; + accountName = _account.Name; + } + + if (Server is not NatsServer server) + return; + + var (updated, _) = server.LookupAccount(accountName); + if (updated is null) + return; + + lock (_mu) + { + if (!ReferenceEquals(_account, updated)) + _account = updated; + } + } + + internal void ProcessSubsOnConfigReload(ISet? accountsWithChangedStreamImports) + { + INatsAccount? acc; + var checkPerms = false; + var checkAcc = false; + var retained = new List(); + var removed = new List(); + + lock (_mu) + { + checkPerms = Perms is not null; + checkAcc = _account is not null; + acc = _account; + if (!checkPerms && !checkAcc) + return; + + if (checkAcc && acc is not null && accountsWithChangedStreamImports is not null && + !accountsWithChangedStreamImports.Contains(acc.Name)) + { + checkAcc = false; + } + + MPerms = null; + foreach (var sub in Subs.Values) + { + var subject = Encoding.ASCII.GetString(sub.Subject); + var canSub = CanSubscribe(subject); + var canQSub = sub.Queue is { Length: > 0 } q && CanSubscribe(subject, Encoding.ASCII.GetString(q)); + + if (!canSub && !canQSub) + { + removed.Add(sub); + } + else if (checkAcc) + { + retained.Add(sub); + } + } + } + + if (checkAcc && acc is not null) + { + foreach (var sub in retained) + { + AddShadowSubscriptions(acc, sub); + } + } + + foreach (var sub in removed) + { + Unsubscribe(acc, sub, force: true, remove: true); + var sid = sub.Sid is { Length: > 0 } s ? Encoding.ASCII.GetString(s) : string.Empty; + SendErr($"Permissions Violation for Subscription to \"{Encoding.ASCII.GetString(sub.Subject)}\" (sid \"{sid}\")"); + Noticef("Removed sub \"{0}\" (sid \"{1}\") for \"{2}\" - not authorized", + Encoding.ASCII.GetString(sub.Subject), sid, GetAuthUser()); + } + } + + internal void Reconnect() + { + lock (_mu) + { + if (Flags.IsSet(ClientFlags.NoReconnect) || Server is null) + return; + } + + // Route/gateway/leaf reconnect orchestration is owned by server sessions. + } + + internal (INatsAccount? Account, SubscriptionIndexResult? Result) GetAccAndResultFromCache() + { + var pa = ParseCtx.Pa; + if (pa.Subject is null || pa.Subject.Length == 0) + return (null, null); + + _in.PaCache ??= new Dictionary(StringComparer.Ordinal); + var cacheKeyBytes = pa.PaCache is { Length: > 0 } k ? k : pa.Subject; + var cacheKey = Encoding.ASCII.GetString(cacheKeyBytes); + + if (_in.PaCache.TryGetValue(cacheKey, out var cached) && + cached.Acc is Account cachedAcc && + cached.Results is not null && + cachedAcc.Sublist is not null && + cached.GenId == (ulong)cachedAcc.Sublist.GenId()) + { + return (cached.Acc, cached.Results); + } + + INatsAccount? acc = null; + if (Kind == ClientKind.Router && pa.Account is { Length: > 0 } && _account is not null) + { + acc = _account; + } + else if (Server is NatsServer server && pa.Account is { Length: > 0 } accountNameBytes) + { + var accountName = Encoding.ASCII.GetString(accountNameBytes); + (acc, _) = server.LookupAccount(accountName); + } + + if (acc is not Account concreteAcc || concreteAcc.Sublist is null) + return (null, null); + + var result = concreteAcc.Sublist.MatchBytes(pa.Subject); + if (_in.PaCache.Count >= MaxPerAccountCacheSize) + { + foreach (var key in _in.PaCache.Keys.ToArray()) + { + _in.PaCache.Remove(key); + if (_in.PaCache.Count < MaxPerAccountCacheSize) + break; + } + } + + _in.PaCache[cacheKey] = new PerAccountCache + { + Acc = concreteAcc, + Results = result, + GenId = (ulong)concreteAcc.Sublist.GenId(), + }; + return (concreteAcc, result); + } + + internal void PruneClosedSubFromPerAccountCache() + { + if (_in.PaCache is null || _in.PaCache.Count == 0) + return; + + foreach (var key in _in.PaCache.Keys.ToArray()) + { + var entry = _in.PaCache[key]; + var result = entry.Results; + if (result is null) + { + _in.PaCache.Remove(key); + continue; + } + + var remove = result.PSubs.Any(static s => s.IsClosed()); + if (!remove) + { + foreach (var qsub in result.QSubs) + { + if (qsub.Any(static s => s.IsClosed())) + { + remove = true; + break; + } + } + } + + if (remove) + _in.PaCache.Remove(key); + } + } + + internal void AddServerAndClusterInfo(ClientInfo? ci) + { + if (ci is null) + return; + + if (Server is NatsServer server) + { + ci.Server = Kind == ClientKind.Leaf ? ci.Server : server.Name(); + var cluster = server.CachedClusterName(); + if (!string.IsNullOrWhiteSpace(cluster)) + ci.Cluster = [cluster]; + } + } + + internal ClientInfo? GetClientInfo(bool detailed) + { + if (Kind is not (ClientKind.Client or ClientKind.Leaf or ClientKind.JetStream or ClientKind.Account)) + return null; + + var ci = new ClientInfo(); + if (detailed) + AddServerAndClusterInfo(ci); + + lock (_mu) + { + ci.Account = _account?.Name ?? string.Empty; + ci.Rtt = Rtt; + if (!detailed) + return ci; + + ci.Start = Start == default ? string.Empty : Start.ToString("O"); + ci.Host = Host; + ci.Id = Cid; + ci.Name = Opts.Name; + ci.User = GetRawAuthUser(); + ci.Lang = Opts.Lang; + ci.Version = Opts.Version; + ci.Jwt = Opts.Jwt; + ci.NameTag = NameTag; + ci.Kind = KindString(); + ci.ClientType = ClientTypeString(); + } + + return ci; + } + + internal Exception? DoTLSServerHandshake( + string typ, + SslServerAuthenticationOptions tlsConfig, + double timeout, + PinnedCertSet? pinnedCerts) + { + var (_, err) = DoTLSHandshake(typ, solicit: false, null, tlsConfig, null, string.Empty, timeout, pinnedCerts); + return err; + } + + internal (bool resetTlsName, Exception? err) DoTLSClientHandshake( + string typ, + Uri? url, + SslClientAuthenticationOptions tlsConfig, + string tlsName, + double timeout, + PinnedCertSet? pinnedCerts) + { + return DoTLSHandshake(typ, solicit: true, url, null, tlsConfig, tlsName, timeout, pinnedCerts); + } + + internal (bool resetTlsName, Exception? err) DoTLSHandshake( + string typ, + bool solicit, + Uri? url, + SslServerAuthenticationOptions? serverTlsConfig, + SslClientAuthenticationOptions? clientTlsConfig, + string tlsName, + double timeout, + PinnedCertSet? pinnedCerts) + { + if (_nc is null) + return (false, ServerErrors.ErrConnectionClosed); + + var kind = Kind; + var resetTlsName = false; + Exception? err = null; + SslStream? ssl = null; + + try + { + var baseStream = _nc; + if (solicit) + { + Debugf("Starting TLS {0} client handshake", typ); + var options = clientTlsConfig ?? new SslClientAuthenticationOptions(); + if (string.IsNullOrWhiteSpace(options.TargetHost)) + { + var host = url?.Host ?? string.Empty; + options.TargetHost = !string.IsNullOrWhiteSpace(tlsName) ? tlsName : host; + } + + ssl = new SslStream(baseStream, leaveInnerStreamOpen: false); + _nc = ssl; + + using var cts = timeout > 0 + ? new CancellationTokenSource(TimeSpan.FromSeconds(timeout)) + : new CancellationTokenSource(); + ssl.AuthenticateAsClientAsync(options, cts.Token).GetAwaiter().GetResult(); + } + else + { + Debugf(kind == ClientKind.Client + ? "Starting TLS client connection handshake" + : "Starting TLS {0} server handshake", typ); + ssl = new SslStream(baseStream, leaveInnerStreamOpen: false); + _nc = ssl; + + using var cts = timeout > 0 + ? new CancellationTokenSource(TimeSpan.FromSeconds(timeout)) + : new CancellationTokenSource(); + ssl.AuthenticateAsServerAsync(serverTlsConfig ?? new SslServerAuthenticationOptions(), cts.Token) + .GetAwaiter() + .GetResult(); + } + + if (pinnedCerts is { Count: > 0 } && !MatchesPinnedCert(pinnedCerts)) + err = new InvalidOperationException("certificate not pinned"); + } + catch (AuthenticationException authEx) + { + if (solicit && !string.IsNullOrWhiteSpace(tlsName) && url is not null && + string.Equals(url.Host, tlsName, StringComparison.OrdinalIgnoreCase)) + { + resetTlsName = true; + } + err = authEx; + } + catch (OperationCanceledException) + { + err = new TimeoutException("TLS handshake timeout"); + } + catch (Exception ex) + { + err = ex; + } + + if (err is null) + { + lock (_mu) + { + Flags = Flags.Set(ClientFlags.HandshakeComplete); + if (IsClosed()) + return (false, ServerErrors.ErrConnectionClosed); + } + return (false, null); + } + + if (kind == ClientKind.Client) + Errorf("TLS handshake error: {0}", err.Message); + else + Errorf("TLS {0} handshake error: {1}", typ, err.Message); + + CloseConnection(ClosedState.TlsHandshakeError); + return (resetTlsName, ServerErrors.ErrConnectionClosed); + } + + internal static (HashSet Allowed, Exception? Error) ConvertAllowedConnectionTypes(IEnumerable cts) + { + var unknown = new List(); + var allowed = new HashSet(StringComparer.Ordinal); + + foreach (var value in cts) + { + var upper = value.ToUpperInvariant(); + if (AuthHandler.ConnectionTypes.IsKnown(upper)) + { + allowed.Add(upper); + } + else + { + unknown.Add(upper); + } + } + + return unknown.Count == 0 + ? (allowed, null) + : (allowed, new ArgumentException($"invalid connection types \"{string.Join(",", unknown)}\"")); + } + + internal void RateLimitErrorf(string format, params object?[] args) + { + if (Server is null) + return; + + var statement = string.Format(format, args); + if (!TryMarkRateLimited("ERR:" + statement)) + return; + + var suffix = FormatClientSuffix(); + if (!string.IsNullOrWhiteSpace(String())) + Errorf("{0} - {1}{2}", String(), statement, suffix); + else + Errorf("{0}{1}", statement, suffix); + } + + internal void RateLimitFormatWarnf(string format, params object?[] args) + { + if (Server is null) + return; + + if (!TryMarkRateLimited("WARN_FMT:" + format)) + return; + + var statement = string.Format(format, args); + var suffix = FormatClientSuffix(); + if (!string.IsNullOrWhiteSpace(String())) + Warnf("{0} - {1}{2}", String(), statement, suffix); + else + Warnf("{0}{1}", statement, suffix); + } + + internal void RateLimitWarnf(string format, params object?[] args) + { + if (Server is null) + return; + + var statement = string.Format(format, args); + if (!TryMarkRateLimited("WARN:" + statement)) + return; + + var suffix = FormatClientSuffix(); + if (!string.IsNullOrWhiteSpace(String())) + Warnf("{0} - {1}{2}", String(), statement, suffix); + else + Warnf("{0}{1}", statement, suffix); + } + + internal void RateLimitDebugf(string format, params object?[] args) + { + if (Server is null) + return; + + var statement = string.Format(format, args); + if (!TryMarkRateLimited("DBG:" + statement)) + return; + + var suffix = FormatClientSuffix(); + if (!string.IsNullOrWhiteSpace(String())) + Debugf("{0} - {1}{2}", String(), statement, suffix); + else + Debugf("{0}{1}", statement, suffix); + } + + internal void SetFirstPingTimer() + { + var opts = Server?.Options; + if (opts is null) + return; + + var d = opts.PingInterval; + if (Kind == ClientKind.Router && opts.Cluster.PingInterval > TimeSpan.Zero) + d = opts.Cluster.PingInterval; + if (IsWebSocket() && opts.Websocket.PingInterval > TimeSpan.Zero) + d = opts.Websocket.PingInterval; + + if (!opts.DisableShortFirstPing) + { + if (Kind != ClientKind.Client) + { + if (d > FirstPingInterval) + d = FirstPingInterval; + d = AdjustPingInterval(Kind, d); + } + else if (d > FirstClientPingInterval) + { + d = FirstClientPingInterval; + } + } + + var addTicks = d.Ticks > 0 ? Random.Shared.NextInt64(Math.Max(1, d.Ticks / 5)) : 0L; + d = d.Add(TimeSpan.FromTicks(addTicks)); + + ClearPingTimer(); + _pingTimer = new Timer(_ => ProcessPingTimer(), null, d, Timeout.InfiniteTimeSpan); + } + + private bool TryMarkRateLimited(string key) + { + var serverKey = (object?)Server ?? this; + var cache = RateLimitCacheByServer.GetOrCreateValue(serverKey); + return cache.TryAdd(key, DateTime.UtcNow); + } +} diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.SubscriptionsAndDelivery.cs b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.SubscriptionsAndDelivery.cs index 927c304..8b26bba 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.SubscriptionsAndDelivery.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.SubscriptionsAndDelivery.cs @@ -134,7 +134,7 @@ public sealed partial class ClientConnection return null; // Max-delivery based deferred unsub is not modeled yet, so unsubscribe immediately. - Unsubscribe(Account, sub, force: true, remove: true); + Unsubscribe(_account, sub, force: true, remove: true); } if (Opts.Verbose) diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.cs b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.cs index edecabb..516ad45 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.cs @@ -97,7 +97,7 @@ public sealed partial class ClientConnection // Connection kind and server references. internal ClientKind Kind; // mirrors c.kind internal INatsServer? Server; // mirrors c.srv - internal INatsAccount? Account; // mirrors c.acc + internal INatsAccount? _account; // mirrors c.acc internal ClientPermissions? Perms; // mirrors c.perms internal MsgDeny? MPerms; // mirrors c.mperms @@ -439,15 +439,15 @@ public sealed partial class ClientConnection if (!acc.IsValid) throw new BadAccountException(); // Deregister from previous account. - if (Account is not null) + if (_account is not null) { - var prev = Account.RemoveClient(this); + var prev = _account.RemoveClient(this); if (prev == 1) Server?.DecActiveAccounts(); } lock (_mu) { - Account = acc; + _account = acc; ApplyAccountLimits(); } @@ -503,7 +503,7 @@ public sealed partial class ClientConnection /// internal void ApplyAccountLimits() { - if (Account is null || (Kind != ClientKind.Client && Kind != ClientKind.Leaf)) + if (_account is null || (Kind != ClientKind.Client && Kind != ClientKind.Leaf)) return; Volatile.Write(ref _mpay, JwtNoLimit); @@ -1111,7 +1111,7 @@ public sealed partial class ClientConnection internal void SetAccount(INatsAccount? acc) { - lock (_mu) { Account = acc; } + lock (_mu) { _account = acc; } } internal void SetAccount(Account? acc) => SetAccount(acc as INatsAccount); @@ -1360,25 +1360,29 @@ public sealed partial class ClientConnection // Account / server helpers (features 540-545) // ========================================================================= - internal INatsAccount? GetAccount() + internal INatsAccount? Account() { - lock (_mu) { return Account; } + lock (_mu) { return _account; } } + internal INatsAccount? GetAccount() => Account(); + // ========================================================================= // TLS handshake helpers (features 546-548) // ========================================================================= internal async Task DoTlsServerHandshakeAsync(SslServerAuthenticationOptions opts, CancellationToken ct = default) { - // Deferred: full TLS flow will be completed with server integration. - return false; + _ = ct; + return await Task.FromResult( + DoTLSServerHandshake("client", opts, Server?.Options.TlsTimeout ?? 2, Server?.Options.TlsPinnedCerts) is null); } internal async Task DoTlsClientHandshakeAsync(SslClientAuthenticationOptions opts, CancellationToken ct = default) { - // Deferred: full TLS flow will be completed with server integration. - return false; + _ = ct; + var (_, err) = DoTLSClientHandshake("route", null, opts, opts.TargetHost ?? string.Empty, Server?.Options.TlsTimeout ?? 2, null); + return await Task.FromResult(err is null); } // ========================================================================= @@ -1759,9 +1763,8 @@ public sealed partial class ClientConnection // features 477-496 and 487-503: see ClientConnection.SubscriptionsAndDelivery.cs // features 497-515 and 520: see ClientConnection.InboundAndHeaders.cs - // feature 534-535: swapAccountAfterReload, processSubsOnConfigReload - // feature 537: reconnect - // feature 569: setFirstPingTimer + // features 521-522, 534-535, 537, 540-548, 553, 565-569: + // see ClientConnection.LifecycleAndTls.cs // ========================================================================= // IsMqtt / IsWebSocket helpers (used by clientType, not separately tracked) diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/ClientTypes.cs b/dotnet/src/ZB.MOM.NatsNet.Server/ClientTypes.cs index 608761e..ca53a62 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/ClientTypes.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/ClientTypes.cs @@ -292,10 +292,12 @@ public sealed class ClientOptions /// public sealed class ClientInfo { + public string Server { get; set; } = string.Empty; public string Start { get; set; } = string.Empty; public string Host { get; set; } = string.Empty; public ulong Id { get; set; } public string Account { get; set; } = string.Empty; + public string ServiceName { get; set; } = string.Empty; public string User { get; set; } = string.Empty; public string Name { get; set; } = string.Empty; public string Lang { get; set; } = string.Empty; @@ -311,6 +313,7 @@ public sealed class ClientInfo public bool Restart { get; set; } public bool Disconnect { get; set; } public string[]? Cluster { get; set; } + public List Alternates { get; set; } = []; public bool Service { get; set; } /// @@ -319,6 +322,13 @@ public sealed class ClientInfo /// Added here to support . /// public TimeSpan Rtt { get; set; } + + /// + /// Returns the service account for this client info payload. + /// Mirrors Go ClientInfo.serviceAccount(). + /// + public string ServiceAccount() => + string.IsNullOrWhiteSpace(ServiceName) ? Account : ServiceName; } // ============================================================================ diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.Listeners.cs b/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.Listeners.cs index 6056e61..dc176c6 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.Listeners.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.Listeners.cs @@ -690,8 +690,9 @@ public sealed partial class NatsServer lock (c) { // acc name if not the global account. - if (c.Account?.Name != null && c.Account.Name != ServerConstants.DefaultGlobalAccount) - acc = c.Account.Name; + var account = c.GetAccount(); + if (account?.Name != null && account.Name != ServerConstants.DefaultGlobalAccount) + acc = account.Name; } var cc = new ClosedClient diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ClientConnectionStubFeaturesTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ClientConnectionStubFeaturesTests.cs index 19e14f3..ccaf654 100644 --- a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ClientConnectionStubFeaturesTests.cs +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ClientConnectionStubFeaturesTests.cs @@ -283,4 +283,117 @@ public sealed class ClientConnectionStubFeaturesTests result.PSubs.Add(new Subscription { Subject = Encoding.ASCII.GetBytes("foo"), Sid = Encoding.ASCII.GetBytes("1") }); c.ProcessMsgResults(null, result, "hello\r\n"u8.ToArray(), null, Encoding.ASCII.GetBytes("foo"), null, PmrFlags.None).didDeliver.ShouldBeTrue(); } + + [Fact] + public void LifecycleAndTlsHelpers_GroupC_ShouldBehave() + { + var logger = new CaptureLogger(); + var (server, err) = NatsServer.NewServer(new ServerOptions + { + PingInterval = TimeSpan.FromMilliseconds(120), + }); + err.ShouldBeNull(); + server.SetLogger(logger, debugFlag: true, traceFlag: true); + + using var ms = new MemoryStream(); + var c = new ClientConnection(ClientKind.Client, server, ms) + { + Cid = 42, + Host = "127.0.0.1", + Start = DateTime.UtcNow.AddSeconds(-2), + Rtt = TimeSpan.FromMilliseconds(5), + }; + + c.SetFirstPingTimer(); + GetTimer(c, "_pingTimer").ShouldNotBeNull(); + + c.WatchForStaleConnection(TimeSpan.FromMilliseconds(20), pingMax: 0); + Thread.Sleep(60); + c.IsClosed().ShouldBeTrue(); + + var temp = Account.NewAccount("A"); + temp.Sublist = SubscriptionIndex.NewSublistWithCache(); + c.SetAccount(temp); + + var registered = server.LookupOrRegisterAccount("A").Account; + registered.Sublist = SubscriptionIndex.NewSublistWithCache(); + var inserted = new Subscription + { + Subject = Encoding.ASCII.GetBytes("foo.bar"), + Sid = Encoding.ASCII.GetBytes("11"), + }; + registered.Sublist.Insert(inserted).ShouldBeNull(); + + c.SwapAccountAfterReload(); + c.GetAccount().ShouldBe(registered); + + c.Perms = new ClientPermissions(); + c.Perms.Sub.Deny = SubscriptionIndex.NewSublistWithCache(); + c.Perms.Sub.Deny.Insert(new Subscription { Subject = Encoding.ASCII.GetBytes(">") }).ShouldBeNull(); + c.Subs["22"] = new Subscription + { + Subject = Encoding.ASCII.GetBytes("foo.bar"), + Sid = Encoding.ASCII.GetBytes("22"), + }; + c.ProcessSubsOnConfigReload(new HashSet(StringComparer.Ordinal) { registered.Name }); + c.Subs.ContainsKey("22").ShouldBeFalse(); + + c.ParseCtx.Pa.Account = Encoding.ASCII.GetBytes("A"); + c.ParseCtx.Pa.Subject = Encoding.ASCII.GetBytes("foo.bar"); + c.ParseCtx.Pa.PaCache = Encoding.ASCII.GetBytes("A:foo.bar"); + var cached = c.GetAccAndResultFromCache(); + cached.Account.ShouldBe(registered); + cached.Result.ShouldNotBeNull(); + cached.Result.PSubs.Count.ShouldBeGreaterThan(0); + + var closedSub = new Subscription { Subject = Encoding.ASCII.GetBytes("foo.closed") }; + closedSub.Close(); + var inField = typeof(ClientConnection).GetField("_in", BindingFlags.Instance | BindingFlags.NonPublic)!; + var state = (ReadCacheState)inField.GetValue(c)!; + state.PaCache = new Dictionary(StringComparer.Ordinal) + { + ["closed"] = new PerAccountCache + { + Acc = registered, + Results = new SubscriptionIndexResult + { + PSubs = { closedSub }, + }, + GenId = 1, + }, + }; + inField.SetValue(c, state); + c.PruneClosedSubFromPerAccountCache(); + state = (ReadCacheState)inField.GetValue(c)!; + state.PaCache.ShouldNotBeNull(); + state.PaCache.Count.ShouldBe(0); + + var info = c.GetClientInfo(detailed: true); + info.ShouldNotBeNull(); + info!.Account.ShouldBe("A"); + info.Server.ShouldNotBeNullOrWhiteSpace(); + info.ServiceAccount().ShouldBe("A"); + + var (allowed, convertErr) = ClientConnection.ConvertAllowedConnectionTypes( + ["standard", "mqtt", "bad"]); + allowed.ShouldContain(AuthHandler.ConnectionTypes.Standard); + allowed.ShouldContain(AuthHandler.ConnectionTypes.Mqtt); + convertErr.ShouldNotBeNull(); + + c.RateLimitWarnf("warn {0}", 1); + c.RateLimitWarnf("warn {0}", 1); + logger.Warnings.Count.ShouldBe(1); + } + + private sealed class CaptureLogger : INatsLogger + { + public List Warnings { get; } = []; + + public void Noticef(string format, params object[] args) { } + public void Warnf(string format, params object[] args) => Warnings.Add(string.Format(format, args)); + public void Fatalf(string format, params object[] args) { } + public void Errorf(string format, params object[] args) { } + public void Debugf(string format, params object[] args) { } + public void Tracef(string format, params object[] args) { } + } } diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ImplBacklog/NatsServerTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ImplBacklog/NatsServerTests.cs index e4f24b7..eca96fa 100644 --- a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ImplBacklog/NatsServerTests.cs +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ImplBacklog/NatsServerTests.cs @@ -6,6 +6,26 @@ namespace ZB.MOM.NatsNet.Server.Tests.ImplBacklog; public sealed class NatsServerTests { + [Fact] + public void RateLimitedClientLogging_ShouldSuppressDuplicates() + { + var logger = new NatsServerCaptureLogger(); + var (server, err) = NatsServer.NewServer(new ServerOptions()); + err.ShouldBeNull(); + server.SetLogger(logger, debugFlag: true, traceFlag: true); + + var c = new ClientConnection(ClientKind.Client, server, new MemoryStream()); + c.RateLimitWarnf("duplicate warning {0}", "A"); + c.RateLimitWarnf("duplicate warning {0}", "A"); + c.RateLimitFormatWarnf("format warning {0}", "B"); + c.RateLimitFormatWarnf("format warning {0}", "C"); + c.RateLimitErrorf("duplicate error {0}", "X"); + c.RateLimitErrorf("duplicate error {0}", "X"); + + logger.Warnings.Count.ShouldBe(2); + logger.Errors.Count.ShouldBe(1); + } + [Fact] // T:2886 public void CustomRouterAuthentication_ShouldSucceed() { @@ -518,4 +538,17 @@ public sealed class NatsServerTests "TestServerShutdownDuringStart".ShouldNotBeNullOrWhiteSpace(); } + private sealed class NatsServerCaptureLogger : INatsLogger + { + public List Warnings { get; } = []; + public List Errors { get; } = []; + + public void Noticef(string format, params object[] args) { } + public void Warnf(string format, params object[] args) => Warnings.Add(string.Format(format, args)); + public void Fatalf(string format, params object[] args) { } + public void Errorf(string format, params object[] args) => Errors.Add(string.Format(format, args)); + public void Debugf(string format, params object[] args) { } + public void Tracef(string format, params object[] args) { } + } + } diff --git a/porting.db b/porting.db index 6278a21..26ab5b2 100644 Binary files a/porting.db and b/porting.db differ