Compare commits

...

4 Commits

Author SHA1 Message Date
Joseph Doherty fce9e99553 Issue #11: implement gateway workerclient 2026-04-26 17:09:51 -04:00
dohertj2 c8fb3e91a3 Merge PR #63: Issue #8 add gRPC authentication and scope authorization
Verified with dotnet build src\\MxGateway.sln and dotnet test src\\MxGateway.sln.
2026-04-26 17:06:23 -04:00
dohertj2 8ce327e6f4 Merge PR #62: Issue #7 implement local api key admin cli
Verified with dotnet build src\\MxGateway.sln and dotnet test src\\MxGateway.sln.
2026-04-26 17:02:09 -04:00
Joseph Doherty fad0ac9948 Issue #8: add grpc authentication and scope authorization 2026-04-26 17:01:59 -04:00
20 changed files with 1789 additions and 5 deletions
+34 -2
View File
@@ -411,7 +411,7 @@ session ids as protocol faults and close the session.
`WorkerClient` is the gateway-side object that owns one worker connection.
Suggested public shape:
Current public shape:
```csharp
public interface IWorkerClient : IAsyncDisposable
@@ -419,6 +419,7 @@ public interface IWorkerClient : IAsyncDisposable
string SessionId { get; }
int? ProcessId { get; }
WorkerClientState State { get; }
DateTimeOffset LastHeartbeatAt { get; }
Task StartAsync(CancellationToken cancellationToken);
Task<WorkerCommandReply> InvokeAsync(
@@ -438,12 +439,17 @@ Internally it owns:
- pipe stream,
- read loop,
- write loop,
- bounded outbound command/control channel,
- outbound command/control channel serialized by the write loop,
- bounded inbound event channel,
- pending command dictionary keyed by correlation id,
- heartbeat monitor,
- terminal fault source.
`StartAsync` sends `GatewayHello`, verifies the `WorkerHello` protocol version
and nonce, waits for `WorkerReady`, and only then exposes `Ready` state. The
read loop starts after readiness so the handshake has a single owner for its
ordered frames.
### Read Loop
The read loop:
@@ -612,6 +618,15 @@ hashes the presented secret, and compares the stored and presented hashes with
results distinguish malformed credentials, missing keys, revoked keys, missing
pepper configuration, and hash mismatch for internal authorization handling.
`GatewayGrpcAuthorizationInterceptor` enforces this authentication model for
public gRPC calls. Missing, malformed, revoked, unknown, or mismatched keys fail
with `Unauthenticated`. Authenticated calls missing the scope required by the
RPC fail with `PermissionDenied`. The interceptor applies to unary calls and
server-streaming calls and stores the authenticated `ApiKeyIdentity` in
`IGatewayRequestIdentityAccessor` for the duration of the request handler.
`Authentication:Mode` set to `Disabled` bypasses API-key verification for local
development only.
Recommended scopes:
- `session:open`
@@ -677,6 +692,20 @@ Commands requiring authorization:
- worker shutdown diagnostics,
- metadata queries if they expose sensitive plant structure.
Current gRPC scope mapping:
- `OpenSession` requires `session:open`.
- `CloseSession` requires `session:close`.
- `StreamEvents` and `DrainEvents` require `events:read`.
- read-style MXAccess commands such as `Register`, `AddItem`, `Advise`, and
`Ping` require `invoke:read`.
- `Write` and `Write2` require `invoke:write`.
- `WriteSecured`, `WriteSecured2`, and `AuthenticateUser` require
`invoke:secure`.
- metadata commands such as `ArchestrAUserToId`, `GetSessionState`, and
`GetWorkerInfo` require `metadata:read`.
- `ShutdownWorker` requires `admin`.
### Worker IPC
Named pipes should be local only. Pipe ACLs should restrict access to:
@@ -819,6 +848,9 @@ workers and fake transports.
Focused tests:
- session state transitions,
- gRPC API-key authentication for unary and streaming calls,
- gRPC scope mapping for sessions, invokes, events, metadata, and admin
commands,
- worker startup failures,
- protocol version mismatch,
- malformed frame handling,
+7 -3
View File
@@ -566,9 +566,13 @@ Because each client owns one worker, a crash or leak affects only that session.
External gateway:
- use TLS for remote gRPC if crossing machine boundaries,
- authenticate clients with Windows auth, mTLS, or a deployment-specific token,
- authorize access to commands that can write, authenticate users, or alter
runtime state.
- authenticate v1 gRPC clients with `authorization: Bearer
mxgw_<key-id>_<secret>` API-key metadata,
- reject missing or invalid API keys with gRPC `Unauthenticated`,
- reject valid keys that lack the required session, invoke, event, metadata, or
admin scope with gRPC `PermissionDenied`,
- authorize access to commands that can write, authenticate users, expose
metadata, stream events, or alter runtime state.
Internal worker IPC:
@@ -3,6 +3,7 @@ using MxGateway.Server.Configuration;
using MxGateway.Server.Diagnostics;
using MxGateway.Server.Metrics;
using MxGateway.Server.Security.Authentication;
using MxGateway.Server.Security.Authorization;
using MxGateway.Server.Workers;
namespace MxGateway.Server;
@@ -26,6 +27,7 @@ public static class GatewayApplication
builder.Services.AddGatewayConfiguration();
builder.Services.AddSqliteAuthStore();
builder.Services.AddGatewayGrpcAuthorization();
builder.Services.AddHealthChecks();
builder.Services.AddSingleton<GatewayMetrics>();
builder.Services.AddWorkerProcessLauncher();
@@ -5,6 +5,7 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Grpc.AspNetCore" Version="2.76.0" />
<PackageReference Include="Microsoft.Data.Sqlite" Version="10.0.7" />
</ItemGroup>
@@ -0,0 +1,74 @@
using Grpc.Core;
using Grpc.Core.Interceptors;
using Microsoft.Extensions.Options;
using MxGateway.Server.Configuration;
using MxGateway.Server.Security.Authentication;
namespace MxGateway.Server.Security.Authorization;
public sealed class GatewayGrpcAuthorizationInterceptor(
IApiKeyVerifier apiKeyVerifier,
GatewayGrpcScopeResolver scopeResolver,
IGatewayRequestIdentityAccessor identityAccessor,
IOptions<GatewayOptions> options) : Interceptor
{
public override async Task<TResponse> UnaryServerHandler<TRequest, TResponse>(
TRequest request,
ServerCallContext context,
UnaryServerMethod<TRequest, TResponse> continuation)
{
ApiKeyIdentity? identity = await AuthenticateAndAuthorizeAsync(request, context).ConfigureAwait(false);
IDisposable? identityScope = identity is null ? null : identityAccessor.Push(identity);
using (identityScope)
{
return await continuation(request, context).ConfigureAwait(false);
}
}
public override async Task ServerStreamingServerHandler<TRequest, TResponse>(
TRequest request,
IServerStreamWriter<TResponse> responseStream,
ServerCallContext context,
ServerStreamingServerMethod<TRequest, TResponse> continuation)
{
ApiKeyIdentity? identity = await AuthenticateAndAuthorizeAsync(request, context).ConfigureAwait(false);
IDisposable? identityScope = identity is null ? null : identityAccessor.Push(identity);
using (identityScope)
{
await continuation(request, responseStream, context).ConfigureAwait(false);
}
}
private async Task<ApiKeyIdentity?> AuthenticateAndAuthorizeAsync<TRequest>(
TRequest request,
ServerCallContext context)
where TRequest : class
{
if (options.Value.Authentication.Mode == AuthenticationMode.Disabled)
{
return null;
}
string? authorizationHeader = context.RequestHeaders.GetValue("authorization");
ApiKeyVerificationResult verificationResult = await apiKeyVerifier
.VerifyAsync(authorizationHeader, context.CancellationToken)
.ConfigureAwait(false);
if (!verificationResult.Succeeded || verificationResult.Identity is null)
{
throw new RpcException(new Status(
StatusCode.Unauthenticated,
"Missing or invalid API key."));
}
string requiredScope = scopeResolver.ResolveRequiredScope(request);
if (!verificationResult.Identity.Scopes.Contains(requiredScope))
{
throw new RpcException(new Status(
StatusCode.PermissionDenied,
$"API key is missing required scope '{requiredScope}'."));
}
return verificationResult.Identity;
}
}
@@ -0,0 +1,40 @@
using MxGateway.Contracts.Proto;
namespace MxGateway.Server.Security.Authorization;
public sealed class GatewayGrpcScopeResolver
{
public string ResolveRequiredScope(object request)
{
return request switch
{
OpenSessionRequest => GatewayScopes.SessionOpen,
CloseSessionRequest => GatewayScopes.SessionClose,
StreamEventsRequest => GatewayScopes.EventsRead,
MxCommandRequest commandRequest => ResolveCommandScope(commandRequest.Command?.Kind ?? MxCommandKind.Unspecified),
_ => GatewayScopes.Admin
};
}
private static string ResolveCommandScope(MxCommandKind kind)
{
return kind switch
{
MxCommandKind.Write or
MxCommandKind.Write2 => GatewayScopes.InvokeWrite,
MxCommandKind.WriteSecured or
MxCommandKind.WriteSecured2 or
MxCommandKind.AuthenticateUser => GatewayScopes.InvokeSecure,
MxCommandKind.ArchestraUserToId or
MxCommandKind.GetSessionState or
MxCommandKind.GetWorkerInfo => GatewayScopes.MetadataRead,
MxCommandKind.DrainEvents => GatewayScopes.EventsRead,
MxCommandKind.ShutdownWorker => GatewayScopes.Admin,
_ => GatewayScopes.InvokeRead
};
}
}
@@ -0,0 +1,38 @@
using MxGateway.Server.Security.Authentication;
namespace MxGateway.Server.Security.Authorization;
public sealed class GatewayRequestIdentityAccessor : IGatewayRequestIdentityAccessor
{
private readonly AsyncLocal<ApiKeyIdentity?> currentIdentity = new();
public ApiKeyIdentity? Current => currentIdentity.Value;
public IDisposable Push(ApiKeyIdentity identity)
{
ArgumentNullException.ThrowIfNull(identity);
ApiKeyIdentity? previousIdentity = currentIdentity.Value;
currentIdentity.Value = identity;
return new IdentityScope(this, previousIdentity);
}
private sealed class IdentityScope(
GatewayRequestIdentityAccessor accessor,
ApiKeyIdentity? previousIdentity) : IDisposable
{
private bool disposed;
public void Dispose()
{
if (disposed)
{
return;
}
accessor.currentIdentity.Value = previousIdentity;
disposed = true;
}
}
}
@@ -0,0 +1,13 @@
namespace MxGateway.Server.Security.Authorization;
public static class GatewayScopes
{
public const string SessionOpen = "session:open";
public const string SessionClose = "session:close";
public const string InvokeRead = "invoke:read";
public const string InvokeWrite = "invoke:write";
public const string InvokeSecure = "invoke:secure";
public const string EventsRead = "events:read";
public const string MetadataRead = "metadata:read";
public const string Admin = "admin";
}
@@ -0,0 +1,16 @@
using Grpc.Core.Interceptors;
namespace MxGateway.Server.Security.Authorization;
public static class GrpcAuthorizationServiceCollectionExtensions
{
public static IServiceCollection AddGatewayGrpcAuthorization(this IServiceCollection services)
{
services.AddSingleton<GatewayGrpcScopeResolver>();
services.AddSingleton<IGatewayRequestIdentityAccessor, GatewayRequestIdentityAccessor>();
services.AddSingleton<GatewayGrpcAuthorizationInterceptor>();
services.AddGrpc(options => options.Interceptors.Add<GatewayGrpcAuthorizationInterceptor>());
return services;
}
}
@@ -0,0 +1,10 @@
using MxGateway.Server.Security.Authentication;
namespace MxGateway.Server.Security.Authorization;
public interface IGatewayRequestIdentityAccessor
{
ApiKeyIdentity? Current { get; }
IDisposable Push(ApiKeyIdentity identity);
}
@@ -0,0 +1,27 @@
using MxGateway.Contracts.Proto;
namespace MxGateway.Server.Workers;
public interface IWorkerClient : IAsyncDisposable
{
string SessionId { get; }
int? ProcessId { get; }
WorkerClientState State { get; }
DateTimeOffset LastHeartbeatAt { get; }
Task StartAsync(CancellationToken cancellationToken);
Task<WorkerCommandReply> InvokeAsync(
WorkerCommand command,
TimeSpan timeout,
CancellationToken cancellationToken);
IAsyncEnumerable<WorkerEvent> ReadEventsAsync(CancellationToken cancellationToken);
Task ShutdownAsync(TimeSpan timeout, CancellationToken cancellationToken);
void Kill(string reason);
}
@@ -0,0 +1,755 @@
using System.Collections.Concurrent;
using System.Runtime.CompilerServices;
using System.Threading.Channels;
using Google.Protobuf.WellKnownTypes;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Server.Metrics;
namespace MxGateway.Server.Workers;
public sealed class WorkerClient : IWorkerClient
{
private const string GatewayVersionFallback = "unknown";
private readonly object _syncRoot = new();
private readonly WorkerClientConnection _connection;
private readonly WorkerClientOptions _options;
private readonly GatewayMetrics? _metrics;
private readonly TimeProvider _timeProvider;
private readonly ILogger<WorkerClient> _logger;
private readonly WorkerFrameReader _reader;
private readonly WorkerFrameWriter _writer;
private readonly Channel<WorkerEnvelope> _outboundEnvelopes;
private readonly Channel<WorkerEvent> _events;
private readonly ConcurrentDictionary<string, PendingCommand> _pendingCommands = new(StringComparer.Ordinal);
private readonly CancellationTokenSource _stopCts = new();
private long _nextSequence;
private WorkerClientState _state;
private DateTimeOffset _lastHeartbeatAt;
private int? _processId;
private Task? _readLoopTask;
private Task? _writeLoopTask;
private Task? _heartbeatLoopTask;
private bool _disposed;
public WorkerClient(
WorkerClientConnection connection,
WorkerClientOptions? options = null,
GatewayMetrics? metrics = null,
TimeProvider? timeProvider = null,
ILogger<WorkerClient>? logger = null)
{
_connection = connection ?? throw new ArgumentNullException(nameof(connection));
_options = options ?? new WorkerClientOptions();
_metrics = metrics;
_timeProvider = timeProvider ?? TimeProvider.System;
_logger = logger ?? NullLogger<WorkerClient>.Instance;
_reader = new WorkerFrameReader(connection.Stream, connection.FrameOptions);
_writer = new WorkerFrameWriter(connection.Stream, connection.FrameOptions);
_outboundEnvelopes = Channel.CreateUnbounded<WorkerEnvelope>(
new UnboundedChannelOptions
{
SingleReader = true,
SingleWriter = false,
AllowSynchronousContinuations = false,
});
_events = Channel.CreateBounded<WorkerEvent>(
new BoundedChannelOptions(_options.EventChannelCapacity)
{
SingleReader = false,
SingleWriter = true,
FullMode = BoundedChannelFullMode.Wait,
AllowSynchronousContinuations = false,
});
_lastHeartbeatAt = _timeProvider.GetUtcNow();
}
public string SessionId => _connection.SessionId;
public int? ProcessId
{
get
{
lock (_syncRoot)
{
return _processId;
}
}
}
public WorkerClientState State
{
get
{
lock (_syncRoot)
{
return _state;
}
}
}
public DateTimeOffset LastHeartbeatAt
{
get
{
lock (_syncRoot)
{
return _lastHeartbeatAt;
}
}
}
public async Task StartAsync(CancellationToken cancellationToken)
{
ThrowIfDisposed();
TransitionFromCreatedToHandshaking();
_writeLoopTask = Task.Run(WriteLoopAsync);
await EnqueueAsync(CreateGatewayHelloEnvelope(), cancellationToken).ConfigureAwait(false);
WorkerEnvelope helloEnvelope = await ReadHandshakeEnvelopeAsync(
WorkerEnvelope.BodyOneofCase.WorkerHello,
cancellationToken).ConfigureAwait(false);
ValidateWorkerHello(helloEnvelope.WorkerHello);
WorkerEnvelope readyEnvelope = await ReadHandshakeEnvelopeAsync(
WorkerEnvelope.BodyOneofCase.WorkerReady,
cancellationToken).ConfigureAwait(false);
MarkReady(readyEnvelope.WorkerReady);
_readLoopTask = Task.Run(ReadLoopAsync);
_heartbeatLoopTask = Task.Run(HeartbeatLoopAsync);
}
public async Task<WorkerCommandReply> InvokeAsync(
WorkerCommand command,
TimeSpan timeout,
CancellationToken cancellationToken)
{
ArgumentNullException.ThrowIfNull(command);
ThrowIfDisposed();
EnsureReady();
if (timeout <= TimeSpan.Zero)
{
throw new ArgumentOutOfRangeException(nameof(timeout), timeout, "Command timeout must be greater than zero.");
}
string correlationId = Guid.NewGuid().ToString("N");
string method = GetCommandMethod(command);
PendingCommand pendingCommand = new(
correlationId,
method,
_timeProvider.GetTimestamp());
if (!_pendingCommands.TryAdd(correlationId, pendingCommand))
{
throw new InvalidOperationException("Generated a duplicate command correlation id.");
}
_metrics?.CommandStarted(method);
try
{
await EnqueueAsync(CreateCommandEnvelope(correlationId, command), cancellationToken).ConfigureAwait(false);
using CancellationTokenSource timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
Task timeoutTask = Task.Delay(timeout, timeoutCts.Token);
Task<WorkerCommandReply> replyTask = pendingCommand.Task;
Task completedTask = await Task.WhenAny(replyTask, timeoutTask).ConfigureAwait(false);
if (completedTask == replyTask)
{
await timeoutCts.CancelAsync().ConfigureAwait(false);
return await replyTask.ConfigureAwait(false);
}
if (cancellationToken.IsCancellationRequested)
{
RemovePendingCommandAsFailed(
correlationId,
pendingCommand,
WorkerClientErrorCode.GatewayShutdown,
"Command wait was canceled.");
cancellationToken.ThrowIfCancellationRequested();
}
RemovePendingCommandAsFailed(
correlationId,
pendingCommand,
WorkerClientErrorCode.CommandTimeout,
$"Worker command {method} timed out after {timeout}.");
throw new WorkerClientException(
WorkerClientErrorCode.CommandTimeout,
$"Worker command {method} timed out after {timeout}.");
}
catch
{
_pendingCommands.TryRemove(correlationId, out _);
throw;
}
}
public async IAsyncEnumerable<WorkerEvent> ReadEventsAsync(
[EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (WorkerEvent workerEvent in _events.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false))
{
yield return workerEvent;
}
}
public async Task ShutdownAsync(TimeSpan timeout, CancellationToken cancellationToken)
{
ThrowIfDisposed();
if (timeout <= TimeSpan.Zero)
{
throw new ArgumentOutOfRangeException(nameof(timeout), timeout, "Shutdown timeout must be greater than zero.");
}
WorkerClientState state = State;
if (state is WorkerClientState.Closed or WorkerClientState.Faulted)
{
return;
}
MarkClosing();
await EnqueueAsync(CreateShutdownEnvelope(timeout, "gateway-shutdown"), cancellationToken).ConfigureAwait(false);
_outboundEnvelopes.Writer.TryComplete();
using CancellationTokenSource timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
timeoutCts.CancelAfter(timeout);
try
{
await WaitForBackgroundTasksAsync(timeoutCts.Token).ConfigureAwait(false);
MarkClosed("shutdown");
}
catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested)
{
SetFaulted(
WorkerClientErrorCode.ShutdownTimeout,
"Worker shutdown timed out.",
null);
throw new WorkerClientException(
WorkerClientErrorCode.ShutdownTimeout,
$"Worker shutdown timed out after {timeout}.");
}
}
public void Kill(string reason)
{
ThrowIfDisposed();
_connection.ProcessHandle?.Process.Kill(entireProcessTree: true);
_metrics?.WorkerKilled(reason);
SetFaulted(
WorkerClientErrorCode.WorkerFaulted,
$"Worker was killed by the gateway: {reason}.",
null);
}
public async ValueTask DisposeAsync()
{
if (_disposed)
{
return;
}
_disposed = true;
_stopCts.Cancel();
_outboundEnvelopes.Writer.TryComplete();
_events.Writer.TryComplete();
CompletePendingCommands(
new WorkerClientException(
WorkerClientErrorCode.GatewayShutdown,
"Worker client was disposed."));
await WaitForBackgroundTasksAsync(CancellationToken.None).ConfigureAwait(false);
await _connection.Stream.DisposeAsync().ConfigureAwait(false);
_connection.ProcessHandle?.Dispose();
_stopCts.Dispose();
}
private async Task WriteLoopAsync()
{
try
{
await foreach (WorkerEnvelope envelope in _outboundEnvelopes.Reader.ReadAllAsync(_stopCts.Token).ConfigureAwait(false))
{
await _writer.WriteAsync(envelope, _stopCts.Token).ConfigureAwait(false);
}
}
catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState())
{
}
catch (Exception exception)
{
SetFaulted(
WorkerClientErrorCode.WriteFailed,
"Worker pipe write failed.",
exception);
}
}
private async Task ReadLoopAsync()
{
try
{
while (!_stopCts.IsCancellationRequested)
{
WorkerEnvelope envelope = await _reader.ReadAsync(_stopCts.Token).ConfigureAwait(false);
await DispatchEnvelopeAsync(envelope, _stopCts.Token).ConfigureAwait(false);
}
}
catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState())
{
}
catch (WorkerFrameProtocolException exception) when (exception.ErrorCode == WorkerFrameProtocolErrorCode.EndOfStream)
{
SetFaulted(
WorkerClientErrorCode.PipeDisconnected,
"Worker pipe disconnected.",
exception);
}
catch (Exception exception)
{
SetFaulted(
WorkerClientErrorCode.ProtocolViolation,
"Worker read loop failed.",
exception);
}
}
private async Task HeartbeatLoopAsync()
{
try
{
while (!_stopCts.IsCancellationRequested)
{
await Task.Delay(_options.HeartbeatCheckInterval, _stopCts.Token).ConfigureAwait(false);
if (State != WorkerClientState.Ready)
{
continue;
}
DateTimeOffset lastHeartbeatAt = LastHeartbeatAt;
DateTimeOffset now = _timeProvider.GetUtcNow();
if (now - lastHeartbeatAt <= _options.HeartbeatGrace)
{
continue;
}
_metrics?.HeartbeatFailed(SessionId);
SetFaulted(
WorkerClientErrorCode.HeartbeatExpired,
$"Worker heartbeat expired. Last heartbeat was at {lastHeartbeatAt:O}.",
null);
}
}
catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState())
{
}
}
private async Task DispatchEnvelopeAsync(
WorkerEnvelope envelope,
CancellationToken cancellationToken)
{
switch (envelope.BodyCase)
{
case WorkerEnvelope.BodyOneofCase.WorkerCommandReply:
CompleteCommand(envelope);
break;
case WorkerEnvelope.BodyOneofCase.WorkerEvent:
await EnqueueWorkerEventAsync(envelope.WorkerEvent, cancellationToken).ConfigureAwait(false);
break;
case WorkerEnvelope.BodyOneofCase.WorkerHeartbeat:
MarkHeartbeat(envelope.WorkerHeartbeat);
break;
case WorkerEnvelope.BodyOneofCase.WorkerFault:
SetFaulted(
WorkerClientErrorCode.WorkerFaulted,
CreateWorkerFaultMessage(envelope.WorkerFault),
null);
break;
case WorkerEnvelope.BodyOneofCase.WorkerShutdownAck:
MarkClosed("worker-shutdown-ack");
break;
default:
SetFaulted(
WorkerClientErrorCode.ProtocolViolation,
$"Worker sent unexpected envelope body {envelope.BodyCase}.",
null);
break;
}
}
private async Task EnqueueWorkerEventAsync(
WorkerEvent workerEvent,
CancellationToken cancellationToken)
{
if (workerEvent.Event is not null)
{
_metrics?.EventReceived(SessionId, workerEvent.Event.Family.ToString());
}
if (!await _events.Writer.WaitToWriteAsync(cancellationToken).ConfigureAwait(false))
{
return;
}
if (!_events.Writer.TryWrite(workerEvent))
{
_metrics?.QueueOverflow("worker-events");
SetFaulted(
WorkerClientErrorCode.ProtocolViolation,
"Worker event channel rejected an event.",
null);
}
}
private void CompleteCommand(WorkerEnvelope envelope)
{
string correlationId = envelope.CorrelationId;
if (string.IsNullOrWhiteSpace(correlationId))
{
correlationId = envelope.WorkerCommandReply.Reply?.CorrelationId ?? string.Empty;
}
if (!_pendingCommands.TryRemove(correlationId, out PendingCommand? pendingCommand))
{
_logger.LogDebug(
"Ignoring late or unknown worker command reply for session {SessionId} and correlation {CorrelationId}.",
SessionId,
correlationId);
return;
}
TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp);
_metrics?.CommandSucceeded(pendingCommand.Method, duration);
pendingCommand.SetResult(envelope.WorkerCommandReply);
}
private void RemovePendingCommandAsFailed(
string correlationId,
PendingCommand pendingCommand,
WorkerClientErrorCode errorCode,
string message)
{
if (!_pendingCommands.TryRemove(correlationId, out _))
{
return;
}
TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp);
_metrics?.CommandFailed(pendingCommand.Method, errorCode.ToString(), duration);
pendingCommand.SetException(new WorkerClientException(errorCode, message));
}
private async Task<WorkerEnvelope> ReadHandshakeEnvelopeAsync(
WorkerEnvelope.BodyOneofCase expectedBody,
CancellationToken cancellationToken)
{
WorkerEnvelope envelope = await _reader.ReadAsync(cancellationToken).ConfigureAwait(false);
if (envelope.BodyCase != expectedBody)
{
throw new WorkerClientException(
WorkerClientErrorCode.ProtocolViolation,
$"Worker handshake expected {expectedBody} but received {envelope.BodyCase}.");
}
return envelope;
}
private void ValidateWorkerHello(WorkerHello workerHello)
{
if (workerHello.ProtocolVersion != _connection.FrameOptions.ProtocolVersion)
{
throw new WorkerClientException(
WorkerClientErrorCode.ProtocolViolation,
"Worker hello protocol version does not match the gateway protocol version.");
}
if (!string.Equals(workerHello.Nonce, _connection.Nonce, StringComparison.Ordinal))
{
throw new WorkerClientException(
WorkerClientErrorCode.ProtocolViolation,
"Worker hello nonce does not match the gateway nonce.");
}
lock (_syncRoot)
{
_processId = workerHello.WorkerProcessId == 0
? _connection.ProcessHandle?.ProcessId
: workerHello.WorkerProcessId;
}
}
private void MarkReady(WorkerReady ready)
{
lock (_syncRoot)
{
_processId = ready.WorkerProcessId == 0
? _processId ?? _connection.ProcessHandle?.ProcessId
: ready.WorkerProcessId;
_lastHeartbeatAt = _timeProvider.GetUtcNow();
_state = WorkerClientState.Ready;
}
DateTimeOffset readyAt = _timeProvider.GetUtcNow();
DateTimeOffset launchedAt = _connection.ProcessHandle?.LaunchedAt ?? readyAt;
_metrics?.WorkerStarted(readyAt - launchedAt);
}
private void MarkHeartbeat(WorkerHeartbeat heartbeat)
{
lock (_syncRoot)
{
_lastHeartbeatAt = _timeProvider.GetUtcNow();
if (heartbeat.WorkerProcessId != 0)
{
_processId = heartbeat.WorkerProcessId;
}
}
}
private void MarkClosing()
{
lock (_syncRoot)
{
if (_state is WorkerClientState.Closed or WorkerClientState.Faulted)
{
return;
}
_state = WorkerClientState.Closing;
}
}
private void MarkClosed(string reason)
{
lock (_syncRoot)
{
if (_state == WorkerClientState.Closed)
{
return;
}
_state = WorkerClientState.Closed;
}
_stopCts.Cancel();
_outboundEnvelopes.Writer.TryComplete();
_events.Writer.TryComplete();
CompletePendingCommands(
new WorkerClientException(
WorkerClientErrorCode.GatewayShutdown,
$"Worker client closed because {reason}."));
_metrics?.WorkerStopped(reason);
}
private void SetFaulted(
WorkerClientErrorCode errorCode,
string message,
Exception? exception)
{
WorkerClientException fault = exception is null
? new WorkerClientException(errorCode, message)
: new WorkerClientException(errorCode, message, exception);
lock (_syncRoot)
{
if (_state is WorkerClientState.Faulted or WorkerClientState.Closed)
{
return;
}
_state = WorkerClientState.Faulted;
}
_stopCts.Cancel();
_outboundEnvelopes.Writer.TryComplete(fault);
_events.Writer.TryComplete(fault);
CompletePendingCommands(fault);
_metrics?.Fault(errorCode.ToString());
_logger.LogWarning(exception, "Worker client faulted for session {SessionId}: {Message}", SessionId, message);
}
private void CompletePendingCommands(Exception exception)
{
foreach (KeyValuePair<string, PendingCommand> item in _pendingCommands.ToArray())
{
if (_pendingCommands.TryRemove(item.Key, out PendingCommand? pendingCommand))
{
TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp);
_metrics?.CommandFailed(pendingCommand.Method, exception.GetType().Name, duration);
pendingCommand.SetException(exception);
}
}
}
private void TransitionFromCreatedToHandshaking()
{
lock (_syncRoot)
{
if (_state != WorkerClientState.Created)
{
throw new WorkerClientException(
WorkerClientErrorCode.InvalidState,
$"Worker client cannot start from state {_state}.");
}
_state = WorkerClientState.Handshaking;
}
}
private void EnsureReady()
{
WorkerClientState state = State;
if (state != WorkerClientState.Ready)
{
throw new WorkerClientException(
WorkerClientErrorCode.InvalidState,
$"Worker client is not ready. Current state is {state}.");
}
}
private bool IsTerminalState()
{
WorkerClientState state = State;
return state is WorkerClientState.Closing or WorkerClientState.Closed or WorkerClientState.Faulted;
}
private async Task EnqueueAsync(
WorkerEnvelope envelope,
CancellationToken cancellationToken)
{
try
{
await _outboundEnvelopes.Writer.WriteAsync(envelope, cancellationToken).ConfigureAwait(false);
}
catch (ChannelClosedException exception)
{
throw new WorkerClientException(
WorkerClientErrorCode.WriteFailed,
"Worker outbound channel is closed.",
exception);
}
}
private WorkerEnvelope CreateGatewayHelloEnvelope()
{
return CreateEnvelope(
correlationId: string.Empty,
envelope => envelope.GatewayHello = new GatewayHello
{
SupportedProtocolVersion = _connection.FrameOptions.ProtocolVersion,
Nonce = _connection.Nonce,
GatewayVersion = typeof(GatewayContractInfo).Assembly.GetName().Version?.ToString() ?? GatewayVersionFallback,
});
}
private WorkerEnvelope CreateCommandEnvelope(
string correlationId,
WorkerCommand command)
{
return CreateEnvelope(
correlationId,
envelope => envelope.WorkerCommand = command.Clone());
}
private WorkerEnvelope CreateShutdownEnvelope(
TimeSpan timeout,
string reason)
{
return CreateEnvelope(
correlationId: string.Empty,
envelope => envelope.WorkerShutdown = new WorkerShutdown
{
GracePeriod = Duration.FromTimeSpan(timeout),
Reason = reason,
});
}
private WorkerEnvelope CreateEnvelope(
string correlationId,
Action<WorkerEnvelope> setBody)
{
WorkerEnvelope envelope = new()
{
ProtocolVersion = _connection.FrameOptions.ProtocolVersion,
SessionId = SessionId,
Sequence = (ulong)Interlocked.Increment(ref _nextSequence),
CorrelationId = correlationId,
};
setBody(envelope);
return envelope;
}
private static string GetCommandMethod(WorkerCommand command)
{
return command.Command?.Kind.ToString() ?? MxCommandKind.Unspecified.ToString();
}
private static string CreateWorkerFaultMessage(WorkerFault fault)
{
return string.IsNullOrWhiteSpace(fault.DiagnosticMessage)
? $"Worker faulted with category {fault.Category}."
: $"Worker faulted with category {fault.Category}: {fault.DiagnosticMessage}";
}
private async Task WaitForBackgroundTasksAsync(CancellationToken cancellationToken)
{
Task[] tasks = new[] { _readLoopTask, _writeLoopTask, _heartbeatLoopTask }
.Where(task => task is not null)
.Cast<Task>()
.ToArray();
if (tasks.Length == 0)
{
return;
}
await Task.WhenAll(tasks).WaitAsync(cancellationToken).ConfigureAwait(false);
}
private void ThrowIfDisposed()
{
ObjectDisposedException.ThrowIf(_disposed, this);
}
private sealed class PendingCommand
{
private readonly TaskCompletionSource<WorkerCommandReply> _completion = new(TaskCreationOptions.RunContinuationsAsynchronously);
public PendingCommand(
string correlationId,
string method,
long startTimestamp)
{
CorrelationId = correlationId;
Method = method;
StartTimestamp = startTimestamp;
}
public string CorrelationId { get; }
public string Method { get; }
public long StartTimestamp { get; }
public Task<WorkerCommandReply> Task => _completion.Task;
public void SetResult(WorkerCommandReply reply)
{
_completion.TrySetResult(reply);
}
public void SetException(Exception exception)
{
_completion.TrySetException(exception);
}
}
}
@@ -0,0 +1,38 @@
namespace MxGateway.Server.Workers;
public sealed class WorkerClientConnection
{
public WorkerClientConnection(
string sessionId,
string nonce,
Stream stream,
WorkerFrameProtocolOptions frameOptions,
WorkerProcessHandle? processHandle = null)
{
if (string.IsNullOrWhiteSpace(sessionId))
{
throw new ArgumentException("Session id is required.", nameof(sessionId));
}
if (string.IsNullOrWhiteSpace(nonce))
{
throw new ArgumentException("Worker nonce is required.", nameof(nonce));
}
SessionId = sessionId;
Nonce = nonce;
Stream = stream ?? throw new ArgumentNullException(nameof(stream));
FrameOptions = frameOptions ?? throw new ArgumentNullException(nameof(frameOptions));
ProcessHandle = processHandle;
}
public string SessionId { get; }
public string Nonce { get; }
public Stream Stream { get; }
public WorkerFrameProtocolOptions FrameOptions { get; }
public WorkerProcessHandle? ProcessHandle { get; }
}
@@ -0,0 +1,14 @@
namespace MxGateway.Server.Workers;
public enum WorkerClientErrorCode
{
InvalidState,
ProtocolViolation,
PipeDisconnected,
CommandTimeout,
WorkerFaulted,
HeartbeatExpired,
ShutdownTimeout,
GatewayShutdown,
WriteFailed,
}
@@ -0,0 +1,23 @@
namespace MxGateway.Server.Workers;
public sealed class WorkerClientException : Exception
{
public WorkerClientException(
WorkerClientErrorCode errorCode,
string message)
: base(message)
{
ErrorCode = errorCode;
}
public WorkerClientException(
WorkerClientErrorCode errorCode,
string message,
Exception innerException)
: base(message, innerException)
{
ErrorCode = errorCode;
}
public WorkerClientErrorCode ErrorCode { get; }
}
@@ -0,0 +1,24 @@
namespace MxGateway.Server.Workers;
public sealed class WorkerClientOptions
{
public static readonly TimeSpan DefaultHeartbeatGrace = TimeSpan.FromSeconds(15);
public static readonly TimeSpan DefaultHeartbeatCheckInterval = TimeSpan.FromSeconds(1);
public static readonly TimeSpan DefaultEventChannelFullModeTimeout = TimeSpan.FromSeconds(5);
public WorkerClientOptions()
{
HeartbeatGrace = DefaultHeartbeatGrace;
HeartbeatCheckInterval = DefaultHeartbeatCheckInterval;
EventChannelCapacity = 1_024;
EventChannelFullModeTimeout = DefaultEventChannelFullModeTimeout;
}
public TimeSpan HeartbeatGrace { get; init; }
public TimeSpan HeartbeatCheckInterval { get; init; }
public int EventChannelCapacity { get; init; }
public TimeSpan EventChannelFullModeTimeout { get; init; }
}
@@ -0,0 +1,11 @@
namespace MxGateway.Server.Workers;
public enum WorkerClientState
{
Created,
Handshaking,
Ready,
Closing,
Closed,
Faulted,
}
@@ -0,0 +1,341 @@
using System.IO.Pipes;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Server.Workers;
namespace MxGateway.Tests.Gateway.Workers;
public sealed class WorkerClientTests
{
private const string SessionId = "session-worker-client";
private const string Nonce = "nonce-worker-client";
private const int WorkerProcessId = 4321;
private static readonly TimeSpan TestTimeout = TimeSpan.FromSeconds(5);
[Fact]
public async Task StartAsync_WithWorkerHelloAndReady_EntersReadyState()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
Assert.Equal(WorkerClientState.Ready, client.State);
Assert.Equal(WorkerProcessId, client.ProcessId);
}
[Fact]
public async Task InvokeAsync_WithMatchingReply_CompletesPendingCommand()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
Task<WorkerCommandReply> invokeTask = client.InvokeAsync(
CreateCommand(MxCommandKind.Ping),
TestTimeout,
CancellationToken.None);
WorkerEnvelope commandEnvelope = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerCommand, commandEnvelope.BodyCase);
Assert.False(string.IsNullOrWhiteSpace(commandEnvelope.CorrelationId));
await pipePair.WorkerWriter.WriteAsync(
CreateCommandReplyEnvelope(commandEnvelope.CorrelationId, MxCommandKind.Ping));
WorkerCommandReply reply = await invokeTask.WaitAsync(TestTimeout);
Assert.Equal(commandEnvelope.CorrelationId, reply.Reply.CorrelationId);
Assert.Equal(MxCommandKind.Ping, reply.Reply.Kind);
}
[Fact]
public async Task InvokeAsync_WithLateReply_IgnoresLateReplyAndKeepsClientReady()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
Task<WorkerCommandReply> timedOutInvokeTask = client.InvokeAsync(
CreateCommand(MxCommandKind.Ping),
TimeSpan.FromMilliseconds(50),
CancellationToken.None);
WorkerEnvelope timedOutCommand = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
WorkerClientException exception = await Assert.ThrowsAsync<WorkerClientException>(
async () => await timedOutInvokeTask);
Assert.Equal(WorkerClientErrorCode.CommandTimeout, exception.ErrorCode);
await pipePair.WorkerWriter.WriteAsync(
CreateCommandReplyEnvelope(timedOutCommand.CorrelationId, MxCommandKind.Ping));
await Task.Delay(TimeSpan.FromMilliseconds(50));
Task<WorkerCommandReply> secondInvokeTask = client.InvokeAsync(
CreateCommand(MxCommandKind.GetWorkerInfo),
TestTimeout,
CancellationToken.None);
WorkerEnvelope secondCommand = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
await pipePair.WorkerWriter.WriteAsync(
CreateCommandReplyEnvelope(secondCommand.CorrelationId, MxCommandKind.GetWorkerInfo));
WorkerCommandReply reply = await secondInvokeTask.WaitAsync(TestTimeout);
Assert.Equal(WorkerClientState.Ready, client.State);
Assert.Equal(MxCommandKind.GetWorkerInfo, reply.Reply.Kind);
}
[Fact]
public async Task ReadEventsAsync_WithWorkerEvents_YieldsEventsInPipeOrder()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
using CancellationTokenSource cancellationTokenSource = new(TestTimeout);
await using IAsyncEnumerator<WorkerEvent> events =
client.ReadEventsAsync(cancellationTokenSource.Token).GetAsyncEnumerator(cancellationTokenSource.Token);
await pipePair.WorkerWriter.WriteAsync(
CreateEventEnvelope(sequence: 11, MxEventFamily.OnDataChange));
await pipePair.WorkerWriter.WriteAsync(
CreateEventEnvelope(sequence: 12, MxEventFamily.OperationComplete));
Assert.True(await events.MoveNextAsync());
Assert.Equal((ulong)11, events.Current.Event.WorkerSequence);
Assert.Equal(MxEventFamily.OnDataChange, events.Current.Event.Family);
Assert.True(await events.MoveNextAsync());
Assert.Equal((ulong)12, events.Current.Event.WorkerSequence);
Assert.Equal(MxEventFamily.OperationComplete, events.Current.Event.Family);
}
[Fact]
public async Task ReadLoop_WhenPipeDisconnects_FaultsClient()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
await pipePair.DisposeWorkerSideAsync();
await WaitUntilAsync(
() => client.State == WorkerClientState.Faulted,
TestTimeout);
Assert.Equal(WorkerClientState.Faulted, client.State);
}
[Fact]
public async Task HeartbeatMonitor_WhenHeartbeatExpires_FaultsClient()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(
pipePair,
new WorkerClientOptions
{
HeartbeatGrace = TimeSpan.FromMilliseconds(80),
HeartbeatCheckInterval = TimeSpan.FromMilliseconds(20),
EventChannelCapacity = 8,
});
await CompleteHandshakeAsync(client, pipePair);
await WaitUntilAsync(
() => client.State == WorkerClientState.Faulted,
TestTimeout);
Assert.Equal(WorkerClientState.Faulted, client.State);
}
private static WorkerClient CreateClient(
PipePair pipePair,
WorkerClientOptions? options = null)
{
WorkerFrameProtocolOptions frameOptions = new(SessionId);
WorkerClientConnection connection = new(
SessionId,
Nonce,
pipePair.GatewayStream,
frameOptions);
return new WorkerClient(connection, options);
}
private static async Task CompleteHandshakeAsync(
WorkerClient client,
PipePair pipePair)
{
Task startTask = client.StartAsync(CancellationToken.None);
WorkerEnvelope gatewayHello = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
Assert.Equal(WorkerEnvelope.BodyOneofCase.GatewayHello, gatewayHello.BodyCase);
Assert.Equal(Nonce, gatewayHello.GatewayHello.Nonce);
Assert.Equal(GatewayContractInfo.WorkerProtocolVersion, gatewayHello.GatewayHello.SupportedProtocolVersion);
await pipePair.WorkerWriter.WriteAsync(CreateWorkerHelloEnvelope());
await pipePair.WorkerWriter.WriteAsync(CreateWorkerReadyEnvelope());
await startTask.WaitAsync(TestTimeout);
}
private static WorkerCommand CreateCommand(MxCommandKind kind)
{
return new WorkerCommand
{
Command = new MxCommand
{
Kind = kind,
},
};
}
private static WorkerEnvelope CreateWorkerHelloEnvelope()
{
return CreateWorkerEnvelope(
correlationId: string.Empty,
sequence: 1,
envelope => envelope.WorkerHello = new WorkerHello
{
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
Nonce = Nonce,
WorkerProcessId = WorkerProcessId,
WorkerVersion = "fake-worker",
});
}
private static WorkerEnvelope CreateWorkerReadyEnvelope()
{
return CreateWorkerEnvelope(
correlationId: string.Empty,
sequence: 2,
envelope => envelope.WorkerReady = new WorkerReady
{
WorkerProcessId = WorkerProcessId,
MxaccessProgid = "LMXProxy.LMXProxyServer.1",
MxaccessClsid = "{C30B52F5-2CB5-4760-AF0A-3A344A7EB5DC}",
});
}
private static WorkerEnvelope CreateCommandReplyEnvelope(
string correlationId,
MxCommandKind kind)
{
return CreateWorkerEnvelope(
correlationId,
sequence: 10,
envelope => envelope.WorkerCommandReply = new WorkerCommandReply
{
Reply = new MxCommandReply
{
SessionId = SessionId,
CorrelationId = correlationId,
Kind = kind,
},
});
}
private static WorkerEnvelope CreateEventEnvelope(
ulong sequence,
MxEventFamily family)
{
return CreateWorkerEnvelope(
correlationId: string.Empty,
sequence,
envelope => envelope.WorkerEvent = new WorkerEvent
{
Event = new MxEvent
{
SessionId = SessionId,
Family = family,
WorkerSequence = sequence,
},
});
}
private static WorkerEnvelope CreateWorkerEnvelope(
string correlationId,
ulong sequence,
Action<WorkerEnvelope> setBody)
{
WorkerEnvelope envelope = new()
{
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
SessionId = SessionId,
Sequence = sequence,
CorrelationId = correlationId,
};
setBody(envelope);
return envelope;
}
private static async Task WaitUntilAsync(
Func<bool> predicate,
TimeSpan timeout)
{
using CancellationTokenSource cancellationTokenSource = new(timeout);
while (!predicate())
{
await Task.Delay(TimeSpan.FromMilliseconds(10), cancellationTokenSource.Token);
}
}
private sealed class PipePair : IAsyncDisposable
{
private readonly NamedPipeClientStream _workerStream;
private bool _workerSideDisposed;
private PipePair(
NamedPipeServerStream gatewayStream,
NamedPipeClientStream workerStream)
{
GatewayStream = gatewayStream;
_workerStream = workerStream;
WorkerReader = new WorkerFrameReader(_workerStream, new WorkerFrameProtocolOptions(SessionId));
WorkerWriter = new WorkerFrameWriter(_workerStream, new WorkerFrameProtocolOptions(SessionId));
}
public NamedPipeServerStream GatewayStream { get; }
public WorkerFrameReader WorkerReader { get; }
public WorkerFrameWriter WorkerWriter { get; }
public static async Task<PipePair> CreateAsync()
{
string pipeName = $"mxaccessgw-workerclient-tests-{Guid.NewGuid():N}";
NamedPipeServerStream gatewayStream = new(
pipeName,
PipeDirection.InOut,
maxNumberOfServerInstances: 1,
PipeTransmissionMode.Byte,
PipeOptions.Asynchronous);
NamedPipeClientStream workerStream = new(
".",
pipeName,
PipeDirection.InOut,
PipeOptions.Asynchronous);
Task waitForConnectionTask = gatewayStream.WaitForConnectionAsync();
await workerStream.ConnectAsync();
await waitForConnectionTask;
return new PipePair(gatewayStream, workerStream);
}
public async ValueTask DisposeWorkerSideAsync()
{
if (_workerSideDisposed)
{
return;
}
await _workerStream.DisposeAsync();
_workerSideDisposed = true;
}
public async ValueTask DisposeAsync()
{
await DisposeWorkerSideAsync();
await GatewayStream.DisposeAsync();
}
}
}
@@ -0,0 +1,267 @@
using Grpc.Core;
using Microsoft.Extensions.Options;
using MxGateway.Contracts.Proto;
using MxGateway.Server.Configuration;
using MxGateway.Server.Security.Authentication;
using MxGateway.Server.Security.Authorization;
namespace MxGateway.Tests.Security.Authorization;
public sealed class GatewayGrpcAuthorizationInterceptorTests
{
[Fact]
public async Task UnaryServerHandler_MissingApiKey_ReturnsUnauthenticated()
{
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
new FakeApiKeyVerifier(ApiKeyVerificationResult.Fail(
ApiKeyVerificationFailure.MissingOrMalformedCredentials)),
new GatewayRequestIdentityAccessor());
RpcException exception = await Assert.ThrowsAsync<RpcException>(
() => interceptor.UnaryServerHandler(
new OpenSessionRequest(),
new TestServerCallContext([]),
(_, _) => Task.FromResult(new OpenSessionReply())));
Assert.Equal(StatusCode.Unauthenticated, exception.StatusCode);
Assert.DoesNotContain("secret", exception.Status.Detail, StringComparison.OrdinalIgnoreCase);
}
[Fact]
public async Task UnaryServerHandler_InvalidApiKey_DoesNotExposeRawCredentialInStatus()
{
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
new FakeApiKeyVerifier(ApiKeyVerificationResult.Fail(ApiKeyVerificationFailure.SecretMismatch)),
new GatewayRequestIdentityAccessor());
RpcException exception = await Assert.ThrowsAsync<RpcException>(
() => interceptor.UnaryServerHandler(
new OpenSessionRequest(),
ContextWithAuthorization("Bearer mxgw_operator01_super-secret"),
(_, _) => Task.FromResult(new OpenSessionReply())));
Assert.Equal(StatusCode.Unauthenticated, exception.StatusCode);
Assert.DoesNotContain("super-secret", exception.Status.Detail, StringComparison.Ordinal);
}
[Fact]
public async Task UnaryServerHandler_ValidApiKeyMissingScope_ReturnsPermissionDenied()
{
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
new FakeApiKeyVerifier(SuccessWithScopes(GatewayScopes.EventsRead)),
new GatewayRequestIdentityAccessor());
RpcException exception = await Assert.ThrowsAsync<RpcException>(
() => interceptor.UnaryServerHandler(
new OpenSessionRequest(),
ContextWithAuthorization("Bearer mxgw_operator01_secret"),
(_, _) => Task.FromResult(new OpenSessionReply())));
Assert.Equal(StatusCode.PermissionDenied, exception.StatusCode);
Assert.Contains(GatewayScopes.SessionOpen, exception.Status.Detail, StringComparison.Ordinal);
}
[Fact]
public async Task UnaryServerHandler_ValidApiKeyWithScope_SetsRequestIdentity()
{
GatewayRequestIdentityAccessor identityAccessor = new();
ApiKeyIdentity? identitySeenByHandler = null;
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
new FakeApiKeyVerifier(SuccessWithScopes(GatewayScopes.SessionOpen)),
identityAccessor);
OpenSessionReply reply = await interceptor.UnaryServerHandler(
new OpenSessionRequest(),
ContextWithAuthorization("Bearer mxgw_operator01_secret"),
(_, _) =>
{
identitySeenByHandler = identityAccessor.Current;
return Task.FromResult(new OpenSessionReply { SessionId = "session-1" });
});
Assert.Equal("session-1", reply.SessionId);
Assert.NotNull(identitySeenByHandler);
Assert.Equal("operator01", identitySeenByHandler.KeyId);
Assert.Null(identityAccessor.Current);
}
[Fact]
public async Task ServerStreamingServerHandler_ValidApiKeyMissingScope_ReturnsPermissionDenied()
{
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
new FakeApiKeyVerifier(SuccessWithScopes(GatewayScopes.SessionOpen)),
new GatewayRequestIdentityAccessor());
RpcException exception = await Assert.ThrowsAsync<RpcException>(
() => interceptor.ServerStreamingServerHandler(
new StreamEventsRequest(),
new TestServerStreamWriter<MxEvent>(),
ContextWithAuthorization("Bearer mxgw_operator01_secret"),
(_, _, _) => Task.CompletedTask));
Assert.Equal(StatusCode.PermissionDenied, exception.StatusCode);
Assert.Contains(GatewayScopes.EventsRead, exception.Status.Detail, StringComparison.Ordinal);
}
[Fact]
public async Task ServerStreamingServerHandler_ValidApiKeyWithScope_AllowsStream()
{
GatewayRequestIdentityAccessor identityAccessor = new();
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
new FakeApiKeyVerifier(SuccessWithScopes(GatewayScopes.EventsRead)),
identityAccessor);
TestServerStreamWriter<MxEvent> streamWriter = new();
await interceptor.ServerStreamingServerHandler(
new StreamEventsRequest(),
streamWriter,
ContextWithAuthorization("Bearer mxgw_operator01_secret"),
async (_, writer, _) =>
{
Assert.Equal("operator01", identityAccessor.Current?.KeyId);
await writer.WriteAsync(new MxEvent { SessionId = "session-1" });
});
MxEvent eventMessage = Assert.Single(streamWriter.Messages);
Assert.Equal("session-1", eventMessage.SessionId);
Assert.Null(identityAccessor.Current);
}
[Fact]
public async Task UnaryServerHandler_AuthenticationDisabled_SkipsApiKeyVerification()
{
GatewayRequestIdentityAccessor identityAccessor = new();
FakeApiKeyVerifier verifier = new(ApiKeyVerificationResult.Fail(
ApiKeyVerificationFailure.MissingOrMalformedCredentials));
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
verifier,
identityAccessor,
AuthenticationMode.Disabled);
OpenSessionReply reply = await interceptor.UnaryServerHandler(
new OpenSessionRequest(),
new TestServerCallContext([]),
(_, _) => Task.FromResult(new OpenSessionReply { SessionId = "session-1" }));
Assert.Equal("session-1", reply.SessionId);
Assert.False(verifier.WasCalled);
Assert.Null(identityAccessor.Current);
}
private static GatewayGrpcAuthorizationInterceptor CreateInterceptor(
IApiKeyVerifier apiKeyVerifier,
IGatewayRequestIdentityAccessor identityAccessor,
AuthenticationMode authenticationMode = AuthenticationMode.ApiKey)
{
return new GatewayGrpcAuthorizationInterceptor(
apiKeyVerifier,
new GatewayGrpcScopeResolver(),
identityAccessor,
Options.Create(new GatewayOptions
{
Authentication = new AuthenticationOptions
{
Mode = authenticationMode
}
}));
}
private static ApiKeyVerificationResult SuccessWithScopes(params string[] scopes)
{
return ApiKeyVerificationResult.Success(new ApiKeyIdentity(
KeyId: "operator01",
KeyPrefix: "mxgw_operator01",
DisplayName: "Operator Key",
Scopes: new HashSet<string>(scopes, StringComparer.Ordinal)));
}
private static TestServerCallContext ContextWithAuthorization(string authorizationHeader)
{
return new TestServerCallContext([new Metadata.Entry("authorization", authorizationHeader)]);
}
private sealed class FakeApiKeyVerifier(ApiKeyVerificationResult result) : IApiKeyVerifier
{
public bool WasCalled { get; private set; }
public string? LastAuthorizationHeader { get; private set; }
public Task<ApiKeyVerificationResult> VerifyAsync(
string? authorizationHeader,
CancellationToken cancellationToken)
{
WasCalled = true;
LastAuthorizationHeader = authorizationHeader;
return Task.FromResult(result);
}
}
private sealed class TestServerStreamWriter<T> : IServerStreamWriter<T>
{
public List<T> Messages { get; } = [];
public WriteOptions? WriteOptions { get; set; }
public Task WriteAsync(T message)
{
Messages.Add(message);
return Task.CompletedTask;
}
}
private sealed class TestServerCallContext(
Metadata requestHeaders,
CancellationToken cancellationToken = default) : ServerCallContext
{
private readonly Metadata responseTrailers = [];
private readonly Dictionary<object, object> userState = [];
private Status status;
private WriteOptions? writeOptions;
protected override string MethodCore => "/mxaccess_gateway.v1.MxAccessGateway/Test";
protected override string HostCore => "localhost";
protected override string PeerCore => "ipv4:127.0.0.1:5000";
protected override DateTime DeadlineCore => DateTime.UtcNow.AddMinutes(1);
protected override Metadata RequestHeadersCore => requestHeaders;
protected override CancellationToken CancellationTokenCore => cancellationToken;
protected override Metadata ResponseTrailersCore => responseTrailers;
protected override Status StatusCore
{
get => status;
set => status = value;
}
protected override WriteOptions? WriteOptionsCore
{
get => writeOptions;
set => writeOptions = value;
}
protected override AuthContext AuthContextCore { get; } = new(
string.Empty,
new Dictionary<string, List<AuthProperty>>(StringComparer.Ordinal));
protected override IDictionary<object, object> UserStateCore => userState;
protected override Task WriteResponseHeadersAsyncCore(Metadata responseHeaders)
{
return Task.CompletedTask;
}
protected override ContextPropagationToken CreatePropagationTokenCore(
ContextPropagationOptions? options)
{
throw new NotSupportedException();
}
}
}
@@ -0,0 +1,54 @@
using MxGateway.Contracts.Proto;
using MxGateway.Server.Security.Authorization;
namespace MxGateway.Tests.Security.Authorization;
public sealed class GatewayGrpcScopeResolverTests
{
[Theory]
[InlineData(typeof(OpenSessionRequest), GatewayScopes.SessionOpen)]
[InlineData(typeof(CloseSessionRequest), GatewayScopes.SessionClose)]
[InlineData(typeof(StreamEventsRequest), GatewayScopes.EventsRead)]
public void ResolveRequiredScope_KnownRpcRequest_ReturnsExpectedScope(
Type requestType,
string expectedScope)
{
GatewayGrpcScopeResolver resolver = new();
object request = Activator.CreateInstance(requestType)!;
string scope = resolver.ResolveRequiredScope(request);
Assert.Equal(expectedScope, scope);
}
[Theory]
[InlineData(MxCommandKind.Register, GatewayScopes.InvokeRead)]
[InlineData(MxCommandKind.AddItem, GatewayScopes.InvokeRead)]
[InlineData(MxCommandKind.Advise, GatewayScopes.InvokeRead)]
[InlineData(MxCommandKind.Write, GatewayScopes.InvokeWrite)]
[InlineData(MxCommandKind.Write2, GatewayScopes.InvokeWrite)]
[InlineData(MxCommandKind.WriteSecured, GatewayScopes.InvokeSecure)]
[InlineData(MxCommandKind.WriteSecured2, GatewayScopes.InvokeSecure)]
[InlineData(MxCommandKind.AuthenticateUser, GatewayScopes.InvokeSecure)]
[InlineData(MxCommandKind.ArchestraUserToId, GatewayScopes.MetadataRead)]
[InlineData(MxCommandKind.GetSessionState, GatewayScopes.MetadataRead)]
[InlineData(MxCommandKind.GetWorkerInfo, GatewayScopes.MetadataRead)]
[InlineData(MxCommandKind.DrainEvents, GatewayScopes.EventsRead)]
[InlineData(MxCommandKind.ShutdownWorker, GatewayScopes.Admin)]
public void ResolveRequiredScope_InvokeCommand_ReturnsExpectedScope(
MxCommandKind commandKind,
string expectedScope)
{
GatewayGrpcScopeResolver resolver = new();
string scope = resolver.ResolveRequiredScope(new MxCommandRequest
{
Command = new MxCommand
{
Kind = commandKind
}
});
Assert.Equal(expectedScope, scope);
}
}