Compare commits

..

4 Commits

Author SHA1 Message Date
Joseph Doherty 14419853c7 Issue #25: implement sta command dispatcher 2026-04-26 17:49:01 -04:00
dohertj2 a20517f5ad Merge pull request #71 from agent-3/issue-13-implement-public-grpc-service
Issue #13: implement public grpc service
2026-04-26 17:48:35 -04:00
dohertj2 626e7762d9 Merge PR #70: Issue #31 implement MXSTATUS_PROXY and HRESULT conversion
Verified after merging current main with dotnet build src\\MxGateway.sln, dotnet test src\\MxGateway.Worker.Tests\\MxGateway.Worker.Tests.csproj -p:Platform=x86, and dotnet test src\\MxGateway.sln --no-build.
2026-04-26 17:43:08 -04:00
Joseph Doherty 8d6d3f6188 Issue #13: implement public grpc service 2026-04-26 17:42:46 -04:00
12 changed files with 1563 additions and 2 deletions
+11 -2
View File
@@ -64,8 +64,8 @@ MxGateway.Server
Configuration
Grpc
MxAccessGatewayService
RequestReplyMapper
EventMapper
MxAccessGrpcRequestValidator
MxAccessGrpcMapper
Dashboard
Pages
Components
@@ -105,6 +105,15 @@ service MxAccessGateway {
}
```
`MxAccessGatewayService` implements these public RPCs in the gateway process.
It validates public requests with `MxAccessGrpcRequestValidator`, delegates
session lifecycle and command routing to `ISessionManager`, and maps worker
command replies and events through `MxAccessGrpcMapper`. Session lookup,
validation, and worker transport failures become gRPC status errors. MXAccess
method replies that reached the worker remain `MxCommandReply` payloads so
HRESULT values, status arrays, and method-specific reply fields survive
transport boundaries.
Add this later only after the command and event model is stable:
```protobuf
+8
View File
@@ -852,6 +852,14 @@ The gRPC layer should be thin:
Avoid embedding MXAccess-specific business logic in gRPC handlers. Keep the
translation code testable.
The gateway maps `MxAccessGateway` to `MxAccessGatewayService`. The service
implements `OpenSession`, `CloseSession`, `Invoke`, and `StreamEvents` by
validating public requests, delegating session work to `ISessionManager`, and
using explicit mapper code for public-to-worker commands, worker replies, and
events. Missing sessions and transport failures return gRPC status errors;
worker command replies preserve MXAccess HRESULT and status details in the
public reply.
## C# Worker Versus C++ Worker
Start with a C# .NET Framework 4.8 x86 worker.
@@ -1,6 +1,7 @@
using MxGateway.Contracts;
using MxGateway.Server.Configuration;
using MxGateway.Server.Diagnostics;
using MxGateway.Server.Grpc;
using MxGateway.Server.Metrics;
using MxGateway.Server.Security.Authentication;
using MxGateway.Server.Security.Authorization;
@@ -31,6 +32,8 @@ public static class GatewayApplication
builder.Services.AddGatewayGrpcAuthorization();
builder.Services.AddHealthChecks();
builder.Services.AddSingleton<GatewayMetrics>();
builder.Services.AddSingleton<MxAccessGrpcMapper>();
builder.Services.AddSingleton<MxAccessGrpcRequestValidator>();
builder.Services.AddWorkerProcessLauncher();
builder.Services.AddGatewaySessions();
@@ -49,6 +52,8 @@ public static class GatewayApplication
WorkerProtocolVersion: GatewayContractInfo.WorkerProtocolVersion)))
.WithName("LiveHealth");
endpoints.MapGrpcService<MxAccessGatewayService>();
return endpoints;
}
}
@@ -0,0 +1,179 @@
using Grpc.Core;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Server.Security.Authorization;
using MxGateway.Server.Sessions;
using MxGateway.Server.Workers;
namespace MxGateway.Server.Grpc;
public sealed class MxAccessGatewayService(
ISessionManager sessionManager,
IGatewayRequestIdentityAccessor identityAccessor,
MxAccessGrpcRequestValidator requestValidator,
MxAccessGrpcMapper mapper,
ILogger<MxAccessGatewayService> logger) : MxAccessGateway.MxAccessGatewayBase
{
public override async Task<OpenSessionReply> OpenSession(
OpenSessionRequest request,
ServerCallContext context)
{
try
{
requestValidator.ValidateOpenSession(request);
GatewaySession session = await sessionManager
.OpenSessionAsync(
SessionOpenRequest.FromContract(request),
ResolveClientIdentity(),
context.CancellationToken)
.ConfigureAwait(false);
OpenSessionReply reply = new()
{
SessionId = session.SessionId,
BackendName = session.BackendName,
WorkerProcessId = session.WorkerProcessId ?? 0,
WorkerProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
DefaultCommandTimeout = Google.Protobuf.WellKnownTypes.Duration.FromTimeSpan(session.CommandTimeout),
ProtocolStatus = MxAccessGrpcMapper.Ok(),
};
reply.Capabilities.Add("unary-open-session");
reply.Capabilities.Add("unary-close-session");
reply.Capabilities.Add("unary-invoke");
reply.Capabilities.Add("server-stream-events");
return reply;
}
catch (Exception exception) when (exception is not RpcException)
{
throw MapException(exception);
}
}
public override async Task<CloseSessionReply> CloseSession(
CloseSessionRequest request,
ServerCallContext context)
{
try
{
requestValidator.ValidateCloseSession(request);
SessionCloseResult result = await sessionManager
.CloseSessionAsync(request.SessionId, context.CancellationToken)
.ConfigureAwait(false);
return new CloseSessionReply
{
SessionId = result.SessionId,
FinalState = result.FinalState,
ProtocolStatus = MxAccessGrpcMapper.Ok(result.AlreadyClosed ? "Session was already closed." : "Session closed."),
};
}
catch (Exception exception) when (exception is not RpcException)
{
throw MapException(exception);
}
}
public override async Task<MxCommandReply> Invoke(
MxCommandRequest request,
ServerCallContext context)
{
try
{
requestValidator.ValidateInvoke(request);
WorkerCommand workerCommand = mapper.MapCommand(request);
WorkerCommandReply workerReply = await sessionManager
.InvokeAsync(request.SessionId, workerCommand, context.CancellationToken)
.ConfigureAwait(false);
return mapper.MapCommandReply(workerReply);
}
catch (Exception exception) when (exception is not RpcException)
{
throw MapException(exception);
}
}
public override async Task StreamEvents(
StreamEventsRequest request,
IServerStreamWriter<MxEvent> responseStream,
ServerCallContext context)
{
try
{
requestValidator.ValidateStreamEvents(request);
await foreach (WorkerEvent workerEvent in sessionManager
.ReadEventsAsync(request.SessionId, context.CancellationToken)
.WithCancellation(context.CancellationToken)
.ConfigureAwait(false))
{
MxEvent publicEvent = mapper.MapEvent(workerEvent);
if (publicEvent.WorkerSequence <= request.AfterWorkerSequence)
{
continue;
}
await responseStream.WriteAsync(publicEvent).ConfigureAwait(false);
}
}
catch (Exception exception) when (exception is not RpcException)
{
throw MapException(exception);
}
}
private string? ResolveClientIdentity()
{
return identityAccessor.Current?.DisplayName ?? identityAccessor.Current?.KeyId;
}
private RpcException MapException(Exception exception)
{
if (exception is OperationCanceledException)
{
return new RpcException(new Status(StatusCode.Cancelled, "gRPC request was canceled."));
}
if (exception is SessionManagerException sessionException)
{
return MapSessionException(sessionException);
}
if (exception is WorkerClientException workerClientException)
{
return MapWorkerClientException(workerClientException);
}
logger.LogWarning(exception, "Public gRPC request failed.");
return new RpcException(new Status(StatusCode.Unavailable, "Gateway request failed before an MXAccess reply was available."));
}
private static RpcException MapSessionException(SessionManagerException exception)
{
StatusCode statusCode = exception.ErrorCode switch
{
SessionManagerErrorCode.SessionNotFound => StatusCode.NotFound,
SessionManagerErrorCode.SessionNotReady => StatusCode.FailedPrecondition,
SessionManagerErrorCode.SessionLimitExceeded => StatusCode.ResourceExhausted,
SessionManagerErrorCode.OpenFailed => StatusCode.Unavailable,
SessionManagerErrorCode.CloseFailed => StatusCode.Unavailable,
_ => StatusCode.Unavailable,
};
return new RpcException(new Status(statusCode, exception.Message));
}
private static RpcException MapWorkerClientException(WorkerClientException exception)
{
StatusCode statusCode = exception.ErrorCode switch
{
WorkerClientErrorCode.CommandTimeout => StatusCode.DeadlineExceeded,
WorkerClientErrorCode.GatewayShutdown => StatusCode.Cancelled,
WorkerClientErrorCode.InvalidState => StatusCode.FailedPrecondition,
WorkerClientErrorCode.ProtocolViolation => StatusCode.Internal,
_ => StatusCode.Unavailable,
};
return new RpcException(new Status(statusCode, exception.Message));
}
}
@@ -0,0 +1,124 @@
using Google.Protobuf.WellKnownTypes;
using MxGateway.Contracts.Proto;
namespace MxGateway.Server.Grpc;
public sealed class MxAccessGrpcMapper
{
private readonly TimeProvider _timeProvider;
public MxAccessGrpcMapper(TimeProvider? timeProvider = null)
{
_timeProvider = timeProvider ?? TimeProvider.System;
}
public WorkerCommand MapCommand(MxCommandRequest request)
{
ArgumentNullException.ThrowIfNull(request);
ArgumentNullException.ThrowIfNull(request.Command);
return new WorkerCommand
{
Command = request.Command.Clone(),
EnqueueTimestamp = Timestamp.FromDateTimeOffset(_timeProvider.GetUtcNow()),
};
}
public MxCommandReply MapCommandReply(WorkerCommandReply reply)
{
ArgumentNullException.ThrowIfNull(reply);
if (reply.Reply is null)
{
return new MxCommandReply
{
ProtocolStatus = ProtocolViolation("Worker command reply did not contain a public reply payload."),
};
}
return reply.Reply.Clone();
}
public MxEvent MapEvent(WorkerEvent workerEvent)
{
ArgumentNullException.ThrowIfNull(workerEvent);
return workerEvent.Event?.Clone() ?? new MxEvent
{
Family = MxEventFamily.Unspecified,
RawStatus = "Worker event did not contain a public event payload.",
};
}
public static ProtocolStatus Ok(string message = "OK")
{
return new ProtocolStatus
{
Code = ProtocolStatusCode.Ok,
Message = message,
};
}
public static ProtocolStatus InvalidRequest(string message)
{
return new ProtocolStatus
{
Code = ProtocolStatusCode.InvalidRequest,
Message = message,
};
}
public static ProtocolStatus SessionNotFound(string message)
{
return new ProtocolStatus
{
Code = ProtocolStatusCode.SessionNotFound,
Message = message,
};
}
public static ProtocolStatus SessionNotReady(string message)
{
return new ProtocolStatus
{
Code = ProtocolStatusCode.SessionNotReady,
Message = message,
};
}
public static ProtocolStatus WorkerUnavailable(string message)
{
return new ProtocolStatus
{
Code = ProtocolStatusCode.WorkerUnavailable,
Message = message,
};
}
public static ProtocolStatus Timeout(string message)
{
return new ProtocolStatus
{
Code = ProtocolStatusCode.Timeout,
Message = message,
};
}
public static ProtocolStatus Canceled(string message)
{
return new ProtocolStatus
{
Code = ProtocolStatusCode.Canceled,
Message = message,
};
}
public static ProtocolStatus ProtocolViolation(string message)
{
return new ProtocolStatus
{
Code = ProtocolStatusCode.ProtocolViolation,
Message = message,
};
}
}
@@ -0,0 +1,101 @@
using Grpc.Core;
using MxGateway.Contracts.Proto;
namespace MxGateway.Server.Grpc;
public sealed class MxAccessGrpcRequestValidator
{
public void ValidateOpenSession(OpenSessionRequest request)
{
ArgumentNullException.ThrowIfNull(request);
if (request.CommandTimeout is not null && request.CommandTimeout.ToTimeSpan() <= TimeSpan.Zero)
{
throw InvalidArgument("Command timeout must be greater than zero when provided.");
}
}
public void ValidateCloseSession(CloseSessionRequest request)
{
ArgumentNullException.ThrowIfNull(request);
RequireSessionId(request.SessionId);
}
public void ValidateStreamEvents(StreamEventsRequest request)
{
ArgumentNullException.ThrowIfNull(request);
RequireSessionId(request.SessionId);
}
public void ValidateInvoke(MxCommandRequest request)
{
ArgumentNullException.ThrowIfNull(request);
RequireSessionId(request.SessionId);
if (request.Command is null)
{
throw InvalidArgument("Invoke requires a command payload.");
}
if (request.Command.Kind is MxCommandKind.Unspecified)
{
throw InvalidArgument("Invoke requires a command kind.");
}
ValidateCommandPayload(request.Command);
}
private static void RequireSessionId(string sessionId)
{
if (string.IsNullOrWhiteSpace(sessionId))
{
throw InvalidArgument("Session id is required.");
}
}
private static void ValidateCommandPayload(MxCommand command)
{
MxCommand.PayloadOneofCase expectedPayload = ExpectedPayload(command.Kind);
if (command.PayloadCase != expectedPayload)
{
throw InvalidArgument(
$"Command kind {command.Kind} requires payload {expectedPayload} but received {command.PayloadCase}.");
}
}
private static MxCommand.PayloadOneofCase ExpectedPayload(MxCommandKind kind)
{
return kind switch
{
MxCommandKind.Register => MxCommand.PayloadOneofCase.Register,
MxCommandKind.Unregister => MxCommand.PayloadOneofCase.Unregister,
MxCommandKind.AddItem => MxCommand.PayloadOneofCase.AddItem,
MxCommandKind.AddItem2 => MxCommand.PayloadOneofCase.AddItem2,
MxCommandKind.RemoveItem => MxCommand.PayloadOneofCase.RemoveItem,
MxCommandKind.Advise => MxCommand.PayloadOneofCase.Advise,
MxCommandKind.UnAdvise => MxCommand.PayloadOneofCase.UnAdvise,
MxCommandKind.AdviseSupervisory => MxCommand.PayloadOneofCase.AdviseSupervisory,
MxCommandKind.AddBufferedItem => MxCommand.PayloadOneofCase.AddBufferedItem,
MxCommandKind.SetBufferedUpdateInterval => MxCommand.PayloadOneofCase.SetBufferedUpdateInterval,
MxCommandKind.Suspend => MxCommand.PayloadOneofCase.Suspend,
MxCommandKind.Activate => MxCommand.PayloadOneofCase.Activate,
MxCommandKind.Write => MxCommand.PayloadOneofCase.Write,
MxCommandKind.Write2 => MxCommand.PayloadOneofCase.Write2,
MxCommandKind.WriteSecured => MxCommand.PayloadOneofCase.WriteSecured,
MxCommandKind.WriteSecured2 => MxCommand.PayloadOneofCase.WriteSecured2,
MxCommandKind.AuthenticateUser => MxCommand.PayloadOneofCase.AuthenticateUser,
MxCommandKind.ArchestraUserToId => MxCommand.PayloadOneofCase.ArchestraUserToId,
MxCommandKind.Ping => MxCommand.PayloadOneofCase.Ping,
MxCommandKind.GetSessionState => MxCommand.PayloadOneofCase.GetSessionState,
MxCommandKind.GetWorkerInfo => MxCommand.PayloadOneofCase.GetWorkerInfo,
MxCommandKind.DrainEvents => MxCommand.PayloadOneofCase.DrainEvents,
MxCommandKind.ShutdownWorker => MxCommand.PayloadOneofCase.ShutdownWorker,
_ => MxCommand.PayloadOneofCase.None,
};
}
private static RpcException InvalidArgument(string detail)
{
return new RpcException(new Status(StatusCode.InvalidArgument, detail));
}
}
@@ -0,0 +1,458 @@
using System.Runtime.CompilerServices;
using Google.Protobuf.WellKnownTypes;
using Grpc.Core;
using Microsoft.Extensions.Logging.Abstractions;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Server.Grpc;
using MxGateway.Server.Security.Authentication;
using MxGateway.Server.Security.Authorization;
using MxGateway.Server.Sessions;
using MxGateway.Server.Workers;
namespace MxGateway.Tests.Gateway.Grpc;
public sealed class MxAccessGatewayServiceTests
{
[Fact]
public async Task OpenSession_WithValidRequest_ReturnsSessionDetails()
{
GatewayRequestIdentityAccessor identityAccessor = new();
FakeSessionManager sessionManager = new()
{
OpenSessionResult = CreateSession("session-1", processId: 4321),
};
MxAccessGatewayService service = CreateService(sessionManager, identityAccessor);
using IDisposable identityScope = identityAccessor.Push(CreateIdentity());
OpenSessionReply reply = await service.OpenSession(
new OpenSessionRequest
{
ClientSessionName = "operator-session",
CommandTimeout = Duration.FromTimeSpan(TimeSpan.FromSeconds(7)),
},
new TestServerCallContext());
Assert.Equal("session-1", reply.SessionId);
Assert.Equal(GatewayContractInfo.DefaultBackendName, reply.BackendName);
Assert.Equal(4321, reply.WorkerProcessId);
Assert.Equal(GatewayContractInfo.WorkerProtocolVersion, reply.WorkerProtocolVersion);
Assert.Equal(ProtocolStatusCode.Ok, reply.ProtocolStatus.Code);
Assert.Contains("unary-invoke", reply.Capabilities);
Assert.Equal("Operator Key", sessionManager.LastClientIdentity);
Assert.Equal("operator-session", sessionManager.LastOpenRequest?.ClientSessionName);
}
[Fact]
public async Task Invoke_WhenSessionMissing_ThrowsNotFound()
{
FakeSessionManager sessionManager = new()
{
InvokeException = new SessionManagerException(
SessionManagerErrorCode.SessionNotFound,
"Session session-missing was not found."),
};
MxAccessGatewayService service = CreateService(sessionManager);
RpcException exception = await Assert.ThrowsAsync<RpcException>(
async () => await service.Invoke(
CreatePingRequest("session-missing"),
new TestServerCallContext()));
Assert.Equal(StatusCode.NotFound, exception.StatusCode);
Assert.Contains("session-missing", exception.Status.Detail, StringComparison.Ordinal);
}
[Fact]
public async Task Invoke_WithMismatchedPayload_ThrowsInvalidArgumentAndDoesNotCallSessionManager()
{
FakeSessionManager sessionManager = new();
MxAccessGatewayService service = CreateService(sessionManager);
MxCommandRequest request = new()
{
SessionId = "session-1",
Command = new MxCommand
{
Kind = MxCommandKind.AddItem,
Ping = new PingCommand { Message = "wrong-payload" },
},
};
RpcException exception = await Assert.ThrowsAsync<RpcException>(
async () => await service.Invoke(request, new TestServerCallContext()));
Assert.Equal(StatusCode.InvalidArgument, exception.StatusCode);
Assert.Equal(0, sessionManager.InvokeCount);
}
[Fact]
public async Task Invoke_WithWorkerReply_ReturnsHresultStatusAndMethodPayload()
{
const int hresult = unchecked((int)0x80004005);
FakeSessionManager sessionManager = new()
{
InvokeReply = new WorkerCommandReply
{
Reply = new MxCommandReply
{
SessionId = "session-1",
CorrelationId = "worker-correlation",
Kind = MxCommandKind.AddItem,
ProtocolStatus = MxAccessGrpcMapper.Ok(),
Hresult = hresult,
AddItem = new AddItemReply { ItemHandle = 42 },
DiagnosticMessage = "mxaccess diagnostic",
},
},
};
sessionManager.InvokeReply.Reply.Statuses.Add(new MxStatusProxy
{
Success = 0,
Category = MxStatusCategory.SoftwareError,
Detail = 1001,
DiagnosticText = "status detail",
});
MxAccessGatewayService service = CreateService(sessionManager);
MxCommandRequest request = new()
{
SessionId = "session-1",
ClientCorrelationId = "client-correlation",
Command = new MxCommand
{
Kind = MxCommandKind.AddItem,
AddItem = new AddItemCommand
{
ServerHandle = 12,
ItemDefinition = "Galaxy.Tag.Value",
},
},
};
MxCommandReply reply = await service.Invoke(request, new TestServerCallContext());
Assert.Equal(MxCommandKind.AddItem, sessionManager.LastWorkerCommand?.Command.Kind);
Assert.Equal("Galaxy.Tag.Value", sessionManager.LastWorkerCommand?.Command.AddItem.ItemDefinition);
Assert.NotNull(sessionManager.LastWorkerCommand?.EnqueueTimestamp);
Assert.Equal(hresult, reply.Hresult);
Assert.Equal(42, reply.AddItem.ItemHandle);
Assert.Equal("status detail", Assert.Single(reply.Statuses).DiagnosticText);
Assert.Equal("mxaccess diagnostic", reply.DiagnosticMessage);
}
[Fact]
public async Task StreamEvents_WithAfterSequence_WritesOnlyLaterEvents()
{
FakeSessionManager sessionManager = new();
sessionManager.Events.Add(CreateWorkerEvent("session-1", workerSequence: 1));
sessionManager.Events.Add(CreateWorkerEvent("session-1", workerSequence: 2));
MxAccessGatewayService service = CreateService(sessionManager);
TestServerStreamWriter<MxEvent> writer = new();
await service.StreamEvents(
new StreamEventsRequest
{
SessionId = "session-1",
AfterWorkerSequence = 1,
},
writer,
new TestServerCallContext());
MxEvent writtenEvent = Assert.Single(writer.Messages);
Assert.Equal((ulong)2, writtenEvent.WorkerSequence);
Assert.Equal("session-1", sessionManager.LastReadEventsSessionId);
}
[Fact]
public async Task CloseSession_WithBlankSessionId_ThrowsInvalidArgument()
{
MxAccessGatewayService service = CreateService(new FakeSessionManager());
RpcException exception = await Assert.ThrowsAsync<RpcException>(
async () => await service.CloseSession(
new CloseSessionRequest(),
new TestServerCallContext()));
Assert.Equal(StatusCode.InvalidArgument, exception.StatusCode);
}
private static MxAccessGatewayService CreateService(
FakeSessionManager sessionManager,
IGatewayRequestIdentityAccessor? identityAccessor = null)
{
return new MxAccessGatewayService(
sessionManager,
identityAccessor ?? new GatewayRequestIdentityAccessor(),
new MxAccessGrpcRequestValidator(),
new MxAccessGrpcMapper(),
NullLogger<MxAccessGatewayService>.Instance);
}
private static ApiKeyIdentity CreateIdentity()
{
return new ApiKeyIdentity(
KeyId: "operator01",
KeyPrefix: "mxgw_operator01",
DisplayName: "Operator Key",
Scopes: new HashSet<string>(StringComparer.Ordinal));
}
private static GatewaySession CreateSession(
string sessionId,
int processId)
{
GatewaySession session = new(
sessionId,
GatewayContractInfo.DefaultBackendName,
"pipe",
"nonce",
"Operator Key",
"operator-session",
"client-correlation",
TimeSpan.FromSeconds(7),
TimeSpan.FromSeconds(30),
TimeSpan.FromSeconds(10),
DateTimeOffset.UtcNow);
session.AttachWorkerClient(new FakeWorkerClient(processId));
session.MarkReady();
return session;
}
private static MxCommandRequest CreatePingRequest(string sessionId)
{
return new MxCommandRequest
{
SessionId = sessionId,
Command = new MxCommand
{
Kind = MxCommandKind.Ping,
Ping = new PingCommand { Message = "ping" },
},
};
}
private static WorkerEvent CreateWorkerEvent(
string sessionId,
ulong workerSequence)
{
return new WorkerEvent
{
Event = new MxEvent
{
Family = MxEventFamily.OnDataChange,
SessionId = sessionId,
WorkerSequence = workerSequence,
OnDataChange = new OnDataChangeEvent(),
},
};
}
private sealed class FakeSessionManager : ISessionManager
{
public GatewaySession? OpenSessionResult { get; init; }
public SessionOpenRequest? LastOpenRequest { get; private set; }
public string? LastClientIdentity { get; private set; }
public string? LastReadEventsSessionId { get; private set; }
public WorkerCommand? LastWorkerCommand { get; private set; }
public WorkerCommandReply InvokeReply { get; init; } = new()
{
Reply = new MxCommandReply
{
SessionId = "session-1",
Kind = MxCommandKind.Ping,
ProtocolStatus = MxAccessGrpcMapper.Ok(),
},
};
public Exception? InvokeException { get; init; }
public int InvokeCount { get; private set; }
public List<WorkerEvent> Events { get; } = [];
public Task<GatewaySession> OpenSessionAsync(
SessionOpenRequest request,
string? clientIdentity,
CancellationToken cancellationToken)
{
LastOpenRequest = request;
LastClientIdentity = clientIdentity;
return Task.FromResult(OpenSessionResult ?? CreateSession("session-1", processId: 1234));
}
public bool TryGetSession(
string sessionId,
out GatewaySession session)
{
session = OpenSessionResult ?? CreateSession(sessionId, processId: 1234);
return true;
}
public Task<WorkerCommandReply> InvokeAsync(
string sessionId,
WorkerCommand command,
CancellationToken cancellationToken)
{
InvokeCount++;
LastWorkerCommand = command;
if (InvokeException is not null)
{
throw InvokeException;
}
return Task.FromResult(InvokeReply);
}
public async IAsyncEnumerable<WorkerEvent> ReadEventsAsync(
string sessionId,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
LastReadEventsSessionId = sessionId;
foreach (WorkerEvent workerEvent in Events)
{
cancellationToken.ThrowIfCancellationRequested();
await Task.Yield();
yield return workerEvent;
}
}
public Task<SessionCloseResult> CloseSessionAsync(
string sessionId,
CancellationToken cancellationToken)
{
return Task.FromResult(new SessionCloseResult(sessionId, SessionState.Closed, AlreadyClosed: false));
}
public Task<int> CloseExpiredLeasesAsync(
DateTimeOffset now,
CancellationToken cancellationToken)
{
return Task.FromResult(0);
}
public Task ShutdownAsync(CancellationToken cancellationToken)
{
return Task.CompletedTask;
}
}
private sealed class FakeWorkerClient(int processId) : IWorkerClient
{
public string SessionId { get; } = "session-1";
public int? ProcessId { get; } = processId;
public WorkerClientState State { get; } = WorkerClientState.Ready;
public DateTimeOffset LastHeartbeatAt { get; } = DateTimeOffset.UtcNow;
public Task StartAsync(CancellationToken cancellationToken)
{
return Task.CompletedTask;
}
public Task<WorkerCommandReply> InvokeAsync(
WorkerCommand command,
TimeSpan timeout,
CancellationToken cancellationToken)
{
return Task.FromResult(new WorkerCommandReply());
}
public async IAsyncEnumerable<WorkerEvent> ReadEventsAsync(
[EnumeratorCancellation] CancellationToken cancellationToken)
{
await Task.CompletedTask;
yield break;
}
public Task ShutdownAsync(
TimeSpan timeout,
CancellationToken cancellationToken)
{
return Task.CompletedTask;
}
public void Kill(string reason)
{
}
public ValueTask DisposeAsync()
{
return ValueTask.CompletedTask;
}
}
private sealed class TestServerStreamWriter<T> : IServerStreamWriter<T>
{
public List<T> Messages { get; } = [];
public WriteOptions? WriteOptions { get; set; }
public Task WriteAsync(T message)
{
Messages.Add(message);
return Task.CompletedTask;
}
}
private sealed class TestServerCallContext(CancellationToken cancellationToken = default) : ServerCallContext
{
private readonly Metadata requestHeaders = [];
private readonly Metadata responseTrailers = [];
private readonly Dictionary<object, object> userState = [];
private Status status;
private WriteOptions? writeOptions;
protected override string MethodCore => "/mxaccess_gateway.v1.MxAccessGateway/Test";
protected override string HostCore => "localhost";
protected override string PeerCore => "ipv4:127.0.0.1:5000";
protected override DateTime DeadlineCore => DateTime.UtcNow.AddMinutes(1);
protected override Metadata RequestHeadersCore => requestHeaders;
protected override CancellationToken CancellationTokenCore => cancellationToken;
protected override Metadata ResponseTrailersCore => responseTrailers;
protected override Status StatusCore
{
get => status;
set => status = value;
}
protected override WriteOptions? WriteOptionsCore
{
get => writeOptions;
set => writeOptions = value;
}
protected override AuthContext AuthContextCore { get; } = new(
string.Empty,
new Dictionary<string, List<AuthProperty>>(StringComparer.Ordinal));
protected override IDictionary<object, object> UserStateCore => userState;
protected override Task WriteResponseHeadersAsyncCore(Metadata responseHeaders)
{
return Task.CompletedTask;
}
protected override ContextPropagationToken CreatePropagationTokenCore(
ContextPropagationOptions? options)
{
throw new NotSupportedException();
}
}
}
@@ -0,0 +1,76 @@
using MxGateway.Contracts.Proto;
using MxGateway.Server.Grpc;
namespace MxGateway.Tests.Gateway.Grpc;
public sealed class MxAccessGrpcMapperTests
{
[Fact]
public void MapCommand_ClonesMethodSpecificPayloadForWorkerBoundary()
{
MxAccessGrpcMapper mapper = new();
MxCommandRequest request = new()
{
SessionId = "session-1",
Command = new MxCommand
{
Kind = MxCommandKind.Write,
Write = new WriteCommand
{
ServerHandle = 10,
ItemHandle = 20,
UserId = 30,
Value = new MxValue
{
DataType = MxDataType.String,
StringValue = "value",
},
},
},
};
WorkerCommand workerCommand = mapper.MapCommand(request);
request.Command.Write.Value.StringValue = "changed";
Assert.Equal(MxCommandKind.Write, workerCommand.Command.Kind);
Assert.Equal("value", workerCommand.Command.Write.Value.StringValue);
Assert.NotNull(workerCommand.EnqueueTimestamp);
}
[Fact]
public void MapCommandReply_PreservesHresultStatusesAndPayload()
{
const int hresult = unchecked((int)0x80070005);
WorkerCommandReply workerReply = new()
{
Reply = new MxCommandReply
{
SessionId = "session-1",
Kind = MxCommandKind.Register,
ProtocolStatus = MxAccessGrpcMapper.Ok(),
Hresult = hresult,
Register = new RegisterReply { ServerHandle = 50 },
},
};
workerReply.Reply.Statuses.Add(new MxStatusProxy
{
Success = 0,
Category = MxStatusCategory.SecurityError,
DiagnosticText = "denied",
});
MxCommandReply publicReply = new MxAccessGrpcMapper().MapCommandReply(workerReply);
Assert.Equal(hresult, publicReply.Hresult);
Assert.Equal(50, publicReply.Register.ServerHandle);
Assert.Equal("denied", Assert.Single(publicReply.Statuses).DiagnosticText);
}
[Fact]
public void MapCommandReply_WhenWorkerReplyMissing_ReturnsProtocolViolationReply()
{
MxCommandReply publicReply = new MxAccessGrpcMapper().MapCommandReply(new WorkerCommandReply());
Assert.Equal(ProtocolStatusCode.ProtocolViolation, publicReply.ProtocolStatus.Code);
}
}
@@ -0,0 +1,279 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Sta;
namespace MxGateway.Worker.Tests.Sta;
public sealed class StaCommandDispatcherTests
{
[Fact]
public async Task DispatchAsync_ExecutesCommandsOnStaInQueueOrder()
{
using StaRuntime runtime = CreateRuntime();
runtime.Start();
RecordingCommandExecutor executor = new();
StaCommandDispatcher dispatcher = new(runtime, executor);
Task<MxCommandReply> first = dispatcher.DispatchAsync(CreateCommand("correlation-1", MxCommandKind.Register));
Task<MxCommandReply> second = dispatcher.DispatchAsync(CreateCommand("correlation-2", MxCommandKind.AddItem));
MxCommandReply[] replies = await Task.WhenAll(first, second);
Assert.Equal(new[] { "correlation-1", "correlation-2" }, executor.CorrelationIds);
Assert.All(executor.ThreadIds, threadId => Assert.Equal(runtime.StaThreadId, threadId));
Assert.Equal("correlation-1", replies[0].CorrelationId);
Assert.Equal("correlation-2", replies[1].CorrelationId);
Assert.Equal(ProtocolStatusCode.Ok, replies[0].ProtocolStatus.Code);
}
[Fact]
public async Task DispatchAsync_WhenExecutorThrows_ReturnsFailureReplyWithHResult()
{
using StaRuntime runtime = CreateRuntime();
runtime.Start();
StaCommandDispatcher dispatcher = new(
runtime,
new ThrowingCommandExecutor(new COMException("provider detail", unchecked((int)0x80070057))));
MxCommandReply reply = await dispatcher.DispatchAsync(CreateCommand("correlation-1", MxCommandKind.Register));
Assert.Equal("session-1", reply.SessionId);
Assert.Equal("correlation-1", reply.CorrelationId);
Assert.Equal(MxCommandKind.Register, reply.Kind);
Assert.Equal(ProtocolStatusCode.MxaccessFailure, reply.ProtocolStatus.Code);
Assert.Equal(unchecked((int)0x80070057), reply.Hresult);
Assert.Contains("0x80070057", reply.DiagnosticMessage);
Assert.DoesNotContain("provider detail", reply.DiagnosticMessage);
}
[Fact]
public async Task DispatchAsync_WhenCanceledBeforeExecution_ReturnsCanceledReplyWithoutExecuting()
{
using StaRuntime runtime = CreateRuntime();
runtime.Start();
BlockingCommandExecutor executor = new();
StaCommandDispatcher dispatcher = new(runtime, executor);
Task<MxCommandReply> blocked = dispatcher.DispatchAsync(CreateCommand("blocked", MxCommandKind.Register));
Assert.True(executor.Started.Wait(TimeSpan.FromSeconds(2)));
using CancellationTokenSource cancellation = new();
Task<MxCommandReply> canceled = dispatcher.DispatchAsync(
CreateCommand("canceled", MxCommandKind.AddItem, cancellation.Token));
cancellation.Cancel();
executor.Release();
MxCommandReply canceledReply = await canceled;
await blocked;
Assert.Equal(ProtocolStatusCode.Canceled, canceledReply.ProtocolStatus.Code);
Assert.DoesNotContain("canceled", executor.CorrelationIds);
}
[Fact]
public async Task DispatchAsync_WhenCanceledAfterExecutionStarts_StillReturnsLateReply()
{
using StaRuntime runtime = CreateRuntime();
runtime.Start();
BlockingCommandExecutor executor = new();
StaCommandDispatcher dispatcher = new(runtime, executor);
using CancellationTokenSource cancellation = new();
Task<MxCommandReply> replyTask = dispatcher.DispatchAsync(
CreateCommand("late-reply", MxCommandKind.Register, cancellation.Token));
Assert.True(executor.Started.Wait(TimeSpan.FromSeconds(2)));
cancellation.Cancel();
executor.Release();
MxCommandReply reply = await replyTask;
Assert.Equal(ProtocolStatusCode.Ok, reply.ProtocolStatus.Code);
Assert.Contains("late-reply", executor.CorrelationIds);
}
[Fact]
public async Task DispatchAsync_WhenShutdownRequested_RejectsNewCommands()
{
using StaRuntime runtime = CreateRuntime();
runtime.Start();
StaCommandDispatcher dispatcher = new(runtime, new RecordingCommandExecutor());
dispatcher.RequestShutdown();
MxCommandReply reply = await dispatcher.DispatchAsync(CreateCommand("correlation-1", MxCommandKind.Register));
Assert.Equal(ProtocolStatusCode.WorkerUnavailable, reply.ProtocolStatus.Code);
Assert.Equal("correlation-1", reply.CorrelationId);
}
[Fact]
public async Task PopulateHeartbeat_ReportsCurrentCorrelationAndPendingCount()
{
using StaRuntime runtime = CreateRuntime();
runtime.Start();
BlockingCommandExecutor executor = new();
StaCommandDispatcher dispatcher = new(runtime, executor);
Task<MxCommandReply> current = dispatcher.DispatchAsync(CreateCommand("current", MxCommandKind.Register));
Assert.True(executor.Started.Wait(TimeSpan.FromSeconds(2)));
Task<MxCommandReply> pending = dispatcher.DispatchAsync(CreateCommand("pending", MxCommandKind.AddItem));
WorkerHeartbeat heartbeat = new();
dispatcher.PopulateHeartbeat(heartbeat);
Assert.Equal("current", heartbeat.CurrentCommandCorrelationId);
Assert.Equal(1u, heartbeat.PendingCommandCount);
executor.Release();
await Task.WhenAll(current, pending);
}
private static StaCommand CreateCommand(
string correlationId,
MxCommandKind kind,
CancellationToken cancellationToken = default)
{
return new StaCommand(
"session-1",
correlationId,
new MxCommand
{
Kind = kind,
Ping = new PingCommand
{
Message = correlationId,
},
},
cancellationToken: cancellationToken);
}
private static StaRuntime CreateRuntime()
{
return new StaRuntime(
new NoopComApartmentInitializer(),
new StaMessagePump(),
TimeSpan.FromMilliseconds(25));
}
private sealed class RecordingCommandExecutor : IStaCommandExecutor
{
private readonly object gate = new();
private readonly List<string> correlationIds = new();
private readonly List<int> threadIds = new();
public IReadOnlyList<string> CorrelationIds
{
get
{
lock (gate)
{
return correlationIds.ToArray();
}
}
}
public IReadOnlyList<int> ThreadIds
{
get
{
lock (gate)
{
return threadIds.ToArray();
}
}
}
public MxCommandReply Execute(StaCommand command)
{
lock (gate)
{
correlationIds.Add(command.CorrelationId);
threadIds.Add(Thread.CurrentThread.ManagedThreadId);
}
return new MxCommandReply
{
ProtocolStatus = new ProtocolStatus
{
Code = ProtocolStatusCode.Ok,
Message = "OK",
},
};
}
}
private sealed class BlockingCommandExecutor : IStaCommandExecutor
{
private readonly ManualResetEventSlim release = new(false);
private readonly object gate = new();
private readonly List<string> correlationIds = new();
public ManualResetEventSlim Started { get; } = new(false);
public IReadOnlyList<string> CorrelationIds
{
get
{
lock (gate)
{
return correlationIds.ToArray();
}
}
}
public MxCommandReply Execute(StaCommand command)
{
lock (gate)
{
correlationIds.Add(command.CorrelationId);
}
Started.Set();
release.Wait(TimeSpan.FromSeconds(5));
return new MxCommandReply
{
ProtocolStatus = new ProtocolStatus
{
Code = ProtocolStatusCode.Ok,
Message = "OK",
},
};
}
public void Release()
{
release.Set();
}
}
private sealed class ThrowingCommandExecutor : IStaCommandExecutor
{
private readonly Exception exception;
public ThrowingCommandExecutor(Exception exception)
{
this.exception = exception;
}
public MxCommandReply Execute(StaCommand command)
{
throw exception;
}
}
private sealed class NoopComApartmentInitializer : IStaComApartmentInitializer
{
public void Initialize()
{
}
public void Uninitialize()
{
}
}
}
@@ -0,0 +1,8 @@
using MxGateway.Contracts.Proto;
namespace MxGateway.Worker.Sta;
public interface IStaCommandExecutor
{
MxCommandReply Execute(StaCommand command);
}
+47
View File
@@ -0,0 +1,47 @@
using System;
using System.Threading;
using Google.Protobuf.WellKnownTypes;
using MxGateway.Contracts.Proto;
namespace MxGateway.Worker.Sta;
public sealed class StaCommand
{
public StaCommand(
string sessionId,
string correlationId,
MxCommand command,
Timestamp? enqueueTimestamp = null,
CancellationToken cancellationToken = default)
{
if (string.IsNullOrWhiteSpace(sessionId))
{
throw new ArgumentException("STA command requires a session id.", nameof(sessionId));
}
if (string.IsNullOrWhiteSpace(correlationId))
{
throw new ArgumentException("STA command requires a correlation id.", nameof(correlationId));
}
SessionId = sessionId;
CorrelationId = correlationId;
Command = command ?? throw new ArgumentNullException(nameof(command));
EnqueueTimestamp = enqueueTimestamp ?? Timestamp.FromDateTime(DateTime.UtcNow);
CancellationToken = cancellationToken;
}
public string SessionId { get; }
public string CorrelationId { get; }
public MxCommand Command { get; }
public Timestamp EnqueueTimestamp { get; }
public CancellationToken CancellationToken { get; }
public MxCommandKind Kind => Command.Kind;
public string MethodName => Kind.ToString();
}
@@ -0,0 +1,267 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Conversion;
namespace MxGateway.Worker.Sta;
public sealed class StaCommandDispatcher
{
private readonly HResultConverter hresultConverter;
private readonly IStaCommandExecutor commandExecutor;
private readonly Queue<QueuedStaCommand> commandQueue = new();
private readonly StaRuntime staRuntime;
private readonly object gate = new();
private bool drainActive;
private bool shutdownRequested;
private string currentCommandCorrelationId = string.Empty;
public StaCommandDispatcher(
StaRuntime staRuntime,
IStaCommandExecutor commandExecutor)
: this(staRuntime, commandExecutor, new HResultConverter())
{
}
public StaCommandDispatcher(
StaRuntime staRuntime,
IStaCommandExecutor commandExecutor,
HResultConverter hresultConverter)
{
this.staRuntime = staRuntime ?? throw new ArgumentNullException(nameof(staRuntime));
this.commandExecutor = commandExecutor ?? throw new ArgumentNullException(nameof(commandExecutor));
this.hresultConverter = hresultConverter ?? throw new ArgumentNullException(nameof(hresultConverter));
}
public int PendingCommandCount
{
get
{
lock (gate)
{
return commandQueue.Count;
}
}
}
public string CurrentCommandCorrelationId
{
get
{
lock (gate)
{
return currentCommandCorrelationId;
}
}
}
public Task<MxCommandReply> DispatchAsync(StaCommand command)
{
if (command is null)
{
throw new ArgumentNullException(nameof(command));
}
lock (gate)
{
if (shutdownRequested)
{
return Task.FromResult(CreateRejectedReply(
command,
ProtocolStatusCode.WorkerUnavailable,
"The STA command dispatcher is shutting down."));
}
QueuedStaCommand queuedCommand = new(command);
commandQueue.Enqueue(queuedCommand);
if (!drainActive)
{
drainActive = true;
_ = DrainAsync();
}
return queuedCommand.Task;
}
}
public void RequestShutdown()
{
lock (gate)
{
shutdownRequested = true;
}
}
public void PopulateHeartbeat(WorkerHeartbeat heartbeat)
{
if (heartbeat is null)
{
throw new ArgumentNullException(nameof(heartbeat));
}
lock (gate)
{
heartbeat.PendingCommandCount = (uint)commandQueue.Count;
heartbeat.CurrentCommandCorrelationId = currentCommandCorrelationId;
}
}
private async Task DrainAsync()
{
while (true)
{
QueuedStaCommand queuedCommand;
lock (gate)
{
if (commandQueue.Count == 0)
{
drainActive = false;
return;
}
queuedCommand = commandQueue.Dequeue();
}
await ExecuteQueuedCommandAsync(queuedCommand).ConfigureAwait(false);
}
}
private async Task ExecuteQueuedCommandAsync(QueuedStaCommand queuedCommand)
{
StaCommand command = queuedCommand.Command;
if (command.CancellationToken.IsCancellationRequested)
{
queuedCommand.Complete(CreateRejectedReply(
command,
ProtocolStatusCode.Canceled,
"The STA command was canceled before execution."));
return;
}
SetCurrentCommand(command.CorrelationId);
try
{
MxCommandReply reply = await staRuntime
.InvokeAsync(() => commandExecutor.Execute(command))
.ConfigureAwait(false);
queuedCommand.Complete(NormalizeReply(command, reply));
}
catch (Exception exception)
{
queuedCommand.Complete(CreateExceptionReply(command, exception));
}
finally
{
SetCurrentCommand(string.Empty);
}
}
private void SetCurrentCommand(string correlationId)
{
lock (gate)
{
currentCommandCorrelationId = correlationId;
}
}
private MxCommandReply NormalizeReply(
StaCommand command,
MxCommandReply reply)
{
if (reply is null)
{
return CreateRejectedReply(
command,
ProtocolStatusCode.ProtocolViolation,
"STA command executor returned null.");
}
if (string.IsNullOrWhiteSpace(reply.SessionId))
{
reply.SessionId = command.SessionId;
}
if (string.IsNullOrWhiteSpace(reply.CorrelationId))
{
reply.CorrelationId = command.CorrelationId;
}
if (reply.Kind == MxCommandKind.Unspecified)
{
reply.Kind = command.Kind;
}
if (reply.ProtocolStatus is null)
{
reply.ProtocolStatus = new ProtocolStatus
{
Code = ProtocolStatusCode.Ok,
Message = "OK",
};
}
return reply;
}
private MxCommandReply CreateExceptionReply(
StaCommand command,
Exception exception)
{
HResultConversion conversion = hresultConverter.Convert(exception);
MxCommandReply reply = CreateBaseReply(command);
reply.ProtocolStatus = conversion.ProtocolStatus;
reply.Hresult = conversion.HResult;
reply.DiagnosticMessage = conversion.DiagnosticMessage;
return reply;
}
private static MxCommandReply CreateRejectedReply(
StaCommand command,
ProtocolStatusCode statusCode,
string message)
{
MxCommandReply reply = CreateBaseReply(command);
reply.ProtocolStatus = new ProtocolStatus
{
Code = statusCode,
Message = message,
};
reply.DiagnosticMessage = message;
return reply;
}
private static MxCommandReply CreateBaseReply(StaCommand command)
{
return new MxCommandReply
{
SessionId = command.SessionId,
CorrelationId = command.CorrelationId,
Kind = command.Kind,
};
}
private sealed class QueuedStaCommand
{
private readonly TaskCompletionSource<MxCommandReply> completion = new(
TaskCreationOptions.RunContinuationsAsynchronously);
public QueuedStaCommand(StaCommand command)
{
Command = command;
}
public StaCommand Command { get; }
public Task<MxCommandReply> Task => completion.Task;
public void Complete(MxCommandReply reply)
{
completion.TrySetResult(reply);
}
}
}