Compare commits

...

10 Commits

Author SHA1 Message Date
Joseph Doherty 6559672fc1 Issue #30: implement value conversion 2026-04-26 17:26:36 -04:00
dohertj2 97c30b9d00 Merge PR #66: Issue #23 implement STA runtime and message pump
Verified with dotnet build src\\MxGateway.sln, dotnet test src\\MxGateway.Worker.Tests\\MxGateway.Worker.Tests.csproj -p:Platform=x86, and dotnet test src\\MxGateway.sln --no-build.
2026-04-26 17:23:02 -04:00
dohertj2 603aff7004 Merge PR #65: Issue #22 implement pipe client and frame protocol
Verified with dotnet build src\\MxGateway.sln, dotnet test src\\MxGateway.Worker.Tests\\MxGateway.Worker.Tests.csproj -p:Platform=x86, and dotnet test src\\MxGateway.sln --no-build.
2026-04-26 17:20:28 -04:00
Joseph Doherty e81682e367 Issue #23: implement sta runtime and message pump 2026-04-26 17:19:00 -04:00
Joseph Doherty d5a982152b Issue #22: implement pipe client and frame protocol 2026-04-26 17:16:49 -04:00
dohertj2 0b0be7098e Merge PR #64: Issue #11 implement gateway WorkerClient
Verified with dotnet build src\\MxGateway.sln and dotnet test src\\MxGateway.sln.
2026-04-26 17:14:03 -04:00
Joseph Doherty fce9e99553 Issue #11: implement gateway workerclient 2026-04-26 17:09:51 -04:00
dohertj2 c8fb3e91a3 Merge PR #63: Issue #8 add gRPC authentication and scope authorization
Verified with dotnet build src\\MxGateway.sln and dotnet test src\\MxGateway.sln.
2026-04-26 17:06:23 -04:00
dohertj2 8ce327e6f4 Merge PR #62: Issue #7 implement local api key admin cli
Verified with dotnet build src\\MxGateway.sln and dotnet test src\\MxGateway.sln.
2026-04-26 17:02:09 -04:00
Joseph Doherty fad0ac9948 Issue #8: add grpc authentication and scope authorization 2026-04-26 17:01:59 -04:00
45 changed files with 4301 additions and 8 deletions
+34 -2
View File
@@ -411,7 +411,7 @@ session ids as protocol faults and close the session.
`WorkerClient` is the gateway-side object that owns one worker connection.
Suggested public shape:
Current public shape:
```csharp
public interface IWorkerClient : IAsyncDisposable
@@ -419,6 +419,7 @@ public interface IWorkerClient : IAsyncDisposable
string SessionId { get; }
int? ProcessId { get; }
WorkerClientState State { get; }
DateTimeOffset LastHeartbeatAt { get; }
Task StartAsync(CancellationToken cancellationToken);
Task<WorkerCommandReply> InvokeAsync(
@@ -438,12 +439,17 @@ Internally it owns:
- pipe stream,
- read loop,
- write loop,
- bounded outbound command/control channel,
- outbound command/control channel serialized by the write loop,
- bounded inbound event channel,
- pending command dictionary keyed by correlation id,
- heartbeat monitor,
- terminal fault source.
`StartAsync` sends `GatewayHello`, verifies the `WorkerHello` protocol version
and nonce, waits for `WorkerReady`, and only then exposes `Ready` state. The
read loop starts after readiness so the handshake has a single owner for its
ordered frames.
### Read Loop
The read loop:
@@ -612,6 +618,15 @@ hashes the presented secret, and compares the stored and presented hashes with
results distinguish malformed credentials, missing keys, revoked keys, missing
pepper configuration, and hash mismatch for internal authorization handling.
`GatewayGrpcAuthorizationInterceptor` enforces this authentication model for
public gRPC calls. Missing, malformed, revoked, unknown, or mismatched keys fail
with `Unauthenticated`. Authenticated calls missing the scope required by the
RPC fail with `PermissionDenied`. The interceptor applies to unary calls and
server-streaming calls and stores the authenticated `ApiKeyIdentity` in
`IGatewayRequestIdentityAccessor` for the duration of the request handler.
`Authentication:Mode` set to `Disabled` bypasses API-key verification for local
development only.
Recommended scopes:
- `session:open`
@@ -677,6 +692,20 @@ Commands requiring authorization:
- worker shutdown diagnostics,
- metadata queries if they expose sensitive plant structure.
Current gRPC scope mapping:
- `OpenSession` requires `session:open`.
- `CloseSession` requires `session:close`.
- `StreamEvents` and `DrainEvents` require `events:read`.
- read-style MXAccess commands such as `Register`, `AddItem`, `Advise`, and
`Ping` require `invoke:read`.
- `Write` and `Write2` require `invoke:write`.
- `WriteSecured`, `WriteSecured2`, and `AuthenticateUser` require
`invoke:secure`.
- metadata commands such as `ArchestrAUserToId`, `GetSessionState`, and
`GetWorkerInfo` require `metadata:read`.
- `ShutdownWorker` requires `admin`.
### Worker IPC
Named pipes should be local only. Pipe ACLs should restrict access to:
@@ -819,6 +848,9 @@ workers and fake transports.
Focused tests:
- session state transitions,
- gRPC API-key authentication for unary and streaming calls,
- gRPC scope mapping for sessions, invokes, events, metadata, and admin
commands,
- worker startup failures,
- protocol version mismatch,
- malformed frame handling,
+11
View File
@@ -250,6 +250,17 @@ The loop should update a heartbeat timestamp after:
- finishing a command,
- processing an MXAccess event.
`StaRuntime` implements this runtime boundary in the worker. It starts one
background thread named `MxGateway.Worker.STA`, sets it to `ApartmentState.STA`,
initializes COM through `StaComApartmentInitializer`, and runs
`StaMessagePump`. Commands are scheduled through `InvokeAsync`; the command
queue signals an `AutoResetEvent` so `MsgWaitForMultipleObjectsEx` can wake the
STA without busy-waiting. `LastActivityUtc` records pump, command, startup, and
shutdown activity so the future heartbeat/watchdog can report whether the STA
is still responsive. Shutdown marks the runtime as closing, wakes the pump,
rejects new commands, cancels queued work, uninitializes COM on the STA, and
waits for the thread to exit.
## COM Creation
The MXAccess analysis source at `C:\Users\dohertj2\Desktop\mxaccess` identifies
+7 -3
View File
@@ -566,9 +566,13 @@ Because each client owns one worker, a crash or leak affects only that session.
External gateway:
- use TLS for remote gRPC if crossing machine boundaries,
- authenticate clients with Windows auth, mTLS, or a deployment-specific token,
- authorize access to commands that can write, authenticate users, or alter
runtime state.
- authenticate v1 gRPC clients with `authorization: Bearer
mxgw_<key-id>_<secret>` API-key metadata,
- reject missing or invalid API keys with gRPC `Unauthenticated`,
- reject valid keys that lack the required session, invoke, event, metadata, or
admin scope with gRPC `PermissionDenied`,
- authorize access to commands that can write, authenticate users, expose
metadata, stream events, or alter runtime state.
Internal worker IPC:
@@ -3,6 +3,7 @@ using MxGateway.Server.Configuration;
using MxGateway.Server.Diagnostics;
using MxGateway.Server.Metrics;
using MxGateway.Server.Security.Authentication;
using MxGateway.Server.Security.Authorization;
using MxGateway.Server.Workers;
namespace MxGateway.Server;
@@ -26,6 +27,7 @@ public static class GatewayApplication
builder.Services.AddGatewayConfiguration();
builder.Services.AddSqliteAuthStore();
builder.Services.AddGatewayGrpcAuthorization();
builder.Services.AddHealthChecks();
builder.Services.AddSingleton<GatewayMetrics>();
builder.Services.AddWorkerProcessLauncher();
@@ -5,6 +5,7 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Grpc.AspNetCore" Version="2.76.0" />
<PackageReference Include="Microsoft.Data.Sqlite" Version="10.0.7" />
</ItemGroup>
@@ -0,0 +1,74 @@
using Grpc.Core;
using Grpc.Core.Interceptors;
using Microsoft.Extensions.Options;
using MxGateway.Server.Configuration;
using MxGateway.Server.Security.Authentication;
namespace MxGateway.Server.Security.Authorization;
public sealed class GatewayGrpcAuthorizationInterceptor(
IApiKeyVerifier apiKeyVerifier,
GatewayGrpcScopeResolver scopeResolver,
IGatewayRequestIdentityAccessor identityAccessor,
IOptions<GatewayOptions> options) : Interceptor
{
public override async Task<TResponse> UnaryServerHandler<TRequest, TResponse>(
TRequest request,
ServerCallContext context,
UnaryServerMethod<TRequest, TResponse> continuation)
{
ApiKeyIdentity? identity = await AuthenticateAndAuthorizeAsync(request, context).ConfigureAwait(false);
IDisposable? identityScope = identity is null ? null : identityAccessor.Push(identity);
using (identityScope)
{
return await continuation(request, context).ConfigureAwait(false);
}
}
public override async Task ServerStreamingServerHandler<TRequest, TResponse>(
TRequest request,
IServerStreamWriter<TResponse> responseStream,
ServerCallContext context,
ServerStreamingServerMethod<TRequest, TResponse> continuation)
{
ApiKeyIdentity? identity = await AuthenticateAndAuthorizeAsync(request, context).ConfigureAwait(false);
IDisposable? identityScope = identity is null ? null : identityAccessor.Push(identity);
using (identityScope)
{
await continuation(request, responseStream, context).ConfigureAwait(false);
}
}
private async Task<ApiKeyIdentity?> AuthenticateAndAuthorizeAsync<TRequest>(
TRequest request,
ServerCallContext context)
where TRequest : class
{
if (options.Value.Authentication.Mode == AuthenticationMode.Disabled)
{
return null;
}
string? authorizationHeader = context.RequestHeaders.GetValue("authorization");
ApiKeyVerificationResult verificationResult = await apiKeyVerifier
.VerifyAsync(authorizationHeader, context.CancellationToken)
.ConfigureAwait(false);
if (!verificationResult.Succeeded || verificationResult.Identity is null)
{
throw new RpcException(new Status(
StatusCode.Unauthenticated,
"Missing or invalid API key."));
}
string requiredScope = scopeResolver.ResolveRequiredScope(request);
if (!verificationResult.Identity.Scopes.Contains(requiredScope))
{
throw new RpcException(new Status(
StatusCode.PermissionDenied,
$"API key is missing required scope '{requiredScope}'."));
}
return verificationResult.Identity;
}
}
@@ -0,0 +1,40 @@
using MxGateway.Contracts.Proto;
namespace MxGateway.Server.Security.Authorization;
public sealed class GatewayGrpcScopeResolver
{
public string ResolveRequiredScope(object request)
{
return request switch
{
OpenSessionRequest => GatewayScopes.SessionOpen,
CloseSessionRequest => GatewayScopes.SessionClose,
StreamEventsRequest => GatewayScopes.EventsRead,
MxCommandRequest commandRequest => ResolveCommandScope(commandRequest.Command?.Kind ?? MxCommandKind.Unspecified),
_ => GatewayScopes.Admin
};
}
private static string ResolveCommandScope(MxCommandKind kind)
{
return kind switch
{
MxCommandKind.Write or
MxCommandKind.Write2 => GatewayScopes.InvokeWrite,
MxCommandKind.WriteSecured or
MxCommandKind.WriteSecured2 or
MxCommandKind.AuthenticateUser => GatewayScopes.InvokeSecure,
MxCommandKind.ArchestraUserToId or
MxCommandKind.GetSessionState or
MxCommandKind.GetWorkerInfo => GatewayScopes.MetadataRead,
MxCommandKind.DrainEvents => GatewayScopes.EventsRead,
MxCommandKind.ShutdownWorker => GatewayScopes.Admin,
_ => GatewayScopes.InvokeRead
};
}
}
@@ -0,0 +1,38 @@
using MxGateway.Server.Security.Authentication;
namespace MxGateway.Server.Security.Authorization;
public sealed class GatewayRequestIdentityAccessor : IGatewayRequestIdentityAccessor
{
private readonly AsyncLocal<ApiKeyIdentity?> currentIdentity = new();
public ApiKeyIdentity? Current => currentIdentity.Value;
public IDisposable Push(ApiKeyIdentity identity)
{
ArgumentNullException.ThrowIfNull(identity);
ApiKeyIdentity? previousIdentity = currentIdentity.Value;
currentIdentity.Value = identity;
return new IdentityScope(this, previousIdentity);
}
private sealed class IdentityScope(
GatewayRequestIdentityAccessor accessor,
ApiKeyIdentity? previousIdentity) : IDisposable
{
private bool disposed;
public void Dispose()
{
if (disposed)
{
return;
}
accessor.currentIdentity.Value = previousIdentity;
disposed = true;
}
}
}
@@ -0,0 +1,13 @@
namespace MxGateway.Server.Security.Authorization;
public static class GatewayScopes
{
public const string SessionOpen = "session:open";
public const string SessionClose = "session:close";
public const string InvokeRead = "invoke:read";
public const string InvokeWrite = "invoke:write";
public const string InvokeSecure = "invoke:secure";
public const string EventsRead = "events:read";
public const string MetadataRead = "metadata:read";
public const string Admin = "admin";
}
@@ -0,0 +1,16 @@
using Grpc.Core.Interceptors;
namespace MxGateway.Server.Security.Authorization;
public static class GrpcAuthorizationServiceCollectionExtensions
{
public static IServiceCollection AddGatewayGrpcAuthorization(this IServiceCollection services)
{
services.AddSingleton<GatewayGrpcScopeResolver>();
services.AddSingleton<IGatewayRequestIdentityAccessor, GatewayRequestIdentityAccessor>();
services.AddSingleton<GatewayGrpcAuthorizationInterceptor>();
services.AddGrpc(options => options.Interceptors.Add<GatewayGrpcAuthorizationInterceptor>());
return services;
}
}
@@ -0,0 +1,10 @@
using MxGateway.Server.Security.Authentication;
namespace MxGateway.Server.Security.Authorization;
public interface IGatewayRequestIdentityAccessor
{
ApiKeyIdentity? Current { get; }
IDisposable Push(ApiKeyIdentity identity);
}
@@ -0,0 +1,27 @@
using MxGateway.Contracts.Proto;
namespace MxGateway.Server.Workers;
public interface IWorkerClient : IAsyncDisposable
{
string SessionId { get; }
int? ProcessId { get; }
WorkerClientState State { get; }
DateTimeOffset LastHeartbeatAt { get; }
Task StartAsync(CancellationToken cancellationToken);
Task<WorkerCommandReply> InvokeAsync(
WorkerCommand command,
TimeSpan timeout,
CancellationToken cancellationToken);
IAsyncEnumerable<WorkerEvent> ReadEventsAsync(CancellationToken cancellationToken);
Task ShutdownAsync(TimeSpan timeout, CancellationToken cancellationToken);
void Kill(string reason);
}
@@ -0,0 +1,755 @@
using System.Collections.Concurrent;
using System.Runtime.CompilerServices;
using System.Threading.Channels;
using Google.Protobuf.WellKnownTypes;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Server.Metrics;
namespace MxGateway.Server.Workers;
public sealed class WorkerClient : IWorkerClient
{
private const string GatewayVersionFallback = "unknown";
private readonly object _syncRoot = new();
private readonly WorkerClientConnection _connection;
private readonly WorkerClientOptions _options;
private readonly GatewayMetrics? _metrics;
private readonly TimeProvider _timeProvider;
private readonly ILogger<WorkerClient> _logger;
private readonly WorkerFrameReader _reader;
private readonly WorkerFrameWriter _writer;
private readonly Channel<WorkerEnvelope> _outboundEnvelopes;
private readonly Channel<WorkerEvent> _events;
private readonly ConcurrentDictionary<string, PendingCommand> _pendingCommands = new(StringComparer.Ordinal);
private readonly CancellationTokenSource _stopCts = new();
private long _nextSequence;
private WorkerClientState _state;
private DateTimeOffset _lastHeartbeatAt;
private int? _processId;
private Task? _readLoopTask;
private Task? _writeLoopTask;
private Task? _heartbeatLoopTask;
private bool _disposed;
public WorkerClient(
WorkerClientConnection connection,
WorkerClientOptions? options = null,
GatewayMetrics? metrics = null,
TimeProvider? timeProvider = null,
ILogger<WorkerClient>? logger = null)
{
_connection = connection ?? throw new ArgumentNullException(nameof(connection));
_options = options ?? new WorkerClientOptions();
_metrics = metrics;
_timeProvider = timeProvider ?? TimeProvider.System;
_logger = logger ?? NullLogger<WorkerClient>.Instance;
_reader = new WorkerFrameReader(connection.Stream, connection.FrameOptions);
_writer = new WorkerFrameWriter(connection.Stream, connection.FrameOptions);
_outboundEnvelopes = Channel.CreateUnbounded<WorkerEnvelope>(
new UnboundedChannelOptions
{
SingleReader = true,
SingleWriter = false,
AllowSynchronousContinuations = false,
});
_events = Channel.CreateBounded<WorkerEvent>(
new BoundedChannelOptions(_options.EventChannelCapacity)
{
SingleReader = false,
SingleWriter = true,
FullMode = BoundedChannelFullMode.Wait,
AllowSynchronousContinuations = false,
});
_lastHeartbeatAt = _timeProvider.GetUtcNow();
}
public string SessionId => _connection.SessionId;
public int? ProcessId
{
get
{
lock (_syncRoot)
{
return _processId;
}
}
}
public WorkerClientState State
{
get
{
lock (_syncRoot)
{
return _state;
}
}
}
public DateTimeOffset LastHeartbeatAt
{
get
{
lock (_syncRoot)
{
return _lastHeartbeatAt;
}
}
}
public async Task StartAsync(CancellationToken cancellationToken)
{
ThrowIfDisposed();
TransitionFromCreatedToHandshaking();
_writeLoopTask = Task.Run(WriteLoopAsync);
await EnqueueAsync(CreateGatewayHelloEnvelope(), cancellationToken).ConfigureAwait(false);
WorkerEnvelope helloEnvelope = await ReadHandshakeEnvelopeAsync(
WorkerEnvelope.BodyOneofCase.WorkerHello,
cancellationToken).ConfigureAwait(false);
ValidateWorkerHello(helloEnvelope.WorkerHello);
WorkerEnvelope readyEnvelope = await ReadHandshakeEnvelopeAsync(
WorkerEnvelope.BodyOneofCase.WorkerReady,
cancellationToken).ConfigureAwait(false);
MarkReady(readyEnvelope.WorkerReady);
_readLoopTask = Task.Run(ReadLoopAsync);
_heartbeatLoopTask = Task.Run(HeartbeatLoopAsync);
}
public async Task<WorkerCommandReply> InvokeAsync(
WorkerCommand command,
TimeSpan timeout,
CancellationToken cancellationToken)
{
ArgumentNullException.ThrowIfNull(command);
ThrowIfDisposed();
EnsureReady();
if (timeout <= TimeSpan.Zero)
{
throw new ArgumentOutOfRangeException(nameof(timeout), timeout, "Command timeout must be greater than zero.");
}
string correlationId = Guid.NewGuid().ToString("N");
string method = GetCommandMethod(command);
PendingCommand pendingCommand = new(
correlationId,
method,
_timeProvider.GetTimestamp());
if (!_pendingCommands.TryAdd(correlationId, pendingCommand))
{
throw new InvalidOperationException("Generated a duplicate command correlation id.");
}
_metrics?.CommandStarted(method);
try
{
await EnqueueAsync(CreateCommandEnvelope(correlationId, command), cancellationToken).ConfigureAwait(false);
using CancellationTokenSource timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
Task timeoutTask = Task.Delay(timeout, timeoutCts.Token);
Task<WorkerCommandReply> replyTask = pendingCommand.Task;
Task completedTask = await Task.WhenAny(replyTask, timeoutTask).ConfigureAwait(false);
if (completedTask == replyTask)
{
await timeoutCts.CancelAsync().ConfigureAwait(false);
return await replyTask.ConfigureAwait(false);
}
if (cancellationToken.IsCancellationRequested)
{
RemovePendingCommandAsFailed(
correlationId,
pendingCommand,
WorkerClientErrorCode.GatewayShutdown,
"Command wait was canceled.");
cancellationToken.ThrowIfCancellationRequested();
}
RemovePendingCommandAsFailed(
correlationId,
pendingCommand,
WorkerClientErrorCode.CommandTimeout,
$"Worker command {method} timed out after {timeout}.");
throw new WorkerClientException(
WorkerClientErrorCode.CommandTimeout,
$"Worker command {method} timed out after {timeout}.");
}
catch
{
_pendingCommands.TryRemove(correlationId, out _);
throw;
}
}
public async IAsyncEnumerable<WorkerEvent> ReadEventsAsync(
[EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (WorkerEvent workerEvent in _events.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false))
{
yield return workerEvent;
}
}
public async Task ShutdownAsync(TimeSpan timeout, CancellationToken cancellationToken)
{
ThrowIfDisposed();
if (timeout <= TimeSpan.Zero)
{
throw new ArgumentOutOfRangeException(nameof(timeout), timeout, "Shutdown timeout must be greater than zero.");
}
WorkerClientState state = State;
if (state is WorkerClientState.Closed or WorkerClientState.Faulted)
{
return;
}
MarkClosing();
await EnqueueAsync(CreateShutdownEnvelope(timeout, "gateway-shutdown"), cancellationToken).ConfigureAwait(false);
_outboundEnvelopes.Writer.TryComplete();
using CancellationTokenSource timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
timeoutCts.CancelAfter(timeout);
try
{
await WaitForBackgroundTasksAsync(timeoutCts.Token).ConfigureAwait(false);
MarkClosed("shutdown");
}
catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested)
{
SetFaulted(
WorkerClientErrorCode.ShutdownTimeout,
"Worker shutdown timed out.",
null);
throw new WorkerClientException(
WorkerClientErrorCode.ShutdownTimeout,
$"Worker shutdown timed out after {timeout}.");
}
}
public void Kill(string reason)
{
ThrowIfDisposed();
_connection.ProcessHandle?.Process.Kill(entireProcessTree: true);
_metrics?.WorkerKilled(reason);
SetFaulted(
WorkerClientErrorCode.WorkerFaulted,
$"Worker was killed by the gateway: {reason}.",
null);
}
public async ValueTask DisposeAsync()
{
if (_disposed)
{
return;
}
_disposed = true;
_stopCts.Cancel();
_outboundEnvelopes.Writer.TryComplete();
_events.Writer.TryComplete();
CompletePendingCommands(
new WorkerClientException(
WorkerClientErrorCode.GatewayShutdown,
"Worker client was disposed."));
await WaitForBackgroundTasksAsync(CancellationToken.None).ConfigureAwait(false);
await _connection.Stream.DisposeAsync().ConfigureAwait(false);
_connection.ProcessHandle?.Dispose();
_stopCts.Dispose();
}
private async Task WriteLoopAsync()
{
try
{
await foreach (WorkerEnvelope envelope in _outboundEnvelopes.Reader.ReadAllAsync(_stopCts.Token).ConfigureAwait(false))
{
await _writer.WriteAsync(envelope, _stopCts.Token).ConfigureAwait(false);
}
}
catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState())
{
}
catch (Exception exception)
{
SetFaulted(
WorkerClientErrorCode.WriteFailed,
"Worker pipe write failed.",
exception);
}
}
private async Task ReadLoopAsync()
{
try
{
while (!_stopCts.IsCancellationRequested)
{
WorkerEnvelope envelope = await _reader.ReadAsync(_stopCts.Token).ConfigureAwait(false);
await DispatchEnvelopeAsync(envelope, _stopCts.Token).ConfigureAwait(false);
}
}
catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState())
{
}
catch (WorkerFrameProtocolException exception) when (exception.ErrorCode == WorkerFrameProtocolErrorCode.EndOfStream)
{
SetFaulted(
WorkerClientErrorCode.PipeDisconnected,
"Worker pipe disconnected.",
exception);
}
catch (Exception exception)
{
SetFaulted(
WorkerClientErrorCode.ProtocolViolation,
"Worker read loop failed.",
exception);
}
}
private async Task HeartbeatLoopAsync()
{
try
{
while (!_stopCts.IsCancellationRequested)
{
await Task.Delay(_options.HeartbeatCheckInterval, _stopCts.Token).ConfigureAwait(false);
if (State != WorkerClientState.Ready)
{
continue;
}
DateTimeOffset lastHeartbeatAt = LastHeartbeatAt;
DateTimeOffset now = _timeProvider.GetUtcNow();
if (now - lastHeartbeatAt <= _options.HeartbeatGrace)
{
continue;
}
_metrics?.HeartbeatFailed(SessionId);
SetFaulted(
WorkerClientErrorCode.HeartbeatExpired,
$"Worker heartbeat expired. Last heartbeat was at {lastHeartbeatAt:O}.",
null);
}
}
catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState())
{
}
}
private async Task DispatchEnvelopeAsync(
WorkerEnvelope envelope,
CancellationToken cancellationToken)
{
switch (envelope.BodyCase)
{
case WorkerEnvelope.BodyOneofCase.WorkerCommandReply:
CompleteCommand(envelope);
break;
case WorkerEnvelope.BodyOneofCase.WorkerEvent:
await EnqueueWorkerEventAsync(envelope.WorkerEvent, cancellationToken).ConfigureAwait(false);
break;
case WorkerEnvelope.BodyOneofCase.WorkerHeartbeat:
MarkHeartbeat(envelope.WorkerHeartbeat);
break;
case WorkerEnvelope.BodyOneofCase.WorkerFault:
SetFaulted(
WorkerClientErrorCode.WorkerFaulted,
CreateWorkerFaultMessage(envelope.WorkerFault),
null);
break;
case WorkerEnvelope.BodyOneofCase.WorkerShutdownAck:
MarkClosed("worker-shutdown-ack");
break;
default:
SetFaulted(
WorkerClientErrorCode.ProtocolViolation,
$"Worker sent unexpected envelope body {envelope.BodyCase}.",
null);
break;
}
}
private async Task EnqueueWorkerEventAsync(
WorkerEvent workerEvent,
CancellationToken cancellationToken)
{
if (workerEvent.Event is not null)
{
_metrics?.EventReceived(SessionId, workerEvent.Event.Family.ToString());
}
if (!await _events.Writer.WaitToWriteAsync(cancellationToken).ConfigureAwait(false))
{
return;
}
if (!_events.Writer.TryWrite(workerEvent))
{
_metrics?.QueueOverflow("worker-events");
SetFaulted(
WorkerClientErrorCode.ProtocolViolation,
"Worker event channel rejected an event.",
null);
}
}
private void CompleteCommand(WorkerEnvelope envelope)
{
string correlationId = envelope.CorrelationId;
if (string.IsNullOrWhiteSpace(correlationId))
{
correlationId = envelope.WorkerCommandReply.Reply?.CorrelationId ?? string.Empty;
}
if (!_pendingCommands.TryRemove(correlationId, out PendingCommand? pendingCommand))
{
_logger.LogDebug(
"Ignoring late or unknown worker command reply for session {SessionId} and correlation {CorrelationId}.",
SessionId,
correlationId);
return;
}
TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp);
_metrics?.CommandSucceeded(pendingCommand.Method, duration);
pendingCommand.SetResult(envelope.WorkerCommandReply);
}
private void RemovePendingCommandAsFailed(
string correlationId,
PendingCommand pendingCommand,
WorkerClientErrorCode errorCode,
string message)
{
if (!_pendingCommands.TryRemove(correlationId, out _))
{
return;
}
TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp);
_metrics?.CommandFailed(pendingCommand.Method, errorCode.ToString(), duration);
pendingCommand.SetException(new WorkerClientException(errorCode, message));
}
private async Task<WorkerEnvelope> ReadHandshakeEnvelopeAsync(
WorkerEnvelope.BodyOneofCase expectedBody,
CancellationToken cancellationToken)
{
WorkerEnvelope envelope = await _reader.ReadAsync(cancellationToken).ConfigureAwait(false);
if (envelope.BodyCase != expectedBody)
{
throw new WorkerClientException(
WorkerClientErrorCode.ProtocolViolation,
$"Worker handshake expected {expectedBody} but received {envelope.BodyCase}.");
}
return envelope;
}
private void ValidateWorkerHello(WorkerHello workerHello)
{
if (workerHello.ProtocolVersion != _connection.FrameOptions.ProtocolVersion)
{
throw new WorkerClientException(
WorkerClientErrorCode.ProtocolViolation,
"Worker hello protocol version does not match the gateway protocol version.");
}
if (!string.Equals(workerHello.Nonce, _connection.Nonce, StringComparison.Ordinal))
{
throw new WorkerClientException(
WorkerClientErrorCode.ProtocolViolation,
"Worker hello nonce does not match the gateway nonce.");
}
lock (_syncRoot)
{
_processId = workerHello.WorkerProcessId == 0
? _connection.ProcessHandle?.ProcessId
: workerHello.WorkerProcessId;
}
}
private void MarkReady(WorkerReady ready)
{
lock (_syncRoot)
{
_processId = ready.WorkerProcessId == 0
? _processId ?? _connection.ProcessHandle?.ProcessId
: ready.WorkerProcessId;
_lastHeartbeatAt = _timeProvider.GetUtcNow();
_state = WorkerClientState.Ready;
}
DateTimeOffset readyAt = _timeProvider.GetUtcNow();
DateTimeOffset launchedAt = _connection.ProcessHandle?.LaunchedAt ?? readyAt;
_metrics?.WorkerStarted(readyAt - launchedAt);
}
private void MarkHeartbeat(WorkerHeartbeat heartbeat)
{
lock (_syncRoot)
{
_lastHeartbeatAt = _timeProvider.GetUtcNow();
if (heartbeat.WorkerProcessId != 0)
{
_processId = heartbeat.WorkerProcessId;
}
}
}
private void MarkClosing()
{
lock (_syncRoot)
{
if (_state is WorkerClientState.Closed or WorkerClientState.Faulted)
{
return;
}
_state = WorkerClientState.Closing;
}
}
private void MarkClosed(string reason)
{
lock (_syncRoot)
{
if (_state == WorkerClientState.Closed)
{
return;
}
_state = WorkerClientState.Closed;
}
_stopCts.Cancel();
_outboundEnvelopes.Writer.TryComplete();
_events.Writer.TryComplete();
CompletePendingCommands(
new WorkerClientException(
WorkerClientErrorCode.GatewayShutdown,
$"Worker client closed because {reason}."));
_metrics?.WorkerStopped(reason);
}
private void SetFaulted(
WorkerClientErrorCode errorCode,
string message,
Exception? exception)
{
WorkerClientException fault = exception is null
? new WorkerClientException(errorCode, message)
: new WorkerClientException(errorCode, message, exception);
lock (_syncRoot)
{
if (_state is WorkerClientState.Faulted or WorkerClientState.Closed)
{
return;
}
_state = WorkerClientState.Faulted;
}
_stopCts.Cancel();
_outboundEnvelopes.Writer.TryComplete(fault);
_events.Writer.TryComplete(fault);
CompletePendingCommands(fault);
_metrics?.Fault(errorCode.ToString());
_logger.LogWarning(exception, "Worker client faulted for session {SessionId}: {Message}", SessionId, message);
}
private void CompletePendingCommands(Exception exception)
{
foreach (KeyValuePair<string, PendingCommand> item in _pendingCommands.ToArray())
{
if (_pendingCommands.TryRemove(item.Key, out PendingCommand? pendingCommand))
{
TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp);
_metrics?.CommandFailed(pendingCommand.Method, exception.GetType().Name, duration);
pendingCommand.SetException(exception);
}
}
}
private void TransitionFromCreatedToHandshaking()
{
lock (_syncRoot)
{
if (_state != WorkerClientState.Created)
{
throw new WorkerClientException(
WorkerClientErrorCode.InvalidState,
$"Worker client cannot start from state {_state}.");
}
_state = WorkerClientState.Handshaking;
}
}
private void EnsureReady()
{
WorkerClientState state = State;
if (state != WorkerClientState.Ready)
{
throw new WorkerClientException(
WorkerClientErrorCode.InvalidState,
$"Worker client is not ready. Current state is {state}.");
}
}
private bool IsTerminalState()
{
WorkerClientState state = State;
return state is WorkerClientState.Closing or WorkerClientState.Closed or WorkerClientState.Faulted;
}
private async Task EnqueueAsync(
WorkerEnvelope envelope,
CancellationToken cancellationToken)
{
try
{
await _outboundEnvelopes.Writer.WriteAsync(envelope, cancellationToken).ConfigureAwait(false);
}
catch (ChannelClosedException exception)
{
throw new WorkerClientException(
WorkerClientErrorCode.WriteFailed,
"Worker outbound channel is closed.",
exception);
}
}
private WorkerEnvelope CreateGatewayHelloEnvelope()
{
return CreateEnvelope(
correlationId: string.Empty,
envelope => envelope.GatewayHello = new GatewayHello
{
SupportedProtocolVersion = _connection.FrameOptions.ProtocolVersion,
Nonce = _connection.Nonce,
GatewayVersion = typeof(GatewayContractInfo).Assembly.GetName().Version?.ToString() ?? GatewayVersionFallback,
});
}
private WorkerEnvelope CreateCommandEnvelope(
string correlationId,
WorkerCommand command)
{
return CreateEnvelope(
correlationId,
envelope => envelope.WorkerCommand = command.Clone());
}
private WorkerEnvelope CreateShutdownEnvelope(
TimeSpan timeout,
string reason)
{
return CreateEnvelope(
correlationId: string.Empty,
envelope => envelope.WorkerShutdown = new WorkerShutdown
{
GracePeriod = Duration.FromTimeSpan(timeout),
Reason = reason,
});
}
private WorkerEnvelope CreateEnvelope(
string correlationId,
Action<WorkerEnvelope> setBody)
{
WorkerEnvelope envelope = new()
{
ProtocolVersion = _connection.FrameOptions.ProtocolVersion,
SessionId = SessionId,
Sequence = (ulong)Interlocked.Increment(ref _nextSequence),
CorrelationId = correlationId,
};
setBody(envelope);
return envelope;
}
private static string GetCommandMethod(WorkerCommand command)
{
return command.Command?.Kind.ToString() ?? MxCommandKind.Unspecified.ToString();
}
private static string CreateWorkerFaultMessage(WorkerFault fault)
{
return string.IsNullOrWhiteSpace(fault.DiagnosticMessage)
? $"Worker faulted with category {fault.Category}."
: $"Worker faulted with category {fault.Category}: {fault.DiagnosticMessage}";
}
private async Task WaitForBackgroundTasksAsync(CancellationToken cancellationToken)
{
Task[] tasks = new[] { _readLoopTask, _writeLoopTask, _heartbeatLoopTask }
.Where(task => task is not null)
.Cast<Task>()
.ToArray();
if (tasks.Length == 0)
{
return;
}
await Task.WhenAll(tasks).WaitAsync(cancellationToken).ConfigureAwait(false);
}
private void ThrowIfDisposed()
{
ObjectDisposedException.ThrowIf(_disposed, this);
}
private sealed class PendingCommand
{
private readonly TaskCompletionSource<WorkerCommandReply> _completion = new(TaskCreationOptions.RunContinuationsAsynchronously);
public PendingCommand(
string correlationId,
string method,
long startTimestamp)
{
CorrelationId = correlationId;
Method = method;
StartTimestamp = startTimestamp;
}
public string CorrelationId { get; }
public string Method { get; }
public long StartTimestamp { get; }
public Task<WorkerCommandReply> Task => _completion.Task;
public void SetResult(WorkerCommandReply reply)
{
_completion.TrySetResult(reply);
}
public void SetException(Exception exception)
{
_completion.TrySetException(exception);
}
}
}
@@ -0,0 +1,38 @@
namespace MxGateway.Server.Workers;
public sealed class WorkerClientConnection
{
public WorkerClientConnection(
string sessionId,
string nonce,
Stream stream,
WorkerFrameProtocolOptions frameOptions,
WorkerProcessHandle? processHandle = null)
{
if (string.IsNullOrWhiteSpace(sessionId))
{
throw new ArgumentException("Session id is required.", nameof(sessionId));
}
if (string.IsNullOrWhiteSpace(nonce))
{
throw new ArgumentException("Worker nonce is required.", nameof(nonce));
}
SessionId = sessionId;
Nonce = nonce;
Stream = stream ?? throw new ArgumentNullException(nameof(stream));
FrameOptions = frameOptions ?? throw new ArgumentNullException(nameof(frameOptions));
ProcessHandle = processHandle;
}
public string SessionId { get; }
public string Nonce { get; }
public Stream Stream { get; }
public WorkerFrameProtocolOptions FrameOptions { get; }
public WorkerProcessHandle? ProcessHandle { get; }
}
@@ -0,0 +1,14 @@
namespace MxGateway.Server.Workers;
public enum WorkerClientErrorCode
{
InvalidState,
ProtocolViolation,
PipeDisconnected,
CommandTimeout,
WorkerFaulted,
HeartbeatExpired,
ShutdownTimeout,
GatewayShutdown,
WriteFailed,
}
@@ -0,0 +1,23 @@
namespace MxGateway.Server.Workers;
public sealed class WorkerClientException : Exception
{
public WorkerClientException(
WorkerClientErrorCode errorCode,
string message)
: base(message)
{
ErrorCode = errorCode;
}
public WorkerClientException(
WorkerClientErrorCode errorCode,
string message,
Exception innerException)
: base(message, innerException)
{
ErrorCode = errorCode;
}
public WorkerClientErrorCode ErrorCode { get; }
}
@@ -0,0 +1,24 @@
namespace MxGateway.Server.Workers;
public sealed class WorkerClientOptions
{
public static readonly TimeSpan DefaultHeartbeatGrace = TimeSpan.FromSeconds(15);
public static readonly TimeSpan DefaultHeartbeatCheckInterval = TimeSpan.FromSeconds(1);
public static readonly TimeSpan DefaultEventChannelFullModeTimeout = TimeSpan.FromSeconds(5);
public WorkerClientOptions()
{
HeartbeatGrace = DefaultHeartbeatGrace;
HeartbeatCheckInterval = DefaultHeartbeatCheckInterval;
EventChannelCapacity = 1_024;
EventChannelFullModeTimeout = DefaultEventChannelFullModeTimeout;
}
public TimeSpan HeartbeatGrace { get; init; }
public TimeSpan HeartbeatCheckInterval { get; init; }
public int EventChannelCapacity { get; init; }
public TimeSpan EventChannelFullModeTimeout { get; init; }
}
@@ -0,0 +1,11 @@
namespace MxGateway.Server.Workers;
public enum WorkerClientState
{
Created,
Handshaking,
Ready,
Closing,
Closed,
Faulted,
}
@@ -0,0 +1,341 @@
using System.IO.Pipes;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Server.Workers;
namespace MxGateway.Tests.Gateway.Workers;
public sealed class WorkerClientTests
{
private const string SessionId = "session-worker-client";
private const string Nonce = "nonce-worker-client";
private const int WorkerProcessId = 4321;
private static readonly TimeSpan TestTimeout = TimeSpan.FromSeconds(5);
[Fact]
public async Task StartAsync_WithWorkerHelloAndReady_EntersReadyState()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
Assert.Equal(WorkerClientState.Ready, client.State);
Assert.Equal(WorkerProcessId, client.ProcessId);
}
[Fact]
public async Task InvokeAsync_WithMatchingReply_CompletesPendingCommand()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
Task<WorkerCommandReply> invokeTask = client.InvokeAsync(
CreateCommand(MxCommandKind.Ping),
TestTimeout,
CancellationToken.None);
WorkerEnvelope commandEnvelope = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerCommand, commandEnvelope.BodyCase);
Assert.False(string.IsNullOrWhiteSpace(commandEnvelope.CorrelationId));
await pipePair.WorkerWriter.WriteAsync(
CreateCommandReplyEnvelope(commandEnvelope.CorrelationId, MxCommandKind.Ping));
WorkerCommandReply reply = await invokeTask.WaitAsync(TestTimeout);
Assert.Equal(commandEnvelope.CorrelationId, reply.Reply.CorrelationId);
Assert.Equal(MxCommandKind.Ping, reply.Reply.Kind);
}
[Fact]
public async Task InvokeAsync_WithLateReply_IgnoresLateReplyAndKeepsClientReady()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
Task<WorkerCommandReply> timedOutInvokeTask = client.InvokeAsync(
CreateCommand(MxCommandKind.Ping),
TimeSpan.FromMilliseconds(50),
CancellationToken.None);
WorkerEnvelope timedOutCommand = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
WorkerClientException exception = await Assert.ThrowsAsync<WorkerClientException>(
async () => await timedOutInvokeTask);
Assert.Equal(WorkerClientErrorCode.CommandTimeout, exception.ErrorCode);
await pipePair.WorkerWriter.WriteAsync(
CreateCommandReplyEnvelope(timedOutCommand.CorrelationId, MxCommandKind.Ping));
await Task.Delay(TimeSpan.FromMilliseconds(50));
Task<WorkerCommandReply> secondInvokeTask = client.InvokeAsync(
CreateCommand(MxCommandKind.GetWorkerInfo),
TestTimeout,
CancellationToken.None);
WorkerEnvelope secondCommand = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
await pipePair.WorkerWriter.WriteAsync(
CreateCommandReplyEnvelope(secondCommand.CorrelationId, MxCommandKind.GetWorkerInfo));
WorkerCommandReply reply = await secondInvokeTask.WaitAsync(TestTimeout);
Assert.Equal(WorkerClientState.Ready, client.State);
Assert.Equal(MxCommandKind.GetWorkerInfo, reply.Reply.Kind);
}
[Fact]
public async Task ReadEventsAsync_WithWorkerEvents_YieldsEventsInPipeOrder()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
using CancellationTokenSource cancellationTokenSource = new(TestTimeout);
await using IAsyncEnumerator<WorkerEvent> events =
client.ReadEventsAsync(cancellationTokenSource.Token).GetAsyncEnumerator(cancellationTokenSource.Token);
await pipePair.WorkerWriter.WriteAsync(
CreateEventEnvelope(sequence: 11, MxEventFamily.OnDataChange));
await pipePair.WorkerWriter.WriteAsync(
CreateEventEnvelope(sequence: 12, MxEventFamily.OperationComplete));
Assert.True(await events.MoveNextAsync());
Assert.Equal((ulong)11, events.Current.Event.WorkerSequence);
Assert.Equal(MxEventFamily.OnDataChange, events.Current.Event.Family);
Assert.True(await events.MoveNextAsync());
Assert.Equal((ulong)12, events.Current.Event.WorkerSequence);
Assert.Equal(MxEventFamily.OperationComplete, events.Current.Event.Family);
}
[Fact]
public async Task ReadLoop_WhenPipeDisconnects_FaultsClient()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
await pipePair.DisposeWorkerSideAsync();
await WaitUntilAsync(
() => client.State == WorkerClientState.Faulted,
TestTimeout);
Assert.Equal(WorkerClientState.Faulted, client.State);
}
[Fact]
public async Task HeartbeatMonitor_WhenHeartbeatExpires_FaultsClient()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(
pipePair,
new WorkerClientOptions
{
HeartbeatGrace = TimeSpan.FromMilliseconds(80),
HeartbeatCheckInterval = TimeSpan.FromMilliseconds(20),
EventChannelCapacity = 8,
});
await CompleteHandshakeAsync(client, pipePair);
await WaitUntilAsync(
() => client.State == WorkerClientState.Faulted,
TestTimeout);
Assert.Equal(WorkerClientState.Faulted, client.State);
}
private static WorkerClient CreateClient(
PipePair pipePair,
WorkerClientOptions? options = null)
{
WorkerFrameProtocolOptions frameOptions = new(SessionId);
WorkerClientConnection connection = new(
SessionId,
Nonce,
pipePair.GatewayStream,
frameOptions);
return new WorkerClient(connection, options);
}
private static async Task CompleteHandshakeAsync(
WorkerClient client,
PipePair pipePair)
{
Task startTask = client.StartAsync(CancellationToken.None);
WorkerEnvelope gatewayHello = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
Assert.Equal(WorkerEnvelope.BodyOneofCase.GatewayHello, gatewayHello.BodyCase);
Assert.Equal(Nonce, gatewayHello.GatewayHello.Nonce);
Assert.Equal(GatewayContractInfo.WorkerProtocolVersion, gatewayHello.GatewayHello.SupportedProtocolVersion);
await pipePair.WorkerWriter.WriteAsync(CreateWorkerHelloEnvelope());
await pipePair.WorkerWriter.WriteAsync(CreateWorkerReadyEnvelope());
await startTask.WaitAsync(TestTimeout);
}
private static WorkerCommand CreateCommand(MxCommandKind kind)
{
return new WorkerCommand
{
Command = new MxCommand
{
Kind = kind,
},
};
}
private static WorkerEnvelope CreateWorkerHelloEnvelope()
{
return CreateWorkerEnvelope(
correlationId: string.Empty,
sequence: 1,
envelope => envelope.WorkerHello = new WorkerHello
{
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
Nonce = Nonce,
WorkerProcessId = WorkerProcessId,
WorkerVersion = "fake-worker",
});
}
private static WorkerEnvelope CreateWorkerReadyEnvelope()
{
return CreateWorkerEnvelope(
correlationId: string.Empty,
sequence: 2,
envelope => envelope.WorkerReady = new WorkerReady
{
WorkerProcessId = WorkerProcessId,
MxaccessProgid = "LMXProxy.LMXProxyServer.1",
MxaccessClsid = "{C30B52F5-2CB5-4760-AF0A-3A344A7EB5DC}",
});
}
private static WorkerEnvelope CreateCommandReplyEnvelope(
string correlationId,
MxCommandKind kind)
{
return CreateWorkerEnvelope(
correlationId,
sequence: 10,
envelope => envelope.WorkerCommandReply = new WorkerCommandReply
{
Reply = new MxCommandReply
{
SessionId = SessionId,
CorrelationId = correlationId,
Kind = kind,
},
});
}
private static WorkerEnvelope CreateEventEnvelope(
ulong sequence,
MxEventFamily family)
{
return CreateWorkerEnvelope(
correlationId: string.Empty,
sequence,
envelope => envelope.WorkerEvent = new WorkerEvent
{
Event = new MxEvent
{
SessionId = SessionId,
Family = family,
WorkerSequence = sequence,
},
});
}
private static WorkerEnvelope CreateWorkerEnvelope(
string correlationId,
ulong sequence,
Action<WorkerEnvelope> setBody)
{
WorkerEnvelope envelope = new()
{
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
SessionId = SessionId,
Sequence = sequence,
CorrelationId = correlationId,
};
setBody(envelope);
return envelope;
}
private static async Task WaitUntilAsync(
Func<bool> predicate,
TimeSpan timeout)
{
using CancellationTokenSource cancellationTokenSource = new(timeout);
while (!predicate())
{
await Task.Delay(TimeSpan.FromMilliseconds(10), cancellationTokenSource.Token);
}
}
private sealed class PipePair : IAsyncDisposable
{
private readonly NamedPipeClientStream _workerStream;
private bool _workerSideDisposed;
private PipePair(
NamedPipeServerStream gatewayStream,
NamedPipeClientStream workerStream)
{
GatewayStream = gatewayStream;
_workerStream = workerStream;
WorkerReader = new WorkerFrameReader(_workerStream, new WorkerFrameProtocolOptions(SessionId));
WorkerWriter = new WorkerFrameWriter(_workerStream, new WorkerFrameProtocolOptions(SessionId));
}
public NamedPipeServerStream GatewayStream { get; }
public WorkerFrameReader WorkerReader { get; }
public WorkerFrameWriter WorkerWriter { get; }
public static async Task<PipePair> CreateAsync()
{
string pipeName = $"mxaccessgw-workerclient-tests-{Guid.NewGuid():N}";
NamedPipeServerStream gatewayStream = new(
pipeName,
PipeDirection.InOut,
maxNumberOfServerInstances: 1,
PipeTransmissionMode.Byte,
PipeOptions.Asynchronous);
NamedPipeClientStream workerStream = new(
".",
pipeName,
PipeDirection.InOut,
PipeOptions.Asynchronous);
Task waitForConnectionTask = gatewayStream.WaitForConnectionAsync();
await workerStream.ConnectAsync();
await waitForConnectionTask;
return new PipePair(gatewayStream, workerStream);
}
public async ValueTask DisposeWorkerSideAsync()
{
if (_workerSideDisposed)
{
return;
}
await _workerStream.DisposeAsync();
_workerSideDisposed = true;
}
public async ValueTask DisposeAsync()
{
await DisposeWorkerSideAsync();
await GatewayStream.DisposeAsync();
}
}
}
@@ -0,0 +1,267 @@
using Grpc.Core;
using Microsoft.Extensions.Options;
using MxGateway.Contracts.Proto;
using MxGateway.Server.Configuration;
using MxGateway.Server.Security.Authentication;
using MxGateway.Server.Security.Authorization;
namespace MxGateway.Tests.Security.Authorization;
public sealed class GatewayGrpcAuthorizationInterceptorTests
{
[Fact]
public async Task UnaryServerHandler_MissingApiKey_ReturnsUnauthenticated()
{
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
new FakeApiKeyVerifier(ApiKeyVerificationResult.Fail(
ApiKeyVerificationFailure.MissingOrMalformedCredentials)),
new GatewayRequestIdentityAccessor());
RpcException exception = await Assert.ThrowsAsync<RpcException>(
() => interceptor.UnaryServerHandler(
new OpenSessionRequest(),
new TestServerCallContext([]),
(_, _) => Task.FromResult(new OpenSessionReply())));
Assert.Equal(StatusCode.Unauthenticated, exception.StatusCode);
Assert.DoesNotContain("secret", exception.Status.Detail, StringComparison.OrdinalIgnoreCase);
}
[Fact]
public async Task UnaryServerHandler_InvalidApiKey_DoesNotExposeRawCredentialInStatus()
{
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
new FakeApiKeyVerifier(ApiKeyVerificationResult.Fail(ApiKeyVerificationFailure.SecretMismatch)),
new GatewayRequestIdentityAccessor());
RpcException exception = await Assert.ThrowsAsync<RpcException>(
() => interceptor.UnaryServerHandler(
new OpenSessionRequest(),
ContextWithAuthorization("Bearer mxgw_operator01_super-secret"),
(_, _) => Task.FromResult(new OpenSessionReply())));
Assert.Equal(StatusCode.Unauthenticated, exception.StatusCode);
Assert.DoesNotContain("super-secret", exception.Status.Detail, StringComparison.Ordinal);
}
[Fact]
public async Task UnaryServerHandler_ValidApiKeyMissingScope_ReturnsPermissionDenied()
{
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
new FakeApiKeyVerifier(SuccessWithScopes(GatewayScopes.EventsRead)),
new GatewayRequestIdentityAccessor());
RpcException exception = await Assert.ThrowsAsync<RpcException>(
() => interceptor.UnaryServerHandler(
new OpenSessionRequest(),
ContextWithAuthorization("Bearer mxgw_operator01_secret"),
(_, _) => Task.FromResult(new OpenSessionReply())));
Assert.Equal(StatusCode.PermissionDenied, exception.StatusCode);
Assert.Contains(GatewayScopes.SessionOpen, exception.Status.Detail, StringComparison.Ordinal);
}
[Fact]
public async Task UnaryServerHandler_ValidApiKeyWithScope_SetsRequestIdentity()
{
GatewayRequestIdentityAccessor identityAccessor = new();
ApiKeyIdentity? identitySeenByHandler = null;
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
new FakeApiKeyVerifier(SuccessWithScopes(GatewayScopes.SessionOpen)),
identityAccessor);
OpenSessionReply reply = await interceptor.UnaryServerHandler(
new OpenSessionRequest(),
ContextWithAuthorization("Bearer mxgw_operator01_secret"),
(_, _) =>
{
identitySeenByHandler = identityAccessor.Current;
return Task.FromResult(new OpenSessionReply { SessionId = "session-1" });
});
Assert.Equal("session-1", reply.SessionId);
Assert.NotNull(identitySeenByHandler);
Assert.Equal("operator01", identitySeenByHandler.KeyId);
Assert.Null(identityAccessor.Current);
}
[Fact]
public async Task ServerStreamingServerHandler_ValidApiKeyMissingScope_ReturnsPermissionDenied()
{
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
new FakeApiKeyVerifier(SuccessWithScopes(GatewayScopes.SessionOpen)),
new GatewayRequestIdentityAccessor());
RpcException exception = await Assert.ThrowsAsync<RpcException>(
() => interceptor.ServerStreamingServerHandler(
new StreamEventsRequest(),
new TestServerStreamWriter<MxEvent>(),
ContextWithAuthorization("Bearer mxgw_operator01_secret"),
(_, _, _) => Task.CompletedTask));
Assert.Equal(StatusCode.PermissionDenied, exception.StatusCode);
Assert.Contains(GatewayScopes.EventsRead, exception.Status.Detail, StringComparison.Ordinal);
}
[Fact]
public async Task ServerStreamingServerHandler_ValidApiKeyWithScope_AllowsStream()
{
GatewayRequestIdentityAccessor identityAccessor = new();
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
new FakeApiKeyVerifier(SuccessWithScopes(GatewayScopes.EventsRead)),
identityAccessor);
TestServerStreamWriter<MxEvent> streamWriter = new();
await interceptor.ServerStreamingServerHandler(
new StreamEventsRequest(),
streamWriter,
ContextWithAuthorization("Bearer mxgw_operator01_secret"),
async (_, writer, _) =>
{
Assert.Equal("operator01", identityAccessor.Current?.KeyId);
await writer.WriteAsync(new MxEvent { SessionId = "session-1" });
});
MxEvent eventMessage = Assert.Single(streamWriter.Messages);
Assert.Equal("session-1", eventMessage.SessionId);
Assert.Null(identityAccessor.Current);
}
[Fact]
public async Task UnaryServerHandler_AuthenticationDisabled_SkipsApiKeyVerification()
{
GatewayRequestIdentityAccessor identityAccessor = new();
FakeApiKeyVerifier verifier = new(ApiKeyVerificationResult.Fail(
ApiKeyVerificationFailure.MissingOrMalformedCredentials));
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
verifier,
identityAccessor,
AuthenticationMode.Disabled);
OpenSessionReply reply = await interceptor.UnaryServerHandler(
new OpenSessionRequest(),
new TestServerCallContext([]),
(_, _) => Task.FromResult(new OpenSessionReply { SessionId = "session-1" }));
Assert.Equal("session-1", reply.SessionId);
Assert.False(verifier.WasCalled);
Assert.Null(identityAccessor.Current);
}
private static GatewayGrpcAuthorizationInterceptor CreateInterceptor(
IApiKeyVerifier apiKeyVerifier,
IGatewayRequestIdentityAccessor identityAccessor,
AuthenticationMode authenticationMode = AuthenticationMode.ApiKey)
{
return new GatewayGrpcAuthorizationInterceptor(
apiKeyVerifier,
new GatewayGrpcScopeResolver(),
identityAccessor,
Options.Create(new GatewayOptions
{
Authentication = new AuthenticationOptions
{
Mode = authenticationMode
}
}));
}
private static ApiKeyVerificationResult SuccessWithScopes(params string[] scopes)
{
return ApiKeyVerificationResult.Success(new ApiKeyIdentity(
KeyId: "operator01",
KeyPrefix: "mxgw_operator01",
DisplayName: "Operator Key",
Scopes: new HashSet<string>(scopes, StringComparer.Ordinal)));
}
private static TestServerCallContext ContextWithAuthorization(string authorizationHeader)
{
return new TestServerCallContext([new Metadata.Entry("authorization", authorizationHeader)]);
}
private sealed class FakeApiKeyVerifier(ApiKeyVerificationResult result) : IApiKeyVerifier
{
public bool WasCalled { get; private set; }
public string? LastAuthorizationHeader { get; private set; }
public Task<ApiKeyVerificationResult> VerifyAsync(
string? authorizationHeader,
CancellationToken cancellationToken)
{
WasCalled = true;
LastAuthorizationHeader = authorizationHeader;
return Task.FromResult(result);
}
}
private sealed class TestServerStreamWriter<T> : IServerStreamWriter<T>
{
public List<T> Messages { get; } = [];
public WriteOptions? WriteOptions { get; set; }
public Task WriteAsync(T message)
{
Messages.Add(message);
return Task.CompletedTask;
}
}
private sealed class TestServerCallContext(
Metadata requestHeaders,
CancellationToken cancellationToken = default) : ServerCallContext
{
private readonly Metadata responseTrailers = [];
private readonly Dictionary<object, object> userState = [];
private Status status;
private WriteOptions? writeOptions;
protected override string MethodCore => "/mxaccess_gateway.v1.MxAccessGateway/Test";
protected override string HostCore => "localhost";
protected override string PeerCore => "ipv4:127.0.0.1:5000";
protected override DateTime DeadlineCore => DateTime.UtcNow.AddMinutes(1);
protected override Metadata RequestHeadersCore => requestHeaders;
protected override CancellationToken CancellationTokenCore => cancellationToken;
protected override Metadata ResponseTrailersCore => responseTrailers;
protected override Status StatusCore
{
get => status;
set => status = value;
}
protected override WriteOptions? WriteOptionsCore
{
get => writeOptions;
set => writeOptions = value;
}
protected override AuthContext AuthContextCore { get; } = new(
string.Empty,
new Dictionary<string, List<AuthProperty>>(StringComparer.Ordinal));
protected override IDictionary<object, object> UserStateCore => userState;
protected override Task WriteResponseHeadersAsyncCore(Metadata responseHeaders)
{
return Task.CompletedTask;
}
protected override ContextPropagationToken CreatePropagationTokenCore(
ContextPropagationOptions? options)
{
throw new NotSupportedException();
}
}
}
@@ -0,0 +1,54 @@
using MxGateway.Contracts.Proto;
using MxGateway.Server.Security.Authorization;
namespace MxGateway.Tests.Security.Authorization;
public sealed class GatewayGrpcScopeResolverTests
{
[Theory]
[InlineData(typeof(OpenSessionRequest), GatewayScopes.SessionOpen)]
[InlineData(typeof(CloseSessionRequest), GatewayScopes.SessionClose)]
[InlineData(typeof(StreamEventsRequest), GatewayScopes.EventsRead)]
public void ResolveRequiredScope_KnownRpcRequest_ReturnsExpectedScope(
Type requestType,
string expectedScope)
{
GatewayGrpcScopeResolver resolver = new();
object request = Activator.CreateInstance(requestType)!;
string scope = resolver.ResolveRequiredScope(request);
Assert.Equal(expectedScope, scope);
}
[Theory]
[InlineData(MxCommandKind.Register, GatewayScopes.InvokeRead)]
[InlineData(MxCommandKind.AddItem, GatewayScopes.InvokeRead)]
[InlineData(MxCommandKind.Advise, GatewayScopes.InvokeRead)]
[InlineData(MxCommandKind.Write, GatewayScopes.InvokeWrite)]
[InlineData(MxCommandKind.Write2, GatewayScopes.InvokeWrite)]
[InlineData(MxCommandKind.WriteSecured, GatewayScopes.InvokeSecure)]
[InlineData(MxCommandKind.WriteSecured2, GatewayScopes.InvokeSecure)]
[InlineData(MxCommandKind.AuthenticateUser, GatewayScopes.InvokeSecure)]
[InlineData(MxCommandKind.ArchestraUserToId, GatewayScopes.MetadataRead)]
[InlineData(MxCommandKind.GetSessionState, GatewayScopes.MetadataRead)]
[InlineData(MxCommandKind.GetWorkerInfo, GatewayScopes.MetadataRead)]
[InlineData(MxCommandKind.DrainEvents, GatewayScopes.EventsRead)]
[InlineData(MxCommandKind.ShutdownWorker, GatewayScopes.Admin)]
public void ResolveRequiredScope_InvokeCommand_ReturnsExpectedScope(
MxCommandKind commandKind,
string expectedScope)
{
GatewayGrpcScopeResolver resolver = new();
string scope = resolver.ResolveRequiredScope(new MxCommandRequest
{
Command = new MxCommand
{
Kind = commandKind
}
});
Assert.Equal(expectedScope, scope);
}
}
@@ -1,6 +1,9 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Contracts;
using MxGateway.Worker.Bootstrap;
using MxGateway.Worker.Ipc;
namespace MxGateway.Worker.Tests.Bootstrap;
@@ -15,16 +18,19 @@ public sealed class WorkerApplicationTests
int exitCode = MxGateway.Worker.WorkerApplication.Run(
ValidArgs(),
environment,
logger);
logger,
new SucceedingPipeClient());
Assert.Equal((int)WorkerExitCode.Success, exitCode);
MemoryWorkerLogEntry entry = Assert.Single(logger.Entries);
Assert.Equal(2, logger.Entries.Count);
MemoryWorkerLogEntry entry = logger.Entries[0];
Assert.Equal("Information", entry.Level);
Assert.Equal("WorkerBootstrapSucceeded", entry.EventName);
Assert.Equal("session-1", entry.Fields["session_id"]);
Assert.Equal("mxaccess-gateway-123-session-1", entry.Fields["pipe_name"]);
Assert.Equal(GatewayContractInfo.WorkerProtocolVersion, entry.Fields["protocol_version"]);
Assert.Equal("[redacted]", entry.Fields["nonce"]);
Assert.Equal("WorkerPipeHandshakeSucceeded", logger.Entries[1].EventName);
}
[Fact]
@@ -73,6 +79,24 @@ public sealed class WorkerApplicationTests
Assert.Equal((int)WorkerExitCode.MissingNonce, exitCode);
}
[Fact]
public void Run_WithPipeProtocolFailure_ReturnsProtocolViolation()
{
MemoryWorkerEnvironment environment = CreateEnvironment("nonce-secret");
MemoryWorkerLogger logger = new();
int exitCode = MxGateway.Worker.WorkerApplication.Run(
ValidArgs(),
environment,
logger,
new ThrowingPipeClient(new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.NonceMismatch,
"Bad nonce.")));
Assert.Equal((int)WorkerExitCode.ProtocolViolation, exitCode);
Assert.Equal("WorkerPipeProtocolFailure", logger.Entries[1].EventName);
}
[Fact]
public void Run_WithUnexpectedBootstrapFailure_ReturnsUnexpectedFailure()
{
@@ -110,4 +134,31 @@ public sealed class WorkerApplicationTests
environment.Set(WorkerOptions.NonceEnvironmentVariableName, nonce);
return environment;
}
private sealed class SucceedingPipeClient : IWorkerPipeClient
{
public Task RunAsync(
WorkerOptions options,
CancellationToken cancellationToken = default)
{
return Task.CompletedTask;
}
}
private sealed class ThrowingPipeClient : IWorkerPipeClient
{
private readonly Exception _exception;
public ThrowingPipeClient(Exception exception)
{
_exception = exception;
}
public Task RunAsync(
WorkerOptions options,
CancellationToken cancellationToken = default)
{
throw _exception;
}
}
}
@@ -0,0 +1,183 @@
using System;
using Google.Protobuf;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Bootstrap;
using MxGateway.Worker.Conversion;
using ProtobufTimestamp = Google.Protobuf.WellKnownTypes.Timestamp;
namespace MxGateway.Worker.Tests.Conversion;
public sealed class VariantConverterTests
{
private readonly VariantConverter _converter = new();
[Theory]
[InlineData(true, MxDataType.Boolean, MxValue.KindOneofCase.BoolValue)]
[InlineData(42, MxDataType.Integer, MxValue.KindOneofCase.Int32Value)]
[InlineData(42L, MxDataType.Integer, MxValue.KindOneofCase.Int64Value)]
[InlineData(1.25f, MxDataType.Float, MxValue.KindOneofCase.FloatValue)]
[InlineData(2.5d, MxDataType.Double, MxValue.KindOneofCase.DoubleValue)]
[InlineData("value", MxDataType.String, MxValue.KindOneofCase.StringValue)]
public void Convert_WithSupportedScalar_ProjectsTypedValue(
object value,
MxDataType expectedDataType,
MxValue.KindOneofCase expectedKind)
{
MxValue converted = _converter.Convert(value);
Assert.Equal(expectedDataType, converted.DataType);
Assert.Equal(expectedKind, converted.KindCase);
Assert.False(string.IsNullOrWhiteSpace(converted.VariantType));
}
[Fact]
public void Convert_WithDateTime_ProjectsTimestamp()
{
DateTime dateTime = new(2026, 4, 26, 17, 45, 0, DateTimeKind.Utc);
MxValue converted = _converter.Convert(dateTime);
Assert.Equal(MxDataType.Time, converted.DataType);
Assert.Equal(ProtobufTimestamp.FromDateTime(dateTime), converted.TimestampValue);
Assert.Equal("VT_DATE", converted.VariantType);
}
[Fact]
public void Convert_WithFileTimeAndExpectedTime_ProjectsTimestamp()
{
DateTime dateTime = new(2026, 4, 26, 17, 45, 0, DateTimeKind.Utc);
MxValue converted = _converter.Convert(dateTime.ToFileTimeUtc(), MxDataType.Time);
Assert.Equal(MxDataType.Time, converted.DataType);
Assert.Equal(ProtobufTimestamp.FromDateTime(dateTime), converted.TimestampValue);
Assert.Equal("VT_I8", converted.VariantType);
}
[Theory]
[InlineData(null, "VT_EMPTY")]
[InlineData(typeof(DBNull), "VT_NULL")]
public void Convert_WithNullLikeValue_PreservesNull(
object? value,
string expectedVariantType)
{
object? actualValue = value is System.Type ? DBNull.Value : value;
MxValue converted = _converter.Convert(actualValue);
Assert.True(converted.IsNull);
Assert.Equal(MxDataType.NoData, converted.DataType);
Assert.Equal(expectedVariantType, converted.VariantType);
Assert.Equal(MxValue.KindOneofCase.None, converted.KindCase);
}
[Fact]
public void ConvertArray_WithSupportedArrays_ProjectsTypedValuesAndDimensions()
{
MxValue bools = _converter.Convert(new[] { true, false });
MxValue ints = _converter.Convert(new[] { 1, 2, 3 });
MxValue floats = _converter.Convert(new[] { 1.25f, 2.5f });
MxValue doubles = _converter.Convert(new[] { 1.25d, 2.5d });
MxValue strings = _converter.Convert(new[] { "one", "two" });
MxValue times = _converter.Convert(new[]
{
new DateTime(2026, 4, 26, 17, 45, 0, DateTimeKind.Utc),
new DateTime(2026, 4, 26, 17, 46, 0, DateTimeKind.Utc),
});
Assert.Equal(new[] { true, false }, bools.ArrayValue.BoolValues.Values);
Assert.Equal(new[] { 1, 2, 3 }, ints.ArrayValue.Int32Values.Values);
Assert.Equal(new[] { 1.25f, 2.5f }, floats.ArrayValue.FloatValues.Values);
Assert.Equal(new[] { 1.25d, 2.5d }, doubles.ArrayValue.DoubleValues.Values);
Assert.Equal(new[] { "one", "two" }, strings.ArrayValue.StringValues.Values);
Assert.Equal(2, times.ArrayValue.TimestampValues.Values.Count);
Assert.Equal(new uint[] { 2 }, bools.ArrayValue.Dimensions);
Assert.Equal(MxDataType.Boolean, bools.ArrayValue.ElementDataType);
}
[Fact]
public void ConvertArray_WithMultidimensionalArray_PreservesRankAndDimensions()
{
int[,] values =
{
{ 1, 2, 3 },
{ 4, 5, 6 },
};
MxValue converted = _converter.Convert(values);
Assert.Equal(new uint[] { 2, 3 }, converted.ArrayValue.Dimensions);
Assert.Equal(new[] { 1, 2, 3, 4, 5, 6 }, converted.ArrayValue.Int32Values.Values);
}
[Fact]
public void ConvertArray_WithExpectedTimeAndFileTimeValues_ProjectsTimestampArray()
{
DateTime first = new(2026, 4, 26, 17, 45, 0, DateTimeKind.Utc);
DateTime second = new(2026, 4, 26, 17, 46, 0, DateTimeKind.Utc);
MxValue converted = _converter.Convert(
new[] { first.ToFileTimeUtc(), second.ToFileTimeUtc() },
MxDataType.Time);
Assert.Equal(MxDataType.Time, converted.ArrayValue.ElementDataType);
Assert.Equal(
new[] { ProtobufTimestamp.FromDateTime(first), ProtobufTimestamp.FromDateTime(second) },
converted.ArrayValue.TimestampValues.Values);
}
[Fact]
public void Convert_WithUnknownScalar_PreservesRawMetadata()
{
UnsupportedVariant value = new("opaque");
MxValue converted = _converter.Convert(value);
Assert.Equal(MxDataType.Unknown, converted.DataType);
Assert.Equal(MxValue.KindOneofCase.RawValue, converted.KindCase);
Assert.Contains(typeof(UnsupportedVariant).FullName!, converted.VariantType);
Assert.Contains(typeof(UnsupportedVariant).FullName!, converted.RawDiagnostic);
Assert.Equal(ByteString.CopyFromUtf8("opaque"), converted.RawValue);
}
[Fact]
public void ConvertArray_WithUnknownArray_PreservesRawMetadata()
{
UnsupportedVariant[] values =
[
new("first"),
new("second"),
];
MxValue converted = _converter.Convert(values);
Assert.Equal(MxDataType.Unknown, converted.ArrayValue.ElementDataType);
Assert.Equal(MxArray.ValuesOneofCase.RawValues, converted.ArrayValue.ValuesCase);
Assert.Equal(new uint[] { 2 }, converted.ArrayValue.Dimensions);
Assert.Equal("first", converted.ArrayValue.RawValues.Values[0].ToStringUtf8());
Assert.Contains(typeof(UnsupportedVariant).FullName!, converted.ArrayValue.RawDiagnostic);
}
[Fact]
public void Redactor_WithCredentialBearingValueFields_RedactsBeforeLogging()
{
Assert.Equal(WorkerLogRedactor.RedactedValue, WorkerLogRedactor.RedactValue("credential_value", "secret"));
Assert.Equal(WorkerLogRedactor.RedactedValue, WorkerLogRedactor.RedactValue("password_value", "secret"));
Assert.Equal(WorkerLogRedactor.RedactedValue, WorkerLogRedactor.RedactValue("secured_write_token", "secret"));
}
private sealed class UnsupportedVariant
{
private readonly string _value;
public UnsupportedVariant(string value)
{
_value = value;
}
public override string ToString()
{
return _value;
}
}
}
@@ -0,0 +1,163 @@
using System;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Google.Protobuf;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Ipc;
namespace MxGateway.Worker.Tests.Ipc;
public sealed class WorkerFrameProtocolTests
{
private const string SessionId = "session-1";
private const string Nonce = "nonce-secret";
[Fact]
public async Task WriteAndReadAsync_WithValidEnvelope_RoundTripsFrame()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream stream = new();
WorkerEnvelope original = CreateGatewayHelloEnvelope();
WorkerFrameWriter writer = new(stream, options);
await writer.WriteAsync(original);
stream.Position = 0;
WorkerFrameReader reader = new(stream, options);
WorkerEnvelope parsed = await reader.ReadAsync();
Assert.Equal(original, parsed);
}
[Fact]
public async Task ReadAsync_WithWrongProtocolVersion_ThrowsProtocolVersionMismatch()
{
WorkerFrameProtocolOptions options = CreateOptions();
WorkerEnvelope envelope = CreateGatewayHelloEnvelope();
envelope.ProtocolVersion++;
MemoryStream stream = new(CreateFrame(envelope));
WorkerFrameReader reader = new(stream, options);
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await reader.ReadAsync());
Assert.Equal(WorkerFrameProtocolErrorCode.ProtocolVersionMismatch, exception.ErrorCode);
}
[Fact]
public async Task ReadAsync_WithWrongSessionId_ThrowsSessionMismatch()
{
WorkerFrameProtocolOptions options = CreateOptions();
WorkerEnvelope envelope = CreateGatewayHelloEnvelope();
envelope.SessionId = "different-session";
MemoryStream stream = new(CreateFrame(envelope));
WorkerFrameReader reader = new(stream, options);
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await reader.ReadAsync());
Assert.Equal(WorkerFrameProtocolErrorCode.SessionMismatch, exception.ErrorCode);
}
[Fact]
public async Task ReadAsync_WithMalformedLength_ThrowsMalformedLength()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream stream = new(new byte[sizeof(uint)]);
WorkerFrameReader reader = new(stream, options);
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await reader.ReadAsync());
Assert.Equal(WorkerFrameProtocolErrorCode.MalformedLength, exception.ErrorCode);
}
[Fact]
public async Task ReadAsync_WithMalformedPayload_ThrowsInvalidEnvelope()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream stream = new(CreateFrame(new byte[] { 0x80 }));
WorkerFrameReader reader = new(stream, options);
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await reader.ReadAsync());
Assert.Equal(WorkerFrameProtocolErrorCode.InvalidEnvelope, exception.ErrorCode);
}
[Fact]
public async Task WriteAsync_WithConcurrentCalls_SerializesCompleteFrames()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream stream = new();
WorkerFrameWriter writer = new(stream, options);
await Task.WhenAll(
writer.WriteAsync(CreateGatewayHelloEnvelope(sequence: 1)),
writer.WriteAsync(CreateGatewayHelloEnvelope(sequence: 2)),
writer.WriteAsync(CreateGatewayHelloEnvelope(sequence: 3)));
stream.Position = 0;
WorkerFrameReader reader = new(stream, options);
WorkerEnvelope first = await reader.ReadAsync();
WorkerEnvelope second = await reader.ReadAsync();
WorkerEnvelope third = await reader.ReadAsync();
Assert.Equal(new ulong[] { 1, 2, 3 }, new[] { first.Sequence, second.Sequence, third.Sequence }.OrderBy(sequence => sequence));
}
private static WorkerFrameProtocolOptions CreateOptions()
{
return new WorkerFrameProtocolOptions(
SessionId,
GatewayContractInfo.WorkerProtocolVersion,
Nonce);
}
private static WorkerEnvelope CreateGatewayHelloEnvelope(ulong sequence = 1)
{
return new WorkerEnvelope
{
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
SessionId = SessionId,
Sequence = sequence,
GatewayHello = new GatewayHello
{
SupportedProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
Nonce = Nonce,
GatewayVersion = "test-gateway",
},
};
}
private static byte[] CreateFrame(IMessage message)
{
return CreateFrame(message.ToByteArray());
}
private static byte[] CreateFrame(byte[] payload)
{
byte[] frame = new byte[sizeof(uint) + payload.Length];
WriteUInt32LittleEndian(frame, (uint)payload.Length);
payload.CopyTo(frame, sizeof(uint));
return frame;
}
private static void WriteUInt32LittleEndian(
byte[] buffer,
uint value)
{
buffer[0] = (byte)value;
buffer[1] = (byte)(value >> 8);
buffer[2] = (byte)(value >> 16);
buffer[3] = (byte)(value >> 24);
}
}
@@ -0,0 +1,61 @@
using System;
using System.IO.Pipes;
using System.Threading.Tasks;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Bootstrap;
using MxGateway.Worker.Ipc;
namespace MxGateway.Worker.Tests.Ipc;
public sealed class WorkerPipeClientTests
{
[Fact]
public async Task RunAsync_ConnectsToPipeAndCompletesHandshake()
{
string pipeName = $"mxaccess-gateway-test-{Guid.NewGuid():N}";
WorkerOptions workerOptions = new(
"session-1",
pipeName,
GatewayContractInfo.WorkerProtocolVersion,
"nonce-secret");
WorkerFrameProtocolOptions frameOptions = new(workerOptions);
using NamedPipeServerStream server = new(
pipeName,
PipeDirection.InOut,
1,
PipeTransmissionMode.Byte,
PipeOptions.Asynchronous);
WorkerPipeClient client = new(connectTimeoutMilliseconds: 5000);
Task clientTask = client.RunAsync(workerOptions);
await Task.Factory.FromAsync(server.BeginWaitForConnection, server.EndWaitForConnection, null);
WorkerFrameReader reader = new(server, frameOptions);
WorkerFrameWriter writer = new(server, frameOptions);
await writer.WriteAsync(new WorkerEnvelope
{
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
SessionId = "session-1",
Sequence = 1,
GatewayHello = new GatewayHello
{
SupportedProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
Nonce = "nonce-secret",
GatewayVersion = "test-gateway",
},
});
WorkerEnvelope hello = await reader.ReadAsync();
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerHello, hello.BodyCase);
Assert.Equal("nonce-secret", hello.WorkerHello.Nonce);
WorkerEnvelope ready = await reader.ReadAsync();
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerReady, ready.BodyCase);
await clientTask;
}
}
@@ -0,0 +1,192 @@
using System.Collections.Generic;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Ipc;
namespace MxGateway.Worker.Tests.Ipc;
public sealed class WorkerPipeSessionTests
{
private const string SessionId = "session-1";
private const string Nonce = "nonce-secret";
[Fact]
public async Task CompleteStartupHandshakeAsync_WithValidGatewayHello_SendsHelloThenReady()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream inbound = new();
await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope());
inbound.Position = 0;
MemoryStream outbound = new();
WorkerPipeSession session = CreateSession(inbound, outbound, options);
bool initialized = false;
await session.CompleteStartupHandshakeAsync(
_ =>
{
initialized = true;
return Task.CompletedTask;
});
Assert.True(initialized);
WorkerEnvelope[] written = ReadWrittenFrames(outbound, options);
Assert.Equal(2, written.Length);
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerHello, written[0].BodyCase);
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerReady, written[1].BodyCase);
Assert.Equal(Nonce, written[0].WorkerHello.Nonce);
}
[Fact]
public async Task CompleteStartupHandshakeAsync_WithWrongNonce_FaultsBeforeInitialization()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream inbound = new();
await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope(nonce: "wrong"));
inbound.Position = 0;
MemoryStream outbound = new();
WorkerPipeSession session = CreateSession(inbound, outbound, options);
bool initialized = false;
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await session.CompleteStartupHandshakeAsync(
_ =>
{
initialized = true;
return Task.CompletedTask;
}));
Assert.False(initialized);
Assert.Equal(WorkerFrameProtocolErrorCode.NonceMismatch, exception.ErrorCode);
WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options));
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerFault, fault.BodyCase);
Assert.Equal(WorkerFaultCategory.ProtocolViolation, fault.WorkerFault.Category);
}
[Fact]
public async Task CompleteStartupHandshakeAsync_WithWrongProtocol_FaultsBeforeInitialization()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream inbound = new();
await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope(supportedProtocolVersion: 999));
inbound.Position = 0;
MemoryStream outbound = new();
WorkerPipeSession session = CreateSession(inbound, outbound, options);
bool initialized = false;
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await session.CompleteStartupHandshakeAsync(
_ =>
{
initialized = true;
return Task.CompletedTask;
}));
Assert.False(initialized);
Assert.Equal(WorkerFrameProtocolErrorCode.ProtocolVersionMismatch, exception.ErrorCode);
WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options));
Assert.Equal(WorkerFaultCategory.ProtocolMismatch, fault.WorkerFault.Category);
}
[Fact]
public async Task CompleteStartupHandshakeAsync_WithMalformedFrame_WritesWorkerFault()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream inbound = new(CreateFrame(new byte[] { 0x80 }));
MemoryStream outbound = new();
WorkerPipeSession session = CreateSession(inbound, outbound, options);
bool initialized = false;
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await session.CompleteStartupHandshakeAsync(
_ =>
{
initialized = true;
return Task.CompletedTask;
}));
Assert.False(initialized);
Assert.Equal(WorkerFrameProtocolErrorCode.InvalidEnvelope, exception.ErrorCode);
WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options));
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerFault, fault.BodyCase);
Assert.Equal(WorkerFaultCategory.ProtocolViolation, fault.WorkerFault.Category);
}
private static WorkerPipeSession CreateSession(
Stream inbound,
Stream outbound,
WorkerFrameProtocolOptions options)
{
return new WorkerPipeSession(
new WorkerFrameReader(inbound, options),
new WorkerFrameWriter(outbound, options),
options,
() => 1234);
}
private static WorkerFrameProtocolOptions CreateOptions()
{
return new WorkerFrameProtocolOptions(
SessionId,
GatewayContractInfo.WorkerProtocolVersion,
Nonce);
}
private static WorkerEnvelope CreateGatewayHelloEnvelope(
string nonce = Nonce,
uint supportedProtocolVersion = GatewayContractInfo.WorkerProtocolVersion)
{
return new WorkerEnvelope
{
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
SessionId = SessionId,
Sequence = 1,
GatewayHello = new GatewayHello
{
SupportedProtocolVersion = supportedProtocolVersion,
Nonce = nonce,
GatewayVersion = "test-gateway",
},
};
}
private static WorkerEnvelope[] ReadWrittenFrames(
MemoryStream stream,
WorkerFrameProtocolOptions options)
{
stream.Position = 0;
WorkerFrameReader reader = new(stream, options);
List<WorkerEnvelope> envelopes = new();
while (stream.Position < stream.Length)
{
envelopes.Add(reader.ReadAsync(CancellationToken.None).GetAwaiter().GetResult());
}
return envelopes.ToArray();
}
private static byte[] CreateFrame(byte[] payload)
{
byte[] frame = new byte[sizeof(uint) + payload.Length];
WriteUInt32LittleEndian(frame, (uint)payload.Length);
payload.CopyTo(frame, sizeof(uint));
return frame;
}
private static void WriteUInt32LittleEndian(
byte[] buffer,
uint value)
{
buffer[0] = (byte)value;
buffer[1] = (byte)(value >> 8);
buffer[2] = (byte)(value >> 16);
buffer[3] = (byte)(value >> 24);
}
}
@@ -0,0 +1,152 @@
using System;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Worker.Sta;
namespace MxGateway.Worker.Tests.Sta;
public sealed class StaRuntimeTests
{
[Fact]
public async Task InvokeAsync_ExecutesCommandOnStaThread()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
StaCommandObservation observation = await runtime.InvokeAsync(
() => new StaCommandObservation(
Thread.CurrentThread.ManagedThreadId,
Thread.CurrentThread.GetApartmentState()));
Assert.Equal(runtime.StaThreadId, observation.ThreadId);
Assert.Equal(initializer.InitializeThreadId, observation.ThreadId);
Assert.Equal(ApartmentState.STA, observation.ApartmentState);
}
[Fact]
public async Task InvokeAsync_WakesIdlePumpForQueuedCommand()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = new(
initializer,
new StaMessagePump(),
TimeSpan.FromSeconds(30));
runtime.Start();
Stopwatch stopwatch = Stopwatch.StartNew();
int threadId = await runtime.InvokeAsync(() => Thread.CurrentThread.ManagedThreadId);
stopwatch.Stop();
Assert.Equal(runtime.StaThreadId, threadId);
Assert.True(
stopwatch.Elapsed < TimeSpan.FromSeconds(2),
$"Command took {stopwatch.Elapsed} to execute, so the command wake event did not wake the STA promptly.");
}
[Fact]
public void Shutdown_StopsThreadAndUninitializesComApartment()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
bool stopped = runtime.Shutdown(TimeSpan.FromSeconds(2));
Assert.True(stopped);
Assert.False(runtime.IsRunning);
Assert.Equal(1, initializer.InitializeCount);
Assert.Equal(1, initializer.UninitializeCount);
Assert.Equal(initializer.InitializeThreadId, initializer.UninitializeThreadId);
}
[Fact]
public void LastActivityUtc_UpdatesWhilePumpIsIdle()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
DateTimeOffset firstActivity = runtime.LastActivityUtc;
bool updated = SpinWait.SpinUntil(
() => runtime.LastActivityUtc > firstActivity,
TimeSpan.FromSeconds(2));
Assert.True(updated);
}
[Fact]
public async Task InvokeAsync_CommandException_FaultsReturnedTaskWithoutStoppingRuntime()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
InvalidOperationException exception = await Assert.ThrowsAsync<InvalidOperationException>(
() => runtime.InvokeAsync<int>(() => throw new InvalidOperationException("command failed")));
int threadId = await runtime.InvokeAsync(() => Thread.CurrentThread.ManagedThreadId);
Assert.Equal("command failed", exception.Message);
Assert.Equal(runtime.StaThreadId, threadId);
}
[Fact]
public async Task InvokeAsync_AfterShutdown_ReturnsFaultedTask()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
runtime.Shutdown(TimeSpan.FromSeconds(2));
InvalidOperationException exception = await Assert.ThrowsAsync<InvalidOperationException>(
() => runtime.InvokeAsync(() => Thread.CurrentThread.ManagedThreadId));
Assert.Contains("shutting down", exception.Message);
}
private static StaRuntime CreateRuntime(RecordingComApartmentInitializer initializer)
{
return new StaRuntime(
initializer,
new StaMessagePump(),
TimeSpan.FromMilliseconds(25));
}
private sealed class StaCommandObservation
{
public StaCommandObservation(int threadId, ApartmentState apartmentState)
{
ThreadId = threadId;
ApartmentState = apartmentState;
}
public int ThreadId { get; }
public ApartmentState ApartmentState { get; }
}
private sealed class RecordingComApartmentInitializer : IStaComApartmentInitializer
{
public int InitializeCount { get; private set; }
public int UninitializeCount { get; private set; }
public int? InitializeThreadId { get; private set; }
public int? UninitializeThreadId { get; private set; }
public void Initialize()
{
InitializeCount++;
InitializeThreadId = Thread.CurrentThread.ManagedThreadId;
}
public void Uninitialize()
{
UninitializeCount++;
UninitializeThreadId = Thread.CurrentThread.ManagedThreadId;
}
}
}
@@ -7,4 +7,6 @@ public enum WorkerExitCode
InvalidArguments = 2,
InvalidProtocolVersion = 3,
MissingNonce = 4,
PipeConnectionFailed = 5,
ProtocolViolation = 6,
}
@@ -0,0 +1,522 @@
using System;
using System.Globalization;
using Google.Protobuf;
using Google.Protobuf.WellKnownTypes;
using MxGateway.Contracts.Proto;
namespace MxGateway.Worker.Conversion;
public sealed class VariantConverter
{
public MxValue Convert(object? value)
{
return Convert(value, MxDataType.Unspecified);
}
public MxValue Convert(
object? value,
MxDataType expectedDataType)
{
if (value is null || value is DBNull)
{
return CreateNullValue(value, expectedDataType);
}
if (value is Array array)
{
return new MxValue
{
DataType = MxDataType.Unspecified,
VariantType = CreateArrayVariantType(array),
ArrayValue = ConvertArray(array, expectedDataType),
};
}
return ConvertScalar(value, expectedDataType);
}
public MxArray ConvertArray(
Array array,
MxDataType expectedElementDataType = MxDataType.Unspecified)
{
if (array is null)
{
throw new ArgumentNullException(nameof(array));
}
MxArray mxArray = new()
{
VariantType = CreateArrayVariantType(array),
};
for (int dimension = 0; dimension < array.Rank; dimension++)
{
mxArray.Dimensions.Add((uint)array.GetLength(dimension));
}
System.Type? elementType = array.GetType().GetElementType();
MxDataType elementDataType = ResolveArrayElementDataType(elementType, expectedElementDataType);
mxArray.ElementDataType = elementDataType;
switch (elementDataType)
{
case MxDataType.Boolean:
mxArray.BoolValues = ConvertBoolArray(array);
return mxArray;
case MxDataType.Integer:
if (elementType == typeof(long) || elementType == typeof(ulong))
{
mxArray.Int64Values = ConvertInt64Array(array);
}
else
{
mxArray.Int32Values = ConvertInt32Array(array);
}
return mxArray;
case MxDataType.Float:
mxArray.FloatValues = ConvertFloatArray(array);
return mxArray;
case MxDataType.Double:
mxArray.DoubleValues = ConvertDoubleArray(array);
return mxArray;
case MxDataType.String:
mxArray.StringValues = ConvertStringArray(array);
return mxArray;
case MxDataType.Time:
mxArray.TimestampValues = ConvertTimestampArray(array);
return mxArray;
default:
mxArray.ElementDataType = MxDataType.Unknown;
mxArray.RawElementDataType = (int)expectedElementDataType;
mxArray.RawDiagnostic = CreateRawDiagnostic(array);
mxArray.RawValues = ConvertRawArray(array);
return mxArray;
}
}
private static MxValue ConvertScalar(
object value,
MxDataType expectedDataType)
{
System.Type valueType = value.GetType();
string variantType = GetVariantTypeName(valueType);
switch (System.Type.GetTypeCode(valueType))
{
case TypeCode.Boolean:
return new MxValue
{
DataType = MxDataType.Boolean,
VariantType = variantType,
BoolValue = (bool)value,
};
case TypeCode.Byte:
case TypeCode.SByte:
case TypeCode.Int16:
case TypeCode.UInt16:
case TypeCode.Int32:
return new MxValue
{
DataType = MxDataType.Integer,
VariantType = variantType,
Int32Value = System.Convert.ToInt32(value, CultureInfo.InvariantCulture),
};
case TypeCode.UInt32:
case TypeCode.Int64:
return ConvertInt64Scalar(value, variantType, expectedDataType);
case TypeCode.UInt64:
return ConvertUInt64Scalar((ulong)value, variantType, expectedDataType);
case TypeCode.Single:
return new MxValue
{
DataType = MxDataType.Float,
VariantType = variantType,
FloatValue = (float)value,
};
case TypeCode.Double:
return new MxValue
{
DataType = MxDataType.Double,
VariantType = variantType,
DoubleValue = (double)value,
};
case TypeCode.Decimal:
return new MxValue
{
DataType = MxDataType.Double,
VariantType = variantType,
DoubleValue = System.Convert.ToDouble(value, CultureInfo.InvariantCulture),
RawDiagnostic = "Decimal value projected to double.",
};
case TypeCode.String:
case TypeCode.Char:
return new MxValue
{
DataType = MxDataType.String,
VariantType = variantType,
StringValue = System.Convert.ToString(value, CultureInfo.InvariantCulture) ?? string.Empty,
};
case TypeCode.DateTime:
return new MxValue
{
DataType = MxDataType.Time,
VariantType = variantType,
TimestampValue = ToTimestamp((DateTime)value),
};
default:
return CreateRawValue(value, expectedDataType);
}
}
private static MxValue ConvertInt64Scalar(
object value,
string variantType,
MxDataType expectedDataType)
{
long longValue = System.Convert.ToInt64(value, CultureInfo.InvariantCulture);
if (expectedDataType == MxDataType.Time)
{
return new MxValue
{
DataType = MxDataType.Time,
VariantType = variantType,
TimestampValue = Timestamp.FromDateTime(DateTime.FromFileTimeUtc(longValue)),
};
}
return new MxValue
{
DataType = MxDataType.Integer,
VariantType = variantType,
Int64Value = longValue,
};
}
private static MxValue ConvertUInt64Scalar(
ulong value,
string variantType,
MxDataType expectedDataType)
{
if (expectedDataType == MxDataType.Time && value <= long.MaxValue)
{
return new MxValue
{
DataType = MxDataType.Time,
VariantType = variantType,
TimestampValue = Timestamp.FromDateTime(DateTime.FromFileTimeUtc((long)value)),
};
}
if (value <= long.MaxValue)
{
return new MxValue
{
DataType = MxDataType.Integer,
VariantType = variantType,
Int64Value = (long)value,
};
}
return CreateRawValue(value, expectedDataType, "UInt64 value exceeds Int64 range.");
}
private static MxValue CreateNullValue(
object? value,
MxDataType expectedDataType)
{
return new MxValue
{
DataType = expectedDataType == MxDataType.Unspecified ? MxDataType.NoData : expectedDataType,
VariantType = value is DBNull ? "VT_NULL" : "VT_EMPTY",
IsNull = true,
};
}
private static MxValue CreateRawValue(
object value,
MxDataType expectedDataType,
string? diagnosticPrefix = null)
{
string diagnostic = CreateRawDiagnostic(value);
if (!string.IsNullOrWhiteSpace(diagnosticPrefix))
{
diagnostic = $"{diagnosticPrefix} {diagnostic}";
}
return new MxValue
{
DataType = MxDataType.Unknown,
VariantType = GetVariantTypeName(value.GetType()),
RawDataType = (int)expectedDataType,
RawDiagnostic = diagnostic,
RawValue = ByteString.CopyFromUtf8(System.Convert.ToString(value, CultureInfo.InvariantCulture) ?? string.Empty),
};
}
private static BoolArray ConvertBoolArray(Array array)
{
BoolArray values = new();
foreach (object? item in array)
{
values.Values.Add(item is not null && System.Convert.ToBoolean(item, CultureInfo.InvariantCulture));
}
return values;
}
private static Int32Array ConvertInt32Array(Array array)
{
Int32Array values = new();
foreach (object? item in array)
{
values.Values.Add(item is null ? 0 : System.Convert.ToInt32(item, CultureInfo.InvariantCulture));
}
return values;
}
private static Int64Array ConvertInt64Array(Array array)
{
Int64Array values = new();
foreach (object? item in array)
{
values.Values.Add(item is null ? 0 : System.Convert.ToInt64(item, CultureInfo.InvariantCulture));
}
return values;
}
private static FloatArray ConvertFloatArray(Array array)
{
FloatArray values = new();
foreach (object? item in array)
{
values.Values.Add(item is null ? 0 : System.Convert.ToSingle(item, CultureInfo.InvariantCulture));
}
return values;
}
private static DoubleArray ConvertDoubleArray(Array array)
{
DoubleArray values = new();
foreach (object? item in array)
{
values.Values.Add(item is null ? 0 : System.Convert.ToDouble(item, CultureInfo.InvariantCulture));
}
return values;
}
private static StringArray ConvertStringArray(Array array)
{
StringArray values = new();
foreach (object? item in array)
{
values.Values.Add(item is null ? string.Empty : System.Convert.ToString(item, CultureInfo.InvariantCulture) ?? string.Empty);
}
return values;
}
private static TimestampArray ConvertTimestampArray(Array array)
{
TimestampArray values = new();
foreach (object? item in array)
{
if (item is null)
{
values.Values.Add(Timestamp.FromDateTime(new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc)));
}
else if (item is DateTime dateTime)
{
values.Values.Add(ToTimestamp(dateTime));
}
else
{
long fileTime = System.Convert.ToInt64(item, CultureInfo.InvariantCulture);
values.Values.Add(Timestamp.FromDateTime(DateTime.FromFileTimeUtc(fileTime)));
}
}
return values;
}
private static RawArray ConvertRawArray(Array array)
{
RawArray values = new();
foreach (object? item in array)
{
string rawValue = item is null
? string.Empty
: System.Convert.ToString(item, CultureInfo.InvariantCulture) ?? string.Empty;
values.Values.Add(ByteString.CopyFromUtf8(rawValue));
}
return values;
}
private static MxDataType ResolveArrayElementDataType(
System.Type? elementType,
MxDataType expectedElementDataType)
{
if (expectedElementDataType != MxDataType.Unspecified)
{
return expectedElementDataType;
}
if (elementType == typeof(bool))
{
return MxDataType.Boolean;
}
if (elementType == typeof(byte)
|| elementType == typeof(sbyte)
|| elementType == typeof(short)
|| elementType == typeof(ushort)
|| elementType == typeof(int)
|| elementType == typeof(uint)
|| elementType == typeof(long)
|| elementType == typeof(ulong))
{
return MxDataType.Integer;
}
if (elementType == typeof(float))
{
return MxDataType.Float;
}
if (elementType == typeof(double) || elementType == typeof(decimal))
{
return MxDataType.Double;
}
if (elementType == typeof(string) || elementType == typeof(char))
{
return MxDataType.String;
}
if (elementType == typeof(DateTime))
{
return MxDataType.Time;
}
return MxDataType.Unknown;
}
private static Timestamp ToTimestamp(DateTime dateTime)
{
DateTime utcDateTime = dateTime.Kind switch
{
DateTimeKind.Utc => dateTime,
DateTimeKind.Local => dateTime.ToUniversalTime(),
_ => DateTime.SpecifyKind(dateTime, DateTimeKind.Utc),
};
return Timestamp.FromDateTime(utcDateTime);
}
private static string CreateArrayVariantType(Array array)
{
System.Type? elementType = array.GetType().GetElementType();
return $"SAFEARRAY({GetVariantTypeName(elementType)})";
}
private static string GetVariantTypeName(System.Type? type)
{
if (type is null)
{
return "VT_EMPTY";
}
System.Type nonNullableType = Nullable.GetUnderlyingType(type) ?? type;
if (nonNullableType == typeof(bool))
{
return "VT_BOOL";
}
if (nonNullableType == typeof(byte))
{
return "VT_UI1";
}
if (nonNullableType == typeof(sbyte))
{
return "VT_I1";
}
if (nonNullableType == typeof(short))
{
return "VT_I2";
}
if (nonNullableType == typeof(ushort))
{
return "VT_UI2";
}
if (nonNullableType == typeof(int))
{
return "VT_I4";
}
if (nonNullableType == typeof(uint))
{
return "VT_UI4";
}
if (nonNullableType == typeof(long))
{
return "VT_I8";
}
if (nonNullableType == typeof(ulong))
{
return "VT_UI8";
}
if (nonNullableType == typeof(float))
{
return "VT_R4";
}
if (nonNullableType == typeof(double) || nonNullableType == typeof(decimal))
{
return "VT_R8";
}
if (nonNullableType == typeof(string) || nonNullableType == typeof(char))
{
return "VT_BSTR";
}
if (nonNullableType == typeof(DateTime))
{
return "VT_DATE";
}
return $"CLR:{nonNullableType.FullName}";
}
private static string CreateRawDiagnostic(object value)
{
return $"Unsupported variant projection for CLR type '{value.GetType().FullName}'.";
}
}
@@ -0,0 +1,12 @@
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Worker.Bootstrap;
namespace MxGateway.Worker.Ipc;
public interface IWorkerPipeClient
{
Task RunAsync(
WorkerOptions options,
CancellationToken cancellationToken = default);
}
@@ -0,0 +1,33 @@
using System;
using MxGateway.Contracts.Proto;
namespace MxGateway.Worker.Ipc;
internal static class WorkerEnvelopeValidator
{
public static void Validate(
WorkerEnvelope envelope,
WorkerFrameProtocolOptions options)
{
if (envelope.ProtocolVersion != options.ProtocolVersion)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.ProtocolVersionMismatch,
$"Worker envelope protocol version {envelope.ProtocolVersion} does not match expected version {options.ProtocolVersion}.");
}
if (!string.Equals(envelope.SessionId, options.SessionId, StringComparison.Ordinal))
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.SessionMismatch,
"Worker envelope session id does not match the owning worker session.");
}
if (envelope.BodyCase == WorkerEnvelope.BodyOneofCase.None)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.InvalidEnvelope,
"Worker envelope must include a typed body.");
}
}
}
@@ -0,0 +1,15 @@
namespace MxGateway.Worker.Ipc;
public enum WorkerFrameProtocolErrorCode
{
Unknown = 0,
InvalidConfiguration = 1,
EndOfStream = 2,
MalformedLength = 3,
MessageTooLarge = 4,
InvalidEnvelope = 5,
ProtocolVersionMismatch = 6,
SessionMismatch = 7,
NonceMismatch = 8,
UnexpectedEnvelopeBody = 9,
}
@@ -0,0 +1,25 @@
using System;
namespace MxGateway.Worker.Ipc;
public sealed class WorkerFrameProtocolException : Exception
{
public WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode errorCode,
string message)
: base(message)
{
ErrorCode = errorCode;
}
public WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode errorCode,
string message,
Exception innerException)
: base(message, innerException)
{
ErrorCode = errorCode;
}
public WorkerFrameProtocolErrorCode ErrorCode { get; }
}
@@ -0,0 +1,86 @@
using System;
using MxGateway.Contracts;
using MxGateway.Worker.Bootstrap;
namespace MxGateway.Worker.Ipc;
public sealed class WorkerFrameProtocolOptions
{
public const int DefaultMaxMessageBytes = 16 * 1024 * 1024;
public WorkerFrameProtocolOptions(WorkerOptions options)
: this(
options?.SessionId ?? throw new ArgumentNullException(nameof(options)),
options.ProtocolVersion,
options.Nonce,
DefaultMaxMessageBytes)
{
}
public WorkerFrameProtocolOptions(
string sessionId,
uint protocolVersion,
string nonce)
: this(
sessionId,
protocolVersion,
nonce,
DefaultMaxMessageBytes)
{
}
public WorkerFrameProtocolOptions(
string sessionId,
uint protocolVersion,
string nonce,
int maxMessageBytes)
{
if (string.IsNullOrWhiteSpace(sessionId))
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.InvalidConfiguration,
"Worker frame protocol requires a session id.");
}
if (protocolVersion == 0)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.InvalidConfiguration,
"Worker frame protocol requires a non-zero protocol version.");
}
if (protocolVersion != GatewayContractInfo.WorkerProtocolVersion)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.ProtocolVersionMismatch,
$"Worker frame protocol version {protocolVersion} is not supported.");
}
if (string.IsNullOrWhiteSpace(nonce))
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.InvalidConfiguration,
"Worker frame protocol requires a nonce.");
}
if (maxMessageBytes <= 0)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.InvalidConfiguration,
"Worker frame protocol max message size must be greater than zero.");
}
SessionId = sessionId;
ProtocolVersion = protocolVersion;
Nonce = nonce;
MaxMessageBytes = maxMessageBytes;
}
public string SessionId { get; }
public uint ProtocolVersion { get; }
public string Nonce { get; }
public int MaxMessageBytes { get; }
}
@@ -0,0 +1,93 @@
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Google.Protobuf;
using MxGateway.Contracts.Proto;
namespace MxGateway.Worker.Ipc;
public sealed class WorkerFrameReader
{
private readonly WorkerFrameProtocolOptions _options;
private readonly Stream _stream;
public WorkerFrameReader(
Stream stream,
WorkerFrameProtocolOptions options)
{
_stream = stream ?? throw new ArgumentNullException(nameof(stream));
_options = options ?? throw new ArgumentNullException(nameof(options));
}
public async Task<WorkerEnvelope> ReadAsync(CancellationToken cancellationToken = default)
{
byte[] lengthPrefix = new byte[sizeof(uint)];
await ReadExactlyOrThrowAsync(lengthPrefix, cancellationToken).ConfigureAwait(false);
uint payloadLength = ReadUInt32LittleEndian(lengthPrefix);
if (payloadLength == 0)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.MalformedLength,
"Worker frame payload length must be greater than zero.");
}
if (payloadLength > _options.MaxMessageBytes)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.MessageTooLarge,
$"Worker frame payload length {payloadLength} exceeds the configured maximum of {_options.MaxMessageBytes} bytes.");
}
byte[] payload = new byte[payloadLength];
await ReadExactlyOrThrowAsync(payload, cancellationToken).ConfigureAwait(false);
WorkerEnvelope envelope;
try
{
envelope = WorkerEnvelope.Parser.ParseFrom(payload);
}
catch (InvalidProtocolBufferException exception)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.InvalidEnvelope,
"Worker frame payload is not a valid WorkerEnvelope protobuf message.",
exception);
}
WorkerEnvelopeValidator.Validate(envelope, _options);
return envelope;
}
private static uint ReadUInt32LittleEndian(byte[] buffer)
{
return (uint)buffer[0]
| ((uint)buffer[1] << 8)
| ((uint)buffer[2] << 16)
| ((uint)buffer[3] << 24);
}
private async Task ReadExactlyOrThrowAsync(
byte[] buffer,
CancellationToken cancellationToken)
{
int offset = 0;
while (offset < buffer.Length)
{
int bytesRead = await _stream
.ReadAsync(buffer, offset, buffer.Length - offset, cancellationToken)
.ConfigureAwait(false);
if (bytesRead == 0)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.EndOfStream,
"Worker frame ended before the expected number of bytes were read.");
}
offset += bytesRead;
}
}
}
@@ -0,0 +1,76 @@
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Google.Protobuf;
using MxGateway.Contracts.Proto;
namespace MxGateway.Worker.Ipc;
public sealed class WorkerFrameWriter
{
private readonly WorkerFrameProtocolOptions _options;
private readonly SemaphoreSlim _writeLock = new(1, 1);
private readonly Stream _stream;
public WorkerFrameWriter(
Stream stream,
WorkerFrameProtocolOptions options)
{
_stream = stream ?? throw new ArgumentNullException(nameof(stream));
_options = options ?? throw new ArgumentNullException(nameof(options));
}
public async Task WriteAsync(
WorkerEnvelope envelope,
CancellationToken cancellationToken = default)
{
if (envelope is null)
{
throw new ArgumentNullException(nameof(envelope));
}
WorkerEnvelopeValidator.Validate(envelope, _options);
int payloadLength = envelope.CalculateSize();
if (payloadLength == 0)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.InvalidEnvelope,
"Worker envelope cannot serialize to an empty payload.");
}
if (payloadLength > _options.MaxMessageBytes)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.MessageTooLarge,
$"Worker envelope payload length {payloadLength} exceeds the configured maximum of {_options.MaxMessageBytes} bytes.");
}
byte[] payload = envelope.ToByteArray();
byte[] lengthPrefix = new byte[sizeof(uint)];
WriteUInt32LittleEndian(lengthPrefix, (uint)payloadLength);
await _writeLock.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await _stream.WriteAsync(lengthPrefix, 0, lengthPrefix.Length, cancellationToken).ConfigureAwait(false);
await _stream.WriteAsync(payload, 0, payload.Length, cancellationToken).ConfigureAwait(false);
await _stream.FlushAsync(cancellationToken).ConfigureAwait(false);
}
finally
{
_writeLock.Release();
}
}
private static void WriteUInt32LittleEndian(
byte[] buffer,
uint value)
{
buffer[0] = (byte)value;
buffer[1] = (byte)(value >> 8);
buffer[2] = (byte)(value >> 16);
buffer[3] = (byte)(value >> 24);
}
}
@@ -0,0 +1,67 @@
using System;
using System.IO.Pipes;
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Worker.Bootstrap;
namespace MxGateway.Worker.Ipc;
public sealed class WorkerPipeClient : IWorkerPipeClient
{
public const int DefaultConnectTimeoutMilliseconds = 30000;
private readonly int _connectTimeoutMilliseconds;
public WorkerPipeClient()
: this(DefaultConnectTimeoutMilliseconds)
{
}
public WorkerPipeClient(int connectTimeoutMilliseconds)
{
if (connectTimeoutMilliseconds <= 0)
{
throw new ArgumentOutOfRangeException(
nameof(connectTimeoutMilliseconds),
"Worker pipe connect timeout must be greater than zero.");
}
_connectTimeoutMilliseconds = connectTimeoutMilliseconds;
}
public async Task RunAsync(
WorkerOptions options,
CancellationToken cancellationToken = default)
{
if (options is null)
{
throw new ArgumentNullException(nameof(options));
}
WorkerFrameProtocolOptions frameOptions = new(options);
using NamedPipeClientStream pipe = new(
".",
options.PipeName,
PipeDirection.InOut,
PipeOptions.Asynchronous);
await ConnectAsync(pipe, cancellationToken).ConfigureAwait(false);
WorkerPipeSession session = new(pipe, frameOptions);
await session.CompleteStartupHandshakeAsync(cancellationToken).ConfigureAwait(false);
}
private Task ConnectAsync(
NamedPipeClientStream pipe,
CancellationToken cancellationToken)
{
return Task.Run(
() =>
{
cancellationToken.ThrowIfCancellationRequested();
pipe.Connect(_connectTimeoutMilliseconds);
},
cancellationToken);
}
}
@@ -0,0 +1,218 @@
using System;
using System.Diagnostics;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Google.Protobuf.WellKnownTypes;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.MxAccess;
namespace MxGateway.Worker.Ipc;
public sealed class WorkerPipeSession
{
private readonly WorkerFrameProtocolOptions _options;
private readonly Func<int> _processIdProvider;
private readonly WorkerFrameReader _reader;
private readonly WorkerFrameWriter _writer;
private long _nextSequence;
public WorkerPipeSession(
Stream stream,
WorkerFrameProtocolOptions options)
: this(
new WorkerFrameReader(stream, options),
new WorkerFrameWriter(stream, options),
options,
() => Process.GetCurrentProcess().Id)
{
}
public WorkerPipeSession(
WorkerFrameReader reader,
WorkerFrameWriter writer,
WorkerFrameProtocolOptions options,
Func<int> processIdProvider)
{
_reader = reader ?? throw new ArgumentNullException(nameof(reader));
_writer = writer ?? throw new ArgumentNullException(nameof(writer));
_options = options ?? throw new ArgumentNullException(nameof(options));
_processIdProvider = processIdProvider ?? throw new ArgumentNullException(nameof(processIdProvider));
}
public Task CompleteStartupHandshakeAsync(CancellationToken cancellationToken = default)
{
return CompleteStartupHandshakeAsync(_ => Task.CompletedTask, cancellationToken);
}
public async Task CompleteStartupHandshakeAsync(
Func<CancellationToken, Task> initializeMxAccessAsync,
CancellationToken cancellationToken = default)
{
if (initializeMxAccessAsync is null)
{
throw new ArgumentNullException(nameof(initializeMxAccessAsync));
}
try
{
WorkerEnvelope envelope = await _reader.ReadAsync(cancellationToken).ConfigureAwait(false);
ValidateGatewayHello(envelope);
await WriteWorkerHelloAsync(cancellationToken).ConfigureAwait(false);
await initializeMxAccessAsync(cancellationToken).ConfigureAwait(false);
await WriteWorkerReadyAsync(cancellationToken).ConfigureAwait(false);
}
catch (WorkerFrameProtocolException exception)
{
await TryWriteFaultAsync(exception, cancellationToken).ConfigureAwait(false);
throw;
}
}
private void ValidateGatewayHello(WorkerEnvelope envelope)
{
if (envelope.BodyCase != WorkerEnvelope.BodyOneofCase.GatewayHello)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.UnexpectedEnvelopeBody,
"Worker expected GatewayHello during startup handshake.");
}
GatewayHello gatewayHello = envelope.GatewayHello;
if (gatewayHello.SupportedProtocolVersion != _options.ProtocolVersion)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.ProtocolVersionMismatch,
$"GatewayHello supported protocol version {gatewayHello.SupportedProtocolVersion} does not match expected version {_options.ProtocolVersion}.");
}
if (!string.Equals(gatewayHello.Nonce, _options.Nonce, StringComparison.Ordinal))
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.NonceMismatch,
"GatewayHello nonce does not match the worker launch nonce.");
}
}
private Task WriteWorkerHelloAsync(CancellationToken cancellationToken)
{
return _writer.WriteAsync(
CreateEnvelope(new WorkerHello
{
ProtocolVersion = _options.ProtocolVersion,
Nonce = _options.Nonce,
WorkerProcessId = _processIdProvider(),
WorkerVersion = typeof(WorkerPipeSession).Assembly.GetName().Version?.ToString() ?? string.Empty,
}),
cancellationToken);
}
private Task WriteWorkerReadyAsync(CancellationToken cancellationToken)
{
return _writer.WriteAsync(
CreateEnvelope(new WorkerReady
{
WorkerProcessId = _processIdProvider(),
MxaccessProgid = MxAccessInteropInfo.ProgId,
MxaccessClsid = MxAccessInteropInfo.Clsid,
ReadyTimestamp = Timestamp.FromDateTime(DateTime.UtcNow),
}),
cancellationToken);
}
private async Task TryWriteFaultAsync(
WorkerFrameProtocolException exception,
CancellationToken cancellationToken)
{
try
{
await _writer
.WriteAsync(CreateEnvelope(CreateFault(exception)), cancellationToken)
.ConfigureAwait(false);
}
catch (Exception faultWriteException) when (
faultWriteException is IOException
|| faultWriteException is ObjectDisposedException
|| faultWriteException is WorkerFrameProtocolException)
{
// The original protocol failure is the actionable error.
}
}
private WorkerEnvelope CreateEnvelope(WorkerHello hello)
{
return CreateBaseEnvelope(hello);
}
private WorkerEnvelope CreateEnvelope(WorkerReady ready)
{
return CreateBaseEnvelope(ready);
}
private WorkerEnvelope CreateEnvelope(WorkerFault fault)
{
return CreateBaseEnvelope(fault);
}
private WorkerEnvelope CreateBaseEnvelope(WorkerHello body)
{
WorkerEnvelope envelope = CreateBaseEnvelope();
envelope.WorkerHello = body;
return envelope;
}
private WorkerEnvelope CreateBaseEnvelope(WorkerReady body)
{
WorkerEnvelope envelope = CreateBaseEnvelope();
envelope.WorkerReady = body;
return envelope;
}
private WorkerEnvelope CreateBaseEnvelope(WorkerFault body)
{
WorkerEnvelope envelope = CreateBaseEnvelope();
envelope.WorkerFault = body;
return envelope;
}
private WorkerEnvelope CreateBaseEnvelope()
{
return new WorkerEnvelope
{
ProtocolVersion = _options.ProtocolVersion,
SessionId = _options.SessionId,
Sequence = NextSequence(),
};
}
private ulong NextSequence()
{
return unchecked((ulong)Interlocked.Increment(ref _nextSequence));
}
private static WorkerFault CreateFault(WorkerFrameProtocolException exception)
{
return new WorkerFault
{
Category = MapFaultCategory(exception.ErrorCode),
ExceptionType = exception.GetType().FullName ?? string.Empty,
DiagnosticMessage = exception.Message,
ProtocolStatus = new ProtocolStatus
{
Code = ProtocolStatusCode.ProtocolViolation,
Message = exception.Message,
},
};
}
private static WorkerFaultCategory MapFaultCategory(WorkerFrameProtocolErrorCode errorCode)
{
return errorCode switch
{
WorkerFrameProtocolErrorCode.ProtocolVersionMismatch => WorkerFaultCategory.ProtocolMismatch,
WorkerFrameProtocolErrorCode.EndOfStream => WorkerFaultCategory.PipeDisconnected,
_ => WorkerFaultCategory.ProtocolViolation,
};
}
}
@@ -0,0 +1,8 @@
namespace MxGateway.Worker.Sta;
public interface IStaComApartmentInitializer
{
void Initialize();
void Uninitialize();
}
+8
View File
@@ -0,0 +1,8 @@
namespace MxGateway.Worker.Sta;
internal interface IStaWorkItem
{
void CancelBeforeExecution();
void Execute();
}
@@ -0,0 +1,31 @@
using System;
using System.Runtime.InteropServices;
namespace MxGateway.Worker.Sta;
public sealed class StaComApartmentInitializer : IStaComApartmentInitializer
{
private const uint CoInitializeApartmentThreaded = 0x2;
private const int SOk = 0;
private const int SFalse = 1;
public void Initialize()
{
int hresult = CoInitializeEx(IntPtr.Zero, CoInitializeApartmentThreaded);
if (hresult != SOk && hresult != SFalse)
{
throw new COMException("Failed to initialize the worker STA COM apartment.", hresult);
}
}
public void Uninitialize()
{
CoUninitialize();
}
[DllImport("ole32.dll")]
private static extern int CoInitializeEx(IntPtr reserved, uint coInit);
[DllImport("ole32.dll")]
private static extern void CoUninitialize();
}
+111
View File
@@ -0,0 +1,111 @@
using System;
using System.Runtime.InteropServices;
using System.Threading;
using Microsoft.Win32.SafeHandles;
namespace MxGateway.Worker.Sta;
public sealed class StaMessagePump
{
private const uint Infinite = 0xFFFFFFFF;
private const uint MsgWaitFailed = 0xFFFFFFFF;
private const uint MwmoInputAvailable = 0x0004;
private const uint PmRemove = 0x0001;
private const uint QsAllInput = 0x04FF;
public void WaitForWorkOrMessages(WaitHandle commandWakeEvent, TimeSpan timeout)
{
if (commandWakeEvent is null)
{
throw new ArgumentNullException(nameof(commandWakeEvent));
}
uint timeoutMilliseconds = ToTimeoutMilliseconds(timeout);
SafeWaitHandle safeHandle = commandWakeEvent.SafeWaitHandle;
IntPtr[] handles = [safeHandle.DangerousGetHandle()];
uint result = MsgWaitForMultipleObjectsEx(
(uint)handles.Length,
handles,
timeoutMilliseconds,
QsAllInput,
MwmoInputAvailable);
if (result == MsgWaitFailed)
{
throw new InvalidOperationException(
"The worker STA message pump failed while waiting for command work or Windows messages.");
}
}
public int PumpPendingMessages()
{
int pumpedMessages = 0;
while (PeekMessage(out NativeMessage message, IntPtr.Zero, 0, 0, PmRemove))
{
TranslateMessage(ref message);
DispatchMessage(ref message);
pumpedMessages++;
}
return pumpedMessages;
}
private static uint ToTimeoutMilliseconds(TimeSpan timeout)
{
if (timeout == Timeout.InfiniteTimeSpan)
{
return Infinite;
}
if (timeout <= TimeSpan.Zero)
{
return 0;
}
return timeout.TotalMilliseconds >= uint.MaxValue
? uint.MaxValue - 1
: (uint)Math.Ceiling(timeout.TotalMilliseconds);
}
[DllImport("user32.dll", SetLastError = true)]
private static extern uint MsgWaitForMultipleObjectsEx(
uint count,
IntPtr[] handles,
uint milliseconds,
uint wakeMask,
uint flags);
[DllImport("user32.dll", SetLastError = true)]
private static extern bool PeekMessage(
out NativeMessage message,
IntPtr windowHandle,
uint messageFilterMin,
uint messageFilterMax,
uint removeMessage);
[DllImport("user32.dll")]
private static extern bool TranslateMessage(ref NativeMessage message);
[DllImport("user32.dll")]
private static extern IntPtr DispatchMessage(ref NativeMessage message);
[StructLayout(LayoutKind.Sequential)]
private struct NativeMessage
{
public IntPtr WindowHandle;
public uint Message;
public UIntPtr WParam;
public IntPtr LParam;
public uint Time;
public NativePoint Point;
}
[StructLayout(LayoutKind.Sequential)]
private struct NativePoint
{
public int X;
public int Y;
}
}
+267
View File
@@ -0,0 +1,267 @@
using System;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;
namespace MxGateway.Worker.Sta;
public sealed class StaRuntime : IDisposable
{
private readonly IStaComApartmentInitializer comApartmentInitializer;
private readonly StaMessagePump messagePump;
private readonly ConcurrentQueue<IStaWorkItem> commandQueue = new();
private readonly AutoResetEvent commandWakeEvent = new(false);
private readonly ManualResetEventSlim startedEvent = new(false);
private readonly ManualResetEventSlim stoppedEvent = new(false);
private readonly object gate = new();
private readonly Thread staThread;
private readonly TimeSpan idlePumpInterval;
private bool disposed;
private bool startRequested;
private bool shutdownRequested;
private Exception? startupException;
private long lastActivityUtcTicks;
private bool comInitialized;
public StaRuntime()
: this(new StaComApartmentInitializer(), new StaMessagePump(), TimeSpan.FromMilliseconds(50))
{
}
public StaRuntime(
IStaComApartmentInitializer comApartmentInitializer,
StaMessagePump messagePump,
TimeSpan idlePumpInterval)
{
this.comApartmentInitializer = comApartmentInitializer
?? throw new ArgumentNullException(nameof(comApartmentInitializer));
this.messagePump = messagePump ?? throw new ArgumentNullException(nameof(messagePump));
if (idlePumpInterval <= TimeSpan.Zero)
{
throw new ArgumentOutOfRangeException(
nameof(idlePumpInterval),
"The idle pump interval must be greater than zero.");
}
this.idlePumpInterval = idlePumpInterval;
lastActivityUtcTicks = DateTimeOffset.UtcNow.UtcTicks;
staThread = new Thread(ThreadMain)
{
IsBackground = true,
Name = "MxGateway.Worker.STA"
};
staThread.SetApartmentState(ApartmentState.STA);
}
public int? StaThreadId { get; private set; }
public DateTimeOffset LastActivityUtc =>
new(new DateTime(Volatile.Read(ref lastActivityUtcTicks), DateTimeKind.Utc));
public bool IsRunning => startedEvent.IsSet && !stoppedEvent.IsSet;
public void Start()
{
ThrowIfDisposed();
lock (gate)
{
if (shutdownRequested)
{
throw new InvalidOperationException("The worker STA runtime is shutting down.");
}
if (!startRequested)
{
startRequested = true;
staThread.Start();
}
}
startedEvent.Wait();
if (startupException is not null)
{
throw new InvalidOperationException(
"The worker STA runtime failed to initialize.",
startupException);
}
}
public Task InvokeAsync(Action command, CancellationToken cancellationToken = default)
{
if (command is null)
{
throw new ArgumentNullException(nameof(command));
}
return InvokeAsync(
() =>
{
command();
return true;
},
cancellationToken);
}
public Task<T> InvokeAsync<T>(Func<T> command, CancellationToken cancellationToken = default)
{
if (command is null)
{
throw new ArgumentNullException(nameof(command));
}
ThrowIfDisposed();
if (cancellationToken.IsCancellationRequested)
{
return Task.FromCanceled<T>(cancellationToken);
}
StaWorkItem<T> workItem = new(command, cancellationToken);
lock (gate)
{
if (shutdownRequested)
{
return Task.FromException<T>(
new InvalidOperationException("The worker STA runtime is shutting down."));
}
commandQueue.Enqueue(workItem);
}
commandWakeEvent.Set();
return workItem.Task;
}
public bool Shutdown(TimeSpan timeout)
{
if (timeout < TimeSpan.Zero && timeout != Timeout.InfiniteTimeSpan)
{
throw new ArgumentOutOfRangeException(nameof(timeout));
}
lock (gate)
{
shutdownRequested = true;
}
commandWakeEvent.Set();
if (!startedEvent.IsSet && !staThread.IsAlive)
{
CancelQueuedCommands();
stoppedEvent.Set();
return true;
}
bool stopped = stoppedEvent.Wait(timeout);
if (stopped)
{
CancelQueuedCommands();
}
return stopped;
}
public void Dispose()
{
if (disposed)
{
return;
}
bool stopped = Shutdown(TimeSpan.FromSeconds(5));
if (stopped)
{
commandWakeEvent.Dispose();
startedEvent.Dispose();
stoppedEvent.Dispose();
}
disposed = true;
}
private void ThreadMain()
{
try
{
StaThreadId = Thread.CurrentThread.ManagedThreadId;
comApartmentInitializer.Initialize();
comInitialized = true;
MarkActivity();
startedEvent.Set();
while (!IsShutdownRequested())
{
ProcessQueuedCommands();
messagePump.WaitForWorkOrMessages(commandWakeEvent, idlePumpInterval);
messagePump.PumpPendingMessages();
MarkActivity();
}
ProcessQueuedCommands();
}
catch (Exception exception)
{
startupException = exception;
startedEvent.Set();
}
finally
{
CancelQueuedCommands();
try
{
if (comInitialized)
{
comApartmentInitializer.Uninitialize();
}
}
finally
{
MarkActivity();
stoppedEvent.Set();
}
}
}
private void ProcessQueuedCommands()
{
while (commandQueue.TryDequeue(out IStaWorkItem? workItem))
{
MarkActivity();
workItem.Execute();
MarkActivity();
}
}
private void CancelQueuedCommands()
{
while (commandQueue.TryDequeue(out IStaWorkItem? workItem))
{
workItem.CancelBeforeExecution();
}
}
private bool IsShutdownRequested()
{
lock (gate)
{
return shutdownRequested;
}
}
private void MarkActivity()
{
Volatile.Write(ref lastActivityUtcTicks, DateTimeOffset.UtcNow.UtcTicks);
}
private void ThrowIfDisposed()
{
if (disposed)
{
throw new ObjectDisposedException(nameof(StaRuntime));
}
}
}
+71
View File
@@ -0,0 +1,71 @@
using System;
using System.Threading;
using System.Threading.Tasks;
namespace MxGateway.Worker.Sta;
internal sealed class StaWorkItem<T> : IStaWorkItem
{
private readonly Func<T> command;
private readonly CancellationToken cancellationToken;
private readonly CancellationTokenRegistration cancellationRegistration;
private int started;
public StaWorkItem(Func<T> command, CancellationToken cancellationToken)
{
this.command = command ?? throw new ArgumentNullException(nameof(command));
this.cancellationToken = cancellationToken;
Completion = new TaskCompletionSource<T>(TaskCreationOptions.RunContinuationsAsynchronously);
if (cancellationToken.CanBeCanceled)
{
cancellationRegistration = cancellationToken.Register(
() =>
{
if (Interlocked.CompareExchange(ref started, 1, 0) == 0)
{
Completion.TrySetCanceled(cancellationToken);
}
});
}
}
public Task<T> Task => Completion.Task;
private TaskCompletionSource<T> Completion { get; }
public void CancelBeforeExecution()
{
if (Interlocked.CompareExchange(ref started, 1, 0) == 0)
{
Completion.TrySetCanceled(cancellationToken);
cancellationRegistration.Dispose();
}
}
public void Execute()
{
if (Interlocked.CompareExchange(ref started, 1, 0) != 0)
{
cancellationRegistration.Dispose();
return;
}
cancellationRegistration.Dispose();
if (cancellationToken.IsCancellationRequested)
{
Completion.TrySetCanceled(cancellationToken);
return;
}
try
{
Completion.TrySetResult(command());
}
catch (Exception exception)
{
Completion.TrySetException(exception);
}
}
}
+52 -1
View File
@@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using System.IO;
using MxGateway.Worker.Bootstrap;
using MxGateway.Worker.Ipc;
namespace MxGateway.Worker;
@@ -11,13 +13,27 @@ public static class WorkerApplication
return Run(
args,
new EnvironmentVariableWorkerEnvironment(),
new WorkerConsoleLogger(Console.Error));
new WorkerConsoleLogger(Console.Error),
new WorkerPipeClient());
}
public static int Run(
string[] args,
IWorkerEnvironment environment,
IWorkerLogger logger)
{
return Run(
args,
environment,
logger,
new WorkerPipeClient());
}
public static int Run(
string[] args,
IWorkerEnvironment environment,
IWorkerLogger logger,
IWorkerPipeClient pipeClient)
{
if (args is null)
{
@@ -34,6 +50,11 @@ public static class WorkerApplication
throw new ArgumentNullException(nameof(logger));
}
if (pipeClient is null)
{
throw new ArgumentNullException(nameof(pipeClient));
}
try
{
WorkerOptionsParser parser = new(environment);
@@ -61,8 +82,38 @@ public static class WorkerApplication
["nonce"] = options.Nonce,
});
pipeClient.RunAsync(options).GetAwaiter().GetResult();
logger.Information("WorkerPipeHandshakeSucceeded", new Dictionary<string, object?>
{
["session_id"] = options.SessionId,
["pipe_name"] = options.PipeName,
["protocol_version"] = options.ProtocolVersion,
});
return (int)WorkerExitCode.Success;
}
catch (WorkerFrameProtocolException exception)
{
logger.Error("WorkerPipeProtocolFailure", new Dictionary<string, object?>
{
["exit_code"] = WorkerExitCode.ProtocolViolation,
["error_code"] = exception.ErrorCode,
["exception_type"] = exception.GetType().FullName,
});
return (int)WorkerExitCode.ProtocolViolation;
}
catch (Exception exception) when (exception is IOException or TimeoutException)
{
logger.Error("WorkerPipeConnectionFailed", new Dictionary<string, object?>
{
["exit_code"] = WorkerExitCode.PipeConnectionFailed,
["exception_type"] = exception.GetType().FullName,
});
return (int)WorkerExitCode.PipeConnectionFailed;
}
catch (Exception exception)
{
logger.Error("WorkerBootstrapUnexpectedFailure", new Dictionary<string, object?>