using System; using System.Collections.Generic; using System.Threading.Tasks; using MxGateway.Contracts.Proto; using MxGateway.Worker.Conversion; namespace MxGateway.Worker.Sta; public sealed class StaCommandDispatcher { private readonly HResultConverter hresultConverter; private readonly IStaCommandExecutor commandExecutor; private readonly Queue commandQueue = new(); private readonly StaRuntime staRuntime; private readonly object gate = new(); private bool drainActive; private bool shutdownRequested; private string currentCommandCorrelationId = string.Empty; public StaCommandDispatcher( StaRuntime staRuntime, IStaCommandExecutor commandExecutor) : this(staRuntime, commandExecutor, new HResultConverter()) { } public StaCommandDispatcher( StaRuntime staRuntime, IStaCommandExecutor commandExecutor, HResultConverter hresultConverter) { this.staRuntime = staRuntime ?? throw new ArgumentNullException(nameof(staRuntime)); this.commandExecutor = commandExecutor ?? throw new ArgumentNullException(nameof(commandExecutor)); this.hresultConverter = hresultConverter ?? throw new ArgumentNullException(nameof(hresultConverter)); } public int PendingCommandCount { get { lock (gate) { return commandQueue.Count; } } } public string CurrentCommandCorrelationId { get { lock (gate) { return currentCommandCorrelationId; } } } public Task DispatchAsync(StaCommand command) { if (command is null) { throw new ArgumentNullException(nameof(command)); } lock (gate) { if (shutdownRequested) { return Task.FromResult(CreateRejectedReply( command, ProtocolStatusCode.WorkerUnavailable, "The STA command dispatcher is shutting down.")); } QueuedStaCommand queuedCommand = new(command); commandQueue.Enqueue(queuedCommand); if (!drainActive) { drainActive = true; _ = DrainAsync(); } return queuedCommand.Task; } } public void RequestShutdown() { lock (gate) { shutdownRequested = true; } } public void PopulateHeartbeat(WorkerHeartbeat heartbeat) { if (heartbeat is null) { throw new ArgumentNullException(nameof(heartbeat)); } lock (gate) { heartbeat.PendingCommandCount = (uint)commandQueue.Count; heartbeat.CurrentCommandCorrelationId = currentCommandCorrelationId; } } private async Task DrainAsync() { while (true) { QueuedStaCommand queuedCommand; lock (gate) { if (commandQueue.Count == 0) { drainActive = false; return; } queuedCommand = commandQueue.Dequeue(); } await ExecuteQueuedCommandAsync(queuedCommand).ConfigureAwait(false); } } private async Task ExecuteQueuedCommandAsync(QueuedStaCommand queuedCommand) { StaCommand command = queuedCommand.Command; if (command.CancellationToken.IsCancellationRequested) { queuedCommand.Complete(CreateRejectedReply( command, ProtocolStatusCode.Canceled, "The STA command was canceled before execution.")); return; } SetCurrentCommand(command.CorrelationId); try { MxCommandReply reply = await staRuntime .InvokeAsync(() => commandExecutor.Execute(command)) .ConfigureAwait(false); queuedCommand.Complete(NormalizeReply(command, reply)); } catch (Exception exception) { queuedCommand.Complete(CreateExceptionReply(command, exception)); } finally { SetCurrentCommand(string.Empty); } } private void SetCurrentCommand(string correlationId) { lock (gate) { currentCommandCorrelationId = correlationId; } } private MxCommandReply NormalizeReply( StaCommand command, MxCommandReply reply) { if (reply is null) { return CreateRejectedReply( command, ProtocolStatusCode.ProtocolViolation, "STA command executor returned null."); } if (string.IsNullOrWhiteSpace(reply.SessionId)) { reply.SessionId = command.SessionId; } if (string.IsNullOrWhiteSpace(reply.CorrelationId)) { reply.CorrelationId = command.CorrelationId; } if (reply.Kind == MxCommandKind.Unspecified) { reply.Kind = command.Kind; } if (reply.ProtocolStatus is null) { reply.ProtocolStatus = new ProtocolStatus { Code = ProtocolStatusCode.Ok, Message = "OK", }; } return reply; } private MxCommandReply CreateExceptionReply( StaCommand command, Exception exception) { HResultConversion conversion = hresultConverter.Convert(exception); MxCommandReply reply = CreateBaseReply(command); reply.ProtocolStatus = conversion.ProtocolStatus; reply.Hresult = conversion.HResult; reply.DiagnosticMessage = conversion.DiagnosticMessage; return reply; } private static MxCommandReply CreateRejectedReply( StaCommand command, ProtocolStatusCode statusCode, string message) { MxCommandReply reply = CreateBaseReply(command); reply.ProtocolStatus = new ProtocolStatus { Code = statusCode, Message = message, }; reply.DiagnosticMessage = message; return reply; } private static MxCommandReply CreateBaseReply(StaCommand command) { return new MxCommandReply { SessionId = command.SessionId, CorrelationId = command.CorrelationId, Kind = command.Kind, }; } private sealed class QueuedStaCommand { private readonly TaskCompletionSource completion = new( TaskCreationOptions.RunContinuationsAsynchronously); public QueuedStaCommand(StaCommand command) { Command = command; } public StaCommand Command { get; } public Task Task => completion.Task; public void Complete(MxCommandReply reply) { completion.TrySetResult(reply); } } }