diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/Accounts/DirJwtStore.cs b/dotnet/src/ZB.MOM.NatsNet.Server/Accounts/DirJwtStore.cs index ad045e4..127a3c3 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/Accounts/DirJwtStore.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/Accounts/DirJwtStore.cs @@ -833,6 +833,58 @@ public sealed class DirJwtStore : IDisposable // Private static helpers // --------------------------------------------------------------------------- + /// + /// Validates the supplied path exists, and enforces whether it must be a + /// directory or regular file. + /// Mirrors Go validatePathExists. + /// + internal static string ValidatePathExists(string path, bool dir) + { + if (string.IsNullOrWhiteSpace(path)) + { + throw new ArgumentException("path is not specified", nameof(path)); + } + + string absolutePath; + try + { + absolutePath = Path.GetFullPath(path); + } + catch (Exception ex) + { + throw new InvalidOperationException($"error parsing path [{path}]: {ex.Message}", ex); + } + + if (!File.Exists(absolutePath) && !Directory.Exists(absolutePath)) + { + throw new InvalidOperationException($"the path [{absolutePath}] doesn't exist"); + } + + var attributes = File.GetAttributes(absolutePath); + var isDirectory = (attributes & FileAttributes.Directory) == FileAttributes.Directory; + + if (dir && !isDirectory) + { + throw new InvalidOperationException($"the path [{absolutePath}] is not a directory"); + } + + if (!dir && isDirectory) + { + throw new InvalidOperationException($"the path [{absolutePath}] is not a file"); + } + + return absolutePath; + } + + /// + /// Validates the supplied path exists and is a directory. + /// Mirrors Go validateDirPath. + /// + internal static string ValidateDirPath(string path) + { + return ValidatePathExists(path, dir: true); + } + /// /// Validates that exists and is a directory, optionally /// creating it when is true. @@ -841,31 +893,18 @@ public sealed class DirJwtStore : IDisposable /// private static string NewDir(string dirPath, bool create) { - if (string.IsNullOrEmpty(dirPath)) + if (Directory.Exists(dirPath) || File.Exists(dirPath)) { - throw new ArgumentException("Path is not specified", nameof(dirPath)); - } - - if (Directory.Exists(dirPath)) - { - return Path.GetFullPath(dirPath); + return ValidateDirPath(dirPath); } if (!create) { - throw new DirectoryNotFoundException( - $"The path [{dirPath}] doesn't exist"); + return ValidateDirPath(dirPath); } Directory.CreateDirectory(dirPath); - - if (!Directory.Exists(dirPath)) - { - throw new DirectoryNotFoundException( - $"Failed to create directory [{dirPath}]"); - } - - return Path.GetFullPath(dirPath); + return ValidateDirPath(dirPath); } /// @@ -1044,6 +1083,7 @@ internal sealed class ExpirationTracker { // Min-heap ordered by expiration (Unix nanoseconds stored as ticks for TimeSpan compatibility). private readonly PriorityQueue _heap; + private readonly List _compatHeap; // Index from publicKey to JwtItem for O(1) lookup and hash tracking. private readonly Dictionary _idx; @@ -1068,6 +1108,7 @@ internal sealed class ExpirationTracker EvictOnLimit = evictOnLimit; Ttl = ttl; _heap = new PriorityQueue(); + _compatHeap = []; _idx = new Dictionary(StringComparer.Ordinal); _lru = new LinkedList(); _hash = new byte[SHA256.HashSizeInBytes]; @@ -1075,6 +1116,55 @@ internal sealed class ExpirationTracker internal void SetTimer(Timer timer) => _timer = timer; + /// Returns the number of items in the compatibility heap. + /// Mirrors Go expirationTracker.Len. + internal int Len() => _compatHeap.Count; + + /// Returns true when item expires before . + /// Mirrors Go expirationTracker.Less. + internal bool Less(int i, int j) + { + return _compatHeap[i].Expiration < _compatHeap[j].Expiration; + } + + /// Swaps two compatibility heap items and updates their indexes. + /// Mirrors Go expirationTracker.Swap. + internal void Swap(int i, int j) + { + (_compatHeap[i], _compatHeap[j]) = (_compatHeap[j], _compatHeap[i]); + _compatHeap[i].Index = i; + _compatHeap[j].Index = j; + } + + /// Adds an item to the compatibility heap and index maps. + /// Mirrors Go expirationTracker.Push. + internal void Push(JwtItem item) + { + item.Index = _compatHeap.Count; + _compatHeap.Add(item); + _idx[item.PublicKey] = item; + _lru.AddLast(item.PublicKey); + } + + /// Removes and returns the last compatibility heap item. + /// Mirrors Go expirationTracker.Pop. + internal JwtItem Pop() + { + var n = _compatHeap.Count; + var item = _compatHeap[n - 1]; + _compatHeap.RemoveAt(n - 1); + item.Index = -1; + + var node = _lru.Find(item.PublicKey); + if (node != null) + { + _lru.Remove(node); + } + + _idx.Remove(item.PublicKey); + return item; + } + /// /// Adds or updates tracking for . /// When an entry already exists its expiration and hash are updated. @@ -1259,6 +1349,7 @@ internal sealed class ExpirationTracker internal void Reset() { _heap.Clear(); + _compatHeap.Clear(); _idx.Clear(); _lru.Clear(); Array.Clear(_hash); @@ -1352,6 +1443,7 @@ internal sealed class ExpirationTracker /// internal sealed class JwtItem { + internal int Index { get; set; } internal string PublicKey { get; } internal long Expiration { get; set; } internal byte[] Hash { get; set; } @@ -1367,6 +1459,7 @@ internal sealed class JwtItem internal JwtItem(string publicKey, long expiration, byte[] hash) { + Index = -1; PublicKey = publicKey; Expiration = expiration; Hash = hash; diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.cs b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.cs index 7f526d9..2faba27 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/ClientConnection.cs @@ -17,6 +17,7 @@ using System.Net; using System.Net.Security; using System.Net.Sockets; using System.Runtime.CompilerServices; +using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using System.Text; using System.Text.Json; @@ -875,6 +876,47 @@ public sealed partial class ClientConnection } } + /// + /// Returns true when the current TLS peer certificate matches one of the pinned + /// SPKI SHA-256 key identifiers. + /// Mirrors Go client.matchesPinnedCert. + /// + internal bool MatchesPinnedCert(PinnedCertSet? tlsPinnedCerts) + { + if (tlsPinnedCerts == null) + { + return true; + } + + var certificate = GetTlsCertificate(); + if (certificate == null) + { + Debugf("Failed pinned cert test as client did not provide a certificate"); + return false; + } + + byte[] subjectPublicKeyInfo; + try + { + subjectPublicKeyInfo = certificate.PublicKey.ExportSubjectPublicKeyInfo(); + } + catch + { + subjectPublicKeyInfo = certificate.GetPublicKey(); + } + + var sha = SHA256.HashData(subjectPublicKeyInfo); + var keyId = Convert.ToHexString(sha).ToLowerInvariant(); + + if (!tlsPinnedCerts.Contains(keyId)) + { + Debugf("Failed pinned cert test for key id: {0}", keyId); + return false; + } + + return true; + } + internal void SetAccount(INatsAccount? acc) { lock (_mu) { Account = acc; } diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.Reload.cs b/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.Reload.cs index b6d5b51..c483b1e 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.Reload.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.Reload.cs @@ -14,7 +14,6 @@ // Adapted from server/reload.go in the NATS server Go source. using System.Reflection; -using System.Security.Cryptography; using System.Text.Json; using ZB.MOM.NatsNet.Server.Auth; using ZB.MOM.NatsNet.Server.Internal; @@ -1331,26 +1330,7 @@ public sealed partial class NatsServer private static bool MatchesPinnedCert(ClientConnection client, PinnedCertSet? pinnedCerts) { - if (pinnedCerts == null || pinnedCerts.Count == 0) - return true; - - var certificate = client.GetTlsCertificate(); - if (certificate == null) - return false; - - byte[] keyBytes; - try - { - keyBytes = certificate.PublicKey.ExportSubjectPublicKeyInfo(); - } - catch - { - keyBytes = certificate.GetPublicKey(); - } - - var hash = SHA256.HashData(keyBytes); - var hex = Convert.ToHexString(hash).ToLowerInvariant(); - return pinnedCerts.Contains(hex); + return client.MatchesPinnedCert(pinnedCerts); } } diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Accounts/DirectoryStoreTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Accounts/DirectoryStoreTests.cs index 24bc2ae..4dc95e3 100644 --- a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Accounts/DirectoryStoreTests.cs +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Accounts/DirectoryStoreTests.cs @@ -767,4 +767,47 @@ public sealed class DirectoryStoreTests : IDisposable foreach (var s in stores) try { s?.Dispose(); } catch { /* best-effort */ } } } + + [Fact] + public void ValidateDirPath_ExistingDirectory_ReturnsAbsolutePath() + { + var dir = MakeTempDir(); + + var validated = DirJwtStore.ValidateDirPath(dir); + + validated.ShouldBe(Path.GetFullPath(dir)); + } + + [Fact] + public void ValidatePathExists_PathIsFileWhenDirectoryExpected_Throws() + { + var dir = MakeTempDir(); + var file = Path.Combine(dir, "token.jwt"); + File.WriteAllText(file, "jwt"); + + Should.Throw(() => DirJwtStore.ValidatePathExists(file, dir: true)); + } + + [Fact] + public void ExpirationTracker_HeapPrimitives_MaintainIndexAndTracking() + { + var tracker = new ExpirationTracker(limit: 10, evictOnLimit: true, ttl: TimeSpan.Zero); + var a = new JwtItem("A", expiration: 10, hash: [1, 2, 3]); + var b = new JwtItem("B", expiration: 20, hash: [4, 5, 6]); + + tracker.Push(a); + tracker.Push(b); + + tracker.Len().ShouldBe(2); + tracker.Less(0, 1).ShouldBeTrue(); + + tracker.Swap(0, 1); + tracker.Less(0, 1).ShouldBeFalse(); + + var popped = tracker.Pop(); + popped.PublicKey.ShouldBe("A"); + tracker.IsTracked("A").ShouldBeFalse(); + tracker.IsTracked("B").ShouldBeTrue(); + tracker.Len().ShouldBe(1); + } } diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ServerTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ServerTests.cs index 2e4ca56..649e6f1 100644 --- a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ServerTests.cs +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ServerTests.cs @@ -22,6 +22,7 @@ using NSubstitute.ExceptionExtensions; using Shouldly; using Xunit; using ZB.MOM.NatsNet.Server.Auth; +using ZB.MOM.NatsNet.Server.Internal; namespace ZB.MOM.NatsNet.Server.Tests; @@ -218,6 +219,21 @@ public sealed class ServerTests err.ShouldNotBeNull(); } + [Fact] + public void MatchesPinnedCert_NullPinnedSet_ReturnsTrue() + { + var client = new ClientConnection(ClientKind.Client, nc: new MemoryStream()); + client.MatchesPinnedCert(null).ShouldBeTrue(); + } + + [Fact] + public void MatchesPinnedCert_NoTlsCertificate_ReturnsFalse() + { + var client = new ClientConnection(ClientKind.Client, nc: new MemoryStream()); + var pinned = new PinnedCertSet([new string('a', 64)]); + client.MatchesPinnedCert(pinned).ShouldBeFalse(); + } + // ========================================================================= // GetServerProto // ========================================================================= diff --git a/porting.db b/porting.db index fa78b0f..dc3392b 100644 Binary files a/porting.db and b/porting.db differ