using System; using System.Diagnostics; using System.IO; using System.Threading; using System.Threading.Tasks; using Google.Protobuf.WellKnownTypes; using MxGateway.Contracts.Proto; using MxGateway.Worker.MxAccess; namespace MxGateway.Worker.Ipc; public sealed class WorkerPipeSession { private readonly WorkerFrameProtocolOptions _options; private readonly Func _processIdProvider; private readonly WorkerFrameReader _reader; private readonly WorkerFrameWriter _writer; private MxAccessStaSession? _mxAccessStaSession; private long _nextSequence; public WorkerPipeSession( Stream stream, WorkerFrameProtocolOptions options) : this( new WorkerFrameReader(stream, options), new WorkerFrameWriter(stream, options), options, () => Process.GetCurrentProcess().Id) { } public WorkerPipeSession( WorkerFrameReader reader, WorkerFrameWriter writer, WorkerFrameProtocolOptions options, Func processIdProvider) { _reader = reader ?? throw new ArgumentNullException(nameof(reader)); _writer = writer ?? throw new ArgumentNullException(nameof(writer)); _options = options ?? throw new ArgumentNullException(nameof(options)); _processIdProvider = processIdProvider ?? throw new ArgumentNullException(nameof(processIdProvider)); } public Task CompleteStartupHandshakeAsync(CancellationToken cancellationToken = default) { return CompleteStartupHandshakeAsync(InitializeMxAccessAsync, cancellationToken); } public async Task CompleteStartupHandshakeAsync( Func initializeMxAccessAsync, CancellationToken cancellationToken = default) { if (initializeMxAccessAsync is null) { throw new ArgumentNullException(nameof(initializeMxAccessAsync)); } await CompleteStartupHandshakeAsync( async innerCancellationToken => { await initializeMxAccessAsync(innerCancellationToken).ConfigureAwait(false); return CreateWorkerReady(); }, cancellationToken).ConfigureAwait(false); } public async Task CompleteStartupHandshakeAsync( Func> initializeMxAccessAsync, CancellationToken cancellationToken = default) { if (initializeMxAccessAsync is null) { throw new ArgumentNullException(nameof(initializeMxAccessAsync)); } try { WorkerEnvelope envelope = await _reader.ReadAsync(cancellationToken).ConfigureAwait(false); ValidateGatewayHello(envelope); await WriteWorkerHelloAsync(cancellationToken).ConfigureAwait(false); WorkerReady ready = await initializeMxAccessAsync(cancellationToken).ConfigureAwait(false); await WriteWorkerReadyAsync(ready, cancellationToken).ConfigureAwait(false); } catch (WorkerFrameProtocolException exception) { await TryWriteFaultAsync(exception, cancellationToken).ConfigureAwait(false); throw; } catch (Exception exception) when (exception is not OperationCanceledException) { await TryWriteFaultAsync(MxAccessCreationException.From(exception), cancellationToken) .ConfigureAwait(false); throw; } } private void ValidateGatewayHello(WorkerEnvelope envelope) { if (envelope.BodyCase != WorkerEnvelope.BodyOneofCase.GatewayHello) { throw new WorkerFrameProtocolException( WorkerFrameProtocolErrorCode.UnexpectedEnvelopeBody, "Worker expected GatewayHello during startup handshake."); } GatewayHello gatewayHello = envelope.GatewayHello; if (gatewayHello.SupportedProtocolVersion != _options.ProtocolVersion) { throw new WorkerFrameProtocolException( WorkerFrameProtocolErrorCode.ProtocolVersionMismatch, $"GatewayHello supported protocol version {gatewayHello.SupportedProtocolVersion} does not match expected version {_options.ProtocolVersion}."); } if (!string.Equals(gatewayHello.Nonce, _options.Nonce, StringComparison.Ordinal)) { throw new WorkerFrameProtocolException( WorkerFrameProtocolErrorCode.NonceMismatch, "GatewayHello nonce does not match the worker launch nonce."); } } private Task WriteWorkerHelloAsync(CancellationToken cancellationToken) { return _writer.WriteAsync( CreateEnvelope(new WorkerHello { ProtocolVersion = _options.ProtocolVersion, Nonce = _options.Nonce, WorkerProcessId = _processIdProvider(), WorkerVersion = typeof(WorkerPipeSession).Assembly.GetName().Version?.ToString() ?? string.Empty, }), cancellationToken); } private Task WriteWorkerReadyAsync( WorkerReady ready, CancellationToken cancellationToken) { return _writer.WriteAsync(CreateEnvelope(ready), cancellationToken); } private async Task TryWriteFaultAsync( WorkerFrameProtocolException exception, CancellationToken cancellationToken) { try { await _writer .WriteAsync(CreateEnvelope(CreateFault(exception)), cancellationToken) .ConfigureAwait(false); } catch (Exception faultWriteException) when ( faultWriteException is IOException || faultWriteException is ObjectDisposedException || faultWriteException is WorkerFrameProtocolException) { // The original protocol failure is the actionable error. } } private async Task TryWriteFaultAsync( MxAccessCreationException exception, CancellationToken cancellationToken) { try { await _writer .WriteAsync(CreateEnvelope(CreateFault(exception)), cancellationToken) .ConfigureAwait(false); } catch (Exception faultWriteException) when ( faultWriteException is IOException || faultWriteException is ObjectDisposedException || faultWriteException is WorkerFrameProtocolException) { // The MXAccess creation failure is the actionable error. } } private WorkerEnvelope CreateEnvelope(WorkerHello hello) { return CreateBaseEnvelope(hello); } private WorkerEnvelope CreateEnvelope(WorkerReady ready) { return CreateBaseEnvelope(ready); } private WorkerEnvelope CreateEnvelope(WorkerFault fault) { return CreateBaseEnvelope(fault); } private WorkerEnvelope CreateBaseEnvelope(WorkerHello body) { WorkerEnvelope envelope = CreateBaseEnvelope(); envelope.WorkerHello = body; return envelope; } private WorkerEnvelope CreateBaseEnvelope(WorkerReady body) { WorkerEnvelope envelope = CreateBaseEnvelope(); envelope.WorkerReady = body; return envelope; } private WorkerEnvelope CreateBaseEnvelope(WorkerFault body) { WorkerEnvelope envelope = CreateBaseEnvelope(); envelope.WorkerFault = body; return envelope; } private WorkerEnvelope CreateBaseEnvelope() { return new WorkerEnvelope { ProtocolVersion = _options.ProtocolVersion, SessionId = _options.SessionId, Sequence = NextSequence(), }; } private ulong NextSequence() { return unchecked((ulong)Interlocked.Increment(ref _nextSequence)); } private async Task InitializeMxAccessAsync(CancellationToken cancellationToken) { _mxAccessStaSession = new MxAccessStaSession(); try { return await _mxAccessStaSession .StartAsync(_processIdProvider(), cancellationToken) .ConfigureAwait(false); } catch { _mxAccessStaSession.Dispose(); _mxAccessStaSession = null; throw; } } private WorkerReady CreateWorkerReady() { return new WorkerReady { WorkerProcessId = _processIdProvider(), MxaccessProgid = MxAccessInteropInfo.ProgId, MxaccessClsid = MxAccessInteropInfo.Clsid, ReadyTimestamp = Timestamp.FromDateTime(DateTime.UtcNow), }; } private static WorkerFault CreateFault(WorkerFrameProtocolException exception) { return new WorkerFault { Category = MapFaultCategory(exception.ErrorCode), ExceptionType = exception.GetType().FullName ?? string.Empty, DiagnosticMessage = exception.Message, ProtocolStatus = new ProtocolStatus { Code = ProtocolStatusCode.ProtocolViolation, Message = exception.Message, }, }; } private static WorkerFault CreateFault(MxAccessCreationException exception) { WorkerFault fault = new() { Category = WorkerFaultCategory.MxaccessCreationFailed, ExceptionType = exception.InnerException?.GetType().FullName ?? exception.GetType().FullName ?? string.Empty, DiagnosticMessage = exception.Message, ProtocolStatus = new ProtocolStatus { Code = ProtocolStatusCode.WorkerUnavailable, Message = exception.Message, }, }; int? hresult = MxAccessCreationException.ExtractHResult(exception); if (hresult.HasValue) { fault.Hresult = hresult.Value; } return fault; } private static WorkerFaultCategory MapFaultCategory(WorkerFrameProtocolErrorCode errorCode) { return errorCode switch { WorkerFrameProtocolErrorCode.ProtocolVersionMismatch => WorkerFaultCategory.ProtocolMismatch, WorkerFrameProtocolErrorCode.EndOfStream => WorkerFaultCategory.PipeDisconnected, _ => WorkerFaultCategory.ProtocolViolation, }; } }