diff --git a/src/NATS.Server/Subscriptions/SubList.cs b/src/NATS.Server/Subscriptions/SubList.cs index 0056734..9a83f0f 100644 --- a/src/NATS.Server/Subscriptions/SubList.cs +++ b/src/NATS.Server/Subscriptions/SubList.cs @@ -113,25 +113,35 @@ public sealed class SubList : IDisposable try { var key = $"{sub.RouteId}|{sub.Account}|{sub.Subject}|{sub.Queue}"; + var changed = false; if (sub.IsRemoval) { - _remoteSubs.Remove(key); - InterestChanged?.Invoke(new InterestChange( - InterestChangeKind.RemoteRemoved, - sub.Subject, - sub.Queue, - sub.Account)); + changed = _remoteSubs.Remove(key); + if (changed) + { + InterestChanged?.Invoke(new InterestChange( + InterestChangeKind.RemoteRemoved, + sub.Subject, + sub.Queue, + sub.Account)); + } } else { - _remoteSubs[key] = sub; - InterestChanged?.Invoke(new InterestChange( - InterestChangeKind.RemoteAdded, - sub.Subject, - sub.Queue, - sub.Account)); + if (!_remoteSubs.TryGetValue(key, out var existing) || existing != sub) + { + _remoteSubs[key] = sub; + changed = true; + InterestChanged?.Invoke(new InterestChange( + InterestChangeKind.RemoteAdded, + sub.Subject, + sub.Queue, + sub.Account)); + } } - Interlocked.Increment(ref _generation); + + if (changed) + Interlocked.Increment(ref _generation); } finally { diff --git a/tests/NATS.Server.Tests/Gateways/GatewayInterestIdempotencyTests.cs b/tests/NATS.Server.Tests/Gateways/GatewayInterestIdempotencyTests.cs new file mode 100644 index 0000000..dbe9150 --- /dev/null +++ b/tests/NATS.Server.Tests/Gateways/GatewayInterestIdempotencyTests.cs @@ -0,0 +1,87 @@ +using System.Net; +using System.Net.Sockets; +using System.Text; +using NATS.Server.Gateways; +using NATS.Server.Subscriptions; + +namespace NATS.Server.Tests.Gateways; + +public class GatewayInterestIdempotencyTests +{ + [Fact] + public async Task Duplicate_RSplus_or_reconnect_replay_does_not_double_count_remote_interest() + { + using var listener = new TcpListener(IPAddress.Loopback, 0); + listener.Start(); + + var port = ((IPEndPoint)listener.LocalEndpoint).Port; + using var remoteSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await remoteSocket.ConnectAsync(IPAddress.Loopback, port); + using var gatewaySocket = await listener.AcceptSocketAsync(); + await using var gateway = new GatewayConnection(gatewaySocket); + using var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + + var handshakeTask = gateway.PerformOutboundHandshakeAsync("LOCAL", timeout.Token); + (await ReadLineAsync(remoteSocket, timeout.Token)).ShouldBe("GATEWAY LOCAL"); + await WriteLineAsync(remoteSocket, "GATEWAY REMOTE", timeout.Token); + await handshakeTask; + + using var subList = new SubList(); + var remoteAdded = 0; + subList.InterestChanged += change => + { + if (change.Kind == InterestChangeKind.RemoteAdded) + remoteAdded++; + }; + + gateway.RemoteSubscriptionReceived = sub => + { + subList.ApplyRemoteSub(sub); + return Task.CompletedTask; + }; + gateway.StartLoop(timeout.Token); + + await WriteLineAsync(remoteSocket, "A+ A orders.*", timeout.Token); + await WaitForAsync(() => subList.HasRemoteInterest("A", "orders.created"), timeout.Token); + + await WriteLineAsync(remoteSocket, "A+ A orders.*", timeout.Token); + await Task.Delay(100, timeout.Token); + + subList.MatchRemote("A", "orders.created").Count.ShouldBe(1); + remoteAdded.ShouldBe(1); + } + + private static async Task ReadLineAsync(Socket socket, CancellationToken ct) + { + var bytes = new List(64); + var single = new byte[1]; + while (true) + { + var read = await socket.ReceiveAsync(single, SocketFlags.None, ct); + if (read == 0) + break; + if (single[0] == (byte)'\n') + break; + if (single[0] != (byte)'\r') + bytes.Add(single[0]); + } + + return Encoding.ASCII.GetString([.. bytes]); + } + + private static Task WriteLineAsync(Socket socket, string line, CancellationToken ct) + => socket.SendAsync(Encoding.ASCII.GetBytes($"{line}\r\n"), SocketFlags.None, ct).AsTask(); + + private static async Task WaitForAsync(Func predicate, CancellationToken ct) + { + while (!ct.IsCancellationRequested) + { + if (predicate()) + return; + + await Task.Delay(20, ct); + } + + throw new TimeoutException("Timed out waiting for condition."); + } +} diff --git a/tests/NATS.Server.Tests/LeafNodes/LeafInterestIdempotencyTests.cs b/tests/NATS.Server.Tests/LeafNodes/LeafInterestIdempotencyTests.cs new file mode 100644 index 0000000..13e27f4 --- /dev/null +++ b/tests/NATS.Server.Tests/LeafNodes/LeafInterestIdempotencyTests.cs @@ -0,0 +1,87 @@ +using System.Net; +using System.Net.Sockets; +using System.Text; +using NATS.Server.LeafNodes; +using NATS.Server.Subscriptions; + +namespace NATS.Server.Tests.LeafNodes; + +public class LeafInterestIdempotencyTests +{ + [Fact] + public async Task Duplicate_RSplus_or_reconnect_replay_does_not_double_count_remote_interest() + { + using var listener = new TcpListener(IPAddress.Loopback, 0); + listener.Start(); + + var port = ((IPEndPoint)listener.LocalEndpoint).Port; + using var remoteSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await remoteSocket.ConnectAsync(IPAddress.Loopback, port); + using var leafSocket = await listener.AcceptSocketAsync(); + await using var leaf = new LeafConnection(leafSocket); + using var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + + var handshakeTask = leaf.PerformOutboundHandshakeAsync("LOCAL", timeout.Token); + (await ReadLineAsync(remoteSocket, timeout.Token)).ShouldBe("LEAF LOCAL"); + await WriteLineAsync(remoteSocket, "LEAF REMOTE", timeout.Token); + await handshakeTask; + + using var subList = new SubList(); + var remoteAdded = 0; + subList.InterestChanged += change => + { + if (change.Kind == InterestChangeKind.RemoteAdded) + remoteAdded++; + }; + + leaf.RemoteSubscriptionReceived = sub => + { + subList.ApplyRemoteSub(sub); + return Task.CompletedTask; + }; + leaf.StartLoop(timeout.Token); + + await WriteLineAsync(remoteSocket, "LS+ A orders.*", timeout.Token); + await WaitForAsync(() => subList.HasRemoteInterest("A", "orders.created"), timeout.Token); + + await WriteLineAsync(remoteSocket, "LS+ A orders.*", timeout.Token); + await Task.Delay(100, timeout.Token); + + subList.MatchRemote("A", "orders.created").Count.ShouldBe(1); + remoteAdded.ShouldBe(1); + } + + private static async Task ReadLineAsync(Socket socket, CancellationToken ct) + { + var bytes = new List(64); + var single = new byte[1]; + while (true) + { + var read = await socket.ReceiveAsync(single, SocketFlags.None, ct); + if (read == 0) + break; + if (single[0] == (byte)'\n') + break; + if (single[0] != (byte)'\r') + bytes.Add(single[0]); + } + + return Encoding.ASCII.GetString([.. bytes]); + } + + private static Task WriteLineAsync(Socket socket, string line, CancellationToken ct) + => socket.SendAsync(Encoding.ASCII.GetBytes($"{line}\r\n"), SocketFlags.None, ct).AsTask(); + + private static async Task WaitForAsync(Func predicate, CancellationToken ct) + { + while (!ct.IsCancellationRequested) + { + if (predicate()) + return; + + await Task.Delay(20, ct); + } + + throw new TimeoutException("Timed out waiting for condition."); + } +} diff --git a/tests/NATS.Server.Tests/Routes/RouteInterestIdempotencyTests.cs b/tests/NATS.Server.Tests/Routes/RouteInterestIdempotencyTests.cs new file mode 100644 index 0000000..c596cab --- /dev/null +++ b/tests/NATS.Server.Tests/Routes/RouteInterestIdempotencyTests.cs @@ -0,0 +1,87 @@ +using System.Net; +using System.Net.Sockets; +using System.Text; +using NATS.Server.Routes; +using NATS.Server.Subscriptions; + +namespace NATS.Server.Tests.Routes; + +public class RouteInterestIdempotencyTests +{ + [Fact] + public async Task Duplicate_RSplus_or_reconnect_replay_does_not_double_count_remote_interest() + { + using var listener = new TcpListener(IPAddress.Loopback, 0); + listener.Start(); + + var port = ((IPEndPoint)listener.LocalEndpoint).Port; + using var remoteSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await remoteSocket.ConnectAsync(IPAddress.Loopback, port); + using var routeSocket = await listener.AcceptSocketAsync(); + await using var route = new RouteConnection(routeSocket); + using var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + + var handshakeTask = route.PerformOutboundHandshakeAsync("LOCAL", timeout.Token); + (await ReadLineAsync(remoteSocket, timeout.Token)).ShouldBe("ROUTE LOCAL"); + await WriteLineAsync(remoteSocket, "ROUTE REMOTE", timeout.Token); + await handshakeTask; + + using var subList = new SubList(); + var remoteAdded = 0; + subList.InterestChanged += change => + { + if (change.Kind == InterestChangeKind.RemoteAdded) + remoteAdded++; + }; + + route.RemoteSubscriptionReceived = sub => + { + subList.ApplyRemoteSub(sub); + return Task.CompletedTask; + }; + route.StartFrameLoop(timeout.Token); + + await WriteLineAsync(remoteSocket, "RS+ A orders.*", timeout.Token); + await WaitForAsync(() => subList.HasRemoteInterest("A", "orders.created"), timeout.Token); + + await WriteLineAsync(remoteSocket, "RS+ A orders.*", timeout.Token); + await Task.Delay(100, timeout.Token); + + subList.MatchRemote("A", "orders.created").Count.ShouldBe(1); + remoteAdded.ShouldBe(1); + } + + private static async Task ReadLineAsync(Socket socket, CancellationToken ct) + { + var bytes = new List(64); + var single = new byte[1]; + while (true) + { + var read = await socket.ReceiveAsync(single, SocketFlags.None, ct); + if (read == 0) + break; + if (single[0] == (byte)'\n') + break; + if (single[0] != (byte)'\r') + bytes.Add(single[0]); + } + + return Encoding.ASCII.GetString([.. bytes]); + } + + private static Task WriteLineAsync(Socket socket, string line, CancellationToken ct) + => socket.SendAsync(Encoding.ASCII.GetBytes($"{line}\r\n"), SocketFlags.None, ct).AsTask(); + + private static async Task WaitForAsync(Func predicate, CancellationToken ct) + { + while (!ct.IsCancellationRequested) + { + if (predicate()) + return; + + await Task.Delay(20, ct); + } + + throw new TimeoutException("Timed out waiting for condition."); + } +}