From b0041c5d1811b91b3603b5ec01398dc28fe4b635 Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Tue, 28 Apr 2026 06:27:01 -0400 Subject: [PATCH] Fix reliability findings --- clients/go/mxgateway/client_session_test.go | 67 +++++++++++++++---- clients/go/mxgateway/session.go | 35 ++++++++-- .../Grpc/EventStreamService.cs | 2 + .../Metrics/GatewayMetrics.cs | 11 +++ .../Sessions/SessionManager.cs | 13 +++- .../Gateway/Grpc/EventStreamServiceTests.cs | 26 +++++++ .../Gateway/Sessions/SessionManagerTests.cs | 49 +++++++++++++- .../Ipc/WorkerPipeSessionTests.cs | 46 +++++++++++++ src/MxGateway.Worker/Ipc/WorkerPipeSession.cs | 5 ++ 9 files changed, 233 insertions(+), 21 deletions(-) diff --git a/clients/go/mxgateway/client_session_test.go b/clients/go/mxgateway/client_session_test.go index 46ffb5f..b1577b4 100644 --- a/clients/go/mxgateway/client_session_test.go +++ b/clients/go/mxgateway/client_session_test.go @@ -113,6 +113,40 @@ func TestEventSubscriptionCloseStopsStream(t *testing.T) { } } +func TestEventsAfterCancelsStreamWhenCompatibilityChannelIsAbandoned(t *testing.T) { + fake := &fakeGatewayServer{ + streamStarted: make(chan struct{}), + streamDone: make(chan struct{}), + streamEventCount: 64, + } + client, cleanup := newBufconnClient(t, fake) + defer cleanup() + session := NewSessionForID(client, "session-1") + + events, err := session.EventsAfter(context.Background(), 0) + if err != nil { + t.Fatalf("EventsAfter() error = %v", err) + } + <-fake.streamStarted + + select { + case <-fake.streamDone: + case <-time.After(2 * time.Second): + t.Fatal("compatibility event stream did not stop after result channel filled") + } + + for { + select { + case _, ok := <-events: + if !ok { + return + } + case <-time.After(2 * time.Second): + t.Fatal("compatibility event channel did not close") + } + } +} + func TestSessionHelpersBuildCommandsAndExposeRawReply(t *testing.T) { fake := &fakeGatewayServer{ invokeReply: &pb.MxCommandReply{ @@ -267,13 +301,14 @@ func newBufconnClient(t *testing.T, fake *fakeGatewayServer) (*Client, func()) { type fakeGatewayServer struct { pb.UnimplementedMxAccessGatewayServer - openReply *pb.OpenSessionReply - openAuth string - streamAuth string - streamStarted chan struct{} - streamDone chan struct{} - invokeReply *pb.MxCommandReply - invokeRequest *pb.MxCommandRequest + openReply *pb.OpenSessionReply + openAuth string + streamAuth string + streamStarted chan struct{} + streamDone chan struct{} + streamEventCount int + invokeReply *pb.MxCommandReply + invokeRequest *pb.MxCommandRequest } func (s *fakeGatewayServer) OpenSession(ctx context.Context, req *pb.OpenSessionRequest) (*pb.OpenSessionReply, error) { @@ -320,12 +355,18 @@ func (s *fakeGatewayServer) StreamEvents(req *pb.StreamEventsRequest, stream grp if s.streamStarted != nil { close(s.streamStarted) } - if err := stream.Send(&pb.MxEvent{ - SessionId: req.GetSessionId(), - Family: pb.MxEventFamily_MX_EVENT_FAMILY_ON_DATA_CHANGE, - WorkerSequence: 1, - }); err != nil { - return err + eventCount := s.streamEventCount + if eventCount == 0 { + eventCount = 1 + } + for sequence := 1; sequence <= eventCount; sequence++ { + if err := stream.Send(&pb.MxEvent{ + SessionId: req.GetSessionId(), + Family: pb.MxEventFamily_MX_EVENT_FAMILY_ON_DATA_CHANGE, + WorkerSequence: uint64(sequence), + }); err != nil { + return err + } } <-stream.Context().Done() return io.EOF diff --git a/clients/go/mxgateway/session.go b/clients/go/mxgateway/session.go index 81be344..4a6099c 100644 --- a/clients/go/mxgateway/session.go +++ b/clients/go/mxgateway/session.go @@ -418,7 +418,7 @@ func (s *Session) Events(ctx context.Context) (<-chan EventResult, error) { // EventsAfter streams ordered session events after the given worker sequence. func (s *Session) EventsAfter(ctx context.Context, afterWorkerSequence uint64) (<-chan EventResult, error) { - subscription, err := s.SubscribeEventsAfter(ctx, afterWorkerSequence) + subscription, err := s.subscribeEventsAfter(ctx, afterWorkerSequence, true) if err != nil { return nil, err } @@ -432,6 +432,10 @@ func (s *Session) SubscribeEvents(ctx context.Context) (*EventSubscription, erro // SubscribeEventsAfter starts an owned event subscription after the given worker sequence. func (s *Session) SubscribeEventsAfter(ctx context.Context, afterWorkerSequence uint64) (*EventSubscription, error) { + return s.subscribeEventsAfter(ctx, afterWorkerSequence, false) +} + +func (s *Session) subscribeEventsAfter(ctx context.Context, afterWorkerSequence uint64, cancelWhenResultBufferFull bool) (*EventSubscription, error) { streamCtx, cancel := context.WithCancel(ctx) stream, err := s.client.StreamEventsRaw(streamCtx, &pb.StreamEventsRequest{ SessionId: s.ID(), @@ -450,7 +454,7 @@ func (s *Session) SubscribeEventsAfter(ctx context.Context, afterWorkerSequence for { event, err := stream.Recv() if err == nil { - if !sendEventResult(streamCtx, results, EventResult{Event: event}) { + if !sendEventResult(streamCtx, results, EventResult{Event: event}, cancelWhenResultBufferFull, cancel) { return } continue @@ -458,7 +462,12 @@ func (s *Session) SubscribeEventsAfter(ctx context.Context, afterWorkerSequence if err == io.EOF || status.Code(err) == codes.Canceled || streamCtx.Err() != nil { return } - sendEventResult(streamCtx, results, EventResult{Err: &GatewayError{Op: "stream events", Err: err}}) + sendEventResult( + streamCtx, + results, + EventResult{Err: &GatewayError{Op: "stream events", Err: err}}, + cancelWhenResultBufferFull, + cancel) return } }() @@ -477,7 +486,25 @@ func ensureBulkSize(name string, length int) error { return nil } -func sendEventResult(ctx context.Context, results chan<- EventResult, result EventResult) bool { +func sendEventResult( + ctx context.Context, + results chan<- EventResult, + result EventResult, + cancelWhenBufferFull bool, + cancel context.CancelFunc, +) bool { + if cancelWhenBufferFull { + select { + case results <- result: + return true + case <-ctx.Done(): + return false + default: + cancel() + return false + } + } + select { case results <- result: return true diff --git a/src/MxGateway.Server/Grpc/EventStreamService.cs b/src/MxGateway.Server/Grpc/EventStreamService.cs index 8d8249f..db17452 100644 --- a/src/MxGateway.Server/Grpc/EventStreamService.cs +++ b/src/MxGateway.Server/Grpc/EventStreamService.cs @@ -66,6 +66,8 @@ public sealed class EventStreamService( { await streamCts.CancelAsync().ConfigureAwait(false); subscriber.Dispose(); + Interlocked.Exchange(ref streamQueueDepth, 0); + metrics.SetGrpcEventStreamQueueDepth(0); metrics.StreamDisconnected("Detached"); try diff --git a/src/MxGateway.Server/Metrics/GatewayMetrics.cs b/src/MxGateway.Server/Metrics/GatewayMetrics.cs index 96e9412..928d379 100644 --- a/src/MxGateway.Server/Metrics/GatewayMetrics.cs +++ b/src/MxGateway.Server/Metrics/GatewayMetrics.cs @@ -101,6 +101,17 @@ public sealed class GatewayMetrics : IDisposable _sessionsClosedCounter.Add(1); } + public void SessionRemoved() + { + lock (_syncRoot) + { + if (_openSessions > 0) + { + _openSessions--; + } + } + } + public void WorkerStarted(TimeSpan startupDuration) { lock (_syncRoot) diff --git a/src/MxGateway.Server/Sessions/SessionManager.cs b/src/MxGateway.Server/Sessions/SessionManager.cs index fc116c8..694e639 100644 --- a/src/MxGateway.Server/Sessions/SessionManager.cs +++ b/src/MxGateway.Server/Sessions/SessionManager.cs @@ -184,8 +184,11 @@ public sealed class SessionManager : ISessionManager exception, "Graceful shutdown failed for session {SessionId}; killing worker.", session.SessionId); - session.KillWorker(GatewayShutdownReason); - await RemoveSessionAsync(session).ConfigureAwait(false); + if (_registry.TryGet(session.SessionId, out _)) + { + session.KillWorker(GatewayShutdownReason); + await RemoveSessionAsync(session).ConfigureAwait(false); + } } } } @@ -210,7 +213,13 @@ public sealed class SessionManager : ISessionManager catch (Exception exception) { session.MarkFaulted(exception.Message); + if (!wasClosed) + { + _metrics.SessionRemoved(); + } + _metrics.Fault(SessionManagerErrorCode.CloseFailed.ToString()); + await RemoveSessionAsync(session).ConfigureAwait(false); throw new SessionManagerException( SessionManagerErrorCode.CloseFailed, $"Failed to close session {session.SessionId}.", diff --git a/src/MxGateway.Tests/Gateway/Grpc/EventStreamServiceTests.cs b/src/MxGateway.Tests/Gateway/Grpc/EventStreamServiceTests.cs index d0d864f..ff38b45 100644 --- a/src/MxGateway.Tests/Gateway/Grpc/EventStreamServiceTests.cs +++ b/src/MxGateway.Tests/Gateway/Grpc/EventStreamServiceTests.cs @@ -85,6 +85,32 @@ public sealed class EventStreamServiceTests await WaitUntilAsync(() => session.ActiveEventSubscriberCount == 0); } + [Fact] + public async Task StreamEventsAsync_WhenDisposedWithBufferedEvents_ResetsStreamQueueDepth() + { + FakeWorkerClient workerClient = new(); + GatewaySession session = CreateReadySession(workerClient); + using GatewayMetrics metrics = new(); + EventStreamService service = CreateService( + new FakeSessionManager(session), + metrics, + queueCapacity: 8); + workerClient.Events.Add(CreateWorkerEvent(sequence: 1, MxEventFamily.OnDataChange)); + workerClient.Events.Add(CreateWorkerEvent(sequence: 2, MxEventFamily.OnDataChange)); + workerClient.Events.Add(CreateWorkerEvent(sequence: 3, MxEventFamily.OnDataChange)); + workerClient.CompleteAfterConfiguredEvents = true; + await using IAsyncEnumerator subscriber = service + .StreamEventsAsync(CreateRequest(session.SessionId), CancellationToken.None) + .GetAsyncEnumerator(); + + Assert.True(await subscriber.MoveNextAsync().AsTask().WaitAsync(TestTimeout)); + await WaitUntilAsync(() => metrics.GetSnapshot().GrpcEventStreamQueueDepth > 0); + + await subscriber.DisposeAsync(); + + await WaitUntilAsync(() => metrics.GetSnapshot().GrpcEventStreamQueueDepth == 0); + } + [Fact] public async Task StreamEventsAsync_WhenStreamQueueOverflows_FaultsSessionAndReportsOverflow() { diff --git a/src/MxGateway.Tests/Gateway/Sessions/SessionManagerTests.cs b/src/MxGateway.Tests/Gateway/Sessions/SessionManagerTests.cs index 13ec3bc..52efdd3 100644 --- a/src/MxGateway.Tests/Gateway/Sessions/SessionManagerTests.cs +++ b/src/MxGateway.Tests/Gateway/Sessions/SessionManagerTests.cs @@ -179,6 +179,48 @@ public sealed class SessionManagerTests Assert.Equal(1, workerClient.KillCount); } + [Fact] + public async Task CloseSessionAsync_WhenWorkerShutdownFails_RemovesSessionAndReleasesSlot() + { + FakeWorkerClient failingWorkerClient = new() + { + ShutdownException = new WorkerClientException( + WorkerClientErrorCode.ShutdownTimeout, + "Worker shutdown timed out."), + }; + FakeWorkerClient replacementWorkerClient = new(); + SessionRegistry registry = new(); + using GatewayMetrics metrics = new(); + SessionManager manager = CreateManager( + new QueueingSessionWorkerClientFactory(failingWorkerClient, replacementWorkerClient), + registry, + metrics, + CreateOptions(maxSessions: 1)); + GatewaySession firstSession = await manager.OpenSessionAsync( + CreateOpenRequest(), + "client-1", + CancellationToken.None); + metrics.EventReceived(firstSession.SessionId, MxEventFamily.OnDataChange.ToString()); + + SessionManagerException exception = await Assert.ThrowsAsync( + async () => await manager.CloseSessionAsync(firstSession.SessionId, CancellationToken.None)); + GatewaySession secondSession = await manager.OpenSessionAsync( + CreateOpenRequest(), + "client-2", + CancellationToken.None); + + Assert.Equal(SessionManagerErrorCode.CloseFailed, exception.ErrorCode); + Assert.False(manager.TryGetSession(firstSession.SessionId, out _)); + Assert.True(manager.TryGetSession(secondSession.SessionId, out _)); + Assert.Equal(1, registry.Count); + Assert.Equal(1, failingWorkerClient.KillCount); + Assert.Equal(1, failingWorkerClient.DisposeCount); + GatewayMetricsSnapshot snapshot = metrics.GetSnapshot(); + Assert.Equal(0, snapshot.SessionsClosed); + Assert.False(snapshot.EventsBySession.ContainsKey(firstSession.SessionId)); + Assert.Equal(1, snapshot.OpenSessions); + } + [Fact] public async Task OpenSessionAsync_WhenWorkerCreationFails_RemovesSessionFromRegistry() { @@ -254,14 +296,14 @@ public sealed class SessionManagerTests metrics ?? new GatewayMetrics()); } - private static GatewayOptions CreateOptions() + private static GatewayOptions CreateOptions(int maxSessions = 64) { return new GatewayOptions { Sessions = new SessionOptions { DefaultCommandTimeoutSeconds = 30, - MaxSessions = 64, + MaxSessions = maxSessions, }, Worker = new WorkerOptions { @@ -359,6 +401,8 @@ public sealed class SessionManagerTests public int KillCount { get; private set; } + public int DisposeCount { get; private set; } + public Exception? ShutdownException { get; init; } public WorkerCommand? LastCommand { get; private set; } @@ -424,6 +468,7 @@ public sealed class SessionManagerTests public ValueTask DisposeAsync() { + DisposeCount++; return ValueTask.CompletedTask; } } diff --git a/src/MxGateway.Worker.Tests/Ipc/WorkerPipeSessionTests.cs b/src/MxGateway.Worker.Tests/Ipc/WorkerPipeSessionTests.cs index c05e574..d8188cc 100644 --- a/src/MxGateway.Worker.Tests/Ipc/WorkerPipeSessionTests.cs +++ b/src/MxGateway.Worker.Tests/Ipc/WorkerPipeSessionTests.cs @@ -343,6 +343,45 @@ public sealed class WorkerPipeSessionTests await runTask; } + [Fact] + public async Task RunAsync_WhenCommandThrowsAfterShutdown_DropsLateFaultAndWritesShutdownAck() + { + using CancellationTokenSource cancellation = new(TimeSpan.FromSeconds(5)); + using PipePair pipePair = await PipePair.CreateAsync(cancellation.Token); + FakeRuntimeSession runtime = new() + { + BlockDispatch = true, + ThrowAfterDispatchReleased = true, + }; + WorkerPipeSession session = CreatePipeSession( + pipePair.WorkerStream, + runtime, + new WorkerPipeSessionOptions + { + HeartbeatInterval = TimeSpan.FromSeconds(1), + HeartbeatGrace = TimeSpan.FromSeconds(5), + }); + Task runTask = session.RunAsync(cancellation.Token); + await CompleteGatewayHandshakeAsync(pipePair, cancellation.Token); + + await pipePair.GatewayWriter.WriteAsync( + CreateCommandEnvelope("command-fails-during-shutdown"), + cancellation.Token); + Assert.True(runtime.DispatchStarted.Wait(TimeSpan.FromSeconds(2))); + + await pipePair.GatewayWriter + .WriteAsync(CreateShutdownEnvelope(), cancellation.Token); + + WorkerEnvelope firstEnvelopeAfterShutdown = await pipePair.GatewayReader + .ReadAsync(cancellation.Token); + + Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerShutdownAck, firstEnvelopeAfterShutdown.BodyCase); + Assert.Equal(ProtocolStatusCode.Ok, firstEnvelopeAfterShutdown.WorkerShutdownAck.Status.Code); + Task completedTask = await Task.WhenAny(runTask, Task.Delay(TimeSpan.FromSeconds(2), cancellation.Token)); + Assert.Same(runTask, completedTask); + await runTask; + } + private static WorkerPipeSession CreateSession( Stream inbound, Stream outbound, @@ -574,6 +613,8 @@ public sealed class WorkerPipeSessionTests public bool BlockDispatch { get; set; } + public bool ThrowAfterDispatchReleased { get; set; } + public Task StartAsync( string sessionId, int workerProcessId, @@ -613,6 +654,11 @@ public sealed class WorkerPipeSessionTests lastEventSequence: 0, currentCommandCorrelationId: string.Empty)); + if (ThrowAfterDispatchReleased) + { + throw new InvalidOperationException("Command failed after shutdown started."); + } + return new MxCommandReply { SessionId = command.SessionId, diff --git a/src/MxGateway.Worker/Ipc/WorkerPipeSession.cs b/src/MxGateway.Worker/Ipc/WorkerPipeSession.cs index cdbea79..a42cc7f 100644 --- a/src/MxGateway.Worker/Ipc/WorkerPipeSession.cs +++ b/src/MxGateway.Worker/Ipc/WorkerPipeSession.cs @@ -386,6 +386,11 @@ public sealed class WorkerPipeSession } catch (Exception exception) when (exception is not OperationCanceledException) { + if (_state is not WorkerState.Ready and not WorkerState.ExecutingCommand) + { + return; + } + _state = WorkerState.Faulted; await TryWriteFaultAsync( CreateFault(