using System; using System.Collections.Generic; using System.IO; using System.IO.Pipes; using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; using Google.Protobuf.WellKnownTypes; using MxGateway.Contracts; using MxGateway.Contracts.Proto; using MxGateway.Worker.Ipc; using MxGateway.Worker.MxAccess; using MxGateway.Worker.Sta; namespace MxGateway.Worker.Tests.Ipc; public sealed class WorkerPipeSessionTests { private const string SessionId = "session-1"; private const string Nonce = "nonce-secret"; [Fact] public async Task CompleteStartupHandshakeAsync_WithValidGatewayHello_SendsHelloThenReady() { WorkerFrameProtocolOptions options = CreateOptions(); MemoryStream inbound = new(); await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope()); inbound.Position = 0; MemoryStream outbound = new(); WorkerPipeSession session = CreateSession(inbound, outbound, options); bool initialized = false; await session.CompleteStartupHandshakeAsync( _ => { initialized = true; return Task.CompletedTask; }); Assert.True(initialized); WorkerEnvelope[] written = ReadWrittenFrames(outbound, options); Assert.Equal(2, written.Length); Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerHello, written[0].BodyCase); Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerReady, written[1].BodyCase); Assert.Equal(Nonce, written[0].WorkerHello.Nonce); Assert.Equal(1234, written[1].WorkerReady.WorkerProcessId); Assert.Equal(MxGateway.Worker.MxAccess.MxAccessInteropInfo.ProgId, written[1].WorkerReady.MxaccessProgid); Assert.Equal(MxGateway.Worker.MxAccess.MxAccessInteropInfo.Clsid, written[1].WorkerReady.MxaccessClsid); Assert.NotNull(written[1].WorkerReady.ReadyTimestamp); } [Fact] public async Task CompleteStartupHandshakeAsync_WithWrongNonce_FaultsBeforeInitialization() { WorkerFrameProtocolOptions options = CreateOptions(); MemoryStream inbound = new(); await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope(nonce: "wrong")); inbound.Position = 0; MemoryStream outbound = new(); WorkerPipeSession session = CreateSession(inbound, outbound, options); bool initialized = false; WorkerFrameProtocolException exception = await Assert.ThrowsAsync( async () => await session.CompleteStartupHandshakeAsync( _ => { initialized = true; return Task.CompletedTask; })); Assert.False(initialized); Assert.Equal(WorkerFrameProtocolErrorCode.NonceMismatch, exception.ErrorCode); WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options)); Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerFault, fault.BodyCase); Assert.Equal(WorkerFaultCategory.ProtocolViolation, fault.WorkerFault.Category); } [Fact] public async Task CompleteStartupHandshakeAsync_WithWrongProtocol_FaultsBeforeInitialization() { WorkerFrameProtocolOptions options = CreateOptions(); MemoryStream inbound = new(); await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope(supportedProtocolVersion: 999)); inbound.Position = 0; MemoryStream outbound = new(); WorkerPipeSession session = CreateSession(inbound, outbound, options); bool initialized = false; WorkerFrameProtocolException exception = await Assert.ThrowsAsync( async () => await session.CompleteStartupHandshakeAsync( _ => { initialized = true; return Task.CompletedTask; })); Assert.False(initialized); Assert.Equal(WorkerFrameProtocolErrorCode.ProtocolVersionMismatch, exception.ErrorCode); WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options)); Assert.Equal(WorkerFaultCategory.ProtocolMismatch, fault.WorkerFault.Category); } [Fact] public async Task CompleteStartupHandshakeAsync_WithMalformedFrame_WritesWorkerFault() { WorkerFrameProtocolOptions options = CreateOptions(); MemoryStream inbound = new(CreateFrame(new byte[] { 0x80 })); MemoryStream outbound = new(); WorkerPipeSession session = CreateSession(inbound, outbound, options); bool initialized = false; WorkerFrameProtocolException exception = await Assert.ThrowsAsync( async () => await session.CompleteStartupHandshakeAsync( _ => { initialized = true; return Task.CompletedTask; })); Assert.False(initialized); Assert.Equal(WorkerFrameProtocolErrorCode.InvalidEnvelope, exception.ErrorCode); WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options)); Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerFault, fault.BodyCase); Assert.Equal(WorkerFaultCategory.ProtocolViolation, fault.WorkerFault.Category); } [Fact] public async Task CompleteStartupHandshakeAsync_WhenMxAccessCreationFails_WritesFaultInsteadOfReady() { const int hresult = unchecked((int)0x80040154); WorkerFrameProtocolOptions options = CreateOptions(); MemoryStream inbound = new(); await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope()); inbound.Position = 0; MemoryStream outbound = new(); WorkerPipeSession session = CreateSession(inbound, outbound, options); await Assert.ThrowsAsync( async () => await session.CompleteStartupHandshakeAsync( _ => Task.FromException(new COMException("Class not registered.", hresult)))); WorkerEnvelope[] written = ReadWrittenFrames(outbound, options); Assert.Equal(2, written.Length); Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerHello, written[0].BodyCase); Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerFault, written[1].BodyCase); Assert.Equal(WorkerFaultCategory.MxaccessCreationFailed, written[1].WorkerFault.Category); Assert.Equal(hresult, written[1].WorkerFault.Hresult); Assert.Equal(typeof(COMException).FullName, written[1].WorkerFault.ExceptionType); Assert.Equal(ProtocolStatusCode.WorkerUnavailable, written[1].WorkerFault.ProtocolStatus.Code); } [Fact] public async Task RunAsync_SendsHeartbeatPayloadFromRuntimeSnapshot() { using CancellationTokenSource cancellation = new(TimeSpan.FromSeconds(5)); using PipePair pipePair = await PipePair.CreateAsync(cancellation.Token); FakeRuntimeSession runtime = new(); runtime.SetSnapshot(new WorkerRuntimeHeartbeatSnapshot( DateTimeOffset.UtcNow, pendingCommandCount: 2, outboundEventQueueDepth: 3, lastEventSequence: 42, currentCommandCorrelationId: "current-command")); WorkerPipeSession session = CreatePipeSession( pipePair.WorkerStream, runtime, new WorkerPipeSessionOptions { HeartbeatInterval = TimeSpan.FromMilliseconds(20), HeartbeatGrace = TimeSpan.FromSeconds(5), }); Task runTask = session.RunAsync(cancellation.Token); await CompleteGatewayHandshakeAsync(pipePair, cancellation.Token); await ThrowIfCompletedAsync(runTask); WorkerEnvelope heartbeat = await ReadUntilAsync( pipePair.GatewayReader, WorkerEnvelope.BodyOneofCase.WorkerHeartbeat, cancellation.Token); Assert.Equal(WorkerState.ExecutingCommand, heartbeat.WorkerHeartbeat.State); Assert.Equal(1234, heartbeat.WorkerHeartbeat.WorkerProcessId); Assert.Equal(2u, heartbeat.WorkerHeartbeat.PendingCommandCount); Assert.Equal(3u, heartbeat.WorkerHeartbeat.OutboundEventQueueDepth); Assert.Equal(42UL, heartbeat.WorkerHeartbeat.LastEventSequence); Assert.Equal("current-command", heartbeat.WorkerHeartbeat.CurrentCommandCorrelationId); await SendShutdownAndWaitAsync(pipePair, runTask, cancellation.Token); } [Fact] public async Task RunAsync_WhenCommandIsExecuting_HeartbeatReportsCurrentCorrelation() { using CancellationTokenSource cancellation = new(TimeSpan.FromSeconds(5)); using PipePair pipePair = await PipePair.CreateAsync(cancellation.Token); FakeRuntimeSession runtime = new() { BlockDispatch = true, }; WorkerPipeSession session = CreatePipeSession( pipePair.WorkerStream, runtime, new WorkerPipeSessionOptions { HeartbeatInterval = TimeSpan.FromMilliseconds(20), HeartbeatGrace = TimeSpan.FromSeconds(5), }); Task runTask = session.RunAsync(cancellation.Token); await CompleteGatewayHandshakeAsync(pipePair, cancellation.Token); await pipePair.GatewayWriter.WriteAsync( CreateCommandEnvelope("command-1"), cancellation.Token); Assert.True(runtime.DispatchStarted.Wait(TimeSpan.FromSeconds(2))); WorkerEnvelope heartbeat = await ReadUntilAsync( pipePair.GatewayReader, WorkerEnvelope.BodyOneofCase.WorkerHeartbeat, envelope => envelope.WorkerHeartbeat.CurrentCommandCorrelationId == "command-1", cancellation.Token); Assert.Equal("command-1", heartbeat.WorkerHeartbeat.CurrentCommandCorrelationId); Assert.Equal(WorkerState.ExecutingCommand, heartbeat.WorkerHeartbeat.State); runtime.ReleaseDispatch(); WorkerEnvelope reply = await ReadUntilAsync( pipePair.GatewayReader, WorkerEnvelope.BodyOneofCase.WorkerCommandReply, cancellation.Token); Assert.Equal("command-1", reply.CorrelationId); Assert.Equal(ProtocolStatusCode.Ok, reply.WorkerCommandReply.Reply.ProtocolStatus.Code); await SendShutdownAndWaitAsync(pipePair, runTask, cancellation.Token); } [Fact] public async Task RunAsync_WhenStaActivityIsStale_WritesWatchdogFault() { using CancellationTokenSource cancellation = new(TimeSpan.FromSeconds(5)); using PipePair pipePair = await PipePair.CreateAsync(cancellation.Token); FakeRuntimeSession runtime = new(); runtime.SetSnapshot(new WorkerRuntimeHeartbeatSnapshot( DateTimeOffset.UtcNow - TimeSpan.FromSeconds(5), pendingCommandCount: 0, outboundEventQueueDepth: 0, lastEventSequence: 0, currentCommandCorrelationId: "stuck-command")); WorkerPipeSession session = CreatePipeSession( pipePair.WorkerStream, runtime, new WorkerPipeSessionOptions { HeartbeatInterval = TimeSpan.FromMilliseconds(20), HeartbeatGrace = TimeSpan.FromMilliseconds(50), }); Task runTask = session.RunAsync(cancellation.Token); await CompleteGatewayHandshakeAsync(pipePair, cancellation.Token); WorkerEnvelope fault = await ReadUntilAsync( pipePair.GatewayReader, WorkerEnvelope.BodyOneofCase.WorkerFault, cancellation.Token); Assert.Equal(WorkerFaultCategory.StaHung, fault.WorkerFault.Category); Assert.Equal("stuck-command", fault.WorkerFault.CommandMethod); Assert.Contains("STA activity is stale", fault.WorkerFault.DiagnosticMessage); await SendShutdownAndWaitAsync(pipePair, runTask, cancellation.Token); } private static WorkerPipeSession CreateSession( Stream inbound, Stream outbound, WorkerFrameProtocolOptions options) { return new WorkerPipeSession( new WorkerFrameReader(inbound, options), new WorkerFrameWriter(outbound, options), options, () => 1234); } private static WorkerPipeSession CreatePipeSession( Stream stream, FakeRuntimeSession runtime, WorkerPipeSessionOptions sessionOptions) { WorkerFrameProtocolOptions options = CreateOptions(); return new WorkerPipeSession( new WorkerFrameReader(stream, options), new WorkerFrameWriter(stream, options), options, () => 1234, sessionOptions, () => runtime); } private static WorkerFrameProtocolOptions CreateOptions() { return new WorkerFrameProtocolOptions( SessionId, GatewayContractInfo.WorkerProtocolVersion, Nonce); } private static WorkerEnvelope CreateGatewayHelloEnvelope( string nonce = Nonce, uint supportedProtocolVersion = GatewayContractInfo.WorkerProtocolVersion) { return new WorkerEnvelope { ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion, SessionId = SessionId, Sequence = 1, GatewayHello = new GatewayHello { SupportedProtocolVersion = supportedProtocolVersion, Nonce = nonce, GatewayVersion = "test-gateway", }, }; } private static WorkerEnvelope CreateCommandEnvelope(string correlationId) { return new WorkerEnvelope { ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion, SessionId = SessionId, Sequence = 2, CorrelationId = correlationId, WorkerCommand = new WorkerCommand { Command = new MxCommand { Kind = MxCommandKind.Ping, Ping = new PingCommand { Message = "ping", }, }, EnqueueTimestamp = Timestamp.FromDateTimeOffset(DateTimeOffset.UtcNow), }, }; } private static WorkerEnvelope CreateShutdownEnvelope() { return new WorkerEnvelope { ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion, SessionId = SessionId, Sequence = 3, WorkerShutdown = new WorkerShutdown { GracePeriod = Duration.FromTimeSpan(TimeSpan.FromSeconds(1)), Reason = "test-complete", }, }; } private static async Task CompleteGatewayHandshakeAsync( PipePair pipePair, CancellationToken cancellationToken) { await pipePair.GatewayWriter .WriteAsync(CreateGatewayHelloEnvelope(), cancellationToken) .ConfigureAwait(false); WorkerEnvelope hello = await pipePair.GatewayReader.ReadAsync(cancellationToken).ConfigureAwait(false); WorkerEnvelope ready = await pipePair.GatewayReader.ReadAsync(cancellationToken).ConfigureAwait(false); Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerHello, hello.BodyCase); Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerReady, ready.BodyCase); } private static async Task SendShutdownAndWaitAsync( PipePair pipePair, Task runTask, CancellationToken cancellationToken) { await pipePair.GatewayWriter .WriteAsync(CreateShutdownEnvelope(), cancellationToken) .ConfigureAwait(false); WorkerEnvelope shutdownAck = await ReadUntilAsync( pipePair.GatewayReader, WorkerEnvelope.BodyOneofCase.WorkerShutdownAck, cancellationToken); Assert.Equal(ProtocolStatusCode.Ok, shutdownAck.WorkerShutdownAck.Status.Code); Task completedTask = await Task .WhenAny(runTask, Task.Delay(TimeSpan.FromSeconds(2), cancellationToken)) .ConfigureAwait(false); Assert.Same(runTask, completedTask); await runTask.ConfigureAwait(false); } private static async Task ThrowIfCompletedAsync(Task task) { await Task.Delay(TimeSpan.FromMilliseconds(100)); if (task.IsCompleted) { await task; } } private static Task ReadUntilAsync( WorkerFrameReader reader, WorkerEnvelope.BodyOneofCase expectedBody, CancellationToken cancellationToken) { return ReadUntilAsync( reader, expectedBody, _ => true, cancellationToken); } private static async Task ReadUntilAsync( WorkerFrameReader reader, WorkerEnvelope.BodyOneofCase expectedBody, Func predicate, CancellationToken cancellationToken) { while (true) { WorkerEnvelope envelope = await reader.ReadAsync(cancellationToken).ConfigureAwait(false); if (envelope.BodyCase == expectedBody && predicate(envelope)) { return envelope; } } } private static WorkerEnvelope[] ReadWrittenFrames( MemoryStream stream, WorkerFrameProtocolOptions options) { stream.Position = 0; WorkerFrameReader reader = new(stream, options); List envelopes = new(); while (stream.Position < stream.Length) { envelopes.Add(reader.ReadAsync(CancellationToken.None).GetAwaiter().GetResult()); } return envelopes.ToArray(); } private static byte[] CreateFrame(byte[] payload) { byte[] frame = new byte[sizeof(uint) + payload.Length]; WriteUInt32LittleEndian(frame, (uint)payload.Length); payload.CopyTo(frame, sizeof(uint)); return frame; } private static void WriteUInt32LittleEndian( byte[] buffer, uint value) { buffer[0] = (byte)value; buffer[1] = (byte)(value >> 8); buffer[2] = (byte)(value >> 16); buffer[3] = (byte)(value >> 24); } private sealed class FakeRuntimeSession : IWorkerRuntimeSession { private readonly ManualResetEventSlim releaseDispatch = new(false); private readonly object gate = new(); private WorkerRuntimeHeartbeatSnapshot snapshot = new( DateTimeOffset.UtcNow, pendingCommandCount: 0, outboundEventQueueDepth: 0, lastEventSequence: 0, currentCommandCorrelationId: string.Empty); public ManualResetEventSlim DispatchStarted { get; } = new(false); public bool BlockDispatch { get; set; } public Task StartAsync( string sessionId, int workerProcessId, CancellationToken cancellationToken = default) { return Task.FromResult(new WorkerReady { WorkerProcessId = workerProcessId, MxaccessProgid = MxGateway.Worker.MxAccess.MxAccessInteropInfo.ProgId, MxaccessClsid = MxGateway.Worker.MxAccess.MxAccessInteropInfo.Clsid, ReadyTimestamp = Timestamp.FromDateTimeOffset(DateTimeOffset.UtcNow), }); } public Task DispatchAsync(StaCommand command) { return Task.Run( () => { SetSnapshot(new WorkerRuntimeHeartbeatSnapshot( DateTimeOffset.UtcNow, pendingCommandCount: 0, outboundEventQueueDepth: 0, lastEventSequence: 0, command.CorrelationId)); DispatchStarted.Set(); if (BlockDispatch) { releaseDispatch.Wait(TimeSpan.FromSeconds(5)); } SetSnapshot(new WorkerRuntimeHeartbeatSnapshot( DateTimeOffset.UtcNow, pendingCommandCount: 0, outboundEventQueueDepth: 0, lastEventSequence: 0, currentCommandCorrelationId: string.Empty)); return new MxCommandReply { SessionId = command.SessionId, CorrelationId = command.CorrelationId, Kind = command.Kind, ProtocolStatus = new ProtocolStatus { Code = ProtocolStatusCode.Ok, Message = "OK", }, }; }); } public WorkerRuntimeHeartbeatSnapshot CaptureHeartbeat() { lock (gate) { return snapshot; } } public void RequestShutdown() { releaseDispatch.Set(); } public void ReleaseDispatch() { releaseDispatch.Set(); } public void SetSnapshot(WorkerRuntimeHeartbeatSnapshot value) { lock (gate) { snapshot = value; } } public void Dispose() { releaseDispatch.Set(); releaseDispatch.Dispose(); DispatchStarted.Dispose(); } } private sealed class PipePair : IDisposable { private readonly NamedPipeServerStream gatewayStream; private PipePair( NamedPipeServerStream gatewayStream, NamedPipeClientStream workerStream) { this.gatewayStream = gatewayStream; WorkerStream = workerStream; WorkerFrameProtocolOptions options = CreateOptions(); GatewayReader = new WorkerFrameReader(gatewayStream, options); GatewayWriter = new WorkerFrameWriter(gatewayStream, options); } public Stream WorkerStream { get; } public WorkerFrameReader GatewayReader { get; } public WorkerFrameWriter GatewayWriter { get; } public static async Task CreateAsync(CancellationToken cancellationToken) { string pipeName = $"mxaccessgw-worker-session-tests-{Guid.NewGuid():N}"; NamedPipeServerStream gatewayStream = new( pipeName, PipeDirection.InOut, maxNumberOfServerInstances: 1, PipeTransmissionMode.Byte, PipeOptions.Asynchronous); NamedPipeClientStream workerStream = new( ".", pipeName, PipeDirection.InOut, PipeOptions.Asynchronous); Task waitForConnectionTask = gatewayStream.WaitForConnectionAsync(); await Task .Run(() => workerStream.Connect(5000), cancellationToken) .ConfigureAwait(false); await waitForConnectionTask.ConfigureAwait(false); return new PipePair(gatewayStream, workerStream); } public void Dispose() { WorkerStream.Dispose(); gatewayStream.Dispose(); } } }