From 8db2de37cd183eb4d772acc247974afe0de0c2de Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 22 Feb 2026 20:24:35 -0500 Subject: [PATCH] feat: implement NatsClient connection handler with read/write pipeline --- src/NATS.Server/NatsClient.cs | 284 ++++++++++++++++++ src/NATS.Server/Subscriptions/Subscription.cs | 3 + tests/NATS.Server.Tests/ClientTests.cs | 87 ++++++ 3 files changed, 374 insertions(+) create mode 100644 src/NATS.Server/NatsClient.cs create mode 100644 tests/NATS.Server.Tests/ClientTests.cs diff --git a/src/NATS.Server/NatsClient.cs b/src/NATS.Server/NatsClient.cs new file mode 100644 index 0000000..9dfb3db --- /dev/null +++ b/src/NATS.Server/NatsClient.cs @@ -0,0 +1,284 @@ +using System.Buffers; +using System.IO.Pipelines; +using System.Net.Sockets; +using System.Text; +using System.Text.Json; +using NATS.Server.Protocol; +using NATS.Server.Subscriptions; + +namespace NATS.Server; + +public interface IMessageRouter +{ + void ProcessMessage(string subject, string? replyTo, ReadOnlyMemory headers, + ReadOnlyMemory payload, NatsClient sender); + void RemoveClient(NatsClient client); +} + +public interface ISubListAccess +{ + SubList SubList { get; } +} + +public sealed class NatsClient : IDisposable +{ + private readonly Socket _socket; + private readonly NetworkStream _stream; + private readonly NatsOptions _options; + private readonly ServerInfo _serverInfo; + private readonly NatsParser _parser; + private readonly SemaphoreSlim _writeLock = new(1, 1); + private readonly Dictionary _subs = new(); + + public ulong Id { get; } + public ClientOptions? ClientOpts { get; private set; } + public IMessageRouter? Router { get; set; } + public bool ConnectReceived { get; private set; } + + // Stats + public long InMsgs; + public long OutMsgs; + public long InBytes; + public long OutBytes; + + public IReadOnlyDictionary Subscriptions => _subs; + + public NatsClient(ulong id, Socket socket, NatsOptions options, ServerInfo serverInfo) + { + Id = id; + _socket = socket; + _stream = new NetworkStream(socket, ownsSocket: false); + _options = options; + _serverInfo = serverInfo; + _parser = new NatsParser(options.MaxPayload); + } + + public async Task RunAsync(CancellationToken ct) + { + var pipe = new Pipe(); + try + { + // Send INFO + await SendInfoAsync(ct); + + // Start read pump and command processing in parallel + var fillTask = FillPipeAsync(pipe.Writer, ct); + var processTask = ProcessCommandsAsync(pipe.Reader, ct); + + await Task.WhenAny(fillTask, processTask); + } + catch (OperationCanceledException) { } + catch (Exception) { /* connection error -- clean up */ } + finally + { + Router?.RemoveClient(this); + } + } + + private async Task FillPipeAsync(PipeWriter writer, CancellationToken ct) + { + try + { + while (!ct.IsCancellationRequested) + { + var memory = writer.GetMemory(4096); + int bytesRead = await _stream.ReadAsync(memory, ct); + if (bytesRead == 0) + break; + + writer.Advance(bytesRead); + var result = await writer.FlushAsync(ct); + if (result.IsCompleted) + break; + } + } + finally + { + await writer.CompleteAsync(); + } + } + + private async Task ProcessCommandsAsync(PipeReader reader, CancellationToken ct) + { + try + { + while (!ct.IsCancellationRequested) + { + var result = await reader.ReadAsync(ct); + var buffer = result.Buffer; + + while (_parser.TryParse(ref buffer, out var cmd)) + { + await DispatchCommandAsync(cmd, ct); + } + + reader.AdvanceTo(buffer.Start, buffer.End); + + if (result.IsCompleted) + break; + } + } + finally + { + await reader.CompleteAsync(); + } + } + + private async ValueTask DispatchCommandAsync(ParsedCommand cmd, CancellationToken ct) + { + switch (cmd.Type) + { + case CommandType.Connect: + ProcessConnect(cmd); + break; + + case CommandType.Ping: + await WriteAsync(NatsProtocol.PongBytes, ct); + break; + + case CommandType.Pong: + // Update RTT tracking (placeholder) + break; + + case CommandType.Sub: + ProcessSub(cmd); + break; + + case CommandType.Unsub: + ProcessUnsub(cmd); + break; + + case CommandType.Pub: + case CommandType.HPub: + ProcessPub(cmd); + break; + } + } + + private void ProcessConnect(ParsedCommand cmd) + { + ClientOpts = JsonSerializer.Deserialize(cmd.Payload.Span) + ?? new ClientOptions(); + ConnectReceived = true; + } + + private void ProcessSub(ParsedCommand cmd) + { + var sub = new Subscription + { + Subject = cmd.Subject!, + Queue = cmd.Queue, + Sid = cmd.Sid!, + }; + + _subs[cmd.Sid!] = sub; + sub.Client = this; + + if (Router is ISubListAccess sl) + sl.SubList.Insert(sub); + } + + private void ProcessUnsub(ParsedCommand cmd) + { + if (!_subs.TryGetValue(cmd.Sid!, out var sub)) + return; + + if (cmd.MaxMessages > 0) + { + sub.MaxMessages = cmd.MaxMessages; + // Will be cleaned up when MessageCount reaches MaxMessages + return; + } + + _subs.Remove(cmd.Sid!); + + if (Router is ISubListAccess sl) + sl.SubList.Remove(sub); + } + + private void ProcessPub(ParsedCommand cmd) + { + Interlocked.Increment(ref InMsgs); + Interlocked.Add(ref InBytes, cmd.Payload.Length); + + ReadOnlyMemory headers = default; + ReadOnlyMemory payload = cmd.Payload; + + if (cmd.Type == CommandType.HPub && cmd.HeaderSize > 0) + { + headers = cmd.Payload[..cmd.HeaderSize]; + payload = cmd.Payload[cmd.HeaderSize..]; + } + + Router?.ProcessMessage(cmd.Subject!, cmd.ReplyTo, headers, payload, this); + } + + private async Task SendInfoAsync(CancellationToken ct) + { + var infoJson = JsonSerializer.Serialize(_serverInfo); + var infoLine = Encoding.ASCII.GetBytes($"INFO {infoJson}\r\n"); + await WriteAsync(infoLine, ct); + } + + public async Task SendMessageAsync(string subject, string sid, string? replyTo, + ReadOnlyMemory headers, ReadOnlyMemory payload, CancellationToken ct) + { + Interlocked.Increment(ref OutMsgs); + Interlocked.Add(ref OutBytes, payload.Length + headers.Length); + + byte[] line; + if (headers.Length > 0) + { + int totalSize = headers.Length + payload.Length; + line = Encoding.ASCII.GetBytes($"HMSG {subject} {sid} {(replyTo != null ? replyTo + " " : "")}{headers.Length} {totalSize}\r\n"); + } + else + { + line = Encoding.ASCII.GetBytes($"MSG {subject} {sid} {(replyTo != null ? replyTo + " " : "")}{payload.Length}\r\n"); + } + + await _writeLock.WaitAsync(ct); + try + { + await _stream.WriteAsync(line, ct); + if (headers.Length > 0) + await _stream.WriteAsync(headers, ct); + if (payload.Length > 0) + await _stream.WriteAsync(payload, ct); + await _stream.WriteAsync(NatsProtocol.CrLf, ct); + await _stream.FlushAsync(ct); + } + finally + { + _writeLock.Release(); + } + } + + private async Task WriteAsync(byte[] data, CancellationToken ct) + { + await _writeLock.WaitAsync(ct); + try + { + await _stream.WriteAsync(data, ct); + await _stream.FlushAsync(ct); + } + finally + { + _writeLock.Release(); + } + } + + public void RemoveAllSubscriptions(SubList subList) + { + foreach (var sub in _subs.Values) + subList.Remove(sub); + _subs.Clear(); + } + + public void Dispose() + { + _stream.Dispose(); + _socket.Dispose(); + _writeLock.Dispose(); + } +} diff --git a/src/NATS.Server/Subscriptions/Subscription.cs b/src/NATS.Server/Subscriptions/Subscription.cs index b77e0ad..d96095b 100644 --- a/src/NATS.Server/Subscriptions/Subscription.cs +++ b/src/NATS.Server/Subscriptions/Subscription.cs @@ -1,3 +1,5 @@ +using NATS.Server; + namespace NATS.Server.Subscriptions; public sealed class Subscription @@ -7,4 +9,5 @@ public sealed class Subscription public required string Sid { get; init; } public long MessageCount; // Interlocked public long MaxMessages; // 0 = unlimited + public NatsClient? Client { get; set; } } diff --git a/tests/NATS.Server.Tests/ClientTests.cs b/tests/NATS.Server.Tests/ClientTests.cs new file mode 100644 index 0000000..67a1a08 --- /dev/null +++ b/tests/NATS.Server.Tests/ClientTests.cs @@ -0,0 +1,87 @@ +using System.IO.Pipelines; +using System.Net; +using System.Net.Sockets; +using System.Text; +using System.Text.Json; +using NATS.Server; +using NATS.Server.Protocol; + +namespace NATS.Server.Tests; + +public class ClientTests : IAsyncDisposable +{ + private readonly Socket _serverSocket; + private readonly Socket _clientSocket; + private readonly NatsClient _natsClient; + private readonly CancellationTokenSource _cts = new(); + + public ClientTests() + { + // Create connected socket pair via loopback + var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + listener.Listen(1); + var port = ((IPEndPoint)listener.LocalEndPoint!).Port; + + _clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + _clientSocket.Connect(IPAddress.Loopback, port); + _serverSocket = listener.Accept(); + listener.Dispose(); + + var serverInfo = new ServerInfo + { + ServerId = "test", + ServerName = "test", + Version = "0.1.0", + Host = "127.0.0.1", + Port = 4222, + }; + + _natsClient = new NatsClient(1, _serverSocket, new NatsOptions(), serverInfo); + } + + public async ValueTask DisposeAsync() + { + await _cts.CancelAsync(); + _natsClient.Dispose(); + _clientSocket.Dispose(); + } + + [Fact] + public async Task Client_sends_INFO_on_start() + { + var runTask = _natsClient.RunAsync(_cts.Token); + + // Read from client socket — should get INFO + var buf = new byte[4096]; + var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None); + var response = Encoding.ASCII.GetString(buf, 0, n); + + Assert.StartsWith("INFO ", response); + Assert.Contains("server_id", response); + Assert.Contains("\r\n", response); + + await _cts.CancelAsync(); + } + + [Fact] + public async Task Client_responds_PONG_to_PING() + { + var runTask = _natsClient.RunAsync(_cts.Token); + + // Read INFO + var buf = new byte[4096]; + await _clientSocket.ReceiveAsync(buf, SocketFlags.None); + + // Send CONNECT then PING + await _clientSocket.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPING\r\n")); + + // Read response — should get PONG + var n = await _clientSocket.ReceiveAsync(buf, SocketFlags.None); + var response = Encoding.ASCII.GetString(buf, 0, n); + + Assert.Contains("PONG\r\n", response); + + await _cts.CancelAsync(); + } +}