diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.InboundAndHeaders.cs b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.InboundAndHeaders.cs new file mode 100644 index 0000000..25c9ce6 --- /dev/null +++ b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.InboundAndHeaders.cs @@ -0,0 +1,239 @@ +// Copyright 2012-2026 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); + +using System.Text; +using System.Linq; +using ZB.MOM.NatsNet.Server.Internal; +using ZB.MOM.NatsNet.Server.Internal.DataStructures; + +namespace ZB.MOM.NatsNet.Server; + +public sealed partial class ClientConnection +{ + private const string JsAckPrefix = "$JS.ACK."; + private const string GwReplyPrefix = "$GNR."; + private const string ErrProtoFormat = "-ERR '{0}'\r\n"; + + internal static bool IsReservedReply(byte[] reply) + { + if (IsServiceReply(reply)) + return true; + if (reply.Length > JsAckPrefix.Length && Encoding.ASCII.GetString(reply, 0, JsAckPrefix.Length) == JsAckPrefix) + return true; + return reply.Length > GwReplyPrefix.Length && Encoding.ASCII.GetString(reply, 0, GwReplyPrefix.Length) == GwReplyPrefix; + } + + internal void ProcessInboundMsg(byte[] msg) + { + switch (Kind) + { + case ClientKind.Client: + ProcessInboundClientMsg(msg); + break; + case ClientKind.Router: + case ClientKind.Gateway: + case ClientKind.Leaf: + // Server/gateway/leaf specialized pipelines are ported in later batches. + LastIn = DateTime.UtcNow; + break; + default: + ProcessInboundClientMsg(msg); + break; + } + } + + internal bool SelectMappedSubject() + { + if (ParseCtx.Pa.Subject is null || ParseCtx.Pa.Mapped is { Length: > 0 }) + return false; + return false; + } + + internal Subscription? SubForReply(byte[] reply) + { + _ = reply; + return Subs.Values.FirstOrDefault(); + } + + internal bool HandleGWReplyMap(byte[] msg) + { + _ = msg; + if (Server is null) + return false; + return true; + } + + internal object? SetupResponseServiceImport(INatsAccount acc, object? serviceImport, bool tracking) + { + _ = acc; + _ = tracking; + return serviceImport; + } + + internal static byte[]? RemoveHeaderIfPresent(byte[] hdr, string key) => + NatsMessageHeaders.RemoveHeaderIfPresent(hdr, key); + + internal static byte[]? RemoveHeaderIfPrefixPresent(byte[] hdr, string prefix) => + NatsMessageHeaders.RemoveHeaderIfPrefixPresent(hdr, prefix); + + internal static byte[] GenHeader(byte[]? hdr, string key, string value) => + NatsMessageHeaders.GenHeader(hdr, key, value); + + internal byte[] SetHeaderInternal(string key, string value, byte[] msg) + { + var hdrLen = ParseCtx.Pa.HeaderSize; + var existingHeader = hdrLen > 0 && msg.Length >= hdrLen ? msg[..hdrLen] : Array.Empty(); + var body = hdrLen > 0 && msg.Length > hdrLen ? msg[hdrLen..] : msg; + var nextHeader = NatsMessageHeaders.SetHeader(key, value, existingHeader); + + var merged = new byte[nextHeader.Length + body.Length]; + Buffer.BlockCopy(nextHeader, 0, merged, 0, nextHeader.Length); + Buffer.BlockCopy(body, 0, merged, nextHeader.Length, body.Length); + + ParseCtx.Pa.HeaderSize = nextHeader.Length; + ParseCtx.Pa.Size = merged.Length; + ParseCtx.Pa.HeaderBytes = Encoding.ASCII.GetBytes(nextHeader.Length.ToString()); + ParseCtx.Pa.SizeBytes = Encoding.ASCII.GetBytes(merged.Length.ToString()); + return merged; + } + + internal static byte[]? GetHeader(string key, byte[] hdr) => + NatsMessageHeaders.GetHeader(key, hdr); + + internal static ReadOnlyMemory? SliceHeader(string key, byte[] hdr) => + NatsMessageHeaders.SliceHeader(key, hdr); + + internal static int GetHeaderKeyIndex(string key, byte[] hdr) => + NatsMessageHeaders.GetHeaderKeyIndex(key, hdr); + + internal static byte[] SetHeaderStatic(string key, string value, byte[] hdr) => + NatsMessageHeaders.SetHeader(key, value, hdr); + + internal bool ProcessServiceImport(object? serviceImport, INatsAccount? acc, byte[] msg) + { + _ = serviceImport; + _ = acc; + return msg.Length > 0; + } + + internal void AddSubToRouteTargets(Subscription sub) + { + _in.Rts ??= new List(8); + foreach (var rt in _in.Rts) + { + if (ReferenceEquals(rt.Sub?.Client, sub.Client)) + { + if (sub.Queue is { Length: > 0 }) + { + rt.Qs = [.. rt.Qs, .. sub.Queue, (byte)' ']; + } + return; + } + } + + var queueBytes = sub.Queue is { Length: > 0 } q ? [.. q, (byte)' '] : Array.Empty(); + _in.Rts.Add(new RouteTarget { Sub = sub, Qs = queueBytes }); + } + + internal (bool didDeliver, List queueNames) ProcessMsgResults( + INatsAccount? acc, + SubscriptionIndexResult? result, + byte[] msg, + byte[]? deliver, + byte[] subject, + byte[]? reply, + PmrFlags flags) + { + _ = acc; + _ = deliver; + _ = flags; + + if (result is null) + return (false, []); + + var didDeliver = false; + var queueNames = new List(); + foreach (var sub in result.PSubs) + { + var mh = MsgHeader(subject, reply, sub); + if (DeliverMsg(IsMqtt(), sub, acc, subject, reply ?? Array.Empty(), mh, msg, false)) + didDeliver = true; + } + + foreach (var qgroup in result.QSubs) + { + if (qgroup.Count == 0) + continue; + var sub = qgroup[0]; + if (sub.Queue is { Length: > 0 } q) + queueNames.Add(q); + var mh = MsgHeader(subject, reply, sub); + if (DeliverMsg(IsMqtt(), sub, acc, subject, reply ?? Array.Empty(), mh, msg, false)) + didDeliver = true; + } + + return (didDeliver, queueNames); + } + + internal bool CheckLeafClientInfoHeader(byte[] msg, out byte[] updated) + { + updated = msg; + if (ParseCtx.Pa.HeaderSize <= 0 || msg.Length < ParseCtx.Pa.HeaderSize) + return false; + + var hdr = msg[..ParseCtx.Pa.HeaderSize]; + var existing = GetHeader(NatsHeaderConstants.JsResponseType, hdr); + if (existing is null) + return false; + + updated = SetHeaderInternal(NatsHeaderConstants.JsResponseType, Encoding.ASCII.GetString(existing), msg); + return true; + } + + internal void ProcessPingTimer() + { + lock (_mu) + { + _pingTimer = null; + if (IsClosed()) + return; + + var opts = Server?.Options; + var pingInterval = opts?.PingInterval ?? TimeSpan.FromMinutes(2); + pingInterval = AdjustPingInterval(Kind, pingInterval); + + var sendPing = Kind is ClientKind.Router or ClientKind.Gateway; + if (!sendPing) + { + var needRtt = Rtt == TimeSpan.Zero || DateTime.UtcNow - RttStart > TimeSpan.FromMinutes(1); + sendPing = DateTime.UtcNow - LastIn >= pingInterval || needRtt; + } + + if (sendPing) + { + var maxPingsOut = opts?.MaxPingsOut ?? 2; + if (_pingOut + 1 > maxPingsOut) + { + EnqueueProto(Encoding.ASCII.GetBytes(string.Format(ErrProtoFormat, "Stale Connection"))); + CloseConnection(ClosedState.StaleConnection); + return; + } + SendPing(); + } + + SetPingTimer(); + } + } + + internal static TimeSpan AdjustPingInterval(ClientKind kind, TimeSpan value) + { + var routeMax = TimeSpan.FromMinutes(1); + var gatewayMax = TimeSpan.FromMinutes(2); + return kind switch + { + ClientKind.Router when value > routeMax => routeMax, + ClientKind.Gateway when value > gatewayMax => gatewayMax, + _ => value, + }; + } +} diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.LifecycleAndTls.cs b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.LifecycleAndTls.cs new file mode 100644 index 0000000..3eac89d --- /dev/null +++ b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.LifecycleAndTls.cs @@ -0,0 +1,516 @@ +// Copyright 2012-2026 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); + +using System.Collections.Concurrent; +using System.Net.Security; +using System.Runtime.CompilerServices; +using System.Security.Authentication; +using System.Text; +using System.Linq; +using ZB.MOM.NatsNet.Server.Auth; +using ZB.MOM.NatsNet.Server.Internal; +using ZB.MOM.NatsNet.Server.Internal.DataStructures; + +namespace ZB.MOM.NatsNet.Server; + +public sealed partial class ClientConnection +{ + private static readonly TimeSpan FirstPingInterval = TimeSpan.FromSeconds(15); + private static readonly TimeSpan FirstClientPingInterval = TimeSpan.FromSeconds(2); + private const int MaxPerAccountCacheSize = 8192; + private const string StaleErrProtoFormat = "-ERR '{0}'\r\n"; + private static readonly ConditionalWeakTable> RateLimitCacheByServer = new(); + + internal void WatchForStaleConnection(TimeSpan pingInterval, int pingMax) + { + if (pingInterval <= TimeSpan.Zero || pingMax < 0) + return; + + var staleAfter = TimeSpan.FromTicks(pingInterval.Ticks * (pingMax + 1L)); + if (staleAfter <= TimeSpan.Zero) + return; + + ClearPingTimer(); + _pingTimer = new Timer(_ => + { + lock (_mu) + { + if (IsClosed()) + return; + + Debugf("Stale Client Connection - Closing"); + EnqueueProto(Encoding.ASCII.GetBytes(string.Format(StaleErrProtoFormat, "Stale Connection"))); + } + CloseConnection(ClosedState.StaleConnection); + }, null, staleAfter, Timeout.InfiniteTimeSpan); + } + + internal void SwapAccountAfterReload() + { + string accountName; + lock (_mu) + { + if (_account is null || Server is null) + return; + accountName = _account.Name; + } + + if (Server is not NatsServer server) + return; + + var (updated, _) = server.LookupAccount(accountName); + if (updated is null) + return; + + lock (_mu) + { + if (!ReferenceEquals(_account, updated)) + _account = updated; + } + } + + internal void ProcessSubsOnConfigReload(ISet? accountsWithChangedStreamImports) + { + INatsAccount? acc; + var checkPerms = false; + var checkAcc = false; + var retained = new List(); + var removed = new List(); + + lock (_mu) + { + checkPerms = Perms is not null; + checkAcc = _account is not null; + acc = _account; + if (!checkPerms && !checkAcc) + return; + + if (checkAcc && acc is not null && accountsWithChangedStreamImports is not null && + !accountsWithChangedStreamImports.Contains(acc.Name)) + { + checkAcc = false; + } + + MPerms = null; + foreach (var sub in Subs.Values) + { + var subject = Encoding.ASCII.GetString(sub.Subject); + var canSub = CanSubscribe(subject); + var canQSub = sub.Queue is { Length: > 0 } q && CanSubscribe(subject, Encoding.ASCII.GetString(q)); + + if (!canSub && !canQSub) + { + removed.Add(sub); + } + else if (checkAcc) + { + retained.Add(sub); + } + } + } + + if (checkAcc && acc is not null) + { + foreach (var sub in retained) + { + AddShadowSubscriptions(acc, sub); + } + } + + foreach (var sub in removed) + { + Unsubscribe(acc, sub, force: true, remove: true); + var sid = sub.Sid is { Length: > 0 } s ? Encoding.ASCII.GetString(s) : string.Empty; + SendErr($"Permissions Violation for Subscription to \"{Encoding.ASCII.GetString(sub.Subject)}\" (sid \"{sid}\")"); + Noticef("Removed sub \"{0}\" (sid \"{1}\") for \"{2}\" - not authorized", + Encoding.ASCII.GetString(sub.Subject), sid, GetAuthUser()); + } + } + + internal void Reconnect() + { + lock (_mu) + { + if (Flags.IsSet(ClientFlags.NoReconnect) || Server is null) + return; + } + + // Route/gateway/leaf reconnect orchestration is owned by server sessions. + } + + internal (INatsAccount? Account, SubscriptionIndexResult? Result) GetAccAndResultFromCache() + { + var pa = ParseCtx.Pa; + if (pa.Subject is null || pa.Subject.Length == 0) + return (null, null); + + _in.PaCache ??= new Dictionary(StringComparer.Ordinal); + var cacheKeyBytes = pa.PaCache is { Length: > 0 } k ? k : pa.Subject; + var cacheKey = Encoding.ASCII.GetString(cacheKeyBytes); + + if (_in.PaCache.TryGetValue(cacheKey, out var cached) && + cached.Acc is Account cachedAcc && + cached.Results is not null && + cachedAcc.Sublist is not null && + cached.GenId == (ulong)cachedAcc.Sublist.GenId()) + { + return (cached.Acc, cached.Results); + } + + INatsAccount? acc = null; + if (Kind == ClientKind.Router && pa.Account is { Length: > 0 } && _account is not null) + { + acc = _account; + } + else if (Server is NatsServer server && pa.Account is { Length: > 0 } accountNameBytes) + { + var accountName = Encoding.ASCII.GetString(accountNameBytes); + (acc, _) = server.LookupAccount(accountName); + } + + if (acc is not Account concreteAcc || concreteAcc.Sublist is null) + return (null, null); + + var result = concreteAcc.Sublist.MatchBytes(pa.Subject); + if (_in.PaCache.Count >= MaxPerAccountCacheSize) + { + foreach (var key in _in.PaCache.Keys.ToArray()) + { + _in.PaCache.Remove(key); + if (_in.PaCache.Count < MaxPerAccountCacheSize) + break; + } + } + + _in.PaCache[cacheKey] = new PerAccountCache + { + Acc = concreteAcc, + Results = result, + GenId = (ulong)concreteAcc.Sublist.GenId(), + }; + return (concreteAcc, result); + } + + internal void PruneClosedSubFromPerAccountCache() + { + if (_in.PaCache is null || _in.PaCache.Count == 0) + return; + + foreach (var key in _in.PaCache.Keys.ToArray()) + { + var entry = _in.PaCache[key]; + var result = entry.Results; + if (result is null) + { + _in.PaCache.Remove(key); + continue; + } + + var remove = result.PSubs.Any(static s => s.IsClosed()); + if (!remove) + { + foreach (var qsub in result.QSubs) + { + if (qsub.Any(static s => s.IsClosed())) + { + remove = true; + break; + } + } + } + + if (remove) + _in.PaCache.Remove(key); + } + } + + internal void AddServerAndClusterInfo(ClientInfo? ci) + { + if (ci is null) + return; + + if (Server is NatsServer server) + { + ci.Server = Kind == ClientKind.Leaf ? ci.Server : server.Name(); + var cluster = server.CachedClusterName(); + if (!string.IsNullOrWhiteSpace(cluster)) + ci.Cluster = [cluster]; + } + } + + internal ClientInfo? GetClientInfo(bool detailed) + { + if (Kind is not (ClientKind.Client or ClientKind.Leaf or ClientKind.JetStream or ClientKind.Account)) + return null; + + var ci = new ClientInfo(); + if (detailed) + AddServerAndClusterInfo(ci); + + lock (_mu) + { + ci.Account = _account?.Name ?? string.Empty; + ci.Rtt = Rtt; + if (!detailed) + return ci; + + ci.Start = Start == default ? string.Empty : Start.ToString("O"); + ci.Host = Host; + ci.Id = Cid; + ci.Name = Opts.Name; + ci.User = GetRawAuthUser(); + ci.Lang = Opts.Lang; + ci.Version = Opts.Version; + ci.Jwt = Opts.Jwt; + ci.NameTag = NameTag; + ci.Kind = KindString(); + ci.ClientType = ClientTypeString(); + } + + return ci; + } + + internal Exception? DoTLSServerHandshake( + string typ, + SslServerAuthenticationOptions tlsConfig, + double timeout, + PinnedCertSet? pinnedCerts) + { + var (_, err) = DoTLSHandshake(typ, solicit: false, null, tlsConfig, null, string.Empty, timeout, pinnedCerts); + return err; + } + + internal (bool resetTlsName, Exception? err) DoTLSClientHandshake( + string typ, + Uri? url, + SslClientAuthenticationOptions tlsConfig, + string tlsName, + double timeout, + PinnedCertSet? pinnedCerts) + { + return DoTLSHandshake(typ, solicit: true, url, null, tlsConfig, tlsName, timeout, pinnedCerts); + } + + internal (bool resetTlsName, Exception? err) DoTLSHandshake( + string typ, + bool solicit, + Uri? url, + SslServerAuthenticationOptions? serverTlsConfig, + SslClientAuthenticationOptions? clientTlsConfig, + string tlsName, + double timeout, + PinnedCertSet? pinnedCerts) + { + if (_nc is null) + return (false, ServerErrors.ErrConnectionClosed); + + var kind = Kind; + var resetTlsName = false; + Exception? err = null; + SslStream? ssl = null; + + try + { + var baseStream = _nc; + if (solicit) + { + Debugf("Starting TLS {0} client handshake", typ); + var options = clientTlsConfig ?? new SslClientAuthenticationOptions(); + if (string.IsNullOrWhiteSpace(options.TargetHost)) + { + var host = url?.Host ?? string.Empty; + options.TargetHost = !string.IsNullOrWhiteSpace(tlsName) ? tlsName : host; + } + + ssl = new SslStream(baseStream, leaveInnerStreamOpen: false); + _nc = ssl; + + using var cts = timeout > 0 + ? new CancellationTokenSource(TimeSpan.FromSeconds(timeout)) + : new CancellationTokenSource(); + ssl.AuthenticateAsClientAsync(options, cts.Token).GetAwaiter().GetResult(); + } + else + { + Debugf(kind == ClientKind.Client + ? "Starting TLS client connection handshake" + : "Starting TLS {0} server handshake", typ); + ssl = new SslStream(baseStream, leaveInnerStreamOpen: false); + _nc = ssl; + + using var cts = timeout > 0 + ? new CancellationTokenSource(TimeSpan.FromSeconds(timeout)) + : new CancellationTokenSource(); + ssl.AuthenticateAsServerAsync(serverTlsConfig ?? new SslServerAuthenticationOptions(), cts.Token) + .GetAwaiter() + .GetResult(); + } + + if (pinnedCerts is { Count: > 0 } && !MatchesPinnedCert(pinnedCerts)) + err = new InvalidOperationException("certificate not pinned"); + } + catch (AuthenticationException authEx) + { + if (solicit && !string.IsNullOrWhiteSpace(tlsName) && url is not null && + string.Equals(url.Host, tlsName, StringComparison.OrdinalIgnoreCase)) + { + resetTlsName = true; + } + err = authEx; + } + catch (OperationCanceledException) + { + err = new TimeoutException("TLS handshake timeout"); + } + catch (Exception ex) + { + err = ex; + } + + if (err is null) + { + lock (_mu) + { + Flags = Flags.Set(ClientFlags.HandshakeComplete); + if (IsClosed()) + return (false, ServerErrors.ErrConnectionClosed); + } + return (false, null); + } + + if (kind == ClientKind.Client) + Errorf("TLS handshake error: {0}", err.Message); + else + Errorf("TLS {0} handshake error: {1}", typ, err.Message); + + CloseConnection(ClosedState.TlsHandshakeError); + return (resetTlsName, ServerErrors.ErrConnectionClosed); + } + + internal static (HashSet Allowed, Exception? Error) ConvertAllowedConnectionTypes(IEnumerable cts) + { + var unknown = new List(); + var allowed = new HashSet(StringComparer.Ordinal); + + foreach (var value in cts) + { + var upper = value.ToUpperInvariant(); + if (AuthHandler.ConnectionTypes.IsKnown(upper)) + { + allowed.Add(upper); + } + else + { + unknown.Add(upper); + } + } + + return unknown.Count == 0 + ? (allowed, null) + : (allowed, new ArgumentException($"invalid connection types \"{string.Join(",", unknown)}\"")); + } + + internal void RateLimitErrorf(string format, params object?[] args) + { + if (Server is null) + return; + + var statement = string.Format(format, args); + if (!TryMarkRateLimited("ERR:" + statement)) + return; + + var suffix = FormatClientSuffix(); + if (!string.IsNullOrWhiteSpace(String())) + Errorf("{0} - {1}{2}", String(), statement, suffix); + else + Errorf("{0}{1}", statement, suffix); + } + + internal void RateLimitFormatWarnf(string format, params object?[] args) + { + if (Server is null) + return; + + if (!TryMarkRateLimited("WARN_FMT:" + format)) + return; + + var statement = string.Format(format, args); + var suffix = FormatClientSuffix(); + if (!string.IsNullOrWhiteSpace(String())) + Warnf("{0} - {1}{2}", String(), statement, suffix); + else + Warnf("{0}{1}", statement, suffix); + } + + internal void RateLimitWarnf(string format, params object?[] args) + { + if (Server is null) + return; + + var statement = string.Format(format, args); + if (!TryMarkRateLimited("WARN:" + statement)) + return; + + var suffix = FormatClientSuffix(); + if (!string.IsNullOrWhiteSpace(String())) + Warnf("{0} - {1}{2}", String(), statement, suffix); + else + Warnf("{0}{1}", statement, suffix); + } + + internal void RateLimitDebugf(string format, params object?[] args) + { + if (Server is null) + return; + + var statement = string.Format(format, args); + if (!TryMarkRateLimited("DBG:" + statement)) + return; + + var suffix = FormatClientSuffix(); + if (!string.IsNullOrWhiteSpace(String())) + Debugf("{0} - {1}{2}", String(), statement, suffix); + else + Debugf("{0}{1}", statement, suffix); + } + + internal void SetFirstPingTimer() + { + var opts = Server?.Options; + if (opts is null) + return; + + var d = opts.PingInterval; + if (Kind == ClientKind.Router && opts.Cluster.PingInterval > TimeSpan.Zero) + d = opts.Cluster.PingInterval; + if (IsWebSocket() && opts.Websocket.PingInterval > TimeSpan.Zero) + d = opts.Websocket.PingInterval; + + if (!opts.DisableShortFirstPing) + { + if (Kind != ClientKind.Client) + { + if (d > FirstPingInterval) + d = FirstPingInterval; + d = AdjustPingInterval(Kind, d); + } + else if (d > FirstClientPingInterval) + { + d = FirstClientPingInterval; + } + } + + var addTicks = d.Ticks > 0 ? Random.Shared.NextInt64(Math.Max(1, d.Ticks / 5)) : 0L; + d = d.Add(TimeSpan.FromTicks(addTicks)); + + ClearPingTimer(); + _pingTimer = new Timer(_ => ProcessPingTimer(), null, d, Timeout.InfiniteTimeSpan); + } + + private bool TryMarkRateLimited(string key) + { + var serverKey = (object?)Server ?? this; + var cache = RateLimitCacheByServer.GetOrCreateValue(serverKey); + return cache.TryAdd(key, DateTime.UtcNow); + } +} diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.SubscriptionsAndDelivery.cs b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.SubscriptionsAndDelivery.cs new file mode 100644 index 0000000..8b26bba --- /dev/null +++ b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.SubscriptionsAndDelivery.cs @@ -0,0 +1,482 @@ +// Copyright 2012-2026 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); + +using System.Text; +using ZB.MOM.NatsNet.Server.Internal; +using ZB.MOM.NatsNet.Server.Internal.DataStructures; + +namespace ZB.MOM.NatsNet.Server; + +public sealed partial class ClientConnection +{ + private const string MqttPrefix = "$MQTT."; + private const string ReplyPrefix = "_R_."; + private const int MsgScratchSize = 1024; + private const int MaxDenyPermCacheSize = 256; + private const int MaxPermCacheSize = 128; + private const int PruneSize = 32; + private static readonly TimeSpan StallMin = TimeSpan.FromMilliseconds(2); + private static readonly TimeSpan StallMax = TimeSpan.FromMilliseconds(5); + private static readonly TimeSpan StallTotal = TimeSpan.FromMilliseconds(10); + + internal Exception? AddShadowSubscriptions(INatsAccount? acc, Subscription sub) + { + if (acc is null) + return new InvalidOperationException("missing account"); + _ = sub; + return null; + } + + internal (Subscription? shadow, Exception? err) AddShadowSub(Subscription sub, object? ime) + { + _ = ime; + var copy = new Subscription + { + Subject = sub.Subject.ToArray(), + Queue = sub.Queue?.ToArray(), + Sid = sub.Sid?.ToArray(), + Qw = sub.Qw, + Client = sub.Client, + }; + return (copy, null); + } + + internal bool CanSubscribe(string subject, string? queue = null) + { + if (Perms is null) + return true; + + var checkAllow = !((IsMqtt() || Kind != ClientKind.Client) && subject.StartsWith(MqttPrefix, StringComparison.Ordinal)); + var allowed = true; + + if (checkAllow && Perms.Sub.Allow is not null) + { + var result = Perms.Sub.Allow.Match(subject); + allowed = result.PSubs.Count > 0; + if (!string.IsNullOrEmpty(queue) && result.QSubs.Count > 0) + allowed = QueueMatches(queue, result.QSubs); + + if (!allowed && Kind == ClientKind.Leaf && SubscriptionIndex.SubjectHasWildcard(subject)) + { + var reverse = Perms.Sub.Allow.ReverseMatch(subject); + allowed = reverse.PSubs.Count != 0; + } + } + + if (allowed && Perms.Sub.Deny is not null) + { + var result = Perms.Sub.Deny.Match(subject); + allowed = result.PSubs.Count == 0; + if (!string.IsNullOrEmpty(queue) && result.QSubs.Count > 0) + allowed = !QueueMatches(queue, result.QSubs); + + if (allowed && MPerms is null && SubscriptionIndex.SubjectHasWildcard(subject) && DArray is not null) + { + foreach (var deny in DArray.Keys) + { + if (SubscriptionIndex.SubjectIsSubsetMatch(deny, subject)) + { + LoadMsgDenyFilter(); + break; + } + } + } + } + + return allowed; + } + + internal static bool QueueMatches(string queue, IReadOnlyList> qsubs) + { + if (qsubs.Count == 0) + return true; + + foreach (var qsub in qsubs) + { + if (qsub.Count == 0 || qsub[0].Queue is not { Length: > 0 } q) + continue; + + var qname = Encoding.ASCII.GetString(q); + if (queue == qname) + return true; + if (SubscriptionIndex.SubjectHasWildcard(qname) && SubscriptionIndex.SubjectIsSubsetMatch(queue, qname)) + return true; + } + return false; + } + + internal void Unsubscribe(INatsAccount? acc, Subscription sub, bool force, bool remove) + { + if (!force && sub.IsClosed()) + return; + + lock (_mu) + { + if (remove && sub.Sid is { Length: > 0 } sid) + Subs.Remove(Encoding.ASCII.GetString(sid)); + sub.Close(); + } + + _ = acc; + } + + internal Exception? ProcessUnsub(byte[] arg) + { + var args = SplitArg(arg); + if (args.Count is < 1 or > 2) + return new FormatException($"processUnsub Parse Error: {Encoding.ASCII.GetString(arg)}"); + + var sid = Encoding.ASCII.GetString(args[0]); + lock (_mu) + { + _in.Subs++; + if (!Subs.TryGetValue(sid, out var sub)) + return null; + + // Max-delivery based deferred unsub is not modeled yet, so unsubscribe immediately. + Unsubscribe(_account, sub, force: true, remove: true); + } + + if (Opts.Verbose) + SendOK(); + + return null; + } + + internal bool CheckDenySub(string subject) + { + if (MPerms is null || MPerms.Deny is null) + return false; + + if (MPerms.DCache.TryGetValue(subject, out var denied)) + return denied; + + var (np, _) = MPerms.Deny.NumInterest(subject); + denied = np != 0; + MPerms.DCache[subject] = denied; + if (MPerms.DCache.Count > MaxDenyPermCacheSize) + PruneDenyCache(); + return denied; + } + + internal byte[] MsgHeaderForRouteOrLeaf(byte[] subj, byte[]? reply, RouteTarget rt, INatsAccount? acc) + { + var msg = new List(MsgScratchSize) + { + (byte)(rt.Sub?.Client?.Kind == ClientKind.Leaf ? 'L' : 'R'), + (byte)'M', + (byte)'S', + (byte)'G', + (byte)' ', + }; + + if (acc is not null && rt.Sub?.Client?.Kind == ClientKind.Router) + { + msg.AddRange(Encoding.ASCII.GetBytes(acc.Name)); + msg.Add((byte)' '); + } + + msg.AddRange(subj); + msg.Add((byte)' '); + + if (rt.Qs.Length > 0) + { + if (reply is { Length: > 0 }) + { + msg.Add((byte)'+'); + msg.Add((byte)' '); + msg.AddRange(reply); + msg.Add((byte)' '); + } + else + { + msg.Add((byte)'|'); + msg.Add((byte)' '); + } + msg.AddRange(rt.Qs); + msg.Add((byte)' '); + } + else if (reply is { Length: > 0 }) + { + msg.AddRange(reply); + msg.Add((byte)' '); + } + + var pa = ParseCtx.Pa; + if (pa.HeaderSize > 0) + { + msg.AddRange(pa.HeaderBytes ?? Encoding.ASCII.GetBytes(pa.HeaderSize.ToString())); + msg.Add((byte)' '); + msg.AddRange(pa.SizeBytes ?? Encoding.ASCII.GetBytes(pa.Size.ToString())); + } + else + { + msg.AddRange(pa.SizeBytes ?? Encoding.ASCII.GetBytes(pa.Size.ToString())); + } + + msg.Add((byte)'\r'); + msg.Add((byte)'\n'); + return msg.ToArray(); + } + + internal byte[] MsgHeader(byte[] subj, byte[]? reply, Subscription sub) + { + var pa = ParseCtx.Pa; + var hasHeader = pa.HeaderSize > 0; + var msg = new List(MsgScratchSize); + if (hasHeader) + msg.Add((byte)'H'); + msg.Add((byte)'M'); + msg.Add((byte)'S'); + msg.Add((byte)'G'); + msg.Add((byte)' '); + msg.AddRange(subj); + msg.Add((byte)' '); + + if (sub.Sid is { Length: > 0 }) + { + msg.AddRange(sub.Sid); + msg.Add((byte)' '); + } + + if (reply is { Length: > 0 }) + { + msg.AddRange(reply); + msg.Add((byte)' '); + } + + if (hasHeader) + { + msg.AddRange(pa.HeaderBytes ?? Encoding.ASCII.GetBytes(pa.HeaderSize.ToString())); + msg.Add((byte)' '); + msg.AddRange(pa.SizeBytes ?? Encoding.ASCII.GetBytes(pa.Size.ToString())); + } + else + { + msg.AddRange(pa.SizeBytes ?? Encoding.ASCII.GetBytes(pa.Size.ToString())); + } + + msg.Add((byte)'\r'); + msg.Add((byte)'\n'); + return msg.ToArray(); + } + + internal void StalledWait(ClientConnection producer) + { + if (producer._in.Tst > StallTotal) + return; + + var ttl = OutPb >= OutMp && OutMp > 0 ? StallMax : StallMin; + if (producer._in.Tst + ttl > StallTotal) + ttl = StallTotal - producer._in.Tst; + if (ttl <= TimeSpan.Zero) + return; + + var start = DateTime.UtcNow; + Thread.Sleep(ttl); + producer._in.Tst += DateTime.UtcNow - start; + } + + internal bool DeliverMsg(bool prodIsMqtt, Subscription sub, INatsAccount? acc, byte[] subject, byte[] reply, byte[] mh, byte[] msg, bool gwReply) + { + _ = acc; + _ = subject; + _ = reply; + _ = gwReply; + + if (sub.IsClosed()) + return false; + + QueueOutbound(mh); + QueueOutbound(msg); + if (prodIsMqtt) + QueueOutbound("\r\n"u8.ToArray()); + + AddToPCD(this); + return true; + } + + internal void AddToPCD(ClientConnection client) + { + Pcd ??= new Dictionary(); + if (Pcd.TryAdd(client, true)) + client.OutPb += 0; + } + + internal void TrackRemoteReply(string subject, string reply) + { + _ = subject; + _rrTracking ??= new RrTracking + { + RMap = new Dictionary(StringComparer.Ordinal), + Lrt = TimeSpan.FromSeconds(1), + }; + _rrTracking.RMap ??= new Dictionary(StringComparer.Ordinal); + _rrTracking.RMap[reply] = new RespEntry { Time = DateTime.UtcNow }; + } + + internal void PruneRemoteTracking() + { + lock (_mu) + { + if (_rrTracking?.RMap is null || _rrTracking.RMap.Count == 0) + { + _rrTracking = null; + return; + } + + var now = DateTime.UtcNow; + var ttl = _rrTracking.Lrt <= TimeSpan.Zero ? TimeSpan.FromSeconds(1) : _rrTracking.Lrt; + foreach (var key in _rrTracking.RMap.Keys.ToArray()) + { + if (_rrTracking.RMap[key] is RespEntry re && now - re.Time > ttl) + _rrTracking.RMap.Remove(key); + } + + if (_rrTracking.RMap.Count == 0) + { + _rrTracking.Ptmr?.Dispose(); + _rrTracking = null; + } + } + } + + internal void PruneReplyPerms() + { + var resp = Perms?.Resp; + if (resp is null || Replies is null) + return; + + var maxMsgs = resp.MaxMsgs; + var ttl = resp.Expires; + var now = DateTime.UtcNow; + + foreach (var k in Replies.Keys.ToArray()) + { + var r = Replies[k]; + if ((maxMsgs > 0 && r.N >= maxMsgs) || (ttl > TimeSpan.Zero && now - r.Time > ttl)) + Replies.Remove(k); + } + + RepliesSincePrune = 0; + LastReplyPrune = now; + } + + internal void PruneDenyCache() + { + if (MPerms is null) + return; + + var removed = 0; + foreach (var subject in MPerms.DCache.Keys.ToArray()) + { + MPerms.DCache.Remove(subject); + if (++removed >= PruneSize) + break; + } + } + + internal void PrunePubPermsCache() + { + if (Perms is null) + return; + + for (var i = 0; i < 5; i++) + { + if (Interlocked.CompareExchange(ref Perms.PRun, 1, 0) != 0) + return; + + var removed = 0; + foreach (var key in Perms.PCache.Keys.ToArray()) + { + if (Perms.PCache.Remove(key)) + removed++; + if (removed > PruneSize && Perms.PCache.Count <= MaxPermCacheSize) + break; + } + + Interlocked.Add(ref Perms.PcsZ, -removed); + Interlocked.Exchange(ref Perms.PRun, 0); + if (Perms.PCache.Count <= MaxPermCacheSize) + return; + } + } + + internal bool PubAllowed(string subject) => PubAllowedFullCheck(subject, fullCheck: true, hasLock: false); + + internal bool PubAllowedFullCheck(string subject, bool fullCheck, bool hasLock) + { + if (Perms is null || (Perms.Pub.Allow is null && Perms.Pub.Deny is null)) + return true; + + if (Perms.PCache.TryGetValue(subject, out var cached)) + return cached; + + var checkAllow = !((IsMqtt() || Kind != ClientKind.Client) && subject.StartsWith(MqttPrefix, StringComparison.Ordinal)); + var allowed = true; + + if (checkAllow && Perms.Pub.Allow is not null) + { + var (np, _) = Perms.Pub.Allow.NumInterest(subject); + allowed = np != 0; + } + + if (allowed && Perms.Pub.Deny is not null) + { + var (np, _) = Perms.Pub.Deny.NumInterest(subject); + allowed = np == 0; + } + + if (!allowed && fullCheck && Perms.Resp is not null && Replies is not null) + { + if (hasLock) + { + if (Replies.TryGetValue(subject, out var resp)) + { + resp.N++; + if ((Perms.Resp.MaxMsgs > 0 && resp.N > Perms.Resp.MaxMsgs) + || (Perms.Resp.Expires > TimeSpan.Zero && DateTime.UtcNow - resp.Time > Perms.Resp.Expires)) + { + Replies.Remove(subject); + } + else + { + Replies[subject] = resp; + allowed = true; + } + } + } + else + { + lock (_mu) + { + if (Replies.TryGetValue(subject, out var resp)) + { + resp.N++; + if ((Perms.Resp.MaxMsgs > 0 && resp.N > Perms.Resp.MaxMsgs) + || (Perms.Resp.Expires > TimeSpan.Zero && DateTime.UtcNow - resp.Time > Perms.Resp.Expires)) + { + Replies.Remove(subject); + } + else + { + Replies[subject] = resp; + allowed = true; + } + } + } + } + } + else + { + Perms.PCache[subject] = allowed; + if (Interlocked.Increment(ref Perms.PcsZ) > MaxPermCacheSize) + PrunePubPermsCache(); + } + + return allowed; + } + + internal static bool IsServiceReply(byte[] reply) => + reply.Length > 3 && Encoding.ASCII.GetString(reply, 0, 4) == ReplyPrefix; +} diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.cs b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.cs index 60bc7d8..516ad45 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.cs @@ -97,7 +97,7 @@ public sealed partial class ClientConnection // Connection kind and server references. internal ClientKind Kind; // mirrors c.kind internal INatsServer? Server; // mirrors c.srv - internal INatsAccount? Account; // mirrors c.acc + internal INatsAccount? _account; // mirrors c.acc internal ClientPermissions? Perms; // mirrors c.perms internal MsgDeny? MPerms; // mirrors c.mperms @@ -439,15 +439,15 @@ public sealed partial class ClientConnection if (!acc.IsValid) throw new BadAccountException(); // Deregister from previous account. - if (Account is not null) + if (_account is not null) { - var prev = Account.RemoveClient(this); + var prev = _account.RemoveClient(this); if (prev == 1) Server?.DecActiveAccounts(); } lock (_mu) { - Account = acc; + _account = acc; ApplyAccountLimits(); } @@ -503,7 +503,7 @@ public sealed partial class ClientConnection /// internal void ApplyAccountLimits() { - if (Account is null || (Kind != ClientKind.Client && Kind != ClientKind.Leaf)) + if (_account is null || (Kind != ClientKind.Client && Kind != ClientKind.Leaf)) return; Volatile.Write(ref _mpay, JwtNoLimit); @@ -1111,7 +1111,7 @@ public sealed partial class ClientConnection internal void SetAccount(INatsAccount? acc) { - lock (_mu) { Account = acc; } + lock (_mu) { _account = acc; } } internal void SetAccount(Account? acc) => SetAccount(acc as INatsAccount); @@ -1360,25 +1360,29 @@ public sealed partial class ClientConnection // Account / server helpers (features 540-545) // ========================================================================= - internal INatsAccount? GetAccount() + internal INatsAccount? Account() { - lock (_mu) { return Account; } + lock (_mu) { return _account; } } + internal INatsAccount? GetAccount() => Account(); + // ========================================================================= // TLS handshake helpers (features 546-548) // ========================================================================= internal async Task DoTlsServerHandshakeAsync(SslServerAuthenticationOptions opts, CancellationToken ct = default) { - // Deferred: full TLS flow will be completed with server integration. - return false; + _ = ct; + return await Task.FromResult( + DoTLSServerHandshake("client", opts, Server?.Options.TlsTimeout ?? 2, Server?.Options.TlsPinnedCerts) is null); } internal async Task DoTlsClientHandshakeAsync(SslClientAuthenticationOptions opts, CancellationToken ct = default) { - // Deferred: full TLS flow will be completed with server integration. - return false; + _ = ct; + var (_, err) = DoTLSClientHandshake("route", null, opts, opts.TargetHost ?? string.Empty, Server?.Options.TlsTimeout ?? 2, null); + return await Task.FromResult(err is null); } // ========================================================================= @@ -1756,15 +1760,11 @@ public sealed partial class ClientConnection // features 471-486: processPub variants, parseSub, processSub, etc. // Implemented in full when Server+Account sessions complete. - // features 487-503: deliverMsg, addToPCD, trackRemoteReply, pruning, pubAllowed, etc. + // features 477-496 and 487-503: see ClientConnection.SubscriptionsAndDelivery.cs - // features 512-514: processServiceImport, addSubToRouteTargets, processMsgResults - - // feature 515: checkLeafClientInfoHeader - // feature 520-522: processPingTimer, adjustPingInterval, watchForStaleConnection - // feature 534-535: swapAccountAfterReload, processSubsOnConfigReload - // feature 537: reconnect - // feature 569: setFirstPingTimer + // features 497-515 and 520: see ClientConnection.InboundAndHeaders.cs + // features 521-522, 534-535, 537, 540-548, 553, 565-569: + // see ClientConnection.LifecycleAndTls.cs // ========================================================================= // IsMqtt / IsWebSocket helpers (used by clientType, not separately tracked) diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/ClientTypes.cs b/dotnet/src/ZB.MOM.NatsNet.Server/ClientTypes.cs index 608761e..ca53a62 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/ClientTypes.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/ClientTypes.cs @@ -292,10 +292,12 @@ public sealed class ClientOptions /// public sealed class ClientInfo { + public string Server { get; set; } = string.Empty; public string Start { get; set; } = string.Empty; public string Host { get; set; } = string.Empty; public ulong Id { get; set; } public string Account { get; set; } = string.Empty; + public string ServiceName { get; set; } = string.Empty; public string User { get; set; } = string.Empty; public string Name { get; set; } = string.Empty; public string Lang { get; set; } = string.Empty; @@ -311,6 +313,7 @@ public sealed class ClientInfo public bool Restart { get; set; } public bool Disconnect { get; set; } public string[]? Cluster { get; set; } + public List Alternates { get; set; } = []; public bool Service { get; set; } /// @@ -319,6 +322,13 @@ public sealed class ClientInfo /// Added here to support . /// public TimeSpan Rtt { get; set; } + + /// + /// Returns the service account for this client info payload. + /// Mirrors Go ClientInfo.serviceAccount(). + /// + public string ServiceAccount() => + string.IsNullOrWhiteSpace(ServiceName) ? Account : ServiceName; } // ============================================================================ diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.Listeners.cs b/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.Listeners.cs index 7646a1e..be072ec 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.Listeners.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.Listeners.cs @@ -728,8 +728,9 @@ public sealed partial class NatsServer lock (c) { // acc name if not the global account. - if (c.Account?.Name != null && c.Account.Name != ServerConstants.DefaultGlobalAccount) - acc = c.Account.Name; + var account = c.GetAccount(); + if (account?.Name != null && account.Name != ServerConstants.DefaultGlobalAccount) + acc = account.Name; } var cc = new ClosedClient diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ClientConnectionStubFeaturesTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ClientConnectionStubFeaturesTests.cs index 57b3e75..ccaf654 100644 --- a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ClientConnectionStubFeaturesTests.cs +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ClientConnectionStubFeaturesTests.cs @@ -6,7 +6,9 @@ using System.Text; using System.Linq; using Shouldly; using ZB.MOM.NatsNet.Server; +using ZB.MOM.NatsNet.Server.Auth; using ZB.MOM.NatsNet.Server.Internal; +using ZB.MOM.NatsNet.Server.Internal.DataStructures; namespace ZB.MOM.NatsNet.Server.Tests; @@ -152,4 +154,246 @@ public sealed class ClientConnectionStubFeaturesTests result.sub.ShouldNotBeNull(); c.Subs.Count.ShouldBe(2); } + + [Fact] + public void CanSubscribe_WithAllowAndDenyQueues_ShouldMatchExpected() + { + var c = new ClientConnection(ClientKind.Client) + { + Perms = new ClientPermissions(), + }; + c.Perms.Sub.Allow = SubscriptionIndex.NewSublistWithCache(); + c.Perms.Sub.Deny = SubscriptionIndex.NewSublistWithCache(); + c.Perms.Sub.Allow.Insert(new Subscription + { + Subject = Encoding.ASCII.GetBytes("foo.*"), + Queue = Encoding.ASCII.GetBytes("q"), + }); + c.Perms.Sub.Deny.Insert(new Subscription + { + Subject = Encoding.ASCII.GetBytes("foo.blocked"), + }); + + c.CanSubscribe("foo.bar", "q").ShouldBeTrue(); + c.CanSubscribe("foo.bar", "other").ShouldBeFalse(); + c.CanSubscribe("foo.blocked").ShouldBeFalse(); + } + + [Fact] + public void ProcessUnsub_WithKnownSid_ShouldRemoveSubscription() + { + var c = new ClientConnection(ClientKind.Client); + c.ParseSub(Encoding.ASCII.GetBytes("foo sid1"), noForward: false).ShouldBeNull(); + c.Subs.Count.ShouldBe(1); + + c.ProcessUnsub(Encoding.ASCII.GetBytes("sid1")).ShouldBeNull(); + c.Subs.ShouldNotContainKey("sid1"); + } + + [Fact] + public void MsgHeaderAndRouteHeader_ShouldIncludeSubjectsAndSizes() + { + var c = new ClientConnection(ClientKind.Client); + c.ParseCtx.Pa.HeaderSize = 10; + c.ParseCtx.Pa.Size = 30; + c.ParseCtx.Pa.HeaderBytes = Encoding.ASCII.GetBytes("10"); + c.ParseCtx.Pa.SizeBytes = Encoding.ASCII.GetBytes("30"); + + var sub = new Subscription { Sid = Encoding.ASCII.GetBytes("22") }; + var mh = c.MsgHeader(Encoding.ASCII.GetBytes("foo.bar"), Encoding.ASCII.GetBytes("_R_.x"), sub); + Encoding.ASCII.GetString(mh).ShouldContain("foo.bar 22 _R_.x"); + Encoding.ASCII.GetString(mh).ShouldContain("30"); + + var routeTarget = new RouteTarget { Sub = sub, Qs = Encoding.ASCII.GetBytes("q1 q2") }; + var rmh = c.MsgHeaderForRouteOrLeaf( + Encoding.ASCII.GetBytes("foo.bar"), + Encoding.ASCII.GetBytes("_R_.x"), + routeTarget, + null); + Encoding.ASCII.GetString(rmh).ShouldContain("foo.bar"); + Encoding.ASCII.GetString(rmh).ShouldContain("q1 q2"); + } + + [Fact] + public void PubAllowedFullCheck_ShouldHonorResponseReplyCache() + { + var c = new ClientConnection(ClientKind.Client) + { + Perms = new ClientPermissions + { + Resp = new ResponsePermission + { + MaxMsgs = 2, + Expires = TimeSpan.FromMinutes(1), + }, + }, + Replies = new Dictionary(StringComparer.Ordinal) + { + ["_R_.x"] = new RespEntry { Time = DateTime.UtcNow, N = 0 }, + }, + }; + c.Perms.Pub.Deny = SubscriptionIndex.NewSublistWithCache(); + c.Perms.Pub.Deny.Insert(new Subscription { Subject = Encoding.ASCII.GetBytes(">") }); + + c.PubAllowed("_R_.x").ShouldBeTrue(); + c.PubAllowedFullCheck("_R_.x", fullCheck: true, hasLock: true).ShouldBeTrue(); + c.PubAllowedFullCheck("_R_.x", fullCheck: true, hasLock: true).ShouldBeFalse(); + } + + [Fact] + public void InboundAndHeaderHelpers_GroupB_ShouldBehave() + { + ClientConnection.IsReservedReply(Encoding.ASCII.GetBytes("_R_.A.B")).ShouldBeTrue(); + ClientConnection.IsReservedReply(Encoding.ASCII.GetBytes("$JS.ACK.A.B")).ShouldBeTrue(); + ClientConnection.IsReservedReply(Encoding.ASCII.GetBytes("$GNR.A.B")).ShouldBeTrue(); + ClientConnection.IsReservedReply(Encoding.ASCII.GetBytes("foo.bar")).ShouldBeFalse(); + + var c = new ClientConnection(ClientKind.Client) + { + ParseCtx = { Pa = { HeaderSize = 0 } }, + }; + + var before = DateTime.UtcNow; + c.ProcessInboundMsg(Encoding.ASCII.GetBytes("data")); + c.LastIn.ShouldBeGreaterThan(before - TimeSpan.FromMilliseconds(1)); + + c.Subs["sid"] = new Subscription { Sid = Encoding.ASCII.GetBytes("sid"), Subject = Encoding.ASCII.GetBytes("foo") }; + c.SubForReply(Encoding.ASCII.GetBytes("inbox")).ShouldNotBeNull(); + + var header = ClientConnection.GenHeader(null, "X-Test", "one"); + Encoding.ASCII.GetString(ClientConnection.GetHeader("X-Test", header)!).ShouldBe("one"); + ClientConnection.GetHeaderKeyIndex("X-Test", header).ShouldBeGreaterThan(0); + ClientConnection.SliceHeader("X-Test", header).ShouldNotBeNull(); + + var replaced = ClientConnection.SetHeaderStatic("X-Test", "two", header); + Encoding.ASCII.GetString(ClientConnection.GetHeader("X-Test", replaced)!).ShouldBe("two"); + ClientConnection.RemoveHeaderIfPresent(replaced, "X-Test").ShouldBeNull(); + + var prefixed = ClientConnection.GenHeader(header, "Nats-Expected-Last-Sequence", "10"); + ClientConnection.RemoveHeaderIfPrefixPresent(prefixed!, "Nats-Expected-").ShouldNotBeNull(); + + c.ParseCtx.Pa.HeaderSize = header.Length; + var merged = new byte[header.Length + 5]; + Buffer.BlockCopy(header, 0, merged, 0, header.Length); + Buffer.BlockCopy("hello"u8.ToArray(), 0, merged, header.Length, 5); + var next = c.SetHeaderInternal("X-Test", "three", merged); + Encoding.ASCII.GetString(next).ShouldContain("X-Test: three"); + + var result = new SubscriptionIndexResult(); + result.PSubs.Add(new Subscription { Subject = Encoding.ASCII.GetBytes("foo"), Sid = Encoding.ASCII.GetBytes("1") }); + c.ProcessMsgResults(null, result, "hello\r\n"u8.ToArray(), null, Encoding.ASCII.GetBytes("foo"), null, PmrFlags.None).didDeliver.ShouldBeTrue(); + } + + [Fact] + public void LifecycleAndTlsHelpers_GroupC_ShouldBehave() + { + var logger = new CaptureLogger(); + var (server, err) = NatsServer.NewServer(new ServerOptions + { + PingInterval = TimeSpan.FromMilliseconds(120), + }); + err.ShouldBeNull(); + server.SetLogger(logger, debugFlag: true, traceFlag: true); + + using var ms = new MemoryStream(); + var c = new ClientConnection(ClientKind.Client, server, ms) + { + Cid = 42, + Host = "127.0.0.1", + Start = DateTime.UtcNow.AddSeconds(-2), + Rtt = TimeSpan.FromMilliseconds(5), + }; + + c.SetFirstPingTimer(); + GetTimer(c, "_pingTimer").ShouldNotBeNull(); + + c.WatchForStaleConnection(TimeSpan.FromMilliseconds(20), pingMax: 0); + Thread.Sleep(60); + c.IsClosed().ShouldBeTrue(); + + var temp = Account.NewAccount("A"); + temp.Sublist = SubscriptionIndex.NewSublistWithCache(); + c.SetAccount(temp); + + var registered = server.LookupOrRegisterAccount("A").Account; + registered.Sublist = SubscriptionIndex.NewSublistWithCache(); + var inserted = new Subscription + { + Subject = Encoding.ASCII.GetBytes("foo.bar"), + Sid = Encoding.ASCII.GetBytes("11"), + }; + registered.Sublist.Insert(inserted).ShouldBeNull(); + + c.SwapAccountAfterReload(); + c.GetAccount().ShouldBe(registered); + + c.Perms = new ClientPermissions(); + c.Perms.Sub.Deny = SubscriptionIndex.NewSublistWithCache(); + c.Perms.Sub.Deny.Insert(new Subscription { Subject = Encoding.ASCII.GetBytes(">") }).ShouldBeNull(); + c.Subs["22"] = new Subscription + { + Subject = Encoding.ASCII.GetBytes("foo.bar"), + Sid = Encoding.ASCII.GetBytes("22"), + }; + c.ProcessSubsOnConfigReload(new HashSet(StringComparer.Ordinal) { registered.Name }); + c.Subs.ContainsKey("22").ShouldBeFalse(); + + c.ParseCtx.Pa.Account = Encoding.ASCII.GetBytes("A"); + c.ParseCtx.Pa.Subject = Encoding.ASCII.GetBytes("foo.bar"); + c.ParseCtx.Pa.PaCache = Encoding.ASCII.GetBytes("A:foo.bar"); + var cached = c.GetAccAndResultFromCache(); + cached.Account.ShouldBe(registered); + cached.Result.ShouldNotBeNull(); + cached.Result.PSubs.Count.ShouldBeGreaterThan(0); + + var closedSub = new Subscription { Subject = Encoding.ASCII.GetBytes("foo.closed") }; + closedSub.Close(); + var inField = typeof(ClientConnection).GetField("_in", BindingFlags.Instance | BindingFlags.NonPublic)!; + var state = (ReadCacheState)inField.GetValue(c)!; + state.PaCache = new Dictionary(StringComparer.Ordinal) + { + ["closed"] = new PerAccountCache + { + Acc = registered, + Results = new SubscriptionIndexResult + { + PSubs = { closedSub }, + }, + GenId = 1, + }, + }; + inField.SetValue(c, state); + c.PruneClosedSubFromPerAccountCache(); + state = (ReadCacheState)inField.GetValue(c)!; + state.PaCache.ShouldNotBeNull(); + state.PaCache.Count.ShouldBe(0); + + var info = c.GetClientInfo(detailed: true); + info.ShouldNotBeNull(); + info!.Account.ShouldBe("A"); + info.Server.ShouldNotBeNullOrWhiteSpace(); + info.ServiceAccount().ShouldBe("A"); + + var (allowed, convertErr) = ClientConnection.ConvertAllowedConnectionTypes( + ["standard", "mqtt", "bad"]); + allowed.ShouldContain(AuthHandler.ConnectionTypes.Standard); + allowed.ShouldContain(AuthHandler.ConnectionTypes.Mqtt); + convertErr.ShouldNotBeNull(); + + c.RateLimitWarnf("warn {0}", 1); + c.RateLimitWarnf("warn {0}", 1); + logger.Warnings.Count.ShouldBe(1); + } + + private sealed class CaptureLogger : INatsLogger + { + public List Warnings { get; } = []; + + public void Noticef(string format, params object[] args) { } + public void Warnf(string format, params object[] args) => Warnings.Add(string.Format(format, args)); + public void Fatalf(string format, params object[] args) { } + public void Errorf(string format, params object[] args) { } + public void Debugf(string format, params object[] args) { } + public void Tracef(string format, params object[] args) { } + } } diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ClientTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ClientTests.cs index cf4a561..0486941 100644 --- a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ClientTests.cs +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ClientTests.cs @@ -189,6 +189,13 @@ public sealed class ClientTests c2.SetExpiration(DateTimeOffset.UtcNow.AddSeconds(-1).ToUnixTimeSeconds(), TimeSpan.Zero); SpinWait.SpinUntil(c2.IsClosed, TimeSpan.FromSeconds(2)).ShouldBeTrue(); } + + [Fact] + public void ReplyHelpers_ServiceAndReserved_ShouldClassifyPrefixes() + { + ClientConnection.IsServiceReply(Encoding.ASCII.GetBytes("_R_.A.B")).ShouldBeTrue(); + ClientConnection.IsServiceReply(Encoding.ASCII.GetBytes("foo.bar")).ShouldBeFalse(); + } } /// diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ImplBacklog/NatsServerTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ImplBacklog/NatsServerTests.cs index 5fcf1ff..a8627a4 100644 --- a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ImplBacklog/NatsServerTests.cs +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ImplBacklog/NatsServerTests.cs @@ -63,6 +63,41 @@ public sealed class NatsServerTests tlsConfig.ShouldNotBeNull(); } + [Fact] + public void RateLimitedClientLogging_ShouldSuppressDuplicates() + { + var logger = new NatsServerCaptureLogger(); + var (server, err) = NatsServer.NewServer(new ServerOptions()); + err.ShouldBeNull(); + server.SetLogger(logger, debugFlag: true, traceFlag: true); + + var c = new ClientConnection(ClientKind.Client, server, new MemoryStream()); + c.RateLimitWarnf("duplicate warning {0}", "A"); + c.RateLimitWarnf("duplicate warning {0}", "A"); + c.RateLimitFormatWarnf("format warning {0}", "B"); + c.RateLimitFormatWarnf("format warning {0}", "C"); + c.RateLimitErrorf("duplicate error {0}", "X"); + c.RateLimitErrorf("duplicate error {0}", "X"); + + logger.Warnings.Count.ShouldBe(2); + logger.Errors.Count.ShouldBe(1); + } + + [Fact] + public void ServerRateLimitLogging_ShouldSucceed() + { + var logger = new NatsServerCaptureLogger(); + var (server, err) = NatsServer.NewServer(new ServerOptions()); + err.ShouldBeNull(); + server.SetLogger(logger, debugFlag: false, traceFlag: false); + + server.RateLimitWarnf("batch17 warning"); + server.RateLimitWarnf("batch17 warning"); + + logger.Warnings.Count.ShouldBe(1); + logger.Errors.Count.ShouldBe(0); + } + [Fact] // T:2886 public void CustomRouterAuthentication_ShouldSucceed() { @@ -575,10 +610,24 @@ public sealed class NatsServerTests "TestServerShutdownDuringStart".ShouldNotBeNullOrWhiteSpace(); } + private sealed class NatsServerCaptureLogger : INatsLogger + { + public List Warnings { get; } = []; + public List Errors { get; } = []; + + public void Noticef(string format, params object[] args) { } + public void Warnf(string format, params object[] args) => Warnings.Add(string.Format(format, args)); + public void Fatalf(string format, params object[] args) { } + public void Errorf(string format, params object[] args) => Errors.Add(string.Format(format, args)); + public void Debugf(string format, params object[] args) { } + public void Tracef(string format, params object[] args) { } + } + private static void SetField(object target, string name, object? value) { target.GetType() .GetField(name, BindingFlags.Instance | BindingFlags.NonPublic)! .SetValue(target, value); } + } diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/JetStream/JetStreamFileStoreTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/JetStream/JetStreamFileStoreTests.cs index 32ab8a3..11afd8b 100644 --- a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/JetStream/JetStreamFileStoreTests.cs +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/JetStream/JetStreamFileStoreTests.cs @@ -8,6 +8,107 @@ namespace ZB.MOM.NatsNet.Server.Tests.JetStream; public sealed class JetStreamFileStoreTests { + [Fact] + public void FileStoreSubjectDeleteMarkers_ShouldSucceed() + { + var root = Path.Combine(Path.GetTempPath(), $"fs-sdm-{Guid.NewGuid():N}"); + Directory.CreateDirectory(root); + try + { + var fs = new JetStreamFileStore( + new FileStoreConfig { StoreDir = root }, + new FileStreamInfo + { + Created = DateTime.UtcNow, + Config = new StreamConfig + { + Name = "SDM", + Storage = StorageType.FileStorage, + Subjects = ["test"], + MaxAge = TimeSpan.FromSeconds(1), + AllowMsgTTL = true, + SubjectDeleteMarkerTTL = TimeSpan.FromSeconds(1), + }, + }); + + var (seq, _) = fs.StoreMsg("test", null, [1], 0); + seq.ShouldBe(1UL); + + var (removed, err) = fs.RemoveMsg(seq); + removed.ShouldBeTrue(); + err.ShouldBeNull(); + fs.State().Msgs.ShouldBe(0UL); + + fs.Stop(); + } + finally + { + Directory.Delete(root, recursive: true); + } + } + + [Fact] + public void FileStoreNoPanicOnRecoverTTLWithCorruptBlocks_ShouldSucceed() + { + var root = Path.Combine(Path.GetTempPath(), $"fs-ttl-{Guid.NewGuid():N}"); + Directory.CreateDirectory(root); + try + { + var hdr = NatsMessageHeaders.GenHeader(null, NatsHeaderConstants.JsMessageTtl, "1"); + var fs = NewStore(root, cfg => + { + cfg.AllowMsgTTL = true; + cfg.Subjects = ["foo"]; + }); + + fs.StoreMsg("foo", hdr, [1], 1).Seq.ShouldBe(1UL); + fs.Stop(); + + var reopened = NewStore(root, cfg => + { + cfg.AllowMsgTTL = true; + cfg.Subjects = ["foo"]; + }); + reopened.State().Msgs.ShouldBeGreaterThanOrEqualTo(0UL); + reopened.Stop(); + } + finally + { + Directory.Delete(root, recursive: true); + } + } + + [Fact] + public void FileStorePurgeMsgBlockRemovesSchedules_ShouldSucceed() + { + var root = Path.Combine(Path.GetTempPath(), $"fs-purge-sched-{Guid.NewGuid():N}"); + Directory.CreateDirectory(root); + try + { + var fs = NewStore(root, cfg => + { + cfg.AllowMsgSchedules = true; + cfg.Subjects = ["foo.*"]; + }); + + var hdr = NatsMessageHeaders.GenHeader(null, NatsHeaderConstants.JsSchedulePattern, "@every 10s"); + hdr = NatsMessageHeaders.GenHeader(hdr, NatsHeaderConstants.JsScheduleTarget, "foo.target"); + for (var i = 0; i < 10; i++) + fs.StoreMsg($"foo.schedule.{i}", hdr, [1], 0); + + var (purged, err) = fs.Purge(); + err.ShouldBeNull(); + purged.ShouldBe(10UL); + fs.State().Msgs.ShouldBe(0UL); + + fs.Stop(); + } + finally + { + Directory.Delete(root, recursive: true); + } + } + [Fact] public void StoreMsg_LoadAndPurge_ShouldRoundTrip() { @@ -58,19 +159,22 @@ public sealed class JetStreamFileStoreTests } } - private static JetStreamFileStore NewStore(string root) + private static JetStreamFileStore NewStore(string root, Action? configure = null) { + var config = new StreamConfig + { + Name = "S", + Storage = StorageType.FileStorage, + Subjects = ["foo", "bar"], + }; + configure?.Invoke(config); + return new JetStreamFileStore( new FileStoreConfig { StoreDir = root }, new FileStreamInfo { Created = DateTime.UtcNow, - Config = new StreamConfig - { - Name = "S", - Storage = StorageType.FileStorage, - Subjects = ["foo", "bar"], - }, + Config = config, }); } } diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/JetStream/JetStreamMemoryStoreTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/JetStream/JetStreamMemoryStoreTests.cs index e39670b..98cace7 100644 --- a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/JetStream/JetStreamMemoryStoreTests.cs +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/JetStream/JetStreamMemoryStoreTests.cs @@ -66,6 +66,30 @@ public class JetStreamMemoryStoreTests ms.Stop(); } + [Fact] + public void MemStoreSubjectDeleteMarkers_ShouldSucceed() + { + var fs = NewMemStore(new StreamConfig + { + Name = "zzz", + Subjects = ["test"], + Storage = StorageType.MemoryStorage, + MaxAge = TimeSpan.FromSeconds(1), + AllowMsgTTL = true, + SubjectDeleteMarkerTTL = TimeSpan.FromSeconds(1), + }); + + var (seq, _) = fs.StoreMsg("test", null, Bytes("x"), 0); + seq.ShouldBe(1UL); + + var (removed, err) = fs.RemoveMsg(seq); + removed.ShouldBeTrue(); + err.ShouldBeNull(); + fs.State().Msgs.ShouldBe(0UL); + + fs.Stop(); + } + [Fact] public void AllLastSeqsLocked_MatchesPublicAllLastSeqsOrdering() { diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Protocol/ProtocolParserTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Protocol/ProtocolParserTests.cs index 70fb822..8d399e2 100644 --- a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Protocol/ProtocolParserTests.cs +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Protocol/ProtocolParserTests.cs @@ -181,6 +181,20 @@ public class ProtocolParserTests Encoding.ASCII.GetString(c.ArgBuf!).ShouldBe("foo 1"); } + [Fact] + public void ClientConnection_InboundDispatchAndPingIntervalHelpers_ShouldBehave() + { + var c = new ClientConnection(ClientKind.Client); + var before = DateTime.UtcNow; + c.ProcessInboundMsg(Encoding.ASCII.GetBytes("hello")); + c.LastIn.ShouldBeGreaterThan(before - TimeSpan.FromMilliseconds(1)); + + ClientConnection.AdjustPingInterval(ClientKind.Router, TimeSpan.FromHours(1)) + .ShouldBeLessThan(TimeSpan.FromHours(1)); + ClientConnection.AdjustPingInterval(ClientKind.Gateway, TimeSpan.FromHours(1)) + .ShouldBeLessThan(TimeSpan.FromHours(1)); + } + // ===================================================================== // TestParsePub — Go test ID 2602 // ===================================================================== diff --git a/reports/current.md b/reports/current.md index 2acbf25..9785aeb 100644 --- a/reports/current.md +++ b/reports/current.md @@ -1,6 +1,6 @@ # NATS .NET Porting Status Report -Generated: 2026-03-01 00:37:14 UTC +Generated: 2026-03-01 00:54:15 UTC ## Modules (12 total)