Files
natsdotnet/src/NATS.Server/LeafNodes/LeafNodeManager.cs
2026-02-23 12:11:19 -05:00

214 lines
6.8 KiB
C#

using System.Collections.Concurrent;
using System.Net;
using System.Net.Sockets;
using Microsoft.Extensions.Logging;
using NATS.Server.Configuration;
using NATS.Server.Subscriptions;
namespace NATS.Server.LeafNodes;
public sealed class LeafNodeManager : IAsyncDisposable
{
private readonly LeafNodeOptions _options;
private readonly ServerStats _stats;
private readonly string _serverId;
private readonly Action<RemoteSubscription> _remoteSubSink;
private readonly Action<LeafMessage> _messageSink;
private readonly ILogger<LeafNodeManager> _logger;
private readonly ConcurrentDictionary<string, LeafConnection> _connections = new(StringComparer.Ordinal);
private CancellationTokenSource? _cts;
private Socket? _listener;
private Task? _acceptLoopTask;
public string ListenEndpoint => $"{_options.Host}:{_options.Port}";
public LeafNodeManager(
LeafNodeOptions options,
ServerStats stats,
string serverId,
Action<RemoteSubscription> remoteSubSink,
Action<LeafMessage> messageSink,
ILogger<LeafNodeManager> logger)
{
_options = options;
_stats = stats;
_serverId = serverId;
_remoteSubSink = remoteSubSink;
_messageSink = messageSink;
_logger = logger;
}
public Task StartAsync(CancellationToken ct)
{
_cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
_listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_listener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true);
_listener.Bind(new IPEndPoint(IPAddress.Parse(_options.Host), _options.Port));
_listener.Listen(128);
if (_options.Port == 0)
_options.Port = ((IPEndPoint)_listener.LocalEndPoint!).Port;
_acceptLoopTask = Task.Run(() => AcceptLoopAsync(_cts.Token));
foreach (var remote in _options.Remotes.Distinct(StringComparer.OrdinalIgnoreCase))
_ = Task.Run(() => ConnectWithRetryAsync(remote, _cts.Token));
_logger.LogDebug("Leaf manager started (listen={Host}:{Port})", _options.Host, _options.Port);
return Task.CompletedTask;
}
public async Task ForwardMessageAsync(string subject, string? replyTo, ReadOnlyMemory<byte> payload, CancellationToken ct)
{
foreach (var connection in _connections.Values)
await connection.SendMessageAsync(subject, replyTo, payload, ct);
}
public void PropagateLocalSubscription(string account, string subject, string? queue)
{
foreach (var connection in _connections.Values)
_ = connection.SendLsPlusAsync(account, subject, queue, _cts?.Token ?? CancellationToken.None);
}
public void PropagateLocalUnsubscription(string account, string subject, string? queue)
{
foreach (var connection in _connections.Values)
_ = connection.SendLsMinusAsync(account, subject, queue, _cts?.Token ?? CancellationToken.None);
}
public async ValueTask DisposeAsync()
{
if (_cts == null)
return;
await _cts.CancelAsync();
_listener?.Dispose();
if (_acceptLoopTask != null)
await _acceptLoopTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
foreach (var connection in _connections.Values)
await connection.DisposeAsync();
_connections.Clear();
Interlocked.Exchange(ref _stats.Leafs, 0);
_cts.Dispose();
_cts = null;
_logger.LogDebug("Leaf manager stopped");
}
private async Task AcceptLoopAsync(CancellationToken ct)
{
while (!ct.IsCancellationRequested)
{
Socket socket;
try
{
socket = await _listener!.AcceptAsync(ct);
}
catch
{
break;
}
_ = Task.Run(() => HandleInboundAsync(socket, ct), ct);
}
}
private async Task HandleInboundAsync(Socket socket, CancellationToken ct)
{
var connection = new LeafConnection(socket);
try
{
await connection.PerformInboundHandshakeAsync(_serverId, ct);
Register(connection);
}
catch
{
await connection.DisposeAsync();
}
}
private async Task ConnectWithRetryAsync(string remote, CancellationToken ct)
{
while (!ct.IsCancellationRequested)
{
try
{
var endPoint = ParseEndpoint(remote);
var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await socket.ConnectAsync(endPoint.Address, endPoint.Port, ct);
var connection = new LeafConnection(socket);
await connection.PerformOutboundHandshakeAsync(_serverId, ct);
Register(connection);
return;
}
catch (OperationCanceledException)
{
return;
}
catch (Exception ex)
{
_logger.LogDebug(ex, "Leaf connect retry for {Remote}", remote);
}
try
{
await Task.Delay(250, ct);
}
catch (OperationCanceledException)
{
return;
}
}
}
private void Register(LeafConnection connection)
{
var key = $"{connection.RemoteId}:{connection.RemoteEndpoint}:{Guid.NewGuid():N}";
if (!_connections.TryAdd(key, connection))
{
_ = connection.DisposeAsync();
return;
}
connection.RemoteSubscriptionReceived = sub =>
{
_remoteSubSink(sub);
return Task.CompletedTask;
};
connection.MessageReceived = msg =>
{
_messageSink(msg);
return Task.CompletedTask;
};
connection.StartLoop(_cts!.Token);
Interlocked.Increment(ref _stats.Leafs);
_ = Task.Run(() => WatchConnectionAsync(key, connection, _cts!.Token));
}
private async Task WatchConnectionAsync(string key, LeafConnection connection, CancellationToken ct)
{
try
{
await connection.WaitUntilClosedAsync(ct);
}
catch
{
}
finally
{
if (_connections.TryRemove(key, out _))
Interlocked.Decrement(ref _stats.Leafs);
await connection.DisposeAsync();
}
}
private static IPEndPoint ParseEndpoint(string endpoint)
{
var parts = endpoint.Split(':', 2, StringSplitOptions.TrimEntries | StringSplitOptions.RemoveEmptyEntries);
if (parts.Length != 2)
throw new FormatException($"Invalid endpoint: {endpoint}");
return new IPEndPoint(IPAddress.Parse(parts[0]), int.Parse(parts[1]));
}
}