From d5a982152b72dfddaf7810c372810a21ac849393 Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 26 Apr 2026 17:16:49 -0400 Subject: [PATCH] Issue #22: implement pipe client and frame protocol --- .../Bootstrap/WorkerApplicationTests.cs | 55 ++++- .../Ipc/WorkerFrameProtocolTests.cs | 163 +++++++++++++ .../Ipc/WorkerPipeClientTests.cs | 61 +++++ .../Ipc/WorkerPipeSessionTests.cs | 192 +++++++++++++++ .../Bootstrap/WorkerExitCode.cs | 2 + src/MxGateway.Worker/Ipc/IWorkerPipeClient.cs | 12 + .../Ipc/WorkerEnvelopeValidator.cs | 33 +++ .../Ipc/WorkerFrameProtocolErrorCode.cs | 15 ++ .../Ipc/WorkerFrameProtocolException.cs | 25 ++ .../Ipc/WorkerFrameProtocolOptions.cs | 86 +++++++ src/MxGateway.Worker/Ipc/WorkerFrameReader.cs | 93 ++++++++ src/MxGateway.Worker/Ipc/WorkerFrameWriter.cs | 76 ++++++ src/MxGateway.Worker/Ipc/WorkerPipeClient.cs | 67 ++++++ src/MxGateway.Worker/Ipc/WorkerPipeSession.cs | 218 ++++++++++++++++++ src/MxGateway.Worker/WorkerApplication.cs | 53 ++++- 15 files changed, 1148 insertions(+), 3 deletions(-) create mode 100644 src/MxGateway.Worker.Tests/Ipc/WorkerFrameProtocolTests.cs create mode 100644 src/MxGateway.Worker.Tests/Ipc/WorkerPipeClientTests.cs create mode 100644 src/MxGateway.Worker.Tests/Ipc/WorkerPipeSessionTests.cs create mode 100644 src/MxGateway.Worker/Ipc/IWorkerPipeClient.cs create mode 100644 src/MxGateway.Worker/Ipc/WorkerEnvelopeValidator.cs create mode 100644 src/MxGateway.Worker/Ipc/WorkerFrameProtocolErrorCode.cs create mode 100644 src/MxGateway.Worker/Ipc/WorkerFrameProtocolException.cs create mode 100644 src/MxGateway.Worker/Ipc/WorkerFrameProtocolOptions.cs create mode 100644 src/MxGateway.Worker/Ipc/WorkerFrameReader.cs create mode 100644 src/MxGateway.Worker/Ipc/WorkerFrameWriter.cs create mode 100644 src/MxGateway.Worker/Ipc/WorkerPipeClient.cs create mode 100644 src/MxGateway.Worker/Ipc/WorkerPipeSession.cs diff --git a/src/MxGateway.Worker.Tests/Bootstrap/WorkerApplicationTests.cs b/src/MxGateway.Worker.Tests/Bootstrap/WorkerApplicationTests.cs index f16aeba..a142266 100644 --- a/src/MxGateway.Worker.Tests/Bootstrap/WorkerApplicationTests.cs +++ b/src/MxGateway.Worker.Tests/Bootstrap/WorkerApplicationTests.cs @@ -1,6 +1,9 @@ using System; +using System.Threading; +using System.Threading.Tasks; using MxGateway.Contracts; using MxGateway.Worker.Bootstrap; +using MxGateway.Worker.Ipc; namespace MxGateway.Worker.Tests.Bootstrap; @@ -15,16 +18,19 @@ public sealed class WorkerApplicationTests int exitCode = MxGateway.Worker.WorkerApplication.Run( ValidArgs(), environment, - logger); + logger, + new SucceedingPipeClient()); Assert.Equal((int)WorkerExitCode.Success, exitCode); - MemoryWorkerLogEntry entry = Assert.Single(logger.Entries); + Assert.Equal(2, logger.Entries.Count); + MemoryWorkerLogEntry entry = logger.Entries[0]; Assert.Equal("Information", entry.Level); Assert.Equal("WorkerBootstrapSucceeded", entry.EventName); Assert.Equal("session-1", entry.Fields["session_id"]); Assert.Equal("mxaccess-gateway-123-session-1", entry.Fields["pipe_name"]); Assert.Equal(GatewayContractInfo.WorkerProtocolVersion, entry.Fields["protocol_version"]); Assert.Equal("[redacted]", entry.Fields["nonce"]); + Assert.Equal("WorkerPipeHandshakeSucceeded", logger.Entries[1].EventName); } [Fact] @@ -73,6 +79,24 @@ public sealed class WorkerApplicationTests Assert.Equal((int)WorkerExitCode.MissingNonce, exitCode); } + [Fact] + public void Run_WithPipeProtocolFailure_ReturnsProtocolViolation() + { + MemoryWorkerEnvironment environment = CreateEnvironment("nonce-secret"); + MemoryWorkerLogger logger = new(); + + int exitCode = MxGateway.Worker.WorkerApplication.Run( + ValidArgs(), + environment, + logger, + new ThrowingPipeClient(new WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode.NonceMismatch, + "Bad nonce."))); + + Assert.Equal((int)WorkerExitCode.ProtocolViolation, exitCode); + Assert.Equal("WorkerPipeProtocolFailure", logger.Entries[1].EventName); + } + [Fact] public void Run_WithUnexpectedBootstrapFailure_ReturnsUnexpectedFailure() { @@ -110,4 +134,31 @@ public sealed class WorkerApplicationTests environment.Set(WorkerOptions.NonceEnvironmentVariableName, nonce); return environment; } + + private sealed class SucceedingPipeClient : IWorkerPipeClient + { + public Task RunAsync( + WorkerOptions options, + CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } + } + + private sealed class ThrowingPipeClient : IWorkerPipeClient + { + private readonly Exception _exception; + + public ThrowingPipeClient(Exception exception) + { + _exception = exception; + } + + public Task RunAsync( + WorkerOptions options, + CancellationToken cancellationToken = default) + { + throw _exception; + } + } } diff --git a/src/MxGateway.Worker.Tests/Ipc/WorkerFrameProtocolTests.cs b/src/MxGateway.Worker.Tests/Ipc/WorkerFrameProtocolTests.cs new file mode 100644 index 0000000..5cd977b --- /dev/null +++ b/src/MxGateway.Worker.Tests/Ipc/WorkerFrameProtocolTests.cs @@ -0,0 +1,163 @@ +using System; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using Google.Protobuf; +using MxGateway.Contracts; +using MxGateway.Contracts.Proto; +using MxGateway.Worker.Ipc; + +namespace MxGateway.Worker.Tests.Ipc; + +public sealed class WorkerFrameProtocolTests +{ + private const string SessionId = "session-1"; + private const string Nonce = "nonce-secret"; + + [Fact] + public async Task WriteAndReadAsync_WithValidEnvelope_RoundTripsFrame() + { + WorkerFrameProtocolOptions options = CreateOptions(); + MemoryStream stream = new(); + WorkerEnvelope original = CreateGatewayHelloEnvelope(); + + WorkerFrameWriter writer = new(stream, options); + await writer.WriteAsync(original); + stream.Position = 0; + + WorkerFrameReader reader = new(stream, options); + WorkerEnvelope parsed = await reader.ReadAsync(); + + Assert.Equal(original, parsed); + } + + [Fact] + public async Task ReadAsync_WithWrongProtocolVersion_ThrowsProtocolVersionMismatch() + { + WorkerFrameProtocolOptions options = CreateOptions(); + WorkerEnvelope envelope = CreateGatewayHelloEnvelope(); + envelope.ProtocolVersion++; + MemoryStream stream = new(CreateFrame(envelope)); + + WorkerFrameReader reader = new(stream, options); + WorkerFrameProtocolException exception = + await Assert.ThrowsAsync( + async () => await reader.ReadAsync()); + + Assert.Equal(WorkerFrameProtocolErrorCode.ProtocolVersionMismatch, exception.ErrorCode); + } + + [Fact] + public async Task ReadAsync_WithWrongSessionId_ThrowsSessionMismatch() + { + WorkerFrameProtocolOptions options = CreateOptions(); + WorkerEnvelope envelope = CreateGatewayHelloEnvelope(); + envelope.SessionId = "different-session"; + MemoryStream stream = new(CreateFrame(envelope)); + + WorkerFrameReader reader = new(stream, options); + WorkerFrameProtocolException exception = + await Assert.ThrowsAsync( + async () => await reader.ReadAsync()); + + Assert.Equal(WorkerFrameProtocolErrorCode.SessionMismatch, exception.ErrorCode); + } + + [Fact] + public async Task ReadAsync_WithMalformedLength_ThrowsMalformedLength() + { + WorkerFrameProtocolOptions options = CreateOptions(); + MemoryStream stream = new(new byte[sizeof(uint)]); + + WorkerFrameReader reader = new(stream, options); + WorkerFrameProtocolException exception = + await Assert.ThrowsAsync( + async () => await reader.ReadAsync()); + + Assert.Equal(WorkerFrameProtocolErrorCode.MalformedLength, exception.ErrorCode); + } + + [Fact] + public async Task ReadAsync_WithMalformedPayload_ThrowsInvalidEnvelope() + { + WorkerFrameProtocolOptions options = CreateOptions(); + MemoryStream stream = new(CreateFrame(new byte[] { 0x80 })); + + WorkerFrameReader reader = new(stream, options); + WorkerFrameProtocolException exception = + await Assert.ThrowsAsync( + async () => await reader.ReadAsync()); + + Assert.Equal(WorkerFrameProtocolErrorCode.InvalidEnvelope, exception.ErrorCode); + } + + [Fact] + public async Task WriteAsync_WithConcurrentCalls_SerializesCompleteFrames() + { + WorkerFrameProtocolOptions options = CreateOptions(); + MemoryStream stream = new(); + WorkerFrameWriter writer = new(stream, options); + + await Task.WhenAll( + writer.WriteAsync(CreateGatewayHelloEnvelope(sequence: 1)), + writer.WriteAsync(CreateGatewayHelloEnvelope(sequence: 2)), + writer.WriteAsync(CreateGatewayHelloEnvelope(sequence: 3))); + + stream.Position = 0; + WorkerFrameReader reader = new(stream, options); + + WorkerEnvelope first = await reader.ReadAsync(); + WorkerEnvelope second = await reader.ReadAsync(); + WorkerEnvelope third = await reader.ReadAsync(); + + Assert.Equal(new ulong[] { 1, 2, 3 }, new[] { first.Sequence, second.Sequence, third.Sequence }.OrderBy(sequence => sequence)); + } + + private static WorkerFrameProtocolOptions CreateOptions() + { + return new WorkerFrameProtocolOptions( + SessionId, + GatewayContractInfo.WorkerProtocolVersion, + Nonce); + } + + private static WorkerEnvelope CreateGatewayHelloEnvelope(ulong sequence = 1) + { + return new WorkerEnvelope + { + ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion, + SessionId = SessionId, + Sequence = sequence, + GatewayHello = new GatewayHello + { + SupportedProtocolVersion = GatewayContractInfo.WorkerProtocolVersion, + Nonce = Nonce, + GatewayVersion = "test-gateway", + }, + }; + } + + private static byte[] CreateFrame(IMessage message) + { + return CreateFrame(message.ToByteArray()); + } + + private static byte[] CreateFrame(byte[] payload) + { + byte[] frame = new byte[sizeof(uint) + payload.Length]; + WriteUInt32LittleEndian(frame, (uint)payload.Length); + payload.CopyTo(frame, sizeof(uint)); + + return frame; + } + + private static void WriteUInt32LittleEndian( + byte[] buffer, + uint value) + { + buffer[0] = (byte)value; + buffer[1] = (byte)(value >> 8); + buffer[2] = (byte)(value >> 16); + buffer[3] = (byte)(value >> 24); + } +} diff --git a/src/MxGateway.Worker.Tests/Ipc/WorkerPipeClientTests.cs b/src/MxGateway.Worker.Tests/Ipc/WorkerPipeClientTests.cs new file mode 100644 index 0000000..36caa96 --- /dev/null +++ b/src/MxGateway.Worker.Tests/Ipc/WorkerPipeClientTests.cs @@ -0,0 +1,61 @@ +using System; +using System.IO.Pipes; +using System.Threading.Tasks; +using MxGateway.Contracts; +using MxGateway.Contracts.Proto; +using MxGateway.Worker.Bootstrap; +using MxGateway.Worker.Ipc; + +namespace MxGateway.Worker.Tests.Ipc; + +public sealed class WorkerPipeClientTests +{ + [Fact] + public async Task RunAsync_ConnectsToPipeAndCompletesHandshake() + { + string pipeName = $"mxaccess-gateway-test-{Guid.NewGuid():N}"; + WorkerOptions workerOptions = new( + "session-1", + pipeName, + GatewayContractInfo.WorkerProtocolVersion, + "nonce-secret"); + WorkerFrameProtocolOptions frameOptions = new(workerOptions); + + using NamedPipeServerStream server = new( + pipeName, + PipeDirection.InOut, + 1, + PipeTransmissionMode.Byte, + PipeOptions.Asynchronous); + + WorkerPipeClient client = new(connectTimeoutMilliseconds: 5000); + Task clientTask = client.RunAsync(workerOptions); + + await Task.Factory.FromAsync(server.BeginWaitForConnection, server.EndWaitForConnection, null); + + WorkerFrameReader reader = new(server, frameOptions); + WorkerFrameWriter writer = new(server, frameOptions); + + await writer.WriteAsync(new WorkerEnvelope + { + ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion, + SessionId = "session-1", + Sequence = 1, + GatewayHello = new GatewayHello + { + SupportedProtocolVersion = GatewayContractInfo.WorkerProtocolVersion, + Nonce = "nonce-secret", + GatewayVersion = "test-gateway", + }, + }); + + WorkerEnvelope hello = await reader.ReadAsync(); + Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerHello, hello.BodyCase); + Assert.Equal("nonce-secret", hello.WorkerHello.Nonce); + + WorkerEnvelope ready = await reader.ReadAsync(); + Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerReady, ready.BodyCase); + + await clientTask; + } +} diff --git a/src/MxGateway.Worker.Tests/Ipc/WorkerPipeSessionTests.cs b/src/MxGateway.Worker.Tests/Ipc/WorkerPipeSessionTests.cs new file mode 100644 index 0000000..7e5dd74 --- /dev/null +++ b/src/MxGateway.Worker.Tests/Ipc/WorkerPipeSessionTests.cs @@ -0,0 +1,192 @@ +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using MxGateway.Contracts; +using MxGateway.Contracts.Proto; +using MxGateway.Worker.Ipc; + +namespace MxGateway.Worker.Tests.Ipc; + +public sealed class WorkerPipeSessionTests +{ + private const string SessionId = "session-1"; + private const string Nonce = "nonce-secret"; + + [Fact] + public async Task CompleteStartupHandshakeAsync_WithValidGatewayHello_SendsHelloThenReady() + { + WorkerFrameProtocolOptions options = CreateOptions(); + MemoryStream inbound = new(); + await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope()); + inbound.Position = 0; + MemoryStream outbound = new(); + WorkerPipeSession session = CreateSession(inbound, outbound, options); + bool initialized = false; + + await session.CompleteStartupHandshakeAsync( + _ => + { + initialized = true; + return Task.CompletedTask; + }); + + Assert.True(initialized); + WorkerEnvelope[] written = ReadWrittenFrames(outbound, options); + Assert.Equal(2, written.Length); + Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerHello, written[0].BodyCase); + Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerReady, written[1].BodyCase); + Assert.Equal(Nonce, written[0].WorkerHello.Nonce); + } + + [Fact] + public async Task CompleteStartupHandshakeAsync_WithWrongNonce_FaultsBeforeInitialization() + { + WorkerFrameProtocolOptions options = CreateOptions(); + MemoryStream inbound = new(); + await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope(nonce: "wrong")); + inbound.Position = 0; + MemoryStream outbound = new(); + WorkerPipeSession session = CreateSession(inbound, outbound, options); + bool initialized = false; + + WorkerFrameProtocolException exception = + await Assert.ThrowsAsync( + async () => await session.CompleteStartupHandshakeAsync( + _ => + { + initialized = true; + return Task.CompletedTask; + })); + + Assert.False(initialized); + Assert.Equal(WorkerFrameProtocolErrorCode.NonceMismatch, exception.ErrorCode); + WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options)); + Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerFault, fault.BodyCase); + Assert.Equal(WorkerFaultCategory.ProtocolViolation, fault.WorkerFault.Category); + } + + [Fact] + public async Task CompleteStartupHandshakeAsync_WithWrongProtocol_FaultsBeforeInitialization() + { + WorkerFrameProtocolOptions options = CreateOptions(); + MemoryStream inbound = new(); + await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope(supportedProtocolVersion: 999)); + inbound.Position = 0; + MemoryStream outbound = new(); + WorkerPipeSession session = CreateSession(inbound, outbound, options); + bool initialized = false; + + WorkerFrameProtocolException exception = + await Assert.ThrowsAsync( + async () => await session.CompleteStartupHandshakeAsync( + _ => + { + initialized = true; + return Task.CompletedTask; + })); + + Assert.False(initialized); + Assert.Equal(WorkerFrameProtocolErrorCode.ProtocolVersionMismatch, exception.ErrorCode); + WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options)); + Assert.Equal(WorkerFaultCategory.ProtocolMismatch, fault.WorkerFault.Category); + } + + [Fact] + public async Task CompleteStartupHandshakeAsync_WithMalformedFrame_WritesWorkerFault() + { + WorkerFrameProtocolOptions options = CreateOptions(); + MemoryStream inbound = new(CreateFrame(new byte[] { 0x80 })); + MemoryStream outbound = new(); + WorkerPipeSession session = CreateSession(inbound, outbound, options); + bool initialized = false; + + WorkerFrameProtocolException exception = + await Assert.ThrowsAsync( + async () => await session.CompleteStartupHandshakeAsync( + _ => + { + initialized = true; + return Task.CompletedTask; + })); + + Assert.False(initialized); + Assert.Equal(WorkerFrameProtocolErrorCode.InvalidEnvelope, exception.ErrorCode); + WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options)); + Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerFault, fault.BodyCase); + Assert.Equal(WorkerFaultCategory.ProtocolViolation, fault.WorkerFault.Category); + } + + private static WorkerPipeSession CreateSession( + Stream inbound, + Stream outbound, + WorkerFrameProtocolOptions options) + { + return new WorkerPipeSession( + new WorkerFrameReader(inbound, options), + new WorkerFrameWriter(outbound, options), + options, + () => 1234); + } + + private static WorkerFrameProtocolOptions CreateOptions() + { + return new WorkerFrameProtocolOptions( + SessionId, + GatewayContractInfo.WorkerProtocolVersion, + Nonce); + } + + private static WorkerEnvelope CreateGatewayHelloEnvelope( + string nonce = Nonce, + uint supportedProtocolVersion = GatewayContractInfo.WorkerProtocolVersion) + { + return new WorkerEnvelope + { + ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion, + SessionId = SessionId, + Sequence = 1, + GatewayHello = new GatewayHello + { + SupportedProtocolVersion = supportedProtocolVersion, + Nonce = nonce, + GatewayVersion = "test-gateway", + }, + }; + } + + private static WorkerEnvelope[] ReadWrittenFrames( + MemoryStream stream, + WorkerFrameProtocolOptions options) + { + stream.Position = 0; + WorkerFrameReader reader = new(stream, options); + List envelopes = new(); + + while (stream.Position < stream.Length) + { + envelopes.Add(reader.ReadAsync(CancellationToken.None).GetAwaiter().GetResult()); + } + + return envelopes.ToArray(); + } + + private static byte[] CreateFrame(byte[] payload) + { + byte[] frame = new byte[sizeof(uint) + payload.Length]; + WriteUInt32LittleEndian(frame, (uint)payload.Length); + payload.CopyTo(frame, sizeof(uint)); + + return frame; + } + + private static void WriteUInt32LittleEndian( + byte[] buffer, + uint value) + { + buffer[0] = (byte)value; + buffer[1] = (byte)(value >> 8); + buffer[2] = (byte)(value >> 16); + buffer[3] = (byte)(value >> 24); + } +} diff --git a/src/MxGateway.Worker/Bootstrap/WorkerExitCode.cs b/src/MxGateway.Worker/Bootstrap/WorkerExitCode.cs index 5379abd..67bbad1 100644 --- a/src/MxGateway.Worker/Bootstrap/WorkerExitCode.cs +++ b/src/MxGateway.Worker/Bootstrap/WorkerExitCode.cs @@ -7,4 +7,6 @@ public enum WorkerExitCode InvalidArguments = 2, InvalidProtocolVersion = 3, MissingNonce = 4, + PipeConnectionFailed = 5, + ProtocolViolation = 6, } diff --git a/src/MxGateway.Worker/Ipc/IWorkerPipeClient.cs b/src/MxGateway.Worker/Ipc/IWorkerPipeClient.cs new file mode 100644 index 0000000..8812f38 --- /dev/null +++ b/src/MxGateway.Worker/Ipc/IWorkerPipeClient.cs @@ -0,0 +1,12 @@ +using System.Threading; +using System.Threading.Tasks; +using MxGateway.Worker.Bootstrap; + +namespace MxGateway.Worker.Ipc; + +public interface IWorkerPipeClient +{ + Task RunAsync( + WorkerOptions options, + CancellationToken cancellationToken = default); +} diff --git a/src/MxGateway.Worker/Ipc/WorkerEnvelopeValidator.cs b/src/MxGateway.Worker/Ipc/WorkerEnvelopeValidator.cs new file mode 100644 index 0000000..ebe3823 --- /dev/null +++ b/src/MxGateway.Worker/Ipc/WorkerEnvelopeValidator.cs @@ -0,0 +1,33 @@ +using System; +using MxGateway.Contracts.Proto; + +namespace MxGateway.Worker.Ipc; + +internal static class WorkerEnvelopeValidator +{ + public static void Validate( + WorkerEnvelope envelope, + WorkerFrameProtocolOptions options) + { + if (envelope.ProtocolVersion != options.ProtocolVersion) + { + throw new WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode.ProtocolVersionMismatch, + $"Worker envelope protocol version {envelope.ProtocolVersion} does not match expected version {options.ProtocolVersion}."); + } + + if (!string.Equals(envelope.SessionId, options.SessionId, StringComparison.Ordinal)) + { + throw new WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode.SessionMismatch, + "Worker envelope session id does not match the owning worker session."); + } + + if (envelope.BodyCase == WorkerEnvelope.BodyOneofCase.None) + { + throw new WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode.InvalidEnvelope, + "Worker envelope must include a typed body."); + } + } +} diff --git a/src/MxGateway.Worker/Ipc/WorkerFrameProtocolErrorCode.cs b/src/MxGateway.Worker/Ipc/WorkerFrameProtocolErrorCode.cs new file mode 100644 index 0000000..46e90ed --- /dev/null +++ b/src/MxGateway.Worker/Ipc/WorkerFrameProtocolErrorCode.cs @@ -0,0 +1,15 @@ +namespace MxGateway.Worker.Ipc; + +public enum WorkerFrameProtocolErrorCode +{ + Unknown = 0, + InvalidConfiguration = 1, + EndOfStream = 2, + MalformedLength = 3, + MessageTooLarge = 4, + InvalidEnvelope = 5, + ProtocolVersionMismatch = 6, + SessionMismatch = 7, + NonceMismatch = 8, + UnexpectedEnvelopeBody = 9, +} diff --git a/src/MxGateway.Worker/Ipc/WorkerFrameProtocolException.cs b/src/MxGateway.Worker/Ipc/WorkerFrameProtocolException.cs new file mode 100644 index 0000000..25e3b1d --- /dev/null +++ b/src/MxGateway.Worker/Ipc/WorkerFrameProtocolException.cs @@ -0,0 +1,25 @@ +using System; + +namespace MxGateway.Worker.Ipc; + +public sealed class WorkerFrameProtocolException : Exception +{ + public WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode errorCode, + string message) + : base(message) + { + ErrorCode = errorCode; + } + + public WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode errorCode, + string message, + Exception innerException) + : base(message, innerException) + { + ErrorCode = errorCode; + } + + public WorkerFrameProtocolErrorCode ErrorCode { get; } +} diff --git a/src/MxGateway.Worker/Ipc/WorkerFrameProtocolOptions.cs b/src/MxGateway.Worker/Ipc/WorkerFrameProtocolOptions.cs new file mode 100644 index 0000000..123b979 --- /dev/null +++ b/src/MxGateway.Worker/Ipc/WorkerFrameProtocolOptions.cs @@ -0,0 +1,86 @@ +using System; +using MxGateway.Contracts; +using MxGateway.Worker.Bootstrap; + +namespace MxGateway.Worker.Ipc; + +public sealed class WorkerFrameProtocolOptions +{ + public const int DefaultMaxMessageBytes = 16 * 1024 * 1024; + + public WorkerFrameProtocolOptions(WorkerOptions options) + : this( + options?.SessionId ?? throw new ArgumentNullException(nameof(options)), + options.ProtocolVersion, + options.Nonce, + DefaultMaxMessageBytes) + { + } + + public WorkerFrameProtocolOptions( + string sessionId, + uint protocolVersion, + string nonce) + : this( + sessionId, + protocolVersion, + nonce, + DefaultMaxMessageBytes) + { + } + + public WorkerFrameProtocolOptions( + string sessionId, + uint protocolVersion, + string nonce, + int maxMessageBytes) + { + if (string.IsNullOrWhiteSpace(sessionId)) + { + throw new WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode.InvalidConfiguration, + "Worker frame protocol requires a session id."); + } + + if (protocolVersion == 0) + { + throw new WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode.InvalidConfiguration, + "Worker frame protocol requires a non-zero protocol version."); + } + + if (protocolVersion != GatewayContractInfo.WorkerProtocolVersion) + { + throw new WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode.ProtocolVersionMismatch, + $"Worker frame protocol version {protocolVersion} is not supported."); + } + + if (string.IsNullOrWhiteSpace(nonce)) + { + throw new WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode.InvalidConfiguration, + "Worker frame protocol requires a nonce."); + } + + if (maxMessageBytes <= 0) + { + throw new WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode.InvalidConfiguration, + "Worker frame protocol max message size must be greater than zero."); + } + + SessionId = sessionId; + ProtocolVersion = protocolVersion; + Nonce = nonce; + MaxMessageBytes = maxMessageBytes; + } + + public string SessionId { get; } + + public uint ProtocolVersion { get; } + + public string Nonce { get; } + + public int MaxMessageBytes { get; } +} diff --git a/src/MxGateway.Worker/Ipc/WorkerFrameReader.cs b/src/MxGateway.Worker/Ipc/WorkerFrameReader.cs new file mode 100644 index 0000000..1e7cbd1 --- /dev/null +++ b/src/MxGateway.Worker/Ipc/WorkerFrameReader.cs @@ -0,0 +1,93 @@ +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Google.Protobuf; +using MxGateway.Contracts.Proto; + +namespace MxGateway.Worker.Ipc; + +public sealed class WorkerFrameReader +{ + private readonly WorkerFrameProtocolOptions _options; + private readonly Stream _stream; + + public WorkerFrameReader( + Stream stream, + WorkerFrameProtocolOptions options) + { + _stream = stream ?? throw new ArgumentNullException(nameof(stream)); + _options = options ?? throw new ArgumentNullException(nameof(options)); + } + + public async Task ReadAsync(CancellationToken cancellationToken = default) + { + byte[] lengthPrefix = new byte[sizeof(uint)]; + await ReadExactlyOrThrowAsync(lengthPrefix, cancellationToken).ConfigureAwait(false); + + uint payloadLength = ReadUInt32LittleEndian(lengthPrefix); + if (payloadLength == 0) + { + throw new WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode.MalformedLength, + "Worker frame payload length must be greater than zero."); + } + + if (payloadLength > _options.MaxMessageBytes) + { + throw new WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode.MessageTooLarge, + $"Worker frame payload length {payloadLength} exceeds the configured maximum of {_options.MaxMessageBytes} bytes."); + } + + byte[] payload = new byte[payloadLength]; + await ReadExactlyOrThrowAsync(payload, cancellationToken).ConfigureAwait(false); + + WorkerEnvelope envelope; + try + { + envelope = WorkerEnvelope.Parser.ParseFrom(payload); + } + catch (InvalidProtocolBufferException exception) + { + throw new WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode.InvalidEnvelope, + "Worker frame payload is not a valid WorkerEnvelope protobuf message.", + exception); + } + + WorkerEnvelopeValidator.Validate(envelope, _options); + + return envelope; + } + + private static uint ReadUInt32LittleEndian(byte[] buffer) + { + return (uint)buffer[0] + | ((uint)buffer[1] << 8) + | ((uint)buffer[2] << 16) + | ((uint)buffer[3] << 24); + } + + private async Task ReadExactlyOrThrowAsync( + byte[] buffer, + CancellationToken cancellationToken) + { + int offset = 0; + while (offset < buffer.Length) + { + int bytesRead = await _stream + .ReadAsync(buffer, offset, buffer.Length - offset, cancellationToken) + .ConfigureAwait(false); + + if (bytesRead == 0) + { + throw new WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode.EndOfStream, + "Worker frame ended before the expected number of bytes were read."); + } + + offset += bytesRead; + } + } +} diff --git a/src/MxGateway.Worker/Ipc/WorkerFrameWriter.cs b/src/MxGateway.Worker/Ipc/WorkerFrameWriter.cs new file mode 100644 index 0000000..934faef --- /dev/null +++ b/src/MxGateway.Worker/Ipc/WorkerFrameWriter.cs @@ -0,0 +1,76 @@ +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Google.Protobuf; +using MxGateway.Contracts.Proto; + +namespace MxGateway.Worker.Ipc; + +public sealed class WorkerFrameWriter +{ + private readonly WorkerFrameProtocolOptions _options; + private readonly SemaphoreSlim _writeLock = new(1, 1); + private readonly Stream _stream; + + public WorkerFrameWriter( + Stream stream, + WorkerFrameProtocolOptions options) + { + _stream = stream ?? throw new ArgumentNullException(nameof(stream)); + _options = options ?? throw new ArgumentNullException(nameof(options)); + } + + public async Task WriteAsync( + WorkerEnvelope envelope, + CancellationToken cancellationToken = default) + { + if (envelope is null) + { + throw new ArgumentNullException(nameof(envelope)); + } + + WorkerEnvelopeValidator.Validate(envelope, _options); + + int payloadLength = envelope.CalculateSize(); + if (payloadLength == 0) + { + throw new WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode.InvalidEnvelope, + "Worker envelope cannot serialize to an empty payload."); + } + + if (payloadLength > _options.MaxMessageBytes) + { + throw new WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode.MessageTooLarge, + $"Worker envelope payload length {payloadLength} exceeds the configured maximum of {_options.MaxMessageBytes} bytes."); + } + + byte[] payload = envelope.ToByteArray(); + byte[] lengthPrefix = new byte[sizeof(uint)]; + WriteUInt32LittleEndian(lengthPrefix, (uint)payloadLength); + + await _writeLock.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + await _stream.WriteAsync(lengthPrefix, 0, lengthPrefix.Length, cancellationToken).ConfigureAwait(false); + await _stream.WriteAsync(payload, 0, payload.Length, cancellationToken).ConfigureAwait(false); + await _stream.FlushAsync(cancellationToken).ConfigureAwait(false); + } + finally + { + _writeLock.Release(); + } + } + + private static void WriteUInt32LittleEndian( + byte[] buffer, + uint value) + { + buffer[0] = (byte)value; + buffer[1] = (byte)(value >> 8); + buffer[2] = (byte)(value >> 16); + buffer[3] = (byte)(value >> 24); + } +} diff --git a/src/MxGateway.Worker/Ipc/WorkerPipeClient.cs b/src/MxGateway.Worker/Ipc/WorkerPipeClient.cs new file mode 100644 index 0000000..e9e408f --- /dev/null +++ b/src/MxGateway.Worker/Ipc/WorkerPipeClient.cs @@ -0,0 +1,67 @@ +using System; +using System.IO.Pipes; +using System.Threading; +using System.Threading.Tasks; +using MxGateway.Worker.Bootstrap; + +namespace MxGateway.Worker.Ipc; + +public sealed class WorkerPipeClient : IWorkerPipeClient +{ + public const int DefaultConnectTimeoutMilliseconds = 30000; + + private readonly int _connectTimeoutMilliseconds; + + public WorkerPipeClient() + : this(DefaultConnectTimeoutMilliseconds) + { + } + + public WorkerPipeClient(int connectTimeoutMilliseconds) + { + if (connectTimeoutMilliseconds <= 0) + { + throw new ArgumentOutOfRangeException( + nameof(connectTimeoutMilliseconds), + "Worker pipe connect timeout must be greater than zero."); + } + + _connectTimeoutMilliseconds = connectTimeoutMilliseconds; + } + + public async Task RunAsync( + WorkerOptions options, + CancellationToken cancellationToken = default) + { + if (options is null) + { + throw new ArgumentNullException(nameof(options)); + } + + WorkerFrameProtocolOptions frameOptions = new(options); + + using NamedPipeClientStream pipe = new( + ".", + options.PipeName, + PipeDirection.InOut, + PipeOptions.Asynchronous); + + await ConnectAsync(pipe, cancellationToken).ConfigureAwait(false); + + WorkerPipeSession session = new(pipe, frameOptions); + await session.CompleteStartupHandshakeAsync(cancellationToken).ConfigureAwait(false); + } + + private Task ConnectAsync( + NamedPipeClientStream pipe, + CancellationToken cancellationToken) + { + return Task.Run( + () => + { + cancellationToken.ThrowIfCancellationRequested(); + pipe.Connect(_connectTimeoutMilliseconds); + }, + cancellationToken); + } +} diff --git a/src/MxGateway.Worker/Ipc/WorkerPipeSession.cs b/src/MxGateway.Worker/Ipc/WorkerPipeSession.cs new file mode 100644 index 0000000..6232cc6 --- /dev/null +++ b/src/MxGateway.Worker/Ipc/WorkerPipeSession.cs @@ -0,0 +1,218 @@ +using System; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Google.Protobuf.WellKnownTypes; +using MxGateway.Contracts.Proto; +using MxGateway.Worker.MxAccess; + +namespace MxGateway.Worker.Ipc; + +public sealed class WorkerPipeSession +{ + private readonly WorkerFrameProtocolOptions _options; + private readonly Func _processIdProvider; + private readonly WorkerFrameReader _reader; + private readonly WorkerFrameWriter _writer; + private long _nextSequence; + + public WorkerPipeSession( + Stream stream, + WorkerFrameProtocolOptions options) + : this( + new WorkerFrameReader(stream, options), + new WorkerFrameWriter(stream, options), + options, + () => Process.GetCurrentProcess().Id) + { + } + + public WorkerPipeSession( + WorkerFrameReader reader, + WorkerFrameWriter writer, + WorkerFrameProtocolOptions options, + Func processIdProvider) + { + _reader = reader ?? throw new ArgumentNullException(nameof(reader)); + _writer = writer ?? throw new ArgumentNullException(nameof(writer)); + _options = options ?? throw new ArgumentNullException(nameof(options)); + _processIdProvider = processIdProvider ?? throw new ArgumentNullException(nameof(processIdProvider)); + } + + public Task CompleteStartupHandshakeAsync(CancellationToken cancellationToken = default) + { + return CompleteStartupHandshakeAsync(_ => Task.CompletedTask, cancellationToken); + } + + public async Task CompleteStartupHandshakeAsync( + Func initializeMxAccessAsync, + CancellationToken cancellationToken = default) + { + if (initializeMxAccessAsync is null) + { + throw new ArgumentNullException(nameof(initializeMxAccessAsync)); + } + + try + { + WorkerEnvelope envelope = await _reader.ReadAsync(cancellationToken).ConfigureAwait(false); + ValidateGatewayHello(envelope); + + await WriteWorkerHelloAsync(cancellationToken).ConfigureAwait(false); + await initializeMxAccessAsync(cancellationToken).ConfigureAwait(false); + await WriteWorkerReadyAsync(cancellationToken).ConfigureAwait(false); + } + catch (WorkerFrameProtocolException exception) + { + await TryWriteFaultAsync(exception, cancellationToken).ConfigureAwait(false); + throw; + } + } + + private void ValidateGatewayHello(WorkerEnvelope envelope) + { + if (envelope.BodyCase != WorkerEnvelope.BodyOneofCase.GatewayHello) + { + throw new WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode.UnexpectedEnvelopeBody, + "Worker expected GatewayHello during startup handshake."); + } + + GatewayHello gatewayHello = envelope.GatewayHello; + if (gatewayHello.SupportedProtocolVersion != _options.ProtocolVersion) + { + throw new WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode.ProtocolVersionMismatch, + $"GatewayHello supported protocol version {gatewayHello.SupportedProtocolVersion} does not match expected version {_options.ProtocolVersion}."); + } + + if (!string.Equals(gatewayHello.Nonce, _options.Nonce, StringComparison.Ordinal)) + { + throw new WorkerFrameProtocolException( + WorkerFrameProtocolErrorCode.NonceMismatch, + "GatewayHello nonce does not match the worker launch nonce."); + } + } + + private Task WriteWorkerHelloAsync(CancellationToken cancellationToken) + { + return _writer.WriteAsync( + CreateEnvelope(new WorkerHello + { + ProtocolVersion = _options.ProtocolVersion, + Nonce = _options.Nonce, + WorkerProcessId = _processIdProvider(), + WorkerVersion = typeof(WorkerPipeSession).Assembly.GetName().Version?.ToString() ?? string.Empty, + }), + cancellationToken); + } + + private Task WriteWorkerReadyAsync(CancellationToken cancellationToken) + { + return _writer.WriteAsync( + CreateEnvelope(new WorkerReady + { + WorkerProcessId = _processIdProvider(), + MxaccessProgid = MxAccessInteropInfo.ProgId, + MxaccessClsid = MxAccessInteropInfo.Clsid, + ReadyTimestamp = Timestamp.FromDateTime(DateTime.UtcNow), + }), + cancellationToken); + } + + private async Task TryWriteFaultAsync( + WorkerFrameProtocolException exception, + CancellationToken cancellationToken) + { + try + { + await _writer + .WriteAsync(CreateEnvelope(CreateFault(exception)), cancellationToken) + .ConfigureAwait(false); + } + catch (Exception faultWriteException) when ( + faultWriteException is IOException + || faultWriteException is ObjectDisposedException + || faultWriteException is WorkerFrameProtocolException) + { + // The original protocol failure is the actionable error. + } + } + + private WorkerEnvelope CreateEnvelope(WorkerHello hello) + { + return CreateBaseEnvelope(hello); + } + + private WorkerEnvelope CreateEnvelope(WorkerReady ready) + { + return CreateBaseEnvelope(ready); + } + + private WorkerEnvelope CreateEnvelope(WorkerFault fault) + { + return CreateBaseEnvelope(fault); + } + + private WorkerEnvelope CreateBaseEnvelope(WorkerHello body) + { + WorkerEnvelope envelope = CreateBaseEnvelope(); + envelope.WorkerHello = body; + return envelope; + } + + private WorkerEnvelope CreateBaseEnvelope(WorkerReady body) + { + WorkerEnvelope envelope = CreateBaseEnvelope(); + envelope.WorkerReady = body; + return envelope; + } + + private WorkerEnvelope CreateBaseEnvelope(WorkerFault body) + { + WorkerEnvelope envelope = CreateBaseEnvelope(); + envelope.WorkerFault = body; + return envelope; + } + + private WorkerEnvelope CreateBaseEnvelope() + { + return new WorkerEnvelope + { + ProtocolVersion = _options.ProtocolVersion, + SessionId = _options.SessionId, + Sequence = NextSequence(), + }; + } + + private ulong NextSequence() + { + return unchecked((ulong)Interlocked.Increment(ref _nextSequence)); + } + + private static WorkerFault CreateFault(WorkerFrameProtocolException exception) + { + return new WorkerFault + { + Category = MapFaultCategory(exception.ErrorCode), + ExceptionType = exception.GetType().FullName ?? string.Empty, + DiagnosticMessage = exception.Message, + ProtocolStatus = new ProtocolStatus + { + Code = ProtocolStatusCode.ProtocolViolation, + Message = exception.Message, + }, + }; + } + + private static WorkerFaultCategory MapFaultCategory(WorkerFrameProtocolErrorCode errorCode) + { + return errorCode switch + { + WorkerFrameProtocolErrorCode.ProtocolVersionMismatch => WorkerFaultCategory.ProtocolMismatch, + WorkerFrameProtocolErrorCode.EndOfStream => WorkerFaultCategory.PipeDisconnected, + _ => WorkerFaultCategory.ProtocolViolation, + }; + } +} diff --git a/src/MxGateway.Worker/WorkerApplication.cs b/src/MxGateway.Worker/WorkerApplication.cs index 1c39a91..3bacc28 100644 --- a/src/MxGateway.Worker/WorkerApplication.cs +++ b/src/MxGateway.Worker/WorkerApplication.cs @@ -1,6 +1,8 @@ using System; using System.Collections.Generic; +using System.IO; using MxGateway.Worker.Bootstrap; +using MxGateway.Worker.Ipc; namespace MxGateway.Worker; @@ -11,13 +13,27 @@ public static class WorkerApplication return Run( args, new EnvironmentVariableWorkerEnvironment(), - new WorkerConsoleLogger(Console.Error)); + new WorkerConsoleLogger(Console.Error), + new WorkerPipeClient()); } public static int Run( string[] args, IWorkerEnvironment environment, IWorkerLogger logger) + { + return Run( + args, + environment, + logger, + new WorkerPipeClient()); + } + + public static int Run( + string[] args, + IWorkerEnvironment environment, + IWorkerLogger logger, + IWorkerPipeClient pipeClient) { if (args is null) { @@ -34,6 +50,11 @@ public static class WorkerApplication throw new ArgumentNullException(nameof(logger)); } + if (pipeClient is null) + { + throw new ArgumentNullException(nameof(pipeClient)); + } + try { WorkerOptionsParser parser = new(environment); @@ -61,8 +82,38 @@ public static class WorkerApplication ["nonce"] = options.Nonce, }); + pipeClient.RunAsync(options).GetAwaiter().GetResult(); + + logger.Information("WorkerPipeHandshakeSucceeded", new Dictionary + { + ["session_id"] = options.SessionId, + ["pipe_name"] = options.PipeName, + ["protocol_version"] = options.ProtocolVersion, + }); + return (int)WorkerExitCode.Success; } + catch (WorkerFrameProtocolException exception) + { + logger.Error("WorkerPipeProtocolFailure", new Dictionary + { + ["exit_code"] = WorkerExitCode.ProtocolViolation, + ["error_code"] = exception.ErrorCode, + ["exception_type"] = exception.GetType().FullName, + }); + + return (int)WorkerExitCode.ProtocolViolation; + } + catch (Exception exception) when (exception is IOException or TimeoutException) + { + logger.Error("WorkerPipeConnectionFailed", new Dictionary + { + ["exit_code"] = WorkerExitCode.PipeConnectionFailed, + ["exception_type"] = exception.GetType().FullName, + }); + + return (int)WorkerExitCode.PipeConnectionFailed; + } catch (Exception exception) { logger.Error("WorkerBootstrapUnexpectedFailure", new Dictionary