// Copyright 2012-2026 The NATS Authors // Licensed under the Apache License, Version 2.0 using System.Reflection; using System.Text; using System.Linq; using Shouldly; using ZB.MOM.NatsNet.Server; using ZB.MOM.NatsNet.Server.Auth; using ZB.MOM.NatsNet.Server.Internal; using ZB.MOM.NatsNet.Server.Internal.DataStructures; namespace ZB.MOM.NatsNet.Server.Tests; public sealed class ClientConnectionStubFeaturesTests { [Fact] public void ProcessConnect_ProcessPong_AndTimers_ShouldBehave() { var (server, err) = NatsServer.NewServer(new ServerOptions { PingInterval = TimeSpan.FromMilliseconds(20), AuthTimeout = 0.1, }); err.ShouldBeNull(); using var ms = new MemoryStream(); var c = new ClientConnection(ClientKind.Client, server, ms) { Cid = 9, Trace = true, }; var connectJson = Encoding.UTF8.GetBytes("{\"echo\":false,\"headers\":true,\"name\":\"unit\"}"); c.ProcessConnect(connectJson); c.Opts.Name.ShouldBe("unit"); c.Echo.ShouldBeFalse(); c.Headers.ShouldBeTrue(); c.RttStart = DateTime.UtcNow - TimeSpan.FromMilliseconds(50); c.ProcessPong(); c.GetRttValue().ShouldBeGreaterThan(TimeSpan.Zero); c.SetPingTimer(); GetTimer(c, "_pingTimer").ShouldNotBeNull(); c.SetAuthTimer(TimeSpan.FromMilliseconds(20)); GetTimer(c, "_atmr").ShouldNotBeNull(); c.TraceMsg(Encoding.UTF8.GetBytes("MSG")); c.FlushSignal(); c.UpdateS2AutoCompressionLevel(); c.SetExpirationTimer(TimeSpan.Zero); c.IsClosed().ShouldBeTrue(); } private static Timer? GetTimer(ClientConnection c, string field) { return (Timer?)typeof(ClientConnection) .GetField(field, BindingFlags.Instance | BindingFlags.NonPublic)! .GetValue(c); } [Fact] public void QueueOutbound_ChunkingAndPendingBytes_ShouldTrackState() { var c = new ClientConnection(ClientKind.Client) { OutMp = 100_000, }; c.QueueOutbound(new byte[70_000]); c.OutPb.ShouldBe(70_000); c.OutNb.Count.ShouldBeGreaterThan(1); c.OutNb.Sum(chunk => chunk.Count).ShouldBe(70_000); } [Fact] public void FlushOutbound_WithoutServerOrConn_ShouldNoOpTrue() { var c = new ClientConnection(ClientKind.Client); c.QueueOutbound(Encoding.ASCII.GetBytes("hello")); c.FlushOutbound().ShouldBeTrue(); c.OutPb.ShouldBe(5); } [Fact] public void HandleWriteTimeout_ClosePolicy_ShouldMarkClosed() { var c = new ClientConnection(ClientKind.Client) { OutWtp = WriteTimeoutPolicy.Close, }; c.HandleWriteTimeout(0, 100, 1).ShouldBeTrue(); c.Flags.IsSet(ClientFlags.ConnMarkedClosed).ShouldBeTrue(); c.Flags.IsSet(ClientFlags.SkipFlushOnClose).ShouldBeTrue(); } [Fact] public void HandleWriteTimeout_RetryPolicy_ShouldSetSlowConsumerFlag() { var c = new ClientConnection(ClientKind.Client) { OutWtp = WriteTimeoutPolicy.Retry, }; c.HandleWriteTimeout(1, 100, 2).ShouldBeFalse(); c.Flags.IsSet(ClientFlags.IsSlowConsumer).ShouldBeTrue(); } [Fact] public void ProcessPubAndHeaderPubWrappers_ShouldPopulateParseContext() { var c = new ClientConnection(ClientKind.Client) { Headers = true, }; c.ProcessPub(Encoding.ASCII.GetBytes("foo 5")).ShouldBeNull(); Encoding.ASCII.GetString(c.ParseCtx.Pa.Subject!).ShouldBe("foo"); c.ParseCtx.Pa.Size.ShouldBe(5); c.ProcessHeaderPub(Encoding.ASCII.GetBytes("foo 3 5"), null).ShouldBeNull(); Encoding.ASCII.GetString(c.ParseCtx.Pa.Subject!).ShouldBe("foo"); c.ParseCtx.Pa.HeaderSize.ShouldBe(3); c.ParseCtx.Pa.Size.ShouldBe(5); } [Fact] public void SplitArgParseSubAndProcessSub_ShouldCreateSubscriptions() { var tokens = ClientConnection.SplitArg(Encoding.ASCII.GetBytes("foo queue sid\r\n")); tokens.Count.ShouldBe(3); Encoding.ASCII.GetString(tokens[0]).ShouldBe("foo"); Encoding.ASCII.GetString(tokens[1]).ShouldBe("queue"); Encoding.ASCII.GetString(tokens[2]).ShouldBe("sid"); var c = new ClientConnection(ClientKind.Client); c.ParseSub(Encoding.ASCII.GetBytes("foo queue sid"), noForward: true).ShouldBeNull(); c.Subs.Count.ShouldBe(1); var result = c.ProcessSubEx( Encoding.ASCII.GetBytes("bar"), null, Encoding.ASCII.GetBytes("sid2"), noForward: false, si: false, rsi: false); result.err.ShouldBeNull(); result.sub.ShouldNotBeNull(); c.Subs.Count.ShouldBe(2); } [Fact] public void CanSubscribe_WithAllowAndDenyQueues_ShouldMatchExpected() { var c = new ClientConnection(ClientKind.Client) { Perms = new ClientPermissions(), }; c.Perms.Sub.Allow = SubscriptionIndex.NewSublistWithCache(); c.Perms.Sub.Deny = SubscriptionIndex.NewSublistWithCache(); c.Perms.Sub.Allow.Insert(new Subscription { Subject = Encoding.ASCII.GetBytes("foo.*"), Queue = Encoding.ASCII.GetBytes("q"), }); c.Perms.Sub.Deny.Insert(new Subscription { Subject = Encoding.ASCII.GetBytes("foo.blocked"), }); c.CanSubscribe("foo.bar", "q").ShouldBeTrue(); c.CanSubscribe("foo.bar", "other").ShouldBeFalse(); c.CanSubscribe("foo.blocked").ShouldBeFalse(); } [Fact] public void ProcessUnsub_WithKnownSid_ShouldRemoveSubscription() { var c = new ClientConnection(ClientKind.Client); c.ParseSub(Encoding.ASCII.GetBytes("foo sid1"), noForward: false).ShouldBeNull(); c.Subs.Count.ShouldBe(1); c.ProcessUnsub(Encoding.ASCII.GetBytes("sid1")).ShouldBeNull(); c.Subs.ShouldNotContainKey("sid1"); } [Fact] public void MsgHeaderAndRouteHeader_ShouldIncludeSubjectsAndSizes() { var c = new ClientConnection(ClientKind.Client); c.ParseCtx.Pa.HeaderSize = 10; c.ParseCtx.Pa.Size = 30; c.ParseCtx.Pa.HeaderBytes = Encoding.ASCII.GetBytes("10"); c.ParseCtx.Pa.SizeBytes = Encoding.ASCII.GetBytes("30"); var sub = new Subscription { Sid = Encoding.ASCII.GetBytes("22") }; var mh = c.MsgHeader(Encoding.ASCII.GetBytes("foo.bar"), Encoding.ASCII.GetBytes("_R_.x"), sub); Encoding.ASCII.GetString(mh).ShouldContain("foo.bar 22 _R_.x"); Encoding.ASCII.GetString(mh).ShouldContain("30"); var routeTarget = new RouteTarget { Sub = sub, Qs = Encoding.ASCII.GetBytes("q1 q2") }; var rmh = c.MsgHeaderForRouteOrLeaf( Encoding.ASCII.GetBytes("foo.bar"), Encoding.ASCII.GetBytes("_R_.x"), routeTarget, null); Encoding.ASCII.GetString(rmh).ShouldContain("foo.bar"); Encoding.ASCII.GetString(rmh).ShouldContain("q1 q2"); } [Fact] public void PubAllowedFullCheck_ShouldHonorResponseReplyCache() { var c = new ClientConnection(ClientKind.Client) { Perms = new ClientPermissions { Resp = new ResponsePermission { MaxMsgs = 2, Expires = TimeSpan.FromMinutes(1), }, }, Replies = new Dictionary(StringComparer.Ordinal) { ["_R_.x"] = new RespEntry { Time = DateTime.UtcNow, N = 0 }, }, }; c.Perms.Pub.Deny = SubscriptionIndex.NewSublistWithCache(); c.Perms.Pub.Deny.Insert(new Subscription { Subject = Encoding.ASCII.GetBytes(">") }); c.PubAllowed("_R_.x").ShouldBeTrue(); c.PubAllowedFullCheck("_R_.x", fullCheck: true, hasLock: true).ShouldBeTrue(); c.PubAllowedFullCheck("_R_.x", fullCheck: true, hasLock: true).ShouldBeFalse(); } [Fact] public void InboundAndHeaderHelpers_GroupB_ShouldBehave() { ClientConnection.IsReservedReply(Encoding.ASCII.GetBytes("_R_.A.B")).ShouldBeTrue(); ClientConnection.IsReservedReply(Encoding.ASCII.GetBytes("$JS.ACK.A.B")).ShouldBeTrue(); ClientConnection.IsReservedReply(Encoding.ASCII.GetBytes("$GNR.A.B")).ShouldBeTrue(); ClientConnection.IsReservedReply(Encoding.ASCII.GetBytes("foo.bar")).ShouldBeFalse(); var c = new ClientConnection(ClientKind.Client) { ParseCtx = { Pa = { HeaderSize = 0 } }, }; var before = DateTime.UtcNow; c.ProcessInboundMsg(Encoding.ASCII.GetBytes("data")); c.LastIn.ShouldBeGreaterThan(before - TimeSpan.FromMilliseconds(1)); c.Subs["sid"] = new Subscription { Sid = Encoding.ASCII.GetBytes("sid"), Subject = Encoding.ASCII.GetBytes("foo") }; c.SubForReply(Encoding.ASCII.GetBytes("inbox")).ShouldNotBeNull(); var header = ClientConnection.GenHeader(null, "X-Test", "one"); Encoding.ASCII.GetString(ClientConnection.GetHeader("X-Test", header)!).ShouldBe("one"); ClientConnection.GetHeaderKeyIndex("X-Test", header).ShouldBeGreaterThan(0); ClientConnection.SliceHeader("X-Test", header).ShouldNotBeNull(); var replaced = ClientConnection.SetHeaderStatic("X-Test", "two", header); Encoding.ASCII.GetString(ClientConnection.GetHeader("X-Test", replaced)!).ShouldBe("two"); ClientConnection.RemoveHeaderIfPresent(replaced, "X-Test").ShouldBeNull(); var prefixed = ClientConnection.GenHeader(header, "Nats-Expected-Last-Sequence", "10"); ClientConnection.RemoveHeaderIfPrefixPresent(prefixed!, "Nats-Expected-").ShouldNotBeNull(); c.ParseCtx.Pa.HeaderSize = header.Length; var merged = new byte[header.Length + 5]; Buffer.BlockCopy(header, 0, merged, 0, header.Length); Buffer.BlockCopy("hello"u8.ToArray(), 0, merged, header.Length, 5); var next = c.SetHeaderInternal("X-Test", "three", merged); Encoding.ASCII.GetString(next).ShouldContain("X-Test: three"); var result = new SubscriptionIndexResult(); result.PSubs.Add(new Subscription { Subject = Encoding.ASCII.GetBytes("foo"), Sid = Encoding.ASCII.GetBytes("1") }); c.ProcessMsgResults(null, result, "hello\r\n"u8.ToArray(), null, Encoding.ASCII.GetBytes("foo"), null, PmrFlags.None).didDeliver.ShouldBeTrue(); } [Fact] public void LifecycleAndTlsHelpers_GroupC_ShouldBehave() { var logger = new CaptureLogger(); var (server, err) = NatsServer.NewServer(new ServerOptions { PingInterval = TimeSpan.FromMilliseconds(120), }); err.ShouldBeNull(); server.SetLogger(logger, debugFlag: true, traceFlag: true); using var ms = new MemoryStream(); var c = new ClientConnection(ClientKind.Client, server, ms) { Cid = 42, Host = "127.0.0.1", Start = DateTime.UtcNow.AddSeconds(-2), Rtt = TimeSpan.FromMilliseconds(5), }; c.SetFirstPingTimer(); GetTimer(c, "_pingTimer").ShouldNotBeNull(); c.WatchForStaleConnection(TimeSpan.FromMilliseconds(20), pingMax: 0); var staleDeadline = DateTime.UtcNow.AddMilliseconds(500); while (!c.IsClosed() && DateTime.UtcNow < staleDeadline) Thread.Sleep(10); c.IsClosed().ShouldBeTrue(); var temp = Account.NewAccount("A"); temp.Sublist = SubscriptionIndex.NewSublistWithCache(); c.SetAccount(temp); var registered = server.LookupOrRegisterAccount("A").Account; registered.Sublist = SubscriptionIndex.NewSublistWithCache(); var inserted = new Subscription { Subject = Encoding.ASCII.GetBytes("foo.bar"), Sid = Encoding.ASCII.GetBytes("11"), }; registered.Sublist.Insert(inserted).ShouldBeNull(); c.SwapAccountAfterReload(); c.GetAccount().ShouldBe(registered); c.Perms = new ClientPermissions(); c.Perms.Sub.Deny = SubscriptionIndex.NewSublistWithCache(); c.Perms.Sub.Deny.Insert(new Subscription { Subject = Encoding.ASCII.GetBytes(">") }).ShouldBeNull(); c.Subs["22"] = new Subscription { Subject = Encoding.ASCII.GetBytes("foo.bar"), Sid = Encoding.ASCII.GetBytes("22"), }; c.ProcessSubsOnConfigReload(new HashSet(StringComparer.Ordinal) { registered.Name }); c.Subs.ContainsKey("22").ShouldBeFalse(); c.ParseCtx.Pa.Account = Encoding.ASCII.GetBytes("A"); c.ParseCtx.Pa.Subject = Encoding.ASCII.GetBytes("foo.bar"); c.ParseCtx.Pa.PaCache = Encoding.ASCII.GetBytes("A:foo.bar"); var cached = c.GetAccAndResultFromCache(); cached.Account.ShouldBe(registered); cached.Result.ShouldNotBeNull(); cached.Result.PSubs.Count.ShouldBeGreaterThan(0); var closedSub = new Subscription { Subject = Encoding.ASCII.GetBytes("foo.closed") }; closedSub.Close(); var inField = typeof(ClientConnection).GetField("_in", BindingFlags.Instance | BindingFlags.NonPublic)!; var state = (ReadCacheState)inField.GetValue(c)!; state.PaCache = new Dictionary(StringComparer.Ordinal) { ["closed"] = new PerAccountCache { Acc = registered, Results = new SubscriptionIndexResult { PSubs = { closedSub }, }, GenId = 1, }, }; inField.SetValue(c, state); c.PruneClosedSubFromPerAccountCache(); state = (ReadCacheState)inField.GetValue(c)!; state.PaCache.ShouldNotBeNull(); state.PaCache.Count.ShouldBe(0); var info = c.GetClientInfo(detailed: true); info.ShouldNotBeNull(); info!.Account.ShouldBe("A"); info.Server.ShouldNotBeNullOrWhiteSpace(); info.ServiceAccount().ShouldBe("A"); var (allowed, convertErr) = ClientConnection.ConvertAllowedConnectionTypes( ["standard", "mqtt", "bad"]); allowed.ShouldContain(AuthHandler.ConnectionTypes.Standard); allowed.ShouldContain(AuthHandler.ConnectionTypes.Mqtt); convertErr.ShouldNotBeNull(); c.RateLimitWarnf("warn {0}", 1); c.RateLimitWarnf("warn {0}", 1); logger.Warnings.Count.ShouldBe(1); } private sealed class CaptureLogger : INatsLogger { public List Warnings { get; } = []; public void Noticef(string format, params object[] args) { } public void Warnf(string format, params object[] args) => Warnings.Add(string.Format(format, args)); public void Fatalf(string format, params object[] args) { } public void Errorf(string format, params object[] args) { } public void Debugf(string format, params object[] args) { } public void Tracef(string format, params object[] args) { } } }