diff --git a/docs/gateway-process-design.md b/docs/gateway-process-design.md index 408b687..8c0036f 100644 --- a/docs/gateway-process-design.md +++ b/docs/gateway-process-design.md @@ -206,13 +206,23 @@ accounting and a clear fan-out policy. Behavior: 1. Validate session id and authorize event access. -2. Attach a stream cursor to the session event channel. -3. Send events in worker sequence order. -4. Stop on client cancellation, session close, or session fault. -5. Emit a terminal status when the session faults if gRPC status alone cannot +2. Attach the single active subscriber lease for the session. +3. Read worker events into a bounded public stream queue. +4. Send events in worker sequence order. +5. Stop on client cancellation, session close, or session fault. +6. Emit a terminal status when the session faults if gRPC status alone cannot preserve the required details. -The gateway must not reorder events from one worker. +`EventStreamService` owns subscriber tracking and public stream backpressure. +The default policy allows one active subscriber per session. A second subscriber +is rejected with `EventSubscriberAlreadyActive`. Stream cancellation releases +the subscriber lease so a later stream can attach to the session. + +The gateway must not reorder events from one worker. `EventStreamService` writes +mapped events to a bounded first-in, first-out queue and faults the session with +`EventQueueOverflow` if the queue fills. The gateway does not synthesize +`OperationComplete`; it forwards that family only when the worker reports a +native MXAccess `OperationComplete` event. ## Web Dashboard @@ -584,7 +594,8 @@ worker MXAccess event -> worker outbound event queue -> worker pipe writer -> gateway read loop - -> session event channel + -> worker client event queue + -> EventStreamService bounded stream queue -> gRPC StreamEvents ``` @@ -598,13 +609,15 @@ The gateway should record: Default backpressure policy for parity testing should be fail-fast: -1. If the session event channel fills, fault the session. +1. If the worker client event queue fills, fault the worker client. +2. If the public stream queue fills, fault the gateway session. 2. Preserve the overflow details in logs and metrics. 3. Do not silently drop data-change events. -Do not set a production event-rate target before measurement. Emit event rate, -queue depth, stream send latency, and overflow metrics. Later production modes -may support explicit coalescing by item handle as an opt-in behavior. +Do not set a production event-rate target before measurement. `GatewayMetrics` +records received event counts by family, queue depth, stream disconnects, and +overflow counts. Later production modes may support explicit coalescing by item +handle as an opt-in behavior. The gateway should not synthesize `OperationComplete` from write completion, command replies, ASB completion queues, or completion-only status frames. Forward diff --git a/docs/implementation-plan-mxaccess-worker.md b/docs/implementation-plan-mxaccess-worker.md index 96c1ded..2f6cc73 100644 --- a/docs/implementation-plan-mxaccess-worker.md +++ b/docs/implementation-plan-mxaccess-worker.md @@ -189,6 +189,8 @@ Tests: Labels: `area:worker`, `type:feature`, `priority:p0` +Status: implemented. + Deliverables: - `Register`, @@ -447,4 +449,3 @@ Acceptance criteria: - each public method has planned parity fixture or documented gap, - gateway results preserve HRESULT/status/value/event shape. - diff --git a/docs/mxaccess-worker-instance-design.md b/docs/mxaccess-worker-instance-design.md index b11fcd3..7f3ea2b 100644 --- a/docs/mxaccess-worker-instance-design.md +++ b/docs/mxaccess-worker-instance-design.md @@ -294,7 +294,10 @@ creates `LMXProxyServerClass` through `MxAccessComObjectFactory` on the STA, attaches `MxAccessBaseEventSink`, and returns `WorkerReady` only after those steps succeed. `MxAccessSession` keeps the raw COM object private, records the STA managed thread id that created it, detaches the base event sink during -disposal, and releases the COM reference on the STA. +disposal, and releases the COM reference on the STA. After creation, +`MxAccessStaSession` owns a `StaCommandDispatcher` backed by +`MxAccessCommandExecutor`; `DispatchAsync` queues contract commands back to the +same STA instead of exposing the COM object to callers. Creation rules: @@ -414,6 +417,21 @@ Diagnostics: Implement method-specific dispatch instead of a generic string method invoker. Parity tests need stable command-specific request and reply shapes. +`MxAccessCommandExecutor` implements the first command pair: + +- `Register` calls `LMXProxyServerClass.Register` with the requested client + name and preserves the returned server handle in both `ReturnValue` and + `RegisterReply.ServerHandle`. +- `Unregister` calls `LMXProxyServerClass.Unregister` with the requested server + handle. The reply has no method-specific payload because the public MXAccess + method returns `void`. + +Both commands set `Hresult` to `0` only after the COM call returns normally. +COM exceptions flow through `StaCommandDispatcher`, which captures the thrown +HRESULT and converts the reply to `ProtocolStatusCode.MxaccessFailure`. +`MxAccessStaSession.GetRegisteredServerHandlesAsync` returns an STA-read +snapshot of tracked server handles for diagnostics and future cleanup logic. + ## Handle Registry The worker should track MXAccess state for diagnostics and cleanup, while still @@ -434,6 +452,8 @@ Rules: - Do not invent handles. - Do not rewrite handles returned by MXAccess. +- Record server handles only after `Register` succeeds. +- Remove server handles only after `Unregister` succeeds. - Preserve invalid-handle behavior from MXAccess. - Preserve cross-server handle behavior from MXAccess. - Use registry state for cleanup and diagnostics, not semantic correction. diff --git a/gateway.md b/gateway.md index ca82673..5243f1e 100644 --- a/gateway.md +++ b/gateway.md @@ -527,11 +527,7 @@ Worker policy: - bounded outbound event channel, - never block MXAccess event handler on pipe writes, -- if the outbound channel is full, apply configured policy: - - disconnect session, - - drop oldest low-priority data-change events, - - coalesce data changes by item handle, - - or block briefly then fault. +- fail the worker session when the outbound channel is full. For full parity testing, default should be fail-fast rather than silent drop. For production high-rate telemetry, add explicit coalescing modes. @@ -540,9 +536,15 @@ Gateway policy: - one event sequencer per session, - preserve per-session event order, -- support multiple client event subscribers only if explicitly required, -- apply backpressure from slow gRPC streams, -- disconnect or coalesce according to client-selected mode. +- allow one active client event subscriber per session, +- reject a second subscriber with a clear session error, +- use a bounded `EventStreamService` queue between worker events and gRPC + writes, +- fault the session when the bounded stream queue overflows, +- detach the subscriber when the stream is canceled. + +The gateway forwards only events reported by the worker. It does not synthesize +`OperationComplete` from write completion, command replies, or status frames. ## Isolation And Fault Handling @@ -864,10 +866,11 @@ translation code testable. The gateway maps `MxAccessGateway` to `MxAccessGatewayService`. The service implements `OpenSession`, `CloseSession`, `Invoke`, and `StreamEvents` by validating public requests, delegating session work to `ISessionManager`, and -using explicit mapper code for public-to-worker commands, worker replies, and -events. Missing sessions and transport failures return gRPC status errors; -worker command replies preserve MXAccess HRESULT and status details in the -public reply. +using explicit mapper code for public-to-worker commands and worker replies. +`StreamEvents` delegates subscriber ownership, ordering, and backpressure to +`EventStreamService`. Missing sessions and transport failures return gRPC +status errors; worker command replies preserve MXAccess HRESULT and status +details in the public reply. ## C# Worker Versus C++ Worker diff --git a/src/MxGateway.Server/GatewayApplication.cs b/src/MxGateway.Server/GatewayApplication.cs index 0a3f92d..b838bd1 100644 --- a/src/MxGateway.Server/GatewayApplication.cs +++ b/src/MxGateway.Server/GatewayApplication.cs @@ -37,6 +37,7 @@ public static class GatewayApplication builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); + builder.Services.AddSingleton(); builder.Services.AddWorkerProcessLauncher(); builder.Services.AddGatewaySessions(); builder.Services.AddGatewayDashboard(); diff --git a/src/MxGateway.Server/Grpc/EventStreamService.cs b/src/MxGateway.Server/Grpc/EventStreamService.cs new file mode 100644 index 0000000..1aacd39 --- /dev/null +++ b/src/MxGateway.Server/Grpc/EventStreamService.cs @@ -0,0 +1,140 @@ +using System.Runtime.CompilerServices; +using System.Threading.Channels; +using Microsoft.Extensions.Options; +using MxGateway.Contracts.Proto; +using MxGateway.Server.Configuration; +using MxGateway.Server.Metrics; +using MxGateway.Server.Sessions; +using MxGateway.Server.Workers; + +namespace MxGateway.Server.Grpc; + +public sealed class EventStreamService( + ISessionManager sessionManager, + IOptions options, + MxAccessGrpcMapper mapper, + GatewayMetrics metrics, + ILogger logger) : IEventStreamService +{ + public async IAsyncEnumerable StreamEventsAsync( + StreamEventsRequest request, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + if (!sessionManager.TryGetSession(request.SessionId, out GatewaySession session)) + { + throw new SessionManagerException( + SessionManagerErrorCode.SessionNotFound, + $"Session {request.SessionId} was not found."); + } + + using IDisposable subscriber = session.AttachEventSubscriber( + options.Value.Sessions.AllowMultipleEventSubscribers); + using CancellationTokenSource streamCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + + int streamQueueDepth = 0; + Channel eventQueue = Channel.CreateBounded( + new BoundedChannelOptions(options.Value.Events.QueueCapacity) + { + SingleReader = true, + SingleWriter = true, + FullMode = BoundedChannelFullMode.Wait, + AllowSynchronousContinuations = false, + }); + Task producerTask = ProduceEventsAsync( + session, + request.AfterWorkerSequence, + eventQueue.Writer, + () => + { + int depth = Interlocked.Increment(ref streamQueueDepth); + metrics.SetEventQueueDepth(depth); + }, + streamCts.Token); + + try + { + await foreach (MxEvent mxEvent in eventQueue.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) + { + int depth = Math.Max(0, Interlocked.Decrement(ref streamQueueDepth)); + metrics.SetEventQueueDepth(depth); + yield return mxEvent; + } + + await producerTask.ConfigureAwait(false); + } + finally + { + await streamCts.CancelAsync().ConfigureAwait(false); + subscriber.Dispose(); + metrics.StreamDisconnected("Detached"); + + try + { + await producerTask.ConfigureAwait(false); + } + catch (OperationCanceledException) when (streamCts.IsCancellationRequested) + { + } + catch (Exception exception) + { + logger.LogDebug( + exception, + "Event stream producer stopped for session {SessionId}.", + request.SessionId); + } + } + } + + private async Task ProduceEventsAsync( + GatewaySession session, + ulong afterWorkerSequence, + ChannelWriter writer, + Action eventQueued, + CancellationToken cancellationToken) + { + try + { + await foreach (WorkerEvent workerEvent in session + .ReadEventsAsync(cancellationToken) + .WithCancellation(cancellationToken) + .ConfigureAwait(false)) + { + MxEvent publicEvent = mapper.MapEvent(workerEvent); + if (publicEvent.WorkerSequence <= afterWorkerSequence) + { + continue; + } + + if (!writer.TryWrite(publicEvent)) + { + string message = $"Session {session.SessionId} event stream queue overflowed."; + session.MarkFaulted(message); + metrics.QueueOverflow("grpc-event-stream"); + metrics.Fault(SessionManagerErrorCode.EventQueueOverflow.ToString()); + writer.TryComplete(new SessionManagerException( + SessionManagerErrorCode.EventQueueOverflow, + message)); + return; + } + + eventQueued(); + } + + writer.TryComplete(); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + writer.TryComplete(); + } + catch (Exception exception) + { + if (exception is WorkerClientException) + { + session.MarkFaulted(exception.Message); + metrics.Fault(WorkerClientErrorCode.WorkerFaulted.ToString()); + } + + writer.TryComplete(exception); + } + } +} diff --git a/src/MxGateway.Server/Grpc/IEventStreamService.cs b/src/MxGateway.Server/Grpc/IEventStreamService.cs new file mode 100644 index 0000000..06f4acb --- /dev/null +++ b/src/MxGateway.Server/Grpc/IEventStreamService.cs @@ -0,0 +1,10 @@ +using MxGateway.Contracts.Proto; + +namespace MxGateway.Server.Grpc; + +public interface IEventStreamService +{ + IAsyncEnumerable StreamEventsAsync( + StreamEventsRequest request, + CancellationToken cancellationToken); +} diff --git a/src/MxGateway.Server/Grpc/MxAccessGatewayService.cs b/src/MxGateway.Server/Grpc/MxAccessGatewayService.cs index 1ee8777..2b75baf 100644 --- a/src/MxGateway.Server/Grpc/MxAccessGatewayService.cs +++ b/src/MxGateway.Server/Grpc/MxAccessGatewayService.cs @@ -12,6 +12,7 @@ public sealed class MxAccessGatewayService( IGatewayRequestIdentityAccessor identityAccessor, MxAccessGrpcRequestValidator requestValidator, MxAccessGrpcMapper mapper, + IEventStreamService eventStreamService, ILogger logger) : MxAccessGateway.MxAccessGatewayBase { public override async Task OpenSession( @@ -102,17 +103,11 @@ public sealed class MxAccessGatewayService( try { requestValidator.ValidateStreamEvents(request); - await foreach (WorkerEvent workerEvent in sessionManager - .ReadEventsAsync(request.SessionId, context.CancellationToken) + await foreach (MxEvent publicEvent in eventStreamService + .StreamEventsAsync(request, context.CancellationToken) .WithCancellation(context.CancellationToken) .ConfigureAwait(false)) { - MxEvent publicEvent = mapper.MapEvent(workerEvent); - if (publicEvent.WorkerSequence <= request.AfterWorkerSequence) - { - continue; - } - await responseStream.WriteAsync(publicEvent).ConfigureAwait(false); } } @@ -154,6 +149,8 @@ public sealed class MxAccessGatewayService( { SessionManagerErrorCode.SessionNotFound => StatusCode.NotFound, SessionManagerErrorCode.SessionNotReady => StatusCode.FailedPrecondition, + SessionManagerErrorCode.EventSubscriberAlreadyActive => StatusCode.ResourceExhausted, + SessionManagerErrorCode.EventQueueOverflow => StatusCode.ResourceExhausted, SessionManagerErrorCode.SessionLimitExceeded => StatusCode.ResourceExhausted, SessionManagerErrorCode.OpenFailed => StatusCode.Unavailable, SessionManagerErrorCode.CloseFailed => StatusCode.Unavailable, diff --git a/src/MxGateway.Server/Sessions/GatewaySession.cs b/src/MxGateway.Server/Sessions/GatewaySession.cs index d010669..dab1f34 100644 --- a/src/MxGateway.Server/Sessions/GatewaySession.cs +++ b/src/MxGateway.Server/Sessions/GatewaySession.cs @@ -13,6 +13,7 @@ public sealed class GatewaySession private DateTimeOffset _lastClientActivityAt; private DateTimeOffset? _leaseExpiresAt; private bool _closeStarted; + private int _activeEventSubscriberCount; public GatewaySession( string sessionId, @@ -131,6 +132,17 @@ public sealed class GatewaySession } } + public int ActiveEventSubscriberCount + { + get + { + lock (_syncRoot) + { + return _activeEventSubscriberCount; + } + } + } + public void AttachWorkerClient(IWorkerClient workerClient) { ArgumentNullException.ThrowIfNull(workerClient); @@ -202,6 +214,29 @@ public sealed class GatewaySession } } + public IDisposable AttachEventSubscriber(bool allowMultipleSubscribers) + { + lock (_syncRoot) + { + if (_state != SessionState.Ready || _workerClient?.State != WorkerClientState.Ready) + { + throw new SessionManagerException( + SessionManagerErrorCode.SessionNotReady, + $"Session {SessionId} is not ready for event streaming. Current state is {_state}."); + } + + if (!allowMultipleSubscribers && _activeEventSubscriberCount > 0) + { + throw new SessionManagerException( + SessionManagerErrorCode.EventSubscriberAlreadyActive, + $"Session {SessionId} already has an active event stream subscriber."); + } + + _activeEventSubscriberCount++; + return new EventSubscriberLease(this); + } + } + public async Task InvokeAsync( WorkerCommand command, CancellationToken cancellationToken) @@ -287,4 +322,31 @@ public sealed class GatewaySession return _workerClient; } } + + private void DetachEventSubscriber() + { + lock (_syncRoot) + { + if (_activeEventSubscriberCount > 0) + { + _activeEventSubscriberCount--; + } + } + } + + private sealed class EventSubscriberLease(GatewaySession session) : IDisposable + { + private bool _disposed; + + public void Dispose() + { + if (_disposed) + { + return; + } + + session.DetachEventSubscriber(); + _disposed = true; + } + } } diff --git a/src/MxGateway.Server/Sessions/SessionManagerErrorCode.cs b/src/MxGateway.Server/Sessions/SessionManagerErrorCode.cs index dcbca45..f7e723c 100644 --- a/src/MxGateway.Server/Sessions/SessionManagerErrorCode.cs +++ b/src/MxGateway.Server/Sessions/SessionManagerErrorCode.cs @@ -4,6 +4,8 @@ public enum SessionManagerErrorCode { SessionNotFound, SessionNotReady, + EventSubscriberAlreadyActive, + EventQueueOverflow, SessionLimitExceeded, OpenFailed, CloseFailed, diff --git a/src/MxGateway.Server/Workers/WorkerClient.cs b/src/MxGateway.Server/Workers/WorkerClient.cs index 9e0daa9..18b8cb6 100644 --- a/src/MxGateway.Server/Workers/WorkerClient.cs +++ b/src/MxGateway.Server/Workers/WorkerClient.cs @@ -29,6 +29,7 @@ public sealed class WorkerClient : IWorkerClient private WorkerClientState _state; private DateTimeOffset _lastHeartbeatAt; private int? _processId; + private int _eventQueueDepth; private Task? _readLoopTask; private Task? _writeLoopTask; private Task? _heartbeatLoopTask; @@ -197,6 +198,8 @@ public sealed class WorkerClient : IWorkerClient { await foreach (WorkerEvent workerEvent in _events.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) { + int queueDepth = Math.Max(0, Interlocked.Decrement(ref _eventQueueDepth)); + _metrics?.SetEventQueueDepth(queueDepth); yield return workerEvent; } } @@ -394,11 +397,6 @@ public sealed class WorkerClient : IWorkerClient _metrics?.EventReceived(SessionId, workerEvent.Event.Family.ToString()); } - if (!await _events.Writer.WaitToWriteAsync(cancellationToken).ConfigureAwait(false)) - { - return; - } - if (!_events.Writer.TryWrite(workerEvent)) { _metrics?.QueueOverflow("worker-events"); @@ -406,7 +404,11 @@ public sealed class WorkerClient : IWorkerClient WorkerClientErrorCode.ProtocolViolation, "Worker event channel rejected an event.", null); + return; } + + int queueDepth = Interlocked.Increment(ref _eventQueueDepth); + _metrics?.SetEventQueueDepth(queueDepth); } private void CompleteCommand(WorkerEnvelope envelope) diff --git a/src/MxGateway.Tests/Gateway/Grpc/EventStreamServiceTests.cs b/src/MxGateway.Tests/Gateway/Grpc/EventStreamServiceTests.cs new file mode 100644 index 0000000..4960895 --- /dev/null +++ b/src/MxGateway.Tests/Gateway/Grpc/EventStreamServiceTests.cs @@ -0,0 +1,383 @@ +using System.Runtime.CompilerServices; +using Grpc.Core; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; +using MxGateway.Contracts; +using MxGateway.Contracts.Proto; +using MxGateway.Server.Configuration; +using MxGateway.Server.Grpc; +using MxGateway.Server.Metrics; +using MxGateway.Server.Sessions; +using MxGateway.Server.Workers; + +namespace MxGateway.Tests.Gateway.Grpc; + +public sealed class EventStreamServiceTests +{ + private static readonly TimeSpan TestTimeout = TimeSpan.FromSeconds(5); + + [Fact] + public async Task StreamEventsAsync_YieldsEventsInWorkerOrder() + { + FakeWorkerClient workerClient = new(); + GatewaySession session = CreateReadySession(workerClient); + FakeSessionManager sessionManager = new(session); + using GatewayMetrics metrics = new(); + EventStreamService service = CreateService(sessionManager, metrics: metrics); + workerClient.Events.Add(CreateWorkerEvent(sequence: 10, MxEventFamily.OnDataChange)); + workerClient.Events.Add(CreateWorkerEvent(sequence: 11, MxEventFamily.OnWriteComplete)); + workerClient.CompleteAfterConfiguredEvents = true; + + List events = await CollectEventsAsync(service, session.SessionId); + + Assert.Equal([10UL, 11UL], events.Select(mxEvent => mxEvent.WorkerSequence).ToArray()); + Assert.Equal(MxEventFamily.OnDataChange, events[0].Family); + Assert.Equal(MxEventFamily.OnWriteComplete, events[1].Family); + Assert.Equal(1, metrics.GetSnapshot().StreamDisconnects); + } + + [Fact] + public async Task StreamEventsAsync_WhenSecondSubscriberStarts_RejectsClearly() + { + FakeWorkerClient workerClient = new(); + GatewaySession session = CreateReadySession(workerClient); + EventStreamService service = CreateService(new FakeSessionManager(session)); + using CancellationTokenSource firstSubscriberCancellation = new(); + await using IAsyncEnumerator firstSubscriber = service + .StreamEventsAsync(CreateRequest(session.SessionId), firstSubscriberCancellation.Token) + .GetAsyncEnumerator(firstSubscriberCancellation.Token); + Task firstMoveTask = firstSubscriber.MoveNextAsync().AsTask(); + + await WaitUntilAsync(() => session.ActiveEventSubscriberCount == 1); + await using IAsyncEnumerator secondSubscriber = service + .StreamEventsAsync(CreateRequest(session.SessionId), CancellationToken.None) + .GetAsyncEnumerator(); + + SessionManagerException exception = await Assert.ThrowsAsync( + async () => await secondSubscriber.MoveNextAsync().AsTask().WaitAsync(TestTimeout)); + + Assert.Equal(SessionManagerErrorCode.EventSubscriberAlreadyActive, exception.ErrorCode); + await firstSubscriberCancellation.CancelAsync(); + await Assert.ThrowsAnyAsync( + async () => await firstMoveTask.WaitAsync(TestTimeout)); + await firstSubscriber.DisposeAsync(); + await WaitUntilAsync(() => session.ActiveEventSubscriberCount == 0); + } + + [Fact] + public async Task StreamEventsAsync_WhenCanceled_DetachesSubscriber() + { + FakeWorkerClient workerClient = new(); + GatewaySession session = CreateReadySession(workerClient); + EventStreamService service = CreateService(new FakeSessionManager(session)); + using CancellationTokenSource cancellationTokenSource = new(); + await using IAsyncEnumerator subscriber = service + .StreamEventsAsync(CreateRequest(session.SessionId), cancellationTokenSource.Token) + .GetAsyncEnumerator(cancellationTokenSource.Token); + Task moveTask = subscriber.MoveNextAsync().AsTask(); + + await WaitUntilAsync(() => session.ActiveEventSubscriberCount == 1); + await cancellationTokenSource.CancelAsync(); + await Assert.ThrowsAnyAsync( + async () => await moveTask.WaitAsync(TestTimeout)); + await subscriber.DisposeAsync(); + + await WaitUntilAsync(() => session.ActiveEventSubscriberCount == 0); + } + + [Fact] + public async Task StreamEventsAsync_WhenStreamQueueOverflows_FaultsSessionAndReportsOverflow() + { + FakeWorkerClient workerClient = new(); + GatewaySession session = CreateReadySession(workerClient); + using GatewayMetrics metrics = new(); + EventStreamService service = CreateService( + new FakeSessionManager(session), + metrics, + queueCapacity: 1); + workerClient.Events.Add(CreateWorkerEvent(sequence: 1, MxEventFamily.OnDataChange)); + workerClient.Events.Add(CreateWorkerEvent(sequence: 2, MxEventFamily.OnDataChange)); + workerClient.Events.Add(CreateWorkerEvent(sequence: 3, MxEventFamily.OnDataChange)); + workerClient.CompleteAfterConfiguredEvents = true; + await using IAsyncEnumerator subscriber = service + .StreamEventsAsync(CreateRequest(session.SessionId), CancellationToken.None) + .GetAsyncEnumerator(); + + Assert.True(await subscriber.MoveNextAsync().AsTask().WaitAsync(TestTimeout)); + await WaitUntilAsync(() => session.State == SessionState.Faulted); + SessionManagerException exception = await Assert.ThrowsAsync( + async () => await subscriber.MoveNextAsync().AsTask().WaitAsync(TestTimeout)); + + Assert.Equal(SessionManagerErrorCode.EventQueueOverflow, exception.ErrorCode); + Assert.Equal(SessionState.Faulted, session.State); + Assert.Equal(1, metrics.GetSnapshot().QueueOverflows); + Assert.Equal(1, metrics.GetSnapshot().Faults); + } + + [Fact] + public async Task StreamEventsAsync_DoesNotSynthesizeOperationComplete() + { + FakeWorkerClient workerClient = new(); + GatewaySession session = CreateReadySession(workerClient); + EventStreamService service = CreateService(new FakeSessionManager(session)); + workerClient.Events.Add(CreateWorkerEvent(sequence: 10, MxEventFamily.OnWriteComplete)); + workerClient.CompleteAfterConfiguredEvents = true; + + List events = await CollectEventsAsync(service, session.SessionId); + + MxEvent mxEvent = Assert.Single(events); + Assert.Equal(MxEventFamily.OnWriteComplete, mxEvent.Family); + Assert.DoesNotContain(events, candidate => candidate.Family == MxEventFamily.OperationComplete); + } + + [Fact] + public async Task StreamEventsAsync_WhenWorkerEventStreamFaults_PropagatesTerminalFault() + { + FakeWorkerClient workerClient = new() + { + TerminalException = new WorkerClientException( + WorkerClientErrorCode.WorkerFaulted, + "worker terminal fault"), + }; + GatewaySession session = CreateReadySession(workerClient); + using GatewayMetrics metrics = new(); + EventStreamService service = CreateService(new FakeSessionManager(session), metrics); + await using IAsyncEnumerator subscriber = service + .StreamEventsAsync(CreateRequest(session.SessionId), CancellationToken.None) + .GetAsyncEnumerator(); + + WorkerClientException exception = await Assert.ThrowsAsync( + async () => await subscriber.MoveNextAsync().AsTask().WaitAsync(TestTimeout)); + + Assert.Equal(WorkerClientErrorCode.WorkerFaulted, exception.ErrorCode); + Assert.Equal(SessionState.Faulted, session.State); + Assert.Equal(1, metrics.GetSnapshot().Faults); + } + + private static EventStreamService CreateService( + FakeSessionManager sessionManager, + GatewayMetrics? metrics = null, + int queueCapacity = 8) + { + return new EventStreamService( + sessionManager, + Options.Create(new GatewayOptions + { + Events = new EventOptions + { + QueueCapacity = queueCapacity, + }, + }), + new MxAccessGrpcMapper(), + metrics ?? new GatewayMetrics(), + NullLogger.Instance); + } + + private static async Task> CollectEventsAsync( + EventStreamService service, + string sessionId) + { + List events = []; + await foreach (MxEvent mxEvent in service + .StreamEventsAsync(CreateRequest(sessionId), CancellationToken.None) + .WithCancellation(CancellationToken.None)) + { + events.Add(mxEvent); + } + + return events; + } + + private static StreamEventsRequest CreateRequest(string sessionId) + { + return new StreamEventsRequest + { + SessionId = sessionId, + }; + } + + private static GatewaySession CreateReadySession(FakeWorkerClient workerClient) + { + GatewaySession session = new( + "session-events", + GatewayContractInfo.DefaultBackendName, + "pipe", + "nonce", + "client", + "client-session", + "client-correlation", + TimeSpan.FromSeconds(30), + TimeSpan.FromSeconds(30), + TimeSpan.FromSeconds(10), + DateTimeOffset.UtcNow); + session.AttachWorkerClient(workerClient); + session.MarkReady(); + + return session; + } + + private static WorkerEvent CreateWorkerEvent( + ulong sequence, + MxEventFamily family) + { + MxEvent mxEvent = new() + { + SessionId = "session-events", + Family = family, + WorkerSequence = sequence, + }; + + switch (family) + { + case MxEventFamily.OnDataChange: + mxEvent.OnDataChange = new OnDataChangeEvent(); + break; + case MxEventFamily.OnWriteComplete: + mxEvent.OnWriteComplete = new OnWriteCompleteEvent(); + break; + case MxEventFamily.OperationComplete: + mxEvent.OperationComplete = new OperationCompleteEvent(); + break; + case MxEventFamily.OnBufferedDataChange: + mxEvent.OnBufferedDataChange = new OnBufferedDataChangeEvent(); + break; + } + + return new WorkerEvent + { + Event = mxEvent, + }; + } + + private static async Task WaitUntilAsync(Func predicate) + { + using CancellationTokenSource cancellationTokenSource = new(TestTimeout); + while (!predicate()) + { + await Task.Delay(TimeSpan.FromMilliseconds(10), cancellationTokenSource.Token); + } + } + + private sealed class FakeSessionManager(GatewaySession session) : ISessionManager + { + public Task OpenSessionAsync( + SessionOpenRequest request, + string? clientIdentity, + CancellationToken cancellationToken) + { + return Task.FromResult(session); + } + + public bool TryGetSession( + string sessionId, + out GatewaySession gatewaySession) + { + gatewaySession = session; + return string.Equals(sessionId, session.SessionId, StringComparison.Ordinal); + } + + public Task InvokeAsync( + string sessionId, + WorkerCommand command, + CancellationToken cancellationToken) + { + return Task.FromResult(new WorkerCommandReply()); + } + + public IAsyncEnumerable ReadEventsAsync( + string sessionId, + CancellationToken cancellationToken) + { + return session.ReadEventsAsync(cancellationToken); + } + + public Task CloseSessionAsync( + string sessionId, + CancellationToken cancellationToken) + { + return Task.FromResult(new SessionCloseResult(sessionId, SessionState.Closed, AlreadyClosed: false)); + } + + public Task CloseExpiredLeasesAsync( + DateTimeOffset now, + CancellationToken cancellationToken) + { + return Task.FromResult(0); + } + + public Task ShutdownAsync(CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + } + + private sealed class FakeWorkerClient : IWorkerClient + { + public List Events { get; } = []; + + public bool CompleteAfterConfiguredEvents { get; set; } + + public Exception? TerminalException { get; init; } + + public string SessionId { get; } = "session-events"; + + public int? ProcessId { get; } = 4321; + + public WorkerClientState State { get; private set; } = WorkerClientState.Ready; + + public DateTimeOffset LastHeartbeatAt { get; } = DateTimeOffset.UtcNow; + + public Task StartAsync(CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + + public Task InvokeAsync( + WorkerCommand command, + TimeSpan timeout, + CancellationToken cancellationToken) + { + return Task.FromResult(new WorkerCommandReply()); + } + + public async IAsyncEnumerable ReadEventsAsync( + [EnumeratorCancellation] CancellationToken cancellationToken) + { + foreach (WorkerEvent workerEvent in Events) + { + cancellationToken.ThrowIfCancellationRequested(); + yield return workerEvent; + } + + if (TerminalException is not null) + { + throw TerminalException; + } + + if (CompleteAfterConfiguredEvents) + { + yield break; + } + + await Task.Delay(Timeout.InfiniteTimeSpan, cancellationToken); + } + + public Task ShutdownAsync( + TimeSpan timeout, + CancellationToken cancellationToken) + { + State = WorkerClientState.Closed; + return Task.CompletedTask; + } + + public void Kill(string reason) + { + State = WorkerClientState.Faulted; + } + + public ValueTask DisposeAsync() + { + return ValueTask.CompletedTask; + } + } +} diff --git a/src/MxGateway.Tests/Gateway/Grpc/MxAccessGatewayServiceTests.cs b/src/MxGateway.Tests/Gateway/Grpc/MxAccessGatewayServiceTests.cs index 85ce3fe..0d7d605 100644 --- a/src/MxGateway.Tests/Gateway/Grpc/MxAccessGatewayServiceTests.cs +++ b/src/MxGateway.Tests/Gateway/Grpc/MxAccessGatewayServiceTests.cs @@ -184,6 +184,7 @@ public sealed class MxAccessGatewayServiceTests identityAccessor ?? new GatewayRequestIdentityAccessor(), new MxAccessGrpcRequestValidator(), new MxAccessGrpcMapper(), + new FakeEventStreamService(sessionManager), NullLogger.Instance); } @@ -275,6 +276,11 @@ public sealed class MxAccessGatewayServiceTests public List Events { get; } = []; + public void RecordReadEventsSessionId(string sessionId) + { + LastReadEventsSessionId = sessionId; + } + public Task OpenSessionAsync( SessionOpenRequest request, string? clientIdentity, @@ -343,6 +349,27 @@ public sealed class MxAccessGatewayServiceTests } } + private sealed class FakeEventStreamService(FakeSessionManager sessionManager) : IEventStreamService + { + public async IAsyncEnumerable StreamEventsAsync( + StreamEventsRequest request, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + sessionManager.RecordReadEventsSessionId(request.SessionId); + foreach (WorkerEvent workerEvent in sessionManager.Events) + { + cancellationToken.ThrowIfCancellationRequested(); + await Task.Yield(); + if (workerEvent.Event.WorkerSequence <= request.AfterWorkerSequence) + { + continue; + } + + yield return workerEvent.Event; + } + } + } + private sealed class FakeWorkerClient(int processId) : IWorkerClient { public string SessionId { get; } = "session-1"; diff --git a/src/MxGateway.Tests/Gateway/Workers/WorkerClientTests.cs b/src/MxGateway.Tests/Gateway/Workers/WorkerClientTests.cs index c519c07..cf55511 100644 --- a/src/MxGateway.Tests/Gateway/Workers/WorkerClientTests.cs +++ b/src/MxGateway.Tests/Gateway/Workers/WorkerClientTests.cs @@ -109,6 +109,32 @@ public sealed class WorkerClientTests Assert.Equal(MxEventFamily.OperationComplete, events.Current.Event.Family); } + [Fact] + public async Task ReadLoop_WhenEventQueueOverflows_FaultsClient() + { + await using PipePair pipePair = await PipePair.CreateAsync(); + await using WorkerClient client = CreateClient( + pipePair, + new WorkerClientOptions + { + EventChannelCapacity = 1, + HeartbeatGrace = TimeSpan.FromSeconds(30), + HeartbeatCheckInterval = TimeSpan.FromSeconds(30), + }); + await CompleteHandshakeAsync(client, pipePair); + + await pipePair.WorkerWriter.WriteAsync( + CreateEventEnvelope(sequence: 11, MxEventFamily.OnDataChange)); + await pipePair.WorkerWriter.WriteAsync( + CreateEventEnvelope(sequence: 12, MxEventFamily.OnDataChange)); + + await WaitUntilAsync( + () => client.State == WorkerClientState.Faulted, + TestTimeout); + + Assert.Equal(WorkerClientState.Faulted, client.State); + } + [Fact] public async Task ReadLoop_WhenPipeDisconnects_FaultsClient() { diff --git a/src/MxGateway.Worker.Tests/MxAccess/MxAccessCommandExecutorTests.cs b/src/MxGateway.Worker.Tests/MxAccess/MxAccessCommandExecutorTests.cs new file mode 100644 index 0000000..1a42f49 --- /dev/null +++ b/src/MxGateway.Worker.Tests/MxAccess/MxAccessCommandExecutorTests.cs @@ -0,0 +1,220 @@ +using System; +using System.Runtime.InteropServices; +using System.Threading.Tasks; +using MxGateway.Contracts.Proto; +using MxGateway.Worker.MxAccess; +using MxGateway.Worker.Sta; + +namespace MxGateway.Worker.Tests.MxAccess; + +public sealed class MxAccessCommandExecutorTests +{ + [Fact] + public async Task DispatchAsync_Register_CallsMxAccessOnStaAndPreservesServerHandle() + { + FakeMxAccessComObjectFactory factory = new(new FakeMxAccessComObject(registerHandle: 42)); + using StaRuntime runtime = CreateRuntime(); + using MxAccessStaSession session = new(runtime, factory, new NoopEventSink()); + await session.StartAsync(workerProcessId: 1234); + + MxCommandReply reply = await session.DispatchAsync(CreateRegisterCommand("correlation-1", "client-a")); + + Assert.Equal(ProtocolStatusCode.Ok, reply.ProtocolStatus.Code); + Assert.True(reply.HasHresult); + Assert.Equal(0, reply.Hresult); + Assert.Equal(42, reply.Register.ServerHandle); + Assert.Equal(MxDataType.Integer, reply.ReturnValue.DataType); + Assert.Equal(42, reply.ReturnValue.Int32Value); + Assert.Equal(runtime.StaThreadId, factory.FakeComObject.RegisterThreadId); + Assert.Equal("client-a", factory.FakeComObject.RegisteredClientName); + + RegisteredServerHandle registeredServerHandle = Assert.Single( + await session.GetRegisteredServerHandlesAsync()); + Assert.Equal(42, registeredServerHandle.ServerHandle); + Assert.Equal("client-a", registeredServerHandle.ClientName); + } + + [Fact] + public async Task DispatchAsync_Unregister_CallsMxAccessOnStaAndRemovesTrackedServerHandle() + { + FakeMxAccessComObject fakeComObject = new(registerHandle: 43); + FakeMxAccessComObjectFactory factory = new(fakeComObject); + using StaRuntime runtime = CreateRuntime(); + using MxAccessStaSession session = new(runtime, factory, new NoopEventSink()); + await session.StartAsync(workerProcessId: 1234); + await session.DispatchAsync(CreateRegisterCommand("register", "client-a")); + + MxCommandReply reply = await session.DispatchAsync(CreateUnregisterCommand("unregister", 43)); + + Assert.Equal(ProtocolStatusCode.Ok, reply.ProtocolStatus.Code); + Assert.Equal(43, fakeComObject.UnregisteredServerHandle); + Assert.Equal(runtime.StaThreadId, fakeComObject.UnregisterThreadId); + Assert.Empty(await session.GetRegisteredServerHandlesAsync()); + } + + [Fact] + public async Task DispatchAsync_UnregisterWhenMxAccessThrows_PreservesHResultAndDoesNotRewriteFailure() + { + const int hresult = unchecked((int)0x80070057); + FakeMxAccessComObject fakeComObject = new( + registerHandle: 44, + unregisterException: new COMException("Invalid handle.", hresult)); + FakeMxAccessComObjectFactory factory = new(fakeComObject); + using StaRuntime runtime = CreateRuntime(); + using MxAccessStaSession session = new(runtime, factory, new NoopEventSink()); + await session.StartAsync(workerProcessId: 1234); + await session.DispatchAsync(CreateRegisterCommand("register-before-failure", "client-a")); + + MxCommandReply reply = await session.DispatchAsync(CreateUnregisterCommand("invalid-unregister", 44)); + + Assert.Equal(ProtocolStatusCode.MxaccessFailure, reply.ProtocolStatus.Code); + Assert.True(reply.HasHresult); + Assert.Equal(hresult, reply.Hresult); + Assert.Contains("0x80070057", reply.DiagnosticMessage); + Assert.Equal(44, fakeComObject.UnregisteredServerHandle); + + RegisteredServerHandle registeredServerHandle = Assert.Single( + await session.GetRegisteredServerHandlesAsync()); + Assert.Equal(44, registeredServerHandle.ServerHandle); + } + + [Fact] + public async Task DispatchAsync_RegisterWithoutPayload_ReturnsInvalidRequest() + { + FakeMxAccessComObjectFactory factory = new(new FakeMxAccessComObject(registerHandle: 45)); + using StaRuntime runtime = CreateRuntime(); + using MxAccessStaSession session = new(runtime, factory, new NoopEventSink()); + await session.StartAsync(workerProcessId: 1234); + + MxCommandReply reply = await session.DispatchAsync(new StaCommand( + "session-1", + "missing-payload", + new MxCommand + { + Kind = MxCommandKind.Register, + })); + + Assert.Equal(ProtocolStatusCode.InvalidRequest, reply.ProtocolStatus.Code); + Assert.Null(factory.FakeComObject.RegisteredClientName); + } + + private static StaCommand CreateRegisterCommand( + string correlationId, + string clientName) + { + return new StaCommand( + "session-1", + correlationId, + new MxCommand + { + Kind = MxCommandKind.Register, + Register = new RegisterCommand + { + ClientName = clientName, + }, + }); + } + + private static StaCommand CreateUnregisterCommand( + string correlationId, + int serverHandle) + { + return new StaCommand( + "session-1", + correlationId, + new MxCommand + { + Kind = MxCommandKind.Unregister, + Unregister = new UnregisterCommand + { + ServerHandle = serverHandle, + }, + }); + } + + private static StaRuntime CreateRuntime() + { + return new StaRuntime( + new NoopComApartmentInitializer(), + new StaMessagePump(), + TimeSpan.FromMilliseconds(25)); + } + + private sealed class FakeMxAccessComObject + { + private readonly int registerHandle; + private readonly Exception? unregisterException; + + public FakeMxAccessComObject( + int registerHandle, + Exception? unregisterException = null) + { + this.registerHandle = registerHandle; + this.unregisterException = unregisterException; + } + + public string? RegisteredClientName { get; private set; } + + public int? RegisterThreadId { get; private set; } + + public int? UnregisteredServerHandle { get; private set; } + + public int? UnregisterThreadId { get; private set; } + + public int Register(string clientName) + { + RegisteredClientName = clientName; + RegisterThreadId = Environment.CurrentManagedThreadId; + + return registerHandle; + } + + public void Unregister(int serverHandle) + { + UnregisteredServerHandle = serverHandle; + UnregisterThreadId = Environment.CurrentManagedThreadId; + + if (unregisterException is not null) + { + throw unregisterException; + } + } + } + + private sealed class FakeMxAccessComObjectFactory : IMxAccessComObjectFactory + { + public FakeMxAccessComObjectFactory(FakeMxAccessComObject fakeComObject) + { + FakeComObject = fakeComObject; + } + + public FakeMxAccessComObject FakeComObject { get; } + + public object Create() + { + return FakeComObject; + } + } + + private sealed class NoopEventSink : IMxAccessEventSink + { + public void Attach(object mxAccessComObject) + { + } + + public void Detach() + { + } + } + + private sealed class NoopComApartmentInitializer : IStaComApartmentInitializer + { + public void Initialize() + { + } + + public void Uninitialize() + { + } + } +} diff --git a/src/MxGateway.Worker.Tests/MxAccess/MxAccessLiveComCreationTests.cs b/src/MxGateway.Worker.Tests/MxAccess/MxAccessLiveComCreationTests.cs index 98014b1..f8ab696 100644 --- a/src/MxGateway.Worker.Tests/MxAccess/MxAccessLiveComCreationTests.cs +++ b/src/MxGateway.Worker.Tests/MxAccess/MxAccessLiveComCreationTests.cs @@ -1,6 +1,8 @@ using System; using System.Threading.Tasks; +using MxGateway.Contracts.Proto; using MxGateway.Worker.MxAccess; +using MxGateway.Worker.Sta; namespace MxGateway.Worker.Tests.MxAccess; @@ -21,4 +23,53 @@ public sealed class MxAccessLiveComCreationTests await session.StartAsync(workerProcessId: 1234); } + + [Fact] + public async Task RegisterAndUnregister_WhenOptedIn_RoundTripsInstalledMxAccessServerHandle() + { + if (!RunLiveMxAccessTests()) + { + return; + } + + using MxAccessStaSession session = new(); + await session.StartAsync(workerProcessId: 1234); + + MxCommandReply registerReply = await session.DispatchAsync(new StaCommand( + "session-1", + "live-register", + new MxCommand + { + Kind = MxCommandKind.Register, + Register = new RegisterCommand + { + ClientName = "MxGateway.Worker.Tests", + }, + })); + + Assert.Equal(ProtocolStatusCode.Ok, registerReply.ProtocolStatus.Code); + Assert.True(registerReply.Register.ServerHandle > 0); + + MxCommandReply unregisterReply = await session.DispatchAsync(new StaCommand( + "session-1", + "live-unregister", + new MxCommand + { + Kind = MxCommandKind.Unregister, + Unregister = new UnregisterCommand + { + ServerHandle = registerReply.Register.ServerHandle, + }, + })); + + Assert.Equal(ProtocolStatusCode.Ok, unregisterReply.ProtocolStatus.Code); + } + + private static bool RunLiveMxAccessTests() + { + return string.Equals( + Environment.GetEnvironmentVariable("MXGATEWAY_RUN_LIVE_MXACCESS_TESTS"), + "1", + StringComparison.Ordinal); + } } diff --git a/src/MxGateway.Worker/MxAccess/IMxAccessServer.cs b/src/MxGateway.Worker/MxAccess/IMxAccessServer.cs new file mode 100644 index 0000000..5ad7525 --- /dev/null +++ b/src/MxGateway.Worker/MxAccess/IMxAccessServer.cs @@ -0,0 +1,8 @@ +namespace MxGateway.Worker.MxAccess; + +public interface IMxAccessServer +{ + int Register(string clientName); + + void Unregister(int serverHandle); +} diff --git a/src/MxGateway.Worker/MxAccess/MxAccessComServer.cs b/src/MxGateway.Worker/MxAccess/MxAccessComServer.cs new file mode 100644 index 0000000..e1a601b --- /dev/null +++ b/src/MxGateway.Worker/MxAccess/MxAccessComServer.cs @@ -0,0 +1,59 @@ +using System; +using System.Reflection; +using System.Runtime.ExceptionServices; +using ArchestrA.MxAccess; + +namespace MxGateway.Worker.MxAccess; + +public sealed class MxAccessComServer : IMxAccessServer +{ + private readonly object mxAccessComObject; + + public MxAccessComServer(object mxAccessComObject) + { + this.mxAccessComObject = mxAccessComObject ?? throw new ArgumentNullException(nameof(mxAccessComObject)); + } + + public int Register(string clientName) + { + if (mxAccessComObject is ILMXProxyServer mxAccessServer) + { + return mxAccessServer.Register(clientName); + } + + return (int)Invoke(nameof(Register), clientName); + } + + public void Unregister(int serverHandle) + { + if (mxAccessComObject is ILMXProxyServer mxAccessServer) + { + mxAccessServer.Unregister(serverHandle); + return; + } + + Invoke(nameof(Unregister), serverHandle); + } + + private object Invoke( + string methodName, + params object[] arguments) + { + try + { + return mxAccessComObject + .GetType() + .InvokeMember( + methodName, + BindingFlags.Instance | BindingFlags.Public | BindingFlags.InvokeMethod, + binder: null, + target: mxAccessComObject, + args: arguments); + } + catch (TargetInvocationException exception) when (exception.InnerException is not null) + { + ExceptionDispatchInfo.Capture(exception.InnerException).Throw(); + throw; + } + } +} diff --git a/src/MxGateway.Worker/MxAccess/MxAccessCommandExecutor.cs b/src/MxGateway.Worker/MxAccess/MxAccessCommandExecutor.cs new file mode 100644 index 0000000..9c6ebb1 --- /dev/null +++ b/src/MxGateway.Worker/MxAccess/MxAccessCommandExecutor.cs @@ -0,0 +1,103 @@ +using System; +using MxGateway.Contracts.Proto; +using MxGateway.Worker.Conversion; +using MxGateway.Worker.Sta; + +namespace MxGateway.Worker.MxAccess; + +public sealed class MxAccessCommandExecutor : IStaCommandExecutor +{ + private readonly MxAccessSession session; + private readonly VariantConverter variantConverter; + + public MxAccessCommandExecutor(MxAccessSession session) + : this(session, new VariantConverter()) + { + } + + public MxAccessCommandExecutor( + MxAccessSession session, + VariantConverter variantConverter) + { + this.session = session ?? throw new ArgumentNullException(nameof(session)); + this.variantConverter = variantConverter ?? throw new ArgumentNullException(nameof(variantConverter)); + } + + public MxCommandReply Execute(StaCommand command) + { + if (command is null) + { + throw new ArgumentNullException(nameof(command)); + } + + return command.Kind switch + { + MxCommandKind.Register => ExecuteRegister(command), + MxCommandKind.Unregister => ExecuteUnregister(command), + _ => CreateInvalidRequestReply(command, $"Unsupported MXAccess command kind {command.Kind}."), + }; + } + + private MxCommandReply ExecuteRegister(StaCommand command) + { + if (command.Command.PayloadCase != MxCommand.PayloadOneofCase.Register) + { + return CreateInvalidRequestReply(command, "Register command payload is required."); + } + + int serverHandle = session.Register(command.Command.Register.ClientName); + MxCommandReply reply = CreateOkReply(command); + reply.ReturnValue = variantConverter.Convert(serverHandle); + reply.Register = new RegisterReply + { + ServerHandle = serverHandle, + }; + + return reply; + } + + private MxCommandReply ExecuteUnregister(StaCommand command) + { + if (command.Command.PayloadCase != MxCommand.PayloadOneofCase.Unregister) + { + return CreateInvalidRequestReply(command, "Unregister command payload is required."); + } + + session.Unregister(command.Command.Unregister.ServerHandle); + return CreateOkReply(command); + } + + private static MxCommandReply CreateOkReply(StaCommand command) + { + return new MxCommandReply + { + SessionId = command.SessionId, + CorrelationId = command.CorrelationId, + Kind = command.Kind, + Hresult = 0, + ProtocolStatus = new ProtocolStatus + { + Code = ProtocolStatusCode.Ok, + Message = "OK", + }, + }; + } + + private static MxCommandReply CreateInvalidRequestReply( + StaCommand command, + string message) + { + return new MxCommandReply + { + SessionId = command.SessionId, + CorrelationId = command.CorrelationId, + Kind = command.Kind, + ProtocolStatus = new ProtocolStatus + { + Code = ProtocolStatusCode.InvalidRequest, + Message = message, + }, + DiagnosticMessage = message, + }; + } +} diff --git a/src/MxGateway.Worker/MxAccess/MxAccessHandleRegistry.cs b/src/MxGateway.Worker/MxAccess/MxAccessHandleRegistry.cs new file mode 100644 index 0000000..669acb3 --- /dev/null +++ b/src/MxGateway.Worker/MxAccess/MxAccessHandleRegistry.cs @@ -0,0 +1,31 @@ +using System.Collections.Generic; +using System.Linq; + +namespace MxGateway.Worker.MxAccess; + +public sealed class MxAccessHandleRegistry +{ + private readonly Dictionary serverHandles = new(); + + public IReadOnlyList ServerHandles => serverHandles + .Values + .OrderBy(handle => handle.ServerHandle) + .ToArray(); + + public void RegisterServerHandle( + int serverHandle, + string clientName) + { + serverHandles[serverHandle] = new RegisteredServerHandle(serverHandle, clientName); + } + + public void UnregisterServerHandle(int serverHandle) + { + serverHandles.Remove(serverHandle); + } + + public bool ContainsServerHandle(int serverHandle) + { + return serverHandles.ContainsKey(serverHandle); + } +} diff --git a/src/MxGateway.Worker/MxAccess/MxAccessSession.cs b/src/MxGateway.Worker/MxAccess/MxAccessSession.cs index 78a168e..6874ef0 100644 --- a/src/MxGateway.Worker/MxAccess/MxAccessSession.cs +++ b/src/MxGateway.Worker/MxAccess/MxAccessSession.cs @@ -8,21 +8,29 @@ namespace MxGateway.Worker.MxAccess; public sealed class MxAccessSession : IDisposable { private readonly object mxAccessComObject; + private readonly IMxAccessServer mxAccessServer; private readonly IMxAccessEventSink eventSink; + private readonly MxAccessHandleRegistry handleRegistry; private bool disposed; private MxAccessSession( object mxAccessComObject, + IMxAccessServer mxAccessServer, IMxAccessEventSink eventSink, + MxAccessHandleRegistry handleRegistry, int creationThreadId) { this.mxAccessComObject = mxAccessComObject ?? throw new ArgumentNullException(nameof(mxAccessComObject)); + this.mxAccessServer = mxAccessServer ?? throw new ArgumentNullException(nameof(mxAccessServer)); this.eventSink = eventSink ?? throw new ArgumentNullException(nameof(eventSink)); + this.handleRegistry = handleRegistry ?? throw new ArgumentNullException(nameof(handleRegistry)); CreationThreadId = creationThreadId; } public int CreationThreadId { get; } + public MxAccessHandleRegistry HandleRegistry => handleRegistry; + public WorkerReady CreateWorkerReady(int workerProcessId) { return new WorkerReady @@ -62,7 +70,9 @@ public sealed class MxAccessSession : IDisposable return new MxAccessSession( mxAccessComObject, + new MxAccessComServer(mxAccessComObject), eventSink, + new MxAccessHandleRegistry(), Environment.CurrentManagedThreadId); } catch (Exception exception) @@ -78,6 +88,24 @@ public sealed class MxAccessSession : IDisposable } } + public int Register(string clientName) + { + ThrowIfDisposed(); + + int serverHandle = mxAccessServer.Register(clientName); + handleRegistry.RegisterServerHandle(serverHandle, clientName); + + return serverHandle; + } + + public void Unregister(int serverHandle) + { + ThrowIfDisposed(); + + mxAccessServer.Unregister(serverHandle); + handleRegistry.UnregisterServerHandle(serverHandle); + } + public void Dispose() { if (disposed) @@ -94,4 +122,12 @@ public sealed class MxAccessSession : IDisposable disposed = true; } + + private void ThrowIfDisposed() + { + if (disposed) + { + throw new ObjectDisposedException(nameof(MxAccessSession)); + } + } } diff --git a/src/MxGateway.Worker/MxAccess/MxAccessStaSession.cs b/src/MxGateway.Worker/MxAccess/MxAccessStaSession.cs index 770a332..945633b 100644 --- a/src/MxGateway.Worker/MxAccess/MxAccessStaSession.cs +++ b/src/MxGateway.Worker/MxAccess/MxAccessStaSession.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using MxGateway.Contracts.Proto; @@ -11,6 +12,7 @@ public sealed class MxAccessStaSession : IDisposable private readonly IMxAccessComObjectFactory factory; private readonly IMxAccessEventSink eventSink; private readonly StaRuntime staRuntime; + private StaCommandDispatcher? commandDispatcher; private MxAccessSession? session; private bool disposed; @@ -47,11 +49,38 @@ public sealed class MxAccessStaSession : IDisposable } session = MxAccessSession.Create(factory, eventSink); + commandDispatcher = new StaCommandDispatcher( + staRuntime, + new MxAccessCommandExecutor(session)); + return session.CreateWorkerReady(workerProcessId); }, cancellationToken); } + public Task DispatchAsync(StaCommand command) + { + if (commandDispatcher is null) + { + throw new InvalidOperationException("MXAccess COM session has not been started."); + } + + return commandDispatcher.DispatchAsync(command); + } + + public Task> GetRegisteredServerHandlesAsync( + CancellationToken cancellationToken = default) + { + if (session is null) + { + throw new InvalidOperationException("MXAccess COM session has not been started."); + } + + return staRuntime.InvokeAsync( + () => session.HandleRegistry.ServerHandles, + cancellationToken); + } + public void Dispose() { if (disposed) @@ -59,6 +88,8 @@ public sealed class MxAccessStaSession : IDisposable return; } + commandDispatcher?.RequestShutdown(); + if (session is not null) { staRuntime.InvokeAsync(() => session.Dispose()).GetAwaiter().GetResult(); diff --git a/src/MxGateway.Worker/MxAccess/RegisteredServerHandle.cs b/src/MxGateway.Worker/MxAccess/RegisteredServerHandle.cs new file mode 100644 index 0000000..1560970 --- /dev/null +++ b/src/MxGateway.Worker/MxAccess/RegisteredServerHandle.cs @@ -0,0 +1,16 @@ +namespace MxGateway.Worker.MxAccess; + +public sealed class RegisteredServerHandle +{ + public RegisteredServerHandle( + int serverHandle, + string clientName) + { + ServerHandle = serverHandle; + ClientName = clientName; + } + + public int ServerHandle { get; } + + public string ClientName { get; } +}