using System.Net.Sockets; using System.Text; using System.Text.Json; using NATS.Server.Subscriptions; namespace NATS.Server.Routes; public sealed class RouteConnection(Socket socket) : IAsyncDisposable { private readonly Socket _socket = socket; private readonly NetworkStream _stream = new(socket, ownsSocket: true); private readonly SemaphoreSlim _writeGate = new(1, 1); private readonly CancellationTokenSource _closedCts = new(); private Task? _frameLoopTask; public string? RemoteServerId { get; private set; } public string RemoteEndpoint => _socket.RemoteEndPoint?.ToString() ?? Guid.NewGuid().ToString("N"); /// /// The pool index assigned to this route connection. Used for account-based /// routing to deterministically select which pool connection handles traffic /// for a given account. See . /// public int PoolIndex { get; set; } /// /// The pool size agreed upon during handshake negotiation with the remote peer. /// Defaults to 0 (no pooling / pre-negotiation state). Set after handshake completes. /// Go reference: server/route.go negotiateRoutePool. /// public int NegotiatedPoolSize { get; private set; } /// /// Negotiates the effective route pool size between local and remote peers. /// Returns Math.Min(localPoolSize, remotePoolSize), but returns 0 if /// either side is 0 for backward compatibility with peers that do not support pooling. /// Go reference: server/route.go negotiateRoutePool. /// public static int NegotiatePoolSize(int localPoolSize, int remotePoolSize) { if (localPoolSize == 0 || remotePoolSize == 0) return 0; return Math.Min(localPoolSize, remotePoolSize); } /// /// Applies the result of pool size negotiation to this connection. /// internal void SetNegotiatedPoolSize(int negotiatedPoolSize) { NegotiatedPoolSize = negotiatedPoolSize; } public Func? RemoteSubscriptionReceived { get; set; } public Func? RoutedMessageReceived { get; set; } public async Task PerformOutboundHandshakeAsync(string serverId, CancellationToken ct) { await WriteLineAsync($"ROUTE {serverId}", ct); var line = await ReadLineAsync(ct); RemoteServerId = ParseHandshake(line); } public async Task PerformInboundHandshakeAsync(string serverId, CancellationToken ct) { var line = await ReadLineAsync(ct); RemoteServerId = ParseHandshake(line); await WriteLineAsync($"ROUTE {serverId}", ct); } public void StartFrameLoop(CancellationToken ct) { if (_frameLoopTask != null) return; var linked = CancellationTokenSource.CreateLinkedTokenSource(ct, _closedCts.Token); _frameLoopTask = Task.Run(() => ReadFramesAsync(linked.Token), linked.Token); } public async Task SendRsPlusAsync(string account, string subject, string? queue, CancellationToken ct) { var frame = queue is { Length: > 0 } ? $"RS+ {account} {subject} {queue}" : $"RS+ {account} {subject}"; await WriteLineAsync(frame, ct); } public async Task SendRsMinusAsync(string account, string subject, string? queue, CancellationToken ct) { var frame = queue is { Length: > 0 } ? $"RS- {account} {subject} {queue}" : $"RS- {account} {subject}"; await WriteLineAsync(frame, ct); } public async Task SendRmsgAsync(string account, string subject, string? replyTo, ReadOnlyMemory payload, CancellationToken ct) { var replyToken = string.IsNullOrEmpty(replyTo) ? "-" : replyTo; await _writeGate.WaitAsync(ct); try { var control = Encoding.ASCII.GetBytes($"RMSG {account} {subject} {replyToken} {payload.Length}\r\n"); await _stream.WriteAsync(control, ct); if (!payload.IsEmpty) await _stream.WriteAsync(payload, ct); await _stream.WriteAsync("\r\n"u8.ToArray(), ct); await _stream.FlushAsync(ct); } finally { _writeGate.Release(); } } public async Task WaitUntilClosedAsync(CancellationToken ct) { if (_frameLoopTask == null) return; using var linked = CancellationTokenSource.CreateLinkedTokenSource(ct, _closedCts.Token); await _frameLoopTask.WaitAsync(linked.Token); } public async ValueTask DisposeAsync() { await _closedCts.CancelAsync(); if (_frameLoopTask != null) await _frameLoopTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); _closedCts.Dispose(); _writeGate.Dispose(); await _stream.DisposeAsync(); } private async Task ReadFramesAsync(CancellationToken ct) { while (!ct.IsCancellationRequested) { string line; try { line = await ReadLineAsync(ct); } catch (OperationCanceledException) { break; } catch (IOException) { break; } if (line.StartsWith("RS+ ", StringComparison.Ordinal)) { var parts = line.Split(' ', StringSplitOptions.RemoveEmptyEntries); if (RemoteSubscriptionReceived != null && TryParseAccountScopedInterest(parts, out var parsedAccount, out var parsedSubject, out var queue)) { await RemoteSubscriptionReceived(new RemoteSubscription(parsedSubject, queue, RemoteServerId ?? string.Empty, parsedAccount)); } continue; } if (line.StartsWith("RS- ", StringComparison.Ordinal)) { var parts = line.Split(' ', StringSplitOptions.RemoveEmptyEntries); if (RemoteSubscriptionReceived != null && TryParseAccountScopedInterest(parts, out var parsedAccount, out var parsedSubject, out var queue)) { await RemoteSubscriptionReceived(RemoteSubscription.Removal(parsedSubject, queue, RemoteServerId ?? string.Empty, parsedAccount)); } continue; } if (!line.StartsWith("RMSG ", StringComparison.Ordinal)) continue; var args = line.Split(' ', StringSplitOptions.RemoveEmptyEntries); if (args.Length < 4) continue; var account = "$G"; string subject; string replyToken; string sizeToken; // New format: RMSG // Legacy format: RMSG if (args.Length >= 5 && !LooksLikeSubject(args[1])) { account = args[1]; subject = args[2]; replyToken = args[3]; sizeToken = args[4]; } else { subject = args[1]; replyToken = args[2]; sizeToken = args[3]; } var reply = replyToken == "-" ? null : replyToken; if (!int.TryParse(sizeToken, out var size) || size < 0) continue; var payload = await ReadPayloadAsync(size, ct); if (RoutedMessageReceived != null) await RoutedMessageReceived(new RouteMessage(subject, reply, payload, account)); } } private async Task> ReadPayloadAsync(int size, CancellationToken ct) { var payload = new byte[size]; var offset = 0; while (offset < size) { var read = await _stream.ReadAsync(payload.AsMemory(offset, size - offset), ct); if (read == 0) throw new IOException("Route connection closed during payload read"); offset += read; } var trailer = new byte[2]; var trailerRead = 0; while (trailerRead < 2) { var read = await _stream.ReadAsync(trailer.AsMemory(trailerRead, 2 - trailerRead), ct); if (read == 0) throw new IOException("Route connection closed during payload trailer read"); trailerRead += read; } if (trailer[0] != (byte)'\r' || trailer[1] != (byte)'\n') throw new IOException("Invalid route payload trailer"); return payload; } private async Task WriteLineAsync(string line, CancellationToken ct) { await _writeGate.WaitAsync(ct); try { var bytes = Encoding.ASCII.GetBytes($"{line}\r\n"); await _stream.WriteAsync(bytes, ct); await _stream.FlushAsync(ct); } finally { _writeGate.Release(); } } private async Task ReadLineAsync(CancellationToken ct) { var bytes = new List(64); var single = new byte[1]; while (true) { var read = await _stream.ReadAsync(single, ct); if (read == 0) throw new IOException("Route connection closed"); if (single[0] == (byte)'\n') break; if (single[0] != (byte)'\r') bytes.Add(single[0]); } return Encoding.ASCII.GetString([.. bytes]); } private static string ParseHandshake(string line) { if (!line.StartsWith("ROUTE ", StringComparison.OrdinalIgnoreCase)) throw new InvalidOperationException("Invalid route handshake"); var id = line[6..].Trim(); if (id.Length == 0) throw new InvalidOperationException("Route handshake missing server id"); return id; } private static bool TryParseAccountScopedInterest(string[] parts, out string account, out string subject, out string? queue) { account = "$G"; subject = string.Empty; queue = null; if (parts.Length < 2) return false; // New format: RS+ [queue] // Legacy format: RS+ [queue] if (parts.Length >= 3 && !LooksLikeSubject(parts[1])) { account = parts[1]; subject = parts[2]; queue = parts.Length >= 4 ? parts[3] : null; return true; } subject = parts[1]; queue = parts.Length >= 3 ? parts[2] : null; return true; } private static bool LooksLikeSubject(string token) => token.Contains('.', StringComparison.Ordinal) || token.Contains('*', StringComparison.Ordinal) || token.Contains('>', StringComparison.Ordinal); public static string BuildConnectInfoJson(string serverId, IEnumerable? accounts, string? topologySnapshot) { var payload = new { server_id = serverId, accounts = (accounts ?? []).ToArray(), topology = topologySnapshot ?? string.Empty, }; return JsonSerializer.Serialize(payload); } } public sealed record RouteMessage(string Subject, string? ReplyTo, ReadOnlyMemory Payload, string Account = "$G");