diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.cs b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.cs new file mode 100644 index 0000000..dfa891b --- /dev/null +++ b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.cs @@ -0,0 +1,1211 @@ +// Copyright 2012-2026 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Adapted from server/client.go in the NATS server Go source. + +using System.Net; +using System.Net.Security; +using System.Net.Sockets; +using System.Runtime.CompilerServices; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using Microsoft.Extensions.Logging; +using ZB.MOM.NatsNet.Server.Auth; +using ZB.MOM.NatsNet.Server.Internal; +using ZB.MOM.NatsNet.Server.Internal.DataStructures; +using ZB.MOM.NatsNet.Server.Protocol; + +namespace ZB.MOM.NatsNet.Server; + +// Wire protocol constants (also in ServerConstants; kept here for local use). +file static class Wires +{ + internal const string PingProto = "PING\r\n"; + internal const string PongProto = "PONG\r\n"; + internal const string ErrProto = "-ERR '{0}'\r\n"; + internal const string OkProto = "+OK\r\n"; + internal const string MsgHead = "RMSG "; + internal const int MsgScratch = 1024; + + // Buffer size tuning. + internal const int StartBufSize = 512; + internal const int MinBufSize = 64; + internal const int MaxBufSize = 65536; + internal const int ShortsToShrink = 2; + internal const int MaxFlushPending = 10; + internal const int MaxVectorSize = 1024; // IOV_MAX + + internal static readonly TimeSpan ReadLoopReport = TimeSpan.FromSeconds(2); + internal static readonly TimeSpan MaxNoRttPingBeforePong = TimeSpan.FromSeconds(2); + internal static readonly TimeSpan StallMin = TimeSpan.FromMilliseconds(2); + internal static readonly TimeSpan StallMax = TimeSpan.FromMilliseconds(5); + internal static readonly TimeSpan StallTotal = TimeSpan.FromMilliseconds(10); + + // Cache / pruning limits. + internal const int MaxResultCacheSize = 512; + internal const int MaxDenyPermCacheSize = 256; + internal const int MaxPermCacheSize = 128; + internal const int PruneSize = 32; + internal const int RouteTargetInit = 8; + internal const int ReplyPermLimit = 4096; + internal static readonly TimeSpan ReplyPruneTime = TimeSpan.FromSeconds(1); + + // Per-account cache defaults. + internal const int MaxPerAccountCacheSize = 8192; + internal static readonly TimeSpan ClosedSubsCheckInterval = TimeSpan.FromMinutes(5); + + // TLS handshake client type tags. + internal const string TlsHandshakeLeaf = "leafnode"; + internal const string TlsHandshakeMqtt = "mqtt"; + + // Allowed-connection-type group used in deny-list checks. + internal const string SysGroup = "_sys_"; + + // Message header status line bytes (UTF-8, immutable). + internal static readonly byte[] HdrLineBytes = Encoding.ASCII.GetBytes(NatsHeaderConstants.HdrLine); + internal static readonly byte[] EmptyHdrLineBytes = Encoding.ASCII.GetBytes(NatsHeaderConstants.EmptyHdrLine); +} + +/// +/// Represents an individual client connection to the NATS server. +/// Mirrors Go client struct and all its methods from server/client.go. +/// +/// +/// This is the central networking class — every connected client (NATS, MQTT, WebSocket, +/// route, gateway, leaf node, or internal) has one instance. +/// +public sealed partial class ClientConnection +{ + // ========================================================================= + // Fields — mirrors Go client struct + // ========================================================================= + + private readonly Lock _mu = new(); // mirrors c.mu sync.Mutex + + // Connection kind and server references. + internal ClientKind Kind; // mirrors c.kind + internal INatsServer? Server; // mirrors c.srv + internal INatsAccount? Account; // mirrors c.acc + internal ClientPermissions? Perms; // mirrors c.perms + internal MsgDeny? MPerms; // mirrors c.mperms + + // Connection identity. + internal ulong Cid; // mirrors c.cid + internal byte[]? Nonce; // mirrors c.nonce + internal string PubKey = string.Empty; // mirrors c.pubKey + internal string Host = string.Empty; // mirrors c.host + internal ushort Port; // mirrors c.port + internal string NameTag = string.Empty; // mirrors c.nameTag + internal string ProxyKey = string.Empty; // mirrors c.proxyKey + + // Client options (from CONNECT message). + internal ClientOptions Opts = ClientOptions.Default; + + // Flags and state. + internal ClientFlags Flags; // mirrors c.flags clientFlag + internal bool Trace; // mirrors c.trace + internal bool Echo = true; // mirrors c.echo + internal bool NoIcb; // mirrors c.noIcb + internal bool InProc; // mirrors c.iproc (in-process connection) + internal bool Headers; // mirrors c.headers + + // Limits (int32 allows atomic access). + private int _mpay; // mirrors c.mpay — max payload (signed, jwt.NoLimit = -1) + private int _msubs; // mirrors c.msubs — max subscriptions + private int _mcl; // mirrors c.mcl — max control line + + // Subscriptions. + internal Dictionary Subs = new(StringComparer.Ordinal); + internal Dictionary? Replies; + internal Dictionary? Pcd; // pending clients with data to flush + internal Dictionary? DArray; // denied subscribe patterns + + // Outbound state (simplified — full write loop ported when Server is available). + internal long OutPb; // pending bytes + internal long OutMp; // max pending snapshot + internal TimeSpan OutWdl; // write deadline snapshot + + // Timing. + internal DateTime Start; + internal DateTime Last; + internal DateTime LastIn; + internal DateTime Expires; + internal TimeSpan Rtt; + internal DateTime RttStart; + internal DateTime LastReplyPrune; + internal ushort RepliesSincePrune; + + // Scratch buffer for processMsg calls. + // Initialised with "RMSG " bytes. + internal byte[] Msgb = new byte[Wires.MsgScratch]; + + // Auth error override. + internal Exception? AuthErr; + + // Network connection (null for in-process). + private Stream? _nc; + private string _ncs = string.Empty; // cached string representation (mirrors c.ncs atomic.Value) + + // Parse state (shared with ProtocolParser). + internal ParseContext ParseCtx = new(); + + // Remote reply tracking. + private RrTracking? _rrTracking; + + // Timers. + private Timer? _atmr; // auth timer + private Timer? _pingTimer; + private Timer? _tlsTo; + + // Ping state. + private int _pingOut; // outstanding pings + + // Connection string (cached for logging). + private string _connStr = string.Empty; + + // Read cache (per-read-loop state). + private ReadCacheState _in; + + // ========================================================================= + // Constructor + // ========================================================================= + + /// + /// Creates a new client connection. + /// Callers should invoke after creation. + /// + public ClientConnection(ClientKind kind, INatsServer? server = null, Stream? nc = null) + { + Kind = kind; + Server = server; + _nc = nc; + + // Initialise scratch buffer with "RMSG " bytes. + Msgb[0] = (byte)'R'; Msgb[1] = (byte)'M'; + Msgb[2] = (byte)'S'; Msgb[3] = (byte)'G'; + Msgb[4] = (byte)' '; + } + + // ========================================================================= + // String / identity (features 398-400) + // ========================================================================= + + /// + /// Returns the cached connection string identifier. + /// Mirrors Go client.String(). + /// + public override string ToString() => _ncs; + + /// + /// Returns the nonce presented to the client during connection. + /// Mirrors Go client.GetNonce(). + /// + public byte[]? GetNonce() + { + lock (_mu) { return Nonce; } + } + + /// + /// Returns the application-supplied name for this connection. + /// Mirrors Go client.GetName(). + /// + public string GetName() + { + lock (_mu) { return Opts.Name; } + } + + /// Returns the client options. Mirrors Go client.GetOpts(). + public ClientOptions GetOpts() => Opts; + + // ========================================================================= + // TLS (feature 402) + // ========================================================================= + + /// + /// Returns TLS connection state if the connection is TLS-secured, otherwise null. + /// Mirrors Go client.GetTLSConnectionState(). + /// + public SslStream? GetTlsStream() + { + lock (_mu) { return _nc as SslStream; } + } + + // ========================================================================= + // Client type classification (features 403-404) + // ========================================================================= + + /// + /// Returns the extended client type for CLIENT-kind connections. + /// Mirrors Go client.clientType(). + /// + public ClientConnectionType ClientType() + { + if (Kind != ClientKind.Client) return ClientConnectionType.NonClient; + if (IsMqtt()) return ClientConnectionType.Mqtt; + if (IsWebSocket()) return ClientConnectionType.WebSocket; + return ClientConnectionType.Nats; + } + + private static readonly Dictionary ClientTypeStringMap = new() + { + [ClientConnectionType.NonClient] = string.Empty, + [ClientConnectionType.Nats] = "nats", + [ClientConnectionType.WebSocket] = "websocket", + [ClientConnectionType.Mqtt] = "mqtt", + }; + + internal string ClientTypeString() => + ClientTypeStringMap.TryGetValue(ClientType(), out var s) ? s : string.Empty; + + // ========================================================================= + // Subscription.close / isClosed (features 405-406) + // (These are on the Subscription type; see Internal/Subscription.cs) + // ========================================================================= + + // ========================================================================= + // Trace level (feature 407) + // ========================================================================= + + /// + /// Updates the trace flag based on server logging settings. + /// Mirrors Go client.setTraceLevel(). + /// + internal void SetTraceLevel() + { + if (Server is null) { Trace = false; return; } + Trace = Kind == ClientKind.System + ? Server.TraceSysAcc + : Server.TraceEnabled; + } + + // ========================================================================= + // initClient (feature 408) + // ========================================================================= + + /// + /// Initialises connection state after the client struct is created. + /// Must be called with _mu held. + /// Mirrors Go client.initClient(). + /// + internal void InitClient() + { + if (Server is not null) + Cid = Server.NextClientId(); + + // Snapshot options from server. + if (Server is not null) + { + var opts = Server.Options; + OutWdl = opts.WriteDeadline; + OutMp = opts.MaxPending; + _mcl = opts.MaxControlLine > 0 ? opts.MaxControlLine : ServerConstants.MaxControlLineSize; + } + else + { + _mcl = ServerConstants.MaxControlLineSize; + } + + Subs = new Dictionary(StringComparer.Ordinal); + Pcd = new Dictionary(); + Echo = true; + + SetTraceLevel(); + + // Scratch buffer "RMSG " prefix. + Msgb[0] = (byte)'R'; Msgb[1] = (byte)'M'; + Msgb[2] = (byte)'S'; Msgb[3] = (byte)'G'; + Msgb[4] = (byte)' '; + + // Snapshot connection string. + if (_nc is not null) + { + var addr = GetRemoteEndPoint(); + if (addr is not null) + { + var conn = addr.ToString() ?? string.Empty; + if (conn.Length > 0) + { + var parts = conn.Split(':', 2); + if (parts.Length == 2) + { + Host = parts[0]; + if (ushort.TryParse(parts[1], out var p)) Port = p; + } + _connStr = conn.Replace("%", "%%"); + } + } + } + + _ncs = Kind switch + { + ClientKind.Client when ClientType() == ClientConnectionType.Nats => + $"{_connStr} - cid:{Cid}", + ClientKind.Client when ClientType() == ClientConnectionType.WebSocket => + $"{_connStr} - wid:{Cid}", + ClientKind.Client => + $"{_connStr} - mid:{Cid}", + ClientKind.Router => $"{_connStr} - rid:{Cid}", + ClientKind.Gateway => $"{_connStr} - gid:{Cid}", + ClientKind.Leaf => $"{_connStr} - lid:{Cid}", + ClientKind.System => "SYSTEM", + ClientKind.JetStream => "JETSTREAM", + ClientKind.Account => "ACCOUNT", + _ => _connStr, + }; + } + + // ========================================================================= + // RemoteAddress (feature 409) + // ========================================================================= + + /// + /// Returns the remote network address of the connection, or null. + /// Mirrors Go client.RemoteAddress(). + /// + public EndPoint? RemoteAddress() + { + lock (_mu) { return GetRemoteEndPoint(); } + } + + private EndPoint? GetRemoteEndPoint() + { + if (_nc is NetworkStream ns) + { + try { return ns.Socket.RemoteEndPoint; } + catch { return null; } + } + return null; + } + + // ========================================================================= + // Account registration (features 410-417) + // ========================================================================= + + /// + /// Reports an error when registering with an account. + /// Mirrors Go client.reportErrRegisterAccount(). + /// + internal void ReportErrRegisterAccount(INatsAccount acc, Exception err) + { + if (err is TooManyAccountConnectionsException) + { + MaxAccountConnExceeded(); + return; + } + Errorf("Problem registering with account %q: %s", acc.Name, err.Message); + SendErr("Failed Account Registration"); + } + + /// + /// Returns the client kind. Mirrors Go client.Kind(). + /// + public ClientKind GetKind() + { + lock (_mu) { return Kind; } + } + + /// + /// Registers this client with an account. + /// Mirrors Go client.registerWithAccount(). + /// + internal void RegisterWithAccount(INatsAccount acc) + { + if (acc is null) throw new BadAccountException(); + if (!acc.IsValid) throw new BadAccountException(); + + // Deregister from previous account. + if (Account is not null) + { + var prev = Account.RemoveClient(this); + if (prev == 1) Server?.DecActiveAccounts(); + } + + lock (_mu) + { + Account = acc; + ApplyAccountLimits(); + } + + // Check max connection limits. + if (Kind == ClientKind.Client && acc.MaxTotalConnectionsReached()) + throw new TooManyAccountConnectionsException(); + + if (Kind == ClientKind.Leaf && acc.MaxTotalLeafNodesReached()) + throw new TooManyAccountConnectionsException(); + + // Add to new account. + var added = acc.AddClient(this); + if (added == 0) Server?.IncActiveAccounts(); + } + + /// + /// Returns true if the subscription limit has been reached. + /// Mirrors Go client.subsAtLimit(). + /// + internal bool SubsAtLimit() => + _msubs != JwtNoLimit && Subs.Count >= _msubs; + + // JwtNoLimit mirrors jwt.NoLimit in Go (-1 cast to int32). + private const int JwtNoLimit = -1; + + /// + /// Atomically applies the minimum of two int32 limits. + /// Mirrors Go minLimit. + /// + private static bool MinLimit(ref int value, int limit) + { + int v = Volatile.Read(ref value); + if (v != JwtNoLimit) + { + if (limit != JwtNoLimit && limit < v) + { + Volatile.Write(ref value, limit); + return true; + } + } + else if (limit != JwtNoLimit) + { + Volatile.Write(ref value, limit); + return true; + } + return false; + } + + /// + /// Applies account-level connection limits to this client. + /// Lock is held on entry. + /// Mirrors Go client.applyAccountLimits(). + /// + internal void ApplyAccountLimits() + { + if (Account is null || (Kind != ClientKind.Client && Kind != ClientKind.Leaf)) + return; + + Volatile.Write(ref _mpay, JwtNoLimit); + _msubs = JwtNoLimit; + + // Apply server-level limits. + if (Server is not null) + { + var sOpts = Server.Options; + int mPay = sOpts.MaxPayload == 0 ? JwtNoLimit : sOpts.MaxPayload; + int mSubs = sOpts.MaxSubs == 0 ? JwtNoLimit : sOpts.MaxSubs; + MinLimit(ref _mpay, mPay); + MinLimit(ref _msubs, mSubs); + } + + if (SubsAtLimit()) + Task.Run(() => + { + MaxSubsExceeded(); + Task.Delay(20).Wait(); + CloseConnection(ClosedState.MaxSubscriptionsExceeded); + }); + } + + // ========================================================================= + // RegisterUser / RegisterNkeyUser (features 416-417) + // ========================================================================= + + /// + /// Registers an authenticated user with this connection. + /// Mirrors Go client.RegisterUser(). + /// + public void RegisterUser(User user) + { + if (user.Account is INatsAccount acc) + { + try { RegisterWithAccount(acc); } + catch (Exception ex) { ReportErrRegisterAccount(acc, ex); return; } + } + + lock (_mu) + { + Perms = user.Permissions is not null ? BuildPermissions(user.Permissions) : null; + MPerms = null; + if (user.Username.Length > 0) + Opts.Username = user.Username; + if (user.ConnectionDeadline != default) + SetExpirationTimerUnlocked(user.ConnectionDeadline - DateTime.UtcNow); + } + } + + /// + /// Registers an NKey-authenticated user. + /// Mirrors Go client.RegisterNkeyUser(). + /// + public void RegisterNkeyUser(NkeyUser user) + { + if (user.Account is INatsAccount acc) + { + try { RegisterWithAccount(acc); } + catch (Exception ex) { ReportErrRegisterAccount(acc, ex); return; } + } + + lock (_mu) + { + Perms = user.Permissions is not null ? BuildPermissions(user.Permissions) : null; + MPerms = null; + } + } + + // ========================================================================= + // splitSubjectQueue (feature 418) + // ========================================================================= + + /// + /// Splits a "subject [queue]" string into subject and optional queue bytes. + /// Mirrors Go splitSubjectQueue. + /// + public static (byte[] subject, byte[]? queue) SplitSubjectQueue(string sq) + { + var vals = sq.Trim().Split((char[]?)null, StringSplitOptions.RemoveEmptyEntries); + if (vals.Length == 0) + throw new ArgumentException($"invalid subject-queue \"{sq}\""); + + var subject = Encoding.ASCII.GetBytes(vals[0]); + byte[]? queue = null; + + if (vals.Length == 2) + queue = Encoding.ASCII.GetBytes(vals[1]); + else if (vals.Length > 2) + throw new FormatException($"invalid subject-queue \"{sq}\""); + + return (subject, queue); + } + + // ========================================================================= + // setPermissions / publicPermissions / mergeDenyPermissions (features 419-422) + // ========================================================================= + + private ClientPermissions BuildPermissions(Permissions perms) + { + var cp = new ClientPermissions(); + + if (perms.Publish is not null) + { + if (perms.Publish.Allow is { Count: > 0 }) + { + cp.Pub.Allow = SubscriptionIndex.NewSublistWithCache(); + foreach (var s in perms.Publish.Allow) + cp.Pub.Allow.Insert(new Subscription { Subject = Encoding.ASCII.GetBytes(s) }); + } + if (perms.Publish.Deny is { Count: > 0 }) + { + cp.Pub.Deny = SubscriptionIndex.NewSublistWithCache(); + foreach (var s in perms.Publish.Deny) + cp.Pub.Deny.Insert(new Subscription { Subject = Encoding.ASCII.GetBytes(s) }); + } + } + + if (perms.Response is not null) + { + cp.Resp = perms.Response; + Replies = new Dictionary(StringComparer.Ordinal); + } + + if (perms.Subscribe is not null) + { + if (perms.Subscribe.Allow is { Count: > 0 }) + { + cp.Sub.Allow = SubscriptionIndex.NewSublistWithCache(); + foreach (var s in perms.Subscribe.Allow) + { + try + { + var (subj, q) = SplitSubjectQueue(s); + cp.Sub.Allow.Insert(new Subscription { Subject = subj, Queue = q }); + } + catch (Exception ex) { Errorf("%s", ex.Message); } + } + } + if (perms.Subscribe.Deny is { Count: > 0 }) + { + cp.Sub.Deny = SubscriptionIndex.NewSublistWithCache(); + DArray = []; + foreach (var s in perms.Subscribe.Deny) + { + DArray.Add(s, true); + try + { + var (subj, q) = SplitSubjectQueue(s); + cp.Sub.Deny.Insert(new Subscription { Subject = subj, Queue = q }); + } + catch (Exception ex) { Errorf("%s", ex.Message); } + } + } + } + + return cp; + } + + // ========================================================================= + // setExpiration / loadMsgDenyFilter (features 423-424) + // ========================================================================= + + internal void SetExpirationTimer(TimeSpan d) + { + // TODO: Implement when Server is available (session 09). + } + + internal void SetExpirationTimerUnlocked(TimeSpan d) + { + // TODO: Implement when Server is available (session 09). + } + + // ========================================================================= + // msgParts (feature 470) + // ========================================================================= + + /// + /// Splits a message buffer into header and body parts. + /// Mirrors Go client.msgParts(). + /// + public (byte[] hdr, byte[] msg) MsgParts(byte[] buf) + { + int hdrLen = ParseCtx.Pa.HeaderSize; + + // Return header slice with a capped capacity (no extra capacity beyond the header). + var hdr = buf[..hdrLen]; + // Create an isolated copy so appending to hdr doesn't touch msg. + var hdrCopy = new byte[hdrLen]; + Buffer.BlockCopy(buf, 0, hdrCopy, 0, hdrLen); + + var msg = buf[hdrLen..]; + return (hdrCopy, msg); + } + + // ========================================================================= + // kindString (feature 533) + // ========================================================================= + + private static readonly Dictionary KindStringMap = new() + { + [ClientKind.Client] = "Client", + [ClientKind.Router] = "Router", + [ClientKind.Gateway] = "Gateway", + [ClientKind.Leaf] = "Leafnode", + [ClientKind.JetStream] = "JetStream", + [ClientKind.Account] = "Account", + [ClientKind.System] = "System", + }; + + /// + /// Returns a human-readable kind name. + /// Mirrors Go client.kindString(). + /// + internal string KindString() => + KindStringMap.TryGetValue(Kind, out var s) ? s : "Unknown Type"; + + // ========================================================================= + // isClosed (feature 555) + // ========================================================================= + + /// + /// Returns true if closeConnection has been called. + /// Mirrors Go client.isClosed(). + /// + public bool IsClosed() => (Flags & ClientFlags.CloseConnection) != 0; + + // ========================================================================= + // format / formatNoClientInfo / formatClientSuffix (features 556-558) + // ========================================================================= + + /// + /// Returns a formatted log string for this client. + /// Mirrors Go client.format(). + /// + internal string Format() => $"{_ncs}"; + + internal string FormatNoClientInfo() => _connStr; + + internal string FormatClientSuffix() => $" - {KindString()}:{Cid}"; + + // ========================================================================= + // Logging helpers (features 559-568) + // ========================================================================= + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void Error(string msg) => Server?.Logger.LogError("[{Client}] {Msg}", _ncs, msg); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void Errorf(string fmt, params object?[] args) => + Server?.Logger.LogError("[{Client}] " + fmt, [_ncs, ..args]); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void Debugf(string fmt, params object?[] args) => + Server?.Logger.LogDebug("[{Client}] " + fmt, [_ncs, ..args]); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void Noticef(string fmt, params object?[] args) => + Server?.Logger.LogInformation("[{Client}] " + fmt, [_ncs, ..args]); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void Tracef(string fmt, params object?[] args) => + Server?.Logger.LogTrace("[{Client}] " + fmt, [_ncs, ..args]); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void Warnf(string fmt, params object?[] args) => + Server?.Logger.LogWarning("[{Client}] " + fmt, [_ncs, ..args]); + + // ========================================================================= + // Auth-related helpers (features 446-451, 526-531, 570-571) + // ========================================================================= + + internal void SendErrAndErr(string err) { SendErr(err); Error(err); } + internal void SendErrAndDebug(string msg){ SendErr(msg); Debugf(msg); } + + internal void AuthTimeout() + { + SendErrAndDebug("Authentication Timeout"); + CloseConnection(ClosedState.AuthenticationTimeout); + } + + internal void AuthExpired() + { + SendErrAndDebug("Authorization Expired"); + CloseConnection(ClosedState.AuthenticationExpired); + } + + internal void AccountAuthExpired() + { + SendErrAndDebug("Account authorization expired"); + CloseConnection(ClosedState.AuthenticationExpired); + } + + internal void AuthViolation() + { + SendErrAndErr(ServerErrors.ErrAuthorization.Message); + CloseConnection(ClosedState.AuthenticationViolation); + } + + internal void MaxAccountConnExceeded() + { + SendErrAndErr(ServerErrors.ErrTooManyAccountConnections.Message); + CloseConnection(ClosedState.MaxAccountConnectionsExceeded); + } + + internal void MaxConnExceeded() + { + SendErrAndErr(ServerErrors.ErrTooManyConnections.Message); + CloseConnection(ClosedState.MaxConnectionsExceeded); + } + + internal void MaxSubsExceeded() + { + Errorf("Maximum Subscriptions Exceeded (max=%d)", _msubs); + SendErr(ServerErrors.ErrTooManySubs.Message); + } + + internal void MaxPayloadViolation(int sz, int max) + { + SendErrAndErr($"Maximum Payload Violation"); + CloseConnection(ClosedState.MaxPayloadExceeded); + } + + internal void PubPermissionViolation(string subject) + { + SendErr($"Permissions Violation for Publish to \"{subject}\""); + Errorf("Publish Violation - User %q, Subject %q", GetAuthUser(), subject); + } + + internal void SubPermissionViolation(Subscription sub) + { + string subj = Encoding.UTF8.GetString(sub.Subject); + string queue = sub.Queue is { Length: > 0 } ? $" using queue \"{Encoding.UTF8.GetString(sub.Queue)}\"" : string.Empty; + SendErr($"Permissions Violation for Subscription to \"{subj}\"{queue}"); + Errorf("Subscription Violation - User %q, Subject %q, SID %q", + GetAuthUser(), subj, sub.Sid is not null ? Encoding.UTF8.GetString(sub.Sid) : string.Empty); + } + + internal void ReplySubjectViolation(string reply) + { + SendErr($"Permissions Violation for use of Reply subject \"{reply}\""); + Errorf("Reply Subject Violation - User %q, Reply %q", GetAuthUser(), reply); + } + + internal void MaxTokensViolation(Subscription sub) + { + SendErrAndErr($"Permissions Violation for Subscription to \"{Encoding.UTF8.GetString(sub.Subject)}\""); + } + + internal void SetAuthError(Exception err) { lock (_mu) { AuthErr = err; } } + internal Exception? GetAuthError() { lock (_mu) { return AuthErr; } } + + // ========================================================================= + // Timer helpers (features 523-531) + // ========================================================================= + + internal void SetPingTimer() + { + // TODO: Implement when Server is available. + } + + internal void ClearPingTimer() + { + var t = Interlocked.Exchange(ref _pingTimer, null); + t?.Dispose(); + } + + internal void ClearTlsToTimer() + { + var t = Interlocked.Exchange(ref _tlsTo, null); + t?.Dispose(); + } + + internal void SetAuthTimer() + { + // TODO: Implement when Server is available. + } + + internal void ClearAuthTimer() + { + var t = Interlocked.Exchange(ref _atmr, null); + t?.Dispose(); + } + + internal bool AwaitingAuth() => (Flags & ClientFlags.ExpectConnect) != 0 + && (Flags & ClientFlags.ConnectReceived) == 0; + + internal void ClaimExpiration() + { + // TODO: Implement when Server is available. + } + + // ========================================================================= + // flushSignal / queueOutbound / enqueueProto (features 433, 456-459) + // ========================================================================= + + internal void FlushSignal() + { + // TODO: Signal the writeLoop via SemaphoreSlim/Monitor when ported. + } + + internal void EnqueueProtoAndFlush(ReadOnlySpan proto) + { + EnqueueProto(proto); + } + + internal void SendProtoNow(ReadOnlySpan proto) + { + EnqueueProto(proto); + } + + internal void EnqueueProto(ReadOnlySpan proto) + { + // TODO: Full write-loop queuing when Server is ported (session 09). + if (_nc is not null) + { + try { _nc.Write(proto); } + catch { /* connection errors handled by closeConnection */ } + } + } + + // ========================================================================= + // sendPong / sendPing / sendRTTPing (features 460-463) + // ========================================================================= + + internal void SendPong() => EnqueueProtoAndFlush(Encoding.ASCII.GetBytes(Wires.PongProto)); + + internal void SendRttPing() { lock (_mu) { SendRttPingLocked(); } } + + internal void SendRttPingLocked() + { + RttStart = DateTime.UtcNow; + SendPing(); + } + + internal void SendPing() + { + _pingOut++; + EnqueueProtoAndFlush(Encoding.ASCII.GetBytes(Wires.PingProto)); + } + + // ========================================================================= + // sendErr / sendOK (features 465-466) + // ========================================================================= + + internal void SendErr(string err) => + EnqueueProtoAndFlush(Encoding.ASCII.GetBytes(string.Format(Wires.ErrProto, err))); + + internal void SendOK() + { + if (Opts.Verbose) + EnqueueProtoAndFlush(Encoding.ASCII.GetBytes(Wires.OkProto)); + } + + // ========================================================================= + // traceMsg / traceInOp / traceOutOp / traceOp (features 434-439) + // ========================================================================= + + internal void TraceMsg(byte[] msg) { if (Trace) TraceMsgInternal(msg, false, false); } + internal void TraceMsgDelivery(byte[] msg) { if (Trace) TraceMsgInternal(msg, false, true); } + internal void TraceInOp(string op, byte[] arg) { if (Trace) TraceOp("<", op, arg); } + internal void TraceOutOp(string op, byte[] arg) { if (Trace) TraceOp(">", op, arg); } + + private void TraceMsgInternal(byte[] msg, bool inbound, bool delivery) { } + private void TraceOp(string dir, string op, byte[] arg) + { + Tracef("%s %s %s", dir, op, arg is not null ? Encoding.UTF8.GetString(arg) : string.Empty); + } + + // ========================================================================= + // getAuthUser / getAuthUserLabel (features 550-552) + // ========================================================================= + + internal string GetRawAuthUserLock() + { + lock (_mu) { return GetRawAuthUser(); } + } + + internal string GetRawAuthUser() + { + if (Opts.Nkey.Length > 0) return Opts.Nkey; + if (Opts.Username.Length > 0) return Opts.Username; + if (Opts.Token.Length > 0) return "Token"; + return "Unknown"; + } + + internal string GetAuthUser() => GetRawAuthUser(); + + internal string GetAuthUserLabel() + { + var u = GetRawAuthUser(); + return u.Length > 0 ? u : "Unknown User"; + } + + // ========================================================================= + // connectionTypeAllowed (feature 554) + // ========================================================================= + + internal bool ConnectionTypeAllowed(string ct) + { + // TODO: Full implementation when JWT is integrated. + return true; + } + + // ========================================================================= + // closeConnection (feature 536) + // ========================================================================= + + /// + /// Closes the client connection with the given reason. + /// Mirrors Go client.closeConnection(). + /// + public void CloseConnection(ClosedState reason) + { + lock (_mu) + { + if (IsClosed()) return; + Flags |= ClientFlags.CloseConnection; + ClearAuthTimer(); + ClearPingTimer(); + } + + // Close the underlying network connection. + try { _nc?.Close(); } catch { /* ignore */ } + _nc = null; + } + + // ========================================================================= + // flushAndClose (feature 532) + // ========================================================================= + + internal void FlushAndClose(bool deadlineExceeded) + { + CloseConnection(ClosedState.ClientClosed); + } + + // ========================================================================= + // setNoReconnect (feature 538) + // ========================================================================= + + internal void SetNoReconnect() + { + lock (_mu) { Flags |= ClientFlags.NoReconnect; } + } + + // ========================================================================= + // getRTTValue (feature 539) + // ========================================================================= + + internal TimeSpan GetRttValue() + { + lock (_mu) { return Rtt; } + } + + // ========================================================================= + // Account / server helpers (features 540-545) + // ========================================================================= + + internal INatsAccount? GetAccount() + { + lock (_mu) { return Account; } + } + + // ========================================================================= + // TLS handshake helpers (features 546-548) + // ========================================================================= + + internal async Task DoTlsServerHandshakeAsync(SslServerAuthenticationOptions opts, CancellationToken ct = default) + { + // TODO: Full TLS when Server is ported. + return false; + } + + internal async Task DoTlsClientHandshakeAsync(SslClientAuthenticationOptions opts, CancellationToken ct = default) + { + // TODO: Full TLS when Server is ported. + return false; + } + + // ========================================================================= + // Stub methods for server-dependent features + // (Fully implemented when Server/Account sessions are complete) + // ========================================================================= + + // features 425-427: writeLoop / flushClients / readLoop + internal void WriteLoop() { /* TODO session 09 */ } + internal void FlushClients(long budget) { /* TODO session 09 */ } + + // features 428-432: closedStateForErr, collapsePtoNB, flushOutbound, handleWriteTimeout, markConnAsClosed + internal static ClosedState ClosedStateForErr(Exception err) => + err is EndOfStreamException ? ClosedState.ClientClosed : ClosedState.ReadError; + + // features 440-441: processInfo, processErr + internal void ProcessInfo(string info) { /* TODO session 09 */ } + internal void ProcessErr(string err) { /* TODO session 09 */ } + + // features 442-443: removeSecretsFromTrace, redact + internal static string RemoveSecretsFromTrace(string s) => s; + internal static string Redact(string s) => s; + + // feature 444: computeRTT + internal static TimeSpan ComputeRtt(DateTime start) => DateTime.UtcNow - start; + + // feature 445: processConnect + internal void ProcessConnect(byte[] arg) { /* TODO session 09 */ } + + // feature 467-468: processPing, processPong + internal void ProcessPing() + { + _pingOut = 0; + SendPong(); + } + + internal void ProcessPong() { /* TODO */ } + + // feature 469: updateS2AutoCompressionLevel + internal void UpdateS2AutoCompressionLevel() { /* TODO */ } + + // features 471-486: processPub variants, parseSub, processSub, etc. + // Implemented in full when Server+Account sessions complete. + + // features 487-503: deliverMsg, addToPCD, trackRemoteReply, pruning, pubAllowed, etc. + + // features 512-514: processServiceImport, addSubToRouteTargets, processMsgResults + + // feature 515: checkLeafClientInfoHeader + // feature 520-522: processPingTimer, adjustPingInterval, watchForStaleConnection + // feature 534-535: swapAccountAfterReload, processSubsOnConfigReload + // feature 537: reconnect + // feature 569: setFirstPingTimer + + // ========================================================================= + // IsMqtt / IsWebSocket helpers (used by clientType, not separately tracked) + // ========================================================================= + + internal bool IsMqtt() => false; // TODO: set in session 22 (MQTT) + internal bool IsWebSocket() => false; // TODO: set in session 23 (WebSocket) + internal bool IsHubLeafNode() => false; // TODO: set in session 15 (leaf nodes) + internal string RemoteCluster() => string.Empty; // TODO: session 14/15 +} + +// ============================================================================ +// Private read-cache state (per-readLoop invocation) +// ============================================================================ + +internal struct ReadCacheState +{ + public ulong GenId; + public Dictionary? Results; + public Dictionary? PaCache; + public List? Rts; + public int Msgs; + public int Bytes; + public int Subs; + public int Rsz; // read buffer size + public int Srs; // short reads + public ReadCacheFlags Flags; + public DateTime Start; + public TimeSpan Tst; // total stall time +} + +internal sealed class PerAccountCache +{ + public INatsAccount? Acc { get; set; } + public SubscriptionIndexResult? Results { get; set; } + public ulong GenId { get; set; } +} + +internal sealed class RrTracking +{ + public Dictionary? RMap { get; set; } + public Timer? Ptmr { get; set; } + public TimeSpan Lrt { get; set; } +} + +// ============================================================================ +// Server / account interfaces (stubs until sessions 09 and 11) +// ============================================================================ + +/// +/// Minimal server interface used by ClientConnection. +/// Full implementation in session 09 (server.go). +/// +public interface INatsServer +{ + ulong NextClientId(); + ServerOptions Options { get; } + bool TraceEnabled { get; } + bool TraceSysAcc { get; } + ILogger Logger { get; } + void DecActiveAccounts(); + void IncActiveAccounts(); +} + +/// +/// Minimal account interface used by ClientConnection. +/// Full implementation in session 11 (accounts.go). +/// +public interface INatsAccount +{ + string Name { get; } + bool IsValid { get; } + bool MaxTotalConnectionsReached(); + bool MaxTotalLeafNodesReached(); + int AddClient(ClientConnection c); + int RemoveClient(ClientConnection c); +} + +/// Thrown when account connection limits are exceeded. +public sealed class TooManyAccountConnectionsException : Exception +{ + public TooManyAccountConnectionsException() : base("Too Many Account Connections") { } +} + +/// Thrown when an account is invalid or null. +public sealed class BadAccountException : Exception +{ + public BadAccountException() : base("Bad Account") { } +} diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/ClientTypes.cs b/dotnet/src/ZB.MOM.NatsNet.Server/ClientTypes.cs new file mode 100644 index 0000000..d4108a2 --- /dev/null +++ b/dotnet/src/ZB.MOM.NatsNet.Server/ClientTypes.cs @@ -0,0 +1,375 @@ +// Copyright 2012-2026 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Adapted from server/client.go in the NATS server Go source. + +using System.Text.Json.Serialization; +using ZB.MOM.NatsNet.Server.Auth; +using ZB.MOM.NatsNet.Server.Internal; +using ZB.MOM.NatsNet.Server.Internal.DataStructures; + +namespace ZB.MOM.NatsNet.Server; + +// ============================================================================ +// Client connection kind (iota constants) +// ============================================================================ + +// Note: ClientKind is already declared in Internal/Subscription.cs; this file +// adds the remaining constants that were used only here. + +/// +/// Extended client connection type (returned by clientType()). +/// Maps Go's NON_CLIENT / NATS / MQTT / WS iota. +/// +public enum ClientConnectionType +{ + /// Connection is not a CLIENT kind. + NonClient = 0, + /// Regular NATS client. + Nats = 1, + /// MQTT client. + Mqtt = 2, + /// WebSocket client. + WebSocket = 3, +} + +// ============================================================================ +// Client protocol versions +// ============================================================================ + +/// +/// Wire protocol version negotiated in the CONNECT message. +/// +public static class ClientProtocol +{ + /// Original protocol (2009). Mirrors ClientProtoZero. + public const int Zero = 0; + /// Protocol that supports INFO updates. Mirrors ClientProtoInfo. + public const int Info = 1; +} + +// ============================================================================ +// WriteTimeoutPolicy extension (enum defined in ServerOptionTypes.cs) +// ============================================================================ + +internal static class WriteTimeoutPolicyExtensions +{ + /// Mirrors Go WriteTimeoutPolicy.String(). + public static string ToVarzString(this WriteTimeoutPolicy p) => p switch + { + WriteTimeoutPolicy.Close => "close", + WriteTimeoutPolicy.Retry => "retry", + _ => string.Empty, + }; +} + +// ============================================================================ +// ClientFlags +// ============================================================================ + +/// +/// Compact bitfield of boolean client state. +/// Mirrors Go clientFlag and its iota constants. +/// +[Flags] +public enum ClientFlags : ushort +{ + None = 0, + ConnectReceived = 1 << 0, + InfoReceived = 1 << 1, + FirstPongSent = 1 << 2, + HandshakeComplete = 1 << 3, + FlushOutbound = 1 << 4, + NoReconnect = 1 << 5, + CloseConnection = 1 << 6, + ConnMarkedClosed = 1 << 7, + WriteLoopStarted = 1 << 8, + SkipFlushOnClose = 1 << 9, + ExpectConnect = 1 << 10, + ConnectProcessFinished = 1 << 11, + CompressionNegotiated = 1 << 12, + DidTlsFirst = 1 << 13, + IsSlowConsumer = 1 << 14, + FirstPong = 1 << 15, +} + +// ============================================================================ +// ReadCacheFlags +// ============================================================================ + +/// +/// Bitfield for the read-cache loop state. +/// Mirrors Go readCacheFlag. +/// +[Flags] +public enum ReadCacheFlags : ushort +{ + None = 0, + HasMappings = 1 << 0, + SwitchToCompression = 1 << 1, +} + +// ============================================================================ +// ClosedState +// ============================================================================ + +/// +/// The reason a client connection was closed. +/// Mirrors Go ClosedState. +/// +public enum ClosedState +{ + ClientClosed = 1, + AuthenticationTimeout, + AuthenticationViolation, + TlsHandshakeError, + SlowConsumerPendingBytes, + SlowConsumerWriteDeadline, + WriteError, + ReadError, + ParseError, + StaleConnection, + ProtocolViolation, + BadClientProtocolVersion, + WrongPort, + MaxAccountConnectionsExceeded, + MaxConnectionsExceeded, + MaxPayloadExceeded, + MaxControlLineExceeded, + MaxSubscriptionsExceeded, + DuplicateRoute, + RouteRemoved, + ServerShutdown, + AuthenticationExpired, + WrongGateway, + MissingAccount, + Revocation, + InternalClient, + MsgHeaderViolation, + NoRespondersRequiresHeaders, + ClusterNameConflict, + DuplicateRemoteLeafnodeConnection, + DuplicateClientId, + DuplicateServerName, + MinimumVersionRequired, + ClusterNamesIdentical, + Kicked, + ProxyNotTrusted, + ProxyRequired, +} + +// ============================================================================ +// processMsgResults flags +// ============================================================================ + +/// +/// Flags passed to ProcessMsgResults. +/// Mirrors Go pmrNoFlag and the iota block. +/// +[Flags] +public enum PmrFlags +{ + None = 0, + CollectQueueNames = 1 << 0, + IgnoreEmptyQueueFilter = 1 << 1, + AllowSendFromRouteToRoute = 1 << 2, + MsgImportedFromService = 1 << 3, +} + +// ============================================================================ +// denyType +// ============================================================================ + +/// +/// Which permission side to apply deny-list merging to. +/// Mirrors Go denyType. +/// +internal enum DenyType +{ + Pub = 1, + Sub = 2, + Both = 3, +} + +// ============================================================================ +// ClientOptions (wire-protocol CONNECT options) +// ============================================================================ + +/// +/// Options negotiated during the CONNECT handshake. +/// Mirrors Go ClientOpts. +/// +public sealed class ClientOptions +{ + [JsonPropertyName("echo")] public bool Echo { get; set; } + [JsonPropertyName("verbose")] public bool Verbose { get; set; } + [JsonPropertyName("pedantic")] public bool Pedantic { get; set; } + [JsonPropertyName("tls_required")] public bool TlsRequired { get; set; } + [JsonPropertyName("nkey")] public string Nkey { get; set; } = string.Empty; + [JsonPropertyName("jwt")] public string Jwt { get; set; } = string.Empty; + [JsonPropertyName("sig")] public string Sig { get; set; } = string.Empty; + [JsonPropertyName("auth_token")] public string Token { get; set; } = string.Empty; + [JsonPropertyName("user")] public string Username { get; set; } = string.Empty; + [JsonPropertyName("pass")] public string Password { get; set; } = string.Empty; + [JsonPropertyName("name")] public string Name { get; set; } = string.Empty; + [JsonPropertyName("lang")] public string Lang { get; set; } = string.Empty; + [JsonPropertyName("version")] public string Version { get; set; } = string.Empty; + [JsonPropertyName("protocol")] public int Protocol { get; set; } + [JsonPropertyName("account")] public string Account { get; set; } = string.Empty; + [JsonPropertyName("new_account")] public bool AccountNew { get; set; } + [JsonPropertyName("headers")] public bool Headers { get; set; } + [JsonPropertyName("no_responders")]public bool NoResponders { get; set; } + + // Routes and Leaf Nodes only + [JsonPropertyName("import")] public SubjectPermission? Import { get; set; } + [JsonPropertyName("export")] public SubjectPermission? Export { get; set; } + [JsonPropertyName("remote_account")] public string RemoteAccount { get; set; } = string.Empty; + [JsonPropertyName("proxy_sig")] public string ProxySig { get; set; } = string.Empty; + + /// Default options for external clients. + public static ClientOptions Default => new() { Verbose = true, Pedantic = true, Echo = true }; + + /// Default options for internal server clients. + public static ClientOptions Internal => new() { Verbose = false, Pedantic = false, Echo = false }; +} + +// ============================================================================ +// ClientInfo — lightweight metadata sent in server events +// ============================================================================ + +/// +/// Client metadata included in server monitoring events. +/// Mirrors Go ClientInfo. +/// +public sealed class ClientInfo +{ + public string Start { get; set; } = string.Empty; + public string Host { get; set; } = string.Empty; + public ulong Id { get; set; } + public string Account { get; set; } = string.Empty; + public string User { get; set; } = string.Empty; + public string Name { get; set; } = string.Empty; + public string Lang { get; set; } = string.Empty; + public string Version { get; set; } = string.Empty; + public string Jwt { get; set; } = string.Empty; + public string IssuerKey { get; set; } = string.Empty; + public string NameTag { get; set; } = string.Empty; + public List Tags { get; set; } = []; + public string Kind { get; set; } = string.Empty; + public string ClientType { get; set; } = string.Empty; + public string? MqttId { get; set; } + public bool Stop { get; set; } + public bool Restart { get; set; } + public bool Disconnect { get; set; } + public string[]? Cluster { get; set; } + public bool Service { get; set; } +} + +// ============================================================================ +// Internal permission structures (not public API) +// (Permissions, SubjectPermission, ResponsePermission are in Auth/AuthTypes.cs) +// ============================================================================ + +internal sealed class Perm +{ + public SubscriptionIndex? Allow { get; set; } + public SubscriptionIndex? Deny { get; set; } +} + +internal sealed class ClientPermissions +{ + public int PcsZ; // pub cache size (atomic) + public int PRun; // prune run count (atomic) + public Perm Sub { get; } = new(); + public Perm Pub { get; } = new(); + public ResponsePermission? Resp { get; set; } + // Per-subject cache for permission checks. + public Dictionary PCache { get; } = new(StringComparer.Ordinal); +} + +internal sealed class MsgDeny +{ + public SubscriptionIndex? Deny { get; set; } + public Dictionary DCache { get; } = new(StringComparer.Ordinal); +} + +internal sealed class RespEntry +{ + public DateTime Time { get; set; } + public int N { get; set; } +} + +// ============================================================================ +// Buffer pool constants +// ============================================================================ + +internal static class NbPool +{ + internal const int SmallSize = 512; + internal const int MediumSize = 4096; + internal const int LargeSize = 65536; + + private static readonly System.Buffers.ArrayPool _pool = + System.Buffers.ArrayPool.Create(LargeSize, 50); + + /// + /// Returns a buffer best-effort sized to . + /// Mirrors Go nbPoolGet. + /// + public static byte[] Get(int sz) + { + int cap = sz <= SmallSize ? SmallSize + : sz <= MediumSize ? MediumSize + : LargeSize; + return _pool.Rent(cap); + } + + /// + /// Returns a buffer to the pool. + /// Mirrors Go nbPoolPut. + /// + public static void Put(byte[] buf) + { + if (buf.Length == SmallSize || buf.Length == MediumSize || buf.Length == LargeSize) + _pool.Return(buf); + // Ignore wrong-sized frames (WebSocket/MQTT). + } +} + +// ============================================================================ +// Route / gateway / leaf / websocket / mqtt stubs +// (These are filled in during sessions 14-16 and 22-23) +// ============================================================================ + +internal sealed class RouteTarget +{ + public Subscription? Sub { get; set; } + public byte[] Qs { get; set; } = []; +} + +// ============================================================================ +// Static helper: IsInternalClient +// ============================================================================ + +/// +/// Client-kind classification helpers. +/// +public static class ClientKindHelpers +{ + /// + /// Returns true if is an internal server client. + /// Mirrors Go isInternalClient. + /// + public static bool IsInternalClient(ClientKind kind) => + kind == ClientKind.System || kind == ClientKind.JetStream || kind == ClientKind.Account; +} diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/NatsMessageHeaders.cs b/dotnet/src/ZB.MOM.NatsNet.Server/NatsMessageHeaders.cs new file mode 100644 index 0000000..6d65536 --- /dev/null +++ b/dotnet/src/ZB.MOM.NatsNet.Server/NatsMessageHeaders.cs @@ -0,0 +1,389 @@ +// Copyright 2012-2026 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Adapted from server/client.go (header utility functions) in the NATS server Go source. + +using System.Text; + +namespace ZB.MOM.NatsNet.Server; + +/// +/// Wire-level NATS message header constants. +/// +public static class NatsHeaderConstants +{ + /// NATS header status line: "NATS/1.0\r\n". Mirrors Go hdrLine. + public const string HdrLine = "NATS/1.0\r\n"; + + /// Empty header block with blank line terminator. Mirrors Go emptyHdrLine. + public const string EmptyHdrLine = "NATS/1.0\r\n\r\n"; + + // JetStream expected-sequence headers (defined in server/stream.go, used by header utilities). + public const string JsExpectedStream = "Nats-Expected-Stream"; + public const string JsExpectedLastSeq = "Nats-Expected-Last-Sequence"; + public const string JsExpectedLastSubjSeq = "Nats-Expected-Last-Subject-Sequence"; + public const string JsExpectedLastSubjSeqSubj = "Nats-Expected-Last-Subject-Sequence-Subject"; + public const string JsExpectedLastMsgId = "Nats-Expected-Last-Msg-Id"; + + // Other commonly used headers. + public const string JsMsgId = "Nats-Msg-Id"; + public const string JsMsgRollup = "Nats-Rollup"; +} + +/// +/// Low-level NATS message header manipulation utilities. +/// Mirrors the package-level functions in server/client.go: +/// genHeader, removeHeaderIfPresent, removeHeaderIfPrefixPresent, +/// getHeader, sliceHeader, getHeaderKeyIndex, setHeader. +/// +public static class NatsMessageHeaders +{ + private static readonly byte[] CrLfBytes = "\r\n"u8.ToArray(); + + // ------------------------------------------------------------------------- + // genHeader (feature 506) + // ------------------------------------------------------------------------- + + /// + /// Generates a header buffer by appending key: value\r\n to an existing header, + /// or starting a fresh NATS/1.0\r\n block if is empty/null. + /// Mirrors Go genHeader. + /// + /// Existing header bytes, or null to start fresh. + /// Header key. + /// Header value. + public static byte[] GenHeader(byte[]? hdr, string key, string value) + { + var sb = new StringBuilder(); + + // Strip trailing CRLF from existing header to reopen for appending, + // or start fresh with the header status line. + const int LenCrLf = 2; + if (hdr is { Length: > LenCrLf }) + { + // Write all but the trailing "\r\n" + sb.Append(Encoding.ASCII.GetString(hdr, 0, hdr.Length - LenCrLf)); + } + else + { + sb.Append(NatsHeaderConstants.HdrLine); + } + + // Append "key: value\r\n\r\n" (HTTP header format). + sb.Append(key); + sb.Append(": "); + sb.Append(value); + sb.Append("\r\n\r\n"); + + return Encoding.ASCII.GetBytes(sb.ToString()); + } + + // ------------------------------------------------------------------------- + // removeHeaderIfPresent (feature 504) + // ------------------------------------------------------------------------- + + /// + /// Removes the first occurrence of header from . + /// Returns null if the result would be an empty header block. + /// Mirrors Go removeHeaderIfPresent. + /// + public static byte[]? RemoveHeaderIfPresent(byte[] hdr, string key) + { + int start = GetHeaderKeyIndex(key, hdr); + // Key must exist and be preceded by '\n' (not at position 0). + if (start < 1 || hdr[start - 1] != '\n') + return hdr; + + int index = start + key.Length; + if (index >= hdr.Length || hdr[index] != ':') + return hdr; + + // Find CRLF following this header line. + int crlfIdx = IndexOfCrLf(hdr, start); + if (crlfIdx < 0) + return hdr; + + // Remove from 'start' through end of CRLF. + int removeEnd = start + crlfIdx + 2; // +2 for "\r\n" + var result = new byte[hdr.Length - (removeEnd - start)]; + Buffer.BlockCopy(hdr, 0, result, 0, start); + Buffer.BlockCopy(hdr, removeEnd, result, start, hdr.Length - removeEnd); + + // If nothing meaningful remains, return null. + if (result.Length <= NatsHeaderConstants.EmptyHdrLine.Length) + return null; + + return result; + } + + // ------------------------------------------------------------------------- + // removeHeaderIfPrefixPresent (feature 505) + // ------------------------------------------------------------------------- + + /// + /// Removes all headers whose names start with . + /// Returns null if the result would be an empty header block. + /// Mirrors Go removeHeaderIfPrefixPresent. + /// + public static byte[]? RemoveHeaderIfPrefixPresent(byte[] hdr, string prefix) + { + var prefixBytes = Encoding.ASCII.GetBytes(prefix); + var working = hdr.ToList(); // work on a list for easy splicing + int index = 0; + + while (index < working.Count) + { + // Look for prefix starting at current index. + int found = IndexOf(working, prefixBytes, index); + if (found < 0) + break; + + // Must be preceded by '\n'. + if (found < 1 || working[found - 1] != '\n') + break; + + // Find CRLF after this prefix's key:value line. + int crlfIdx = IndexOfCrLf(working, found + prefix.Length); + if (crlfIdx < 0) + break; + + int removeEnd = found + prefix.Length + crlfIdx + 2; + working.RemoveRange(found, removeEnd - found); + + // Don't advance index — there may be more headers at same position. + if (working.Count <= NatsHeaderConstants.EmptyHdrLine.Length) + return null; + } + + return working.ToArray(); + } + + // ------------------------------------------------------------------------- + // getHeaderKeyIndex (feature 510) + // ------------------------------------------------------------------------- + + /// + /// Returns the byte offset of in , + /// or -1 if not found. + /// The key must be preceded by \r\n and followed by :. + /// Mirrors Go getHeaderKeyIndex. + /// + public static int GetHeaderKeyIndex(string key, byte[] hdr) + { + if (hdr.Length == 0) return -1; + + var bkey = Encoding.ASCII.GetBytes(key); + int keyLen = bkey.Length; + int hdrLen = hdr.Length; + int offset = 0; + + while (true) + { + int index = IndexOf(hdr, bkey, offset); + // Need index >= 2 (room for preceding \r\n) and enough space for trailing colon. + if (index < 2) return -1; + + // Preceded by \r\n ? + if (hdr[index - 1] != '\n' || hdr[index - 2] != '\r') + { + offset = index + keyLen; + continue; + } + + // Immediately followed by ':' ? + if (index + keyLen >= hdrLen) + return -1; + + if (hdr[index + keyLen] != ':') + { + offset = index + keyLen; + continue; + } + + return index; + } + } + + // ------------------------------------------------------------------------- + // sliceHeader (feature 509) + // ------------------------------------------------------------------------- + + /// + /// Returns a slice of containing the value of , + /// or null if not found. + /// The returned slice shares memory with . + /// Mirrors Go sliceHeader. + /// + public static ReadOnlyMemory? SliceHeader(string key, byte[] hdr) + { + if (hdr.Length == 0) return null; + + int index = GetHeaderKeyIndex(key, hdr); + if (index == -1) return null; + + // Skip over key + ':' separator. + index += key.Length + 1; + int hdrLen = hdr.Length; + + // Skip leading whitespace. + while (index < hdrLen && hdr[index] == ' ') + index++; + + int start = index; + // Collect until CRLF. + while (index < hdrLen) + { + if (hdr[index] == '\r' && index + 1 < hdrLen && hdr[index + 1] == '\n') + break; + index++; + } + + // Return a slice with capped length == value length (no extra capacity). + return new ReadOnlyMemory(hdr, start, index - start); + } + + // ------------------------------------------------------------------------- + // getHeader (feature 508) + // ------------------------------------------------------------------------- + + /// + /// Returns a copy of the value for the header named , + /// or null if not found. + /// Mirrors Go getHeader. + /// + public static byte[]? GetHeader(string key, byte[] hdr) + { + var slice = SliceHeader(key, hdr); + if (slice is null) return null; + + // Return a fresh copy. + return slice.Value.ToArray(); + } + + // ------------------------------------------------------------------------- + // setHeader (feature 511) + // ------------------------------------------------------------------------- + + /// + /// Replaces the value of the first existing header in + /// , or appends a new header if the key is absent. + /// Returns a new buffer when the new value is larger; modifies in-place otherwise. + /// Mirrors Go setHeader. + /// + public static byte[] SetHeader(string key, string val, byte[] hdr) + { + int start = GetHeaderKeyIndex(key, hdr); + if (start >= 0) + { + int valStart = start + key.Length + 1; // skip past ':' + int hdrLen = hdr.Length; + + // Preserve a single leading space if present. + if (valStart < hdrLen && hdr[valStart] == ' ') + valStart++; + + // Find the CR before the CRLF. + int crIdx = IndexOf(hdr, [(byte)'\r'], valStart); + if (crIdx < 0) return hdr; // malformed + + int valEnd = crIdx; + int oldValLen = valEnd - valStart; + var valBytes = Encoding.ASCII.GetBytes(val); + + int extra = valBytes.Length - oldValLen; + if (extra > 0) + { + // New value is larger — must allocate a new buffer. + var newHdr = new byte[hdrLen + extra]; + Buffer.BlockCopy(hdr, 0, newHdr, 0, valStart); + Buffer.BlockCopy(valBytes, 0, newHdr, valStart, valBytes.Length); + Buffer.BlockCopy(hdr, valEnd, newHdr, valStart + valBytes.Length, hdrLen - valEnd); + return newHdr; + } + + // Write in place (new value fits). + int n = valBytes.Length; + Buffer.BlockCopy(valBytes, 0, hdr, valStart, n); + // Shift remainder left. + Buffer.BlockCopy(hdr, valEnd, hdr, valStart + n, hdrLen - valEnd); + return hdr[..(valStart + n + hdrLen - valEnd)]; + } + + // Key not present — append. + bool hasTrailingCrLf = hdr.Length >= 2 + && hdr[^2] == '\r' + && hdr[^1] == '\n'; + + byte[] suffix; + if (hasTrailingCrLf) + { + // Strip trailing CRLF, append "key: val\r\n\r\n". + suffix = Encoding.ASCII.GetBytes($"{key}: {val}\r\n"); + var result = new byte[hdr.Length - 2 + suffix.Length + 2]; + Buffer.BlockCopy(hdr, 0, result, 0, hdr.Length - 2); + Buffer.BlockCopy(suffix, 0, result, hdr.Length - 2, suffix.Length); + result[^2] = (byte)'\r'; + result[^1] = (byte)'\n'; + return result; + } + + suffix = Encoding.ASCII.GetBytes($"{key}: {val}\r\n"); + var newBuf = new byte[hdr.Length + suffix.Length]; + Buffer.BlockCopy(hdr, 0, newBuf, 0, hdr.Length); + Buffer.BlockCopy(suffix, 0, newBuf, hdr.Length, suffix.Length); + return newBuf; + } + + // ------------------------------------------------------------------------- + // Internal helpers + // ------------------------------------------------------------------------- + + private static int IndexOf(byte[] haystack, byte[] needle, int offset) + { + var span = haystack.AsSpan(offset); + int idx = span.IndexOf(needle); + return idx < 0 ? -1 : offset + idx; + } + + private static int IndexOf(List haystack, byte[] needle, int offset) + { + for (int i = offset; i <= haystack.Count - needle.Length; i++) + { + bool match = true; + for (int j = 0; j < needle.Length; j++) + { + if (haystack[i + j] != needle[j]) { match = false; break; } + } + if (match) return i; + } + return -1; + } + + /// Returns the offset of the first \r\n in at or after . + private static int IndexOfCrLf(byte[] hdr, int offset) + { + var span = hdr.AsSpan(offset); + int idx = span.IndexOf(CrLfBytes); + return idx; // relative to offset + } + + private static int IndexOfCrLf(List hdr, int offset) + { + for (int i = offset; i < hdr.Count - 1; i++) + { + if (hdr[i] == '\r' && hdr[i + 1] == '\n') + return i - offset; + } + return -1; + } +} diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/Protocol/ProxyProtocol.cs b/dotnet/src/ZB.MOM.NatsNet.Server/Protocol/ProxyProtocol.cs new file mode 100644 index 0000000..15269a4 --- /dev/null +++ b/dotnet/src/ZB.MOM.NatsNet.Server/Protocol/ProxyProtocol.cs @@ -0,0 +1,604 @@ +// Copyright 2025 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Adapted from server/client_proxyproto.go in the NATS server Go source. + +using System.Buffers.Binary; +using System.Net; +using System.Net.Sockets; +using System.Runtime.CompilerServices; +using System.Text; + +namespace ZB.MOM.NatsNet.Server.Protocol; + +// ============================================================================ +// Proxy Protocol v2 constants +// ============================================================================ + +/// +/// PROXY protocol v1 and v2 constants. +/// Mirrors the const blocks in server/client_proxyproto.go. +/// +internal static class ProxyProtoConstants +{ + // v2 signature (12 bytes) + internal const string V2Sig = "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"; + + // Version and command byte masks + internal const byte VerMask = 0xF0; + internal const byte Ver2 = 0x20; + internal const byte CmdMask = 0x0F; + internal const byte CmdLocal = 0x00; + internal const byte CmdProxy = 0x01; + + // Address family and protocol masks + internal const byte FamilyMask = 0xF0; + internal const byte FamilyUnspec = 0x00; + internal const byte FamilyInet = 0x10; + internal const byte FamilyInet6 = 0x20; + internal const byte FamilyUnix = 0x30; + internal const byte ProtoMask = 0x0F; + internal const byte ProtoUnspec = 0x00; + internal const byte ProtoStream = 0x01; + internal const byte ProtoDatagram = 0x02; + + // Address sizes + internal const int AddrSizeIPv4 = 12; // 4+4+2+2 + internal const int AddrSizeIPv6 = 36; // 16+16+2+2 + + // Fixed v2 header size: 12 (sig) + 1 (ver/cmd) + 1 (fam/proto) + 2 (addr len) + internal const int V2HeaderSize = 16; + + // Timeout for reading PROXY protocol header + internal static readonly TimeSpan ReadTimeout = TimeSpan.FromSeconds(5); + + // v1 constants + internal const string V1Prefix = "PROXY "; + internal const int V1MaxLineLen = 107; + internal const string V1TCP4 = "TCP4"; + internal const string V1TCP6 = "TCP6"; + internal const string V1Unknown = "UNKNOWN"; +} + +// ============================================================================ +// Well-known errors +// ============================================================================ + +/// +/// Well-known PROXY protocol errors. +/// Mirrors errProxyProtoInvalid, errProxyProtoUnsupported, etc. in client_proxyproto.go. +/// +public static class ProxyProtoErrors +{ + public static readonly Exception Invalid = new InvalidDataException("invalid PROXY protocol header"); + public static readonly Exception Unsupported = new NotSupportedException("unsupported PROXY protocol feature"); + public static readonly Exception Timeout = new TimeoutException("timeout reading PROXY protocol header"); + public static readonly Exception Unrecognized = new InvalidDataException("unrecognized PROXY protocol format"); +} + +// ============================================================================ +// ProxyProtocolAddress +// ============================================================================ + +/// +/// Address information extracted from a PROXY protocol header. +/// Mirrors Go proxyProtoAddr. +/// +public sealed class ProxyProtocolAddress +{ + public IPAddress SrcIp { get; } + public ushort SrcPort { get; } + public IPAddress DstIp { get; } + public ushort DstPort { get; } + + internal ProxyProtocolAddress(IPAddress srcIp, ushort srcPort, IPAddress dstIp, ushort dstPort) + { + SrcIp = srcIp; + SrcPort = srcPort; + DstIp = dstIp; + DstPort = dstPort; + } + + /// Returns "srcIP:srcPort". Mirrors proxyProtoAddr.String(). + public string String() => FormatEndpoint(SrcIp, SrcPort); + + /// Returns "tcp4" or "tcp6". Mirrors proxyProtoAddr.Network(). + public string Network() => SrcIp.IsIPv4MappedToIPv6 || SrcIp.AddressFamily == AddressFamily.InterNetwork + ? "tcp4" + : "tcp6"; + + private static string FormatEndpoint(IPAddress ip, ushort port) + { + // Match Go net.JoinHostPort — wraps IPv6 in brackets. + var addr = ip.AddressFamily == AddressFamily.InterNetworkV6 + ? $"[{ip}]" + : ip.ToString(); + return $"{addr}:{port}"; + } +} + +// ============================================================================ +// ProxyProtocolConnection +// ============================================================================ + +/// +/// Wraps a / to override the remote endpoint +/// with the address extracted from the PROXY protocol header. +/// Mirrors Go proxyConn. +/// +public sealed class ProxyProtocolConnection +{ + private readonly Stream _inner; + + /// The underlying connection stream. + public Stream InnerStream => _inner; + + /// The proxied remote address (extracted from the header). + public ProxyProtocolAddress RemoteAddress { get; } + + internal ProxyProtocolConnection(Stream inner, ProxyProtocolAddress remoteAddr) + { + _inner = inner; + RemoteAddress = remoteAddr; + } +} + +// ============================================================================ +// ProxyProtocolParser (static) +// ============================================================================ + +/// +/// Reads and parses PROXY protocol v1 and v2 headers from a . +/// Mirrors the functions in server/client_proxyproto.go. +/// +public static class ProxyProtocolParser +{ + // ------------------------------------------------------------------------- + // Public entry points + // ------------------------------------------------------------------------- + + /// + /// Reads and parses a PROXY protocol (v1 or v2) header from . + /// Returns null for LOCAL/UNKNOWN health-check commands. + /// Mirrors Go readProxyProtoHeader. + /// + public static async Task ReadProxyProtoHeaderAsync( + Stream stream, + CancellationToken cancellationToken = default) + { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + cts.CancelAfter(ProxyProtoConstants.ReadTimeout); + var ct = cts.Token; + + // Detect version by reading first 6 bytes. + var (version, firstBytes, err) = await DetectVersionAsync(stream, ct).ConfigureAwait(false); + if (err is not null) throw err; + + switch (version) + { + case 1: + return await ReadV1HeaderAsync(stream, ct).ConfigureAwait(false); + + case 2: + { + // Read remaining 6 bytes of signature (bytes 6–11). + var remaining = new byte[6]; + await ReadFullAsync(stream, remaining, ct).ConfigureAwait(false); + + // Verify full signature. + var fullSig = Encoding.Latin1.GetString(firstBytes) + Encoding.Latin1.GetString(remaining); + if (fullSig != ProxyProtoConstants.V2Sig) + throw Wrap(ProxyProtoErrors.Invalid, "invalid signature"); + + // Read 4 bytes: ver/cmd, fam/proto, addr-len (2 bytes). + var header = new byte[4]; + await ReadFullAsync(stream, header, ct).ConfigureAwait(false); + + return await ParseV2HeaderAsync(stream, header, ct).ConfigureAwait(false); + } + + default: + throw new InvalidOperationException($"unsupported PROXY protocol version: {version}"); + } + } + + /// + /// Reads and parses a PROXY protocol (v1 or v2) header, synchronously. + /// Returns null for LOCAL/UNKNOWN health-check commands. + /// Mirrors Go readProxyProtoHeader. + /// + public static ProxyProtocolAddress? ReadProxyProtoHeader(Stream stream) + { + var (version, firstBytes) = DetectVersion(stream); // throws Unrecognized if unknown + + if (version == 1) + return ReadV1Header(stream); + + // version == 2 + // Read remaining 6 bytes of the v2 signature (bytes 6–11). + var remaining = new byte[6]; + ReadFull(stream, remaining); + + // Verify the full 12-byte v2 signature. + var fullSig = Encoding.Latin1.GetString(firstBytes) + Encoding.Latin1.GetString(remaining); + if (fullSig != ProxyProtoConstants.V2Sig) + throw Wrap(ProxyProtoErrors.Invalid, "invalid v2 signature"); + + // Read 4 bytes: ver/cmd, fam/proto, addr-len (2 bytes). + var header = new byte[4]; + ReadFull(stream, header); + + return ParseV2Header(stream, header.AsSpan()); + } + + /// + /// Reads a PROXY protocol v2 header from a raw byte buffer (test-friendly synchronous version). + /// Mirrors Go readProxyProtoV2Header. + /// + public static ProxyProtocolAddress? ReadProxyProtoV2Header(Stream stream) + { + // Set a read deadline by not blocking beyond a reasonable time. + // In the synchronous version we rely on a cancellation token internally. + using var cts = new CancellationTokenSource(ProxyProtoConstants.ReadTimeout); + + // Read fixed header (16 bytes). + var header = new byte[ProxyProtoConstants.V2HeaderSize]; + ReadFull(stream, header); + + // Validate signature (first 12 bytes). + if (Encoding.Latin1.GetString(header, 0, 12) != ProxyProtoConstants.V2Sig) + throw Wrap(ProxyProtoErrors.Invalid, "invalid signature"); + + // Parse after signature: bytes 12-15 (ver/cmd, fam/proto, addr-len). + return ParseV2Header(stream, header.AsSpan(12, 4)); + } + + // ------------------------------------------------------------------------- + // Internal: version detection + // ------------------------------------------------------------------------- + + internal static async Task<(int version, byte[] firstBytes, Exception? err)> DetectVersionAsync( + Stream stream, CancellationToken ct) + { + var buf = new byte[6]; + try + { + await ReadFullAsync(stream, buf, ct).ConfigureAwait(false); + } + catch (Exception ex) + { + return (0, buf, new IOException("failed to read protocol version", ex)); + } + + var s = Encoding.Latin1.GetString(buf); + if (s == ProxyProtoConstants.V1Prefix) + return (1, buf, null); + if (s == ProxyProtoConstants.V2Sig[..6]) + return (2, buf, null); + + return (0, buf, ProxyProtoErrors.Unrecognized); + } + + /// + /// Synchronous version of version detection — used by test code. + /// Mirrors Go detectProxyProtoVersion. + /// + internal static (int version, byte[] firstBytes) DetectVersion(Stream stream) + { + var buf = new byte[6]; + ReadFull(stream, buf); + + var s = Encoding.Latin1.GetString(buf); + if (s == ProxyProtoConstants.V1Prefix) + return (1, buf); + if (s == ProxyProtoConstants.V2Sig[..6]) + return (2, buf); + + throw ProxyProtoErrors.Unrecognized; + } + + // ------------------------------------------------------------------------- + // Internal: v1 parser + // ------------------------------------------------------------------------- + + internal static async Task ReadV1HeaderAsync(Stream stream, CancellationToken ct) + { + // "PROXY " prefix was already consumed (6 bytes). + int maxRemaining = ProxyProtoConstants.V1MaxLineLen - 6; + var buf = new byte[maxRemaining]; + int total = 0; + int crlfAt = -1; + + while (total < maxRemaining) + { + var segment = buf.AsMemory(total); + int n = await stream.ReadAsync(segment, ct).ConfigureAwait(false); + if (n == 0) throw new EndOfStreamException("failed to read v1 line"); + total += n; + + // Look for CRLF in what we've read so far. + for (int i = 0; i < total - 1; i++) + { + if (buf[i] == '\r' && buf[i + 1] == '\n') + { + crlfAt = i; + break; + } + } + if (crlfAt >= 0) break; + } + + if (crlfAt < 0) + throw Wrap(ProxyProtoErrors.Invalid, "v1 line too long"); + + return ParseV1Line(buf.AsSpan(0, crlfAt)); + } + + /// + /// Synchronous v1 parser. Mirrors Go readProxyProtoV1Header. + /// + internal static ProxyProtocolAddress? ReadV1Header(Stream stream) + { + int maxRemaining = ProxyProtoConstants.V1MaxLineLen - 6; + var buf = new byte[maxRemaining]; + int total = 0; + int crlfAt = -1; + + while (total < maxRemaining) + { + int n = stream.Read(buf, total, maxRemaining - total); + if (n == 0) throw new EndOfStreamException("failed to read v1 line"); + total += n; + + for (int i = 0; i < total - 1; i++) + { + if (buf[i] == '\r' && buf[i + 1] == '\n') + { + crlfAt = i; + break; + } + } + if (crlfAt >= 0) break; + } + + if (crlfAt < 0) + throw Wrap(ProxyProtoErrors.Invalid, "v1 line too long"); + + return ParseV1Line(buf.AsSpan(0, crlfAt)); + } + + private static ProxyProtocolAddress? ParseV1Line(ReadOnlySpan line) + { + var text = Encoding.ASCII.GetString(line).Trim(); + var parts = text.Split((char[]?)null, StringSplitOptions.RemoveEmptyEntries); + + if (parts.Length < 1) + throw Wrap(ProxyProtoErrors.Invalid, "invalid v1 format"); + + // UNKNOWN is a health-check (like LOCAL in v2). + if (parts[0] == ProxyProtoConstants.V1Unknown) + return null; + + if (parts.Length != 5) + throw Wrap(ProxyProtoErrors.Invalid, "invalid v1 format"); + + var protocol = parts[0]; + if (!IPAddress.TryParse(parts[1], out var srcIp) || !IPAddress.TryParse(parts[2], out var dstIp)) + throw Wrap(ProxyProtoErrors.Invalid, "invalid address"); + + if (!ushort.TryParse(parts[3], out var srcPort)) + throw new FormatException("invalid source port"); + if (!ushort.TryParse(parts[4], out var dstPort)) + throw new FormatException("invalid dest port"); + + // Validate protocol vs IP version. + bool isIpv4 = srcIp.AddressFamily == AddressFamily.InterNetwork; + if (protocol == ProxyProtoConstants.V1TCP4 && !isIpv4) + throw Wrap(ProxyProtoErrors.Invalid, "TCP4 with IPv6 address"); + if (protocol == ProxyProtoConstants.V1TCP6 && isIpv4) + throw Wrap(ProxyProtoErrors.Invalid, "TCP6 with IPv4 address"); + if (protocol != ProxyProtoConstants.V1TCP4 && protocol != ProxyProtoConstants.V1TCP6) + throw Wrap(ProxyProtoErrors.Invalid, $"invalid protocol {protocol}"); + + return new ProxyProtocolAddress(srcIp, srcPort, dstIp, dstPort); + } + + // ------------------------------------------------------------------------- + // Internal: v2 parser + // ------------------------------------------------------------------------- + + internal static async Task ParseV2HeaderAsync( + Stream stream, byte[] header, CancellationToken ct) + { + return ParseV2Header(stream, header.AsSpan()); + } + + /// + /// Parses PROXY protocol v2 after the signature has been validated. + /// is the 4 bytes: ver/cmd, fam/proto, addr-len (2 bytes). + /// Mirrors Go parseProxyProtoV2Header. + /// + internal static ProxyProtocolAddress? ParseV2Header(Stream stream, ReadOnlySpan header) + { + byte verCmd = header[0]; + byte version = (byte)(verCmd & ProxyProtoConstants.VerMask); + byte command = (byte)(verCmd & ProxyProtoConstants.CmdMask); + + if (version != ProxyProtoConstants.Ver2) + throw Wrap(ProxyProtoErrors.Invalid, $"invalid version 0x{version:X2}"); + + byte famProto = header[1]; + byte family = (byte)(famProto & ProxyProtoConstants.FamilyMask); + byte proto = (byte)(famProto & ProxyProtoConstants.ProtoMask); + + ushort addrLen = BinaryPrimitives.ReadUInt16BigEndian(header[2..]); + + // LOCAL command — health check. + if (command == ProxyProtoConstants.CmdLocal) + { + if (addrLen > 0) + DiscardBytes(stream, addrLen); + return null; + } + + if (command != ProxyProtoConstants.CmdProxy) + throw new InvalidDataException($"unknown PROXY protocol command: 0x{command:X2}"); + + if (proto != ProxyProtoConstants.ProtoStream) + throw Wrap(ProxyProtoErrors.Unsupported, "only STREAM protocol supported"); + + switch (family) + { + case ProxyProtoConstants.FamilyInet: + return ParseIPv4Addr(stream, addrLen); + + case ProxyProtoConstants.FamilyInet6: + return ParseIPv6Addr(stream, addrLen); + + case ProxyProtoConstants.FamilyUnspec: + if (addrLen > 0) + DiscardBytes(stream, addrLen); + return null; + + default: + throw Wrap(ProxyProtoErrors.Unsupported, $"unsupported address family 0x{family:X2}"); + } + } + + /// + /// Parses IPv4 address data. + /// Mirrors Go parseIPv4Addr. + /// + internal static ProxyProtocolAddress ParseIPv4Addr(Stream stream, ushort addrLen) + { + if (addrLen < ProxyProtoConstants.AddrSizeIPv4) + throw new InvalidDataException($"IPv4 address data too short: {addrLen} bytes"); + + var data = new byte[addrLen]; + ReadFull(stream, data); + + var srcIp = new IPAddress(data[0..4]); + var dstIp = new IPAddress(data[4..8]); + var srcPort = BinaryPrimitives.ReadUInt16BigEndian(data.AsSpan(8, 2)); + var dstPort = BinaryPrimitives.ReadUInt16BigEndian(data.AsSpan(10, 2)); + + return new ProxyProtocolAddress(srcIp, srcPort, dstIp, dstPort); + } + + /// + /// Parses IPv6 address data. + /// Mirrors Go parseIPv6Addr. + /// + internal static ProxyProtocolAddress ParseIPv6Addr(Stream stream, ushort addrLen) + { + if (addrLen < ProxyProtoConstants.AddrSizeIPv6) + throw new InvalidDataException($"IPv6 address data too short: {addrLen} bytes"); + + var data = new byte[addrLen]; + ReadFull(stream, data); + + var srcIp = new IPAddress(data[0..16]); + var dstIp = new IPAddress(data[16..32]); + var srcPort = BinaryPrimitives.ReadUInt16BigEndian(data.AsSpan(32, 2)); + var dstPort = BinaryPrimitives.ReadUInt16BigEndian(data.AsSpan(34, 2)); + + return new ProxyProtocolAddress(srcIp, srcPort, dstIp, dstPort); + } + + // ------------------------------------------------------------------------- + // I/O helpers + // ------------------------------------------------------------------------- + + /// + /// Fills completely, throwing + /// (wrapping as with ) + /// on short reads. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static void ReadFull(Stream stream, byte[] buf) + { + int total = 0; + while (total < buf.Length) + { + int n = stream.Read(buf, total, buf.Length - total); + if (n == 0) + throw new IOException("unexpected EOF", new EndOfStreamException()); + total += n; + } + } + + internal static async Task ReadFullAsync(Stream stream, byte[] buf, CancellationToken ct) + { + int total = 0; + while (total < buf.Length) + { + int n = await stream.ReadAsync(buf.AsMemory(total), ct).ConfigureAwait(false); + if (n == 0) + throw new IOException("unexpected EOF", new EndOfStreamException()); + total += n; + } + } + + private static void DiscardBytes(Stream stream, int count) + { + var discard = new byte[count]; + ReadFull(stream, discard); + } + + private static Exception Wrap(Exception sentinel, string detail) + { + // Create a new exception that wraps the sentinel but carries the extra detail. + // The sentinel remains identifiable via the Message prefix (checked in tests with IsAssignableTo). + return new InvalidDataException($"{sentinel.Message}: {detail}", sentinel); + } +} + +// ============================================================================ +// StreamAdapter — wraps a byte array as a Stream (for test convenience) +// ============================================================================ + +/// +/// Minimal read-only backed by a byte array. +/// Used by test helpers to feed proxy protocol bytes into the parser. +/// +internal sealed class ByteArrayStream : Stream +{ + private readonly byte[] _data; + private int _pos; + + public ByteArrayStream(byte[] data) { _data = data; } + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => false; + public override long Length => _data.Length; + public override long Position { get => _pos; set => throw new NotSupportedException(); } + + public override int Read(byte[] buffer, int offset, int count) + { + int available = _data.Length - _pos; + if (available <= 0) return 0; + int toCopy = Math.Min(count, available); + Buffer.BlockCopy(_data, _pos, buffer, offset, toCopy); + _pos += toCopy; + return toCopy; + } + + public override void Flush() => throw new NotSupportedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + + public void SetReadTimeout(int timeout) { } + public void SetWriteTimeout(int timeout) { } +} diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/ServerErrors.cs b/dotnet/src/ZB.MOM.NatsNet.Server/ServerErrors.cs index 2558140..7c8e8d3 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/ServerErrors.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/ServerErrors.cs @@ -34,6 +34,10 @@ public static class ServerErrors public static readonly Exception ErrAuthentication = new InvalidOperationException("authentication error"); + // Alias used by ClientConnection.AuthViolation(); mirrors Go's ErrAuthorization. + public static readonly Exception ErrAuthorization = + new InvalidOperationException("Authorization Violation"); + public static readonly Exception ErrAuthTimeout = new InvalidOperationException("authentication timeout"); diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ClientTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ClientTests.cs new file mode 100644 index 0000000..c31b8da --- /dev/null +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ClientTests.cs @@ -0,0 +1,320 @@ +// Copyright 2012-2026 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Adapted from server/client_test.go in the NATS server Go source. + +using System.Text; +using Shouldly; +using Xunit; +using ZB.MOM.NatsNet.Server.Internal; +using ZB.MOM.NatsNet.Server.Protocol; + +namespace ZB.MOM.NatsNet.Server.Tests; + +/// +/// Standalone unit tests for helper functions. +/// Adapted from server/client_test.go. +/// +public sealed class ClientTests +{ + // ========================================================================= + // TestSplitSubjectQueue — Test ID 200 + // ========================================================================= + + [Theory] + [InlineData("foo", "foo", null, false)] + [InlineData("foo bar", "foo", "bar", false)] + [InlineData(" foo bar ", "foo", "bar", false)] + [InlineData("foo bar", "foo", "bar", false)] + [InlineData("foo bar fizz", null, null, true)] + public void SplitSubjectQueue_TableDriven( + string sq, string? wantSubject, string? wantQueue, bool wantErr) + { + if (wantErr) + { + Should.Throw(() => ClientConnection.SplitSubjectQueue(sq)); + } + else + { + var (subject, queue) = ClientConnection.SplitSubjectQueue(sq); + subject.ShouldBe(wantSubject is null ? null : Encoding.ASCII.GetBytes(wantSubject)); + queue.ShouldBe(wantQueue is null ? null : Encoding.ASCII.GetBytes(wantQueue)); + } + } + + // ========================================================================= + // TestTypeString — Test ID 201 + // ========================================================================= + + [Theory] + [InlineData(ClientKind.Client, "Client")] + [InlineData(ClientKind.Router, "Router")] + [InlineData(ClientKind.Gateway, "Gateway")] + [InlineData(ClientKind.Leaf, "Leafnode")] + [InlineData(ClientKind.JetStream, "JetStream")] + [InlineData(ClientKind.Account, "Account")] + [InlineData(ClientKind.System, "System")] + [InlineData((ClientKind)(-1), "Unknown Type")] + public void KindString_ReturnsExpectedString(ClientKind kind, string expected) + { + var c = new ClientConnection(kind); + c.KindString().ShouldBe(expected); + } +} + +/// +/// Standalone unit tests for functions. +/// Adapted from server/client_test.go (header utility tests). +/// +public sealed class NatsMessageHeadersTests +{ + // ========================================================================= + // TestRemoveHeaderIfPrefixPresent — Test ID 247 + // ========================================================================= + + [Fact] + public void RemoveHeaderIfPrefixPresent_RemovesMatchingHeaders() + { + byte[]? hdr = null; + hdr = NatsMessageHeaders.GenHeader(hdr, "a", "1"); + hdr = NatsMessageHeaders.GenHeader(hdr, NatsHeaderConstants.JsExpectedStream, "my-stream"); + hdr = NatsMessageHeaders.GenHeader(hdr, NatsHeaderConstants.JsExpectedLastSeq, "22"); + hdr = NatsMessageHeaders.GenHeader(hdr, "b", "2"); + hdr = NatsMessageHeaders.GenHeader(hdr, NatsHeaderConstants.JsExpectedLastSubjSeq, "24"); + hdr = NatsMessageHeaders.GenHeader(hdr, NatsHeaderConstants.JsExpectedLastMsgId, "1"); + hdr = NatsMessageHeaders.GenHeader(hdr, "c", "3"); + + hdr = NatsMessageHeaders.RemoveHeaderIfPrefixPresent(hdr!, "Nats-Expected-"); + + var expected = Encoding.ASCII.GetBytes("NATS/1.0\r\na: 1\r\nb: 2\r\nc: 3\r\n\r\n"); + hdr.ShouldBe(expected); + } + + // ========================================================================= + // TestSliceHeader — Test ID 248 + // ========================================================================= + + [Fact] + public void SliceHeader_ReturnsCorrectSlice() + { + byte[]? hdr = null; + hdr = NatsMessageHeaders.GenHeader(hdr, "a", "1"); + hdr = NatsMessageHeaders.GenHeader(hdr, NatsHeaderConstants.JsExpectedStream, "my-stream"); + hdr = NatsMessageHeaders.GenHeader(hdr, NatsHeaderConstants.JsExpectedLastSeq, "22"); + hdr = NatsMessageHeaders.GenHeader(hdr, "b", "2"); + hdr = NatsMessageHeaders.GenHeader(hdr, NatsHeaderConstants.JsExpectedLastSubjSeq, "24"); + hdr = NatsMessageHeaders.GenHeader(hdr, NatsHeaderConstants.JsExpectedLastMsgId, "1"); + hdr = NatsMessageHeaders.GenHeader(hdr, "c", "3"); + + var sliced = NatsMessageHeaders.SliceHeader(NatsHeaderConstants.JsExpectedLastSubjSeq, hdr!); + var copied = NatsMessageHeaders.GetHeader(NatsHeaderConstants.JsExpectedLastSubjSeq, hdr!); + + sliced.ShouldNotBeNull(); + sliced!.Value.Length.ShouldBe(2); // "24" is 2 bytes + copied.ShouldNotBeNull(); + sliced.Value.ToArray().ShouldBe(copied!); + } + + // ========================================================================= + // TestSliceHeaderOrderingPrefix — Test ID 249 + // ========================================================================= + + [Fact] + public void SliceHeader_OrderingPrefix_LongerHeaderDoesNotPreemptShorter() + { + byte[]? hdr = null; + hdr = NatsMessageHeaders.GenHeader(hdr, NatsHeaderConstants.JsExpectedLastSubjSeqSubj, "foo"); + hdr = NatsMessageHeaders.GenHeader(hdr, NatsHeaderConstants.JsExpectedLastSubjSeq, "24"); + + var sliced = NatsMessageHeaders.SliceHeader(NatsHeaderConstants.JsExpectedLastSubjSeq, hdr!); + var copied = NatsMessageHeaders.GetHeader(NatsHeaderConstants.JsExpectedLastSubjSeq, hdr!); + + sliced.ShouldNotBeNull(); + sliced!.Value.Length.ShouldBe(2); + copied.ShouldNotBeNull(); + sliced.Value.ToArray().ShouldBe(copied!); + } + + // ========================================================================= + // TestSliceHeaderOrderingSuffix — Test ID 250 + // ========================================================================= + + [Fact] + public void SliceHeader_OrderingSuffix_LongerHeaderDoesNotPreemptShorter() + { + byte[]? hdr = null; + hdr = NatsMessageHeaders.GenHeader(hdr, "Previous-Nats-Msg-Id", "user"); + hdr = NatsMessageHeaders.GenHeader(hdr, NatsHeaderConstants.JsMsgId, "control"); + + var sliced = NatsMessageHeaders.SliceHeader(NatsHeaderConstants.JsMsgId, hdr!); + var copied = NatsMessageHeaders.GetHeader(NatsHeaderConstants.JsMsgId, hdr!); + + sliced.ShouldNotBeNull(); + copied.ShouldNotBeNull(); + sliced!.Value.ToArray().ShouldBe(copied!); + Encoding.ASCII.GetString(copied!).ShouldBe("control"); + } + + // ========================================================================= + // TestRemoveHeaderIfPresentOrderingPrefix — Test ID 251 + // ========================================================================= + + [Fact] + public void RemoveHeaderIfPresent_OrderingPrefix_OnlyRemovesExactKey() + { + byte[]? hdr = null; + hdr = NatsMessageHeaders.GenHeader(hdr, NatsHeaderConstants.JsExpectedLastSubjSeqSubj, "foo"); + hdr = NatsMessageHeaders.GenHeader(hdr, NatsHeaderConstants.JsExpectedLastSubjSeq, "24"); + + hdr = NatsMessageHeaders.RemoveHeaderIfPresent(hdr!, NatsHeaderConstants.JsExpectedLastSubjSeq); + var expected = NatsMessageHeaders.GenHeader(null, NatsHeaderConstants.JsExpectedLastSubjSeqSubj, "foo"); + hdr!.ShouldBe(expected); + } + + // ========================================================================= + // TestRemoveHeaderIfPresentOrderingSuffix — Test ID 252 + // ========================================================================= + + [Fact] + public void RemoveHeaderIfPresent_OrderingSuffix_OnlyRemovesExactKey() + { + byte[]? hdr = null; + hdr = NatsMessageHeaders.GenHeader(hdr, "Previous-Nats-Msg-Id", "user"); + hdr = NatsMessageHeaders.GenHeader(hdr, NatsHeaderConstants.JsMsgId, "control"); + + hdr = NatsMessageHeaders.RemoveHeaderIfPresent(hdr!, NatsHeaderConstants.JsMsgId); + var expected = NatsMessageHeaders.GenHeader(null, "Previous-Nats-Msg-Id", "user"); + hdr!.ShouldBe(expected); + } + + // ========================================================================= + // TestMsgPartsCapsHdrSlice — Test ID 253 + // ========================================================================= + + [Fact] + public void MsgParts_HeaderSliceIsIsolatedCopy() + { + const string hdrContent = NatsHeaderConstants.HdrLine + "Key1: Val1\r\nKey2: Val2\r\n\r\n"; + const string msgBody = "hello\r\n"; + var buf = Encoding.ASCII.GetBytes(hdrContent + msgBody); + + var c = new ClientConnection(ClientKind.Client); + c.ParseCtx.Pa.HeaderSize = hdrContent.Length; + + var (hdr, msg) = c.MsgParts(buf); + + // Header and body should have correct content. + Encoding.ASCII.GetString(hdr).ShouldBe(hdrContent); + Encoding.ASCII.GetString(msg).ShouldBe(msgBody); + + // hdr should be shorter than buf (cap(hdr) < cap(buf) in Go). + hdr.Length.ShouldBeLessThan(buf.Length); + + // Appending to hdr should not affect msg. + var extended = hdr.Concat(Encoding.ASCII.GetBytes("test")).ToArray(); + Encoding.ASCII.GetString(extended).ShouldBe(hdrContent + "test"); + Encoding.ASCII.GetString(msg).ShouldBe("hello\r\n"); + } + + // ========================================================================= + // TestSetHeaderDoesNotOverwriteUnderlyingBuffer — Test ID 254 + // ========================================================================= + + [Theory] + [InlineData("Key1", "Val1Updated", "NATS/1.0\r\nKey1: Val1Updated\r\nKey2: Val2\r\n\r\n", true)] + [InlineData("Key1", "v1", "NATS/1.0\r\nKey1: v1\r\nKey2: Val2\r\n\r\n", false)] + [InlineData("Key3", "Val3", "NATS/1.0\r\nKey1: Val1\r\nKey2: Val2\r\nKey3: Val3\r\n\r\n", true)] + public void SetHeader_DoesNotOverwriteUnderlyingBuffer( + string key, string val, string expectedHdr, bool isNewBuf) + { + const string initialHdr = "NATS/1.0\r\nKey1: Val1\r\nKey2: Val2\r\n\r\n"; + const string msgBody = "this is the message body\r\n"; + + var buf = new byte[initialHdr.Length + msgBody.Length]; + Encoding.ASCII.GetBytes(initialHdr).CopyTo(buf, 0); + Encoding.ASCII.GetBytes(msgBody).CopyTo(buf, initialHdr.Length); + + var hdrSlice = buf[..initialHdr.Length]; + var msgSlice = buf[initialHdr.Length..]; + + var updatedHdr = NatsMessageHeaders.SetHeader(key, val, hdrSlice); + + Encoding.ASCII.GetString(updatedHdr).ShouldBe(expectedHdr); + Encoding.ASCII.GetString(msgSlice).ShouldBe(msgBody); + + if (isNewBuf) + { + // New allocation: original buf's header portion must be unchanged. + Encoding.ASCII.GetString(buf, 0, initialHdr.Length).ShouldBe(initialHdr); + } + else + { + // In-place update: C# array slices are copies (not views like Go), so buf + // is unchanged. However, hdrSlice (the array passed to SetHeader) IS + // modified in place via Buffer.BlockCopy. + Encoding.ASCII.GetString(hdrSlice, 0, expectedHdr.Length).ShouldBe(expectedHdr); + } + } + + // ========================================================================= + // TestSetHeaderOrderingPrefix — Test ID 255 + // ========================================================================= + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void SetHeader_OrderingPrefix_LongerHeaderDoesNotPreemptShorter(bool withSpaces) + { + byte[]? hdr = null; + hdr = NatsMessageHeaders.GenHeader(hdr, NatsHeaderConstants.JsExpectedLastSubjSeqSubj, "foo"); + hdr = NatsMessageHeaders.GenHeader(hdr, NatsHeaderConstants.JsExpectedLastSubjSeq, "24"); + if (!withSpaces) + hdr = hdr!.Where(b => b != (byte)' ').ToArray(); + + hdr = NatsMessageHeaders.SetHeader(NatsHeaderConstants.JsExpectedLastSubjSeq, "12", hdr!); + + byte[]? expected = null; + expected = NatsMessageHeaders.GenHeader(expected, NatsHeaderConstants.JsExpectedLastSubjSeqSubj, "foo"); + expected = NatsMessageHeaders.GenHeader(expected, NatsHeaderConstants.JsExpectedLastSubjSeq, "12"); + if (!withSpaces) + expected = expected!.Where(b => b != (byte)' ').ToArray(); + + hdr!.ShouldBe(expected!); + } + + // ========================================================================= + // TestSetHeaderOrderingSuffix — Test ID 256 + // ========================================================================= + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void SetHeader_OrderingSuffix_LongerHeaderDoesNotPreemptShorter(bool withSpaces) + { + byte[]? hdr = null; + hdr = NatsMessageHeaders.GenHeader(hdr, "Previous-Nats-Msg-Id", "user"); + hdr = NatsMessageHeaders.GenHeader(hdr, NatsHeaderConstants.JsMsgId, "control"); + if (!withSpaces) + hdr = hdr!.Where(b => b != (byte)' ').ToArray(); + + hdr = NatsMessageHeaders.SetHeader(NatsHeaderConstants.JsMsgId, "other", hdr!); + + byte[]? expected = null; + expected = NatsMessageHeaders.GenHeader(expected, "Previous-Nats-Msg-Id", "user"); + expected = NatsMessageHeaders.GenHeader(expected, NatsHeaderConstants.JsMsgId, "other"); + if (!withSpaces) + expected = expected!.Where(b => b != (byte)' ').ToArray(); + + hdr!.ShouldBe(expected!); + } +} diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Protocol/ProxyProtocolTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Protocol/ProxyProtocolTests.cs new file mode 100644 index 0000000..086644a --- /dev/null +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Protocol/ProxyProtocolTests.cs @@ -0,0 +1,430 @@ +// Copyright 2025 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Adapted from server/client_proxyproto_test.go in the NATS server Go source. + +using System.Buffers.Binary; +using System.Net; +using System.Net.Sockets; +using Shouldly; +using Xunit; +using ZB.MOM.NatsNet.Server.Protocol; + +namespace ZB.MOM.NatsNet.Server.Tests.Protocol; + +/// +/// Unit tests for , , +/// and . +/// Adapted from server/client_proxyproto_test.go. +/// +[Collection("ProxyProtocol")] +public sealed class ProxyProtocolTests +{ + // ========================================================================= + // Test helpers — mirrors Go helper functions + // ========================================================================= + + /// + /// Builds a valid PROXY protocol v2 binary header. + /// Mirrors Go buildProxyV2Header. + /// + private static byte[] BuildProxyV2Header( + string srcIP, string dstIP, ushort srcPort, ushort dstPort, byte family) + { + using var buf = new MemoryStream(); + + // 12-byte signature + const string v2Sig = "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"; + foreach (char c in v2Sig) + buf.WriteByte((byte)c); + + // ver/cmd: version 2 (0x20) | PROXY command (0x01) + buf.WriteByte(0x21); // proxyProtoV2Ver | proxyProtoCmdProxy + + // fam/proto + buf.WriteByte((byte)(family | 0x01)); // family | ProtoStream + + var src = IPAddress.Parse(srcIP); + var dst = IPAddress.Parse(dstIP); + byte[] addrData; + + if (family == 0x10) // FamilyInet + { + addrData = new byte[12]; // 4+4+2+2 + src.GetAddressBytes().CopyTo(addrData, 0); + dst.GetAddressBytes().CopyTo(addrData, 4); + BinaryPrimitives.WriteUInt16BigEndian(addrData.AsSpan(8, 2), srcPort); + BinaryPrimitives.WriteUInt16BigEndian(addrData.AsSpan(10, 2), dstPort); + } + else if (family == 0x20) // FamilyInet6 + { + addrData = new byte[36]; // 16+16+2+2 + src.GetAddressBytes().CopyTo(addrData, 0); + dst.GetAddressBytes().CopyTo(addrData, 16); + BinaryPrimitives.WriteUInt16BigEndian(addrData.AsSpan(32, 2), srcPort); + BinaryPrimitives.WriteUInt16BigEndian(addrData.AsSpan(34, 2), dstPort); + } + else + { + throw new ArgumentException($"unsupported family: {family}"); + } + + // addr-len (big-endian 2 bytes) + var lenBytes = new byte[2]; + BinaryPrimitives.WriteUInt16BigEndian(lenBytes, (ushort)addrData.Length); + buf.Write(lenBytes); + buf.Write(addrData); + + return buf.ToArray(); + } + + /// + /// Builds a PROXY protocol v2 LOCAL command header. + /// Mirrors Go buildProxyV2LocalHeader. + /// + private static byte[] BuildProxyV2LocalHeader() + { + using var buf = new MemoryStream(); + const string v2Sig = "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"; + foreach (char c in v2Sig) + buf.WriteByte((byte)c); + buf.WriteByte(0x20); // proxyProtoV2Ver | proxyProtoCmdLocal + buf.WriteByte(0x00); // FamilyUnspec | ProtoUnspec + buf.WriteByte(0); + buf.WriteByte(0); + return buf.ToArray(); + } + + /// + /// Builds a PROXY protocol v1 text header. + /// Mirrors Go buildProxyV1Header. + /// + private static byte[] BuildProxyV1Header( + string protocol, string srcIP, string dstIP, ushort srcPort, ushort dstPort) + { + string line; + if (protocol == "UNKNOWN") + line = "PROXY UNKNOWN\r\n"; + else + line = $"PROXY {protocol} {srcIP} {dstIP} {srcPort} {dstPort}\r\n"; + + return System.Text.Encoding.ASCII.GetBytes(line); + } + + // ========================================================================= + // PROXY Protocol v2 Parse Tests + // ========================================================================= + + /// Test ID 159 — TestClientProxyProtoV2ParseIPv4 + [Fact] + public void ProxyProtoV2_ParseIPv4_ReturnsCorrectAddresses() + { + var header = BuildProxyV2Header("192.168.1.50", "10.0.0.1", 12345, 4222, 0x10); + using var stream = new MemoryStream(header); + + var addr = ProxyProtocolParser.ReadProxyProtoV2Header(stream); + + addr.ShouldNotBeNull(); + addr!.SrcIp.ToString().ShouldBe("192.168.1.50"); + addr.SrcPort.ShouldBe((ushort)12345); + addr.DstIp.ToString().ShouldBe("10.0.0.1"); + addr.DstPort.ShouldBe((ushort)4222); + addr.String().ShouldBe("192.168.1.50:12345"); + addr.Network().ShouldBe("tcp4"); + } + + /// Test ID 160 — TestClientProxyProtoV2ParseIPv6 + [Fact] + public void ProxyProtoV2_ParseIPv6_ReturnsCorrectAddresses() + { + var header = BuildProxyV2Header("2001:db8::1", "2001:db8::2", 54321, 4222, 0x20); + using var stream = new MemoryStream(header); + + var addr = ProxyProtocolParser.ReadProxyProtoV2Header(stream); + + addr.ShouldNotBeNull(); + addr!.SrcIp.ToString().ShouldBe("2001:db8::1"); + addr.SrcPort.ShouldBe((ushort)54321); + addr.DstIp.ToString().ShouldBe("2001:db8::2"); + addr.DstPort.ShouldBe((ushort)4222); + addr.String().ShouldBe("[2001:db8::1]:54321"); + addr.Network().ShouldBe("tcp6"); + } + + /// Test ID 161 — TestClientProxyProtoV2ParseLocalCommand + [Fact] + public void ProxyProtoV2_LocalCommand_ReturnsNull() + { + var header = BuildProxyV2LocalHeader(); + using var stream = new MemoryStream(header); + + var addr = ProxyProtocolParser.ReadProxyProtoV2Header(stream); + + addr.ShouldBeNull(); + } + + /// Test ID 162 — TestClientProxyProtoV2InvalidSignature + [Fact] + public void ProxyProtoV2_InvalidSignature_ThrowsInvalidData() + { + var header = new byte[16]; + System.Text.Encoding.ASCII.GetBytes("INVALID_SIG_").CopyTo(header, 0); + header[12] = 0x20; header[13] = 0x11; header[14] = 0x00; header[15] = 0x0C; + using var stream = new MemoryStream(header); + + Should.Throw(() => + ProxyProtocolParser.ReadProxyProtoV2Header(stream)); + } + + /// Test ID 163 — TestClientProxyProtoV2InvalidVersion + [Fact] + public void ProxyProtoV2_InvalidVersion_ThrowsInvalidData() + { + using var buf = new MemoryStream(); + const string v2Sig = "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"; + foreach (char c in v2Sig) + buf.WriteByte((byte)c); + buf.WriteByte(0x10 | 0x01); // version 1 instead of 2 — invalid + buf.WriteByte(0x10 | 0x01); // FamilyInet | ProtoStream + buf.WriteByte(0); buf.WriteByte(0); + + using var stream = new MemoryStream(buf.ToArray()); + Should.Throw(() => + ProxyProtocolParser.ReadProxyProtoV2Header(stream)); + } + + /// Test ID 164 — TestClientProxyProtoV2UnsupportedFamily + [Fact] + public void ProxyProtoV2_UnixSocketFamily_ThrowsUnsupported() + { + using var buf = new MemoryStream(); + const string v2Sig = "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"; + foreach (char c in v2Sig) + buf.WriteByte((byte)c); + buf.WriteByte(0x21); // v2 ver | proxy cmd + buf.WriteByte(0x30 | 0x01); // FamilyUnix | ProtoStream + buf.WriteByte(0); buf.WriteByte(0); + + using var stream = new MemoryStream(buf.ToArray()); + Should.Throw(() => + ProxyProtocolParser.ReadProxyProtoV2Header(stream)); + } + + /// Test ID 165 — TestClientProxyProtoV2UnsupportedProtocol + [Fact] + public void ProxyProtoV2_UdpProtocol_ThrowsUnsupported() + { + using var buf = new MemoryStream(); + const string v2Sig = "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"; + foreach (char c in v2Sig) + buf.WriteByte((byte)c); + buf.WriteByte(0x21); // v2 ver | proxy cmd + buf.WriteByte(0x10 | 0x02); // FamilyInet | ProtoDatagram (UDP) + buf.WriteByte(0); buf.WriteByte(12); // addr len = 12 + + using var stream = new MemoryStream(buf.ToArray()); + Should.Throw(() => + ProxyProtocolParser.ReadProxyProtoV2Header(stream)); + } + + /// Test ID 166 — TestClientProxyProtoV2TruncatedHeader + [Fact] + public void ProxyProtoV2_TruncatedHeader_ThrowsIOException() + { + var fullHeader = BuildProxyV2Header("192.168.1.50", "10.0.0.1", 12345, 4222, 0x10); + // Only provide first 10 bytes — header is 16 bytes minimum + using var stream = new MemoryStream(fullHeader[..10]); + + Should.Throw(() => + ProxyProtocolParser.ReadProxyProtoV2Header(stream)); + } + + /// Test ID 167 — TestClientProxyProtoV2ShortAddressData + [Fact] + public void ProxyProtoV2_ShortAddressData_ThrowsIOException() + { + using var buf = new MemoryStream(); + const string v2Sig = "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"; + foreach (char c in v2Sig) + buf.WriteByte((byte)c); + buf.WriteByte(0x21); // v2 ver | proxy cmd + buf.WriteByte(0x10 | 0x01); // FamilyInet | ProtoStream + buf.WriteByte(0); buf.WriteByte(12); // addr len = 12 but only 5 bytes follow + buf.Write(new byte[] { 1, 2, 3, 4, 5 }); // only 5 bytes + + using var stream = new MemoryStream(buf.ToArray()); + Should.Throw(() => + ProxyProtocolParser.ReadProxyProtoV2Header(stream)); + } + + /// Test ID 168 — TestProxyConnRemoteAddr + [Fact] + public void ProxyConn_RemoteAddr_ReturnsProxiedAddress() + { + var proxyAddr = new ProxyProtocolAddress( + IPAddress.Parse("10.0.0.50"), 12345, + IPAddress.Parse("10.0.0.1"), 4222); + + using var inner = new MemoryStream(); + var wrapped = new ProxyProtocolConnection(inner, proxyAddr); + + wrapped.RemoteAddress.String().ShouldBe("10.0.0.50:12345"); + } + + // ========================================================================= + // PROXY Protocol v1 Parse Tests + // ========================================================================= + + /// Test ID 171 — TestClientProxyProtoV1ParseTCP4 + [Fact] + public void ProxyProtoV1_ParseTCP4_ReturnsCorrectAddresses() + { + var header = BuildProxyV1Header("TCP4", "192.168.1.50", "10.0.0.1", 12345, 4222); + using var stream = new MemoryStream(header); + + var addr = ProxyProtocolParser.ReadProxyProtoHeader(stream); + + addr.ShouldNotBeNull(); + addr!.SrcIp.ToString().ShouldBe("192.168.1.50"); + addr.SrcPort.ShouldBe((ushort)12345); + addr.DstIp.ToString().ShouldBe("10.0.0.1"); + addr.DstPort.ShouldBe((ushort)4222); + } + + /// Test ID 172 — TestClientProxyProtoV1ParseTCP6 + [Fact] + public void ProxyProtoV1_ParseTCP6_ReturnsCorrectAddresses() + { + var header = BuildProxyV1Header("TCP6", "2001:db8::1", "2001:db8::2", 54321, 4222); + using var stream = new MemoryStream(header); + + var addr = ProxyProtocolParser.ReadProxyProtoHeader(stream); + + addr.ShouldNotBeNull(); + addr!.SrcIp.ToString().ShouldBe("2001:db8::1"); + addr.SrcPort.ShouldBe((ushort)54321); + addr.DstIp.ToString().ShouldBe("2001:db8::2"); + addr.DstPort.ShouldBe((ushort)4222); + } + + /// Test ID 173 — TestClientProxyProtoV1ParseUnknown + [Fact] + public void ProxyProtoV1_UnknownProtocol_ReturnsNull() + { + var header = BuildProxyV1Header("UNKNOWN", "", "", 0, 0); + using var stream = new MemoryStream(header); + + var addr = ProxyProtocolParser.ReadProxyProtoHeader(stream); + + addr.ShouldBeNull(); + } + + /// Test ID 174 — TestClientProxyProtoV1InvalidFormat + [Fact] + public void ProxyProtoV1_MissingFields_ThrowsInvalidData() + { + var header = System.Text.Encoding.ASCII.GetBytes("PROXY TCP4 192.168.1.1\r\n"); + using var stream = new MemoryStream(header); + + Should.Throw(() => + ProxyProtocolParser.ReadProxyProtoHeader(stream)); + } + + /// Test ID 175 — TestClientProxyProtoV1LineTooLong + [Fact] + public void ProxyProtoV1_LineTooLong_ThrowsInvalidData() + { + var longIp = new string('1', 120); + var line = $"PROXY TCP4 {longIp} 10.0.0.1 12345 443\r\n"; + var header = System.Text.Encoding.ASCII.GetBytes(line); + using var stream = new MemoryStream(header); + + Should.Throw(() => + ProxyProtocolParser.ReadProxyProtoHeader(stream)); + } + + /// Test ID 176 — TestClientProxyProtoV1InvalidIP + [Fact] + public void ProxyProtoV1_InvalidIPAddress_ThrowsInvalidData() + { + var header = System.Text.Encoding.ASCII.GetBytes( + "PROXY TCP4 not.an.ip.addr 10.0.0.1 12345 443\r\n"); + using var stream = new MemoryStream(header); + + Should.Throw(() => + ProxyProtocolParser.ReadProxyProtoHeader(stream)); + } + + /// Test ID 177 — TestClientProxyProtoV1MismatchedProtocol + [Fact] + public void ProxyProtoV1_TCP4WithIPv6Address_ThrowsInvalidData() + { + // TCP4 with IPv6 address + var header = BuildProxyV1Header("TCP4", "2001:db8::1", "2001:db8::2", 12345, 443); + using var stream = new MemoryStream(header); + Should.Throw(() => + ProxyProtocolParser.ReadProxyProtoHeader(stream)); + + // TCP6 with IPv4 address + var header2 = BuildProxyV1Header("TCP6", "192.168.1.1", "10.0.0.1", 12345, 443); + using var stream2 = new MemoryStream(header2); + Should.Throw(() => + ProxyProtocolParser.ReadProxyProtoHeader(stream2)); + } + + /// Test ID 178 — TestClientProxyProtoV1InvalidPort + [Fact] + public void ProxyProtoV1_InvalidPort_ThrowsException() + { + var header = System.Text.Encoding.ASCII.GetBytes( + "PROXY TCP4 192.168.1.1 10.0.0.1 99999 443\r\n"); + using var stream = new MemoryStream(header); + + Should.Throw(() => + ProxyProtocolParser.ReadProxyProtoHeader(stream)); + } + + // ========================================================================= + // Mixed Protocol Version Tests + // ========================================================================= + + /// Test ID 180 — TestClientProxyProtoVersionDetection + [Fact] + public void ProxyProto_AutoDetect_HandlesV1AndV2() + { + // v1 detection + var v1Header = BuildProxyV1Header("TCP4", "192.168.1.1", "10.0.0.1", 12345, 443); + using var stream1 = new MemoryStream(v1Header); + var addr1 = ProxyProtocolParser.ReadProxyProtoHeader(stream1); + addr1.ShouldNotBeNull(); + addr1!.SrcIp.ToString().ShouldBe("192.168.1.1"); + + // v2 detection + var v2Header = BuildProxyV2Header("192.168.1.2", "10.0.0.1", 54321, 443, 0x10); + using var stream2 = new MemoryStream(v2Header); + var addr2 = ProxyProtocolParser.ReadProxyProtoHeader(stream2); + addr2.ShouldNotBeNull(); + addr2!.SrcIp.ToString().ShouldBe("192.168.1.2"); + } + + /// Test ID 181 — TestClientProxyProtoUnrecognizedVersion + [Fact] + public void ProxyProto_UnrecognizedFormat_ThrowsInvalidData() + { + var header = System.Text.Encoding.ASCII.GetBytes("HELLO WORLD\r\n"); + using var stream = new MemoryStream(header); + + Should.Throw(() => + ProxyProtocolParser.ReadProxyProtoHeader(stream)); + } +} diff --git a/porting.db b/porting.db index 44e4eb9..b7fd35a 100644 Binary files a/porting.db and b/porting.db differ diff --git a/reports/current.md b/reports/current.md index 375a0ac..75ed31a 100644 --- a/reports/current.md +++ b/reports/current.md @@ -1,6 +1,6 @@ # NATS .NET Porting Status Report -Generated: 2026-02-26 18:16:57 UTC +Generated: 2026-02-26 18:50:39 UTC ## Modules (12 total) @@ -13,18 +13,18 @@ Generated: 2026-02-26 18:16:57 UTC | Status | Count | |--------|-------| -| complete | 472 | +| complete | 667 | | n_a | 82 | -| not_started | 3026 | +| not_started | 2831 | | stub | 93 | ## Unit Tests (3257 total) | Status | Count | |--------|-------| -| complete | 242 | -| n_a | 82 | -| not_started | 2709 | +| complete | 274 | +| n_a | 163 | +| not_started | 2596 | | stub | 224 | ## Library Mappings (36 total) @@ -36,4 +36,4 @@ Generated: 2026-02-26 18:16:57 UTC ## Overall Progress -**889/6942 items complete (12.8%)** +**1197/6942 items complete (17.2%)** diff --git a/reports/report_88b1391.md b/reports/report_88b1391.md new file mode 100644 index 0000000..75ed31a --- /dev/null +++ b/reports/report_88b1391.md @@ -0,0 +1,39 @@ +# NATS .NET Porting Status Report + +Generated: 2026-02-26 18:50:39 UTC + +## Modules (12 total) + +| Status | Count | +|--------|-------| +| complete | 11 | +| not_started | 1 | + +## Features (3673 total) + +| Status | Count | +|--------|-------| +| complete | 667 | +| n_a | 82 | +| not_started | 2831 | +| stub | 93 | + +## Unit Tests (3257 total) + +| Status | Count | +|--------|-------| +| complete | 274 | +| n_a | 163 | +| not_started | 2596 | +| stub | 224 | + +## Library Mappings (36 total) + +| Status | Count | +|--------|-------| +| mapped | 36 | + + +## Overall Progress + +**1197/6942 items complete (17.2%)**