Files
mxaccessgw/src/MxGateway.Worker/Sta/StaCommandDispatcher.cs
T
2026-04-26 17:49:01 -04:00

268 lines
7.0 KiB
C#

using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Conversion;
namespace MxGateway.Worker.Sta;
public sealed class StaCommandDispatcher
{
private readonly HResultConverter hresultConverter;
private readonly IStaCommandExecutor commandExecutor;
private readonly Queue<QueuedStaCommand> commandQueue = new();
private readonly StaRuntime staRuntime;
private readonly object gate = new();
private bool drainActive;
private bool shutdownRequested;
private string currentCommandCorrelationId = string.Empty;
public StaCommandDispatcher(
StaRuntime staRuntime,
IStaCommandExecutor commandExecutor)
: this(staRuntime, commandExecutor, new HResultConverter())
{
}
public StaCommandDispatcher(
StaRuntime staRuntime,
IStaCommandExecutor commandExecutor,
HResultConverter hresultConverter)
{
this.staRuntime = staRuntime ?? throw new ArgumentNullException(nameof(staRuntime));
this.commandExecutor = commandExecutor ?? throw new ArgumentNullException(nameof(commandExecutor));
this.hresultConverter = hresultConverter ?? throw new ArgumentNullException(nameof(hresultConverter));
}
public int PendingCommandCount
{
get
{
lock (gate)
{
return commandQueue.Count;
}
}
}
public string CurrentCommandCorrelationId
{
get
{
lock (gate)
{
return currentCommandCorrelationId;
}
}
}
public Task<MxCommandReply> DispatchAsync(StaCommand command)
{
if (command is null)
{
throw new ArgumentNullException(nameof(command));
}
lock (gate)
{
if (shutdownRequested)
{
return Task.FromResult(CreateRejectedReply(
command,
ProtocolStatusCode.WorkerUnavailable,
"The STA command dispatcher is shutting down."));
}
QueuedStaCommand queuedCommand = new(command);
commandQueue.Enqueue(queuedCommand);
if (!drainActive)
{
drainActive = true;
_ = DrainAsync();
}
return queuedCommand.Task;
}
}
public void RequestShutdown()
{
lock (gate)
{
shutdownRequested = true;
}
}
public void PopulateHeartbeat(WorkerHeartbeat heartbeat)
{
if (heartbeat is null)
{
throw new ArgumentNullException(nameof(heartbeat));
}
lock (gate)
{
heartbeat.PendingCommandCount = (uint)commandQueue.Count;
heartbeat.CurrentCommandCorrelationId = currentCommandCorrelationId;
}
}
private async Task DrainAsync()
{
while (true)
{
QueuedStaCommand queuedCommand;
lock (gate)
{
if (commandQueue.Count == 0)
{
drainActive = false;
return;
}
queuedCommand = commandQueue.Dequeue();
}
await ExecuteQueuedCommandAsync(queuedCommand).ConfigureAwait(false);
}
}
private async Task ExecuteQueuedCommandAsync(QueuedStaCommand queuedCommand)
{
StaCommand command = queuedCommand.Command;
if (command.CancellationToken.IsCancellationRequested)
{
queuedCommand.Complete(CreateRejectedReply(
command,
ProtocolStatusCode.Canceled,
"The STA command was canceled before execution."));
return;
}
SetCurrentCommand(command.CorrelationId);
try
{
MxCommandReply reply = await staRuntime
.InvokeAsync(() => commandExecutor.Execute(command))
.ConfigureAwait(false);
queuedCommand.Complete(NormalizeReply(command, reply));
}
catch (Exception exception)
{
queuedCommand.Complete(CreateExceptionReply(command, exception));
}
finally
{
SetCurrentCommand(string.Empty);
}
}
private void SetCurrentCommand(string correlationId)
{
lock (gate)
{
currentCommandCorrelationId = correlationId;
}
}
private MxCommandReply NormalizeReply(
StaCommand command,
MxCommandReply reply)
{
if (reply is null)
{
return CreateRejectedReply(
command,
ProtocolStatusCode.ProtocolViolation,
"STA command executor returned null.");
}
if (string.IsNullOrWhiteSpace(reply.SessionId))
{
reply.SessionId = command.SessionId;
}
if (string.IsNullOrWhiteSpace(reply.CorrelationId))
{
reply.CorrelationId = command.CorrelationId;
}
if (reply.Kind == MxCommandKind.Unspecified)
{
reply.Kind = command.Kind;
}
if (reply.ProtocolStatus is null)
{
reply.ProtocolStatus = new ProtocolStatus
{
Code = ProtocolStatusCode.Ok,
Message = "OK",
};
}
return reply;
}
private MxCommandReply CreateExceptionReply(
StaCommand command,
Exception exception)
{
HResultConversion conversion = hresultConverter.Convert(exception);
MxCommandReply reply = CreateBaseReply(command);
reply.ProtocolStatus = conversion.ProtocolStatus;
reply.Hresult = conversion.HResult;
reply.DiagnosticMessage = conversion.DiagnosticMessage;
return reply;
}
private static MxCommandReply CreateRejectedReply(
StaCommand command,
ProtocolStatusCode statusCode,
string message)
{
MxCommandReply reply = CreateBaseReply(command);
reply.ProtocolStatus = new ProtocolStatus
{
Code = statusCode,
Message = message,
};
reply.DiagnosticMessage = message;
return reply;
}
private static MxCommandReply CreateBaseReply(StaCommand command)
{
return new MxCommandReply
{
SessionId = command.SessionId,
CorrelationId = command.CorrelationId,
Kind = command.Kind,
};
}
private sealed class QueuedStaCommand
{
private readonly TaskCompletionSource<MxCommandReply> completion = new(
TaskCreationOptions.RunContinuationsAsynchronously);
public QueuedStaCommand(StaCommand command)
{
Command = command;
}
public StaCommand Command { get; }
public Task<MxCommandReply> Task => completion.Task;
public void Complete(MxCommandReply reply)
{
completion.TrySetResult(reply);
}
}
}