diff --git a/docs/mxaccess-worker-instance-design.md b/docs/mxaccess-worker-instance-design.md index 0682494..60eb2d0 100644 --- a/docs/mxaccess-worker-instance-design.md +++ b/docs/mxaccess-worker-instance-design.md @@ -250,6 +250,17 @@ The loop should update a heartbeat timestamp after: - finishing a command, - processing an MXAccess event. +`StaRuntime` implements this runtime boundary in the worker. It starts one +background thread named `MxGateway.Worker.STA`, sets it to `ApartmentState.STA`, +initializes COM through `StaComApartmentInitializer`, and runs +`StaMessagePump`. Commands are scheduled through `InvokeAsync`; the command +queue signals an `AutoResetEvent` so `MsgWaitForMultipleObjectsEx` can wake the +STA without busy-waiting. `LastActivityUtc` records pump, command, startup, and +shutdown activity so the future heartbeat/watchdog can report whether the STA +is still responsive. Shutdown marks the runtime as closing, wakes the pump, +rejects new commands, cancels queued work, uninitializes COM on the STA, and +waits for the thread to exit. + ## COM Creation The MXAccess analysis source at `C:\Users\dohertj2\Desktop\mxaccess` identifies diff --git a/src/MxGateway.Worker.Tests/Sta/StaRuntimeTests.cs b/src/MxGateway.Worker.Tests/Sta/StaRuntimeTests.cs new file mode 100644 index 0000000..43a1014 --- /dev/null +++ b/src/MxGateway.Worker.Tests/Sta/StaRuntimeTests.cs @@ -0,0 +1,152 @@ +using System; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using MxGateway.Worker.Sta; + +namespace MxGateway.Worker.Tests.Sta; + +public sealed class StaRuntimeTests +{ + [Fact] + public async Task InvokeAsync_ExecutesCommandOnStaThread() + { + RecordingComApartmentInitializer initializer = new(); + using StaRuntime runtime = CreateRuntime(initializer); + + runtime.Start(); + + StaCommandObservation observation = await runtime.InvokeAsync( + () => new StaCommandObservation( + Thread.CurrentThread.ManagedThreadId, + Thread.CurrentThread.GetApartmentState())); + + Assert.Equal(runtime.StaThreadId, observation.ThreadId); + Assert.Equal(initializer.InitializeThreadId, observation.ThreadId); + Assert.Equal(ApartmentState.STA, observation.ApartmentState); + } + + [Fact] + public async Task InvokeAsync_WakesIdlePumpForQueuedCommand() + { + RecordingComApartmentInitializer initializer = new(); + using StaRuntime runtime = new( + initializer, + new StaMessagePump(), + TimeSpan.FromSeconds(30)); + runtime.Start(); + Stopwatch stopwatch = Stopwatch.StartNew(); + + int threadId = await runtime.InvokeAsync(() => Thread.CurrentThread.ManagedThreadId); + + stopwatch.Stop(); + Assert.Equal(runtime.StaThreadId, threadId); + Assert.True( + stopwatch.Elapsed < TimeSpan.FromSeconds(2), + $"Command took {stopwatch.Elapsed} to execute, so the command wake event did not wake the STA promptly."); + } + + [Fact] + public void Shutdown_StopsThreadAndUninitializesComApartment() + { + RecordingComApartmentInitializer initializer = new(); + using StaRuntime runtime = CreateRuntime(initializer); + runtime.Start(); + + bool stopped = runtime.Shutdown(TimeSpan.FromSeconds(2)); + + Assert.True(stopped); + Assert.False(runtime.IsRunning); + Assert.Equal(1, initializer.InitializeCount); + Assert.Equal(1, initializer.UninitializeCount); + Assert.Equal(initializer.InitializeThreadId, initializer.UninitializeThreadId); + } + + [Fact] + public void LastActivityUtc_UpdatesWhilePumpIsIdle() + { + RecordingComApartmentInitializer initializer = new(); + using StaRuntime runtime = CreateRuntime(initializer); + runtime.Start(); + DateTimeOffset firstActivity = runtime.LastActivityUtc; + + bool updated = SpinWait.SpinUntil( + () => runtime.LastActivityUtc > firstActivity, + TimeSpan.FromSeconds(2)); + + Assert.True(updated); + } + + [Fact] + public async Task InvokeAsync_CommandException_FaultsReturnedTaskWithoutStoppingRuntime() + { + RecordingComApartmentInitializer initializer = new(); + using StaRuntime runtime = CreateRuntime(initializer); + runtime.Start(); + + InvalidOperationException exception = await Assert.ThrowsAsync( + () => runtime.InvokeAsync(() => throw new InvalidOperationException("command failed"))); + + int threadId = await runtime.InvokeAsync(() => Thread.CurrentThread.ManagedThreadId); + Assert.Equal("command failed", exception.Message); + Assert.Equal(runtime.StaThreadId, threadId); + } + + [Fact] + public async Task InvokeAsync_AfterShutdown_ReturnsFaultedTask() + { + RecordingComApartmentInitializer initializer = new(); + using StaRuntime runtime = CreateRuntime(initializer); + runtime.Start(); + runtime.Shutdown(TimeSpan.FromSeconds(2)); + + InvalidOperationException exception = await Assert.ThrowsAsync( + () => runtime.InvokeAsync(() => Thread.CurrentThread.ManagedThreadId)); + + Assert.Contains("shutting down", exception.Message); + } + + private static StaRuntime CreateRuntime(RecordingComApartmentInitializer initializer) + { + return new StaRuntime( + initializer, + new StaMessagePump(), + TimeSpan.FromMilliseconds(25)); + } + + private sealed class StaCommandObservation + { + public StaCommandObservation(int threadId, ApartmentState apartmentState) + { + ThreadId = threadId; + ApartmentState = apartmentState; + } + + public int ThreadId { get; } + + public ApartmentState ApartmentState { get; } + } + + private sealed class RecordingComApartmentInitializer : IStaComApartmentInitializer + { + public int InitializeCount { get; private set; } + + public int UninitializeCount { get; private set; } + + public int? InitializeThreadId { get; private set; } + + public int? UninitializeThreadId { get; private set; } + + public void Initialize() + { + InitializeCount++; + InitializeThreadId = Thread.CurrentThread.ManagedThreadId; + } + + public void Uninitialize() + { + UninitializeCount++; + UninitializeThreadId = Thread.CurrentThread.ManagedThreadId; + } + } +} diff --git a/src/MxGateway.Worker/Sta/IStaComApartmentInitializer.cs b/src/MxGateway.Worker/Sta/IStaComApartmentInitializer.cs new file mode 100644 index 0000000..782a8eb --- /dev/null +++ b/src/MxGateway.Worker/Sta/IStaComApartmentInitializer.cs @@ -0,0 +1,8 @@ +namespace MxGateway.Worker.Sta; + +public interface IStaComApartmentInitializer +{ + void Initialize(); + + void Uninitialize(); +} diff --git a/src/MxGateway.Worker/Sta/IStaWorkItem.cs b/src/MxGateway.Worker/Sta/IStaWorkItem.cs new file mode 100644 index 0000000..735385a --- /dev/null +++ b/src/MxGateway.Worker/Sta/IStaWorkItem.cs @@ -0,0 +1,8 @@ +namespace MxGateway.Worker.Sta; + +internal interface IStaWorkItem +{ + void CancelBeforeExecution(); + + void Execute(); +} diff --git a/src/MxGateway.Worker/Sta/StaComApartmentInitializer.cs b/src/MxGateway.Worker/Sta/StaComApartmentInitializer.cs new file mode 100644 index 0000000..2f6b2e0 --- /dev/null +++ b/src/MxGateway.Worker/Sta/StaComApartmentInitializer.cs @@ -0,0 +1,31 @@ +using System; +using System.Runtime.InteropServices; + +namespace MxGateway.Worker.Sta; + +public sealed class StaComApartmentInitializer : IStaComApartmentInitializer +{ + private const uint CoInitializeApartmentThreaded = 0x2; + private const int SOk = 0; + private const int SFalse = 1; + + public void Initialize() + { + int hresult = CoInitializeEx(IntPtr.Zero, CoInitializeApartmentThreaded); + if (hresult != SOk && hresult != SFalse) + { + throw new COMException("Failed to initialize the worker STA COM apartment.", hresult); + } + } + + public void Uninitialize() + { + CoUninitialize(); + } + + [DllImport("ole32.dll")] + private static extern int CoInitializeEx(IntPtr reserved, uint coInit); + + [DllImport("ole32.dll")] + private static extern void CoUninitialize(); +} diff --git a/src/MxGateway.Worker/Sta/StaMessagePump.cs b/src/MxGateway.Worker/Sta/StaMessagePump.cs new file mode 100644 index 0000000..e0a0f21 --- /dev/null +++ b/src/MxGateway.Worker/Sta/StaMessagePump.cs @@ -0,0 +1,111 @@ +using System; +using System.Runtime.InteropServices; +using System.Threading; +using Microsoft.Win32.SafeHandles; + +namespace MxGateway.Worker.Sta; + +public sealed class StaMessagePump +{ + private const uint Infinite = 0xFFFFFFFF; + private const uint MsgWaitFailed = 0xFFFFFFFF; + private const uint MwmoInputAvailable = 0x0004; + private const uint PmRemove = 0x0001; + private const uint QsAllInput = 0x04FF; + + public void WaitForWorkOrMessages(WaitHandle commandWakeEvent, TimeSpan timeout) + { + if (commandWakeEvent is null) + { + throw new ArgumentNullException(nameof(commandWakeEvent)); + } + + uint timeoutMilliseconds = ToTimeoutMilliseconds(timeout); + + SafeWaitHandle safeHandle = commandWakeEvent.SafeWaitHandle; + IntPtr[] handles = [safeHandle.DangerousGetHandle()]; + uint result = MsgWaitForMultipleObjectsEx( + (uint)handles.Length, + handles, + timeoutMilliseconds, + QsAllInput, + MwmoInputAvailable); + + if (result == MsgWaitFailed) + { + throw new InvalidOperationException( + "The worker STA message pump failed while waiting for command work or Windows messages."); + } + } + + public int PumpPendingMessages() + { + int pumpedMessages = 0; + + while (PeekMessage(out NativeMessage message, IntPtr.Zero, 0, 0, PmRemove)) + { + TranslateMessage(ref message); + DispatchMessage(ref message); + pumpedMessages++; + } + + return pumpedMessages; + } + + private static uint ToTimeoutMilliseconds(TimeSpan timeout) + { + if (timeout == Timeout.InfiniteTimeSpan) + { + return Infinite; + } + + if (timeout <= TimeSpan.Zero) + { + return 0; + } + + return timeout.TotalMilliseconds >= uint.MaxValue + ? uint.MaxValue - 1 + : (uint)Math.Ceiling(timeout.TotalMilliseconds); + } + + [DllImport("user32.dll", SetLastError = true)] + private static extern uint MsgWaitForMultipleObjectsEx( + uint count, + IntPtr[] handles, + uint milliseconds, + uint wakeMask, + uint flags); + + [DllImport("user32.dll", SetLastError = true)] + private static extern bool PeekMessage( + out NativeMessage message, + IntPtr windowHandle, + uint messageFilterMin, + uint messageFilterMax, + uint removeMessage); + + [DllImport("user32.dll")] + private static extern bool TranslateMessage(ref NativeMessage message); + + [DllImport("user32.dll")] + private static extern IntPtr DispatchMessage(ref NativeMessage message); + + [StructLayout(LayoutKind.Sequential)] + private struct NativeMessage + { + public IntPtr WindowHandle; + public uint Message; + public UIntPtr WParam; + public IntPtr LParam; + public uint Time; + public NativePoint Point; + } + + [StructLayout(LayoutKind.Sequential)] + private struct NativePoint + { + public int X; + public int Y; + } +} diff --git a/src/MxGateway.Worker/Sta/StaRuntime.cs b/src/MxGateway.Worker/Sta/StaRuntime.cs new file mode 100644 index 0000000..7d88401 --- /dev/null +++ b/src/MxGateway.Worker/Sta/StaRuntime.cs @@ -0,0 +1,267 @@ +using System; +using System.Collections.Concurrent; +using System.Threading; +using System.Threading.Tasks; + +namespace MxGateway.Worker.Sta; + +public sealed class StaRuntime : IDisposable +{ + private readonly IStaComApartmentInitializer comApartmentInitializer; + private readonly StaMessagePump messagePump; + private readonly ConcurrentQueue commandQueue = new(); + private readonly AutoResetEvent commandWakeEvent = new(false); + private readonly ManualResetEventSlim startedEvent = new(false); + private readonly ManualResetEventSlim stoppedEvent = new(false); + private readonly object gate = new(); + private readonly Thread staThread; + private readonly TimeSpan idlePumpInterval; + private bool disposed; + private bool startRequested; + private bool shutdownRequested; + private Exception? startupException; + private long lastActivityUtcTicks; + private bool comInitialized; + + public StaRuntime() + : this(new StaComApartmentInitializer(), new StaMessagePump(), TimeSpan.FromMilliseconds(50)) + { + } + + public StaRuntime( + IStaComApartmentInitializer comApartmentInitializer, + StaMessagePump messagePump, + TimeSpan idlePumpInterval) + { + this.comApartmentInitializer = comApartmentInitializer + ?? throw new ArgumentNullException(nameof(comApartmentInitializer)); + this.messagePump = messagePump ?? throw new ArgumentNullException(nameof(messagePump)); + + if (idlePumpInterval <= TimeSpan.Zero) + { + throw new ArgumentOutOfRangeException( + nameof(idlePumpInterval), + "The idle pump interval must be greater than zero."); + } + + this.idlePumpInterval = idlePumpInterval; + lastActivityUtcTicks = DateTimeOffset.UtcNow.UtcTicks; + staThread = new Thread(ThreadMain) + { + IsBackground = true, + Name = "MxGateway.Worker.STA" + }; + staThread.SetApartmentState(ApartmentState.STA); + } + + public int? StaThreadId { get; private set; } + + public DateTimeOffset LastActivityUtc => + new(new DateTime(Volatile.Read(ref lastActivityUtcTicks), DateTimeKind.Utc)); + + public bool IsRunning => startedEvent.IsSet && !stoppedEvent.IsSet; + + public void Start() + { + ThrowIfDisposed(); + + lock (gate) + { + if (shutdownRequested) + { + throw new InvalidOperationException("The worker STA runtime is shutting down."); + } + + if (!startRequested) + { + startRequested = true; + staThread.Start(); + } + } + + startedEvent.Wait(); + if (startupException is not null) + { + throw new InvalidOperationException( + "The worker STA runtime failed to initialize.", + startupException); + } + } + + public Task InvokeAsync(Action command, CancellationToken cancellationToken = default) + { + if (command is null) + { + throw new ArgumentNullException(nameof(command)); + } + + return InvokeAsync( + () => + { + command(); + return true; + }, + cancellationToken); + } + + public Task InvokeAsync(Func command, CancellationToken cancellationToken = default) + { + if (command is null) + { + throw new ArgumentNullException(nameof(command)); + } + + ThrowIfDisposed(); + + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + StaWorkItem workItem = new(command, cancellationToken); + + lock (gate) + { + if (shutdownRequested) + { + return Task.FromException( + new InvalidOperationException("The worker STA runtime is shutting down.")); + } + + commandQueue.Enqueue(workItem); + } + + commandWakeEvent.Set(); + return workItem.Task; + } + + public bool Shutdown(TimeSpan timeout) + { + if (timeout < TimeSpan.Zero && timeout != Timeout.InfiniteTimeSpan) + { + throw new ArgumentOutOfRangeException(nameof(timeout)); + } + + lock (gate) + { + shutdownRequested = true; + } + + commandWakeEvent.Set(); + + if (!startedEvent.IsSet && !staThread.IsAlive) + { + CancelQueuedCommands(); + stoppedEvent.Set(); + return true; + } + + bool stopped = stoppedEvent.Wait(timeout); + if (stopped) + { + CancelQueuedCommands(); + } + + return stopped; + } + + public void Dispose() + { + if (disposed) + { + return; + } + + bool stopped = Shutdown(TimeSpan.FromSeconds(5)); + if (stopped) + { + commandWakeEvent.Dispose(); + startedEvent.Dispose(); + stoppedEvent.Dispose(); + } + + disposed = true; + } + + private void ThreadMain() + { + try + { + StaThreadId = Thread.CurrentThread.ManagedThreadId; + comApartmentInitializer.Initialize(); + comInitialized = true; + MarkActivity(); + startedEvent.Set(); + + while (!IsShutdownRequested()) + { + ProcessQueuedCommands(); + messagePump.WaitForWorkOrMessages(commandWakeEvent, idlePumpInterval); + messagePump.PumpPendingMessages(); + MarkActivity(); + } + + ProcessQueuedCommands(); + } + catch (Exception exception) + { + startupException = exception; + startedEvent.Set(); + } + finally + { + CancelQueuedCommands(); + try + { + if (comInitialized) + { + comApartmentInitializer.Uninitialize(); + } + } + finally + { + MarkActivity(); + stoppedEvent.Set(); + } + } + } + + private void ProcessQueuedCommands() + { + while (commandQueue.TryDequeue(out IStaWorkItem? workItem)) + { + MarkActivity(); + workItem.Execute(); + MarkActivity(); + } + } + + private void CancelQueuedCommands() + { + while (commandQueue.TryDequeue(out IStaWorkItem? workItem)) + { + workItem.CancelBeforeExecution(); + } + } + + private bool IsShutdownRequested() + { + lock (gate) + { + return shutdownRequested; + } + } + + private void MarkActivity() + { + Volatile.Write(ref lastActivityUtcTicks, DateTimeOffset.UtcNow.UtcTicks); + } + + private void ThrowIfDisposed() + { + if (disposed) + { + throw new ObjectDisposedException(nameof(StaRuntime)); + } + } +} diff --git a/src/MxGateway.Worker/Sta/StaWorkItem.cs b/src/MxGateway.Worker/Sta/StaWorkItem.cs new file mode 100644 index 0000000..3a7f8e0 --- /dev/null +++ b/src/MxGateway.Worker/Sta/StaWorkItem.cs @@ -0,0 +1,71 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace MxGateway.Worker.Sta; + +internal sealed class StaWorkItem : IStaWorkItem +{ + private readonly Func command; + private readonly CancellationToken cancellationToken; + private readonly CancellationTokenRegistration cancellationRegistration; + private int started; + + public StaWorkItem(Func command, CancellationToken cancellationToken) + { + this.command = command ?? throw new ArgumentNullException(nameof(command)); + this.cancellationToken = cancellationToken; + Completion = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + if (cancellationToken.CanBeCanceled) + { + cancellationRegistration = cancellationToken.Register( + () => + { + if (Interlocked.CompareExchange(ref started, 1, 0) == 0) + { + Completion.TrySetCanceled(cancellationToken); + } + }); + } + } + + public Task Task => Completion.Task; + + private TaskCompletionSource Completion { get; } + + public void CancelBeforeExecution() + { + if (Interlocked.CompareExchange(ref started, 1, 0) == 0) + { + Completion.TrySetCanceled(cancellationToken); + cancellationRegistration.Dispose(); + } + } + + public void Execute() + { + if (Interlocked.CompareExchange(ref started, 1, 0) != 0) + { + cancellationRegistration.Dispose(); + return; + } + + cancellationRegistration.Dispose(); + + if (cancellationToken.IsCancellationRequested) + { + Completion.TrySetCanceled(cancellationToken); + return; + } + + try + { + Completion.TrySetResult(command()); + } + catch (Exception exception) + { + Completion.TrySetException(exception); + } + } +}