Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 451dccf7e3 | |||
| 5511609880 | |||
| cde9c89386 | |||
| d496f1fd75 | |||
| 6559672fc1 | |||
| 97c30b9d00 | |||
| 603aff7004 | |||
| e81682e367 | |||
| d5a982152b | |||
| 0b0be7098e | |||
| fce9e99553 | |||
| c8fb3e91a3 | |||
| 8ce327e6f4 | |||
| fad0ac9948 | |||
| 9cb2f1c5cd | |||
| da9ffe0e11 | |||
| 0af1427859 | |||
| e2b4dfcb32 |
@@ -330,6 +330,20 @@ The worker remains authoritative for MXAccess handles. The gateway may keep a
|
||||
shadow state for diagnostics, but it must not invent, rewrite, or recycle
|
||||
MXAccess handles.
|
||||
|
||||
`SessionManager` owns the current in-memory session registry. It allocates a
|
||||
session id, creates the worker pipe name and nonce, registers the session before
|
||||
worker startup, and removes the session if startup fails. A successful
|
||||
`OpenSession` attaches the ready `IWorkerClient` and transitions the session to
|
||||
`Ready`.
|
||||
|
||||
Only `Ready` sessions accept command and event operations. `CloseSession` is
|
||||
idempotent for sessions still known to the registry: the first close shuts down
|
||||
the worker, and later closes return the final `Closed` state. Lease handling is
|
||||
exposed as a session hook so a monitor can close expired sessions without
|
||||
embedding lease policy in the worker client. Gateway shutdown walks the
|
||||
registry, closes each known session, and kills a worker if graceful shutdown
|
||||
fails.
|
||||
|
||||
## Worker Launch
|
||||
|
||||
The gateway should launch the worker using explicit configuration:
|
||||
@@ -411,7 +425,7 @@ session ids as protocol faults and close the session.
|
||||
|
||||
`WorkerClient` is the gateway-side object that owns one worker connection.
|
||||
|
||||
Suggested public shape:
|
||||
Current public shape:
|
||||
|
||||
```csharp
|
||||
public interface IWorkerClient : IAsyncDisposable
|
||||
@@ -419,6 +433,7 @@ public interface IWorkerClient : IAsyncDisposable
|
||||
string SessionId { get; }
|
||||
int? ProcessId { get; }
|
||||
WorkerClientState State { get; }
|
||||
DateTimeOffset LastHeartbeatAt { get; }
|
||||
|
||||
Task StartAsync(CancellationToken cancellationToken);
|
||||
Task<WorkerCommandReply> InvokeAsync(
|
||||
@@ -438,12 +453,17 @@ Internally it owns:
|
||||
- pipe stream,
|
||||
- read loop,
|
||||
- write loop,
|
||||
- bounded outbound command/control channel,
|
||||
- outbound command/control channel serialized by the write loop,
|
||||
- bounded inbound event channel,
|
||||
- pending command dictionary keyed by correlation id,
|
||||
- heartbeat monitor,
|
||||
- terminal fault source.
|
||||
|
||||
`StartAsync` sends `GatewayHello`, verifies the `WorkerHello` protocol version
|
||||
and nonce, waits for `WorkerReady`, and only then exposes `Ready` state. The
|
||||
read loop starts after readiness so the handshake has a single owner for its
|
||||
ordered frames.
|
||||
|
||||
### Read Loop
|
||||
|
||||
The read loop:
|
||||
@@ -612,6 +632,15 @@ hashes the presented secret, and compares the stored and presented hashes with
|
||||
results distinguish malformed credentials, missing keys, revoked keys, missing
|
||||
pepper configuration, and hash mismatch for internal authorization handling.
|
||||
|
||||
`GatewayGrpcAuthorizationInterceptor` enforces this authentication model for
|
||||
public gRPC calls. Missing, malformed, revoked, unknown, or mismatched keys fail
|
||||
with `Unauthenticated`. Authenticated calls missing the scope required by the
|
||||
RPC fail with `PermissionDenied`. The interceptor applies to unary calls and
|
||||
server-streaming calls and stores the authenticated `ApiKeyIdentity` in
|
||||
`IGatewayRequestIdentityAccessor` for the duration of the request handler.
|
||||
`Authentication:Mode` set to `Disabled` bypasses API-key verification for local
|
||||
development only.
|
||||
|
||||
Recommended scopes:
|
||||
|
||||
- `session:open`
|
||||
@@ -631,6 +660,23 @@ gRPC admin API. It should initialize the auth database, create keys, list keys
|
||||
without secrets, revoke keys, rotate keys, and print raw secrets only once at
|
||||
creation.
|
||||
|
||||
`MxGateway.Server` exposes local API-key administration as an `apikey`
|
||||
subcommand before the web host starts:
|
||||
|
||||
```bash
|
||||
MxGateway.Server apikey init-db --sqlite-path C:\ProgramData\MxGateway\gateway-auth.db
|
||||
MxGateway.Server apikey create-key --key-id operator01 --display-name Operator --scopes session:open,events:read
|
||||
MxGateway.Server apikey list-keys --json
|
||||
MxGateway.Server apikey revoke-key --key-id operator01
|
||||
MxGateway.Server apikey rotate-key --key-id operator01 --json
|
||||
```
|
||||
|
||||
The subcommands accept `--sqlite-path`, `--pepper`, and `--json`. `--pepper`
|
||||
sets the local `MxGateway:ApiKeyPepper` configuration value for the command
|
||||
process; deployments should normally provide the pepper through the configured
|
||||
secret source. `create-key` and `rotate-key` print the full raw API key exactly
|
||||
once. `list-keys` never prints raw secrets or `secret_hash` values.
|
||||
|
||||
SQLite auth storage should use startup migrations with a `schema_version` table.
|
||||
Migrations should run inside transactions and fail startup if the database
|
||||
schema is newer than the running binary understands.
|
||||
@@ -660,6 +706,20 @@ Commands requiring authorization:
|
||||
- worker shutdown diagnostics,
|
||||
- metadata queries if they expose sensitive plant structure.
|
||||
|
||||
Current gRPC scope mapping:
|
||||
|
||||
- `OpenSession` requires `session:open`.
|
||||
- `CloseSession` requires `session:close`.
|
||||
- `StreamEvents` and `DrainEvents` require `events:read`.
|
||||
- read-style MXAccess commands such as `Register`, `AddItem`, `Advise`, and
|
||||
`Ping` require `invoke:read`.
|
||||
- `Write` and `Write2` require `invoke:write`.
|
||||
- `WriteSecured`, `WriteSecured2`, and `AuthenticateUser` require
|
||||
`invoke:secure`.
|
||||
- metadata commands such as `ArchestrAUserToId`, `GetSessionState`, and
|
||||
`GetWorkerInfo` require `metadata:read`.
|
||||
- `ShutdownWorker` requires `admin`.
|
||||
|
||||
### Worker IPC
|
||||
|
||||
Named pipes should be local only. Pipe ACLs should restrict access to:
|
||||
@@ -802,6 +862,9 @@ workers and fake transports.
|
||||
Focused tests:
|
||||
|
||||
- session state transitions,
|
||||
- gRPC API-key authentication for unary and streaming calls,
|
||||
- gRPC scope mapping for sessions, invokes, events, metadata, and admin
|
||||
commands,
|
||||
- worker startup failures,
|
||||
- protocol version mismatch,
|
||||
- malformed frame handling,
|
||||
|
||||
@@ -114,6 +114,21 @@ Startup sequence:
|
||||
If validation fails before MXAccess creation, exit quickly with a non-zero exit
|
||||
code. If MXAccess creation fails, send `WorkerFault` when possible and exit.
|
||||
|
||||
The bootstrap layer returns structured exit codes before it creates pipes,
|
||||
starts the STA, or touches MXAccess:
|
||||
|
||||
| Exit code | Name | Meaning |
|
||||
|-----------|------|---------|
|
||||
| `0` | `Success` | Required bootstrap options are valid. |
|
||||
| `1` | `UnexpectedFailure` | A non-bootstrap exception reaches the process boundary. |
|
||||
| `2` | `InvalidArguments` | Required arguments are missing or unknown arguments are present. |
|
||||
| `3` | `InvalidProtocolVersion` | `--protocol-version` is not numeric or does not match the supported worker protocol. |
|
||||
| `4` | `MissingNonce` | `MXGATEWAY_WORKER_NONCE` is absent or empty. |
|
||||
|
||||
Bootstrap logs use `WorkerConsoleLogger` key/value output. `WorkerLogRedactor`
|
||||
redacts fields whose names indicate nonce, secret, password, token,
|
||||
credential, or API key values before the message is written.
|
||||
|
||||
## Internal Components
|
||||
|
||||
```text
|
||||
@@ -235,6 +250,17 @@ The loop should update a heartbeat timestamp after:
|
||||
- finishing a command,
|
||||
- processing an MXAccess event.
|
||||
|
||||
`StaRuntime` implements this runtime boundary in the worker. It starts one
|
||||
background thread named `MxGateway.Worker.STA`, sets it to `ApartmentState.STA`,
|
||||
initializes COM through `StaComApartmentInitializer`, and runs
|
||||
`StaMessagePump`. Commands are scheduled through `InvokeAsync`; the command
|
||||
queue signals an `AutoResetEvent` so `MsgWaitForMultipleObjectsEx` can wake the
|
||||
STA without busy-waiting. `LastActivityUtc` records pump, command, startup, and
|
||||
shutdown activity so the future heartbeat/watchdog can report whether the STA
|
||||
is still responsive. Shutdown marks the runtime as closing, wakes the pump,
|
||||
rejects new commands, cancels queued work, uninitializes COM on the STA, and
|
||||
waits for the thread to exit.
|
||||
|
||||
## COM Creation
|
||||
|
||||
The MXAccess analysis source at `C:\Users\dohertj2\Desktop\mxaccess` identifies
|
||||
@@ -263,6 +289,13 @@ The worker should reference the interop assembly and instantiate
|
||||
`LMXProxyServerClass` on the dedicated STA thread. Keep the ProgID and assembly
|
||||
path configurable for diagnostics, but this COM class is the v1 default.
|
||||
|
||||
`MxAccessStaSession` owns the initial COM creation path. It starts `StaRuntime`,
|
||||
creates `LMXProxyServerClass` through `MxAccessComObjectFactory` on the STA,
|
||||
attaches `MxAccessBaseEventSink`, and returns `WorkerReady` only after those
|
||||
steps succeed. `MxAccessSession` keeps the raw COM object private, records the
|
||||
STA managed thread id that created it, detaches the base event sink during
|
||||
disposal, and releases the COM reference on the STA.
|
||||
|
||||
Creation rules:
|
||||
|
||||
- Create COM object only on the STA.
|
||||
@@ -280,6 +313,11 @@ If COM creation fails, the worker should send a structured fault with:
|
||||
- worker process id,
|
||||
- session id.
|
||||
|
||||
`WorkerPipeSession` maps startup exceptions from this path to
|
||||
`WorkerFaultCategory.MxaccessCreationFailed`, includes the captured HRESULT
|
||||
when the exception exposes one, and does not send `WorkerReady` after a failed
|
||||
COM creation attempt.
|
||||
|
||||
## Event Sink
|
||||
|
||||
The worker must subscribe to every public MXAccess event family:
|
||||
|
||||
+13
-3
@@ -566,9 +566,13 @@ Because each client owns one worker, a crash or leak affects only that session.
|
||||
External gateway:
|
||||
|
||||
- use TLS for remote gRPC if crossing machine boundaries,
|
||||
- authenticate clients with Windows auth, mTLS, or a deployment-specific token,
|
||||
- authorize access to commands that can write, authenticate users, or alter
|
||||
runtime state.
|
||||
- authenticate v1 gRPC clients with `authorization: Bearer
|
||||
mxgw_<key-id>_<secret>` API-key metadata,
|
||||
- reject missing or invalid API keys with gRPC `Unauthenticated`,
|
||||
- reject valid keys that lack the required session, invoke, event, metadata, or
|
||||
admin scope with gRPC `PermissionDenied`,
|
||||
- authorize access to commands that can write, authenticate users, expose
|
||||
metadata, stream events, or alter runtime state.
|
||||
|
||||
Internal worker IPC:
|
||||
|
||||
@@ -795,6 +799,12 @@ Core operations:
|
||||
- track worker state,
|
||||
- close or kill worker.
|
||||
|
||||
The gateway implementation keeps sessions in an in-memory `SessionRegistry`
|
||||
keyed by session id. `SessionManager` owns the state machine, creates
|
||||
per-session pipe names and nonces, starts the worker through the worker-client
|
||||
factory, gates commands to `Ready` sessions, exposes lease-close hooks, and
|
||||
cleans up workers during gateway shutdown.
|
||||
|
||||
State machine:
|
||||
|
||||
```text
|
||||
|
||||
@@ -3,6 +3,8 @@ using MxGateway.Server.Configuration;
|
||||
using MxGateway.Server.Diagnostics;
|
||||
using MxGateway.Server.Metrics;
|
||||
using MxGateway.Server.Security.Authentication;
|
||||
using MxGateway.Server.Security.Authorization;
|
||||
using MxGateway.Server.Sessions;
|
||||
using MxGateway.Server.Workers;
|
||||
|
||||
namespace MxGateway.Server;
|
||||
@@ -26,9 +28,11 @@ public static class GatewayApplication
|
||||
|
||||
builder.Services.AddGatewayConfiguration();
|
||||
builder.Services.AddSqliteAuthStore();
|
||||
builder.Services.AddGatewayGrpcAuthorization();
|
||||
builder.Services.AddHealthChecks();
|
||||
builder.Services.AddSingleton<GatewayMetrics>();
|
||||
builder.Services.AddWorkerProcessLauncher();
|
||||
builder.Services.AddGatewaySessions();
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Grpc.AspNetCore" Version="2.76.0" />
|
||||
<PackageReference Include="Microsoft.Data.Sqlite" Version="10.0.7" />
|
||||
</ItemGroup>
|
||||
|
||||
|
||||
@@ -1,7 +1,43 @@
|
||||
using MxGateway.Server;
|
||||
using MxGateway.Server.Configuration;
|
||||
using MxGateway.Server.Security.Authentication;
|
||||
|
||||
var app = GatewayApplication.Build(args);
|
||||
ApiKeyAdminParseResult apiKeyAdminCommand = ApiKeyAdminCommandLineParser.Parse(args);
|
||||
if (apiKeyAdminCommand.IsApiKeyCommand)
|
||||
{
|
||||
if (apiKeyAdminCommand.Command is null)
|
||||
{
|
||||
await Console.Error.WriteLineAsync(apiKeyAdminCommand.Error);
|
||||
return 2;
|
||||
}
|
||||
|
||||
WebApplicationBuilder builder = GatewayApplication.CreateBuilder([]);
|
||||
ApplyApiKeyAdminOverrides(builder.Configuration, apiKeyAdminCommand.Command);
|
||||
await using WebApplication cliApp = builder.Build();
|
||||
await using AsyncServiceScope scope = cliApp.Services.CreateAsyncScope();
|
||||
|
||||
ApiKeyAdminCliRunner runner = scope.ServiceProvider.GetRequiredService<ApiKeyAdminCliRunner>();
|
||||
|
||||
return await runner.RunAsync(apiKeyAdminCommand.Command, Console.Out, CancellationToken.None);
|
||||
}
|
||||
|
||||
WebApplication app = GatewayApplication.Build(args);
|
||||
|
||||
app.Run();
|
||||
|
||||
return 0;
|
||||
|
||||
static void ApplyApiKeyAdminOverrides(IConfiguration configuration, ApiKeyAdminCommand command)
|
||||
{
|
||||
if (!string.IsNullOrWhiteSpace(command.SqlitePath))
|
||||
{
|
||||
configuration[$"{GatewayOptions.SectionName}:Authentication:SqlitePath"] = command.SqlitePath;
|
||||
}
|
||||
|
||||
if (!string.IsNullOrWhiteSpace(command.Pepper))
|
||||
{
|
||||
configuration["MxGateway:ApiKeyPepper"] = command.Pepper;
|
||||
}
|
||||
}
|
||||
|
||||
public partial class Program;
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
using System.Text.Json;
|
||||
|
||||
namespace MxGateway.Server.Security.Authentication;
|
||||
|
||||
public sealed class ApiKeyAdminCliRunner(
|
||||
IAuthStoreMigrator migrator,
|
||||
IApiKeyAdminStore adminStore,
|
||||
IApiKeyAuditStore auditStore,
|
||||
IApiKeySecretHasher hasher)
|
||||
{
|
||||
private static readonly JsonSerializerOptions JsonOptions = new()
|
||||
{
|
||||
WriteIndented = true
|
||||
};
|
||||
|
||||
public async Task<int> RunAsync(
|
||||
ApiKeyAdminCommand command,
|
||||
TextWriter output,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
ApiKeyAdminOutput result = command.Kind switch
|
||||
{
|
||||
ApiKeyAdminCommandKind.InitDb => await InitDbAsync(cancellationToken).ConfigureAwait(false),
|
||||
ApiKeyAdminCommandKind.CreateKey => await CreateKeyAsync(command, cancellationToken).ConfigureAwait(false),
|
||||
ApiKeyAdminCommandKind.ListKeys => await ListKeysAsync(cancellationToken).ConfigureAwait(false),
|
||||
ApiKeyAdminCommandKind.RevokeKey => await RevokeKeyAsync(command, cancellationToken).ConfigureAwait(false),
|
||||
ApiKeyAdminCommandKind.RotateKey => await RotateKeyAsync(command, cancellationToken).ConfigureAwait(false),
|
||||
_ => throw new InvalidOperationException($"Unsupported API key command '{command.Kind}'.")
|
||||
};
|
||||
|
||||
await WriteOutputAsync(command, result, output).ConfigureAwait(false);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
private async Task<ApiKeyAdminOutput> InitDbAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
await migrator.MigrateAsync(cancellationToken).ConfigureAwait(false);
|
||||
await AppendAuditAsync(null, "init-db", null, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
return new ApiKeyAdminOutput("init-db", "initialized", null, []);
|
||||
}
|
||||
|
||||
private async Task<ApiKeyAdminOutput> CreateKeyAsync(
|
||||
ApiKeyAdminCommand command,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
await migrator.MigrateAsync(cancellationToken).ConfigureAwait(false);
|
||||
|
||||
string keyId = Required(command.KeyId);
|
||||
string secret = ApiKeySecretGenerator.Generate();
|
||||
string apiKey = FormatApiKey(keyId, secret);
|
||||
|
||||
await adminStore.CreateAsync(
|
||||
new ApiKeyCreateRequest(
|
||||
KeyId: keyId,
|
||||
KeyPrefix: $"mxgw_{keyId}",
|
||||
SecretHash: hasher.HashSecret(secret),
|
||||
DisplayName: Required(command.DisplayName),
|
||||
Scopes: command.Scopes,
|
||||
CreatedUtc: DateTimeOffset.UtcNow),
|
||||
cancellationToken)
|
||||
.ConfigureAwait(false);
|
||||
await AppendAuditAsync(keyId, "create-key", null, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
return new ApiKeyAdminOutput("create-key", "created", apiKey, []);
|
||||
}
|
||||
|
||||
private async Task<ApiKeyAdminOutput> ListKeysAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
await migrator.MigrateAsync(cancellationToken).ConfigureAwait(false);
|
||||
IReadOnlyList<ApiKeyRecord> keys = await adminStore.ListAsync(cancellationToken).ConfigureAwait(false);
|
||||
await AppendAuditAsync(null, "list-keys", null, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
return new ApiKeyAdminOutput(
|
||||
"list-keys",
|
||||
"ok",
|
||||
null,
|
||||
keys.Select(ToListedKey).ToArray());
|
||||
}
|
||||
|
||||
private async Task<ApiKeyAdminOutput> RevokeKeyAsync(
|
||||
ApiKeyAdminCommand command,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
await migrator.MigrateAsync(cancellationToken).ConfigureAwait(false);
|
||||
|
||||
string keyId = Required(command.KeyId);
|
||||
bool revoked = await adminStore.RevokeAsync(keyId, DateTimeOffset.UtcNow, cancellationToken)
|
||||
.ConfigureAwait(false);
|
||||
|
||||
await AppendAuditAsync(keyId, "revoke-key", revoked ? "revoked" : "not-found-or-already-revoked", cancellationToken)
|
||||
.ConfigureAwait(false);
|
||||
|
||||
return new ApiKeyAdminOutput("revoke-key", revoked ? "revoked" : "not-found-or-already-revoked", null, []);
|
||||
}
|
||||
|
||||
private async Task<ApiKeyAdminOutput> RotateKeyAsync(
|
||||
ApiKeyAdminCommand command,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
await migrator.MigrateAsync(cancellationToken).ConfigureAwait(false);
|
||||
|
||||
string keyId = Required(command.KeyId);
|
||||
string secret = ApiKeySecretGenerator.Generate();
|
||||
string apiKey = FormatApiKey(keyId, secret);
|
||||
|
||||
bool rotated = await adminStore.RotateAsync(keyId, hasher.HashSecret(secret), DateTimeOffset.UtcNow, cancellationToken)
|
||||
.ConfigureAwait(false);
|
||||
|
||||
await AppendAuditAsync(keyId, "rotate-key", rotated ? "rotated" : "not-found", cancellationToken)
|
||||
.ConfigureAwait(false);
|
||||
|
||||
return new ApiKeyAdminOutput("rotate-key", rotated ? "rotated" : "not-found", rotated ? apiKey : null, []);
|
||||
}
|
||||
|
||||
private static async Task WriteOutputAsync(
|
||||
ApiKeyAdminCommand command,
|
||||
ApiKeyAdminOutput result,
|
||||
TextWriter output)
|
||||
{
|
||||
if (command.Json)
|
||||
{
|
||||
await output.WriteLineAsync(JsonSerializer.Serialize(result, JsonOptions)).ConfigureAwait(false);
|
||||
return;
|
||||
}
|
||||
|
||||
await output.WriteLineAsync($"{result.Command}: {result.Status}").ConfigureAwait(false);
|
||||
|
||||
if (result.ApiKey is not null)
|
||||
{
|
||||
await output.WriteLineAsync($"API key: {result.ApiKey}").ConfigureAwait(false);
|
||||
}
|
||||
|
||||
foreach (ApiKeyAdminListedKey key in result.Keys)
|
||||
{
|
||||
string revoked = key.RevokedUtc is null ? "active" : "revoked";
|
||||
await output.WriteLineAsync($"{key.KeyId}\t{key.DisplayName}\t{revoked}\t{string.Join(',', key.Scopes)}")
|
||||
.ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task AppendAuditAsync(
|
||||
string? keyId,
|
||||
string eventType,
|
||||
string? details,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
await auditStore.AppendAsync(
|
||||
new ApiKeyAuditEntry(
|
||||
KeyId: keyId,
|
||||
EventType: eventType,
|
||||
RemoteAddress: null,
|
||||
Details: details),
|
||||
cancellationToken)
|
||||
.ConfigureAwait(false);
|
||||
}
|
||||
|
||||
private static ApiKeyAdminListedKey ToListedKey(ApiKeyRecord key)
|
||||
{
|
||||
return new ApiKeyAdminListedKey(
|
||||
KeyId: key.KeyId,
|
||||
KeyPrefix: key.KeyPrefix,
|
||||
DisplayName: key.DisplayName,
|
||||
Scopes: key.Scopes,
|
||||
CreatedUtc: key.CreatedUtc,
|
||||
LastUsedUtc: key.LastUsedUtc,
|
||||
RevokedUtc: key.RevokedUtc);
|
||||
}
|
||||
|
||||
private static string FormatApiKey(string keyId, string secret)
|
||||
{
|
||||
return $"mxgw_{keyId}_{secret}";
|
||||
}
|
||||
|
||||
private static string Required(string? value)
|
||||
{
|
||||
return value ?? throw new InvalidOperationException("Required command value was not provided.");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
namespace MxGateway.Server.Security.Authentication;
|
||||
|
||||
public sealed record ApiKeyAdminCommand(
|
||||
ApiKeyAdminCommandKind Kind,
|
||||
bool Json,
|
||||
string? SqlitePath,
|
||||
string? Pepper,
|
||||
string? KeyId,
|
||||
string? DisplayName,
|
||||
IReadOnlySet<string> Scopes);
|
||||
@@ -0,0 +1,10 @@
|
||||
namespace MxGateway.Server.Security.Authentication;
|
||||
|
||||
public enum ApiKeyAdminCommandKind
|
||||
{
|
||||
InitDb,
|
||||
CreateKey,
|
||||
ListKeys,
|
||||
RevokeKey,
|
||||
RotateKey
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
namespace MxGateway.Server.Security.Authentication;
|
||||
|
||||
public static class ApiKeyAdminCommandLineParser
|
||||
{
|
||||
public static ApiKeyAdminParseResult Parse(IReadOnlyList<string> args)
|
||||
{
|
||||
if (args.Count == 0 || !string.Equals(args[0], "apikey", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
return ApiKeyAdminParseResult.NotApiKeyCommand();
|
||||
}
|
||||
|
||||
if (args.Count < 2)
|
||||
{
|
||||
return ApiKeyAdminParseResult.Fail("Missing apikey subcommand.");
|
||||
}
|
||||
|
||||
if (!TryParseKind(args[1], out ApiKeyAdminCommandKind kind))
|
||||
{
|
||||
return ApiKeyAdminParseResult.Fail($"Unknown apikey subcommand '{args[1]}'.");
|
||||
}
|
||||
|
||||
Dictionary<string, string?> options = new(StringComparer.OrdinalIgnoreCase);
|
||||
bool json = false;
|
||||
|
||||
for (int index = 2; index < args.Count; index++)
|
||||
{
|
||||
string arg = args[index];
|
||||
if (string.Equals(arg, "--json", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
json = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!arg.StartsWith("--", StringComparison.Ordinal))
|
||||
{
|
||||
return ApiKeyAdminParseResult.Fail($"Unexpected argument '{arg}'.");
|
||||
}
|
||||
|
||||
string name = arg[2..];
|
||||
string? value;
|
||||
|
||||
int equalsIndex = name.IndexOf('=', StringComparison.Ordinal);
|
||||
if (equalsIndex >= 0)
|
||||
{
|
||||
value = name[(equalsIndex + 1)..];
|
||||
name = name[..equalsIndex];
|
||||
}
|
||||
else
|
||||
{
|
||||
if (index + 1 >= args.Count || args[index + 1].StartsWith("--", StringComparison.Ordinal))
|
||||
{
|
||||
return ApiKeyAdminParseResult.Fail($"Option '--{name}' requires a value.");
|
||||
}
|
||||
|
||||
value = args[++index];
|
||||
}
|
||||
|
||||
options[name] = value;
|
||||
}
|
||||
|
||||
string? keyId = GetOption(options, "key-id");
|
||||
string? displayName = GetOption(options, "display-name");
|
||||
IReadOnlySet<string> scopes = ParseScopes(GetOption(options, "scopes"));
|
||||
|
||||
string? validationError = Validate(kind, keyId, displayName);
|
||||
if (validationError is not null)
|
||||
{
|
||||
return ApiKeyAdminParseResult.Fail(validationError);
|
||||
}
|
||||
|
||||
return ApiKeyAdminParseResult.Success(new ApiKeyAdminCommand(
|
||||
Kind: kind,
|
||||
Json: json,
|
||||
SqlitePath: GetOption(options, "sqlite-path"),
|
||||
Pepper: GetOption(options, "pepper"),
|
||||
KeyId: keyId,
|
||||
DisplayName: displayName,
|
||||
Scopes: scopes));
|
||||
}
|
||||
|
||||
private static bool TryParseKind(string value, out ApiKeyAdminCommandKind kind)
|
||||
{
|
||||
switch (value.ToLowerInvariant())
|
||||
{
|
||||
case "init-db":
|
||||
kind = ApiKeyAdminCommandKind.InitDb;
|
||||
return true;
|
||||
case "create-key":
|
||||
kind = ApiKeyAdminCommandKind.CreateKey;
|
||||
return true;
|
||||
case "list-keys":
|
||||
kind = ApiKeyAdminCommandKind.ListKeys;
|
||||
return true;
|
||||
case "revoke-key":
|
||||
kind = ApiKeyAdminCommandKind.RevokeKey;
|
||||
return true;
|
||||
case "rotate-key":
|
||||
kind = ApiKeyAdminCommandKind.RotateKey;
|
||||
return true;
|
||||
default:
|
||||
kind = default;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private static string? Validate(ApiKeyAdminCommandKind kind, string? keyId, string? displayName)
|
||||
{
|
||||
if (kind is ApiKeyAdminCommandKind.CreateKey or ApiKeyAdminCommandKind.RevokeKey or ApiKeyAdminCommandKind.RotateKey
|
||||
&& string.IsNullOrWhiteSpace(keyId))
|
||||
{
|
||||
return $"Subcommand '{KindName(kind)}' requires --key-id.";
|
||||
}
|
||||
|
||||
if (!string.IsNullOrWhiteSpace(keyId) && !IsValidKeyId(keyId))
|
||||
{
|
||||
return "API key id may contain only letters, numbers, periods, and hyphens.";
|
||||
}
|
||||
|
||||
if (kind == ApiKeyAdminCommandKind.CreateKey && string.IsNullOrWhiteSpace(displayName))
|
||||
{
|
||||
return "Subcommand 'create-key' requires --display-name.";
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private static string KindName(ApiKeyAdminCommandKind kind)
|
||||
{
|
||||
return kind switch
|
||||
{
|
||||
ApiKeyAdminCommandKind.InitDb => "init-db",
|
||||
ApiKeyAdminCommandKind.CreateKey => "create-key",
|
||||
ApiKeyAdminCommandKind.ListKeys => "list-keys",
|
||||
ApiKeyAdminCommandKind.RevokeKey => "revoke-key",
|
||||
ApiKeyAdminCommandKind.RotateKey => "rotate-key",
|
||||
_ => kind.ToString()
|
||||
};
|
||||
}
|
||||
|
||||
private static bool IsValidKeyId(string keyId)
|
||||
{
|
||||
return keyId.All(character =>
|
||||
char.IsAsciiLetterOrDigit(character)
|
||||
|| character is '.' or '-');
|
||||
}
|
||||
|
||||
private static string? GetOption(Dictionary<string, string?> options, string name)
|
||||
{
|
||||
return options.TryGetValue(name, out string? value) ? value : null;
|
||||
}
|
||||
|
||||
private static IReadOnlySet<string> ParseScopes(string? scopes)
|
||||
{
|
||||
return new HashSet<string>(
|
||||
(scopes ?? string.Empty)
|
||||
.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries),
|
||||
StringComparer.Ordinal);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
namespace MxGateway.Server.Security.Authentication;
|
||||
|
||||
public sealed record ApiKeyAdminListedKey(
|
||||
string KeyId,
|
||||
string KeyPrefix,
|
||||
string DisplayName,
|
||||
IReadOnlySet<string> Scopes,
|
||||
DateTimeOffset CreatedUtc,
|
||||
DateTimeOffset? LastUsedUtc,
|
||||
DateTimeOffset? RevokedUtc);
|
||||
@@ -0,0 +1,7 @@
|
||||
namespace MxGateway.Server.Security.Authentication;
|
||||
|
||||
public sealed record ApiKeyAdminOutput(
|
||||
string Command,
|
||||
string Status,
|
||||
string? ApiKey,
|
||||
IReadOnlyList<ApiKeyAdminListedKey> Keys);
|
||||
@@ -0,0 +1,22 @@
|
||||
namespace MxGateway.Server.Security.Authentication;
|
||||
|
||||
public sealed record ApiKeyAdminParseResult(
|
||||
bool IsApiKeyCommand,
|
||||
ApiKeyAdminCommand? Command,
|
||||
string? Error)
|
||||
{
|
||||
public static ApiKeyAdminParseResult NotApiKeyCommand()
|
||||
{
|
||||
return new ApiKeyAdminParseResult(false, null, null);
|
||||
}
|
||||
|
||||
public static ApiKeyAdminParseResult Success(ApiKeyAdminCommand command)
|
||||
{
|
||||
return new ApiKeyAdminParseResult(true, command, null);
|
||||
}
|
||||
|
||||
public static ApiKeyAdminParseResult Fail(string error)
|
||||
{
|
||||
return new ApiKeyAdminParseResult(true, null, error);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
namespace MxGateway.Server.Security.Authentication;
|
||||
|
||||
public sealed record ApiKeyCreateRequest(
|
||||
string KeyId,
|
||||
string KeyPrefix,
|
||||
byte[] SecretHash,
|
||||
string DisplayName,
|
||||
IReadOnlySet<string> Scopes,
|
||||
DateTimeOffset CreatedUtc);
|
||||
@@ -0,0 +1,26 @@
|
||||
using Microsoft.Data.Sqlite;
|
||||
|
||||
namespace MxGateway.Server.Security.Authentication;
|
||||
|
||||
public static class ApiKeyRecordReader
|
||||
{
|
||||
public static ApiKeyRecord Read(SqliteDataReader reader)
|
||||
{
|
||||
return new ApiKeyRecord(
|
||||
KeyId: reader.GetString(0),
|
||||
KeyPrefix: reader.GetString(1),
|
||||
SecretHash: (byte[])reader["secret_hash"],
|
||||
DisplayName: reader.GetString(3),
|
||||
Scopes: ApiKeyScopeSerializer.Deserialize(reader.GetString(4)),
|
||||
CreatedUtc: DateTimeOffset.Parse(reader.GetString(5), System.Globalization.CultureInfo.InvariantCulture),
|
||||
LastUsedUtc: ReadNullableDateTimeOffset(reader, 6),
|
||||
RevokedUtc: ReadNullableDateTimeOffset(reader, 7));
|
||||
}
|
||||
|
||||
private static DateTimeOffset? ReadNullableDateTimeOffset(SqliteDataReader reader, int ordinal)
|
||||
{
|
||||
return reader.IsDBNull(ordinal)
|
||||
? null
|
||||
: DateTimeOffset.Parse(reader.GetString(ordinal), System.Globalization.CultureInfo.InvariantCulture);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
using System.Security.Cryptography;
|
||||
|
||||
namespace MxGateway.Server.Security.Authentication;
|
||||
|
||||
public static class ApiKeySecretGenerator
|
||||
{
|
||||
public static string Generate()
|
||||
{
|
||||
Span<byte> bytes = stackalloc byte[32];
|
||||
RandomNumberGenerator.Fill(bytes);
|
||||
|
||||
return Convert.ToBase64String(bytes)
|
||||
.TrimEnd('=')
|
||||
.Replace('+', '-')
|
||||
.Replace('/', '_');
|
||||
}
|
||||
}
|
||||
@@ -7,9 +7,11 @@ public static class AuthStoreServiceCollectionExtensions
|
||||
services.AddSingleton<IApiKeyParser, ApiKeyParser>();
|
||||
services.AddSingleton<IApiKeySecretHasher, ApiKeySecretHasher>();
|
||||
services.AddSingleton<IApiKeyVerifier, ApiKeyVerifier>();
|
||||
services.AddSingleton<ApiKeyAdminCliRunner>();
|
||||
services.AddSingleton<AuthSqliteConnectionFactory>();
|
||||
services.AddSingleton<IAuthStoreMigrator, SqliteAuthStoreMigrator>();
|
||||
services.AddSingleton<IApiKeyStore, SqliteApiKeyStore>();
|
||||
services.AddSingleton<IApiKeyAdminStore, SqliteApiKeyAdminStore>();
|
||||
services.AddSingleton<IApiKeyAuditStore, SqliteApiKeyAuditStore>();
|
||||
services.AddHostedService<AuthStoreMigrationHostedService>();
|
||||
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
namespace MxGateway.Server.Security.Authentication;
|
||||
|
||||
public interface IApiKeyAdminStore
|
||||
{
|
||||
Task CreateAsync(ApiKeyCreateRequest request, CancellationToken cancellationToken);
|
||||
|
||||
Task<IReadOnlyList<ApiKeyRecord>> ListAsync(CancellationToken cancellationToken);
|
||||
|
||||
Task<bool> RevokeAsync(string keyId, DateTimeOffset revokedUtc, CancellationToken cancellationToken);
|
||||
|
||||
Task<bool> RotateAsync(
|
||||
string keyId,
|
||||
byte[] secretHash,
|
||||
DateTimeOffset rotatedUtc,
|
||||
CancellationToken cancellationToken);
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
using Microsoft.Data.Sqlite;
|
||||
|
||||
namespace MxGateway.Server.Security.Authentication;
|
||||
|
||||
public sealed class SqliteApiKeyAdminStore(AuthSqliteConnectionFactory connectionFactory) : IApiKeyAdminStore
|
||||
{
|
||||
public async Task CreateAsync(ApiKeyCreateRequest request, CancellationToken cancellationToken)
|
||||
{
|
||||
await using SqliteConnection connection = connectionFactory.CreateConnection();
|
||||
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
|
||||
|
||||
await using SqliteCommand command = connection.CreateCommand();
|
||||
command.CommandText = """
|
||||
INSERT INTO api_keys (
|
||||
key_id,
|
||||
key_prefix,
|
||||
secret_hash,
|
||||
display_name,
|
||||
scopes,
|
||||
created_utc,
|
||||
last_used_utc,
|
||||
revoked_utc)
|
||||
VALUES (
|
||||
$key_id,
|
||||
$key_prefix,
|
||||
$secret_hash,
|
||||
$display_name,
|
||||
$scopes,
|
||||
$created_utc,
|
||||
NULL,
|
||||
NULL);
|
||||
""";
|
||||
AddCreateParameters(command, request);
|
||||
|
||||
await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
public async Task<IReadOnlyList<ApiKeyRecord>> ListAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
await using SqliteConnection connection = connectionFactory.CreateConnection();
|
||||
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
|
||||
|
||||
await using SqliteCommand command = connection.CreateCommand();
|
||||
command.CommandText = """
|
||||
SELECT key_id, key_prefix, secret_hash, display_name, scopes, created_utc, last_used_utc, revoked_utc
|
||||
FROM api_keys
|
||||
ORDER BY key_id;
|
||||
""";
|
||||
|
||||
List<ApiKeyRecord> records = [];
|
||||
|
||||
await using SqliteDataReader reader = await command.ExecuteReaderAsync(cancellationToken)
|
||||
.ConfigureAwait(false);
|
||||
|
||||
while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false))
|
||||
{
|
||||
records.Add(ApiKeyRecordReader.Read(reader));
|
||||
}
|
||||
|
||||
return records;
|
||||
}
|
||||
|
||||
public async Task<bool> RevokeAsync(string keyId, DateTimeOffset revokedUtc, CancellationToken cancellationToken)
|
||||
{
|
||||
await using SqliteConnection connection = connectionFactory.CreateConnection();
|
||||
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
|
||||
|
||||
await using SqliteCommand command = connection.CreateCommand();
|
||||
command.CommandText = """
|
||||
UPDATE api_keys
|
||||
SET revoked_utc = $revoked_utc
|
||||
WHERE key_id = $key_id AND revoked_utc IS NULL;
|
||||
""";
|
||||
command.Parameters.AddWithValue("$key_id", keyId);
|
||||
command.Parameters.AddWithValue("$revoked_utc", revokedUtc.ToString("O"));
|
||||
|
||||
int rows = await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
|
||||
|
||||
return rows > 0;
|
||||
}
|
||||
|
||||
public async Task<bool> RotateAsync(
|
||||
string keyId,
|
||||
byte[] secretHash,
|
||||
DateTimeOffset rotatedUtc,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
await using SqliteConnection connection = connectionFactory.CreateConnection();
|
||||
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
|
||||
|
||||
await using SqliteCommand command = connection.CreateCommand();
|
||||
command.CommandText = """
|
||||
UPDATE api_keys
|
||||
SET secret_hash = $secret_hash,
|
||||
last_used_utc = NULL,
|
||||
revoked_utc = NULL
|
||||
WHERE key_id = $key_id;
|
||||
""";
|
||||
command.Parameters.AddWithValue("$key_id", keyId);
|
||||
command.Parameters.Add("$secret_hash", SqliteType.Blob).Value = secretHash;
|
||||
|
||||
int rows = await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
|
||||
|
||||
return rows > 0;
|
||||
}
|
||||
|
||||
private static void AddCreateParameters(SqliteCommand command, ApiKeyCreateRequest request)
|
||||
{
|
||||
command.Parameters.AddWithValue("$key_id", request.KeyId);
|
||||
command.Parameters.AddWithValue("$key_prefix", request.KeyPrefix);
|
||||
command.Parameters.Add("$secret_hash", SqliteType.Blob).Value = request.SecretHash;
|
||||
command.Parameters.AddWithValue("$display_name", request.DisplayName);
|
||||
command.Parameters.AddWithValue("$scopes", ApiKeyScopeSerializer.Serialize(request.Scopes));
|
||||
command.Parameters.AddWithValue("$created_utc", request.CreatedUtc.ToString("O"));
|
||||
}
|
||||
}
|
||||
@@ -61,26 +61,6 @@ public sealed class SqliteApiKeyStore(AuthSqliteConnectionFactory connectionFact
|
||||
return null;
|
||||
}
|
||||
|
||||
return ReadApiKeyRecord(reader);
|
||||
}
|
||||
|
||||
private static ApiKeyRecord ReadApiKeyRecord(SqliteDataReader reader)
|
||||
{
|
||||
return new ApiKeyRecord(
|
||||
KeyId: reader.GetString(0),
|
||||
KeyPrefix: reader.GetString(1),
|
||||
SecretHash: (byte[])reader["secret_hash"],
|
||||
DisplayName: reader.GetString(3),
|
||||
Scopes: ApiKeyScopeSerializer.Deserialize(reader.GetString(4)),
|
||||
CreatedUtc: DateTimeOffset.Parse(reader.GetString(5), System.Globalization.CultureInfo.InvariantCulture),
|
||||
LastUsedUtc: ReadNullableDateTimeOffset(reader, 6),
|
||||
RevokedUtc: ReadNullableDateTimeOffset(reader, 7));
|
||||
}
|
||||
|
||||
private static DateTimeOffset? ReadNullableDateTimeOffset(SqliteDataReader reader, int ordinal)
|
||||
{
|
||||
return reader.IsDBNull(ordinal)
|
||||
? null
|
||||
: DateTimeOffset.Parse(reader.GetString(ordinal), System.Globalization.CultureInfo.InvariantCulture);
|
||||
return ApiKeyRecordReader.Read(reader);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
using Grpc.Core;
|
||||
using Grpc.Core.Interceptors;
|
||||
using Microsoft.Extensions.Options;
|
||||
using MxGateway.Server.Configuration;
|
||||
using MxGateway.Server.Security.Authentication;
|
||||
|
||||
namespace MxGateway.Server.Security.Authorization;
|
||||
|
||||
public sealed class GatewayGrpcAuthorizationInterceptor(
|
||||
IApiKeyVerifier apiKeyVerifier,
|
||||
GatewayGrpcScopeResolver scopeResolver,
|
||||
IGatewayRequestIdentityAccessor identityAccessor,
|
||||
IOptions<GatewayOptions> options) : Interceptor
|
||||
{
|
||||
public override async Task<TResponse> UnaryServerHandler<TRequest, TResponse>(
|
||||
TRequest request,
|
||||
ServerCallContext context,
|
||||
UnaryServerMethod<TRequest, TResponse> continuation)
|
||||
{
|
||||
ApiKeyIdentity? identity = await AuthenticateAndAuthorizeAsync(request, context).ConfigureAwait(false);
|
||||
IDisposable? identityScope = identity is null ? null : identityAccessor.Push(identity);
|
||||
using (identityScope)
|
||||
{
|
||||
return await continuation(request, context).ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
|
||||
public override async Task ServerStreamingServerHandler<TRequest, TResponse>(
|
||||
TRequest request,
|
||||
IServerStreamWriter<TResponse> responseStream,
|
||||
ServerCallContext context,
|
||||
ServerStreamingServerMethod<TRequest, TResponse> continuation)
|
||||
{
|
||||
ApiKeyIdentity? identity = await AuthenticateAndAuthorizeAsync(request, context).ConfigureAwait(false);
|
||||
IDisposable? identityScope = identity is null ? null : identityAccessor.Push(identity);
|
||||
using (identityScope)
|
||||
{
|
||||
await continuation(request, responseStream, context).ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task<ApiKeyIdentity?> AuthenticateAndAuthorizeAsync<TRequest>(
|
||||
TRequest request,
|
||||
ServerCallContext context)
|
||||
where TRequest : class
|
||||
{
|
||||
if (options.Value.Authentication.Mode == AuthenticationMode.Disabled)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
string? authorizationHeader = context.RequestHeaders.GetValue("authorization");
|
||||
ApiKeyVerificationResult verificationResult = await apiKeyVerifier
|
||||
.VerifyAsync(authorizationHeader, context.CancellationToken)
|
||||
.ConfigureAwait(false);
|
||||
|
||||
if (!verificationResult.Succeeded || verificationResult.Identity is null)
|
||||
{
|
||||
throw new RpcException(new Status(
|
||||
StatusCode.Unauthenticated,
|
||||
"Missing or invalid API key."));
|
||||
}
|
||||
|
||||
string requiredScope = scopeResolver.ResolveRequiredScope(request);
|
||||
if (!verificationResult.Identity.Scopes.Contains(requiredScope))
|
||||
{
|
||||
throw new RpcException(new Status(
|
||||
StatusCode.PermissionDenied,
|
||||
$"API key is missing required scope '{requiredScope}'."));
|
||||
}
|
||||
|
||||
return verificationResult.Identity;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
using MxGateway.Contracts.Proto;
|
||||
|
||||
namespace MxGateway.Server.Security.Authorization;
|
||||
|
||||
public sealed class GatewayGrpcScopeResolver
|
||||
{
|
||||
public string ResolveRequiredScope(object request)
|
||||
{
|
||||
return request switch
|
||||
{
|
||||
OpenSessionRequest => GatewayScopes.SessionOpen,
|
||||
CloseSessionRequest => GatewayScopes.SessionClose,
|
||||
StreamEventsRequest => GatewayScopes.EventsRead,
|
||||
MxCommandRequest commandRequest => ResolveCommandScope(commandRequest.Command?.Kind ?? MxCommandKind.Unspecified),
|
||||
_ => GatewayScopes.Admin
|
||||
};
|
||||
}
|
||||
|
||||
private static string ResolveCommandScope(MxCommandKind kind)
|
||||
{
|
||||
return kind switch
|
||||
{
|
||||
MxCommandKind.Write or
|
||||
MxCommandKind.Write2 => GatewayScopes.InvokeWrite,
|
||||
|
||||
MxCommandKind.WriteSecured or
|
||||
MxCommandKind.WriteSecured2 or
|
||||
MxCommandKind.AuthenticateUser => GatewayScopes.InvokeSecure,
|
||||
|
||||
MxCommandKind.ArchestraUserToId or
|
||||
MxCommandKind.GetSessionState or
|
||||
MxCommandKind.GetWorkerInfo => GatewayScopes.MetadataRead,
|
||||
|
||||
MxCommandKind.DrainEvents => GatewayScopes.EventsRead,
|
||||
MxCommandKind.ShutdownWorker => GatewayScopes.Admin,
|
||||
|
||||
_ => GatewayScopes.InvokeRead
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
using MxGateway.Server.Security.Authentication;
|
||||
|
||||
namespace MxGateway.Server.Security.Authorization;
|
||||
|
||||
public sealed class GatewayRequestIdentityAccessor : IGatewayRequestIdentityAccessor
|
||||
{
|
||||
private readonly AsyncLocal<ApiKeyIdentity?> currentIdentity = new();
|
||||
|
||||
public ApiKeyIdentity? Current => currentIdentity.Value;
|
||||
|
||||
public IDisposable Push(ApiKeyIdentity identity)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(identity);
|
||||
|
||||
ApiKeyIdentity? previousIdentity = currentIdentity.Value;
|
||||
currentIdentity.Value = identity;
|
||||
|
||||
return new IdentityScope(this, previousIdentity);
|
||||
}
|
||||
|
||||
private sealed class IdentityScope(
|
||||
GatewayRequestIdentityAccessor accessor,
|
||||
ApiKeyIdentity? previousIdentity) : IDisposable
|
||||
{
|
||||
private bool disposed;
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
if (disposed)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
accessor.currentIdentity.Value = previousIdentity;
|
||||
disposed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
namespace MxGateway.Server.Security.Authorization;
|
||||
|
||||
public static class GatewayScopes
|
||||
{
|
||||
public const string SessionOpen = "session:open";
|
||||
public const string SessionClose = "session:close";
|
||||
public const string InvokeRead = "invoke:read";
|
||||
public const string InvokeWrite = "invoke:write";
|
||||
public const string InvokeSecure = "invoke:secure";
|
||||
public const string EventsRead = "events:read";
|
||||
public const string MetadataRead = "metadata:read";
|
||||
public const string Admin = "admin";
|
||||
}
|
||||
+16
@@ -0,0 +1,16 @@
|
||||
using Grpc.Core.Interceptors;
|
||||
|
||||
namespace MxGateway.Server.Security.Authorization;
|
||||
|
||||
public static class GrpcAuthorizationServiceCollectionExtensions
|
||||
{
|
||||
public static IServiceCollection AddGatewayGrpcAuthorization(this IServiceCollection services)
|
||||
{
|
||||
services.AddSingleton<GatewayGrpcScopeResolver>();
|
||||
services.AddSingleton<IGatewayRequestIdentityAccessor, GatewayRequestIdentityAccessor>();
|
||||
services.AddSingleton<GatewayGrpcAuthorizationInterceptor>();
|
||||
services.AddGrpc(options => options.Interceptors.Add<GatewayGrpcAuthorizationInterceptor>());
|
||||
|
||||
return services;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
using MxGateway.Server.Security.Authentication;
|
||||
|
||||
namespace MxGateway.Server.Security.Authorization;
|
||||
|
||||
public interface IGatewayRequestIdentityAccessor
|
||||
{
|
||||
ApiKeyIdentity? Current { get; }
|
||||
|
||||
IDisposable Push(ApiKeyIdentity identity);
|
||||
}
|
||||
@@ -0,0 +1,290 @@
|
||||
using MxGateway.Contracts.Proto;
|
||||
using MxGateway.Server.Workers;
|
||||
|
||||
namespace MxGateway.Server.Sessions;
|
||||
|
||||
public sealed class GatewaySession
|
||||
{
|
||||
private readonly object _syncRoot = new();
|
||||
private readonly SemaphoreSlim _closeLock = new(1, 1);
|
||||
private IWorkerClient? _workerClient;
|
||||
private SessionState _state = SessionState.Creating;
|
||||
private string? _finalFault;
|
||||
private DateTimeOffset _lastClientActivityAt;
|
||||
private DateTimeOffset? _leaseExpiresAt;
|
||||
private bool _closeStarted;
|
||||
|
||||
public GatewaySession(
|
||||
string sessionId,
|
||||
string backendName,
|
||||
string pipeName,
|
||||
string nonce,
|
||||
string? clientIdentity,
|
||||
string? clientSessionName,
|
||||
string? clientCorrelationId,
|
||||
TimeSpan commandTimeout,
|
||||
TimeSpan startupTimeout,
|
||||
TimeSpan shutdownTimeout,
|
||||
DateTimeOffset openedAt)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(sessionId))
|
||||
{
|
||||
throw new ArgumentException("Session id is required.", nameof(sessionId));
|
||||
}
|
||||
|
||||
if (string.IsNullOrWhiteSpace(backendName))
|
||||
{
|
||||
throw new ArgumentException("Backend name is required.", nameof(backendName));
|
||||
}
|
||||
|
||||
if (string.IsNullOrWhiteSpace(pipeName))
|
||||
{
|
||||
throw new ArgumentException("Pipe name is required.", nameof(pipeName));
|
||||
}
|
||||
|
||||
if (string.IsNullOrWhiteSpace(nonce))
|
||||
{
|
||||
throw new ArgumentException("Nonce is required.", nameof(nonce));
|
||||
}
|
||||
|
||||
SessionId = sessionId;
|
||||
BackendName = backendName;
|
||||
PipeName = pipeName;
|
||||
Nonce = nonce;
|
||||
ClientIdentity = clientIdentity;
|
||||
ClientSessionName = clientSessionName;
|
||||
ClientCorrelationId = clientCorrelationId;
|
||||
CommandTimeout = commandTimeout;
|
||||
StartupTimeout = startupTimeout;
|
||||
ShutdownTimeout = shutdownTimeout;
|
||||
OpenedAt = openedAt;
|
||||
_lastClientActivityAt = openedAt;
|
||||
}
|
||||
|
||||
public string SessionId { get; }
|
||||
|
||||
public string BackendName { get; }
|
||||
|
||||
public string PipeName { get; }
|
||||
|
||||
public string Nonce { get; }
|
||||
|
||||
public string? ClientIdentity { get; }
|
||||
|
||||
public string? ClientSessionName { get; }
|
||||
|
||||
public string? ClientCorrelationId { get; }
|
||||
|
||||
public TimeSpan CommandTimeout { get; }
|
||||
|
||||
public TimeSpan StartupTimeout { get; }
|
||||
|
||||
public TimeSpan ShutdownTimeout { get; }
|
||||
|
||||
public DateTimeOffset OpenedAt { get; }
|
||||
|
||||
public int? WorkerProcessId => _workerClient?.ProcessId;
|
||||
|
||||
public IWorkerClient? WorkerClient => _workerClient;
|
||||
|
||||
public SessionState State
|
||||
{
|
||||
get
|
||||
{
|
||||
lock (_syncRoot)
|
||||
{
|
||||
return _state;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public DateTimeOffset LastClientActivityAt
|
||||
{
|
||||
get
|
||||
{
|
||||
lock (_syncRoot)
|
||||
{
|
||||
return _lastClientActivityAt;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public DateTimeOffset? LeaseExpiresAt
|
||||
{
|
||||
get
|
||||
{
|
||||
lock (_syncRoot)
|
||||
{
|
||||
return _leaseExpiresAt;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public string? FinalFault
|
||||
{
|
||||
get
|
||||
{
|
||||
lock (_syncRoot)
|
||||
{
|
||||
return _finalFault;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void AttachWorkerClient(IWorkerClient workerClient)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(workerClient);
|
||||
|
||||
lock (_syncRoot)
|
||||
{
|
||||
_workerClient = workerClient;
|
||||
}
|
||||
}
|
||||
|
||||
public void TransitionTo(SessionState nextState)
|
||||
{
|
||||
lock (_syncRoot)
|
||||
{
|
||||
if (_state is SessionState.Closed)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if (_state is SessionState.Faulted && nextState is not SessionState.Closed)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
_state = nextState;
|
||||
}
|
||||
}
|
||||
|
||||
public void MarkReady()
|
||||
{
|
||||
TransitionTo(SessionState.Ready);
|
||||
}
|
||||
|
||||
public void MarkFaulted(string reason)
|
||||
{
|
||||
lock (_syncRoot)
|
||||
{
|
||||
if (_state is SessionState.Closed)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
_finalFault = reason;
|
||||
_state = SessionState.Faulted;
|
||||
}
|
||||
}
|
||||
|
||||
public void TouchClientActivity(DateTimeOffset activityAt)
|
||||
{
|
||||
lock (_syncRoot)
|
||||
{
|
||||
_lastClientActivityAt = activityAt;
|
||||
}
|
||||
}
|
||||
|
||||
public void ExtendLease(DateTimeOffset leaseExpiresAt)
|
||||
{
|
||||
lock (_syncRoot)
|
||||
{
|
||||
_leaseExpiresAt = leaseExpiresAt;
|
||||
}
|
||||
}
|
||||
|
||||
public bool IsLeaseExpired(DateTimeOffset now)
|
||||
{
|
||||
lock (_syncRoot)
|
||||
{
|
||||
return _leaseExpiresAt is not null && _leaseExpiresAt <= now;
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<WorkerCommandReply> InvokeAsync(
|
||||
WorkerCommand command,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
IWorkerClient workerClient = GetReadyWorkerClient();
|
||||
TouchClientActivity(DateTimeOffset.UtcNow);
|
||||
|
||||
return await workerClient.InvokeAsync(command, CommandTimeout, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<WorkerEvent> ReadEventsAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
IWorkerClient workerClient = GetReadyWorkerClient();
|
||||
TouchClientActivity(DateTimeOffset.UtcNow);
|
||||
|
||||
return workerClient.ReadEventsAsync(cancellationToken);
|
||||
}
|
||||
|
||||
public async Task<SessionCloseResult> CloseAsync(
|
||||
string reason,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
await _closeLock.WaitAsync(cancellationToken).ConfigureAwait(false);
|
||||
try
|
||||
{
|
||||
if (_state is SessionState.Closed)
|
||||
{
|
||||
return new SessionCloseResult(SessionId, SessionState.Closed, AlreadyClosed: true);
|
||||
}
|
||||
|
||||
bool alreadyClosing = _closeStarted;
|
||||
_closeStarted = true;
|
||||
_state = SessionState.Closing;
|
||||
|
||||
if (_workerClient is not null)
|
||||
{
|
||||
try
|
||||
{
|
||||
await _workerClient.ShutdownAsync(ShutdownTimeout, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
catch
|
||||
{
|
||||
_workerClient.Kill(reason);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
_state = SessionState.Closed;
|
||||
return new SessionCloseResult(SessionId, SessionState.Closed, alreadyClosing);
|
||||
}
|
||||
finally
|
||||
{
|
||||
_closeLock.Release();
|
||||
}
|
||||
}
|
||||
|
||||
public void KillWorker(string reason)
|
||||
{
|
||||
_workerClient?.Kill(reason);
|
||||
TransitionTo(SessionState.Closed);
|
||||
}
|
||||
|
||||
public async ValueTask DisposeAsync()
|
||||
{
|
||||
_closeLock.Dispose();
|
||||
if (_workerClient is not null)
|
||||
{
|
||||
await _workerClient.DisposeAsync().ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
|
||||
private IWorkerClient GetReadyWorkerClient()
|
||||
{
|
||||
lock (_syncRoot)
|
||||
{
|
||||
if (_state != SessionState.Ready || _workerClient?.State != WorkerClientState.Ready)
|
||||
{
|
||||
throw new SessionManagerException(
|
||||
SessionManagerErrorCode.SessionNotReady,
|
||||
$"Session {SessionId} is not ready. Current state is {_state}.");
|
||||
}
|
||||
|
||||
return _workerClient;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
using MxGateway.Contracts.Proto;
|
||||
|
||||
namespace MxGateway.Server.Sessions;
|
||||
|
||||
public interface ISessionManager
|
||||
{
|
||||
Task<GatewaySession> OpenSessionAsync(
|
||||
SessionOpenRequest request,
|
||||
string? clientIdentity,
|
||||
CancellationToken cancellationToken);
|
||||
|
||||
bool TryGetSession(
|
||||
string sessionId,
|
||||
out GatewaySession session);
|
||||
|
||||
Task<WorkerCommandReply> InvokeAsync(
|
||||
string sessionId,
|
||||
WorkerCommand command,
|
||||
CancellationToken cancellationToken);
|
||||
|
||||
IAsyncEnumerable<WorkerEvent> ReadEventsAsync(
|
||||
string sessionId,
|
||||
CancellationToken cancellationToken);
|
||||
|
||||
Task<SessionCloseResult> CloseSessionAsync(
|
||||
string sessionId,
|
||||
CancellationToken cancellationToken);
|
||||
|
||||
Task<int> CloseExpiredLeasesAsync(
|
||||
DateTimeOffset now,
|
||||
CancellationToken cancellationToken);
|
||||
|
||||
Task ShutdownAsync(CancellationToken cancellationToken);
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
namespace MxGateway.Server.Sessions;
|
||||
|
||||
public interface ISessionRegistry
|
||||
{
|
||||
int Count { get; }
|
||||
|
||||
int ActiveCount { get; }
|
||||
|
||||
bool TryAdd(GatewaySession session);
|
||||
|
||||
bool TryGet(string sessionId, out GatewaySession session);
|
||||
|
||||
bool TryRemove(string sessionId, out GatewaySession session);
|
||||
|
||||
IReadOnlyCollection<GatewaySession> Snapshot();
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
namespace MxGateway.Server.Sessions;
|
||||
|
||||
public interface ISessionWorkerClientFactory
|
||||
{
|
||||
Task<MxGateway.Server.Workers.IWorkerClient> CreateAsync(
|
||||
GatewaySession session,
|
||||
CancellationToken cancellationToken);
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
using MxGateway.Contracts.Proto;
|
||||
|
||||
namespace MxGateway.Server.Sessions;
|
||||
|
||||
public sealed record SessionCloseResult(
|
||||
string SessionId,
|
||||
SessionState FinalState,
|
||||
bool AlreadyClosed);
|
||||
@@ -0,0 +1,287 @@
|
||||
using System.Security.Cryptography;
|
||||
using Google.Protobuf.WellKnownTypes;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using Microsoft.Extensions.Options;
|
||||
using MxGateway.Contracts;
|
||||
using MxGateway.Contracts.Proto;
|
||||
using MxGateway.Server.Configuration;
|
||||
using MxGateway.Server.Metrics;
|
||||
using MxGateway.Server.Workers;
|
||||
|
||||
namespace MxGateway.Server.Sessions;
|
||||
|
||||
public sealed class SessionManager : ISessionManager
|
||||
{
|
||||
public const string DefaultCloseReason = "client-close";
|
||||
public const string GatewayShutdownReason = "gateway-shutdown";
|
||||
public const string LeaseExpiredReason = "lease-expired";
|
||||
|
||||
private readonly ISessionRegistry _registry;
|
||||
private readonly ISessionWorkerClientFactory _workerClientFactory;
|
||||
private readonly GatewayMetrics _metrics;
|
||||
private readonly TimeProvider _timeProvider;
|
||||
private readonly ILogger<SessionManager> _logger;
|
||||
private readonly GatewayOptions _options;
|
||||
|
||||
public SessionManager(
|
||||
ISessionRegistry registry,
|
||||
ISessionWorkerClientFactory workerClientFactory,
|
||||
IOptions<GatewayOptions> options,
|
||||
GatewayMetrics metrics,
|
||||
TimeProvider? timeProvider = null,
|
||||
ILogger<SessionManager>? logger = null)
|
||||
{
|
||||
_registry = registry ?? throw new ArgumentNullException(nameof(registry));
|
||||
_workerClientFactory = workerClientFactory ?? throw new ArgumentNullException(nameof(workerClientFactory));
|
||||
ArgumentNullException.ThrowIfNull(options);
|
||||
_metrics = metrics ?? throw new ArgumentNullException(nameof(metrics));
|
||||
_timeProvider = timeProvider ?? TimeProvider.System;
|
||||
_logger = logger ?? NullLogger<SessionManager>.Instance;
|
||||
_options = options.Value;
|
||||
}
|
||||
|
||||
public async Task<GatewaySession> OpenSessionAsync(
|
||||
SessionOpenRequest request,
|
||||
string? clientIdentity,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(request);
|
||||
EnsureSessionCapacity();
|
||||
|
||||
GatewaySession session = CreateSession(request, clientIdentity);
|
||||
if (!_registry.TryAdd(session))
|
||||
{
|
||||
throw new SessionManagerException(
|
||||
SessionManagerErrorCode.OpenFailed,
|
||||
$"Session id collision while opening session {session.SessionId}.");
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
session.TransitionTo(SessionState.StartingWorker);
|
||||
IWorkerClient workerClient = await _workerClientFactory
|
||||
.CreateAsync(session, cancellationToken)
|
||||
.ConfigureAwait(false);
|
||||
|
||||
session.AttachWorkerClient(workerClient);
|
||||
session.MarkReady();
|
||||
_metrics.SessionOpened();
|
||||
|
||||
return session;
|
||||
}
|
||||
catch (Exception exception)
|
||||
{
|
||||
session.MarkFaulted(exception.Message);
|
||||
_registry.TryRemove(session.SessionId, out _);
|
||||
await session.DisposeAsync().ConfigureAwait(false);
|
||||
_metrics.Fault(SessionManagerErrorCode.OpenFailed.ToString());
|
||||
_logger.LogWarning(
|
||||
exception,
|
||||
"Failed to open gateway session {SessionId}.",
|
||||
session.SessionId);
|
||||
|
||||
throw new SessionManagerException(
|
||||
SessionManagerErrorCode.OpenFailed,
|
||||
$"Failed to open session {session.SessionId}.",
|
||||
exception);
|
||||
}
|
||||
}
|
||||
|
||||
public bool TryGetSession(
|
||||
string sessionId,
|
||||
out GatewaySession session)
|
||||
{
|
||||
return _registry.TryGet(sessionId, out session);
|
||||
}
|
||||
|
||||
public async Task<WorkerCommandReply> InvokeAsync(
|
||||
string sessionId,
|
||||
WorkerCommand command,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
GatewaySession session = GetRequiredSession(sessionId);
|
||||
|
||||
try
|
||||
{
|
||||
return await session.InvokeAsync(command, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
catch (SessionManagerException)
|
||||
{
|
||||
throw;
|
||||
}
|
||||
catch (Exception exception)
|
||||
{
|
||||
if (session.WorkerClient?.State == WorkerClientState.Faulted)
|
||||
{
|
||||
session.MarkFaulted(exception.Message);
|
||||
}
|
||||
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<WorkerEvent> ReadEventsAsync(
|
||||
string sessionId,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
GatewaySession session = GetRequiredSession(sessionId);
|
||||
|
||||
return session.ReadEventsAsync(cancellationToken);
|
||||
}
|
||||
|
||||
public async Task<SessionCloseResult> CloseSessionAsync(
|
||||
string sessionId,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
GatewaySession session = GetRequiredSession(sessionId);
|
||||
SessionCloseResult result = await CloseSessionCoreAsync(
|
||||
session,
|
||||
DefaultCloseReason,
|
||||
cancellationToken).ConfigureAwait(false);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
public async Task<int> CloseExpiredLeasesAsync(
|
||||
DateTimeOffset now,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
int closedCount = 0;
|
||||
foreach (GatewaySession session in _registry.Snapshot())
|
||||
{
|
||||
if (!session.IsLeaseExpired(now))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
await CloseSessionCoreAsync(session, LeaseExpiredReason, cancellationToken).ConfigureAwait(false);
|
||||
closedCount++;
|
||||
}
|
||||
|
||||
return closedCount;
|
||||
}
|
||||
|
||||
public async Task ShutdownAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
foreach (GatewaySession session in _registry.Snapshot())
|
||||
{
|
||||
try
|
||||
{
|
||||
await CloseSessionCoreAsync(session, GatewayShutdownReason, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
catch (Exception exception)
|
||||
{
|
||||
_logger.LogWarning(
|
||||
exception,
|
||||
"Graceful shutdown failed for session {SessionId}; killing worker.",
|
||||
session.SessionId);
|
||||
session.KillWorker(GatewayShutdownReason);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async Task<SessionCloseResult> CloseSessionCoreAsync(
|
||||
GatewaySession session,
|
||||
string reason,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
bool wasClosed = session.State == SessionState.Closed;
|
||||
try
|
||||
{
|
||||
SessionCloseResult result = await session.CloseAsync(reason, cancellationToken).ConfigureAwait(false);
|
||||
if (!wasClosed && !result.AlreadyClosed)
|
||||
{
|
||||
_metrics.SessionClosed();
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
catch (Exception exception)
|
||||
{
|
||||
session.MarkFaulted(exception.Message);
|
||||
_metrics.Fault(SessionManagerErrorCode.CloseFailed.ToString());
|
||||
throw new SessionManagerException(
|
||||
SessionManagerErrorCode.CloseFailed,
|
||||
$"Failed to close session {session.SessionId}.",
|
||||
exception);
|
||||
}
|
||||
}
|
||||
|
||||
private GatewaySession GetRequiredSession(string sessionId)
|
||||
{
|
||||
if (!_registry.TryGet(sessionId, out GatewaySession session))
|
||||
{
|
||||
throw new SessionManagerException(
|
||||
SessionManagerErrorCode.SessionNotFound,
|
||||
$"Session {sessionId} was not found.");
|
||||
}
|
||||
|
||||
return session;
|
||||
}
|
||||
|
||||
private void EnsureSessionCapacity()
|
||||
{
|
||||
if (_registry.ActiveCount >= _options.Sessions.MaxSessions)
|
||||
{
|
||||
throw new SessionManagerException(
|
||||
SessionManagerErrorCode.SessionLimitExceeded,
|
||||
$"Gateway session limit {_options.Sessions.MaxSessions} has been reached.");
|
||||
}
|
||||
}
|
||||
|
||||
private GatewaySession CreateSession(
|
||||
SessionOpenRequest request,
|
||||
string? clientIdentity)
|
||||
{
|
||||
string sessionId = CreateSessionId();
|
||||
string backendName = string.IsNullOrWhiteSpace(request.RequestedBackend)
|
||||
? GatewayContractInfo.DefaultBackendName
|
||||
: request.RequestedBackend!;
|
||||
TimeSpan commandTimeout = ResolveCommandTimeout(request.CommandTimeout);
|
||||
TimeSpan startupTimeout = TimeSpan.FromSeconds(_options.Worker.StartupTimeoutSeconds);
|
||||
TimeSpan shutdownTimeout = TimeSpan.FromSeconds(_options.Worker.ShutdownTimeoutSeconds);
|
||||
string pipeName = $"mxaccess-gateway-{Environment.ProcessId}-{sessionId}";
|
||||
string nonce = CreateNonce();
|
||||
DateTimeOffset openedAt = _timeProvider.GetUtcNow();
|
||||
|
||||
return new GatewaySession(
|
||||
sessionId,
|
||||
backendName,
|
||||
pipeName,
|
||||
nonce,
|
||||
clientIdentity,
|
||||
request.ClientSessionName,
|
||||
request.ClientCorrelationId,
|
||||
commandTimeout,
|
||||
startupTimeout,
|
||||
shutdownTimeout,
|
||||
openedAt);
|
||||
}
|
||||
|
||||
private TimeSpan ResolveCommandTimeout(Duration? requestedTimeout)
|
||||
{
|
||||
if (requestedTimeout is null)
|
||||
{
|
||||
return TimeSpan.FromSeconds(_options.Sessions.DefaultCommandTimeoutSeconds);
|
||||
}
|
||||
|
||||
TimeSpan timeout = requestedTimeout.ToTimeSpan();
|
||||
return timeout <= TimeSpan.Zero
|
||||
? TimeSpan.FromSeconds(_options.Sessions.DefaultCommandTimeoutSeconds)
|
||||
: timeout;
|
||||
}
|
||||
|
||||
private static string CreateSessionId()
|
||||
{
|
||||
return $"session-{Guid.NewGuid():N}";
|
||||
}
|
||||
|
||||
private static string CreateNonce()
|
||||
{
|
||||
Span<byte> bytes = stackalloc byte[32];
|
||||
RandomNumberGenerator.Fill(bytes);
|
||||
|
||||
return Convert.ToBase64String(bytes);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
namespace MxGateway.Server.Sessions;
|
||||
|
||||
public enum SessionManagerErrorCode
|
||||
{
|
||||
SessionNotFound,
|
||||
SessionNotReady,
|
||||
SessionLimitExceeded,
|
||||
OpenFailed,
|
||||
CloseFailed,
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
namespace MxGateway.Server.Sessions;
|
||||
|
||||
public sealed class SessionManagerException : Exception
|
||||
{
|
||||
public SessionManagerException(
|
||||
SessionManagerErrorCode errorCode,
|
||||
string message)
|
||||
: base(message)
|
||||
{
|
||||
ErrorCode = errorCode;
|
||||
}
|
||||
|
||||
public SessionManagerException(
|
||||
SessionManagerErrorCode errorCode,
|
||||
string message,
|
||||
Exception innerException)
|
||||
: base(message, innerException)
|
||||
{
|
||||
ErrorCode = errorCode;
|
||||
}
|
||||
|
||||
public SessionManagerErrorCode ErrorCode { get; }
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
using Google.Protobuf.WellKnownTypes;
|
||||
using MxGateway.Contracts.Proto;
|
||||
|
||||
namespace MxGateway.Server.Sessions;
|
||||
|
||||
public sealed record SessionOpenRequest(
|
||||
string? RequestedBackend,
|
||||
string? ClientSessionName,
|
||||
string? ClientCorrelationId,
|
||||
Duration? CommandTimeout)
|
||||
{
|
||||
public static SessionOpenRequest FromContract(OpenSessionRequest request)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(request);
|
||||
|
||||
return new SessionOpenRequest(
|
||||
request.RequestedBackend,
|
||||
request.ClientSessionName,
|
||||
request.ClientCorrelationId,
|
||||
request.CommandTimeout);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
using System.Collections.Concurrent;
|
||||
using MxGateway.Contracts.Proto;
|
||||
|
||||
namespace MxGateway.Server.Sessions;
|
||||
|
||||
public sealed class SessionRegistry : ISessionRegistry
|
||||
{
|
||||
private readonly ConcurrentDictionary<string, GatewaySession> _sessions = new(StringComparer.Ordinal);
|
||||
|
||||
public int Count => _sessions.Count;
|
||||
|
||||
public int ActiveCount => _sessions.Values.Count(session => session.State is not SessionState.Closed);
|
||||
|
||||
public bool TryAdd(GatewaySession session)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(session);
|
||||
|
||||
return _sessions.TryAdd(session.SessionId, session);
|
||||
}
|
||||
|
||||
public bool TryGet(
|
||||
string sessionId,
|
||||
out GatewaySession session)
|
||||
{
|
||||
return _sessions.TryGetValue(sessionId, out session!);
|
||||
}
|
||||
|
||||
public bool TryRemove(
|
||||
string sessionId,
|
||||
out GatewaySession session)
|
||||
{
|
||||
return _sessions.TryRemove(sessionId, out session!);
|
||||
}
|
||||
|
||||
public IReadOnlyCollection<GatewaySession> Snapshot()
|
||||
{
|
||||
return _sessions.Values.ToArray();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
namespace MxGateway.Server.Sessions;
|
||||
|
||||
public static class SessionServiceCollectionExtensions
|
||||
{
|
||||
public static IServiceCollection AddGatewaySessions(this IServiceCollection services)
|
||||
{
|
||||
services.AddSingleton<ISessionRegistry, SessionRegistry>();
|
||||
services.AddSingleton<ISessionWorkerClientFactory, SessionWorkerClientFactory>();
|
||||
services.AddSingleton<ISessionManager, SessionManager>();
|
||||
|
||||
return services;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
using System.IO.Pipes;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
using MxGateway.Contracts;
|
||||
using MxGateway.Contracts.Proto;
|
||||
using MxGateway.Server.Configuration;
|
||||
using MxGateway.Server.Metrics;
|
||||
using MxGateway.Server.Workers;
|
||||
|
||||
namespace MxGateway.Server.Sessions;
|
||||
|
||||
public sealed class SessionWorkerClientFactory : ISessionWorkerClientFactory
|
||||
{
|
||||
private readonly IWorkerProcessLauncher _workerProcessLauncher;
|
||||
private readonly GatewayMetrics _metrics;
|
||||
private readonly TimeProvider _timeProvider;
|
||||
private readonly ILoggerFactory _loggerFactory;
|
||||
private readonly GatewayOptions _options;
|
||||
|
||||
public SessionWorkerClientFactory(
|
||||
IWorkerProcessLauncher workerProcessLauncher,
|
||||
IOptions<GatewayOptions> options,
|
||||
GatewayMetrics metrics,
|
||||
ILoggerFactory loggerFactory,
|
||||
TimeProvider? timeProvider = null)
|
||||
{
|
||||
_workerProcessLauncher = workerProcessLauncher ?? throw new ArgumentNullException(nameof(workerProcessLauncher));
|
||||
ArgumentNullException.ThrowIfNull(options);
|
||||
_metrics = metrics ?? throw new ArgumentNullException(nameof(metrics));
|
||||
_loggerFactory = loggerFactory ?? throw new ArgumentNullException(nameof(loggerFactory));
|
||||
_timeProvider = timeProvider ?? TimeProvider.System;
|
||||
_options = options.Value;
|
||||
}
|
||||
|
||||
public async Task<IWorkerClient> CreateAsync(
|
||||
GatewaySession session,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(session);
|
||||
|
||||
NamedPipeServerStream? pipe = CreatePipe(session.PipeName);
|
||||
WorkerProcessHandle? processHandle = null;
|
||||
IWorkerClient? workerClient = null;
|
||||
try
|
||||
{
|
||||
session.TransitionTo(SessionState.StartingWorker);
|
||||
processHandle = await _workerProcessLauncher
|
||||
.LaunchAsync(
|
||||
new WorkerProcessLaunchRequest(
|
||||
session.SessionId,
|
||||
session.PipeName,
|
||||
GatewayContractInfo.WorkerProtocolVersion,
|
||||
session.Nonce,
|
||||
pipe),
|
||||
cancellationToken)
|
||||
.ConfigureAwait(false);
|
||||
|
||||
session.TransitionTo(SessionState.WaitingForPipe);
|
||||
await WaitForPipeConnectionAsync(pipe, session.StartupTimeout, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
session.TransitionTo(SessionState.Handshaking);
|
||||
WorkerFrameProtocolOptions frameOptions = new(
|
||||
session.SessionId,
|
||||
GatewayContractInfo.WorkerProtocolVersion,
|
||||
_options.Worker.MaxMessageBytes);
|
||||
WorkerClientConnection connection = new(
|
||||
session.SessionId,
|
||||
session.Nonce,
|
||||
pipe,
|
||||
frameOptions,
|
||||
processHandle);
|
||||
WorkerClientOptions clientOptions = new()
|
||||
{
|
||||
HeartbeatGrace = TimeSpan.FromSeconds(_options.Worker.HeartbeatGraceSeconds),
|
||||
HeartbeatCheckInterval = TimeSpan.FromSeconds(_options.Worker.HeartbeatIntervalSeconds),
|
||||
EventChannelCapacity = _options.Events.QueueCapacity,
|
||||
};
|
||||
|
||||
workerClient = new WorkerClient(
|
||||
connection,
|
||||
clientOptions,
|
||||
_metrics,
|
||||
_timeProvider,
|
||||
_loggerFactory.CreateLogger<WorkerClient>());
|
||||
|
||||
pipe = null;
|
||||
processHandle = null;
|
||||
|
||||
session.TransitionTo(SessionState.InitializingWorker);
|
||||
await workerClient.StartAsync(cancellationToken).ConfigureAwait(false);
|
||||
|
||||
return workerClient;
|
||||
}
|
||||
catch
|
||||
{
|
||||
if (workerClient is not null)
|
||||
{
|
||||
await workerClient.DisposeAsync().ConfigureAwait(false);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (processHandle is not null)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (!processHandle.Process.HasExited)
|
||||
{
|
||||
processHandle.Process.Kill(entireProcessTree: true);
|
||||
_metrics.WorkerKilled("OpenSessionFailed");
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
processHandle.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
pipe?.Dispose();
|
||||
}
|
||||
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
private static NamedPipeServerStream CreatePipe(string pipeName)
|
||||
{
|
||||
return new NamedPipeServerStream(
|
||||
pipeName,
|
||||
PipeDirection.InOut,
|
||||
maxNumberOfServerInstances: 1,
|
||||
PipeTransmissionMode.Byte,
|
||||
PipeOptions.Asynchronous);
|
||||
}
|
||||
|
||||
private static async Task WaitForPipeConnectionAsync(
|
||||
NamedPipeServerStream pipe,
|
||||
TimeSpan startupTimeout,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
using CancellationTokenSource timeout = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
|
||||
timeout.CancelAfter(startupTimeout);
|
||||
await pipe.WaitForConnectionAsync(timeout.Token).ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
using MxGateway.Contracts.Proto;
|
||||
|
||||
namespace MxGateway.Server.Workers;
|
||||
|
||||
public interface IWorkerClient : IAsyncDisposable
|
||||
{
|
||||
string SessionId { get; }
|
||||
|
||||
int? ProcessId { get; }
|
||||
|
||||
WorkerClientState State { get; }
|
||||
|
||||
DateTimeOffset LastHeartbeatAt { get; }
|
||||
|
||||
Task StartAsync(CancellationToken cancellationToken);
|
||||
|
||||
Task<WorkerCommandReply> InvokeAsync(
|
||||
WorkerCommand command,
|
||||
TimeSpan timeout,
|
||||
CancellationToken cancellationToken);
|
||||
|
||||
IAsyncEnumerable<WorkerEvent> ReadEventsAsync(CancellationToken cancellationToken);
|
||||
|
||||
Task ShutdownAsync(TimeSpan timeout, CancellationToken cancellationToken);
|
||||
|
||||
void Kill(string reason);
|
||||
}
|
||||
@@ -0,0 +1,755 @@
|
||||
using System.Collections.Concurrent;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Threading.Channels;
|
||||
using Google.Protobuf.WellKnownTypes;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using MxGateway.Contracts;
|
||||
using MxGateway.Contracts.Proto;
|
||||
using MxGateway.Server.Metrics;
|
||||
|
||||
namespace MxGateway.Server.Workers;
|
||||
|
||||
public sealed class WorkerClient : IWorkerClient
|
||||
{
|
||||
private const string GatewayVersionFallback = "unknown";
|
||||
private readonly object _syncRoot = new();
|
||||
private readonly WorkerClientConnection _connection;
|
||||
private readonly WorkerClientOptions _options;
|
||||
private readonly GatewayMetrics? _metrics;
|
||||
private readonly TimeProvider _timeProvider;
|
||||
private readonly ILogger<WorkerClient> _logger;
|
||||
private readonly WorkerFrameReader _reader;
|
||||
private readonly WorkerFrameWriter _writer;
|
||||
private readonly Channel<WorkerEnvelope> _outboundEnvelopes;
|
||||
private readonly Channel<WorkerEvent> _events;
|
||||
private readonly ConcurrentDictionary<string, PendingCommand> _pendingCommands = new(StringComparer.Ordinal);
|
||||
private readonly CancellationTokenSource _stopCts = new();
|
||||
private long _nextSequence;
|
||||
private WorkerClientState _state;
|
||||
private DateTimeOffset _lastHeartbeatAt;
|
||||
private int? _processId;
|
||||
private Task? _readLoopTask;
|
||||
private Task? _writeLoopTask;
|
||||
private Task? _heartbeatLoopTask;
|
||||
private bool _disposed;
|
||||
|
||||
public WorkerClient(
|
||||
WorkerClientConnection connection,
|
||||
WorkerClientOptions? options = null,
|
||||
GatewayMetrics? metrics = null,
|
||||
TimeProvider? timeProvider = null,
|
||||
ILogger<WorkerClient>? logger = null)
|
||||
{
|
||||
_connection = connection ?? throw new ArgumentNullException(nameof(connection));
|
||||
_options = options ?? new WorkerClientOptions();
|
||||
_metrics = metrics;
|
||||
_timeProvider = timeProvider ?? TimeProvider.System;
|
||||
_logger = logger ?? NullLogger<WorkerClient>.Instance;
|
||||
_reader = new WorkerFrameReader(connection.Stream, connection.FrameOptions);
|
||||
_writer = new WorkerFrameWriter(connection.Stream, connection.FrameOptions);
|
||||
_outboundEnvelopes = Channel.CreateUnbounded<WorkerEnvelope>(
|
||||
new UnboundedChannelOptions
|
||||
{
|
||||
SingleReader = true,
|
||||
SingleWriter = false,
|
||||
AllowSynchronousContinuations = false,
|
||||
});
|
||||
_events = Channel.CreateBounded<WorkerEvent>(
|
||||
new BoundedChannelOptions(_options.EventChannelCapacity)
|
||||
{
|
||||
SingleReader = false,
|
||||
SingleWriter = true,
|
||||
FullMode = BoundedChannelFullMode.Wait,
|
||||
AllowSynchronousContinuations = false,
|
||||
});
|
||||
_lastHeartbeatAt = _timeProvider.GetUtcNow();
|
||||
}
|
||||
|
||||
public string SessionId => _connection.SessionId;
|
||||
|
||||
public int? ProcessId
|
||||
{
|
||||
get
|
||||
{
|
||||
lock (_syncRoot)
|
||||
{
|
||||
return _processId;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public WorkerClientState State
|
||||
{
|
||||
get
|
||||
{
|
||||
lock (_syncRoot)
|
||||
{
|
||||
return _state;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public DateTimeOffset LastHeartbeatAt
|
||||
{
|
||||
get
|
||||
{
|
||||
lock (_syncRoot)
|
||||
{
|
||||
return _lastHeartbeatAt;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public async Task StartAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
ThrowIfDisposed();
|
||||
TransitionFromCreatedToHandshaking();
|
||||
|
||||
_writeLoopTask = Task.Run(WriteLoopAsync);
|
||||
await EnqueueAsync(CreateGatewayHelloEnvelope(), cancellationToken).ConfigureAwait(false);
|
||||
|
||||
WorkerEnvelope helloEnvelope = await ReadHandshakeEnvelopeAsync(
|
||||
WorkerEnvelope.BodyOneofCase.WorkerHello,
|
||||
cancellationToken).ConfigureAwait(false);
|
||||
ValidateWorkerHello(helloEnvelope.WorkerHello);
|
||||
|
||||
WorkerEnvelope readyEnvelope = await ReadHandshakeEnvelopeAsync(
|
||||
WorkerEnvelope.BodyOneofCase.WorkerReady,
|
||||
cancellationToken).ConfigureAwait(false);
|
||||
MarkReady(readyEnvelope.WorkerReady);
|
||||
|
||||
_readLoopTask = Task.Run(ReadLoopAsync);
|
||||
_heartbeatLoopTask = Task.Run(HeartbeatLoopAsync);
|
||||
}
|
||||
|
||||
public async Task<WorkerCommandReply> InvokeAsync(
|
||||
WorkerCommand command,
|
||||
TimeSpan timeout,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(command);
|
||||
ThrowIfDisposed();
|
||||
EnsureReady();
|
||||
|
||||
if (timeout <= TimeSpan.Zero)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(timeout), timeout, "Command timeout must be greater than zero.");
|
||||
}
|
||||
|
||||
string correlationId = Guid.NewGuid().ToString("N");
|
||||
string method = GetCommandMethod(command);
|
||||
PendingCommand pendingCommand = new(
|
||||
correlationId,
|
||||
method,
|
||||
_timeProvider.GetTimestamp());
|
||||
|
||||
if (!_pendingCommands.TryAdd(correlationId, pendingCommand))
|
||||
{
|
||||
throw new InvalidOperationException("Generated a duplicate command correlation id.");
|
||||
}
|
||||
|
||||
_metrics?.CommandStarted(method);
|
||||
|
||||
try
|
||||
{
|
||||
await EnqueueAsync(CreateCommandEnvelope(correlationId, command), cancellationToken).ConfigureAwait(false);
|
||||
using CancellationTokenSource timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
|
||||
Task timeoutTask = Task.Delay(timeout, timeoutCts.Token);
|
||||
Task<WorkerCommandReply> replyTask = pendingCommand.Task;
|
||||
Task completedTask = await Task.WhenAny(replyTask, timeoutTask).ConfigureAwait(false);
|
||||
|
||||
if (completedTask == replyTask)
|
||||
{
|
||||
await timeoutCts.CancelAsync().ConfigureAwait(false);
|
||||
return await replyTask.ConfigureAwait(false);
|
||||
}
|
||||
|
||||
if (cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
RemovePendingCommandAsFailed(
|
||||
correlationId,
|
||||
pendingCommand,
|
||||
WorkerClientErrorCode.GatewayShutdown,
|
||||
"Command wait was canceled.");
|
||||
cancellationToken.ThrowIfCancellationRequested();
|
||||
}
|
||||
|
||||
RemovePendingCommandAsFailed(
|
||||
correlationId,
|
||||
pendingCommand,
|
||||
WorkerClientErrorCode.CommandTimeout,
|
||||
$"Worker command {method} timed out after {timeout}.");
|
||||
|
||||
throw new WorkerClientException(
|
||||
WorkerClientErrorCode.CommandTimeout,
|
||||
$"Worker command {method} timed out after {timeout}.");
|
||||
}
|
||||
catch
|
||||
{
|
||||
_pendingCommands.TryRemove(correlationId, out _);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<WorkerEvent> ReadEventsAsync(
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken)
|
||||
{
|
||||
await foreach (WorkerEvent workerEvent in _events.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false))
|
||||
{
|
||||
yield return workerEvent;
|
||||
}
|
||||
}
|
||||
|
||||
public async Task ShutdownAsync(TimeSpan timeout, CancellationToken cancellationToken)
|
||||
{
|
||||
ThrowIfDisposed();
|
||||
if (timeout <= TimeSpan.Zero)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(timeout), timeout, "Shutdown timeout must be greater than zero.");
|
||||
}
|
||||
|
||||
WorkerClientState state = State;
|
||||
if (state is WorkerClientState.Closed or WorkerClientState.Faulted)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
MarkClosing();
|
||||
await EnqueueAsync(CreateShutdownEnvelope(timeout, "gateway-shutdown"), cancellationToken).ConfigureAwait(false);
|
||||
_outboundEnvelopes.Writer.TryComplete();
|
||||
|
||||
using CancellationTokenSource timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
|
||||
timeoutCts.CancelAfter(timeout);
|
||||
try
|
||||
{
|
||||
await WaitForBackgroundTasksAsync(timeoutCts.Token).ConfigureAwait(false);
|
||||
MarkClosed("shutdown");
|
||||
}
|
||||
catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
SetFaulted(
|
||||
WorkerClientErrorCode.ShutdownTimeout,
|
||||
"Worker shutdown timed out.",
|
||||
null);
|
||||
throw new WorkerClientException(
|
||||
WorkerClientErrorCode.ShutdownTimeout,
|
||||
$"Worker shutdown timed out after {timeout}.");
|
||||
}
|
||||
}
|
||||
|
||||
public void Kill(string reason)
|
||||
{
|
||||
ThrowIfDisposed();
|
||||
_connection.ProcessHandle?.Process.Kill(entireProcessTree: true);
|
||||
_metrics?.WorkerKilled(reason);
|
||||
SetFaulted(
|
||||
WorkerClientErrorCode.WorkerFaulted,
|
||||
$"Worker was killed by the gateway: {reason}.",
|
||||
null);
|
||||
}
|
||||
|
||||
public async ValueTask DisposeAsync()
|
||||
{
|
||||
if (_disposed)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
_disposed = true;
|
||||
_stopCts.Cancel();
|
||||
_outboundEnvelopes.Writer.TryComplete();
|
||||
_events.Writer.TryComplete();
|
||||
CompletePendingCommands(
|
||||
new WorkerClientException(
|
||||
WorkerClientErrorCode.GatewayShutdown,
|
||||
"Worker client was disposed."));
|
||||
|
||||
await WaitForBackgroundTasksAsync(CancellationToken.None).ConfigureAwait(false);
|
||||
await _connection.Stream.DisposeAsync().ConfigureAwait(false);
|
||||
_connection.ProcessHandle?.Dispose();
|
||||
_stopCts.Dispose();
|
||||
}
|
||||
|
||||
private async Task WriteLoopAsync()
|
||||
{
|
||||
try
|
||||
{
|
||||
await foreach (WorkerEnvelope envelope in _outboundEnvelopes.Reader.ReadAllAsync(_stopCts.Token).ConfigureAwait(false))
|
||||
{
|
||||
await _writer.WriteAsync(envelope, _stopCts.Token).ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState())
|
||||
{
|
||||
}
|
||||
catch (Exception exception)
|
||||
{
|
||||
SetFaulted(
|
||||
WorkerClientErrorCode.WriteFailed,
|
||||
"Worker pipe write failed.",
|
||||
exception);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task ReadLoopAsync()
|
||||
{
|
||||
try
|
||||
{
|
||||
while (!_stopCts.IsCancellationRequested)
|
||||
{
|
||||
WorkerEnvelope envelope = await _reader.ReadAsync(_stopCts.Token).ConfigureAwait(false);
|
||||
await DispatchEnvelopeAsync(envelope, _stopCts.Token).ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState())
|
||||
{
|
||||
}
|
||||
catch (WorkerFrameProtocolException exception) when (exception.ErrorCode == WorkerFrameProtocolErrorCode.EndOfStream)
|
||||
{
|
||||
SetFaulted(
|
||||
WorkerClientErrorCode.PipeDisconnected,
|
||||
"Worker pipe disconnected.",
|
||||
exception);
|
||||
}
|
||||
catch (Exception exception)
|
||||
{
|
||||
SetFaulted(
|
||||
WorkerClientErrorCode.ProtocolViolation,
|
||||
"Worker read loop failed.",
|
||||
exception);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task HeartbeatLoopAsync()
|
||||
{
|
||||
try
|
||||
{
|
||||
while (!_stopCts.IsCancellationRequested)
|
||||
{
|
||||
await Task.Delay(_options.HeartbeatCheckInterval, _stopCts.Token).ConfigureAwait(false);
|
||||
if (State != WorkerClientState.Ready)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
DateTimeOffset lastHeartbeatAt = LastHeartbeatAt;
|
||||
DateTimeOffset now = _timeProvider.GetUtcNow();
|
||||
if (now - lastHeartbeatAt <= _options.HeartbeatGrace)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
_metrics?.HeartbeatFailed(SessionId);
|
||||
SetFaulted(
|
||||
WorkerClientErrorCode.HeartbeatExpired,
|
||||
$"Worker heartbeat expired. Last heartbeat was at {lastHeartbeatAt:O}.",
|
||||
null);
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState())
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
private async Task DispatchEnvelopeAsync(
|
||||
WorkerEnvelope envelope,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
switch (envelope.BodyCase)
|
||||
{
|
||||
case WorkerEnvelope.BodyOneofCase.WorkerCommandReply:
|
||||
CompleteCommand(envelope);
|
||||
break;
|
||||
case WorkerEnvelope.BodyOneofCase.WorkerEvent:
|
||||
await EnqueueWorkerEventAsync(envelope.WorkerEvent, cancellationToken).ConfigureAwait(false);
|
||||
break;
|
||||
case WorkerEnvelope.BodyOneofCase.WorkerHeartbeat:
|
||||
MarkHeartbeat(envelope.WorkerHeartbeat);
|
||||
break;
|
||||
case WorkerEnvelope.BodyOneofCase.WorkerFault:
|
||||
SetFaulted(
|
||||
WorkerClientErrorCode.WorkerFaulted,
|
||||
CreateWorkerFaultMessage(envelope.WorkerFault),
|
||||
null);
|
||||
break;
|
||||
case WorkerEnvelope.BodyOneofCase.WorkerShutdownAck:
|
||||
MarkClosed("worker-shutdown-ack");
|
||||
break;
|
||||
default:
|
||||
SetFaulted(
|
||||
WorkerClientErrorCode.ProtocolViolation,
|
||||
$"Worker sent unexpected envelope body {envelope.BodyCase}.",
|
||||
null);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private async Task EnqueueWorkerEventAsync(
|
||||
WorkerEvent workerEvent,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
if (workerEvent.Event is not null)
|
||||
{
|
||||
_metrics?.EventReceived(SessionId, workerEvent.Event.Family.ToString());
|
||||
}
|
||||
|
||||
if (!await _events.Writer.WaitToWriteAsync(cancellationToken).ConfigureAwait(false))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if (!_events.Writer.TryWrite(workerEvent))
|
||||
{
|
||||
_metrics?.QueueOverflow("worker-events");
|
||||
SetFaulted(
|
||||
WorkerClientErrorCode.ProtocolViolation,
|
||||
"Worker event channel rejected an event.",
|
||||
null);
|
||||
}
|
||||
}
|
||||
|
||||
private void CompleteCommand(WorkerEnvelope envelope)
|
||||
{
|
||||
string correlationId = envelope.CorrelationId;
|
||||
if (string.IsNullOrWhiteSpace(correlationId))
|
||||
{
|
||||
correlationId = envelope.WorkerCommandReply.Reply?.CorrelationId ?? string.Empty;
|
||||
}
|
||||
|
||||
if (!_pendingCommands.TryRemove(correlationId, out PendingCommand? pendingCommand))
|
||||
{
|
||||
_logger.LogDebug(
|
||||
"Ignoring late or unknown worker command reply for session {SessionId} and correlation {CorrelationId}.",
|
||||
SessionId,
|
||||
correlationId);
|
||||
return;
|
||||
}
|
||||
|
||||
TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp);
|
||||
_metrics?.CommandSucceeded(pendingCommand.Method, duration);
|
||||
pendingCommand.SetResult(envelope.WorkerCommandReply);
|
||||
}
|
||||
|
||||
private void RemovePendingCommandAsFailed(
|
||||
string correlationId,
|
||||
PendingCommand pendingCommand,
|
||||
WorkerClientErrorCode errorCode,
|
||||
string message)
|
||||
{
|
||||
if (!_pendingCommands.TryRemove(correlationId, out _))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp);
|
||||
_metrics?.CommandFailed(pendingCommand.Method, errorCode.ToString(), duration);
|
||||
pendingCommand.SetException(new WorkerClientException(errorCode, message));
|
||||
}
|
||||
|
||||
private async Task<WorkerEnvelope> ReadHandshakeEnvelopeAsync(
|
||||
WorkerEnvelope.BodyOneofCase expectedBody,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
WorkerEnvelope envelope = await _reader.ReadAsync(cancellationToken).ConfigureAwait(false);
|
||||
if (envelope.BodyCase != expectedBody)
|
||||
{
|
||||
throw new WorkerClientException(
|
||||
WorkerClientErrorCode.ProtocolViolation,
|
||||
$"Worker handshake expected {expectedBody} but received {envelope.BodyCase}.");
|
||||
}
|
||||
|
||||
return envelope;
|
||||
}
|
||||
|
||||
private void ValidateWorkerHello(WorkerHello workerHello)
|
||||
{
|
||||
if (workerHello.ProtocolVersion != _connection.FrameOptions.ProtocolVersion)
|
||||
{
|
||||
throw new WorkerClientException(
|
||||
WorkerClientErrorCode.ProtocolViolation,
|
||||
"Worker hello protocol version does not match the gateway protocol version.");
|
||||
}
|
||||
|
||||
if (!string.Equals(workerHello.Nonce, _connection.Nonce, StringComparison.Ordinal))
|
||||
{
|
||||
throw new WorkerClientException(
|
||||
WorkerClientErrorCode.ProtocolViolation,
|
||||
"Worker hello nonce does not match the gateway nonce.");
|
||||
}
|
||||
|
||||
lock (_syncRoot)
|
||||
{
|
||||
_processId = workerHello.WorkerProcessId == 0
|
||||
? _connection.ProcessHandle?.ProcessId
|
||||
: workerHello.WorkerProcessId;
|
||||
}
|
||||
}
|
||||
|
||||
private void MarkReady(WorkerReady ready)
|
||||
{
|
||||
lock (_syncRoot)
|
||||
{
|
||||
_processId = ready.WorkerProcessId == 0
|
||||
? _processId ?? _connection.ProcessHandle?.ProcessId
|
||||
: ready.WorkerProcessId;
|
||||
_lastHeartbeatAt = _timeProvider.GetUtcNow();
|
||||
_state = WorkerClientState.Ready;
|
||||
}
|
||||
|
||||
DateTimeOffset readyAt = _timeProvider.GetUtcNow();
|
||||
DateTimeOffset launchedAt = _connection.ProcessHandle?.LaunchedAt ?? readyAt;
|
||||
_metrics?.WorkerStarted(readyAt - launchedAt);
|
||||
}
|
||||
|
||||
private void MarkHeartbeat(WorkerHeartbeat heartbeat)
|
||||
{
|
||||
lock (_syncRoot)
|
||||
{
|
||||
_lastHeartbeatAt = _timeProvider.GetUtcNow();
|
||||
if (heartbeat.WorkerProcessId != 0)
|
||||
{
|
||||
_processId = heartbeat.WorkerProcessId;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void MarkClosing()
|
||||
{
|
||||
lock (_syncRoot)
|
||||
{
|
||||
if (_state is WorkerClientState.Closed or WorkerClientState.Faulted)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
_state = WorkerClientState.Closing;
|
||||
}
|
||||
}
|
||||
|
||||
private void MarkClosed(string reason)
|
||||
{
|
||||
lock (_syncRoot)
|
||||
{
|
||||
if (_state == WorkerClientState.Closed)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
_state = WorkerClientState.Closed;
|
||||
}
|
||||
|
||||
_stopCts.Cancel();
|
||||
_outboundEnvelopes.Writer.TryComplete();
|
||||
_events.Writer.TryComplete();
|
||||
CompletePendingCommands(
|
||||
new WorkerClientException(
|
||||
WorkerClientErrorCode.GatewayShutdown,
|
||||
$"Worker client closed because {reason}."));
|
||||
_metrics?.WorkerStopped(reason);
|
||||
}
|
||||
|
||||
private void SetFaulted(
|
||||
WorkerClientErrorCode errorCode,
|
||||
string message,
|
||||
Exception? exception)
|
||||
{
|
||||
WorkerClientException fault = exception is null
|
||||
? new WorkerClientException(errorCode, message)
|
||||
: new WorkerClientException(errorCode, message, exception);
|
||||
|
||||
lock (_syncRoot)
|
||||
{
|
||||
if (_state is WorkerClientState.Faulted or WorkerClientState.Closed)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
_state = WorkerClientState.Faulted;
|
||||
}
|
||||
|
||||
_stopCts.Cancel();
|
||||
_outboundEnvelopes.Writer.TryComplete(fault);
|
||||
_events.Writer.TryComplete(fault);
|
||||
CompletePendingCommands(fault);
|
||||
_metrics?.Fault(errorCode.ToString());
|
||||
_logger.LogWarning(exception, "Worker client faulted for session {SessionId}: {Message}", SessionId, message);
|
||||
}
|
||||
|
||||
private void CompletePendingCommands(Exception exception)
|
||||
{
|
||||
foreach (KeyValuePair<string, PendingCommand> item in _pendingCommands.ToArray())
|
||||
{
|
||||
if (_pendingCommands.TryRemove(item.Key, out PendingCommand? pendingCommand))
|
||||
{
|
||||
TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp);
|
||||
_metrics?.CommandFailed(pendingCommand.Method, exception.GetType().Name, duration);
|
||||
pendingCommand.SetException(exception);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void TransitionFromCreatedToHandshaking()
|
||||
{
|
||||
lock (_syncRoot)
|
||||
{
|
||||
if (_state != WorkerClientState.Created)
|
||||
{
|
||||
throw new WorkerClientException(
|
||||
WorkerClientErrorCode.InvalidState,
|
||||
$"Worker client cannot start from state {_state}.");
|
||||
}
|
||||
|
||||
_state = WorkerClientState.Handshaking;
|
||||
}
|
||||
}
|
||||
|
||||
private void EnsureReady()
|
||||
{
|
||||
WorkerClientState state = State;
|
||||
if (state != WorkerClientState.Ready)
|
||||
{
|
||||
throw new WorkerClientException(
|
||||
WorkerClientErrorCode.InvalidState,
|
||||
$"Worker client is not ready. Current state is {state}.");
|
||||
}
|
||||
}
|
||||
|
||||
private bool IsTerminalState()
|
||||
{
|
||||
WorkerClientState state = State;
|
||||
return state is WorkerClientState.Closing or WorkerClientState.Closed or WorkerClientState.Faulted;
|
||||
}
|
||||
|
||||
private async Task EnqueueAsync(
|
||||
WorkerEnvelope envelope,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
try
|
||||
{
|
||||
await _outboundEnvelopes.Writer.WriteAsync(envelope, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
catch (ChannelClosedException exception)
|
||||
{
|
||||
throw new WorkerClientException(
|
||||
WorkerClientErrorCode.WriteFailed,
|
||||
"Worker outbound channel is closed.",
|
||||
exception);
|
||||
}
|
||||
}
|
||||
|
||||
private WorkerEnvelope CreateGatewayHelloEnvelope()
|
||||
{
|
||||
return CreateEnvelope(
|
||||
correlationId: string.Empty,
|
||||
envelope => envelope.GatewayHello = new GatewayHello
|
||||
{
|
||||
SupportedProtocolVersion = _connection.FrameOptions.ProtocolVersion,
|
||||
Nonce = _connection.Nonce,
|
||||
GatewayVersion = typeof(GatewayContractInfo).Assembly.GetName().Version?.ToString() ?? GatewayVersionFallback,
|
||||
});
|
||||
}
|
||||
|
||||
private WorkerEnvelope CreateCommandEnvelope(
|
||||
string correlationId,
|
||||
WorkerCommand command)
|
||||
{
|
||||
return CreateEnvelope(
|
||||
correlationId,
|
||||
envelope => envelope.WorkerCommand = command.Clone());
|
||||
}
|
||||
|
||||
private WorkerEnvelope CreateShutdownEnvelope(
|
||||
TimeSpan timeout,
|
||||
string reason)
|
||||
{
|
||||
return CreateEnvelope(
|
||||
correlationId: string.Empty,
|
||||
envelope => envelope.WorkerShutdown = new WorkerShutdown
|
||||
{
|
||||
GracePeriod = Duration.FromTimeSpan(timeout),
|
||||
Reason = reason,
|
||||
});
|
||||
}
|
||||
|
||||
private WorkerEnvelope CreateEnvelope(
|
||||
string correlationId,
|
||||
Action<WorkerEnvelope> setBody)
|
||||
{
|
||||
WorkerEnvelope envelope = new()
|
||||
{
|
||||
ProtocolVersion = _connection.FrameOptions.ProtocolVersion,
|
||||
SessionId = SessionId,
|
||||
Sequence = (ulong)Interlocked.Increment(ref _nextSequence),
|
||||
CorrelationId = correlationId,
|
||||
};
|
||||
setBody(envelope);
|
||||
|
||||
return envelope;
|
||||
}
|
||||
|
||||
private static string GetCommandMethod(WorkerCommand command)
|
||||
{
|
||||
return command.Command?.Kind.ToString() ?? MxCommandKind.Unspecified.ToString();
|
||||
}
|
||||
|
||||
private static string CreateWorkerFaultMessage(WorkerFault fault)
|
||||
{
|
||||
return string.IsNullOrWhiteSpace(fault.DiagnosticMessage)
|
||||
? $"Worker faulted with category {fault.Category}."
|
||||
: $"Worker faulted with category {fault.Category}: {fault.DiagnosticMessage}";
|
||||
}
|
||||
|
||||
private async Task WaitForBackgroundTasksAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
Task[] tasks = new[] { _readLoopTask, _writeLoopTask, _heartbeatLoopTask }
|
||||
.Where(task => task is not null)
|
||||
.Cast<Task>()
|
||||
.ToArray();
|
||||
|
||||
if (tasks.Length == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
await Task.WhenAll(tasks).WaitAsync(cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
private void ThrowIfDisposed()
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
}
|
||||
|
||||
private sealed class PendingCommand
|
||||
{
|
||||
private readonly TaskCompletionSource<WorkerCommandReply> _completion = new(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
|
||||
public PendingCommand(
|
||||
string correlationId,
|
||||
string method,
|
||||
long startTimestamp)
|
||||
{
|
||||
CorrelationId = correlationId;
|
||||
Method = method;
|
||||
StartTimestamp = startTimestamp;
|
||||
}
|
||||
|
||||
public string CorrelationId { get; }
|
||||
|
||||
public string Method { get; }
|
||||
|
||||
public long StartTimestamp { get; }
|
||||
|
||||
public Task<WorkerCommandReply> Task => _completion.Task;
|
||||
|
||||
public void SetResult(WorkerCommandReply reply)
|
||||
{
|
||||
_completion.TrySetResult(reply);
|
||||
}
|
||||
|
||||
public void SetException(Exception exception)
|
||||
{
|
||||
_completion.TrySetException(exception);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
namespace MxGateway.Server.Workers;
|
||||
|
||||
public sealed class WorkerClientConnection
|
||||
{
|
||||
public WorkerClientConnection(
|
||||
string sessionId,
|
||||
string nonce,
|
||||
Stream stream,
|
||||
WorkerFrameProtocolOptions frameOptions,
|
||||
WorkerProcessHandle? processHandle = null)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(sessionId))
|
||||
{
|
||||
throw new ArgumentException("Session id is required.", nameof(sessionId));
|
||||
}
|
||||
|
||||
if (string.IsNullOrWhiteSpace(nonce))
|
||||
{
|
||||
throw new ArgumentException("Worker nonce is required.", nameof(nonce));
|
||||
}
|
||||
|
||||
SessionId = sessionId;
|
||||
Nonce = nonce;
|
||||
Stream = stream ?? throw new ArgumentNullException(nameof(stream));
|
||||
FrameOptions = frameOptions ?? throw new ArgumentNullException(nameof(frameOptions));
|
||||
ProcessHandle = processHandle;
|
||||
}
|
||||
|
||||
public string SessionId { get; }
|
||||
|
||||
public string Nonce { get; }
|
||||
|
||||
public Stream Stream { get; }
|
||||
|
||||
public WorkerFrameProtocolOptions FrameOptions { get; }
|
||||
|
||||
public WorkerProcessHandle? ProcessHandle { get; }
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
namespace MxGateway.Server.Workers;
|
||||
|
||||
public enum WorkerClientErrorCode
|
||||
{
|
||||
InvalidState,
|
||||
ProtocolViolation,
|
||||
PipeDisconnected,
|
||||
CommandTimeout,
|
||||
WorkerFaulted,
|
||||
HeartbeatExpired,
|
||||
ShutdownTimeout,
|
||||
GatewayShutdown,
|
||||
WriteFailed,
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
namespace MxGateway.Server.Workers;
|
||||
|
||||
public sealed class WorkerClientException : Exception
|
||||
{
|
||||
public WorkerClientException(
|
||||
WorkerClientErrorCode errorCode,
|
||||
string message)
|
||||
: base(message)
|
||||
{
|
||||
ErrorCode = errorCode;
|
||||
}
|
||||
|
||||
public WorkerClientException(
|
||||
WorkerClientErrorCode errorCode,
|
||||
string message,
|
||||
Exception innerException)
|
||||
: base(message, innerException)
|
||||
{
|
||||
ErrorCode = errorCode;
|
||||
}
|
||||
|
||||
public WorkerClientErrorCode ErrorCode { get; }
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
namespace MxGateway.Server.Workers;
|
||||
|
||||
public sealed class WorkerClientOptions
|
||||
{
|
||||
public static readonly TimeSpan DefaultHeartbeatGrace = TimeSpan.FromSeconds(15);
|
||||
public static readonly TimeSpan DefaultHeartbeatCheckInterval = TimeSpan.FromSeconds(1);
|
||||
public static readonly TimeSpan DefaultEventChannelFullModeTimeout = TimeSpan.FromSeconds(5);
|
||||
|
||||
public WorkerClientOptions()
|
||||
{
|
||||
HeartbeatGrace = DefaultHeartbeatGrace;
|
||||
HeartbeatCheckInterval = DefaultHeartbeatCheckInterval;
|
||||
EventChannelCapacity = 1_024;
|
||||
EventChannelFullModeTimeout = DefaultEventChannelFullModeTimeout;
|
||||
}
|
||||
|
||||
public TimeSpan HeartbeatGrace { get; init; }
|
||||
|
||||
public TimeSpan HeartbeatCheckInterval { get; init; }
|
||||
|
||||
public int EventChannelCapacity { get; init; }
|
||||
|
||||
public TimeSpan EventChannelFullModeTimeout { get; init; }
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
namespace MxGateway.Server.Workers;
|
||||
|
||||
public enum WorkerClientState
|
||||
{
|
||||
Created,
|
||||
Handshaking,
|
||||
Ready,
|
||||
Closing,
|
||||
Closed,
|
||||
Faulted,
|
||||
}
|
||||
@@ -0,0 +1,320 @@
|
||||
using Google.Protobuf.WellKnownTypes;
|
||||
using Microsoft.Extensions.Options;
|
||||
using MxGateway.Contracts.Proto;
|
||||
using MxGateway.Server.Configuration;
|
||||
using MxGateway.Server.Metrics;
|
||||
using MxGateway.Server.Sessions;
|
||||
using MxGateway.Server.Workers;
|
||||
|
||||
namespace MxGateway.Tests.Gateway.Sessions;
|
||||
|
||||
public sealed class SessionManagerTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task OpenSessionAsync_WithWorkerReady_RegistersReadySession()
|
||||
{
|
||||
FakeWorkerClient workerClient = new();
|
||||
FakeSessionWorkerClientFactory factory = new(workerClient)
|
||||
{
|
||||
ApplyLifecycleTransitions = true,
|
||||
};
|
||||
using GatewayMetrics metrics = new();
|
||||
SessionManager manager = CreateManager(factory, metrics: metrics);
|
||||
|
||||
GatewaySession session = await manager.OpenSessionAsync(CreateOpenRequest(), "client-1", CancellationToken.None);
|
||||
|
||||
Assert.True(manager.TryGetSession(session.SessionId, out GatewaySession registered));
|
||||
Assert.Same(session, registered);
|
||||
Assert.Equal(SessionState.Ready, session.State);
|
||||
Assert.Equal("client-1", session.ClientIdentity);
|
||||
Assert.Equal(["StartingWorker", "WaitingForPipe", "Handshaking", "InitializingWorker"], factory.ObservedStates);
|
||||
Assert.Equal(1, metrics.GetSnapshot().OpenSessions);
|
||||
Assert.Equal(1, metrics.GetSnapshot().SessionsOpened);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokeAsync_WhenSessionReady_ForwardsCommandToWorker()
|
||||
{
|
||||
FakeWorkerClient workerClient = new();
|
||||
SessionManager manager = CreateManager(new FakeSessionWorkerClientFactory(workerClient));
|
||||
GatewaySession session = await manager.OpenSessionAsync(CreateOpenRequest(), "client-1", CancellationToken.None);
|
||||
|
||||
WorkerCommandReply reply = await manager.InvokeAsync(
|
||||
session.SessionId,
|
||||
CreateCommand(MxCommandKind.Ping),
|
||||
CancellationToken.None);
|
||||
|
||||
Assert.Equal(1, workerClient.InvokeCount);
|
||||
Assert.Equal(MxCommandKind.Ping, reply.Reply.Kind);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokeAsync_WhenSessionFaulted_RejectsCommand()
|
||||
{
|
||||
FakeWorkerClient workerClient = new();
|
||||
SessionManager manager = CreateManager(new FakeSessionWorkerClientFactory(workerClient));
|
||||
GatewaySession session = await manager.OpenSessionAsync(CreateOpenRequest(), "client-1", CancellationToken.None);
|
||||
session.MarkFaulted("test fault");
|
||||
|
||||
SessionManagerException exception = await Assert.ThrowsAsync<SessionManagerException>(
|
||||
async () => await manager.InvokeAsync(
|
||||
session.SessionId,
|
||||
CreateCommand(MxCommandKind.Ping),
|
||||
CancellationToken.None));
|
||||
|
||||
Assert.Equal(SessionManagerErrorCode.SessionNotReady, exception.ErrorCode);
|
||||
Assert.Equal(0, workerClient.InvokeCount);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CloseSessionAsync_WhenCalledTwice_IsIdempotent()
|
||||
{
|
||||
FakeWorkerClient workerClient = new();
|
||||
using GatewayMetrics metrics = new();
|
||||
SessionManager manager = CreateManager(new FakeSessionWorkerClientFactory(workerClient), metrics: metrics);
|
||||
GatewaySession session = await manager.OpenSessionAsync(CreateOpenRequest(), "client-1", CancellationToken.None);
|
||||
|
||||
SessionCloseResult firstClose = await manager.CloseSessionAsync(session.SessionId, CancellationToken.None);
|
||||
SessionCloseResult secondClose = await manager.CloseSessionAsync(session.SessionId, CancellationToken.None);
|
||||
|
||||
Assert.False(firstClose.AlreadyClosed);
|
||||
Assert.True(secondClose.AlreadyClosed);
|
||||
Assert.Equal(SessionState.Closed, firstClose.FinalState);
|
||||
Assert.Equal(SessionState.Closed, secondClose.FinalState);
|
||||
Assert.Equal(1, workerClient.ShutdownCount);
|
||||
Assert.Equal(1, metrics.GetSnapshot().SessionsClosed);
|
||||
Assert.Equal(0, metrics.GetSnapshot().OpenSessions);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task OpenSessionAsync_WhenWorkerCreationFails_RemovesSessionFromRegistry()
|
||||
{
|
||||
SessionRegistry registry = new();
|
||||
using GatewayMetrics metrics = new();
|
||||
SessionManager manager = CreateManager(
|
||||
new FailingSessionWorkerClientFactory(),
|
||||
registry,
|
||||
metrics);
|
||||
|
||||
SessionManagerException exception = await Assert.ThrowsAsync<SessionManagerException>(
|
||||
async () => await manager.OpenSessionAsync(CreateOpenRequest(), "client-1", CancellationToken.None));
|
||||
|
||||
Assert.Equal(SessionManagerErrorCode.OpenFailed, exception.ErrorCode);
|
||||
Assert.Equal(0, registry.Count);
|
||||
Assert.Equal(0, metrics.GetSnapshot().SessionsOpened);
|
||||
Assert.Equal(1, metrics.GetSnapshot().Faults);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CloseExpiredLeasesAsync_ClosesExpiredSessionsOnly()
|
||||
{
|
||||
FakeWorkerClient expiredClient = new();
|
||||
FakeWorkerClient activeClient = new();
|
||||
QueueingSessionWorkerClientFactory factory = new(expiredClient, activeClient);
|
||||
SessionManager manager = CreateManager(factory);
|
||||
GatewaySession expiredSession = await manager.OpenSessionAsync(CreateOpenRequest(), "client-1", CancellationToken.None);
|
||||
GatewaySession activeSession = await manager.OpenSessionAsync(CreateOpenRequest(), "client-2", CancellationToken.None);
|
||||
DateTimeOffset now = DateTimeOffset.UtcNow;
|
||||
expiredSession.ExtendLease(now.AddSeconds(-1));
|
||||
activeSession.ExtendLease(now.AddMinutes(5));
|
||||
|
||||
int closedCount = await manager.CloseExpiredLeasesAsync(now, CancellationToken.None);
|
||||
|
||||
Assert.Equal(1, closedCount);
|
||||
Assert.Equal(SessionState.Closed, expiredSession.State);
|
||||
Assert.Equal(SessionState.Ready, activeSession.State);
|
||||
Assert.Equal(1, expiredClient.ShutdownCount);
|
||||
Assert.Equal(0, activeClient.ShutdownCount);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ShutdownAsync_ClosesAllRegisteredSessions()
|
||||
{
|
||||
FakeWorkerClient firstClient = new();
|
||||
FakeWorkerClient secondClient = new();
|
||||
QueueingSessionWorkerClientFactory factory = new(firstClient, secondClient);
|
||||
using GatewayMetrics metrics = new();
|
||||
SessionManager manager = CreateManager(factory, metrics: metrics);
|
||||
GatewaySession firstSession = await manager.OpenSessionAsync(CreateOpenRequest(), "client-1", CancellationToken.None);
|
||||
GatewaySession secondSession = await manager.OpenSessionAsync(CreateOpenRequest(), "client-2", CancellationToken.None);
|
||||
|
||||
await manager.ShutdownAsync(CancellationToken.None);
|
||||
|
||||
Assert.Equal(SessionState.Closed, firstSession.State);
|
||||
Assert.Equal(SessionState.Closed, secondSession.State);
|
||||
Assert.Equal(1, firstClient.ShutdownCount);
|
||||
Assert.Equal(1, secondClient.ShutdownCount);
|
||||
Assert.Equal(2, metrics.GetSnapshot().SessionsClosed);
|
||||
Assert.Equal(0, metrics.GetSnapshot().OpenSessions);
|
||||
}
|
||||
|
||||
private static SessionManager CreateManager(
|
||||
ISessionWorkerClientFactory factory,
|
||||
ISessionRegistry? registry = null,
|
||||
GatewayMetrics? metrics = null,
|
||||
GatewayOptions? options = null)
|
||||
{
|
||||
return new SessionManager(
|
||||
registry ?? new SessionRegistry(),
|
||||
factory,
|
||||
Options.Create(options ?? CreateOptions()),
|
||||
metrics ?? new GatewayMetrics());
|
||||
}
|
||||
|
||||
private static GatewayOptions CreateOptions()
|
||||
{
|
||||
return new GatewayOptions
|
||||
{
|
||||
Sessions = new SessionOptions
|
||||
{
|
||||
DefaultCommandTimeoutSeconds = 30,
|
||||
MaxSessions = 64,
|
||||
},
|
||||
Worker = new WorkerOptions
|
||||
{
|
||||
StartupTimeoutSeconds = 30,
|
||||
ShutdownTimeoutSeconds = 10,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
private static SessionOpenRequest CreateOpenRequest()
|
||||
{
|
||||
return new SessionOpenRequest(
|
||||
RequestedBackend: null,
|
||||
ClientSessionName: "test-session",
|
||||
ClientCorrelationId: "client-correlation-1",
|
||||
CommandTimeout: Duration.FromTimeSpan(TimeSpan.FromSeconds(5)));
|
||||
}
|
||||
|
||||
private static WorkerCommand CreateCommand(MxCommandKind kind)
|
||||
{
|
||||
return new WorkerCommand
|
||||
{
|
||||
Command = new MxCommand
|
||||
{
|
||||
Kind = kind,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
private sealed class FakeSessionWorkerClientFactory(IWorkerClient workerClient) : ISessionWorkerClientFactory
|
||||
{
|
||||
public List<string> ObservedStates { get; } = [];
|
||||
|
||||
public bool ApplyLifecycleTransitions { get; init; }
|
||||
|
||||
public Task<IWorkerClient> CreateAsync(
|
||||
GatewaySession session,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
ObservedStates.Add(session.State.ToString());
|
||||
if (ApplyLifecycleTransitions)
|
||||
{
|
||||
session.TransitionTo(SessionState.WaitingForPipe);
|
||||
ObservedStates.Add(session.State.ToString());
|
||||
session.TransitionTo(SessionState.Handshaking);
|
||||
ObservedStates.Add(session.State.ToString());
|
||||
session.TransitionTo(SessionState.InitializingWorker);
|
||||
ObservedStates.Add(session.State.ToString());
|
||||
}
|
||||
|
||||
return Task.FromResult(workerClient);
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class QueueingSessionWorkerClientFactory : ISessionWorkerClientFactory
|
||||
{
|
||||
private readonly Queue<IWorkerClient> _workerClients;
|
||||
|
||||
public QueueingSessionWorkerClientFactory(params IWorkerClient[] workerClients)
|
||||
{
|
||||
_workerClients = new Queue<IWorkerClient>(workerClients);
|
||||
}
|
||||
|
||||
public Task<IWorkerClient> CreateAsync(
|
||||
GatewaySession session,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
return Task.FromResult(_workerClients.Dequeue());
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class FailingSessionWorkerClientFactory : ISessionWorkerClientFactory
|
||||
{
|
||||
public Task<IWorkerClient> CreateAsync(
|
||||
GatewaySession session,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
throw new InvalidOperationException("worker startup failed");
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class FakeWorkerClient : IWorkerClient
|
||||
{
|
||||
public string SessionId { get; init; } = "session-1";
|
||||
|
||||
public int? ProcessId { get; init; } = 1234;
|
||||
|
||||
public WorkerClientState State { get; set; } = WorkerClientState.Ready;
|
||||
|
||||
public DateTimeOffset LastHeartbeatAt { get; init; } = DateTimeOffset.UtcNow;
|
||||
|
||||
public int InvokeCount { get; private set; }
|
||||
|
||||
public int ShutdownCount { get; private set; }
|
||||
|
||||
public int KillCount { get; private set; }
|
||||
|
||||
public Task StartAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public Task<WorkerCommandReply> InvokeAsync(
|
||||
WorkerCommand command,
|
||||
TimeSpan timeout,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
InvokeCount++;
|
||||
MxCommandKind kind = command.Command?.Kind ?? MxCommandKind.Unspecified;
|
||||
|
||||
return Task.FromResult(new WorkerCommandReply
|
||||
{
|
||||
Reply = new MxCommandReply
|
||||
{
|
||||
SessionId = SessionId,
|
||||
CorrelationId = "correlation-1",
|
||||
Kind = kind,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<WorkerEvent> ReadEventsAsync(
|
||||
[System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken)
|
||||
{
|
||||
await Task.CompletedTask;
|
||||
yield break;
|
||||
}
|
||||
|
||||
public Task ShutdownAsync(
|
||||
TimeSpan timeout,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
ShutdownCount++;
|
||||
State = WorkerClientState.Closed;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public void Kill(string reason)
|
||||
{
|
||||
KillCount++;
|
||||
State = WorkerClientState.Faulted;
|
||||
}
|
||||
|
||||
public ValueTask DisposeAsync()
|
||||
{
|
||||
return ValueTask.CompletedTask;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,341 @@
|
||||
using System.IO.Pipes;
|
||||
using MxGateway.Contracts;
|
||||
using MxGateway.Contracts.Proto;
|
||||
using MxGateway.Server.Workers;
|
||||
|
||||
namespace MxGateway.Tests.Gateway.Workers;
|
||||
|
||||
public sealed class WorkerClientTests
|
||||
{
|
||||
private const string SessionId = "session-worker-client";
|
||||
private const string Nonce = "nonce-worker-client";
|
||||
private const int WorkerProcessId = 4321;
|
||||
private static readonly TimeSpan TestTimeout = TimeSpan.FromSeconds(5);
|
||||
|
||||
[Fact]
|
||||
public async Task StartAsync_WithWorkerHelloAndReady_EntersReadyState()
|
||||
{
|
||||
await using PipePair pipePair = await PipePair.CreateAsync();
|
||||
await using WorkerClient client = CreateClient(pipePair);
|
||||
|
||||
await CompleteHandshakeAsync(client, pipePair);
|
||||
|
||||
Assert.Equal(WorkerClientState.Ready, client.State);
|
||||
Assert.Equal(WorkerProcessId, client.ProcessId);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokeAsync_WithMatchingReply_CompletesPendingCommand()
|
||||
{
|
||||
await using PipePair pipePair = await PipePair.CreateAsync();
|
||||
await using WorkerClient client = CreateClient(pipePair);
|
||||
await CompleteHandshakeAsync(client, pipePair);
|
||||
|
||||
Task<WorkerCommandReply> invokeTask = client.InvokeAsync(
|
||||
CreateCommand(MxCommandKind.Ping),
|
||||
TestTimeout,
|
||||
CancellationToken.None);
|
||||
|
||||
WorkerEnvelope commandEnvelope = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
|
||||
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerCommand, commandEnvelope.BodyCase);
|
||||
Assert.False(string.IsNullOrWhiteSpace(commandEnvelope.CorrelationId));
|
||||
|
||||
await pipePair.WorkerWriter.WriteAsync(
|
||||
CreateCommandReplyEnvelope(commandEnvelope.CorrelationId, MxCommandKind.Ping));
|
||||
|
||||
WorkerCommandReply reply = await invokeTask.WaitAsync(TestTimeout);
|
||||
|
||||
Assert.Equal(commandEnvelope.CorrelationId, reply.Reply.CorrelationId);
|
||||
Assert.Equal(MxCommandKind.Ping, reply.Reply.Kind);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokeAsync_WithLateReply_IgnoresLateReplyAndKeepsClientReady()
|
||||
{
|
||||
await using PipePair pipePair = await PipePair.CreateAsync();
|
||||
await using WorkerClient client = CreateClient(pipePair);
|
||||
await CompleteHandshakeAsync(client, pipePair);
|
||||
|
||||
Task<WorkerCommandReply> timedOutInvokeTask = client.InvokeAsync(
|
||||
CreateCommand(MxCommandKind.Ping),
|
||||
TimeSpan.FromMilliseconds(50),
|
||||
CancellationToken.None);
|
||||
WorkerEnvelope timedOutCommand = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
|
||||
|
||||
WorkerClientException exception = await Assert.ThrowsAsync<WorkerClientException>(
|
||||
async () => await timedOutInvokeTask);
|
||||
Assert.Equal(WorkerClientErrorCode.CommandTimeout, exception.ErrorCode);
|
||||
|
||||
await pipePair.WorkerWriter.WriteAsync(
|
||||
CreateCommandReplyEnvelope(timedOutCommand.CorrelationId, MxCommandKind.Ping));
|
||||
await Task.Delay(TimeSpan.FromMilliseconds(50));
|
||||
|
||||
Task<WorkerCommandReply> secondInvokeTask = client.InvokeAsync(
|
||||
CreateCommand(MxCommandKind.GetWorkerInfo),
|
||||
TestTimeout,
|
||||
CancellationToken.None);
|
||||
WorkerEnvelope secondCommand = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
|
||||
await pipePair.WorkerWriter.WriteAsync(
|
||||
CreateCommandReplyEnvelope(secondCommand.CorrelationId, MxCommandKind.GetWorkerInfo));
|
||||
|
||||
WorkerCommandReply reply = await secondInvokeTask.WaitAsync(TestTimeout);
|
||||
|
||||
Assert.Equal(WorkerClientState.Ready, client.State);
|
||||
Assert.Equal(MxCommandKind.GetWorkerInfo, reply.Reply.Kind);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ReadEventsAsync_WithWorkerEvents_YieldsEventsInPipeOrder()
|
||||
{
|
||||
await using PipePair pipePair = await PipePair.CreateAsync();
|
||||
await using WorkerClient client = CreateClient(pipePair);
|
||||
await CompleteHandshakeAsync(client, pipePair);
|
||||
using CancellationTokenSource cancellationTokenSource = new(TestTimeout);
|
||||
|
||||
await using IAsyncEnumerator<WorkerEvent> events =
|
||||
client.ReadEventsAsync(cancellationTokenSource.Token).GetAsyncEnumerator(cancellationTokenSource.Token);
|
||||
|
||||
await pipePair.WorkerWriter.WriteAsync(
|
||||
CreateEventEnvelope(sequence: 11, MxEventFamily.OnDataChange));
|
||||
await pipePair.WorkerWriter.WriteAsync(
|
||||
CreateEventEnvelope(sequence: 12, MxEventFamily.OperationComplete));
|
||||
|
||||
Assert.True(await events.MoveNextAsync());
|
||||
Assert.Equal((ulong)11, events.Current.Event.WorkerSequence);
|
||||
Assert.Equal(MxEventFamily.OnDataChange, events.Current.Event.Family);
|
||||
|
||||
Assert.True(await events.MoveNextAsync());
|
||||
Assert.Equal((ulong)12, events.Current.Event.WorkerSequence);
|
||||
Assert.Equal(MxEventFamily.OperationComplete, events.Current.Event.Family);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ReadLoop_WhenPipeDisconnects_FaultsClient()
|
||||
{
|
||||
await using PipePair pipePair = await PipePair.CreateAsync();
|
||||
await using WorkerClient client = CreateClient(pipePair);
|
||||
await CompleteHandshakeAsync(client, pipePair);
|
||||
|
||||
await pipePair.DisposeWorkerSideAsync();
|
||||
|
||||
await WaitUntilAsync(
|
||||
() => client.State == WorkerClientState.Faulted,
|
||||
TestTimeout);
|
||||
|
||||
Assert.Equal(WorkerClientState.Faulted, client.State);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task HeartbeatMonitor_WhenHeartbeatExpires_FaultsClient()
|
||||
{
|
||||
await using PipePair pipePair = await PipePair.CreateAsync();
|
||||
await using WorkerClient client = CreateClient(
|
||||
pipePair,
|
||||
new WorkerClientOptions
|
||||
{
|
||||
HeartbeatGrace = TimeSpan.FromMilliseconds(80),
|
||||
HeartbeatCheckInterval = TimeSpan.FromMilliseconds(20),
|
||||
EventChannelCapacity = 8,
|
||||
});
|
||||
await CompleteHandshakeAsync(client, pipePair);
|
||||
|
||||
await WaitUntilAsync(
|
||||
() => client.State == WorkerClientState.Faulted,
|
||||
TestTimeout);
|
||||
|
||||
Assert.Equal(WorkerClientState.Faulted, client.State);
|
||||
}
|
||||
|
||||
private static WorkerClient CreateClient(
|
||||
PipePair pipePair,
|
||||
WorkerClientOptions? options = null)
|
||||
{
|
||||
WorkerFrameProtocolOptions frameOptions = new(SessionId);
|
||||
WorkerClientConnection connection = new(
|
||||
SessionId,
|
||||
Nonce,
|
||||
pipePair.GatewayStream,
|
||||
frameOptions);
|
||||
|
||||
return new WorkerClient(connection, options);
|
||||
}
|
||||
|
||||
private static async Task CompleteHandshakeAsync(
|
||||
WorkerClient client,
|
||||
PipePair pipePair)
|
||||
{
|
||||
Task startTask = client.StartAsync(CancellationToken.None);
|
||||
|
||||
WorkerEnvelope gatewayHello = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
|
||||
Assert.Equal(WorkerEnvelope.BodyOneofCase.GatewayHello, gatewayHello.BodyCase);
|
||||
Assert.Equal(Nonce, gatewayHello.GatewayHello.Nonce);
|
||||
Assert.Equal(GatewayContractInfo.WorkerProtocolVersion, gatewayHello.GatewayHello.SupportedProtocolVersion);
|
||||
|
||||
await pipePair.WorkerWriter.WriteAsync(CreateWorkerHelloEnvelope());
|
||||
await pipePair.WorkerWriter.WriteAsync(CreateWorkerReadyEnvelope());
|
||||
await startTask.WaitAsync(TestTimeout);
|
||||
}
|
||||
|
||||
private static WorkerCommand CreateCommand(MxCommandKind kind)
|
||||
{
|
||||
return new WorkerCommand
|
||||
{
|
||||
Command = new MxCommand
|
||||
{
|
||||
Kind = kind,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
private static WorkerEnvelope CreateWorkerHelloEnvelope()
|
||||
{
|
||||
return CreateWorkerEnvelope(
|
||||
correlationId: string.Empty,
|
||||
sequence: 1,
|
||||
envelope => envelope.WorkerHello = new WorkerHello
|
||||
{
|
||||
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
|
||||
Nonce = Nonce,
|
||||
WorkerProcessId = WorkerProcessId,
|
||||
WorkerVersion = "fake-worker",
|
||||
});
|
||||
}
|
||||
|
||||
private static WorkerEnvelope CreateWorkerReadyEnvelope()
|
||||
{
|
||||
return CreateWorkerEnvelope(
|
||||
correlationId: string.Empty,
|
||||
sequence: 2,
|
||||
envelope => envelope.WorkerReady = new WorkerReady
|
||||
{
|
||||
WorkerProcessId = WorkerProcessId,
|
||||
MxaccessProgid = "LMXProxy.LMXProxyServer.1",
|
||||
MxaccessClsid = "{C30B52F5-2CB5-4760-AF0A-3A344A7EB5DC}",
|
||||
});
|
||||
}
|
||||
|
||||
private static WorkerEnvelope CreateCommandReplyEnvelope(
|
||||
string correlationId,
|
||||
MxCommandKind kind)
|
||||
{
|
||||
return CreateWorkerEnvelope(
|
||||
correlationId,
|
||||
sequence: 10,
|
||||
envelope => envelope.WorkerCommandReply = new WorkerCommandReply
|
||||
{
|
||||
Reply = new MxCommandReply
|
||||
{
|
||||
SessionId = SessionId,
|
||||
CorrelationId = correlationId,
|
||||
Kind = kind,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
private static WorkerEnvelope CreateEventEnvelope(
|
||||
ulong sequence,
|
||||
MxEventFamily family)
|
||||
{
|
||||
return CreateWorkerEnvelope(
|
||||
correlationId: string.Empty,
|
||||
sequence,
|
||||
envelope => envelope.WorkerEvent = new WorkerEvent
|
||||
{
|
||||
Event = new MxEvent
|
||||
{
|
||||
SessionId = SessionId,
|
||||
Family = family,
|
||||
WorkerSequence = sequence,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
private static WorkerEnvelope CreateWorkerEnvelope(
|
||||
string correlationId,
|
||||
ulong sequence,
|
||||
Action<WorkerEnvelope> setBody)
|
||||
{
|
||||
WorkerEnvelope envelope = new()
|
||||
{
|
||||
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
|
||||
SessionId = SessionId,
|
||||
Sequence = sequence,
|
||||
CorrelationId = correlationId,
|
||||
};
|
||||
setBody(envelope);
|
||||
|
||||
return envelope;
|
||||
}
|
||||
|
||||
private static async Task WaitUntilAsync(
|
||||
Func<bool> predicate,
|
||||
TimeSpan timeout)
|
||||
{
|
||||
using CancellationTokenSource cancellationTokenSource = new(timeout);
|
||||
while (!predicate())
|
||||
{
|
||||
await Task.Delay(TimeSpan.FromMilliseconds(10), cancellationTokenSource.Token);
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class PipePair : IAsyncDisposable
|
||||
{
|
||||
private readonly NamedPipeClientStream _workerStream;
|
||||
private bool _workerSideDisposed;
|
||||
|
||||
private PipePair(
|
||||
NamedPipeServerStream gatewayStream,
|
||||
NamedPipeClientStream workerStream)
|
||||
{
|
||||
GatewayStream = gatewayStream;
|
||||
_workerStream = workerStream;
|
||||
WorkerReader = new WorkerFrameReader(_workerStream, new WorkerFrameProtocolOptions(SessionId));
|
||||
WorkerWriter = new WorkerFrameWriter(_workerStream, new WorkerFrameProtocolOptions(SessionId));
|
||||
}
|
||||
|
||||
public NamedPipeServerStream GatewayStream { get; }
|
||||
|
||||
public WorkerFrameReader WorkerReader { get; }
|
||||
|
||||
public WorkerFrameWriter WorkerWriter { get; }
|
||||
|
||||
public static async Task<PipePair> CreateAsync()
|
||||
{
|
||||
string pipeName = $"mxaccessgw-workerclient-tests-{Guid.NewGuid():N}";
|
||||
NamedPipeServerStream gatewayStream = new(
|
||||
pipeName,
|
||||
PipeDirection.InOut,
|
||||
maxNumberOfServerInstances: 1,
|
||||
PipeTransmissionMode.Byte,
|
||||
PipeOptions.Asynchronous);
|
||||
NamedPipeClientStream workerStream = new(
|
||||
".",
|
||||
pipeName,
|
||||
PipeDirection.InOut,
|
||||
PipeOptions.Asynchronous);
|
||||
|
||||
Task waitForConnectionTask = gatewayStream.WaitForConnectionAsync();
|
||||
await workerStream.ConnectAsync();
|
||||
await waitForConnectionTask;
|
||||
|
||||
return new PipePair(gatewayStream, workerStream);
|
||||
}
|
||||
|
||||
public async ValueTask DisposeWorkerSideAsync()
|
||||
{
|
||||
if (_workerSideDisposed)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
await _workerStream.DisposeAsync();
|
||||
_workerSideDisposed = true;
|
||||
}
|
||||
|
||||
public async ValueTask DisposeAsync()
|
||||
{
|
||||
await DisposeWorkerSideAsync();
|
||||
await GatewayStream.DisposeAsync();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,242 @@
|
||||
using System.Text.Json;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using MxGateway.Server.Configuration;
|
||||
using MxGateway.Server.Security.Authentication;
|
||||
|
||||
namespace MxGateway.Tests.Security.Authentication;
|
||||
|
||||
public sealed class ApiKeyAdminCliRunnerTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task CreateKeyAsync_CreatesAuthenticatingKeyAndAudits()
|
||||
{
|
||||
await using ServiceProvider services = BuildServices(CreateTempDatabasePath());
|
||||
ApiKeyAdminCliRunner runner = services.GetRequiredService<ApiKeyAdminCliRunner>();
|
||||
StringWriter output = new();
|
||||
|
||||
await runner.RunAsync(
|
||||
new ApiKeyAdminCommand(
|
||||
Kind: ApiKeyAdminCommandKind.CreateKey,
|
||||
Json: true,
|
||||
SqlitePath: null,
|
||||
Pepper: null,
|
||||
KeyId: "operator01",
|
||||
DisplayName: "Operator",
|
||||
Scopes: new HashSet<string>(StringComparer.Ordinal) { "session:open", "events:read" }),
|
||||
output,
|
||||
CancellationToken.None);
|
||||
|
||||
string apiKey = ReadApiKey(output.ToString());
|
||||
|
||||
IApiKeyVerifier verifier = services.GetRequiredService<IApiKeyVerifier>();
|
||||
ApiKeyVerificationResult verification = await verifier.VerifyAsync($"Bearer {apiKey}", CancellationToken.None);
|
||||
|
||||
Assert.True(verification.Succeeded);
|
||||
Assert.NotNull(verification.Identity);
|
||||
Assert.Equal("operator01", verification.Identity.KeyId);
|
||||
Assert.Contains("session:open", verification.Identity.Scopes);
|
||||
|
||||
IReadOnlyList<ApiKeyAuditRecord> auditRecords = await services
|
||||
.GetRequiredService<IApiKeyAuditStore>()
|
||||
.ListRecentAsync(10, CancellationToken.None);
|
||||
|
||||
Assert.Contains(auditRecords, record => record.EventType == "create-key" && record.KeyId == "operator01");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ListKeysAsync_DoesNotPrintRawSecret()
|
||||
{
|
||||
await using ServiceProvider services = BuildServices(CreateTempDatabasePath());
|
||||
ApiKeyAdminCliRunner runner = services.GetRequiredService<ApiKeyAdminCliRunner>();
|
||||
string apiKey = await CreateKeyAsync(runner, "operator01");
|
||||
StringWriter listOutput = new();
|
||||
|
||||
await runner.RunAsync(
|
||||
new ApiKeyAdminCommand(
|
||||
Kind: ApiKeyAdminCommandKind.ListKeys,
|
||||
Json: true,
|
||||
SqlitePath: null,
|
||||
Pepper: null,
|
||||
KeyId: null,
|
||||
DisplayName: null,
|
||||
Scopes: new HashSet<string>(StringComparer.Ordinal)),
|
||||
listOutput,
|
||||
CancellationToken.None);
|
||||
|
||||
string listJson = listOutput.ToString();
|
||||
|
||||
Assert.Contains("operator01", listJson, StringComparison.Ordinal);
|
||||
Assert.DoesNotContain(apiKey, listJson, StringComparison.Ordinal);
|
||||
Assert.DoesNotContain(ApiKeySecret(apiKey), listJson, StringComparison.Ordinal);
|
||||
Assert.DoesNotContain("secret_hash", listJson, StringComparison.OrdinalIgnoreCase);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task RevokeKeyAsync_RevokedKeyFailsVerificationAndAudits()
|
||||
{
|
||||
await using ServiceProvider services = BuildServices(CreateTempDatabasePath());
|
||||
ApiKeyAdminCliRunner runner = services.GetRequiredService<ApiKeyAdminCliRunner>();
|
||||
string apiKey = await CreateKeyAsync(runner, "operator01");
|
||||
|
||||
await runner.RunAsync(
|
||||
new ApiKeyAdminCommand(
|
||||
Kind: ApiKeyAdminCommandKind.RevokeKey,
|
||||
Json: true,
|
||||
SqlitePath: null,
|
||||
Pepper: null,
|
||||
KeyId: "operator01",
|
||||
DisplayName: null,
|
||||
Scopes: new HashSet<string>(StringComparer.Ordinal)),
|
||||
TextWriter.Null,
|
||||
CancellationToken.None);
|
||||
|
||||
ApiKeyVerificationResult verification = await services
|
||||
.GetRequiredService<IApiKeyVerifier>()
|
||||
.VerifyAsync($"Bearer {apiKey}", CancellationToken.None);
|
||||
|
||||
Assert.False(verification.Succeeded);
|
||||
Assert.Equal(ApiKeyVerificationFailure.KeyRevoked, verification.Failure);
|
||||
|
||||
IReadOnlyList<ApiKeyAuditRecord> auditRecords = await services
|
||||
.GetRequiredService<IApiKeyAuditStore>()
|
||||
.ListRecentAsync(10, CancellationToken.None);
|
||||
|
||||
Assert.Contains(auditRecords, record => record.EventType == "revoke-key" && record.KeyId == "operator01");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task RotateKeyAsync_PrintsNewSecretOnceAndInvalidatesOldSecret()
|
||||
{
|
||||
await using ServiceProvider services = BuildServices(CreateTempDatabasePath());
|
||||
ApiKeyAdminCliRunner runner = services.GetRequiredService<ApiKeyAdminCliRunner>();
|
||||
string oldApiKey = await CreateKeyAsync(runner, "operator01");
|
||||
StringWriter rotateOutput = new();
|
||||
|
||||
await runner.RunAsync(
|
||||
new ApiKeyAdminCommand(
|
||||
Kind: ApiKeyAdminCommandKind.RotateKey,
|
||||
Json: true,
|
||||
SqlitePath: null,
|
||||
Pepper: null,
|
||||
KeyId: "operator01",
|
||||
DisplayName: null,
|
||||
Scopes: new HashSet<string>(StringComparer.Ordinal)),
|
||||
rotateOutput,
|
||||
CancellationToken.None);
|
||||
|
||||
string rotateJson = rotateOutput.ToString();
|
||||
string newApiKey = ReadApiKey(rotateJson);
|
||||
|
||||
Assert.NotEqual(oldApiKey, newApiKey);
|
||||
Assert.Equal(1, CountOccurrences(rotateJson, newApiKey));
|
||||
|
||||
IApiKeyVerifier verifier = services.GetRequiredService<IApiKeyVerifier>();
|
||||
ApiKeyVerificationResult oldVerification = await verifier.VerifyAsync($"Bearer {oldApiKey}", CancellationToken.None);
|
||||
ApiKeyVerificationResult newVerification = await verifier.VerifyAsync($"Bearer {newApiKey}", CancellationToken.None);
|
||||
|
||||
Assert.False(oldVerification.Succeeded);
|
||||
Assert.Equal(ApiKeyVerificationFailure.SecretMismatch, oldVerification.Failure);
|
||||
Assert.True(newVerification.Succeeded);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CreateKeyAsync_PrintsRawSecretExactlyOnce()
|
||||
{
|
||||
await using ServiceProvider services = BuildServices(CreateTempDatabasePath());
|
||||
ApiKeyAdminCliRunner runner = services.GetRequiredService<ApiKeyAdminCliRunner>();
|
||||
StringWriter output = new();
|
||||
|
||||
await runner.RunAsync(
|
||||
new ApiKeyAdminCommand(
|
||||
Kind: ApiKeyAdminCommandKind.CreateKey,
|
||||
Json: true,
|
||||
SqlitePath: null,
|
||||
Pepper: null,
|
||||
KeyId: "operator01",
|
||||
DisplayName: "Operator",
|
||||
Scopes: new HashSet<string>(StringComparer.Ordinal)),
|
||||
output,
|
||||
CancellationToken.None);
|
||||
|
||||
string json = output.ToString();
|
||||
string apiKey = ReadApiKey(json);
|
||||
|
||||
Assert.Equal(1, CountOccurrences(json, apiKey));
|
||||
Assert.Equal(1, CountOccurrences(json, ApiKeySecret(apiKey)));
|
||||
}
|
||||
|
||||
private static async Task<string> CreateKeyAsync(ApiKeyAdminCliRunner runner, string keyId)
|
||||
{
|
||||
StringWriter output = new();
|
||||
await runner.RunAsync(
|
||||
new ApiKeyAdminCommand(
|
||||
Kind: ApiKeyAdminCommandKind.CreateKey,
|
||||
Json: true,
|
||||
SqlitePath: null,
|
||||
Pepper: null,
|
||||
KeyId: keyId,
|
||||
DisplayName: "Operator",
|
||||
Scopes: new HashSet<string>(StringComparer.Ordinal) { "session:open" }),
|
||||
output,
|
||||
CancellationToken.None);
|
||||
|
||||
return ReadApiKey(output.ToString());
|
||||
}
|
||||
|
||||
private static ServiceProvider BuildServices(string databasePath)
|
||||
{
|
||||
IConfigurationRoot configuration = new ConfigurationBuilder()
|
||||
.AddInMemoryCollection(
|
||||
new Dictionary<string, string?>
|
||||
{
|
||||
["MxGateway:Authentication:SqlitePath"] = databasePath,
|
||||
["MxGateway:ApiKeyPepper"] = "test-pepper"
|
||||
})
|
||||
.Build();
|
||||
|
||||
ServiceCollection services = new();
|
||||
services.AddSingleton<IConfiguration>(configuration);
|
||||
services.AddGatewayConfiguration();
|
||||
services.AddSqliteAuthStore();
|
||||
|
||||
return services.BuildServiceProvider(validateScopes: true);
|
||||
}
|
||||
|
||||
private static string CreateTempDatabasePath()
|
||||
{
|
||||
string directory = Path.Combine(Path.GetTempPath(), "mxgateway-auth-cli-tests", Guid.NewGuid().ToString("N"));
|
||||
Directory.CreateDirectory(directory);
|
||||
|
||||
return Path.Combine(directory, "gateway-auth.db");
|
||||
}
|
||||
|
||||
private static string ReadApiKey(string json)
|
||||
{
|
||||
using JsonDocument document = JsonDocument.Parse(json);
|
||||
|
||||
return document.RootElement.GetProperty("ApiKey").GetString()
|
||||
?? throw new InvalidOperationException("API key was not present in command output.");
|
||||
}
|
||||
|
||||
private static string ApiKeySecret(string apiKey)
|
||||
{
|
||||
string[] parts = apiKey.Split('_', 3);
|
||||
|
||||
return parts[2];
|
||||
}
|
||||
|
||||
private static int CountOccurrences(string value, string pattern)
|
||||
{
|
||||
int count = 0;
|
||||
int index = 0;
|
||||
|
||||
while ((index = value.IndexOf(pattern, index, StringComparison.Ordinal)) >= 0)
|
||||
{
|
||||
count++;
|
||||
index += pattern.Length;
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
using MxGateway.Server.Security.Authentication;
|
||||
|
||||
namespace MxGateway.Tests.Security.Authentication;
|
||||
|
||||
public sealed class ApiKeyAdminCommandLineParserTests
|
||||
{
|
||||
[Fact]
|
||||
public void Parse_NonApiKeyCommand_ReturnsNotApiKeyCommand()
|
||||
{
|
||||
ApiKeyAdminParseResult result = ApiKeyAdminCommandLineParser.Parse(["--urls=http://localhost:5000"]);
|
||||
|
||||
Assert.False(result.IsApiKeyCommand);
|
||||
Assert.Null(result.Command);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Parse_CreateKeyCommand_ReturnsOptions()
|
||||
{
|
||||
ApiKeyAdminParseResult result = ApiKeyAdminCommandLineParser.Parse(
|
||||
[
|
||||
"apikey",
|
||||
"create-key",
|
||||
"--key-id",
|
||||
"operator01",
|
||||
"--display-name",
|
||||
"Operator",
|
||||
"--scopes",
|
||||
"session:open,events:read",
|
||||
"--sqlite-path",
|
||||
"auth.db",
|
||||
"--pepper",
|
||||
"pepper",
|
||||
"--json"
|
||||
]);
|
||||
|
||||
Assert.True(result.IsApiKeyCommand);
|
||||
Assert.Null(result.Error);
|
||||
Assert.NotNull(result.Command);
|
||||
Assert.Equal(ApiKeyAdminCommandKind.CreateKey, result.Command.Kind);
|
||||
Assert.True(result.Command.Json);
|
||||
Assert.Equal("operator01", result.Command.KeyId);
|
||||
Assert.Equal("Operator", result.Command.DisplayName);
|
||||
Assert.Equal("auth.db", result.Command.SqlitePath);
|
||||
Assert.Equal("pepper", result.Command.Pepper);
|
||||
Assert.Contains("session:open", result.Command.Scopes);
|
||||
Assert.Contains("events:read", result.Command.Scopes);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Parse_CreateKeyWithoutDisplayName_ReturnsError()
|
||||
{
|
||||
ApiKeyAdminParseResult result = ApiKeyAdminCommandLineParser.Parse(
|
||||
["apikey", "create-key", "--key-id", "operator01"]);
|
||||
|
||||
Assert.True(result.IsApiKeyCommand);
|
||||
Assert.Null(result.Command);
|
||||
Assert.Contains("--display-name", result.Error, StringComparison.Ordinal);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Parse_KeyIdWithUnderscore_ReturnsError()
|
||||
{
|
||||
ApiKeyAdminParseResult result = ApiKeyAdminCommandLineParser.Parse(
|
||||
["apikey", "revoke-key", "--key-id", "operator_01"]);
|
||||
|
||||
Assert.True(result.IsApiKeyCommand);
|
||||
Assert.Null(result.Command);
|
||||
Assert.Contains("letters, numbers, periods, and hyphens", result.Error, StringComparison.Ordinal);
|
||||
}
|
||||
}
|
||||
+267
@@ -0,0 +1,267 @@
|
||||
using Grpc.Core;
|
||||
using Microsoft.Extensions.Options;
|
||||
using MxGateway.Contracts.Proto;
|
||||
using MxGateway.Server.Configuration;
|
||||
using MxGateway.Server.Security.Authentication;
|
||||
using MxGateway.Server.Security.Authorization;
|
||||
|
||||
namespace MxGateway.Tests.Security.Authorization;
|
||||
|
||||
public sealed class GatewayGrpcAuthorizationInterceptorTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task UnaryServerHandler_MissingApiKey_ReturnsUnauthenticated()
|
||||
{
|
||||
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
|
||||
new FakeApiKeyVerifier(ApiKeyVerificationResult.Fail(
|
||||
ApiKeyVerificationFailure.MissingOrMalformedCredentials)),
|
||||
new GatewayRequestIdentityAccessor());
|
||||
|
||||
RpcException exception = await Assert.ThrowsAsync<RpcException>(
|
||||
() => interceptor.UnaryServerHandler(
|
||||
new OpenSessionRequest(),
|
||||
new TestServerCallContext([]),
|
||||
(_, _) => Task.FromResult(new OpenSessionReply())));
|
||||
|
||||
Assert.Equal(StatusCode.Unauthenticated, exception.StatusCode);
|
||||
Assert.DoesNotContain("secret", exception.Status.Detail, StringComparison.OrdinalIgnoreCase);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task UnaryServerHandler_InvalidApiKey_DoesNotExposeRawCredentialInStatus()
|
||||
{
|
||||
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
|
||||
new FakeApiKeyVerifier(ApiKeyVerificationResult.Fail(ApiKeyVerificationFailure.SecretMismatch)),
|
||||
new GatewayRequestIdentityAccessor());
|
||||
|
||||
RpcException exception = await Assert.ThrowsAsync<RpcException>(
|
||||
() => interceptor.UnaryServerHandler(
|
||||
new OpenSessionRequest(),
|
||||
ContextWithAuthorization("Bearer mxgw_operator01_super-secret"),
|
||||
(_, _) => Task.FromResult(new OpenSessionReply())));
|
||||
|
||||
Assert.Equal(StatusCode.Unauthenticated, exception.StatusCode);
|
||||
Assert.DoesNotContain("super-secret", exception.Status.Detail, StringComparison.Ordinal);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task UnaryServerHandler_ValidApiKeyMissingScope_ReturnsPermissionDenied()
|
||||
{
|
||||
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
|
||||
new FakeApiKeyVerifier(SuccessWithScopes(GatewayScopes.EventsRead)),
|
||||
new GatewayRequestIdentityAccessor());
|
||||
|
||||
RpcException exception = await Assert.ThrowsAsync<RpcException>(
|
||||
() => interceptor.UnaryServerHandler(
|
||||
new OpenSessionRequest(),
|
||||
ContextWithAuthorization("Bearer mxgw_operator01_secret"),
|
||||
(_, _) => Task.FromResult(new OpenSessionReply())));
|
||||
|
||||
Assert.Equal(StatusCode.PermissionDenied, exception.StatusCode);
|
||||
Assert.Contains(GatewayScopes.SessionOpen, exception.Status.Detail, StringComparison.Ordinal);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task UnaryServerHandler_ValidApiKeyWithScope_SetsRequestIdentity()
|
||||
{
|
||||
GatewayRequestIdentityAccessor identityAccessor = new();
|
||||
ApiKeyIdentity? identitySeenByHandler = null;
|
||||
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
|
||||
new FakeApiKeyVerifier(SuccessWithScopes(GatewayScopes.SessionOpen)),
|
||||
identityAccessor);
|
||||
|
||||
OpenSessionReply reply = await interceptor.UnaryServerHandler(
|
||||
new OpenSessionRequest(),
|
||||
ContextWithAuthorization("Bearer mxgw_operator01_secret"),
|
||||
(_, _) =>
|
||||
{
|
||||
identitySeenByHandler = identityAccessor.Current;
|
||||
|
||||
return Task.FromResult(new OpenSessionReply { SessionId = "session-1" });
|
||||
});
|
||||
|
||||
Assert.Equal("session-1", reply.SessionId);
|
||||
Assert.NotNull(identitySeenByHandler);
|
||||
Assert.Equal("operator01", identitySeenByHandler.KeyId);
|
||||
Assert.Null(identityAccessor.Current);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ServerStreamingServerHandler_ValidApiKeyMissingScope_ReturnsPermissionDenied()
|
||||
{
|
||||
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
|
||||
new FakeApiKeyVerifier(SuccessWithScopes(GatewayScopes.SessionOpen)),
|
||||
new GatewayRequestIdentityAccessor());
|
||||
|
||||
RpcException exception = await Assert.ThrowsAsync<RpcException>(
|
||||
() => interceptor.ServerStreamingServerHandler(
|
||||
new StreamEventsRequest(),
|
||||
new TestServerStreamWriter<MxEvent>(),
|
||||
ContextWithAuthorization("Bearer mxgw_operator01_secret"),
|
||||
(_, _, _) => Task.CompletedTask));
|
||||
|
||||
Assert.Equal(StatusCode.PermissionDenied, exception.StatusCode);
|
||||
Assert.Contains(GatewayScopes.EventsRead, exception.Status.Detail, StringComparison.Ordinal);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ServerStreamingServerHandler_ValidApiKeyWithScope_AllowsStream()
|
||||
{
|
||||
GatewayRequestIdentityAccessor identityAccessor = new();
|
||||
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
|
||||
new FakeApiKeyVerifier(SuccessWithScopes(GatewayScopes.EventsRead)),
|
||||
identityAccessor);
|
||||
TestServerStreamWriter<MxEvent> streamWriter = new();
|
||||
|
||||
await interceptor.ServerStreamingServerHandler(
|
||||
new StreamEventsRequest(),
|
||||
streamWriter,
|
||||
ContextWithAuthorization("Bearer mxgw_operator01_secret"),
|
||||
async (_, writer, _) =>
|
||||
{
|
||||
Assert.Equal("operator01", identityAccessor.Current?.KeyId);
|
||||
await writer.WriteAsync(new MxEvent { SessionId = "session-1" });
|
||||
});
|
||||
|
||||
MxEvent eventMessage = Assert.Single(streamWriter.Messages);
|
||||
Assert.Equal("session-1", eventMessage.SessionId);
|
||||
Assert.Null(identityAccessor.Current);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task UnaryServerHandler_AuthenticationDisabled_SkipsApiKeyVerification()
|
||||
{
|
||||
GatewayRequestIdentityAccessor identityAccessor = new();
|
||||
FakeApiKeyVerifier verifier = new(ApiKeyVerificationResult.Fail(
|
||||
ApiKeyVerificationFailure.MissingOrMalformedCredentials));
|
||||
GatewayGrpcAuthorizationInterceptor interceptor = CreateInterceptor(
|
||||
verifier,
|
||||
identityAccessor,
|
||||
AuthenticationMode.Disabled);
|
||||
|
||||
OpenSessionReply reply = await interceptor.UnaryServerHandler(
|
||||
new OpenSessionRequest(),
|
||||
new TestServerCallContext([]),
|
||||
(_, _) => Task.FromResult(new OpenSessionReply { SessionId = "session-1" }));
|
||||
|
||||
Assert.Equal("session-1", reply.SessionId);
|
||||
Assert.False(verifier.WasCalled);
|
||||
Assert.Null(identityAccessor.Current);
|
||||
}
|
||||
|
||||
private static GatewayGrpcAuthorizationInterceptor CreateInterceptor(
|
||||
IApiKeyVerifier apiKeyVerifier,
|
||||
IGatewayRequestIdentityAccessor identityAccessor,
|
||||
AuthenticationMode authenticationMode = AuthenticationMode.ApiKey)
|
||||
{
|
||||
return new GatewayGrpcAuthorizationInterceptor(
|
||||
apiKeyVerifier,
|
||||
new GatewayGrpcScopeResolver(),
|
||||
identityAccessor,
|
||||
Options.Create(new GatewayOptions
|
||||
{
|
||||
Authentication = new AuthenticationOptions
|
||||
{
|
||||
Mode = authenticationMode
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
private static ApiKeyVerificationResult SuccessWithScopes(params string[] scopes)
|
||||
{
|
||||
return ApiKeyVerificationResult.Success(new ApiKeyIdentity(
|
||||
KeyId: "operator01",
|
||||
KeyPrefix: "mxgw_operator01",
|
||||
DisplayName: "Operator Key",
|
||||
Scopes: new HashSet<string>(scopes, StringComparer.Ordinal)));
|
||||
}
|
||||
|
||||
private static TestServerCallContext ContextWithAuthorization(string authorizationHeader)
|
||||
{
|
||||
return new TestServerCallContext([new Metadata.Entry("authorization", authorizationHeader)]);
|
||||
}
|
||||
|
||||
private sealed class FakeApiKeyVerifier(ApiKeyVerificationResult result) : IApiKeyVerifier
|
||||
{
|
||||
public bool WasCalled { get; private set; }
|
||||
|
||||
public string? LastAuthorizationHeader { get; private set; }
|
||||
|
||||
public Task<ApiKeyVerificationResult> VerifyAsync(
|
||||
string? authorizationHeader,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
WasCalled = true;
|
||||
LastAuthorizationHeader = authorizationHeader;
|
||||
|
||||
return Task.FromResult(result);
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class TestServerStreamWriter<T> : IServerStreamWriter<T>
|
||||
{
|
||||
public List<T> Messages { get; } = [];
|
||||
|
||||
public WriteOptions? WriteOptions { get; set; }
|
||||
|
||||
public Task WriteAsync(T message)
|
||||
{
|
||||
Messages.Add(message);
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class TestServerCallContext(
|
||||
Metadata requestHeaders,
|
||||
CancellationToken cancellationToken = default) : ServerCallContext
|
||||
{
|
||||
private readonly Metadata responseTrailers = [];
|
||||
private readonly Dictionary<object, object> userState = [];
|
||||
private Status status;
|
||||
private WriteOptions? writeOptions;
|
||||
|
||||
protected override string MethodCore => "/mxaccess_gateway.v1.MxAccessGateway/Test";
|
||||
|
||||
protected override string HostCore => "localhost";
|
||||
|
||||
protected override string PeerCore => "ipv4:127.0.0.1:5000";
|
||||
|
||||
protected override DateTime DeadlineCore => DateTime.UtcNow.AddMinutes(1);
|
||||
|
||||
protected override Metadata RequestHeadersCore => requestHeaders;
|
||||
|
||||
protected override CancellationToken CancellationTokenCore => cancellationToken;
|
||||
|
||||
protected override Metadata ResponseTrailersCore => responseTrailers;
|
||||
|
||||
protected override Status StatusCore
|
||||
{
|
||||
get => status;
|
||||
set => status = value;
|
||||
}
|
||||
|
||||
protected override WriteOptions? WriteOptionsCore
|
||||
{
|
||||
get => writeOptions;
|
||||
set => writeOptions = value;
|
||||
}
|
||||
|
||||
protected override AuthContext AuthContextCore { get; } = new(
|
||||
string.Empty,
|
||||
new Dictionary<string, List<AuthProperty>>(StringComparer.Ordinal));
|
||||
|
||||
protected override IDictionary<object, object> UserStateCore => userState;
|
||||
|
||||
protected override Task WriteResponseHeadersAsyncCore(Metadata responseHeaders)
|
||||
{
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
protected override ContextPropagationToken CreatePropagationTokenCore(
|
||||
ContextPropagationOptions? options)
|
||||
{
|
||||
throw new NotSupportedException();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
using MxGateway.Contracts.Proto;
|
||||
using MxGateway.Server.Security.Authorization;
|
||||
|
||||
namespace MxGateway.Tests.Security.Authorization;
|
||||
|
||||
public sealed class GatewayGrpcScopeResolverTests
|
||||
{
|
||||
[Theory]
|
||||
[InlineData(typeof(OpenSessionRequest), GatewayScopes.SessionOpen)]
|
||||
[InlineData(typeof(CloseSessionRequest), GatewayScopes.SessionClose)]
|
||||
[InlineData(typeof(StreamEventsRequest), GatewayScopes.EventsRead)]
|
||||
public void ResolveRequiredScope_KnownRpcRequest_ReturnsExpectedScope(
|
||||
Type requestType,
|
||||
string expectedScope)
|
||||
{
|
||||
GatewayGrpcScopeResolver resolver = new();
|
||||
object request = Activator.CreateInstance(requestType)!;
|
||||
|
||||
string scope = resolver.ResolveRequiredScope(request);
|
||||
|
||||
Assert.Equal(expectedScope, scope);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(MxCommandKind.Register, GatewayScopes.InvokeRead)]
|
||||
[InlineData(MxCommandKind.AddItem, GatewayScopes.InvokeRead)]
|
||||
[InlineData(MxCommandKind.Advise, GatewayScopes.InvokeRead)]
|
||||
[InlineData(MxCommandKind.Write, GatewayScopes.InvokeWrite)]
|
||||
[InlineData(MxCommandKind.Write2, GatewayScopes.InvokeWrite)]
|
||||
[InlineData(MxCommandKind.WriteSecured, GatewayScopes.InvokeSecure)]
|
||||
[InlineData(MxCommandKind.WriteSecured2, GatewayScopes.InvokeSecure)]
|
||||
[InlineData(MxCommandKind.AuthenticateUser, GatewayScopes.InvokeSecure)]
|
||||
[InlineData(MxCommandKind.ArchestraUserToId, GatewayScopes.MetadataRead)]
|
||||
[InlineData(MxCommandKind.GetSessionState, GatewayScopes.MetadataRead)]
|
||||
[InlineData(MxCommandKind.GetWorkerInfo, GatewayScopes.MetadataRead)]
|
||||
[InlineData(MxCommandKind.DrainEvents, GatewayScopes.EventsRead)]
|
||||
[InlineData(MxCommandKind.ShutdownWorker, GatewayScopes.Admin)]
|
||||
public void ResolveRequiredScope_InvokeCommand_ReturnsExpectedScope(
|
||||
MxCommandKind commandKind,
|
||||
string expectedScope)
|
||||
{
|
||||
GatewayGrpcScopeResolver resolver = new();
|
||||
|
||||
string scope = resolver.ResolveRequiredScope(new MxCommandRequest
|
||||
{
|
||||
Command = new MxCommand
|
||||
{
|
||||
Kind = commandKind
|
||||
}
|
||||
});
|
||||
|
||||
Assert.Equal(expectedScope, scope);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using MxGateway.Worker.Bootstrap;
|
||||
|
||||
namespace MxGateway.Worker.Tests.Bootstrap;
|
||||
|
||||
internal sealed class MemoryWorkerEnvironment : IWorkerEnvironment
|
||||
{
|
||||
private readonly Dictionary<string, string> _values = new();
|
||||
private readonly Exception? _exception;
|
||||
|
||||
public MemoryWorkerEnvironment()
|
||||
{
|
||||
}
|
||||
|
||||
public MemoryWorkerEnvironment(Exception exception)
|
||||
{
|
||||
_exception = exception;
|
||||
}
|
||||
|
||||
public void Set(string name, string value)
|
||||
{
|
||||
_values[name] = value;
|
||||
}
|
||||
|
||||
public string? GetEnvironmentVariable(string name)
|
||||
{
|
||||
if (_exception is not null)
|
||||
{
|
||||
throw _exception;
|
||||
}
|
||||
|
||||
return _values.TryGetValue(name, out string value)
|
||||
? value
|
||||
: null;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace MxGateway.Worker.Tests.Bootstrap;
|
||||
|
||||
internal sealed class MemoryWorkerLogEntry
|
||||
{
|
||||
public MemoryWorkerLogEntry(
|
||||
string level,
|
||||
string eventName,
|
||||
IReadOnlyDictionary<string, object?> fields)
|
||||
{
|
||||
Level = level;
|
||||
EventName = eventName;
|
||||
Fields = fields;
|
||||
}
|
||||
|
||||
public string Level { get; }
|
||||
|
||||
public string EventName { get; }
|
||||
|
||||
public IReadOnlyDictionary<string, object?> Fields { get; }
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
using System.Collections.Generic;
|
||||
using MxGateway.Worker.Bootstrap;
|
||||
|
||||
namespace MxGateway.Worker.Tests.Bootstrap;
|
||||
|
||||
internal sealed class MemoryWorkerLogger : IWorkerLogger
|
||||
{
|
||||
public List<MemoryWorkerLogEntry> Entries { get; } = new();
|
||||
|
||||
public void Information(string eventName, IReadOnlyDictionary<string, object?> fields)
|
||||
{
|
||||
Entries.Add(new MemoryWorkerLogEntry("Information", eventName, WorkerLogRedactor.RedactFields(fields)));
|
||||
}
|
||||
|
||||
public void Error(string eventName, IReadOnlyDictionary<string, object?> fields)
|
||||
{
|
||||
Entries.Add(new MemoryWorkerLogEntry("Error", eventName, WorkerLogRedactor.RedactFields(fields)));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
using System;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using MxGateway.Contracts;
|
||||
using MxGateway.Worker.Bootstrap;
|
||||
using MxGateway.Worker.Ipc;
|
||||
|
||||
namespace MxGateway.Worker.Tests.Bootstrap;
|
||||
|
||||
public sealed class WorkerApplicationTests
|
||||
{
|
||||
[Fact]
|
||||
public void Run_WithValidBootstrapArguments_ReturnsSuccessAndLogsRedactedNonce()
|
||||
{
|
||||
MemoryWorkerEnvironment environment = CreateEnvironment("nonce-secret");
|
||||
MemoryWorkerLogger logger = new();
|
||||
|
||||
int exitCode = MxGateway.Worker.WorkerApplication.Run(
|
||||
ValidArgs(),
|
||||
environment,
|
||||
logger,
|
||||
new SucceedingPipeClient());
|
||||
|
||||
Assert.Equal((int)WorkerExitCode.Success, exitCode);
|
||||
Assert.Equal(2, logger.Entries.Count);
|
||||
MemoryWorkerLogEntry entry = logger.Entries[0];
|
||||
Assert.Equal("Information", entry.Level);
|
||||
Assert.Equal("WorkerBootstrapSucceeded", entry.EventName);
|
||||
Assert.Equal("session-1", entry.Fields["session_id"]);
|
||||
Assert.Equal("mxaccess-gateway-123-session-1", entry.Fields["pipe_name"]);
|
||||
Assert.Equal(GatewayContractInfo.WorkerProtocolVersion, entry.Fields["protocol_version"]);
|
||||
Assert.Equal("[redacted]", entry.Fields["nonce"]);
|
||||
Assert.Equal("WorkerPipeHandshakeSucceeded", logger.Entries[1].EventName);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Run_WithMissingRequiredArguments_ReturnsInvalidArguments()
|
||||
{
|
||||
MemoryWorkerEnvironment environment = CreateEnvironment("nonce-secret");
|
||||
MemoryWorkerLogger logger = new();
|
||||
|
||||
int exitCode = MxGateway.Worker.WorkerApplication.Run(
|
||||
[],
|
||||
environment,
|
||||
logger);
|
||||
|
||||
Assert.Equal((int)WorkerExitCode.InvalidArguments, exitCode);
|
||||
MemoryWorkerLogEntry entry = Assert.Single(logger.Entries);
|
||||
Assert.Equal("Error", entry.Level);
|
||||
Assert.Equal("WorkerBootstrapFailed", entry.EventName);
|
||||
Assert.Equal(WorkerExitCode.InvalidArguments, entry.Fields["exit_code"]);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Run_WithInvalidProtocolVersion_ReturnsInvalidProtocolVersion()
|
||||
{
|
||||
MemoryWorkerEnvironment environment = CreateEnvironment("nonce-secret");
|
||||
MemoryWorkerLogger logger = new();
|
||||
|
||||
int exitCode = MxGateway.Worker.WorkerApplication.Run(
|
||||
ValidArgs(protocolVersion: "999"),
|
||||
environment,
|
||||
logger);
|
||||
|
||||
Assert.Equal((int)WorkerExitCode.InvalidProtocolVersion, exitCode);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Run_WithMissingNonce_ReturnsMissingNonce()
|
||||
{
|
||||
MemoryWorkerEnvironment environment = new();
|
||||
MemoryWorkerLogger logger = new();
|
||||
|
||||
int exitCode = MxGateway.Worker.WorkerApplication.Run(
|
||||
ValidArgs(),
|
||||
environment,
|
||||
logger);
|
||||
|
||||
Assert.Equal((int)WorkerExitCode.MissingNonce, exitCode);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Run_WithPipeProtocolFailure_ReturnsProtocolViolation()
|
||||
{
|
||||
MemoryWorkerEnvironment environment = CreateEnvironment("nonce-secret");
|
||||
MemoryWorkerLogger logger = new();
|
||||
|
||||
int exitCode = MxGateway.Worker.WorkerApplication.Run(
|
||||
ValidArgs(),
|
||||
environment,
|
||||
logger,
|
||||
new ThrowingPipeClient(new WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode.NonceMismatch,
|
||||
"Bad nonce.")));
|
||||
|
||||
Assert.Equal((int)WorkerExitCode.ProtocolViolation, exitCode);
|
||||
Assert.Equal("WorkerPipeProtocolFailure", logger.Entries[1].EventName);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Run_WithUnexpectedBootstrapFailure_ReturnsUnexpectedFailure()
|
||||
{
|
||||
MemoryWorkerEnvironment environment = new(new InvalidOperationException("environment failed"));
|
||||
MemoryWorkerLogger logger = new();
|
||||
|
||||
int exitCode = MxGateway.Worker.WorkerApplication.Run(
|
||||
ValidArgs(),
|
||||
environment,
|
||||
logger);
|
||||
|
||||
Assert.Equal((int)WorkerExitCode.UnexpectedFailure, exitCode);
|
||||
MemoryWorkerLogEntry entry = Assert.Single(logger.Entries);
|
||||
Assert.Equal("WorkerBootstrapUnexpectedFailure", entry.EventName);
|
||||
Assert.Equal(WorkerExitCode.UnexpectedFailure, entry.Fields["exit_code"]);
|
||||
Assert.Equal(typeof(InvalidOperationException).FullName, entry.Fields["exception_type"]);
|
||||
}
|
||||
|
||||
private static string[] ValidArgs(string? protocolVersion = null)
|
||||
{
|
||||
return
|
||||
[
|
||||
"--session-id",
|
||||
"session-1",
|
||||
"--pipe-name",
|
||||
"mxaccess-gateway-123-session-1",
|
||||
"--protocol-version",
|
||||
protocolVersion ?? GatewayContractInfo.WorkerProtocolVersion.ToString(),
|
||||
];
|
||||
}
|
||||
|
||||
private static MemoryWorkerEnvironment CreateEnvironment(string nonce)
|
||||
{
|
||||
MemoryWorkerEnvironment environment = new();
|
||||
environment.Set(WorkerOptions.NonceEnvironmentVariableName, nonce);
|
||||
return environment;
|
||||
}
|
||||
|
||||
private sealed class SucceedingPipeClient : IWorkerPipeClient
|
||||
{
|
||||
public Task RunAsync(
|
||||
WorkerOptions options,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class ThrowingPipeClient : IWorkerPipeClient
|
||||
{
|
||||
private readonly Exception _exception;
|
||||
|
||||
public ThrowingPipeClient(Exception exception)
|
||||
{
|
||||
_exception = exception;
|
||||
}
|
||||
|
||||
public Task RunAsync(
|
||||
WorkerOptions options,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
throw _exception;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using MxGateway.Worker.Bootstrap;
|
||||
|
||||
namespace MxGateway.Worker.Tests.Bootstrap;
|
||||
|
||||
public sealed class WorkerConsoleLoggerTests
|
||||
{
|
||||
[Fact]
|
||||
public void Information_RedactsNonceInStructuredOutput()
|
||||
{
|
||||
StringWriter writer = new();
|
||||
WorkerConsoleLogger logger = new(writer);
|
||||
|
||||
logger.Information("WorkerBootstrapSucceeded", new Dictionary<string, object?>
|
||||
{
|
||||
["session_id"] = "session-1",
|
||||
["nonce"] = "nonce-secret",
|
||||
});
|
||||
|
||||
string output = writer.ToString();
|
||||
|
||||
Assert.Contains("event=WorkerBootstrapSucceeded", output);
|
||||
Assert.Contains("session_id=session-1", output);
|
||||
Assert.Contains("nonce=[redacted]", output);
|
||||
Assert.DoesNotContain("nonce-secret", output);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
using System.Collections.Generic;
|
||||
using MxGateway.Worker.Bootstrap;
|
||||
|
||||
namespace MxGateway.Worker.Tests.Bootstrap;
|
||||
|
||||
public sealed class WorkerLogRedactorTests
|
||||
{
|
||||
[Fact]
|
||||
public void RedactFields_RedactsNonceSecretPasswordTokenCredentialAndApiKeyFields()
|
||||
{
|
||||
Dictionary<string, object?> fields = new()
|
||||
{
|
||||
["nonce"] = "nonce-secret",
|
||||
["client_secret"] = "secret",
|
||||
["password"] = "password",
|
||||
["auth_token"] = "token",
|
||||
["credential_value"] = "credential",
|
||||
["api_key"] = "key",
|
||||
["session_id"] = "session-1",
|
||||
};
|
||||
|
||||
Dictionary<string, object?> redacted = WorkerLogRedactor.RedactFields(fields);
|
||||
|
||||
Assert.Equal("[redacted]", redacted["nonce"]);
|
||||
Assert.Equal("[redacted]", redacted["client_secret"]);
|
||||
Assert.Equal("[redacted]", redacted["password"]);
|
||||
Assert.Equal("[redacted]", redacted["auth_token"]);
|
||||
Assert.Equal("[redacted]", redacted["credential_value"]);
|
||||
Assert.Equal("[redacted]", redacted["api_key"]);
|
||||
Assert.Equal("session-1", redacted["session_id"]);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
using MxGateway.Contracts;
|
||||
using MxGateway.Worker.Bootstrap;
|
||||
|
||||
namespace MxGateway.Worker.Tests.Bootstrap;
|
||||
|
||||
public sealed class WorkerOptionsParserTests
|
||||
{
|
||||
[Fact]
|
||||
public void Parse_WithAllRequiredInputs_ReturnsWorkerOptions()
|
||||
{
|
||||
WorkerOptionsParser parser = new(CreateEnvironment("nonce-secret"));
|
||||
|
||||
WorkerBootstrapResult result = parser.Parse(ValidArgs());
|
||||
|
||||
Assert.True(result.Succeeded);
|
||||
Assert.Equal(WorkerExitCode.Success, result.ExitCode);
|
||||
Assert.NotNull(result.Options);
|
||||
Assert.Equal("session-1", result.Options.SessionId);
|
||||
Assert.Equal("mxaccess-gateway-123-session-1", result.Options.PipeName);
|
||||
Assert.Equal(GatewayContractInfo.WorkerProtocolVersion, result.Options.ProtocolVersion);
|
||||
Assert.Equal("nonce-secret", result.Options.Nonce);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Parse_WithMissingSessionId_ReturnsInvalidArguments()
|
||||
{
|
||||
WorkerOptionsParser parser = new(CreateEnvironment("nonce-secret"));
|
||||
|
||||
WorkerBootstrapResult result = parser.Parse(
|
||||
[
|
||||
"--pipe-name",
|
||||
"mxaccess-gateway-123-session-1",
|
||||
"--protocol-version",
|
||||
GatewayContractInfo.WorkerProtocolVersion.ToString(),
|
||||
]);
|
||||
|
||||
Assert.False(result.Succeeded);
|
||||
Assert.Equal(WorkerExitCode.InvalidArguments, result.ExitCode);
|
||||
Assert.Contains(result.Errors, error => error.Contains("--session-id"));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Parse_WithUnknownOption_ReturnsInvalidArguments()
|
||||
{
|
||||
WorkerOptionsParser parser = new(CreateEnvironment("nonce-secret"));
|
||||
|
||||
WorkerBootstrapResult result = parser.Parse(
|
||||
[
|
||||
"--session-id",
|
||||
"session-1",
|
||||
"--pipe-name",
|
||||
"mxaccess-gateway-123-session-1",
|
||||
"--protocol-version",
|
||||
GatewayContractInfo.WorkerProtocolVersion.ToString(),
|
||||
"--unexpected",
|
||||
"value",
|
||||
]);
|
||||
|
||||
Assert.Equal(WorkerExitCode.InvalidArguments, result.ExitCode);
|
||||
Assert.Contains(result.Errors, error => error.Contains("Unknown option"));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Parse_WithNonNumericProtocolVersion_ReturnsInvalidProtocolVersion()
|
||||
{
|
||||
WorkerOptionsParser parser = new(CreateEnvironment("nonce-secret"));
|
||||
|
||||
WorkerBootstrapResult result = parser.Parse(ValidArgs(protocolVersion: "abc"));
|
||||
|
||||
Assert.False(result.Succeeded);
|
||||
Assert.Equal(WorkerExitCode.InvalidProtocolVersion, result.ExitCode);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Parse_WithUnsupportedProtocolVersion_ReturnsInvalidProtocolVersion()
|
||||
{
|
||||
WorkerOptionsParser parser = new(CreateEnvironment("nonce-secret"));
|
||||
|
||||
WorkerBootstrapResult result = parser.Parse(ValidArgs(protocolVersion: "999"));
|
||||
|
||||
Assert.False(result.Succeeded);
|
||||
Assert.Equal(WorkerExitCode.InvalidProtocolVersion, result.ExitCode);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Parse_WithMissingNonce_ReturnsMissingNonce()
|
||||
{
|
||||
WorkerOptionsParser parser = new(new MemoryWorkerEnvironment());
|
||||
|
||||
WorkerBootstrapResult result = parser.Parse(ValidArgs());
|
||||
|
||||
Assert.False(result.Succeeded);
|
||||
Assert.Equal(WorkerExitCode.MissingNonce, result.ExitCode);
|
||||
}
|
||||
|
||||
private static string[] ValidArgs(string? protocolVersion = null)
|
||||
{
|
||||
return
|
||||
[
|
||||
"--session-id",
|
||||
"session-1",
|
||||
"--pipe-name",
|
||||
"mxaccess-gateway-123-session-1",
|
||||
"--protocol-version",
|
||||
protocolVersion ?? GatewayContractInfo.WorkerProtocolVersion.ToString(),
|
||||
];
|
||||
}
|
||||
|
||||
private static MemoryWorkerEnvironment CreateEnvironment(string nonce)
|
||||
{
|
||||
MemoryWorkerEnvironment environment = new();
|
||||
environment.Set(WorkerOptions.NonceEnvironmentVariableName, nonce);
|
||||
return environment;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,183 @@
|
||||
using System;
|
||||
using Google.Protobuf;
|
||||
using MxGateway.Contracts.Proto;
|
||||
using MxGateway.Worker.Bootstrap;
|
||||
using MxGateway.Worker.Conversion;
|
||||
using ProtobufTimestamp = Google.Protobuf.WellKnownTypes.Timestamp;
|
||||
|
||||
namespace MxGateway.Worker.Tests.Conversion;
|
||||
|
||||
public sealed class VariantConverterTests
|
||||
{
|
||||
private readonly VariantConverter _converter = new();
|
||||
|
||||
[Theory]
|
||||
[InlineData(true, MxDataType.Boolean, MxValue.KindOneofCase.BoolValue)]
|
||||
[InlineData(42, MxDataType.Integer, MxValue.KindOneofCase.Int32Value)]
|
||||
[InlineData(42L, MxDataType.Integer, MxValue.KindOneofCase.Int64Value)]
|
||||
[InlineData(1.25f, MxDataType.Float, MxValue.KindOneofCase.FloatValue)]
|
||||
[InlineData(2.5d, MxDataType.Double, MxValue.KindOneofCase.DoubleValue)]
|
||||
[InlineData("value", MxDataType.String, MxValue.KindOneofCase.StringValue)]
|
||||
public void Convert_WithSupportedScalar_ProjectsTypedValue(
|
||||
object value,
|
||||
MxDataType expectedDataType,
|
||||
MxValue.KindOneofCase expectedKind)
|
||||
{
|
||||
MxValue converted = _converter.Convert(value);
|
||||
|
||||
Assert.Equal(expectedDataType, converted.DataType);
|
||||
Assert.Equal(expectedKind, converted.KindCase);
|
||||
Assert.False(string.IsNullOrWhiteSpace(converted.VariantType));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Convert_WithDateTime_ProjectsTimestamp()
|
||||
{
|
||||
DateTime dateTime = new(2026, 4, 26, 17, 45, 0, DateTimeKind.Utc);
|
||||
|
||||
MxValue converted = _converter.Convert(dateTime);
|
||||
|
||||
Assert.Equal(MxDataType.Time, converted.DataType);
|
||||
Assert.Equal(ProtobufTimestamp.FromDateTime(dateTime), converted.TimestampValue);
|
||||
Assert.Equal("VT_DATE", converted.VariantType);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Convert_WithFileTimeAndExpectedTime_ProjectsTimestamp()
|
||||
{
|
||||
DateTime dateTime = new(2026, 4, 26, 17, 45, 0, DateTimeKind.Utc);
|
||||
|
||||
MxValue converted = _converter.Convert(dateTime.ToFileTimeUtc(), MxDataType.Time);
|
||||
|
||||
Assert.Equal(MxDataType.Time, converted.DataType);
|
||||
Assert.Equal(ProtobufTimestamp.FromDateTime(dateTime), converted.TimestampValue);
|
||||
Assert.Equal("VT_I8", converted.VariantType);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(null, "VT_EMPTY")]
|
||||
[InlineData(typeof(DBNull), "VT_NULL")]
|
||||
public void Convert_WithNullLikeValue_PreservesNull(
|
||||
object? value,
|
||||
string expectedVariantType)
|
||||
{
|
||||
object? actualValue = value is System.Type ? DBNull.Value : value;
|
||||
|
||||
MxValue converted = _converter.Convert(actualValue);
|
||||
|
||||
Assert.True(converted.IsNull);
|
||||
Assert.Equal(MxDataType.NoData, converted.DataType);
|
||||
Assert.Equal(expectedVariantType, converted.VariantType);
|
||||
Assert.Equal(MxValue.KindOneofCase.None, converted.KindCase);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ConvertArray_WithSupportedArrays_ProjectsTypedValuesAndDimensions()
|
||||
{
|
||||
MxValue bools = _converter.Convert(new[] { true, false });
|
||||
MxValue ints = _converter.Convert(new[] { 1, 2, 3 });
|
||||
MxValue floats = _converter.Convert(new[] { 1.25f, 2.5f });
|
||||
MxValue doubles = _converter.Convert(new[] { 1.25d, 2.5d });
|
||||
MxValue strings = _converter.Convert(new[] { "one", "two" });
|
||||
MxValue times = _converter.Convert(new[]
|
||||
{
|
||||
new DateTime(2026, 4, 26, 17, 45, 0, DateTimeKind.Utc),
|
||||
new DateTime(2026, 4, 26, 17, 46, 0, DateTimeKind.Utc),
|
||||
});
|
||||
|
||||
Assert.Equal(new[] { true, false }, bools.ArrayValue.BoolValues.Values);
|
||||
Assert.Equal(new[] { 1, 2, 3 }, ints.ArrayValue.Int32Values.Values);
|
||||
Assert.Equal(new[] { 1.25f, 2.5f }, floats.ArrayValue.FloatValues.Values);
|
||||
Assert.Equal(new[] { 1.25d, 2.5d }, doubles.ArrayValue.DoubleValues.Values);
|
||||
Assert.Equal(new[] { "one", "two" }, strings.ArrayValue.StringValues.Values);
|
||||
Assert.Equal(2, times.ArrayValue.TimestampValues.Values.Count);
|
||||
Assert.Equal(new uint[] { 2 }, bools.ArrayValue.Dimensions);
|
||||
Assert.Equal(MxDataType.Boolean, bools.ArrayValue.ElementDataType);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ConvertArray_WithMultidimensionalArray_PreservesRankAndDimensions()
|
||||
{
|
||||
int[,] values =
|
||||
{
|
||||
{ 1, 2, 3 },
|
||||
{ 4, 5, 6 },
|
||||
};
|
||||
|
||||
MxValue converted = _converter.Convert(values);
|
||||
|
||||
Assert.Equal(new uint[] { 2, 3 }, converted.ArrayValue.Dimensions);
|
||||
Assert.Equal(new[] { 1, 2, 3, 4, 5, 6 }, converted.ArrayValue.Int32Values.Values);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ConvertArray_WithExpectedTimeAndFileTimeValues_ProjectsTimestampArray()
|
||||
{
|
||||
DateTime first = new(2026, 4, 26, 17, 45, 0, DateTimeKind.Utc);
|
||||
DateTime second = new(2026, 4, 26, 17, 46, 0, DateTimeKind.Utc);
|
||||
|
||||
MxValue converted = _converter.Convert(
|
||||
new[] { first.ToFileTimeUtc(), second.ToFileTimeUtc() },
|
||||
MxDataType.Time);
|
||||
|
||||
Assert.Equal(MxDataType.Time, converted.ArrayValue.ElementDataType);
|
||||
Assert.Equal(
|
||||
new[] { ProtobufTimestamp.FromDateTime(first), ProtobufTimestamp.FromDateTime(second) },
|
||||
converted.ArrayValue.TimestampValues.Values);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Convert_WithUnknownScalar_PreservesRawMetadata()
|
||||
{
|
||||
UnsupportedVariant value = new("opaque");
|
||||
|
||||
MxValue converted = _converter.Convert(value);
|
||||
|
||||
Assert.Equal(MxDataType.Unknown, converted.DataType);
|
||||
Assert.Equal(MxValue.KindOneofCase.RawValue, converted.KindCase);
|
||||
Assert.Contains(typeof(UnsupportedVariant).FullName!, converted.VariantType);
|
||||
Assert.Contains(typeof(UnsupportedVariant).FullName!, converted.RawDiagnostic);
|
||||
Assert.Equal(ByteString.CopyFromUtf8("opaque"), converted.RawValue);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ConvertArray_WithUnknownArray_PreservesRawMetadata()
|
||||
{
|
||||
UnsupportedVariant[] values =
|
||||
[
|
||||
new("first"),
|
||||
new("second"),
|
||||
];
|
||||
|
||||
MxValue converted = _converter.Convert(values);
|
||||
|
||||
Assert.Equal(MxDataType.Unknown, converted.ArrayValue.ElementDataType);
|
||||
Assert.Equal(MxArray.ValuesOneofCase.RawValues, converted.ArrayValue.ValuesCase);
|
||||
Assert.Equal(new uint[] { 2 }, converted.ArrayValue.Dimensions);
|
||||
Assert.Equal("first", converted.ArrayValue.RawValues.Values[0].ToStringUtf8());
|
||||
Assert.Contains(typeof(UnsupportedVariant).FullName!, converted.ArrayValue.RawDiagnostic);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Redactor_WithCredentialBearingValueFields_RedactsBeforeLogging()
|
||||
{
|
||||
Assert.Equal(WorkerLogRedactor.RedactedValue, WorkerLogRedactor.RedactValue("credential_value", "secret"));
|
||||
Assert.Equal(WorkerLogRedactor.RedactedValue, WorkerLogRedactor.RedactValue("password_value", "secret"));
|
||||
Assert.Equal(WorkerLogRedactor.RedactedValue, WorkerLogRedactor.RedactValue("secured_write_token", "secret"));
|
||||
}
|
||||
|
||||
private sealed class UnsupportedVariant
|
||||
{
|
||||
private readonly string _value;
|
||||
|
||||
public UnsupportedVariant(string value)
|
||||
{
|
||||
_value = value;
|
||||
}
|
||||
|
||||
public override string ToString()
|
||||
{
|
||||
return _value;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
using System;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
using Google.Protobuf;
|
||||
using MxGateway.Contracts;
|
||||
using MxGateway.Contracts.Proto;
|
||||
using MxGateway.Worker.Ipc;
|
||||
|
||||
namespace MxGateway.Worker.Tests.Ipc;
|
||||
|
||||
public sealed class WorkerFrameProtocolTests
|
||||
{
|
||||
private const string SessionId = "session-1";
|
||||
private const string Nonce = "nonce-secret";
|
||||
|
||||
[Fact]
|
||||
public async Task WriteAndReadAsync_WithValidEnvelope_RoundTripsFrame()
|
||||
{
|
||||
WorkerFrameProtocolOptions options = CreateOptions();
|
||||
MemoryStream stream = new();
|
||||
WorkerEnvelope original = CreateGatewayHelloEnvelope();
|
||||
|
||||
WorkerFrameWriter writer = new(stream, options);
|
||||
await writer.WriteAsync(original);
|
||||
stream.Position = 0;
|
||||
|
||||
WorkerFrameReader reader = new(stream, options);
|
||||
WorkerEnvelope parsed = await reader.ReadAsync();
|
||||
|
||||
Assert.Equal(original, parsed);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ReadAsync_WithWrongProtocolVersion_ThrowsProtocolVersionMismatch()
|
||||
{
|
||||
WorkerFrameProtocolOptions options = CreateOptions();
|
||||
WorkerEnvelope envelope = CreateGatewayHelloEnvelope();
|
||||
envelope.ProtocolVersion++;
|
||||
MemoryStream stream = new(CreateFrame(envelope));
|
||||
|
||||
WorkerFrameReader reader = new(stream, options);
|
||||
WorkerFrameProtocolException exception =
|
||||
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
|
||||
async () => await reader.ReadAsync());
|
||||
|
||||
Assert.Equal(WorkerFrameProtocolErrorCode.ProtocolVersionMismatch, exception.ErrorCode);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ReadAsync_WithWrongSessionId_ThrowsSessionMismatch()
|
||||
{
|
||||
WorkerFrameProtocolOptions options = CreateOptions();
|
||||
WorkerEnvelope envelope = CreateGatewayHelloEnvelope();
|
||||
envelope.SessionId = "different-session";
|
||||
MemoryStream stream = new(CreateFrame(envelope));
|
||||
|
||||
WorkerFrameReader reader = new(stream, options);
|
||||
WorkerFrameProtocolException exception =
|
||||
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
|
||||
async () => await reader.ReadAsync());
|
||||
|
||||
Assert.Equal(WorkerFrameProtocolErrorCode.SessionMismatch, exception.ErrorCode);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ReadAsync_WithMalformedLength_ThrowsMalformedLength()
|
||||
{
|
||||
WorkerFrameProtocolOptions options = CreateOptions();
|
||||
MemoryStream stream = new(new byte[sizeof(uint)]);
|
||||
|
||||
WorkerFrameReader reader = new(stream, options);
|
||||
WorkerFrameProtocolException exception =
|
||||
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
|
||||
async () => await reader.ReadAsync());
|
||||
|
||||
Assert.Equal(WorkerFrameProtocolErrorCode.MalformedLength, exception.ErrorCode);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ReadAsync_WithMalformedPayload_ThrowsInvalidEnvelope()
|
||||
{
|
||||
WorkerFrameProtocolOptions options = CreateOptions();
|
||||
MemoryStream stream = new(CreateFrame(new byte[] { 0x80 }));
|
||||
|
||||
WorkerFrameReader reader = new(stream, options);
|
||||
WorkerFrameProtocolException exception =
|
||||
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
|
||||
async () => await reader.ReadAsync());
|
||||
|
||||
Assert.Equal(WorkerFrameProtocolErrorCode.InvalidEnvelope, exception.ErrorCode);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WriteAsync_WithConcurrentCalls_SerializesCompleteFrames()
|
||||
{
|
||||
WorkerFrameProtocolOptions options = CreateOptions();
|
||||
MemoryStream stream = new();
|
||||
WorkerFrameWriter writer = new(stream, options);
|
||||
|
||||
await Task.WhenAll(
|
||||
writer.WriteAsync(CreateGatewayHelloEnvelope(sequence: 1)),
|
||||
writer.WriteAsync(CreateGatewayHelloEnvelope(sequence: 2)),
|
||||
writer.WriteAsync(CreateGatewayHelloEnvelope(sequence: 3)));
|
||||
|
||||
stream.Position = 0;
|
||||
WorkerFrameReader reader = new(stream, options);
|
||||
|
||||
WorkerEnvelope first = await reader.ReadAsync();
|
||||
WorkerEnvelope second = await reader.ReadAsync();
|
||||
WorkerEnvelope third = await reader.ReadAsync();
|
||||
|
||||
Assert.Equal(new ulong[] { 1, 2, 3 }, new[] { first.Sequence, second.Sequence, third.Sequence }.OrderBy(sequence => sequence));
|
||||
}
|
||||
|
||||
private static WorkerFrameProtocolOptions CreateOptions()
|
||||
{
|
||||
return new WorkerFrameProtocolOptions(
|
||||
SessionId,
|
||||
GatewayContractInfo.WorkerProtocolVersion,
|
||||
Nonce);
|
||||
}
|
||||
|
||||
private static WorkerEnvelope CreateGatewayHelloEnvelope(ulong sequence = 1)
|
||||
{
|
||||
return new WorkerEnvelope
|
||||
{
|
||||
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
|
||||
SessionId = SessionId,
|
||||
Sequence = sequence,
|
||||
GatewayHello = new GatewayHello
|
||||
{
|
||||
SupportedProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
|
||||
Nonce = Nonce,
|
||||
GatewayVersion = "test-gateway",
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
private static byte[] CreateFrame(IMessage message)
|
||||
{
|
||||
return CreateFrame(message.ToByteArray());
|
||||
}
|
||||
|
||||
private static byte[] CreateFrame(byte[] payload)
|
||||
{
|
||||
byte[] frame = new byte[sizeof(uint) + payload.Length];
|
||||
WriteUInt32LittleEndian(frame, (uint)payload.Length);
|
||||
payload.CopyTo(frame, sizeof(uint));
|
||||
|
||||
return frame;
|
||||
}
|
||||
|
||||
private static void WriteUInt32LittleEndian(
|
||||
byte[] buffer,
|
||||
uint value)
|
||||
{
|
||||
buffer[0] = (byte)value;
|
||||
buffer[1] = (byte)(value >> 8);
|
||||
buffer[2] = (byte)(value >> 16);
|
||||
buffer[3] = (byte)(value >> 24);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
using System;
|
||||
using System.IO.Pipes;
|
||||
using System.Threading.Tasks;
|
||||
using MxGateway.Contracts;
|
||||
using MxGateway.Contracts.Proto;
|
||||
using MxGateway.Worker.Bootstrap;
|
||||
using MxGateway.Worker.Ipc;
|
||||
|
||||
namespace MxGateway.Worker.Tests.Ipc;
|
||||
|
||||
public sealed class WorkerPipeClientTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task RunAsync_ConnectsToPipeAndCompletesHandshake()
|
||||
{
|
||||
string pipeName = $"mxaccess-gateway-test-{Guid.NewGuid():N}";
|
||||
WorkerOptions workerOptions = new(
|
||||
"session-1",
|
||||
pipeName,
|
||||
GatewayContractInfo.WorkerProtocolVersion,
|
||||
"nonce-secret");
|
||||
WorkerFrameProtocolOptions frameOptions = new(workerOptions);
|
||||
|
||||
using NamedPipeServerStream server = new(
|
||||
pipeName,
|
||||
PipeDirection.InOut,
|
||||
1,
|
||||
PipeTransmissionMode.Byte,
|
||||
PipeOptions.Asynchronous);
|
||||
|
||||
WorkerPipeClient client = new(connectTimeoutMilliseconds: 5000);
|
||||
Task clientTask = client.RunAsync(workerOptions);
|
||||
|
||||
await Task.Factory.FromAsync(server.BeginWaitForConnection, server.EndWaitForConnection, null);
|
||||
|
||||
WorkerFrameReader reader = new(server, frameOptions);
|
||||
WorkerFrameWriter writer = new(server, frameOptions);
|
||||
|
||||
await writer.WriteAsync(new WorkerEnvelope
|
||||
{
|
||||
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
|
||||
SessionId = "session-1",
|
||||
Sequence = 1,
|
||||
GatewayHello = new GatewayHello
|
||||
{
|
||||
SupportedProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
|
||||
Nonce = "nonce-secret",
|
||||
GatewayVersion = "test-gateway",
|
||||
},
|
||||
});
|
||||
|
||||
WorkerEnvelope hello = await reader.ReadAsync();
|
||||
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerHello, hello.BodyCase);
|
||||
Assert.Equal("nonce-secret", hello.WorkerHello.Nonce);
|
||||
|
||||
WorkerEnvelope ready = await reader.ReadAsync();
|
||||
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerReady, ready.BodyCase);
|
||||
|
||||
await clientTask;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using MxGateway.Contracts;
|
||||
using MxGateway.Contracts.Proto;
|
||||
using MxGateway.Worker.Ipc;
|
||||
|
||||
namespace MxGateway.Worker.Tests.Ipc;
|
||||
|
||||
public sealed class WorkerPipeSessionTests
|
||||
{
|
||||
private const string SessionId = "session-1";
|
||||
private const string Nonce = "nonce-secret";
|
||||
|
||||
[Fact]
|
||||
public async Task CompleteStartupHandshakeAsync_WithValidGatewayHello_SendsHelloThenReady()
|
||||
{
|
||||
WorkerFrameProtocolOptions options = CreateOptions();
|
||||
MemoryStream inbound = new();
|
||||
await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope());
|
||||
inbound.Position = 0;
|
||||
MemoryStream outbound = new();
|
||||
WorkerPipeSession session = CreateSession(inbound, outbound, options);
|
||||
bool initialized = false;
|
||||
|
||||
await session.CompleteStartupHandshakeAsync(
|
||||
_ =>
|
||||
{
|
||||
initialized = true;
|
||||
return Task.CompletedTask;
|
||||
});
|
||||
|
||||
Assert.True(initialized);
|
||||
WorkerEnvelope[] written = ReadWrittenFrames(outbound, options);
|
||||
Assert.Equal(2, written.Length);
|
||||
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerHello, written[0].BodyCase);
|
||||
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerReady, written[1].BodyCase);
|
||||
Assert.Equal(Nonce, written[0].WorkerHello.Nonce);
|
||||
Assert.Equal(1234, written[1].WorkerReady.WorkerProcessId);
|
||||
Assert.Equal(MxGateway.Worker.MxAccess.MxAccessInteropInfo.ProgId, written[1].WorkerReady.MxaccessProgid);
|
||||
Assert.Equal(MxGateway.Worker.MxAccess.MxAccessInteropInfo.Clsid, written[1].WorkerReady.MxaccessClsid);
|
||||
Assert.NotNull(written[1].WorkerReady.ReadyTimestamp);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompleteStartupHandshakeAsync_WithWrongNonce_FaultsBeforeInitialization()
|
||||
{
|
||||
WorkerFrameProtocolOptions options = CreateOptions();
|
||||
MemoryStream inbound = new();
|
||||
await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope(nonce: "wrong"));
|
||||
inbound.Position = 0;
|
||||
MemoryStream outbound = new();
|
||||
WorkerPipeSession session = CreateSession(inbound, outbound, options);
|
||||
bool initialized = false;
|
||||
|
||||
WorkerFrameProtocolException exception =
|
||||
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
|
||||
async () => await session.CompleteStartupHandshakeAsync(
|
||||
_ =>
|
||||
{
|
||||
initialized = true;
|
||||
return Task.CompletedTask;
|
||||
}));
|
||||
|
||||
Assert.False(initialized);
|
||||
Assert.Equal(WorkerFrameProtocolErrorCode.NonceMismatch, exception.ErrorCode);
|
||||
WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options));
|
||||
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerFault, fault.BodyCase);
|
||||
Assert.Equal(WorkerFaultCategory.ProtocolViolation, fault.WorkerFault.Category);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompleteStartupHandshakeAsync_WithWrongProtocol_FaultsBeforeInitialization()
|
||||
{
|
||||
WorkerFrameProtocolOptions options = CreateOptions();
|
||||
MemoryStream inbound = new();
|
||||
await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope(supportedProtocolVersion: 999));
|
||||
inbound.Position = 0;
|
||||
MemoryStream outbound = new();
|
||||
WorkerPipeSession session = CreateSession(inbound, outbound, options);
|
||||
bool initialized = false;
|
||||
|
||||
WorkerFrameProtocolException exception =
|
||||
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
|
||||
async () => await session.CompleteStartupHandshakeAsync(
|
||||
_ =>
|
||||
{
|
||||
initialized = true;
|
||||
return Task.CompletedTask;
|
||||
}));
|
||||
|
||||
Assert.False(initialized);
|
||||
Assert.Equal(WorkerFrameProtocolErrorCode.ProtocolVersionMismatch, exception.ErrorCode);
|
||||
WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options));
|
||||
Assert.Equal(WorkerFaultCategory.ProtocolMismatch, fault.WorkerFault.Category);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompleteStartupHandshakeAsync_WithMalformedFrame_WritesWorkerFault()
|
||||
{
|
||||
WorkerFrameProtocolOptions options = CreateOptions();
|
||||
MemoryStream inbound = new(CreateFrame(new byte[] { 0x80 }));
|
||||
MemoryStream outbound = new();
|
||||
WorkerPipeSession session = CreateSession(inbound, outbound, options);
|
||||
bool initialized = false;
|
||||
|
||||
WorkerFrameProtocolException exception =
|
||||
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
|
||||
async () => await session.CompleteStartupHandshakeAsync(
|
||||
_ =>
|
||||
{
|
||||
initialized = true;
|
||||
return Task.CompletedTask;
|
||||
}));
|
||||
|
||||
Assert.False(initialized);
|
||||
Assert.Equal(WorkerFrameProtocolErrorCode.InvalidEnvelope, exception.ErrorCode);
|
||||
WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options));
|
||||
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerFault, fault.BodyCase);
|
||||
Assert.Equal(WorkerFaultCategory.ProtocolViolation, fault.WorkerFault.Category);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompleteStartupHandshakeAsync_WhenMxAccessCreationFails_WritesFaultInsteadOfReady()
|
||||
{
|
||||
const int hresult = unchecked((int)0x80040154);
|
||||
WorkerFrameProtocolOptions options = CreateOptions();
|
||||
MemoryStream inbound = new();
|
||||
await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope());
|
||||
inbound.Position = 0;
|
||||
MemoryStream outbound = new();
|
||||
WorkerPipeSession session = CreateSession(inbound, outbound, options);
|
||||
|
||||
await Assert.ThrowsAsync<COMException>(
|
||||
async () => await session.CompleteStartupHandshakeAsync(
|
||||
_ => Task.FromException<WorkerReady>(new COMException("Class not registered.", hresult))));
|
||||
|
||||
WorkerEnvelope[] written = ReadWrittenFrames(outbound, options);
|
||||
Assert.Equal(2, written.Length);
|
||||
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerHello, written[0].BodyCase);
|
||||
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerFault, written[1].BodyCase);
|
||||
Assert.Equal(WorkerFaultCategory.MxaccessCreationFailed, written[1].WorkerFault.Category);
|
||||
Assert.Equal(hresult, written[1].WorkerFault.Hresult);
|
||||
Assert.Equal(typeof(COMException).FullName, written[1].WorkerFault.ExceptionType);
|
||||
Assert.Equal(ProtocolStatusCode.WorkerUnavailable, written[1].WorkerFault.ProtocolStatus.Code);
|
||||
}
|
||||
|
||||
private static WorkerPipeSession CreateSession(
|
||||
Stream inbound,
|
||||
Stream outbound,
|
||||
WorkerFrameProtocolOptions options)
|
||||
{
|
||||
return new WorkerPipeSession(
|
||||
new WorkerFrameReader(inbound, options),
|
||||
new WorkerFrameWriter(outbound, options),
|
||||
options,
|
||||
() => 1234);
|
||||
}
|
||||
|
||||
private static WorkerFrameProtocolOptions CreateOptions()
|
||||
{
|
||||
return new WorkerFrameProtocolOptions(
|
||||
SessionId,
|
||||
GatewayContractInfo.WorkerProtocolVersion,
|
||||
Nonce);
|
||||
}
|
||||
|
||||
private static WorkerEnvelope CreateGatewayHelloEnvelope(
|
||||
string nonce = Nonce,
|
||||
uint supportedProtocolVersion = GatewayContractInfo.WorkerProtocolVersion)
|
||||
{
|
||||
return new WorkerEnvelope
|
||||
{
|
||||
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
|
||||
SessionId = SessionId,
|
||||
Sequence = 1,
|
||||
GatewayHello = new GatewayHello
|
||||
{
|
||||
SupportedProtocolVersion = supportedProtocolVersion,
|
||||
Nonce = nonce,
|
||||
GatewayVersion = "test-gateway",
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
private static WorkerEnvelope[] ReadWrittenFrames(
|
||||
MemoryStream stream,
|
||||
WorkerFrameProtocolOptions options)
|
||||
{
|
||||
stream.Position = 0;
|
||||
WorkerFrameReader reader = new(stream, options);
|
||||
List<WorkerEnvelope> envelopes = new();
|
||||
|
||||
while (stream.Position < stream.Length)
|
||||
{
|
||||
envelopes.Add(reader.ReadAsync(CancellationToken.None).GetAwaiter().GetResult());
|
||||
}
|
||||
|
||||
return envelopes.ToArray();
|
||||
}
|
||||
|
||||
private static byte[] CreateFrame(byte[] payload)
|
||||
{
|
||||
byte[] frame = new byte[sizeof(uint) + payload.Length];
|
||||
WriteUInt32LittleEndian(frame, (uint)payload.Length);
|
||||
payload.CopyTo(frame, sizeof(uint));
|
||||
|
||||
return frame;
|
||||
}
|
||||
|
||||
private static void WriteUInt32LittleEndian(
|
||||
byte[] buffer,
|
||||
uint value)
|
||||
{
|
||||
buffer[0] = (byte)value;
|
||||
buffer[1] = (byte)(value >> 8);
|
||||
buffer[2] = (byte)(value >> 16);
|
||||
buffer[3] = (byte)(value >> 24);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
using System;
|
||||
using System.Threading.Tasks;
|
||||
using MxGateway.Worker.MxAccess;
|
||||
|
||||
namespace MxGateway.Worker.Tests.MxAccess;
|
||||
|
||||
public sealed class MxAccessLiveComCreationTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task StartAsync_WhenOptedIn_CreatesInstalledMxAccessComObjectOnSta()
|
||||
{
|
||||
if (!string.Equals(
|
||||
Environment.GetEnvironmentVariable("MXGATEWAY_RUN_LIVE_MXACCESS_TESTS"),
|
||||
"1",
|
||||
StringComparison.Ordinal))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
using MxAccessStaSession session = new();
|
||||
|
||||
await session.StartAsync(workerProcessId: 1234);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
using System;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using MxGateway.Contracts.Proto;
|
||||
using MxGateway.Worker.MxAccess;
|
||||
using MxGateway.Worker.Sta;
|
||||
|
||||
namespace MxGateway.Worker.Tests.MxAccess;
|
||||
|
||||
public sealed class MxAccessStaSessionTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task StartAsync_CreatesComObjectAndAttachesEventSinkOnStaThread()
|
||||
{
|
||||
FakeMxAccessComObjectFactory factory = new();
|
||||
FakeMxAccessEventSink eventSink = new();
|
||||
using StaRuntime runtime = CreateRuntime();
|
||||
using MxAccessStaSession session = new(runtime, factory, eventSink);
|
||||
|
||||
WorkerReady ready = await session.StartAsync(workerProcessId: 1234);
|
||||
|
||||
Assert.Equal(1234, ready.WorkerProcessId);
|
||||
Assert.Equal(MxAccessInteropInfo.ProgId, ready.MxaccessProgid);
|
||||
Assert.Equal(MxAccessInteropInfo.Clsid, ready.MxaccessClsid);
|
||||
Assert.NotNull(ready.ReadyTimestamp);
|
||||
Assert.Equal(runtime.StaThreadId, factory.CreateThreadId);
|
||||
Assert.Equal(runtime.StaThreadId, eventSink.AttachThreadId);
|
||||
Assert.Equal(ApartmentState.STA, factory.CreateApartmentState);
|
||||
Assert.Same(factory.CreatedObject, eventSink.AttachedObject);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task StartAsync_WhenFactoryFails_MapsCreationExceptionWithHResult()
|
||||
{
|
||||
const int hresult = unchecked((int)0x80040154);
|
||||
FakeMxAccessComObjectFactory factory = new(new COMException("Class not registered.", hresult));
|
||||
FakeMxAccessEventSink eventSink = new();
|
||||
using StaRuntime runtime = CreateRuntime();
|
||||
using MxAccessStaSession session = new(runtime, factory, eventSink);
|
||||
|
||||
MxAccessCreationException exception = await Assert.ThrowsAsync<MxAccessCreationException>(
|
||||
() => session.StartAsync(workerProcessId: 1234));
|
||||
|
||||
Assert.Equal(hresult, exception.CapturedHResult);
|
||||
Assert.Equal(MxAccessInteropInfo.ProgId, exception.AttemptedProgId);
|
||||
Assert.Equal(MxAccessInteropInfo.Clsid, exception.AttemptedClsid);
|
||||
Assert.Null(eventSink.AttachedObject);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Dispose_DetachesEventSinkOnStaThread()
|
||||
{
|
||||
FakeMxAccessComObjectFactory factory = new();
|
||||
FakeMxAccessEventSink eventSink = new();
|
||||
using StaRuntime runtime = CreateRuntime();
|
||||
MxAccessStaSession session = new(runtime, factory, eventSink);
|
||||
await session.StartAsync(workerProcessId: 1234);
|
||||
|
||||
session.Dispose();
|
||||
|
||||
Assert.Equal(runtime.StaThreadId, eventSink.DetachThreadId);
|
||||
}
|
||||
|
||||
private static StaRuntime CreateRuntime()
|
||||
{
|
||||
return new StaRuntime(
|
||||
new NoopComApartmentInitializer(),
|
||||
new StaMessagePump(),
|
||||
TimeSpan.FromMilliseconds(25));
|
||||
}
|
||||
|
||||
private sealed class FakeMxAccessComObjectFactory : IMxAccessComObjectFactory
|
||||
{
|
||||
private readonly Exception? exception;
|
||||
|
||||
public FakeMxAccessComObjectFactory(Exception? exception = null)
|
||||
{
|
||||
this.exception = exception;
|
||||
}
|
||||
|
||||
public object CreatedObject { get; } = new();
|
||||
|
||||
public int? CreateThreadId { get; private set; }
|
||||
|
||||
public ApartmentState? CreateApartmentState { get; private set; }
|
||||
|
||||
public object Create()
|
||||
{
|
||||
CreateThreadId = Thread.CurrentThread.ManagedThreadId;
|
||||
CreateApartmentState = Thread.CurrentThread.GetApartmentState();
|
||||
|
||||
if (exception is not null)
|
||||
{
|
||||
throw exception;
|
||||
}
|
||||
|
||||
return CreatedObject;
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class FakeMxAccessEventSink : IMxAccessEventSink
|
||||
{
|
||||
public object? AttachedObject { get; private set; }
|
||||
|
||||
public int? AttachThreadId { get; private set; }
|
||||
|
||||
public int? DetachThreadId { get; private set; }
|
||||
|
||||
public void Attach(object mxAccessComObject)
|
||||
{
|
||||
AttachedObject = mxAccessComObject;
|
||||
AttachThreadId = Thread.CurrentThread.ManagedThreadId;
|
||||
}
|
||||
|
||||
public void Detach()
|
||||
{
|
||||
DetachThreadId = Thread.CurrentThread.ManagedThreadId;
|
||||
AttachedObject = null;
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class NoopComApartmentInitializer : IStaComApartmentInitializer
|
||||
{
|
||||
public void Initialize()
|
||||
{
|
||||
}
|
||||
|
||||
public void Uninitialize()
|
||||
{
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
using System;
|
||||
using System.Diagnostics;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using MxGateway.Worker.Sta;
|
||||
|
||||
namespace MxGateway.Worker.Tests.Sta;
|
||||
|
||||
public sealed class StaRuntimeTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task InvokeAsync_ExecutesCommandOnStaThread()
|
||||
{
|
||||
RecordingComApartmentInitializer initializer = new();
|
||||
using StaRuntime runtime = CreateRuntime(initializer);
|
||||
|
||||
runtime.Start();
|
||||
|
||||
StaCommandObservation observation = await runtime.InvokeAsync(
|
||||
() => new StaCommandObservation(
|
||||
Thread.CurrentThread.ManagedThreadId,
|
||||
Thread.CurrentThread.GetApartmentState()));
|
||||
|
||||
Assert.Equal(runtime.StaThreadId, observation.ThreadId);
|
||||
Assert.Equal(initializer.InitializeThreadId, observation.ThreadId);
|
||||
Assert.Equal(ApartmentState.STA, observation.ApartmentState);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokeAsync_WakesIdlePumpForQueuedCommand()
|
||||
{
|
||||
RecordingComApartmentInitializer initializer = new();
|
||||
using StaRuntime runtime = new(
|
||||
initializer,
|
||||
new StaMessagePump(),
|
||||
TimeSpan.FromSeconds(30));
|
||||
runtime.Start();
|
||||
Stopwatch stopwatch = Stopwatch.StartNew();
|
||||
|
||||
int threadId = await runtime.InvokeAsync(() => Thread.CurrentThread.ManagedThreadId);
|
||||
|
||||
stopwatch.Stop();
|
||||
Assert.Equal(runtime.StaThreadId, threadId);
|
||||
Assert.True(
|
||||
stopwatch.Elapsed < TimeSpan.FromSeconds(2),
|
||||
$"Command took {stopwatch.Elapsed} to execute, so the command wake event did not wake the STA promptly.");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Shutdown_StopsThreadAndUninitializesComApartment()
|
||||
{
|
||||
RecordingComApartmentInitializer initializer = new();
|
||||
using StaRuntime runtime = CreateRuntime(initializer);
|
||||
runtime.Start();
|
||||
|
||||
bool stopped = runtime.Shutdown(TimeSpan.FromSeconds(2));
|
||||
|
||||
Assert.True(stopped);
|
||||
Assert.False(runtime.IsRunning);
|
||||
Assert.Equal(1, initializer.InitializeCount);
|
||||
Assert.Equal(1, initializer.UninitializeCount);
|
||||
Assert.Equal(initializer.InitializeThreadId, initializer.UninitializeThreadId);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void LastActivityUtc_UpdatesWhilePumpIsIdle()
|
||||
{
|
||||
RecordingComApartmentInitializer initializer = new();
|
||||
using StaRuntime runtime = CreateRuntime(initializer);
|
||||
runtime.Start();
|
||||
DateTimeOffset firstActivity = runtime.LastActivityUtc;
|
||||
|
||||
bool updated = SpinWait.SpinUntil(
|
||||
() => runtime.LastActivityUtc > firstActivity,
|
||||
TimeSpan.FromSeconds(2));
|
||||
|
||||
Assert.True(updated);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokeAsync_CommandException_FaultsReturnedTaskWithoutStoppingRuntime()
|
||||
{
|
||||
RecordingComApartmentInitializer initializer = new();
|
||||
using StaRuntime runtime = CreateRuntime(initializer);
|
||||
runtime.Start();
|
||||
|
||||
InvalidOperationException exception = await Assert.ThrowsAsync<InvalidOperationException>(
|
||||
() => runtime.InvokeAsync<int>(() => throw new InvalidOperationException("command failed")));
|
||||
|
||||
int threadId = await runtime.InvokeAsync(() => Thread.CurrentThread.ManagedThreadId);
|
||||
Assert.Equal("command failed", exception.Message);
|
||||
Assert.Equal(runtime.StaThreadId, threadId);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokeAsync_AfterShutdown_ReturnsFaultedTask()
|
||||
{
|
||||
RecordingComApartmentInitializer initializer = new();
|
||||
using StaRuntime runtime = CreateRuntime(initializer);
|
||||
runtime.Start();
|
||||
runtime.Shutdown(TimeSpan.FromSeconds(2));
|
||||
|
||||
InvalidOperationException exception = await Assert.ThrowsAsync<InvalidOperationException>(
|
||||
() => runtime.InvokeAsync(() => Thread.CurrentThread.ManagedThreadId));
|
||||
|
||||
Assert.Contains("shutting down", exception.Message);
|
||||
}
|
||||
|
||||
private static StaRuntime CreateRuntime(RecordingComApartmentInitializer initializer)
|
||||
{
|
||||
return new StaRuntime(
|
||||
initializer,
|
||||
new StaMessagePump(),
|
||||
TimeSpan.FromMilliseconds(25));
|
||||
}
|
||||
|
||||
private sealed class StaCommandObservation
|
||||
{
|
||||
public StaCommandObservation(int threadId, ApartmentState apartmentState)
|
||||
{
|
||||
ThreadId = threadId;
|
||||
ApartmentState = apartmentState;
|
||||
}
|
||||
|
||||
public int ThreadId { get; }
|
||||
|
||||
public ApartmentState ApartmentState { get; }
|
||||
}
|
||||
|
||||
private sealed class RecordingComApartmentInitializer : IStaComApartmentInitializer
|
||||
{
|
||||
public int InitializeCount { get; private set; }
|
||||
|
||||
public int UninitializeCount { get; private set; }
|
||||
|
||||
public int? InitializeThreadId { get; private set; }
|
||||
|
||||
public int? UninitializeThreadId { get; private set; }
|
||||
|
||||
public void Initialize()
|
||||
{
|
||||
InitializeCount++;
|
||||
InitializeThreadId = Thread.CurrentThread.ManagedThreadId;
|
||||
}
|
||||
|
||||
public void Uninitialize()
|
||||
{
|
||||
UninitializeCount++;
|
||||
UninitializeThreadId = Thread.CurrentThread.ManagedThreadId;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
using System;
|
||||
|
||||
namespace MxGateway.Worker.Bootstrap;
|
||||
|
||||
public sealed class EnvironmentVariableWorkerEnvironment : IWorkerEnvironment
|
||||
{
|
||||
public string? GetEnvironmentVariable(string name)
|
||||
{
|
||||
return Environment.GetEnvironmentVariable(name);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
namespace MxGateway.Worker.Bootstrap;
|
||||
|
||||
public interface IWorkerEnvironment
|
||||
{
|
||||
string? GetEnvironmentVariable(string name);
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace MxGateway.Worker.Bootstrap;
|
||||
|
||||
public interface IWorkerLogger
|
||||
{
|
||||
void Information(string eventName, IReadOnlyDictionary<string, object?> fields);
|
||||
|
||||
void Error(string eventName, IReadOnlyDictionary<string, object?> fields);
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
|
||||
namespace MxGateway.Worker.Bootstrap;
|
||||
|
||||
public sealed class WorkerBootstrapResult
|
||||
{
|
||||
private WorkerBootstrapResult(
|
||||
WorkerExitCode exitCode,
|
||||
WorkerOptions? options,
|
||||
IReadOnlyList<string> errors)
|
||||
{
|
||||
ExitCode = exitCode;
|
||||
Options = options;
|
||||
Errors = errors;
|
||||
}
|
||||
|
||||
public WorkerExitCode ExitCode { get; }
|
||||
|
||||
public WorkerOptions? Options { get; }
|
||||
|
||||
public IReadOnlyList<string> Errors { get; }
|
||||
|
||||
public bool Succeeded => ExitCode == WorkerExitCode.Success;
|
||||
|
||||
public static WorkerBootstrapResult Success(WorkerOptions options)
|
||||
{
|
||||
return new WorkerBootstrapResult(WorkerExitCode.Success, options, []);
|
||||
}
|
||||
|
||||
public static WorkerBootstrapResult Failure(WorkerExitCode exitCode, IEnumerable<string> errors)
|
||||
{
|
||||
return new WorkerBootstrapResult(exitCode, null, errors.ToArray());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
|
||||
namespace MxGateway.Worker.Bootstrap;
|
||||
|
||||
public sealed class WorkerConsoleLogger : IWorkerLogger
|
||||
{
|
||||
private readonly TextWriter _writer;
|
||||
|
||||
public WorkerConsoleLogger(TextWriter writer)
|
||||
{
|
||||
_writer = writer ?? throw new ArgumentNullException(nameof(writer));
|
||||
}
|
||||
|
||||
public void Information(string eventName, IReadOnlyDictionary<string, object?> fields)
|
||||
{
|
||||
Write("Information", eventName, fields);
|
||||
}
|
||||
|
||||
public void Error(string eventName, IReadOnlyDictionary<string, object?> fields)
|
||||
{
|
||||
Write("Error", eventName, fields);
|
||||
}
|
||||
|
||||
private void Write(
|
||||
string level,
|
||||
string eventName,
|
||||
IReadOnlyDictionary<string, object?> fields)
|
||||
{
|
||||
Dictionary<string, object?> redactedFields = WorkerLogRedactor.RedactFields(fields);
|
||||
string fieldText = string.Join(
|
||||
" ",
|
||||
redactedFields.Select(field => $"{field.Key}={FormatValue(field.Value)}"));
|
||||
|
||||
_writer.WriteLine($"level={level} event={eventName} {fieldText}".TrimEnd());
|
||||
}
|
||||
|
||||
private static string FormatValue(object? value)
|
||||
{
|
||||
return value?.ToString() ?? string.Empty;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
namespace MxGateway.Worker.Bootstrap;
|
||||
|
||||
public enum WorkerExitCode
|
||||
{
|
||||
Success = 0,
|
||||
UnexpectedFailure = 1,
|
||||
InvalidArguments = 2,
|
||||
InvalidProtocolVersion = 3,
|
||||
MissingNonce = 4,
|
||||
PipeConnectionFailed = 5,
|
||||
ProtocolViolation = 6,
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace MxGateway.Worker.Bootstrap;
|
||||
|
||||
public static class WorkerLogRedactor
|
||||
{
|
||||
public const string RedactedValue = "[redacted]";
|
||||
|
||||
private static readonly string[] SensitiveFieldNameParts =
|
||||
[
|
||||
"nonce",
|
||||
"secret",
|
||||
"password",
|
||||
"token",
|
||||
"credential",
|
||||
"apikey",
|
||||
"api_key",
|
||||
];
|
||||
|
||||
public static Dictionary<string, object?> RedactFields(IReadOnlyDictionary<string, object?> fields)
|
||||
{
|
||||
Dictionary<string, object?> redactedFields = [];
|
||||
|
||||
foreach (KeyValuePair<string, object?> field in fields)
|
||||
{
|
||||
redactedFields[field.Key] = RedactValue(field.Key, field.Value);
|
||||
}
|
||||
|
||||
return redactedFields;
|
||||
}
|
||||
|
||||
public static object? RedactValue(string fieldName, object? value)
|
||||
{
|
||||
if (value is null)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
foreach (string sensitiveFieldNamePart in SensitiveFieldNameParts)
|
||||
{
|
||||
if (fieldName.IndexOf(sensitiveFieldNamePart, StringComparison.OrdinalIgnoreCase) >= 0)
|
||||
{
|
||||
return RedactedValue;
|
||||
}
|
||||
}
|
||||
|
||||
return value;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
namespace MxGateway.Worker.Bootstrap;
|
||||
|
||||
public sealed class WorkerOptions
|
||||
{
|
||||
public const string NonceEnvironmentVariableName = "MXGATEWAY_WORKER_NONCE";
|
||||
|
||||
public WorkerOptions(
|
||||
string sessionId,
|
||||
string pipeName,
|
||||
uint protocolVersion,
|
||||
string nonce)
|
||||
{
|
||||
SessionId = sessionId;
|
||||
PipeName = pipeName;
|
||||
ProtocolVersion = protocolVersion;
|
||||
Nonce = nonce;
|
||||
}
|
||||
|
||||
public string SessionId { get; }
|
||||
|
||||
public string PipeName { get; }
|
||||
|
||||
public uint ProtocolVersion { get; }
|
||||
|
||||
public string Nonce { get; }
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using MxGateway.Contracts;
|
||||
|
||||
namespace MxGateway.Worker.Bootstrap;
|
||||
|
||||
public sealed class WorkerOptionsParser
|
||||
{
|
||||
private const string SessionIdOptionName = "--session-id";
|
||||
private const string PipeNameOptionName = "--pipe-name";
|
||||
private const string ProtocolVersionOptionName = "--protocol-version";
|
||||
|
||||
private readonly IWorkerEnvironment _environment;
|
||||
|
||||
public WorkerOptionsParser(IWorkerEnvironment environment)
|
||||
{
|
||||
_environment = environment ?? throw new ArgumentNullException(nameof(environment));
|
||||
}
|
||||
|
||||
public WorkerBootstrapResult Parse(string[] args)
|
||||
{
|
||||
if (args is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(args));
|
||||
}
|
||||
|
||||
Dictionary<string, string> values = new(StringComparer.OrdinalIgnoreCase);
|
||||
List<string> errors = [];
|
||||
|
||||
for (int index = 0; index < args.Length; index++)
|
||||
{
|
||||
string arg = args[index];
|
||||
if (!IsKnownOption(arg))
|
||||
{
|
||||
errors.Add($"Unknown option '{arg}'.");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (index + 1 >= args.Length || args[index + 1].StartsWith("--", StringComparison.Ordinal))
|
||||
{
|
||||
errors.Add($"Option '{arg}' requires a value.");
|
||||
continue;
|
||||
}
|
||||
|
||||
values[arg] = args[index + 1];
|
||||
index++;
|
||||
}
|
||||
|
||||
string? sessionId = ReadRequired(values, SessionIdOptionName, errors);
|
||||
string? pipeName = ReadRequired(values, PipeNameOptionName, errors);
|
||||
string? protocolVersionText = ReadRequired(values, ProtocolVersionOptionName, errors);
|
||||
|
||||
if (errors.Count > 0)
|
||||
{
|
||||
return WorkerBootstrapResult.Failure(WorkerExitCode.InvalidArguments, errors);
|
||||
}
|
||||
|
||||
if (!uint.TryParse(protocolVersionText, out uint protocolVersion)
|
||||
|| protocolVersion != GatewayContractInfo.WorkerProtocolVersion)
|
||||
{
|
||||
return WorkerBootstrapResult.Failure(
|
||||
WorkerExitCode.InvalidProtocolVersion,
|
||||
[$"Unsupported protocol version '{protocolVersionText}'."]);
|
||||
}
|
||||
|
||||
string? nonce = _environment.GetEnvironmentVariable(WorkerOptions.NonceEnvironmentVariableName);
|
||||
|
||||
if (string.IsNullOrWhiteSpace(nonce))
|
||||
{
|
||||
return WorkerBootstrapResult.Failure(
|
||||
WorkerExitCode.MissingNonce,
|
||||
["Required worker nonce environment variable is missing."]);
|
||||
}
|
||||
|
||||
return WorkerBootstrapResult.Success(new WorkerOptions(
|
||||
sessionId!,
|
||||
pipeName!,
|
||||
protocolVersion,
|
||||
nonce!));
|
||||
}
|
||||
|
||||
private static string? ReadRequired(
|
||||
IReadOnlyDictionary<string, string> values,
|
||||
string optionName,
|
||||
List<string> errors)
|
||||
{
|
||||
if (!values.TryGetValue(optionName, out string value)
|
||||
|| string.IsNullOrWhiteSpace(value))
|
||||
{
|
||||
errors.Add($"Required option '{optionName}' is missing.");
|
||||
return null;
|
||||
}
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
private static bool IsKnownOption(string optionName)
|
||||
{
|
||||
return optionName is SessionIdOptionName or PipeNameOptionName or ProtocolVersionOptionName;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,522 @@
|
||||
using System;
|
||||
using System.Globalization;
|
||||
using Google.Protobuf;
|
||||
using Google.Protobuf.WellKnownTypes;
|
||||
using MxGateway.Contracts.Proto;
|
||||
|
||||
namespace MxGateway.Worker.Conversion;
|
||||
|
||||
public sealed class VariantConverter
|
||||
{
|
||||
public MxValue Convert(object? value)
|
||||
{
|
||||
return Convert(value, MxDataType.Unspecified);
|
||||
}
|
||||
|
||||
public MxValue Convert(
|
||||
object? value,
|
||||
MxDataType expectedDataType)
|
||||
{
|
||||
if (value is null || value is DBNull)
|
||||
{
|
||||
return CreateNullValue(value, expectedDataType);
|
||||
}
|
||||
|
||||
if (value is Array array)
|
||||
{
|
||||
return new MxValue
|
||||
{
|
||||
DataType = MxDataType.Unspecified,
|
||||
VariantType = CreateArrayVariantType(array),
|
||||
ArrayValue = ConvertArray(array, expectedDataType),
|
||||
};
|
||||
}
|
||||
|
||||
return ConvertScalar(value, expectedDataType);
|
||||
}
|
||||
|
||||
public MxArray ConvertArray(
|
||||
Array array,
|
||||
MxDataType expectedElementDataType = MxDataType.Unspecified)
|
||||
{
|
||||
if (array is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(array));
|
||||
}
|
||||
|
||||
MxArray mxArray = new()
|
||||
{
|
||||
VariantType = CreateArrayVariantType(array),
|
||||
};
|
||||
|
||||
for (int dimension = 0; dimension < array.Rank; dimension++)
|
||||
{
|
||||
mxArray.Dimensions.Add((uint)array.GetLength(dimension));
|
||||
}
|
||||
|
||||
System.Type? elementType = array.GetType().GetElementType();
|
||||
MxDataType elementDataType = ResolveArrayElementDataType(elementType, expectedElementDataType);
|
||||
mxArray.ElementDataType = elementDataType;
|
||||
|
||||
switch (elementDataType)
|
||||
{
|
||||
case MxDataType.Boolean:
|
||||
mxArray.BoolValues = ConvertBoolArray(array);
|
||||
return mxArray;
|
||||
|
||||
case MxDataType.Integer:
|
||||
if (elementType == typeof(long) || elementType == typeof(ulong))
|
||||
{
|
||||
mxArray.Int64Values = ConvertInt64Array(array);
|
||||
}
|
||||
else
|
||||
{
|
||||
mxArray.Int32Values = ConvertInt32Array(array);
|
||||
}
|
||||
|
||||
return mxArray;
|
||||
|
||||
case MxDataType.Float:
|
||||
mxArray.FloatValues = ConvertFloatArray(array);
|
||||
return mxArray;
|
||||
|
||||
case MxDataType.Double:
|
||||
mxArray.DoubleValues = ConvertDoubleArray(array);
|
||||
return mxArray;
|
||||
|
||||
case MxDataType.String:
|
||||
mxArray.StringValues = ConvertStringArray(array);
|
||||
return mxArray;
|
||||
|
||||
case MxDataType.Time:
|
||||
mxArray.TimestampValues = ConvertTimestampArray(array);
|
||||
return mxArray;
|
||||
|
||||
default:
|
||||
mxArray.ElementDataType = MxDataType.Unknown;
|
||||
mxArray.RawElementDataType = (int)expectedElementDataType;
|
||||
mxArray.RawDiagnostic = CreateRawDiagnostic(array);
|
||||
mxArray.RawValues = ConvertRawArray(array);
|
||||
return mxArray;
|
||||
}
|
||||
}
|
||||
|
||||
private static MxValue ConvertScalar(
|
||||
object value,
|
||||
MxDataType expectedDataType)
|
||||
{
|
||||
System.Type valueType = value.GetType();
|
||||
string variantType = GetVariantTypeName(valueType);
|
||||
|
||||
switch (System.Type.GetTypeCode(valueType))
|
||||
{
|
||||
case TypeCode.Boolean:
|
||||
return new MxValue
|
||||
{
|
||||
DataType = MxDataType.Boolean,
|
||||
VariantType = variantType,
|
||||
BoolValue = (bool)value,
|
||||
};
|
||||
|
||||
case TypeCode.Byte:
|
||||
case TypeCode.SByte:
|
||||
case TypeCode.Int16:
|
||||
case TypeCode.UInt16:
|
||||
case TypeCode.Int32:
|
||||
return new MxValue
|
||||
{
|
||||
DataType = MxDataType.Integer,
|
||||
VariantType = variantType,
|
||||
Int32Value = System.Convert.ToInt32(value, CultureInfo.InvariantCulture),
|
||||
};
|
||||
|
||||
case TypeCode.UInt32:
|
||||
case TypeCode.Int64:
|
||||
return ConvertInt64Scalar(value, variantType, expectedDataType);
|
||||
|
||||
case TypeCode.UInt64:
|
||||
return ConvertUInt64Scalar((ulong)value, variantType, expectedDataType);
|
||||
|
||||
case TypeCode.Single:
|
||||
return new MxValue
|
||||
{
|
||||
DataType = MxDataType.Float,
|
||||
VariantType = variantType,
|
||||
FloatValue = (float)value,
|
||||
};
|
||||
|
||||
case TypeCode.Double:
|
||||
return new MxValue
|
||||
{
|
||||
DataType = MxDataType.Double,
|
||||
VariantType = variantType,
|
||||
DoubleValue = (double)value,
|
||||
};
|
||||
|
||||
case TypeCode.Decimal:
|
||||
return new MxValue
|
||||
{
|
||||
DataType = MxDataType.Double,
|
||||
VariantType = variantType,
|
||||
DoubleValue = System.Convert.ToDouble(value, CultureInfo.InvariantCulture),
|
||||
RawDiagnostic = "Decimal value projected to double.",
|
||||
};
|
||||
|
||||
case TypeCode.String:
|
||||
case TypeCode.Char:
|
||||
return new MxValue
|
||||
{
|
||||
DataType = MxDataType.String,
|
||||
VariantType = variantType,
|
||||
StringValue = System.Convert.ToString(value, CultureInfo.InvariantCulture) ?? string.Empty,
|
||||
};
|
||||
|
||||
case TypeCode.DateTime:
|
||||
return new MxValue
|
||||
{
|
||||
DataType = MxDataType.Time,
|
||||
VariantType = variantType,
|
||||
TimestampValue = ToTimestamp((DateTime)value),
|
||||
};
|
||||
|
||||
default:
|
||||
return CreateRawValue(value, expectedDataType);
|
||||
}
|
||||
}
|
||||
|
||||
private static MxValue ConvertInt64Scalar(
|
||||
object value,
|
||||
string variantType,
|
||||
MxDataType expectedDataType)
|
||||
{
|
||||
long longValue = System.Convert.ToInt64(value, CultureInfo.InvariantCulture);
|
||||
if (expectedDataType == MxDataType.Time)
|
||||
{
|
||||
return new MxValue
|
||||
{
|
||||
DataType = MxDataType.Time,
|
||||
VariantType = variantType,
|
||||
TimestampValue = Timestamp.FromDateTime(DateTime.FromFileTimeUtc(longValue)),
|
||||
};
|
||||
}
|
||||
|
||||
return new MxValue
|
||||
{
|
||||
DataType = MxDataType.Integer,
|
||||
VariantType = variantType,
|
||||
Int64Value = longValue,
|
||||
};
|
||||
}
|
||||
|
||||
private static MxValue ConvertUInt64Scalar(
|
||||
ulong value,
|
||||
string variantType,
|
||||
MxDataType expectedDataType)
|
||||
{
|
||||
if (expectedDataType == MxDataType.Time && value <= long.MaxValue)
|
||||
{
|
||||
return new MxValue
|
||||
{
|
||||
DataType = MxDataType.Time,
|
||||
VariantType = variantType,
|
||||
TimestampValue = Timestamp.FromDateTime(DateTime.FromFileTimeUtc((long)value)),
|
||||
};
|
||||
}
|
||||
|
||||
if (value <= long.MaxValue)
|
||||
{
|
||||
return new MxValue
|
||||
{
|
||||
DataType = MxDataType.Integer,
|
||||
VariantType = variantType,
|
||||
Int64Value = (long)value,
|
||||
};
|
||||
}
|
||||
|
||||
return CreateRawValue(value, expectedDataType, "UInt64 value exceeds Int64 range.");
|
||||
}
|
||||
|
||||
private static MxValue CreateNullValue(
|
||||
object? value,
|
||||
MxDataType expectedDataType)
|
||||
{
|
||||
return new MxValue
|
||||
{
|
||||
DataType = expectedDataType == MxDataType.Unspecified ? MxDataType.NoData : expectedDataType,
|
||||
VariantType = value is DBNull ? "VT_NULL" : "VT_EMPTY",
|
||||
IsNull = true,
|
||||
};
|
||||
}
|
||||
|
||||
private static MxValue CreateRawValue(
|
||||
object value,
|
||||
MxDataType expectedDataType,
|
||||
string? diagnosticPrefix = null)
|
||||
{
|
||||
string diagnostic = CreateRawDiagnostic(value);
|
||||
if (!string.IsNullOrWhiteSpace(diagnosticPrefix))
|
||||
{
|
||||
diagnostic = $"{diagnosticPrefix} {diagnostic}";
|
||||
}
|
||||
|
||||
return new MxValue
|
||||
{
|
||||
DataType = MxDataType.Unknown,
|
||||
VariantType = GetVariantTypeName(value.GetType()),
|
||||
RawDataType = (int)expectedDataType,
|
||||
RawDiagnostic = diagnostic,
|
||||
RawValue = ByteString.CopyFromUtf8(System.Convert.ToString(value, CultureInfo.InvariantCulture) ?? string.Empty),
|
||||
};
|
||||
}
|
||||
|
||||
private static BoolArray ConvertBoolArray(Array array)
|
||||
{
|
||||
BoolArray values = new();
|
||||
foreach (object? item in array)
|
||||
{
|
||||
values.Values.Add(item is not null && System.Convert.ToBoolean(item, CultureInfo.InvariantCulture));
|
||||
}
|
||||
|
||||
return values;
|
||||
}
|
||||
|
||||
private static Int32Array ConvertInt32Array(Array array)
|
||||
{
|
||||
Int32Array values = new();
|
||||
foreach (object? item in array)
|
||||
{
|
||||
values.Values.Add(item is null ? 0 : System.Convert.ToInt32(item, CultureInfo.InvariantCulture));
|
||||
}
|
||||
|
||||
return values;
|
||||
}
|
||||
|
||||
private static Int64Array ConvertInt64Array(Array array)
|
||||
{
|
||||
Int64Array values = new();
|
||||
foreach (object? item in array)
|
||||
{
|
||||
values.Values.Add(item is null ? 0 : System.Convert.ToInt64(item, CultureInfo.InvariantCulture));
|
||||
}
|
||||
|
||||
return values;
|
||||
}
|
||||
|
||||
private static FloatArray ConvertFloatArray(Array array)
|
||||
{
|
||||
FloatArray values = new();
|
||||
foreach (object? item in array)
|
||||
{
|
||||
values.Values.Add(item is null ? 0 : System.Convert.ToSingle(item, CultureInfo.InvariantCulture));
|
||||
}
|
||||
|
||||
return values;
|
||||
}
|
||||
|
||||
private static DoubleArray ConvertDoubleArray(Array array)
|
||||
{
|
||||
DoubleArray values = new();
|
||||
foreach (object? item in array)
|
||||
{
|
||||
values.Values.Add(item is null ? 0 : System.Convert.ToDouble(item, CultureInfo.InvariantCulture));
|
||||
}
|
||||
|
||||
return values;
|
||||
}
|
||||
|
||||
private static StringArray ConvertStringArray(Array array)
|
||||
{
|
||||
StringArray values = new();
|
||||
foreach (object? item in array)
|
||||
{
|
||||
values.Values.Add(item is null ? string.Empty : System.Convert.ToString(item, CultureInfo.InvariantCulture) ?? string.Empty);
|
||||
}
|
||||
|
||||
return values;
|
||||
}
|
||||
|
||||
private static TimestampArray ConvertTimestampArray(Array array)
|
||||
{
|
||||
TimestampArray values = new();
|
||||
foreach (object? item in array)
|
||||
{
|
||||
if (item is null)
|
||||
{
|
||||
values.Values.Add(Timestamp.FromDateTime(new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc)));
|
||||
}
|
||||
else if (item is DateTime dateTime)
|
||||
{
|
||||
values.Values.Add(ToTimestamp(dateTime));
|
||||
}
|
||||
else
|
||||
{
|
||||
long fileTime = System.Convert.ToInt64(item, CultureInfo.InvariantCulture);
|
||||
values.Values.Add(Timestamp.FromDateTime(DateTime.FromFileTimeUtc(fileTime)));
|
||||
}
|
||||
}
|
||||
|
||||
return values;
|
||||
}
|
||||
|
||||
private static RawArray ConvertRawArray(Array array)
|
||||
{
|
||||
RawArray values = new();
|
||||
foreach (object? item in array)
|
||||
{
|
||||
string rawValue = item is null
|
||||
? string.Empty
|
||||
: System.Convert.ToString(item, CultureInfo.InvariantCulture) ?? string.Empty;
|
||||
values.Values.Add(ByteString.CopyFromUtf8(rawValue));
|
||||
}
|
||||
|
||||
return values;
|
||||
}
|
||||
|
||||
private static MxDataType ResolveArrayElementDataType(
|
||||
System.Type? elementType,
|
||||
MxDataType expectedElementDataType)
|
||||
{
|
||||
if (expectedElementDataType != MxDataType.Unspecified)
|
||||
{
|
||||
return expectedElementDataType;
|
||||
}
|
||||
|
||||
if (elementType == typeof(bool))
|
||||
{
|
||||
return MxDataType.Boolean;
|
||||
}
|
||||
|
||||
if (elementType == typeof(byte)
|
||||
|| elementType == typeof(sbyte)
|
||||
|| elementType == typeof(short)
|
||||
|| elementType == typeof(ushort)
|
||||
|| elementType == typeof(int)
|
||||
|| elementType == typeof(uint)
|
||||
|| elementType == typeof(long)
|
||||
|| elementType == typeof(ulong))
|
||||
{
|
||||
return MxDataType.Integer;
|
||||
}
|
||||
|
||||
if (elementType == typeof(float))
|
||||
{
|
||||
return MxDataType.Float;
|
||||
}
|
||||
|
||||
if (elementType == typeof(double) || elementType == typeof(decimal))
|
||||
{
|
||||
return MxDataType.Double;
|
||||
}
|
||||
|
||||
if (elementType == typeof(string) || elementType == typeof(char))
|
||||
{
|
||||
return MxDataType.String;
|
||||
}
|
||||
|
||||
if (elementType == typeof(DateTime))
|
||||
{
|
||||
return MxDataType.Time;
|
||||
}
|
||||
|
||||
return MxDataType.Unknown;
|
||||
}
|
||||
|
||||
private static Timestamp ToTimestamp(DateTime dateTime)
|
||||
{
|
||||
DateTime utcDateTime = dateTime.Kind switch
|
||||
{
|
||||
DateTimeKind.Utc => dateTime,
|
||||
DateTimeKind.Local => dateTime.ToUniversalTime(),
|
||||
_ => DateTime.SpecifyKind(dateTime, DateTimeKind.Utc),
|
||||
};
|
||||
|
||||
return Timestamp.FromDateTime(utcDateTime);
|
||||
}
|
||||
|
||||
private static string CreateArrayVariantType(Array array)
|
||||
{
|
||||
System.Type? elementType = array.GetType().GetElementType();
|
||||
return $"SAFEARRAY({GetVariantTypeName(elementType)})";
|
||||
}
|
||||
|
||||
private static string GetVariantTypeName(System.Type? type)
|
||||
{
|
||||
if (type is null)
|
||||
{
|
||||
return "VT_EMPTY";
|
||||
}
|
||||
|
||||
System.Type nonNullableType = Nullable.GetUnderlyingType(type) ?? type;
|
||||
if (nonNullableType == typeof(bool))
|
||||
{
|
||||
return "VT_BOOL";
|
||||
}
|
||||
|
||||
if (nonNullableType == typeof(byte))
|
||||
{
|
||||
return "VT_UI1";
|
||||
}
|
||||
|
||||
if (nonNullableType == typeof(sbyte))
|
||||
{
|
||||
return "VT_I1";
|
||||
}
|
||||
|
||||
if (nonNullableType == typeof(short))
|
||||
{
|
||||
return "VT_I2";
|
||||
}
|
||||
|
||||
if (nonNullableType == typeof(ushort))
|
||||
{
|
||||
return "VT_UI2";
|
||||
}
|
||||
|
||||
if (nonNullableType == typeof(int))
|
||||
{
|
||||
return "VT_I4";
|
||||
}
|
||||
|
||||
if (nonNullableType == typeof(uint))
|
||||
{
|
||||
return "VT_UI4";
|
||||
}
|
||||
|
||||
if (nonNullableType == typeof(long))
|
||||
{
|
||||
return "VT_I8";
|
||||
}
|
||||
|
||||
if (nonNullableType == typeof(ulong))
|
||||
{
|
||||
return "VT_UI8";
|
||||
}
|
||||
|
||||
if (nonNullableType == typeof(float))
|
||||
{
|
||||
return "VT_R4";
|
||||
}
|
||||
|
||||
if (nonNullableType == typeof(double) || nonNullableType == typeof(decimal))
|
||||
{
|
||||
return "VT_R8";
|
||||
}
|
||||
|
||||
if (nonNullableType == typeof(string) || nonNullableType == typeof(char))
|
||||
{
|
||||
return "VT_BSTR";
|
||||
}
|
||||
|
||||
if (nonNullableType == typeof(DateTime))
|
||||
{
|
||||
return "VT_DATE";
|
||||
}
|
||||
|
||||
return $"CLR:{nonNullableType.FullName}";
|
||||
}
|
||||
|
||||
private static string CreateRawDiagnostic(object value)
|
||||
{
|
||||
return $"Unsupported variant projection for CLR type '{value.GetType().FullName}'.";
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using MxGateway.Worker.Bootstrap;
|
||||
|
||||
namespace MxGateway.Worker.Ipc;
|
||||
|
||||
public interface IWorkerPipeClient
|
||||
{
|
||||
Task RunAsync(
|
||||
WorkerOptions options,
|
||||
CancellationToken cancellationToken = default);
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
using System;
|
||||
using MxGateway.Contracts.Proto;
|
||||
|
||||
namespace MxGateway.Worker.Ipc;
|
||||
|
||||
internal static class WorkerEnvelopeValidator
|
||||
{
|
||||
public static void Validate(
|
||||
WorkerEnvelope envelope,
|
||||
WorkerFrameProtocolOptions options)
|
||||
{
|
||||
if (envelope.ProtocolVersion != options.ProtocolVersion)
|
||||
{
|
||||
throw new WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode.ProtocolVersionMismatch,
|
||||
$"Worker envelope protocol version {envelope.ProtocolVersion} does not match expected version {options.ProtocolVersion}.");
|
||||
}
|
||||
|
||||
if (!string.Equals(envelope.SessionId, options.SessionId, StringComparison.Ordinal))
|
||||
{
|
||||
throw new WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode.SessionMismatch,
|
||||
"Worker envelope session id does not match the owning worker session.");
|
||||
}
|
||||
|
||||
if (envelope.BodyCase == WorkerEnvelope.BodyOneofCase.None)
|
||||
{
|
||||
throw new WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode.InvalidEnvelope,
|
||||
"Worker envelope must include a typed body.");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
namespace MxGateway.Worker.Ipc;
|
||||
|
||||
public enum WorkerFrameProtocolErrorCode
|
||||
{
|
||||
Unknown = 0,
|
||||
InvalidConfiguration = 1,
|
||||
EndOfStream = 2,
|
||||
MalformedLength = 3,
|
||||
MessageTooLarge = 4,
|
||||
InvalidEnvelope = 5,
|
||||
ProtocolVersionMismatch = 6,
|
||||
SessionMismatch = 7,
|
||||
NonceMismatch = 8,
|
||||
UnexpectedEnvelopeBody = 9,
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
using System;
|
||||
|
||||
namespace MxGateway.Worker.Ipc;
|
||||
|
||||
public sealed class WorkerFrameProtocolException : Exception
|
||||
{
|
||||
public WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode errorCode,
|
||||
string message)
|
||||
: base(message)
|
||||
{
|
||||
ErrorCode = errorCode;
|
||||
}
|
||||
|
||||
public WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode errorCode,
|
||||
string message,
|
||||
Exception innerException)
|
||||
: base(message, innerException)
|
||||
{
|
||||
ErrorCode = errorCode;
|
||||
}
|
||||
|
||||
public WorkerFrameProtocolErrorCode ErrorCode { get; }
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
using System;
|
||||
using MxGateway.Contracts;
|
||||
using MxGateway.Worker.Bootstrap;
|
||||
|
||||
namespace MxGateway.Worker.Ipc;
|
||||
|
||||
public sealed class WorkerFrameProtocolOptions
|
||||
{
|
||||
public const int DefaultMaxMessageBytes = 16 * 1024 * 1024;
|
||||
|
||||
public WorkerFrameProtocolOptions(WorkerOptions options)
|
||||
: this(
|
||||
options?.SessionId ?? throw new ArgumentNullException(nameof(options)),
|
||||
options.ProtocolVersion,
|
||||
options.Nonce,
|
||||
DefaultMaxMessageBytes)
|
||||
{
|
||||
}
|
||||
|
||||
public WorkerFrameProtocolOptions(
|
||||
string sessionId,
|
||||
uint protocolVersion,
|
||||
string nonce)
|
||||
: this(
|
||||
sessionId,
|
||||
protocolVersion,
|
||||
nonce,
|
||||
DefaultMaxMessageBytes)
|
||||
{
|
||||
}
|
||||
|
||||
public WorkerFrameProtocolOptions(
|
||||
string sessionId,
|
||||
uint protocolVersion,
|
||||
string nonce,
|
||||
int maxMessageBytes)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(sessionId))
|
||||
{
|
||||
throw new WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode.InvalidConfiguration,
|
||||
"Worker frame protocol requires a session id.");
|
||||
}
|
||||
|
||||
if (protocolVersion == 0)
|
||||
{
|
||||
throw new WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode.InvalidConfiguration,
|
||||
"Worker frame protocol requires a non-zero protocol version.");
|
||||
}
|
||||
|
||||
if (protocolVersion != GatewayContractInfo.WorkerProtocolVersion)
|
||||
{
|
||||
throw new WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode.ProtocolVersionMismatch,
|
||||
$"Worker frame protocol version {protocolVersion} is not supported.");
|
||||
}
|
||||
|
||||
if (string.IsNullOrWhiteSpace(nonce))
|
||||
{
|
||||
throw new WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode.InvalidConfiguration,
|
||||
"Worker frame protocol requires a nonce.");
|
||||
}
|
||||
|
||||
if (maxMessageBytes <= 0)
|
||||
{
|
||||
throw new WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode.InvalidConfiguration,
|
||||
"Worker frame protocol max message size must be greater than zero.");
|
||||
}
|
||||
|
||||
SessionId = sessionId;
|
||||
ProtocolVersion = protocolVersion;
|
||||
Nonce = nonce;
|
||||
MaxMessageBytes = maxMessageBytes;
|
||||
}
|
||||
|
||||
public string SessionId { get; }
|
||||
|
||||
public uint ProtocolVersion { get; }
|
||||
|
||||
public string Nonce { get; }
|
||||
|
||||
public int MaxMessageBytes { get; }
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
using System;
|
||||
using System.IO;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Google.Protobuf;
|
||||
using MxGateway.Contracts.Proto;
|
||||
|
||||
namespace MxGateway.Worker.Ipc;
|
||||
|
||||
public sealed class WorkerFrameReader
|
||||
{
|
||||
private readonly WorkerFrameProtocolOptions _options;
|
||||
private readonly Stream _stream;
|
||||
|
||||
public WorkerFrameReader(
|
||||
Stream stream,
|
||||
WorkerFrameProtocolOptions options)
|
||||
{
|
||||
_stream = stream ?? throw new ArgumentNullException(nameof(stream));
|
||||
_options = options ?? throw new ArgumentNullException(nameof(options));
|
||||
}
|
||||
|
||||
public async Task<WorkerEnvelope> ReadAsync(CancellationToken cancellationToken = default)
|
||||
{
|
||||
byte[] lengthPrefix = new byte[sizeof(uint)];
|
||||
await ReadExactlyOrThrowAsync(lengthPrefix, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
uint payloadLength = ReadUInt32LittleEndian(lengthPrefix);
|
||||
if (payloadLength == 0)
|
||||
{
|
||||
throw new WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode.MalformedLength,
|
||||
"Worker frame payload length must be greater than zero.");
|
||||
}
|
||||
|
||||
if (payloadLength > _options.MaxMessageBytes)
|
||||
{
|
||||
throw new WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode.MessageTooLarge,
|
||||
$"Worker frame payload length {payloadLength} exceeds the configured maximum of {_options.MaxMessageBytes} bytes.");
|
||||
}
|
||||
|
||||
byte[] payload = new byte[payloadLength];
|
||||
await ReadExactlyOrThrowAsync(payload, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
WorkerEnvelope envelope;
|
||||
try
|
||||
{
|
||||
envelope = WorkerEnvelope.Parser.ParseFrom(payload);
|
||||
}
|
||||
catch (InvalidProtocolBufferException exception)
|
||||
{
|
||||
throw new WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode.InvalidEnvelope,
|
||||
"Worker frame payload is not a valid WorkerEnvelope protobuf message.",
|
||||
exception);
|
||||
}
|
||||
|
||||
WorkerEnvelopeValidator.Validate(envelope, _options);
|
||||
|
||||
return envelope;
|
||||
}
|
||||
|
||||
private static uint ReadUInt32LittleEndian(byte[] buffer)
|
||||
{
|
||||
return (uint)buffer[0]
|
||||
| ((uint)buffer[1] << 8)
|
||||
| ((uint)buffer[2] << 16)
|
||||
| ((uint)buffer[3] << 24);
|
||||
}
|
||||
|
||||
private async Task ReadExactlyOrThrowAsync(
|
||||
byte[] buffer,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
int offset = 0;
|
||||
while (offset < buffer.Length)
|
||||
{
|
||||
int bytesRead = await _stream
|
||||
.ReadAsync(buffer, offset, buffer.Length - offset, cancellationToken)
|
||||
.ConfigureAwait(false);
|
||||
|
||||
if (bytesRead == 0)
|
||||
{
|
||||
throw new WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode.EndOfStream,
|
||||
"Worker frame ended before the expected number of bytes were read.");
|
||||
}
|
||||
|
||||
offset += bytesRead;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
using System;
|
||||
using System.IO;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Google.Protobuf;
|
||||
using MxGateway.Contracts.Proto;
|
||||
|
||||
namespace MxGateway.Worker.Ipc;
|
||||
|
||||
public sealed class WorkerFrameWriter
|
||||
{
|
||||
private readonly WorkerFrameProtocolOptions _options;
|
||||
private readonly SemaphoreSlim _writeLock = new(1, 1);
|
||||
private readonly Stream _stream;
|
||||
|
||||
public WorkerFrameWriter(
|
||||
Stream stream,
|
||||
WorkerFrameProtocolOptions options)
|
||||
{
|
||||
_stream = stream ?? throw new ArgumentNullException(nameof(stream));
|
||||
_options = options ?? throw new ArgumentNullException(nameof(options));
|
||||
}
|
||||
|
||||
public async Task WriteAsync(
|
||||
WorkerEnvelope envelope,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (envelope is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(envelope));
|
||||
}
|
||||
|
||||
WorkerEnvelopeValidator.Validate(envelope, _options);
|
||||
|
||||
int payloadLength = envelope.CalculateSize();
|
||||
if (payloadLength == 0)
|
||||
{
|
||||
throw new WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode.InvalidEnvelope,
|
||||
"Worker envelope cannot serialize to an empty payload.");
|
||||
}
|
||||
|
||||
if (payloadLength > _options.MaxMessageBytes)
|
||||
{
|
||||
throw new WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode.MessageTooLarge,
|
||||
$"Worker envelope payload length {payloadLength} exceeds the configured maximum of {_options.MaxMessageBytes} bytes.");
|
||||
}
|
||||
|
||||
byte[] payload = envelope.ToByteArray();
|
||||
byte[] lengthPrefix = new byte[sizeof(uint)];
|
||||
WriteUInt32LittleEndian(lengthPrefix, (uint)payloadLength);
|
||||
|
||||
await _writeLock.WaitAsync(cancellationToken).ConfigureAwait(false);
|
||||
try
|
||||
{
|
||||
await _stream.WriteAsync(lengthPrefix, 0, lengthPrefix.Length, cancellationToken).ConfigureAwait(false);
|
||||
await _stream.WriteAsync(payload, 0, payload.Length, cancellationToken).ConfigureAwait(false);
|
||||
await _stream.FlushAsync(cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
finally
|
||||
{
|
||||
_writeLock.Release();
|
||||
}
|
||||
}
|
||||
|
||||
private static void WriteUInt32LittleEndian(
|
||||
byte[] buffer,
|
||||
uint value)
|
||||
{
|
||||
buffer[0] = (byte)value;
|
||||
buffer[1] = (byte)(value >> 8);
|
||||
buffer[2] = (byte)(value >> 16);
|
||||
buffer[3] = (byte)(value >> 24);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
using System;
|
||||
using System.IO.Pipes;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using MxGateway.Worker.Bootstrap;
|
||||
|
||||
namespace MxGateway.Worker.Ipc;
|
||||
|
||||
public sealed class WorkerPipeClient : IWorkerPipeClient
|
||||
{
|
||||
public const int DefaultConnectTimeoutMilliseconds = 30000;
|
||||
|
||||
private readonly int _connectTimeoutMilliseconds;
|
||||
|
||||
public WorkerPipeClient()
|
||||
: this(DefaultConnectTimeoutMilliseconds)
|
||||
{
|
||||
}
|
||||
|
||||
public WorkerPipeClient(int connectTimeoutMilliseconds)
|
||||
{
|
||||
if (connectTimeoutMilliseconds <= 0)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(
|
||||
nameof(connectTimeoutMilliseconds),
|
||||
"Worker pipe connect timeout must be greater than zero.");
|
||||
}
|
||||
|
||||
_connectTimeoutMilliseconds = connectTimeoutMilliseconds;
|
||||
}
|
||||
|
||||
public async Task RunAsync(
|
||||
WorkerOptions options,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (options is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(options));
|
||||
}
|
||||
|
||||
WorkerFrameProtocolOptions frameOptions = new(options);
|
||||
|
||||
using NamedPipeClientStream pipe = new(
|
||||
".",
|
||||
options.PipeName,
|
||||
PipeDirection.InOut,
|
||||
PipeOptions.Asynchronous);
|
||||
|
||||
await ConnectAsync(pipe, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
WorkerPipeSession session = new(pipe, frameOptions);
|
||||
await session.CompleteStartupHandshakeAsync(cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
private Task ConnectAsync(
|
||||
NamedPipeClientStream pipe,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
return Task.Run(
|
||||
() =>
|
||||
{
|
||||
cancellationToken.ThrowIfCancellationRequested();
|
||||
pipe.Connect(_connectTimeoutMilliseconds);
|
||||
},
|
||||
cancellationToken);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,307 @@
|
||||
using System;
|
||||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Google.Protobuf.WellKnownTypes;
|
||||
using MxGateway.Contracts.Proto;
|
||||
using MxGateway.Worker.MxAccess;
|
||||
|
||||
namespace MxGateway.Worker.Ipc;
|
||||
|
||||
public sealed class WorkerPipeSession
|
||||
{
|
||||
private readonly WorkerFrameProtocolOptions _options;
|
||||
private readonly Func<int> _processIdProvider;
|
||||
private readonly WorkerFrameReader _reader;
|
||||
private readonly WorkerFrameWriter _writer;
|
||||
private MxAccessStaSession? _mxAccessStaSession;
|
||||
private long _nextSequence;
|
||||
|
||||
public WorkerPipeSession(
|
||||
Stream stream,
|
||||
WorkerFrameProtocolOptions options)
|
||||
: this(
|
||||
new WorkerFrameReader(stream, options),
|
||||
new WorkerFrameWriter(stream, options),
|
||||
options,
|
||||
() => Process.GetCurrentProcess().Id)
|
||||
{
|
||||
}
|
||||
|
||||
public WorkerPipeSession(
|
||||
WorkerFrameReader reader,
|
||||
WorkerFrameWriter writer,
|
||||
WorkerFrameProtocolOptions options,
|
||||
Func<int> processIdProvider)
|
||||
{
|
||||
_reader = reader ?? throw new ArgumentNullException(nameof(reader));
|
||||
_writer = writer ?? throw new ArgumentNullException(nameof(writer));
|
||||
_options = options ?? throw new ArgumentNullException(nameof(options));
|
||||
_processIdProvider = processIdProvider ?? throw new ArgumentNullException(nameof(processIdProvider));
|
||||
}
|
||||
|
||||
public Task CompleteStartupHandshakeAsync(CancellationToken cancellationToken = default)
|
||||
{
|
||||
return CompleteStartupHandshakeAsync(InitializeMxAccessAsync, cancellationToken);
|
||||
}
|
||||
|
||||
public async Task CompleteStartupHandshakeAsync(
|
||||
Func<CancellationToken, Task> initializeMxAccessAsync,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (initializeMxAccessAsync is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(initializeMxAccessAsync));
|
||||
}
|
||||
|
||||
await CompleteStartupHandshakeAsync(
|
||||
async innerCancellationToken =>
|
||||
{
|
||||
await initializeMxAccessAsync(innerCancellationToken).ConfigureAwait(false);
|
||||
return CreateWorkerReady();
|
||||
},
|
||||
cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
public async Task CompleteStartupHandshakeAsync(
|
||||
Func<CancellationToken, Task<WorkerReady>> initializeMxAccessAsync,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (initializeMxAccessAsync is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(initializeMxAccessAsync));
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
WorkerEnvelope envelope = await _reader.ReadAsync(cancellationToken).ConfigureAwait(false);
|
||||
ValidateGatewayHello(envelope);
|
||||
|
||||
await WriteWorkerHelloAsync(cancellationToken).ConfigureAwait(false);
|
||||
WorkerReady ready = await initializeMxAccessAsync(cancellationToken).ConfigureAwait(false);
|
||||
await WriteWorkerReadyAsync(ready, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
catch (WorkerFrameProtocolException exception)
|
||||
{
|
||||
await TryWriteFaultAsync(exception, cancellationToken).ConfigureAwait(false);
|
||||
throw;
|
||||
}
|
||||
catch (Exception exception) when (exception is not OperationCanceledException)
|
||||
{
|
||||
await TryWriteFaultAsync(MxAccessCreationException.From(exception), cancellationToken)
|
||||
.ConfigureAwait(false);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
private void ValidateGatewayHello(WorkerEnvelope envelope)
|
||||
{
|
||||
if (envelope.BodyCase != WorkerEnvelope.BodyOneofCase.GatewayHello)
|
||||
{
|
||||
throw new WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode.UnexpectedEnvelopeBody,
|
||||
"Worker expected GatewayHello during startup handshake.");
|
||||
}
|
||||
|
||||
GatewayHello gatewayHello = envelope.GatewayHello;
|
||||
if (gatewayHello.SupportedProtocolVersion != _options.ProtocolVersion)
|
||||
{
|
||||
throw new WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode.ProtocolVersionMismatch,
|
||||
$"GatewayHello supported protocol version {gatewayHello.SupportedProtocolVersion} does not match expected version {_options.ProtocolVersion}.");
|
||||
}
|
||||
|
||||
if (!string.Equals(gatewayHello.Nonce, _options.Nonce, StringComparison.Ordinal))
|
||||
{
|
||||
throw new WorkerFrameProtocolException(
|
||||
WorkerFrameProtocolErrorCode.NonceMismatch,
|
||||
"GatewayHello nonce does not match the worker launch nonce.");
|
||||
}
|
||||
}
|
||||
|
||||
private Task WriteWorkerHelloAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
return _writer.WriteAsync(
|
||||
CreateEnvelope(new WorkerHello
|
||||
{
|
||||
ProtocolVersion = _options.ProtocolVersion,
|
||||
Nonce = _options.Nonce,
|
||||
WorkerProcessId = _processIdProvider(),
|
||||
WorkerVersion = typeof(WorkerPipeSession).Assembly.GetName().Version?.ToString() ?? string.Empty,
|
||||
}),
|
||||
cancellationToken);
|
||||
}
|
||||
|
||||
private Task WriteWorkerReadyAsync(
|
||||
WorkerReady ready,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
return _writer.WriteAsync(CreateEnvelope(ready), cancellationToken);
|
||||
}
|
||||
|
||||
private async Task TryWriteFaultAsync(
|
||||
WorkerFrameProtocolException exception,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
try
|
||||
{
|
||||
await _writer
|
||||
.WriteAsync(CreateEnvelope(CreateFault(exception)), cancellationToken)
|
||||
.ConfigureAwait(false);
|
||||
}
|
||||
catch (Exception faultWriteException) when (
|
||||
faultWriteException is IOException
|
||||
|| faultWriteException is ObjectDisposedException
|
||||
|| faultWriteException is WorkerFrameProtocolException)
|
||||
{
|
||||
// The original protocol failure is the actionable error.
|
||||
}
|
||||
}
|
||||
|
||||
private async Task TryWriteFaultAsync(
|
||||
MxAccessCreationException exception,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
try
|
||||
{
|
||||
await _writer
|
||||
.WriteAsync(CreateEnvelope(CreateFault(exception)), cancellationToken)
|
||||
.ConfigureAwait(false);
|
||||
}
|
||||
catch (Exception faultWriteException) when (
|
||||
faultWriteException is IOException
|
||||
|| faultWriteException is ObjectDisposedException
|
||||
|| faultWriteException is WorkerFrameProtocolException)
|
||||
{
|
||||
// The MXAccess creation failure is the actionable error.
|
||||
}
|
||||
}
|
||||
|
||||
private WorkerEnvelope CreateEnvelope(WorkerHello hello)
|
||||
{
|
||||
return CreateBaseEnvelope(hello);
|
||||
}
|
||||
|
||||
private WorkerEnvelope CreateEnvelope(WorkerReady ready)
|
||||
{
|
||||
return CreateBaseEnvelope(ready);
|
||||
}
|
||||
|
||||
private WorkerEnvelope CreateEnvelope(WorkerFault fault)
|
||||
{
|
||||
return CreateBaseEnvelope(fault);
|
||||
}
|
||||
|
||||
private WorkerEnvelope CreateBaseEnvelope(WorkerHello body)
|
||||
{
|
||||
WorkerEnvelope envelope = CreateBaseEnvelope();
|
||||
envelope.WorkerHello = body;
|
||||
return envelope;
|
||||
}
|
||||
|
||||
private WorkerEnvelope CreateBaseEnvelope(WorkerReady body)
|
||||
{
|
||||
WorkerEnvelope envelope = CreateBaseEnvelope();
|
||||
envelope.WorkerReady = body;
|
||||
return envelope;
|
||||
}
|
||||
|
||||
private WorkerEnvelope CreateBaseEnvelope(WorkerFault body)
|
||||
{
|
||||
WorkerEnvelope envelope = CreateBaseEnvelope();
|
||||
envelope.WorkerFault = body;
|
||||
return envelope;
|
||||
}
|
||||
|
||||
private WorkerEnvelope CreateBaseEnvelope()
|
||||
{
|
||||
return new WorkerEnvelope
|
||||
{
|
||||
ProtocolVersion = _options.ProtocolVersion,
|
||||
SessionId = _options.SessionId,
|
||||
Sequence = NextSequence(),
|
||||
};
|
||||
}
|
||||
|
||||
private ulong NextSequence()
|
||||
{
|
||||
return unchecked((ulong)Interlocked.Increment(ref _nextSequence));
|
||||
}
|
||||
|
||||
private async Task<WorkerReady> InitializeMxAccessAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
_mxAccessStaSession = new MxAccessStaSession();
|
||||
try
|
||||
{
|
||||
return await _mxAccessStaSession
|
||||
.StartAsync(_processIdProvider(), cancellationToken)
|
||||
.ConfigureAwait(false);
|
||||
}
|
||||
catch
|
||||
{
|
||||
_mxAccessStaSession.Dispose();
|
||||
_mxAccessStaSession = null;
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
private WorkerReady CreateWorkerReady()
|
||||
{
|
||||
return new WorkerReady
|
||||
{
|
||||
WorkerProcessId = _processIdProvider(),
|
||||
MxaccessProgid = MxAccessInteropInfo.ProgId,
|
||||
MxaccessClsid = MxAccessInteropInfo.Clsid,
|
||||
ReadyTimestamp = Timestamp.FromDateTime(DateTime.UtcNow),
|
||||
};
|
||||
}
|
||||
|
||||
private static WorkerFault CreateFault(WorkerFrameProtocolException exception)
|
||||
{
|
||||
return new WorkerFault
|
||||
{
|
||||
Category = MapFaultCategory(exception.ErrorCode),
|
||||
ExceptionType = exception.GetType().FullName ?? string.Empty,
|
||||
DiagnosticMessage = exception.Message,
|
||||
ProtocolStatus = new ProtocolStatus
|
||||
{
|
||||
Code = ProtocolStatusCode.ProtocolViolation,
|
||||
Message = exception.Message,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
private static WorkerFault CreateFault(MxAccessCreationException exception)
|
||||
{
|
||||
WorkerFault fault = new()
|
||||
{
|
||||
Category = WorkerFaultCategory.MxaccessCreationFailed,
|
||||
ExceptionType = exception.InnerException?.GetType().FullName ?? exception.GetType().FullName ?? string.Empty,
|
||||
DiagnosticMessage = exception.Message,
|
||||
ProtocolStatus = new ProtocolStatus
|
||||
{
|
||||
Code = ProtocolStatusCode.WorkerUnavailable,
|
||||
Message = exception.Message,
|
||||
},
|
||||
};
|
||||
|
||||
int? hresult = MxAccessCreationException.ExtractHResult(exception);
|
||||
if (hresult.HasValue)
|
||||
{
|
||||
fault.Hresult = hresult.Value;
|
||||
}
|
||||
|
||||
return fault;
|
||||
}
|
||||
|
||||
private static WorkerFaultCategory MapFaultCategory(WorkerFrameProtocolErrorCode errorCode)
|
||||
{
|
||||
return errorCode switch
|
||||
{
|
||||
WorkerFrameProtocolErrorCode.ProtocolVersionMismatch => WorkerFaultCategory.ProtocolMismatch,
|
||||
WorkerFrameProtocolErrorCode.EndOfStream => WorkerFaultCategory.PipeDisconnected,
|
||||
_ => WorkerFaultCategory.ProtocolViolation,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
namespace MxGateway.Worker.MxAccess;
|
||||
|
||||
public interface IMxAccessComObjectFactory
|
||||
{
|
||||
object Create();
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
namespace MxGateway.Worker.MxAccess;
|
||||
|
||||
public interface IMxAccessEventSink
|
||||
{
|
||||
void Attach(object mxAccessComObject);
|
||||
|
||||
void Detach();
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
using ArchestrA.MxAccess;
|
||||
|
||||
namespace MxGateway.Worker.MxAccess;
|
||||
|
||||
public sealed class MxAccessBaseEventSink : IMxAccessEventSink
|
||||
{
|
||||
private LMXProxyServerClass? server;
|
||||
|
||||
public void Attach(object mxAccessComObject)
|
||||
{
|
||||
server = (LMXProxyServerClass)mxAccessComObject;
|
||||
server.OnDataChange += OnDataChange;
|
||||
server.OnWriteComplete += OnWriteComplete;
|
||||
server.OperationComplete += OperationComplete;
|
||||
server.OnBufferedDataChange += OnBufferedDataChange;
|
||||
}
|
||||
|
||||
public void Detach()
|
||||
{
|
||||
if (server is null)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
server.OnDataChange -= OnDataChange;
|
||||
server.OnWriteComplete -= OnWriteComplete;
|
||||
server.OperationComplete -= OperationComplete;
|
||||
server.OnBufferedDataChange -= OnBufferedDataChange;
|
||||
server = null;
|
||||
}
|
||||
|
||||
private static void OnDataChange(
|
||||
int hLMXServerHandle,
|
||||
int phItemHandle,
|
||||
object pvItemValue,
|
||||
int pwItemQuality,
|
||||
object pftItemTimeStamp,
|
||||
ref MXSTATUS_PROXY[] pVars)
|
||||
{
|
||||
}
|
||||
|
||||
private static void OnWriteComplete(
|
||||
int hLMXServerHandle,
|
||||
int phItemHandle,
|
||||
ref MXSTATUS_PROXY[] pVars)
|
||||
{
|
||||
}
|
||||
|
||||
private static void OperationComplete(
|
||||
int hLMXServerHandle,
|
||||
int phItemHandle,
|
||||
ref MXSTATUS_PROXY[] pVars)
|
||||
{
|
||||
}
|
||||
|
||||
private static void OnBufferedDataChange(
|
||||
int hLMXServerHandle,
|
||||
int phItemHandle,
|
||||
MxDataType dtDataType,
|
||||
object pvItemValue,
|
||||
object pwItemQuality,
|
||||
object pftItemTimeStamp,
|
||||
ref MXSTATUS_PROXY[] pVars)
|
||||
{
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
using ArchestrA.MxAccess;
|
||||
|
||||
namespace MxGateway.Worker.MxAccess;
|
||||
|
||||
public sealed class MxAccessComObjectFactory : IMxAccessComObjectFactory
|
||||
{
|
||||
public object Create()
|
||||
{
|
||||
return new LMXProxyServerClass();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
using System;
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace MxGateway.Worker.MxAccess;
|
||||
|
||||
public sealed class MxAccessCreationException : Exception
|
||||
{
|
||||
public MxAccessCreationException(Exception innerException)
|
||||
: base(
|
||||
$"Failed to create MXAccess COM object {MxAccessInteropInfo.ComClassName} ({MxAccessInteropInfo.ProgId}).",
|
||||
innerException)
|
||||
{
|
||||
AttemptedProgId = MxAccessInteropInfo.ProgId;
|
||||
AttemptedClsid = MxAccessInteropInfo.Clsid;
|
||||
AttemptedComClassName = MxAccessInteropInfo.ComClassName;
|
||||
HResult = innerException.HResult;
|
||||
}
|
||||
|
||||
public string AttemptedProgId { get; }
|
||||
|
||||
public string AttemptedClsid { get; }
|
||||
|
||||
public string AttemptedComClassName { get; }
|
||||
|
||||
public int? CapturedHResult => HResult == 0 ? null : HResult;
|
||||
|
||||
public static MxAccessCreationException From(Exception exception)
|
||||
{
|
||||
return exception is MxAccessCreationException creationException
|
||||
? creationException
|
||||
: new MxAccessCreationException(exception);
|
||||
}
|
||||
|
||||
public static int? ExtractHResult(Exception exception)
|
||||
{
|
||||
if (exception is MxAccessCreationException creationException)
|
||||
{
|
||||
return creationException.CapturedHResult;
|
||||
}
|
||||
|
||||
if (exception is COMException comException)
|
||||
{
|
||||
return comException.HResult;
|
||||
}
|
||||
|
||||
return exception.HResult == 0 ? null : exception.HResult;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
using System;
|
||||
using System.Runtime.InteropServices;
|
||||
using Google.Protobuf.WellKnownTypes;
|
||||
using MxGateway.Contracts.Proto;
|
||||
|
||||
namespace MxGateway.Worker.MxAccess;
|
||||
|
||||
public sealed class MxAccessSession : IDisposable
|
||||
{
|
||||
private readonly object mxAccessComObject;
|
||||
private readonly IMxAccessEventSink eventSink;
|
||||
private bool disposed;
|
||||
|
||||
private MxAccessSession(
|
||||
object mxAccessComObject,
|
||||
IMxAccessEventSink eventSink,
|
||||
int creationThreadId)
|
||||
{
|
||||
this.mxAccessComObject = mxAccessComObject ?? throw new ArgumentNullException(nameof(mxAccessComObject));
|
||||
this.eventSink = eventSink ?? throw new ArgumentNullException(nameof(eventSink));
|
||||
CreationThreadId = creationThreadId;
|
||||
}
|
||||
|
||||
public int CreationThreadId { get; }
|
||||
|
||||
public WorkerReady CreateWorkerReady(int workerProcessId)
|
||||
{
|
||||
return new WorkerReady
|
||||
{
|
||||
WorkerProcessId = workerProcessId,
|
||||
MxaccessProgid = MxAccessInteropInfo.ProgId,
|
||||
MxaccessClsid = MxAccessInteropInfo.Clsid,
|
||||
ReadyTimestamp = Timestamp.FromDateTime(DateTime.UtcNow),
|
||||
};
|
||||
}
|
||||
|
||||
public static MxAccessSession Create(
|
||||
IMxAccessComObjectFactory factory,
|
||||
IMxAccessEventSink eventSink)
|
||||
{
|
||||
if (factory is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(factory));
|
||||
}
|
||||
|
||||
if (eventSink is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(eventSink));
|
||||
}
|
||||
|
||||
object? mxAccessComObject = null;
|
||||
|
||||
try
|
||||
{
|
||||
mxAccessComObject = factory.Create();
|
||||
if (mxAccessComObject is null)
|
||||
{
|
||||
throw new InvalidOperationException("MXAccess COM factory returned null.");
|
||||
}
|
||||
|
||||
eventSink.Attach(mxAccessComObject);
|
||||
|
||||
return new MxAccessSession(
|
||||
mxAccessComObject,
|
||||
eventSink,
|
||||
Environment.CurrentManagedThreadId);
|
||||
}
|
||||
catch (Exception exception)
|
||||
{
|
||||
eventSink.Detach();
|
||||
|
||||
if (mxAccessComObject is not null && Marshal.IsComObject(mxAccessComObject))
|
||||
{
|
||||
Marshal.FinalReleaseComObject(mxAccessComObject);
|
||||
}
|
||||
|
||||
throw MxAccessCreationException.From(exception);
|
||||
}
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
if (disposed)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
eventSink.Detach();
|
||||
|
||||
if (Marshal.IsComObject(mxAccessComObject))
|
||||
{
|
||||
Marshal.FinalReleaseComObject(mxAccessComObject);
|
||||
}
|
||||
|
||||
disposed = true;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
using System;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using MxGateway.Contracts.Proto;
|
||||
using MxGateway.Worker.Sta;
|
||||
|
||||
namespace MxGateway.Worker.MxAccess;
|
||||
|
||||
public sealed class MxAccessStaSession : IDisposable
|
||||
{
|
||||
private readonly IMxAccessComObjectFactory factory;
|
||||
private readonly IMxAccessEventSink eventSink;
|
||||
private readonly StaRuntime staRuntime;
|
||||
private MxAccessSession? session;
|
||||
private bool disposed;
|
||||
|
||||
public MxAccessStaSession()
|
||||
: this(
|
||||
new StaRuntime(),
|
||||
new MxAccessComObjectFactory(),
|
||||
new MxAccessBaseEventSink())
|
||||
{
|
||||
}
|
||||
|
||||
public MxAccessStaSession(
|
||||
StaRuntime staRuntime,
|
||||
IMxAccessComObjectFactory factory,
|
||||
IMxAccessEventSink eventSink)
|
||||
{
|
||||
this.staRuntime = staRuntime ?? throw new ArgumentNullException(nameof(staRuntime));
|
||||
this.factory = factory ?? throw new ArgumentNullException(nameof(factory));
|
||||
this.eventSink = eventSink ?? throw new ArgumentNullException(nameof(eventSink));
|
||||
}
|
||||
|
||||
public Task<WorkerReady> StartAsync(
|
||||
int workerProcessId,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
staRuntime.Start();
|
||||
|
||||
return staRuntime.InvokeAsync(
|
||||
() =>
|
||||
{
|
||||
if (session is not null)
|
||||
{
|
||||
throw new InvalidOperationException("MXAccess COM session has already been created.");
|
||||
}
|
||||
|
||||
session = MxAccessSession.Create(factory, eventSink);
|
||||
return session.CreateWorkerReady(workerProcessId);
|
||||
},
|
||||
cancellationToken);
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
if (disposed)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if (session is not null)
|
||||
{
|
||||
staRuntime.InvokeAsync(() => session.Dispose()).GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
staRuntime.Dispose();
|
||||
disposed = true;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
namespace MxGateway.Worker.Sta;
|
||||
|
||||
public interface IStaComApartmentInitializer
|
||||
{
|
||||
void Initialize();
|
||||
|
||||
void Uninitialize();
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
namespace MxGateway.Worker.Sta;
|
||||
|
||||
internal interface IStaWorkItem
|
||||
{
|
||||
void CancelBeforeExecution();
|
||||
|
||||
void Execute();
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
using System;
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace MxGateway.Worker.Sta;
|
||||
|
||||
public sealed class StaComApartmentInitializer : IStaComApartmentInitializer
|
||||
{
|
||||
private const uint CoInitializeApartmentThreaded = 0x2;
|
||||
private const int SOk = 0;
|
||||
private const int SFalse = 1;
|
||||
|
||||
public void Initialize()
|
||||
{
|
||||
int hresult = CoInitializeEx(IntPtr.Zero, CoInitializeApartmentThreaded);
|
||||
if (hresult != SOk && hresult != SFalse)
|
||||
{
|
||||
throw new COMException("Failed to initialize the worker STA COM apartment.", hresult);
|
||||
}
|
||||
}
|
||||
|
||||
public void Uninitialize()
|
||||
{
|
||||
CoUninitialize();
|
||||
}
|
||||
|
||||
[DllImport("ole32.dll")]
|
||||
private static extern int CoInitializeEx(IntPtr reserved, uint coInit);
|
||||
|
||||
[DllImport("ole32.dll")]
|
||||
private static extern void CoUninitialize();
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
using System;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Threading;
|
||||
using Microsoft.Win32.SafeHandles;
|
||||
|
||||
namespace MxGateway.Worker.Sta;
|
||||
|
||||
public sealed class StaMessagePump
|
||||
{
|
||||
private const uint Infinite = 0xFFFFFFFF;
|
||||
private const uint MsgWaitFailed = 0xFFFFFFFF;
|
||||
private const uint MwmoInputAvailable = 0x0004;
|
||||
private const uint PmRemove = 0x0001;
|
||||
private const uint QsAllInput = 0x04FF;
|
||||
|
||||
public void WaitForWorkOrMessages(WaitHandle commandWakeEvent, TimeSpan timeout)
|
||||
{
|
||||
if (commandWakeEvent is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(commandWakeEvent));
|
||||
}
|
||||
|
||||
uint timeoutMilliseconds = ToTimeoutMilliseconds(timeout);
|
||||
|
||||
SafeWaitHandle safeHandle = commandWakeEvent.SafeWaitHandle;
|
||||
IntPtr[] handles = [safeHandle.DangerousGetHandle()];
|
||||
uint result = MsgWaitForMultipleObjectsEx(
|
||||
(uint)handles.Length,
|
||||
handles,
|
||||
timeoutMilliseconds,
|
||||
QsAllInput,
|
||||
MwmoInputAvailable);
|
||||
|
||||
if (result == MsgWaitFailed)
|
||||
{
|
||||
throw new InvalidOperationException(
|
||||
"The worker STA message pump failed while waiting for command work or Windows messages.");
|
||||
}
|
||||
}
|
||||
|
||||
public int PumpPendingMessages()
|
||||
{
|
||||
int pumpedMessages = 0;
|
||||
|
||||
while (PeekMessage(out NativeMessage message, IntPtr.Zero, 0, 0, PmRemove))
|
||||
{
|
||||
TranslateMessage(ref message);
|
||||
DispatchMessage(ref message);
|
||||
pumpedMessages++;
|
||||
}
|
||||
|
||||
return pumpedMessages;
|
||||
}
|
||||
|
||||
private static uint ToTimeoutMilliseconds(TimeSpan timeout)
|
||||
{
|
||||
if (timeout == Timeout.InfiniteTimeSpan)
|
||||
{
|
||||
return Infinite;
|
||||
}
|
||||
|
||||
if (timeout <= TimeSpan.Zero)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
return timeout.TotalMilliseconds >= uint.MaxValue
|
||||
? uint.MaxValue - 1
|
||||
: (uint)Math.Ceiling(timeout.TotalMilliseconds);
|
||||
}
|
||||
|
||||
[DllImport("user32.dll", SetLastError = true)]
|
||||
private static extern uint MsgWaitForMultipleObjectsEx(
|
||||
uint count,
|
||||
IntPtr[] handles,
|
||||
uint milliseconds,
|
||||
uint wakeMask,
|
||||
uint flags);
|
||||
|
||||
[DllImport("user32.dll", SetLastError = true)]
|
||||
private static extern bool PeekMessage(
|
||||
out NativeMessage message,
|
||||
IntPtr windowHandle,
|
||||
uint messageFilterMin,
|
||||
uint messageFilterMax,
|
||||
uint removeMessage);
|
||||
|
||||
[DllImport("user32.dll")]
|
||||
private static extern bool TranslateMessage(ref NativeMessage message);
|
||||
|
||||
[DllImport("user32.dll")]
|
||||
private static extern IntPtr DispatchMessage(ref NativeMessage message);
|
||||
|
||||
[StructLayout(LayoutKind.Sequential)]
|
||||
private struct NativeMessage
|
||||
{
|
||||
public IntPtr WindowHandle;
|
||||
public uint Message;
|
||||
public UIntPtr WParam;
|
||||
public IntPtr LParam;
|
||||
public uint Time;
|
||||
public NativePoint Point;
|
||||
}
|
||||
|
||||
[StructLayout(LayoutKind.Sequential)]
|
||||
private struct NativePoint
|
||||
{
|
||||
public int X;
|
||||
public int Y;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
using System;
|
||||
using System.Collections.Concurrent;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace MxGateway.Worker.Sta;
|
||||
|
||||
public sealed class StaRuntime : IDisposable
|
||||
{
|
||||
private readonly IStaComApartmentInitializer comApartmentInitializer;
|
||||
private readonly StaMessagePump messagePump;
|
||||
private readonly ConcurrentQueue<IStaWorkItem> commandQueue = new();
|
||||
private readonly AutoResetEvent commandWakeEvent = new(false);
|
||||
private readonly ManualResetEventSlim startedEvent = new(false);
|
||||
private readonly ManualResetEventSlim stoppedEvent = new(false);
|
||||
private readonly object gate = new();
|
||||
private readonly Thread staThread;
|
||||
private readonly TimeSpan idlePumpInterval;
|
||||
private bool disposed;
|
||||
private bool startRequested;
|
||||
private bool shutdownRequested;
|
||||
private Exception? startupException;
|
||||
private long lastActivityUtcTicks;
|
||||
private bool comInitialized;
|
||||
|
||||
public StaRuntime()
|
||||
: this(new StaComApartmentInitializer(), new StaMessagePump(), TimeSpan.FromMilliseconds(50))
|
||||
{
|
||||
}
|
||||
|
||||
public StaRuntime(
|
||||
IStaComApartmentInitializer comApartmentInitializer,
|
||||
StaMessagePump messagePump,
|
||||
TimeSpan idlePumpInterval)
|
||||
{
|
||||
this.comApartmentInitializer = comApartmentInitializer
|
||||
?? throw new ArgumentNullException(nameof(comApartmentInitializer));
|
||||
this.messagePump = messagePump ?? throw new ArgumentNullException(nameof(messagePump));
|
||||
|
||||
if (idlePumpInterval <= TimeSpan.Zero)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(
|
||||
nameof(idlePumpInterval),
|
||||
"The idle pump interval must be greater than zero.");
|
||||
}
|
||||
|
||||
this.idlePumpInterval = idlePumpInterval;
|
||||
lastActivityUtcTicks = DateTimeOffset.UtcNow.UtcTicks;
|
||||
staThread = new Thread(ThreadMain)
|
||||
{
|
||||
IsBackground = true,
|
||||
Name = "MxGateway.Worker.STA"
|
||||
};
|
||||
staThread.SetApartmentState(ApartmentState.STA);
|
||||
}
|
||||
|
||||
public int? StaThreadId { get; private set; }
|
||||
|
||||
public DateTimeOffset LastActivityUtc =>
|
||||
new(new DateTime(Volatile.Read(ref lastActivityUtcTicks), DateTimeKind.Utc));
|
||||
|
||||
public bool IsRunning => startedEvent.IsSet && !stoppedEvent.IsSet;
|
||||
|
||||
public void Start()
|
||||
{
|
||||
ThrowIfDisposed();
|
||||
|
||||
lock (gate)
|
||||
{
|
||||
if (shutdownRequested)
|
||||
{
|
||||
throw new InvalidOperationException("The worker STA runtime is shutting down.");
|
||||
}
|
||||
|
||||
if (!startRequested)
|
||||
{
|
||||
startRequested = true;
|
||||
staThread.Start();
|
||||
}
|
||||
}
|
||||
|
||||
startedEvent.Wait();
|
||||
if (startupException is not null)
|
||||
{
|
||||
throw new InvalidOperationException(
|
||||
"The worker STA runtime failed to initialize.",
|
||||
startupException);
|
||||
}
|
||||
}
|
||||
|
||||
public Task InvokeAsync(Action command, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (command is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(command));
|
||||
}
|
||||
|
||||
return InvokeAsync(
|
||||
() =>
|
||||
{
|
||||
command();
|
||||
return true;
|
||||
},
|
||||
cancellationToken);
|
||||
}
|
||||
|
||||
public Task<T> InvokeAsync<T>(Func<T> command, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (command is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(command));
|
||||
}
|
||||
|
||||
ThrowIfDisposed();
|
||||
|
||||
if (cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
return Task.FromCanceled<T>(cancellationToken);
|
||||
}
|
||||
|
||||
StaWorkItem<T> workItem = new(command, cancellationToken);
|
||||
|
||||
lock (gate)
|
||||
{
|
||||
if (shutdownRequested)
|
||||
{
|
||||
return Task.FromException<T>(
|
||||
new InvalidOperationException("The worker STA runtime is shutting down."));
|
||||
}
|
||||
|
||||
commandQueue.Enqueue(workItem);
|
||||
}
|
||||
|
||||
commandWakeEvent.Set();
|
||||
return workItem.Task;
|
||||
}
|
||||
|
||||
public bool Shutdown(TimeSpan timeout)
|
||||
{
|
||||
if (timeout < TimeSpan.Zero && timeout != Timeout.InfiniteTimeSpan)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(timeout));
|
||||
}
|
||||
|
||||
lock (gate)
|
||||
{
|
||||
shutdownRequested = true;
|
||||
}
|
||||
|
||||
commandWakeEvent.Set();
|
||||
|
||||
if (!startedEvent.IsSet && !staThread.IsAlive)
|
||||
{
|
||||
CancelQueuedCommands();
|
||||
stoppedEvent.Set();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool stopped = stoppedEvent.Wait(timeout);
|
||||
if (stopped)
|
||||
{
|
||||
CancelQueuedCommands();
|
||||
}
|
||||
|
||||
return stopped;
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
if (disposed)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
bool stopped = Shutdown(TimeSpan.FromSeconds(5));
|
||||
if (stopped)
|
||||
{
|
||||
commandWakeEvent.Dispose();
|
||||
startedEvent.Dispose();
|
||||
stoppedEvent.Dispose();
|
||||
}
|
||||
|
||||
disposed = true;
|
||||
}
|
||||
|
||||
private void ThreadMain()
|
||||
{
|
||||
try
|
||||
{
|
||||
StaThreadId = Thread.CurrentThread.ManagedThreadId;
|
||||
comApartmentInitializer.Initialize();
|
||||
comInitialized = true;
|
||||
MarkActivity();
|
||||
startedEvent.Set();
|
||||
|
||||
while (!IsShutdownRequested())
|
||||
{
|
||||
ProcessQueuedCommands();
|
||||
messagePump.WaitForWorkOrMessages(commandWakeEvent, idlePumpInterval);
|
||||
messagePump.PumpPendingMessages();
|
||||
MarkActivity();
|
||||
}
|
||||
|
||||
ProcessQueuedCommands();
|
||||
}
|
||||
catch (Exception exception)
|
||||
{
|
||||
startupException = exception;
|
||||
startedEvent.Set();
|
||||
}
|
||||
finally
|
||||
{
|
||||
CancelQueuedCommands();
|
||||
try
|
||||
{
|
||||
if (comInitialized)
|
||||
{
|
||||
comApartmentInitializer.Uninitialize();
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
MarkActivity();
|
||||
stoppedEvent.Set();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void ProcessQueuedCommands()
|
||||
{
|
||||
while (commandQueue.TryDequeue(out IStaWorkItem? workItem))
|
||||
{
|
||||
MarkActivity();
|
||||
workItem.Execute();
|
||||
MarkActivity();
|
||||
}
|
||||
}
|
||||
|
||||
private void CancelQueuedCommands()
|
||||
{
|
||||
while (commandQueue.TryDequeue(out IStaWorkItem? workItem))
|
||||
{
|
||||
workItem.CancelBeforeExecution();
|
||||
}
|
||||
}
|
||||
|
||||
private bool IsShutdownRequested()
|
||||
{
|
||||
lock (gate)
|
||||
{
|
||||
return shutdownRequested;
|
||||
}
|
||||
}
|
||||
|
||||
private void MarkActivity()
|
||||
{
|
||||
Volatile.Write(ref lastActivityUtcTicks, DateTimeOffset.UtcNow.UtcTicks);
|
||||
}
|
||||
|
||||
private void ThrowIfDisposed()
|
||||
{
|
||||
if (disposed)
|
||||
{
|
||||
throw new ObjectDisposedException(nameof(StaRuntime));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
using System;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace MxGateway.Worker.Sta;
|
||||
|
||||
internal sealed class StaWorkItem<T> : IStaWorkItem
|
||||
{
|
||||
private readonly Func<T> command;
|
||||
private readonly CancellationToken cancellationToken;
|
||||
private readonly CancellationTokenRegistration cancellationRegistration;
|
||||
private int started;
|
||||
|
||||
public StaWorkItem(Func<T> command, CancellationToken cancellationToken)
|
||||
{
|
||||
this.command = command ?? throw new ArgumentNullException(nameof(command));
|
||||
this.cancellationToken = cancellationToken;
|
||||
Completion = new TaskCompletionSource<T>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
|
||||
if (cancellationToken.CanBeCanceled)
|
||||
{
|
||||
cancellationRegistration = cancellationToken.Register(
|
||||
() =>
|
||||
{
|
||||
if (Interlocked.CompareExchange(ref started, 1, 0) == 0)
|
||||
{
|
||||
Completion.TrySetCanceled(cancellationToken);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
public Task<T> Task => Completion.Task;
|
||||
|
||||
private TaskCompletionSource<T> Completion { get; }
|
||||
|
||||
public void CancelBeforeExecution()
|
||||
{
|
||||
if (Interlocked.CompareExchange(ref started, 1, 0) == 0)
|
||||
{
|
||||
Completion.TrySetCanceled(cancellationToken);
|
||||
cancellationRegistration.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
public void Execute()
|
||||
{
|
||||
if (Interlocked.CompareExchange(ref started, 1, 0) != 0)
|
||||
{
|
||||
cancellationRegistration.Dispose();
|
||||
return;
|
||||
}
|
||||
|
||||
cancellationRegistration.Dispose();
|
||||
|
||||
if (cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
Completion.TrySetCanceled(cancellationToken);
|
||||
return;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
Completion.TrySetResult(command());
|
||||
}
|
||||
catch (Exception exception)
|
||||
{
|
||||
Completion.TrySetException(exception);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,16 +1,128 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using MxGateway.Worker.Bootstrap;
|
||||
using MxGateway.Worker.Ipc;
|
||||
|
||||
namespace MxGateway.Worker;
|
||||
|
||||
public static class WorkerApplication
|
||||
{
|
||||
public static int Run(string[] args)
|
||||
{
|
||||
return Run(
|
||||
args,
|
||||
new EnvironmentVariableWorkerEnvironment(),
|
||||
new WorkerConsoleLogger(Console.Error),
|
||||
new WorkerPipeClient());
|
||||
}
|
||||
|
||||
public static int Run(
|
||||
string[] args,
|
||||
IWorkerEnvironment environment,
|
||||
IWorkerLogger logger)
|
||||
{
|
||||
return Run(
|
||||
args,
|
||||
environment,
|
||||
logger,
|
||||
new WorkerPipeClient());
|
||||
}
|
||||
|
||||
public static int Run(
|
||||
string[] args,
|
||||
IWorkerEnvironment environment,
|
||||
IWorkerLogger logger,
|
||||
IWorkerPipeClient pipeClient)
|
||||
{
|
||||
if (args is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(args));
|
||||
}
|
||||
|
||||
return 0;
|
||||
if (environment is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(environment));
|
||||
}
|
||||
|
||||
if (logger is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(logger));
|
||||
}
|
||||
|
||||
if (pipeClient is null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(pipeClient));
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
WorkerOptionsParser parser = new(environment);
|
||||
WorkerBootstrapResult result = parser.Parse(args);
|
||||
|
||||
if (!result.Succeeded)
|
||||
{
|
||||
logger.Error("WorkerBootstrapFailed", new Dictionary<string, object?>
|
||||
{
|
||||
["exit_code"] = result.ExitCode,
|
||||
["errors"] = string.Join(";", result.Errors),
|
||||
});
|
||||
|
||||
return (int)result.ExitCode;
|
||||
}
|
||||
|
||||
WorkerOptions options = result.Options
|
||||
?? throw new InvalidOperationException("Successful bootstrap result did not include worker options.");
|
||||
|
||||
logger.Information("WorkerBootstrapSucceeded", new Dictionary<string, object?>
|
||||
{
|
||||
["session_id"] = options.SessionId,
|
||||
["pipe_name"] = options.PipeName,
|
||||
["protocol_version"] = options.ProtocolVersion,
|
||||
["nonce"] = options.Nonce,
|
||||
});
|
||||
|
||||
pipeClient.RunAsync(options).GetAwaiter().GetResult();
|
||||
|
||||
logger.Information("WorkerPipeHandshakeSucceeded", new Dictionary<string, object?>
|
||||
{
|
||||
["session_id"] = options.SessionId,
|
||||
["pipe_name"] = options.PipeName,
|
||||
["protocol_version"] = options.ProtocolVersion,
|
||||
});
|
||||
|
||||
return (int)WorkerExitCode.Success;
|
||||
}
|
||||
catch (WorkerFrameProtocolException exception)
|
||||
{
|
||||
logger.Error("WorkerPipeProtocolFailure", new Dictionary<string, object?>
|
||||
{
|
||||
["exit_code"] = WorkerExitCode.ProtocolViolation,
|
||||
["error_code"] = exception.ErrorCode,
|
||||
["exception_type"] = exception.GetType().FullName,
|
||||
});
|
||||
|
||||
return (int)WorkerExitCode.ProtocolViolation;
|
||||
}
|
||||
catch (Exception exception) when (exception is IOException or TimeoutException)
|
||||
{
|
||||
logger.Error("WorkerPipeConnectionFailed", new Dictionary<string, object?>
|
||||
{
|
||||
["exit_code"] = WorkerExitCode.PipeConnectionFailed,
|
||||
["exception_type"] = exception.GetType().FullName,
|
||||
});
|
||||
|
||||
return (int)WorkerExitCode.PipeConnectionFailed;
|
||||
}
|
||||
catch (Exception exception)
|
||||
{
|
||||
logger.Error("WorkerBootstrapUnexpectedFailure", new Dictionary<string, object?>
|
||||
{
|
||||
["exit_code"] = WorkerExitCode.UnexpectedFailure,
|
||||
["exception_type"] = exception.GetType().FullName,
|
||||
});
|
||||
|
||||
return (int)WorkerExitCode.UnexpectedFailure;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user