diff --git a/src/NATS.Server/JetStream/Cluster/StreamReplicaGroup.cs b/src/NATS.Server/JetStream/Cluster/StreamReplicaGroup.cs new file mode 100644 index 0000000..284aad0 --- /dev/null +++ b/src/NATS.Server/JetStream/Cluster/StreamReplicaGroup.cs @@ -0,0 +1,65 @@ +using NATS.Server.Raft; + +namespace NATS.Server.JetStream.Cluster; + +public sealed class StreamReplicaGroup +{ + private readonly List _nodes; + + public string StreamName { get; } + public IReadOnlyList Nodes => _nodes; + public RaftNode Leader { get; private set; } + + public StreamReplicaGroup(string streamName, int replicas) + { + StreamName = streamName; + + var nodeCount = Math.Max(replicas, 1); + _nodes = Enumerable.Range(1, nodeCount) + .Select(i => new RaftNode($"{streamName.ToLowerInvariant()}-r{i}")) + .ToList(); + + foreach (var node in _nodes) + node.ConfigureCluster(_nodes); + + Leader = ElectLeader(_nodes[0]); + } + + public async ValueTask ProposeAsync(string command, CancellationToken ct) + { + if (!Leader.IsLeader) + Leader = ElectLeader(SelectNextCandidate(Leader)); + + return await Leader.ProposeAsync(command, ct); + } + + public Task StepDownAsync(CancellationToken ct) + { + _ = ct; + var previous = Leader; + previous.RequestStepDown(); + Leader = ElectLeader(SelectNextCandidate(previous)); + return Task.CompletedTask; + } + + private RaftNode SelectNextCandidate(RaftNode currentLeader) + { + if (_nodes.Count == 1) + return _nodes[0]; + + var index = _nodes.FindIndex(n => n.Id == currentLeader.Id); + if (index < 0) + return _nodes[0]; + + return _nodes[(index + 1) % _nodes.Count]; + } + + private RaftNode ElectLeader(RaftNode candidate) + { + candidate.StartElection(_nodes.Count); + foreach (var voter in _nodes.Where(n => n.Id != candidate.Id)) + candidate.ReceiveVote(voter.GrantVote(candidate.Term), _nodes.Count); + + return candidate; + } +} diff --git a/src/NATS.Server/JetStream/StreamManager.cs b/src/NATS.Server/JetStream/StreamManager.cs index b294cff..3b1e1f3 100644 --- a/src/NATS.Server/JetStream/StreamManager.cs +++ b/src/NATS.Server/JetStream/StreamManager.cs @@ -14,6 +14,8 @@ public sealed class StreamManager private readonly JetStreamMetaGroup? _metaGroup; private readonly ConcurrentDictionary _streams = new(StringComparer.Ordinal); + private readonly ConcurrentDictionary _replicaGroups = + new(StringComparer.Ordinal); private readonly ConcurrentDictionary> _mirrorsByOrigin = new(StringComparer.Ordinal); private readonly ConcurrentDictionary> _sourcesByOrigin = @@ -36,6 +38,12 @@ public sealed class StreamManager normalized.Name, _ => new StreamHandle(normalized, new MemStore()), (_, existing) => existing with { Config = normalized }); + _replicaGroups.AddOrUpdate( + normalized.Name, + _ => new StreamReplicaGroup(normalized.Name, normalized.Replicas), + (_, existing) => existing.Nodes.Count == Math.Max(normalized.Replicas, 1) + ? existing + : new StreamReplicaGroup(normalized.Name, normalized.Replicas)); RebuildReplicationCoordinators(); _metaGroup?.ProposeCreateStreamAsync(normalized, default).GetAwaiter().GetResult(); @@ -77,6 +85,9 @@ public sealed class StreamManager if (stream == null) return null; + if (_replicaGroups.TryGetValue(stream.Config.Name, out var replicaGroup)) + _ = replicaGroup.ProposeAsync($"PUB {subject}", default).GetAwaiter().GetResult(); + var seq = stream.Store.AppendAsync(subject, payload, default).GetAwaiter().GetResult(); EnforceLimits(stream); var stored = stream.Store.LoadAsync(seq, default).GetAwaiter().GetResult(); @@ -90,6 +101,14 @@ public sealed class StreamManager }; } + public Task StepDownStreamLeaderAsync(string stream, CancellationToken ct) + { + if (_replicaGroups.TryGetValue(stream, out var replicaGroup)) + return replicaGroup.StepDownAsync(ct); + + return Task.CompletedTask; + } + private static StreamConfig NormalizeConfig(StreamConfig config) { var copy = new StreamConfig diff --git a/src/NATS.Server/Raft/RaftNode.cs b/src/NATS.Server/Raft/RaftNode.cs index e39ce91..f3ab0af 100644 --- a/src/NATS.Server/Raft/RaftNode.cs +++ b/src/NATS.Server/Raft/RaftNode.cs @@ -9,6 +9,7 @@ public sealed class RaftNode public string Id { get; } public int Term => TermState.CurrentTerm; + public bool IsLeader => Role == RaftRole.Leader; public RaftRole Role { get; private set; } = RaftRole.Follower; public RaftTermState TermState { get; } = new(); public long AppliedIndex { get; set; } @@ -99,6 +100,8 @@ public sealed class RaftNode public void RequestStepDown() { Role = RaftRole.Follower; + _votesReceived = 0; + TermState.VotedFor = null; } private void TryBecomeLeader(int clusterSize) diff --git a/tests/NATS.Server.Tests/JetStreamStreamReplicaGroupTests.cs b/tests/NATS.Server.Tests/JetStreamStreamReplicaGroupTests.cs new file mode 100644 index 0000000..17e1cf2 --- /dev/null +++ b/tests/NATS.Server.Tests/JetStreamStreamReplicaGroupTests.cs @@ -0,0 +1,71 @@ +using System.Text; +using NATS.Server.JetStream; +using NATS.Server.JetStream.Models; +using NATS.Server.JetStream.Publish; + +namespace NATS.Server.Tests; + +public class JetStreamStreamReplicaGroupTests +{ + [Fact] + public async Task Leader_stepdown_preserves_stream_write_availability_after_new_election() + { + await using var fixture = await JetStreamReplicaFixture.StartAsync(nodes: 3); + await fixture.CreateStreamAsync("ORDERS", replicas: 3); + + await fixture.StepDownStreamLeaderAsync("ORDERS"); + var ack = await fixture.PublishAndGetAckAsync("orders.created", "1"); + + ack.Stream.ShouldBe("ORDERS"); + ack.Seq.ShouldBeGreaterThan((ulong)0); + } +} + +internal sealed class JetStreamReplicaFixture : IAsyncDisposable +{ + private readonly StreamManager _streamManager; + private readonly JetStreamPublisher _publisher; + + private JetStreamReplicaFixture(StreamManager streamManager) + { + _streamManager = streamManager; + _publisher = new JetStreamPublisher(_streamManager); + } + + public static Task StartAsync(int nodes) + { + _ = nodes; + var streamManager = new StreamManager(); + return Task.FromResult(new JetStreamReplicaFixture(streamManager)); + } + + public Task CreateStreamAsync(string name, int replicas) + { + var response = _streamManager.CreateOrUpdate(new StreamConfig + { + Name = name, + Subjects = ["orders.*"], + Replicas = replicas, + }); + + if (response.Error is not null) + throw new InvalidOperationException(response.Error.Description); + + return Task.CompletedTask; + } + + public Task StepDownStreamLeaderAsync(string stream) + { + return _streamManager.StepDownStreamLeaderAsync(stream, default); + } + + public Task PublishAndGetAckAsync(string subject, string payload) + { + if (_publisher.TryCapture(subject, Encoding.UTF8.GetBytes(payload), null, out var ack)) + return Task.FromResult(ack); + + throw new InvalidOperationException("Publish did not match a stream."); + } + + public ValueTask DisposeAsync() => ValueTask.CompletedTask; +}