From 0c12b0f6e3165b1f952ace1ebd5e5ac6e5355e36 Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 22 Feb 2026 21:44:18 -0500 Subject: [PATCH] feat: enforce MaxConnections limit in accept loop --- src/NATS.Server/NatsServer.cs | 27 +++++++++++ tests/NATS.Server.Tests/ServerTests.cs | 66 ++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/src/NATS.Server/NatsServer.cs b/src/NATS.Server/NatsServer.cs index 7e63146..5900aef 100644 --- a/src/NATS.Server/NatsServer.cs +++ b/src/NATS.Server/NatsServer.cs @@ -1,6 +1,7 @@ using System.Collections.Concurrent; using System.Net; using System.Net.Sockets; +using System.Text; using Microsoft.Extensions.Logging; using NATS.Server.Protocol; using NATS.Server.Subscriptions; @@ -56,6 +57,32 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable while (!ct.IsCancellationRequested) { var socket = await _listener.AcceptAsync(ct); + + // Check MaxConnections before creating the client + if (_options.MaxConnections > 0 && _clients.Count >= _options.MaxConnections) + { + _logger.LogWarning("Client connection rejected: maximum connections ({MaxConnections}) exceeded", + _options.MaxConnections); + try + { + var stream = new NetworkStream(socket, ownsSocket: false); + var errBytes = Encoding.ASCII.GetBytes( + $"-ERR '{NatsProtocol.ErrMaxConnectionsExceeded}'\r\n"); + await stream.WriteAsync(errBytes, ct); + await stream.FlushAsync(ct); + stream.Dispose(); + } + catch (Exception ex) + { + _logger.LogDebug(ex, "Failed to send -ERR to rejected client"); + } + finally + { + socket.Dispose(); + } + continue; + } + var clientId = Interlocked.Increment(ref _nextClientId); _logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint); diff --git a/tests/NATS.Server.Tests/ServerTests.cs b/tests/NATS.Server.Tests/ServerTests.cs index dd0f7ca..e80bf7c 100644 --- a/tests/NATS.Server.Tests/ServerTests.cs +++ b/tests/NATS.Server.Tests/ServerTests.cs @@ -124,3 +124,69 @@ public class ServerTests : IAsyncLifetime msg.ShouldContain("MSG foo.bar 1 5\r\n"); } } + +public class MaxConnectionsTests : IAsyncLifetime +{ + private readonly NatsServer _server; + private readonly int _port; + private readonly CancellationTokenSource _cts = new(); + + public MaxConnectionsTests() + { + _port = GetFreePort(); + _server = new NatsServer(new NatsOptions { Port = _port, MaxConnections = 2 }, NullLoggerFactory.Instance); + } + + public async Task InitializeAsync() + { + _ = _server.StartAsync(_cts.Token); + await _server.WaitForReadyAsync(); + } + + public async Task DisposeAsync() + { + await _cts.CancelAsync(); + _server.Dispose(); + } + + private static int GetFreePort() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + sock.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + return ((IPEndPoint)sock.LocalEndPoint!).Port; + } + + [Fact] + public async Task Server_rejects_connection_when_max_reached() + { + using var readCts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + + // Connect two clients (at limit) + var client1 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await client1.ConnectAsync(IPAddress.Loopback, _port); + var buf = new byte[4096]; + var n = await client1.ReceiveAsync(buf, SocketFlags.None, readCts.Token); + Encoding.ASCII.GetString(buf, 0, n).ShouldStartWith("INFO "); + + var client2 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await client2.ConnectAsync(IPAddress.Loopback, _port); + n = await client2.ReceiveAsync(buf, SocketFlags.None, readCts.Token); + Encoding.ASCII.GetString(buf, 0, n).ShouldStartWith("INFO "); + + // Third client should be rejected + var client3 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await client3.ConnectAsync(IPAddress.Loopback, _port); + + n = await client3.ReceiveAsync(buf, SocketFlags.None, readCts.Token); + var response = Encoding.ASCII.GetString(buf, 0, n); + response.ShouldContain("-ERR 'maximum connections exceeded'"); + + // Connection should be closed + n = await client3.ReceiveAsync(buf, SocketFlags.None, readCts.Token); + n.ShouldBe(0); + + client1.Dispose(); + client2.Dispose(); + client3.Dispose(); + } +}