// Copyright 2012-2026 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
namespace ZB.MOM.NatsNet.Server.Internal;
///
/// A Go-like WaitGroup: tracks a set of in-flight operations and lets callers
/// block until all of them complete.
///
internal sealed class WaitGroup
{
private int _count;
private volatile TaskCompletionSource _tcs;
public WaitGroup()
{
_tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
_tcs.SetResult(true); // starts at zero, so "done" immediately
}
///
/// Increment the counter by (usually 1).
/// Must be called before starting the goroutine it tracks.
///
public void Add(int delta = 1)
{
var newCount = Interlocked.Add(ref _count, delta);
if (newCount < 0)
throw new InvalidOperationException("WaitGroup counter went negative");
if (newCount == 0)
{
// All goroutines done — signal any waiters.
Volatile.Read(ref _tcs).TrySetResult(true);
}
else if (delta > 0 && newCount == delta)
{
// Transitioning from 0 to positive — replace the completed TCS
// with a fresh unsignaled one so Wait() will block correctly.
Volatile.Write(ref _tcs,
new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously));
}
}
/// Decrement the counter by 1. Called when a goroutine finishes.
public void Done() => Add(-1);
/// Block synchronously until the counter reaches 0.
public void Wait()
{
if (Volatile.Read(ref _count) == 0) return;
Volatile.Read(ref _tcs).Task.GetAwaiter().GetResult();
}
}