feat: enforce account jetstream limits and jwt tiers

This commit is contained in:
Joseph Doherty
2026-02-23 06:21:51 -05:00
parent ccbcf759a9
commit 2aa7265db1
8 changed files with 91 additions and 3 deletions

View File

@@ -12,6 +12,8 @@ public sealed class Account : IDisposable
public Permissions? DefaultPermissions { get; set; } public Permissions? DefaultPermissions { get; set; }
public int MaxConnections { get; set; } // 0 = unlimited public int MaxConnections { get; set; } // 0 = unlimited
public int MaxSubscriptions { get; set; } // 0 = unlimited public int MaxSubscriptions { get; set; } // 0 = unlimited
public int MaxJetStreamStreams { get; set; } // 0 = unlimited
public string? JetStreamTier { get; set; }
// JWT fields // JWT fields
public string? Nkey { get; set; } public string? Nkey { get; set; }
@@ -33,6 +35,7 @@ public sealed class Account : IDisposable
private readonly ConcurrentDictionary<ulong, byte> _clients = new(); private readonly ConcurrentDictionary<ulong, byte> _clients = new();
private int _subscriptionCount; private int _subscriptionCount;
private int _jetStreamStreamCount;
public Account(string name) public Account(string name)
{ {
@@ -41,6 +44,7 @@ public sealed class Account : IDisposable
public int ClientCount => _clients.Count; public int ClientCount => _clients.Count;
public int SubscriptionCount => Volatile.Read(ref _subscriptionCount); public int SubscriptionCount => Volatile.Read(ref _subscriptionCount);
public int JetStreamStreamCount => Volatile.Read(ref _jetStreamStreamCount);
/// <summary>Returns false if max connections exceeded.</summary> /// <summary>Returns false if max connections exceeded.</summary>
public bool AddClient(ulong clientId) public bool AddClient(ulong clientId)
@@ -66,6 +70,23 @@ public sealed class Account : IDisposable
Interlocked.Decrement(ref _subscriptionCount); Interlocked.Decrement(ref _subscriptionCount);
} }
public bool TryReserveStream()
{
if (MaxJetStreamStreams > 0 && Volatile.Read(ref _jetStreamStreamCount) >= MaxJetStreamStreams)
return false;
Interlocked.Increment(ref _jetStreamStreamCount);
return true;
}
public void ReleaseStream()
{
if (Volatile.Read(ref _jetStreamStreamCount) == 0)
return;
Interlocked.Decrement(ref _jetStreamStreamCount);
}
// Per-account message/byte stats // Per-account message/byte stats
private long _inMsgs; private long _inMsgs;
private long _outMsgs; private long _outMsgs;

View File

@@ -6,4 +6,6 @@ public sealed class AuthResult
public string? AccountName { get; init; } public string? AccountName { get; init; }
public Permissions? Permissions { get; init; } public Permissions? Permissions { get; init; }
public DateTimeOffset? Expiry { get; init; } public DateTimeOffset? Expiry { get; init; }
public int MaxJetStreamStreams { get; init; }
public string? JetStreamTier { get; init; }
} }

View File

@@ -47,6 +47,10 @@ public sealed class AccountNats
[JsonPropertyName("limits")] [JsonPropertyName("limits")]
public AccountLimits? Limits { get; set; } public AccountLimits? Limits { get; set; }
/// <summary>JetStream entitlement limits/tier for this account.</summary>
[JsonPropertyName("jetstream")]
public AccountJetStreamLimits? JetStream { get; set; }
/// <summary>NKey public keys authorized to sign user JWTs for this account.</summary> /// <summary>NKey public keys authorized to sign user JWTs for this account.</summary>
[JsonPropertyName("signing_keys")] [JsonPropertyName("signing_keys")]
public string[]? SigningKeys { get; set; } public string[]? SigningKeys { get; set; }
@@ -92,3 +96,12 @@ public sealed class AccountLimits
[JsonPropertyName("data")] [JsonPropertyName("data")]
public long MaxData { get; set; } public long MaxData { get; set; }
} }
public sealed class AccountJetStreamLimits
{
[JsonPropertyName("max_streams")]
public int MaxStreams { get; set; }
[JsonPropertyName("tier")]
public string? Tier { get; set; }
}

View File

@@ -143,6 +143,8 @@ public sealed class JwtAuthenticator : IAuthenticator
AccountName = issuerAccount, AccountName = issuerAccount,
Permissions = permissions, Permissions = permissions,
Expiry = userClaims.GetExpiry(), Expiry = userClaims.GetExpiry(),
MaxJetStreamStreams = accountClaims.Nats?.JetStream?.MaxStreams ?? 0,
JetStreamTier = accountClaims.Nats?.JetStream?.Tier,
}; };
} }

View File

@@ -1,4 +1,5 @@
using System.Collections.Concurrent; using System.Collections.Concurrent;
using NATS.Server.Auth;
using NATS.Server.JetStream.Api; using NATS.Server.JetStream.Api;
using NATS.Server.JetStream.Cluster; using NATS.Server.JetStream.Cluster;
using NATS.Server.JetStream.MirrorSource; using NATS.Server.JetStream.MirrorSource;
@@ -11,6 +12,7 @@ namespace NATS.Server.JetStream;
public sealed class StreamManager public sealed class StreamManager
{ {
private readonly Account? _account;
private readonly JetStreamMetaGroup? _metaGroup; private readonly JetStreamMetaGroup? _metaGroup;
private readonly ConcurrentDictionary<string, StreamHandle> _streams = private readonly ConcurrentDictionary<string, StreamHandle> _streams =
new(StringComparer.Ordinal); new(StringComparer.Ordinal);
@@ -21,9 +23,10 @@ public sealed class StreamManager
private readonly ConcurrentDictionary<string, List<SourceCoordinator>> _sourcesByOrigin = private readonly ConcurrentDictionary<string, List<SourceCoordinator>> _sourcesByOrigin =
new(StringComparer.Ordinal); new(StringComparer.Ordinal);
public StreamManager(JetStreamMetaGroup? metaGroup = null) public StreamManager(JetStreamMetaGroup? metaGroup = null, Account? account = null)
{ {
_metaGroup = metaGroup; _metaGroup = metaGroup;
_account = account;
} }
public IReadOnlyCollection<string> StreamNames => _streams.Keys.ToArray(); public IReadOnlyCollection<string> StreamNames => _streams.Keys.ToArray();
@@ -34,6 +37,10 @@ public sealed class StreamManager
return JetStreamApiResponse.ErrorResponse(400, "stream name required"); return JetStreamApiResponse.ErrorResponse(400, "stream name required");
var normalized = NormalizeConfig(config); var normalized = NormalizeConfig(config);
var isCreate = !_streams.ContainsKey(normalized.Name);
if (isCreate && _account is not null && !_account.TryReserveStream())
return JetStreamApiResponse.ErrorResponse(10027, "maximum streams exceeded");
var handle = _streams.AddOrUpdate( var handle = _streams.AddOrUpdate(
normalized.Name, normalized.Name,
_ => new StreamHandle(normalized, new MemStore()), _ => new StreamHandle(normalized, new MemStore()),

View File

@@ -419,6 +419,10 @@ public sealed class NatsClient : IDisposable
{ {
var accountName = authResult.AccountName ?? Account.GlobalAccountName; var accountName = authResult.AccountName ?? Account.GlobalAccountName;
Account = server.GetOrCreateAccount(accountName); Account = server.GetOrCreateAccount(accountName);
if (authResult.MaxJetStreamStreams > 0)
Account.MaxJetStreamStreams = authResult.MaxJetStreamStreams;
if (!string.IsNullOrWhiteSpace(authResult.JetStreamTier))
Account.JetStreamTier = authResult.JetStreamTier;
if (!Account.AddClient(Id)) if (!Account.AddClient(Id))
{ {
Account = null; Account = null;

View File

@@ -1,4 +1,6 @@
using System.Text; using System.Text;
using System.Text.Json;
using NATS.Server.Auth;
using NATS.Server.JetStream; using NATS.Server.JetStream;
using NATS.Server.JetStream.Api; using NATS.Server.JetStream.Api;
using NATS.Server.JetStream.Consumers; using NATS.Server.JetStream.Consumers;
@@ -18,9 +20,9 @@ internal sealed class JetStreamApiFixture : IAsyncDisposable
private readonly JetStreamApiRouter _router; private readonly JetStreamApiRouter _router;
private readonly JetStreamPublisher _publisher; private readonly JetStreamPublisher _publisher;
private JetStreamApiFixture() private JetStreamApiFixture(Account? account = null)
{ {
_streamManager = new StreamManager(); _streamManager = new StreamManager(account: account);
_consumerManager = new ConsumerManager(); _consumerManager = new ConsumerManager();
_router = new JetStreamApiRouter(_streamManager, _consumerManager); _router = new JetStreamApiRouter(_streamManager, _consumerManager);
_publisher = new JetStreamPublisher(_streamManager); _publisher = new JetStreamPublisher(_streamManager);
@@ -73,6 +75,17 @@ internal sealed class JetStreamApiFixture : IAsyncDisposable
return fixture; return fixture;
} }
public static Task<JetStreamApiFixture> StartJwtLimitedAccountAsync(int maxStreams)
{
var account = new Account("JWT-LIMITED")
{
MaxJetStreamStreams = maxStreams,
JetStreamTier = "jwt-tier",
};
return Task.FromResult(new JetStreamApiFixture(account));
}
public Task<PubAck> PublishAndGetAckAsync(string subject, string payload, string? msgId = null, bool expectError = false) public Task<PubAck> PublishAndGetAckAsync(string subject, string payload, string? msgId = null, bool expectError = false)
{ {
if (_publisher.TryCapture(subject, Encoding.UTF8.GetBytes(payload), msgId, out var ack)) if (_publisher.TryCapture(subject, Encoding.UTF8.GetBytes(payload), msgId, out var ack))
@@ -103,6 +116,16 @@ internal sealed class JetStreamApiFixture : IAsyncDisposable
return Task.FromResult(_router.Route(subject, Encoding.UTF8.GetBytes(payload))); return Task.FromResult(_router.Route(subject, Encoding.UTF8.GetBytes(payload)));
} }
public Task<JetStreamApiResponse> CreateStreamAsync(string streamName, IReadOnlyList<string> subjects)
{
var payload = JsonSerializer.Serialize(new
{
name = streamName,
subjects,
});
return RequestLocalAsync($"$JS.API.STREAM.CREATE.{streamName}", payload);
}
public Task<StreamState> GetStreamStateAsync(string streamName) public Task<StreamState> GetStreamStateAsync(string streamName)
{ {
return _streamManager.GetStateAsync(streamName, default).AsTask(); return _streamManager.GetStateAsync(streamName, default).AsTask();

View File

@@ -0,0 +1,16 @@
namespace NATS.Server.Tests;
public class JetStreamJwtLimitTests
{
[Fact]
public async Task Account_limit_rejects_stream_create_when_max_streams_reached()
{
await using var fixture = await JetStreamApiFixture.StartJwtLimitedAccountAsync(maxStreams: 1);
(await fixture.CreateStreamAsync("S1", subjects: ["s1.*"])) .Error.ShouldBeNull();
var second = await fixture.CreateStreamAsync("S2", subjects: ["s2.*"]);
second.Error.ShouldNotBeNull();
second.Error!.Code.ShouldBe(10027);
}
}