using System; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; using MxGateway.Contracts.Proto; using MxGateway.Worker.Sta; namespace MxGateway.Worker.Tests.Sta; public sealed class StaCommandDispatcherTests { [Fact] public async Task DispatchAsync_ExecutesCommandsOnStaInQueueOrder() { using StaRuntime runtime = CreateRuntime(); runtime.Start(); RecordingCommandExecutor executor = new(); StaCommandDispatcher dispatcher = new(runtime, executor); Task first = dispatcher.DispatchAsync(CreateCommand("correlation-1", MxCommandKind.Register)); Task second = dispatcher.DispatchAsync(CreateCommand("correlation-2", MxCommandKind.AddItem)); MxCommandReply[] replies = await Task.WhenAll(first, second); Assert.Equal(new[] { "correlation-1", "correlation-2" }, executor.CorrelationIds); Assert.All(executor.ThreadIds, threadId => Assert.Equal(runtime.StaThreadId, threadId)); Assert.Equal("correlation-1", replies[0].CorrelationId); Assert.Equal("correlation-2", replies[1].CorrelationId); Assert.Equal(ProtocolStatusCode.Ok, replies[0].ProtocolStatus.Code); } [Fact] public async Task DispatchAsync_WhenExecutorThrows_ReturnsFailureReplyWithHResult() { using StaRuntime runtime = CreateRuntime(); runtime.Start(); StaCommandDispatcher dispatcher = new( runtime, new ThrowingCommandExecutor(new COMException("provider detail", unchecked((int)0x80070057)))); MxCommandReply reply = await dispatcher.DispatchAsync(CreateCommand("correlation-1", MxCommandKind.Register)); Assert.Equal("session-1", reply.SessionId); Assert.Equal("correlation-1", reply.CorrelationId); Assert.Equal(MxCommandKind.Register, reply.Kind); Assert.Equal(ProtocolStatusCode.MxaccessFailure, reply.ProtocolStatus.Code); Assert.Equal(unchecked((int)0x80070057), reply.Hresult); Assert.Contains("0x80070057", reply.DiagnosticMessage); Assert.DoesNotContain("provider detail", reply.DiagnosticMessage); } [Fact] public async Task DispatchAsync_WhenCanceledBeforeExecution_ReturnsCanceledReplyWithoutExecuting() { using StaRuntime runtime = CreateRuntime(); runtime.Start(); BlockingCommandExecutor executor = new(); StaCommandDispatcher dispatcher = new(runtime, executor); Task blocked = dispatcher.DispatchAsync(CreateCommand("blocked", MxCommandKind.Register)); Assert.True(executor.Started.Wait(TimeSpan.FromSeconds(2))); using CancellationTokenSource cancellation = new(); Task canceled = dispatcher.DispatchAsync( CreateCommand("canceled", MxCommandKind.AddItem, cancellation.Token)); cancellation.Cancel(); executor.Release(); MxCommandReply canceledReply = await canceled; await blocked; Assert.Equal(ProtocolStatusCode.Canceled, canceledReply.ProtocolStatus.Code); Assert.DoesNotContain("canceled", executor.CorrelationIds); } [Fact] public async Task DispatchAsync_WhenCanceledAfterExecutionStarts_StillReturnsLateReply() { using StaRuntime runtime = CreateRuntime(); runtime.Start(); BlockingCommandExecutor executor = new(); StaCommandDispatcher dispatcher = new(runtime, executor); using CancellationTokenSource cancellation = new(); Task replyTask = dispatcher.DispatchAsync( CreateCommand("late-reply", MxCommandKind.Register, cancellation.Token)); Assert.True(executor.Started.Wait(TimeSpan.FromSeconds(2))); cancellation.Cancel(); executor.Release(); MxCommandReply reply = await replyTask; Assert.Equal(ProtocolStatusCode.Ok, reply.ProtocolStatus.Code); Assert.Contains("late-reply", executor.CorrelationIds); } [Fact] public async Task DispatchAsync_WhenShutdownRequested_RejectsNewCommands() { using StaRuntime runtime = CreateRuntime(); runtime.Start(); StaCommandDispatcher dispatcher = new(runtime, new RecordingCommandExecutor()); dispatcher.RequestShutdown(); MxCommandReply reply = await dispatcher.DispatchAsync(CreateCommand("correlation-1", MxCommandKind.Register)); Assert.Equal(ProtocolStatusCode.WorkerUnavailable, reply.ProtocolStatus.Code); Assert.Equal("correlation-1", reply.CorrelationId); } [Fact] public async Task PopulateHeartbeat_ReportsCurrentCorrelationAndPendingCount() { using StaRuntime runtime = CreateRuntime(); runtime.Start(); BlockingCommandExecutor executor = new(); StaCommandDispatcher dispatcher = new(runtime, executor); Task current = dispatcher.DispatchAsync(CreateCommand("current", MxCommandKind.Register)); Assert.True(executor.Started.Wait(TimeSpan.FromSeconds(2))); Task pending = dispatcher.DispatchAsync(CreateCommand("pending", MxCommandKind.AddItem)); WorkerHeartbeat heartbeat = new(); dispatcher.PopulateHeartbeat(heartbeat); Assert.Equal("current", heartbeat.CurrentCommandCorrelationId); Assert.Equal(1u, heartbeat.PendingCommandCount); executor.Release(); await Task.WhenAll(current, pending); } private static StaCommand CreateCommand( string correlationId, MxCommandKind kind, CancellationToken cancellationToken = default) { return new StaCommand( "session-1", correlationId, new MxCommand { Kind = kind, Ping = new PingCommand { Message = correlationId, }, }, cancellationToken: cancellationToken); } private static StaRuntime CreateRuntime() { return new StaRuntime( new NoopComApartmentInitializer(), new StaMessagePump(), TimeSpan.FromMilliseconds(25)); } private sealed class RecordingCommandExecutor : IStaCommandExecutor { private readonly object gate = new(); private readonly List correlationIds = new(); private readonly List threadIds = new(); public IReadOnlyList CorrelationIds { get { lock (gate) { return correlationIds.ToArray(); } } } public IReadOnlyList ThreadIds { get { lock (gate) { return threadIds.ToArray(); } } } public MxCommandReply Execute(StaCommand command) { lock (gate) { correlationIds.Add(command.CorrelationId); threadIds.Add(Thread.CurrentThread.ManagedThreadId); } return new MxCommandReply { ProtocolStatus = new ProtocolStatus { Code = ProtocolStatusCode.Ok, Message = "OK", }, }; } } private sealed class BlockingCommandExecutor : IStaCommandExecutor { private readonly ManualResetEventSlim release = new(false); private readonly object gate = new(); private readonly List correlationIds = new(); public ManualResetEventSlim Started { get; } = new(false); public IReadOnlyList CorrelationIds { get { lock (gate) { return correlationIds.ToArray(); } } } public MxCommandReply Execute(StaCommand command) { lock (gate) { correlationIds.Add(command.CorrelationId); } Started.Set(); release.Wait(TimeSpan.FromSeconds(5)); return new MxCommandReply { ProtocolStatus = new ProtocolStatus { Code = ProtocolStatusCode.Ok, Message = "OK", }, }; } public void Release() { release.Set(); } } private sealed class ThrowingCommandExecutor : IStaCommandExecutor { private readonly Exception exception; public ThrowingCommandExecutor(Exception exception) { this.exception = exception; } public MxCommandReply Execute(StaCommand command) { throw exception; } } private sealed class NoopComApartmentInitializer : IStaComApartmentInitializer { public void Initialize() { } public void Uninitialize() { } } }