diff --git a/src/NATS.Server/Auth/Account.cs b/src/NATS.Server/Auth/Account.cs index bce25e1..c95d23e 100644 --- a/src/NATS.Server/Auth/Account.cs +++ b/src/NATS.Server/Auth/Account.cs @@ -12,6 +12,8 @@ public sealed class Account : IDisposable public Permissions? DefaultPermissions { get; set; } public int MaxConnections { get; set; } // 0 = unlimited public int MaxSubscriptions { get; set; } // 0 = unlimited + public int MaxJetStreamStreams { get; set; } // 0 = unlimited + public string? JetStreamTier { get; set; } // JWT fields public string? Nkey { get; set; } @@ -33,6 +35,7 @@ public sealed class Account : IDisposable private readonly ConcurrentDictionary _clients = new(); private int _subscriptionCount; + private int _jetStreamStreamCount; public Account(string name) { @@ -41,6 +44,7 @@ public sealed class Account : IDisposable public int ClientCount => _clients.Count; public int SubscriptionCount => Volatile.Read(ref _subscriptionCount); + public int JetStreamStreamCount => Volatile.Read(ref _jetStreamStreamCount); /// Returns false if max connections exceeded. public bool AddClient(ulong clientId) @@ -66,6 +70,23 @@ public sealed class Account : IDisposable Interlocked.Decrement(ref _subscriptionCount); } + public bool TryReserveStream() + { + if (MaxJetStreamStreams > 0 && Volatile.Read(ref _jetStreamStreamCount) >= MaxJetStreamStreams) + return false; + + Interlocked.Increment(ref _jetStreamStreamCount); + return true; + } + + public void ReleaseStream() + { + if (Volatile.Read(ref _jetStreamStreamCount) == 0) + return; + + Interlocked.Decrement(ref _jetStreamStreamCount); + } + // Per-account message/byte stats private long _inMsgs; private long _outMsgs; diff --git a/src/NATS.Server/Auth/AuthResult.cs b/src/NATS.Server/Auth/AuthResult.cs index 9e2d93c..dbc6322 100644 --- a/src/NATS.Server/Auth/AuthResult.cs +++ b/src/NATS.Server/Auth/AuthResult.cs @@ -6,4 +6,6 @@ public sealed class AuthResult public string? AccountName { get; init; } public Permissions? Permissions { get; init; } public DateTimeOffset? Expiry { get; init; } + public int MaxJetStreamStreams { get; init; } + public string? JetStreamTier { get; init; } } diff --git a/src/NATS.Server/Auth/Jwt/AccountClaims.cs b/src/NATS.Server/Auth/Jwt/AccountClaims.cs index d581d98..18c24c4 100644 --- a/src/NATS.Server/Auth/Jwt/AccountClaims.cs +++ b/src/NATS.Server/Auth/Jwt/AccountClaims.cs @@ -47,6 +47,10 @@ public sealed class AccountNats [JsonPropertyName("limits")] public AccountLimits? Limits { get; set; } + /// JetStream entitlement limits/tier for this account. + [JsonPropertyName("jetstream")] + public AccountJetStreamLimits? JetStream { get; set; } + /// NKey public keys authorized to sign user JWTs for this account. [JsonPropertyName("signing_keys")] public string[]? SigningKeys { get; set; } @@ -92,3 +96,12 @@ public sealed class AccountLimits [JsonPropertyName("data")] public long MaxData { get; set; } } + +public sealed class AccountJetStreamLimits +{ + [JsonPropertyName("max_streams")] + public int MaxStreams { get; set; } + + [JsonPropertyName("tier")] + public string? Tier { get; set; } +} diff --git a/src/NATS.Server/Auth/JwtAuthenticator.cs b/src/NATS.Server/Auth/JwtAuthenticator.cs index f28a155..5df1f27 100644 --- a/src/NATS.Server/Auth/JwtAuthenticator.cs +++ b/src/NATS.Server/Auth/JwtAuthenticator.cs @@ -143,6 +143,8 @@ public sealed class JwtAuthenticator : IAuthenticator AccountName = issuerAccount, Permissions = permissions, Expiry = userClaims.GetExpiry(), + MaxJetStreamStreams = accountClaims.Nats?.JetStream?.MaxStreams ?? 0, + JetStreamTier = accountClaims.Nats?.JetStream?.Tier, }; } diff --git a/src/NATS.Server/JetStream/StreamManager.cs b/src/NATS.Server/JetStream/StreamManager.cs index 3b1e1f3..ef55401 100644 --- a/src/NATS.Server/JetStream/StreamManager.cs +++ b/src/NATS.Server/JetStream/StreamManager.cs @@ -1,4 +1,5 @@ using System.Collections.Concurrent; +using NATS.Server.Auth; using NATS.Server.JetStream.Api; using NATS.Server.JetStream.Cluster; using NATS.Server.JetStream.MirrorSource; @@ -11,6 +12,7 @@ namespace NATS.Server.JetStream; public sealed class StreamManager { + private readonly Account? _account; private readonly JetStreamMetaGroup? _metaGroup; private readonly ConcurrentDictionary _streams = new(StringComparer.Ordinal); @@ -21,9 +23,10 @@ public sealed class StreamManager private readonly ConcurrentDictionary> _sourcesByOrigin = new(StringComparer.Ordinal); - public StreamManager(JetStreamMetaGroup? metaGroup = null) + public StreamManager(JetStreamMetaGroup? metaGroup = null, Account? account = null) { _metaGroup = metaGroup; + _account = account; } public IReadOnlyCollection StreamNames => _streams.Keys.ToArray(); @@ -34,6 +37,10 @@ public sealed class StreamManager return JetStreamApiResponse.ErrorResponse(400, "stream name required"); var normalized = NormalizeConfig(config); + var isCreate = !_streams.ContainsKey(normalized.Name); + if (isCreate && _account is not null && !_account.TryReserveStream()) + return JetStreamApiResponse.ErrorResponse(10027, "maximum streams exceeded"); + var handle = _streams.AddOrUpdate( normalized.Name, _ => new StreamHandle(normalized, new MemStore()), diff --git a/src/NATS.Server/NatsClient.cs b/src/NATS.Server/NatsClient.cs index 49017ec..0ab0e9d 100644 --- a/src/NATS.Server/NatsClient.cs +++ b/src/NATS.Server/NatsClient.cs @@ -419,6 +419,10 @@ public sealed class NatsClient : IDisposable { var accountName = authResult.AccountName ?? Account.GlobalAccountName; Account = server.GetOrCreateAccount(accountName); + if (authResult.MaxJetStreamStreams > 0) + Account.MaxJetStreamStreams = authResult.MaxJetStreamStreams; + if (!string.IsNullOrWhiteSpace(authResult.JetStreamTier)) + Account.JetStreamTier = authResult.JetStreamTier; if (!Account.AddClient(Id)) { Account = null; diff --git a/tests/NATS.Server.Tests/JetStreamApiFixture.cs b/tests/NATS.Server.Tests/JetStreamApiFixture.cs index 39d9f93..273ae87 100644 --- a/tests/NATS.Server.Tests/JetStreamApiFixture.cs +++ b/tests/NATS.Server.Tests/JetStreamApiFixture.cs @@ -1,4 +1,6 @@ using System.Text; +using System.Text.Json; +using NATS.Server.Auth; using NATS.Server.JetStream; using NATS.Server.JetStream.Api; using NATS.Server.JetStream.Consumers; @@ -18,9 +20,9 @@ internal sealed class JetStreamApiFixture : IAsyncDisposable private readonly JetStreamApiRouter _router; private readonly JetStreamPublisher _publisher; - private JetStreamApiFixture() + private JetStreamApiFixture(Account? account = null) { - _streamManager = new StreamManager(); + _streamManager = new StreamManager(account: account); _consumerManager = new ConsumerManager(); _router = new JetStreamApiRouter(_streamManager, _consumerManager); _publisher = new JetStreamPublisher(_streamManager); @@ -73,6 +75,17 @@ internal sealed class JetStreamApiFixture : IAsyncDisposable return fixture; } + public static Task StartJwtLimitedAccountAsync(int maxStreams) + { + var account = new Account("JWT-LIMITED") + { + MaxJetStreamStreams = maxStreams, + JetStreamTier = "jwt-tier", + }; + + return Task.FromResult(new JetStreamApiFixture(account)); + } + public Task PublishAndGetAckAsync(string subject, string payload, string? msgId = null, bool expectError = false) { if (_publisher.TryCapture(subject, Encoding.UTF8.GetBytes(payload), msgId, out var ack)) @@ -103,6 +116,16 @@ internal sealed class JetStreamApiFixture : IAsyncDisposable return Task.FromResult(_router.Route(subject, Encoding.UTF8.GetBytes(payload))); } + public Task CreateStreamAsync(string streamName, IReadOnlyList subjects) + { + var payload = JsonSerializer.Serialize(new + { + name = streamName, + subjects, + }); + return RequestLocalAsync($"$JS.API.STREAM.CREATE.{streamName}", payload); + } + public Task GetStreamStateAsync(string streamName) { return _streamManager.GetStateAsync(streamName, default).AsTask(); diff --git a/tests/NATS.Server.Tests/JetStreamJwtLimitTests.cs b/tests/NATS.Server.Tests/JetStreamJwtLimitTests.cs new file mode 100644 index 0000000..b5ee66a --- /dev/null +++ b/tests/NATS.Server.Tests/JetStreamJwtLimitTests.cs @@ -0,0 +1,16 @@ +namespace NATS.Server.Tests; + +public class JetStreamJwtLimitTests +{ + [Fact] + public async Task Account_limit_rejects_stream_create_when_max_streams_reached() + { + await using var fixture = await JetStreamApiFixture.StartJwtLimitedAccountAsync(maxStreams: 1); + + (await fixture.CreateStreamAsync("S1", subjects: ["s1.*"])) .Error.ShouldBeNull(); + var second = await fixture.CreateStreamAsync("S2", subjects: ["s2.*"]); + + second.Error.ShouldNotBeNull(); + second.Error!.Code.ShouldBe(10027); + } +}