Issue #23: implement sta runtime and message pump #66

Merged
dohertj2 merged 1 commits from agent-1/issue-23-implement-sta-runtime-and-message-pump into main 2026-04-26 17:23:04 -04:00
8 changed files with 659 additions and 0 deletions
+11
View File
@@ -250,6 +250,17 @@ The loop should update a heartbeat timestamp after:
- finishing a command,
- processing an MXAccess event.
`StaRuntime` implements this runtime boundary in the worker. It starts one
background thread named `MxGateway.Worker.STA`, sets it to `ApartmentState.STA`,
initializes COM through `StaComApartmentInitializer`, and runs
`StaMessagePump`. Commands are scheduled through `InvokeAsync`; the command
queue signals an `AutoResetEvent` so `MsgWaitForMultipleObjectsEx` can wake the
STA without busy-waiting. `LastActivityUtc` records pump, command, startup, and
shutdown activity so the future heartbeat/watchdog can report whether the STA
is still responsive. Shutdown marks the runtime as closing, wakes the pump,
rejects new commands, cancels queued work, uninitializes COM on the STA, and
waits for the thread to exit.
## COM Creation
The MXAccess analysis source at `C:\Users\dohertj2\Desktop\mxaccess` identifies
@@ -0,0 +1,152 @@
using System;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Worker.Sta;
namespace MxGateway.Worker.Tests.Sta;
public sealed class StaRuntimeTests
{
[Fact]
public async Task InvokeAsync_ExecutesCommandOnStaThread()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
StaCommandObservation observation = await runtime.InvokeAsync(
() => new StaCommandObservation(
Thread.CurrentThread.ManagedThreadId,
Thread.CurrentThread.GetApartmentState()));
Assert.Equal(runtime.StaThreadId, observation.ThreadId);
Assert.Equal(initializer.InitializeThreadId, observation.ThreadId);
Assert.Equal(ApartmentState.STA, observation.ApartmentState);
}
[Fact]
public async Task InvokeAsync_WakesIdlePumpForQueuedCommand()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = new(
initializer,
new StaMessagePump(),
TimeSpan.FromSeconds(30));
runtime.Start();
Stopwatch stopwatch = Stopwatch.StartNew();
int threadId = await runtime.InvokeAsync(() => Thread.CurrentThread.ManagedThreadId);
stopwatch.Stop();
Assert.Equal(runtime.StaThreadId, threadId);
Assert.True(
stopwatch.Elapsed < TimeSpan.FromSeconds(2),
$"Command took {stopwatch.Elapsed} to execute, so the command wake event did not wake the STA promptly.");
}
[Fact]
public void Shutdown_StopsThreadAndUninitializesComApartment()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
bool stopped = runtime.Shutdown(TimeSpan.FromSeconds(2));
Assert.True(stopped);
Assert.False(runtime.IsRunning);
Assert.Equal(1, initializer.InitializeCount);
Assert.Equal(1, initializer.UninitializeCount);
Assert.Equal(initializer.InitializeThreadId, initializer.UninitializeThreadId);
}
[Fact]
public void LastActivityUtc_UpdatesWhilePumpIsIdle()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
DateTimeOffset firstActivity = runtime.LastActivityUtc;
bool updated = SpinWait.SpinUntil(
() => runtime.LastActivityUtc > firstActivity,
TimeSpan.FromSeconds(2));
Assert.True(updated);
}
[Fact]
public async Task InvokeAsync_CommandException_FaultsReturnedTaskWithoutStoppingRuntime()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
InvalidOperationException exception = await Assert.ThrowsAsync<InvalidOperationException>(
() => runtime.InvokeAsync<int>(() => throw new InvalidOperationException("command failed")));
int threadId = await runtime.InvokeAsync(() => Thread.CurrentThread.ManagedThreadId);
Assert.Equal("command failed", exception.Message);
Assert.Equal(runtime.StaThreadId, threadId);
}
[Fact]
public async Task InvokeAsync_AfterShutdown_ReturnsFaultedTask()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
runtime.Shutdown(TimeSpan.FromSeconds(2));
InvalidOperationException exception = await Assert.ThrowsAsync<InvalidOperationException>(
() => runtime.InvokeAsync(() => Thread.CurrentThread.ManagedThreadId));
Assert.Contains("shutting down", exception.Message);
}
private static StaRuntime CreateRuntime(RecordingComApartmentInitializer initializer)
{
return new StaRuntime(
initializer,
new StaMessagePump(),
TimeSpan.FromMilliseconds(25));
}
private sealed class StaCommandObservation
{
public StaCommandObservation(int threadId, ApartmentState apartmentState)
{
ThreadId = threadId;
ApartmentState = apartmentState;
}
public int ThreadId { get; }
public ApartmentState ApartmentState { get; }
}
private sealed class RecordingComApartmentInitializer : IStaComApartmentInitializer
{
public int InitializeCount { get; private set; }
public int UninitializeCount { get; private set; }
public int? InitializeThreadId { get; private set; }
public int? UninitializeThreadId { get; private set; }
public void Initialize()
{
InitializeCount++;
InitializeThreadId = Thread.CurrentThread.ManagedThreadId;
}
public void Uninitialize()
{
UninitializeCount++;
UninitializeThreadId = Thread.CurrentThread.ManagedThreadId;
}
}
}
@@ -0,0 +1,8 @@
namespace MxGateway.Worker.Sta;
public interface IStaComApartmentInitializer
{
void Initialize();
void Uninitialize();
}
+8
View File
@@ -0,0 +1,8 @@
namespace MxGateway.Worker.Sta;
internal interface IStaWorkItem
{
void CancelBeforeExecution();
void Execute();
}
@@ -0,0 +1,31 @@
using System;
using System.Runtime.InteropServices;
namespace MxGateway.Worker.Sta;
public sealed class StaComApartmentInitializer : IStaComApartmentInitializer
{
private const uint CoInitializeApartmentThreaded = 0x2;
private const int SOk = 0;
private const int SFalse = 1;
public void Initialize()
{
int hresult = CoInitializeEx(IntPtr.Zero, CoInitializeApartmentThreaded);
if (hresult != SOk && hresult != SFalse)
{
throw new COMException("Failed to initialize the worker STA COM apartment.", hresult);
}
}
public void Uninitialize()
{
CoUninitialize();
}
[DllImport("ole32.dll")]
private static extern int CoInitializeEx(IntPtr reserved, uint coInit);
[DllImport("ole32.dll")]
private static extern void CoUninitialize();
}
+111
View File
@@ -0,0 +1,111 @@
using System;
using System.Runtime.InteropServices;
using System.Threading;
using Microsoft.Win32.SafeHandles;
namespace MxGateway.Worker.Sta;
public sealed class StaMessagePump
{
private const uint Infinite = 0xFFFFFFFF;
private const uint MsgWaitFailed = 0xFFFFFFFF;
private const uint MwmoInputAvailable = 0x0004;
private const uint PmRemove = 0x0001;
private const uint QsAllInput = 0x04FF;
public void WaitForWorkOrMessages(WaitHandle commandWakeEvent, TimeSpan timeout)
{
if (commandWakeEvent is null)
{
throw new ArgumentNullException(nameof(commandWakeEvent));
}
uint timeoutMilliseconds = ToTimeoutMilliseconds(timeout);
SafeWaitHandle safeHandle = commandWakeEvent.SafeWaitHandle;
IntPtr[] handles = [safeHandle.DangerousGetHandle()];
uint result = MsgWaitForMultipleObjectsEx(
(uint)handles.Length,
handles,
timeoutMilliseconds,
QsAllInput,
MwmoInputAvailable);
if (result == MsgWaitFailed)
{
throw new InvalidOperationException(
"The worker STA message pump failed while waiting for command work or Windows messages.");
}
}
public int PumpPendingMessages()
{
int pumpedMessages = 0;
while (PeekMessage(out NativeMessage message, IntPtr.Zero, 0, 0, PmRemove))
{
TranslateMessage(ref message);
DispatchMessage(ref message);
pumpedMessages++;
}
return pumpedMessages;
}
private static uint ToTimeoutMilliseconds(TimeSpan timeout)
{
if (timeout == Timeout.InfiniteTimeSpan)
{
return Infinite;
}
if (timeout <= TimeSpan.Zero)
{
return 0;
}
return timeout.TotalMilliseconds >= uint.MaxValue
? uint.MaxValue - 1
: (uint)Math.Ceiling(timeout.TotalMilliseconds);
}
[DllImport("user32.dll", SetLastError = true)]
private static extern uint MsgWaitForMultipleObjectsEx(
uint count,
IntPtr[] handles,
uint milliseconds,
uint wakeMask,
uint flags);
[DllImport("user32.dll", SetLastError = true)]
private static extern bool PeekMessage(
out NativeMessage message,
IntPtr windowHandle,
uint messageFilterMin,
uint messageFilterMax,
uint removeMessage);
[DllImport("user32.dll")]
private static extern bool TranslateMessage(ref NativeMessage message);
[DllImport("user32.dll")]
private static extern IntPtr DispatchMessage(ref NativeMessage message);
[StructLayout(LayoutKind.Sequential)]
private struct NativeMessage
{
public IntPtr WindowHandle;
public uint Message;
public UIntPtr WParam;
public IntPtr LParam;
public uint Time;
public NativePoint Point;
}
[StructLayout(LayoutKind.Sequential)]
private struct NativePoint
{
public int X;
public int Y;
}
}
+267
View File
@@ -0,0 +1,267 @@
using System;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;
namespace MxGateway.Worker.Sta;
public sealed class StaRuntime : IDisposable
{
private readonly IStaComApartmentInitializer comApartmentInitializer;
private readonly StaMessagePump messagePump;
private readonly ConcurrentQueue<IStaWorkItem> commandQueue = new();
private readonly AutoResetEvent commandWakeEvent = new(false);
private readonly ManualResetEventSlim startedEvent = new(false);
private readonly ManualResetEventSlim stoppedEvent = new(false);
private readonly object gate = new();
private readonly Thread staThread;
private readonly TimeSpan idlePumpInterval;
private bool disposed;
private bool startRequested;
private bool shutdownRequested;
private Exception? startupException;
private long lastActivityUtcTicks;
private bool comInitialized;
public StaRuntime()
: this(new StaComApartmentInitializer(), new StaMessagePump(), TimeSpan.FromMilliseconds(50))
{
}
public StaRuntime(
IStaComApartmentInitializer comApartmentInitializer,
StaMessagePump messagePump,
TimeSpan idlePumpInterval)
{
this.comApartmentInitializer = comApartmentInitializer
?? throw new ArgumentNullException(nameof(comApartmentInitializer));
this.messagePump = messagePump ?? throw new ArgumentNullException(nameof(messagePump));
if (idlePumpInterval <= TimeSpan.Zero)
{
throw new ArgumentOutOfRangeException(
nameof(idlePumpInterval),
"The idle pump interval must be greater than zero.");
}
this.idlePumpInterval = idlePumpInterval;
lastActivityUtcTicks = DateTimeOffset.UtcNow.UtcTicks;
staThread = new Thread(ThreadMain)
{
IsBackground = true,
Name = "MxGateway.Worker.STA"
};
staThread.SetApartmentState(ApartmentState.STA);
}
public int? StaThreadId { get; private set; }
public DateTimeOffset LastActivityUtc =>
new(new DateTime(Volatile.Read(ref lastActivityUtcTicks), DateTimeKind.Utc));
public bool IsRunning => startedEvent.IsSet && !stoppedEvent.IsSet;
public void Start()
{
ThrowIfDisposed();
lock (gate)
{
if (shutdownRequested)
{
throw new InvalidOperationException("The worker STA runtime is shutting down.");
}
if (!startRequested)
{
startRequested = true;
staThread.Start();
}
}
startedEvent.Wait();
if (startupException is not null)
{
throw new InvalidOperationException(
"The worker STA runtime failed to initialize.",
startupException);
}
}
public Task InvokeAsync(Action command, CancellationToken cancellationToken = default)
{
if (command is null)
{
throw new ArgumentNullException(nameof(command));
}
return InvokeAsync(
() =>
{
command();
return true;
},
cancellationToken);
}
public Task<T> InvokeAsync<T>(Func<T> command, CancellationToken cancellationToken = default)
{
if (command is null)
{
throw new ArgumentNullException(nameof(command));
}
ThrowIfDisposed();
if (cancellationToken.IsCancellationRequested)
{
return Task.FromCanceled<T>(cancellationToken);
}
StaWorkItem<T> workItem = new(command, cancellationToken);
lock (gate)
{
if (shutdownRequested)
{
return Task.FromException<T>(
new InvalidOperationException("The worker STA runtime is shutting down."));
}
commandQueue.Enqueue(workItem);
}
commandWakeEvent.Set();
return workItem.Task;
}
public bool Shutdown(TimeSpan timeout)
{
if (timeout < TimeSpan.Zero && timeout != Timeout.InfiniteTimeSpan)
{
throw new ArgumentOutOfRangeException(nameof(timeout));
}
lock (gate)
{
shutdownRequested = true;
}
commandWakeEvent.Set();
if (!startedEvent.IsSet && !staThread.IsAlive)
{
CancelQueuedCommands();
stoppedEvent.Set();
return true;
}
bool stopped = stoppedEvent.Wait(timeout);
if (stopped)
{
CancelQueuedCommands();
}
return stopped;
}
public void Dispose()
{
if (disposed)
{
return;
}
bool stopped = Shutdown(TimeSpan.FromSeconds(5));
if (stopped)
{
commandWakeEvent.Dispose();
startedEvent.Dispose();
stoppedEvent.Dispose();
}
disposed = true;
}
private void ThreadMain()
{
try
{
StaThreadId = Thread.CurrentThread.ManagedThreadId;
comApartmentInitializer.Initialize();
comInitialized = true;
MarkActivity();
startedEvent.Set();
while (!IsShutdownRequested())
{
ProcessQueuedCommands();
messagePump.WaitForWorkOrMessages(commandWakeEvent, idlePumpInterval);
messagePump.PumpPendingMessages();
MarkActivity();
}
ProcessQueuedCommands();
}
catch (Exception exception)
{
startupException = exception;
startedEvent.Set();
}
finally
{
CancelQueuedCommands();
try
{
if (comInitialized)
{
comApartmentInitializer.Uninitialize();
}
}
finally
{
MarkActivity();
stoppedEvent.Set();
}
}
}
private void ProcessQueuedCommands()
{
while (commandQueue.TryDequeue(out IStaWorkItem? workItem))
{
MarkActivity();
workItem.Execute();
MarkActivity();
}
}
private void CancelQueuedCommands()
{
while (commandQueue.TryDequeue(out IStaWorkItem? workItem))
{
workItem.CancelBeforeExecution();
}
}
private bool IsShutdownRequested()
{
lock (gate)
{
return shutdownRequested;
}
}
private void MarkActivity()
{
Volatile.Write(ref lastActivityUtcTicks, DateTimeOffset.UtcNow.UtcTicks);
}
private void ThrowIfDisposed()
{
if (disposed)
{
throw new ObjectDisposedException(nameof(StaRuntime));
}
}
}
+71
View File
@@ -0,0 +1,71 @@
using System;
using System.Threading;
using System.Threading.Tasks;
namespace MxGateway.Worker.Sta;
internal sealed class StaWorkItem<T> : IStaWorkItem
{
private readonly Func<T> command;
private readonly CancellationToken cancellationToken;
private readonly CancellationTokenRegistration cancellationRegistration;
private int started;
public StaWorkItem(Func<T> command, CancellationToken cancellationToken)
{
this.command = command ?? throw new ArgumentNullException(nameof(command));
this.cancellationToken = cancellationToken;
Completion = new TaskCompletionSource<T>(TaskCreationOptions.RunContinuationsAsynchronously);
if (cancellationToken.CanBeCanceled)
{
cancellationRegistration = cancellationToken.Register(
() =>
{
if (Interlocked.CompareExchange(ref started, 1, 0) == 0)
{
Completion.TrySetCanceled(cancellationToken);
}
});
}
}
public Task<T> Task => Completion.Task;
private TaskCompletionSource<T> Completion { get; }
public void CancelBeforeExecution()
{
if (Interlocked.CompareExchange(ref started, 1, 0) == 0)
{
Completion.TrySetCanceled(cancellationToken);
cancellationRegistration.Dispose();
}
}
public void Execute()
{
if (Interlocked.CompareExchange(ref started, 1, 0) != 0)
{
cancellationRegistration.Dispose();
return;
}
cancellationRegistration.Dispose();
if (cancellationToken.IsCancellationRequested)
{
Completion.TrySetCanceled(cancellationToken);
return;
}
try
{
Completion.TrySetResult(command());
}
catch (Exception exception)
{
Completion.TrySetException(exception);
}
}
}