feat: enforce MaxConnections limit in accept loop
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
using System.Collections.Concurrent;
|
using System.Collections.Concurrent;
|
||||||
using System.Net;
|
using System.Net;
|
||||||
using System.Net.Sockets;
|
using System.Net.Sockets;
|
||||||
|
using System.Text;
|
||||||
using Microsoft.Extensions.Logging;
|
using Microsoft.Extensions.Logging;
|
||||||
using NATS.Server.Protocol;
|
using NATS.Server.Protocol;
|
||||||
using NATS.Server.Subscriptions;
|
using NATS.Server.Subscriptions;
|
||||||
@@ -56,6 +57,32 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
|||||||
while (!ct.IsCancellationRequested)
|
while (!ct.IsCancellationRequested)
|
||||||
{
|
{
|
||||||
var socket = await _listener.AcceptAsync(ct);
|
var socket = await _listener.AcceptAsync(ct);
|
||||||
|
|
||||||
|
// Check MaxConnections before creating the client
|
||||||
|
if (_options.MaxConnections > 0 && _clients.Count >= _options.MaxConnections)
|
||||||
|
{
|
||||||
|
_logger.LogWarning("Client connection rejected: maximum connections ({MaxConnections}) exceeded",
|
||||||
|
_options.MaxConnections);
|
||||||
|
try
|
||||||
|
{
|
||||||
|
var stream = new NetworkStream(socket, ownsSocket: false);
|
||||||
|
var errBytes = Encoding.ASCII.GetBytes(
|
||||||
|
$"-ERR '{NatsProtocol.ErrMaxConnectionsExceeded}'\r\n");
|
||||||
|
await stream.WriteAsync(errBytes, ct);
|
||||||
|
await stream.FlushAsync(ct);
|
||||||
|
stream.Dispose();
|
||||||
|
}
|
||||||
|
catch (Exception ex)
|
||||||
|
{
|
||||||
|
_logger.LogDebug(ex, "Failed to send -ERR to rejected client");
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
socket.Dispose();
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
var clientId = Interlocked.Increment(ref _nextClientId);
|
var clientId = Interlocked.Increment(ref _nextClientId);
|
||||||
|
|
||||||
_logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint);
|
_logger.LogDebug("Client {ClientId} connected from {RemoteEndpoint}", clientId, socket.RemoteEndPoint);
|
||||||
|
|||||||
@@ -124,3 +124,69 @@ public class ServerTests : IAsyncLifetime
|
|||||||
msg.ShouldContain("MSG foo.bar 1 5\r\n");
|
msg.ShouldContain("MSG foo.bar 1 5\r\n");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public class MaxConnectionsTests : IAsyncLifetime
|
||||||
|
{
|
||||||
|
private readonly NatsServer _server;
|
||||||
|
private readonly int _port;
|
||||||
|
private readonly CancellationTokenSource _cts = new();
|
||||||
|
|
||||||
|
public MaxConnectionsTests()
|
||||||
|
{
|
||||||
|
_port = GetFreePort();
|
||||||
|
_server = new NatsServer(new NatsOptions { Port = _port, MaxConnections = 2 }, NullLoggerFactory.Instance);
|
||||||
|
}
|
||||||
|
|
||||||
|
public async Task InitializeAsync()
|
||||||
|
{
|
||||||
|
_ = _server.StartAsync(_cts.Token);
|
||||||
|
await _server.WaitForReadyAsync();
|
||||||
|
}
|
||||||
|
|
||||||
|
public async Task DisposeAsync()
|
||||||
|
{
|
||||||
|
await _cts.CancelAsync();
|
||||||
|
_server.Dispose();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static int GetFreePort()
|
||||||
|
{
|
||||||
|
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||||
|
sock.Bind(new IPEndPoint(IPAddress.Loopback, 0));
|
||||||
|
return ((IPEndPoint)sock.LocalEndPoint!).Port;
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Server_rejects_connection_when_max_reached()
|
||||||
|
{
|
||||||
|
using var readCts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
|
||||||
|
|
||||||
|
// Connect two clients (at limit)
|
||||||
|
var client1 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||||
|
await client1.ConnectAsync(IPAddress.Loopback, _port);
|
||||||
|
var buf = new byte[4096];
|
||||||
|
var n = await client1.ReceiveAsync(buf, SocketFlags.None, readCts.Token);
|
||||||
|
Encoding.ASCII.GetString(buf, 0, n).ShouldStartWith("INFO ");
|
||||||
|
|
||||||
|
var client2 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||||
|
await client2.ConnectAsync(IPAddress.Loopback, _port);
|
||||||
|
n = await client2.ReceiveAsync(buf, SocketFlags.None, readCts.Token);
|
||||||
|
Encoding.ASCII.GetString(buf, 0, n).ShouldStartWith("INFO ");
|
||||||
|
|
||||||
|
// Third client should be rejected
|
||||||
|
var client3 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||||
|
await client3.ConnectAsync(IPAddress.Loopback, _port);
|
||||||
|
|
||||||
|
n = await client3.ReceiveAsync(buf, SocketFlags.None, readCts.Token);
|
||||||
|
var response = Encoding.ASCII.GetString(buf, 0, n);
|
||||||
|
response.ShouldContain("-ERR 'maximum connections exceeded'");
|
||||||
|
|
||||||
|
// Connection should be closed
|
||||||
|
n = await client3.ReceiveAsync(buf, SocketFlags.None, readCts.Token);
|
||||||
|
n.ShouldBe(0);
|
||||||
|
|
||||||
|
client1.Dispose();
|
||||||
|
client2.Dispose();
|
||||||
|
client3.Dispose();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user