Files
natsdotnet/tests/NATS.Server.Tests/ClientTests.cs
Joseph Doherty 2980a343c1 feat: integrate authentication into server accept loop and client CONNECT processing
Wire AuthService into NatsServer and NatsClient to enforce authentication
on incoming connections. The server builds an AuthService from NatsOptions,
sets auth_required in ServerInfo, and generates per-client nonces when
NKey auth is configured. NatsClient validates credentials in ProcessConnect,
enforces publish/subscribe permissions, and implements an auth timeout that
closes connections that don't send CONNECT in time. Existing tests without
auth continue to work since AuthService.IsAuthRequired is false by default.
2026-02-22 22:55:50 -05:00

136 lines
4.3 KiB
C#

using System.IO.Pipelines;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Text.Json;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Server;
using NATS.Server.Auth;
using NATS.Server.Protocol;
namespace NATS.Server.Tests;
public class ClientTests : IAsyncDisposable
{
private readonly Socket _serverSocket;
private readonly Socket _clientSocket;
private readonly NatsClient _natsClient;
private readonly CancellationTokenSource _cts = new();
public ClientTests()
{
// Create connected socket pair via loopback
var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
listener.Listen(1);
var port = ((IPEndPoint)listener.LocalEndPoint!).Port;
_clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_clientSocket.Connect(IPAddress.Loopback, port);
_serverSocket = listener.Accept();
listener.Dispose();
var serverInfo = new ServerInfo
{
ServerId = "test",
ServerName = "test",
Version = "0.1.0",
Host = "127.0.0.1",
Port = 4222,
};
var authService = AuthService.Build(new NatsOptions());
_natsClient = new NatsClient(1, _serverSocket, new NatsOptions(), serverInfo, authService, null, NullLogger.Instance);
}
public async ValueTask DisposeAsync()
{
await _cts.CancelAsync();
_natsClient.Dispose();
_clientSocket.Dispose();
}
[Fact]
public async Task Client_sends_INFO_on_start()
{
var runTask = _natsClient.RunAsync(_cts.Token);
// Read from client socket — should get INFO
var buf = new byte[4096];
var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None);
var response = Encoding.ASCII.GetString(buf, 0, n);
response.ShouldStartWith("INFO ");
response.ShouldContain("server_id");
response.ShouldContain("\r\n");
await _cts.CancelAsync();
}
[Fact]
public async Task Client_responds_PONG_to_PING()
{
var runTask = _natsClient.RunAsync(_cts.Token);
// Read INFO
var buf = new byte[4096];
await _clientSocket.ReceiveAsync(buf, SocketFlags.None);
// Send CONNECT then PING
await _clientSocket.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPING\r\n"));
// Read response — should get PONG
var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None);
var response = Encoding.ASCII.GetString(buf, 0, n);
response.ShouldContain("PONG\r\n");
await _cts.CancelAsync();
}
[Fact]
public async Task Client_SendErrAsync_writes_correct_wire_format()
{
var runTask = _natsClient.RunAsync(_cts.Token);
// Read INFO first
var buf = new byte[4096];
using var readCts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token);
// Trigger SendErrAsync
await _natsClient.SendErrAsync("Invalid Subject");
var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token);
var response = Encoding.ASCII.GetString(buf, 0, n);
response.ShouldBe("-ERR 'Invalid Subject'\r\n");
await _cts.CancelAsync();
}
[Fact]
public async Task Client_SendErrAndCloseAsync_sends_error_then_disconnects()
{
var runTask = _natsClient.RunAsync(_cts.Token);
// Read INFO first
var buf = new byte[4096];
using var readCts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token);
// Trigger SendErrAndCloseAsync
await _natsClient.SendErrAndCloseAsync("maximum connections exceeded");
var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token);
var response = Encoding.ASCII.GetString(buf, 0, n);
response.ShouldBe("-ERR 'maximum connections exceeded'\r\n");
// Connection should be closed — next read returns 0
n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None, readCts.Token);
n.ShouldBe(0);
}
}