using System.Collections.Generic; using System.IO; using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; using MxGateway.Contracts; using MxGateway.Contracts.Proto; using MxGateway.Worker.Ipc; namespace MxGateway.Worker.Tests.Ipc; public sealed class WorkerPipeSessionTests { private const string SessionId = "session-1"; private const string Nonce = "nonce-secret"; [Fact] public async Task CompleteStartupHandshakeAsync_WithValidGatewayHello_SendsHelloThenReady() { WorkerFrameProtocolOptions options = CreateOptions(); MemoryStream inbound = new(); await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope()); inbound.Position = 0; MemoryStream outbound = new(); WorkerPipeSession session = CreateSession(inbound, outbound, options); bool initialized = false; await session.CompleteStartupHandshakeAsync( _ => { initialized = true; return Task.CompletedTask; }); Assert.True(initialized); WorkerEnvelope[] written = ReadWrittenFrames(outbound, options); Assert.Equal(2, written.Length); Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerHello, written[0].BodyCase); Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerReady, written[1].BodyCase); Assert.Equal(Nonce, written[0].WorkerHello.Nonce); Assert.Equal(1234, written[1].WorkerReady.WorkerProcessId); Assert.Equal(MxGateway.Worker.MxAccess.MxAccessInteropInfo.ProgId, written[1].WorkerReady.MxaccessProgid); Assert.Equal(MxGateway.Worker.MxAccess.MxAccessInteropInfo.Clsid, written[1].WorkerReady.MxaccessClsid); Assert.NotNull(written[1].WorkerReady.ReadyTimestamp); } [Fact] public async Task CompleteStartupHandshakeAsync_WithWrongNonce_FaultsBeforeInitialization() { WorkerFrameProtocolOptions options = CreateOptions(); MemoryStream inbound = new(); await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope(nonce: "wrong")); inbound.Position = 0; MemoryStream outbound = new(); WorkerPipeSession session = CreateSession(inbound, outbound, options); bool initialized = false; WorkerFrameProtocolException exception = await Assert.ThrowsAsync( async () => await session.CompleteStartupHandshakeAsync( _ => { initialized = true; return Task.CompletedTask; })); Assert.False(initialized); Assert.Equal(WorkerFrameProtocolErrorCode.NonceMismatch, exception.ErrorCode); WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options)); Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerFault, fault.BodyCase); Assert.Equal(WorkerFaultCategory.ProtocolViolation, fault.WorkerFault.Category); } [Fact] public async Task CompleteStartupHandshakeAsync_WithWrongProtocol_FaultsBeforeInitialization() { WorkerFrameProtocolOptions options = CreateOptions(); MemoryStream inbound = new(); await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope(supportedProtocolVersion: 999)); inbound.Position = 0; MemoryStream outbound = new(); WorkerPipeSession session = CreateSession(inbound, outbound, options); bool initialized = false; WorkerFrameProtocolException exception = await Assert.ThrowsAsync( async () => await session.CompleteStartupHandshakeAsync( _ => { initialized = true; return Task.CompletedTask; })); Assert.False(initialized); Assert.Equal(WorkerFrameProtocolErrorCode.ProtocolVersionMismatch, exception.ErrorCode); WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options)); Assert.Equal(WorkerFaultCategory.ProtocolMismatch, fault.WorkerFault.Category); } [Fact] public async Task CompleteStartupHandshakeAsync_WithMalformedFrame_WritesWorkerFault() { WorkerFrameProtocolOptions options = CreateOptions(); MemoryStream inbound = new(CreateFrame(new byte[] { 0x80 })); MemoryStream outbound = new(); WorkerPipeSession session = CreateSession(inbound, outbound, options); bool initialized = false; WorkerFrameProtocolException exception = await Assert.ThrowsAsync( async () => await session.CompleteStartupHandshakeAsync( _ => { initialized = true; return Task.CompletedTask; })); Assert.False(initialized); Assert.Equal(WorkerFrameProtocolErrorCode.InvalidEnvelope, exception.ErrorCode); WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options)); Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerFault, fault.BodyCase); Assert.Equal(WorkerFaultCategory.ProtocolViolation, fault.WorkerFault.Category); } [Fact] public async Task CompleteStartupHandshakeAsync_WhenMxAccessCreationFails_WritesFaultInsteadOfReady() { const int hresult = unchecked((int)0x80040154); WorkerFrameProtocolOptions options = CreateOptions(); MemoryStream inbound = new(); await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope()); inbound.Position = 0; MemoryStream outbound = new(); WorkerPipeSession session = CreateSession(inbound, outbound, options); await Assert.ThrowsAsync( async () => await session.CompleteStartupHandshakeAsync( _ => Task.FromException(new COMException("Class not registered.", hresult)))); WorkerEnvelope[] written = ReadWrittenFrames(outbound, options); Assert.Equal(2, written.Length); Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerHello, written[0].BodyCase); Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerFault, written[1].BodyCase); Assert.Equal(WorkerFaultCategory.MxaccessCreationFailed, written[1].WorkerFault.Category); Assert.Equal(hresult, written[1].WorkerFault.Hresult); Assert.Equal(typeof(COMException).FullName, written[1].WorkerFault.ExceptionType); Assert.Equal(ProtocolStatusCode.WorkerUnavailable, written[1].WorkerFault.ProtocolStatus.Code); } private static WorkerPipeSession CreateSession( Stream inbound, Stream outbound, WorkerFrameProtocolOptions options) { return new WorkerPipeSession( new WorkerFrameReader(inbound, options), new WorkerFrameWriter(outbound, options), options, () => 1234); } private static WorkerFrameProtocolOptions CreateOptions() { return new WorkerFrameProtocolOptions( SessionId, GatewayContractInfo.WorkerProtocolVersion, Nonce); } private static WorkerEnvelope CreateGatewayHelloEnvelope( string nonce = Nonce, uint supportedProtocolVersion = GatewayContractInfo.WorkerProtocolVersion) { return new WorkerEnvelope { ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion, SessionId = SessionId, Sequence = 1, GatewayHello = new GatewayHello { SupportedProtocolVersion = supportedProtocolVersion, Nonce = nonce, GatewayVersion = "test-gateway", }, }; } private static WorkerEnvelope[] ReadWrittenFrames( MemoryStream stream, WorkerFrameProtocolOptions options) { stream.Position = 0; WorkerFrameReader reader = new(stream, options); List envelopes = new(); while (stream.Position < stream.Length) { envelopes.Add(reader.ReadAsync(CancellationToken.None).GetAwaiter().GetResult()); } return envelopes.ToArray(); } private static byte[] CreateFrame(byte[] payload) { byte[] frame = new byte[sizeof(uint) + payload.Length]; WriteUInt32LittleEndian(frame, (uint)payload.Length); payload.CopyTo(frame, sizeof(uint)); return frame; } private static void WriteUInt32LittleEndian( byte[] buffer, uint value) { buffer[0] = (byte)value; buffer[1] = (byte)(value >> 8); buffer[2] = (byte)(value >> 16); buffer[3] = (byte)(value >> 24); } }