diff --git a/src/NATS.Server/NatsServer.cs b/src/NATS.Server/NatsServer.cs new file mode 100644 index 0000000..e2fd343 --- /dev/null +++ b/src/NATS.Server/NatsServer.cs @@ -0,0 +1,140 @@ +using System.Collections.Concurrent; +using System.Net; +using System.Net.Sockets; +using NATS.Server.Protocol; +using NATS.Server.Subscriptions; + +namespace NATS.Server; + +public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable +{ + private readonly NatsOptions _options; + private readonly ConcurrentDictionary _clients = new(); + private readonly SubList _subList = new(); + private readonly ServerInfo _serverInfo; + private Socket? _listener; + private ulong _nextClientId; + + public SubList SubList => _subList; + + public NatsServer(NatsOptions options) + { + _options = options; + _serverInfo = new ServerInfo + { + ServerId = Guid.NewGuid().ToString("N")[..20].ToUpperInvariant(), + ServerName = options.ServerName ?? $"nats-dotnet-{Environment.MachineName}", + Version = NatsProtocol.Version, + Host = options.Host, + Port = options.Port, + MaxPayload = options.MaxPayload, + }; + } + + public async Task StartAsync(CancellationToken ct) + { + _listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + _listener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true); + _listener.Bind(new IPEndPoint( + _options.Host == "0.0.0.0" ? IPAddress.Any : IPAddress.Parse(_options.Host), + _options.Port)); + _listener.Listen(128); + + try + { + while (!ct.IsCancellationRequested) + { + var socket = await _listener.AcceptAsync(ct); + var clientId = Interlocked.Increment(ref _nextClientId); + + var client = new NatsClient(clientId, socket, _options, _serverInfo); + client.Router = this; + _clients[clientId] = client; + + _ = RunClientAsync(client, ct); + } + } + catch (OperationCanceledException) { } + } + + private async Task RunClientAsync(NatsClient client, CancellationToken ct) + { + try + { + await client.RunAsync(ct); + } + catch (Exception) + { + // Client disconnected or errored + } + finally + { + RemoveClient(client); + } + } + + public void ProcessMessage(string subject, string? replyTo, ReadOnlyMemory headers, + ReadOnlyMemory payload, NatsClient sender) + { + var result = _subList.Match(subject); + + // Deliver to plain subscribers + foreach (var sub in result.PlainSubs) + { + if (sub.Client == null || sub.Client == sender && !(sender.ClientOpts?.Echo ?? true)) + continue; + + DeliverMessage(sub, subject, replyTo, headers, payload); + } + + // Deliver to one member of each queue group (round-robin) + foreach (var queueGroup in result.QueueSubs) + { + if (queueGroup.Length == 0) continue; + + // Simple round-robin -- pick based on total delivered across group + var idx = Math.Abs((int)Interlocked.Increment(ref sender.OutMsgs)) % queueGroup.Length; + // Undo the OutMsgs increment -- it will be incremented properly in SendMessageAsync + Interlocked.Decrement(ref sender.OutMsgs); + + for (int attempt = 0; attempt < queueGroup.Length; attempt++) + { + var sub = queueGroup[(idx + attempt) % queueGroup.Length]; + if (sub.Client != null && (sub.Client != sender || (sender.ClientOpts?.Echo ?? true))) + { + DeliverMessage(sub, subject, replyTo, headers, payload); + break; + } + } + } + } + + private static void DeliverMessage(Subscription sub, string subject, string? replyTo, + ReadOnlyMemory headers, ReadOnlyMemory payload) + { + var client = sub.Client; + if (client == null) return; + + // Check auto-unsub + var count = Interlocked.Increment(ref sub.MessageCount); + if (sub.MaxMessages > 0 && count > sub.MaxMessages) + return; + + // Fire and forget -- deliver asynchronously + _ = client.SendMessageAsync(subject, sub.Sid, replyTo, headers, payload, CancellationToken.None); + } + + public void RemoveClient(NatsClient client) + { + _clients.TryRemove(client.Id, out _); + client.RemoveAllSubscriptions(_subList); + } + + public void Dispose() + { + _listener?.Dispose(); + foreach (var client in _clients.Values) + client.Dispose(); + _subList.Dispose(); + } +} diff --git a/tests/NATS.Server.Tests/ServerTests.cs b/tests/NATS.Server.Tests/ServerTests.cs new file mode 100644 index 0000000..c1f7036 --- /dev/null +++ b/tests/NATS.Server.Tests/ServerTests.cs @@ -0,0 +1,116 @@ +using System.Net; +using System.Net.Sockets; +using System.Text; +using NATS.Server; + +namespace NATS.Server.Tests; + +public class ServerTests : IAsyncDisposable +{ + private readonly NatsServer _server; + private readonly int _port; + private readonly CancellationTokenSource _cts = new(); + + public ServerTests() + { + // Use random port + _port = GetFreePort(); + _server = new NatsServer(new NatsOptions { Port = _port }); + } + + public async ValueTask DisposeAsync() + { + await _cts.CancelAsync(); + _server.Dispose(); + } + + private static int GetFreePort() + { + using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + sock.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + return ((IPEndPoint)sock.LocalEndPoint!).Port; + } + + private async Task ConnectClientAsync() + { + var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await sock.ConnectAsync(IPAddress.Loopback, _port); + return sock; + } + + private static async Task ReadLineAsync(Socket sock, int bufSize = 4096) + { + var buf = new byte[bufSize]; + var n = await sock.ReceiveAsync(buf, SocketFlags.None); + return Encoding.ASCII.GetString(buf, 0, n); + } + + [Fact] + public async Task Server_accepts_connection_and_sends_INFO() + { + var serverTask = _server.StartAsync(_cts.Token); + await Task.Delay(100); // let server start + + using var client = await ConnectClientAsync(); + var response = await ReadLineAsync(client); + + Assert.StartsWith("INFO ", response); + await _cts.CancelAsync(); + } + + [Fact] + public async Task Server_basic_pubsub() + { + var serverTask = _server.StartAsync(_cts.Token); + await Task.Delay(100); + + using var pub = await ConnectClientAsync(); + using var sub = await ConnectClientAsync(); + + // Read INFO from both + await ReadLineAsync(pub); + await ReadLineAsync(sub); + + // CONNECT + SUB on subscriber + await sub.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nSUB foo 1\r\n")); + await Task.Delay(50); + + // CONNECT + PUB on publisher + await pub.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPUB foo 5\r\nHello\r\n")); + await Task.Delay(100); + + // Read MSG from subscriber + var buf = new byte[4096]; + var n = await sub.ReceiveAsync(buf, SocketFlags.None); + var msg = Encoding.ASCII.GetString(buf, 0, n); + + Assert.Contains("MSG foo 1 5\r\nHello\r\n", msg); + await _cts.CancelAsync(); + } + + [Fact] + public async Task Server_wildcard_matching() + { + var serverTask = _server.StartAsync(_cts.Token); + await Task.Delay(100); + + using var pub = await ConnectClientAsync(); + using var sub = await ConnectClientAsync(); + + await ReadLineAsync(pub); + await ReadLineAsync(sub); + + await sub.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nSUB foo.* 1\r\n")); + await Task.Delay(50); + + await pub.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPUB foo.bar 5\r\nHello\r\n")); + await Task.Delay(100); + + var buf = new byte[4096]; + var n = await sub.ReceiveAsync(buf, SocketFlags.None); + var msg = Encoding.ASCII.GetString(buf, 0, n); + + Assert.Contains("MSG foo.bar 1 5\r\n", msg); + await _cts.CancelAsync(); + } +}