// Copyright 2012-2026 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); using System.Collections.Concurrent; using System.Net.Security; using System.Runtime.CompilerServices; using System.Security.Authentication; using System.Text; using System.Linq; using ZB.MOM.NatsNet.Server.Auth; using ZB.MOM.NatsNet.Server.Internal; using ZB.MOM.NatsNet.Server.Internal.DataStructures; namespace ZB.MOM.NatsNet.Server; public sealed partial class ClientConnection { private static readonly TimeSpan FirstPingInterval = TimeSpan.FromSeconds(15); private static readonly TimeSpan FirstClientPingInterval = TimeSpan.FromSeconds(2); private const int MaxPerAccountCacheSize = 8192; private const string StaleErrProtoFormat = "-ERR '{0}'\r\n"; private static readonly ConditionalWeakTable> RateLimitCacheByServer = new(); internal void WatchForStaleConnection(TimeSpan pingInterval, int pingMax) { if (pingInterval <= TimeSpan.Zero || pingMax < 0) return; var staleAfter = TimeSpan.FromTicks(pingInterval.Ticks * (pingMax + 1L)); if (pingMax == 0 && staleAfter > TimeSpan.Zero) staleAfter = TimeSpan.FromTicks(Math.Max(1, pingInterval.Ticks / 2)); if (staleAfter <= TimeSpan.Zero) return; ClearPingTimer(); _pingTimer = new Timer(_ => { lock (_mu) { if (IsClosed()) return; Debugf("Stale Client Connection - Closing"); EnqueueProto(Encoding.ASCII.GetBytes(string.Format(StaleErrProtoFormat, "Stale Connection"))); } CloseConnection(ClosedState.StaleConnection); }, null, staleAfter, Timeout.InfiniteTimeSpan); } internal void SwapAccountAfterReload() { string accountName; lock (_mu) { if (_account is null || Server is null) return; accountName = _account.Name; } if (Server is not NatsServer server) return; var (updated, _) = server.LookupAccount(accountName); if (updated is null) return; lock (_mu) { if (!ReferenceEquals(_account, updated)) _account = updated; } } internal void ProcessSubsOnConfigReload(ISet? accountsWithChangedStreamImports) { INatsAccount? acc; var checkPerms = false; var checkAcc = false; var retained = new List(); var removed = new List(); lock (_mu) { checkPerms = Perms is not null; checkAcc = _account is not null; acc = _account; if (!checkPerms && !checkAcc) return; if (checkAcc && acc is not null && accountsWithChangedStreamImports is not null && !accountsWithChangedStreamImports.Contains(acc.Name)) { checkAcc = false; } MPerms = null; foreach (var sub in Subs.Values) { var subject = Encoding.ASCII.GetString(sub.Subject); var canSub = CanSubscribe(subject); var canQSub = sub.Queue is { Length: > 0 } q && CanSubscribe(subject, Encoding.ASCII.GetString(q)); if (!canSub && !canQSub) { removed.Add(sub); } else if (checkAcc) { retained.Add(sub); } } } if (checkAcc && acc is not null) { foreach (var sub in retained) { AddShadowSubscriptions(acc, sub); } } foreach (var sub in removed) { Unsubscribe(acc, sub, force: true, remove: true); var sid = sub.Sid is { Length: > 0 } s ? Encoding.ASCII.GetString(s) : string.Empty; SendErr($"Permissions Violation for Subscription to \"{Encoding.ASCII.GetString(sub.Subject)}\" (sid \"{sid}\")"); Noticef("Removed sub \"{0}\" (sid \"{1}\") for \"{2}\" - not authorized", Encoding.ASCII.GetString(sub.Subject), sid, GetAuthUser()); } } internal void Reconnect() { lock (_mu) { if (Flags.IsSet(ClientFlags.NoReconnect) || Server is null) return; } // Route/gateway/leaf reconnect orchestration is owned by server sessions. } internal (INatsAccount? Account, SubscriptionIndexResult? Result) GetAccAndResultFromCache() { var pa = ParseCtx.Pa; if (pa.Subject is null || pa.Subject.Length == 0) return (null, null); _in.PaCache ??= new Dictionary(StringComparer.Ordinal); var cacheKeyBytes = pa.PaCache is { Length: > 0 } k ? k : pa.Subject; var cacheKey = Encoding.ASCII.GetString(cacheKeyBytes); if (_in.PaCache.TryGetValue(cacheKey, out var cached) && cached.Acc is Account cachedAcc && cached.Results is not null && cachedAcc.Sublist is not null && cached.GenId == (ulong)cachedAcc.Sublist.GenId()) { return (cached.Acc, cached.Results); } INatsAccount? acc = null; if (Kind == ClientKind.Router && pa.Account is { Length: > 0 } && _account is not null) { acc = _account; } else if (Server is NatsServer server && pa.Account is { Length: > 0 } accountNameBytes) { var accountName = Encoding.ASCII.GetString(accountNameBytes); (acc, _) = server.LookupAccount(accountName); } if (acc is not Account concreteAcc || concreteAcc.Sublist is null) return (null, null); var result = concreteAcc.Sublist.MatchBytes(pa.Subject); if (_in.PaCache.Count >= MaxPerAccountCacheSize) { foreach (var key in _in.PaCache.Keys.ToArray()) { _in.PaCache.Remove(key); if (_in.PaCache.Count < MaxPerAccountCacheSize) break; } } _in.PaCache[cacheKey] = new PerAccountCache { Acc = concreteAcc, Results = result, GenId = (ulong)concreteAcc.Sublist.GenId(), }; return (concreteAcc, result); } internal void PruneClosedSubFromPerAccountCache() { if (_in.PaCache is null || _in.PaCache.Count == 0) return; foreach (var key in _in.PaCache.Keys.ToArray()) { var entry = _in.PaCache[key]; var result = entry.Results; if (result is null) { _in.PaCache.Remove(key); continue; } var remove = result.PSubs.Any(static s => s.IsClosed()); if (!remove) { foreach (var qsub in result.QSubs) { if (qsub.Any(static s => s.IsClosed())) { remove = true; break; } } } if (remove) _in.PaCache.Remove(key); } } internal void AddServerAndClusterInfo(ClientInfo? ci) { if (ci is null) return; if (Server is NatsServer server) { ci.Server = Kind == ClientKind.Leaf ? ci.Server : server.Name(); var cluster = server.CachedClusterName(); if (!string.IsNullOrWhiteSpace(cluster)) ci.Cluster = [cluster]; } } internal ClientInfo? GetClientInfo(bool detailed) { if (Kind is not (ClientKind.Client or ClientKind.Leaf or ClientKind.JetStream or ClientKind.Account)) return null; var ci = new ClientInfo(); if (detailed) AddServerAndClusterInfo(ci); lock (_mu) { ci.Account = _account?.Name ?? string.Empty; ci.Rtt = Rtt; if (!detailed) return ci; ci.Start = Start == default ? string.Empty : Start.ToString("O"); ci.Host = Host; ci.Id = Cid; ci.Name = Opts.Name; ci.User = GetRawAuthUser(); ci.Lang = Opts.Lang; ci.Version = Opts.Version; ci.Jwt = Opts.Jwt; ci.NameTag = NameTag; ci.Kind = KindString(); ci.ClientType = ClientTypeString(); } return ci; } internal Exception? DoTLSServerHandshake( string typ, SslServerAuthenticationOptions tlsConfig, double timeout, PinnedCertSet? pinnedCerts) { var (_, err) = DoTLSHandshake(typ, solicit: false, null, tlsConfig, null, string.Empty, timeout, pinnedCerts); return err; } internal (bool resetTlsName, Exception? err) DoTLSClientHandshake( string typ, Uri? url, SslClientAuthenticationOptions tlsConfig, string tlsName, double timeout, PinnedCertSet? pinnedCerts) { return DoTLSHandshake(typ, solicit: true, url, null, tlsConfig, tlsName, timeout, pinnedCerts); } internal (bool resetTlsName, Exception? err) DoTLSHandshake( string typ, bool solicit, Uri? url, SslServerAuthenticationOptions? serverTlsConfig, SslClientAuthenticationOptions? clientTlsConfig, string tlsName, double timeout, PinnedCertSet? pinnedCerts) { if (_nc is null) return (false, ServerErrors.ErrConnectionClosed); var kind = Kind; var resetTlsName = false; Exception? err = null; SslStream? ssl = null; try { var baseStream = _nc; if (solicit) { Debugf("Starting TLS {0} client handshake", typ); var options = clientTlsConfig ?? new SslClientAuthenticationOptions(); if (string.IsNullOrWhiteSpace(options.TargetHost)) { var host = url?.Host ?? string.Empty; options.TargetHost = !string.IsNullOrWhiteSpace(tlsName) ? tlsName : host; } ssl = new SslStream(baseStream, leaveInnerStreamOpen: false); _nc = ssl; using var cts = timeout > 0 ? new CancellationTokenSource(TimeSpan.FromSeconds(timeout)) : new CancellationTokenSource(); ssl.AuthenticateAsClientAsync(options, cts.Token).GetAwaiter().GetResult(); } else { Debugf(kind == ClientKind.Client ? "Starting TLS client connection handshake" : "Starting TLS {0} server handshake", typ); ssl = new SslStream(baseStream, leaveInnerStreamOpen: false); _nc = ssl; using var cts = timeout > 0 ? new CancellationTokenSource(TimeSpan.FromSeconds(timeout)) : new CancellationTokenSource(); ssl.AuthenticateAsServerAsync(serverTlsConfig ?? new SslServerAuthenticationOptions(), cts.Token) .GetAwaiter() .GetResult(); } if (pinnedCerts is { Count: > 0 } && !MatchesPinnedCert(pinnedCerts)) err = new InvalidOperationException("certificate not pinned"); } catch (AuthenticationException authEx) { if (solicit && !string.IsNullOrWhiteSpace(tlsName) && url is not null && string.Equals(url.Host, tlsName, StringComparison.OrdinalIgnoreCase)) { resetTlsName = true; } err = authEx; } catch (OperationCanceledException) { err = new TimeoutException("TLS handshake timeout"); } catch (Exception ex) { err = ex; } if (err is null) { lock (_mu) { Flags = Flags.Set(ClientFlags.HandshakeComplete); if (IsClosed()) return (false, ServerErrors.ErrConnectionClosed); } return (false, null); } if (kind == ClientKind.Client) Errorf("TLS handshake error: {0}", err.Message); else Errorf("TLS {0} handshake error: {1}", typ, err.Message); CloseConnection(ClosedState.TlsHandshakeError); return (resetTlsName, ServerErrors.ErrConnectionClosed); } internal static (HashSet Allowed, Exception? Error) ConvertAllowedConnectionTypes(IEnumerable cts) { var unknown = new List(); var allowed = new HashSet(StringComparer.Ordinal); foreach (var value in cts) { var upper = value.ToUpperInvariant(); if (AuthHandler.ConnectionTypes.IsKnown(upper)) { allowed.Add(upper); } else { unknown.Add(upper); } } return unknown.Count == 0 ? (allowed, null) : (allowed, new ArgumentException($"invalid connection types \"{string.Join(",", unknown)}\"")); } internal void RateLimitErrorf(string format, params object?[] args) { if (Server is null) return; var statement = string.Format(format, args); if (!TryMarkRateLimited("ERR:" + statement)) return; var suffix = FormatClientSuffix(); if (!string.IsNullOrWhiteSpace(String())) Errorf("{0} - {1}{2}", String(), statement, suffix); else Errorf("{0}{1}", statement, suffix); } internal void RateLimitFormatWarnf(string format, params object?[] args) { if (Server is null) return; if (!TryMarkRateLimited("WARN_FMT:" + format)) return; var statement = string.Format(format, args); var suffix = FormatClientSuffix(); if (!string.IsNullOrWhiteSpace(String())) Warnf("{0} - {1}{2}", String(), statement, suffix); else Warnf("{0}{1}", statement, suffix); } internal void RateLimitWarnf(string format, params object?[] args) { if (Server is null) return; var statement = string.Format(format, args); if (!TryMarkRateLimited("WARN:" + statement)) return; var suffix = FormatClientSuffix(); if (!string.IsNullOrWhiteSpace(String())) Warnf("{0} - {1}{2}", String(), statement, suffix); else Warnf("{0}{1}", statement, suffix); } internal void RateLimitDebugf(string format, params object?[] args) { if (Server is null) return; var statement = string.Format(format, args); if (!TryMarkRateLimited("DBG:" + statement)) return; var suffix = FormatClientSuffix(); if (!string.IsNullOrWhiteSpace(String())) Debugf("{0} - {1}{2}", String(), statement, suffix); else Debugf("{0}{1}", statement, suffix); } internal void SetFirstPingTimer() { var opts = Server?.Options; if (opts is null) return; var d = opts.PingInterval; if (Kind == ClientKind.Router && opts.Cluster.PingInterval > TimeSpan.Zero) d = opts.Cluster.PingInterval; if (IsWebSocket() && opts.Websocket.PingInterval > TimeSpan.Zero) d = opts.Websocket.PingInterval; if (!opts.DisableShortFirstPing) { if (Kind != ClientKind.Client) { if (d > FirstPingInterval) d = FirstPingInterval; d = AdjustPingInterval(Kind, d); } else if (d > FirstClientPingInterval) { d = FirstClientPingInterval; } } var addTicks = d.Ticks > 0 ? Random.Shared.NextInt64(Math.Max(1, d.Ticks / 5)) : 0L; d = d.Add(TimeSpan.FromTicks(addTicks)); ClearPingTimer(); _pingTimer = new Timer(_ => ProcessPingTimer(), null, d, Timeout.InfiniteTimeSpan); } private bool TryMarkRateLimited(string key) { var serverKey = (object?)Server ?? this; var cache = RateLimitCacheByServer.GetOrCreateValue(serverKey); return cache.TryAdd(key, DateTime.UtcNow); } }