Compare commits

..

20 Commits

Author SHA1 Message Date
Joseph Doherty 6559672fc1 Issue #30: implement value conversion 2026-04-26 17:26:36 -04:00
dohertj2 97c30b9d00 Merge PR #66: Issue #23 implement STA runtime and message pump
Verified with dotnet build src\\MxGateway.sln, dotnet test src\\MxGateway.Worker.Tests\\MxGateway.Worker.Tests.csproj -p:Platform=x86, and dotnet test src\\MxGateway.sln --no-build.
2026-04-26 17:23:02 -04:00
dohertj2 603aff7004 Merge PR #65: Issue #22 implement pipe client and frame protocol
Verified with dotnet build src\\MxGateway.sln, dotnet test src\\MxGateway.Worker.Tests\\MxGateway.Worker.Tests.csproj -p:Platform=x86, and dotnet test src\\MxGateway.sln --no-build.
2026-04-26 17:20:28 -04:00
Joseph Doherty e81682e367 Issue #23: implement sta runtime and message pump 2026-04-26 17:19:00 -04:00
Joseph Doherty d5a982152b Issue #22: implement pipe client and frame protocol 2026-04-26 17:16:49 -04:00
dohertj2 0b0be7098e Merge PR #64: Issue #11 implement gateway WorkerClient
Verified with dotnet build src\\MxGateway.sln and dotnet test src\\MxGateway.sln.
2026-04-26 17:14:03 -04:00
Joseph Doherty fce9e99553 Issue #11: implement gateway workerclient 2026-04-26 17:09:51 -04:00
dohertj2 c8fb3e91a3 Merge PR #63: Issue #8 add gRPC authentication and scope authorization
Verified with dotnet build src\\MxGateway.sln and dotnet test src\\MxGateway.sln.
2026-04-26 17:06:23 -04:00
dohertj2 8ce327e6f4 Merge PR #62: Issue #7 implement local api key admin cli
Verified with dotnet build src\\MxGateway.sln and dotnet test src\\MxGateway.sln.
2026-04-26 17:02:09 -04:00
Joseph Doherty fad0ac9948 Issue #8: add grpc authentication and scope authorization 2026-04-26 17:01:59 -04:00
dohertj2 9cb2f1c5cd Merge PR #61: Issue #21 implement worker bootstrap and options
Verified with dotnet build src\\MxGateway.sln, dotnet test src\\MxGateway.Worker.Tests\\MxGateway.Worker.Tests.csproj -p:Platform=x86, and dotnet test src\\MxGateway.sln --no-build.
2026-04-26 16:56:52 -04:00
Joseph Doherty da9ffe0e11 Issue #7: implement local api key admin cli 2026-04-26 16:56:12 -04:00
Joseph Doherty 0af1427859 Issue #21: implement worker bootstrap and options 2026-04-26 16:53:06 -04:00
dohertj2 e2b4dfcb32 Merge PR #60: Issue #10 implement worker process launcher
Verified with dotnet build src\\MxGateway.sln and dotnet test src\\MxGateway.sln --no-build.
2026-04-26 16:50:00 -04:00
dohertj2 3b3e41acf4 Merge PR #59: Issue #6 implement API key hashing and verification
Verified with dotnet build src\\MxGateway.sln and dotnet test src\\MxGateway.sln.
2026-04-26 16:46:06 -04:00
Joseph Doherty c1188c6957 Issue #10: implement worker process launcher 2026-04-26 16:45:42 -04:00
dohertj2 4094e64ee0 Merge PR #58: Issue #20 scaffold worker project
Verified with dotnet build src\\MxGateway.sln, dotnet test src\\MxGateway.Worker.Tests\\MxGateway.Worker.Tests.csproj -p:Platform=x86, and dotnet test src\\MxGateway.sln --no-build.
2026-04-26 16:41:34 -04:00
Joseph Doherty 696be17139 Issue #6: implement api key hashing and verification 2026-04-26 16:40:46 -04:00
Joseph Doherty b42c3c8b3b Issue #20: scaffold worker project 2026-04-26 16:37:23 -04:00
dohertj2 420a813967 Merge PR #56: Issue #5 implement SQLite auth store and migrations
Verified with dotnet build src\\MxGateway.sln and dotnet test src\\MxGateway.sln.
2026-04-26 16:34:28 -04:00
120 changed files with 7759 additions and 28 deletions
+62
View File
@@ -0,0 +1,62 @@
# Worker Process Launcher
The gateway uses `WorkerProcessLauncher` to validate and start one worker
process for a gateway session. The launcher owns process start semantics only;
pipe handshaking and `WorkerReady` validation remain part of the worker client
startup path.
## Launch Inputs
`WorkerProcessLaunchRequest` carries the per-session bootstrap values:
- `SessionId`,
- `PipeName`,
- `ProtocolVersion`,
- `Nonce`,
- optional `PipeReservation` cleanup handle.
The launcher passes `SessionId`, `PipeName`, and `ProtocolVersion` as command
line arguments:
```text
--session-id <sessionId> --pipe-name <pipeName> --protocol-version <version>
```
The launcher sets the nonce through the `MXGATEWAY_WORKER_NONCE` environment
variable. The nonce is not included in `WorkerProcessCommandLine` so logs and
diagnostics can report the launch command without exposing the secret.
## Validation And Cleanup
Before starting the process, the launcher validates that the configured worker
path exists, has a `.exe` extension, contains a valid Windows Portable
Executable header, and matches the configured `RequiredArchitecture`.
After the process starts, `IWorkerStartupProbe` waits for startup readiness.
The default probe only verifies that the worker did not exit immediately. The
worker client replaces this probe when pipe connection, hello, and
`WorkerReady` handling are implemented.
If startup fails or exceeds `WorkerOptions.StartupTimeoutSeconds`, the launcher
kills the worker process tree, disposes the process handle, disposes the
optional pipe reservation, records a worker kill metric, and reports a
`WorkerProcessLaunchException`.
## Verification
Run the focused launcher tests after changing process launch behavior:
```bash
dotnet test src/MxGateway.Tests/MxGateway.Tests.csproj --filter WorkerProcessLauncherTests
```
Run the gateway build because the launcher is part of `MxGateway.Server`:
```bash
dotnet build src/MxGateway.Server/MxGateway.Server.csproj
```
## Related Documentation
- [Gateway Process Detailed Design](./gateway-process-design.md)
- [Worker Frame Protocol](./WorkerFrameProtocol.md)
+74 -2
View File
@@ -360,6 +360,15 @@ Before launch, validate:
- worker file version or product version is acceptable,
- worker is expected to be x86.
`WorkerProcessLauncher` implements the first validation layer now: it resolves
the worker executable path, requires a `.exe`, validates the Windows Portable
Executable header, and verifies the configured processor architecture. It passes
only `--session-id`, `--pipe-name`, and `--protocol-version` on the command
line. The per-session nonce is set through `MXGATEWAY_WORKER_NONCE` so the
command line remains safe to log. Startup failures and startup timeouts kill and
dispose the worker process and the pre-created pipe reservation before the
session manager observes the failure.
## Worker IPC
The gateway creates the pipe server before launching the worker.
@@ -402,7 +411,7 @@ session ids as protocol faults and close the session.
`WorkerClient` is the gateway-side object that owns one worker connection.
Suggested public shape:
Current public shape:
```csharp
public interface IWorkerClient : IAsyncDisposable
@@ -410,6 +419,7 @@ public interface IWorkerClient : IAsyncDisposable
string SessionId { get; }
int? ProcessId { get; }
WorkerClientState State { get; }
DateTimeOffset LastHeartbeatAt { get; }
Task StartAsync(CancellationToken cancellationToken);
Task<WorkerCommandReply> InvokeAsync(
@@ -429,12 +439,17 @@ Internally it owns:
- pipe stream,
- read loop,
- write loop,
- bounded outbound command/control channel,
- outbound command/control channel serialized by the write loop,
- bounded inbound event channel,
- pending command dictionary keyed by correlation id,
- heartbeat monitor,
- terminal fault source.
`StartAsync` sends `GatewayHello`, verifies the `WorkerHello` protocol version
and nonce, waits for `WorkerReady`, and only then exposes `Ready` state. The
read loop starts after readiness so the handshake has a single owner for its
ordered frames.
### Read Loop
The read loop:
@@ -589,6 +604,29 @@ The gateway should split the key into a stable key id and secret component,
load the key record by id, hash the presented secret, and compare using a
constant-time comparison.
`ApiKeyParser` accepts only `authorization: Bearer mxgw_<key-id>_<secret>`.
Malformed headers fail before any database lookup. The parsed raw secret is
kept only long enough for `ApiKeySecretHasher` to compute an HMAC-SHA256 hash
using the configured `Authentication:PepperSecretName` lookup in application
configuration. The raw secret is not stored in the auth database, identity
model, logs, or verification result.
`ApiKeyVerifier` loads the stored key record by key id, rejects revoked keys,
hashes the presented secret, and compares the stored and presented hashes with
`CryptographicOperations.FixedTimeEquals`. A successful verification returns an
`ApiKeyIdentity` with key id, key prefix, display name, and scopes. Failure
results distinguish malformed credentials, missing keys, revoked keys, missing
pepper configuration, and hash mismatch for internal authorization handling.
`GatewayGrpcAuthorizationInterceptor` enforces this authentication model for
public gRPC calls. Missing, malformed, revoked, unknown, or mismatched keys fail
with `Unauthenticated`. Authenticated calls missing the scope required by the
RPC fail with `PermissionDenied`. The interceptor applies to unary calls and
server-streaming calls and stores the authenticated `ApiKeyIdentity` in
`IGatewayRequestIdentityAccessor` for the duration of the request handler.
`Authentication:Mode` set to `Disabled` bypasses API-key verification for local
development only.
Recommended scopes:
- `session:open`
@@ -608,6 +646,23 @@ gRPC admin API. It should initialize the auth database, create keys, list keys
without secrets, revoke keys, rotate keys, and print raw secrets only once at
creation.
`MxGateway.Server` exposes local API-key administration as an `apikey`
subcommand before the web host starts:
```bash
MxGateway.Server apikey init-db --sqlite-path C:\ProgramData\MxGateway\gateway-auth.db
MxGateway.Server apikey create-key --key-id operator01 --display-name Operator --scopes session:open,events:read
MxGateway.Server apikey list-keys --json
MxGateway.Server apikey revoke-key --key-id operator01
MxGateway.Server apikey rotate-key --key-id operator01 --json
```
The subcommands accept `--sqlite-path`, `--pepper`, and `--json`. `--pepper`
sets the local `MxGateway:ApiKeyPepper` configuration value for the command
process; deployments should normally provide the pepper through the configured
secret source. `create-key` and `rotate-key` print the full raw API key exactly
once. `list-keys` never prints raw secrets or `secret_hash` values.
SQLite auth storage should use startup migrations with a `schema_version` table.
Migrations should run inside transactions and fail startup if the database
schema is newer than the running binary understands.
@@ -637,6 +692,20 @@ Commands requiring authorization:
- worker shutdown diagnostics,
- metadata queries if they expose sensitive plant structure.
Current gRPC scope mapping:
- `OpenSession` requires `session:open`.
- `CloseSession` requires `session:close`.
- `StreamEvents` and `DrainEvents` require `events:read`.
- read-style MXAccess commands such as `Register`, `AddItem`, `Advise`, and
`Ping` require `invoke:read`.
- `Write` and `Write2` require `invoke:write`.
- `WriteSecured`, `WriteSecured2`, and `AuthenticateUser` require
`invoke:secure`.
- metadata commands such as `ArchestrAUserToId`, `GetSessionState`, and
`GetWorkerInfo` require `metadata:read`.
- `ShutdownWorker` requires `admin`.
### Worker IPC
Named pipes should be local only. Pipe ACLs should restrict access to:
@@ -779,6 +848,9 @@ workers and fake transports.
Focused tests:
- session state transitions,
- gRPC API-key authentication for unary and streaming calls,
- gRPC scope mapping for sessions, invokes, events, metadata, and admin
commands,
- worker startup failures,
- protocol version mismatch,
- malformed frame handling,
+53
View File
@@ -26,6 +26,33 @@ Style guides:
- [C# Style Guide](./style-guides/CSharpStyleGuide.md)
- [Protobuf Style Guide](./style-guides/ProtobufStyleGuide.md)
## Build And Test
Build the SDK-style worker project with the .NET SDK MSBuild entry point. The
project targets .NET Framework 4.8, but the SDK resolver comes from the .NET SDK
installation:
```powershell
dotnet msbuild src\MxGateway.Worker\MxGateway.Worker.csproj /restore /p:Configuration=Debug /p:Platform=x86
```
`docs/toolchain-links.md` records the Visual Studio MSBuild executable for
classic .NET Framework and COM interop builds:
```powershell
& "C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\MSBuild\Current\Bin\MSBuild.exe" src\MxGateway.Worker\MxGateway.Worker.csproj /p:Configuration=Debug /p:Platform=x86
```
Run the worker tests with the same platform target:
```powershell
dotnet test src\MxGateway.Worker.Tests\MxGateway.Worker.Tests.csproj -p:Platform=x86
```
The only MXAccess interop reference belongs in `MxGateway.Worker`. Gateway and
test projects may reference the worker project for metadata and scaffold tests,
but they must not reference `ArchestrA.MXAccess.dll` directly.
## Responsibilities
The worker owns:
@@ -87,6 +114,21 @@ Startup sequence:
If validation fails before MXAccess creation, exit quickly with a non-zero exit
code. If MXAccess creation fails, send `WorkerFault` when possible and exit.
The bootstrap layer returns structured exit codes before it creates pipes,
starts the STA, or touches MXAccess:
| Exit code | Name | Meaning |
|-----------|------|---------|
| `0` | `Success` | Required bootstrap options are valid. |
| `1` | `UnexpectedFailure` | A non-bootstrap exception reaches the process boundary. |
| `2` | `InvalidArguments` | Required arguments are missing or unknown arguments are present. |
| `3` | `InvalidProtocolVersion` | `--protocol-version` is not numeric or does not match the supported worker protocol. |
| `4` | `MissingNonce` | `MXGATEWAY_WORKER_NONCE` is absent or empty. |
Bootstrap logs use `WorkerConsoleLogger` key/value output. `WorkerLogRedactor`
redacts fields whose names indicate nonce, secret, password, token,
credential, or API key values before the message is written.
## Internal Components
```text
@@ -208,6 +250,17 @@ The loop should update a heartbeat timestamp after:
- finishing a command,
- processing an MXAccess event.
`StaRuntime` implements this runtime boundary in the worker. It starts one
background thread named `MxGateway.Worker.STA`, sets it to `ApartmentState.STA`,
initializes COM through `StaComApartmentInitializer`, and runs
`StaMessagePump`. Commands are scheduled through `InvokeAsync`; the command
queue signals an `AutoResetEvent` so `MsgWaitForMultipleObjectsEx` can wake the
STA without busy-waiting. `LastActivityUtc` records pump, command, startup, and
shutdown activity so the future heartbeat/watchdog can report whether the STA
is still responsive. Shutdown marks the runtime as closing, wakes the pump,
rejects new commands, cancels queued work, uninitializes COM on the STA, and
waits for the thread to exit.
## COM Creation
The MXAccess analysis source at `C:\Users\dohertj2\Desktop\mxaccess` identifies
+9 -3
View File
@@ -47,6 +47,8 @@ Detailed follow-up docs:
security, observability, and test strategy.
- `docs/WorkerFrameProtocol.md` covers the gateway-side named-pipe frame
reader/writer and `WorkerEnvelope` validation rules.
- `docs/WorkerProcessLauncher.md` covers worker executable validation, process
launch arguments, nonce handling, and startup cleanup behavior.
- `docs/mxaccess-worker-instance-design.md` covers each .NET Framework 4.8 x86
MXAccess worker instance, including STA ownership, message pumping, COM
lifetime, command dispatch, event sinks, conversion, and shutdown.
@@ -564,9 +566,13 @@ Because each client owns one worker, a crash or leak affects only that session.
External gateway:
- use TLS for remote gRPC if crossing machine boundaries,
- authenticate clients with Windows auth, mTLS, or a deployment-specific token,
- authorize access to commands that can write, authenticate users, or alter
runtime state.
- authenticate v1 gRPC clients with `authorization: Bearer
mxgw_<key-id>_<secret>` API-key metadata,
- reject missing or invalid API keys with gRPC `Unauthenticated`,
- reject valid keys that lack the required session, invoke, event, metadata, or
admin scope with gRPC `PermissionDenied`,
- authorize access to commands that can write, authenticate users, expose
metadata, stream events, or alter runtime state.
Internal worker IPC:
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net10.0</TargetFramework>
<TargetFrameworks>net10.0;net48</TargetFrameworks>
</PropertyGroup>
<ItemGroup>
@@ -17,6 +17,7 @@
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="System.Runtime.CompilerServices.Unsafe" Version="6.1.2" />
</ItemGroup>
</Project>
@@ -3,6 +3,8 @@ using MxGateway.Server.Configuration;
using MxGateway.Server.Diagnostics;
using MxGateway.Server.Metrics;
using MxGateway.Server.Security.Authentication;
using MxGateway.Server.Security.Authorization;
using MxGateway.Server.Workers;
namespace MxGateway.Server;
@@ -25,8 +27,10 @@ public static class GatewayApplication
builder.Services.AddGatewayConfiguration();
builder.Services.AddSqliteAuthStore();
builder.Services.AddGatewayGrpcAuthorization();
builder.Services.AddHealthChecks();
builder.Services.AddSingleton<GatewayMetrics>();
builder.Services.AddWorkerProcessLauncher();
return builder;
}
@@ -5,6 +5,7 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Grpc.AspNetCore" Version="2.76.0" />
<PackageReference Include="Microsoft.Data.Sqlite" Version="10.0.7" />
</ItemGroup>
+37 -1
View File
@@ -1,7 +1,43 @@
using MxGateway.Server;
using MxGateway.Server.Configuration;
using MxGateway.Server.Security.Authentication;
var app = GatewayApplication.Build(args);
ApiKeyAdminParseResult apiKeyAdminCommand = ApiKeyAdminCommandLineParser.Parse(args);
if (apiKeyAdminCommand.IsApiKeyCommand)
{
if (apiKeyAdminCommand.Command is null)
{
await Console.Error.WriteLineAsync(apiKeyAdminCommand.Error);
return 2;
}
WebApplicationBuilder builder = GatewayApplication.CreateBuilder([]);
ApplyApiKeyAdminOverrides(builder.Configuration, apiKeyAdminCommand.Command);
await using WebApplication cliApp = builder.Build();
await using AsyncServiceScope scope = cliApp.Services.CreateAsyncScope();
ApiKeyAdminCliRunner runner = scope.ServiceProvider.GetRequiredService<ApiKeyAdminCliRunner>();
return await runner.RunAsync(apiKeyAdminCommand.Command, Console.Out, CancellationToken.None);
}
WebApplication app = GatewayApplication.Build(args);
app.Run();
return 0;
static void ApplyApiKeyAdminOverrides(IConfiguration configuration, ApiKeyAdminCommand command)
{
if (!string.IsNullOrWhiteSpace(command.SqlitePath))
{
configuration[$"{GatewayOptions.SectionName}:Authentication:SqlitePath"] = command.SqlitePath;
}
if (!string.IsNullOrWhiteSpace(command.Pepper))
{
configuration["MxGateway:ApiKeyPepper"] = command.Pepper;
}
}
public partial class Program;
@@ -0,0 +1,180 @@
using System.Text.Json;
namespace MxGateway.Server.Security.Authentication;
public sealed class ApiKeyAdminCliRunner(
IAuthStoreMigrator migrator,
IApiKeyAdminStore adminStore,
IApiKeyAuditStore auditStore,
IApiKeySecretHasher hasher)
{
private static readonly JsonSerializerOptions JsonOptions = new()
{
WriteIndented = true
};
public async Task<int> RunAsync(
ApiKeyAdminCommand command,
TextWriter output,
CancellationToken cancellationToken)
{
ApiKeyAdminOutput result = command.Kind switch
{
ApiKeyAdminCommandKind.InitDb => await InitDbAsync(cancellationToken).ConfigureAwait(false),
ApiKeyAdminCommandKind.CreateKey => await CreateKeyAsync(command, cancellationToken).ConfigureAwait(false),
ApiKeyAdminCommandKind.ListKeys => await ListKeysAsync(cancellationToken).ConfigureAwait(false),
ApiKeyAdminCommandKind.RevokeKey => await RevokeKeyAsync(command, cancellationToken).ConfigureAwait(false),
ApiKeyAdminCommandKind.RotateKey => await RotateKeyAsync(command, cancellationToken).ConfigureAwait(false),
_ => throw new InvalidOperationException($"Unsupported API key command '{command.Kind}'.")
};
await WriteOutputAsync(command, result, output).ConfigureAwait(false);
return 0;
}
private async Task<ApiKeyAdminOutput> InitDbAsync(CancellationToken cancellationToken)
{
await migrator.MigrateAsync(cancellationToken).ConfigureAwait(false);
await AppendAuditAsync(null, "init-db", null, cancellationToken).ConfigureAwait(false);
return new ApiKeyAdminOutput("init-db", "initialized", null, []);
}
private async Task<ApiKeyAdminOutput> CreateKeyAsync(
ApiKeyAdminCommand command,
CancellationToken cancellationToken)
{
await migrator.MigrateAsync(cancellationToken).ConfigureAwait(false);
string keyId = Required(command.KeyId);
string secret = ApiKeySecretGenerator.Generate();
string apiKey = FormatApiKey(keyId, secret);
await adminStore.CreateAsync(
new ApiKeyCreateRequest(
KeyId: keyId,
KeyPrefix: $"mxgw_{keyId}",
SecretHash: hasher.HashSecret(secret),
DisplayName: Required(command.DisplayName),
Scopes: command.Scopes,
CreatedUtc: DateTimeOffset.UtcNow),
cancellationToken)
.ConfigureAwait(false);
await AppendAuditAsync(keyId, "create-key", null, cancellationToken).ConfigureAwait(false);
return new ApiKeyAdminOutput("create-key", "created", apiKey, []);
}
private async Task<ApiKeyAdminOutput> ListKeysAsync(CancellationToken cancellationToken)
{
await migrator.MigrateAsync(cancellationToken).ConfigureAwait(false);
IReadOnlyList<ApiKeyRecord> keys = await adminStore.ListAsync(cancellationToken).ConfigureAwait(false);
await AppendAuditAsync(null, "list-keys", null, cancellationToken).ConfigureAwait(false);
return new ApiKeyAdminOutput(
"list-keys",
"ok",
null,
keys.Select(ToListedKey).ToArray());
}
private async Task<ApiKeyAdminOutput> RevokeKeyAsync(
ApiKeyAdminCommand command,
CancellationToken cancellationToken)
{
await migrator.MigrateAsync(cancellationToken).ConfigureAwait(false);
string keyId = Required(command.KeyId);
bool revoked = await adminStore.RevokeAsync(keyId, DateTimeOffset.UtcNow, cancellationToken)
.ConfigureAwait(false);
await AppendAuditAsync(keyId, "revoke-key", revoked ? "revoked" : "not-found-or-already-revoked", cancellationToken)
.ConfigureAwait(false);
return new ApiKeyAdminOutput("revoke-key", revoked ? "revoked" : "not-found-or-already-revoked", null, []);
}
private async Task<ApiKeyAdminOutput> RotateKeyAsync(
ApiKeyAdminCommand command,
CancellationToken cancellationToken)
{
await migrator.MigrateAsync(cancellationToken).ConfigureAwait(false);
string keyId = Required(command.KeyId);
string secret = ApiKeySecretGenerator.Generate();
string apiKey = FormatApiKey(keyId, secret);
bool rotated = await adminStore.RotateAsync(keyId, hasher.HashSecret(secret), DateTimeOffset.UtcNow, cancellationToken)
.ConfigureAwait(false);
await AppendAuditAsync(keyId, "rotate-key", rotated ? "rotated" : "not-found", cancellationToken)
.ConfigureAwait(false);
return new ApiKeyAdminOutput("rotate-key", rotated ? "rotated" : "not-found", rotated ? apiKey : null, []);
}
private static async Task WriteOutputAsync(
ApiKeyAdminCommand command,
ApiKeyAdminOutput result,
TextWriter output)
{
if (command.Json)
{
await output.WriteLineAsync(JsonSerializer.Serialize(result, JsonOptions)).ConfigureAwait(false);
return;
}
await output.WriteLineAsync($"{result.Command}: {result.Status}").ConfigureAwait(false);
if (result.ApiKey is not null)
{
await output.WriteLineAsync($"API key: {result.ApiKey}").ConfigureAwait(false);
}
foreach (ApiKeyAdminListedKey key in result.Keys)
{
string revoked = key.RevokedUtc is null ? "active" : "revoked";
await output.WriteLineAsync($"{key.KeyId}\t{key.DisplayName}\t{revoked}\t{string.Join(',', key.Scopes)}")
.ConfigureAwait(false);
}
}
private async Task AppendAuditAsync(
string? keyId,
string eventType,
string? details,
CancellationToken cancellationToken)
{
await auditStore.AppendAsync(
new ApiKeyAuditEntry(
KeyId: keyId,
EventType: eventType,
RemoteAddress: null,
Details: details),
cancellationToken)
.ConfigureAwait(false);
}
private static ApiKeyAdminListedKey ToListedKey(ApiKeyRecord key)
{
return new ApiKeyAdminListedKey(
KeyId: key.KeyId,
KeyPrefix: key.KeyPrefix,
DisplayName: key.DisplayName,
Scopes: key.Scopes,
CreatedUtc: key.CreatedUtc,
LastUsedUtc: key.LastUsedUtc,
RevokedUtc: key.RevokedUtc);
}
private static string FormatApiKey(string keyId, string secret)
{
return $"mxgw_{keyId}_{secret}";
}
private static string Required(string? value)
{
return value ?? throw new InvalidOperationException("Required command value was not provided.");
}
}
@@ -0,0 +1,10 @@
namespace MxGateway.Server.Security.Authentication;
public sealed record ApiKeyAdminCommand(
ApiKeyAdminCommandKind Kind,
bool Json,
string? SqlitePath,
string? Pepper,
string? KeyId,
string? DisplayName,
IReadOnlySet<string> Scopes);
@@ -0,0 +1,10 @@
namespace MxGateway.Server.Security.Authentication;
public enum ApiKeyAdminCommandKind
{
InitDb,
CreateKey,
ListKeys,
RevokeKey,
RotateKey
}
@@ -0,0 +1,159 @@
namespace MxGateway.Server.Security.Authentication;
public static class ApiKeyAdminCommandLineParser
{
public static ApiKeyAdminParseResult Parse(IReadOnlyList<string> args)
{
if (args.Count == 0 || !string.Equals(args[0], "apikey", StringComparison.OrdinalIgnoreCase))
{
return ApiKeyAdminParseResult.NotApiKeyCommand();
}
if (args.Count < 2)
{
return ApiKeyAdminParseResult.Fail("Missing apikey subcommand.");
}
if (!TryParseKind(args[1], out ApiKeyAdminCommandKind kind))
{
return ApiKeyAdminParseResult.Fail($"Unknown apikey subcommand '{args[1]}'.");
}
Dictionary<string, string?> options = new(StringComparer.OrdinalIgnoreCase);
bool json = false;
for (int index = 2; index < args.Count; index++)
{
string arg = args[index];
if (string.Equals(arg, "--json", StringComparison.OrdinalIgnoreCase))
{
json = true;
continue;
}
if (!arg.StartsWith("--", StringComparison.Ordinal))
{
return ApiKeyAdminParseResult.Fail($"Unexpected argument '{arg}'.");
}
string name = arg[2..];
string? value;
int equalsIndex = name.IndexOf('=', StringComparison.Ordinal);
if (equalsIndex >= 0)
{
value = name[(equalsIndex + 1)..];
name = name[..equalsIndex];
}
else
{
if (index + 1 >= args.Count || args[index + 1].StartsWith("--", StringComparison.Ordinal))
{
return ApiKeyAdminParseResult.Fail($"Option '--{name}' requires a value.");
}
value = args[++index];
}
options[name] = value;
}
string? keyId = GetOption(options, "key-id");
string? displayName = GetOption(options, "display-name");
IReadOnlySet<string> scopes = ParseScopes(GetOption(options, "scopes"));
string? validationError = Validate(kind, keyId, displayName);
if (validationError is not null)
{
return ApiKeyAdminParseResult.Fail(validationError);
}
return ApiKeyAdminParseResult.Success(new ApiKeyAdminCommand(
Kind: kind,
Json: json,
SqlitePath: GetOption(options, "sqlite-path"),
Pepper: GetOption(options, "pepper"),
KeyId: keyId,
DisplayName: displayName,
Scopes: scopes));
}
private static bool TryParseKind(string value, out ApiKeyAdminCommandKind kind)
{
switch (value.ToLowerInvariant())
{
case "init-db":
kind = ApiKeyAdminCommandKind.InitDb;
return true;
case "create-key":
kind = ApiKeyAdminCommandKind.CreateKey;
return true;
case "list-keys":
kind = ApiKeyAdminCommandKind.ListKeys;
return true;
case "revoke-key":
kind = ApiKeyAdminCommandKind.RevokeKey;
return true;
case "rotate-key":
kind = ApiKeyAdminCommandKind.RotateKey;
return true;
default:
kind = default;
return false;
}
}
private static string? Validate(ApiKeyAdminCommandKind kind, string? keyId, string? displayName)
{
if (kind is ApiKeyAdminCommandKind.CreateKey or ApiKeyAdminCommandKind.RevokeKey or ApiKeyAdminCommandKind.RotateKey
&& string.IsNullOrWhiteSpace(keyId))
{
return $"Subcommand '{KindName(kind)}' requires --key-id.";
}
if (!string.IsNullOrWhiteSpace(keyId) && !IsValidKeyId(keyId))
{
return "API key id may contain only letters, numbers, periods, and hyphens.";
}
if (kind == ApiKeyAdminCommandKind.CreateKey && string.IsNullOrWhiteSpace(displayName))
{
return "Subcommand 'create-key' requires --display-name.";
}
return null;
}
private static string KindName(ApiKeyAdminCommandKind kind)
{
return kind switch
{
ApiKeyAdminCommandKind.InitDb => "init-db",
ApiKeyAdminCommandKind.CreateKey => "create-key",
ApiKeyAdminCommandKind.ListKeys => "list-keys",
ApiKeyAdminCommandKind.RevokeKey => "revoke-key",
ApiKeyAdminCommandKind.RotateKey => "rotate-key",
_ => kind.ToString()
};
}
private static bool IsValidKeyId(string keyId)
{
return keyId.All(character =>
char.IsAsciiLetterOrDigit(character)
|| character is '.' or '-');
}
private static string? GetOption(Dictionary<string, string?> options, string name)
{
return options.TryGetValue(name, out string? value) ? value : null;
}
private static IReadOnlySet<string> ParseScopes(string? scopes)
{
return new HashSet<string>(
(scopes ?? string.Empty)
.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries),
StringComparer.Ordinal);
}
}
@@ -0,0 +1,10 @@
namespace MxGateway.Server.Security.Authentication;
public sealed record ApiKeyAdminListedKey(
string KeyId,
string KeyPrefix,
string DisplayName,
IReadOnlySet<string> Scopes,
DateTimeOffset CreatedUtc,
DateTimeOffset? LastUsedUtc,
DateTimeOffset? RevokedUtc);
@@ -0,0 +1,7 @@
namespace MxGateway.Server.Security.Authentication;
public sealed record ApiKeyAdminOutput(
string Command,
string Status,
string? ApiKey,
IReadOnlyList<ApiKeyAdminListedKey> Keys);
@@ -0,0 +1,22 @@
namespace MxGateway.Server.Security.Authentication;
public sealed record ApiKeyAdminParseResult(
bool IsApiKeyCommand,
ApiKeyAdminCommand? Command,
string? Error)
{
public static ApiKeyAdminParseResult NotApiKeyCommand()
{
return new ApiKeyAdminParseResult(false, null, null);
}
public static ApiKeyAdminParseResult Success(ApiKeyAdminCommand command)
{
return new ApiKeyAdminParseResult(true, command, null);
}
public static ApiKeyAdminParseResult Fail(string error)
{
return new ApiKeyAdminParseResult(true, null, error);
}
}
@@ -0,0 +1,9 @@
namespace MxGateway.Server.Security.Authentication;
public sealed record ApiKeyCreateRequest(
string KeyId,
string KeyPrefix,
byte[] SecretHash,
string DisplayName,
IReadOnlySet<string> Scopes,
DateTimeOffset CreatedUtc);
@@ -0,0 +1,7 @@
namespace MxGateway.Server.Security.Authentication;
public sealed record ApiKeyIdentity(
string KeyId,
string KeyPrefix,
string DisplayName,
IReadOnlySet<string> Scopes);
@@ -0,0 +1,45 @@
namespace MxGateway.Server.Security.Authentication;
public sealed class ApiKeyParser : IApiKeyParser
{
private const string BearerPrefix = "Bearer ";
private const string TokenPrefix = "mxgw_";
public bool TryParseAuthorizationHeader(string? authorizationHeader, out ParsedApiKey? apiKey)
{
apiKey = null;
if (string.IsNullOrWhiteSpace(authorizationHeader)
|| !authorizationHeader.StartsWith(BearerPrefix, StringComparison.OrdinalIgnoreCase))
{
return false;
}
string token = authorizationHeader[BearerPrefix.Length..].Trim();
if (!token.StartsWith(TokenPrefix, StringComparison.OrdinalIgnoreCase))
{
return false;
}
string keyPayload = token[TokenPrefix.Length..];
int separatorIndex = keyPayload.IndexOf('_', StringComparison.Ordinal);
if (separatorIndex <= 0 || separatorIndex == keyPayload.Length - 1)
{
return false;
}
string keyId = keyPayload[..separatorIndex];
string secret = keyPayload[(separatorIndex + 1)..];
if (string.IsNullOrWhiteSpace(keyId) || string.IsNullOrWhiteSpace(secret))
{
return false;
}
apiKey = new ParsedApiKey(keyId, secret);
return true;
}
}
@@ -0,0 +1,4 @@
namespace MxGateway.Server.Security.Authentication;
public sealed class ApiKeyPepperUnavailableException(string pepperSecretName)
: InvalidOperationException($"API key pepper secret '{pepperSecretName}' is not configured.");
@@ -0,0 +1,26 @@
using Microsoft.Data.Sqlite;
namespace MxGateway.Server.Security.Authentication;
public static class ApiKeyRecordReader
{
public static ApiKeyRecord Read(SqliteDataReader reader)
{
return new ApiKeyRecord(
KeyId: reader.GetString(0),
KeyPrefix: reader.GetString(1),
SecretHash: (byte[])reader["secret_hash"],
DisplayName: reader.GetString(3),
Scopes: ApiKeyScopeSerializer.Deserialize(reader.GetString(4)),
CreatedUtc: DateTimeOffset.Parse(reader.GetString(5), System.Globalization.CultureInfo.InvariantCulture),
LastUsedUtc: ReadNullableDateTimeOffset(reader, 6),
RevokedUtc: ReadNullableDateTimeOffset(reader, 7));
}
private static DateTimeOffset? ReadNullableDateTimeOffset(SqliteDataReader reader, int ordinal)
{
return reader.IsDBNull(ordinal)
? null
: DateTimeOffset.Parse(reader.GetString(ordinal), System.Globalization.CultureInfo.InvariantCulture);
}
}
@@ -0,0 +1,17 @@
using System.Security.Cryptography;
namespace MxGateway.Server.Security.Authentication;
public static class ApiKeySecretGenerator
{
public static string Generate()
{
Span<byte> bytes = stackalloc byte[32];
RandomNumberGenerator.Fill(bytes);
return Convert.ToBase64String(bytes)
.TrimEnd('=')
.Replace('+', '-')
.Replace('/', '_');
}
}
@@ -0,0 +1,35 @@
using System.Security.Cryptography;
using System.Text;
using Microsoft.Extensions.Options;
using MxGateway.Server.Configuration;
namespace MxGateway.Server.Security.Authentication;
public sealed class ApiKeySecretHasher(
IConfiguration configuration,
IOptions<GatewayOptions> options) : IApiKeySecretHasher
{
public byte[] HashSecret(string secret)
{
string pepper = GetPepper();
byte[] pepperBytes = Encoding.UTF8.GetBytes(pepper);
byte[] secretBytes = Encoding.UTF8.GetBytes(secret);
using HMACSHA256 hmac = new(pepperBytes);
return hmac.ComputeHash(secretBytes);
}
private string GetPepper()
{
string pepperSecretName = options.Value.Authentication.PepperSecretName;
string? pepper = configuration[pepperSecretName];
if (string.IsNullOrWhiteSpace(pepper))
{
throw new ApiKeyPepperUnavailableException(pepperSecretName);
}
return pepper;
}
}
@@ -0,0 +1,11 @@
namespace MxGateway.Server.Security.Authentication;
public enum ApiKeyVerificationFailure
{
None,
MissingOrMalformedCredentials,
PepperUnavailable,
KeyNotFound,
KeyRevoked,
SecretMismatch
}
@@ -0,0 +1,23 @@
namespace MxGateway.Server.Security.Authentication;
public sealed record ApiKeyVerificationResult(
bool Succeeded,
ApiKeyIdentity? Identity,
ApiKeyVerificationFailure Failure)
{
public static ApiKeyVerificationResult Success(ApiKeyIdentity identity)
{
return new ApiKeyVerificationResult(
Succeeded: true,
Identity: identity,
Failure: ApiKeyVerificationFailure.None);
}
public static ApiKeyVerificationResult Fail(ApiKeyVerificationFailure failure)
{
return new ApiKeyVerificationResult(
Succeeded: false,
Identity: null,
Failure: failure);
}
}
@@ -0,0 +1,57 @@
using System.Security.Cryptography;
namespace MxGateway.Server.Security.Authentication;
public sealed class ApiKeyVerifier(
IApiKeyParser parser,
IApiKeySecretHasher hasher,
IApiKeyStore keyStore) : IApiKeyVerifier
{
public async Task<ApiKeyVerificationResult> VerifyAsync(
string? authorizationHeader,
CancellationToken cancellationToken)
{
if (!parser.TryParseAuthorizationHeader(authorizationHeader, out ParsedApiKey? parsedKey)
|| parsedKey is null)
{
return ApiKeyVerificationResult.Fail(ApiKeyVerificationFailure.MissingOrMalformedCredentials);
}
ApiKeyRecord? storedKey = await keyStore.FindByKeyIdAsync(parsedKey.KeyId, cancellationToken)
.ConfigureAwait(false);
if (storedKey is null)
{
return ApiKeyVerificationResult.Fail(ApiKeyVerificationFailure.KeyNotFound);
}
if (storedKey.RevokedUtc is not null)
{
return ApiKeyVerificationResult.Fail(ApiKeyVerificationFailure.KeyRevoked);
}
byte[] presentedHash;
try
{
presentedHash = hasher.HashSecret(parsedKey.Secret);
}
catch (ApiKeyPepperUnavailableException)
{
return ApiKeyVerificationResult.Fail(ApiKeyVerificationFailure.PepperUnavailable);
}
if (!CryptographicOperations.FixedTimeEquals(presentedHash, storedKey.SecretHash))
{
return ApiKeyVerificationResult.Fail(ApiKeyVerificationFailure.SecretMismatch);
}
await keyStore.MarkKeyUsedAsync(storedKey.KeyId, DateTimeOffset.UtcNow, cancellationToken)
.ConfigureAwait(false);
return ApiKeyVerificationResult.Success(new ApiKeyIdentity(
KeyId: storedKey.KeyId,
KeyPrefix: storedKey.KeyPrefix,
DisplayName: storedKey.DisplayName,
Scopes: storedKey.Scopes));
}
}
@@ -4,9 +4,14 @@ public static class AuthStoreServiceCollectionExtensions
{
public static IServiceCollection AddSqliteAuthStore(this IServiceCollection services)
{
services.AddSingleton<IApiKeyParser, ApiKeyParser>();
services.AddSingleton<IApiKeySecretHasher, ApiKeySecretHasher>();
services.AddSingleton<IApiKeyVerifier, ApiKeyVerifier>();
services.AddSingleton<ApiKeyAdminCliRunner>();
services.AddSingleton<AuthSqliteConnectionFactory>();
services.AddSingleton<IAuthStoreMigrator, SqliteAuthStoreMigrator>();
services.AddSingleton<IApiKeyStore, SqliteApiKeyStore>();
services.AddSingleton<IApiKeyAdminStore, SqliteApiKeyAdminStore>();
services.AddSingleton<IApiKeyAuditStore, SqliteApiKeyAuditStore>();
services.AddHostedService<AuthStoreMigrationHostedService>();
@@ -0,0 +1,16 @@
namespace MxGateway.Server.Security.Authentication;
public interface IApiKeyAdminStore
{
Task CreateAsync(ApiKeyCreateRequest request, CancellationToken cancellationToken);
Task<IReadOnlyList<ApiKeyRecord>> ListAsync(CancellationToken cancellationToken);
Task<bool> RevokeAsync(string keyId, DateTimeOffset revokedUtc, CancellationToken cancellationToken);
Task<bool> RotateAsync(
string keyId,
byte[] secretHash,
DateTimeOffset rotatedUtc,
CancellationToken cancellationToken);
}
@@ -0,0 +1,6 @@
namespace MxGateway.Server.Security.Authentication;
public interface IApiKeyParser
{
bool TryParseAuthorizationHeader(string? authorizationHeader, out ParsedApiKey? apiKey);
}
@@ -0,0 +1,6 @@
namespace MxGateway.Server.Security.Authentication;
public interface IApiKeySecretHasher
{
byte[] HashSecret(string secret);
}
@@ -0,0 +1,8 @@
namespace MxGateway.Server.Security.Authentication;
public interface IApiKeyVerifier
{
Task<ApiKeyVerificationResult> VerifyAsync(
string? authorizationHeader,
CancellationToken cancellationToken);
}
@@ -0,0 +1,3 @@
namespace MxGateway.Server.Security.Authentication;
public sealed record ParsedApiKey(string KeyId, string Secret);
@@ -0,0 +1,116 @@
using Microsoft.Data.Sqlite;
namespace MxGateway.Server.Security.Authentication;
public sealed class SqliteApiKeyAdminStore(AuthSqliteConnectionFactory connectionFactory) : IApiKeyAdminStore
{
public async Task CreateAsync(ApiKeyCreateRequest request, CancellationToken cancellationToken)
{
await using SqliteConnection connection = connectionFactory.CreateConnection();
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
await using SqliteCommand command = connection.CreateCommand();
command.CommandText = """
INSERT INTO api_keys (
key_id,
key_prefix,
secret_hash,
display_name,
scopes,
created_utc,
last_used_utc,
revoked_utc)
VALUES (
$key_id,
$key_prefix,
$secret_hash,
$display_name,
$scopes,
$created_utc,
NULL,
NULL);
""";
AddCreateParameters(command, request);
await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
}
public async Task<IReadOnlyList<ApiKeyRecord>> ListAsync(CancellationToken cancellationToken)
{
await using SqliteConnection connection = connectionFactory.CreateConnection();
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
await using SqliteCommand command = connection.CreateCommand();
command.CommandText = """
SELECT key_id, key_prefix, secret_hash, display_name, scopes, created_utc, last_used_utc, revoked_utc
FROM api_keys
ORDER BY key_id;
""";
List<ApiKeyRecord> records = [];
await using SqliteDataReader reader = await command.ExecuteReaderAsync(cancellationToken)
.ConfigureAwait(false);
while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false))
{
records.Add(ApiKeyRecordReader.Read(reader));
}
return records;
}
public async Task<bool> RevokeAsync(string keyId, DateTimeOffset revokedUtc, CancellationToken cancellationToken)
{
await using SqliteConnection connection = connectionFactory.CreateConnection();
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
await using SqliteCommand command = connection.CreateCommand();
command.CommandText = """
UPDATE api_keys
SET revoked_utc = $revoked_utc
WHERE key_id = $key_id AND revoked_utc IS NULL;
""";
command.Parameters.AddWithValue("$key_id", keyId);
command.Parameters.AddWithValue("$revoked_utc", revokedUtc.ToString("O"));
int rows = await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
return rows > 0;
}
public async Task<bool> RotateAsync(
string keyId,
byte[] secretHash,
DateTimeOffset rotatedUtc,
CancellationToken cancellationToken)
{
await using SqliteConnection connection = connectionFactory.CreateConnection();
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
await using SqliteCommand command = connection.CreateCommand();
command.CommandText = """
UPDATE api_keys
SET secret_hash = $secret_hash,
last_used_utc = NULL,
revoked_utc = NULL
WHERE key_id = $key_id;
""";
command.Parameters.AddWithValue("$key_id", keyId);
command.Parameters.Add("$secret_hash", SqliteType.Blob).Value = secretHash;
int rows = await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
return rows > 0;
}
private static void AddCreateParameters(SqliteCommand command, ApiKeyCreateRequest request)
{
command.Parameters.AddWithValue("$key_id", request.KeyId);
command.Parameters.AddWithValue("$key_prefix", request.KeyPrefix);
command.Parameters.Add("$secret_hash", SqliteType.Blob).Value = request.SecretHash;
command.Parameters.AddWithValue("$display_name", request.DisplayName);
command.Parameters.AddWithValue("$scopes", ApiKeyScopeSerializer.Serialize(request.Scopes));
command.Parameters.AddWithValue("$created_utc", request.CreatedUtc.ToString("O"));
}
}
@@ -61,26 +61,6 @@ public sealed class SqliteApiKeyStore(AuthSqliteConnectionFactory connectionFact
return null;
}
return ReadApiKeyRecord(reader);
}
private static ApiKeyRecord ReadApiKeyRecord(SqliteDataReader reader)
{
return new ApiKeyRecord(
KeyId: reader.GetString(0),
KeyPrefix: reader.GetString(1),
SecretHash: (byte[])reader["secret_hash"],
DisplayName: reader.GetString(3),
Scopes: ApiKeyScopeSerializer.Deserialize(reader.GetString(4)),
CreatedUtc: DateTimeOffset.Parse(reader.GetString(5), System.Globalization.CultureInfo.InvariantCulture),
LastUsedUtc: ReadNullableDateTimeOffset(reader, 6),
RevokedUtc: ReadNullableDateTimeOffset(reader, 7));
}
private static DateTimeOffset? ReadNullableDateTimeOffset(SqliteDataReader reader, int ordinal)
{
return reader.IsDBNull(ordinal)
? null
: DateTimeOffset.Parse(reader.GetString(ordinal), System.Globalization.CultureInfo.InvariantCulture);
return ApiKeyRecordReader.Read(reader);
}
}
@@ -0,0 +1,74 @@
using Grpc.Core;
using Grpc.Core.Interceptors;
using Microsoft.Extensions.Options;
using MxGateway.Server.Configuration;
using MxGateway.Server.Security.Authentication;
namespace MxGateway.Server.Security.Authorization;
public sealed class GatewayGrpcAuthorizationInterceptor(
IApiKeyVerifier apiKeyVerifier,
GatewayGrpcScopeResolver scopeResolver,
IGatewayRequestIdentityAccessor identityAccessor,
IOptions<GatewayOptions> options) : Interceptor
{
public override async Task<TResponse> UnaryServerHandler<TRequest, TResponse>(
TRequest request,
ServerCallContext context,
UnaryServerMethod<TRequest, TResponse> continuation)
{
ApiKeyIdentity? identity = await AuthenticateAndAuthorizeAsync(request, context).ConfigureAwait(false);
IDisposable? identityScope = identity is null ? null : identityAccessor.Push(identity);
using (identityScope)
{
return await continuation(request, context).ConfigureAwait(false);
}
}
public override async Task ServerStreamingServerHandler<TRequest, TResponse>(
TRequest request,
IServerStreamWriter<TResponse> responseStream,
ServerCallContext context,
ServerStreamingServerMethod<TRequest, TResponse> continuation)
{
ApiKeyIdentity? identity = await AuthenticateAndAuthorizeAsync(request, context).ConfigureAwait(false);
IDisposable? identityScope = identity is null ? null : identityAccessor.Push(identity);
using (identityScope)
{
await continuation(request, responseStream, context).ConfigureAwait(false);
}
}
private async Task<ApiKeyIdentity?> AuthenticateAndAuthorizeAsync<TRequest>(
TRequest request,
ServerCallContext context)
where TRequest : class
{
if (options.Value.Authentication.Mode == AuthenticationMode.Disabled)
{
return null;
}
string? authorizationHeader = context.RequestHeaders.GetValue("authorization");
ApiKeyVerificationResult verificationResult = await apiKeyVerifier
.VerifyAsync(authorizationHeader, context.CancellationToken)
.ConfigureAwait(false);
if (!verificationResult.Succeeded || verificationResult.Identity is null)
{
throw new RpcException(new Status(
StatusCode.Unauthenticated,
"Missing or invalid API key."));
}
string requiredScope = scopeResolver.ResolveRequiredScope(request);
if (!verificationResult.Identity.Scopes.Contains(requiredScope))
{
throw new RpcException(new Status(
StatusCode.PermissionDenied,
$"API key is missing required scope '{requiredScope}'."));
}
return verificationResult.Identity;
}
}
@@ -0,0 +1,40 @@
using MxGateway.Contracts.Proto;
namespace MxGateway.Server.Security.Authorization;
public sealed class GatewayGrpcScopeResolver
{
public string ResolveRequiredScope(object request)
{
return request switch
{
OpenSessionRequest => GatewayScopes.SessionOpen,
CloseSessionRequest => GatewayScopes.SessionClose,
StreamEventsRequest => GatewayScopes.EventsRead,
MxCommandRequest commandRequest => ResolveCommandScope(commandRequest.Command?.Kind ?? MxCommandKind.Unspecified),
_ => GatewayScopes.Admin
};
}
private static string ResolveCommandScope(MxCommandKind kind)
{
return kind switch
{
MxCommandKind.Write or
MxCommandKind.Write2 => GatewayScopes.InvokeWrite,
MxCommandKind.WriteSecured or
MxCommandKind.WriteSecured2 or
MxCommandKind.AuthenticateUser => GatewayScopes.InvokeSecure,
MxCommandKind.ArchestraUserToId or
MxCommandKind.GetSessionState or
MxCommandKind.GetWorkerInfo => GatewayScopes.MetadataRead,
MxCommandKind.DrainEvents => GatewayScopes.EventsRead,
MxCommandKind.ShutdownWorker => GatewayScopes.Admin,
_ => GatewayScopes.InvokeRead
};
}
}
@@ -0,0 +1,38 @@
using MxGateway.Server.Security.Authentication;
namespace MxGateway.Server.Security.Authorization;
public sealed class GatewayRequestIdentityAccessor : IGatewayRequestIdentityAccessor
{
private readonly AsyncLocal<ApiKeyIdentity?> currentIdentity = new();
public ApiKeyIdentity? Current => currentIdentity.Value;
public IDisposable Push(ApiKeyIdentity identity)
{
ArgumentNullException.ThrowIfNull(identity);
ApiKeyIdentity? previousIdentity = currentIdentity.Value;
currentIdentity.Value = identity;
return new IdentityScope(this, previousIdentity);
}
private sealed class IdentityScope(
GatewayRequestIdentityAccessor accessor,
ApiKeyIdentity? previousIdentity) : IDisposable
{
private bool disposed;
public void Dispose()
{
if (disposed)
{
return;
}
accessor.currentIdentity.Value = previousIdentity;
disposed = true;
}
}
}
@@ -0,0 +1,13 @@
namespace MxGateway.Server.Security.Authorization;
public static class GatewayScopes
{
public const string SessionOpen = "session:open";
public const string SessionClose = "session:close";
public const string InvokeRead = "invoke:read";
public const string InvokeWrite = "invoke:write";
public const string InvokeSecure = "invoke:secure";
public const string EventsRead = "events:read";
public const string MetadataRead = "metadata:read";
public const string Admin = "admin";
}
@@ -0,0 +1,16 @@
using Grpc.Core.Interceptors;
namespace MxGateway.Server.Security.Authorization;
public static class GrpcAuthorizationServiceCollectionExtensions
{
public static IServiceCollection AddGatewayGrpcAuthorization(this IServiceCollection services)
{
services.AddSingleton<GatewayGrpcScopeResolver>();
services.AddSingleton<IGatewayRequestIdentityAccessor, GatewayRequestIdentityAccessor>();
services.AddSingleton<GatewayGrpcAuthorizationInterceptor>();
services.AddGrpc(options => options.Interceptors.Add<GatewayGrpcAuthorizationInterceptor>());
return services;
}
}
@@ -0,0 +1,10 @@
using MxGateway.Server.Security.Authentication;
namespace MxGateway.Server.Security.Authorization;
public interface IGatewayRequestIdentityAccessor
{
ApiKeyIdentity? Current { get; }
IDisposable Push(ApiKeyIdentity identity);
}
@@ -0,0 +1,27 @@
using MxGateway.Contracts.Proto;
namespace MxGateway.Server.Workers;
public interface IWorkerClient : IAsyncDisposable
{
string SessionId { get; }
int? ProcessId { get; }
WorkerClientState State { get; }
DateTimeOffset LastHeartbeatAt { get; }
Task StartAsync(CancellationToken cancellationToken);
Task<WorkerCommandReply> InvokeAsync(
WorkerCommand command,
TimeSpan timeout,
CancellationToken cancellationToken);
IAsyncEnumerable<WorkerEvent> ReadEventsAsync(CancellationToken cancellationToken);
Task ShutdownAsync(TimeSpan timeout, CancellationToken cancellationToken);
void Kill(string reason);
}
@@ -0,0 +1,14 @@
namespace MxGateway.Server.Workers;
public interface IWorkerProcess : IDisposable
{
int Id { get; }
bool HasExited { get; }
int? ExitCode { get; }
ValueTask WaitForExitAsync(CancellationToken cancellationToken);
void Kill(bool entireProcessTree);
}
@@ -0,0 +1,8 @@
using System.Diagnostics;
namespace MxGateway.Server.Workers;
public interface IWorkerProcessFactory
{
IWorkerProcess Start(ProcessStartInfo startInfo);
}
@@ -0,0 +1,8 @@
namespace MxGateway.Server.Workers;
public interface IWorkerProcessLauncher
{
Task<WorkerProcessHandle> LaunchAsync(
WorkerProcessLaunchRequest request,
CancellationToken cancellationToken = default);
}
@@ -0,0 +1,9 @@
namespace MxGateway.Server.Workers;
public interface IWorkerStartupProbe
{
Task WaitUntilReadyAsync(
IWorkerProcess process,
WorkerProcessLaunchRequest request,
CancellationToken cancellationToken);
}
@@ -0,0 +1,27 @@
using System.Diagnostics;
namespace MxGateway.Server.Workers;
internal sealed class SystemWorkerProcess(Process process) : IWorkerProcess
{
public int Id => process.Id;
public bool HasExited => process.HasExited;
public int? ExitCode => process.HasExited ? process.ExitCode : null;
public async ValueTask WaitForExitAsync(CancellationToken cancellationToken)
{
await process.WaitForExitAsync(cancellationToken).ConfigureAwait(false);
}
public void Kill(bool entireProcessTree)
{
process.Kill(entireProcessTree);
}
public void Dispose()
{
process.Dispose();
}
}
@@ -0,0 +1,22 @@
using System.Diagnostics;
namespace MxGateway.Server.Workers;
public sealed class SystemWorkerProcessFactory : IWorkerProcessFactory
{
public IWorkerProcess Start(ProcessStartInfo startInfo)
{
Process process = new()
{
StartInfo = startInfo,
};
if (!process.Start())
{
process.Dispose();
throw new InvalidOperationException("Worker process failed to start.");
}
return new SystemWorkerProcess(process);
}
}
@@ -0,0 +1,755 @@
using System.Collections.Concurrent;
using System.Runtime.CompilerServices;
using System.Threading.Channels;
using Google.Protobuf.WellKnownTypes;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Server.Metrics;
namespace MxGateway.Server.Workers;
public sealed class WorkerClient : IWorkerClient
{
private const string GatewayVersionFallback = "unknown";
private readonly object _syncRoot = new();
private readonly WorkerClientConnection _connection;
private readonly WorkerClientOptions _options;
private readonly GatewayMetrics? _metrics;
private readonly TimeProvider _timeProvider;
private readonly ILogger<WorkerClient> _logger;
private readonly WorkerFrameReader _reader;
private readonly WorkerFrameWriter _writer;
private readonly Channel<WorkerEnvelope> _outboundEnvelopes;
private readonly Channel<WorkerEvent> _events;
private readonly ConcurrentDictionary<string, PendingCommand> _pendingCommands = new(StringComparer.Ordinal);
private readonly CancellationTokenSource _stopCts = new();
private long _nextSequence;
private WorkerClientState _state;
private DateTimeOffset _lastHeartbeatAt;
private int? _processId;
private Task? _readLoopTask;
private Task? _writeLoopTask;
private Task? _heartbeatLoopTask;
private bool _disposed;
public WorkerClient(
WorkerClientConnection connection,
WorkerClientOptions? options = null,
GatewayMetrics? metrics = null,
TimeProvider? timeProvider = null,
ILogger<WorkerClient>? logger = null)
{
_connection = connection ?? throw new ArgumentNullException(nameof(connection));
_options = options ?? new WorkerClientOptions();
_metrics = metrics;
_timeProvider = timeProvider ?? TimeProvider.System;
_logger = logger ?? NullLogger<WorkerClient>.Instance;
_reader = new WorkerFrameReader(connection.Stream, connection.FrameOptions);
_writer = new WorkerFrameWriter(connection.Stream, connection.FrameOptions);
_outboundEnvelopes = Channel.CreateUnbounded<WorkerEnvelope>(
new UnboundedChannelOptions
{
SingleReader = true,
SingleWriter = false,
AllowSynchronousContinuations = false,
});
_events = Channel.CreateBounded<WorkerEvent>(
new BoundedChannelOptions(_options.EventChannelCapacity)
{
SingleReader = false,
SingleWriter = true,
FullMode = BoundedChannelFullMode.Wait,
AllowSynchronousContinuations = false,
});
_lastHeartbeatAt = _timeProvider.GetUtcNow();
}
public string SessionId => _connection.SessionId;
public int? ProcessId
{
get
{
lock (_syncRoot)
{
return _processId;
}
}
}
public WorkerClientState State
{
get
{
lock (_syncRoot)
{
return _state;
}
}
}
public DateTimeOffset LastHeartbeatAt
{
get
{
lock (_syncRoot)
{
return _lastHeartbeatAt;
}
}
}
public async Task StartAsync(CancellationToken cancellationToken)
{
ThrowIfDisposed();
TransitionFromCreatedToHandshaking();
_writeLoopTask = Task.Run(WriteLoopAsync);
await EnqueueAsync(CreateGatewayHelloEnvelope(), cancellationToken).ConfigureAwait(false);
WorkerEnvelope helloEnvelope = await ReadHandshakeEnvelopeAsync(
WorkerEnvelope.BodyOneofCase.WorkerHello,
cancellationToken).ConfigureAwait(false);
ValidateWorkerHello(helloEnvelope.WorkerHello);
WorkerEnvelope readyEnvelope = await ReadHandshakeEnvelopeAsync(
WorkerEnvelope.BodyOneofCase.WorkerReady,
cancellationToken).ConfigureAwait(false);
MarkReady(readyEnvelope.WorkerReady);
_readLoopTask = Task.Run(ReadLoopAsync);
_heartbeatLoopTask = Task.Run(HeartbeatLoopAsync);
}
public async Task<WorkerCommandReply> InvokeAsync(
WorkerCommand command,
TimeSpan timeout,
CancellationToken cancellationToken)
{
ArgumentNullException.ThrowIfNull(command);
ThrowIfDisposed();
EnsureReady();
if (timeout <= TimeSpan.Zero)
{
throw new ArgumentOutOfRangeException(nameof(timeout), timeout, "Command timeout must be greater than zero.");
}
string correlationId = Guid.NewGuid().ToString("N");
string method = GetCommandMethod(command);
PendingCommand pendingCommand = new(
correlationId,
method,
_timeProvider.GetTimestamp());
if (!_pendingCommands.TryAdd(correlationId, pendingCommand))
{
throw new InvalidOperationException("Generated a duplicate command correlation id.");
}
_metrics?.CommandStarted(method);
try
{
await EnqueueAsync(CreateCommandEnvelope(correlationId, command), cancellationToken).ConfigureAwait(false);
using CancellationTokenSource timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
Task timeoutTask = Task.Delay(timeout, timeoutCts.Token);
Task<WorkerCommandReply> replyTask = pendingCommand.Task;
Task completedTask = await Task.WhenAny(replyTask, timeoutTask).ConfigureAwait(false);
if (completedTask == replyTask)
{
await timeoutCts.CancelAsync().ConfigureAwait(false);
return await replyTask.ConfigureAwait(false);
}
if (cancellationToken.IsCancellationRequested)
{
RemovePendingCommandAsFailed(
correlationId,
pendingCommand,
WorkerClientErrorCode.GatewayShutdown,
"Command wait was canceled.");
cancellationToken.ThrowIfCancellationRequested();
}
RemovePendingCommandAsFailed(
correlationId,
pendingCommand,
WorkerClientErrorCode.CommandTimeout,
$"Worker command {method} timed out after {timeout}.");
throw new WorkerClientException(
WorkerClientErrorCode.CommandTimeout,
$"Worker command {method} timed out after {timeout}.");
}
catch
{
_pendingCommands.TryRemove(correlationId, out _);
throw;
}
}
public async IAsyncEnumerable<WorkerEvent> ReadEventsAsync(
[EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (WorkerEvent workerEvent in _events.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false))
{
yield return workerEvent;
}
}
public async Task ShutdownAsync(TimeSpan timeout, CancellationToken cancellationToken)
{
ThrowIfDisposed();
if (timeout <= TimeSpan.Zero)
{
throw new ArgumentOutOfRangeException(nameof(timeout), timeout, "Shutdown timeout must be greater than zero.");
}
WorkerClientState state = State;
if (state is WorkerClientState.Closed or WorkerClientState.Faulted)
{
return;
}
MarkClosing();
await EnqueueAsync(CreateShutdownEnvelope(timeout, "gateway-shutdown"), cancellationToken).ConfigureAwait(false);
_outboundEnvelopes.Writer.TryComplete();
using CancellationTokenSource timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
timeoutCts.CancelAfter(timeout);
try
{
await WaitForBackgroundTasksAsync(timeoutCts.Token).ConfigureAwait(false);
MarkClosed("shutdown");
}
catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested)
{
SetFaulted(
WorkerClientErrorCode.ShutdownTimeout,
"Worker shutdown timed out.",
null);
throw new WorkerClientException(
WorkerClientErrorCode.ShutdownTimeout,
$"Worker shutdown timed out after {timeout}.");
}
}
public void Kill(string reason)
{
ThrowIfDisposed();
_connection.ProcessHandle?.Process.Kill(entireProcessTree: true);
_metrics?.WorkerKilled(reason);
SetFaulted(
WorkerClientErrorCode.WorkerFaulted,
$"Worker was killed by the gateway: {reason}.",
null);
}
public async ValueTask DisposeAsync()
{
if (_disposed)
{
return;
}
_disposed = true;
_stopCts.Cancel();
_outboundEnvelopes.Writer.TryComplete();
_events.Writer.TryComplete();
CompletePendingCommands(
new WorkerClientException(
WorkerClientErrorCode.GatewayShutdown,
"Worker client was disposed."));
await WaitForBackgroundTasksAsync(CancellationToken.None).ConfigureAwait(false);
await _connection.Stream.DisposeAsync().ConfigureAwait(false);
_connection.ProcessHandle?.Dispose();
_stopCts.Dispose();
}
private async Task WriteLoopAsync()
{
try
{
await foreach (WorkerEnvelope envelope in _outboundEnvelopes.Reader.ReadAllAsync(_stopCts.Token).ConfigureAwait(false))
{
await _writer.WriteAsync(envelope, _stopCts.Token).ConfigureAwait(false);
}
}
catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState())
{
}
catch (Exception exception)
{
SetFaulted(
WorkerClientErrorCode.WriteFailed,
"Worker pipe write failed.",
exception);
}
}
private async Task ReadLoopAsync()
{
try
{
while (!_stopCts.IsCancellationRequested)
{
WorkerEnvelope envelope = await _reader.ReadAsync(_stopCts.Token).ConfigureAwait(false);
await DispatchEnvelopeAsync(envelope, _stopCts.Token).ConfigureAwait(false);
}
}
catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState())
{
}
catch (WorkerFrameProtocolException exception) when (exception.ErrorCode == WorkerFrameProtocolErrorCode.EndOfStream)
{
SetFaulted(
WorkerClientErrorCode.PipeDisconnected,
"Worker pipe disconnected.",
exception);
}
catch (Exception exception)
{
SetFaulted(
WorkerClientErrorCode.ProtocolViolation,
"Worker read loop failed.",
exception);
}
}
private async Task HeartbeatLoopAsync()
{
try
{
while (!_stopCts.IsCancellationRequested)
{
await Task.Delay(_options.HeartbeatCheckInterval, _stopCts.Token).ConfigureAwait(false);
if (State != WorkerClientState.Ready)
{
continue;
}
DateTimeOffset lastHeartbeatAt = LastHeartbeatAt;
DateTimeOffset now = _timeProvider.GetUtcNow();
if (now - lastHeartbeatAt <= _options.HeartbeatGrace)
{
continue;
}
_metrics?.HeartbeatFailed(SessionId);
SetFaulted(
WorkerClientErrorCode.HeartbeatExpired,
$"Worker heartbeat expired. Last heartbeat was at {lastHeartbeatAt:O}.",
null);
}
}
catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState())
{
}
}
private async Task DispatchEnvelopeAsync(
WorkerEnvelope envelope,
CancellationToken cancellationToken)
{
switch (envelope.BodyCase)
{
case WorkerEnvelope.BodyOneofCase.WorkerCommandReply:
CompleteCommand(envelope);
break;
case WorkerEnvelope.BodyOneofCase.WorkerEvent:
await EnqueueWorkerEventAsync(envelope.WorkerEvent, cancellationToken).ConfigureAwait(false);
break;
case WorkerEnvelope.BodyOneofCase.WorkerHeartbeat:
MarkHeartbeat(envelope.WorkerHeartbeat);
break;
case WorkerEnvelope.BodyOneofCase.WorkerFault:
SetFaulted(
WorkerClientErrorCode.WorkerFaulted,
CreateWorkerFaultMessage(envelope.WorkerFault),
null);
break;
case WorkerEnvelope.BodyOneofCase.WorkerShutdownAck:
MarkClosed("worker-shutdown-ack");
break;
default:
SetFaulted(
WorkerClientErrorCode.ProtocolViolation,
$"Worker sent unexpected envelope body {envelope.BodyCase}.",
null);
break;
}
}
private async Task EnqueueWorkerEventAsync(
WorkerEvent workerEvent,
CancellationToken cancellationToken)
{
if (workerEvent.Event is not null)
{
_metrics?.EventReceived(SessionId, workerEvent.Event.Family.ToString());
}
if (!await _events.Writer.WaitToWriteAsync(cancellationToken).ConfigureAwait(false))
{
return;
}
if (!_events.Writer.TryWrite(workerEvent))
{
_metrics?.QueueOverflow("worker-events");
SetFaulted(
WorkerClientErrorCode.ProtocolViolation,
"Worker event channel rejected an event.",
null);
}
}
private void CompleteCommand(WorkerEnvelope envelope)
{
string correlationId = envelope.CorrelationId;
if (string.IsNullOrWhiteSpace(correlationId))
{
correlationId = envelope.WorkerCommandReply.Reply?.CorrelationId ?? string.Empty;
}
if (!_pendingCommands.TryRemove(correlationId, out PendingCommand? pendingCommand))
{
_logger.LogDebug(
"Ignoring late or unknown worker command reply for session {SessionId} and correlation {CorrelationId}.",
SessionId,
correlationId);
return;
}
TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp);
_metrics?.CommandSucceeded(pendingCommand.Method, duration);
pendingCommand.SetResult(envelope.WorkerCommandReply);
}
private void RemovePendingCommandAsFailed(
string correlationId,
PendingCommand pendingCommand,
WorkerClientErrorCode errorCode,
string message)
{
if (!_pendingCommands.TryRemove(correlationId, out _))
{
return;
}
TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp);
_metrics?.CommandFailed(pendingCommand.Method, errorCode.ToString(), duration);
pendingCommand.SetException(new WorkerClientException(errorCode, message));
}
private async Task<WorkerEnvelope> ReadHandshakeEnvelopeAsync(
WorkerEnvelope.BodyOneofCase expectedBody,
CancellationToken cancellationToken)
{
WorkerEnvelope envelope = await _reader.ReadAsync(cancellationToken).ConfigureAwait(false);
if (envelope.BodyCase != expectedBody)
{
throw new WorkerClientException(
WorkerClientErrorCode.ProtocolViolation,
$"Worker handshake expected {expectedBody} but received {envelope.BodyCase}.");
}
return envelope;
}
private void ValidateWorkerHello(WorkerHello workerHello)
{
if (workerHello.ProtocolVersion != _connection.FrameOptions.ProtocolVersion)
{
throw new WorkerClientException(
WorkerClientErrorCode.ProtocolViolation,
"Worker hello protocol version does not match the gateway protocol version.");
}
if (!string.Equals(workerHello.Nonce, _connection.Nonce, StringComparison.Ordinal))
{
throw new WorkerClientException(
WorkerClientErrorCode.ProtocolViolation,
"Worker hello nonce does not match the gateway nonce.");
}
lock (_syncRoot)
{
_processId = workerHello.WorkerProcessId == 0
? _connection.ProcessHandle?.ProcessId
: workerHello.WorkerProcessId;
}
}
private void MarkReady(WorkerReady ready)
{
lock (_syncRoot)
{
_processId = ready.WorkerProcessId == 0
? _processId ?? _connection.ProcessHandle?.ProcessId
: ready.WorkerProcessId;
_lastHeartbeatAt = _timeProvider.GetUtcNow();
_state = WorkerClientState.Ready;
}
DateTimeOffset readyAt = _timeProvider.GetUtcNow();
DateTimeOffset launchedAt = _connection.ProcessHandle?.LaunchedAt ?? readyAt;
_metrics?.WorkerStarted(readyAt - launchedAt);
}
private void MarkHeartbeat(WorkerHeartbeat heartbeat)
{
lock (_syncRoot)
{
_lastHeartbeatAt = _timeProvider.GetUtcNow();
if (heartbeat.WorkerProcessId != 0)
{
_processId = heartbeat.WorkerProcessId;
}
}
}
private void MarkClosing()
{
lock (_syncRoot)
{
if (_state is WorkerClientState.Closed or WorkerClientState.Faulted)
{
return;
}
_state = WorkerClientState.Closing;
}
}
private void MarkClosed(string reason)
{
lock (_syncRoot)
{
if (_state == WorkerClientState.Closed)
{
return;
}
_state = WorkerClientState.Closed;
}
_stopCts.Cancel();
_outboundEnvelopes.Writer.TryComplete();
_events.Writer.TryComplete();
CompletePendingCommands(
new WorkerClientException(
WorkerClientErrorCode.GatewayShutdown,
$"Worker client closed because {reason}."));
_metrics?.WorkerStopped(reason);
}
private void SetFaulted(
WorkerClientErrorCode errorCode,
string message,
Exception? exception)
{
WorkerClientException fault = exception is null
? new WorkerClientException(errorCode, message)
: new WorkerClientException(errorCode, message, exception);
lock (_syncRoot)
{
if (_state is WorkerClientState.Faulted or WorkerClientState.Closed)
{
return;
}
_state = WorkerClientState.Faulted;
}
_stopCts.Cancel();
_outboundEnvelopes.Writer.TryComplete(fault);
_events.Writer.TryComplete(fault);
CompletePendingCommands(fault);
_metrics?.Fault(errorCode.ToString());
_logger.LogWarning(exception, "Worker client faulted for session {SessionId}: {Message}", SessionId, message);
}
private void CompletePendingCommands(Exception exception)
{
foreach (KeyValuePair<string, PendingCommand> item in _pendingCommands.ToArray())
{
if (_pendingCommands.TryRemove(item.Key, out PendingCommand? pendingCommand))
{
TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp);
_metrics?.CommandFailed(pendingCommand.Method, exception.GetType().Name, duration);
pendingCommand.SetException(exception);
}
}
}
private void TransitionFromCreatedToHandshaking()
{
lock (_syncRoot)
{
if (_state != WorkerClientState.Created)
{
throw new WorkerClientException(
WorkerClientErrorCode.InvalidState,
$"Worker client cannot start from state {_state}.");
}
_state = WorkerClientState.Handshaking;
}
}
private void EnsureReady()
{
WorkerClientState state = State;
if (state != WorkerClientState.Ready)
{
throw new WorkerClientException(
WorkerClientErrorCode.InvalidState,
$"Worker client is not ready. Current state is {state}.");
}
}
private bool IsTerminalState()
{
WorkerClientState state = State;
return state is WorkerClientState.Closing or WorkerClientState.Closed or WorkerClientState.Faulted;
}
private async Task EnqueueAsync(
WorkerEnvelope envelope,
CancellationToken cancellationToken)
{
try
{
await _outboundEnvelopes.Writer.WriteAsync(envelope, cancellationToken).ConfigureAwait(false);
}
catch (ChannelClosedException exception)
{
throw new WorkerClientException(
WorkerClientErrorCode.WriteFailed,
"Worker outbound channel is closed.",
exception);
}
}
private WorkerEnvelope CreateGatewayHelloEnvelope()
{
return CreateEnvelope(
correlationId: string.Empty,
envelope => envelope.GatewayHello = new GatewayHello
{
SupportedProtocolVersion = _connection.FrameOptions.ProtocolVersion,
Nonce = _connection.Nonce,
GatewayVersion = typeof(GatewayContractInfo).Assembly.GetName().Version?.ToString() ?? GatewayVersionFallback,
});
}
private WorkerEnvelope CreateCommandEnvelope(
string correlationId,
WorkerCommand command)
{
return CreateEnvelope(
correlationId,
envelope => envelope.WorkerCommand = command.Clone());
}
private WorkerEnvelope CreateShutdownEnvelope(
TimeSpan timeout,
string reason)
{
return CreateEnvelope(
correlationId: string.Empty,
envelope => envelope.WorkerShutdown = new WorkerShutdown
{
GracePeriod = Duration.FromTimeSpan(timeout),
Reason = reason,
});
}
private WorkerEnvelope CreateEnvelope(
string correlationId,
Action<WorkerEnvelope> setBody)
{
WorkerEnvelope envelope = new()
{
ProtocolVersion = _connection.FrameOptions.ProtocolVersion,
SessionId = SessionId,
Sequence = (ulong)Interlocked.Increment(ref _nextSequence),
CorrelationId = correlationId,
};
setBody(envelope);
return envelope;
}
private static string GetCommandMethod(WorkerCommand command)
{
return command.Command?.Kind.ToString() ?? MxCommandKind.Unspecified.ToString();
}
private static string CreateWorkerFaultMessage(WorkerFault fault)
{
return string.IsNullOrWhiteSpace(fault.DiagnosticMessage)
? $"Worker faulted with category {fault.Category}."
: $"Worker faulted with category {fault.Category}: {fault.DiagnosticMessage}";
}
private async Task WaitForBackgroundTasksAsync(CancellationToken cancellationToken)
{
Task[] tasks = new[] { _readLoopTask, _writeLoopTask, _heartbeatLoopTask }
.Where(task => task is not null)
.Cast<Task>()
.ToArray();
if (tasks.Length == 0)
{
return;
}
await Task.WhenAll(tasks).WaitAsync(cancellationToken).ConfigureAwait(false);
}
private void ThrowIfDisposed()
{
ObjectDisposedException.ThrowIf(_disposed, this);
}
private sealed class PendingCommand
{
private readonly TaskCompletionSource<WorkerCommandReply> _completion = new(TaskCreationOptions.RunContinuationsAsynchronously);
public PendingCommand(
string correlationId,
string method,
long startTimestamp)
{
CorrelationId = correlationId;
Method = method;
StartTimestamp = startTimestamp;
}
public string CorrelationId { get; }
public string Method { get; }
public long StartTimestamp { get; }
public Task<WorkerCommandReply> Task => _completion.Task;
public void SetResult(WorkerCommandReply reply)
{
_completion.TrySetResult(reply);
}
public void SetException(Exception exception)
{
_completion.TrySetException(exception);
}
}
}
@@ -0,0 +1,38 @@
namespace MxGateway.Server.Workers;
public sealed class WorkerClientConnection
{
public WorkerClientConnection(
string sessionId,
string nonce,
Stream stream,
WorkerFrameProtocolOptions frameOptions,
WorkerProcessHandle? processHandle = null)
{
if (string.IsNullOrWhiteSpace(sessionId))
{
throw new ArgumentException("Session id is required.", nameof(sessionId));
}
if (string.IsNullOrWhiteSpace(nonce))
{
throw new ArgumentException("Worker nonce is required.", nameof(nonce));
}
SessionId = sessionId;
Nonce = nonce;
Stream = stream ?? throw new ArgumentNullException(nameof(stream));
FrameOptions = frameOptions ?? throw new ArgumentNullException(nameof(frameOptions));
ProcessHandle = processHandle;
}
public string SessionId { get; }
public string Nonce { get; }
public Stream Stream { get; }
public WorkerFrameProtocolOptions FrameOptions { get; }
public WorkerProcessHandle? ProcessHandle { get; }
}
@@ -0,0 +1,14 @@
namespace MxGateway.Server.Workers;
public enum WorkerClientErrorCode
{
InvalidState,
ProtocolViolation,
PipeDisconnected,
CommandTimeout,
WorkerFaulted,
HeartbeatExpired,
ShutdownTimeout,
GatewayShutdown,
WriteFailed,
}
@@ -0,0 +1,23 @@
namespace MxGateway.Server.Workers;
public sealed class WorkerClientException : Exception
{
public WorkerClientException(
WorkerClientErrorCode errorCode,
string message)
: base(message)
{
ErrorCode = errorCode;
}
public WorkerClientException(
WorkerClientErrorCode errorCode,
string message,
Exception innerException)
: base(message, innerException)
{
ErrorCode = errorCode;
}
public WorkerClientErrorCode ErrorCode { get; }
}
@@ -0,0 +1,24 @@
namespace MxGateway.Server.Workers;
public sealed class WorkerClientOptions
{
public static readonly TimeSpan DefaultHeartbeatGrace = TimeSpan.FromSeconds(15);
public static readonly TimeSpan DefaultHeartbeatCheckInterval = TimeSpan.FromSeconds(1);
public static readonly TimeSpan DefaultEventChannelFullModeTimeout = TimeSpan.FromSeconds(5);
public WorkerClientOptions()
{
HeartbeatGrace = DefaultHeartbeatGrace;
HeartbeatCheckInterval = DefaultHeartbeatCheckInterval;
EventChannelCapacity = 1_024;
EventChannelFullModeTimeout = DefaultEventChannelFullModeTimeout;
}
public TimeSpan HeartbeatGrace { get; init; }
public TimeSpan HeartbeatCheckInterval { get; init; }
public int EventChannelCapacity { get; init; }
public TimeSpan EventChannelFullModeTimeout { get; init; }
}
@@ -0,0 +1,11 @@
namespace MxGateway.Server.Workers;
public enum WorkerClientState
{
Created,
Handshaking,
Ready,
Closing,
Closed,
Faulted,
}
@@ -0,0 +1,80 @@
using System.Buffers.Binary;
using MxGateway.Server.Configuration;
namespace MxGateway.Server.Workers;
internal static class WorkerExecutableValidator
{
private const ushort ImageFileMachineI386 = 0x014c;
private const ushort ImageFileMachineAmd64 = 0x8664;
private const int DosHeaderSignatureOffset = 0;
private const int PeHeaderOffsetPointer = 0x3c;
private const int PeSignatureSize = 4;
private const int MachineOffsetFromPeHeader = PeSignatureSize;
private const int MinimumHeaderSize = 0x40;
public static void Validate(
string executablePath,
WorkerArchitecture requiredArchitecture)
{
ushort machine = ReadMachineType(executablePath);
ushort expectedMachine = requiredArchitecture switch
{
WorkerArchitecture.X86 => ImageFileMachineI386,
WorkerArchitecture.X64 => ImageFileMachineAmd64,
_ => throw new WorkerProcessLaunchException(
WorkerProcessLaunchErrorCode.InvalidExecutable,
"Worker executable required architecture is unsupported."),
};
if (machine != expectedMachine)
{
throw new WorkerProcessLaunchException(
WorkerProcessLaunchErrorCode.InvalidExecutable,
$"Worker executable architecture does not match required {requiredArchitecture} architecture.");
}
}
private static ushort ReadMachineType(string executablePath)
{
byte[] header = new byte[MinimumHeaderSize];
using FileStream stream = File.OpenRead(executablePath);
if (stream.Read(header) < header.Length)
{
throw InvalidExecutable("Worker executable is too small to contain a valid PE header.");
}
if (header[DosHeaderSignatureOffset] != 'M' || header[DosHeaderSignatureOffset + 1] != 'Z')
{
throw InvalidExecutable("Worker executable does not contain an MZ header.");
}
int peHeaderOffset = BinaryPrimitives.ReadInt32LittleEndian(header.AsSpan(PeHeaderOffsetPointer, sizeof(int)));
if (peHeaderOffset < MinimumHeaderSize)
{
throw InvalidExecutable("Worker executable PE header offset is invalid.");
}
byte[] peHeaderBytes = new byte[PeSignatureSize + sizeof(ushort)];
stream.Position = peHeaderOffset;
if (stream.Read(peHeaderBytes) < peHeaderBytes.Length)
{
throw InvalidExecutable("Worker executable PE header is missing.");
}
if (peHeaderBytes[0] != 'P' || peHeaderBytes[1] != 'E' || peHeaderBytes[2] != 0 || peHeaderBytes[3] != 0)
{
throw InvalidExecutable("Worker executable does not contain a PE header.");
}
return BinaryPrimitives.ReadUInt16LittleEndian(
peHeaderBytes.AsSpan(MachineOffsetFromPeHeader, sizeof(ushort)));
}
private static WorkerProcessLaunchException InvalidExecutable(string message)
{
return new WorkerProcessLaunchException(
WorkerProcessLaunchErrorCode.InvalidExecutable,
message);
}
}
@@ -0,0 +1,30 @@
namespace MxGateway.Server.Workers;
public sealed class WorkerProcessCommandLine
{
public WorkerProcessCommandLine(
string executablePath,
IReadOnlyList<string> arguments)
{
ExecutablePath = executablePath;
Arguments = arguments;
}
public string ExecutablePath { get; }
public IReadOnlyList<string> Arguments { get; }
public override string ToString()
{
return string.Join(
" ",
new[] { Quote(ExecutablePath) }.Concat(Arguments.Select(Quote)));
}
private static string Quote(string value)
{
return value.Contains(' ', StringComparison.Ordinal)
? $"\"{value}\""
: value;
}
}
@@ -0,0 +1,28 @@
namespace MxGateway.Server.Workers;
public sealed class WorkerProcessHandle : IDisposable
{
public WorkerProcessHandle(
IWorkerProcess process,
WorkerProcessCommandLine commandLine,
DateTimeOffset launchedAt)
{
Process = process;
ProcessId = process.Id;
CommandLine = commandLine;
LaunchedAt = launchedAt;
}
public IWorkerProcess Process { get; }
public int ProcessId { get; }
public WorkerProcessCommandLine CommandLine { get; }
public DateTimeOffset LaunchedAt { get; }
public void Dispose()
{
Process.Dispose();
}
}
@@ -0,0 +1,13 @@
namespace MxGateway.Server.Workers;
public enum WorkerProcessLaunchErrorCode
{
Unknown = 0,
InvalidRequest = 1,
ExecutableNotFound = 2,
InvalidExecutable = 3,
InvalidWorkingDirectory = 4,
StartFailed = 5,
StartupTimeout = 6,
StartupFailed = 7,
}
@@ -0,0 +1,23 @@
namespace MxGateway.Server.Workers;
public sealed class WorkerProcessLaunchException : Exception
{
public WorkerProcessLaunchException(
WorkerProcessLaunchErrorCode errorCode,
string message)
: base(message)
{
ErrorCode = errorCode;
}
public WorkerProcessLaunchException(
WorkerProcessLaunchErrorCode errorCode,
string message,
Exception innerException)
: base(message, innerException)
{
ErrorCode = errorCode;
}
public WorkerProcessLaunchErrorCode ErrorCode { get; }
}
@@ -0,0 +1,8 @@
namespace MxGateway.Server.Workers;
public sealed record WorkerProcessLaunchRequest(
string SessionId,
string PipeName,
uint ProtocolVersion,
string Nonce,
IDisposable? PipeReservation = null);
@@ -0,0 +1,262 @@
using System.Diagnostics;
using Microsoft.Extensions.Options;
using MxGateway.Server.Configuration;
using MxGateway.Server.Metrics;
namespace MxGateway.Server.Workers;
public sealed class WorkerProcessLauncher : IWorkerProcessLauncher
{
public const string WorkerNonceEnvironmentVariableName = "MXGATEWAY_WORKER_NONCE";
private readonly IWorkerProcessFactory _processFactory;
private readonly IWorkerStartupProbe _startupProbe;
private readonly GatewayMetrics _metrics;
private readonly TimeProvider _timeProvider;
private readonly WorkerOptions _workerOptions;
public WorkerProcessLauncher(
IOptions<GatewayOptions> gatewayOptions,
IWorkerProcessFactory processFactory,
IWorkerStartupProbe startupProbe,
GatewayMetrics metrics,
TimeProvider? timeProvider = null)
{
ArgumentNullException.ThrowIfNull(gatewayOptions);
ArgumentNullException.ThrowIfNull(processFactory);
ArgumentNullException.ThrowIfNull(startupProbe);
ArgumentNullException.ThrowIfNull(metrics);
_workerOptions = gatewayOptions.Value.Worker;
_processFactory = processFactory;
_startupProbe = startupProbe;
_metrics = metrics;
_timeProvider = timeProvider ?? TimeProvider.System;
}
public async Task<WorkerProcessHandle> LaunchAsync(
WorkerProcessLaunchRequest request,
CancellationToken cancellationToken = default)
{
try
{
return await LaunchCoreAsync(request, cancellationToken).ConfigureAwait(false);
}
catch
{
request.PipeReservation?.Dispose();
throw;
}
}
private async Task<WorkerProcessHandle> LaunchCoreAsync(
WorkerProcessLaunchRequest request,
CancellationToken cancellationToken)
{
ValidateRequest(request);
DateTimeOffset startedAt = _timeProvider.GetUtcNow();
ProcessStartInfo startInfo = CreateStartInfo(request, out WorkerProcessCommandLine commandLine);
IWorkerProcess process;
try
{
process = _processFactory.Start(startInfo);
}
catch (Exception exception) when (exception is not WorkerProcessLaunchException)
{
throw new WorkerProcessLaunchException(
WorkerProcessLaunchErrorCode.StartFailed,
"Worker process failed to start.",
exception);
}
try
{
using CancellationTokenSource startupTimeout = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
startupTimeout.CancelAfter(TimeSpan.FromSeconds(_workerOptions.StartupTimeoutSeconds));
await _startupProbe
.WaitUntilReadyAsync(process, request, startupTimeout.Token)
.ConfigureAwait(false);
_metrics.WorkerStarted(_timeProvider.GetUtcNow() - startedAt);
return new WorkerProcessHandle(process, commandLine, startedAt);
}
catch (OperationCanceledException exception) when (!cancellationToken.IsCancellationRequested)
{
KillAndDispose(process, "StartupTimeout");
throw new WorkerProcessLaunchException(
WorkerProcessLaunchErrorCode.StartupTimeout,
"Worker process did not complete startup before the configured timeout.",
exception);
}
catch (OperationCanceledException)
{
KillAndDispose(process, "LaunchCanceled");
throw;
}
catch (Exception exception) when (exception is not WorkerProcessLaunchException)
{
KillAndDispose(process, "StartupFailed");
throw new WorkerProcessLaunchException(
WorkerProcessLaunchErrorCode.StartupFailed,
"Worker process failed during startup.",
exception);
}
catch (WorkerProcessLaunchException)
{
KillAndDispose(process, "StartupFailed");
throw;
}
}
private ProcessStartInfo CreateStartInfo(
WorkerProcessLaunchRequest request,
out WorkerProcessCommandLine commandLine)
{
string executablePath = ResolveExecutablePath();
string workingDirectory = ResolveWorkingDirectory(executablePath);
string[] arguments =
[
"--session-id",
request.SessionId,
"--pipe-name",
request.PipeName,
"--protocol-version",
request.ProtocolVersion.ToString(System.Globalization.CultureInfo.InvariantCulture),
];
ProcessStartInfo startInfo = new()
{
FileName = executablePath,
WorkingDirectory = workingDirectory,
UseShellExecute = false,
CreateNoWindow = true,
ErrorDialog = false,
};
foreach (string argument in arguments)
{
startInfo.ArgumentList.Add(argument);
}
startInfo.Environment[WorkerNonceEnvironmentVariableName] = request.Nonce;
commandLine = new WorkerProcessCommandLine(executablePath, arguments);
return startInfo;
}
private string ResolveExecutablePath()
{
string executablePath;
try
{
executablePath = Path.GetFullPath(_workerOptions.ExecutablePath);
}
catch (Exception exception) when (exception is ArgumentException or NotSupportedException or PathTooLongException)
{
throw new WorkerProcessLaunchException(
WorkerProcessLaunchErrorCode.InvalidExecutable,
"Worker executable path is not a valid filesystem path.",
exception);
}
if (!string.Equals(Path.GetExtension(executablePath), ".exe", StringComparison.OrdinalIgnoreCase))
{
throw new WorkerProcessLaunchException(
WorkerProcessLaunchErrorCode.InvalidExecutable,
"Worker executable path must point to a .exe file.");
}
if (!File.Exists(executablePath))
{
throw new WorkerProcessLaunchException(
WorkerProcessLaunchErrorCode.ExecutableNotFound,
"Worker executable does not exist.");
}
WorkerExecutableValidator.Validate(executablePath, _workerOptions.RequiredArchitecture);
return executablePath;
}
private string ResolveWorkingDirectory(string executablePath)
{
if (string.IsNullOrWhiteSpace(_workerOptions.WorkingDirectory))
{
return Path.GetDirectoryName(executablePath) ?? Environment.CurrentDirectory;
}
string workingDirectory;
try
{
workingDirectory = Path.GetFullPath(_workerOptions.WorkingDirectory);
}
catch (Exception exception) when (exception is ArgumentException or NotSupportedException or PathTooLongException)
{
throw new WorkerProcessLaunchException(
WorkerProcessLaunchErrorCode.InvalidWorkingDirectory,
"Worker working directory is not a valid filesystem path.",
exception);
}
if (!Directory.Exists(workingDirectory))
{
throw new WorkerProcessLaunchException(
WorkerProcessLaunchErrorCode.InvalidWorkingDirectory,
"Worker working directory does not exist.");
}
return workingDirectory;
}
private void KillAndDispose(IWorkerProcess process, string reason)
{
try
{
if (!process.HasExited)
{
process.Kill(entireProcessTree: true);
_metrics.WorkerKilled(reason);
}
}
finally
{
process.Dispose();
}
}
private static void ValidateRequest(WorkerProcessLaunchRequest request)
{
if (string.IsNullOrWhiteSpace(request.SessionId))
{
throw new WorkerProcessLaunchException(
WorkerProcessLaunchErrorCode.InvalidRequest,
"Worker launch requires a session id.");
}
if (string.IsNullOrWhiteSpace(request.PipeName))
{
throw new WorkerProcessLaunchException(
WorkerProcessLaunchErrorCode.InvalidRequest,
"Worker launch requires a pipe name.");
}
if (request.ProtocolVersion == 0)
{
throw new WorkerProcessLaunchException(
WorkerProcessLaunchErrorCode.InvalidRequest,
"Worker launch requires a non-zero protocol version.");
}
if (string.IsNullOrWhiteSpace(request.Nonce))
{
throw new WorkerProcessLaunchException(
WorkerProcessLaunchErrorCode.InvalidRequest,
"Worker launch requires a nonce.");
}
}
}
@@ -0,0 +1,19 @@
namespace MxGateway.Server.Workers;
public sealed class WorkerProcessStartedProbe : IWorkerStartupProbe
{
public Task WaitUntilReadyAsync(
IWorkerProcess process,
WorkerProcessLaunchRequest request,
CancellationToken cancellationToken)
{
if (process.HasExited)
{
throw new WorkerProcessLaunchException(
WorkerProcessLaunchErrorCode.StartupFailed,
$"Worker process exited before startup completed with exit code {process.ExitCode}.");
}
return Task.CompletedTask;
}
}
@@ -0,0 +1,13 @@
namespace MxGateway.Server.Workers;
public static class WorkerServiceCollectionExtensions
{
public static IServiceCollection AddWorkerProcessLauncher(this IServiceCollection services)
{
services.AddSingleton<IWorkerProcessFactory, SystemWorkerProcessFactory>();
services.AddSingleton<IWorkerStartupProbe, WorkerProcessStartedProbe>();
services.AddSingleton<IWorkerProcessLauncher, WorkerProcessLauncher>();
return services;
}
}
@@ -13,6 +13,15 @@ public sealed class GatewayLogRedactorTests
Assert.DoesNotContain("super-secret", redacted);
}
[Fact]
public void RedactApiKey_RemovesSecretContainingUnderscores()
{
string? redacted = GatewayLogRedactor.RedactApiKey("Bearer mxgw_operator01_super_secret_value");
Assert.Equal("Bearer mxgw_operator01_[redacted]", redacted);
Assert.DoesNotContain("super_secret_value", redacted);
}
[Theory]
[InlineData("AuthenticateUser")]
[InlineData("WriteSecured")]
@@ -0,0 +1,341 @@
using System.IO.Pipes;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Server.Workers;
namespace MxGateway.Tests.Gateway.Workers;
public sealed class WorkerClientTests
{
private const string SessionId = "session-worker-client";
private const string Nonce = "nonce-worker-client";
private const int WorkerProcessId = 4321;
private static readonly TimeSpan TestTimeout = TimeSpan.FromSeconds(5);
[Fact]
public async Task StartAsync_WithWorkerHelloAndReady_EntersReadyState()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
Assert.Equal(WorkerClientState.Ready, client.State);
Assert.Equal(WorkerProcessId, client.ProcessId);
}
[Fact]
public async Task InvokeAsync_WithMatchingReply_CompletesPendingCommand()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
Task<WorkerCommandReply> invokeTask = client.InvokeAsync(
CreateCommand(MxCommandKind.Ping),
TestTimeout,
CancellationToken.None);
WorkerEnvelope commandEnvelope = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerCommand, commandEnvelope.BodyCase);
Assert.False(string.IsNullOrWhiteSpace(commandEnvelope.CorrelationId));
await pipePair.WorkerWriter.WriteAsync(
CreateCommandReplyEnvelope(commandEnvelope.CorrelationId, MxCommandKind.Ping));
WorkerCommandReply reply = await invokeTask.WaitAsync(TestTimeout);
Assert.Equal(commandEnvelope.CorrelationId, reply.Reply.CorrelationId);
Assert.Equal(MxCommandKind.Ping, reply.Reply.Kind);
}
[Fact]
public async Task InvokeAsync_WithLateReply_IgnoresLateReplyAndKeepsClientReady()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
Task<WorkerCommandReply> timedOutInvokeTask = client.InvokeAsync(
CreateCommand(MxCommandKind.Ping),
TimeSpan.FromMilliseconds(50),
CancellationToken.None);
WorkerEnvelope timedOutCommand = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
WorkerClientException exception = await Assert.ThrowsAsync<WorkerClientException>(
async () => await timedOutInvokeTask);
Assert.Equal(WorkerClientErrorCode.CommandTimeout, exception.ErrorCode);
await pipePair.WorkerWriter.WriteAsync(
CreateCommandReplyEnvelope(timedOutCommand.CorrelationId, MxCommandKind.Ping));
await Task.Delay(TimeSpan.FromMilliseconds(50));
Task<WorkerCommandReply> secondInvokeTask = client.InvokeAsync(
CreateCommand(MxCommandKind.GetWorkerInfo),
TestTimeout,
CancellationToken.None);
WorkerEnvelope secondCommand = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
await pipePair.WorkerWriter.WriteAsync(
CreateCommandReplyEnvelope(secondCommand.CorrelationId, MxCommandKind.GetWorkerInfo));
WorkerCommandReply reply = await secondInvokeTask.WaitAsync(TestTimeout);
Assert.Equal(WorkerClientState.Ready, client.State);
Assert.Equal(MxCommandKind.GetWorkerInfo, reply.Reply.Kind);
}
[Fact]
public async Task ReadEventsAsync_WithWorkerEvents_YieldsEventsInPipeOrder()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
using CancellationTokenSource cancellationTokenSource = new(TestTimeout);
await using IAsyncEnumerator<WorkerEvent> events =
client.ReadEventsAsync(cancellationTokenSource.Token).GetAsyncEnumerator(cancellationTokenSource.Token);
await pipePair.WorkerWriter.WriteAsync(
CreateEventEnvelope(sequence: 11, MxEventFamily.OnDataChange));
await pipePair.WorkerWriter.WriteAsync(
CreateEventEnvelope(sequence: 12, MxEventFamily.OperationComplete));
Assert.True(await events.MoveNextAsync());
Assert.Equal((ulong)11, events.Current.Event.WorkerSequence);
Assert.Equal(MxEventFamily.OnDataChange, events.Current.Event.Family);
Assert.True(await events.MoveNextAsync());
Assert.Equal((ulong)12, events.Current.Event.WorkerSequence);
Assert.Equal(MxEventFamily.OperationComplete, events.Current.Event.Family);
}
[Fact]
public async Task ReadLoop_WhenPipeDisconnects_FaultsClient()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
await pipePair.DisposeWorkerSideAsync();
await WaitUntilAsync(
() => client.State == WorkerClientState.Faulted,
TestTimeout);
Assert.Equal(WorkerClientState.Faulted, client.State);
}
[Fact]
public async Task HeartbeatMonitor_WhenHeartbeatExpires_FaultsClient()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(
pipePair,
new WorkerClientOptions
{
HeartbeatGrace = TimeSpan.FromMilliseconds(80),
HeartbeatCheckInterval = TimeSpan.FromMilliseconds(20),
EventChannelCapacity = 8,
});
await CompleteHandshakeAsync(client, pipePair);
await WaitUntilAsync(
() => client.State == WorkerClientState.Faulted,
TestTimeout);
Assert.Equal(WorkerClientState.Faulted, client.State);
}
private static WorkerClient CreateClient(
PipePair pipePair,
WorkerClientOptions? options = null)
{
WorkerFrameProtocolOptions frameOptions = new(SessionId);
WorkerClientConnection connection = new(
SessionId,
Nonce,
pipePair.GatewayStream,
frameOptions);
return new WorkerClient(connection, options);
}
private static async Task CompleteHandshakeAsync(
WorkerClient client,
PipePair pipePair)
{
Task startTask = client.StartAsync(CancellationToken.None);
WorkerEnvelope gatewayHello = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
Assert.Equal(WorkerEnvelope.BodyOneofCase.GatewayHello, gatewayHello.BodyCase);
Assert.Equal(Nonce, gatewayHello.GatewayHello.Nonce);
Assert.Equal(GatewayContractInfo.WorkerProtocolVersion, gatewayHello.GatewayHello.SupportedProtocolVersion);
await pipePair.WorkerWriter.WriteAsync(CreateWorkerHelloEnvelope());
await pipePair.WorkerWriter.WriteAsync(CreateWorkerReadyEnvelope());
await startTask.WaitAsync(TestTimeout);
}
private static WorkerCommand CreateCommand(MxCommandKind kind)
{
return new WorkerCommand
{
Command = new MxCommand
{
Kind = kind,
},
};
}
private static WorkerEnvelope CreateWorkerHelloEnvelope()
{
return CreateWorkerEnvelope(
correlationId: string.Empty,
sequence: 1,
envelope => envelope.WorkerHello = new WorkerHello
{
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
Nonce = Nonce,
WorkerProcessId = WorkerProcessId,
WorkerVersion = "fake-worker",
});
}
private static WorkerEnvelope CreateWorkerReadyEnvelope()
{
return CreateWorkerEnvelope(
correlationId: string.Empty,
sequence: 2,
envelope => envelope.WorkerReady = new WorkerReady
{
WorkerProcessId = WorkerProcessId,
MxaccessProgid = "LMXProxy.LMXProxyServer.1",
MxaccessClsid = "{C30B52F5-2CB5-4760-AF0A-3A344A7EB5DC}",
});
}
private static WorkerEnvelope CreateCommandReplyEnvelope(
string correlationId,
MxCommandKind kind)
{
return CreateWorkerEnvelope(
correlationId,
sequence: 10,
envelope => envelope.WorkerCommandReply = new WorkerCommandReply
{
Reply = new MxCommandReply
{
SessionId = SessionId,
CorrelationId = correlationId,
Kind = kind,
},
});
}
private static WorkerEnvelope CreateEventEnvelope(
ulong sequence,
MxEventFamily family)
{
return CreateWorkerEnvelope(
correlationId: string.Empty,
sequence,
envelope => envelope.WorkerEvent = new WorkerEvent
{
Event = new MxEvent
{
SessionId = SessionId,
Family = family,
WorkerSequence = sequence,
},
});
}
private static WorkerEnvelope CreateWorkerEnvelope(
string correlationId,
ulong sequence,
Action<WorkerEnvelope> setBody)
{
WorkerEnvelope envelope = new()
{
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
SessionId = SessionId,
Sequence = sequence,
CorrelationId = correlationId,
};
setBody(envelope);
return envelope;
}
private static async Task WaitUntilAsync(
Func<bool> predicate,
TimeSpan timeout)
{
using CancellationTokenSource cancellationTokenSource = new(timeout);
while (!predicate())
{
await Task.Delay(TimeSpan.FromMilliseconds(10), cancellationTokenSource.Token);
}
}
private sealed class PipePair : IAsyncDisposable
{
private readonly NamedPipeClientStream _workerStream;
private bool _workerSideDisposed;
private PipePair(
NamedPipeServerStream gatewayStream,
NamedPipeClientStream workerStream)
{
GatewayStream = gatewayStream;
_workerStream = workerStream;
WorkerReader = new WorkerFrameReader(_workerStream, new WorkerFrameProtocolOptions(SessionId));
WorkerWriter = new WorkerFrameWriter(_workerStream, new WorkerFrameProtocolOptions(SessionId));
}
public NamedPipeServerStream GatewayStream { get; }
public WorkerFrameReader WorkerReader { get; }
public WorkerFrameWriter WorkerWriter { get; }
public static async Task<PipePair> CreateAsync()
{
string pipeName = $"mxaccessgw-workerclient-tests-{Guid.NewGuid():N}";
NamedPipeServerStream gatewayStream = new(
pipeName,
PipeDirection.InOut,
maxNumberOfServerInstances: 1,
PipeTransmissionMode.Byte,
PipeOptions.Asynchronous);
NamedPipeClientStream workerStream = new(
".",
pipeName,
PipeDirection.InOut,
PipeOptions.Asynchronous);
Task waitForConnectionTask = gatewayStream.WaitForConnectionAsync();
await workerStream.ConnectAsync();
await waitForConnectionTask;
return new PipePair(gatewayStream, workerStream);
}
public async ValueTask DisposeWorkerSideAsync()
{
if (_workerSideDisposed)
{
return;
}
await _workerStream.DisposeAsync();
_workerSideDisposed = true;
}
public async ValueTask DisposeAsync()
{
await DisposeWorkerSideAsync();
await GatewayStream.DisposeAsync();
}
}
}
@@ -0,0 +1,307 @@
using System.Diagnostics;
using Microsoft.Extensions.Options;
using MxGateway.Contracts;
using MxGateway.Server.Configuration;
using MxGateway.Server.Metrics;
using MxGateway.Server.Workers;
namespace MxGateway.Tests.Gateway.Workers;
public sealed class WorkerProcessLauncherTests
{
private const string SessionId = "session-1";
private const string PipeName = "mxaccess-gateway-123-session-1";
private const string Nonce = "super-secret-nonce";
[Fact]
public async Task LaunchAsync_WithValidWorker_StartsProcessWithBootstrapArgumentsAndNonceEnvironment()
{
using TestDirectory directory = TestDirectory.Create();
string executablePath = directory.CreateWorkerExecutable(machine: 0x014c);
FakeWorkerProcess process = new(processId: 1234);
FakePipeReservation pipeReservation = new();
FakeWorkerProcessFactory processFactory = new(process);
GatewayMetrics metrics = new();
WorkerProcessLauncher launcher = CreateLauncher(executablePath, processFactory, new SucceedingStartupProbe(), metrics);
using WorkerProcessHandle handle = await launcher.LaunchAsync(CreateRequest(pipeReservation));
Assert.Equal(1234, handle.ProcessId);
Assert.Same(process, handle.Process);
Assert.NotNull(processFactory.LastStartInfo);
Assert.Equal(Path.GetFullPath(executablePath), processFactory.LastStartInfo.FileName);
Assert.False(processFactory.LastStartInfo.UseShellExecute);
Assert.True(processFactory.LastStartInfo.CreateNoWindow);
Assert.Equal(
["--session-id", SessionId, "--pipe-name", PipeName, "--protocol-version", "1"],
processFactory.LastStartInfo.ArgumentList);
Assert.Equal(Nonce, processFactory.LastStartInfo.Environment[WorkerProcessLauncher.WorkerNonceEnvironmentVariableName]);
Assert.DoesNotContain(Nonce, handle.CommandLine.ToString(), StringComparison.Ordinal);
Assert.DoesNotContain(Nonce, string.Join(" ", handle.CommandLine.Arguments), StringComparison.Ordinal);
Assert.False(pipeReservation.DisposeCalled);
Assert.Equal(1, metrics.GetSnapshot().WorkersRunning);
}
[Fact]
public async Task LaunchAsync_WhenStartupProbeFails_KillsAndDisposesWorker()
{
using TestDirectory directory = TestDirectory.Create();
string executablePath = directory.CreateWorkerExecutable(machine: 0x014c);
FakeWorkerProcess process = new(processId: 1234);
FakePipeReservation pipeReservation = new();
GatewayMetrics metrics = new();
WorkerProcessLauncher launcher = CreateLauncher(
executablePath,
new FakeWorkerProcessFactory(process),
new FailingStartupProbe(),
metrics);
WorkerProcessLaunchException exception =
await Assert.ThrowsAsync<WorkerProcessLaunchException>(
async () => await launcher.LaunchAsync(CreateRequest(pipeReservation)));
Assert.Equal(WorkerProcessLaunchErrorCode.StartupFailed, exception.ErrorCode);
Assert.True(process.KillCalled);
Assert.True(process.DisposeCalled);
Assert.True(pipeReservation.DisposeCalled);
Assert.Equal(1, metrics.GetSnapshot().WorkerKills);
}
[Fact]
public async Task LaunchAsync_WhenStartupTimesOut_KillsAndDisposesWorker()
{
using TestDirectory directory = TestDirectory.Create();
string executablePath = directory.CreateWorkerExecutable(machine: 0x014c);
FakeWorkerProcess process = new(processId: 1234);
GatewayMetrics metrics = new();
WorkerProcessLauncher launcher = CreateLauncher(
executablePath,
new FakeWorkerProcessFactory(process),
new WaitingStartupProbe(),
metrics,
startupTimeoutSeconds: 1);
WorkerProcessLaunchException exception =
await Assert.ThrowsAsync<WorkerProcessLaunchException>(
async () => await launcher.LaunchAsync(CreateRequest()));
Assert.Equal(WorkerProcessLaunchErrorCode.StartupTimeout, exception.ErrorCode);
Assert.True(process.KillCalled);
Assert.True(process.DisposeCalled);
Assert.Equal(1, metrics.GetSnapshot().WorkerKills);
}
[Fact]
public async Task LaunchAsync_WhenExecutableDoesNotExist_FailsBeforeStartingProcess()
{
using TestDirectory directory = TestDirectory.Create();
string executablePath = Path.Combine(directory.Path, "missing-worker.exe");
FakeWorkerProcessFactory processFactory = new(new FakeWorkerProcess(processId: 1234));
WorkerProcessLauncher launcher = CreateLauncher(executablePath, processFactory, new SucceedingStartupProbe());
WorkerProcessLaunchException exception =
await Assert.ThrowsAsync<WorkerProcessLaunchException>(
async () => await launcher.LaunchAsync(CreateRequest()));
Assert.Equal(WorkerProcessLaunchErrorCode.ExecutableNotFound, exception.ErrorCode);
Assert.Null(processFactory.LastStartInfo);
}
[Fact]
public async Task LaunchAsync_WhenExecutableArchitectureDoesNotMatch_FailsBeforeStartingProcess()
{
using TestDirectory directory = TestDirectory.Create();
string executablePath = directory.CreateWorkerExecutable(machine: 0x8664);
FakeWorkerProcessFactory processFactory = new(new FakeWorkerProcess(processId: 1234));
WorkerProcessLauncher launcher = CreateLauncher(executablePath, processFactory, new SucceedingStartupProbe());
WorkerProcessLaunchException exception =
await Assert.ThrowsAsync<WorkerProcessLaunchException>(
async () => await launcher.LaunchAsync(CreateRequest()));
Assert.Equal(WorkerProcessLaunchErrorCode.InvalidExecutable, exception.ErrorCode);
Assert.Null(processFactory.LastStartInfo);
}
[Fact]
public async Task LaunchAsync_WhenWorkerAlreadyExited_FailsAndDisposesWorkerWithoutKill()
{
using TestDirectory directory = TestDirectory.Create();
string executablePath = directory.CreateWorkerExecutable(machine: 0x014c);
FakeWorkerProcess process = new(processId: 1234)
{
HasExited = true,
ExitCode = 42,
};
WorkerProcessLauncher launcher = CreateLauncher(
executablePath,
new FakeWorkerProcessFactory(process),
new WorkerProcessStartedProbe());
WorkerProcessLaunchException exception =
await Assert.ThrowsAsync<WorkerProcessLaunchException>(
async () => await launcher.LaunchAsync(CreateRequest()));
Assert.Equal(WorkerProcessLaunchErrorCode.StartupFailed, exception.ErrorCode);
Assert.False(process.KillCalled);
Assert.True(process.DisposeCalled);
}
private static WorkerProcessLauncher CreateLauncher(
string executablePath,
IWorkerProcessFactory processFactory,
IWorkerStartupProbe startupProbe,
GatewayMetrics? metrics = null,
int startupTimeoutSeconds = 30)
{
GatewayOptions options = new()
{
Worker = new WorkerOptions
{
ExecutablePath = executablePath,
RequiredArchitecture = WorkerArchitecture.X86,
StartupTimeoutSeconds = startupTimeoutSeconds,
},
};
return new WorkerProcessLauncher(
Options.Create(options),
processFactory,
startupProbe,
metrics ?? new GatewayMetrics());
}
private static WorkerProcessLaunchRequest CreateRequest(IDisposable? pipeReservation = null)
{
return new WorkerProcessLaunchRequest(
SessionId,
PipeName,
GatewayContractInfo.WorkerProtocolVersion,
Nonce,
pipeReservation);
}
private sealed class FakeWorkerProcessFactory(IWorkerProcess process) : IWorkerProcessFactory
{
public ProcessStartInfo? LastStartInfo { get; private set; }
public IWorkerProcess Start(ProcessStartInfo startInfo)
{
LastStartInfo = startInfo;
return process;
}
}
private sealed class FakeWorkerProcess(int processId) : IWorkerProcess
{
public int Id { get; } = processId;
public bool HasExited { get; set; }
public int? ExitCode { get; set; }
public bool DisposeCalled { get; private set; }
public bool KillCalled { get; private set; }
public ValueTask WaitForExitAsync(CancellationToken cancellationToken)
{
return ValueTask.CompletedTask;
}
public void Kill(bool entireProcessTree)
{
Assert.True(entireProcessTree);
KillCalled = true;
HasExited = true;
}
public void Dispose()
{
DisposeCalled = true;
}
}
private sealed class SucceedingStartupProbe : IWorkerStartupProbe
{
public Task WaitUntilReadyAsync(
IWorkerProcess process,
WorkerProcessLaunchRequest request,
CancellationToken cancellationToken)
{
return Task.CompletedTask;
}
}
private sealed class FailingStartupProbe : IWorkerStartupProbe
{
public Task WaitUntilReadyAsync(
IWorkerProcess process,
WorkerProcessLaunchRequest request,
CancellationToken cancellationToken)
{
throw new InvalidOperationException("Fake worker startup failed.");
}
}
private sealed class WaitingStartupProbe : IWorkerStartupProbe
{
public async Task WaitUntilReadyAsync(
IWorkerProcess process,
WorkerProcessLaunchRequest request,
CancellationToken cancellationToken)
{
await Task.Delay(Timeout.InfiniteTimeSpan, cancellationToken);
}
}
private sealed class FakePipeReservation : IDisposable
{
public bool DisposeCalled { get; private set; }
public void Dispose()
{
DisposeCalled = true;
}
}
private sealed class TestDirectory : IDisposable
{
private TestDirectory(string path)
{
Path = path;
}
public string Path { get; }
public static TestDirectory Create()
{
string path = System.IO.Path.Combine(System.IO.Path.GetTempPath(), $"mxgateway-tests-{Guid.NewGuid():N}");
Directory.CreateDirectory(path);
return new TestDirectory(path);
}
public string CreateWorkerExecutable(ushort machine)
{
string path = System.IO.Path.Combine(Path, "MxGateway.Worker.exe");
byte[] bytes = new byte[0x100];
bytes[0] = (byte)'M';
bytes[1] = (byte)'Z';
BitConverter.GetBytes(0x80).CopyTo(bytes, 0x3c);
bytes[0x80] = (byte)'P';
bytes[0x81] = (byte)'E';
bytes[0x82] = 0;
bytes[0x83] = 0;
BitConverter.GetBytes(machine).CopyTo(bytes, 0x84);
File.WriteAllBytes(path, bytes);
return path;
}
public void Dispose()
{
Directory.Delete(Path, recursive: true);
}
}
}
@@ -0,0 +1,242 @@
using System.Text.Json;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using MxGateway.Server.Configuration;
using MxGateway.Server.Security.Authentication;
namespace MxGateway.Tests.Security.Authentication;
public sealed class ApiKeyAdminCliRunnerTests
{
[Fact]
public async Task CreateKeyAsync_CreatesAuthenticatingKeyAndAudits()
{
await using ServiceProvider services = BuildServices(CreateTempDatabasePath());
ApiKeyAdminCliRunner runner = services.GetRequiredService<ApiKeyAdminCliRunner>();
StringWriter output = new();
await runner.RunAsync(
new ApiKeyAdminCommand(
Kind: ApiKeyAdminCommandKind.CreateKey,
Json: true,
SqlitePath: null,
Pepper: null,
KeyId: "operator01",
DisplayName: "Operator",
Scopes: new HashSet<string>(StringComparer.Ordinal) { "session:open", "events:read" }),
output,
CancellationToken.None);
string apiKey = ReadApiKey(output.ToString());
IApiKeyVerifier verifier = services.GetRequiredService<IApiKeyVerifier>();
ApiKeyVerificationResult verification = await verifier.VerifyAsync($"Bearer {apiKey}", CancellationToken.None);
Assert.True(verification.Succeeded);
Assert.NotNull(verification.Identity);
Assert.Equal("operator01", verification.Identity.KeyId);
Assert.Contains("session:open", verification.Identity.Scopes);
IReadOnlyList<ApiKeyAuditRecord> auditRecords = await services
.GetRequiredService<IApiKeyAuditStore>()
.ListRecentAsync(10, CancellationToken.None);
Assert.Contains(auditRecords, record => record.EventType == "create-key" && record.KeyId == "operator01");
}
[Fact]
public async Task ListKeysAsync_DoesNotPrintRawSecret()
{
await using ServiceProvider services = BuildServices(CreateTempDatabasePath());
ApiKeyAdminCliRunner runner = services.GetRequiredService<ApiKeyAdminCliRunner>();
string apiKey = await CreateKeyAsync(runner, "operator01");
StringWriter listOutput = new();
await runner.RunAsync(
new ApiKeyAdminCommand(
Kind: ApiKeyAdminCommandKind.ListKeys,
Json: true,
SqlitePath: null,
Pepper: null,
KeyId: null,
DisplayName: null,
Scopes: new HashSet<string>(StringComparer.Ordinal)),
listOutput,
CancellationToken.None);
string listJson = listOutput.ToString();
Assert.Contains("operator01", listJson, StringComparison.Ordinal);
Assert.DoesNotContain(apiKey, listJson, StringComparison.Ordinal);
Assert.DoesNotContain(ApiKeySecret(apiKey), listJson, StringComparison.Ordinal);
Assert.DoesNotContain("secret_hash", listJson, StringComparison.OrdinalIgnoreCase);
}
[Fact]
public async Task RevokeKeyAsync_RevokedKeyFailsVerificationAndAudits()
{
await using ServiceProvider services = BuildServices(CreateTempDatabasePath());
ApiKeyAdminCliRunner runner = services.GetRequiredService<ApiKeyAdminCliRunner>();
string apiKey = await CreateKeyAsync(runner, "operator01");
await runner.RunAsync(
new ApiKeyAdminCommand(
Kind: ApiKeyAdminCommandKind.RevokeKey,
Json: true,
SqlitePath: null,
Pepper: null,
KeyId: "operator01",
DisplayName: null,
Scopes: new HashSet<string>(StringComparer.Ordinal)),
TextWriter.Null,
CancellationToken.None);
ApiKeyVerificationResult verification = await services
.GetRequiredService<IApiKeyVerifier>()
.VerifyAsync($"Bearer {apiKey}", CancellationToken.None);
Assert.False(verification.Succeeded);
Assert.Equal(ApiKeyVerificationFailure.KeyRevoked, verification.Failure);
IReadOnlyList<ApiKeyAuditRecord> auditRecords = await services
.GetRequiredService<IApiKeyAuditStore>()
.ListRecentAsync(10, CancellationToken.None);
Assert.Contains(auditRecords, record => record.EventType == "revoke-key" && record.KeyId == "operator01");
}
[Fact]
public async Task RotateKeyAsync_PrintsNewSecretOnceAndInvalidatesOldSecret()
{
await using ServiceProvider services = BuildServices(CreateTempDatabasePath());
ApiKeyAdminCliRunner runner = services.GetRequiredService<ApiKeyAdminCliRunner>();
string oldApiKey = await CreateKeyAsync(runner, "operator01");
StringWriter rotateOutput = new();
await runner.RunAsync(
new ApiKeyAdminCommand(
Kind: ApiKeyAdminCommandKind.RotateKey,
Json: true,
SqlitePath: null,
Pepper: null,
KeyId: "operator01",
DisplayName: null,
Scopes: new HashSet<string>(StringComparer.Ordinal)),
rotateOutput,
CancellationToken.None);
string rotateJson = rotateOutput.ToString();
string newApiKey = ReadApiKey(rotateJson);
Assert.NotEqual(oldApiKey, newApiKey);
Assert.Equal(1, CountOccurrences(rotateJson, newApiKey));
IApiKeyVerifier verifier = services.GetRequiredService<IApiKeyVerifier>();
ApiKeyVerificationResult oldVerification = await verifier.VerifyAsync($"Bearer {oldApiKey}", CancellationToken.None);
ApiKeyVerificationResult newVerification = await verifier.VerifyAsync($"Bearer {newApiKey}", CancellationToken.None);
Assert.False(oldVerification.Succeeded);
Assert.Equal(ApiKeyVerificationFailure.SecretMismatch, oldVerification.Failure);
Assert.True(newVerification.Succeeded);
}
[Fact]
public async Task CreateKeyAsync_PrintsRawSecretExactlyOnce()
{
await using ServiceProvider services = BuildServices(CreateTempDatabasePath());
ApiKeyAdminCliRunner runner = services.GetRequiredService<ApiKeyAdminCliRunner>();
StringWriter output = new();
await runner.RunAsync(
new ApiKeyAdminCommand(
Kind: ApiKeyAdminCommandKind.CreateKey,
Json: true,
SqlitePath: null,
Pepper: null,
KeyId: "operator01",
DisplayName: "Operator",
Scopes: new HashSet<string>(StringComparer.Ordinal)),
output,
CancellationToken.None);
string json = output.ToString();
string apiKey = ReadApiKey(json);
Assert.Equal(1, CountOccurrences(json, apiKey));
Assert.Equal(1, CountOccurrences(json, ApiKeySecret(apiKey)));
}
private static async Task<string> CreateKeyAsync(ApiKeyAdminCliRunner runner, string keyId)
{
StringWriter output = new();
await runner.RunAsync(
new ApiKeyAdminCommand(
Kind: ApiKeyAdminCommandKind.CreateKey,
Json: true,
SqlitePath: null,
Pepper: null,
KeyId: keyId,
DisplayName: "Operator",
Scopes: new HashSet<string>(StringComparer.Ordinal) { "session:open" }),
output,
CancellationToken.None);
return ReadApiKey(output.ToString());
}
private static ServiceProvider BuildServices(string databasePath)
{
IConfigurationRoot configuration = new ConfigurationBuilder()
.AddInMemoryCollection(
new Dictionary<string, string?>
{
["MxGateway:Authentication:SqlitePath"] = databasePath,
["MxGateway:ApiKeyPepper"] = "test-pepper"
})
.Build();
ServiceCollection services = new();
services.AddSingleton<IConfiguration>(configuration);
services.AddGatewayConfiguration();
services.AddSqliteAuthStore();
return services.BuildServiceProvider(validateScopes: true);
}
private static string CreateTempDatabasePath()
{
string directory = Path.Combine(Path.GetTempPath(), "mxgateway-auth-cli-tests", Guid.NewGuid().ToString("N"));
Directory.CreateDirectory(directory);
return Path.Combine(directory, "gateway-auth.db");
}
private static string ReadApiKey(string json)
{
using JsonDocument document = JsonDocument.Parse(json);
return document.RootElement.GetProperty("ApiKey").GetString()
?? throw new InvalidOperationException("API key was not present in command output.");
}
private static string ApiKeySecret(string apiKey)
{
string[] parts = apiKey.Split('_', 3);
return parts[2];
}
private static int CountOccurrences(string value, string pattern)
{
int count = 0;
int index = 0;
while ((index = value.IndexOf(pattern, index, StringComparison.Ordinal)) >= 0)
{
count++;
index += pattern.Length;
}
return count;
}
}
@@ -0,0 +1,70 @@
using MxGateway.Server.Security.Authentication;
namespace MxGateway.Tests.Security.Authentication;
public sealed class ApiKeyAdminCommandLineParserTests
{
[Fact]
public void Parse_NonApiKeyCommand_ReturnsNotApiKeyCommand()
{
ApiKeyAdminParseResult result = ApiKeyAdminCommandLineParser.Parse(["--urls=http://localhost:5000"]);
Assert.False(result.IsApiKeyCommand);
Assert.Null(result.Command);
}
[Fact]
public void Parse_CreateKeyCommand_ReturnsOptions()
{
ApiKeyAdminParseResult result = ApiKeyAdminCommandLineParser.Parse(
[
"apikey",
"create-key",
"--key-id",
"operator01",
"--display-name",
"Operator",
"--scopes",
"session:open,events:read",
"--sqlite-path",
"auth.db",
"--pepper",
"pepper",
"--json"
]);
Assert.True(result.IsApiKeyCommand);
Assert.Null(result.Error);
Assert.NotNull(result.Command);
Assert.Equal(ApiKeyAdminCommandKind.CreateKey, result.Command.Kind);
Assert.True(result.Command.Json);
Assert.Equal("operator01", result.Command.KeyId);
Assert.Equal("Operator", result.Command.DisplayName);
Assert.Equal("auth.db", result.Command.SqlitePath);
Assert.Equal("pepper", result.Command.Pepper);
Assert.Contains("session:open", result.Command.Scopes);
Assert.Contains("events:read", result.Command.Scopes);
}
[Fact]
public void Parse_CreateKeyWithoutDisplayName_ReturnsError()
{
ApiKeyAdminParseResult result = ApiKeyAdminCommandLineParser.Parse(
["apikey", "create-key", "--key-id", "operator01"]);
Assert.True(result.IsApiKeyCommand);
Assert.Null(result.Command);
Assert.Contains("--display-name", result.Error, StringComparison.Ordinal);
}
[Fact]
public void Parse_KeyIdWithUnderscore_ReturnsError()
{
ApiKeyAdminParseResult result = ApiKeyAdminCommandLineParser.Parse(
["apikey", "revoke-key", "--key-id", "operator_01"]);
Assert.True(result.IsApiKeyCommand);
Assert.Null(result.Command);
Assert.Contains("letters, numbers, periods, and hyphens", result.Error, StringComparison.Ordinal);
}
}
@@ -0,0 +1,38 @@
using MxGateway.Server.Security.Authentication;
namespace MxGateway.Tests.Security.Authentication;
public sealed class ApiKeyParserTests
{
[Fact]
public void TryParseAuthorizationHeader_ValidBearerToken_ReturnsKeyIdAndSecret()
{
ApiKeyParser parser = new();
bool parsed = parser.TryParseAuthorizationHeader(
"Bearer mxgw_operator01_secret_value",
out ParsedApiKey? apiKey);
Assert.True(parsed);
Assert.NotNull(apiKey);
Assert.Equal("operator01", apiKey.KeyId);
Assert.Equal("secret_value", apiKey.Secret);
}
[Theory]
[InlineData(null)]
[InlineData("")]
[InlineData("mxgw_operator01_secret")]
[InlineData("Bearer not-a-gateway-key")]
[InlineData("Bearer mxgw__secret")]
[InlineData("Bearer mxgw_operator01_")]
public void TryParseAuthorizationHeader_MalformedToken_ReturnsFalse(string? authorizationHeader)
{
ApiKeyParser parser = new();
bool parsed = parser.TryParseAuthorizationHeader(authorizationHeader, out ParsedApiKey? apiKey);
Assert.False(parsed);
Assert.Null(apiKey);
}
}
@@ -0,0 +1,62 @@
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Options;
using MxGateway.Server.Configuration;
using MxGateway.Server.Security.Authentication;
namespace MxGateway.Tests.Security.Authentication;
public sealed class ApiKeySecretHasherTests
{
[Fact]
public void HashSecret_SamePepperAndSecret_ReturnsSameHash()
{
ApiKeySecretHasher hasher = CreateHasher("pepper-one");
byte[] firstHash = hasher.HashSecret("raw-secret");
byte[] secondHash = hasher.HashSecret("raw-secret");
Assert.Equal(firstHash, secondHash);
Assert.NotEqual("raw-secret"u8.ToArray(), firstHash);
}
[Fact]
public void HashSecret_DifferentPepper_ReturnsDifferentHash()
{
byte[] firstHash = CreateHasher("pepper-one").HashSecret("raw-secret");
byte[] secondHash = CreateHasher("pepper-two").HashSecret("raw-secret");
Assert.NotEqual(firstHash, secondHash);
}
[Fact]
public void HashSecret_MissingPepper_Throws()
{
ApiKeySecretHasher hasher = CreateHasher(pepper: null);
Assert.Throws<ApiKeyPepperUnavailableException>(() => hasher.HashSecret("raw-secret"));
}
private static ApiKeySecretHasher CreateHasher(string? pepper)
{
Dictionary<string, string?> values = [];
if (pepper is not null)
{
values["TestPepper"] = pepper;
}
IConfigurationRoot configuration = new ConfigurationBuilder()
.AddInMemoryCollection(values)
.Build();
GatewayOptions options = new()
{
Authentication = new AuthenticationOptions
{
PepperSecretName = "TestPepper"
}
};
return new ApiKeySecretHasher(configuration, Options.Create(options));
}
}
@@ -0,0 +1,193 @@
using System.Text.Json;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Options;
using MxGateway.Server.Configuration;
using MxGateway.Server.Security.Authentication;
namespace MxGateway.Tests.Security.Authentication;
public sealed class ApiKeyVerifierTests
{
[Fact]
public async Task VerifyAsync_ValidKey_ReturnsIdentityAndScopes()
{
ApiKeySecretHasher hasher = CreateHasher("pepper");
FakeApiKeyStore store = new(CreateRecord(hasher, revokedUtc: null));
ApiKeyVerifier verifier = new(new ApiKeyParser(), hasher, store);
ApiKeyVerificationResult result = await verifier.VerifyAsync(
"Bearer mxgw_operator01_correct-secret",
CancellationToken.None);
Assert.True(result.Succeeded);
Assert.NotNull(result.Identity);
Assert.Equal("operator01", result.Identity.KeyId);
Assert.Equal("Operator Key", result.Identity.DisplayName);
Assert.Contains("session:open", result.Identity.Scopes);
Assert.Contains("events:read", result.Identity.Scopes);
Assert.True(store.MarkedUsed);
}
[Fact]
public async Task VerifyAsync_ValidKey_DoesNotExposeRawSecretInResult()
{
ApiKeySecretHasher hasher = CreateHasher("pepper");
FakeApiKeyStore store = new(CreateRecord(hasher, revokedUtc: null));
ApiKeyVerifier verifier = new(new ApiKeyParser(), hasher, store);
ApiKeyVerificationResult result = await verifier.VerifyAsync(
"Bearer mxgw_operator01_correct-secret",
CancellationToken.None);
string serialized = JsonSerializer.Serialize(result);
Assert.DoesNotContain("correct-secret", serialized, StringComparison.Ordinal);
}
[Theory]
[InlineData(null)]
[InlineData("Bearer mxgw_operator01")]
[InlineData("Bearer wrong")]
public async Task VerifyAsync_MalformedKey_FailsUnauthenticated(string? authorizationHeader)
{
ApiKeyVerifier verifier = new(
new ApiKeyParser(),
CreateHasher("pepper"),
new FakeApiKeyStore(storedKey: null));
ApiKeyVerificationResult result = await verifier.VerifyAsync(
authorizationHeader,
CancellationToken.None);
Assert.False(result.Succeeded);
Assert.Equal(ApiKeyVerificationFailure.MissingOrMalformedCredentials, result.Failure);
}
[Fact]
public async Task VerifyAsync_UnknownKey_Fails()
{
ApiKeyVerifier verifier = new(
new ApiKeyParser(),
CreateHasher("pepper"),
new FakeApiKeyStore(storedKey: null));
ApiKeyVerificationResult result = await verifier.VerifyAsync(
"Bearer mxgw_missing_secret",
CancellationToken.None);
Assert.False(result.Succeeded);
Assert.Equal(ApiKeyVerificationFailure.KeyNotFound, result.Failure);
}
[Fact]
public async Task VerifyAsync_WrongSecret_Fails()
{
ApiKeySecretHasher hasher = CreateHasher("pepper");
FakeApiKeyStore store = new(CreateRecord(hasher, revokedUtc: null));
ApiKeyVerifier verifier = new(new ApiKeyParser(), hasher, store);
ApiKeyVerificationResult result = await verifier.VerifyAsync(
"Bearer mxgw_operator01_wrong-secret",
CancellationToken.None);
Assert.False(result.Succeeded);
Assert.Equal(ApiKeyVerificationFailure.SecretMismatch, result.Failure);
Assert.False(store.MarkedUsed);
}
[Fact]
public async Task VerifyAsync_RevokedKey_Fails()
{
ApiKeySecretHasher hasher = CreateHasher("pepper");
FakeApiKeyStore store = new(CreateRecord(hasher, DateTimeOffset.UtcNow));
ApiKeyVerifier verifier = new(new ApiKeyParser(), hasher, store);
ApiKeyVerificationResult result = await verifier.VerifyAsync(
"Bearer mxgw_operator01_correct-secret",
CancellationToken.None);
Assert.False(result.Succeeded);
Assert.Equal(ApiKeyVerificationFailure.KeyRevoked, result.Failure);
Assert.False(store.MarkedUsed);
}
[Fact]
public async Task VerifyAsync_MissingPepper_Fails()
{
FakeApiKeyStore store = new(CreateRecord(CreateHasher("pepper"), revokedUtc: null));
ApiKeyVerifier verifier = new(new ApiKeyParser(), CreateHasher(pepper: null), store);
ApiKeyVerificationResult result = await verifier.VerifyAsync(
"Bearer mxgw_operator01_correct-secret",
CancellationToken.None);
Assert.False(result.Succeeded);
Assert.Equal(ApiKeyVerificationFailure.PepperUnavailable, result.Failure);
}
private static ApiKeyRecord CreateRecord(ApiKeySecretHasher hasher, DateTimeOffset? revokedUtc)
{
return new ApiKeyRecord(
KeyId: "operator01",
KeyPrefix: "mxgw_operator01",
SecretHash: hasher.HashSecret("correct-secret"),
DisplayName: "Operator Key",
Scopes: new HashSet<string>(StringComparer.Ordinal)
{
"session:open",
"events:read"
},
CreatedUtc: DateTimeOffset.UtcNow,
LastUsedUtc: null,
RevokedUtc: revokedUtc);
}
private static ApiKeySecretHasher CreateHasher(string? pepper)
{
Dictionary<string, string?> values = [];
if (pepper is not null)
{
values["TestPepper"] = pepper;
}
IConfigurationRoot configuration = new ConfigurationBuilder()
.AddInMemoryCollection(values)
.Build();
GatewayOptions options = new()
{
Authentication = new AuthenticationOptions
{
PepperSecretName = "TestPepper"
}
};
return new ApiKeySecretHasher(configuration, Options.Create(options));
}
private sealed class FakeApiKeyStore(ApiKeyRecord? storedKey) : IApiKeyStore
{
public bool MarkedUsed { get; private set; }
public Task<ApiKeyRecord?> FindByKeyIdAsync(string keyId, CancellationToken cancellationToken)
{
return Task.FromResult(storedKey?.KeyId == keyId ? storedKey : null);
}
public Task<ApiKeyRecord?> FindActiveByKeyIdAsync(string keyId, CancellationToken cancellationToken)
{
return Task.FromResult(
storedKey?.KeyId == keyId && storedKey.RevokedUtc is null
? storedKey
: null);
}
public Task MarkKeyUsedAsync(string keyId, DateTimeOffset usedUtc, CancellationToken cancellationToken)
{
MarkedUsed = storedKey?.KeyId == keyId;
return Task.CompletedTask;
}
}
}
@@ -0,0 +1,267 @@
using Grpc.Core;
using Microsoft.Extensions.Options;
using MxGateway.Contracts.Proto;
using MxGateway.Server.Configuration;
using MxGateway.Server.Security.Authentication;
using MxGateway.Server.Security.Authorization;
namespace MxGateway.Tests.Security.Authorization;
public sealed class GatewayGrpcAuthorizationInterceptorTests
{
[Fact]
public async Task UnaryServerHandler_MissingApiKey_ReturnsUnauthenticated()
{
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
new FakeApiKeyVerifier(ApiKeyVerificationResult.Fail(
ApiKeyVerificationFailure.MissingOrMalformedCredentials)),
new GatewayRequestIdentityAccessor());
RpcException exception = await Assert.ThrowsAsync<RpcException>(
() => interceptor.UnaryServerHandler(
new OpenSessionRequest(),
new TestServerCallContext([]),
(_, _) => Task.FromResult(new OpenSessionReply())));
Assert.Equal(StatusCode.Unauthenticated, exception.StatusCode);
Assert.DoesNotContain("secret", exception.Status.Detail, StringComparison.OrdinalIgnoreCase);
}
[Fact]
public async Task UnaryServerHandler_InvalidApiKey_DoesNotExposeRawCredentialInStatus()
{
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
new FakeApiKeyVerifier(ApiKeyVerificationResult.Fail(ApiKeyVerificationFailure.SecretMismatch)),
new GatewayRequestIdentityAccessor());
RpcException exception = await Assert.ThrowsAsync<RpcException>(
() => interceptor.UnaryServerHandler(
new OpenSessionRequest(),
ContextWithAuthorization("Bearer mxgw_operator01_super-secret"),
(_, _) => Task.FromResult(new OpenSessionReply())));
Assert.Equal(StatusCode.Unauthenticated, exception.StatusCode);
Assert.DoesNotContain("super-secret", exception.Status.Detail, StringComparison.Ordinal);
}
[Fact]
public async Task UnaryServerHandler_ValidApiKeyMissingScope_ReturnsPermissionDenied()
{
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
new FakeApiKeyVerifier(SuccessWithScopes(GatewayScopes.EventsRead)),
new GatewayRequestIdentityAccessor());
RpcException exception = await Assert.ThrowsAsync<RpcException>(
() => interceptor.UnaryServerHandler(
new OpenSessionRequest(),
ContextWithAuthorization("Bearer mxgw_operator01_secret"),
(_, _) => Task.FromResult(new OpenSessionReply())));
Assert.Equal(StatusCode.PermissionDenied, exception.StatusCode);
Assert.Contains(GatewayScopes.SessionOpen, exception.Status.Detail, StringComparison.Ordinal);
}
[Fact]
public async Task UnaryServerHandler_ValidApiKeyWithScope_SetsRequestIdentity()
{
GatewayRequestIdentityAccessor identityAccessor = new();
ApiKeyIdentity? identitySeenByHandler = null;
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
new FakeApiKeyVerifier(SuccessWithScopes(GatewayScopes.SessionOpen)),
identityAccessor);
OpenSessionReply reply = await interceptor.UnaryServerHandler(
new OpenSessionRequest(),
ContextWithAuthorization("Bearer mxgw_operator01_secret"),
(_, _) =>
{
identitySeenByHandler = identityAccessor.Current;
return Task.FromResult(new OpenSessionReply { SessionId = "session-1" });
});
Assert.Equal("session-1", reply.SessionId);
Assert.NotNull(identitySeenByHandler);
Assert.Equal("operator01", identitySeenByHandler.KeyId);
Assert.Null(identityAccessor.Current);
}
[Fact]
public async Task ServerStreamingServerHandler_ValidApiKeyMissingScope_ReturnsPermissionDenied()
{
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
new FakeApiKeyVerifier(SuccessWithScopes(GatewayScopes.SessionOpen)),
new GatewayRequestIdentityAccessor());
RpcException exception = await Assert.ThrowsAsync<RpcException>(
() => interceptor.ServerStreamingServerHandler(
new StreamEventsRequest(),
new TestServerStreamWriter<MxEvent>(),
ContextWithAuthorization("Bearer mxgw_operator01_secret"),
(_, _, _) => Task.CompletedTask));
Assert.Equal(StatusCode.PermissionDenied, exception.StatusCode);
Assert.Contains(GatewayScopes.EventsRead, exception.Status.Detail, StringComparison.Ordinal);
}
[Fact]
public async Task ServerStreamingServerHandler_ValidApiKeyWithScope_AllowsStream()
{
GatewayRequestIdentityAccessor identityAccessor = new();
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
new FakeApiKeyVerifier(SuccessWithScopes(GatewayScopes.EventsRead)),
identityAccessor);
TestServerStreamWriter<MxEvent> streamWriter = new();
await interceptor.ServerStreamingServerHandler(
new StreamEventsRequest(),
streamWriter,
ContextWithAuthorization("Bearer mxgw_operator01_secret"),
async (_, writer, _) =>
{
Assert.Equal("operator01", identityAccessor.Current?.KeyId);
await writer.WriteAsync(new MxEvent { SessionId = "session-1" });
});
MxEvent eventMessage = Assert.Single(streamWriter.Messages);
Assert.Equal("session-1", eventMessage.SessionId);
Assert.Null(identityAccessor.Current);
}
[Fact]
public async Task UnaryServerHandler_AuthenticationDisabled_SkipsApiKeyVerification()
{
GatewayRequestIdentityAccessor identityAccessor = new();
FakeApiKeyVerifier verifier = new(ApiKeyVerificationResult.Fail(
ApiKeyVerificationFailure.MissingOrMalformedCredentials));
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
verifier,
identityAccessor,
AuthenticationMode.Disabled);
OpenSessionReply reply = await interceptor.UnaryServerHandler(
new OpenSessionRequest(),
new TestServerCallContext([]),
(_, _) => Task.FromResult(new OpenSessionReply { SessionId = "session-1" }));
Assert.Equal("session-1", reply.SessionId);
Assert.False(verifier.WasCalled);
Assert.Null(identityAccessor.Current);
}
private static GatewayGrpcAuthorizationInterceptor CreateInterceptor(
IApiKeyVerifier apiKeyVerifier,
IGatewayRequestIdentityAccessor identityAccessor,
AuthenticationMode authenticationMode = AuthenticationMode.ApiKey)
{
return new GatewayGrpcAuthorizationInterceptor(
apiKeyVerifier,
new GatewayGrpcScopeResolver(),
identityAccessor,
Options.Create(new GatewayOptions
{
Authentication = new AuthenticationOptions
{
Mode = authenticationMode
}
}));
}
private static ApiKeyVerificationResult SuccessWithScopes(params string[] scopes)
{
return ApiKeyVerificationResult.Success(new ApiKeyIdentity(
KeyId: "operator01",
KeyPrefix: "mxgw_operator01",
DisplayName: "Operator Key",
Scopes: new HashSet<string>(scopes, StringComparer.Ordinal)));
}
private static TestServerCallContext ContextWithAuthorization(string authorizationHeader)
{
return new TestServerCallContext([new Metadata.Entry("authorization", authorizationHeader)]);
}
private sealed class FakeApiKeyVerifier(ApiKeyVerificationResult result) : IApiKeyVerifier
{
public bool WasCalled { get; private set; }
public string? LastAuthorizationHeader { get; private set; }
public Task<ApiKeyVerificationResult> VerifyAsync(
string? authorizationHeader,
CancellationToken cancellationToken)
{
WasCalled = true;
LastAuthorizationHeader = authorizationHeader;
return Task.FromResult(result);
}
}
private sealed class TestServerStreamWriter<T> : IServerStreamWriter<T>
{
public List<T> Messages { get; } = [];
public WriteOptions? WriteOptions { get; set; }
public Task WriteAsync(T message)
{
Messages.Add(message);
return Task.CompletedTask;
}
}
private sealed class TestServerCallContext(
Metadata requestHeaders,
CancellationToken cancellationToken = default) : ServerCallContext
{
private readonly Metadata responseTrailers = [];
private readonly Dictionary<object, object> userState = [];
private Status status;
private WriteOptions? writeOptions;
protected override string MethodCore => "/mxaccess_gateway.v1.MxAccessGateway/Test";
protected override string HostCore => "localhost";
protected override string PeerCore => "ipv4:127.0.0.1:5000";
protected override DateTime DeadlineCore => DateTime.UtcNow.AddMinutes(1);
protected override Metadata RequestHeadersCore => requestHeaders;
protected override CancellationToken CancellationTokenCore => cancellationToken;
protected override Metadata ResponseTrailersCore => responseTrailers;
protected override Status StatusCore
{
get => status;
set => status = value;
}
protected override WriteOptions? WriteOptionsCore
{
get => writeOptions;
set => writeOptions = value;
}
protected override AuthContext AuthContextCore { get; } = new(
string.Empty,
new Dictionary<string, List<AuthProperty>>(StringComparer.Ordinal));
protected override IDictionary<object, object> UserStateCore => userState;
protected override Task WriteResponseHeadersAsyncCore(Metadata responseHeaders)
{
return Task.CompletedTask;
}
protected override ContextPropagationToken CreatePropagationTokenCore(
ContextPropagationOptions? options)
{
throw new NotSupportedException();
}
}
}
@@ -0,0 +1,54 @@
using MxGateway.Contracts.Proto;
using MxGateway.Server.Security.Authorization;
namespace MxGateway.Tests.Security.Authorization;
public sealed class GatewayGrpcScopeResolverTests
{
[Theory]
[InlineData(typeof(OpenSessionRequest), GatewayScopes.SessionOpen)]
[InlineData(typeof(CloseSessionRequest), GatewayScopes.SessionClose)]
[InlineData(typeof(StreamEventsRequest), GatewayScopes.EventsRead)]
public void ResolveRequiredScope_KnownRpcRequest_ReturnsExpectedScope(
Type requestType,
string expectedScope)
{
GatewayGrpcScopeResolver resolver = new();
object request = Activator.CreateInstance(requestType)!;
string scope = resolver.ResolveRequiredScope(request);
Assert.Equal(expectedScope, scope);
}
[Theory]
[InlineData(MxCommandKind.Register, GatewayScopes.InvokeRead)]
[InlineData(MxCommandKind.AddItem, GatewayScopes.InvokeRead)]
[InlineData(MxCommandKind.Advise, GatewayScopes.InvokeRead)]
[InlineData(MxCommandKind.Write, GatewayScopes.InvokeWrite)]
[InlineData(MxCommandKind.Write2, GatewayScopes.InvokeWrite)]
[InlineData(MxCommandKind.WriteSecured, GatewayScopes.InvokeSecure)]
[InlineData(MxCommandKind.WriteSecured2, GatewayScopes.InvokeSecure)]
[InlineData(MxCommandKind.AuthenticateUser, GatewayScopes.InvokeSecure)]
[InlineData(MxCommandKind.ArchestraUserToId, GatewayScopes.MetadataRead)]
[InlineData(MxCommandKind.GetSessionState, GatewayScopes.MetadataRead)]
[InlineData(MxCommandKind.GetWorkerInfo, GatewayScopes.MetadataRead)]
[InlineData(MxCommandKind.DrainEvents, GatewayScopes.EventsRead)]
[InlineData(MxCommandKind.ShutdownWorker, GatewayScopes.Admin)]
public void ResolveRequiredScope_InvokeCommand_ReturnsExpectedScope(
MxCommandKind commandKind,
string expectedScope)
{
GatewayGrpcScopeResolver resolver = new();
string scope = resolver.ResolveRequiredScope(new MxCommandRequest
{
Command = new MxCommand
{
Kind = commandKind
}
});
Assert.Equal(expectedScope, scope);
}
}
@@ -0,0 +1,37 @@
using System;
using System.Collections.Generic;
using MxGateway.Worker.Bootstrap;
namespace MxGateway.Worker.Tests.Bootstrap;
internal sealed class MemoryWorkerEnvironment : IWorkerEnvironment
{
private readonly Dictionary<string, string> _values = new();
private readonly Exception? _exception;
public MemoryWorkerEnvironment()
{
}
public MemoryWorkerEnvironment(Exception exception)
{
_exception = exception;
}
public void Set(string name, string value)
{
_values[name] = value;
}
public string? GetEnvironmentVariable(string name)
{
if (_exception is not null)
{
throw _exception;
}
return _values.TryGetValue(name, out string value)
? value
: null;
}
}
@@ -0,0 +1,22 @@
using System.Collections.Generic;
namespace MxGateway.Worker.Tests.Bootstrap;
internal sealed class MemoryWorkerLogEntry
{
public MemoryWorkerLogEntry(
string level,
string eventName,
IReadOnlyDictionary<string, object?> fields)
{
Level = level;
EventName = eventName;
Fields = fields;
}
public string Level { get; }
public string EventName { get; }
public IReadOnlyDictionary<string, object?> Fields { get; }
}
@@ -0,0 +1,19 @@
using System.Collections.Generic;
using MxGateway.Worker.Bootstrap;
namespace MxGateway.Worker.Tests.Bootstrap;
internal sealed class MemoryWorkerLogger : IWorkerLogger
{
public List<MemoryWorkerLogEntry> Entries { get; } = new();
public void Information(string eventName, IReadOnlyDictionary<string, object?> fields)
{
Entries.Add(new MemoryWorkerLogEntry("Information", eventName, WorkerLogRedactor.RedactFields(fields)));
}
public void Error(string eventName, IReadOnlyDictionary<string, object?> fields)
{
Entries.Add(new MemoryWorkerLogEntry("Error", eventName, WorkerLogRedactor.RedactFields(fields)));
}
}
@@ -0,0 +1,164 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Contracts;
using MxGateway.Worker.Bootstrap;
using MxGateway.Worker.Ipc;
namespace MxGateway.Worker.Tests.Bootstrap;
public sealed class WorkerApplicationTests
{
[Fact]
public void Run_WithValidBootstrapArguments_ReturnsSuccessAndLogsRedactedNonce()
{
MemoryWorkerEnvironment environment = CreateEnvironment("nonce-secret");
MemoryWorkerLogger logger = new();
int exitCode = MxGateway.Worker.WorkerApplication.Run(
ValidArgs(),
environment,
logger,
new SucceedingPipeClient());
Assert.Equal((int)WorkerExitCode.Success, exitCode);
Assert.Equal(2, logger.Entries.Count);
MemoryWorkerLogEntry entry = logger.Entries[0];
Assert.Equal("Information", entry.Level);
Assert.Equal("WorkerBootstrapSucceeded", entry.EventName);
Assert.Equal("session-1", entry.Fields["session_id"]);
Assert.Equal("mxaccess-gateway-123-session-1", entry.Fields["pipe_name"]);
Assert.Equal(GatewayContractInfo.WorkerProtocolVersion, entry.Fields["protocol_version"]);
Assert.Equal("[redacted]", entry.Fields["nonce"]);
Assert.Equal("WorkerPipeHandshakeSucceeded", logger.Entries[1].EventName);
}
[Fact]
public void Run_WithMissingRequiredArguments_ReturnsInvalidArguments()
{
MemoryWorkerEnvironment environment = CreateEnvironment("nonce-secret");
MemoryWorkerLogger logger = new();
int exitCode = MxGateway.Worker.WorkerApplication.Run(
[],
environment,
logger);
Assert.Equal((int)WorkerExitCode.InvalidArguments, exitCode);
MemoryWorkerLogEntry entry = Assert.Single(logger.Entries);
Assert.Equal("Error", entry.Level);
Assert.Equal("WorkerBootstrapFailed", entry.EventName);
Assert.Equal(WorkerExitCode.InvalidArguments, entry.Fields["exit_code"]);
}
[Fact]
public void Run_WithInvalidProtocolVersion_ReturnsInvalidProtocolVersion()
{
MemoryWorkerEnvironment environment = CreateEnvironment("nonce-secret");
MemoryWorkerLogger logger = new();
int exitCode = MxGateway.Worker.WorkerApplication.Run(
ValidArgs(protocolVersion: "999"),
environment,
logger);
Assert.Equal((int)WorkerExitCode.InvalidProtocolVersion, exitCode);
}
[Fact]
public void Run_WithMissingNonce_ReturnsMissingNonce()
{
MemoryWorkerEnvironment environment = new();
MemoryWorkerLogger logger = new();
int exitCode = MxGateway.Worker.WorkerApplication.Run(
ValidArgs(),
environment,
logger);
Assert.Equal((int)WorkerExitCode.MissingNonce, exitCode);
}
[Fact]
public void Run_WithPipeProtocolFailure_ReturnsProtocolViolation()
{
MemoryWorkerEnvironment environment = CreateEnvironment("nonce-secret");
MemoryWorkerLogger logger = new();
int exitCode = MxGateway.Worker.WorkerApplication.Run(
ValidArgs(),
environment,
logger,
new ThrowingPipeClient(new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.NonceMismatch,
"Bad nonce.")));
Assert.Equal((int)WorkerExitCode.ProtocolViolation, exitCode);
Assert.Equal("WorkerPipeProtocolFailure", logger.Entries[1].EventName);
}
[Fact]
public void Run_WithUnexpectedBootstrapFailure_ReturnsUnexpectedFailure()
{
MemoryWorkerEnvironment environment = new(new InvalidOperationException("environment failed"));
MemoryWorkerLogger logger = new();
int exitCode = MxGateway.Worker.WorkerApplication.Run(
ValidArgs(),
environment,
logger);
Assert.Equal((int)WorkerExitCode.UnexpectedFailure, exitCode);
MemoryWorkerLogEntry entry = Assert.Single(logger.Entries);
Assert.Equal("WorkerBootstrapUnexpectedFailure", entry.EventName);
Assert.Equal(WorkerExitCode.UnexpectedFailure, entry.Fields["exit_code"]);
Assert.Equal(typeof(InvalidOperationException).FullName, entry.Fields["exception_type"]);
}
private static string[] ValidArgs(string? protocolVersion = null)
{
return
[
"--session-id",
"session-1",
"--pipe-name",
"mxaccess-gateway-123-session-1",
"--protocol-version",
protocolVersion ?? GatewayContractInfo.WorkerProtocolVersion.ToString(),
];
}
private static MemoryWorkerEnvironment CreateEnvironment(string nonce)
{
MemoryWorkerEnvironment environment = new();
environment.Set(WorkerOptions.NonceEnvironmentVariableName, nonce);
return environment;
}
private sealed class SucceedingPipeClient : IWorkerPipeClient
{
public Task RunAsync(
WorkerOptions options,
CancellationToken cancellationToken = default)
{
return Task.CompletedTask;
}
}
private sealed class ThrowingPipeClient : IWorkerPipeClient
{
private readonly Exception _exception;
public ThrowingPipeClient(Exception exception)
{
_exception = exception;
}
public Task RunAsync(
WorkerOptions options,
CancellationToken cancellationToken = default)
{
throw _exception;
}
}
}
@@ -0,0 +1,28 @@
using System.Collections.Generic;
using System.IO;
using MxGateway.Worker.Bootstrap;
namespace MxGateway.Worker.Tests.Bootstrap;
public sealed class WorkerConsoleLoggerTests
{
[Fact]
public void Information_RedactsNonceInStructuredOutput()
{
StringWriter writer = new();
WorkerConsoleLogger logger = new(writer);
logger.Information("WorkerBootstrapSucceeded", new Dictionary<string, object?>
{
["session_id"] = "session-1",
["nonce"] = "nonce-secret",
});
string output = writer.ToString();
Assert.Contains("event=WorkerBootstrapSucceeded", output);
Assert.Contains("session_id=session-1", output);
Assert.Contains("nonce=[redacted]", output);
Assert.DoesNotContain("nonce-secret", output);
}
}
@@ -0,0 +1,32 @@
using System.Collections.Generic;
using MxGateway.Worker.Bootstrap;
namespace MxGateway.Worker.Tests.Bootstrap;
public sealed class WorkerLogRedactorTests
{
[Fact]
public void RedactFields_RedactsNonceSecretPasswordTokenCredentialAndApiKeyFields()
{
Dictionary<string, object?> fields = new()
{
["nonce"] = "nonce-secret",
["client_secret"] = "secret",
["password"] = "password",
["auth_token"] = "token",
["credential_value"] = "credential",
["api_key"] = "key",
["session_id"] = "session-1",
};
Dictionary<string, object?> redacted = WorkerLogRedactor.RedactFields(fields);
Assert.Equal("[redacted]", redacted["nonce"]);
Assert.Equal("[redacted]", redacted["client_secret"]);
Assert.Equal("[redacted]", redacted["password"]);
Assert.Equal("[redacted]", redacted["auth_token"]);
Assert.Equal("[redacted]", redacted["credential_value"]);
Assert.Equal("[redacted]", redacted["api_key"]);
Assert.Equal("session-1", redacted["session_id"]);
}
}
@@ -0,0 +1,115 @@
using MxGateway.Contracts;
using MxGateway.Worker.Bootstrap;
namespace MxGateway.Worker.Tests.Bootstrap;
public sealed class WorkerOptionsParserTests
{
[Fact]
public void Parse_WithAllRequiredInputs_ReturnsWorkerOptions()
{
WorkerOptionsParser parser = new(CreateEnvironment("nonce-secret"));
WorkerBootstrapResult result = parser.Parse(ValidArgs());
Assert.True(result.Succeeded);
Assert.Equal(WorkerExitCode.Success, result.ExitCode);
Assert.NotNull(result.Options);
Assert.Equal("session-1", result.Options.SessionId);
Assert.Equal("mxaccess-gateway-123-session-1", result.Options.PipeName);
Assert.Equal(GatewayContractInfo.WorkerProtocolVersion, result.Options.ProtocolVersion);
Assert.Equal("nonce-secret", result.Options.Nonce);
}
[Fact]
public void Parse_WithMissingSessionId_ReturnsInvalidArguments()
{
WorkerOptionsParser parser = new(CreateEnvironment("nonce-secret"));
WorkerBootstrapResult result = parser.Parse(
[
"--pipe-name",
"mxaccess-gateway-123-session-1",
"--protocol-version",
GatewayContractInfo.WorkerProtocolVersion.ToString(),
]);
Assert.False(result.Succeeded);
Assert.Equal(WorkerExitCode.InvalidArguments, result.ExitCode);
Assert.Contains(result.Errors, error => error.Contains("--session-id"));
}
[Fact]
public void Parse_WithUnknownOption_ReturnsInvalidArguments()
{
WorkerOptionsParser parser = new(CreateEnvironment("nonce-secret"));
WorkerBootstrapResult result = parser.Parse(
[
"--session-id",
"session-1",
"--pipe-name",
"mxaccess-gateway-123-session-1",
"--protocol-version",
GatewayContractInfo.WorkerProtocolVersion.ToString(),
"--unexpected",
"value",
]);
Assert.Equal(WorkerExitCode.InvalidArguments, result.ExitCode);
Assert.Contains(result.Errors, error => error.Contains("Unknown option"));
}
[Fact]
public void Parse_WithNonNumericProtocolVersion_ReturnsInvalidProtocolVersion()
{
WorkerOptionsParser parser = new(CreateEnvironment("nonce-secret"));
WorkerBootstrapResult result = parser.Parse(ValidArgs(protocolVersion: "abc"));
Assert.False(result.Succeeded);
Assert.Equal(WorkerExitCode.InvalidProtocolVersion, result.ExitCode);
}
[Fact]
public void Parse_WithUnsupportedProtocolVersion_ReturnsInvalidProtocolVersion()
{
WorkerOptionsParser parser = new(CreateEnvironment("nonce-secret"));
WorkerBootstrapResult result = parser.Parse(ValidArgs(protocolVersion: "999"));
Assert.False(result.Succeeded);
Assert.Equal(WorkerExitCode.InvalidProtocolVersion, result.ExitCode);
}
[Fact]
public void Parse_WithMissingNonce_ReturnsMissingNonce()
{
WorkerOptionsParser parser = new(new MemoryWorkerEnvironment());
WorkerBootstrapResult result = parser.Parse(ValidArgs());
Assert.False(result.Succeeded);
Assert.Equal(WorkerExitCode.MissingNonce, result.ExitCode);
}
private static string[] ValidArgs(string? protocolVersion = null)
{
return
[
"--session-id",
"session-1",
"--pipe-name",
"mxaccess-gateway-123-session-1",
"--protocol-version",
protocolVersion ?? GatewayContractInfo.WorkerProtocolVersion.ToString(),
];
}
private static MemoryWorkerEnvironment CreateEnvironment(string nonce)
{
MemoryWorkerEnvironment environment = new();
environment.Set(WorkerOptions.NonceEnvironmentVariableName, nonce);
return environment;
}
}
@@ -0,0 +1,19 @@
using MxGateway.Contracts;
using MxGateway.Worker.Ipc;
namespace MxGateway.Worker.Tests.Contracts;
public sealed class WorkerContractInfoTests
{
[Fact]
public void SupportedProtocolVersion_UsesSharedGatewayContractVersion()
{
Assert.Equal(GatewayContractInfo.WorkerProtocolVersion, WorkerContractInfo.SupportedProtocolVersion);
}
[Fact]
public void WorkerEnvelopeDescriptorName_UsesGeneratedWorkerContract()
{
Assert.Equal("mxaccess_worker.v1.WorkerEnvelope", WorkerContractInfo.WorkerEnvelopeDescriptorName);
}
}
@@ -0,0 +1,183 @@
using System;
using Google.Protobuf;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Bootstrap;
using MxGateway.Worker.Conversion;
using ProtobufTimestamp = Google.Protobuf.WellKnownTypes.Timestamp;
namespace MxGateway.Worker.Tests.Conversion;
public sealed class VariantConverterTests
{
private readonly VariantConverter _converter = new();
[Theory]
[InlineData(true, MxDataType.Boolean, MxValue.KindOneofCase.BoolValue)]
[InlineData(42, MxDataType.Integer, MxValue.KindOneofCase.Int32Value)]
[InlineData(42L, MxDataType.Integer, MxValue.KindOneofCase.Int64Value)]
[InlineData(1.25f, MxDataType.Float, MxValue.KindOneofCase.FloatValue)]
[InlineData(2.5d, MxDataType.Double, MxValue.KindOneofCase.DoubleValue)]
[InlineData("value", MxDataType.String, MxValue.KindOneofCase.StringValue)]
public void Convert_WithSupportedScalar_ProjectsTypedValue(
object value,
MxDataType expectedDataType,
MxValue.KindOneofCase expectedKind)
{
MxValue converted = _converter.Convert(value);
Assert.Equal(expectedDataType, converted.DataType);
Assert.Equal(expectedKind, converted.KindCase);
Assert.False(string.IsNullOrWhiteSpace(converted.VariantType));
}
[Fact]
public void Convert_WithDateTime_ProjectsTimestamp()
{
DateTime dateTime = new(2026, 4, 26, 17, 45, 0, DateTimeKind.Utc);
MxValue converted = _converter.Convert(dateTime);
Assert.Equal(MxDataType.Time, converted.DataType);
Assert.Equal(ProtobufTimestamp.FromDateTime(dateTime), converted.TimestampValue);
Assert.Equal("VT_DATE", converted.VariantType);
}
[Fact]
public void Convert_WithFileTimeAndExpectedTime_ProjectsTimestamp()
{
DateTime dateTime = new(2026, 4, 26, 17, 45, 0, DateTimeKind.Utc);
MxValue converted = _converter.Convert(dateTime.ToFileTimeUtc(), MxDataType.Time);
Assert.Equal(MxDataType.Time, converted.DataType);
Assert.Equal(ProtobufTimestamp.FromDateTime(dateTime), converted.TimestampValue);
Assert.Equal("VT_I8", converted.VariantType);
}
[Theory]
[InlineData(null, "VT_EMPTY")]
[InlineData(typeof(DBNull), "VT_NULL")]
public void Convert_WithNullLikeValue_PreservesNull(
object? value,
string expectedVariantType)
{
object? actualValue = value is System.Type ? DBNull.Value : value;
MxValue converted = _converter.Convert(actualValue);
Assert.True(converted.IsNull);
Assert.Equal(MxDataType.NoData, converted.DataType);
Assert.Equal(expectedVariantType, converted.VariantType);
Assert.Equal(MxValue.KindOneofCase.None, converted.KindCase);
}
[Fact]
public void ConvertArray_WithSupportedArrays_ProjectsTypedValuesAndDimensions()
{
MxValue bools = _converter.Convert(new[] { true, false });
MxValue ints = _converter.Convert(new[] { 1, 2, 3 });
MxValue floats = _converter.Convert(new[] { 1.25f, 2.5f });
MxValue doubles = _converter.Convert(new[] { 1.25d, 2.5d });
MxValue strings = _converter.Convert(new[] { "one", "two" });
MxValue times = _converter.Convert(new[]
{
new DateTime(2026, 4, 26, 17, 45, 0, DateTimeKind.Utc),
new DateTime(2026, 4, 26, 17, 46, 0, DateTimeKind.Utc),
});
Assert.Equal(new[] { true, false }, bools.ArrayValue.BoolValues.Values);
Assert.Equal(new[] { 1, 2, 3 }, ints.ArrayValue.Int32Values.Values);
Assert.Equal(new[] { 1.25f, 2.5f }, floats.ArrayValue.FloatValues.Values);
Assert.Equal(new[] { 1.25d, 2.5d }, doubles.ArrayValue.DoubleValues.Values);
Assert.Equal(new[] { "one", "two" }, strings.ArrayValue.StringValues.Values);
Assert.Equal(2, times.ArrayValue.TimestampValues.Values.Count);
Assert.Equal(new uint[] { 2 }, bools.ArrayValue.Dimensions);
Assert.Equal(MxDataType.Boolean, bools.ArrayValue.ElementDataType);
}
[Fact]
public void ConvertArray_WithMultidimensionalArray_PreservesRankAndDimensions()
{
int[,] values =
{
{ 1, 2, 3 },
{ 4, 5, 6 },
};
MxValue converted = _converter.Convert(values);
Assert.Equal(new uint[] { 2, 3 }, converted.ArrayValue.Dimensions);
Assert.Equal(new[] { 1, 2, 3, 4, 5, 6 }, converted.ArrayValue.Int32Values.Values);
}
[Fact]
public void ConvertArray_WithExpectedTimeAndFileTimeValues_ProjectsTimestampArray()
{
DateTime first = new(2026, 4, 26, 17, 45, 0, DateTimeKind.Utc);
DateTime second = new(2026, 4, 26, 17, 46, 0, DateTimeKind.Utc);
MxValue converted = _converter.Convert(
new[] { first.ToFileTimeUtc(), second.ToFileTimeUtc() },
MxDataType.Time);
Assert.Equal(MxDataType.Time, converted.ArrayValue.ElementDataType);
Assert.Equal(
new[] { ProtobufTimestamp.FromDateTime(first), ProtobufTimestamp.FromDateTime(second) },
converted.ArrayValue.TimestampValues.Values);
}
[Fact]
public void Convert_WithUnknownScalar_PreservesRawMetadata()
{
UnsupportedVariant value = new("opaque");
MxValue converted = _converter.Convert(value);
Assert.Equal(MxDataType.Unknown, converted.DataType);
Assert.Equal(MxValue.KindOneofCase.RawValue, converted.KindCase);
Assert.Contains(typeof(UnsupportedVariant).FullName!, converted.VariantType);
Assert.Contains(typeof(UnsupportedVariant).FullName!, converted.RawDiagnostic);
Assert.Equal(ByteString.CopyFromUtf8("opaque"), converted.RawValue);
}
[Fact]
public void ConvertArray_WithUnknownArray_PreservesRawMetadata()
{
UnsupportedVariant[] values =
[
new("first"),
new("second"),
];
MxValue converted = _converter.Convert(values);
Assert.Equal(MxDataType.Unknown, converted.ArrayValue.ElementDataType);
Assert.Equal(MxArray.ValuesOneofCase.RawValues, converted.ArrayValue.ValuesCase);
Assert.Equal(new uint[] { 2 }, converted.ArrayValue.Dimensions);
Assert.Equal("first", converted.ArrayValue.RawValues.Values[0].ToStringUtf8());
Assert.Contains(typeof(UnsupportedVariant).FullName!, converted.ArrayValue.RawDiagnostic);
}
[Fact]
public void Redactor_WithCredentialBearingValueFields_RedactsBeforeLogging()
{
Assert.Equal(WorkerLogRedactor.RedactedValue, WorkerLogRedactor.RedactValue("credential_value", "secret"));
Assert.Equal(WorkerLogRedactor.RedactedValue, WorkerLogRedactor.RedactValue("password_value", "secret"));
Assert.Equal(WorkerLogRedactor.RedactedValue, WorkerLogRedactor.RedactValue("secured_write_token", "secret"));
}
private sealed class UnsupportedVariant
{
private readonly string _value;
public UnsupportedVariant(string value)
{
_value = value;
}
public override string ToString()
{
return _value;
}
}
}
@@ -0,0 +1,163 @@
using System;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Google.Protobuf;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Ipc;
namespace MxGateway.Worker.Tests.Ipc;
public sealed class WorkerFrameProtocolTests
{
private const string SessionId = "session-1";
private const string Nonce = "nonce-secret";
[Fact]
public async Task WriteAndReadAsync_WithValidEnvelope_RoundTripsFrame()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream stream = new();
WorkerEnvelope original = CreateGatewayHelloEnvelope();
WorkerFrameWriter writer = new(stream, options);
await writer.WriteAsync(original);
stream.Position = 0;
WorkerFrameReader reader = new(stream, options);
WorkerEnvelope parsed = await reader.ReadAsync();
Assert.Equal(original, parsed);
}
[Fact]
public async Task ReadAsync_WithWrongProtocolVersion_ThrowsProtocolVersionMismatch()
{
WorkerFrameProtocolOptions options = CreateOptions();
WorkerEnvelope envelope = CreateGatewayHelloEnvelope();
envelope.ProtocolVersion++;
MemoryStream stream = new(CreateFrame(envelope));
WorkerFrameReader reader = new(stream, options);
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await reader.ReadAsync());
Assert.Equal(WorkerFrameProtocolErrorCode.ProtocolVersionMismatch, exception.ErrorCode);
}
[Fact]
public async Task ReadAsync_WithWrongSessionId_ThrowsSessionMismatch()
{
WorkerFrameProtocolOptions options = CreateOptions();
WorkerEnvelope envelope = CreateGatewayHelloEnvelope();
envelope.SessionId = "different-session";
MemoryStream stream = new(CreateFrame(envelope));
WorkerFrameReader reader = new(stream, options);
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await reader.ReadAsync());
Assert.Equal(WorkerFrameProtocolErrorCode.SessionMismatch, exception.ErrorCode);
}
[Fact]
public async Task ReadAsync_WithMalformedLength_ThrowsMalformedLength()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream stream = new(new byte[sizeof(uint)]);
WorkerFrameReader reader = new(stream, options);
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await reader.ReadAsync());
Assert.Equal(WorkerFrameProtocolErrorCode.MalformedLength, exception.ErrorCode);
}
[Fact]
public async Task ReadAsync_WithMalformedPayload_ThrowsInvalidEnvelope()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream stream = new(CreateFrame(new byte[] { 0x80 }));
WorkerFrameReader reader = new(stream, options);
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await reader.ReadAsync());
Assert.Equal(WorkerFrameProtocolErrorCode.InvalidEnvelope, exception.ErrorCode);
}
[Fact]
public async Task WriteAsync_WithConcurrentCalls_SerializesCompleteFrames()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream stream = new();
WorkerFrameWriter writer = new(stream, options);
await Task.WhenAll(
writer.WriteAsync(CreateGatewayHelloEnvelope(sequence: 1)),
writer.WriteAsync(CreateGatewayHelloEnvelope(sequence: 2)),
writer.WriteAsync(CreateGatewayHelloEnvelope(sequence: 3)));
stream.Position = 0;
WorkerFrameReader reader = new(stream, options);
WorkerEnvelope first = await reader.ReadAsync();
WorkerEnvelope second = await reader.ReadAsync();
WorkerEnvelope third = await reader.ReadAsync();
Assert.Equal(new ulong[] { 1, 2, 3 }, new[] { first.Sequence, second.Sequence, third.Sequence }.OrderBy(sequence => sequence));
}
private static WorkerFrameProtocolOptions CreateOptions()
{
return new WorkerFrameProtocolOptions(
SessionId,
GatewayContractInfo.WorkerProtocolVersion,
Nonce);
}
private static WorkerEnvelope CreateGatewayHelloEnvelope(ulong sequence = 1)
{
return new WorkerEnvelope
{
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
SessionId = SessionId,
Sequence = sequence,
GatewayHello = new GatewayHello
{
SupportedProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
Nonce = Nonce,
GatewayVersion = "test-gateway",
},
};
}
private static byte[] CreateFrame(IMessage message)
{
return CreateFrame(message.ToByteArray());
}
private static byte[] CreateFrame(byte[] payload)
{
byte[] frame = new byte[sizeof(uint) + payload.Length];
WriteUInt32LittleEndian(frame, (uint)payload.Length);
payload.CopyTo(frame, sizeof(uint));
return frame;
}
private static void WriteUInt32LittleEndian(
byte[] buffer,
uint value)
{
buffer[0] = (byte)value;
buffer[1] = (byte)(value >> 8);
buffer[2] = (byte)(value >> 16);
buffer[3] = (byte)(value >> 24);
}
}
@@ -0,0 +1,61 @@
using System;
using System.IO.Pipes;
using System.Threading.Tasks;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Bootstrap;
using MxGateway.Worker.Ipc;
namespace MxGateway.Worker.Tests.Ipc;
public sealed class WorkerPipeClientTests
{
[Fact]
public async Task RunAsync_ConnectsToPipeAndCompletesHandshake()
{
string pipeName = $"mxaccess-gateway-test-{Guid.NewGuid():N}";
WorkerOptions workerOptions = new(
"session-1",
pipeName,
GatewayContractInfo.WorkerProtocolVersion,
"nonce-secret");
WorkerFrameProtocolOptions frameOptions = new(workerOptions);
using NamedPipeServerStream server = new(
pipeName,
PipeDirection.InOut,
1,
PipeTransmissionMode.Byte,
PipeOptions.Asynchronous);
WorkerPipeClient client = new(connectTimeoutMilliseconds: 5000);
Task clientTask = client.RunAsync(workerOptions);
await Task.Factory.FromAsync(server.BeginWaitForConnection, server.EndWaitForConnection, null);
WorkerFrameReader reader = new(server, frameOptions);
WorkerFrameWriter writer = new(server, frameOptions);
await writer.WriteAsync(new WorkerEnvelope
{
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
SessionId = "session-1",
Sequence = 1,
GatewayHello = new GatewayHello
{
SupportedProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
Nonce = "nonce-secret",
GatewayVersion = "test-gateway",
},
});
WorkerEnvelope hello = await reader.ReadAsync();
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerHello, hello.BodyCase);
Assert.Equal("nonce-secret", hello.WorkerHello.Nonce);
WorkerEnvelope ready = await reader.ReadAsync();
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerReady, ready.BodyCase);
await clientTask;
}
}
@@ -0,0 +1,192 @@
using System.Collections.Generic;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Ipc;
namespace MxGateway.Worker.Tests.Ipc;
public sealed class WorkerPipeSessionTests
{
private const string SessionId = "session-1";
private const string Nonce = "nonce-secret";
[Fact]
public async Task CompleteStartupHandshakeAsync_WithValidGatewayHello_SendsHelloThenReady()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream inbound = new();
await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope());
inbound.Position = 0;
MemoryStream outbound = new();
WorkerPipeSession session = CreateSession(inbound, outbound, options);
bool initialized = false;
await session.CompleteStartupHandshakeAsync(
_ =>
{
initialized = true;
return Task.CompletedTask;
});
Assert.True(initialized);
WorkerEnvelope[] written = ReadWrittenFrames(outbound, options);
Assert.Equal(2, written.Length);
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerHello, written[0].BodyCase);
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerReady, written[1].BodyCase);
Assert.Equal(Nonce, written[0].WorkerHello.Nonce);
}
[Fact]
public async Task CompleteStartupHandshakeAsync_WithWrongNonce_FaultsBeforeInitialization()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream inbound = new();
await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope(nonce: "wrong"));
inbound.Position = 0;
MemoryStream outbound = new();
WorkerPipeSession session = CreateSession(inbound, outbound, options);
bool initialized = false;
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await session.CompleteStartupHandshakeAsync(
_ =>
{
initialized = true;
return Task.CompletedTask;
}));
Assert.False(initialized);
Assert.Equal(WorkerFrameProtocolErrorCode.NonceMismatch, exception.ErrorCode);
WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options));
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerFault, fault.BodyCase);
Assert.Equal(WorkerFaultCategory.ProtocolViolation, fault.WorkerFault.Category);
}
[Fact]
public async Task CompleteStartupHandshakeAsync_WithWrongProtocol_FaultsBeforeInitialization()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream inbound = new();
await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope(supportedProtocolVersion: 999));
inbound.Position = 0;
MemoryStream outbound = new();
WorkerPipeSession session = CreateSession(inbound, outbound, options);
bool initialized = false;
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await session.CompleteStartupHandshakeAsync(
_ =>
{
initialized = true;
return Task.CompletedTask;
}));
Assert.False(initialized);
Assert.Equal(WorkerFrameProtocolErrorCode.ProtocolVersionMismatch, exception.ErrorCode);
WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options));
Assert.Equal(WorkerFaultCategory.ProtocolMismatch, fault.WorkerFault.Category);
}
[Fact]
public async Task CompleteStartupHandshakeAsync_WithMalformedFrame_WritesWorkerFault()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream inbound = new(CreateFrame(new byte[] { 0x80 }));
MemoryStream outbound = new();
WorkerPipeSession session = CreateSession(inbound, outbound, options);
bool initialized = false;
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await session.CompleteStartupHandshakeAsync(
_ =>
{
initialized = true;
return Task.CompletedTask;
}));
Assert.False(initialized);
Assert.Equal(WorkerFrameProtocolErrorCode.InvalidEnvelope, exception.ErrorCode);
WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options));
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerFault, fault.BodyCase);
Assert.Equal(WorkerFaultCategory.ProtocolViolation, fault.WorkerFault.Category);
}
private static WorkerPipeSession CreateSession(
Stream inbound,
Stream outbound,
WorkerFrameProtocolOptions options)
{
return new WorkerPipeSession(
new WorkerFrameReader(inbound, options),
new WorkerFrameWriter(outbound, options),
options,
() => 1234);
}
private static WorkerFrameProtocolOptions CreateOptions()
{
return new WorkerFrameProtocolOptions(
SessionId,
GatewayContractInfo.WorkerProtocolVersion,
Nonce);
}
private static WorkerEnvelope CreateGatewayHelloEnvelope(
string nonce = Nonce,
uint supportedProtocolVersion = GatewayContractInfo.WorkerProtocolVersion)
{
return new WorkerEnvelope
{
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
SessionId = SessionId,
Sequence = 1,
GatewayHello = new GatewayHello
{
SupportedProtocolVersion = supportedProtocolVersion,
Nonce = nonce,
GatewayVersion = "test-gateway",
},
};
}
private static WorkerEnvelope[] ReadWrittenFrames(
MemoryStream stream,
WorkerFrameProtocolOptions options)
{
stream.Position = 0;
WorkerFrameReader reader = new(stream, options);
List<WorkerEnvelope> envelopes = new();
while (stream.Position < stream.Length)
{
envelopes.Add(reader.ReadAsync(CancellationToken.None).GetAwaiter().GetResult());
}
return envelopes.ToArray();
}
private static byte[] CreateFrame(byte[] payload)
{
byte[] frame = new byte[sizeof(uint) + payload.Length];
WriteUInt32LittleEndian(frame, (uint)payload.Length);
payload.CopyTo(frame, sizeof(uint));
return frame;
}
private static void WriteUInt32LittleEndian(
byte[] buffer,
uint value)
{
buffer[0] = (byte)value;
buffer[1] = (byte)(value >> 8);
buffer[2] = (byte)(value >> 16);
buffer[3] = (byte)(value >> 24);
}
}
@@ -0,0 +1,23 @@
using MxGateway.Worker.MxAccess;
namespace MxGateway.Worker.Tests.MxAccess;
public sealed class MxAccessInteropInfoTests
{
[Fact]
public void InteropInfo_IdentifiesInstalledMxAccessComTarget()
{
Assert.Equal("LMXProxy.LMXProxyServer.1", MxAccessInteropInfo.ProgId);
Assert.Equal("LMXProxy.LMXProxyServer", MxAccessInteropInfo.VersionIndependentProgId);
Assert.Equal("{C30B52F5-2CB5-4760-AF0A-3A344A7EB5DC}", MxAccessInteropInfo.Clsid);
Assert.Equal("ArchestrA.MxAccess.LMXProxyServerClass", MxAccessInteropInfo.ComClassName);
}
[Fact]
public void InteropAssemblyName_ComesFromReferencedMxAccessAssembly()
{
Assert.Equal("ArchestrA.MxAccess", MxAccessInteropInfo.InteropAssemblyName);
Assert.Equal(3, MxAccessInteropInfo.InteropAssemblyVersion.Major);
Assert.Equal(2, MxAccessInteropInfo.InteropAssemblyVersion.Minor);
}
}
@@ -0,0 +1,28 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net48</TargetFramework>
<IsPackable>false</IsPackable>
<PlatformTarget>x86</PlatformTarget>
<Prefer32Bit>true</Prefer32Bit>
<ImplicitUsings>disable</ImplicitUsings>
<AutoGenerateBindingRedirects>true</AutoGenerateBindingRedirects>
<GenerateBindingRedirectsOutputType>true</GenerateBindingRedirectsOutputType>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.14.1" />
<PackageReference Include="System.Runtime.CompilerServices.Unsafe" Version="6.1.2" />
<PackageReference Include="xunit" Version="2.9.3" />
<PackageReference Include="xunit.runner.visualstudio" Version="3.1.4" />
</ItemGroup>
<ItemGroup>
<Using Include="Xunit" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\MxGateway.Worker\MxGateway.Worker.csproj" />
</ItemGroup>
</Project>
@@ -0,0 +1,94 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Xml.Linq;
namespace MxGateway.Worker.Tests.ProjectStructure;
public sealed class WorkerProjectReferenceTests
{
[Fact]
public void WorkerProject_TargetsNet48AndX86()
{
XDocument project = LoadProject("MxGateway.Worker");
Assert.Equal("net48", ElementValue(project, "TargetFramework"));
Assert.Equal("x86", ElementValue(project, "PlatformTarget"));
Assert.Equal("true", ElementValue(project, "Prefer32Bit"));
}
[Fact]
public void WorkerTestProject_TargetsNet48AndX86()
{
XDocument project = LoadProject("MxGateway.Worker.Tests");
Assert.Equal("net48", ElementValue(project, "TargetFramework"));
Assert.Equal("x86", ElementValue(project, "PlatformTarget"));
}
[Fact]
public void MxAccessInteropReference_ExistsOnlyInWorkerProject()
{
DirectoryInfo repositoryRoot = FindRepositoryRoot();
string[] projectFiles = Directory.GetFiles(repositoryRoot.FullName, "*.csproj", SearchOption.AllDirectories)
.Where(path => path.IndexOf($"{Path.DirectorySeparatorChar}bin{Path.DirectorySeparatorChar}", StringComparison.OrdinalIgnoreCase) < 0)
.Where(path => path.IndexOf($"{Path.DirectorySeparatorChar}obj{Path.DirectorySeparatorChar}", StringComparison.OrdinalIgnoreCase) < 0)
.ToArray();
IReadOnlyList<string> projectsWithMxAccessReference = projectFiles
.Where(ProjectReferencesMxAccess)
.Select(path => Path.GetFileNameWithoutExtension(path))
.ToArray();
Assert.Equal(["MxGateway.Worker"], projectsWithMxAccessReference);
}
private static bool ProjectReferencesMxAccess(string projectPath)
{
XDocument project = XDocument.Load(projectPath);
return project
.Descendants()
.Where(element => element.Name.LocalName is "Reference" or "COMReference" or "COMFileReference" or "PackageReference")
.Select(element => (string?)element.Attribute("Include") ?? string.Empty)
.Concat(project.Descendants().Where(element => element.Name.LocalName == "HintPath").Select(element => element.Value))
.Any(reference =>
reference.IndexOf("MxAccess", StringComparison.OrdinalIgnoreCase) >= 0
|| reference.IndexOf("ArchestrA.MXAccess", StringComparison.OrdinalIgnoreCase) >= 0
|| reference.IndexOf("LMXProxy", StringComparison.OrdinalIgnoreCase) >= 0);
}
private static XDocument LoadProject(string projectName)
{
DirectoryInfo repositoryRoot = FindRepositoryRoot();
string projectPath = Path.Combine(repositoryRoot.FullName, projectName, $"{projectName}.csproj");
return XDocument.Load(projectPath);
}
private static string ElementValue(XDocument project, string elementName)
{
return project
.Descendants()
.Single(element => element.Name.LocalName == elementName)
.Value;
}
private static DirectoryInfo FindRepositoryRoot()
{
DirectoryInfo? current = new(AppContext.BaseDirectory);
while (current is not null)
{
if (File.Exists(Path.Combine(current.FullName, "MxGateway.sln")))
{
return current;
}
current = current.Parent;
}
throw new DirectoryNotFoundException("Could not locate src/MxGateway.sln from the test output directory.");
}
}
@@ -0,0 +1,152 @@
using System;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Worker.Sta;
namespace MxGateway.Worker.Tests.Sta;
public sealed class StaRuntimeTests
{
[Fact]
public async Task InvokeAsync_ExecutesCommandOnStaThread()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
StaCommandObservation observation = await runtime.InvokeAsync(
() => new StaCommandObservation(
Thread.CurrentThread.ManagedThreadId,
Thread.CurrentThread.GetApartmentState()));
Assert.Equal(runtime.StaThreadId, observation.ThreadId);
Assert.Equal(initializer.InitializeThreadId, observation.ThreadId);
Assert.Equal(ApartmentState.STA, observation.ApartmentState);
}
[Fact]
public async Task InvokeAsync_WakesIdlePumpForQueuedCommand()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = new(
initializer,
new StaMessagePump(),
TimeSpan.FromSeconds(30));
runtime.Start();
Stopwatch stopwatch = Stopwatch.StartNew();
int threadId = await runtime.InvokeAsync(() => Thread.CurrentThread.ManagedThreadId);
stopwatch.Stop();
Assert.Equal(runtime.StaThreadId, threadId);
Assert.True(
stopwatch.Elapsed < TimeSpan.FromSeconds(2),
$"Command took {stopwatch.Elapsed} to execute, so the command wake event did not wake the STA promptly.");
}
[Fact]
public void Shutdown_StopsThreadAndUninitializesComApartment()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
bool stopped = runtime.Shutdown(TimeSpan.FromSeconds(2));
Assert.True(stopped);
Assert.False(runtime.IsRunning);
Assert.Equal(1, initializer.InitializeCount);
Assert.Equal(1, initializer.UninitializeCount);
Assert.Equal(initializer.InitializeThreadId, initializer.UninitializeThreadId);
}
[Fact]
public void LastActivityUtc_UpdatesWhilePumpIsIdle()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
DateTimeOffset firstActivity = runtime.LastActivityUtc;
bool updated = SpinWait.SpinUntil(
() => runtime.LastActivityUtc > firstActivity,
TimeSpan.FromSeconds(2));
Assert.True(updated);
}
[Fact]
public async Task InvokeAsync_CommandException_FaultsReturnedTaskWithoutStoppingRuntime()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
InvalidOperationException exception = await Assert.ThrowsAsync<InvalidOperationException>(
() => runtime.InvokeAsync<int>(() => throw new InvalidOperationException("command failed")));
int threadId = await runtime.InvokeAsync(() => Thread.CurrentThread.ManagedThreadId);
Assert.Equal("command failed", exception.Message);
Assert.Equal(runtime.StaThreadId, threadId);
}
[Fact]
public async Task InvokeAsync_AfterShutdown_ReturnsFaultedTask()
{
RecordingComApartmentInitializer initializer = new();
using StaRuntime runtime = CreateRuntime(initializer);
runtime.Start();
runtime.Shutdown(TimeSpan.FromSeconds(2));
InvalidOperationException exception = await Assert.ThrowsAsync<InvalidOperationException>(
() => runtime.InvokeAsync(() => Thread.CurrentThread.ManagedThreadId));
Assert.Contains("shutting down", exception.Message);
}
private static StaRuntime CreateRuntime(RecordingComApartmentInitializer initializer)
{
return new StaRuntime(
initializer,
new StaMessagePump(),
TimeSpan.FromMilliseconds(25));
}
private sealed class StaCommandObservation
{
public StaCommandObservation(int threadId, ApartmentState apartmentState)
{
ThreadId = threadId;
ApartmentState = apartmentState;
}
public int ThreadId { get; }
public ApartmentState ApartmentState { get; }
}
private sealed class RecordingComApartmentInitializer : IStaComApartmentInitializer
{
public int InitializeCount { get; private set; }
public int UninitializeCount { get; private set; }
public int? InitializeThreadId { get; private set; }
public int? UninitializeThreadId { get; private set; }
public void Initialize()
{
InitializeCount++;
InitializeThreadId = Thread.CurrentThread.ManagedThreadId;
}
public void Uninitialize()
{
UninitializeCount++;
UninitializeThreadId = Thread.CurrentThread.ManagedThreadId;
}
}
}
@@ -0,0 +1,11 @@
using System;
namespace MxGateway.Worker.Bootstrap;
public sealed class EnvironmentVariableWorkerEnvironment : IWorkerEnvironment
{
public string? GetEnvironmentVariable(string name)
{
return Environment.GetEnvironmentVariable(name);
}
}
@@ -0,0 +1,6 @@
namespace MxGateway.Worker.Bootstrap;
public interface IWorkerEnvironment
{
string? GetEnvironmentVariable(string name);
}
@@ -0,0 +1,10 @@
using System.Collections.Generic;
namespace MxGateway.Worker.Bootstrap;
public interface IWorkerLogger
{
void Information(string eventName, IReadOnlyDictionary<string, object?> fields);
void Error(string eventName, IReadOnlyDictionary<string, object?> fields);
}
@@ -0,0 +1,35 @@
using System.Collections.Generic;
using System.Linq;
namespace MxGateway.Worker.Bootstrap;
public sealed class WorkerBootstrapResult
{
private WorkerBootstrapResult(
WorkerExitCode exitCode,
WorkerOptions? options,
IReadOnlyList<string> errors)
{
ExitCode = exitCode;
Options = options;
Errors = errors;
}
public WorkerExitCode ExitCode { get; }
public WorkerOptions? Options { get; }
public IReadOnlyList<string> Errors { get; }
public bool Succeeded => ExitCode == WorkerExitCode.Success;
public static WorkerBootstrapResult Success(WorkerOptions options)
{
return new WorkerBootstrapResult(WorkerExitCode.Success, options, []);
}
public static WorkerBootstrapResult Failure(WorkerExitCode exitCode, IEnumerable<string> errors)
{
return new WorkerBootstrapResult(exitCode, null, errors.ToArray());
}
}
@@ -0,0 +1,44 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
namespace MxGateway.Worker.Bootstrap;
public sealed class WorkerConsoleLogger : IWorkerLogger
{
private readonly TextWriter _writer;
public WorkerConsoleLogger(TextWriter writer)
{
_writer = writer ?? throw new ArgumentNullException(nameof(writer));
}
public void Information(string eventName, IReadOnlyDictionary<string, object?> fields)
{
Write("Information", eventName, fields);
}
public void Error(string eventName, IReadOnlyDictionary<string, object?> fields)
{
Write("Error", eventName, fields);
}
private void Write(
string level,
string eventName,
IReadOnlyDictionary<string, object?> fields)
{
Dictionary<string, object?> redactedFields = WorkerLogRedactor.RedactFields(fields);
string fieldText = string.Join(
" ",
redactedFields.Select(field => $"{field.Key}={FormatValue(field.Value)}"));
_writer.WriteLine($"level={level} event={eventName} {fieldText}".TrimEnd());
}
private static string FormatValue(object? value)
{
return value?.ToString() ?? string.Empty;
}
}
@@ -0,0 +1,12 @@
namespace MxGateway.Worker.Bootstrap;
public enum WorkerExitCode
{
Success = 0,
UnexpectedFailure = 1,
InvalidArguments = 2,
InvalidProtocolVersion = 3,
MissingNonce = 4,
PipeConnectionFailed = 5,
ProtocolViolation = 6,
}
@@ -0,0 +1,50 @@
using System;
using System.Collections.Generic;
namespace MxGateway.Worker.Bootstrap;
public static class WorkerLogRedactor
{
public const string RedactedValue = "[redacted]";
private static readonly string[] SensitiveFieldNameParts =
[
"nonce",
"secret",
"password",
"token",
"credential",
"apikey",
"api_key",
];
public static Dictionary<string, object?> RedactFields(IReadOnlyDictionary<string, object?> fields)
{
Dictionary<string, object?> redactedFields = [];
foreach (KeyValuePair<string, object?> field in fields)
{
redactedFields[field.Key] = RedactValue(field.Key, field.Value);
}
return redactedFields;
}
public static object? RedactValue(string fieldName, object? value)
{
if (value is null)
{
return null;
}
foreach (string sensitiveFieldNamePart in SensitiveFieldNameParts)
{
if (fieldName.IndexOf(sensitiveFieldNamePart, StringComparison.OrdinalIgnoreCase) >= 0)
{
return RedactedValue;
}
}
return value;
}
}
@@ -0,0 +1,26 @@
namespace MxGateway.Worker.Bootstrap;
public sealed class WorkerOptions
{
public const string NonceEnvironmentVariableName = "MXGATEWAY_WORKER_NONCE";
public WorkerOptions(
string sessionId,
string pipeName,
uint protocolVersion,
string nonce)
{
SessionId = sessionId;
PipeName = pipeName;
ProtocolVersion = protocolVersion;
Nonce = nonce;
}
public string SessionId { get; }
public string PipeName { get; }
public uint ProtocolVersion { get; }
public string Nonce { get; }
}
@@ -0,0 +1,101 @@
using System;
using System.Collections.Generic;
using MxGateway.Contracts;
namespace MxGateway.Worker.Bootstrap;
public sealed class WorkerOptionsParser
{
private const string SessionIdOptionName = "--session-id";
private const string PipeNameOptionName = "--pipe-name";
private const string ProtocolVersionOptionName = "--protocol-version";
private readonly IWorkerEnvironment _environment;
public WorkerOptionsParser(IWorkerEnvironment environment)
{
_environment = environment ?? throw new ArgumentNullException(nameof(environment));
}
public WorkerBootstrapResult Parse(string[] args)
{
if (args is null)
{
throw new ArgumentNullException(nameof(args));
}
Dictionary<string, string> values = new(StringComparer.OrdinalIgnoreCase);
List<string> errors = [];
for (int index = 0; index < args.Length; index++)
{
string arg = args[index];
if (!IsKnownOption(arg))
{
errors.Add($"Unknown option '{arg}'.");
continue;
}
if (index + 1 >= args.Length || args[index + 1].StartsWith("--", StringComparison.Ordinal))
{
errors.Add($"Option '{arg}' requires a value.");
continue;
}
values[arg] = args[index + 1];
index++;
}
string? sessionId = ReadRequired(values, SessionIdOptionName, errors);
string? pipeName = ReadRequired(values, PipeNameOptionName, errors);
string? protocolVersionText = ReadRequired(values, ProtocolVersionOptionName, errors);
if (errors.Count > 0)
{
return WorkerBootstrapResult.Failure(WorkerExitCode.InvalidArguments, errors);
}
if (!uint.TryParse(protocolVersionText, out uint protocolVersion)
|| protocolVersion != GatewayContractInfo.WorkerProtocolVersion)
{
return WorkerBootstrapResult.Failure(
WorkerExitCode.InvalidProtocolVersion,
[$"Unsupported protocol version '{protocolVersionText}'."]);
}
string? nonce = _environment.GetEnvironmentVariable(WorkerOptions.NonceEnvironmentVariableName);
if (string.IsNullOrWhiteSpace(nonce))
{
return WorkerBootstrapResult.Failure(
WorkerExitCode.MissingNonce,
["Required worker nonce environment variable is missing."]);
}
return WorkerBootstrapResult.Success(new WorkerOptions(
sessionId!,
pipeName!,
protocolVersion,
nonce!));
}
private static string? ReadRequired(
IReadOnlyDictionary<string, string> values,
string optionName,
List<string> errors)
{
if (!values.TryGetValue(optionName, out string value)
|| string.IsNullOrWhiteSpace(value))
{
errors.Add($"Required option '{optionName}' is missing.");
return null;
}
return value;
}
private static bool IsKnownOption(string optionName)
{
return optionName is SessionIdOptionName or PipeNameOptionName or ProtocolVersionOptionName;
}
}
+1
View File
@@ -0,0 +1 @@
@@ -0,0 +1,522 @@
using System;
using System.Globalization;
using Google.Protobuf;
using Google.Protobuf.WellKnownTypes;
using MxGateway.Contracts.Proto;
namespace MxGateway.Worker.Conversion;
public sealed class VariantConverter
{
public MxValue Convert(object? value)
{
return Convert(value, MxDataType.Unspecified);
}
public MxValue Convert(
object? value,
MxDataType expectedDataType)
{
if (value is null || value is DBNull)
{
return CreateNullValue(value, expectedDataType);
}
if (value is Array array)
{
return new MxValue
{
DataType = MxDataType.Unspecified,
VariantType = CreateArrayVariantType(array),
ArrayValue = ConvertArray(array, expectedDataType),
};
}
return ConvertScalar(value, expectedDataType);
}
public MxArray ConvertArray(
Array array,
MxDataType expectedElementDataType = MxDataType.Unspecified)
{
if (array is null)
{
throw new ArgumentNullException(nameof(array));
}
MxArray mxArray = new()
{
VariantType = CreateArrayVariantType(array),
};
for (int dimension = 0; dimension < array.Rank; dimension++)
{
mxArray.Dimensions.Add((uint)array.GetLength(dimension));
}
System.Type? elementType = array.GetType().GetElementType();
MxDataType elementDataType = ResolveArrayElementDataType(elementType, expectedElementDataType);
mxArray.ElementDataType = elementDataType;
switch (elementDataType)
{
case MxDataType.Boolean:
mxArray.BoolValues = ConvertBoolArray(array);
return mxArray;
case MxDataType.Integer:
if (elementType == typeof(long) || elementType == typeof(ulong))
{
mxArray.Int64Values = ConvertInt64Array(array);
}
else
{
mxArray.Int32Values = ConvertInt32Array(array);
}
return mxArray;
case MxDataType.Float:
mxArray.FloatValues = ConvertFloatArray(array);
return mxArray;
case MxDataType.Double:
mxArray.DoubleValues = ConvertDoubleArray(array);
return mxArray;
case MxDataType.String:
mxArray.StringValues = ConvertStringArray(array);
return mxArray;
case MxDataType.Time:
mxArray.TimestampValues = ConvertTimestampArray(array);
return mxArray;
default:
mxArray.ElementDataType = MxDataType.Unknown;
mxArray.RawElementDataType = (int)expectedElementDataType;
mxArray.RawDiagnostic = CreateRawDiagnostic(array);
mxArray.RawValues = ConvertRawArray(array);
return mxArray;
}
}
private static MxValue ConvertScalar(
object value,
MxDataType expectedDataType)
{
System.Type valueType = value.GetType();
string variantType = GetVariantTypeName(valueType);
switch (System.Type.GetTypeCode(valueType))
{
case TypeCode.Boolean:
return new MxValue
{
DataType = MxDataType.Boolean,
VariantType = variantType,
BoolValue = (bool)value,
};
case TypeCode.Byte:
case TypeCode.SByte:
case TypeCode.Int16:
case TypeCode.UInt16:
case TypeCode.Int32:
return new MxValue
{
DataType = MxDataType.Integer,
VariantType = variantType,
Int32Value = System.Convert.ToInt32(value, CultureInfo.InvariantCulture),
};
case TypeCode.UInt32:
case TypeCode.Int64:
return ConvertInt64Scalar(value, variantType, expectedDataType);
case TypeCode.UInt64:
return ConvertUInt64Scalar((ulong)value, variantType, expectedDataType);
case TypeCode.Single:
return new MxValue
{
DataType = MxDataType.Float,
VariantType = variantType,
FloatValue = (float)value,
};
case TypeCode.Double:
return new MxValue
{
DataType = MxDataType.Double,
VariantType = variantType,
DoubleValue = (double)value,
};
case TypeCode.Decimal:
return new MxValue
{
DataType = MxDataType.Double,
VariantType = variantType,
DoubleValue = System.Convert.ToDouble(value, CultureInfo.InvariantCulture),
RawDiagnostic = "Decimal value projected to double.",
};
case TypeCode.String:
case TypeCode.Char:
return new MxValue
{
DataType = MxDataType.String,
VariantType = variantType,
StringValue = System.Convert.ToString(value, CultureInfo.InvariantCulture) ?? string.Empty,
};
case TypeCode.DateTime:
return new MxValue
{
DataType = MxDataType.Time,
VariantType = variantType,
TimestampValue = ToTimestamp((DateTime)value),
};
default:
return CreateRawValue(value, expectedDataType);
}
}
private static MxValue ConvertInt64Scalar(
object value,
string variantType,
MxDataType expectedDataType)
{
long longValue = System.Convert.ToInt64(value, CultureInfo.InvariantCulture);
if (expectedDataType == MxDataType.Time)
{
return new MxValue
{
DataType = MxDataType.Time,
VariantType = variantType,
TimestampValue = Timestamp.FromDateTime(DateTime.FromFileTimeUtc(longValue)),
};
}
return new MxValue
{
DataType = MxDataType.Integer,
VariantType = variantType,
Int64Value = longValue,
};
}
private static MxValue ConvertUInt64Scalar(
ulong value,
string variantType,
MxDataType expectedDataType)
{
if (expectedDataType == MxDataType.Time && value <= long.MaxValue)
{
return new MxValue
{
DataType = MxDataType.Time,
VariantType = variantType,
TimestampValue = Timestamp.FromDateTime(DateTime.FromFileTimeUtc((long)value)),
};
}
if (value <= long.MaxValue)
{
return new MxValue
{
DataType = MxDataType.Integer,
VariantType = variantType,
Int64Value = (long)value,
};
}
return CreateRawValue(value, expectedDataType, "UInt64 value exceeds Int64 range.");
}
private static MxValue CreateNullValue(
object? value,
MxDataType expectedDataType)
{
return new MxValue
{
DataType = expectedDataType == MxDataType.Unspecified ? MxDataType.NoData : expectedDataType,
VariantType = value is DBNull ? "VT_NULL" : "VT_EMPTY",
IsNull = true,
};
}
private static MxValue CreateRawValue(
object value,
MxDataType expectedDataType,
string? diagnosticPrefix = null)
{
string diagnostic = CreateRawDiagnostic(value);
if (!string.IsNullOrWhiteSpace(diagnosticPrefix))
{
diagnostic = $"{diagnosticPrefix} {diagnostic}";
}
return new MxValue
{
DataType = MxDataType.Unknown,
VariantType = GetVariantTypeName(value.GetType()),
RawDataType = (int)expectedDataType,
RawDiagnostic = diagnostic,
RawValue = ByteString.CopyFromUtf8(System.Convert.ToString(value, CultureInfo.InvariantCulture) ?? string.Empty),
};
}
private static BoolArray ConvertBoolArray(Array array)
{
BoolArray values = new();
foreach (object? item in array)
{
values.Values.Add(item is not null && System.Convert.ToBoolean(item, CultureInfo.InvariantCulture));
}
return values;
}
private static Int32Array ConvertInt32Array(Array array)
{
Int32Array values = new();
foreach (object? item in array)
{
values.Values.Add(item is null ? 0 : System.Convert.ToInt32(item, CultureInfo.InvariantCulture));
}
return values;
}
private static Int64Array ConvertInt64Array(Array array)
{
Int64Array values = new();
foreach (object? item in array)
{
values.Values.Add(item is null ? 0 : System.Convert.ToInt64(item, CultureInfo.InvariantCulture));
}
return values;
}
private static FloatArray ConvertFloatArray(Array array)
{
FloatArray values = new();
foreach (object? item in array)
{
values.Values.Add(item is null ? 0 : System.Convert.ToSingle(item, CultureInfo.InvariantCulture));
}
return values;
}
private static DoubleArray ConvertDoubleArray(Array array)
{
DoubleArray values = new();
foreach (object? item in array)
{
values.Values.Add(item is null ? 0 : System.Convert.ToDouble(item, CultureInfo.InvariantCulture));
}
return values;
}
private static StringArray ConvertStringArray(Array array)
{
StringArray values = new();
foreach (object? item in array)
{
values.Values.Add(item is null ? string.Empty : System.Convert.ToString(item, CultureInfo.InvariantCulture) ?? string.Empty);
}
return values;
}
private static TimestampArray ConvertTimestampArray(Array array)
{
TimestampArray values = new();
foreach (object? item in array)
{
if (item is null)
{
values.Values.Add(Timestamp.FromDateTime(new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc)));
}
else if (item is DateTime dateTime)
{
values.Values.Add(ToTimestamp(dateTime));
}
else
{
long fileTime = System.Convert.ToInt64(item, CultureInfo.InvariantCulture);
values.Values.Add(Timestamp.FromDateTime(DateTime.FromFileTimeUtc(fileTime)));
}
}
return values;
}
private static RawArray ConvertRawArray(Array array)
{
RawArray values = new();
foreach (object? item in array)
{
string rawValue = item is null
? string.Empty
: System.Convert.ToString(item, CultureInfo.InvariantCulture) ?? string.Empty;
values.Values.Add(ByteString.CopyFromUtf8(rawValue));
}
return values;
}
private static MxDataType ResolveArrayElementDataType(
System.Type? elementType,
MxDataType expectedElementDataType)
{
if (expectedElementDataType != MxDataType.Unspecified)
{
return expectedElementDataType;
}
if (elementType == typeof(bool))
{
return MxDataType.Boolean;
}
if (elementType == typeof(byte)
|| elementType == typeof(sbyte)
|| elementType == typeof(short)
|| elementType == typeof(ushort)
|| elementType == typeof(int)
|| elementType == typeof(uint)
|| elementType == typeof(long)
|| elementType == typeof(ulong))
{
return MxDataType.Integer;
}
if (elementType == typeof(float))
{
return MxDataType.Float;
}
if (elementType == typeof(double) || elementType == typeof(decimal))
{
return MxDataType.Double;
}
if (elementType == typeof(string) || elementType == typeof(char))
{
return MxDataType.String;
}
if (elementType == typeof(DateTime))
{
return MxDataType.Time;
}
return MxDataType.Unknown;
}
private static Timestamp ToTimestamp(DateTime dateTime)
{
DateTime utcDateTime = dateTime.Kind switch
{
DateTimeKind.Utc => dateTime,
DateTimeKind.Local => dateTime.ToUniversalTime(),
_ => DateTime.SpecifyKind(dateTime, DateTimeKind.Utc),
};
return Timestamp.FromDateTime(utcDateTime);
}
private static string CreateArrayVariantType(Array array)
{
System.Type? elementType = array.GetType().GetElementType();
return $"SAFEARRAY({GetVariantTypeName(elementType)})";
}
private static string GetVariantTypeName(System.Type? type)
{
if (type is null)
{
return "VT_EMPTY";
}
System.Type nonNullableType = Nullable.GetUnderlyingType(type) ?? type;
if (nonNullableType == typeof(bool))
{
return "VT_BOOL";
}
if (nonNullableType == typeof(byte))
{
return "VT_UI1";
}
if (nonNullableType == typeof(sbyte))
{
return "VT_I1";
}
if (nonNullableType == typeof(short))
{
return "VT_I2";
}
if (nonNullableType == typeof(ushort))
{
return "VT_UI2";
}
if (nonNullableType == typeof(int))
{
return "VT_I4";
}
if (nonNullableType == typeof(uint))
{
return "VT_UI4";
}
if (nonNullableType == typeof(long))
{
return "VT_I8";
}
if (nonNullableType == typeof(ulong))
{
return "VT_UI8";
}
if (nonNullableType == typeof(float))
{
return "VT_R4";
}
if (nonNullableType == typeof(double) || nonNullableType == typeof(decimal))
{
return "VT_R8";
}
if (nonNullableType == typeof(string) || nonNullableType == typeof(char))
{
return "VT_BSTR";
}
if (nonNullableType == typeof(DateTime))
{
return "VT_DATE";
}
return $"CLR:{nonNullableType.FullName}";
}
private static string CreateRawDiagnostic(object value)
{
return $"Unsupported variant projection for CLR type '{value.GetType().FullName}'.";
}
}
@@ -0,0 +1,12 @@
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Worker.Bootstrap;
namespace MxGateway.Worker.Ipc;
public interface IWorkerPipeClient
{
Task RunAsync(
WorkerOptions options,
CancellationToken cancellationToken = default);
}
@@ -0,0 +1,11 @@
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
namespace MxGateway.Worker.Ipc;
public static class WorkerContractInfo
{
public static uint SupportedProtocolVersion => GatewayContractInfo.WorkerProtocolVersion;
public static string WorkerEnvelopeDescriptorName => WorkerEnvelope.Descriptor.FullName;
}

Some files were not shown because too many files have changed in this diff Show More