using System; using System.Diagnostics; using System.IO; using System.Threading; using System.Threading.Tasks; using Google.Protobuf.WellKnownTypes; using MxGateway.Contracts.Proto; using MxGateway.Worker.MxAccess; using MxGateway.Worker.Sta; namespace MxGateway.Worker.Ipc; public sealed class WorkerPipeSession { private readonly WorkerFrameProtocolOptions _options; private readonly Func _processIdProvider; private readonly Func _runtimeSessionFactory; private readonly WorkerPipeSessionOptions _sessionOptions; private readonly WorkerFrameReader _reader; private readonly WorkerFrameWriter _writer; private IWorkerRuntimeSession? _runtimeSession; private long _nextSequence; private WorkerState _state = WorkerState.Starting; private bool _watchdogFaultSent; public WorkerPipeSession( Stream stream, WorkerFrameProtocolOptions options) : this( new WorkerFrameReader(stream, options), new WorkerFrameWriter(stream, options), options, () => Process.GetCurrentProcess().Id) { } public WorkerPipeSession( WorkerFrameReader reader, WorkerFrameWriter writer, WorkerFrameProtocolOptions options, Func processIdProvider) : this( reader, writer, options, processIdProvider, new WorkerPipeSessionOptions(), () => new MxAccessStaSession()) { } public WorkerPipeSession( WorkerFrameReader reader, WorkerFrameWriter writer, WorkerFrameProtocolOptions options, Func processIdProvider, WorkerPipeSessionOptions sessionOptions, Func runtimeSessionFactory) { _reader = reader ?? throw new ArgumentNullException(nameof(reader)); _writer = writer ?? throw new ArgumentNullException(nameof(writer)); _options = options ?? throw new ArgumentNullException(nameof(options)); _processIdProvider = processIdProvider ?? throw new ArgumentNullException(nameof(processIdProvider)); _sessionOptions = sessionOptions ?? throw new ArgumentNullException(nameof(sessionOptions)); _runtimeSessionFactory = runtimeSessionFactory ?? throw new ArgumentNullException(nameof(runtimeSessionFactory)); _sessionOptions.Validate(); } public async Task RunAsync(CancellationToken cancellationToken = default) { _runtimeSession = _runtimeSessionFactory(); try { await CompleteStartupHandshakeAsync( token => _runtimeSession.StartAsync(_options.SessionId, _processIdProvider(), token), cancellationToken).ConfigureAwait(false); await RunMessageLoopAsync(cancellationToken).ConfigureAwait(false); } finally { _runtimeSession?.Dispose(); _runtimeSession = null; _state = WorkerState.Stopped; } } public Task CompleteStartupHandshakeAsync(CancellationToken cancellationToken = default) { return CompleteStartupHandshakeAsync(InitializeMxAccessAsync, cancellationToken); } public async Task CompleteStartupHandshakeAsync( Func initializeMxAccessAsync, CancellationToken cancellationToken = default) { if (initializeMxAccessAsync is null) { throw new ArgumentNullException(nameof(initializeMxAccessAsync)); } await CompleteStartupHandshakeAsync( async innerCancellationToken => { await initializeMxAccessAsync(innerCancellationToken).ConfigureAwait(false); return CreateWorkerReady(); }, cancellationToken).ConfigureAwait(false); } public async Task CompleteStartupHandshakeAsync( Func> initializeMxAccessAsync, CancellationToken cancellationToken = default) { if (initializeMxAccessAsync is null) { throw new ArgumentNullException(nameof(initializeMxAccessAsync)); } try { WorkerEnvelope envelope = await _reader.ReadAsync(cancellationToken).ConfigureAwait(false); _state = WorkerState.Handshaking; ValidateGatewayHello(envelope); await WriteWorkerHelloAsync(cancellationToken).ConfigureAwait(false); _state = WorkerState.InitializingSta; WorkerReady ready = await initializeMxAccessAsync(cancellationToken).ConfigureAwait(false); await WriteWorkerReadyAsync(ready, cancellationToken).ConfigureAwait(false); _state = WorkerState.Ready; } catch (WorkerFrameProtocolException exception) { await TryWriteFaultAsync(exception, cancellationToken).ConfigureAwait(false); throw; } catch (Exception exception) when (exception is not OperationCanceledException) { await TryWriteFaultAsync(MxAccessCreationException.From(exception), cancellationToken) .ConfigureAwait(false); throw; } } private void ValidateGatewayHello(WorkerEnvelope envelope) { if (envelope.BodyCase != WorkerEnvelope.BodyOneofCase.GatewayHello) { throw new WorkerFrameProtocolException( WorkerFrameProtocolErrorCode.UnexpectedEnvelopeBody, "Worker expected GatewayHello during startup handshake."); } GatewayHello gatewayHello = envelope.GatewayHello; if (gatewayHello.SupportedProtocolVersion != _options.ProtocolVersion) { throw new WorkerFrameProtocolException( WorkerFrameProtocolErrorCode.ProtocolVersionMismatch, $"GatewayHello supported protocol version {gatewayHello.SupportedProtocolVersion} does not match expected version {_options.ProtocolVersion}."); } if (!string.Equals(gatewayHello.Nonce, _options.Nonce, StringComparison.Ordinal)) { throw new WorkerFrameProtocolException( WorkerFrameProtocolErrorCode.NonceMismatch, "GatewayHello nonce does not match the worker launch nonce."); } } private Task WriteWorkerHelloAsync(CancellationToken cancellationToken) { return _writer.WriteAsync( CreateEnvelope(new WorkerHello { ProtocolVersion = _options.ProtocolVersion, Nonce = _options.Nonce, WorkerProcessId = _processIdProvider(), WorkerVersion = typeof(WorkerPipeSession).Assembly.GetName().Version?.ToString() ?? string.Empty, }), cancellationToken); } private Task WriteWorkerReadyAsync( WorkerReady ready, CancellationToken cancellationToken) { return _writer.WriteAsync(CreateEnvelope(ready), cancellationToken); } private async Task RunMessageLoopAsync(CancellationToken cancellationToken) { using CancellationTokenSource heartbeatCancellation = CancellationTokenSource .CreateLinkedTokenSource(cancellationToken); Task heartbeatTask = RunHeartbeatLoopAsync(heartbeatCancellation.Token); try { while (!cancellationToken.IsCancellationRequested) { Task readTask = _reader.ReadAsync(cancellationToken); Task completedTask = await Task.WhenAny(readTask, heartbeatTask).ConfigureAwait(false); if (completedTask == heartbeatTask) { await heartbeatTask.ConfigureAwait(false); } WorkerEnvelope envelope = await readTask.ConfigureAwait(false); bool keepReading = await DispatchGatewayEnvelopeAsync(envelope, cancellationToken).ConfigureAwait(false); if (!keepReading) { return; } } } finally { heartbeatCancellation.Cancel(); try { await heartbeatTask.ConfigureAwait(false); } catch (OperationCanceledException) { } } } private async Task DispatchGatewayEnvelopeAsync( WorkerEnvelope envelope, CancellationToken cancellationToken) { switch (envelope.BodyCase) { case WorkerEnvelope.BodyOneofCase.WorkerCommand: _ = ProcessCommandAsync(envelope, cancellationToken); return true; case WorkerEnvelope.BodyOneofCase.WorkerShutdown: await ShutdownAsync(envelope.WorkerShutdown, cancellationToken).ConfigureAwait(false); return false; case WorkerEnvelope.BodyOneofCase.WorkerCancel: return true; default: throw new WorkerFrameProtocolException( WorkerFrameProtocolErrorCode.UnexpectedEnvelopeBody, $"Worker received unexpected gateway envelope body {envelope.BodyCase}."); } } private async Task ProcessCommandAsync( WorkerEnvelope envelope, CancellationToken cancellationToken) { IWorkerRuntimeSession runtimeSession = _runtimeSession ?? throw new InvalidOperationException("Worker runtime session has not been initialized."); WorkerCommand workerCommand = envelope.WorkerCommand; MxCommand command = workerCommand.Command; StaCommand staCommand = new( _options.SessionId, envelope.CorrelationId, command, workerCommand.EnqueueTimestamp, cancellationToken); try { MxCommandReply reply = await runtimeSession.DispatchAsync(staCommand).ConfigureAwait(false); await _writer .WriteAsync( CreateEnvelope(new WorkerCommandReply { Reply = reply, CompletedTimestamp = Timestamp.FromDateTime(DateTime.UtcNow), }), cancellationToken) .ConfigureAwait(false); } catch (Exception exception) when (exception is not OperationCanceledException) { _state = WorkerState.Faulted; await TryWriteFaultAsync( CreateFault( WorkerFaultCategory.MxaccessCommandFailed, staCommand.MethodName, exception), cancellationToken).ConfigureAwait(false); } } private async Task ShutdownAsync( WorkerShutdown shutdown, CancellationToken cancellationToken) { _state = WorkerState.ShuttingDown; _runtimeSession?.RequestShutdown(); await _writer .WriteAsync( CreateEnvelope( new WorkerShutdownAck { Status = new ProtocolStatus { Code = ProtocolStatusCode.Ok, Message = string.IsNullOrWhiteSpace(shutdown.Reason) ? "Worker shutdown accepted." : $"Worker shutdown accepted: {shutdown.Reason}", }, }), cancellationToken) .ConfigureAwait(false); } private async Task RunHeartbeatLoopAsync(CancellationToken cancellationToken) { while (!cancellationToken.IsCancellationRequested) { await Task.Delay(_sessionOptions.HeartbeatInterval, cancellationToken).ConfigureAwait(false); IWorkerRuntimeSession? runtimeSession = _runtimeSession; if (runtimeSession is null) { continue; } WorkerRuntimeHeartbeatSnapshot snapshot = runtimeSession.CaptureHeartbeat(); await _writer .WriteAsync(CreateEnvelope(CreateHeartbeat(snapshot)), cancellationToken) .ConfigureAwait(false); await ReportWatchdogFaultIfNeededAsync(snapshot, cancellationToken).ConfigureAwait(false); } } private async Task ReportWatchdogFaultIfNeededAsync( WorkerRuntimeHeartbeatSnapshot snapshot, CancellationToken cancellationToken) { TimeSpan staleFor = DateTimeOffset.UtcNow - snapshot.LastStaActivityUtc; if (staleFor <= _sessionOptions.HeartbeatGrace) { _watchdogFaultSent = false; return; } if (_watchdogFaultSent) { return; } _watchdogFaultSent = true; await TryWriteFaultAsync( CreateFault( WorkerFaultCategory.StaHung, snapshot.CurrentCommandCorrelationId, $"STA activity is stale by {staleFor}."), cancellationToken).ConfigureAwait(false); } private async Task TryWriteFaultAsync( WorkerFrameProtocolException exception, CancellationToken cancellationToken) { try { await _writer .WriteAsync(CreateEnvelope(CreateFault(exception)), cancellationToken) .ConfigureAwait(false); } catch (Exception faultWriteException) when ( faultWriteException is IOException || faultWriteException is ObjectDisposedException || faultWriteException is WorkerFrameProtocolException) { // The original protocol failure is the actionable error. } } private async Task TryWriteFaultAsync( MxAccessCreationException exception, CancellationToken cancellationToken) { try { await _writer .WriteAsync(CreateEnvelope(CreateFault(exception)), cancellationToken) .ConfigureAwait(false); } catch (Exception faultWriteException) when ( faultWriteException is IOException || faultWriteException is ObjectDisposedException || faultWriteException is WorkerFrameProtocolException) { // The MXAccess creation failure is the actionable error. } } private async Task TryWriteFaultAsync( WorkerFault fault, CancellationToken cancellationToken) { try { await _writer .WriteAsync(CreateEnvelope(fault), cancellationToken) .ConfigureAwait(false); } catch (Exception faultWriteException) when ( faultWriteException is IOException || faultWriteException is ObjectDisposedException || faultWriteException is WorkerFrameProtocolException) { // The runtime fault remains observable through worker exit or pipe closure. } } private WorkerEnvelope CreateEnvelope(WorkerHello hello) { return CreateBaseEnvelope(hello); } private WorkerEnvelope CreateEnvelope(WorkerReady ready) { return CreateBaseEnvelope(ready); } private WorkerEnvelope CreateEnvelope(WorkerFault fault) { return CreateBaseEnvelope(fault); } private WorkerEnvelope CreateEnvelope(WorkerCommandReply reply) { return CreateBaseEnvelope(reply); } private WorkerEnvelope CreateEnvelope(WorkerShutdownAck shutdownAck) { return CreateBaseEnvelope(shutdownAck); } private WorkerEnvelope CreateEnvelope(WorkerHeartbeat heartbeat) { return CreateBaseEnvelope(heartbeat); } private WorkerEnvelope CreateBaseEnvelope(WorkerHello body) { WorkerEnvelope envelope = CreateBaseEnvelope(); envelope.WorkerHello = body; return envelope; } private WorkerEnvelope CreateBaseEnvelope(WorkerReady body) { WorkerEnvelope envelope = CreateBaseEnvelope(); envelope.WorkerReady = body; return envelope; } private WorkerEnvelope CreateBaseEnvelope(WorkerFault body) { WorkerEnvelope envelope = CreateBaseEnvelope(); envelope.WorkerFault = body; return envelope; } private WorkerEnvelope CreateBaseEnvelope(WorkerCommandReply body) { WorkerEnvelope envelope = CreateBaseEnvelope(); envelope.CorrelationId = body.Reply?.CorrelationId ?? string.Empty; envelope.WorkerCommandReply = body; return envelope; } private WorkerEnvelope CreateBaseEnvelope(WorkerShutdownAck body) { WorkerEnvelope envelope = CreateBaseEnvelope(); envelope.WorkerShutdownAck = body; return envelope; } private WorkerEnvelope CreateBaseEnvelope(WorkerHeartbeat body) { WorkerEnvelope envelope = CreateBaseEnvelope(); envelope.WorkerHeartbeat = body; return envelope; } private WorkerEnvelope CreateBaseEnvelope() { return new WorkerEnvelope { ProtocolVersion = _options.ProtocolVersion, SessionId = _options.SessionId, Sequence = NextSequence(), }; } private ulong NextSequence() { return unchecked((ulong)Interlocked.Increment(ref _nextSequence)); } private async Task InitializeMxAccessAsync(CancellationToken cancellationToken) { _runtimeSession = new MxAccessStaSession(); try { return await _runtimeSession .StartAsync(_options.SessionId, _processIdProvider(), cancellationToken) .ConfigureAwait(false); } catch { _runtimeSession.Dispose(); _runtimeSession = null; throw; } } private WorkerHeartbeat CreateHeartbeat(WorkerRuntimeHeartbeatSnapshot snapshot) { WorkerState state = string.IsNullOrWhiteSpace(snapshot.CurrentCommandCorrelationId) ? _state : WorkerState.ExecutingCommand; return new WorkerHeartbeat { WorkerProcessId = _processIdProvider(), State = state, LastStaActivityTimestamp = Timestamp.FromDateTimeOffset(snapshot.LastStaActivityUtc), PendingCommandCount = snapshot.PendingCommandCount, OutboundEventQueueDepth = snapshot.OutboundEventQueueDepth, LastEventSequence = snapshot.LastEventSequence, CurrentCommandCorrelationId = snapshot.CurrentCommandCorrelationId, }; } private WorkerReady CreateWorkerReady() { return new WorkerReady { WorkerProcessId = _processIdProvider(), MxaccessProgid = MxAccessInteropInfo.ProgId, MxaccessClsid = MxAccessInteropInfo.Clsid, ReadyTimestamp = Timestamp.FromDateTime(DateTime.UtcNow), }; } private static WorkerFault CreateFault(WorkerFrameProtocolException exception) { return new WorkerFault { Category = MapFaultCategory(exception.ErrorCode), ExceptionType = exception.GetType().FullName ?? string.Empty, DiagnosticMessage = exception.Message, ProtocolStatus = new ProtocolStatus { Code = ProtocolStatusCode.ProtocolViolation, Message = exception.Message, }, }; } private static WorkerFault CreateFault(MxAccessCreationException exception) { WorkerFault fault = new() { Category = WorkerFaultCategory.MxaccessCreationFailed, ExceptionType = exception.InnerException?.GetType().FullName ?? exception.GetType().FullName ?? string.Empty, DiagnosticMessage = exception.Message, ProtocolStatus = new ProtocolStatus { Code = ProtocolStatusCode.WorkerUnavailable, Message = exception.Message, }, }; int? hresult = MxAccessCreationException.ExtractHResult(exception); if (hresult.HasValue) { fault.Hresult = hresult.Value; } return fault; } private static WorkerFault CreateFault( WorkerFaultCategory category, string commandMethod, Exception exception) { WorkerFault fault = CreateFault( category, commandMethod, exception.Message); fault.ExceptionType = exception.GetType().FullName ?? string.Empty; fault.ProtocolStatus = new ProtocolStatus { Code = ProtocolStatusCode.WorkerUnavailable, Message = exception.Message, }; return fault; } private static WorkerFault CreateFault( WorkerFaultCategory category, string commandMethod, string diagnosticMessage) { return new WorkerFault { Category = category, CommandMethod = commandMethod ?? string.Empty, DiagnosticMessage = diagnosticMessage, ProtocolStatus = new ProtocolStatus { Code = ProtocolStatusCode.WorkerUnavailable, Message = diagnosticMessage, }, }; } private static WorkerFaultCategory MapFaultCategory(WorkerFrameProtocolErrorCode errorCode) { return errorCode switch { WorkerFrameProtocolErrorCode.ProtocolVersionMismatch => WorkerFaultCategory.ProtocolMismatch, WorkerFrameProtocolErrorCode.EndOfStream => WorkerFaultCategory.PipeDisconnected, _ => WorkerFaultCategory.ProtocolViolation, }; } }