Fix reliability findings

This commit is contained in:
Joseph Doherty
2026-04-28 06:27:01 -04:00
parent 907aa49aea
commit b0041c5d18
9 changed files with 233 additions and 21 deletions
+54 -13
View File
@@ -113,6 +113,40 @@ func TestEventSubscriptionCloseStopsStream(t *testing.T) {
}
}
func TestEventsAfterCancelsStreamWhenCompatibilityChannelIsAbandoned(t *testing.T) {
fake := &fakeGatewayServer{
streamStarted: make(chan struct{}),
streamDone: make(chan struct{}),
streamEventCount: 64,
}
client, cleanup := newBufconnClient(t, fake)
defer cleanup()
session := NewSessionForID(client, "session-1")
events, err := session.EventsAfter(context.Background(), 0)
if err != nil {
t.Fatalf("EventsAfter() error = %v", err)
}
<-fake.streamStarted
select {
case <-fake.streamDone:
case <-time.After(2 * time.Second):
t.Fatal("compatibility event stream did not stop after result channel filled")
}
for {
select {
case _, ok := <-events:
if !ok {
return
}
case <-time.After(2 * time.Second):
t.Fatal("compatibility event channel did not close")
}
}
}
func TestSessionHelpersBuildCommandsAndExposeRawReply(t *testing.T) {
fake := &fakeGatewayServer{
invokeReply: &pb.MxCommandReply{
@@ -267,13 +301,14 @@ func newBufconnClient(t *testing.T, fake *fakeGatewayServer) (*Client, func()) {
type fakeGatewayServer struct {
pb.UnimplementedMxAccessGatewayServer
openReply *pb.OpenSessionReply
openAuth string
streamAuth string
streamStarted chan struct{}
streamDone chan struct{}
invokeReply *pb.MxCommandReply
invokeRequest *pb.MxCommandRequest
openReply *pb.OpenSessionReply
openAuth string
streamAuth string
streamStarted chan struct{}
streamDone chan struct{}
streamEventCount int
invokeReply *pb.MxCommandReply
invokeRequest *pb.MxCommandRequest
}
func (s *fakeGatewayServer) OpenSession(ctx context.Context, req *pb.OpenSessionRequest) (*pb.OpenSessionReply, error) {
@@ -320,12 +355,18 @@ func (s *fakeGatewayServer) StreamEvents(req *pb.StreamEventsRequest, stream grp
if s.streamStarted != nil {
close(s.streamStarted)
}
if err := stream.Send(&pb.MxEvent{
SessionId: req.GetSessionId(),
Family: pb.MxEventFamily_MX_EVENT_FAMILY_ON_DATA_CHANGE,
WorkerSequence: 1,
}); err != nil {
return err
eventCount := s.streamEventCount
if eventCount == 0 {
eventCount = 1
}
for sequence := 1; sequence <= eventCount; sequence++ {
if err := stream.Send(&pb.MxEvent{
SessionId: req.GetSessionId(),
Family: pb.MxEventFamily_MX_EVENT_FAMILY_ON_DATA_CHANGE,
WorkerSequence: uint64(sequence),
}); err != nil {
return err
}
}
<-stream.Context().Done()
return io.EOF
+31 -4
View File
@@ -418,7 +418,7 @@ func (s *Session) Events(ctx context.Context) (<-chan EventResult, error) {
// EventsAfter streams ordered session events after the given worker sequence.
func (s *Session) EventsAfter(ctx context.Context, afterWorkerSequence uint64) (<-chan EventResult, error) {
subscription, err := s.SubscribeEventsAfter(ctx, afterWorkerSequence)
subscription, err := s.subscribeEventsAfter(ctx, afterWorkerSequence, true)
if err != nil {
return nil, err
}
@@ -432,6 +432,10 @@ func (s *Session) SubscribeEvents(ctx context.Context) (*EventSubscription, erro
// SubscribeEventsAfter starts an owned event subscription after the given worker sequence.
func (s *Session) SubscribeEventsAfter(ctx context.Context, afterWorkerSequence uint64) (*EventSubscription, error) {
return s.subscribeEventsAfter(ctx, afterWorkerSequence, false)
}
func (s *Session) subscribeEventsAfter(ctx context.Context, afterWorkerSequence uint64, cancelWhenResultBufferFull bool) (*EventSubscription, error) {
streamCtx, cancel := context.WithCancel(ctx)
stream, err := s.client.StreamEventsRaw(streamCtx, &pb.StreamEventsRequest{
SessionId: s.ID(),
@@ -450,7 +454,7 @@ func (s *Session) SubscribeEventsAfter(ctx context.Context, afterWorkerSequence
for {
event, err := stream.Recv()
if err == nil {
if !sendEventResult(streamCtx, results, EventResult{Event: event}) {
if !sendEventResult(streamCtx, results, EventResult{Event: event}, cancelWhenResultBufferFull, cancel) {
return
}
continue
@@ -458,7 +462,12 @@ func (s *Session) SubscribeEventsAfter(ctx context.Context, afterWorkerSequence
if err == io.EOF || status.Code(err) == codes.Canceled || streamCtx.Err() != nil {
return
}
sendEventResult(streamCtx, results, EventResult{Err: &GatewayError{Op: "stream events", Err: err}})
sendEventResult(
streamCtx,
results,
EventResult{Err: &GatewayError{Op: "stream events", Err: err}},
cancelWhenResultBufferFull,
cancel)
return
}
}()
@@ -477,7 +486,25 @@ func ensureBulkSize(name string, length int) error {
return nil
}
func sendEventResult(ctx context.Context, results chan<- EventResult, result EventResult) bool {
func sendEventResult(
ctx context.Context,
results chan<- EventResult,
result EventResult,
cancelWhenBufferFull bool,
cancel context.CancelFunc,
) bool {
if cancelWhenBufferFull {
select {
case results <- result:
return true
case <-ctx.Done():
return false
default:
cancel()
return false
}
}
select {
case results <- result:
return true
@@ -66,6 +66,8 @@ public sealed class EventStreamService(
{
await streamCts.CancelAsync().ConfigureAwait(false);
subscriber.Dispose();
Interlocked.Exchange(ref streamQueueDepth, 0);
metrics.SetGrpcEventStreamQueueDepth(0);
metrics.StreamDisconnected("Detached");
try
@@ -101,6 +101,17 @@ public sealed class GatewayMetrics : IDisposable
_sessionsClosedCounter.Add(1);
}
public void SessionRemoved()
{
lock (_syncRoot)
{
if (_openSessions > 0)
{
_openSessions--;
}
}
}
public void WorkerStarted(TimeSpan startupDuration)
{
lock (_syncRoot)
@@ -184,8 +184,11 @@ public sealed class SessionManager : ISessionManager
exception,
"Graceful shutdown failed for session {SessionId}; killing worker.",
session.SessionId);
session.KillWorker(GatewayShutdownReason);
await RemoveSessionAsync(session).ConfigureAwait(false);
if (_registry.TryGet(session.SessionId, out _))
{
session.KillWorker(GatewayShutdownReason);
await RemoveSessionAsync(session).ConfigureAwait(false);
}
}
}
}
@@ -210,7 +213,13 @@ public sealed class SessionManager : ISessionManager
catch (Exception exception)
{
session.MarkFaulted(exception.Message);
if (!wasClosed)
{
_metrics.SessionRemoved();
}
_metrics.Fault(SessionManagerErrorCode.CloseFailed.ToString());
await RemoveSessionAsync(session).ConfigureAwait(false);
throw new SessionManagerException(
SessionManagerErrorCode.CloseFailed,
$"Failed to close session {session.SessionId}.",
@@ -85,6 +85,32 @@ public sealed class EventStreamServiceTests
await WaitUntilAsync(() => session.ActiveEventSubscriberCount == 0);
}
[Fact]
public async Task StreamEventsAsync_WhenDisposedWithBufferedEvents_ResetsStreamQueueDepth()
{
FakeWorkerClient workerClient = new();
GatewaySession session = CreateReadySession(workerClient);
using GatewayMetrics metrics = new();
EventStreamService service = CreateService(
new FakeSessionManager(session),
metrics,
queueCapacity: 8);
workerClient.Events.Add(CreateWorkerEvent(sequence: 1, MxEventFamily.OnDataChange));
workerClient.Events.Add(CreateWorkerEvent(sequence: 2, MxEventFamily.OnDataChange));
workerClient.Events.Add(CreateWorkerEvent(sequence: 3, MxEventFamily.OnDataChange));
workerClient.CompleteAfterConfiguredEvents = true;
await using IAsyncEnumerator<MxEvent> subscriber = service
.StreamEventsAsync(CreateRequest(session.SessionId), CancellationToken.None)
.GetAsyncEnumerator();
Assert.True(await subscriber.MoveNextAsync().AsTask().WaitAsync(TestTimeout));
await WaitUntilAsync(() => metrics.GetSnapshot().GrpcEventStreamQueueDepth > 0);
await subscriber.DisposeAsync();
await WaitUntilAsync(() => metrics.GetSnapshot().GrpcEventStreamQueueDepth == 0);
}
[Fact]
public async Task StreamEventsAsync_WhenStreamQueueOverflows_FaultsSessionAndReportsOverflow()
{
@@ -179,6 +179,48 @@ public sealed class SessionManagerTests
Assert.Equal(1, workerClient.KillCount);
}
[Fact]
public async Task CloseSessionAsync_WhenWorkerShutdownFails_RemovesSessionAndReleasesSlot()
{
FakeWorkerClient failingWorkerClient = new()
{
ShutdownException = new WorkerClientException(
WorkerClientErrorCode.ShutdownTimeout,
"Worker shutdown timed out."),
};
FakeWorkerClient replacementWorkerClient = new();
SessionRegistry registry = new();
using GatewayMetrics metrics = new();
SessionManager manager = CreateManager(
new QueueingSessionWorkerClientFactory(failingWorkerClient, replacementWorkerClient),
registry,
metrics,
CreateOptions(maxSessions: 1));
GatewaySession firstSession = await manager.OpenSessionAsync(
CreateOpenRequest(),
"client-1",
CancellationToken.None);
metrics.EventReceived(firstSession.SessionId, MxEventFamily.OnDataChange.ToString());
SessionManagerException exception = await Assert.ThrowsAsync<SessionManagerException>(
async () => await manager.CloseSessionAsync(firstSession.SessionId, CancellationToken.None));
GatewaySession secondSession = await manager.OpenSessionAsync(
CreateOpenRequest(),
"client-2",
CancellationToken.None);
Assert.Equal(SessionManagerErrorCode.CloseFailed, exception.ErrorCode);
Assert.False(manager.TryGetSession(firstSession.SessionId, out _));
Assert.True(manager.TryGetSession(secondSession.SessionId, out _));
Assert.Equal(1, registry.Count);
Assert.Equal(1, failingWorkerClient.KillCount);
Assert.Equal(1, failingWorkerClient.DisposeCount);
GatewayMetricsSnapshot snapshot = metrics.GetSnapshot();
Assert.Equal(0, snapshot.SessionsClosed);
Assert.False(snapshot.EventsBySession.ContainsKey(firstSession.SessionId));
Assert.Equal(1, snapshot.OpenSessions);
}
[Fact]
public async Task OpenSessionAsync_WhenWorkerCreationFails_RemovesSessionFromRegistry()
{
@@ -254,14 +296,14 @@ public sealed class SessionManagerTests
metrics ?? new GatewayMetrics());
}
private static GatewayOptions CreateOptions()
private static GatewayOptions CreateOptions(int maxSessions = 64)
{
return new GatewayOptions
{
Sessions = new SessionOptions
{
DefaultCommandTimeoutSeconds = 30,
MaxSessions = 64,
MaxSessions = maxSessions,
},
Worker = new WorkerOptions
{
@@ -359,6 +401,8 @@ public sealed class SessionManagerTests
public int KillCount { get; private set; }
public int DisposeCount { get; private set; }
public Exception? ShutdownException { get; init; }
public WorkerCommand? LastCommand { get; private set; }
@@ -424,6 +468,7 @@ public sealed class SessionManagerTests
public ValueTask DisposeAsync()
{
DisposeCount++;
return ValueTask.CompletedTask;
}
}
@@ -343,6 +343,45 @@ public sealed class WorkerPipeSessionTests
await runTask;
}
[Fact]
public async Task RunAsync_WhenCommandThrowsAfterShutdown_DropsLateFaultAndWritesShutdownAck()
{
using CancellationTokenSource cancellation = new(TimeSpan.FromSeconds(5));
using PipePair pipePair = await PipePair.CreateAsync(cancellation.Token);
FakeRuntimeSession runtime = new()
{
BlockDispatch = true,
ThrowAfterDispatchReleased = true,
};
WorkerPipeSession session = CreatePipeSession(
pipePair.WorkerStream,
runtime,
new WorkerPipeSessionOptions
{
HeartbeatInterval = TimeSpan.FromSeconds(1),
HeartbeatGrace = TimeSpan.FromSeconds(5),
});
Task runTask = session.RunAsync(cancellation.Token);
await CompleteGatewayHandshakeAsync(pipePair, cancellation.Token);
await pipePair.GatewayWriter.WriteAsync(
CreateCommandEnvelope("command-fails-during-shutdown"),
cancellation.Token);
Assert.True(runtime.DispatchStarted.Wait(TimeSpan.FromSeconds(2)));
await pipePair.GatewayWriter
.WriteAsync(CreateShutdownEnvelope(), cancellation.Token);
WorkerEnvelope firstEnvelopeAfterShutdown = await pipePair.GatewayReader
.ReadAsync(cancellation.Token);
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerShutdownAck, firstEnvelopeAfterShutdown.BodyCase);
Assert.Equal(ProtocolStatusCode.Ok, firstEnvelopeAfterShutdown.WorkerShutdownAck.Status.Code);
Task completedTask = await Task.WhenAny(runTask, Task.Delay(TimeSpan.FromSeconds(2), cancellation.Token));
Assert.Same(runTask, completedTask);
await runTask;
}
private static WorkerPipeSession CreateSession(
Stream inbound,
Stream outbound,
@@ -574,6 +613,8 @@ public sealed class WorkerPipeSessionTests
public bool BlockDispatch { get; set; }
public bool ThrowAfterDispatchReleased { get; set; }
public Task<WorkerReady> StartAsync(
string sessionId,
int workerProcessId,
@@ -613,6 +654,11 @@ public sealed class WorkerPipeSessionTests
lastEventSequence: 0,
currentCommandCorrelationId: string.Empty));
if (ThrowAfterDispatchReleased)
{
throw new InvalidOperationException("Command failed after shutdown started.");
}
return new MxCommandReply
{
SessionId = command.SessionId,
@@ -386,6 +386,11 @@ public sealed class WorkerPipeSession
}
catch (Exception exception) when (exception is not OperationCanceledException)
{
if (_state is not WorkerState.Ready and not WorkerState.ExecutingCommand)
{
return;
}
_state = WorkerState.Faulted;
await TryWriteFaultAsync(
CreateFault(