Issue #23: implement sta runtime and message pump

This commit is contained in:
Joseph Doherty
2026-04-26 17:18:53 -04:00
parent 603aff7004
commit e81682e367
8 changed files with 659 additions and 0 deletions
@@ -0,0 +1,152 @@
using System;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Worker.Sta;
namespace MxGateway.Worker.Tests.Sta;
public sealed class StaRuntimeTests
{
[Fact]
public async Task InvokeAsync_ExecutesCommandOnStaThread()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
StaCommandObservation observation = await runtime.InvokeAsync(
() => new StaCommandObservation(
Thread.CurrentThread.ManagedThreadId,
Thread.CurrentThread.GetApartmentState()));
Assert.Equal(runtime.StaThreadId, observation.ThreadId);
Assert.Equal(initializer.InitializeThreadId, observation.ThreadId);
Assert.Equal(ApartmentState.STA, observation.ApartmentState);
}
[Fact]
public async Task InvokeAsync_WakesIdlePumpForQueuedCommand()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = new(
initializer,
new StaMessagePump(),
TimeSpan.FromSeconds(30));
runtime.Start();
Stopwatch stopwatch = Stopwatch.StartNew();
int threadId = await runtime.InvokeAsync(() => Thread.CurrentThread.ManagedThreadId);
stopwatch.Stop();
Assert.Equal(runtime.StaThreadId, threadId);
Assert.True(
stopwatch.Elapsed < TimeSpan.FromSeconds(2),
$"Command took {stopwatch.Elapsed} to execute, so the command wake event did not wake the STA promptly.");
}
[Fact]
public void Shutdown_StopsThreadAndUninitializesComApartment()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
bool stopped = runtime.Shutdown(TimeSpan.FromSeconds(2));
Assert.True(stopped);
Assert.False(runtime.IsRunning);
Assert.Equal(1, initializer.InitializeCount);
Assert.Equal(1, initializer.UninitializeCount);
Assert.Equal(initializer.InitializeThreadId, initializer.UninitializeThreadId);
}
[Fact]
public void LastActivityUtc_UpdatesWhilePumpIsIdle()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
DateTimeOffset firstActivity = runtime.LastActivityUtc;
bool updated = SpinWait.SpinUntil(
() => runtime.LastActivityUtc > firstActivity,
TimeSpan.FromSeconds(2));
Assert.True(updated);
}
[Fact]
public async Task InvokeAsync_CommandException_FaultsReturnedTaskWithoutStoppingRuntime()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
InvalidOperationException exception = await Assert.ThrowsAsync<InvalidOperationException>(
() => runtime.InvokeAsync<int>(() => throw new InvalidOperationException("command failed")));
int threadId = await runtime.InvokeAsync(() => Thread.CurrentThread.ManagedThreadId);
Assert.Equal("command failed", exception.Message);
Assert.Equal(runtime.StaThreadId, threadId);
}
[Fact]
public async Task InvokeAsync_AfterShutdown_ReturnsFaultedTask()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
runtime.Shutdown(TimeSpan.FromSeconds(2));
InvalidOperationException exception = await Assert.ThrowsAsync<InvalidOperationException>(
() => runtime.InvokeAsync(() => Thread.CurrentThread.ManagedThreadId));
Assert.Contains("shutting down", exception.Message);
}
private static StaRuntime CreateRuntime(RecordingComApartmentInitializer initializer)
{
return new StaRuntime(
initializer,
new StaMessagePump(),
TimeSpan.FromMilliseconds(25));
}
private sealed class StaCommandObservation
{
public StaCommandObservation(int threadId, ApartmentState apartmentState)
{
ThreadId = threadId;
ApartmentState = apartmentState;
}
public int ThreadId { get; }
public ApartmentState ApartmentState { get; }
}
private sealed class RecordingComApartmentInitializer : IStaComApartmentInitializer
{
public int InitializeCount { get; private set; }
public int UninitializeCount { get; private set; }
public int? InitializeThreadId { get; private set; }
public int? UninitializeThreadId { get; private set; }
public void Initialize()
{
InitializeCount++;
InitializeThreadId = Thread.CurrentThread.ManagedThreadId;
}
public void Uninitialize()
{
UninitializeCount++;
UninitializeThreadId = Thread.CurrentThread.ManagedThreadId;
}
}
}
@@ -0,0 +1,8 @@
namespace MxGateway.Worker.Sta;
public interface IStaComApartmentInitializer
{
void Initialize();
void Uninitialize();
}
+8
View File
@@ -0,0 +1,8 @@
namespace MxGateway.Worker.Sta;
internal interface IStaWorkItem
{
void CancelBeforeExecution();
void Execute();
}
@@ -0,0 +1,31 @@
using System;
using System.Runtime.InteropServices;
namespace MxGateway.Worker.Sta;
public sealed class StaComApartmentInitializer : IStaComApartmentInitializer
{
private const uint CoInitializeApartmentThreaded = 0x2;
private const int SOk = 0;
private const int SFalse = 1;
public void Initialize()
{
int hresult = CoInitializeEx(IntPtr.Zero, CoInitializeApartmentThreaded);
if (hresult != SOk && hresult != SFalse)
{
throw new COMException("Failed to initialize the worker STA COM apartment.", hresult);
}
}
public void Uninitialize()
{
CoUninitialize();
}
[DllImport("ole32.dll")]
private static extern int CoInitializeEx(IntPtr reserved, uint coInit);
[DllImport("ole32.dll")]
private static extern void CoUninitialize();
}
+111
View File
@@ -0,0 +1,111 @@
using System;
using System.Runtime.InteropServices;
using System.Threading;
using Microsoft.Win32.SafeHandles;
namespace MxGateway.Worker.Sta;
public sealed class StaMessagePump
{
private const uint Infinite = 0xFFFFFFFF;
private const uint MsgWaitFailed = 0xFFFFFFFF;
private const uint MwmoInputAvailable = 0x0004;
private const uint PmRemove = 0x0001;
private const uint QsAllInput = 0x04FF;
public void WaitForWorkOrMessages(WaitHandle commandWakeEvent, TimeSpan timeout)
{
if (commandWakeEvent is null)
{
throw new ArgumentNullException(nameof(commandWakeEvent));
}
uint timeoutMilliseconds = ToTimeoutMilliseconds(timeout);
SafeWaitHandle safeHandle = commandWakeEvent.SafeWaitHandle;
IntPtr[] handles = [safeHandle.DangerousGetHandle()];
uint result = MsgWaitForMultipleObjectsEx(
(uint)handles.Length,
handles,
timeoutMilliseconds,
QsAllInput,
MwmoInputAvailable);
if (result == MsgWaitFailed)
{
throw new InvalidOperationException(
"The worker STA message pump failed while waiting for command work or Windows messages.");
}
}
public int PumpPendingMessages()
{
int pumpedMessages = 0;
while (PeekMessage(out NativeMessage message, IntPtr.Zero, 0, 0, PmRemove))
{
TranslateMessage(ref message);
DispatchMessage(ref message);
pumpedMessages++;
}
return pumpedMessages;
}
private static uint ToTimeoutMilliseconds(TimeSpan timeout)
{
if (timeout == Timeout.InfiniteTimeSpan)
{
return Infinite;
}
if (timeout <= TimeSpan.Zero)
{
return 0;
}
return timeout.TotalMilliseconds >= uint.MaxValue
? uint.MaxValue - 1
: (uint)Math.Ceiling(timeout.TotalMilliseconds);
}
[DllImport("user32.dll", SetLastError = true)]
private static extern uint MsgWaitForMultipleObjectsEx(
uint count,
IntPtr[] handles,
uint milliseconds,
uint wakeMask,
uint flags);
[DllImport("user32.dll", SetLastError = true)]
private static extern bool PeekMessage(
out NativeMessage message,
IntPtr windowHandle,
uint messageFilterMin,
uint messageFilterMax,
uint removeMessage);
[DllImport("user32.dll")]
private static extern bool TranslateMessage(ref NativeMessage message);
[DllImport("user32.dll")]
private static extern IntPtr DispatchMessage(ref NativeMessage message);
[StructLayout(LayoutKind.Sequential)]
private struct NativeMessage
{
public IntPtr WindowHandle;
public uint Message;
public UIntPtr WParam;
public IntPtr LParam;
public uint Time;
public NativePoint Point;
}
[StructLayout(LayoutKind.Sequential)]
private struct NativePoint
{
public int X;
public int Y;
}
}
+267
View File
@@ -0,0 +1,267 @@
using System;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;
namespace MxGateway.Worker.Sta;
public sealed class StaRuntime : IDisposable
{
private readonly IStaComApartmentInitializer comApartmentInitializer;
private readonly StaMessagePump messagePump;
private readonly ConcurrentQueue<IStaWorkItem> commandQueue = new();
private readonly AutoResetEvent commandWakeEvent = new(false);
private readonly ManualResetEventSlim startedEvent = new(false);
private readonly ManualResetEventSlim stoppedEvent = new(false);
private readonly object gate = new();
private readonly Thread staThread;
private readonly TimeSpan idlePumpInterval;
private bool disposed;
private bool startRequested;
private bool shutdownRequested;
private Exception? startupException;
private long lastActivityUtcTicks;
private bool comInitialized;
public StaRuntime()
: this(new StaComApartmentInitializer(), new StaMessagePump(), TimeSpan.FromMilliseconds(50))
{
}
public StaRuntime(
IStaComApartmentInitializer comApartmentInitializer,
StaMessagePump messagePump,
TimeSpan idlePumpInterval)
{
this.comApartmentInitializer = comApartmentInitializer
?? throw new ArgumentNullException(nameof(comApartmentInitializer));
this.messagePump = messagePump ?? throw new ArgumentNullException(nameof(messagePump));
if (idlePumpInterval <= TimeSpan.Zero)
{
throw new ArgumentOutOfRangeException(
nameof(idlePumpInterval),
"The idle pump interval must be greater than zero.");
}
this.idlePumpInterval = idlePumpInterval;
lastActivityUtcTicks = DateTimeOffset.UtcNow.UtcTicks;
staThread = new Thread(ThreadMain)
{
IsBackground = true,
Name = "MxGateway.Worker.STA"
};
staThread.SetApartmentState(ApartmentState.STA);
}
public int? StaThreadId { get; private set; }
public DateTimeOffset LastActivityUtc =>
new(new DateTime(Volatile.Read(ref lastActivityUtcTicks), DateTimeKind.Utc));
public bool IsRunning => startedEvent.IsSet && !stoppedEvent.IsSet;
public void Start()
{
ThrowIfDisposed();
lock (gate)
{
if (shutdownRequested)
{
throw new InvalidOperationException("The worker STA runtime is shutting down.");
}
if (!startRequested)
{
startRequested = true;
staThread.Start();
}
}
startedEvent.Wait();
if (startupException is not null)
{
throw new InvalidOperationException(
"The worker STA runtime failed to initialize.",
startupException);
}
}
public Task InvokeAsync(Action command, CancellationToken cancellationToken = default)
{
if (command is null)
{
throw new ArgumentNullException(nameof(command));
}
return InvokeAsync(
() =>
{
command();
return true;
},
cancellationToken);
}
public Task<T> InvokeAsync<T>(Func<T> command, CancellationToken cancellationToken = default)
{
if (command is null)
{
throw new ArgumentNullException(nameof(command));
}
ThrowIfDisposed();
if (cancellationToken.IsCancellationRequested)
{
return Task.FromCanceled<T>(cancellationToken);
}
StaWorkItem<T> workItem = new(command, cancellationToken);
lock (gate)
{
if (shutdownRequested)
{
return Task.FromException<T>(
new InvalidOperationException("The worker STA runtime is shutting down."));
}
commandQueue.Enqueue(workItem);
}
commandWakeEvent.Set();
return workItem.Task;
}
public bool Shutdown(TimeSpan timeout)
{
if (timeout < TimeSpan.Zero && timeout != Timeout.InfiniteTimeSpan)
{
throw new ArgumentOutOfRangeException(nameof(timeout));
}
lock (gate)
{
shutdownRequested = true;
}
commandWakeEvent.Set();
if (!startedEvent.IsSet && !staThread.IsAlive)
{
CancelQueuedCommands();
stoppedEvent.Set();
return true;
}
bool stopped = stoppedEvent.Wait(timeout);
if (stopped)
{
CancelQueuedCommands();
}
return stopped;
}
public void Dispose()
{
if (disposed)
{
return;
}
bool stopped = Shutdown(TimeSpan.FromSeconds(5));
if (stopped)
{
commandWakeEvent.Dispose();
startedEvent.Dispose();
stoppedEvent.Dispose();
}
disposed = true;
}
private void ThreadMain()
{
try
{
StaThreadId = Thread.CurrentThread.ManagedThreadId;
comApartmentInitializer.Initialize();
comInitialized = true;
MarkActivity();
startedEvent.Set();
while (!IsShutdownRequested())
{
ProcessQueuedCommands();
messagePump.WaitForWorkOrMessages(commandWakeEvent, idlePumpInterval);
messagePump.PumpPendingMessages();
MarkActivity();
}
ProcessQueuedCommands();
}
catch (Exception exception)
{
startupException = exception;
startedEvent.Set();
}
finally
{
CancelQueuedCommands();
try
{
if (comInitialized)
{
comApartmentInitializer.Uninitialize();
}
}
finally
{
MarkActivity();
stoppedEvent.Set();
}
}
}
private void ProcessQueuedCommands()
{
while (commandQueue.TryDequeue(out IStaWorkItem? workItem))
{
MarkActivity();
workItem.Execute();
MarkActivity();
}
}
private void CancelQueuedCommands()
{
while (commandQueue.TryDequeue(out IStaWorkItem? workItem))
{
workItem.CancelBeforeExecution();
}
}
private bool IsShutdownRequested()
{
lock (gate)
{
return shutdownRequested;
}
}
private void MarkActivity()
{
Volatile.Write(ref lastActivityUtcTicks, DateTimeOffset.UtcNow.UtcTicks);
}
private void ThrowIfDisposed()
{
if (disposed)
{
throw new ObjectDisposedException(nameof(StaRuntime));
}
}
}
+71
View File
@@ -0,0 +1,71 @@
using System;
using System.Threading;
using System.Threading.Tasks;
namespace MxGateway.Worker.Sta;
internal sealed class StaWorkItem<T> : IStaWorkItem
{
private readonly Func<T> command;
private readonly CancellationToken cancellationToken;
private readonly CancellationTokenRegistration cancellationRegistration;
private int started;
public StaWorkItem(Func<T> command, CancellationToken cancellationToken)
{
this.command = command ?? throw new ArgumentNullException(nameof(command));
this.cancellationToken = cancellationToken;
Completion = new TaskCompletionSource<T>(TaskCreationOptions.RunContinuationsAsynchronously);
if (cancellationToken.CanBeCanceled)
{
cancellationRegistration = cancellationToken.Register(
() =>
{
if (Interlocked.CompareExchange(ref started, 1, 0) == 0)
{
Completion.TrySetCanceled(cancellationToken);
}
});
}
}
public Task<T> Task => Completion.Task;
private TaskCompletionSource<T> Completion { get; }
public void CancelBeforeExecution()
{
if (Interlocked.CompareExchange(ref started, 1, 0) == 0)
{
Completion.TrySetCanceled(cancellationToken);
cancellationRegistration.Dispose();
}
}
public void Execute()
{
if (Interlocked.CompareExchange(ref started, 1, 0) != 0)
{
cancellationRegistration.Dispose();
return;
}
cancellationRegistration.Dispose();
if (cancellationToken.IsCancellationRequested)
{
Completion.TrySetCanceled(cancellationToken);
return;
}
try
{
Completion.TrySetResult(command());
}
catch (Exception exception)
{
Completion.TrySetException(exception);
}
}
}