Compare commits

...

5 Commits

Author SHA1 Message Date
Joseph Doherty 0d96963c99 Merge remote-tracking branch 'origin/main' into agent-1/issue-15-implement-dashboard-snapshot-service 2026-04-26 17:52:58 -04:00
dohertj2 dede407304 Merge pull request #72 from agent-2/issue-25-implement-sta-command-dispatcher
Issue #25: implement STA command dispatcher
2026-04-26 17:53:32 -04:00
Joseph Doherty 3661420f0a Issue #15: implement dashboard snapshot service 2026-04-26 17:49:59 -04:00
Joseph Doherty 14419853c7 Issue #25: implement sta command dispatcher 2026-04-26 17:49:01 -04:00
dohertj2 a20517f5ad Merge pull request #71 from agent-3/issue-13-implement-public-grpc-service
Issue #13: implement public grpc service
2026-04-26 17:48:35 -04:00
16 changed files with 1203 additions and 0 deletions
+2
View File
@@ -107,6 +107,8 @@ worker, correlation, command, and client identity fields with redaction applied
before values enter log state. `GatewayMetrics` exposes counters, gauges, and
histograms through .NET `Meter` and a snapshot API that dashboard services can
project without binding to a metrics exporter.
`DashboardSnapshotService` projects sessions, workers, metrics, faults, and
effective configuration into immutable DTOs for read-only dashboard rendering.
### Worker Process
@@ -0,0 +1,9 @@
namespace MxGateway.Server.Dashboard;
public sealed record DashboardFaultSummary(
string Source,
string? SessionId,
int? WorkerProcessId,
string State,
string Message,
DateTimeOffset ObservedAt);
@@ -0,0 +1,6 @@
namespace MxGateway.Server.Dashboard;
public sealed record DashboardMetricSummary(
string Name,
long Value,
string? Dimension = null);
@@ -0,0 +1,34 @@
using MxGateway.Server.Diagnostics;
namespace MxGateway.Server.Dashboard;
internal static class DashboardRedactor
{
private static readonly string[] SensitiveTextMarkers =
[
"apikey",
"api_key",
"authorization",
"credential",
"password",
"secret",
"token",
];
public static string? Redact(string? value)
{
if (string.IsNullOrWhiteSpace(value))
{
return value;
}
if (value.Contains("mxgw_", StringComparison.OrdinalIgnoreCase))
{
return GatewayLogRedactor.RedactClientIdentity(value);
}
return SensitiveTextMarkers.Any(marker => value.Contains(marker, StringComparison.OrdinalIgnoreCase))
? GatewayLogRedactor.RedactedValue
: value;
}
}
@@ -0,0 +1,11 @@
namespace MxGateway.Server.Dashboard;
public static class DashboardServiceCollectionExtensions
{
public static IServiceCollection AddGatewayDashboard(this IServiceCollection services)
{
services.AddSingleton<IDashboardSnapshotService, DashboardSnapshotService>();
return services;
}
}
@@ -0,0 +1,19 @@
using MxGateway.Contracts.Proto;
using MxGateway.Server.Workers;
namespace MxGateway.Server.Dashboard;
public sealed record DashboardSessionSummary(
string SessionId,
string BackendName,
SessionState State,
string? ClientIdentity,
string? ClientSessionName,
string? ClientCorrelationId,
DateTimeOffset OpenedAt,
DateTimeOffset LastClientActivityAt,
DateTimeOffset? LeaseExpiresAt,
int? WorkerProcessId,
WorkerClientState? WorkerState,
DateTimeOffset? LastWorkerHeartbeatAt,
string? LastFault);
@@ -0,0 +1,15 @@
using MxGateway.Server.Configuration;
namespace MxGateway.Server.Dashboard;
public sealed record DashboardSnapshot(
DateTimeOffset GeneratedAt,
DateTimeOffset GatewayStartedAt,
TimeSpan GatewayUptime,
string GatewayStatus,
string GatewayVersion,
IReadOnlyList<DashboardSessionSummary> Sessions,
IReadOnlyList<DashboardWorkerSummary> Workers,
IReadOnlyList<DashboardMetricSummary> Metrics,
IReadOnlyList<DashboardFaultSummary> Faults,
EffectiveGatewayConfiguration Configuration);
@@ -0,0 +1,196 @@
using System.Runtime.CompilerServices;
using Microsoft.Extensions.Options;
using MxGateway.Server.Configuration;
using MxGateway.Server.Metrics;
using MxGateway.Server.Sessions;
using MxGateway.Server.Workers;
namespace MxGateway.Server.Dashboard;
public sealed class DashboardSnapshotService : IDashboardSnapshotService
{
private const string HealthyStatus = "Healthy";
private readonly ISessionRegistry _sessionRegistry;
private readonly GatewayMetrics _metrics;
private readonly IGatewayConfigurationProvider _configurationProvider;
private readonly TimeProvider _timeProvider;
private readonly DateTimeOffset _gatewayStartedAt;
private readonly TimeSpan _snapshotInterval;
private readonly int _recentFaultLimit;
private readonly int _recentSessionLimit;
public DashboardSnapshotService(
ISessionRegistry sessionRegistry,
GatewayMetrics metrics,
IGatewayConfigurationProvider configurationProvider,
IOptions<GatewayOptions> options,
TimeProvider? timeProvider = null)
{
_sessionRegistry = sessionRegistry ?? throw new ArgumentNullException(nameof(sessionRegistry));
_metrics = metrics ?? throw new ArgumentNullException(nameof(metrics));
_configurationProvider = configurationProvider ?? throw new ArgumentNullException(nameof(configurationProvider));
ArgumentNullException.ThrowIfNull(options);
_timeProvider = timeProvider ?? TimeProvider.System;
_gatewayStartedAt = _timeProvider.GetUtcNow();
_snapshotInterval = TimeSpan.FromMilliseconds(options.Value.Dashboard.SnapshotIntervalMilliseconds);
_recentFaultLimit = options.Value.Dashboard.RecentFaultLimit;
_recentSessionLimit = options.Value.Dashboard.RecentSessionLimit;
}
public DashboardSnapshot GetSnapshot()
{
DateTimeOffset generatedAt = _timeProvider.GetUtcNow();
IReadOnlyList<GatewaySession> sessions = _sessionRegistry.Snapshot()
.OrderByDescending(session => session.OpenedAt)
.ToArray();
IReadOnlyList<DashboardSessionSummary> sessionSummaries = sessions
.Take(ResolveLimit(_recentSessionLimit))
.Select(CreateSessionSummary)
.ToArray();
IReadOnlyList<DashboardWorkerSummary> workerSummaries = sessions
.Where(session => session.WorkerClient is not null)
.Select(CreateWorkerSummary)
.ToArray();
GatewayMetricsSnapshot metricsSnapshot = _metrics.GetSnapshot();
return new DashboardSnapshot(
GeneratedAt: generatedAt,
GatewayStartedAt: _gatewayStartedAt,
GatewayUptime: generatedAt - _gatewayStartedAt,
GatewayStatus: HealthyStatus,
GatewayVersion: typeof(DashboardSnapshotService).Assembly.GetName().Version?.ToString() ?? "unknown",
Sessions: sessionSummaries,
Workers: workerSummaries,
Metrics: CreateMetricSummaries(metricsSnapshot),
Faults: CreateFaultSummaries(sessions, generatedAt),
Configuration: _configurationProvider.GetEffectiveConfiguration());
}
public async IAsyncEnumerable<DashboardSnapshot> WatchSnapshotsAsync(
[EnumeratorCancellation] CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
{
yield break;
}
yield return GetSnapshot();
using PeriodicTimer timer = new(_snapshotInterval, _timeProvider);
while (true)
{
bool hasNext;
try
{
hasNext = await timer.WaitForNextTickAsync(cancellationToken).ConfigureAwait(false);
}
catch (OperationCanceledException)
{
yield break;
}
if (!hasNext)
{
yield break;
}
yield return GetSnapshot();
}
}
private static DashboardSessionSummary CreateSessionSummary(GatewaySession session)
{
IWorkerClient? workerClient = session.WorkerClient;
return new DashboardSessionSummary(
SessionId: session.SessionId,
BackendName: session.BackendName,
State: session.State,
ClientIdentity: DashboardRedactor.Redact(session.ClientIdentity),
ClientSessionName: DashboardRedactor.Redact(session.ClientSessionName),
ClientCorrelationId: DashboardRedactor.Redact(session.ClientCorrelationId),
OpenedAt: session.OpenedAt,
LastClientActivityAt: session.LastClientActivityAt,
LeaseExpiresAt: session.LeaseExpiresAt,
WorkerProcessId: workerClient?.ProcessId,
WorkerState: workerClient?.State,
LastWorkerHeartbeatAt: workerClient?.LastHeartbeatAt,
LastFault: DashboardRedactor.Redact(session.FinalFault));
}
private static DashboardWorkerSummary CreateWorkerSummary(GatewaySession session)
{
IWorkerClient workerClient = session.WorkerClient!;
return new DashboardWorkerSummary(
SessionId: session.SessionId,
ProcessId: workerClient.ProcessId,
State: workerClient.State,
LastHeartbeatAt: workerClient.LastHeartbeatAt,
LastFault: DashboardRedactor.Redact(session.FinalFault));
}
private static IReadOnlyList<DashboardMetricSummary> CreateMetricSummaries(GatewayMetricsSnapshot snapshot)
{
List<DashboardMetricSummary> metrics =
[
new("mxgateway.sessions.open", snapshot.OpenSessions),
new("mxgateway.workers.running", snapshot.WorkersRunning),
new("mxgateway.events.queue.depth", snapshot.EventQueueDepth),
new("mxgateway.sessions.opened", snapshot.SessionsOpened),
new("mxgateway.sessions.closed", snapshot.SessionsClosed),
new("mxgateway.commands.started", snapshot.CommandsStarted),
new("mxgateway.commands.succeeded", snapshot.CommandsSucceeded),
new("mxgateway.commands.failed", snapshot.CommandsFailed),
new("mxgateway.events.received", snapshot.EventsReceived),
new("mxgateway.queues.overflows", snapshot.QueueOverflows),
new("mxgateway.faults", snapshot.Faults),
new("mxgateway.workers.killed", snapshot.WorkerKills),
new("mxgateway.workers.exited", snapshot.WorkerExits),
new("mxgateway.heartbeats.failed", snapshot.HeartbeatFailures),
new("mxgateway.grpc.streams.disconnected", snapshot.StreamDisconnects),
];
metrics.AddRange(snapshot.CommandFailuresByMethod
.OrderBy(entry => entry.Key, StringComparer.OrdinalIgnoreCase)
.Select(entry => new DashboardMetricSummary("mxgateway.commands.failed", entry.Value, entry.Key)));
metrics.AddRange(snapshot.EventsByFamily
.OrderBy(entry => entry.Key, StringComparer.OrdinalIgnoreCase)
.Select(entry => new DashboardMetricSummary("mxgateway.events.received", entry.Value, entry.Key)));
return metrics;
}
private IReadOnlyList<DashboardFaultSummary> CreateFaultSummaries(
IReadOnlyList<GatewaySession> sessions,
DateTimeOffset generatedAt)
{
return sessions
.Where(HasFault)
.Take(ResolveLimit(_recentFaultLimit))
.Select(session => new DashboardFaultSummary(
Source: session.WorkerClient?.State == WorkerClientState.Faulted ? "Worker" : "Session",
SessionId: session.SessionId,
WorkerProcessId: session.WorkerProcessId,
State: session.WorkerClient?.State == WorkerClientState.Faulted
? WorkerClientState.Faulted.ToString()
: session.State.ToString(),
Message: DashboardRedactor.Redact(session.FinalFault) ?? "Faulted",
ObservedAt: generatedAt))
.ToArray();
}
private static bool HasFault(GatewaySession session)
{
return session.State == MxGateway.Contracts.Proto.SessionState.Faulted
|| session.WorkerClient?.State == WorkerClientState.Faulted
|| !string.IsNullOrWhiteSpace(session.FinalFault);
}
private static int ResolveLimit(int configuredLimit)
{
return configuredLimit < 0 ? 0 : configuredLimit;
}
}
@@ -0,0 +1,10 @@
using MxGateway.Server.Workers;
namespace MxGateway.Server.Dashboard;
public sealed record DashboardWorkerSummary(
string SessionId,
int? ProcessId,
WorkerClientState State,
DateTimeOffset LastHeartbeatAt,
string? LastFault);
@@ -0,0 +1,8 @@
namespace MxGateway.Server.Dashboard;
public interface IDashboardSnapshotService
{
DashboardSnapshot GetSnapshot();
IAsyncEnumerable<DashboardSnapshot> WatchSnapshotsAsync(CancellationToken cancellationToken);
}
@@ -1,5 +1,6 @@
using MxGateway.Contracts;
using MxGateway.Server.Configuration;
using MxGateway.Server.Dashboard;
using MxGateway.Server.Diagnostics;
using MxGateway.Server.Grpc;
using MxGateway.Server.Metrics;
@@ -36,6 +37,7 @@ public static class GatewayApplication
builder.Services.AddSingleton<MxAccessGrpcRequestValidator>();
builder.Services.AddWorkerProcessLauncher();
builder.Services.AddGatewaySessions();
builder.Services.AddGatewayDashboard();
return builder;
}
@@ -0,0 +1,290 @@
using Microsoft.Extensions.Options;
using MxGateway.Contracts.Proto;
using MxGateway.Server.Configuration;
using MxGateway.Server.Dashboard;
using MxGateway.Server.Metrics;
using MxGateway.Server.Sessions;
using MxGateway.Server.Workers;
namespace MxGateway.Tests.Gateway.Dashboard;
public sealed class DashboardSnapshotServiceTests
{
[Fact]
public void GetSnapshot_WhenRegistryEmpty_ReturnsEmptyOperationalState()
{
using GatewayMetrics metrics = new();
DashboardSnapshotService service = CreateService(new SessionRegistry(), metrics);
DashboardSnapshot snapshot = service.GetSnapshot();
Assert.Empty(snapshot.Sessions);
Assert.Empty(snapshot.Workers);
Assert.Empty(snapshot.Faults);
Assert.Contains(snapshot.Metrics, metric => metric.Name == "mxgateway.sessions.open" && metric.Value == 0);
Assert.Equal("Healthy", snapshot.GatewayStatus);
Assert.NotNull(snapshot.Configuration);
}
[Fact]
public void GetSnapshot_ProjectsActiveAndFaultedSessionsWorkersMetricsAndFaults()
{
SessionRegistry registry = new();
GatewaySession activeSession = CreateSession(
"session-active",
"client-one",
DateTimeOffset.Parse("2026-04-26T10:00:00Z"));
activeSession.AttachWorkerClient(new FakeWorkerClient("session-active", 1201, WorkerClientState.Ready));
activeSession.MarkReady();
GatewaySession faultedSession = CreateSession(
"session-faulted",
"client-two",
DateTimeOffset.Parse("2026-04-26T10:01:00Z"));
faultedSession.AttachWorkerClient(new FakeWorkerClient("session-faulted", 1202, WorkerClientState.Faulted));
faultedSession.MarkFaulted("worker pipe disconnected");
registry.TryAdd(activeSession);
registry.TryAdd(faultedSession);
using GatewayMetrics metrics = new();
metrics.SessionOpened();
metrics.SessionOpened();
metrics.CommandStarted("Register");
metrics.CommandFailed("Register", "WorkerFaulted", TimeSpan.FromMilliseconds(7));
metrics.EventReceived("session-active", "OnDataChange");
metrics.Fault("WorkerFaulted");
DashboardSnapshotService service = CreateService(registry, metrics);
DashboardSnapshot snapshot = service.GetSnapshot();
Assert.Equal(2, snapshot.Sessions.Count);
Assert.Equal("session-faulted", snapshot.Sessions[0].SessionId);
Assert.Equal(SessionState.Faulted, snapshot.Sessions[0].State);
Assert.Equal(2, snapshot.Workers.Count);
Assert.Contains(snapshot.Metrics, metric => metric.Name == "mxgateway.commands.started" && metric.Value == 1);
Assert.Contains(
snapshot.Metrics,
metric => metric.Name == "mxgateway.events.received"
&& metric.Dimension == "OnDataChange"
&& metric.Value == 1);
DashboardFaultSummary fault = Assert.Single(snapshot.Faults);
Assert.Equal("Worker", fault.Source);
Assert.Equal("session-faulted", fault.SessionId);
Assert.Equal("worker pipe disconnected", fault.Message);
}
[Fact]
public void GetSnapshot_RedactsSecretsFromSessionAndFaultFields()
{
SessionRegistry registry = new();
GatewaySession session = CreateSession(
"session-redacted",
"Bearer mxgw_admin_super-secret",
DateTimeOffset.Parse("2026-04-26T10:00:00Z"),
clientSessionName: "password=hunter2",
clientCorrelationId: "token=abc123");
session.MarkFaulted("secret=credential-value");
registry.TryAdd(session);
using GatewayMetrics metrics = new();
DashboardSnapshotService service = CreateService(registry, metrics);
DashboardSnapshot snapshot = service.GetSnapshot();
DashboardSessionSummary summary = Assert.Single(snapshot.Sessions);
Assert.Equal("Bearer mxgw_admin_[redacted]", summary.ClientIdentity);
Assert.Equal("[redacted]", summary.ClientSessionName);
Assert.Equal("[redacted]", summary.ClientCorrelationId);
Assert.Equal("[redacted]", summary.LastFault);
Assert.Equal("[redacted]", Assert.Single(snapshot.Faults).Message);
Assert.Equal("[redacted]", snapshot.Configuration.Authentication.PepperSecretName);
}
[Fact]
public void GetSnapshot_DoesNotMutateSessionOrWorkerState()
{
SessionRegistry registry = new();
GatewaySession session = CreateSession(
"session-active",
"client-one",
DateTimeOffset.Parse("2026-04-26T10:00:00Z"));
FakeWorkerClient workerClient = new("session-active", 1201, WorkerClientState.Ready);
session.AttachWorkerClient(workerClient);
session.MarkReady();
registry.TryAdd(session);
using GatewayMetrics metrics = new();
DashboardSnapshotService service = CreateService(registry, metrics);
service.GetSnapshot();
service.GetSnapshot();
Assert.Equal(1, registry.ActiveCount);
Assert.Equal(SessionState.Ready, session.State);
Assert.Equal(WorkerClientState.Ready, workerClient.State);
Assert.Equal(0, workerClient.StartCount);
Assert.Equal(0, workerClient.ShutdownCount);
Assert.Equal(0, workerClient.KillCount);
}
[Fact]
public void GetSnapshot_AppliesRecentSessionAndFaultLimits()
{
SessionRegistry registry = new();
GatewaySession olderSession = CreateSession(
"session-older",
"client-one",
DateTimeOffset.Parse("2026-04-26T10:00:00Z"));
GatewaySession newerSession = CreateSession(
"session-newer",
"client-two",
DateTimeOffset.Parse("2026-04-26T10:01:00Z"));
olderSession.MarkFaulted("older fault");
newerSession.MarkFaulted("newer fault");
registry.TryAdd(olderSession);
registry.TryAdd(newerSession);
using GatewayMetrics metrics = new();
DashboardSnapshotService service = CreateService(
registry,
metrics,
new GatewayOptions
{
Dashboard = new DashboardOptions
{
SnapshotIntervalMilliseconds = 1,
RecentSessionLimit = 1,
RecentFaultLimit = 1,
},
});
DashboardSnapshot snapshot = service.GetSnapshot();
Assert.Equal("session-newer", Assert.Single(snapshot.Sessions).SessionId);
Assert.Equal("session-newer", Assert.Single(snapshot.Faults).SessionId);
}
[Fact]
public async Task WatchSnapshotsAsync_WhenSubscriberCancels_DisposesCleanly()
{
using GatewayMetrics metrics = new();
DashboardSnapshotService service = CreateService(
new SessionRegistry(),
metrics,
new GatewayOptions
{
Dashboard = new DashboardOptions
{
SnapshotIntervalMilliseconds = 1,
},
});
using CancellationTokenSource cancellation = new();
await using IAsyncEnumerator<DashboardSnapshot> enumerator = service
.WatchSnapshotsAsync(cancellation.Token)
.GetAsyncEnumerator();
Assert.True(await enumerator.MoveNextAsync().AsTask().WaitAsync(TimeSpan.FromSeconds(1)));
await cancellation.CancelAsync();
bool hasNext = await enumerator.MoveNextAsync().AsTask().WaitAsync(TimeSpan.FromSeconds(1));
Assert.False(hasNext);
}
private static DashboardSnapshotService CreateService(
SessionRegistry registry,
GatewayMetrics metrics,
GatewayOptions? options = null)
{
GatewayOptions resolvedOptions = options ?? new GatewayOptions
{
Dashboard = new DashboardOptions
{
SnapshotIntervalMilliseconds = 1,
},
};
GatewayConfigurationProvider configurationProvider = new(Options.Create(resolvedOptions));
return new DashboardSnapshotService(
registry,
metrics,
configurationProvider,
Options.Create(resolvedOptions));
}
private static GatewaySession CreateSession(
string sessionId,
string? clientIdentity,
DateTimeOffset openedAt,
string? clientSessionName = "test-session",
string? clientCorrelationId = "client-correlation")
{
return new GatewaySession(
sessionId,
"mxaccess",
$"mxaccess-gateway-1-{sessionId}",
"nonce",
clientIdentity,
clientSessionName,
clientCorrelationId,
TimeSpan.FromSeconds(30),
TimeSpan.FromSeconds(5),
TimeSpan.FromSeconds(5),
openedAt);
}
private sealed class FakeWorkerClient(
string sessionId,
int? processId,
WorkerClientState state) : IWorkerClient
{
public string SessionId { get; } = sessionId;
public int? ProcessId { get; } = processId;
public WorkerClientState State { get; private set; } = state;
public DateTimeOffset LastHeartbeatAt { get; } = DateTimeOffset.Parse("2026-04-26T10:02:00Z");
public int StartCount { get; private set; }
public int ShutdownCount { get; private set; }
public int KillCount { get; private set; }
public Task StartAsync(CancellationToken cancellationToken)
{
StartCount++;
return Task.CompletedTask;
}
public Task<WorkerCommandReply> InvokeAsync(
WorkerCommand command,
TimeSpan timeout,
CancellationToken cancellationToken)
{
return Task.FromResult(new WorkerCommandReply());
}
public async IAsyncEnumerable<WorkerEvent> ReadEventsAsync(
[System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken)
{
await Task.CompletedTask;
yield break;
}
public Task ShutdownAsync(
TimeSpan timeout,
CancellationToken cancellationToken)
{
ShutdownCount++;
State = WorkerClientState.Closed;
return Task.CompletedTask;
}
public void Kill(string reason)
{
KillCount++;
State = WorkerClientState.Faulted;
}
public ValueTask DisposeAsync()
{
return ValueTask.CompletedTask;
}
}
}
@@ -0,0 +1,279 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Sta;
namespace MxGateway.Worker.Tests.Sta;
public sealed class StaCommandDispatcherTests
{
[Fact]
public async Task DispatchAsync_ExecutesCommandsOnStaInQueueOrder()
{
using StaRuntime runtime = CreateRuntime();
runtime.Start();
RecordingCommandExecutor executor = new();
StaCommandDispatcher dispatcher = new(runtime, executor);
Task<MxCommandReply> first = dispatcher.DispatchAsync(CreateCommand("correlation-1", MxCommandKind.Register));
Task<MxCommandReply> second = dispatcher.DispatchAsync(CreateCommand("correlation-2", MxCommandKind.AddItem));
MxCommandReply[] replies = await Task.WhenAll(first, second);
Assert.Equal(new[] { "correlation-1", "correlation-2" }, executor.CorrelationIds);
Assert.All(executor.ThreadIds, threadId => Assert.Equal(runtime.StaThreadId, threadId));
Assert.Equal("correlation-1", replies[0].CorrelationId);
Assert.Equal("correlation-2", replies[1].CorrelationId);
Assert.Equal(ProtocolStatusCode.Ok, replies[0].ProtocolStatus.Code);
}
[Fact]
public async Task DispatchAsync_WhenExecutorThrows_ReturnsFailureReplyWithHResult()
{
using StaRuntime runtime = CreateRuntime();
runtime.Start();
StaCommandDispatcher dispatcher = new(
runtime,
new ThrowingCommandExecutor(new COMException("provider detail", unchecked((int)0x80070057))));
MxCommandReply reply = await dispatcher.DispatchAsync(CreateCommand("correlation-1", MxCommandKind.Register));
Assert.Equal("session-1", reply.SessionId);
Assert.Equal("correlation-1", reply.CorrelationId);
Assert.Equal(MxCommandKind.Register, reply.Kind);
Assert.Equal(ProtocolStatusCode.MxaccessFailure, reply.ProtocolStatus.Code);
Assert.Equal(unchecked((int)0x80070057), reply.Hresult);
Assert.Contains("0x80070057", reply.DiagnosticMessage);
Assert.DoesNotContain("provider detail", reply.DiagnosticMessage);
}
[Fact]
public async Task DispatchAsync_WhenCanceledBeforeExecution_ReturnsCanceledReplyWithoutExecuting()
{
using StaRuntime runtime = CreateRuntime();
runtime.Start();
BlockingCommandExecutor executor = new();
StaCommandDispatcher dispatcher = new(runtime, executor);
Task<MxCommandReply> blocked = dispatcher.DispatchAsync(CreateCommand("blocked", MxCommandKind.Register));
Assert.True(executor.Started.Wait(TimeSpan.FromSeconds(2)));
using CancellationTokenSource cancellation = new();
Task<MxCommandReply> canceled = dispatcher.DispatchAsync(
CreateCommand("canceled", MxCommandKind.AddItem, cancellation.Token));
cancellation.Cancel();
executor.Release();
MxCommandReply canceledReply = await canceled;
await blocked;
Assert.Equal(ProtocolStatusCode.Canceled, canceledReply.ProtocolStatus.Code);
Assert.DoesNotContain("canceled", executor.CorrelationIds);
}
[Fact]
public async Task DispatchAsync_WhenCanceledAfterExecutionStarts_StillReturnsLateReply()
{
using StaRuntime runtime = CreateRuntime();
runtime.Start();
BlockingCommandExecutor executor = new();
StaCommandDispatcher dispatcher = new(runtime, executor);
using CancellationTokenSource cancellation = new();
Task<MxCommandReply> replyTask = dispatcher.DispatchAsync(
CreateCommand("late-reply", MxCommandKind.Register, cancellation.Token));
Assert.True(executor.Started.Wait(TimeSpan.FromSeconds(2)));
cancellation.Cancel();
executor.Release();
MxCommandReply reply = await replyTask;
Assert.Equal(ProtocolStatusCode.Ok, reply.ProtocolStatus.Code);
Assert.Contains("late-reply", executor.CorrelationIds);
}
[Fact]
public async Task DispatchAsync_WhenShutdownRequested_RejectsNewCommands()
{
using StaRuntime runtime = CreateRuntime();
runtime.Start();
StaCommandDispatcher dispatcher = new(runtime, new RecordingCommandExecutor());
dispatcher.RequestShutdown();
MxCommandReply reply = await dispatcher.DispatchAsync(CreateCommand("correlation-1", MxCommandKind.Register));
Assert.Equal(ProtocolStatusCode.WorkerUnavailable, reply.ProtocolStatus.Code);
Assert.Equal("correlation-1", reply.CorrelationId);
}
[Fact]
public async Task PopulateHeartbeat_ReportsCurrentCorrelationAndPendingCount()
{
using StaRuntime runtime = CreateRuntime();
runtime.Start();
BlockingCommandExecutor executor = new();
StaCommandDispatcher dispatcher = new(runtime, executor);
Task<MxCommandReply> current = dispatcher.DispatchAsync(CreateCommand("current", MxCommandKind.Register));
Assert.True(executor.Started.Wait(TimeSpan.FromSeconds(2)));
Task<MxCommandReply> pending = dispatcher.DispatchAsync(CreateCommand("pending", MxCommandKind.AddItem));
WorkerHeartbeat heartbeat = new();
dispatcher.PopulateHeartbeat(heartbeat);
Assert.Equal("current", heartbeat.CurrentCommandCorrelationId);
Assert.Equal(1u, heartbeat.PendingCommandCount);
executor.Release();
await Task.WhenAll(current, pending);
}
private static StaCommand CreateCommand(
string correlationId,
MxCommandKind kind,
CancellationToken cancellationToken = default)
{
return new StaCommand(
"session-1",
correlationId,
new MxCommand
{
Kind = kind,
Ping = new PingCommand
{
Message = correlationId,
},
},
cancellationToken: cancellationToken);
}
private static StaRuntime CreateRuntime()
{
return new StaRuntime(
new NoopComApartmentInitializer(),
new StaMessagePump(),
TimeSpan.FromMilliseconds(25));
}
private sealed class RecordingCommandExecutor : IStaCommandExecutor
{
private readonly object gate = new();
private readonly List<string> correlationIds = new();
private readonly List<int> threadIds = new();
public IReadOnlyList<string> CorrelationIds
{
get
{
lock (gate)
{
return correlationIds.ToArray();
}
}
}
public IReadOnlyList<int> ThreadIds
{
get
{
lock (gate)
{
return threadIds.ToArray();
}
}
}
public MxCommandReply Execute(StaCommand command)
{
lock (gate)
{
correlationIds.Add(command.CorrelationId);
threadIds.Add(Thread.CurrentThread.ManagedThreadId);
}
return new MxCommandReply
{
ProtocolStatus = new ProtocolStatus
{
Code = ProtocolStatusCode.Ok,
Message = "OK",
},
};
}
}
private sealed class BlockingCommandExecutor : IStaCommandExecutor
{
private readonly ManualResetEventSlim release = new(false);
private readonly object gate = new();
private readonly List<string> correlationIds = new();
public ManualResetEventSlim Started { get; } = new(false);
public IReadOnlyList<string> CorrelationIds
{
get
{
lock (gate)
{
return correlationIds.ToArray();
}
}
}
public MxCommandReply Execute(StaCommand command)
{
lock (gate)
{
correlationIds.Add(command.CorrelationId);
}
Started.Set();
release.Wait(TimeSpan.FromSeconds(5));
return new MxCommandReply
{
ProtocolStatus = new ProtocolStatus
{
Code = ProtocolStatusCode.Ok,
Message = "OK",
},
};
}
public void Release()
{
release.Set();
}
}
private sealed class ThrowingCommandExecutor : IStaCommandExecutor
{
private readonly Exception exception;
public ThrowingCommandExecutor(Exception exception)
{
this.exception = exception;
}
public MxCommandReply Execute(StaCommand command)
{
throw exception;
}
}
private sealed class NoopComApartmentInitializer : IStaComApartmentInitializer
{
public void Initialize()
{
}
public void Uninitialize()
{
}
}
}
@@ -0,0 +1,8 @@
using MxGateway.Contracts.Proto;
namespace MxGateway.Worker.Sta;
public interface IStaCommandExecutor
{
MxCommandReply Execute(StaCommand command);
}
+47
View File
@@ -0,0 +1,47 @@
using System;
using System.Threading;
using Google.Protobuf.WellKnownTypes;
using MxGateway.Contracts.Proto;
namespace MxGateway.Worker.Sta;
public sealed class StaCommand
{
public StaCommand(
string sessionId,
string correlationId,
MxCommand command,
Timestamp? enqueueTimestamp = null,
CancellationToken cancellationToken = default)
{
if (string.IsNullOrWhiteSpace(sessionId))
{
throw new ArgumentException("STA command requires a session id.", nameof(sessionId));
}
if (string.IsNullOrWhiteSpace(correlationId))
{
throw new ArgumentException("STA command requires a correlation id.", nameof(correlationId));
}
SessionId = sessionId;
CorrelationId = correlationId;
Command = command ?? throw new ArgumentNullException(nameof(command));
EnqueueTimestamp = enqueueTimestamp ?? Timestamp.FromDateTime(DateTime.UtcNow);
CancellationToken = cancellationToken;
}
public string SessionId { get; }
public string CorrelationId { get; }
public MxCommand Command { get; }
public Timestamp EnqueueTimestamp { get; }
public CancellationToken CancellationToken { get; }
public MxCommandKind Kind => Command.Kind;
public string MethodName => Kind.ToString();
}
@@ -0,0 +1,267 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Conversion;
namespace MxGateway.Worker.Sta;
public sealed class StaCommandDispatcher
{
private readonly HResultConverter hresultConverter;
private readonly IStaCommandExecutor commandExecutor;
private readonly Queue<QueuedStaCommand> commandQueue = new();
private readonly StaRuntime staRuntime;
private readonly object gate = new();
private bool drainActive;
private bool shutdownRequested;
private string currentCommandCorrelationId = string.Empty;
public StaCommandDispatcher(
StaRuntime staRuntime,
IStaCommandExecutor commandExecutor)
: this(staRuntime, commandExecutor, new HResultConverter())
{
}
public StaCommandDispatcher(
StaRuntime staRuntime,
IStaCommandExecutor commandExecutor,
HResultConverter hresultConverter)
{
this.staRuntime = staRuntime ?? throw new ArgumentNullException(nameof(staRuntime));
this.commandExecutor = commandExecutor ?? throw new ArgumentNullException(nameof(commandExecutor));
this.hresultConverter = hresultConverter ?? throw new ArgumentNullException(nameof(hresultConverter));
}
public int PendingCommandCount
{
get
{
lock (gate)
{
return commandQueue.Count;
}
}
}
public string CurrentCommandCorrelationId
{
get
{
lock (gate)
{
return currentCommandCorrelationId;
}
}
}
public Task<MxCommandReply> DispatchAsync(StaCommand command)
{
if (command is null)
{
throw new ArgumentNullException(nameof(command));
}
lock (gate)
{
if (shutdownRequested)
{
return Task.FromResult(CreateRejectedReply(
command,
ProtocolStatusCode.WorkerUnavailable,
"The STA command dispatcher is shutting down."));
}
QueuedStaCommand queuedCommand = new(command);
commandQueue.Enqueue(queuedCommand);
if (!drainActive)
{
drainActive = true;
_ = DrainAsync();
}
return queuedCommand.Task;
}
}
public void RequestShutdown()
{
lock (gate)
{
shutdownRequested = true;
}
}
public void PopulateHeartbeat(WorkerHeartbeat heartbeat)
{
if (heartbeat is null)
{
throw new ArgumentNullException(nameof(heartbeat));
}
lock (gate)
{
heartbeat.PendingCommandCount = (uint)commandQueue.Count;
heartbeat.CurrentCommandCorrelationId = currentCommandCorrelationId;
}
}
private async Task DrainAsync()
{
while (true)
{
QueuedStaCommand queuedCommand;
lock (gate)
{
if (commandQueue.Count == 0)
{
drainActive = false;
return;
}
queuedCommand = commandQueue.Dequeue();
}
await ExecuteQueuedCommandAsync(queuedCommand).ConfigureAwait(false);
}
}
private async Task ExecuteQueuedCommandAsync(QueuedStaCommand queuedCommand)
{
StaCommand command = queuedCommand.Command;
if (command.CancellationToken.IsCancellationRequested)
{
queuedCommand.Complete(CreateRejectedReply(
command,
ProtocolStatusCode.Canceled,
"The STA command was canceled before execution."));
return;
}
SetCurrentCommand(command.CorrelationId);
try
{
MxCommandReply reply = await staRuntime
.InvokeAsync(() => commandExecutor.Execute(command))
.ConfigureAwait(false);
queuedCommand.Complete(NormalizeReply(command, reply));
}
catch (Exception exception)
{
queuedCommand.Complete(CreateExceptionReply(command, exception));
}
finally
{
SetCurrentCommand(string.Empty);
}
}
private void SetCurrentCommand(string correlationId)
{
lock (gate)
{
currentCommandCorrelationId = correlationId;
}
}
private MxCommandReply NormalizeReply(
StaCommand command,
MxCommandReply reply)
{
if (reply is null)
{
return CreateRejectedReply(
command,
ProtocolStatusCode.ProtocolViolation,
"STA command executor returned null.");
}
if (string.IsNullOrWhiteSpace(reply.SessionId))
{
reply.SessionId = command.SessionId;
}
if (string.IsNullOrWhiteSpace(reply.CorrelationId))
{
reply.CorrelationId = command.CorrelationId;
}
if (reply.Kind == MxCommandKind.Unspecified)
{
reply.Kind = command.Kind;
}
if (reply.ProtocolStatus is null)
{
reply.ProtocolStatus = new ProtocolStatus
{
Code = ProtocolStatusCode.Ok,
Message = "OK",
};
}
return reply;
}
private MxCommandReply CreateExceptionReply(
StaCommand command,
Exception exception)
{
HResultConversion conversion = hresultConverter.Convert(exception);
MxCommandReply reply = CreateBaseReply(command);
reply.ProtocolStatus = conversion.ProtocolStatus;
reply.Hresult = conversion.HResult;
reply.DiagnosticMessage = conversion.DiagnosticMessage;
return reply;
}
private static MxCommandReply CreateRejectedReply(
StaCommand command,
ProtocolStatusCode statusCode,
string message)
{
MxCommandReply reply = CreateBaseReply(command);
reply.ProtocolStatus = new ProtocolStatus
{
Code = statusCode,
Message = message,
};
reply.DiagnosticMessage = message;
return reply;
}
private static MxCommandReply CreateBaseReply(StaCommand command)
{
return new MxCommandReply
{
SessionId = command.SessionId,
CorrelationId = command.CorrelationId,
Kind = command.Kind,
};
}
private sealed class QueuedStaCommand
{
private readonly TaskCompletionSource<MxCommandReply> completion = new(
TaskCreationOptions.RunContinuationsAsynchronously);
public QueuedStaCommand(StaCommand command)
{
Command = command;
}
public StaCommand Command { get; }
public Task<MxCommandReply> Task => completion.Task;
public void Complete(MxCommandReply reply)
{
completion.TrySetResult(reply);
}
}
}