using System; using System.Runtime.InteropServices; using System.Threading.Tasks; using MxGateway.Contracts.Proto; using MxGateway.Worker.MxAccess; using MxGateway.Worker.Sta; namespace MxGateway.Worker.Tests.MxAccess; public sealed class MxAccessCommandExecutorTests { [Fact] public async Task DispatchAsync_Register_CallsMxAccessOnStaAndPreservesServerHandle() { FakeMxAccessComObjectFactory factory = new(new FakeMxAccessComObject(registerHandle: 42)); using StaRuntime runtime = CreateRuntime(); using MxAccessStaSession session = new(runtime, factory, new NoopEventSink()); await session.StartAsync(workerProcessId: 1234); MxCommandReply reply = await session.DispatchAsync(CreateRegisterCommand("correlation-1", "client-a")); Assert.Equal(ProtocolStatusCode.Ok, reply.ProtocolStatus.Code); Assert.True(reply.HasHresult); Assert.Equal(0, reply.Hresult); Assert.Equal(42, reply.Register.ServerHandle); Assert.Equal(MxDataType.Integer, reply.ReturnValue.DataType); Assert.Equal(42, reply.ReturnValue.Int32Value); Assert.Equal(runtime.StaThreadId, factory.FakeComObject.RegisterThreadId); Assert.Equal("client-a", factory.FakeComObject.RegisteredClientName); RegisteredServerHandle registeredServerHandle = Assert.Single( await session.GetRegisteredServerHandlesAsync()); Assert.Equal(42, registeredServerHandle.ServerHandle); Assert.Equal("client-a", registeredServerHandle.ClientName); } [Fact] public async Task DispatchAsync_Unregister_CallsMxAccessOnStaAndRemovesTrackedServerHandle() { FakeMxAccessComObject fakeComObject = new(registerHandle: 43); FakeMxAccessComObjectFactory factory = new(fakeComObject); using StaRuntime runtime = CreateRuntime(); using MxAccessStaSession session = new(runtime, factory, new NoopEventSink()); await session.StartAsync(workerProcessId: 1234); await session.DispatchAsync(CreateRegisterCommand("register", "client-a")); MxCommandReply reply = await session.DispatchAsync(CreateUnregisterCommand("unregister", 43)); Assert.Equal(ProtocolStatusCode.Ok, reply.ProtocolStatus.Code); Assert.Equal(43, fakeComObject.UnregisteredServerHandle); Assert.Equal(runtime.StaThreadId, fakeComObject.UnregisterThreadId); Assert.Empty(await session.GetRegisteredServerHandlesAsync()); } [Fact] public async Task DispatchAsync_UnregisterWhenMxAccessThrows_PreservesHResultAndDoesNotRewriteFailure() { const int hresult = unchecked((int)0x80070057); FakeMxAccessComObject fakeComObject = new( registerHandle: 44, unregisterException: new COMException("Invalid handle.", hresult)); FakeMxAccessComObjectFactory factory = new(fakeComObject); using StaRuntime runtime = CreateRuntime(); using MxAccessStaSession session = new(runtime, factory, new NoopEventSink()); await session.StartAsync(workerProcessId: 1234); await session.DispatchAsync(CreateRegisterCommand("register-before-failure", "client-a")); MxCommandReply reply = await session.DispatchAsync(CreateUnregisterCommand("invalid-unregister", 44)); Assert.Equal(ProtocolStatusCode.MxaccessFailure, reply.ProtocolStatus.Code); Assert.True(reply.HasHresult); Assert.Equal(hresult, reply.Hresult); Assert.Contains("0x80070057", reply.DiagnosticMessage); Assert.Equal(44, fakeComObject.UnregisteredServerHandle); RegisteredServerHandle registeredServerHandle = Assert.Single( await session.GetRegisteredServerHandlesAsync()); Assert.Equal(44, registeredServerHandle.ServerHandle); } [Fact] public async Task DispatchAsync_RegisterWithoutPayload_ReturnsInvalidRequest() { FakeMxAccessComObjectFactory factory = new(new FakeMxAccessComObject(registerHandle: 45)); using StaRuntime runtime = CreateRuntime(); using MxAccessStaSession session = new(runtime, factory, new NoopEventSink()); await session.StartAsync(workerProcessId: 1234); MxCommandReply reply = await session.DispatchAsync(new StaCommand( "session-1", "missing-payload", new MxCommand { Kind = MxCommandKind.Register, })); Assert.Equal(ProtocolStatusCode.InvalidRequest, reply.ProtocolStatus.Code); Assert.Null(factory.FakeComObject.RegisteredClientName); } private static StaCommand CreateRegisterCommand( string correlationId, string clientName) { return new StaCommand( "session-1", correlationId, new MxCommand { Kind = MxCommandKind.Register, Register = new RegisterCommand { ClientName = clientName, }, }); } private static StaCommand CreateUnregisterCommand( string correlationId, int serverHandle) { return new StaCommand( "session-1", correlationId, new MxCommand { Kind = MxCommandKind.Unregister, Unregister = new UnregisterCommand { ServerHandle = serverHandle, }, }); } private static StaRuntime CreateRuntime() { return new StaRuntime( new NoopComApartmentInitializer(), new StaMessagePump(), TimeSpan.FromMilliseconds(25)); } private sealed class FakeMxAccessComObject { private readonly int registerHandle; private readonly Exception? unregisterException; public FakeMxAccessComObject( int registerHandle, Exception? unregisterException = null) { this.registerHandle = registerHandle; this.unregisterException = unregisterException; } public string? RegisteredClientName { get; private set; } public int? RegisterThreadId { get; private set; } public int? UnregisteredServerHandle { get; private set; } public int? UnregisterThreadId { get; private set; } public int Register(string clientName) { RegisteredClientName = clientName; RegisterThreadId = Environment.CurrentManagedThreadId; return registerHandle; } public void Unregister(int serverHandle) { UnregisteredServerHandle = serverHandle; UnregisterThreadId = Environment.CurrentManagedThreadId; if (unregisterException is not null) { throw unregisterException; } } } private sealed class FakeMxAccessComObjectFactory : IMxAccessComObjectFactory { public FakeMxAccessComObjectFactory(FakeMxAccessComObject fakeComObject) { FakeComObject = fakeComObject; } public FakeMxAccessComObject FakeComObject { get; } public object Create() { return FakeComObject; } } private sealed class NoopEventSink : IMxAccessEventSink { public void Attach(object mxAccessComObject) { } public void Detach() { } } private sealed class NoopComApartmentInitializer : IStaComApartmentInitializer { public void Initialize() { } public void Uninitialize() { } } }