Merge remote-tracking branch 'origin/main' into agent-1/issue-17-implement-dashboard-authentication

This commit is contained in:
Joseph Doherty
2026-04-26 18:15:38 -04:00
23 changed files with 1279 additions and 37 deletions
@@ -37,6 +37,7 @@ public static class GatewayApplication
builder.Services.AddSingleton<GatewayMetrics>();
builder.Services.AddSingleton<MxAccessGrpcMapper>();
builder.Services.AddSingleton<MxAccessGrpcRequestValidator>();
builder.Services.AddSingleton<IEventStreamService, EventStreamService>();
builder.Services.AddWorkerProcessLauncher();
builder.Services.AddGatewaySessions();
builder.Services.AddGatewayDashboard();
@@ -0,0 +1,140 @@
using System.Runtime.CompilerServices;
using System.Threading.Channels;
using Microsoft.Extensions.Options;
using MxGateway.Contracts.Proto;
using MxGateway.Server.Configuration;
using MxGateway.Server.Metrics;
using MxGateway.Server.Sessions;
using MxGateway.Server.Workers;
namespace MxGateway.Server.Grpc;
public sealed class EventStreamService(
ISessionManager sessionManager,
IOptions<GatewayOptions> options,
MxAccessGrpcMapper mapper,
GatewayMetrics metrics,
ILogger<EventStreamService> logger) : IEventStreamService
{
public async IAsyncEnumerable<MxEvent> StreamEventsAsync(
StreamEventsRequest request,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
if (!sessionManager.TryGetSession(request.SessionId, out GatewaySession session))
{
throw new SessionManagerException(
SessionManagerErrorCode.SessionNotFound,
$"Session {request.SessionId} was not found.");
}
using IDisposable subscriber = session.AttachEventSubscriber(
options.Value.Sessions.AllowMultipleEventSubscribers);
using CancellationTokenSource streamCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
int streamQueueDepth = 0;
Channel<MxEvent> eventQueue = Channel.CreateBounded<MxEvent>(
new BoundedChannelOptions(options.Value.Events.QueueCapacity)
{
SingleReader = true,
SingleWriter = true,
FullMode = BoundedChannelFullMode.Wait,
AllowSynchronousContinuations = false,
});
Task producerTask = ProduceEventsAsync(
session,
request.AfterWorkerSequence,
eventQueue.Writer,
() =>
{
int depth = Interlocked.Increment(ref streamQueueDepth);
metrics.SetEventQueueDepth(depth);
},
streamCts.Token);
try
{
await foreach (MxEvent mxEvent in eventQueue.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false))
{
int depth = Math.Max(0, Interlocked.Decrement(ref streamQueueDepth));
metrics.SetEventQueueDepth(depth);
yield return mxEvent;
}
await producerTask.ConfigureAwait(false);
}
finally
{
await streamCts.CancelAsync().ConfigureAwait(false);
subscriber.Dispose();
metrics.StreamDisconnected("Detached");
try
{
await producerTask.ConfigureAwait(false);
}
catch (OperationCanceledException) when (streamCts.IsCancellationRequested)
{
}
catch (Exception exception)
{
logger.LogDebug(
exception,
"Event stream producer stopped for session {SessionId}.",
request.SessionId);
}
}
}
private async Task ProduceEventsAsync(
GatewaySession session,
ulong afterWorkerSequence,
ChannelWriter<MxEvent> writer,
Action eventQueued,
CancellationToken cancellationToken)
{
try
{
await foreach (WorkerEvent workerEvent in session
.ReadEventsAsync(cancellationToken)
.WithCancellation(cancellationToken)
.ConfigureAwait(false))
{
MxEvent publicEvent = mapper.MapEvent(workerEvent);
if (publicEvent.WorkerSequence <= afterWorkerSequence)
{
continue;
}
if (!writer.TryWrite(publicEvent))
{
string message = $"Session {session.SessionId} event stream queue overflowed.";
session.MarkFaulted(message);
metrics.QueueOverflow("grpc-event-stream");
metrics.Fault(SessionManagerErrorCode.EventQueueOverflow.ToString());
writer.TryComplete(new SessionManagerException(
SessionManagerErrorCode.EventQueueOverflow,
message));
return;
}
eventQueued();
}
writer.TryComplete();
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
writer.TryComplete();
}
catch (Exception exception)
{
if (exception is WorkerClientException)
{
session.MarkFaulted(exception.Message);
metrics.Fault(WorkerClientErrorCode.WorkerFaulted.ToString());
}
writer.TryComplete(exception);
}
}
}
@@ -0,0 +1,10 @@
using MxGateway.Contracts.Proto;
namespace MxGateway.Server.Grpc;
public interface IEventStreamService
{
IAsyncEnumerable<MxEvent> StreamEventsAsync(
StreamEventsRequest request,
CancellationToken cancellationToken);
}
@@ -12,6 +12,7 @@ public sealed class MxAccessGatewayService(
IGatewayRequestIdentityAccessor identityAccessor,
MxAccessGrpcRequestValidator requestValidator,
MxAccessGrpcMapper mapper,
IEventStreamService eventStreamService,
ILogger<MxAccessGatewayService> logger) : MxAccessGateway.MxAccessGatewayBase
{
public override async Task<OpenSessionReply> OpenSession(
@@ -102,17 +103,11 @@ public sealed class MxAccessGatewayService(
try
{
requestValidator.ValidateStreamEvents(request);
await foreach (WorkerEvent workerEvent in sessionManager
.ReadEventsAsync(request.SessionId, context.CancellationToken)
await foreach (MxEvent publicEvent in eventStreamService
.StreamEventsAsync(request, context.CancellationToken)
.WithCancellation(context.CancellationToken)
.ConfigureAwait(false))
{
MxEvent publicEvent = mapper.MapEvent(workerEvent);
if (publicEvent.WorkerSequence <= request.AfterWorkerSequence)
{
continue;
}
await responseStream.WriteAsync(publicEvent).ConfigureAwait(false);
}
}
@@ -154,6 +149,8 @@ public sealed class MxAccessGatewayService(
{
SessionManagerErrorCode.SessionNotFound => StatusCode.NotFound,
SessionManagerErrorCode.SessionNotReady => StatusCode.FailedPrecondition,
SessionManagerErrorCode.EventSubscriberAlreadyActive => StatusCode.ResourceExhausted,
SessionManagerErrorCode.EventQueueOverflow => StatusCode.ResourceExhausted,
SessionManagerErrorCode.SessionLimitExceeded => StatusCode.ResourceExhausted,
SessionManagerErrorCode.OpenFailed => StatusCode.Unavailable,
SessionManagerErrorCode.CloseFailed => StatusCode.Unavailable,
@@ -13,6 +13,7 @@ public sealed class GatewaySession
private DateTimeOffset _lastClientActivityAt;
private DateTimeOffset? _leaseExpiresAt;
private bool _closeStarted;
private int _activeEventSubscriberCount;
public GatewaySession(
string sessionId,
@@ -131,6 +132,17 @@ public sealed class GatewaySession
}
}
public int ActiveEventSubscriberCount
{
get
{
lock (_syncRoot)
{
return _activeEventSubscriberCount;
}
}
}
public void AttachWorkerClient(IWorkerClient workerClient)
{
ArgumentNullException.ThrowIfNull(workerClient);
@@ -202,6 +214,29 @@ public sealed class GatewaySession
}
}
public IDisposable AttachEventSubscriber(bool allowMultipleSubscribers)
{
lock (_syncRoot)
{
if (_state != SessionState.Ready || _workerClient?.State != WorkerClientState.Ready)
{
throw new SessionManagerException(
SessionManagerErrorCode.SessionNotReady,
$"Session {SessionId} is not ready for event streaming. Current state is {_state}.");
}
if (!allowMultipleSubscribers && _activeEventSubscriberCount > 0)
{
throw new SessionManagerException(
SessionManagerErrorCode.EventSubscriberAlreadyActive,
$"Session {SessionId} already has an active event stream subscriber.");
}
_activeEventSubscriberCount++;
return new EventSubscriberLease(this);
}
}
public async Task<WorkerCommandReply> InvokeAsync(
WorkerCommand command,
CancellationToken cancellationToken)
@@ -287,4 +322,31 @@ public sealed class GatewaySession
return _workerClient;
}
}
private void DetachEventSubscriber()
{
lock (_syncRoot)
{
if (_activeEventSubscriberCount > 0)
{
_activeEventSubscriberCount--;
}
}
}
private sealed class EventSubscriberLease(GatewaySession session) : IDisposable
{
private bool _disposed;
public void Dispose()
{
if (_disposed)
{
return;
}
session.DetachEventSubscriber();
_disposed = true;
}
}
}
@@ -4,6 +4,8 @@ public enum SessionManagerErrorCode
{
SessionNotFound,
SessionNotReady,
EventSubscriberAlreadyActive,
EventQueueOverflow,
SessionLimitExceeded,
OpenFailed,
CloseFailed,
+7 -5
View File
@@ -29,6 +29,7 @@ public sealed class WorkerClient : IWorkerClient
private WorkerClientState _state;
private DateTimeOffset _lastHeartbeatAt;
private int? _processId;
private int _eventQueueDepth;
private Task? _readLoopTask;
private Task? _writeLoopTask;
private Task? _heartbeatLoopTask;
@@ -197,6 +198,8 @@ public sealed class WorkerClient : IWorkerClient
{
await foreach (WorkerEvent workerEvent in _events.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false))
{
int queueDepth = Math.Max(0, Interlocked.Decrement(ref _eventQueueDepth));
_metrics?.SetEventQueueDepth(queueDepth);
yield return workerEvent;
}
}
@@ -394,11 +397,6 @@ public sealed class WorkerClient : IWorkerClient
_metrics?.EventReceived(SessionId, workerEvent.Event.Family.ToString());
}
if (!await _events.Writer.WaitToWriteAsync(cancellationToken).ConfigureAwait(false))
{
return;
}
if (!_events.Writer.TryWrite(workerEvent))
{
_metrics?.QueueOverflow("worker-events");
@@ -406,7 +404,11 @@ public sealed class WorkerClient : IWorkerClient
WorkerClientErrorCode.ProtocolViolation,
"Worker event channel rejected an event.",
null);
return;
}
int queueDepth = Interlocked.Increment(ref _eventQueueDepth);
_metrics?.SetEventQueueDepth(queueDepth);
}
private void CompleteCommand(WorkerEnvelope envelope)
@@ -0,0 +1,383 @@
using System.Runtime.CompilerServices;
using Grpc.Core;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Server.Configuration;
using MxGateway.Server.Grpc;
using MxGateway.Server.Metrics;
using MxGateway.Server.Sessions;
using MxGateway.Server.Workers;
namespace MxGateway.Tests.Gateway.Grpc;
public sealed class EventStreamServiceTests
{
private static readonly TimeSpan TestTimeout = TimeSpan.FromSeconds(5);
[Fact]
public async Task StreamEventsAsync_YieldsEventsInWorkerOrder()
{
FakeWorkerClient workerClient = new();
GatewaySession session = CreateReadySession(workerClient);
FakeSessionManager sessionManager = new(session);
using GatewayMetrics metrics = new();
EventStreamService service = CreateService(sessionManager, metrics: metrics);
workerClient.Events.Add(CreateWorkerEvent(sequence: 10, MxEventFamily.OnDataChange));
workerClient.Events.Add(CreateWorkerEvent(sequence: 11, MxEventFamily.OnWriteComplete));
workerClient.CompleteAfterConfiguredEvents = true;
List<MxEvent> events = await CollectEventsAsync(service, session.SessionId);
Assert.Equal([10UL, 11UL], events.Select(mxEvent => mxEvent.WorkerSequence).ToArray());
Assert.Equal(MxEventFamily.OnDataChange, events[0].Family);
Assert.Equal(MxEventFamily.OnWriteComplete, events[1].Family);
Assert.Equal(1, metrics.GetSnapshot().StreamDisconnects);
}
[Fact]
public async Task StreamEventsAsync_WhenSecondSubscriberStarts_RejectsClearly()
{
FakeWorkerClient workerClient = new();
GatewaySession session = CreateReadySession(workerClient);
EventStreamService service = CreateService(new FakeSessionManager(session));
using CancellationTokenSource firstSubscriberCancellation = new();
await using IAsyncEnumerator<MxEvent> firstSubscriber = service
.StreamEventsAsync(CreateRequest(session.SessionId), firstSubscriberCancellation.Token)
.GetAsyncEnumerator(firstSubscriberCancellation.Token);
Task<bool> firstMoveTask = firstSubscriber.MoveNextAsync().AsTask();
await WaitUntilAsync(() => session.ActiveEventSubscriberCount == 1);
await using IAsyncEnumerator<MxEvent> secondSubscriber = service
.StreamEventsAsync(CreateRequest(session.SessionId), CancellationToken.None)
.GetAsyncEnumerator();
SessionManagerException exception = await Assert.ThrowsAsync<SessionManagerException>(
async () => await secondSubscriber.MoveNextAsync().AsTask().WaitAsync(TestTimeout));
Assert.Equal(SessionManagerErrorCode.EventSubscriberAlreadyActive, exception.ErrorCode);
await firstSubscriberCancellation.CancelAsync();
await Assert.ThrowsAnyAsync<OperationCanceledException>(
async () => await firstMoveTask.WaitAsync(TestTimeout));
await firstSubscriber.DisposeAsync();
await WaitUntilAsync(() => session.ActiveEventSubscriberCount == 0);
}
[Fact]
public async Task StreamEventsAsync_WhenCanceled_DetachesSubscriber()
{
FakeWorkerClient workerClient = new();
GatewaySession session = CreateReadySession(workerClient);
EventStreamService service = CreateService(new FakeSessionManager(session));
using CancellationTokenSource cancellationTokenSource = new();
await using IAsyncEnumerator<MxEvent> subscriber = service
.StreamEventsAsync(CreateRequest(session.SessionId), cancellationTokenSource.Token)
.GetAsyncEnumerator(cancellationTokenSource.Token);
Task<bool> moveTask = subscriber.MoveNextAsync().AsTask();
await WaitUntilAsync(() => session.ActiveEventSubscriberCount == 1);
await cancellationTokenSource.CancelAsync();
await Assert.ThrowsAnyAsync<OperationCanceledException>(
async () => await moveTask.WaitAsync(TestTimeout));
await subscriber.DisposeAsync();
await WaitUntilAsync(() => session.ActiveEventSubscriberCount == 0);
}
[Fact]
public async Task StreamEventsAsync_WhenStreamQueueOverflows_FaultsSessionAndReportsOverflow()
{
FakeWorkerClient workerClient = new();
GatewaySession session = CreateReadySession(workerClient);
using GatewayMetrics metrics = new();
EventStreamService service = CreateService(
new FakeSessionManager(session),
metrics,
queueCapacity: 1);
workerClient.Events.Add(CreateWorkerEvent(sequence: 1, MxEventFamily.OnDataChange));
workerClient.Events.Add(CreateWorkerEvent(sequence: 2, MxEventFamily.OnDataChange));
workerClient.Events.Add(CreateWorkerEvent(sequence: 3, MxEventFamily.OnDataChange));
workerClient.CompleteAfterConfiguredEvents = true;
await using IAsyncEnumerator<MxEvent> subscriber = service
.StreamEventsAsync(CreateRequest(session.SessionId), CancellationToken.None)
.GetAsyncEnumerator();
Assert.True(await subscriber.MoveNextAsync().AsTask().WaitAsync(TestTimeout));
await WaitUntilAsync(() => session.State == SessionState.Faulted);
SessionManagerException exception = await Assert.ThrowsAsync<SessionManagerException>(
async () => await subscriber.MoveNextAsync().AsTask().WaitAsync(TestTimeout));
Assert.Equal(SessionManagerErrorCode.EventQueueOverflow, exception.ErrorCode);
Assert.Equal(SessionState.Faulted, session.State);
Assert.Equal(1, metrics.GetSnapshot().QueueOverflows);
Assert.Equal(1, metrics.GetSnapshot().Faults);
}
[Fact]
public async Task StreamEventsAsync_DoesNotSynthesizeOperationComplete()
{
FakeWorkerClient workerClient = new();
GatewaySession session = CreateReadySession(workerClient);
EventStreamService service = CreateService(new FakeSessionManager(session));
workerClient.Events.Add(CreateWorkerEvent(sequence: 10, MxEventFamily.OnWriteComplete));
workerClient.CompleteAfterConfiguredEvents = true;
List<MxEvent> events = await CollectEventsAsync(service, session.SessionId);
MxEvent mxEvent = Assert.Single(events);
Assert.Equal(MxEventFamily.OnWriteComplete, mxEvent.Family);
Assert.DoesNotContain(events, candidate => candidate.Family == MxEventFamily.OperationComplete);
}
[Fact]
public async Task StreamEventsAsync_WhenWorkerEventStreamFaults_PropagatesTerminalFault()
{
FakeWorkerClient workerClient = new()
{
TerminalException = new WorkerClientException(
WorkerClientErrorCode.WorkerFaulted,
"worker terminal fault"),
};
GatewaySession session = CreateReadySession(workerClient);
using GatewayMetrics metrics = new();
EventStreamService service = CreateService(new FakeSessionManager(session), metrics);
await using IAsyncEnumerator<MxEvent> subscriber = service
.StreamEventsAsync(CreateRequest(session.SessionId), CancellationToken.None)
.GetAsyncEnumerator();
WorkerClientException exception = await Assert.ThrowsAsync<WorkerClientException>(
async () => await subscriber.MoveNextAsync().AsTask().WaitAsync(TestTimeout));
Assert.Equal(WorkerClientErrorCode.WorkerFaulted, exception.ErrorCode);
Assert.Equal(SessionState.Faulted, session.State);
Assert.Equal(1, metrics.GetSnapshot().Faults);
}
private static EventStreamService CreateService(
FakeSessionManager sessionManager,
GatewayMetrics? metrics = null,
int queueCapacity = 8)
{
return new EventStreamService(
sessionManager,
Options.Create(new GatewayOptions
{
Events = new EventOptions
{
QueueCapacity = queueCapacity,
},
}),
new MxAccessGrpcMapper(),
metrics ?? new GatewayMetrics(),
NullLogger<EventStreamService>.Instance);
}
private static async Task<List<MxEvent>> CollectEventsAsync(
EventStreamService service,
string sessionId)
{
List<MxEvent> events = [];
await foreach (MxEvent mxEvent in service
.StreamEventsAsync(CreateRequest(sessionId), CancellationToken.None)
.WithCancellation(CancellationToken.None))
{
events.Add(mxEvent);
}
return events;
}
private static StreamEventsRequest CreateRequest(string sessionId)
{
return new StreamEventsRequest
{
SessionId = sessionId,
};
}
private static GatewaySession CreateReadySession(FakeWorkerClient workerClient)
{
GatewaySession session = new(
"session-events",
GatewayContractInfo.DefaultBackendName,
"pipe",
"nonce",
"client",
"client-session",
"client-correlation",
TimeSpan.FromSeconds(30),
TimeSpan.FromSeconds(30),
TimeSpan.FromSeconds(10),
DateTimeOffset.UtcNow);
session.AttachWorkerClient(workerClient);
session.MarkReady();
return session;
}
private static WorkerEvent CreateWorkerEvent(
ulong sequence,
MxEventFamily family)
{
MxEvent mxEvent = new()
{
SessionId = "session-events",
Family = family,
WorkerSequence = sequence,
};
switch (family)
{
case MxEventFamily.OnDataChange:
mxEvent.OnDataChange = new OnDataChangeEvent();
break;
case MxEventFamily.OnWriteComplete:
mxEvent.OnWriteComplete = new OnWriteCompleteEvent();
break;
case MxEventFamily.OperationComplete:
mxEvent.OperationComplete = new OperationCompleteEvent();
break;
case MxEventFamily.OnBufferedDataChange:
mxEvent.OnBufferedDataChange = new OnBufferedDataChangeEvent();
break;
}
return new WorkerEvent
{
Event = mxEvent,
};
}
private static async Task WaitUntilAsync(Func<bool> predicate)
{
using CancellationTokenSource cancellationTokenSource = new(TestTimeout);
while (!predicate())
{
await Task.Delay(TimeSpan.FromMilliseconds(10), cancellationTokenSource.Token);
}
}
private sealed class FakeSessionManager(GatewaySession session) : ISessionManager
{
public Task<GatewaySession> OpenSessionAsync(
SessionOpenRequest request,
string? clientIdentity,
CancellationToken cancellationToken)
{
return Task.FromResult(session);
}
public bool TryGetSession(
string sessionId,
out GatewaySession gatewaySession)
{
gatewaySession = session;
return string.Equals(sessionId, session.SessionId, StringComparison.Ordinal);
}
public Task<WorkerCommandReply> InvokeAsync(
string sessionId,
WorkerCommand command,
CancellationToken cancellationToken)
{
return Task.FromResult(new WorkerCommandReply());
}
public IAsyncEnumerable<WorkerEvent> ReadEventsAsync(
string sessionId,
CancellationToken cancellationToken)
{
return session.ReadEventsAsync(cancellationToken);
}
public Task<SessionCloseResult> CloseSessionAsync(
string sessionId,
CancellationToken cancellationToken)
{
return Task.FromResult(new SessionCloseResult(sessionId, SessionState.Closed, AlreadyClosed: false));
}
public Task<int> CloseExpiredLeasesAsync(
DateTimeOffset now,
CancellationToken cancellationToken)
{
return Task.FromResult(0);
}
public Task ShutdownAsync(CancellationToken cancellationToken)
{
return Task.CompletedTask;
}
}
private sealed class FakeWorkerClient : IWorkerClient
{
public List<WorkerEvent> Events { get; } = [];
public bool CompleteAfterConfiguredEvents { get; set; }
public Exception? TerminalException { get; init; }
public string SessionId { get; } = "session-events";
public int? ProcessId { get; } = 4321;
public WorkerClientState State { get; private set; } = WorkerClientState.Ready;
public DateTimeOffset LastHeartbeatAt { get; } = DateTimeOffset.UtcNow;
public Task StartAsync(CancellationToken cancellationToken)
{
return Task.CompletedTask;
}
public Task<WorkerCommandReply> InvokeAsync(
WorkerCommand command,
TimeSpan timeout,
CancellationToken cancellationToken)
{
return Task.FromResult(new WorkerCommandReply());
}
public async IAsyncEnumerable<WorkerEvent> ReadEventsAsync(
[EnumeratorCancellation] CancellationToken cancellationToken)
{
foreach (WorkerEvent workerEvent in Events)
{
cancellationToken.ThrowIfCancellationRequested();
yield return workerEvent;
}
if (TerminalException is not null)
{
throw TerminalException;
}
if (CompleteAfterConfiguredEvents)
{
yield break;
}
await Task.Delay(Timeout.InfiniteTimeSpan, cancellationToken);
}
public Task ShutdownAsync(
TimeSpan timeout,
CancellationToken cancellationToken)
{
State = WorkerClientState.Closed;
return Task.CompletedTask;
}
public void Kill(string reason)
{
State = WorkerClientState.Faulted;
}
public ValueTask DisposeAsync()
{
return ValueTask.CompletedTask;
}
}
}
@@ -184,6 +184,7 @@ public sealed class MxAccessGatewayServiceTests
identityAccessor ?? new GatewayRequestIdentityAccessor(),
new MxAccessGrpcRequestValidator(),
new MxAccessGrpcMapper(),
new FakeEventStreamService(sessionManager),
NullLogger<MxAccessGatewayService>.Instance);
}
@@ -275,6 +276,11 @@ public sealed class MxAccessGatewayServiceTests
public List<WorkerEvent> Events { get; } = [];
public void RecordReadEventsSessionId(string sessionId)
{
LastReadEventsSessionId = sessionId;
}
public Task<GatewaySession> OpenSessionAsync(
SessionOpenRequest request,
string? clientIdentity,
@@ -343,6 +349,27 @@ public sealed class MxAccessGatewayServiceTests
}
}
private sealed class FakeEventStreamService(FakeSessionManager sessionManager) : IEventStreamService
{
public async IAsyncEnumerable<MxEvent> StreamEventsAsync(
StreamEventsRequest request,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
sessionManager.RecordReadEventsSessionId(request.SessionId);
foreach (WorkerEvent workerEvent in sessionManager.Events)
{
cancellationToken.ThrowIfCancellationRequested();
await Task.Yield();
if (workerEvent.Event.WorkerSequence <= request.AfterWorkerSequence)
{
continue;
}
yield return workerEvent.Event;
}
}
}
private sealed class FakeWorkerClient(int processId) : IWorkerClient
{
public string SessionId { get; } = "session-1";
@@ -109,6 +109,32 @@ public sealed class WorkerClientTests
Assert.Equal(MxEventFamily.OperationComplete, events.Current.Event.Family);
}
[Fact]
public async Task ReadLoop_WhenEventQueueOverflows_FaultsClient()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(
pipePair,
new WorkerClientOptions
{
EventChannelCapacity = 1,
HeartbeatGrace = TimeSpan.FromSeconds(30),
HeartbeatCheckInterval = TimeSpan.FromSeconds(30),
});
await CompleteHandshakeAsync(client, pipePair);
await pipePair.WorkerWriter.WriteAsync(
CreateEventEnvelope(sequence: 11, MxEventFamily.OnDataChange));
await pipePair.WorkerWriter.WriteAsync(
CreateEventEnvelope(sequence: 12, MxEventFamily.OnDataChange));
await WaitUntilAsync(
() => client.State == WorkerClientState.Faulted,
TestTimeout);
Assert.Equal(WorkerClientState.Faulted, client.State);
}
[Fact]
public async Task ReadLoop_WhenPipeDisconnects_FaultsClient()
{
@@ -0,0 +1,220 @@
using System;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.MxAccess;
using MxGateway.Worker.Sta;
namespace MxGateway.Worker.Tests.MxAccess;
public sealed class MxAccessCommandExecutorTests
{
[Fact]
public async Task DispatchAsync_Register_CallsMxAccessOnStaAndPreservesServerHandle()
{
FakeMxAccessComObjectFactory factory = new(new FakeMxAccessComObject(registerHandle: 42));
using StaRuntime runtime = CreateRuntime();
using MxAccessStaSession session = new(runtime, factory, new NoopEventSink());
await session.StartAsync(workerProcessId: 1234);
MxCommandReply reply = await session.DispatchAsync(CreateRegisterCommand("correlation-1", "client-a"));
Assert.Equal(ProtocolStatusCode.Ok, reply.ProtocolStatus.Code);
Assert.True(reply.HasHresult);
Assert.Equal(0, reply.Hresult);
Assert.Equal(42, reply.Register.ServerHandle);
Assert.Equal(MxDataType.Integer, reply.ReturnValue.DataType);
Assert.Equal(42, reply.ReturnValue.Int32Value);
Assert.Equal(runtime.StaThreadId, factory.FakeComObject.RegisterThreadId);
Assert.Equal("client-a", factory.FakeComObject.RegisteredClientName);
RegisteredServerHandle registeredServerHandle = Assert.Single(
await session.GetRegisteredServerHandlesAsync());
Assert.Equal(42, registeredServerHandle.ServerHandle);
Assert.Equal("client-a", registeredServerHandle.ClientName);
}
[Fact]
public async Task DispatchAsync_Unregister_CallsMxAccessOnStaAndRemovesTrackedServerHandle()
{
FakeMxAccessComObject fakeComObject = new(registerHandle: 43);
FakeMxAccessComObjectFactory factory = new(fakeComObject);
using StaRuntime runtime = CreateRuntime();
using MxAccessStaSession session = new(runtime, factory, new NoopEventSink());
await session.StartAsync(workerProcessId: 1234);
await session.DispatchAsync(CreateRegisterCommand("register", "client-a"));
MxCommandReply reply = await session.DispatchAsync(CreateUnregisterCommand("unregister", 43));
Assert.Equal(ProtocolStatusCode.Ok, reply.ProtocolStatus.Code);
Assert.Equal(43, fakeComObject.UnregisteredServerHandle);
Assert.Equal(runtime.StaThreadId, fakeComObject.UnregisterThreadId);
Assert.Empty(await session.GetRegisteredServerHandlesAsync());
}
[Fact]
public async Task DispatchAsync_UnregisterWhenMxAccessThrows_PreservesHResultAndDoesNotRewriteFailure()
{
const int hresult = unchecked((int)0x80070057);
FakeMxAccessComObject fakeComObject = new(
registerHandle: 44,
unregisterException: new COMException("Invalid handle.", hresult));
FakeMxAccessComObjectFactory factory = new(fakeComObject);
using StaRuntime runtime = CreateRuntime();
using MxAccessStaSession session = new(runtime, factory, new NoopEventSink());
await session.StartAsync(workerProcessId: 1234);
await session.DispatchAsync(CreateRegisterCommand("register-before-failure", "client-a"));
MxCommandReply reply = await session.DispatchAsync(CreateUnregisterCommand("invalid-unregister", 44));
Assert.Equal(ProtocolStatusCode.MxaccessFailure, reply.ProtocolStatus.Code);
Assert.True(reply.HasHresult);
Assert.Equal(hresult, reply.Hresult);
Assert.Contains("0x80070057", reply.DiagnosticMessage);
Assert.Equal(44, fakeComObject.UnregisteredServerHandle);
RegisteredServerHandle registeredServerHandle = Assert.Single(
await session.GetRegisteredServerHandlesAsync());
Assert.Equal(44, registeredServerHandle.ServerHandle);
}
[Fact]
public async Task DispatchAsync_RegisterWithoutPayload_ReturnsInvalidRequest()
{
FakeMxAccessComObjectFactory factory = new(new FakeMxAccessComObject(registerHandle: 45));
using StaRuntime runtime = CreateRuntime();
using MxAccessStaSession session = new(runtime, factory, new NoopEventSink());
await session.StartAsync(workerProcessId: 1234);
MxCommandReply reply = await session.DispatchAsync(new StaCommand(
"session-1",
"missing-payload",
new MxCommand
{
Kind = MxCommandKind.Register,
}));
Assert.Equal(ProtocolStatusCode.InvalidRequest, reply.ProtocolStatus.Code);
Assert.Null(factory.FakeComObject.RegisteredClientName);
}
private static StaCommand CreateRegisterCommand(
string correlationId,
string clientName)
{
return new StaCommand(
"session-1",
correlationId,
new MxCommand
{
Kind = MxCommandKind.Register,
Register = new RegisterCommand
{
ClientName = clientName,
},
});
}
private static StaCommand CreateUnregisterCommand(
string correlationId,
int serverHandle)
{
return new StaCommand(
"session-1",
correlationId,
new MxCommand
{
Kind = MxCommandKind.Unregister,
Unregister = new UnregisterCommand
{
ServerHandle = serverHandle,
},
});
}
private static StaRuntime CreateRuntime()
{
return new StaRuntime(
new NoopComApartmentInitializer(),
new StaMessagePump(),
TimeSpan.FromMilliseconds(25));
}
private sealed class FakeMxAccessComObject
{
private readonly int registerHandle;
private readonly Exception? unregisterException;
public FakeMxAccessComObject(
int registerHandle,
Exception? unregisterException = null)
{
this.registerHandle = registerHandle;
this.unregisterException = unregisterException;
}
public string? RegisteredClientName { get; private set; }
public int? RegisterThreadId { get; private set; }
public int? UnregisteredServerHandle { get; private set; }
public int? UnregisterThreadId { get; private set; }
public int Register(string clientName)
{
RegisteredClientName = clientName;
RegisterThreadId = Environment.CurrentManagedThreadId;
return registerHandle;
}
public void Unregister(int serverHandle)
{
UnregisteredServerHandle = serverHandle;
UnregisterThreadId = Environment.CurrentManagedThreadId;
if (unregisterException is not null)
{
throw unregisterException;
}
}
}
private sealed class FakeMxAccessComObjectFactory : IMxAccessComObjectFactory
{
public FakeMxAccessComObjectFactory(FakeMxAccessComObject fakeComObject)
{
FakeComObject = fakeComObject;
}
public FakeMxAccessComObject FakeComObject { get; }
public object Create()
{
return FakeComObject;
}
}
private sealed class NoopEventSink : IMxAccessEventSink
{
public void Attach(object mxAccessComObject)
{
}
public void Detach()
{
}
}
private sealed class NoopComApartmentInitializer : IStaComApartmentInitializer
{
public void Initialize()
{
}
public void Uninitialize()
{
}
}
}
@@ -1,6 +1,8 @@
using System;
using System.Threading.Tasks;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.MxAccess;
using MxGateway.Worker.Sta;
namespace MxGateway.Worker.Tests.MxAccess;
@@ -21,4 +23,53 @@ public sealed class MxAccessLiveComCreationTests
await session.StartAsync(workerProcessId: 1234);
}
[Fact]
public async Task RegisterAndUnregister_WhenOptedIn_RoundTripsInstalledMxAccessServerHandle()
{
if (!RunLiveMxAccessTests())
{
return;
}
using MxAccessStaSession session = new();
await session.StartAsync(workerProcessId: 1234);
MxCommandReply registerReply = await session.DispatchAsync(new StaCommand(
"session-1",
"live-register",
new MxCommand
{
Kind = MxCommandKind.Register,
Register = new RegisterCommand
{
ClientName = "MxGateway.Worker.Tests",
},
}));
Assert.Equal(ProtocolStatusCode.Ok, registerReply.ProtocolStatus.Code);
Assert.True(registerReply.Register.ServerHandle > 0);
MxCommandReply unregisterReply = await session.DispatchAsync(new StaCommand(
"session-1",
"live-unregister",
new MxCommand
{
Kind = MxCommandKind.Unregister,
Unregister = new UnregisterCommand
{
ServerHandle = registerReply.Register.ServerHandle,
},
}));
Assert.Equal(ProtocolStatusCode.Ok, unregisterReply.ProtocolStatus.Code);
}
private static bool RunLiveMxAccessTests()
{
return string.Equals(
Environment.GetEnvironmentVariable("MXGATEWAY_RUN_LIVE_MXACCESS_TESTS"),
"1",
StringComparison.Ordinal);
}
}
@@ -0,0 +1,8 @@
namespace MxGateway.Worker.MxAccess;
public interface IMxAccessServer
{
int Register(string clientName);
void Unregister(int serverHandle);
}
@@ -0,0 +1,59 @@
using System;
using System.Reflection;
using System.Runtime.ExceptionServices;
using ArchestrA.MxAccess;
namespace MxGateway.Worker.MxAccess;
public sealed class MxAccessComServer : IMxAccessServer
{
private readonly object mxAccessComObject;
public MxAccessComServer(object mxAccessComObject)
{
this.mxAccessComObject = mxAccessComObject ?? throw new ArgumentNullException(nameof(mxAccessComObject));
}
public int Register(string clientName)
{
if (mxAccessComObject is ILMXProxyServer mxAccessServer)
{
return mxAccessServer.Register(clientName);
}
return (int)Invoke(nameof(Register), clientName);
}
public void Unregister(int serverHandle)
{
if (mxAccessComObject is ILMXProxyServer mxAccessServer)
{
mxAccessServer.Unregister(serverHandle);
return;
}
Invoke(nameof(Unregister), serverHandle);
}
private object Invoke(
string methodName,
params object[] arguments)
{
try
{
return mxAccessComObject
.GetType()
.InvokeMember(
methodName,
BindingFlags.Instance | BindingFlags.Public | BindingFlags.InvokeMethod,
binder: null,
target: mxAccessComObject,
args: arguments);
}
catch (TargetInvocationException exception) when (exception.InnerException is not null)
{
ExceptionDispatchInfo.Capture(exception.InnerException).Throw();
throw;
}
}
}
@@ -0,0 +1,103 @@
using System;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Conversion;
using MxGateway.Worker.Sta;
namespace MxGateway.Worker.MxAccess;
public sealed class MxAccessCommandExecutor : IStaCommandExecutor
{
private readonly MxAccessSession session;
private readonly VariantConverter variantConverter;
public MxAccessCommandExecutor(MxAccessSession session)
: this(session, new VariantConverter())
{
}
public MxAccessCommandExecutor(
MxAccessSession session,
VariantConverter variantConverter)
{
this.session = session ?? throw new ArgumentNullException(nameof(session));
this.variantConverter = variantConverter ?? throw new ArgumentNullException(nameof(variantConverter));
}
public MxCommandReply Execute(StaCommand command)
{
if (command is null)
{
throw new ArgumentNullException(nameof(command));
}
return command.Kind switch
{
MxCommandKind.Register => ExecuteRegister(command),
MxCommandKind.Unregister => ExecuteUnregister(command),
_ => CreateInvalidRequestReply(command, $"Unsupported MXAccess command kind {command.Kind}."),
};
}
private MxCommandReply ExecuteRegister(StaCommand command)
{
if (command.Command.PayloadCase != MxCommand.PayloadOneofCase.Register)
{
return CreateInvalidRequestReply(command, "Register command payload is required.");
}
int serverHandle = session.Register(command.Command.Register.ClientName);
MxCommandReply reply = CreateOkReply(command);
reply.ReturnValue = variantConverter.Convert(serverHandle);
reply.Register = new RegisterReply
{
ServerHandle = serverHandle,
};
return reply;
}
private MxCommandReply ExecuteUnregister(StaCommand command)
{
if (command.Command.PayloadCase != MxCommand.PayloadOneofCase.Unregister)
{
return CreateInvalidRequestReply(command, "Unregister command payload is required.");
}
session.Unregister(command.Command.Unregister.ServerHandle);
return CreateOkReply(command);
}
private static MxCommandReply CreateOkReply(StaCommand command)
{
return new MxCommandReply
{
SessionId = command.SessionId,
CorrelationId = command.CorrelationId,
Kind = command.Kind,
Hresult = 0,
ProtocolStatus = new ProtocolStatus
{
Code = ProtocolStatusCode.Ok,
Message = "OK",
},
};
}
private static MxCommandReply CreateInvalidRequestReply(
StaCommand command,
string message)
{
return new MxCommandReply
{
SessionId = command.SessionId,
CorrelationId = command.CorrelationId,
Kind = command.Kind,
ProtocolStatus = new ProtocolStatus
{
Code = ProtocolStatusCode.InvalidRequest,
Message = message,
},
DiagnosticMessage = message,
};
}
}
@@ -0,0 +1,31 @@
using System.Collections.Generic;
using System.Linq;
namespace MxGateway.Worker.MxAccess;
public sealed class MxAccessHandleRegistry
{
private readonly Dictionary<int, RegisteredServerHandle> serverHandles = new();
public IReadOnlyList<RegisteredServerHandle> ServerHandles => serverHandles
.Values
.OrderBy(handle => handle.ServerHandle)
.ToArray();
public void RegisterServerHandle(
int serverHandle,
string clientName)
{
serverHandles[serverHandle] = new RegisteredServerHandle(serverHandle, clientName);
}
public void UnregisterServerHandle(int serverHandle)
{
serverHandles.Remove(serverHandle);
}
public bool ContainsServerHandle(int serverHandle)
{
return serverHandles.ContainsKey(serverHandle);
}
}
@@ -8,21 +8,29 @@ namespace MxGateway.Worker.MxAccess;
public sealed class MxAccessSession : IDisposable
{
private readonly object mxAccessComObject;
private readonly IMxAccessServer mxAccessServer;
private readonly IMxAccessEventSink eventSink;
private readonly MxAccessHandleRegistry handleRegistry;
private bool disposed;
private MxAccessSession(
object mxAccessComObject,
IMxAccessServer mxAccessServer,
IMxAccessEventSink eventSink,
MxAccessHandleRegistry handleRegistry,
int creationThreadId)
{
this.mxAccessComObject = mxAccessComObject ?? throw new ArgumentNullException(nameof(mxAccessComObject));
this.mxAccessServer = mxAccessServer ?? throw new ArgumentNullException(nameof(mxAccessServer));
this.eventSink = eventSink ?? throw new ArgumentNullException(nameof(eventSink));
this.handleRegistry = handleRegistry ?? throw new ArgumentNullException(nameof(handleRegistry));
CreationThreadId = creationThreadId;
}
public int CreationThreadId { get; }
public MxAccessHandleRegistry HandleRegistry => handleRegistry;
public WorkerReady CreateWorkerReady(int workerProcessId)
{
return new WorkerReady
@@ -62,7 +70,9 @@ public sealed class MxAccessSession : IDisposable
return new MxAccessSession(
mxAccessComObject,
new MxAccessComServer(mxAccessComObject),
eventSink,
new MxAccessHandleRegistry(),
Environment.CurrentManagedThreadId);
}
catch (Exception exception)
@@ -78,6 +88,24 @@ public sealed class MxAccessSession : IDisposable
}
}
public int Register(string clientName)
{
ThrowIfDisposed();
int serverHandle = mxAccessServer.Register(clientName);
handleRegistry.RegisterServerHandle(serverHandle, clientName);
return serverHandle;
}
public void Unregister(int serverHandle)
{
ThrowIfDisposed();
mxAccessServer.Unregister(serverHandle);
handleRegistry.UnregisterServerHandle(serverHandle);
}
public void Dispose()
{
if (disposed)
@@ -94,4 +122,12 @@ public sealed class MxAccessSession : IDisposable
disposed = true;
}
private void ThrowIfDisposed()
{
if (disposed)
{
throw new ObjectDisposedException(nameof(MxAccessSession));
}
}
}
@@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Contracts.Proto;
@@ -11,6 +12,7 @@ public sealed class MxAccessStaSession : IDisposable
private readonly IMxAccessComObjectFactory factory;
private readonly IMxAccessEventSink eventSink;
private readonly StaRuntime staRuntime;
private StaCommandDispatcher? commandDispatcher;
private MxAccessSession? session;
private bool disposed;
@@ -47,11 +49,38 @@ public sealed class MxAccessStaSession : IDisposable
}
session = MxAccessSession.Create(factory, eventSink);
commandDispatcher = new StaCommandDispatcher(
staRuntime,
new MxAccessCommandExecutor(session));
return session.CreateWorkerReady(workerProcessId);
},
cancellationToken);
}
public Task<MxCommandReply> DispatchAsync(StaCommand command)
{
if (commandDispatcher is null)
{
throw new InvalidOperationException("MXAccess COM session has not been started.");
}
return commandDispatcher.DispatchAsync(command);
}
public Task<IReadOnlyList<RegisteredServerHandle>> GetRegisteredServerHandlesAsync(
CancellationToken cancellationToken = default)
{
if (session is null)
{
throw new InvalidOperationException("MXAccess COM session has not been started.");
}
return staRuntime.InvokeAsync(
() => session.HandleRegistry.ServerHandles,
cancellationToken);
}
public void Dispose()
{
if (disposed)
@@ -59,6 +88,8 @@ public sealed class MxAccessStaSession : IDisposable
return;
}
commandDispatcher?.RequestShutdown();
if (session is not null)
{
staRuntime.InvokeAsync(() => session.Dispose()).GetAwaiter().GetResult();
@@ -0,0 +1,16 @@
namespace MxGateway.Worker.MxAccess;
public sealed class RegisteredServerHandle
{
public RegisteredServerHandle(
int serverHandle,
string clientName)
{
ServerHandle = serverHandle;
ClientName = clientName;
}
public int ServerHandle { get; }
public string ClientName { get; }
}