Implement graceful worker shutdown

This commit is contained in:
Joseph Doherty
2026-04-26 19:36:22 -04:00
parent 95e71cd819
commit d890eff862
15 changed files with 694 additions and 11 deletions
@@ -1,3 +1,4 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.InteropServices;
@@ -147,6 +148,29 @@ public sealed class WorkerPipeSessionTests
Assert.Equal(ProtocolStatusCode.WorkerUnavailable, written[1].WorkerFault.ProtocolStatus.Code);
}
[Fact]
public async Task RunAsync_WithWorkerShutdown_WritesShutdownAckAndReturns()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream inbound = new();
WorkerFrameWriter inboundWriter = new(inbound, options);
await inboundWriter.WriteAsync(CreateGatewayHelloEnvelope());
await inboundWriter.WriteAsync(CreateWorkerShutdownEnvelope());
inbound.Position = 0;
MemoryStream outbound = new();
WorkerPipeSession session = CreateSession(inbound, outbound, options);
await session.CompleteStartupHandshakeAsync(_ => Task.CompletedTask);
await session.RunAsync();
WorkerEnvelope[] written = ReadWrittenFrames(outbound, options);
Assert.Equal(3, written.Length);
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerHello, written[0].BodyCase);
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerReady, written[1].BodyCase);
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerShutdownAck, written[2].BodyCase);
Assert.Equal(ProtocolStatusCode.Ok, written[2].WorkerShutdownAck.Status.Code);
}
private static WorkerPipeSession CreateSession(
Stream inbound,
Stream outbound,
@@ -185,6 +209,21 @@ public sealed class WorkerPipeSessionTests
};
}
private static WorkerEnvelope CreateWorkerShutdownEnvelope()
{
return new WorkerEnvelope
{
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
SessionId = SessionId,
Sequence = 2,
WorkerShutdown = new WorkerShutdown
{
GracePeriod = Google.Protobuf.WellKnownTypes.Duration.FromTimeSpan(TimeSpan.FromSeconds(1)),
Reason = "test-shutdown",
},
};
}
private static WorkerEnvelope[] ReadWrittenFrames(
MemoryStream stream,
WorkerFrameProtocolOptions options)
@@ -1,4 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
using MxGateway.Contracts.Proto;
@@ -414,6 +416,57 @@ public sealed class MxAccessCommandExecutorTests
Assert.Equal(MxAccessAdviceKind.Plain, adviceHandle.AdviceKind);
}
[Fact]
public async Task ShutdownGracefullyAsync_CleansHandlesInAdviceItemServerOrder()
{
FakeMxAccessComObject fakeComObject = new(
registerHandle: 58,
addItemHandle: 510);
FakeMxAccessComObjectFactory factory = new(fakeComObject);
using StaRuntime runtime = CreateRuntime();
using MxAccessStaSession session = new(runtime, factory, new NoopEventSink());
await session.StartAsync(workerProcessId: 1234);
await session.DispatchAsync(CreateRegisterCommand("register-before-shutdown", "client-a"));
await session.DispatchAsync(CreateAddItemCommand("add-before-shutdown", 58, "Galaxy.Tag.Value"));
await session.DispatchAsync(CreateAdviseCommand("advise-before-shutdown", 58, 510));
await session.DispatchAsync(CreateAdviseSupervisoryCommand("supervisory-before-shutdown", 58, 510));
MxAccessShutdownResult result = await session.ShutdownGracefullyAsync(TimeSpan.FromSeconds(2));
Assert.True(result.Succeeded);
Assert.Equal(
new[] { "UnAdvise:58:510", "RemoveItem:58:510", "Unregister:58" },
fakeComObject.OperationNames.Where(name => name.StartsWith("Un", StringComparison.Ordinal)
|| name.StartsWith("Remove", StringComparison.Ordinal)));
}
[Fact]
public async Task ShutdownGracefullyAsync_RecordsCleanupFailuresAndContinues()
{
const int hresult = unchecked((int)0x80070057);
COMException cleanupException = new("Invalid handle.", hresult);
FakeMxAccessComObject fakeComObject = new(
registerHandle: 59,
addItemHandle: 511,
unregisterException: cleanupException,
removeItemException: cleanupException,
unAdviseException: cleanupException);
FakeMxAccessComObjectFactory factory = new(fakeComObject);
using StaRuntime runtime = CreateRuntime();
using MxAccessStaSession session = new(runtime, factory, new NoopEventSink());
await session.StartAsync(workerProcessId: 1234);
await session.DispatchAsync(CreateRegisterCommand("register-before-shutdown-failure", "client-a"));
await session.DispatchAsync(CreateAddItemCommand("add-before-shutdown-failure", 59, "Galaxy.Tag.Value"));
await session.DispatchAsync(CreateAdviseCommand("advise-before-shutdown-failure", 59, 511));
MxAccessShutdownResult result = await session.ShutdownGracefullyAsync(TimeSpan.FromSeconds(2));
Assert.False(result.Succeeded);
Assert.Equal(new[] { "UnAdvise", "RemoveItem", "Unregister" }, result.Failures.Select(failure => failure.Operation));
Assert.All(result.Failures, failure => Assert.Equal(hresult, failure.HResult));
Assert.Contains("Unregister:59", fakeComObject.OperationNames);
}
[Fact]
public async Task DispatchAsync_RegisterWithoutPayload_ReturnsInvalidRequest()
{
@@ -644,6 +697,7 @@ public sealed class MxAccessCommandExecutorTests
private readonly Exception? adviseException;
private readonly Exception? unAdviseException;
private readonly Exception? adviseSupervisoryException;
private readonly List<string> operationNames = new();
public FakeMxAccessComObject(
int registerHandle,
@@ -715,8 +769,11 @@ public sealed class MxAccessCommandExecutorTests
public int? AdviseSupervisoryThreadId { get; private set; }
public IReadOnlyList<string> OperationNames => operationNames.ToArray();
public int Register(string clientName)
{
operationNames.Add($"Register:{clientName}");
RegisteredClientName = clientName;
RegisterThreadId = Environment.CurrentManagedThreadId;
@@ -725,6 +782,7 @@ public sealed class MxAccessCommandExecutorTests
public void Unregister(int serverHandle)
{
operationNames.Add($"Unregister:{serverHandle}");
UnregisteredServerHandle = serverHandle;
UnregisterThreadId = Environment.CurrentManagedThreadId;
@@ -738,6 +796,7 @@ public sealed class MxAccessCommandExecutorTests
int serverHandle,
string itemDefinition)
{
operationNames.Add($"AddItem:{serverHandle}:{itemDefinition}");
AddItemServerHandle = serverHandle;
AddItemDefinition = itemDefinition;
AddItemThreadId = Environment.CurrentManagedThreadId;
@@ -755,6 +814,7 @@ public sealed class MxAccessCommandExecutorTests
string itemDefinition,
string itemContext)
{
operationNames.Add($"AddItem2:{serverHandle}:{itemDefinition}:{itemContext}");
AddItem2ServerHandle = serverHandle;
AddItem2Definition = itemDefinition;
AddItem2Context = itemContext;
@@ -772,6 +832,7 @@ public sealed class MxAccessCommandExecutorTests
int serverHandle,
int itemHandle)
{
operationNames.Add($"RemoveItem:{serverHandle}:{itemHandle}");
RemoveItemServerHandle = serverHandle;
RemovedItemHandle = itemHandle;
RemoveItemThreadId = Environment.CurrentManagedThreadId;
@@ -786,6 +847,7 @@ public sealed class MxAccessCommandExecutorTests
int serverHandle,
int itemHandle)
{
operationNames.Add($"Advise:{serverHandle}:{itemHandle}");
AdviseServerHandle = serverHandle;
AdvisedItemHandle = itemHandle;
AdviseThreadId = Environment.CurrentManagedThreadId;
@@ -800,6 +862,7 @@ public sealed class MxAccessCommandExecutorTests
int serverHandle,
int itemHandle)
{
operationNames.Add($"UnAdvise:{serverHandle}:{itemHandle}");
UnAdviseServerHandle = serverHandle;
UnAdvisedItemHandle = itemHandle;
UnAdviseThreadId = Environment.CurrentManagedThreadId;
@@ -814,6 +877,7 @@ public sealed class MxAccessCommandExecutorTests
int serverHandle,
int itemHandle)
{
operationNames.Add($"AdviseSupervisory:{serverHandle}:{itemHandle}");
AdviseSupervisoryServerHandle = serverHandle;
AdviseSupervisoryItemHandle = itemHandle;
AdviseSupervisoryThreadId = Environment.CurrentManagedThreadId;
@@ -110,6 +110,27 @@ public sealed class StaCommandDispatcherTests
Assert.Equal("correlation-1", reply.CorrelationId);
}
[Fact]
public async Task RequestShutdown_RejectsQueuedCommandButLetsCurrentCommandFinish()
{
using StaRuntime runtime = CreateRuntime();
runtime.Start();
BlockingCommandExecutor executor = new();
StaCommandDispatcher dispatcher = new(runtime, executor);
Task<MxCommandReply> current = dispatcher.DispatchAsync(CreateCommand("current", MxCommandKind.Register));
Assert.True(executor.Started.Wait(TimeSpan.FromSeconds(2)));
Task<MxCommandReply> pending = dispatcher.DispatchAsync(CreateCommand("pending", MxCommandKind.AddItem));
dispatcher.RequestShutdown();
MxCommandReply pendingReply = await pending;
executor.Release();
MxCommandReply currentReply = await current;
Assert.Equal(ProtocolStatusCode.WorkerUnavailable, pendingReply.ProtocolStatus.Code);
Assert.Equal(ProtocolStatusCode.Ok, currentReply.ProtocolStatus.Code);
Assert.Equal(new[] { "current" }, executor.CorrelationIds);
}
[Fact]
public async Task PopulateHeartbeat_ReportsCurrentCorrelationAndPendingCount()
{