From 14419853c7bd7fe28ed3ca973adf74f69cfb980d Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 26 Apr 2026 17:49:01 -0400 Subject: [PATCH] Issue #25: implement sta command dispatcher --- .../Sta/StaCommandDispatcherTests.cs | 279 ++++++++++++++++++ .../Sta/IStaCommandExecutor.cs | 8 + src/MxGateway.Worker/Sta/StaCommand.cs | 47 +++ .../Sta/StaCommandDispatcher.cs | 267 +++++++++++++++++ 4 files changed, 601 insertions(+) create mode 100644 src/MxGateway.Worker.Tests/Sta/StaCommandDispatcherTests.cs create mode 100644 src/MxGateway.Worker/Sta/IStaCommandExecutor.cs create mode 100644 src/MxGateway.Worker/Sta/StaCommand.cs create mode 100644 src/MxGateway.Worker/Sta/StaCommandDispatcher.cs diff --git a/src/MxGateway.Worker.Tests/Sta/StaCommandDispatcherTests.cs b/src/MxGateway.Worker.Tests/Sta/StaCommandDispatcherTests.cs new file mode 100644 index 0000000..f4ee5a5 --- /dev/null +++ b/src/MxGateway.Worker.Tests/Sta/StaCommandDispatcherTests.cs @@ -0,0 +1,279 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; +using MxGateway.Contracts.Proto; +using MxGateway.Worker.Sta; + +namespace MxGateway.Worker.Tests.Sta; + +public sealed class StaCommandDispatcherTests +{ + [Fact] + public async Task DispatchAsync_ExecutesCommandsOnStaInQueueOrder() + { + using StaRuntime runtime = CreateRuntime(); + runtime.Start(); + RecordingCommandExecutor executor = new(); + StaCommandDispatcher dispatcher = new(runtime, executor); + + Task first = dispatcher.DispatchAsync(CreateCommand("correlation-1", MxCommandKind.Register)); + Task second = dispatcher.DispatchAsync(CreateCommand("correlation-2", MxCommandKind.AddItem)); + + MxCommandReply[] replies = await Task.WhenAll(first, second); + + Assert.Equal(new[] { "correlation-1", "correlation-2" }, executor.CorrelationIds); + Assert.All(executor.ThreadIds, threadId => Assert.Equal(runtime.StaThreadId, threadId)); + Assert.Equal("correlation-1", replies[0].CorrelationId); + Assert.Equal("correlation-2", replies[1].CorrelationId); + Assert.Equal(ProtocolStatusCode.Ok, replies[0].ProtocolStatus.Code); + } + + [Fact] + public async Task DispatchAsync_WhenExecutorThrows_ReturnsFailureReplyWithHResult() + { + using StaRuntime runtime = CreateRuntime(); + runtime.Start(); + StaCommandDispatcher dispatcher = new( + runtime, + new ThrowingCommandExecutor(new COMException("provider detail", unchecked((int)0x80070057)))); + + MxCommandReply reply = await dispatcher.DispatchAsync(CreateCommand("correlation-1", MxCommandKind.Register)); + + Assert.Equal("session-1", reply.SessionId); + Assert.Equal("correlation-1", reply.CorrelationId); + Assert.Equal(MxCommandKind.Register, reply.Kind); + Assert.Equal(ProtocolStatusCode.MxaccessFailure, reply.ProtocolStatus.Code); + Assert.Equal(unchecked((int)0x80070057), reply.Hresult); + Assert.Contains("0x80070057", reply.DiagnosticMessage); + Assert.DoesNotContain("provider detail", reply.DiagnosticMessage); + } + + [Fact] + public async Task DispatchAsync_WhenCanceledBeforeExecution_ReturnsCanceledReplyWithoutExecuting() + { + using StaRuntime runtime = CreateRuntime(); + runtime.Start(); + BlockingCommandExecutor executor = new(); + StaCommandDispatcher dispatcher = new(runtime, executor); + Task blocked = dispatcher.DispatchAsync(CreateCommand("blocked", MxCommandKind.Register)); + Assert.True(executor.Started.Wait(TimeSpan.FromSeconds(2))); + + using CancellationTokenSource cancellation = new(); + Task canceled = dispatcher.DispatchAsync( + CreateCommand("canceled", MxCommandKind.AddItem, cancellation.Token)); + cancellation.Cancel(); + + executor.Release(); + MxCommandReply canceledReply = await canceled; + await blocked; + + Assert.Equal(ProtocolStatusCode.Canceled, canceledReply.ProtocolStatus.Code); + Assert.DoesNotContain("canceled", executor.CorrelationIds); + } + + [Fact] + public async Task DispatchAsync_WhenCanceledAfterExecutionStarts_StillReturnsLateReply() + { + using StaRuntime runtime = CreateRuntime(); + runtime.Start(); + BlockingCommandExecutor executor = new(); + StaCommandDispatcher dispatcher = new(runtime, executor); + using CancellationTokenSource cancellation = new(); + + Task replyTask = dispatcher.DispatchAsync( + CreateCommand("late-reply", MxCommandKind.Register, cancellation.Token)); + + Assert.True(executor.Started.Wait(TimeSpan.FromSeconds(2))); + cancellation.Cancel(); + executor.Release(); + + MxCommandReply reply = await replyTask; + + Assert.Equal(ProtocolStatusCode.Ok, reply.ProtocolStatus.Code); + Assert.Contains("late-reply", executor.CorrelationIds); + } + + [Fact] + public async Task DispatchAsync_WhenShutdownRequested_RejectsNewCommands() + { + using StaRuntime runtime = CreateRuntime(); + runtime.Start(); + StaCommandDispatcher dispatcher = new(runtime, new RecordingCommandExecutor()); + + dispatcher.RequestShutdown(); + MxCommandReply reply = await dispatcher.DispatchAsync(CreateCommand("correlation-1", MxCommandKind.Register)); + + Assert.Equal(ProtocolStatusCode.WorkerUnavailable, reply.ProtocolStatus.Code); + Assert.Equal("correlation-1", reply.CorrelationId); + } + + [Fact] + public async Task PopulateHeartbeat_ReportsCurrentCorrelationAndPendingCount() + { + using StaRuntime runtime = CreateRuntime(); + runtime.Start(); + BlockingCommandExecutor executor = new(); + StaCommandDispatcher dispatcher = new(runtime, executor); + + Task current = dispatcher.DispatchAsync(CreateCommand("current", MxCommandKind.Register)); + Assert.True(executor.Started.Wait(TimeSpan.FromSeconds(2))); + Task pending = dispatcher.DispatchAsync(CreateCommand("pending", MxCommandKind.AddItem)); + + WorkerHeartbeat heartbeat = new(); + dispatcher.PopulateHeartbeat(heartbeat); + + Assert.Equal("current", heartbeat.CurrentCommandCorrelationId); + Assert.Equal(1u, heartbeat.PendingCommandCount); + + executor.Release(); + await Task.WhenAll(current, pending); + } + + private static StaCommand CreateCommand( + string correlationId, + MxCommandKind kind, + CancellationToken cancellationToken = default) + { + return new StaCommand( + "session-1", + correlationId, + new MxCommand + { + Kind = kind, + Ping = new PingCommand + { + Message = correlationId, + }, + }, + cancellationToken: cancellationToken); + } + + private static StaRuntime CreateRuntime() + { + return new StaRuntime( + new NoopComApartmentInitializer(), + new StaMessagePump(), + TimeSpan.FromMilliseconds(25)); + } + + private sealed class RecordingCommandExecutor : IStaCommandExecutor + { + private readonly object gate = new(); + private readonly List correlationIds = new(); + private readonly List threadIds = new(); + + public IReadOnlyList CorrelationIds + { + get + { + lock (gate) + { + return correlationIds.ToArray(); + } + } + } + + public IReadOnlyList ThreadIds + { + get + { + lock (gate) + { + return threadIds.ToArray(); + } + } + } + + public MxCommandReply Execute(StaCommand command) + { + lock (gate) + { + correlationIds.Add(command.CorrelationId); + threadIds.Add(Thread.CurrentThread.ManagedThreadId); + } + + return new MxCommandReply + { + ProtocolStatus = new ProtocolStatus + { + Code = ProtocolStatusCode.Ok, + Message = "OK", + }, + }; + } + } + + private sealed class BlockingCommandExecutor : IStaCommandExecutor + { + private readonly ManualResetEventSlim release = new(false); + private readonly object gate = new(); + private readonly List correlationIds = new(); + + public ManualResetEventSlim Started { get; } = new(false); + + public IReadOnlyList CorrelationIds + { + get + { + lock (gate) + { + return correlationIds.ToArray(); + } + } + } + + public MxCommandReply Execute(StaCommand command) + { + lock (gate) + { + correlationIds.Add(command.CorrelationId); + } + + Started.Set(); + release.Wait(TimeSpan.FromSeconds(5)); + + return new MxCommandReply + { + ProtocolStatus = new ProtocolStatus + { + Code = ProtocolStatusCode.Ok, + Message = "OK", + }, + }; + } + + public void Release() + { + release.Set(); + } + } + + private sealed class ThrowingCommandExecutor : IStaCommandExecutor + { + private readonly Exception exception; + + public ThrowingCommandExecutor(Exception exception) + { + this.exception = exception; + } + + public MxCommandReply Execute(StaCommand command) + { + throw exception; + } + } + + private sealed class NoopComApartmentInitializer : IStaComApartmentInitializer + { + public void Initialize() + { + } + + public void Uninitialize() + { + } + } +} diff --git a/src/MxGateway.Worker/Sta/IStaCommandExecutor.cs b/src/MxGateway.Worker/Sta/IStaCommandExecutor.cs new file mode 100644 index 0000000..f47a2c8 --- /dev/null +++ b/src/MxGateway.Worker/Sta/IStaCommandExecutor.cs @@ -0,0 +1,8 @@ +using MxGateway.Contracts.Proto; + +namespace MxGateway.Worker.Sta; + +public interface IStaCommandExecutor +{ + MxCommandReply Execute(StaCommand command); +} diff --git a/src/MxGateway.Worker/Sta/StaCommand.cs b/src/MxGateway.Worker/Sta/StaCommand.cs new file mode 100644 index 0000000..46f95d6 --- /dev/null +++ b/src/MxGateway.Worker/Sta/StaCommand.cs @@ -0,0 +1,47 @@ +using System; +using System.Threading; +using Google.Protobuf.WellKnownTypes; +using MxGateway.Contracts.Proto; + +namespace MxGateway.Worker.Sta; + +public sealed class StaCommand +{ + public StaCommand( + string sessionId, + string correlationId, + MxCommand command, + Timestamp? enqueueTimestamp = null, + CancellationToken cancellationToken = default) + { + if (string.IsNullOrWhiteSpace(sessionId)) + { + throw new ArgumentException("STA command requires a session id.", nameof(sessionId)); + } + + if (string.IsNullOrWhiteSpace(correlationId)) + { + throw new ArgumentException("STA command requires a correlation id.", nameof(correlationId)); + } + + SessionId = sessionId; + CorrelationId = correlationId; + Command = command ?? throw new ArgumentNullException(nameof(command)); + EnqueueTimestamp = enqueueTimestamp ?? Timestamp.FromDateTime(DateTime.UtcNow); + CancellationToken = cancellationToken; + } + + public string SessionId { get; } + + public string CorrelationId { get; } + + public MxCommand Command { get; } + + public Timestamp EnqueueTimestamp { get; } + + public CancellationToken CancellationToken { get; } + + public MxCommandKind Kind => Command.Kind; + + public string MethodName => Kind.ToString(); +} diff --git a/src/MxGateway.Worker/Sta/StaCommandDispatcher.cs b/src/MxGateway.Worker/Sta/StaCommandDispatcher.cs new file mode 100644 index 0000000..3df663e --- /dev/null +++ b/src/MxGateway.Worker/Sta/StaCommandDispatcher.cs @@ -0,0 +1,267 @@ +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using MxGateway.Contracts.Proto; +using MxGateway.Worker.Conversion; + +namespace MxGateway.Worker.Sta; + +public sealed class StaCommandDispatcher +{ + private readonly HResultConverter hresultConverter; + private readonly IStaCommandExecutor commandExecutor; + private readonly Queue commandQueue = new(); + private readonly StaRuntime staRuntime; + private readonly object gate = new(); + private bool drainActive; + private bool shutdownRequested; + private string currentCommandCorrelationId = string.Empty; + + public StaCommandDispatcher( + StaRuntime staRuntime, + IStaCommandExecutor commandExecutor) + : this(staRuntime, commandExecutor, new HResultConverter()) + { + } + + public StaCommandDispatcher( + StaRuntime staRuntime, + IStaCommandExecutor commandExecutor, + HResultConverter hresultConverter) + { + this.staRuntime = staRuntime ?? throw new ArgumentNullException(nameof(staRuntime)); + this.commandExecutor = commandExecutor ?? throw new ArgumentNullException(nameof(commandExecutor)); + this.hresultConverter = hresultConverter ?? throw new ArgumentNullException(nameof(hresultConverter)); + } + + public int PendingCommandCount + { + get + { + lock (gate) + { + return commandQueue.Count; + } + } + } + + public string CurrentCommandCorrelationId + { + get + { + lock (gate) + { + return currentCommandCorrelationId; + } + } + } + + public Task DispatchAsync(StaCommand command) + { + if (command is null) + { + throw new ArgumentNullException(nameof(command)); + } + + lock (gate) + { + if (shutdownRequested) + { + return Task.FromResult(CreateRejectedReply( + command, + ProtocolStatusCode.WorkerUnavailable, + "The STA command dispatcher is shutting down.")); + } + + QueuedStaCommand queuedCommand = new(command); + commandQueue.Enqueue(queuedCommand); + + if (!drainActive) + { + drainActive = true; + _ = DrainAsync(); + } + + return queuedCommand.Task; + } + } + + public void RequestShutdown() + { + lock (gate) + { + shutdownRequested = true; + } + } + + public void PopulateHeartbeat(WorkerHeartbeat heartbeat) + { + if (heartbeat is null) + { + throw new ArgumentNullException(nameof(heartbeat)); + } + + lock (gate) + { + heartbeat.PendingCommandCount = (uint)commandQueue.Count; + heartbeat.CurrentCommandCorrelationId = currentCommandCorrelationId; + } + } + + private async Task DrainAsync() + { + while (true) + { + QueuedStaCommand queuedCommand; + lock (gate) + { + if (commandQueue.Count == 0) + { + drainActive = false; + return; + } + + queuedCommand = commandQueue.Dequeue(); + } + + await ExecuteQueuedCommandAsync(queuedCommand).ConfigureAwait(false); + } + } + + private async Task ExecuteQueuedCommandAsync(QueuedStaCommand queuedCommand) + { + StaCommand command = queuedCommand.Command; + if (command.CancellationToken.IsCancellationRequested) + { + queuedCommand.Complete(CreateRejectedReply( + command, + ProtocolStatusCode.Canceled, + "The STA command was canceled before execution.")); + return; + } + + SetCurrentCommand(command.CorrelationId); + try + { + MxCommandReply reply = await staRuntime + .InvokeAsync(() => commandExecutor.Execute(command)) + .ConfigureAwait(false); + + queuedCommand.Complete(NormalizeReply(command, reply)); + } + catch (Exception exception) + { + queuedCommand.Complete(CreateExceptionReply(command, exception)); + } + finally + { + SetCurrentCommand(string.Empty); + } + } + + private void SetCurrentCommand(string correlationId) + { + lock (gate) + { + currentCommandCorrelationId = correlationId; + } + } + + private MxCommandReply NormalizeReply( + StaCommand command, + MxCommandReply reply) + { + if (reply is null) + { + return CreateRejectedReply( + command, + ProtocolStatusCode.ProtocolViolation, + "STA command executor returned null."); + } + + if (string.IsNullOrWhiteSpace(reply.SessionId)) + { + reply.SessionId = command.SessionId; + } + + if (string.IsNullOrWhiteSpace(reply.CorrelationId)) + { + reply.CorrelationId = command.CorrelationId; + } + + if (reply.Kind == MxCommandKind.Unspecified) + { + reply.Kind = command.Kind; + } + + if (reply.ProtocolStatus is null) + { + reply.ProtocolStatus = new ProtocolStatus + { + Code = ProtocolStatusCode.Ok, + Message = "OK", + }; + } + + return reply; + } + + private MxCommandReply CreateExceptionReply( + StaCommand command, + Exception exception) + { + HResultConversion conversion = hresultConverter.Convert(exception); + MxCommandReply reply = CreateBaseReply(command); + reply.ProtocolStatus = conversion.ProtocolStatus; + reply.Hresult = conversion.HResult; + reply.DiagnosticMessage = conversion.DiagnosticMessage; + + return reply; + } + + private static MxCommandReply CreateRejectedReply( + StaCommand command, + ProtocolStatusCode statusCode, + string message) + { + MxCommandReply reply = CreateBaseReply(command); + reply.ProtocolStatus = new ProtocolStatus + { + Code = statusCode, + Message = message, + }; + reply.DiagnosticMessage = message; + + return reply; + } + + private static MxCommandReply CreateBaseReply(StaCommand command) + { + return new MxCommandReply + { + SessionId = command.SessionId, + CorrelationId = command.CorrelationId, + Kind = command.Kind, + }; + } + + private sealed class QueuedStaCommand + { + private readonly TaskCompletionSource completion = new( + TaskCreationOptions.RunContinuationsAsynchronously); + + public QueuedStaCommand(StaCommand command) + { + Command = command; + } + + public StaCommand Command { get; } + + public Task Task => completion.Task; + + public void Complete(MxCommandReply reply) + { + completion.TrySetResult(reply); + } + } +}