Implement worker register and unregister

This commit is contained in:
Joseph Doherty
2026-04-26 18:08:45 -04:00
parent 9b3637257c
commit 556c3bfa83
11 changed files with 578 additions and 2 deletions
@@ -0,0 +1,220 @@
using System;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.MxAccess;
using MxGateway.Worker.Sta;
namespace MxGateway.Worker.Tests.MxAccess;
public sealed class MxAccessCommandExecutorTests
{
[Fact]
public async Task DispatchAsync_Register_CallsMxAccessOnStaAndPreservesServerHandle()
{
FakeMxAccessComObjectFactory factory = new(new FakeMxAccessComObject(registerHandle: 42));
using StaRuntime runtime = CreateRuntime();
using MxAccessStaSession session = new(runtime, factory, new NoopEventSink());
await session.StartAsync(workerProcessId: 1234);
MxCommandReply reply = await session.DispatchAsync(CreateRegisterCommand("correlation-1", "client-a"));
Assert.Equal(ProtocolStatusCode.Ok, reply.ProtocolStatus.Code);
Assert.True(reply.HasHresult);
Assert.Equal(0, reply.Hresult);
Assert.Equal(42, reply.Register.ServerHandle);
Assert.Equal(MxDataType.Integer, reply.ReturnValue.DataType);
Assert.Equal(42, reply.ReturnValue.Int32Value);
Assert.Equal(runtime.StaThreadId, factory.FakeComObject.RegisterThreadId);
Assert.Equal("client-a", factory.FakeComObject.RegisteredClientName);
RegisteredServerHandle registeredServerHandle = Assert.Single(
await session.GetRegisteredServerHandlesAsync());
Assert.Equal(42, registeredServerHandle.ServerHandle);
Assert.Equal("client-a", registeredServerHandle.ClientName);
}
[Fact]
public async Task DispatchAsync_Unregister_CallsMxAccessOnStaAndRemovesTrackedServerHandle()
{
FakeMxAccessComObject fakeComObject = new(registerHandle: 43);
FakeMxAccessComObjectFactory factory = new(fakeComObject);
using StaRuntime runtime = CreateRuntime();
using MxAccessStaSession session = new(runtime, factory, new NoopEventSink());
await session.StartAsync(workerProcessId: 1234);
await session.DispatchAsync(CreateRegisterCommand("register", "client-a"));
MxCommandReply reply = await session.DispatchAsync(CreateUnregisterCommand("unregister", 43));
Assert.Equal(ProtocolStatusCode.Ok, reply.ProtocolStatus.Code);
Assert.Equal(43, fakeComObject.UnregisteredServerHandle);
Assert.Equal(runtime.StaThreadId, fakeComObject.UnregisterThreadId);
Assert.Empty(await session.GetRegisteredServerHandlesAsync());
}
[Fact]
public async Task DispatchAsync_UnregisterWhenMxAccessThrows_PreservesHResultAndDoesNotRewriteFailure()
{
const int hresult = unchecked((int)0x80070057);
FakeMxAccessComObject fakeComObject = new(
registerHandle: 44,
unregisterException: new COMException("Invalid handle.", hresult));
FakeMxAccessComObjectFactory factory = new(fakeComObject);
using StaRuntime runtime = CreateRuntime();
using MxAccessStaSession session = new(runtime, factory, new NoopEventSink());
await session.StartAsync(workerProcessId: 1234);
await session.DispatchAsync(CreateRegisterCommand("register-before-failure", "client-a"));
MxCommandReply reply = await session.DispatchAsync(CreateUnregisterCommand("invalid-unregister", 44));
Assert.Equal(ProtocolStatusCode.MxaccessFailure, reply.ProtocolStatus.Code);
Assert.True(reply.HasHresult);
Assert.Equal(hresult, reply.Hresult);
Assert.Contains("0x80070057", reply.DiagnosticMessage);
Assert.Equal(44, fakeComObject.UnregisteredServerHandle);
RegisteredServerHandle registeredServerHandle = Assert.Single(
await session.GetRegisteredServerHandlesAsync());
Assert.Equal(44, registeredServerHandle.ServerHandle);
}
[Fact]
public async Task DispatchAsync_RegisterWithoutPayload_ReturnsInvalidRequest()
{
FakeMxAccessComObjectFactory factory = new(new FakeMxAccessComObject(registerHandle: 45));
using StaRuntime runtime = CreateRuntime();
using MxAccessStaSession session = new(runtime, factory, new NoopEventSink());
await session.StartAsync(workerProcessId: 1234);
MxCommandReply reply = await session.DispatchAsync(new StaCommand(
"session-1",
"missing-payload",
new MxCommand
{
Kind = MxCommandKind.Register,
}));
Assert.Equal(ProtocolStatusCode.InvalidRequest, reply.ProtocolStatus.Code);
Assert.Null(factory.FakeComObject.RegisteredClientName);
}
private static StaCommand CreateRegisterCommand(
string correlationId,
string clientName)
{
return new StaCommand(
"session-1",
correlationId,
new MxCommand
{
Kind = MxCommandKind.Register,
Register = new RegisterCommand
{
ClientName = clientName,
},
});
}
private static StaCommand CreateUnregisterCommand(
string correlationId,
int serverHandle)
{
return new StaCommand(
"session-1",
correlationId,
new MxCommand
{
Kind = MxCommandKind.Unregister,
Unregister = new UnregisterCommand
{
ServerHandle = serverHandle,
},
});
}
private static StaRuntime CreateRuntime()
{
return new StaRuntime(
new NoopComApartmentInitializer(),
new StaMessagePump(),
TimeSpan.FromMilliseconds(25));
}
private sealed class FakeMxAccessComObject
{
private readonly int registerHandle;
private readonly Exception? unregisterException;
public FakeMxAccessComObject(
int registerHandle,
Exception? unregisterException = null)
{
this.registerHandle = registerHandle;
this.unregisterException = unregisterException;
}
public string? RegisteredClientName { get; private set; }
public int? RegisterThreadId { get; private set; }
public int? UnregisteredServerHandle { get; private set; }
public int? UnregisterThreadId { get; private set; }
public int Register(string clientName)
{
RegisteredClientName = clientName;
RegisterThreadId = Environment.CurrentManagedThreadId;
return registerHandle;
}
public void Unregister(int serverHandle)
{
UnregisteredServerHandle = serverHandle;
UnregisterThreadId = Environment.CurrentManagedThreadId;
if (unregisterException is not null)
{
throw unregisterException;
}
}
}
private sealed class FakeMxAccessComObjectFactory : IMxAccessComObjectFactory
{
public FakeMxAccessComObjectFactory(FakeMxAccessComObject fakeComObject)
{
FakeComObject = fakeComObject;
}
public FakeMxAccessComObject FakeComObject { get; }
public object Create()
{
return FakeComObject;
}
}
private sealed class NoopEventSink : IMxAccessEventSink
{
public void Attach(object mxAccessComObject)
{
}
public void Detach()
{
}
}
private sealed class NoopComApartmentInitializer : IStaComApartmentInitializer
{
public void Initialize()
{
}
public void Uninitialize()
{
}
}
}
@@ -1,6 +1,8 @@
using System;
using System.Threading.Tasks;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.MxAccess;
using MxGateway.Worker.Sta;
namespace MxGateway.Worker.Tests.MxAccess;
@@ -21,4 +23,53 @@ public sealed class MxAccessLiveComCreationTests
await session.StartAsync(workerProcessId: 1234);
}
[Fact]
public async Task RegisterAndUnregister_WhenOptedIn_RoundTripsInstalledMxAccessServerHandle()
{
if (!RunLiveMxAccessTests())
{
return;
}
using MxAccessStaSession session = new();
await session.StartAsync(workerProcessId: 1234);
MxCommandReply registerReply = await session.DispatchAsync(new StaCommand(
"session-1",
"live-register",
new MxCommand
{
Kind = MxCommandKind.Register,
Register = new RegisterCommand
{
ClientName = "MxGateway.Worker.Tests",
},
}));
Assert.Equal(ProtocolStatusCode.Ok, registerReply.ProtocolStatus.Code);
Assert.True(registerReply.Register.ServerHandle > 0);
MxCommandReply unregisterReply = await session.DispatchAsync(new StaCommand(
"session-1",
"live-unregister",
new MxCommand
{
Kind = MxCommandKind.Unregister,
Unregister = new UnregisterCommand
{
ServerHandle = registerReply.Register.ServerHandle,
},
}));
Assert.Equal(ProtocolStatusCode.Ok, unregisterReply.ProtocolStatus.Code);
}
private static bool RunLiveMxAccessTests()
{
return string.Equals(
Environment.GetEnvironmentVariable("MXGATEWAY_RUN_LIVE_MXACCESS_TESTS"),
"1",
StringComparison.Ordinal);
}
}
@@ -0,0 +1,8 @@
namespace MxGateway.Worker.MxAccess;
public interface IMxAccessServer
{
int Register(string clientName);
void Unregister(int serverHandle);
}
@@ -0,0 +1,59 @@
using System;
using System.Reflection;
using System.Runtime.ExceptionServices;
using ArchestrA.MxAccess;
namespace MxGateway.Worker.MxAccess;
public sealed class MxAccessComServer : IMxAccessServer
{
private readonly object mxAccessComObject;
public MxAccessComServer(object mxAccessComObject)
{
this.mxAccessComObject = mxAccessComObject ?? throw new ArgumentNullException(nameof(mxAccessComObject));
}
public int Register(string clientName)
{
if (mxAccessComObject is ILMXProxyServer mxAccessServer)
{
return mxAccessServer.Register(clientName);
}
return (int)Invoke(nameof(Register), clientName);
}
public void Unregister(int serverHandle)
{
if (mxAccessComObject is ILMXProxyServer mxAccessServer)
{
mxAccessServer.Unregister(serverHandle);
return;
}
Invoke(nameof(Unregister), serverHandle);
}
private object Invoke(
string methodName,
params object[] arguments)
{
try
{
return mxAccessComObject
.GetType()
.InvokeMember(
methodName,
BindingFlags.Instance | BindingFlags.Public | BindingFlags.InvokeMethod,
binder: null,
target: mxAccessComObject,
args: arguments);
}
catch (TargetInvocationException exception) when (exception.InnerException is not null)
{
ExceptionDispatchInfo.Capture(exception.InnerException).Throw();
throw;
}
}
}
@@ -0,0 +1,103 @@
using System;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Conversion;
using MxGateway.Worker.Sta;
namespace MxGateway.Worker.MxAccess;
public sealed class MxAccessCommandExecutor : IStaCommandExecutor
{
private readonly MxAccessSession session;
private readonly VariantConverter variantConverter;
public MxAccessCommandExecutor(MxAccessSession session)
: this(session, new VariantConverter())
{
}
public MxAccessCommandExecutor(
MxAccessSession session,
VariantConverter variantConverter)
{
this.session = session ?? throw new ArgumentNullException(nameof(session));
this.variantConverter = variantConverter ?? throw new ArgumentNullException(nameof(variantConverter));
}
public MxCommandReply Execute(StaCommand command)
{
if (command is null)
{
throw new ArgumentNullException(nameof(command));
}
return command.Kind switch
{
MxCommandKind.Register => ExecuteRegister(command),
MxCommandKind.Unregister => ExecuteUnregister(command),
_ => CreateInvalidRequestReply(command, $"Unsupported MXAccess command kind {command.Kind}."),
};
}
private MxCommandReply ExecuteRegister(StaCommand command)
{
if (command.Command.PayloadCase != MxCommand.PayloadOneofCase.Register)
{
return CreateInvalidRequestReply(command, "Register command payload is required.");
}
int serverHandle = session.Register(command.Command.Register.ClientName);
MxCommandReply reply = CreateOkReply(command);
reply.ReturnValue = variantConverter.Convert(serverHandle);
reply.Register = new RegisterReply
{
ServerHandle = serverHandle,
};
return reply;
}
private MxCommandReply ExecuteUnregister(StaCommand command)
{
if (command.Command.PayloadCase != MxCommand.PayloadOneofCase.Unregister)
{
return CreateInvalidRequestReply(command, "Unregister command payload is required.");
}
session.Unregister(command.Command.Unregister.ServerHandle);
return CreateOkReply(command);
}
private static MxCommandReply CreateOkReply(StaCommand command)
{
return new MxCommandReply
{
SessionId = command.SessionId,
CorrelationId = command.CorrelationId,
Kind = command.Kind,
Hresult = 0,
ProtocolStatus = new ProtocolStatus
{
Code = ProtocolStatusCode.Ok,
Message = "OK",
},
};
}
private static MxCommandReply CreateInvalidRequestReply(
StaCommand command,
string message)
{
return new MxCommandReply
{
SessionId = command.SessionId,
CorrelationId = command.CorrelationId,
Kind = command.Kind,
ProtocolStatus = new ProtocolStatus
{
Code = ProtocolStatusCode.InvalidRequest,
Message = message,
},
DiagnosticMessage = message,
};
}
}
@@ -0,0 +1,31 @@
using System.Collections.Generic;
using System.Linq;
namespace MxGateway.Worker.MxAccess;
public sealed class MxAccessHandleRegistry
{
private readonly Dictionary<int, RegisteredServerHandle> serverHandles = new();
public IReadOnlyList<RegisteredServerHandle> ServerHandles => serverHandles
.Values
.OrderBy(handle => handle.ServerHandle)
.ToArray();
public void RegisterServerHandle(
int serverHandle,
string clientName)
{
serverHandles[serverHandle] = new RegisteredServerHandle(serverHandle, clientName);
}
public void UnregisterServerHandle(int serverHandle)
{
serverHandles.Remove(serverHandle);
}
public bool ContainsServerHandle(int serverHandle)
{
return serverHandles.ContainsKey(serverHandle);
}
}
@@ -8,21 +8,29 @@ namespace MxGateway.Worker.MxAccess;
public sealed class MxAccessSession : IDisposable
{
private readonly object mxAccessComObject;
private readonly IMxAccessServer mxAccessServer;
private readonly IMxAccessEventSink eventSink;
private readonly MxAccessHandleRegistry handleRegistry;
private bool disposed;
private MxAccessSession(
object mxAccessComObject,
IMxAccessServer mxAccessServer,
IMxAccessEventSink eventSink,
MxAccessHandleRegistry handleRegistry,
int creationThreadId)
{
this.mxAccessComObject = mxAccessComObject ?? throw new ArgumentNullException(nameof(mxAccessComObject));
this.mxAccessServer = mxAccessServer ?? throw new ArgumentNullException(nameof(mxAccessServer));
this.eventSink = eventSink ?? throw new ArgumentNullException(nameof(eventSink));
this.handleRegistry = handleRegistry ?? throw new ArgumentNullException(nameof(handleRegistry));
CreationThreadId = creationThreadId;
}
public int CreationThreadId { get; }
public MxAccessHandleRegistry HandleRegistry => handleRegistry;
public WorkerReady CreateWorkerReady(int workerProcessId)
{
return new WorkerReady
@@ -62,7 +70,9 @@ public sealed class MxAccessSession : IDisposable
return new MxAccessSession(
mxAccessComObject,
new MxAccessComServer(mxAccessComObject),
eventSink,
new MxAccessHandleRegistry(),
Environment.CurrentManagedThreadId);
}
catch (Exception exception)
@@ -78,6 +88,24 @@ public sealed class MxAccessSession : IDisposable
}
}
public int Register(string clientName)
{
ThrowIfDisposed();
int serverHandle = mxAccessServer.Register(clientName);
handleRegistry.RegisterServerHandle(serverHandle, clientName);
return serverHandle;
}
public void Unregister(int serverHandle)
{
ThrowIfDisposed();
mxAccessServer.Unregister(serverHandle);
handleRegistry.UnregisterServerHandle(serverHandle);
}
public void Dispose()
{
if (disposed)
@@ -94,4 +122,12 @@ public sealed class MxAccessSession : IDisposable
disposed = true;
}
private void ThrowIfDisposed()
{
if (disposed)
{
throw new ObjectDisposedException(nameof(MxAccessSession));
}
}
}
@@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Contracts.Proto;
@@ -11,6 +12,7 @@ public sealed class MxAccessStaSession : IDisposable
private readonly IMxAccessComObjectFactory factory;
private readonly IMxAccessEventSink eventSink;
private readonly StaRuntime staRuntime;
private StaCommandDispatcher? commandDispatcher;
private MxAccessSession? session;
private bool disposed;
@@ -47,11 +49,38 @@ public sealed class MxAccessStaSession : IDisposable
}
session = MxAccessSession.Create(factory, eventSink);
commandDispatcher = new StaCommandDispatcher(
staRuntime,
new MxAccessCommandExecutor(session));
return session.CreateWorkerReady(workerProcessId);
},
cancellationToken);
}
public Task<MxCommandReply> DispatchAsync(StaCommand command)
{
if (commandDispatcher is null)
{
throw new InvalidOperationException("MXAccess COM session has not been started.");
}
return commandDispatcher.DispatchAsync(command);
}
public Task<IReadOnlyList<RegisteredServerHandle>> GetRegisteredServerHandlesAsync(
CancellationToken cancellationToken = default)
{
if (session is null)
{
throw new InvalidOperationException("MXAccess COM session has not been started.");
}
return staRuntime.InvokeAsync(
() => session.HandleRegistry.ServerHandles,
cancellationToken);
}
public void Dispose()
{
if (disposed)
@@ -59,6 +88,8 @@ public sealed class MxAccessStaSession : IDisposable
return;
}
commandDispatcher?.RequestShutdown();
if (session is not null)
{
staRuntime.InvokeAsync(() => session.Dispose()).GetAwaiter().GetResult();
@@ -0,0 +1,16 @@
namespace MxGateway.Worker.MxAccess;
public sealed class RegisteredServerHandle
{
public RegisteredServerHandle(
int serverHandle,
string clientName)
{
ServerHandle = serverHandle;
ClientName = clientName;
}
public int ServerHandle { get; }
public string ClientName { get; }
}