Merge pull request #72 from agent-2/issue-25-implement-sta-command-dispatcher

Issue #25: implement STA command dispatcher
This commit was merged in pull request #72.
This commit is contained in:
2026-04-26 17:53:32 -04:00
4 changed files with 601 additions and 0 deletions
@@ -0,0 +1,279 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Sta;
namespace MxGateway.Worker.Tests.Sta;
public sealed class StaCommandDispatcherTests
{
[Fact]
public async Task DispatchAsync_ExecutesCommandsOnStaInQueueOrder()
{
using StaRuntime runtime = CreateRuntime();
runtime.Start();
RecordingCommandExecutor executor = new();
StaCommandDispatcher dispatcher = new(runtime, executor);
Task<MxCommandReply> first = dispatcher.DispatchAsync(CreateCommand("correlation-1", MxCommandKind.Register));
Task<MxCommandReply> second = dispatcher.DispatchAsync(CreateCommand("correlation-2", MxCommandKind.AddItem));
MxCommandReply[] replies = await Task.WhenAll(first, second);
Assert.Equal(new[] { "correlation-1", "correlation-2" }, executor.CorrelationIds);
Assert.All(executor.ThreadIds, threadId => Assert.Equal(runtime.StaThreadId, threadId));
Assert.Equal("correlation-1", replies[0].CorrelationId);
Assert.Equal("correlation-2", replies[1].CorrelationId);
Assert.Equal(ProtocolStatusCode.Ok, replies[0].ProtocolStatus.Code);
}
[Fact]
public async Task DispatchAsync_WhenExecutorThrows_ReturnsFailureReplyWithHResult()
{
using StaRuntime runtime = CreateRuntime();
runtime.Start();
StaCommandDispatcher dispatcher = new(
runtime,
new ThrowingCommandExecutor(new COMException("provider detail", unchecked((int)0x80070057))));
MxCommandReply reply = await dispatcher.DispatchAsync(CreateCommand("correlation-1", MxCommandKind.Register));
Assert.Equal("session-1", reply.SessionId);
Assert.Equal("correlation-1", reply.CorrelationId);
Assert.Equal(MxCommandKind.Register, reply.Kind);
Assert.Equal(ProtocolStatusCode.MxaccessFailure, reply.ProtocolStatus.Code);
Assert.Equal(unchecked((int)0x80070057), reply.Hresult);
Assert.Contains("0x80070057", reply.DiagnosticMessage);
Assert.DoesNotContain("provider detail", reply.DiagnosticMessage);
}
[Fact]
public async Task DispatchAsync_WhenCanceledBeforeExecution_ReturnsCanceledReplyWithoutExecuting()
{
using StaRuntime runtime = CreateRuntime();
runtime.Start();
BlockingCommandExecutor executor = new();
StaCommandDispatcher dispatcher = new(runtime, executor);
Task<MxCommandReply> blocked = dispatcher.DispatchAsync(CreateCommand("blocked", MxCommandKind.Register));
Assert.True(executor.Started.Wait(TimeSpan.FromSeconds(2)));
using CancellationTokenSource cancellation = new();
Task<MxCommandReply> canceled = dispatcher.DispatchAsync(
CreateCommand("canceled", MxCommandKind.AddItem, cancellation.Token));
cancellation.Cancel();
executor.Release();
MxCommandReply canceledReply = await canceled;
await blocked;
Assert.Equal(ProtocolStatusCode.Canceled, canceledReply.ProtocolStatus.Code);
Assert.DoesNotContain("canceled", executor.CorrelationIds);
}
[Fact]
public async Task DispatchAsync_WhenCanceledAfterExecutionStarts_StillReturnsLateReply()
{
using StaRuntime runtime = CreateRuntime();
runtime.Start();
BlockingCommandExecutor executor = new();
StaCommandDispatcher dispatcher = new(runtime, executor);
using CancellationTokenSource cancellation = new();
Task<MxCommandReply> replyTask = dispatcher.DispatchAsync(
CreateCommand("late-reply", MxCommandKind.Register, cancellation.Token));
Assert.True(executor.Started.Wait(TimeSpan.FromSeconds(2)));
cancellation.Cancel();
executor.Release();
MxCommandReply reply = await replyTask;
Assert.Equal(ProtocolStatusCode.Ok, reply.ProtocolStatus.Code);
Assert.Contains("late-reply", executor.CorrelationIds);
}
[Fact]
public async Task DispatchAsync_WhenShutdownRequested_RejectsNewCommands()
{
using StaRuntime runtime = CreateRuntime();
runtime.Start();
StaCommandDispatcher dispatcher = new(runtime, new RecordingCommandExecutor());
dispatcher.RequestShutdown();
MxCommandReply reply = await dispatcher.DispatchAsync(CreateCommand("correlation-1", MxCommandKind.Register));
Assert.Equal(ProtocolStatusCode.WorkerUnavailable, reply.ProtocolStatus.Code);
Assert.Equal("correlation-1", reply.CorrelationId);
}
[Fact]
public async Task PopulateHeartbeat_ReportsCurrentCorrelationAndPendingCount()
{
using StaRuntime runtime = CreateRuntime();
runtime.Start();
BlockingCommandExecutor executor = new();
StaCommandDispatcher dispatcher = new(runtime, executor);
Task<MxCommandReply> current = dispatcher.DispatchAsync(CreateCommand("current", MxCommandKind.Register));
Assert.True(executor.Started.Wait(TimeSpan.FromSeconds(2)));
Task<MxCommandReply> pending = dispatcher.DispatchAsync(CreateCommand("pending", MxCommandKind.AddItem));
WorkerHeartbeat heartbeat = new();
dispatcher.PopulateHeartbeat(heartbeat);
Assert.Equal("current", heartbeat.CurrentCommandCorrelationId);
Assert.Equal(1u, heartbeat.PendingCommandCount);
executor.Release();
await Task.WhenAll(current, pending);
}
private static StaCommand CreateCommand(
string correlationId,
MxCommandKind kind,
CancellationToken cancellationToken = default)
{
return new StaCommand(
"session-1",
correlationId,
new MxCommand
{
Kind = kind,
Ping = new PingCommand
{
Message = correlationId,
},
},
cancellationToken: cancellationToken);
}
private static StaRuntime CreateRuntime()
{
return new StaRuntime(
new NoopComApartmentInitializer(),
new StaMessagePump(),
TimeSpan.FromMilliseconds(25));
}
private sealed class RecordingCommandExecutor : IStaCommandExecutor
{
private readonly object gate = new();
private readonly List<string> correlationIds = new();
private readonly List<int> threadIds = new();
public IReadOnlyList<string> CorrelationIds
{
get
{
lock (gate)
{
return correlationIds.ToArray();
}
}
}
public IReadOnlyList<int> ThreadIds
{
get
{
lock (gate)
{
return threadIds.ToArray();
}
}
}
public MxCommandReply Execute(StaCommand command)
{
lock (gate)
{
correlationIds.Add(command.CorrelationId);
threadIds.Add(Thread.CurrentThread.ManagedThreadId);
}
return new MxCommandReply
{
ProtocolStatus = new ProtocolStatus
{
Code = ProtocolStatusCode.Ok,
Message = "OK",
},
};
}
}
private sealed class BlockingCommandExecutor : IStaCommandExecutor
{
private readonly ManualResetEventSlim release = new(false);
private readonly object gate = new();
private readonly List<string> correlationIds = new();
public ManualResetEventSlim Started { get; } = new(false);
public IReadOnlyList<string> CorrelationIds
{
get
{
lock (gate)
{
return correlationIds.ToArray();
}
}
}
public MxCommandReply Execute(StaCommand command)
{
lock (gate)
{
correlationIds.Add(command.CorrelationId);
}
Started.Set();
release.Wait(TimeSpan.FromSeconds(5));
return new MxCommandReply
{
ProtocolStatus = new ProtocolStatus
{
Code = ProtocolStatusCode.Ok,
Message = "OK",
},
};
}
public void Release()
{
release.Set();
}
}
private sealed class ThrowingCommandExecutor : IStaCommandExecutor
{
private readonly Exception exception;
public ThrowingCommandExecutor(Exception exception)
{
this.exception = exception;
}
public MxCommandReply Execute(StaCommand command)
{
throw exception;
}
}
private sealed class NoopComApartmentInitializer : IStaComApartmentInitializer
{
public void Initialize()
{
}
public void Uninitialize()
{
}
}
}
@@ -0,0 +1,8 @@
using MxGateway.Contracts.Proto;
namespace MxGateway.Worker.Sta;
public interface IStaCommandExecutor
{
MxCommandReply Execute(StaCommand command);
}
+47
View File
@@ -0,0 +1,47 @@
using System;
using System.Threading;
using Google.Protobuf.WellKnownTypes;
using MxGateway.Contracts.Proto;
namespace MxGateway.Worker.Sta;
public sealed class StaCommand
{
public StaCommand(
string sessionId,
string correlationId,
MxCommand command,
Timestamp? enqueueTimestamp = null,
CancellationToken cancellationToken = default)
{
if (string.IsNullOrWhiteSpace(sessionId))
{
throw new ArgumentException("STA command requires a session id.", nameof(sessionId));
}
if (string.IsNullOrWhiteSpace(correlationId))
{
throw new ArgumentException("STA command requires a correlation id.", nameof(correlationId));
}
SessionId = sessionId;
CorrelationId = correlationId;
Command = command ?? throw new ArgumentNullException(nameof(command));
EnqueueTimestamp = enqueueTimestamp ?? Timestamp.FromDateTime(DateTime.UtcNow);
CancellationToken = cancellationToken;
}
public string SessionId { get; }
public string CorrelationId { get; }
public MxCommand Command { get; }
public Timestamp EnqueueTimestamp { get; }
public CancellationToken CancellationToken { get; }
public MxCommandKind Kind => Command.Kind;
public string MethodName => Kind.ToString();
}
@@ -0,0 +1,267 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Conversion;
namespace MxGateway.Worker.Sta;
public sealed class StaCommandDispatcher
{
private readonly HResultConverter hresultConverter;
private readonly IStaCommandExecutor commandExecutor;
private readonly Queue<QueuedStaCommand> commandQueue = new();
private readonly StaRuntime staRuntime;
private readonly object gate = new();
private bool drainActive;
private bool shutdownRequested;
private string currentCommandCorrelationId = string.Empty;
public StaCommandDispatcher(
StaRuntime staRuntime,
IStaCommandExecutor commandExecutor)
: this(staRuntime, commandExecutor, new HResultConverter())
{
}
public StaCommandDispatcher(
StaRuntime staRuntime,
IStaCommandExecutor commandExecutor,
HResultConverter hresultConverter)
{
this.staRuntime = staRuntime ?? throw new ArgumentNullException(nameof(staRuntime));
this.commandExecutor = commandExecutor ?? throw new ArgumentNullException(nameof(commandExecutor));
this.hresultConverter = hresultConverter ?? throw new ArgumentNullException(nameof(hresultConverter));
}
public int PendingCommandCount
{
get
{
lock (gate)
{
return commandQueue.Count;
}
}
}
public string CurrentCommandCorrelationId
{
get
{
lock (gate)
{
return currentCommandCorrelationId;
}
}
}
public Task<MxCommandReply> DispatchAsync(StaCommand command)
{
if (command is null)
{
throw new ArgumentNullException(nameof(command));
}
lock (gate)
{
if (shutdownRequested)
{
return Task.FromResult(CreateRejectedReply(
command,
ProtocolStatusCode.WorkerUnavailable,
"The STA command dispatcher is shutting down."));
}
QueuedStaCommand queuedCommand = new(command);
commandQueue.Enqueue(queuedCommand);
if (!drainActive)
{
drainActive = true;
_ = DrainAsync();
}
return queuedCommand.Task;
}
}
public void RequestShutdown()
{
lock (gate)
{
shutdownRequested = true;
}
}
public void PopulateHeartbeat(WorkerHeartbeat heartbeat)
{
if (heartbeat is null)
{
throw new ArgumentNullException(nameof(heartbeat));
}
lock (gate)
{
heartbeat.PendingCommandCount = (uint)commandQueue.Count;
heartbeat.CurrentCommandCorrelationId = currentCommandCorrelationId;
}
}
private async Task DrainAsync()
{
while (true)
{
QueuedStaCommand queuedCommand;
lock (gate)
{
if (commandQueue.Count == 0)
{
drainActive = false;
return;
}
queuedCommand = commandQueue.Dequeue();
}
await ExecuteQueuedCommandAsync(queuedCommand).ConfigureAwait(false);
}
}
private async Task ExecuteQueuedCommandAsync(QueuedStaCommand queuedCommand)
{
StaCommand command = queuedCommand.Command;
if (command.CancellationToken.IsCancellationRequested)
{
queuedCommand.Complete(CreateRejectedReply(
command,
ProtocolStatusCode.Canceled,
"The STA command was canceled before execution."));
return;
}
SetCurrentCommand(command.CorrelationId);
try
{
MxCommandReply reply = await staRuntime
.InvokeAsync(() => commandExecutor.Execute(command))
.ConfigureAwait(false);
queuedCommand.Complete(NormalizeReply(command, reply));
}
catch (Exception exception)
{
queuedCommand.Complete(CreateExceptionReply(command, exception));
}
finally
{
SetCurrentCommand(string.Empty);
}
}
private void SetCurrentCommand(string correlationId)
{
lock (gate)
{
currentCommandCorrelationId = correlationId;
}
}
private MxCommandReply NormalizeReply(
StaCommand command,
MxCommandReply reply)
{
if (reply is null)
{
return CreateRejectedReply(
command,
ProtocolStatusCode.ProtocolViolation,
"STA command executor returned null.");
}
if (string.IsNullOrWhiteSpace(reply.SessionId))
{
reply.SessionId = command.SessionId;
}
if (string.IsNullOrWhiteSpace(reply.CorrelationId))
{
reply.CorrelationId = command.CorrelationId;
}
if (reply.Kind == MxCommandKind.Unspecified)
{
reply.Kind = command.Kind;
}
if (reply.ProtocolStatus is null)
{
reply.ProtocolStatus = new ProtocolStatus
{
Code = ProtocolStatusCode.Ok,
Message = "OK",
};
}
return reply;
}
private MxCommandReply CreateExceptionReply(
StaCommand command,
Exception exception)
{
HResultConversion conversion = hresultConverter.Convert(exception);
MxCommandReply reply = CreateBaseReply(command);
reply.ProtocolStatus = conversion.ProtocolStatus;
reply.Hresult = conversion.HResult;
reply.DiagnosticMessage = conversion.DiagnosticMessage;
return reply;
}
private static MxCommandReply CreateRejectedReply(
StaCommand command,
ProtocolStatusCode statusCode,
string message)
{
MxCommandReply reply = CreateBaseReply(command);
reply.ProtocolStatus = new ProtocolStatus
{
Code = statusCode,
Message = message,
};
reply.DiagnosticMessage = message;
return reply;
}
private static MxCommandReply CreateBaseReply(StaCommand command)
{
return new MxCommandReply
{
SessionId = command.SessionId,
CorrelationId = command.CorrelationId,
Kind = command.Kind,
};
}
private sealed class QueuedStaCommand
{
private readonly TaskCompletionSource<MxCommandReply> completion = new(
TaskCreationOptions.RunContinuationsAsynchronously);
public QueuedStaCommand(StaCommand command)
{
Command = command;
}
public StaCommand Command { get; }
public Task<MxCommandReply> Task => completion.Task;
public void Complete(MxCommandReply reply)
{
completion.TrySetResult(reply);
}
}
}