Files
natsnet/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.LifecycleAndTls.cs

519 lines
16 KiB
C#

// Copyright 2012-2026 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
using System.Collections.Concurrent;
using System.Net.Security;
using System.Runtime.CompilerServices;
using System.Security.Authentication;
using System.Text;
using System.Linq;
using ZB.MOM.NatsNet.Server.Auth;
using ZB.MOM.NatsNet.Server.Internal;
using ZB.MOM.NatsNet.Server.Internal.DataStructures;
namespace ZB.MOM.NatsNet.Server;
public sealed partial class ClientConnection
{
private static readonly TimeSpan FirstPingInterval = TimeSpan.FromSeconds(15);
private static readonly TimeSpan FirstClientPingInterval = TimeSpan.FromSeconds(2);
private const int MaxPerAccountCacheSize = 8192;
private const string StaleErrProtoFormat = "-ERR '{0}'\r\n";
private static readonly ConditionalWeakTable<object, ConcurrentDictionary<string, DateTime>> RateLimitCacheByServer = new();
internal void WatchForStaleConnection(TimeSpan pingInterval, int pingMax)
{
if (pingInterval <= TimeSpan.Zero || pingMax < 0)
return;
var staleAfter = TimeSpan.FromTicks(pingInterval.Ticks * (pingMax + 1L));
if (pingMax == 0 && staleAfter > TimeSpan.Zero)
staleAfter = TimeSpan.FromTicks(Math.Max(1, pingInterval.Ticks / 2));
if (staleAfter <= TimeSpan.Zero)
return;
ClearPingTimer();
_pingTimer = new Timer(_ =>
{
lock (_mu)
{
if (IsClosed())
return;
Debugf("Stale Client Connection - Closing");
EnqueueProto(Encoding.ASCII.GetBytes(string.Format(StaleErrProtoFormat, "Stale Connection")));
}
CloseConnection(ClosedState.StaleConnection);
}, null, staleAfter, Timeout.InfiniteTimeSpan);
}
internal void SwapAccountAfterReload()
{
string accountName;
lock (_mu)
{
if (_account is null || Server is null)
return;
accountName = _account.Name;
}
if (Server is not NatsServer server)
return;
var (updated, _) = server.LookupAccount(accountName);
if (updated is null)
return;
lock (_mu)
{
if (!ReferenceEquals(_account, updated))
_account = updated;
}
}
internal void ProcessSubsOnConfigReload(ISet<string>? accountsWithChangedStreamImports)
{
INatsAccount? acc;
var checkPerms = false;
var checkAcc = false;
var retained = new List<Subscription>();
var removed = new List<Subscription>();
lock (_mu)
{
checkPerms = Perms is not null;
checkAcc = _account is not null;
acc = _account;
if (!checkPerms && !checkAcc)
return;
if (checkAcc && acc is not null && accountsWithChangedStreamImports is not null &&
!accountsWithChangedStreamImports.Contains(acc.Name))
{
checkAcc = false;
}
MPerms = null;
foreach (var sub in Subs.Values)
{
var subject = Encoding.ASCII.GetString(sub.Subject);
var canSub = CanSubscribe(subject);
var canQSub = sub.Queue is { Length: > 0 } q && CanSubscribe(subject, Encoding.ASCII.GetString(q));
if (!canSub && !canQSub)
{
removed.Add(sub);
}
else if (checkAcc)
{
retained.Add(sub);
}
}
}
if (checkAcc && acc is not null)
{
foreach (var sub in retained)
{
AddShadowSubscriptions(acc, sub);
}
}
foreach (var sub in removed)
{
Unsubscribe(acc, sub, force: true, remove: true);
var sid = sub.Sid is { Length: > 0 } s ? Encoding.ASCII.GetString(s) : string.Empty;
SendErr($"Permissions Violation for Subscription to \"{Encoding.ASCII.GetString(sub.Subject)}\" (sid \"{sid}\")");
Noticef("Removed sub \"{0}\" (sid \"{1}\") for \"{2}\" - not authorized",
Encoding.ASCII.GetString(sub.Subject), sid, GetAuthUser());
}
}
internal void Reconnect()
{
lock (_mu)
{
if (Flags.IsSet(ClientFlags.NoReconnect) || Server is null)
return;
}
// Route/gateway/leaf reconnect orchestration is owned by server sessions.
}
internal (INatsAccount? Account, SubscriptionIndexResult? Result) GetAccAndResultFromCache()
{
var pa = ParseCtx.Pa;
if (pa.Subject is null || pa.Subject.Length == 0)
return (null, null);
_in.PaCache ??= new Dictionary<string, PerAccountCache>(StringComparer.Ordinal);
var cacheKeyBytes = pa.PaCache is { Length: > 0 } k ? k : pa.Subject;
var cacheKey = Encoding.ASCII.GetString(cacheKeyBytes);
if (_in.PaCache.TryGetValue(cacheKey, out var cached) &&
cached.Acc is Account cachedAcc &&
cached.Results is not null &&
cachedAcc.Sublist is not null &&
cached.GenId == (ulong)cachedAcc.Sublist.GenId())
{
return (cached.Acc, cached.Results);
}
INatsAccount? acc = null;
if (Kind == ClientKind.Router && pa.Account is { Length: > 0 } && _account is not null)
{
acc = _account;
}
else if (Server is NatsServer server && pa.Account is { Length: > 0 } accountNameBytes)
{
var accountName = Encoding.ASCII.GetString(accountNameBytes);
(acc, _) = server.LookupAccount(accountName);
}
if (acc is not Account concreteAcc || concreteAcc.Sublist is null)
return (null, null);
var result = concreteAcc.Sublist.MatchBytes(pa.Subject);
if (_in.PaCache.Count >= MaxPerAccountCacheSize)
{
foreach (var key in _in.PaCache.Keys.ToArray())
{
_in.PaCache.Remove(key);
if (_in.PaCache.Count < MaxPerAccountCacheSize)
break;
}
}
_in.PaCache[cacheKey] = new PerAccountCache
{
Acc = concreteAcc,
Results = result,
GenId = (ulong)concreteAcc.Sublist.GenId(),
};
return (concreteAcc, result);
}
internal void PruneClosedSubFromPerAccountCache()
{
if (_in.PaCache is null || _in.PaCache.Count == 0)
return;
foreach (var key in _in.PaCache.Keys.ToArray())
{
var entry = _in.PaCache[key];
var result = entry.Results;
if (result is null)
{
_in.PaCache.Remove(key);
continue;
}
var remove = result.PSubs.Any(static s => s.IsClosed());
if (!remove)
{
foreach (var qsub in result.QSubs)
{
if (qsub.Any(static s => s.IsClosed()))
{
remove = true;
break;
}
}
}
if (remove)
_in.PaCache.Remove(key);
}
}
internal void AddServerAndClusterInfo(ClientInfo? ci)
{
if (ci is null)
return;
if (Server is NatsServer server)
{
ci.Server = Kind == ClientKind.Leaf ? ci.Server : server.Name();
var cluster = server.CachedClusterName();
if (!string.IsNullOrWhiteSpace(cluster))
ci.Cluster = [cluster];
}
}
internal ClientInfo? GetClientInfo(bool detailed)
{
if (Kind is not (ClientKind.Client or ClientKind.Leaf or ClientKind.JetStream or ClientKind.Account))
return null;
var ci = new ClientInfo();
if (detailed)
AddServerAndClusterInfo(ci);
lock (_mu)
{
ci.Account = _account?.Name ?? string.Empty;
ci.Rtt = Rtt;
if (!detailed)
return ci;
ci.Start = Start == default ? string.Empty : Start.ToString("O");
ci.Host = Host;
ci.Id = Cid;
ci.Name = Opts.Name;
ci.User = GetRawAuthUser();
ci.Lang = Opts.Lang;
ci.Version = Opts.Version;
ci.Jwt = Opts.Jwt;
ci.NameTag = NameTag;
ci.Kind = KindString();
ci.ClientType = ClientTypeString();
}
return ci;
}
internal Exception? DoTLSServerHandshake(
string typ,
SslServerAuthenticationOptions tlsConfig,
double timeout,
PinnedCertSet? pinnedCerts)
{
var (_, err) = DoTLSHandshake(typ, solicit: false, null, tlsConfig, null, string.Empty, timeout, pinnedCerts);
return err;
}
internal (bool resetTlsName, Exception? err) DoTLSClientHandshake(
string typ,
Uri? url,
SslClientAuthenticationOptions tlsConfig,
string tlsName,
double timeout,
PinnedCertSet? pinnedCerts)
{
return DoTLSHandshake(typ, solicit: true, url, null, tlsConfig, tlsName, timeout, pinnedCerts);
}
internal (bool resetTlsName, Exception? err) DoTLSHandshake(
string typ,
bool solicit,
Uri? url,
SslServerAuthenticationOptions? serverTlsConfig,
SslClientAuthenticationOptions? clientTlsConfig,
string tlsName,
double timeout,
PinnedCertSet? pinnedCerts)
{
if (_nc is null)
return (false, ServerErrors.ErrConnectionClosed);
var kind = Kind;
var resetTlsName = false;
Exception? err = null;
SslStream? ssl = null;
try
{
var baseStream = _nc;
if (solicit)
{
Debugf("Starting TLS {0} client handshake", typ);
var options = clientTlsConfig ?? new SslClientAuthenticationOptions();
if (string.IsNullOrWhiteSpace(options.TargetHost))
{
var host = url?.Host ?? string.Empty;
options.TargetHost = !string.IsNullOrWhiteSpace(tlsName) ? tlsName : host;
}
ssl = new SslStream(baseStream, leaveInnerStreamOpen: false);
_nc = ssl;
using var cts = timeout > 0
? new CancellationTokenSource(TimeSpan.FromSeconds(timeout))
: new CancellationTokenSource();
ssl.AuthenticateAsClientAsync(options, cts.Token).GetAwaiter().GetResult();
}
else
{
Debugf(kind == ClientKind.Client
? "Starting TLS client connection handshake"
: "Starting TLS {0} server handshake", typ);
ssl = new SslStream(baseStream, leaveInnerStreamOpen: false);
_nc = ssl;
using var cts = timeout > 0
? new CancellationTokenSource(TimeSpan.FromSeconds(timeout))
: new CancellationTokenSource();
ssl.AuthenticateAsServerAsync(serverTlsConfig ?? new SslServerAuthenticationOptions(), cts.Token)
.GetAwaiter()
.GetResult();
}
if (pinnedCerts is { Count: > 0 } && !MatchesPinnedCert(pinnedCerts))
err = new InvalidOperationException("certificate not pinned");
}
catch (AuthenticationException authEx)
{
if (solicit && !string.IsNullOrWhiteSpace(tlsName) && url is not null &&
string.Equals(url.Host, tlsName, StringComparison.OrdinalIgnoreCase))
{
resetTlsName = true;
}
err = authEx;
}
catch (OperationCanceledException)
{
err = new TimeoutException("TLS handshake timeout");
}
catch (Exception ex)
{
err = ex;
}
if (err is null)
{
lock (_mu)
{
Flags = Flags.Set(ClientFlags.HandshakeComplete);
if (IsClosed())
return (false, ServerErrors.ErrConnectionClosed);
}
return (false, null);
}
if (kind == ClientKind.Client)
Errorf("TLS handshake error: {0}", err.Message);
else
Errorf("TLS {0} handshake error: {1}", typ, err.Message);
CloseConnection(ClosedState.TlsHandshakeError);
return (resetTlsName, ServerErrors.ErrConnectionClosed);
}
internal static (HashSet<string> Allowed, Exception? Error) ConvertAllowedConnectionTypes(IEnumerable<string> cts)
{
var unknown = new List<string>();
var allowed = new HashSet<string>(StringComparer.Ordinal);
foreach (var value in cts)
{
var upper = value.ToUpperInvariant();
if (AuthHandler.ConnectionTypes.IsKnown(upper))
{
allowed.Add(upper);
}
else
{
unknown.Add(upper);
}
}
return unknown.Count == 0
? (allowed, null)
: (allowed, new ArgumentException($"invalid connection types \"{string.Join(",", unknown)}\""));
}
internal void RateLimitErrorf(string format, params object?[] args)
{
if (Server is null)
return;
var statement = string.Format(format, args);
if (!TryMarkRateLimited("ERR:" + statement))
return;
var suffix = FormatClientSuffix();
if (!string.IsNullOrWhiteSpace(String()))
Errorf("{0} - {1}{2}", String(), statement, suffix);
else
Errorf("{0}{1}", statement, suffix);
}
internal void RateLimitFormatWarnf(string format, params object?[] args)
{
if (Server is null)
return;
if (!TryMarkRateLimited("WARN_FMT:" + format))
return;
var statement = string.Format(format, args);
var suffix = FormatClientSuffix();
if (!string.IsNullOrWhiteSpace(String()))
Warnf("{0} - {1}{2}", String(), statement, suffix);
else
Warnf("{0}{1}", statement, suffix);
}
internal void RateLimitWarnf(string format, params object?[] args)
{
if (Server is null)
return;
var statement = string.Format(format, args);
if (!TryMarkRateLimited("WARN:" + statement))
return;
var suffix = FormatClientSuffix();
if (!string.IsNullOrWhiteSpace(String()))
Warnf("{0} - {1}{2}", String(), statement, suffix);
else
Warnf("{0}{1}", statement, suffix);
}
internal void RateLimitDebugf(string format, params object?[] args)
{
if (Server is null)
return;
var statement = string.Format(format, args);
if (!TryMarkRateLimited("DBG:" + statement))
return;
var suffix = FormatClientSuffix();
if (!string.IsNullOrWhiteSpace(String()))
Debugf("{0} - {1}{2}", String(), statement, suffix);
else
Debugf("{0}{1}", statement, suffix);
}
internal void SetFirstPingTimer()
{
var opts = Server?.Options;
if (opts is null)
return;
var d = opts.PingInterval;
if (Kind == ClientKind.Router && opts.Cluster.PingInterval > TimeSpan.Zero)
d = opts.Cluster.PingInterval;
if (IsWebSocket() && opts.Websocket.PingInterval > TimeSpan.Zero)
d = opts.Websocket.PingInterval;
if (!opts.DisableShortFirstPing)
{
if (Kind != ClientKind.Client)
{
if (d > FirstPingInterval)
d = FirstPingInterval;
d = AdjustPingInterval(Kind, d);
}
else if (d > FirstClientPingInterval)
{
d = FirstClientPingInterval;
}
}
var addTicks = d.Ticks > 0 ? Random.Shared.NextInt64(Math.Max(1, d.Ticks / 5)) : 0L;
d = d.Add(TimeSpan.FromTicks(addTicks));
ClearPingTimer();
_pingTimer = new Timer(_ => ProcessPingTimer(), null, d, Timeout.InfiniteTimeSpan);
}
private bool TryMarkRateLimited(string key)
{
var serverKey = (object?)Server ?? this;
var cache = RateLimitCacheByServer.GetOrCreateValue(serverKey);
return cache.TryAdd(key, DateTime.UtcNow);
}
}