feat: port session 07 — Protocol Parser, Auth extras (TPM/certidp/certstore), Internal utilities & data structures

Session 07 scope (5 features, 17 tests, ~1165 Go LOC):
- Protocol/ParserTypes.cs: ParserState enum (79 states), PublishArgument, ParseContext
- Protocol/IProtocolHandler.cs: handler interface decoupling parser from client
- Protocol/ProtocolParser.cs: Parse(), ProtoSnippet(), OverMaxControlLineLimit(),
  ProcessPub/HeaderPub/RoutedMsgArgs/RoutedHeaderMsgArgs, ClonePubArg(), GetHeader()
- tests/Protocol/ProtocolParserTests.cs: 17 tests via TestProtocolHandler stub

Auth extras from session 06 (committed separately):
- Auth/TpmKeyProvider.cs, Auth/CertificateIdentityProvider/, Auth/CertificateStore/

Internal utilities & data structures (session 06 overflow):
- Internal/AccessTimeService.cs, ElasticPointer.cs, SystemMemory.cs, ProcessStatsProvider.cs
- Internal/DataStructures/GenericSublist.cs, HashWheel.cs
- Internal/DataStructures/SubjectTree.cs, SubjectTreeNode.cs, SubjectTreeParts.cs

All 461 tests pass (460 unit + 1 integration). DB updated for features 2588-2592 and tests 2598-2614.
This commit is contained in:
Joseph Doherty
2026-02-26 13:16:56 -05:00
parent 0a54d342ba
commit 88b1391ef0
56 changed files with 9006 additions and 6 deletions

View File

@@ -0,0 +1,57 @@
namespace ZB.MOM.NatsNet.Server.Auth.CertificateIdentityProvider;
/// <summary>
/// Error and debug message constants for the OCSP peer identity provider.
/// Mirrors certidp/messages.go.
/// </summary>
public static class OcspMessages
{
// Returned errors
public const string ErrIllegalPeerOptsConfig = "expected map to define OCSP peer options, got [{0}]";
public const string ErrIllegalCacheOptsConfig = "expected map to define OCSP peer cache options, got [{0}]";
public const string ErrParsingPeerOptFieldGeneric = "error parsing tls peer config, unknown field [\"{0}\"]";
public const string ErrParsingPeerOptFieldTypeConversion = "error parsing tls peer config, conversion error: {0}";
public const string ErrParsingCacheOptFieldTypeConversion = "error parsing OCSP peer cache config, conversion error: {0}";
public const string ErrUnableToPlugTLSEmptyConfig = "unable to plug TLS verify connection, config is nil";
public const string ErrMTLSRequired = "OCSP peer verification for client connections requires TLS verify (mTLS) to be enabled";
public const string ErrUnableToPlugTLSClient = "unable to register client OCSP verification";
public const string ErrUnableToPlugTLSServer = "unable to register server OCSP verification";
public const string ErrCannotWriteCompressed = "error writing to compression writer: {0}";
public const string ErrCannotReadCompressed = "error reading compression reader: {0}";
public const string ErrTruncatedWrite = "short write on body ({0} != {1})";
public const string ErrCannotCloseWriter = "error closing compression writer: {0}";
public const string ErrParsingCacheOptFieldGeneric = "error parsing OCSP peer cache config, unknown field [\"{0}\"]";
public const string ErrUnknownCacheType = "error parsing OCSP peer cache config, unknown type [{0}]";
public const string ErrInvalidChainlink = "invalid chain link";
public const string ErrBadResponderHTTPStatus = "bad OCSP responder http status: [{0}]";
public const string ErrNoAvailOCSPServers = "no available OCSP servers";
public const string ErrFailedWithAllRequests = "exhausted OCSP responders: {0}";
// Direct logged errors
public const string ErrLoadCacheFail = "Unable to load OCSP peer cache: {0}";
public const string ErrSaveCacheFail = "Unable to save OCSP peer cache: {0}";
public const string ErrBadCacheTypeConfig = "Unimplemented OCSP peer cache type [{0}]";
public const string ErrResponseCompressFail = "Unable to compress OCSP response for key [{0}]: {1}";
public const string ErrResponseDecompressFail = "Unable to decompress OCSP response for key [{0}]: {1}";
public const string ErrPeerEmptyNoEvent = "Peer certificate is nil, cannot send OCSP peer reject event";
public const string ErrPeerEmptyAutoReject = "Peer certificate is nil, rejecting OCSP peer";
// Debug messages
public const string DbgPlugTLSForKind = "Plugging TLS OCSP peer for [{0}]";
public const string DbgNumServerChains = "Peer OCSP enabled: {0} TLS server chain(s) will be evaluated";
public const string DbgNumClientChains = "Peer OCSP enabled: {0} TLS client chain(s) will be evaluated";
public const string DbgLinksInChain = "Chain [{0}]: {1} total link(s)";
public const string DbgSelfSignedValid = "Chain [{0}] is self-signed, thus peer is valid";
public const string DbgValidNonOCSPChain = "Chain [{0}] has no OCSP eligible links, thus peer is valid";
public const string DbgChainIsOCSPEligible = "Chain [{0}] has {1} OCSP eligible link(s)";
public const string DbgChainIsOCSPValid = "Chain [{0}] is OCSP valid for all eligible links, thus peer is valid";
public const string DbgNoOCSPValidChains = "No OCSP valid chains, thus peer is invalid";
public const string DbgCheckingCacheForCert = "Checking OCSP peer cache for [{0}], key [{1}]";
public const string DbgCurrentResponseCached = "Cached OCSP response is current, status [{0}]";
public const string DbgExpiredResponseCached = "Cached OCSP response is expired, status [{0}]";
public const string DbgOCSPValidPeerLink = "OCSP verify pass for [{0}]";
public const string DbgMakingCARequest = "Making OCSP CA request to [{0}]";
public const string DbgResponseExpired = "OCSP response expired: NextUpdate={0}, now={1}, skew={2}";
public const string DbgResponseTTLExpired = "OCSP response TTL expired: expiry={0}, now={1}, skew={2}";
public const string DbgResponseFutureDated = "OCSP response is future-dated: ThisUpdate={0}, now={1}, skew={2}";
}

View File

@@ -0,0 +1,129 @@
using System.Security.Cryptography.X509Certificates;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace ZB.MOM.NatsNet.Server.Auth.CertificateIdentityProvider;
/// <summary>OCSP certificate status values.</summary>
/// <remarks>Mirrors the Go <c>ocsp.Good/Revoked/Unknown</c> constants (0/1/2).</remarks>
[JsonConverter(typeof(OcspStatusAssertionJsonConverter))]
public enum OcspStatusAssertion
{
Good = 0,
Revoked = 1,
Unknown = 2,
}
/// <summary>JSON converter: serializes <see cref="OcspStatusAssertion"/> as lowercase string.</summary>
public sealed class OcspStatusAssertionJsonConverter : JsonConverter<OcspStatusAssertion>
{
private static readonly IReadOnlyDictionary<string, OcspStatusAssertion> StrToVal =
new Dictionary<string, OcspStatusAssertion>(StringComparer.OrdinalIgnoreCase)
{
["good"] = OcspStatusAssertion.Good,
["revoked"] = OcspStatusAssertion.Revoked,
["unknown"] = OcspStatusAssertion.Unknown,
};
private static readonly IReadOnlyDictionary<OcspStatusAssertion, string> ValToStr =
new Dictionary<OcspStatusAssertion, string>
{
[OcspStatusAssertion.Good] = "good",
[OcspStatusAssertion.Revoked] = "revoked",
[OcspStatusAssertion.Unknown] = "unknown",
};
public override OcspStatusAssertion Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var s = reader.GetString() ?? string.Empty;
return StrToVal.TryGetValue(s, out var v) ? v : OcspStatusAssertion.Unknown;
}
public override void Write(Utf8JsonWriter writer, OcspStatusAssertion value, JsonSerializerOptions options)
{
writer.WriteStringValue(ValToStr.TryGetValue(value, out var s) ? s : "unknown");
}
}
/// <summary>
/// Returns the string representation of an OCSP status integer.
/// Falls back to "unknown" for unrecognized values (never defaults to "good").
/// </summary>
public static class OcspStatusAssertionExtensions
{
public static string GetStatusAssertionStr(int statusInt) => statusInt switch
{
0 => "good",
1 => "revoked",
_ => "unknown",
};
}
/// <summary>Parsed OCSP peer configuration.</summary>
public sealed class OcspPeerConfig
{
public static readonly TimeSpan DefaultAllowedClockSkew = TimeSpan.FromSeconds(30);
public static readonly TimeSpan DefaultOCSPResponderTimeout = TimeSpan.FromSeconds(2);
public static readonly TimeSpan DefaultTTLUnsetNextUpdate = TimeSpan.FromHours(1);
public bool Verify { get; set; } = false;
public double Timeout { get; set; } = DefaultOCSPResponderTimeout.TotalSeconds;
public double ClockSkew { get; set; } = DefaultAllowedClockSkew.TotalSeconds;
public bool WarnOnly { get; set; } = false;
public bool UnknownIsGood { get; set; } = false;
public bool AllowWhenCAUnreachable { get; set; } = false;
public double TTLUnsetNextUpdate { get; set; } = DefaultTTLUnsetNextUpdate.TotalSeconds;
/// <summary>Returns a new <see cref="OcspPeerConfig"/> with defaults populated.</summary>
public static OcspPeerConfig Create() => new();
}
/// <summary>
/// Represents a certificate chain link: a leaf certificate and its issuer,
/// plus the OCSP web endpoints parsed from the leaf's AIA extension.
/// </summary>
public sealed class ChainLink
{
public X509Certificate2? Leaf { get; set; }
public X509Certificate2? Issuer { get; set; }
public IReadOnlyList<Uri>? OcspWebEndpoints { get; set; }
}
/// <summary>
/// Parsed OCSP response data. Mirrors the fields of <c>golang.org/x/crypto/ocsp.Response</c>
/// needed by <see cref="OcspUtilities"/>.
/// </summary>
/// <remarks>
/// Full OCSP response parsing (DER/ASN.1) requires an additional library (e.g. Bouncy Castle).
/// This type represents the already-parsed response for use in validation and caching logic.
/// </remarks>
public sealed class OcspResponse
{
public OcspStatusAssertion Status { get; init; }
public DateTime ThisUpdate { get; init; }
/// <summary><see cref="DateTime.MinValue"/> means "not set" (CA did not supply NextUpdate).</summary>
public DateTime NextUpdate { get; init; }
/// <summary>Optional delegated signer certificate (RFC 6960 §4.2.2.2).</summary>
public X509Certificate2? Certificate { get; init; }
}
/// <summary>Neutral logging interface for plugin use. Mirrors the Go <c>certidp.Log</c> struct.</summary>
public sealed class OcspLog
{
public Action<string, object[]>? Debugf { get; set; }
public Action<string, object[]>? Noticef { get; set; }
public Action<string, object[]>? Warnf { get; set; }
public Action<string, object[]>? Errorf { get; set; }
public Action<string, object[]>? Tracef { get; set; }
internal void Debug(string format, params object[] args) => Debugf?.Invoke(format, args);
}
/// <summary>JSON-serializable certificate information.</summary>
public sealed class CertInfo
{
[JsonPropertyName("subject")] public string? Subject { get; init; }
[JsonPropertyName("issuer")] public string? Issuer { get; init; }
[JsonPropertyName("fingerprint")] public string? Fingerprint { get; init; }
[JsonPropertyName("raw")] public byte[]? Raw { get; init; }
}

View File

@@ -0,0 +1,73 @@
using System.Net.Http;
namespace ZB.MOM.NatsNet.Server.Auth.CertificateIdentityProvider;
/// <summary>
/// OCSP responder communication: fetches raw OCSP response bytes from CA endpoints.
/// Mirrors certidp/ocsp_responder.go.
/// </summary>
public static class OcspResponder
{
/// <summary>
/// Fetches an OCSP response from the responder URLs in <paramref name="link"/>.
/// Tries each endpoint in order and returns the first successful response.
/// </summary>
/// <param name="link">Chain link containing leaf cert, issuer cert, and OCSP endpoints.</param>
/// <param name="opts">Configuration (timeout, etc.).</param>
/// <param name="log">Optional logger.</param>
/// <param name="ocspRequest">DER-encoded OCSP request bytes to send.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Raw DER bytes of the OCSP response.</returns>
public static async Task<byte[]> FetchOCSPResponseAsync(
ChainLink link,
OcspPeerConfig opts,
byte[] ocspRequest,
OcspLog? log = null,
CancellationToken cancellationToken = default)
{
if (link.Leaf is null || link.Issuer is null)
throw new ArgumentException(OcspMessages.ErrInvalidChainlink, nameof(link));
if (link.OcspWebEndpoints is null || link.OcspWebEndpoints.Count == 0)
throw new InvalidOperationException(OcspMessages.ErrNoAvailOCSPServers);
var timeout = TimeSpan.FromSeconds(opts.Timeout <= 0
? OcspPeerConfig.DefaultOCSPResponderTimeout.TotalSeconds
: opts.Timeout);
var reqEnc = EncodeOCSPRequest(ocspRequest);
using var hc = new HttpClient { Timeout = timeout };
Exception? lastError = null;
foreach (var endpoint in link.OcspWebEndpoints)
{
var responderUrl = endpoint.ToString().TrimEnd('/');
log?.Debug(OcspMessages.DbgMakingCARequest, responderUrl);
try
{
var url = $"{responderUrl}/{reqEnc}";
using var response = await hc.GetAsync(url, cancellationToken).ConfigureAwait(false);
if (!response.IsSuccessStatusCode)
throw new HttpRequestException(
string.Format(OcspMessages.ErrBadResponderHTTPStatus, (int)response.StatusCode));
return await response.Content.ReadAsByteArrayAsync(cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
{
lastError = ex;
}
}
throw new InvalidOperationException(
string.Format(OcspMessages.ErrFailedWithAllRequests, lastError?.Message), lastError);
}
/// <summary>
/// Base64-encodes the OCSP request DER bytes and URL-escapes the result
/// for use as a path segment (RFC 6960 Appendix A.1).
/// </summary>
public static string EncodeOCSPRequest(byte[] reqDer) =>
Uri.EscapeDataString(Convert.ToBase64String(reqDer));
}

View File

@@ -0,0 +1,219 @@
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
namespace ZB.MOM.NatsNet.Server.Auth.CertificateIdentityProvider;
/// <summary>
/// Utility methods for OCSP peer certificate validation.
/// Mirrors certidp/certidp.go.
/// </summary>
public static class OcspUtilities
{
// OCSP AIA extension OID.
private const string OidAuthorityInfoAccess = "1.3.6.1.5.5.7.1.1";
// OCSPSigning extended key usage OID.
private const string OidOcspSigning = "1.3.6.1.5.5.7.3.9";
/// <summary>Returns the SHA-256 fingerprint of the certificate's raw DER bytes, base64-encoded.</summary>
public static string GenerateFingerprint(X509Certificate2 cert)
{
var hash = SHA256.HashData(cert.RawData);
return Convert.ToBase64String(hash);
}
/// <summary>
/// Filters a list of URI strings to those that are valid HTTP or HTTPS URLs.
/// </summary>
public static IReadOnlyList<Uri> GetWebEndpoints(IEnumerable<string> uris)
{
var result = new List<Uri>();
foreach (var uri in uris)
{
if (!Uri.TryCreate(uri, UriKind.Absolute, out var parsed))
continue;
if (parsed.Scheme != "http" && parsed.Scheme != "https")
continue;
result.Add(parsed);
}
return result;
}
/// <summary>
/// Returns the certificate subject in RDN sequence form, for logging.
/// Not suitable for reliable cache matching.
/// </summary>
public static string GetSubjectDNForm(X509Certificate2? cert) =>
cert is null ? string.Empty : cert.Subject;
/// <summary>
/// Returns the certificate issuer in RDN sequence form, for logging.
/// Not suitable for reliable cache matching.
/// </summary>
public static string GetIssuerDNForm(X509Certificate2? cert) =>
cert is null ? string.Empty : cert.Issuer;
/// <summary>
/// Returns true if the leaf certificate in the chain has OCSP responder endpoints
/// in its Authority Information Access extension.
/// Also populates <see cref="ChainLink.OcspWebEndpoints"/> on the link.
/// </summary>
public static bool CertOCSPEligible(ChainLink? link)
{
if (link?.Leaf is null || link.Leaf.RawData is not { Length: > 0 })
return false;
var ocspUris = GetOcspUris(link.Leaf);
var endpoints = GetWebEndpoints(ocspUris);
if (endpoints.Count == 0)
return false;
link.OcspWebEndpoints = endpoints;
return true;
}
/// <summary>
/// Returns the issuer certificate at position <paramref name="leafPos"/> + 1 in the chain.
/// Returns null if the chain is too short or the leaf is self-signed.
/// </summary>
public static X509Certificate2? GetLeafIssuerCert(IReadOnlyList<X509Certificate2> chain, int leafPos)
{
if (chain.Count == 0 || leafPos < 0)
return null;
if (leafPos >= chain.Count - 1)
return null;
return chain[leafPos + 1];
}
/// <summary>
/// Returns true if the OCSP response is still current within the configured clock skew.
/// </summary>
public static bool OCSPResponseCurrent(OcspResponse response, OcspPeerConfig opts, OcspLog? log = null)
{
var skew = TimeSpan.FromSeconds(opts.ClockSkew < 0 ? OcspPeerConfig.DefaultAllowedClockSkew.TotalSeconds : opts.ClockSkew);
var now = DateTime.UtcNow;
// Check NextUpdate (when set by CA).
if (response.NextUpdate != DateTime.MinValue && response.NextUpdate < now - skew)
{
log?.Debug(OcspMessages.DbgResponseExpired,
response.NextUpdate.ToString("o"), now.ToString("o"), skew);
return false;
}
// If NextUpdate not set, apply TTL from ThisUpdate.
if (response.NextUpdate == DateTime.MinValue)
{
var ttl = TimeSpan.FromSeconds(opts.TTLUnsetNextUpdate < 0
? OcspPeerConfig.DefaultTTLUnsetNextUpdate.TotalSeconds
: opts.TTLUnsetNextUpdate);
var expiry = response.ThisUpdate + ttl;
if (expiry < now - skew)
{
log?.Debug(OcspMessages.DbgResponseTTLExpired,
expiry.ToString("o"), now.ToString("o"), skew);
return false;
}
}
// Check ThisUpdate is not future-dated.
if (response.ThisUpdate > now + skew)
{
log?.Debug(OcspMessages.DbgResponseFutureDated,
response.ThisUpdate.ToString("o"), now.ToString("o"), skew);
return false;
}
return true;
}
/// <summary>
/// Validates that the OCSP response was signed by a valid CA issuer or authorised delegate
/// per RFC 6960 §4.2.2.2.
/// </summary>
public static bool ValidDelegationCheck(X509Certificate2? issuer, OcspResponse? response)
{
if (issuer is null || response is null)
return false;
// Not a delegated response — the CA signed directly.
if (response.Certificate is null)
return true;
// Delegate is the same as the issuer — effectively a direct signing.
if (response.Certificate.Thumbprint == issuer.Thumbprint)
return true;
// Check the delegate has id-kp-OCSPSigning in its extended key usage.
foreach (var ext in response.Certificate.Extensions)
{
if (ext is not X509EnhancedKeyUsageExtension eku)
continue;
foreach (var oid in eku.EnhancedKeyUsages)
{
if (oid.Value == OidOcspSigning)
return true;
}
}
return false;
}
// --- Helpers ---
private static IEnumerable<string> GetOcspUris(X509Certificate2 cert)
{
foreach (var ext in cert.Extensions)
{
if (ext.Oid?.Value != OidAuthorityInfoAccess)
continue;
foreach (var uri in ParseAiaUris(ext.RawData, isOcsp: true))
yield return uri;
}
}
private static List<string> ParseAiaUris(byte[] aiaExtDer, bool isOcsp)
{
// OID for id-ad-ocsp: 1.3.6.1.5.5.7.48.1 → 2B 06 01 05 05 07 30 01
byte[] ocspOid = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x01];
// OID for id-ad-caIssuers: 1.3.6.1.5.5.7.48.2 → 2B 06 01 05 05 07 30 02
byte[] caIssuersOid = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x02];
var target = isOcsp ? ocspOid : caIssuersOid;
var result = new List<string>();
int i = 0;
while (i < aiaExtDer.Length - target.Length - 4)
{
// Look for OID tag (0x06) followed by length matching our OID.
if (aiaExtDer[i] == 0x06 && i + 1 < aiaExtDer.Length && aiaExtDer[i + 1] == target.Length)
{
var match = true;
for (int k = 0; k < target.Length; k++)
{
if (aiaExtDer[i + 2 + k] != target[k]) { match = false; break; }
}
if (match)
{
// Next element should be context [6] IA5String (GeneralName uniformResourceIdentifier).
int pos = i + 2 + target.Length;
if (pos < aiaExtDer.Length && aiaExtDer[pos] == 0x86)
{
pos++;
if (pos < aiaExtDer.Length)
{
int len = aiaExtDer[pos++];
if (pos + len <= aiaExtDer.Length)
{
result.Add(System.Text.Encoding.ASCII.GetString(aiaExtDer, pos, len));
i = pos + len;
continue;
}
}
}
}
}
i++;
}
return result;
}
}

View File

@@ -0,0 +1,137 @@
// Copyright 2022-2025 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
using System.Security.Cryptography.X509Certificates;
namespace ZB.MOM.NatsNet.Server.Auth.CertificateStore;
/// <summary>
/// Windows certificate store location.
/// Mirrors the Go certstore <c>StoreType</c> enum (windowsCurrentUser=1, windowsLocalMachine=2).
/// </summary>
public enum StoreType
{
Empty = 0,
WindowsCurrentUser = 1,
WindowsLocalMachine = 2,
}
/// <summary>
/// Certificate lookup criterion.
/// Mirrors the Go certstore <c>MatchByType</c> enum (matchByIssuer=1, matchBySubject=2, matchByThumbprint=3).
/// </summary>
public enum MatchByType
{
Empty = 0,
Issuer = 1,
Subject = 2,
Thumbprint = 3,
}
/// <summary>
/// Result returned by <see cref="CertificateStoreService.TLSConfig"/>.
/// Mirrors the data that the Go <c>TLSConfig</c> populates into <c>*tls.Config</c>.
/// </summary>
public sealed class CertStoreTlsResult
{
public CertStoreTlsResult(X509Certificate2 leaf, X509Certificate2Collection? caCerts = null)
{
Leaf = leaf;
CaCerts = caCerts;
}
/// <summary>The leaf certificate (with private key) to use as the server/client identity.</summary>
public X509Certificate2 Leaf { get; }
/// <summary>Optional pool of CA certificates used to validate client certificates (mTLS).</summary>
public X509Certificate2Collection? CaCerts { get; }
}
/// <summary>
/// Error constants for the Windows certificate store module.
/// Mirrors certstore/errors.go.
/// </summary>
public static class CertStoreErrors
{
public static readonly InvalidOperationException ErrBadCryptoStoreProvider =
new("unable to open certificate store or store not available");
public static readonly InvalidOperationException ErrBadRSAHashAlgorithm =
new("unsupported RSA hash algorithm");
public static readonly InvalidOperationException ErrBadSigningAlgorithm =
new("unsupported signing algorithm");
public static readonly InvalidOperationException ErrStoreRSASigningError =
new("unable to obtain RSA signature from store");
public static readonly InvalidOperationException ErrStoreECDSASigningError =
new("unable to obtain ECDSA signature from store");
public static readonly InvalidOperationException ErrNoPrivateKeyStoreRef =
new("unable to obtain private key handle from store");
public static readonly InvalidOperationException ErrExtractingPrivateKeyMetadata =
new("unable to extract private key metadata");
public static readonly InvalidOperationException ErrExtractingECCPublicKey =
new("unable to extract ECC public key from store");
public static readonly InvalidOperationException ErrExtractingRSAPublicKey =
new("unable to extract RSA public key from store");
public static readonly InvalidOperationException ErrExtractingPublicKey =
new("unable to extract public key from store");
public static readonly InvalidOperationException ErrBadPublicKeyAlgorithm =
new("unsupported public key algorithm");
public static readonly InvalidOperationException ErrExtractPropertyFromKey =
new("unable to extract property from key");
public static readonly InvalidOperationException ErrBadECCCurveName =
new("unsupported ECC curve name");
public static readonly InvalidOperationException ErrFailedCertSearch =
new("unable to find certificate in store");
public static readonly InvalidOperationException ErrFailedX509Extract =
new("unable to extract x509 from certificate");
public static readonly InvalidOperationException ErrBadMatchByType =
new("cert match by type not implemented");
public static readonly InvalidOperationException ErrBadCertStore =
new("cert store type not implemented");
public static readonly InvalidOperationException ErrConflictCertFileAndStore =
new("'cert_file' and 'cert_store' may not both be configured");
public static readonly InvalidOperationException ErrBadCertStoreField =
new("expected 'cert_store' to be a valid non-empty string");
public static readonly InvalidOperationException ErrBadCertMatchByField =
new("expected 'cert_match_by' to be a valid non-empty string");
public static readonly InvalidOperationException ErrBadCertMatchField =
new("expected 'cert_match' to be a valid non-empty string");
public static readonly InvalidOperationException ErrBadCaCertMatchField =
new("expected 'ca_certs_match' to be a valid non-empty string array");
public static readonly InvalidOperationException ErrBadCertMatchSkipInvalidField =
new("expected 'cert_match_skip_invalid' to be a boolean");
public static readonly InvalidOperationException ErrOSNotCompatCertStore =
new("cert_store not compatible with current operating system");
}

View File

@@ -0,0 +1,264 @@
// Copyright 2022-2025 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Adapted from certstore/certstore.go and certstore/certstore_windows.go in
// the NATS server Go source. The .NET implementation uses System.Security.
// Cryptography.X509Certificates.X509Store in place of Win32 P/Invoke calls.
using System.Runtime.InteropServices;
using System.Security.Cryptography.X509Certificates;
namespace ZB.MOM.NatsNet.Server.Auth.CertificateStore;
/// <summary>
/// Provides access to the Windows certificate store for TLS certificate provisioning.
/// Mirrors certstore/certstore.go and certstore/certstore_windows.go.
///
/// On non-Windows platforms all methods that require the Windows store throw
/// <see cref="CertStoreErrors.ErrOSNotCompatCertStore"/>.
/// </summary>
public static class CertificateStoreService
{
private static readonly IReadOnlyDictionary<string, StoreType> StoreMap =
new Dictionary<string, StoreType>(StringComparer.OrdinalIgnoreCase)
{
["windowscurrentuser"] = StoreType.WindowsCurrentUser,
["windowslocalmachine"] = StoreType.WindowsLocalMachine,
};
private static readonly IReadOnlyDictionary<string, MatchByType> MatchByMap =
new Dictionary<string, MatchByType>(StringComparer.OrdinalIgnoreCase)
{
["issuer"] = MatchByType.Issuer,
["subject"] = MatchByType.Subject,
["thumbprint"] = MatchByType.Thumbprint,
};
// -------------------------------------------------------------------------
// Cross-platform parse helpers
// -------------------------------------------------------------------------
/// <summary>
/// Parses a cert_store string to a <see cref="StoreType"/>.
/// Returns an error if the string is unrecognised or not valid on the current OS.
/// Mirrors <c>ParseCertStore</c>.
/// </summary>
public static (StoreType store, Exception? error) ParseCertStore(string certStore)
{
if (!StoreMap.TryGetValue(certStore, out var st))
return (StoreType.Empty, CertStoreErrors.ErrBadCertStore);
// All currently supported store types are Windows-only.
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
return (StoreType.Empty, CertStoreErrors.ErrOSNotCompatCertStore);
return (st, null);
}
/// <summary>
/// Parses a cert_match_by string to a <see cref="MatchByType"/>.
/// Mirrors <c>ParseCertMatchBy</c>.
/// </summary>
public static (MatchByType matchBy, Exception? error) ParseCertMatchBy(string certMatchBy)
{
if (!MatchByMap.TryGetValue(certMatchBy, out var mb))
return (MatchByType.Empty, CertStoreErrors.ErrBadMatchByType);
return (mb, null);
}
/// <summary>
/// Returns the issuer certificate for <paramref name="leaf"/> by building a chain.
/// Returns null if the chain cannot be built or the leaf is self-signed.
/// Mirrors <c>GetLeafIssuer</c>.
/// </summary>
public static X509Certificate2? GetLeafIssuer(X509Certificate2 leaf)
{
using var chain = new X509Chain();
chain.ChainPolicy.RevocationMode = X509RevocationMode.NoCheck;
chain.ChainPolicy.VerificationFlags = X509VerificationFlags.AllowUnknownCertificateAuthority;
if (!chain.Build(leaf) || chain.ChainElements.Count < 2)
return null;
// chain.ChainElements[0] is the leaf; [1] is its issuer.
return new X509Certificate2(chain.ChainElements[1].Certificate);
}
// -------------------------------------------------------------------------
// TLS configuration entry point
// -------------------------------------------------------------------------
/// <summary>
/// Finds a certificate in the Windows certificate store matching the given criteria and
/// returns a <see cref="CertStoreTlsResult"/> suitable for populating TLS options.
///
/// On non-Windows platforms throws <see cref="CertStoreErrors.ErrOSNotCompatCertStore"/>.
/// Mirrors <c>TLSConfig</c> (certstore_windows.go).
/// </summary>
/// <param name="storeType">Which Windows store to use (CurrentUser or LocalMachine).</param>
/// <param name="matchBy">How to match the certificate (Subject, Issuer, or Thumbprint).</param>
/// <param name="certMatch">The match value (subject name, issuer name, or thumbprint hex).</param>
/// <param name="caCertsMatch">Optional list of subject strings to locate CA certificates.</param>
/// <param name="skipInvalid">If true, skip expired or not-yet-valid certificates.</param>
public static CertStoreTlsResult TLSConfig(
StoreType storeType,
MatchByType matchBy,
string certMatch,
IReadOnlyList<string>? caCertsMatch = null,
bool skipInvalid = false)
{
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
throw CertStoreErrors.ErrOSNotCompatCertStore;
if (storeType is not (StoreType.WindowsCurrentUser or StoreType.WindowsLocalMachine))
throw CertStoreErrors.ErrBadCertStore;
var location = storeType == StoreType.WindowsCurrentUser
? StoreLocation.CurrentUser
: StoreLocation.LocalMachine;
// Find the leaf certificate.
var leaf = matchBy switch
{
MatchByType.Subject or MatchByType.Empty => CertBySubject(certMatch, location, skipInvalid),
MatchByType.Issuer => CertByIssuer(certMatch, location, skipInvalid),
MatchByType.Thumbprint => CertByThumbprint(certMatch, location, skipInvalid),
_ => throw CertStoreErrors.ErrBadMatchByType,
} ?? throw CertStoreErrors.ErrFailedCertSearch;
// Optionally find CA certificates.
X509Certificate2Collection? caPool = null;
if (caCertsMatch is { Count: > 0 })
caPool = CreateCACertsPool(location, caCertsMatch, skipInvalid);
return new CertStoreTlsResult(leaf, caPool);
}
// -------------------------------------------------------------------------
// Certificate search helpers (mirror winCertStore.certByXxx / certSearch)
// -------------------------------------------------------------------------
/// <summary>
/// Finds the first certificate in the personal (MY) store by subject name.
/// Mirrors <c>certBySubject</c>.
/// </summary>
public static X509Certificate2? CertBySubject(string subject, StoreLocation location, bool skipInvalid) =>
CertSearch(StoreName.My, location, X509FindType.FindBySubjectName, subject, skipInvalid);
/// <summary>
/// Finds the first certificate in the personal (MY) store by issuer name.
/// Mirrors <c>certByIssuer</c>.
/// </summary>
public static X509Certificate2? CertByIssuer(string issuer, StoreLocation location, bool skipInvalid) =>
CertSearch(StoreName.My, location, X509FindType.FindByIssuerName, issuer, skipInvalid);
/// <summary>
/// Finds the first certificate in the personal (MY) store by SHA-1 thumbprint (hex string).
/// Mirrors <c>certByThumbprint</c>.
/// </summary>
public static X509Certificate2? CertByThumbprint(string thumbprint, StoreLocation location, bool skipInvalid) =>
CertSearch(StoreName.My, location, X509FindType.FindByThumbprint, thumbprint, skipInvalid);
/// <summary>
/// Searches Root, AuthRoot, and CA stores for certificates matching the given subject name.
/// Returns all matching certificates across all three locations.
/// Mirrors <c>caCertsBySubjectMatch</c>.
/// </summary>
public static IReadOnlyList<X509Certificate2> CaCertsBySubjectMatch(
string subject,
StoreLocation location,
bool skipInvalid)
{
if (string.IsNullOrEmpty(subject))
throw CertStoreErrors.ErrBadCaCertMatchField;
var results = new List<X509Certificate2>();
var searchLocations = new[] { StoreName.Root, StoreName.AuthRoot, StoreName.CertificateAuthority };
foreach (var storeName in searchLocations)
{
var cert = CertSearch(storeName, location, X509FindType.FindBySubjectName, subject, skipInvalid);
if (cert != null)
results.Add(cert);
}
if (results.Count == 0)
throw CertStoreErrors.ErrFailedCertSearch;
return results;
}
/// <summary>
/// Core certificate search — opens the specified store and finds a matching certificate.
/// Returns null if not found.
/// Mirrors <c>certSearch</c>.
/// </summary>
public static X509Certificate2? CertSearch(
StoreName storeName,
StoreLocation storeLocation,
X509FindType findType,
string findValue,
bool skipInvalid)
{
using var store = new X509Store(storeName, storeLocation, OpenFlags.ReadOnly | OpenFlags.OpenExistingOnly);
var certs = store.Certificates.Find(findType, findValue, validOnly: skipInvalid);
if (certs.Count == 0)
return null;
// Pick first that has a private key (mirrors certKey requirement in Go).
foreach (var cert in certs)
{
if (cert.HasPrivateKey)
return cert;
}
// Fall back to first even without private key (e.g. CA cert lookup).
return certs[0];
}
// -------------------------------------------------------------------------
// CA cert pool builder (mirrors createCACertsPool)
// -------------------------------------------------------------------------
/// <summary>
/// Builds a collection of CA certificates from the trusted Root, AuthRoot, and CA stores
/// for each subject name in <paramref name="caCertsMatch"/>.
/// Mirrors <c>createCACertsPool</c>.
/// </summary>
public static X509Certificate2Collection CreateCACertsPool(
StoreLocation location,
IReadOnlyList<string> caCertsMatch,
bool skipInvalid)
{
var pool = new X509Certificate2Collection();
var failCount = 0;
foreach (var subject in caCertsMatch)
{
try
{
var matches = CaCertsBySubjectMatch(subject, location, skipInvalid);
foreach (var cert in matches)
pool.Add(cert);
}
catch
{
failCount++;
}
}
if (failCount == caCertsMatch.Count)
throw new InvalidOperationException("unable to match any CA certificate");
return pool;
}
}

View File

@@ -0,0 +1,61 @@
using System.Runtime.InteropServices;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace ZB.MOM.NatsNet.Server.Auth;
/// <summary>
/// Provides JetStream encryption key management via the Trusted Platform Module (TPM).
/// Windows only — non-Windows platforms throw <see cref="PlatformNotSupportedException"/>.
/// </summary>
/// <remarks>
/// On Windows, the full implementation requires the Tpm2Lib NuGet package and accesses
/// the TPM to seal/unseal keys using PCR-based authorization. The sealed public and
/// private key blobs are persisted to disk as JSON.
/// </remarks>
public static class TpmKeyProvider
{
/// <summary>
/// Loads (or creates) the JetStream encryption key from the TPM.
/// On first call (key file does not exist), generates a new NKey seed, seals it to the
/// TPM, and writes the blobs to <paramref name="jsKeyFile"/>.
/// On subsequent calls, reads the blobs from disk and unseals them using the TPM.
/// </summary>
/// <param name="srkPassword">Storage Root Key password (may be empty).</param>
/// <param name="jsKeyFile">Path to the persisted key blobs JSON file.</param>
/// <param name="jsKeyPassword">Password used to seal/unseal the JetStream key.</param>
/// <param name="pcr">PCR index to bind the authorization policy to.</param>
/// <returns>The JetStream encryption key seed string.</returns>
/// <exception cref="PlatformNotSupportedException">Thrown on non-Windows platforms.</exception>
public static string LoadJetStreamEncryptionKeyFromTpm(
string srkPassword,
string jsKeyFile,
string jsKeyPassword,
int pcr)
{
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
throw new PlatformNotSupportedException("TPM functionality is not supported on this platform.");
// Windows implementation requires Tpm2Lib NuGet package.
// Add <PackageReference Include="Tpm2Lib" Version="*" /> to the .csproj
// under a Windows-conditional ItemGroup before enabling this path.
throw new PlatformNotSupportedException(
"TPM functionality is not supported on this platform. " +
"On Windows, add Tpm2Lib NuGet package and implement via tpm2.OpenTPM().");
}
}
/// <summary>
/// Persisted TPM key blobs stored on disk as JSON.
/// </summary>
internal sealed class NatsPersistedTpmKeys
{
[JsonPropertyName("version")]
public int Version { get; set; }
[JsonPropertyName("private_key")]
public byte[] PrivateKey { get; set; } = [];
[JsonPropertyName("public_key")]
public byte[] PublicKey { get; set; } = [];
}

View File

@@ -0,0 +1,100 @@
namespace ZB.MOM.NatsNet.Server.Internal;
/// <summary>
/// Provides an efficiently-cached Unix nanosecond timestamp updated every
/// <see cref="TickInterval"/> by a shared background timer.
/// Register before use and Unregister when done; the timer shuts down when all
/// registrants have unregistered.
/// </summary>
/// <remarks>
/// Mirrors the Go <c>ats</c> package. Intended for high-frequency cache
/// access-time reads that do not need sub-100ms precision.
/// </remarks>
public static class AccessTimeService
{
/// <summary>How often the cached time is refreshed.</summary>
public static readonly TimeSpan TickInterval = TimeSpan.FromMilliseconds(100);
private static long _utime;
private static long _refs;
private static Timer? _timer;
private static readonly object _lock = new();
static AccessTimeService()
{
// Mirror Go's init(): nothing to pre-allocate in .NET.
}
/// <summary>
/// Registers a user. Starts the background timer when the first registrant calls this.
/// Each call to <see cref="Register"/> must be paired with a call to <see cref="Unregister"/>.
/// </summary>
public static void Register()
{
var v = Interlocked.Increment(ref _refs);
if (v == 1)
{
Interlocked.Exchange(ref _utime, DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() * 1_000_000L);
lock (_lock)
{
_timer?.Dispose();
_timer = new Timer(_ =>
{
Interlocked.Exchange(ref _utime, DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() * 1_000_000L);
}, null, TickInterval, TickInterval);
}
}
}
/// <summary>
/// Unregisters a user. Stops the background timer when the last registrant calls this.
/// </summary>
/// <exception cref="InvalidOperationException">Thrown when unregister is called more times than register.</exception>
public static void Unregister()
{
var v = Interlocked.Decrement(ref _refs);
if (v == 0)
{
lock (_lock)
{
_timer?.Dispose();
_timer = null;
}
}
else if (v < 0)
{
Interlocked.Exchange(ref _refs, 0);
throw new InvalidOperationException("ats: unbalanced unregister for access time state");
}
}
/// <summary>
/// Returns the last cached Unix nanosecond timestamp.
/// If no registrant is active, returns a fresh timestamp (avoids returning zero).
/// </summary>
public static long AccessTime()
{
var v = Interlocked.Read(ref _utime);
if (v == 0)
{
v = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() * 1_000_000L;
Interlocked.CompareExchange(ref _utime, v, 0);
v = Interlocked.Read(ref _utime);
}
return v;
}
/// <summary>
/// Resets all state. For testing only.
/// </summary>
internal static void Reset()
{
lock (_lock)
{
_timer?.Dispose();
_timer = null;
}
Interlocked.Exchange(ref _refs, 0);
Interlocked.Exchange(ref _utime, 0);
}
}

View File

@@ -0,0 +1,678 @@
// Copyright 2025 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
namespace ZB.MOM.NatsNet.Server.Internal.DataStructures;
// Sublist is a routing mechanism to handle subject distribution and
// provides a facility to match subjects from published messages to
// interested subscribers. Subscribers can have wildcard subjects to
// match multiple published subjects.
/// <summary>
/// A value type used with <see cref="SimpleSublist"/> to track interest without
/// storing any associated data. Equivalent to Go's <c>struct{}</c>.
/// </summary>
public readonly struct EmptyStruct : IEquatable<EmptyStruct>
{
public static readonly EmptyStruct Value = default;
public bool Equals(EmptyStruct other) => true;
public override bool Equals(object? obj) => obj is EmptyStruct;
public override int GetHashCode() => 0;
public static bool operator ==(EmptyStruct left, EmptyStruct right) => true;
public static bool operator !=(EmptyStruct left, EmptyStruct right) => false;
}
/// <summary>
/// A thread-safe trie-based NATS subject routing list that efficiently stores and
/// retrieves subscriptions. Wildcards <c>*</c> (single-token) and <c>&gt;</c>
/// (full-wildcard) are supported.
/// </summary>
/// <typeparam name="T">The subscription value type. Must be non-null.</typeparam>
public class GenericSublist<T> where T : notnull
{
// Token separator and wildcard constants (mirrors Go's const block).
private const char Pwc = '*';
private const char Fwc = '>';
private const char Btsep = '.';
// -------------------------------------------------------------------------
// Public error singletons (mirrors Go's var block).
// -------------------------------------------------------------------------
/// <summary>Thrown when a subject is syntactically invalid.</summary>
public static readonly ArgumentException ErrInvalidSubject =
new("gsl: invalid subject");
/// <summary>Thrown when a subscription is not found during removal.</summary>
public static readonly KeyNotFoundException ErrNotFound =
new("gsl: no matches found");
/// <summary>Thrown when a value is already registered for the given subject.</summary>
public static readonly InvalidOperationException ErrAlreadyRegistered =
new("gsl: notification already registered");
// -------------------------------------------------------------------------
// Fields
// -------------------------------------------------------------------------
private readonly TrieLevel _root;
private uint _count;
private readonly ReaderWriterLockSlim _lock = new(LockRecursionPolicy.NoRecursion);
// -------------------------------------------------------------------------
// Construction
// -------------------------------------------------------------------------
internal GenericSublist()
{
_root = new TrieLevel();
}
/// <summary>Creates a new <see cref="GenericSublist{T}"/>.</summary>
public static GenericSublist<T> NewSublist() => new();
/// <summary>Creates a new <see cref="SimpleSublist"/>.</summary>
public static SimpleSublist NewSimpleSublist() => new();
// -------------------------------------------------------------------------
// Public API
// -------------------------------------------------------------------------
/// <summary>Returns the total number of subscriptions stored.</summary>
public uint Count
{
get
{
_lock.EnterReadLock();
try { return _count; }
finally { _lock.ExitReadLock(); }
}
}
/// <summary>
/// Inserts a subscription into the trie.
/// Throws <see cref="ArgumentException"/> if <paramref name="subject"/> is invalid.
/// </summary>
public void Insert(string subject, T value)
{
_lock.EnterWriteLock();
try
{
InsertCore(subject, value);
}
finally
{
_lock.ExitWriteLock();
}
}
/// <summary>
/// Removes a subscription from the trie.
/// Throws <see cref="ArgumentException"/> if the subject is invalid, or
/// <see cref="KeyNotFoundException"/> if not found.
/// </summary>
public void Remove(string subject, T value)
{
_lock.EnterWriteLock();
try
{
RemoveCore(subject, value);
}
finally
{
_lock.ExitWriteLock();
}
}
/// <summary>
/// Calls <paramref name="action"/> for every value whose subscription matches
/// the literal <paramref name="subject"/>.
/// </summary>
public void Match(string subject, Action<T> action)
{
_lock.EnterReadLock();
try
{
var tokens = TokenizeForMatch(subject);
if (tokens == null) return;
MatchLevel(_root, tokens, 0, action);
}
finally
{
_lock.ExitReadLock();
}
}
/// <summary>
/// Calls <paramref name="action"/> for every value whose subscription matches
/// <paramref name="subject"/> supplied as a UTF-8 byte span.
/// </summary>
public void MatchBytes(ReadOnlySpan<byte> subject, Action<T> action)
{
Match(System.Text.Encoding.UTF8.GetString(subject), action);
}
/// <summary>
/// Returns <see langword="true"/> when at least one subscription matches
/// <paramref name="subject"/>.
/// </summary>
public bool HasInterest(string subject)
{
_lock.EnterReadLock();
try
{
var tokens = TokenizeForMatch(subject);
if (tokens == null) return false;
int dummy = 0;
return MatchLevelForAny(_root, tokens, 0, ref dummy);
}
finally
{
_lock.ExitReadLock();
}
}
/// <summary>
/// Returns the number of subscriptions that match <paramref name="subject"/>.
/// </summary>
public int NumInterest(string subject)
{
_lock.EnterReadLock();
try
{
var tokens = TokenizeForMatch(subject);
if (tokens == null) return 0;
int np = 0;
MatchLevelForAny(_root, tokens, 0, ref np);
return np;
}
finally
{
_lock.ExitReadLock();
}
}
/// <summary>
/// Returns <see langword="true"/> if the trie contains any subscription that
/// could match a subject whose tokens begin with the tokens of
/// <paramref name="subject"/>. Used for trie intersection checks.
/// </summary>
public bool HasInterestStartingIn(string subject)
{
_lock.EnterReadLock();
try
{
var tokens = TokenizeSubjectIntoSlice(subject);
return HasInterestStartingInLevel(_root, tokens, 0);
}
finally
{
_lock.ExitReadLock();
}
}
// -------------------------------------------------------------------------
// Internal helpers (accessible to tests in the same assembly).
// -------------------------------------------------------------------------
/// <summary>Returns the maximum depth of the trie. Used in tests.</summary>
internal int NumLevels() => VisitLevel(_root, 0);
// -------------------------------------------------------------------------
// Private: Insert core (lock must be held by caller)
// -------------------------------------------------------------------------
private void InsertCore(string subject, T value)
{
var sfwc = false; // seen full-wildcard token
TrieNode? n = null;
var l = _root;
// Iterate tokens split by '.' using index arithmetic to avoid allocations.
var start = 0;
while (start <= subject.Length)
{
// Find end of this token.
var end = subject.IndexOf(Btsep, start);
var isLast = end < 0;
if (isLast) end = subject.Length;
var tokenLen = end - start;
if (tokenLen == 0 || sfwc)
throw new ArgumentException(ErrInvalidSubject.Message);
if (tokenLen > 1)
{
var t = subject.Substring(start, tokenLen);
if (!l.Nodes.TryGetValue(t, out n))
{
n = new TrieNode();
l.Nodes[t] = n;
}
}
else
{
switch (subject[start])
{
case Pwc:
if (l.PwcNode == null) l.PwcNode = new TrieNode();
n = l.PwcNode;
break;
case Fwc:
if (l.FwcNode == null) l.FwcNode = new TrieNode();
n = l.FwcNode;
sfwc = true;
break;
default:
var t = subject.Substring(start, 1);
if (!l.Nodes.TryGetValue(t, out n))
{
n = new TrieNode();
l.Nodes[t] = n;
}
break;
}
}
n.Next ??= new TrieLevel();
l = n.Next;
if (isLast) break;
start = end + 1;
}
if (n == null)
throw new ArgumentException(ErrInvalidSubject.Message);
n.Subs[value] = subject;
_count++;
}
// -------------------------------------------------------------------------
// Private: Remove core (lock must be held by caller)
// -------------------------------------------------------------------------
private void RemoveCore(string subject, T value)
{
var sfwc = false;
var l = _root;
// We use a fixed-size stack-style array to track visited (level, node, token)
// triples so we can prune upward after removal. 32 is the same as Go's [32]lnt.
var levels = new LevelNodeToken[32];
var levelCount = 0;
TrieNode? n = null;
var start = 0;
while (start <= subject.Length)
{
var end = subject.IndexOf(Btsep, start);
var isLast = end < 0;
if (isLast) end = subject.Length;
var tokenLen = end - start;
if (tokenLen == 0 || sfwc)
throw new ArgumentException(ErrInvalidSubject.Message);
if (l == null!)
throw new KeyNotFoundException(ErrNotFound.Message);
var tokenStr = subject.Substring(start, tokenLen);
if (tokenLen > 1)
{
l.Nodes.TryGetValue(tokenStr, out n);
}
else
{
switch (tokenStr[0])
{
case Pwc:
n = l.PwcNode;
break;
case Fwc:
n = l.FwcNode;
sfwc = true;
break;
default:
l.Nodes.TryGetValue(tokenStr, out n);
break;
}
}
if (n != null)
{
if (levelCount < levels.Length)
levels[levelCount++] = new LevelNodeToken(l, n, tokenStr);
l = n.Next!;
}
else
{
l = null!;
}
if (isLast) break;
start = end + 1;
}
// Remove from the final node's subscription map.
if (!RemoveFromNode(n, value))
throw new KeyNotFoundException(ErrNotFound.Message);
_count--;
// Prune empty nodes upward.
for (var i = levelCount - 1; i >= 0; i--)
{
var (lv, nd, tk) = levels[i];
if (nd.IsEmpty())
lv.PruneNode(nd, tk);
}
}
private static bool RemoveFromNode(TrieNode? n, T value)
{
if (n == null) return false;
return n.Subs.Remove(value);
}
// -------------------------------------------------------------------------
// Private: matchLevel - recursive trie descent with callback
// Mirrors Go's matchLevel function exactly.
// -------------------------------------------------------------------------
private static void MatchLevel(TrieLevel? l, string[] tokens, int start, Action<T> action)
{
TrieNode? pwc = null;
TrieNode? n = null;
for (var i = start; i < tokens.Length; i++)
{
if (l == null) return;
// Full-wildcard at this level matches everything at/below.
if (l.FwcNode != null)
CallbacksForResults(l.FwcNode, action);
pwc = l.PwcNode;
if (pwc != null)
MatchLevel(pwc.Next, tokens, i + 1, action);
l.Nodes.TryGetValue(tokens[i], out n);
l = n?.Next;
}
// After consuming all tokens, emit subs from exact and pwc matches.
if (n != null)
CallbacksForResults(n, action);
if (pwc != null)
CallbacksForResults(pwc, action);
}
private static void CallbacksForResults(TrieNode n, Action<T> action)
{
foreach (var sub in n.Subs.Keys)
action(sub);
}
// -------------------------------------------------------------------------
// Private: matchLevelForAny - returns true on first match, counting via np
// Mirrors Go's matchLevelForAny function exactly.
// -------------------------------------------------------------------------
private static bool MatchLevelForAny(TrieLevel? l, string[] tokens, int start, ref int np)
{
TrieNode? pwc = null;
TrieNode? n = null;
for (var i = start; i < tokens.Length; i++)
{
if (l == null) return false;
if (l.FwcNode != null)
{
np += l.FwcNode.Subs.Count;
return true;
}
pwc = l.PwcNode;
if (pwc != null)
{
if (MatchLevelForAny(pwc.Next, tokens, i + 1, ref np))
return true;
}
l.Nodes.TryGetValue(tokens[i], out n);
l = n?.Next;
}
if (n != null)
{
np += n.Subs.Count;
if (n.Subs.Count > 0) return true;
}
if (pwc != null)
{
np += pwc.Subs.Count;
return pwc.Subs.Count > 0;
}
return false;
}
// -------------------------------------------------------------------------
// Private: hasInterestStartingIn - mirrors Go's hasInterestStartingIn
// -------------------------------------------------------------------------
private static bool HasInterestStartingInLevel(TrieLevel? l, string[] tokens, int start)
{
if (l == null) return false;
if (start >= tokens.Length) return true;
if (l.FwcNode != null) return true;
var found = false;
if (l.PwcNode != null)
found = HasInterestStartingInLevel(l.PwcNode.Next, tokens, start + 1);
if (!found && l.Nodes.TryGetValue(tokens[start], out var n))
found = HasInterestStartingInLevel(n.Next, tokens, start + 1);
return found;
}
// -------------------------------------------------------------------------
// Private: numLevels helper - mirrors Go's visitLevel
// -------------------------------------------------------------------------
private static int VisitLevel(TrieLevel? l, int depth)
{
if (l == null || l.NumNodes() == 0) return depth;
depth++;
var maxDepth = depth;
foreach (var n in l.Nodes.Values)
{
var d = VisitLevel(n.Next, depth);
if (d > maxDepth) maxDepth = d;
}
if (l.PwcNode != null)
{
var d = VisitLevel(l.PwcNode.Next, depth);
if (d > maxDepth) maxDepth = d;
}
if (l.FwcNode != null)
{
var d = VisitLevel(l.FwcNode.Next, depth);
if (d > maxDepth) maxDepth = d;
}
return maxDepth;
}
// -------------------------------------------------------------------------
// Private: tokenization helpers
// -------------------------------------------------------------------------
/// <summary>
/// Tokenizes a subject for match/hasInterest operations.
/// Returns <see langword="null"/> if the subject contains an empty token,
/// because an empty token can never match any subscription in the trie.
/// Mirrors Go's inline tokenization in <c>match()</c> and <c>hasInterest()</c>.
/// </summary>
private static string[]? TokenizeForMatch(string subject)
{
if (subject.Length == 0) return null;
var tokens = new List<string>(8);
var start = 0;
for (var i = 0; i < subject.Length; i++)
{
if (subject[i] == Btsep)
{
if (i - start == 0) return null; // empty token
tokens.Add(subject.Substring(start, i - start));
start = i + 1;
}
}
// Trailing separator produces empty last token.
if (start >= subject.Length) return null;
tokens.Add(subject.Substring(start));
return tokens.ToArray();
}
/// <summary>
/// Tokenizes a subject into a string array without validation.
/// Mirrors Go's <c>tokenizeSubjectIntoSlice</c>.
/// </summary>
private static string[] TokenizeSubjectIntoSlice(string subject)
{
var tokens = new List<string>(8);
var start = 0;
for (var i = 0; i < subject.Length; i++)
{
if (subject[i] == Btsep)
{
tokens.Add(subject.Substring(start, i - start));
start = i + 1;
}
}
tokens.Add(subject.Substring(start));
return tokens.ToArray();
}
// -------------------------------------------------------------------------
// Private: Trie node and level types
// -------------------------------------------------------------------------
/// <summary>
/// A trie node holding a subscription map and an optional link to the next level.
/// Mirrors Go's <c>node[T]</c>.
/// </summary>
private sealed class TrieNode
{
/// <summary>Maps subscription value → original subject string.</summary>
public readonly Dictionary<T, string> Subs = new();
/// <summary>The next trie level below this node, or null if at a leaf.</summary>
public TrieLevel? Next;
/// <summary>
/// Returns true when the node has no subscriptions and no live children.
/// Used during removal to decide whether to prune this node.
/// Mirrors Go's <c>node.isEmpty()</c>.
/// </summary>
public bool IsEmpty() => Subs.Count == 0 && (Next == null || Next.NumNodes() == 0);
}
/// <summary>
/// A trie level containing named child nodes and special wildcard slots.
/// Mirrors Go's <c>level[T]</c>.
/// </summary>
private sealed class TrieLevel
{
public readonly Dictionary<string, TrieNode> Nodes = new();
public TrieNode? PwcNode; // '*' single-token wildcard node
public TrieNode? FwcNode; // '>' full-wildcard node
/// <summary>
/// Returns the total count of live nodes at this level.
/// Mirrors Go's <c>level.numNodes()</c>.
/// </summary>
public int NumNodes()
{
var num = Nodes.Count;
if (PwcNode != null) num++;
if (FwcNode != null) num++;
return num;
}
/// <summary>
/// Removes an empty node from this level, using reference equality to
/// distinguish wildcard slots from named slots.
/// Mirrors Go's <c>level.pruneNode()</c>.
/// </summary>
public void PruneNode(TrieNode n, string token)
{
if (ReferenceEquals(n, FwcNode))
FwcNode = null;
else if (ReferenceEquals(n, PwcNode))
PwcNode = null;
else
Nodes.Remove(token);
}
}
/// <summary>
/// Tracks a (level, node, token) triple during removal for upward pruning.
/// Mirrors Go's <c>lnt[T]</c>.
/// </summary>
private readonly struct LevelNodeToken
{
public readonly TrieLevel Level;
public readonly TrieNode Node;
public readonly string Token;
public LevelNodeToken(TrieLevel level, TrieNode node, string token)
{
Level = level;
Node = node;
Token = token;
}
public void Deconstruct(out TrieLevel level, out TrieNode node, out string token)
{
level = Level;
node = Node;
token = Token;
}
}
}
/// <summary>
/// A lightweight sublist that tracks interest only, without storing any associated data.
/// Equivalent to Go's <c>SimpleSublist = GenericSublist[struct{}]</c>.
/// </summary>
public sealed class SimpleSublist : GenericSublist<EmptyStruct>
{
internal SimpleSublist() { }
}

View File

@@ -0,0 +1,263 @@
using System.Buffers.Binary;
namespace ZB.MOM.NatsNet.Server.Internal.DataStructures;
/// <summary>
/// A time-based hash wheel for efficiently scheduling and expiring timer tasks keyed by sequence number.
/// Each slot covers a 1-second window; the wheel has 4096 slots (covering ~68 minutes before wrapping).
/// Not thread-safe.
/// </summary>
/// <remarks>
/// Mirrors the Go <c>thw.HashWheel</c> type. Timestamps are Unix nanoseconds (<see cref="long"/>).
/// </remarks>
public sealed class HashWheel
{
/// <summary>Slot width in nanoseconds (1 second).</summary>
private const long TickDuration = 1_000_000_000L;
private const int WheelBits = 12;
private const int WheelSize = 1 << WheelBits; // 4096
private const int WheelMask = WheelSize - 1;
private const int HeaderLen = 17; // 1 magic + 8 count + 8 highSeq
public static readonly Exception ErrTaskNotFound = new InvalidOperationException("thw: task not found");
public static readonly Exception ErrInvalidVersion = new InvalidDataException("thw: encoded version not known");
private readonly Slot?[] _wheel = new Slot?[WheelSize];
private long _lowest = long.MaxValue;
private ulong _count;
// --- Slot ---
private sealed class Slot
{
public readonly Dictionary<ulong, long> Entries = new();
public long Lowest = long.MaxValue;
}
/// <summary>Creates a new empty <see cref="HashWheel"/>.</summary>
public static HashWheel NewHashWheel() => new();
private static Slot NewSlot() => new();
private long GetPosition(long expires) => (expires / TickDuration) & WheelMask;
// --- Public API ---
/// <summary>Returns the number of tasks currently scheduled.</summary>
public ulong Count => _count;
/// <summary>Schedules a new timer task.</summary>
public void Add(ulong seq, long expires)
{
var pos = (int)GetPosition(expires);
_wheel[pos] ??= NewSlot();
var slot = _wheel[pos]!;
if (!slot.Entries.ContainsKey(seq))
_count++;
slot.Entries[seq] = expires;
if (expires < slot.Lowest)
{
slot.Lowest = expires;
if (expires < _lowest)
_lowest = expires;
}
}
/// <summary>Removes a timer task.</summary>
/// <exception cref="InvalidOperationException">Thrown (as <see cref="ErrTaskNotFound"/>) when not found.</exception>
public void Remove(ulong seq, long expires)
{
var pos = (int)GetPosition(expires);
var slot = _wheel[pos];
if (slot is null || !slot.Entries.ContainsKey(seq))
throw ErrTaskNotFound;
slot.Entries.Remove(seq);
_count--;
if (slot.Entries.Count == 0)
_wheel[pos] = null;
}
/// <summary>Updates the expiration time of an existing timer task.</summary>
public void Update(ulong seq, long oldExpires, long newExpires)
{
Remove(seq, oldExpires);
Add(seq, newExpires);
}
/// <summary>
/// Expires all tasks whose timestamp is &lt;= now. The callback receives each task;
/// if it returns <see langword="true"/> the task is removed, otherwise it is kept.
/// </summary>
public void ExpireTasks(Func<ulong, long, bool> callback)
{
var now = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() * 1_000_000L;
ExpireTasksInternal(now, callback);
}
internal void ExpireTasksInternal(long ts, Func<ulong, long, bool> callback)
{
if (_lowest > ts)
return;
var globalLowest = long.MaxValue;
for (var pos = 0; pos < WheelSize; pos++)
{
var slot = _wheel[pos];
if (slot is null || slot.Lowest > ts)
{
if (slot is not null && slot.Lowest < globalLowest)
globalLowest = slot.Lowest;
continue;
}
var slotLowest = long.MaxValue;
// Snapshot keys to allow removal during iteration.
var keys = slot.Entries.Keys.ToArray();
foreach (var seq in keys)
{
var exp = slot.Entries[seq];
if (exp <= ts && callback(seq, exp))
{
slot.Entries.Remove(seq);
_count--;
continue;
}
if (exp < slotLowest)
slotLowest = exp;
}
if (slot.Entries.Count == 0)
{
_wheel[pos] = null;
}
else
{
slot.Lowest = slotLowest;
if (slotLowest < globalLowest)
globalLowest = slotLowest;
}
}
_lowest = globalLowest;
}
/// <summary>
/// Returns the earliest expiration timestamp before <paramref name="before"/>,
/// or <see cref="long.MaxValue"/> if none.
/// </summary>
public long GetNextExpiration(long before) =>
_lowest < before ? _lowest : long.MaxValue;
// --- Encode / Decode ---
/// <summary>
/// Serializes the wheel to a byte array. <paramref name="highSeq"/> is stored
/// in the header and returned by <see cref="Decode"/>.
/// </summary>
public byte[] Encode(ulong highSeq)
{
// Preallocate conservatively: header + up to 2 varints per entry.
var buf = new List<byte>(HeaderLen + (int)(_count * 16));
buf.Add(1); // magic version
AppendUint64LE(buf, _count);
AppendUint64LE(buf, highSeq);
foreach (var slot in _wheel)
{
if (slot is null)
continue;
foreach (var (seq, ts) in slot.Entries)
{
AppendVarint(buf, ts);
AppendUvarint(buf, seq);
}
}
return buf.ToArray();
}
/// <summary>
/// Replaces this wheel's contents with those from a binary snapshot.
/// Returns the <c>highSeq</c> stored in the header.
/// </summary>
public ulong Decode(ReadOnlySpan<byte> b)
{
if (b.Length < HeaderLen)
throw (InvalidDataException)ErrInvalidVersion;
if (b[0] != 1)
throw (InvalidDataException)ErrInvalidVersion;
// Reset wheel.
Array.Clear(_wheel);
_lowest = long.MaxValue;
_count = 0;
var count = BinaryPrimitives.ReadUInt64LittleEndian(b[1..]);
var stamp = BinaryPrimitives.ReadUInt64LittleEndian(b[9..]);
var pos = HeaderLen;
for (ulong i = 0; i < count; i++)
{
var ts = ReadVarint(b, ref pos);
var seq = ReadUvarint(b, ref pos);
Add(seq, ts);
}
return stamp;
}
// --- Encoding helpers ---
private static void AppendUint64LE(List<byte> buf, ulong v)
{
buf.Add((byte)v);
buf.Add((byte)(v >> 8));
buf.Add((byte)(v >> 16));
buf.Add((byte)(v >> 24));
buf.Add((byte)(v >> 32));
buf.Add((byte)(v >> 40));
buf.Add((byte)(v >> 48));
buf.Add((byte)(v >> 56));
}
private static void AppendVarint(List<byte> buf, long v)
{
// ZigZag encode like Go's binary.AppendVarint.
var uv = (ulong)((v << 1) ^ (v >> 63));
AppendUvarint(buf, uv);
}
private static void AppendUvarint(List<byte> buf, ulong v)
{
while (v >= 0x80)
{
buf.Add((byte)(v | 0x80));
v >>= 7;
}
buf.Add((byte)v);
}
private static long ReadVarint(ReadOnlySpan<byte> b, ref int pos)
{
var uv = ReadUvarint(b, ref pos);
var v = (long)(uv >> 1);
if ((uv & 1) != 0)
v = ~v;
return v;
}
private static ulong ReadUvarint(ReadOnlySpan<byte> b, ref int pos)
{
ulong x = 0;
int s = 0;
while (pos < b.Length)
{
var by = b[pos++];
x |= (ulong)(by & 0x7F) << s;
if ((by & 0x80) == 0)
return x;
s += 7;
}
throw new InvalidDataException("thw: unexpected EOF in varint");
}
}

View File

@@ -0,0 +1,488 @@
// Copyright 2023-2025 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
namespace ZB.MOM.NatsNet.Server.Internal.DataStructures;
/// <summary>
/// An adaptive radix trie (ART) for storing subject information on literal NATS subjects.
/// Uses dynamic nodes (4/10/16/48/256 children), path compression, and lazy expansion.
/// Supports exact lookup, wildcard matching ('*' and '>'), and ordered/fast iteration.
/// Not thread-safe.
/// </summary>
public sealed class SubjectTree<T>
{
internal ISubjectTreeNode<T>? _root;
private int _size;
/// <summary>Returns the number of entries stored in the tree.</summary>
public int Size() => _size;
/// <summary>Returns true if the tree has no entries.</summary>
public bool Empty() => _size == 0;
/// <summary>Clears all entries from the tree.</summary>
public SubjectTree<T> Reset()
{
_root = null;
_size = 0;
return this;
}
/// <summary>
/// Inserts a value into the tree under the given subject.
/// If the subject already exists, returns the old value with updated=true.
/// Subjects containing byte 127 (the noPivot sentinel) are rejected silently.
/// </summary>
public (T? oldVal, bool updated) Insert(ReadOnlySpan<byte> subject, T value)
{
if (subject.IndexOf(SubjectTreeParts.NoPivot) >= 0)
return (default, false);
var subjectBytes = subject.ToArray();
var (old, updated) = DoInsert(ref _root, subjectBytes, value, 0);
if (!updated)
_size++;
return (old, updated);
}
/// <summary>
/// Finds the value stored at the given literal subject.
/// Returns (value, true) if found, (default, false) otherwise.
/// </summary>
public (T? val, bool found) Find(ReadOnlySpan<byte> subject)
{
var si = 0;
var n = _root;
var subjectBytes = subject.ToArray();
while (n != null)
{
if (n.IsLeaf)
{
var ln = (SubjectTreeLeaf<T>)n;
return ln.Match(subjectBytes.AsSpan(si))
? (ln.Value, true)
: (default, false);
}
var prefix = n.Prefix;
if (prefix.Length > 0)
{
var end = Math.Min(si + prefix.Length, subjectBytes.Length);
if (!subjectBytes.AsSpan(si, end - si).SequenceEqual(prefix.AsSpan(0, end - si)))
return (default, false);
si += prefix.Length;
}
var next = n.FindChild(SubjectTreeParts.Pivot(subjectBytes, si));
if (next == null) return (default, false);
n = next;
}
return (default, false);
}
/// <summary>
/// Deletes the entry at the given literal subject.
/// Returns (value, true) if deleted, (default, false) if not found.
/// </summary>
public (T? val, bool found) Delete(ReadOnlySpan<byte> subject)
{
if (_root == null || subject.IsEmpty) return (default, false);
var subjectBytes = subject.ToArray();
var (val, deleted) = DoDelete(ref _root, subjectBytes, 0);
if (deleted) _size--;
return (val, deleted);
}
/// <summary>
/// Matches all stored subjects against a filter that may contain wildcards ('*' and '>').
/// Invokes fn for each match. Return false from the callback to stop early.
/// </summary>
public void Match(ReadOnlySpan<byte> filter, Func<byte[], T, bool> fn)
{
if (_root == null || filter.IsEmpty || fn == null) return;
var parts = SubjectTreeParts.GenParts(filter.ToArray());
MatchNode(_root, parts, Array.Empty<byte>(), fn);
}
/// <summary>
/// Like Match but returns false if the callback stopped iteration early.
/// Returns true if matching ran to completion.
/// </summary>
public bool MatchUntil(ReadOnlySpan<byte> filter, Func<byte[], T, bool> fn)
{
if (_root == null || filter.IsEmpty || fn == null) return true;
var parts = SubjectTreeParts.GenParts(filter.ToArray());
return MatchNode(_root, parts, Array.Empty<byte>(), fn);
}
/// <summary>
/// Walks all entries in lexicographical order.
/// Return false from the callback to stop early.
/// </summary>
public void IterOrdered(Func<byte[], T, bool> fn)
{
if (_root == null || fn == null) return;
IterNode(_root, Array.Empty<byte>(), ordered: true, fn);
}
/// <summary>
/// Walks all entries in storage order (no ordering guarantee).
/// Return false from the callback to stop early.
/// </summary>
public void IterFast(Func<byte[], T, bool> fn)
{
if (_root == null || fn == null) return;
IterNode(_root, Array.Empty<byte>(), ordered: false, fn);
}
// -------------------------------------------------------------------------
// Internal recursive insert
// -------------------------------------------------------------------------
private static (T? old, bool updated) DoInsert(ref ISubjectTreeNode<T>? np, byte[] subject, T value, int si)
{
if (np == null)
{
np = new SubjectTreeLeaf<T>(subject, value);
return (default, false);
}
if (np.IsLeaf)
{
var ln = (SubjectTreeLeaf<T>)np;
if (ln.Match(subject.AsSpan(si)))
{
var oldVal = ln.Value;
ln.Value = value;
return (oldVal, true);
}
// Split the leaf: compute common prefix between existing suffix and new subject tail.
var cpi = SubjectTreeParts.CommonPrefixLen(ln.Suffix, subject.AsSpan(si));
var nn = new SubjectTreeNode4<T>(subject[si..(si + cpi)]);
ln.SetSuffix(ln.Suffix[cpi..]);
si += cpi;
var p = SubjectTreeParts.Pivot(ln.Suffix, 0);
if (cpi > 0 && si < subject.Length && p == subject[si])
{
// Same pivot after the split — recurse to separate further.
DoInsert(ref np, subject, value, si);
nn.AddChild(p, np!);
}
else
{
var nl = new SubjectTreeLeaf<T>(subject[si..], value);
nn.AddChild(SubjectTreeParts.Pivot(nl.Suffix, 0), nl);
nn.AddChild(SubjectTreeParts.Pivot(ln.Suffix, 0), ln);
}
np = nn;
return (default, false);
}
// Non-leaf node.
var prefix = np.Prefix;
if (prefix.Length > 0)
{
var cpi = SubjectTreeParts.CommonPrefixLen(prefix, subject.AsSpan(si));
if (cpi >= prefix.Length)
{
// Full prefix match: move past this node.
si += prefix.Length;
var pivotByte = SubjectTreeParts.Pivot(subject, si);
var existingChild = np.FindChild(pivotByte);
if (existingChild != null)
{
var before = existingChild;
var (old, upd) = DoInsert(ref existingChild, subject, value, si);
// Only re-register if the child reference changed identity (grew or split).
if (!ReferenceEquals(before, existingChild))
{
np.DeleteChild(pivotByte);
np.AddChild(pivotByte, existingChild!);
}
return (old, upd);
}
if (np.IsFull)
np = np.Grow();
np.AddChild(SubjectTreeParts.Pivot(subject, si), new SubjectTreeLeaf<T>(subject[si..], value));
return (default, false);
}
else
{
// Partial prefix match — insert a new node4 above the current node.
var newPrefix = subject[si..(si + cpi)];
si += cpi;
var splitNode = new SubjectTreeNode4<T>(newPrefix);
((SubjectTreeMeta<T>)np).SetPrefix(prefix[cpi..]);
// Use np.Prefix (updated) to get the correct pivot for the demoted node.
splitNode.AddChild(SubjectTreeParts.Pivot(np.Prefix, 0), np);
splitNode.AddChild(
SubjectTreeParts.Pivot(subject.AsSpan(si), 0),
new SubjectTreeLeaf<T>(subject[si..], value));
np = splitNode;
}
}
else
{
// No prefix on this node.
var pivotByte = SubjectTreeParts.Pivot(subject, si);
var existingChild = np.FindChild(pivotByte);
if (existingChild != null)
{
var before = existingChild;
var (old, upd) = DoInsert(ref existingChild, subject, value, si);
if (!ReferenceEquals(before, existingChild))
{
np.DeleteChild(pivotByte);
np.AddChild(pivotByte, existingChild!);
}
return (old, upd);
}
if (np.IsFull)
np = np.Grow();
np.AddChild(SubjectTreeParts.Pivot(subject, si), new SubjectTreeLeaf<T>(subject[si..], value));
}
return (default, false);
}
// -------------------------------------------------------------------------
// Internal recursive delete
// -------------------------------------------------------------------------
private static (T? val, bool deleted) DoDelete(ref ISubjectTreeNode<T>? np, byte[] subject, int si)
{
if (np == null || subject.Length == 0) return (default, false);
var n = np;
if (n.IsLeaf)
{
var ln = (SubjectTreeLeaf<T>)n;
if (ln.Match(subject.AsSpan(si)))
{
np = null;
return (ln.Value, true);
}
return (default, false);
}
// Check prefix.
var prefix = n.Prefix;
if (prefix.Length > 0)
{
if (subject.Length < si + prefix.Length)
return (default, false);
if (!subject.AsSpan(si, prefix.Length).SequenceEqual(prefix))
return (default, false);
si += prefix.Length;
}
var p = SubjectTreeParts.Pivot(subject, si);
var childNode = n.FindChild(p);
if (childNode == null) return (default, false);
if (childNode.IsLeaf)
{
var childLeaf = (SubjectTreeLeaf<T>)childNode;
if (childLeaf.Match(subject.AsSpan(si)))
{
n.DeleteChild(p);
TryShrink(ref np!, prefix);
return (childLeaf.Value, true);
}
return (default, false);
}
// Recurse into non-leaf child.
var (val, deleted) = DoDelete(ref childNode, subject, si);
if (deleted)
{
if (childNode == null)
{
// Child was nulled out — remove slot and try to shrink.
n.DeleteChild(p);
TryShrink(ref np!, prefix);
}
else
{
// Child changed identity — re-register.
n.DeleteChild(p);
n.AddChild(p, childNode);
}
}
return (val, deleted);
}
private static void TryShrink(ref ISubjectTreeNode<T> np, byte[] parentPrefix)
{
var shrunk = np.Shrink();
if (shrunk == null) return;
if (shrunk.IsLeaf)
{
var shrunkLeaf = (SubjectTreeLeaf<T>)shrunk;
if (parentPrefix.Length > 0)
shrunkLeaf.Suffix = [.. parentPrefix, .. shrunkLeaf.Suffix];
}
else if (parentPrefix.Length > 0)
{
((SubjectTreeMeta<T>)shrunk).SetPrefix([.. parentPrefix, .. shrunk.Prefix]);
}
np = shrunk;
}
// -------------------------------------------------------------------------
// Internal recursive wildcard match
// -------------------------------------------------------------------------
private static bool MatchNode(ISubjectTreeNode<T> n, byte[][] parts, byte[] pre, Func<byte[], T, bool> fn)
{
var hasFwc = parts.Length > 0 && parts[^1].Length == 1 && parts[^1][0] == SubjectTreeParts.Fwc;
while (n != null!)
{
var (nparts, matched) = n.MatchParts(parts);
if (!matched) return true;
if (n.IsLeaf)
{
if (nparts.Length == 0 || (hasFwc && nparts.Length == 1))
{
var ln = (SubjectTreeLeaf<T>)n;
if (!fn(ConcatBytes(pre, ln.Suffix), ln.Value)) return false;
}
return true;
}
// Append this node's prefix to the running accumulator.
var prefix = n.Prefix;
if (prefix.Length > 0)
pre = ConcatBytes(pre, prefix);
if (nparts.Length == 0 && !hasFwc)
{
// No parts remaining and no fwc — look for terminal matches.
var hasTermPwc = parts.Length > 0 && parts[^1].Length == 1 && parts[^1][0] == SubjectTreeParts.Pwc;
var termParts = hasTermPwc ? parts[^1..] : Array.Empty<byte[]>();
foreach (var cn in n.Children())
{
if (cn == null!) continue;
if (cn.IsLeaf)
{
var ln = (SubjectTreeLeaf<T>)cn;
if (ln.Suffix.Length == 0)
{
if (!fn(ConcatBytes(pre, ln.Suffix), ln.Value)) return false;
}
else if (hasTermPwc && Array.IndexOf(ln.Suffix, SubjectTreeParts.TSep) < 0)
{
if (!fn(ConcatBytes(pre, ln.Suffix), ln.Value)) return false;
}
}
else if (hasTermPwc)
{
if (!MatchNode(cn, termParts, pre, fn)) return false;
}
}
return true;
}
// Re-put the terminal fwc if nparts was exhausted by matching.
if (hasFwc && nparts.Length == 0)
nparts = parts[^1..];
var fp = nparts[0];
var pByte = SubjectTreeParts.Pivot(fp, 0);
if (fp.Length == 1 && (pByte == SubjectTreeParts.Pwc || pByte == SubjectTreeParts.Fwc))
{
// Wildcard part — iterate all children.
foreach (var cn in n.Children())
{
if (cn != null!)
{
if (!MatchNode(cn, nparts, pre, fn)) return false;
}
}
return true;
}
// Literal part — find specific child and loop.
var nextNode = n.FindChild(pByte);
if (nextNode == null) return true;
n = nextNode;
parts = nparts;
}
return true;
}
// -------------------------------------------------------------------------
// Internal iteration
// -------------------------------------------------------------------------
private static bool IterNode(ISubjectTreeNode<T> n, byte[] pre, bool ordered, Func<byte[], T, bool> fn)
{
if (n.IsLeaf)
{
var ln = (SubjectTreeLeaf<T>)n;
return fn(ConcatBytes(pre, ln.Suffix), ln.Value);
}
pre = ConcatBytes(pre, n.Prefix);
if (!ordered)
{
foreach (var cn in n.Children())
{
if (cn == null!) continue;
if (!IterNode(cn, pre, false, fn)) return false;
}
return true;
}
// Ordered: sort children by their path bytes lexicographically.
var children = n.Children().Where(c => c != null!).ToArray();
Array.Sort(children, static (a, b) => a.Path.AsSpan().SequenceCompareTo(b.Path.AsSpan()));
foreach (var cn in children)
{
if (!IterNode(cn, pre, true, fn)) return false;
}
return true;
}
// -------------------------------------------------------------------------
// Byte array helpers
// -------------------------------------------------------------------------
internal static byte[] ConcatBytes(byte[] a, byte[] b)
{
if (a.Length == 0) return b.Length == 0 ? Array.Empty<byte>() : b;
if (b.Length == 0) return a;
var result = new byte[a.Length + b.Length];
a.CopyTo(result, 0);
b.CopyTo(result, a.Length);
return result;
}
}

View File

@@ -0,0 +1,483 @@
// Copyright 2023-2025 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
namespace ZB.MOM.NatsNet.Server.Internal.DataStructures;
// Internal node interface for the adaptive radix trie.
internal interface ISubjectTreeNode<T>
{
bool IsLeaf { get; }
byte[] Prefix { get; }
void AddChild(byte key, ISubjectTreeNode<T> child);
ISubjectTreeNode<T>? FindChild(byte key);
void DeleteChild(byte key);
bool IsFull { get; }
ISubjectTreeNode<T> Grow();
ISubjectTreeNode<T>? Shrink();
ISubjectTreeNode<T>[] Children();
int NumChildren { get; }
byte[] Path { get; }
(byte[][] remainingParts, bool matched) MatchParts(byte[][] parts);
string Kind { get; }
}
// Base class for non-leaf nodes, holding prefix and child count.
internal abstract class SubjectTreeMeta<T> : ISubjectTreeNode<T>
{
protected byte[] _prefix;
protected int _size;
protected SubjectTreeMeta(byte[] prefix)
{
_prefix = SubjectTreeParts.CopyBytes(prefix);
}
public bool IsLeaf => false;
public byte[] Prefix => _prefix;
public int NumChildren => _size;
public byte[] Path => _prefix;
public void SetPrefix(byte[] prefix)
{
_prefix = SubjectTreeParts.CopyBytes(prefix);
}
public (byte[][] remainingParts, bool matched) MatchParts(byte[][] parts)
=> SubjectTreeParts.MatchParts(parts, _prefix);
public abstract void AddChild(byte key, ISubjectTreeNode<T> child);
public abstract ISubjectTreeNode<T>? FindChild(byte key);
public abstract void DeleteChild(byte key);
public abstract bool IsFull { get; }
public abstract ISubjectTreeNode<T> Grow();
public abstract ISubjectTreeNode<T>? Shrink();
public abstract ISubjectTreeNode<T>[] Children();
public abstract string Kind { get; }
}
// Leaf node storing the terminal value plus a suffix byte[].
internal sealed class SubjectTreeLeaf<T> : ISubjectTreeNode<T>
{
public T Value;
public byte[] Suffix;
public SubjectTreeLeaf(byte[] suffix, T value)
{
Suffix = SubjectTreeParts.CopyBytes(suffix);
Value = value;
}
public bool IsLeaf => true;
public byte[] Prefix => Array.Empty<byte>();
public int NumChildren => 0;
public byte[] Path => Suffix;
public string Kind => "LEAF";
public bool Match(ReadOnlySpan<byte> subject)
=> subject.SequenceEqual(Suffix);
public void SetSuffix(byte[] suffix)
=> Suffix = SubjectTreeParts.CopyBytes(suffix);
public bool IsFull => true;
public (byte[][] remainingParts, bool matched) MatchParts(byte[][] parts)
=> SubjectTreeParts.MatchParts(parts, Suffix);
// Leaf nodes do not support child operations.
public void AddChild(byte key, ISubjectTreeNode<T> child)
=> throw new InvalidOperationException("AddChild called on leaf");
public ISubjectTreeNode<T>? FindChild(byte key)
=> throw new InvalidOperationException("FindChild called on leaf");
public void DeleteChild(byte key)
=> throw new InvalidOperationException("DeleteChild called on leaf");
public ISubjectTreeNode<T> Grow()
=> throw new InvalidOperationException("Grow called on leaf");
public ISubjectTreeNode<T>? Shrink()
=> throw new InvalidOperationException("Shrink called on leaf");
public ISubjectTreeNode<T>[] Children()
=> Array.Empty<ISubjectTreeNode<T>>();
}
// Node with up to 4 children (keys + children arrays, unsorted).
internal sealed class SubjectTreeNode4<T> : SubjectTreeMeta<T>
{
private readonly byte[] _keys = new byte[4];
private readonly ISubjectTreeNode<T>?[] _children = new ISubjectTreeNode<T>?[4];
public SubjectTreeNode4(byte[] prefix) : base(prefix) { }
public override string Kind => "NODE4";
public override void AddChild(byte key, ISubjectTreeNode<T> child)
{
if (_size >= 4) throw new InvalidOperationException("node4 full!");
_keys[_size] = key;
_children[_size] = child;
_size++;
}
public override ISubjectTreeNode<T>? FindChild(byte key)
{
for (var i = 0; i < _size; i++)
{
if (_keys[i] == key) return _children[i];
}
return null;
}
public override void DeleteChild(byte key)
{
for (var i = 0; i < _size; i++)
{
if (_keys[i] == key)
{
var last = _size - 1;
if (i < last)
{
_keys[i] = _keys[last];
_children[i] = _children[last];
}
_keys[last] = 0;
_children[last] = null;
_size--;
return;
}
}
}
public override bool IsFull => _size >= 4;
public override ISubjectTreeNode<T> Grow()
{
var nn = new SubjectTreeNode10<T>(_prefix);
for (var i = 0; i < 4; i++)
nn.AddChild(_keys[i], _children[i]!);
return nn;
}
public override ISubjectTreeNode<T>? Shrink()
{
if (_size == 1) return _children[0];
return null;
}
public override ISubjectTreeNode<T>[] Children()
{
var result = new ISubjectTreeNode<T>[_size];
for (var i = 0; i < _size; i++)
result[i] = _children[i]!;
return result;
}
// Internal access for tests.
internal byte GetKey(int index) => _keys[index];
internal ISubjectTreeNode<T>? GetChild(int index) => _children[index];
}
// Node with up to 10 children (for numeric token segments).
internal sealed class SubjectTreeNode10<T> : SubjectTreeMeta<T>
{
private readonly byte[] _keys = new byte[10];
private readonly ISubjectTreeNode<T>?[] _children = new ISubjectTreeNode<T>?[10];
public SubjectTreeNode10(byte[] prefix) : base(prefix) { }
public override string Kind => "NODE10";
public override void AddChild(byte key, ISubjectTreeNode<T> child)
{
if (_size >= 10) throw new InvalidOperationException("node10 full!");
_keys[_size] = key;
_children[_size] = child;
_size++;
}
public override ISubjectTreeNode<T>? FindChild(byte key)
{
for (var i = 0; i < _size; i++)
{
if (_keys[i] == key) return _children[i];
}
return null;
}
public override void DeleteChild(byte key)
{
for (var i = 0; i < _size; i++)
{
if (_keys[i] == key)
{
var last = _size - 1;
if (i < last)
{
_keys[i] = _keys[last];
_children[i] = _children[last];
}
_keys[last] = 0;
_children[last] = null;
_size--;
return;
}
}
}
public override bool IsFull => _size >= 10;
public override ISubjectTreeNode<T> Grow()
{
var nn = new SubjectTreeNode16<T>(_prefix);
for (var i = 0; i < _size; i++)
nn.AddChild(_keys[i], _children[i]!);
return nn;
}
public override ISubjectTreeNode<T>? Shrink()
{
if (_size > 4) return null;
var nn = new SubjectTreeNode4<T>(Array.Empty<byte>());
for (var i = 0; i < _size; i++)
nn.AddChild(_keys[i], _children[i]!);
return nn;
}
public override ISubjectTreeNode<T>[] Children()
{
var result = new ISubjectTreeNode<T>[_size];
for (var i = 0; i < _size; i++)
result[i] = _children[i]!;
return result;
}
}
// Node with up to 16 children.
internal sealed class SubjectTreeNode16<T> : SubjectTreeMeta<T>
{
private readonly byte[] _keys = new byte[16];
private readonly ISubjectTreeNode<T>?[] _children = new ISubjectTreeNode<T>?[16];
public SubjectTreeNode16(byte[] prefix) : base(prefix) { }
public override string Kind => "NODE16";
public override void AddChild(byte key, ISubjectTreeNode<T> child)
{
if (_size >= 16) throw new InvalidOperationException("node16 full!");
_keys[_size] = key;
_children[_size] = child;
_size++;
}
public override ISubjectTreeNode<T>? FindChild(byte key)
{
for (var i = 0; i < _size; i++)
{
if (_keys[i] == key) return _children[i];
}
return null;
}
public override void DeleteChild(byte key)
{
for (var i = 0; i < _size; i++)
{
if (_keys[i] == key)
{
var last = _size - 1;
if (i < last)
{
_keys[i] = _keys[last];
_children[i] = _children[last];
}
_keys[last] = 0;
_children[last] = null;
_size--;
return;
}
}
}
public override bool IsFull => _size >= 16;
public override ISubjectTreeNode<T> Grow()
{
var nn = new SubjectTreeNode48<T>(_prefix);
for (var i = 0; i < _size; i++)
nn.AddChild(_keys[i], _children[i]!);
return nn;
}
public override ISubjectTreeNode<T>? Shrink()
{
if (_size > 10) return null;
var nn = new SubjectTreeNode10<T>(Array.Empty<byte>());
for (var i = 0; i < _size; i++)
nn.AddChild(_keys[i], _children[i]!);
return nn;
}
public override ISubjectTreeNode<T>[] Children()
{
var result = new ISubjectTreeNode<T>[_size];
for (var i = 0; i < _size; i++)
result[i] = _children[i]!;
return result;
}
}
// Node with up to 48 children, using a 256-byte key index (1-indexed, 0 means empty).
internal sealed class SubjectTreeNode48<T> : SubjectTreeMeta<T>
{
// _keyIndex[byte] = 1-based index into _children; 0 means no entry.
private readonly byte[] _keyIndex = new byte[256];
private readonly ISubjectTreeNode<T>?[] _children = new ISubjectTreeNode<T>?[48];
public SubjectTreeNode48(byte[] prefix) : base(prefix) { }
public override string Kind => "NODE48";
public override void AddChild(byte key, ISubjectTreeNode<T> child)
{
if (_size >= 48) throw new InvalidOperationException("node48 full!");
_children[_size] = child;
_keyIndex[key] = (byte)(_size + 1); // 1-indexed
_size++;
}
public override ISubjectTreeNode<T>? FindChild(byte key)
{
var i = _keyIndex[key];
if (i == 0) return null;
return _children[i - 1];
}
public override void DeleteChild(byte key)
{
var i = _keyIndex[key];
if (i == 0) return;
i--; // Convert from 1-indexed
var last = _size - 1;
if (i < last)
{
_children[i] = _children[last];
// Find which key index points to 'last' and redirect it to 'i'.
for (var ic = 0; ic < 256; ic++)
{
if (_keyIndex[ic] == last + 1)
{
_keyIndex[ic] = (byte)(i + 1);
break;
}
}
}
_children[last] = null;
_keyIndex[key] = 0;
_size--;
}
public override bool IsFull => _size >= 48;
public override ISubjectTreeNode<T> Grow()
{
var nn = new SubjectTreeNode256<T>(_prefix);
for (var c = 0; c < 256; c++)
{
var i = _keyIndex[c];
if (i > 0)
nn.AddChild((byte)c, _children[i - 1]!);
}
return nn;
}
public override ISubjectTreeNode<T>? Shrink()
{
if (_size > 16) return null;
var nn = new SubjectTreeNode16<T>(Array.Empty<byte>());
for (var c = 0; c < 256; c++)
{
var i = _keyIndex[c];
if (i > 0)
nn.AddChild((byte)c, _children[i - 1]!);
}
return nn;
}
public override ISubjectTreeNode<T>[] Children()
{
var result = new ISubjectTreeNode<T>[_size];
var idx = 0;
for (var i = 0; i < _size; i++)
{
if (_children[i] != null)
result[idx++] = _children[i]!;
}
return result[..idx];
}
// Internal access for tests.
internal byte GetKeyIndex(int key) => _keyIndex[key];
internal ISubjectTreeNode<T>? GetChildAt(int index) => _children[index];
}
// Node with 256 children, indexed directly by byte value.
internal sealed class SubjectTreeNode256<T> : SubjectTreeMeta<T>
{
private readonly ISubjectTreeNode<T>?[] _children = new ISubjectTreeNode<T>?[256];
public SubjectTreeNode256(byte[] prefix) : base(prefix) { }
public override string Kind => "NODE256";
public override void AddChild(byte key, ISubjectTreeNode<T> child)
{
_children[key] = child;
_size++;
}
public override ISubjectTreeNode<T>? FindChild(byte key)
=> _children[key];
public override void DeleteChild(byte key)
{
if (_children[key] != null)
{
_children[key] = null;
_size--;
}
}
public override bool IsFull => false;
public override ISubjectTreeNode<T> Grow()
=> throw new InvalidOperationException("Grow cannot be called on node256");
public override ISubjectTreeNode<T>? Shrink()
{
if (_size > 48) return null;
var nn = new SubjectTreeNode48<T>(Array.Empty<byte>());
for (var c = 0; c < 256; c++)
{
if (_children[c] != null)
nn.AddChild((byte)c, _children[c]!);
}
return nn;
}
public override ISubjectTreeNode<T>[] Children()
=> _children.Where(c => c != null).Select(c => c!).ToArray();
}

View File

@@ -0,0 +1,242 @@
// Copyright 2023-2025 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
namespace ZB.MOM.NatsNet.Server.Internal.DataStructures;
/// <summary>
/// Utility methods for NATS subject matching, wildcard part decomposition,
/// common prefix computation, and byte manipulation used by SubjectTree.
/// </summary>
internal static class SubjectTreeParts
{
// NATS subject special bytes.
internal const byte Pwc = (byte)'*'; // single-token wildcard
internal const byte Fwc = (byte)'>'; // full wildcard (terminal)
internal const byte TSep = (byte)'.'; // token separator
// Sentinel pivot returned when subject position is past end.
internal const byte NoPivot = 127;
/// <summary>
/// Returns the pivot byte at <paramref name="pos"/> in <paramref name="subject"/>,
/// or <see cref="NoPivot"/> if the position is at or beyond the end.
/// </summary>
internal static byte Pivot(ReadOnlySpan<byte> subject, int pos)
=> pos >= subject.Length ? NoPivot : subject[pos];
/// <summary>
/// Returns the pivot byte at <paramref name="pos"/> in <paramref name="subject"/>,
/// or <see cref="NoPivot"/> if the position is at or beyond the end.
/// </summary>
internal static byte Pivot(byte[] subject, int pos)
=> pos >= subject.Length ? NoPivot : subject[pos];
/// <summary>
/// Computes the number of leading bytes that are equal between two spans.
/// </summary>
internal static int CommonPrefixLen(ReadOnlySpan<byte> s1, ReadOnlySpan<byte> s2)
{
var limit = Math.Min(s1.Length, s2.Length);
var i = 0;
while (i < limit && s1[i] == s2[i])
i++;
return i;
}
/// <summary>
/// Returns a copy of <paramref name="src"/>, or an empty array if src is empty.
/// </summary>
internal static byte[] CopyBytes(ReadOnlySpan<byte> src)
{
if (src.IsEmpty) return Array.Empty<byte>();
return src.ToArray();
}
/// <summary>
/// Returns a copy of <paramref name="src"/>, or an empty array if src is null or empty.
/// </summary>
internal static byte[] CopyBytes(byte[]? src)
{
if (src == null || src.Length == 0) return Array.Empty<byte>();
var dst = new byte[src.Length];
src.CopyTo(dst, 0);
return dst;
}
/// <summary>
/// Converts a byte array to a string using Latin-1 (ISO-8859-1) encoding,
/// which preserves a 1:1 byte-to-char mapping for all byte values 0-255.
/// </summary>
internal static string BytesToString(byte[] bytes)
{
if (bytes.Length == 0) return string.Empty;
return System.Text.Encoding.Latin1.GetString(bytes);
}
/// <summary>
/// Breaks a filter subject into parts separated by wildcards ('*' and '>').
/// Each literal segment between wildcards becomes one part; each wildcard
/// becomes its own single-byte part.
/// </summary>
internal static byte[][] GenParts(byte[] filter)
{
var parts = new List<byte[]>(8);
var start = 0;
var e = filter.Length - 1;
for (var i = 0; i < filter.Length; i++)
{
if (filter[i] == TSep)
{
// Check if next token is pwc (internal or terminal).
if (i < e && filter[i + 1] == Pwc &&
((i + 2 <= e && filter[i + 2] == TSep) || i + 1 == e))
{
if (i > start)
parts.Add(filter[start..(i + 1)]);
parts.Add(filter[(i + 1)..(i + 2)]);
i++; // skip pwc
if (i + 2 <= e)
i++; // skip next tsep from next part
start = i + 1;
}
else if (i < e && filter[i + 1] == Fwc && i + 1 == e)
{
if (i > start)
parts.Add(filter[start..(i + 1)]);
parts.Add(filter[(i + 1)..(i + 2)]);
i++; // skip fwc
start = i + 1;
}
}
else if (filter[i] == Pwc || filter[i] == Fwc)
{
// Wildcard must be preceded by tsep (or be at start).
var prev = i - 1;
if (prev >= 0 && filter[prev] != TSep)
continue;
// Wildcard must be at end or followed by tsep.
var next = i + 1;
if (next == e || (next < e && filter[next] != TSep))
continue;
// Full wildcard must be terminal.
if (filter[i] == Fwc && i < e)
break;
// Leading wildcard.
parts.Add(filter[i..(i + 1)]);
if (i + 1 <= e)
i++; // skip next tsep
start = i + 1;
}
}
if (start < filter.Length)
{
// Eat leading tsep if present.
if (filter[start] == TSep)
start++;
if (start < filter.Length)
parts.Add(filter[start..]);
}
return parts.ToArray();
}
/// <summary>
/// Matches parts against a fragment (prefix or suffix).
/// Returns the remaining parts and whether matching succeeded.
/// </summary>
internal static (byte[][] remainingParts, bool matched) MatchParts(byte[][] parts, byte[] frag)
{
var lf = frag.Length;
if (lf == 0) return (parts, true);
var si = 0;
var lpi = parts.Length - 1;
for (var i = 0; i < parts.Length; i++)
{
if (si >= lf)
return (parts[i..], true);
var part = parts[i];
var lp = part.Length;
// Check for wildcard placeholders.
if (lp == 1)
{
if (part[0] == Pwc)
{
// Find the next token separator.
var index = Array.IndexOf(frag, TSep, si);
if (index < 0)
{
// No tsep found.
if (i == lpi)
return (Array.Empty<byte[]>(), true);
return (parts[i..], true);
}
si = index + 1;
continue;
}
else if (part[0] == Fwc)
{
return (Array.Empty<byte[]>(), true);
}
}
var end = Math.Min(si + lp, lf);
// If part is larger than the remaining fragment, adjust.
var comparePart = part;
if (si + lp > end)
comparePart = part[..(end - si)];
if (!frag.AsSpan(si, end - si).SequenceEqual(comparePart))
return (parts, false);
// Fragment still has bytes left.
if (end < lf)
{
si = end;
continue;
}
// We matched a partial part.
if (end < si + lp)
{
if (end >= lf)
{
// Create a copy of parts with the current part trimmed.
var newParts = new byte[parts.Length][];
parts.CopyTo(newParts, 0);
newParts[i] = parts[i][(lf - si)..];
return (newParts[i..], true);
}
else
{
return (parts[(i + 1)..], true);
}
}
if (i == lpi)
return (Array.Empty<byte[]>(), true);
si += part.Length;
}
return (parts, false);
}
}

View File

@@ -0,0 +1,64 @@
namespace ZB.MOM.NatsNet.Server.Internal;
/// <summary>
/// A pointer that can be toggled between weak and strong references, allowing
/// the garbage collector to reclaim the target when weakened.
/// Mirrors the Go <c>elastic.Pointer[T]</c> type.
/// </summary>
/// <typeparam name="T">The type of the referenced object. Must be a reference type.</typeparam>
public sealed class ElasticPointer<T> where T : class
{
private WeakReference<T>? _weak;
private T? _strong;
/// <summary>
/// Creates a new <see cref="ElasticPointer{T}"/> holding a weak reference to <paramref name="value"/>.
/// </summary>
public static ElasticPointer<T> Make(T value)
{
return new ElasticPointer<T> { _weak = new WeakReference<T>(value) };
}
/// <summary>
/// Updates the target. If the pointer is currently strengthened, the strong reference is updated too.
/// </summary>
public void Set(T value)
{
_weak = new WeakReference<T>(value);
if (_strong != null)
_strong = value;
}
/// <summary>
/// Promotes to a strong reference, preventing the GC from collecting the target.
/// No-op if already strengthened or if the weak target has been collected.
/// </summary>
public void Strengthen()
{
if (_strong != null)
return;
if (_weak != null && _weak.TryGetTarget(out var target))
_strong = target;
}
/// <summary>
/// Reverts to a weak reference, allowing the GC to reclaim the target.
/// No-op if already weakened.
/// </summary>
public void Weaken()
{
_strong = null;
}
/// <summary>
/// Returns the target value, or <see langword="null"/> if the weak reference has been collected.
/// </summary>
public T? Value()
{
if (_strong != null)
return _strong;
if (_weak != null && _weak.TryGetTarget(out var target))
return target;
return null;
}
}

View File

@@ -0,0 +1,106 @@
using System.Diagnostics;
namespace ZB.MOM.NatsNet.Server.Internal;
/// <summary>
/// Provides cross-platform process CPU and memory usage statistics.
/// Mirrors the Go <c>pse</c> (Process Status Emulation) package, replacing
/// per-platform implementations (rusage, /proc/stat, PDH) with
/// <see cref="System.Diagnostics.Process"/>.
/// </summary>
public static class ProcessStatsProvider
{
private static readonly Process _self = Process.GetCurrentProcess();
private static readonly int _processorCount = Environment.ProcessorCount;
private static readonly object _lock = new();
private static TimeSpan _lastCpuTime;
private static DateTime _lastSampleTime;
private static double _cachedPcpu;
private static long _cachedRss;
private static long _cachedVss;
static ProcessStatsProvider()
{
UpdateUsage();
StartPeriodicSampling();
}
/// <summary>
/// Returns the current process CPU percentage, RSS (bytes), and VSS (bytes).
/// Values are refreshed approximately every second by a background timer.
/// </summary>
/// <param name="pcpu">Percent CPU utilization (0100 × core count).</param>
/// <param name="rss">Resident set size in bytes.</param>
/// <param name="vss">Virtual memory size in bytes.</param>
public static void ProcUsage(out double pcpu, out long rss, out long vss)
{
lock (_lock)
{
pcpu = _cachedPcpu;
rss = _cachedRss;
vss = _cachedVss;
}
}
private static void UpdateUsage()
{
try
{
_self.Refresh();
var now = DateTime.UtcNow;
var cpuTime = _self.TotalProcessorTime;
lock (_lock)
{
var elapsed = now - _lastSampleTime;
if (elapsed >= TimeSpan.FromMilliseconds(500))
{
var cpuDelta = (cpuTime - _lastCpuTime).TotalSeconds;
// Normalize against elapsed wall time.
// Result is 0100; does not multiply by ProcessorCount to match Go behaviour.
_cachedPcpu = elapsed.TotalSeconds > 0
? Math.Round(cpuDelta / elapsed.TotalSeconds * 1000.0) / 10.0
: 0;
_lastSampleTime = now;
_lastCpuTime = cpuTime;
}
_cachedRss = _self.WorkingSet64;
_cachedVss = _self.VirtualMemorySize64;
}
}
catch
{
// Suppress — diagnostics should never crash the server.
}
}
private static void StartPeriodicSampling()
{
var timer = new Timer(_ => UpdateUsage(), null,
dueTime: TimeSpan.FromSeconds(1),
period: TimeSpan.FromSeconds(1));
// Keep timer alive for the process lifetime.
GC.KeepAlive(timer);
}
// --- Windows PDH helpers (replaced by Process class in .NET) ---
// The following methods exist to satisfy the porting mapping but delegate
// to the cross-platform Process API above.
internal static string GetProcessImageName() =>
Path.GetFileNameWithoutExtension(Environment.ProcessPath ?? _self.ProcessName);
internal static void InitCounters()
{
// No-op: .NET Process class initializes lazily.
}
internal static double PdhOpenQuery() => 0; // Mapped to Process API.
internal static double PdhAddCounter() => 0;
internal static double PdhCollectQueryData() => 0;
internal static double PdhGetFormattedCounterArrayDouble() => 0;
internal static double GetCounterArrayData() => 0;
}

View File

@@ -0,0 +1,95 @@
using System.Runtime.InteropServices;
namespace ZB.MOM.NatsNet.Server.Internal;
/// <summary>
/// Returns total physical memory available to the system in bytes.
/// Mirrors the Go <c>sysmem</c> package with platform-specific implementations.
/// Returns 0 if the value cannot be determined on the current platform.
/// </summary>
public static class SystemMemory
{
/// <summary>Returns total physical memory in bytes, or 0 on failure.</summary>
public static long Memory()
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
return MemoryWindows();
if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
return MemoryDarwin();
if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
return MemoryLinux();
return 0;
}
// --- macOS ---
internal static long MemoryDarwin() => SysctlInt64("hw.memsize");
/// <summary>
/// Reads an int64 sysctl value by name on BSD-derived systems (macOS, FreeBSD, etc.).
/// </summary>
internal static unsafe long SysctlInt64(string name)
{
var size = (nuint)sizeof(long);
long value = 0;
var ret = sysctlbyname(name, &value, &size, IntPtr.Zero, 0);
return ret == 0 ? value : 0;
}
[DllImport("libc", EntryPoint = "sysctlbyname", SetLastError = true)]
private static extern unsafe int sysctlbyname(
string name,
void* oldp,
nuint* oldlenp,
IntPtr newp,
nuint newlen);
// --- Linux ---
internal static long MemoryLinux()
{
try
{
// Parse MemTotal from /proc/meminfo (value is in kB).
foreach (var line in File.ReadLines("/proc/meminfo"))
{
if (!line.StartsWith("MemTotal:", StringComparison.Ordinal))
continue;
var parts = line.Split(' ', StringSplitOptions.RemoveEmptyEntries);
if (parts.Length >= 2 && long.TryParse(parts[1], out var kb))
return kb * 1024L;
}
}
catch
{
// Fall through to return 0.
}
return 0;
}
// --- Windows ---
[StructLayout(LayoutKind.Sequential)]
private struct MemoryStatusEx
{
public uint dwLength;
public uint dwMemoryLoad;
public ulong ullTotalPhys;
public ulong ullAvailPhys;
public ulong ullTotalPageFile;
public ulong ullAvailPageFile;
public ulong ullTotalVirtual;
public ulong ullAvailVirtual;
public ulong ullAvailExtendedVirtual;
}
[DllImport("kernel32.dll", SetLastError = true)]
[return: MarshalAs(UnmanagedType.Bool)]
private static extern bool GlobalMemoryStatusEx(ref MemoryStatusEx lpBuffer);
internal static long MemoryWindows()
{
var msx = new MemoryStatusEx { dwLength = (uint)Marshal.SizeOf<MemoryStatusEx>() };
return GlobalMemoryStatusEx(ref msx) ? (long)msx.ullTotalPhys : 0;
}
}

View File

@@ -0,0 +1,80 @@
// Copyright 2012-2025 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Adapted from server/parser.go and server/client.go in the NATS server Go source.
namespace ZB.MOM.NatsNet.Server.Protocol;
/// <summary>
/// Interface for the protocol handler callbacks invoked by <see cref="ProtocolParser.Parse"/>.
/// Decouples the state machine from the client implementation.
/// The client connection will implement this interface in later sessions.
/// </summary>
public interface IProtocolHandler
{
// ---- Dynamic connection state ----
bool IsMqtt { get; }
bool Trace { get; }
bool HasMappings { get; }
bool IsAwaitingAuth { get; }
/// <summary>
/// Attempts to register the no-auth user for this connection.
/// Returns true if a no-auth user was found and registered (allowing parse to continue).
/// </summary>
bool TryRegisterNoAuthUser();
/// <summary>
/// Returns true if this is a gateway inbound connection that has not yet received CONNECT.
/// </summary>
bool IsGatewayInboundNotConnected { get; }
// ---- Protocol action handlers ----
Exception? ProcessConnect(byte[] arg);
Exception? ProcessInfo(byte[] arg);
void ProcessPing();
void ProcessPong();
void ProcessErr(string arg);
// ---- Sub/unsub handlers (kind-specific) ----
Exception? ProcessClientSub(byte[] arg);
Exception? ProcessClientUnsub(byte[] arg);
Exception? ProcessRemoteSub(byte[] arg, bool isLeaf);
Exception? ProcessRemoteUnsub(byte[] arg, bool isLeafUnsub);
Exception? ProcessGatewayRSub(byte[] arg);
Exception? ProcessGatewayRUnsub(byte[] arg);
Exception? ProcessLeafSub(byte[] arg);
Exception? ProcessLeafUnsub(byte[] arg);
Exception? ProcessAccountSub(byte[] arg);
void ProcessAccountUnsub(byte[] arg);
// ---- Message processing ----
void ProcessInboundMsg(byte[] msg);
bool SelectMappedSubject();
// ---- Tracing ----
void TraceInOp(string name, byte[]? arg);
void TraceMsg(byte[] msg);
// ---- Error handling ----
void SendErr(string msg);
void AuthViolation();
void CloseConnection(int reason);
string KindString();
}

View File

@@ -0,0 +1,171 @@
// Copyright 2012-2025 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Adapted from server/parser.go in the NATS server Go source.
using ZB.MOM.NatsNet.Server.Internal;
namespace ZB.MOM.NatsNet.Server.Protocol;
/// <summary>
/// Parser state machine states.
/// Mirrors the Go <c>parserState</c> const block in parser.go (79 states).
/// </summary>
public enum ParserState
{
OpStart = 0,
OpPlus,
OpPlusO,
OpPlusOk,
OpMinus,
OpMinusE,
OpMinusEr,
OpMinusErr,
OpMinusErrSpc,
MinusErrArg,
OpC,
OpCo,
OpCon,
OpConn,
OpConne,
OpConnec,
OpConnect,
ConnectArg,
OpH,
OpHp,
OpHpu,
OpHpub,
OpHpubSpc,
HpubArg,
OpHm,
OpHms,
OpHmsg,
OpHmsgSpc,
HmsgArg,
OpP,
OpPu,
OpPub,
OpPubSpc,
PubArg,
OpPi,
OpPin,
OpPing,
OpPo,
OpPon,
OpPong,
MsgPayload,
MsgEndR,
MsgEndN,
OpS,
OpSu,
OpSub,
OpSubSpc,
SubArg,
OpA,
OpAsub,
OpAsubSpc,
AsubArg,
OpAusub,
OpAusubSpc,
AusubArg,
OpL,
OpLs,
OpR,
OpRs,
OpU,
OpUn,
OpUns,
OpUnsu,
OpUnsub,
OpUnsubSpc,
UnsubArg,
OpM,
OpMs,
OpMsg,
OpMsgSpc,
MsgArg,
OpI,
OpIn,
OpInf,
OpInfo,
InfoArg,
}
/// <summary>
/// Parsed publish/message arguments.
/// Mirrors Go <c>pubArg</c> struct in parser.go.
/// </summary>
public sealed class PublishArgument
{
public byte[]? Arg { get; set; }
public byte[]? PaCache { get; set; }
public byte[]? Origin { get; set; }
public byte[]? Account { get; set; }
public byte[]? Subject { get; set; }
public byte[]? Deliver { get; set; }
public byte[]? Mapped { get; set; }
public byte[]? Reply { get; set; }
public byte[]? SizeBytes { get; set; }
public byte[]? HeaderBytes { get; set; }
public List<byte[]>? Queues { get; set; }
public int Size { get; set; }
public int HeaderSize { get; set; } = -1;
public bool Delivered { get; set; }
/// <summary>Resets all fields to their defaults.</summary>
public void Reset()
{
Arg = null;
PaCache = null;
Origin = null;
Account = null;
Subject = null;
Deliver = null;
Mapped = null;
Reply = null;
SizeBytes = null;
HeaderBytes = null;
Queues = null;
Size = 0;
HeaderSize = -1;
Delivered = false;
}
}
/// <summary>
/// Holds the parser state for a single connection.
/// Mirrors Go <c>parseState</c> struct embedded in <c>client</c>.
/// </summary>
public sealed class ParseContext
{
// ---- Parser state ----
public ParserState State { get; set; }
public byte Op { get; set; }
public int ArgStart { get; set; }
public int Drop { get; set; }
public PublishArgument Pa { get; } = new();
public byte[]? ArgBuf { get; set; }
public byte[]? MsgBuf { get; set; }
// ---- Connection-level properties (set once at creation) ----
public ClientKind Kind { get; set; }
public int MaxControlLine { get; set; } = ServerConstants.MaxControlLineSize;
public int MaxPayload { get; set; } = -1;
public bool HasHeaders { get; set; }
// ---- Internal scratch buffer ----
internal byte[] Scratch { get; } = new byte[ServerConstants.MaxControlLineSize];
}

File diff suppressed because it is too large Load Diff