Issue #26: implement Register and Unregister #75

Merged
dohertj2 merged 1 commits from agent-2/issue-26-implement-register-and-unregister into main 2026-04-26 18:13:43 -04:00
11 changed files with 578 additions and 2 deletions
+2 -1
View File
@@ -189,6 +189,8 @@ Tests:
Labels: `area:worker`, `type:feature`, `priority:p0`
Status: implemented.
Deliverables:
- `Register`,
@@ -447,4 +449,3 @@ Acceptance criteria:
- each public method has planned parity fixture or documented gap,
- gateway results preserve HRESULT/status/value/event shape.
+21 -1
View File
@@ -294,7 +294,10 @@ creates `LMXProxyServerClass` through `MxAccessComObjectFactory` on the STA,
attaches `MxAccessBaseEventSink`, and returns `WorkerReady` only after those
steps succeed. `MxAccessSession` keeps the raw COM object private, records the
STA managed thread id that created it, detaches the base event sink during
disposal, and releases the COM reference on the STA.
disposal, and releases the COM reference on the STA. After creation,
`MxAccessStaSession` owns a `StaCommandDispatcher` backed by
`MxAccessCommandExecutor`; `DispatchAsync` queues contract commands back to the
same STA instead of exposing the COM object to callers.
Creation rules:
@@ -414,6 +417,21 @@ Diagnostics:
Implement method-specific dispatch instead of a generic string method invoker.
Parity tests need stable command-specific request and reply shapes.
`MxAccessCommandExecutor` implements the first command pair:
- `Register` calls `LMXProxyServerClass.Register` with the requested client
name and preserves the returned server handle in both `ReturnValue` and
`RegisterReply.ServerHandle`.
- `Unregister` calls `LMXProxyServerClass.Unregister` with the requested server
handle. The reply has no method-specific payload because the public MXAccess
method returns `void`.
Both commands set `Hresult` to `0` only after the COM call returns normally.
COM exceptions flow through `StaCommandDispatcher`, which captures the thrown
HRESULT and converts the reply to `ProtocolStatusCode.MxaccessFailure`.
`MxAccessStaSession.GetRegisteredServerHandlesAsync` returns an STA-read
snapshot of tracked server handles for diagnostics and future cleanup logic.
## Handle Registry
The worker should track MXAccess state for diagnostics and cleanup, while still
@@ -434,6 +452,8 @@ Rules:
- Do not invent handles.
- Do not rewrite handles returned by MXAccess.
- Record server handles only after `Register` succeeds.
- Remove server handles only after `Unregister` succeeds.
- Preserve invalid-handle behavior from MXAccess.
- Preserve cross-server handle behavior from MXAccess.
- Use registry state for cleanup and diagnostics, not semantic correction.
@@ -0,0 +1,220 @@
using System;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.MxAccess;
using MxGateway.Worker.Sta;
namespace MxGateway.Worker.Tests.MxAccess;
public sealed class MxAccessCommandExecutorTests
{
[Fact]
public async Task DispatchAsync_Register_CallsMxAccessOnStaAndPreservesServerHandle()
{
FakeMxAccessComObjectFactory factory = new(new FakeMxAccessComObject(registerHandle: 42));
using StaRuntime runtime = CreateRuntime();
using MxAccessStaSession session = new(runtime, factory, new NoopEventSink());
await session.StartAsync(workerProcessId: 1234);
MxCommandReply reply = await session.DispatchAsync(CreateRegisterCommand("correlation-1", "client-a"));
Assert.Equal(ProtocolStatusCode.Ok, reply.ProtocolStatus.Code);
Assert.True(reply.HasHresult);
Assert.Equal(0, reply.Hresult);
Assert.Equal(42, reply.Register.ServerHandle);
Assert.Equal(MxDataType.Integer, reply.ReturnValue.DataType);
Assert.Equal(42, reply.ReturnValue.Int32Value);
Assert.Equal(runtime.StaThreadId, factory.FakeComObject.RegisterThreadId);
Assert.Equal("client-a", factory.FakeComObject.RegisteredClientName);
RegisteredServerHandle registeredServerHandle = Assert.Single(
await session.GetRegisteredServerHandlesAsync());
Assert.Equal(42, registeredServerHandle.ServerHandle);
Assert.Equal("client-a", registeredServerHandle.ClientName);
}
[Fact]
public async Task DispatchAsync_Unregister_CallsMxAccessOnStaAndRemovesTrackedServerHandle()
{
FakeMxAccessComObject fakeComObject = new(registerHandle: 43);
FakeMxAccessComObjectFactory factory = new(fakeComObject);
using StaRuntime runtime = CreateRuntime();
using MxAccessStaSession session = new(runtime, factory, new NoopEventSink());
await session.StartAsync(workerProcessId: 1234);
await session.DispatchAsync(CreateRegisterCommand("register", "client-a"));
MxCommandReply reply = await session.DispatchAsync(CreateUnregisterCommand("unregister", 43));
Assert.Equal(ProtocolStatusCode.Ok, reply.ProtocolStatus.Code);
Assert.Equal(43, fakeComObject.UnregisteredServerHandle);
Assert.Equal(runtime.StaThreadId, fakeComObject.UnregisterThreadId);
Assert.Empty(await session.GetRegisteredServerHandlesAsync());
}
[Fact]
public async Task DispatchAsync_UnregisterWhenMxAccessThrows_PreservesHResultAndDoesNotRewriteFailure()
{
const int hresult = unchecked((int)0x80070057);
FakeMxAccessComObject fakeComObject = new(
registerHandle: 44,
unregisterException: new COMException("Invalid handle.", hresult));
FakeMxAccessComObjectFactory factory = new(fakeComObject);
using StaRuntime runtime = CreateRuntime();
using MxAccessStaSession session = new(runtime, factory, new NoopEventSink());
await session.StartAsync(workerProcessId: 1234);
await session.DispatchAsync(CreateRegisterCommand("register-before-failure", "client-a"));
MxCommandReply reply = await session.DispatchAsync(CreateUnregisterCommand("invalid-unregister", 44));
Assert.Equal(ProtocolStatusCode.MxaccessFailure, reply.ProtocolStatus.Code);
Assert.True(reply.HasHresult);
Assert.Equal(hresult, reply.Hresult);
Assert.Contains("0x80070057", reply.DiagnosticMessage);
Assert.Equal(44, fakeComObject.UnregisteredServerHandle);
RegisteredServerHandle registeredServerHandle = Assert.Single(
await session.GetRegisteredServerHandlesAsync());
Assert.Equal(44, registeredServerHandle.ServerHandle);
}
[Fact]
public async Task DispatchAsync_RegisterWithoutPayload_ReturnsInvalidRequest()
{
FakeMxAccessComObjectFactory factory = new(new FakeMxAccessComObject(registerHandle: 45));
using StaRuntime runtime = CreateRuntime();
using MxAccessStaSession session = new(runtime, factory, new NoopEventSink());
await session.StartAsync(workerProcessId: 1234);
MxCommandReply reply = await session.DispatchAsync(new StaCommand(
"session-1",
"missing-payload",
new MxCommand
{
Kind = MxCommandKind.Register,
}));
Assert.Equal(ProtocolStatusCode.InvalidRequest, reply.ProtocolStatus.Code);
Assert.Null(factory.FakeComObject.RegisteredClientName);
}
private static StaCommand CreateRegisterCommand(
string correlationId,
string clientName)
{
return new StaCommand(
"session-1",
correlationId,
new MxCommand
{
Kind = MxCommandKind.Register,
Register = new RegisterCommand
{
ClientName = clientName,
},
});
}
private static StaCommand CreateUnregisterCommand(
string correlationId,
int serverHandle)
{
return new StaCommand(
"session-1",
correlationId,
new MxCommand
{
Kind = MxCommandKind.Unregister,
Unregister = new UnregisterCommand
{
ServerHandle = serverHandle,
},
});
}
private static StaRuntime CreateRuntime()
{
return new StaRuntime(
new NoopComApartmentInitializer(),
new StaMessagePump(),
TimeSpan.FromMilliseconds(25));
}
private sealed class FakeMxAccessComObject
{
private readonly int registerHandle;
private readonly Exception? unregisterException;
public FakeMxAccessComObject(
int registerHandle,
Exception? unregisterException = null)
{
this.registerHandle = registerHandle;
this.unregisterException = unregisterException;
}
public string? RegisteredClientName { get; private set; }
public int? RegisterThreadId { get; private set; }
public int? UnregisteredServerHandle { get; private set; }
public int? UnregisterThreadId { get; private set; }
public int Register(string clientName)
{
RegisteredClientName = clientName;
RegisterThreadId = Environment.CurrentManagedThreadId;
return registerHandle;
}
public void Unregister(int serverHandle)
{
UnregisteredServerHandle = serverHandle;
UnregisterThreadId = Environment.CurrentManagedThreadId;
if (unregisterException is not null)
{
throw unregisterException;
}
}
}
private sealed class FakeMxAccessComObjectFactory : IMxAccessComObjectFactory
{
public FakeMxAccessComObjectFactory(FakeMxAccessComObject fakeComObject)
{
FakeComObject = fakeComObject;
}
public FakeMxAccessComObject FakeComObject { get; }
public object Create()
{
return FakeComObject;
}
}
private sealed class NoopEventSink : IMxAccessEventSink
{
public void Attach(object mxAccessComObject)
{
}
public void Detach()
{
}
}
private sealed class NoopComApartmentInitializer : IStaComApartmentInitializer
{
public void Initialize()
{
}
public void Uninitialize()
{
}
}
}
@@ -1,6 +1,8 @@
using System;
using System.Threading.Tasks;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.MxAccess;
using MxGateway.Worker.Sta;
namespace MxGateway.Worker.Tests.MxAccess;
@@ -21,4 +23,53 @@ public sealed class MxAccessLiveComCreationTests
await session.StartAsync(workerProcessId: 1234);
}
[Fact]
public async Task RegisterAndUnregister_WhenOptedIn_RoundTripsInstalledMxAccessServerHandle()
{
if (!RunLiveMxAccessTests())
{
return;
}
using MxAccessStaSession session = new();
await session.StartAsync(workerProcessId: 1234);
MxCommandReply registerReply = await session.DispatchAsync(new StaCommand(
"session-1",
"live-register",
new MxCommand
{
Kind = MxCommandKind.Register,
Register = new RegisterCommand
{
ClientName = "MxGateway.Worker.Tests",
},
}));
Assert.Equal(ProtocolStatusCode.Ok, registerReply.ProtocolStatus.Code);
Assert.True(registerReply.Register.ServerHandle > 0);
MxCommandReply unregisterReply = await session.DispatchAsync(new StaCommand(
"session-1",
"live-unregister",
new MxCommand
{
Kind = MxCommandKind.Unregister,
Unregister = new UnregisterCommand
{
ServerHandle = registerReply.Register.ServerHandle,
},
}));
Assert.Equal(ProtocolStatusCode.Ok, unregisterReply.ProtocolStatus.Code);
}
private static bool RunLiveMxAccessTests()
{
return string.Equals(
Environment.GetEnvironmentVariable("MXGATEWAY_RUN_LIVE_MXACCESS_TESTS"),
"1",
StringComparison.Ordinal);
}
}
@@ -0,0 +1,8 @@
namespace MxGateway.Worker.MxAccess;
public interface IMxAccessServer
{
int Register(string clientName);
void Unregister(int serverHandle);
}
@@ -0,0 +1,59 @@
using System;
using System.Reflection;
using System.Runtime.ExceptionServices;
using ArchestrA.MxAccess;
namespace MxGateway.Worker.MxAccess;
public sealed class MxAccessComServer : IMxAccessServer
{
private readonly object mxAccessComObject;
public MxAccessComServer(object mxAccessComObject)
{
this.mxAccessComObject = mxAccessComObject ?? throw new ArgumentNullException(nameof(mxAccessComObject));
}
public int Register(string clientName)
{
if (mxAccessComObject is ILMXProxyServer mxAccessServer)
{
return mxAccessServer.Register(clientName);
}
return (int)Invoke(nameof(Register), clientName);
}
public void Unregister(int serverHandle)
{
if (mxAccessComObject is ILMXProxyServer mxAccessServer)
{
mxAccessServer.Unregister(serverHandle);
return;
}
Invoke(nameof(Unregister), serverHandle);
}
private object Invoke(
string methodName,
params object[] arguments)
{
try
{
return mxAccessComObject
.GetType()
.InvokeMember(
methodName,
BindingFlags.Instance | BindingFlags.Public | BindingFlags.InvokeMethod,
binder: null,
target: mxAccessComObject,
args: arguments);
}
catch (TargetInvocationException exception) when (exception.InnerException is not null)
{
ExceptionDispatchInfo.Capture(exception.InnerException).Throw();
throw;
}
}
}
@@ -0,0 +1,103 @@
using System;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Conversion;
using MxGateway.Worker.Sta;
namespace MxGateway.Worker.MxAccess;
public sealed class MxAccessCommandExecutor : IStaCommandExecutor
{
private readonly MxAccessSession session;
private readonly VariantConverter variantConverter;
public MxAccessCommandExecutor(MxAccessSession session)
: this(session, new VariantConverter())
{
}
public MxAccessCommandExecutor(
MxAccessSession session,
VariantConverter variantConverter)
{
this.session = session ?? throw new ArgumentNullException(nameof(session));
this.variantConverter = variantConverter ?? throw new ArgumentNullException(nameof(variantConverter));
}
public MxCommandReply Execute(StaCommand command)
{
if (command is null)
{
throw new ArgumentNullException(nameof(command));
}
return command.Kind switch
{
MxCommandKind.Register => ExecuteRegister(command),
MxCommandKind.Unregister => ExecuteUnregister(command),
_ => CreateInvalidRequestReply(command, $"Unsupported MXAccess command kind {command.Kind}."),
};
}
private MxCommandReply ExecuteRegister(StaCommand command)
{
if (command.Command.PayloadCase != MxCommand.PayloadOneofCase.Register)
{
return CreateInvalidRequestReply(command, "Register command payload is required.");
}
int serverHandle = session.Register(command.Command.Register.ClientName);
MxCommandReply reply = CreateOkReply(command);
reply.ReturnValue = variantConverter.Convert(serverHandle);
reply.Register = new RegisterReply
{
ServerHandle = serverHandle,
};
return reply;
}
private MxCommandReply ExecuteUnregister(StaCommand command)
{
if (command.Command.PayloadCase != MxCommand.PayloadOneofCase.Unregister)
{
return CreateInvalidRequestReply(command, "Unregister command payload is required.");
}
session.Unregister(command.Command.Unregister.ServerHandle);
return CreateOkReply(command);
}
private static MxCommandReply CreateOkReply(StaCommand command)
{
return new MxCommandReply
{
SessionId = command.SessionId,
CorrelationId = command.CorrelationId,
Kind = command.Kind,
Hresult = 0,
ProtocolStatus = new ProtocolStatus
{
Code = ProtocolStatusCode.Ok,
Message = "OK",
},
};
}
private static MxCommandReply CreateInvalidRequestReply(
StaCommand command,
string message)
{
return new MxCommandReply
{
SessionId = command.SessionId,
CorrelationId = command.CorrelationId,
Kind = command.Kind,
ProtocolStatus = new ProtocolStatus
{
Code = ProtocolStatusCode.InvalidRequest,
Message = message,
},
DiagnosticMessage = message,
};
}
}
@@ -0,0 +1,31 @@
using System.Collections.Generic;
using System.Linq;
namespace MxGateway.Worker.MxAccess;
public sealed class MxAccessHandleRegistry
{
private readonly Dictionary<int, RegisteredServerHandle> serverHandles = new();
public IReadOnlyList<RegisteredServerHandle> ServerHandles => serverHandles
.Values
.OrderBy(handle => handle.ServerHandle)
.ToArray();
public void RegisterServerHandle(
int serverHandle,
string clientName)
{
serverHandles[serverHandle] = new RegisteredServerHandle(serverHandle, clientName);
}
public void UnregisterServerHandle(int serverHandle)
{
serverHandles.Remove(serverHandle);
}
public bool ContainsServerHandle(int serverHandle)
{
return serverHandles.ContainsKey(serverHandle);
}
}
@@ -8,21 +8,29 @@ namespace MxGateway.Worker.MxAccess;
public sealed class MxAccessSession : IDisposable
{
private readonly object mxAccessComObject;
private readonly IMxAccessServer mxAccessServer;
private readonly IMxAccessEventSink eventSink;
private readonly MxAccessHandleRegistry handleRegistry;
private bool disposed;
private MxAccessSession(
object mxAccessComObject,
IMxAccessServer mxAccessServer,
IMxAccessEventSink eventSink,
MxAccessHandleRegistry handleRegistry,
int creationThreadId)
{
this.mxAccessComObject = mxAccessComObject ?? throw new ArgumentNullException(nameof(mxAccessComObject));
this.mxAccessServer = mxAccessServer ?? throw new ArgumentNullException(nameof(mxAccessServer));
this.eventSink = eventSink ?? throw new ArgumentNullException(nameof(eventSink));
this.handleRegistry = handleRegistry ?? throw new ArgumentNullException(nameof(handleRegistry));
CreationThreadId = creationThreadId;
}
public int CreationThreadId { get; }
public MxAccessHandleRegistry HandleRegistry => handleRegistry;
public WorkerReady CreateWorkerReady(int workerProcessId)
{
return new WorkerReady
@@ -62,7 +70,9 @@ public sealed class MxAccessSession : IDisposable
return new MxAccessSession(
mxAccessComObject,
new MxAccessComServer(mxAccessComObject),
eventSink,
new MxAccessHandleRegistry(),
Environment.CurrentManagedThreadId);
}
catch (Exception exception)
@@ -78,6 +88,24 @@ public sealed class MxAccessSession : IDisposable
}
}
public int Register(string clientName)
{
ThrowIfDisposed();
int serverHandle = mxAccessServer.Register(clientName);
handleRegistry.RegisterServerHandle(serverHandle, clientName);
return serverHandle;
}
public void Unregister(int serverHandle)
{
ThrowIfDisposed();
mxAccessServer.Unregister(serverHandle);
handleRegistry.UnregisterServerHandle(serverHandle);
}
public void Dispose()
{
if (disposed)
@@ -94,4 +122,12 @@ public sealed class MxAccessSession : IDisposable
disposed = true;
}
private void ThrowIfDisposed()
{
if (disposed)
{
throw new ObjectDisposedException(nameof(MxAccessSession));
}
}
}
@@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Contracts.Proto;
@@ -11,6 +12,7 @@ public sealed class MxAccessStaSession : IDisposable
private readonly IMxAccessComObjectFactory factory;
private readonly IMxAccessEventSink eventSink;
private readonly StaRuntime staRuntime;
private StaCommandDispatcher? commandDispatcher;
private MxAccessSession? session;
private bool disposed;
@@ -47,11 +49,38 @@ public sealed class MxAccessStaSession : IDisposable
}
session = MxAccessSession.Create(factory, eventSink);
commandDispatcher = new StaCommandDispatcher(
staRuntime,
new MxAccessCommandExecutor(session));
return session.CreateWorkerReady(workerProcessId);
},
cancellationToken);
}
public Task<MxCommandReply> DispatchAsync(StaCommand command)
{
if (commandDispatcher is null)
{
throw new InvalidOperationException("MXAccess COM session has not been started.");
}
return commandDispatcher.DispatchAsync(command);
}
public Task<IReadOnlyList<RegisteredServerHandle>> GetRegisteredServerHandlesAsync(
CancellationToken cancellationToken = default)
{
if (session is null)
{
throw new InvalidOperationException("MXAccess COM session has not been started.");
}
return staRuntime.InvokeAsync(
() => session.HandleRegistry.ServerHandles,
cancellationToken);
}
public void Dispose()
{
if (disposed)
@@ -59,6 +88,8 @@ public sealed class MxAccessStaSession : IDisposable
return;
}
commandDispatcher?.RequestShutdown();
if (session is not null)
{
staRuntime.InvokeAsync(() => session.Dispose()).GetAwaiter().GetResult();
@@ -0,0 +1,16 @@
namespace MxGateway.Worker.MxAccess;
public sealed class RegisteredServerHandle
{
public RegisteredServerHandle(
int serverHandle,
string clientName)
{
ServerHandle = serverHandle;
ClientName = clientName;
}
public int ServerHandle { get; }
public string ClientName { get; }
}