Compare commits

...

4 Commits

Author SHA1 Message Date
Joseph Doherty d5a982152b Issue #22: implement pipe client and frame protocol 2026-04-26 17:16:49 -04:00
dohertj2 0b0be7098e Merge PR #64: Issue #11 implement gateway WorkerClient
Verified with dotnet build src\\MxGateway.sln and dotnet test src\\MxGateway.sln.
2026-04-26 17:14:03 -04:00
Joseph Doherty fce9e99553 Issue #11: implement gateway workerclient 2026-04-26 17:09:51 -04:00
dohertj2 c8fb3e91a3 Merge PR #63: Issue #8 add gRPC authentication and scope authorization
Verified with dotnet build src\\MxGateway.sln and dotnet test src\\MxGateway.sln.
2026-04-26 17:06:23 -04:00
24 changed files with 2389 additions and 5 deletions
+8 -2
View File
@@ -411,7 +411,7 @@ session ids as protocol faults and close the session.
`WorkerClient` is the gateway-side object that owns one worker connection.
Suggested public shape:
Current public shape:
```csharp
public interface IWorkerClient : IAsyncDisposable
@@ -419,6 +419,7 @@ public interface IWorkerClient : IAsyncDisposable
string SessionId { get; }
int? ProcessId { get; }
WorkerClientState State { get; }
DateTimeOffset LastHeartbeatAt { get; }
Task StartAsync(CancellationToken cancellationToken);
Task<WorkerCommandReply> InvokeAsync(
@@ -438,12 +439,17 @@ Internally it owns:
- pipe stream,
- read loop,
- write loop,
- bounded outbound command/control channel,
- outbound command/control channel serialized by the write loop,
- bounded inbound event channel,
- pending command dictionary keyed by correlation id,
- heartbeat monitor,
- terminal fault source.
`StartAsync` sends `GatewayHello`, verifies the `WorkerHello` protocol version
and nonce, waits for `WorkerReady`, and only then exposes `Ready` state. The
read loop starts after readiness so the handshake has a single owner for its
ordered frames.
### Read Loop
The read loop:
@@ -0,0 +1,27 @@
using MxGateway.Contracts.Proto;
namespace MxGateway.Server.Workers;
public interface IWorkerClient : IAsyncDisposable
{
string SessionId { get; }
int? ProcessId { get; }
WorkerClientState State { get; }
DateTimeOffset LastHeartbeatAt { get; }
Task StartAsync(CancellationToken cancellationToken);
Task<WorkerCommandReply> InvokeAsync(
WorkerCommand command,
TimeSpan timeout,
CancellationToken cancellationToken);
IAsyncEnumerable<WorkerEvent> ReadEventsAsync(CancellationToken cancellationToken);
Task ShutdownAsync(TimeSpan timeout, CancellationToken cancellationToken);
void Kill(string reason);
}
@@ -0,0 +1,755 @@
using System.Collections.Concurrent;
using System.Runtime.CompilerServices;
using System.Threading.Channels;
using Google.Protobuf.WellKnownTypes;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Server.Metrics;
namespace MxGateway.Server.Workers;
public sealed class WorkerClient : IWorkerClient
{
private const string GatewayVersionFallback = "unknown";
private readonly object _syncRoot = new();
private readonly WorkerClientConnection _connection;
private readonly WorkerClientOptions _options;
private readonly GatewayMetrics? _metrics;
private readonly TimeProvider _timeProvider;
private readonly ILogger<WorkerClient> _logger;
private readonly WorkerFrameReader _reader;
private readonly WorkerFrameWriter _writer;
private readonly Channel<WorkerEnvelope> _outboundEnvelopes;
private readonly Channel<WorkerEvent> _events;
private readonly ConcurrentDictionary<string, PendingCommand> _pendingCommands = new(StringComparer.Ordinal);
private readonly CancellationTokenSource _stopCts = new();
private long _nextSequence;
private WorkerClientState _state;
private DateTimeOffset _lastHeartbeatAt;
private int? _processId;
private Task? _readLoopTask;
private Task? _writeLoopTask;
private Task? _heartbeatLoopTask;
private bool _disposed;
public WorkerClient(
WorkerClientConnection connection,
WorkerClientOptions? options = null,
GatewayMetrics? metrics = null,
TimeProvider? timeProvider = null,
ILogger<WorkerClient>? logger = null)
{
_connection = connection ?? throw new ArgumentNullException(nameof(connection));
_options = options ?? new WorkerClientOptions();
_metrics = metrics;
_timeProvider = timeProvider ?? TimeProvider.System;
_logger = logger ?? NullLogger<WorkerClient>.Instance;
_reader = new WorkerFrameReader(connection.Stream, connection.FrameOptions);
_writer = new WorkerFrameWriter(connection.Stream, connection.FrameOptions);
_outboundEnvelopes = Channel.CreateUnbounded<WorkerEnvelope>(
new UnboundedChannelOptions
{
SingleReader = true,
SingleWriter = false,
AllowSynchronousContinuations = false,
});
_events = Channel.CreateBounded<WorkerEvent>(
new BoundedChannelOptions(_options.EventChannelCapacity)
{
SingleReader = false,
SingleWriter = true,
FullMode = BoundedChannelFullMode.Wait,
AllowSynchronousContinuations = false,
});
_lastHeartbeatAt = _timeProvider.GetUtcNow();
}
public string SessionId => _connection.SessionId;
public int? ProcessId
{
get
{
lock (_syncRoot)
{
return _processId;
}
}
}
public WorkerClientState State
{
get
{
lock (_syncRoot)
{
return _state;
}
}
}
public DateTimeOffset LastHeartbeatAt
{
get
{
lock (_syncRoot)
{
return _lastHeartbeatAt;
}
}
}
public async Task StartAsync(CancellationToken cancellationToken)
{
ThrowIfDisposed();
TransitionFromCreatedToHandshaking();
_writeLoopTask = Task.Run(WriteLoopAsync);
await EnqueueAsync(CreateGatewayHelloEnvelope(), cancellationToken).ConfigureAwait(false);
WorkerEnvelope helloEnvelope = await ReadHandshakeEnvelopeAsync(
WorkerEnvelope.BodyOneofCase.WorkerHello,
cancellationToken).ConfigureAwait(false);
ValidateWorkerHello(helloEnvelope.WorkerHello);
WorkerEnvelope readyEnvelope = await ReadHandshakeEnvelopeAsync(
WorkerEnvelope.BodyOneofCase.WorkerReady,
cancellationToken).ConfigureAwait(false);
MarkReady(readyEnvelope.WorkerReady);
_readLoopTask = Task.Run(ReadLoopAsync);
_heartbeatLoopTask = Task.Run(HeartbeatLoopAsync);
}
public async Task<WorkerCommandReply> InvokeAsync(
WorkerCommand command,
TimeSpan timeout,
CancellationToken cancellationToken)
{
ArgumentNullException.ThrowIfNull(command);
ThrowIfDisposed();
EnsureReady();
if (timeout <= TimeSpan.Zero)
{
throw new ArgumentOutOfRangeException(nameof(timeout), timeout, "Command timeout must be greater than zero.");
}
string correlationId = Guid.NewGuid().ToString("N");
string method = GetCommandMethod(command);
PendingCommand pendingCommand = new(
correlationId,
method,
_timeProvider.GetTimestamp());
if (!_pendingCommands.TryAdd(correlationId, pendingCommand))
{
throw new InvalidOperationException("Generated a duplicate command correlation id.");
}
_metrics?.CommandStarted(method);
try
{
await EnqueueAsync(CreateCommandEnvelope(correlationId, command), cancellationToken).ConfigureAwait(false);
using CancellationTokenSource timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
Task timeoutTask = Task.Delay(timeout, timeoutCts.Token);
Task<WorkerCommandReply> replyTask = pendingCommand.Task;
Task completedTask = await Task.WhenAny(replyTask, timeoutTask).ConfigureAwait(false);
if (completedTask == replyTask)
{
await timeoutCts.CancelAsync().ConfigureAwait(false);
return await replyTask.ConfigureAwait(false);
}
if (cancellationToken.IsCancellationRequested)
{
RemovePendingCommandAsFailed(
correlationId,
pendingCommand,
WorkerClientErrorCode.GatewayShutdown,
"Command wait was canceled.");
cancellationToken.ThrowIfCancellationRequested();
}
RemovePendingCommandAsFailed(
correlationId,
pendingCommand,
WorkerClientErrorCode.CommandTimeout,
$"Worker command {method} timed out after {timeout}.");
throw new WorkerClientException(
WorkerClientErrorCode.CommandTimeout,
$"Worker command {method} timed out after {timeout}.");
}
catch
{
_pendingCommands.TryRemove(correlationId, out _);
throw;
}
}
public async IAsyncEnumerable<WorkerEvent> ReadEventsAsync(
[EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (WorkerEvent workerEvent in _events.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false))
{
yield return workerEvent;
}
}
public async Task ShutdownAsync(TimeSpan timeout, CancellationToken cancellationToken)
{
ThrowIfDisposed();
if (timeout <= TimeSpan.Zero)
{
throw new ArgumentOutOfRangeException(nameof(timeout), timeout, "Shutdown timeout must be greater than zero.");
}
WorkerClientState state = State;
if (state is WorkerClientState.Closed or WorkerClientState.Faulted)
{
return;
}
MarkClosing();
await EnqueueAsync(CreateShutdownEnvelope(timeout, "gateway-shutdown"), cancellationToken).ConfigureAwait(false);
_outboundEnvelopes.Writer.TryComplete();
using CancellationTokenSource timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
timeoutCts.CancelAfter(timeout);
try
{
await WaitForBackgroundTasksAsync(timeoutCts.Token).ConfigureAwait(false);
MarkClosed("shutdown");
}
catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested)
{
SetFaulted(
WorkerClientErrorCode.ShutdownTimeout,
"Worker shutdown timed out.",
null);
throw new WorkerClientException(
WorkerClientErrorCode.ShutdownTimeout,
$"Worker shutdown timed out after {timeout}.");
}
}
public void Kill(string reason)
{
ThrowIfDisposed();
_connection.ProcessHandle?.Process.Kill(entireProcessTree: true);
_metrics?.WorkerKilled(reason);
SetFaulted(
WorkerClientErrorCode.WorkerFaulted,
$"Worker was killed by the gateway: {reason}.",
null);
}
public async ValueTask DisposeAsync()
{
if (_disposed)
{
return;
}
_disposed = true;
_stopCts.Cancel();
_outboundEnvelopes.Writer.TryComplete();
_events.Writer.TryComplete();
CompletePendingCommands(
new WorkerClientException(
WorkerClientErrorCode.GatewayShutdown,
"Worker client was disposed."));
await WaitForBackgroundTasksAsync(CancellationToken.None).ConfigureAwait(false);
await _connection.Stream.DisposeAsync().ConfigureAwait(false);
_connection.ProcessHandle?.Dispose();
_stopCts.Dispose();
}
private async Task WriteLoopAsync()
{
try
{
await foreach (WorkerEnvelope envelope in _outboundEnvelopes.Reader.ReadAllAsync(_stopCts.Token).ConfigureAwait(false))
{
await _writer.WriteAsync(envelope, _stopCts.Token).ConfigureAwait(false);
}
}
catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState())
{
}
catch (Exception exception)
{
SetFaulted(
WorkerClientErrorCode.WriteFailed,
"Worker pipe write failed.",
exception);
}
}
private async Task ReadLoopAsync()
{
try
{
while (!_stopCts.IsCancellationRequested)
{
WorkerEnvelope envelope = await _reader.ReadAsync(_stopCts.Token).ConfigureAwait(false);
await DispatchEnvelopeAsync(envelope, _stopCts.Token).ConfigureAwait(false);
}
}
catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState())
{
}
catch (WorkerFrameProtocolException exception) when (exception.ErrorCode == WorkerFrameProtocolErrorCode.EndOfStream)
{
SetFaulted(
WorkerClientErrorCode.PipeDisconnected,
"Worker pipe disconnected.",
exception);
}
catch (Exception exception)
{
SetFaulted(
WorkerClientErrorCode.ProtocolViolation,
"Worker read loop failed.",
exception);
}
}
private async Task HeartbeatLoopAsync()
{
try
{
while (!_stopCts.IsCancellationRequested)
{
await Task.Delay(_options.HeartbeatCheckInterval, _stopCts.Token).ConfigureAwait(false);
if (State != WorkerClientState.Ready)
{
continue;
}
DateTimeOffset lastHeartbeatAt = LastHeartbeatAt;
DateTimeOffset now = _timeProvider.GetUtcNow();
if (now - lastHeartbeatAt <= _options.HeartbeatGrace)
{
continue;
}
_metrics?.HeartbeatFailed(SessionId);
SetFaulted(
WorkerClientErrorCode.HeartbeatExpired,
$"Worker heartbeat expired. Last heartbeat was at {lastHeartbeatAt:O}.",
null);
}
}
catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState())
{
}
}
private async Task DispatchEnvelopeAsync(
WorkerEnvelope envelope,
CancellationToken cancellationToken)
{
switch (envelope.BodyCase)
{
case WorkerEnvelope.BodyOneofCase.WorkerCommandReply:
CompleteCommand(envelope);
break;
case WorkerEnvelope.BodyOneofCase.WorkerEvent:
await EnqueueWorkerEventAsync(envelope.WorkerEvent, cancellationToken).ConfigureAwait(false);
break;
case WorkerEnvelope.BodyOneofCase.WorkerHeartbeat:
MarkHeartbeat(envelope.WorkerHeartbeat);
break;
case WorkerEnvelope.BodyOneofCase.WorkerFault:
SetFaulted(
WorkerClientErrorCode.WorkerFaulted,
CreateWorkerFaultMessage(envelope.WorkerFault),
null);
break;
case WorkerEnvelope.BodyOneofCase.WorkerShutdownAck:
MarkClosed("worker-shutdown-ack");
break;
default:
SetFaulted(
WorkerClientErrorCode.ProtocolViolation,
$"Worker sent unexpected envelope body {envelope.BodyCase}.",
null);
break;
}
}
private async Task EnqueueWorkerEventAsync(
WorkerEvent workerEvent,
CancellationToken cancellationToken)
{
if (workerEvent.Event is not null)
{
_metrics?.EventReceived(SessionId, workerEvent.Event.Family.ToString());
}
if (!await _events.Writer.WaitToWriteAsync(cancellationToken).ConfigureAwait(false))
{
return;
}
if (!_events.Writer.TryWrite(workerEvent))
{
_metrics?.QueueOverflow("worker-events");
SetFaulted(
WorkerClientErrorCode.ProtocolViolation,
"Worker event channel rejected an event.",
null);
}
}
private void CompleteCommand(WorkerEnvelope envelope)
{
string correlationId = envelope.CorrelationId;
if (string.IsNullOrWhiteSpace(correlationId))
{
correlationId = envelope.WorkerCommandReply.Reply?.CorrelationId ?? string.Empty;
}
if (!_pendingCommands.TryRemove(correlationId, out PendingCommand? pendingCommand))
{
_logger.LogDebug(
"Ignoring late or unknown worker command reply for session {SessionId} and correlation {CorrelationId}.",
SessionId,
correlationId);
return;
}
TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp);
_metrics?.CommandSucceeded(pendingCommand.Method, duration);
pendingCommand.SetResult(envelope.WorkerCommandReply);
}
private void RemovePendingCommandAsFailed(
string correlationId,
PendingCommand pendingCommand,
WorkerClientErrorCode errorCode,
string message)
{
if (!_pendingCommands.TryRemove(correlationId, out _))
{
return;
}
TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp);
_metrics?.CommandFailed(pendingCommand.Method, errorCode.ToString(), duration);
pendingCommand.SetException(new WorkerClientException(errorCode, message));
}
private async Task<WorkerEnvelope> ReadHandshakeEnvelopeAsync(
WorkerEnvelope.BodyOneofCase expectedBody,
CancellationToken cancellationToken)
{
WorkerEnvelope envelope = await _reader.ReadAsync(cancellationToken).ConfigureAwait(false);
if (envelope.BodyCase != expectedBody)
{
throw new WorkerClientException(
WorkerClientErrorCode.ProtocolViolation,
$"Worker handshake expected {expectedBody} but received {envelope.BodyCase}.");
}
return envelope;
}
private void ValidateWorkerHello(WorkerHello workerHello)
{
if (workerHello.ProtocolVersion != _connection.FrameOptions.ProtocolVersion)
{
throw new WorkerClientException(
WorkerClientErrorCode.ProtocolViolation,
"Worker hello protocol version does not match the gateway protocol version.");
}
if (!string.Equals(workerHello.Nonce, _connection.Nonce, StringComparison.Ordinal))
{
throw new WorkerClientException(
WorkerClientErrorCode.ProtocolViolation,
"Worker hello nonce does not match the gateway nonce.");
}
lock (_syncRoot)
{
_processId = workerHello.WorkerProcessId == 0
? _connection.ProcessHandle?.ProcessId
: workerHello.WorkerProcessId;
}
}
private void MarkReady(WorkerReady ready)
{
lock (_syncRoot)
{
_processId = ready.WorkerProcessId == 0
? _processId ?? _connection.ProcessHandle?.ProcessId
: ready.WorkerProcessId;
_lastHeartbeatAt = _timeProvider.GetUtcNow();
_state = WorkerClientState.Ready;
}
DateTimeOffset readyAt = _timeProvider.GetUtcNow();
DateTimeOffset launchedAt = _connection.ProcessHandle?.LaunchedAt ?? readyAt;
_metrics?.WorkerStarted(readyAt - launchedAt);
}
private void MarkHeartbeat(WorkerHeartbeat heartbeat)
{
lock (_syncRoot)
{
_lastHeartbeatAt = _timeProvider.GetUtcNow();
if (heartbeat.WorkerProcessId != 0)
{
_processId = heartbeat.WorkerProcessId;
}
}
}
private void MarkClosing()
{
lock (_syncRoot)
{
if (_state is WorkerClientState.Closed or WorkerClientState.Faulted)
{
return;
}
_state = WorkerClientState.Closing;
}
}
private void MarkClosed(string reason)
{
lock (_syncRoot)
{
if (_state == WorkerClientState.Closed)
{
return;
}
_state = WorkerClientState.Closed;
}
_stopCts.Cancel();
_outboundEnvelopes.Writer.TryComplete();
_events.Writer.TryComplete();
CompletePendingCommands(
new WorkerClientException(
WorkerClientErrorCode.GatewayShutdown,
$"Worker client closed because {reason}."));
_metrics?.WorkerStopped(reason);
}
private void SetFaulted(
WorkerClientErrorCode errorCode,
string message,
Exception? exception)
{
WorkerClientException fault = exception is null
? new WorkerClientException(errorCode, message)
: new WorkerClientException(errorCode, message, exception);
lock (_syncRoot)
{
if (_state is WorkerClientState.Faulted or WorkerClientState.Closed)
{
return;
}
_state = WorkerClientState.Faulted;
}
_stopCts.Cancel();
_outboundEnvelopes.Writer.TryComplete(fault);
_events.Writer.TryComplete(fault);
CompletePendingCommands(fault);
_metrics?.Fault(errorCode.ToString());
_logger.LogWarning(exception, "Worker client faulted for session {SessionId}: {Message}", SessionId, message);
}
private void CompletePendingCommands(Exception exception)
{
foreach (KeyValuePair<string, PendingCommand> item in _pendingCommands.ToArray())
{
if (_pendingCommands.TryRemove(item.Key, out PendingCommand? pendingCommand))
{
TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp);
_metrics?.CommandFailed(pendingCommand.Method, exception.GetType().Name, duration);
pendingCommand.SetException(exception);
}
}
}
private void TransitionFromCreatedToHandshaking()
{
lock (_syncRoot)
{
if (_state != WorkerClientState.Created)
{
throw new WorkerClientException(
WorkerClientErrorCode.InvalidState,
$"Worker client cannot start from state {_state}.");
}
_state = WorkerClientState.Handshaking;
}
}
private void EnsureReady()
{
WorkerClientState state = State;
if (state != WorkerClientState.Ready)
{
throw new WorkerClientException(
WorkerClientErrorCode.InvalidState,
$"Worker client is not ready. Current state is {state}.");
}
}
private bool IsTerminalState()
{
WorkerClientState state = State;
return state is WorkerClientState.Closing or WorkerClientState.Closed or WorkerClientState.Faulted;
}
private async Task EnqueueAsync(
WorkerEnvelope envelope,
CancellationToken cancellationToken)
{
try
{
await _outboundEnvelopes.Writer.WriteAsync(envelope, cancellationToken).ConfigureAwait(false);
}
catch (ChannelClosedException exception)
{
throw new WorkerClientException(
WorkerClientErrorCode.WriteFailed,
"Worker outbound channel is closed.",
exception);
}
}
private WorkerEnvelope CreateGatewayHelloEnvelope()
{
return CreateEnvelope(
correlationId: string.Empty,
envelope => envelope.GatewayHello = new GatewayHello
{
SupportedProtocolVersion = _connection.FrameOptions.ProtocolVersion,
Nonce = _connection.Nonce,
GatewayVersion = typeof(GatewayContractInfo).Assembly.GetName().Version?.ToString() ?? GatewayVersionFallback,
});
}
private WorkerEnvelope CreateCommandEnvelope(
string correlationId,
WorkerCommand command)
{
return CreateEnvelope(
correlationId,
envelope => envelope.WorkerCommand = command.Clone());
}
private WorkerEnvelope CreateShutdownEnvelope(
TimeSpan timeout,
string reason)
{
return CreateEnvelope(
correlationId: string.Empty,
envelope => envelope.WorkerShutdown = new WorkerShutdown
{
GracePeriod = Duration.FromTimeSpan(timeout),
Reason = reason,
});
}
private WorkerEnvelope CreateEnvelope(
string correlationId,
Action<WorkerEnvelope> setBody)
{
WorkerEnvelope envelope = new()
{
ProtocolVersion = _connection.FrameOptions.ProtocolVersion,
SessionId = SessionId,
Sequence = (ulong)Interlocked.Increment(ref _nextSequence),
CorrelationId = correlationId,
};
setBody(envelope);
return envelope;
}
private static string GetCommandMethod(WorkerCommand command)
{
return command.Command?.Kind.ToString() ?? MxCommandKind.Unspecified.ToString();
}
private static string CreateWorkerFaultMessage(WorkerFault fault)
{
return string.IsNullOrWhiteSpace(fault.DiagnosticMessage)
? $"Worker faulted with category {fault.Category}."
: $"Worker faulted with category {fault.Category}: {fault.DiagnosticMessage}";
}
private async Task WaitForBackgroundTasksAsync(CancellationToken cancellationToken)
{
Task[] tasks = new[] { _readLoopTask, _writeLoopTask, _heartbeatLoopTask }
.Where(task => task is not null)
.Cast<Task>()
.ToArray();
if (tasks.Length == 0)
{
return;
}
await Task.WhenAll(tasks).WaitAsync(cancellationToken).ConfigureAwait(false);
}
private void ThrowIfDisposed()
{
ObjectDisposedException.ThrowIf(_disposed, this);
}
private sealed class PendingCommand
{
private readonly TaskCompletionSource<WorkerCommandReply> _completion = new(TaskCreationOptions.RunContinuationsAsynchronously);
public PendingCommand(
string correlationId,
string method,
long startTimestamp)
{
CorrelationId = correlationId;
Method = method;
StartTimestamp = startTimestamp;
}
public string CorrelationId { get; }
public string Method { get; }
public long StartTimestamp { get; }
public Task<WorkerCommandReply> Task => _completion.Task;
public void SetResult(WorkerCommandReply reply)
{
_completion.TrySetResult(reply);
}
public void SetException(Exception exception)
{
_completion.TrySetException(exception);
}
}
}
@@ -0,0 +1,38 @@
namespace MxGateway.Server.Workers;
public sealed class WorkerClientConnection
{
public WorkerClientConnection(
string sessionId,
string nonce,
Stream stream,
WorkerFrameProtocolOptions frameOptions,
WorkerProcessHandle? processHandle = null)
{
if (string.IsNullOrWhiteSpace(sessionId))
{
throw new ArgumentException("Session id is required.", nameof(sessionId));
}
if (string.IsNullOrWhiteSpace(nonce))
{
throw new ArgumentException("Worker nonce is required.", nameof(nonce));
}
SessionId = sessionId;
Nonce = nonce;
Stream = stream ?? throw new ArgumentNullException(nameof(stream));
FrameOptions = frameOptions ?? throw new ArgumentNullException(nameof(frameOptions));
ProcessHandle = processHandle;
}
public string SessionId { get; }
public string Nonce { get; }
public Stream Stream { get; }
public WorkerFrameProtocolOptions FrameOptions { get; }
public WorkerProcessHandle? ProcessHandle { get; }
}
@@ -0,0 +1,14 @@
namespace MxGateway.Server.Workers;
public enum WorkerClientErrorCode
{
InvalidState,
ProtocolViolation,
PipeDisconnected,
CommandTimeout,
WorkerFaulted,
HeartbeatExpired,
ShutdownTimeout,
GatewayShutdown,
WriteFailed,
}
@@ -0,0 +1,23 @@
namespace MxGateway.Server.Workers;
public sealed class WorkerClientException : Exception
{
public WorkerClientException(
WorkerClientErrorCode errorCode,
string message)
: base(message)
{
ErrorCode = errorCode;
}
public WorkerClientException(
WorkerClientErrorCode errorCode,
string message,
Exception innerException)
: base(message, innerException)
{
ErrorCode = errorCode;
}
public WorkerClientErrorCode ErrorCode { get; }
}
@@ -0,0 +1,24 @@
namespace MxGateway.Server.Workers;
public sealed class WorkerClientOptions
{
public static readonly TimeSpan DefaultHeartbeatGrace = TimeSpan.FromSeconds(15);
public static readonly TimeSpan DefaultHeartbeatCheckInterval = TimeSpan.FromSeconds(1);
public static readonly TimeSpan DefaultEventChannelFullModeTimeout = TimeSpan.FromSeconds(5);
public WorkerClientOptions()
{
HeartbeatGrace = DefaultHeartbeatGrace;
HeartbeatCheckInterval = DefaultHeartbeatCheckInterval;
EventChannelCapacity = 1_024;
EventChannelFullModeTimeout = DefaultEventChannelFullModeTimeout;
}
public TimeSpan HeartbeatGrace { get; init; }
public TimeSpan HeartbeatCheckInterval { get; init; }
public int EventChannelCapacity { get; init; }
public TimeSpan EventChannelFullModeTimeout { get; init; }
}
@@ -0,0 +1,11 @@
namespace MxGateway.Server.Workers;
public enum WorkerClientState
{
Created,
Handshaking,
Ready,
Closing,
Closed,
Faulted,
}
@@ -0,0 +1,341 @@
using System.IO.Pipes;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Server.Workers;
namespace MxGateway.Tests.Gateway.Workers;
public sealed class WorkerClientTests
{
private const string SessionId = "session-worker-client";
private const string Nonce = "nonce-worker-client";
private const int WorkerProcessId = 4321;
private static readonly TimeSpan TestTimeout = TimeSpan.FromSeconds(5);
[Fact]
public async Task StartAsync_WithWorkerHelloAndReady_EntersReadyState()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
Assert.Equal(WorkerClientState.Ready, client.State);
Assert.Equal(WorkerProcessId, client.ProcessId);
}
[Fact]
public async Task InvokeAsync_WithMatchingReply_CompletesPendingCommand()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
Task<WorkerCommandReply> invokeTask = client.InvokeAsync(
CreateCommand(MxCommandKind.Ping),
TestTimeout,
CancellationToken.None);
WorkerEnvelope commandEnvelope = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerCommand, commandEnvelope.BodyCase);
Assert.False(string.IsNullOrWhiteSpace(commandEnvelope.CorrelationId));
await pipePair.WorkerWriter.WriteAsync(
CreateCommandReplyEnvelope(commandEnvelope.CorrelationId, MxCommandKind.Ping));
WorkerCommandReply reply = await invokeTask.WaitAsync(TestTimeout);
Assert.Equal(commandEnvelope.CorrelationId, reply.Reply.CorrelationId);
Assert.Equal(MxCommandKind.Ping, reply.Reply.Kind);
}
[Fact]
public async Task InvokeAsync_WithLateReply_IgnoresLateReplyAndKeepsClientReady()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
Task<WorkerCommandReply> timedOutInvokeTask = client.InvokeAsync(
CreateCommand(MxCommandKind.Ping),
TimeSpan.FromMilliseconds(50),
CancellationToken.None);
WorkerEnvelope timedOutCommand = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
WorkerClientException exception = await Assert.ThrowsAsync<WorkerClientException>(
async () => await timedOutInvokeTask);
Assert.Equal(WorkerClientErrorCode.CommandTimeout, exception.ErrorCode);
await pipePair.WorkerWriter.WriteAsync(
CreateCommandReplyEnvelope(timedOutCommand.CorrelationId, MxCommandKind.Ping));
await Task.Delay(TimeSpan.FromMilliseconds(50));
Task<WorkerCommandReply> secondInvokeTask = client.InvokeAsync(
CreateCommand(MxCommandKind.GetWorkerInfo),
TestTimeout,
CancellationToken.None);
WorkerEnvelope secondCommand = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
await pipePair.WorkerWriter.WriteAsync(
CreateCommandReplyEnvelope(secondCommand.CorrelationId, MxCommandKind.GetWorkerInfo));
WorkerCommandReply reply = await secondInvokeTask.WaitAsync(TestTimeout);
Assert.Equal(WorkerClientState.Ready, client.State);
Assert.Equal(MxCommandKind.GetWorkerInfo, reply.Reply.Kind);
}
[Fact]
public async Task ReadEventsAsync_WithWorkerEvents_YieldsEventsInPipeOrder()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
using CancellationTokenSource cancellationTokenSource = new(TestTimeout);
await using IAsyncEnumerator<WorkerEvent> events =
client.ReadEventsAsync(cancellationTokenSource.Token).GetAsyncEnumerator(cancellationTokenSource.Token);
await pipePair.WorkerWriter.WriteAsync(
CreateEventEnvelope(sequence: 11, MxEventFamily.OnDataChange));
await pipePair.WorkerWriter.WriteAsync(
CreateEventEnvelope(sequence: 12, MxEventFamily.OperationComplete));
Assert.True(await events.MoveNextAsync());
Assert.Equal((ulong)11, events.Current.Event.WorkerSequence);
Assert.Equal(MxEventFamily.OnDataChange, events.Current.Event.Family);
Assert.True(await events.MoveNextAsync());
Assert.Equal((ulong)12, events.Current.Event.WorkerSequence);
Assert.Equal(MxEventFamily.OperationComplete, events.Current.Event.Family);
}
[Fact]
public async Task ReadLoop_WhenPipeDisconnects_FaultsClient()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
await pipePair.DisposeWorkerSideAsync();
await WaitUntilAsync(
() => client.State == WorkerClientState.Faulted,
TestTimeout);
Assert.Equal(WorkerClientState.Faulted, client.State);
}
[Fact]
public async Task HeartbeatMonitor_WhenHeartbeatExpires_FaultsClient()
{
await using PipePair pipePair = await PipePair.CreateAsync();
await using WorkerClient client = CreateClient(
pipePair,
new WorkerClientOptions
{
HeartbeatGrace = TimeSpan.FromMilliseconds(80),
HeartbeatCheckInterval = TimeSpan.FromMilliseconds(20),
EventChannelCapacity = 8,
});
await CompleteHandshakeAsync(client, pipePair);
await WaitUntilAsync(
() => client.State == WorkerClientState.Faulted,
TestTimeout);
Assert.Equal(WorkerClientState.Faulted, client.State);
}
private static WorkerClient CreateClient(
PipePair pipePair,
WorkerClientOptions? options = null)
{
WorkerFrameProtocolOptions frameOptions = new(SessionId);
WorkerClientConnection connection = new(
SessionId,
Nonce,
pipePair.GatewayStream,
frameOptions);
return new WorkerClient(connection, options);
}
private static async Task CompleteHandshakeAsync(
WorkerClient client,
PipePair pipePair)
{
Task startTask = client.StartAsync(CancellationToken.None);
WorkerEnvelope gatewayHello = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout);
Assert.Equal(WorkerEnvelope.BodyOneofCase.GatewayHello, gatewayHello.BodyCase);
Assert.Equal(Nonce, gatewayHello.GatewayHello.Nonce);
Assert.Equal(GatewayContractInfo.WorkerProtocolVersion, gatewayHello.GatewayHello.SupportedProtocolVersion);
await pipePair.WorkerWriter.WriteAsync(CreateWorkerHelloEnvelope());
await pipePair.WorkerWriter.WriteAsync(CreateWorkerReadyEnvelope());
await startTask.WaitAsync(TestTimeout);
}
private static WorkerCommand CreateCommand(MxCommandKind kind)
{
return new WorkerCommand
{
Command = new MxCommand
{
Kind = kind,
},
};
}
private static WorkerEnvelope CreateWorkerHelloEnvelope()
{
return CreateWorkerEnvelope(
correlationId: string.Empty,
sequence: 1,
envelope => envelope.WorkerHello = new WorkerHello
{
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
Nonce = Nonce,
WorkerProcessId = WorkerProcessId,
WorkerVersion = "fake-worker",
});
}
private static WorkerEnvelope CreateWorkerReadyEnvelope()
{
return CreateWorkerEnvelope(
correlationId: string.Empty,
sequence: 2,
envelope => envelope.WorkerReady = new WorkerReady
{
WorkerProcessId = WorkerProcessId,
MxaccessProgid = "LMXProxy.LMXProxyServer.1",
MxaccessClsid = "{C30B52F5-2CB5-4760-AF0A-3A344A7EB5DC}",
});
}
private static WorkerEnvelope CreateCommandReplyEnvelope(
string correlationId,
MxCommandKind kind)
{
return CreateWorkerEnvelope(
correlationId,
sequence: 10,
envelope => envelope.WorkerCommandReply = new WorkerCommandReply
{
Reply = new MxCommandReply
{
SessionId = SessionId,
CorrelationId = correlationId,
Kind = kind,
},
});
}
private static WorkerEnvelope CreateEventEnvelope(
ulong sequence,
MxEventFamily family)
{
return CreateWorkerEnvelope(
correlationId: string.Empty,
sequence,
envelope => envelope.WorkerEvent = new WorkerEvent
{
Event = new MxEvent
{
SessionId = SessionId,
Family = family,
WorkerSequence = sequence,
},
});
}
private static WorkerEnvelope CreateWorkerEnvelope(
string correlationId,
ulong sequence,
Action<WorkerEnvelope> setBody)
{
WorkerEnvelope envelope = new()
{
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
SessionId = SessionId,
Sequence = sequence,
CorrelationId = correlationId,
};
setBody(envelope);
return envelope;
}
private static async Task WaitUntilAsync(
Func<bool> predicate,
TimeSpan timeout)
{
using CancellationTokenSource cancellationTokenSource = new(timeout);
while (!predicate())
{
await Task.Delay(TimeSpan.FromMilliseconds(10), cancellationTokenSource.Token);
}
}
private sealed class PipePair : IAsyncDisposable
{
private readonly NamedPipeClientStream _workerStream;
private bool _workerSideDisposed;
private PipePair(
NamedPipeServerStream gatewayStream,
NamedPipeClientStream workerStream)
{
GatewayStream = gatewayStream;
_workerStream = workerStream;
WorkerReader = new WorkerFrameReader(_workerStream, new WorkerFrameProtocolOptions(SessionId));
WorkerWriter = new WorkerFrameWriter(_workerStream, new WorkerFrameProtocolOptions(SessionId));
}
public NamedPipeServerStream GatewayStream { get; }
public WorkerFrameReader WorkerReader { get; }
public WorkerFrameWriter WorkerWriter { get; }
public static async Task<PipePair> CreateAsync()
{
string pipeName = $"mxaccessgw-workerclient-tests-{Guid.NewGuid():N}";
NamedPipeServerStream gatewayStream = new(
pipeName,
PipeDirection.InOut,
maxNumberOfServerInstances: 1,
PipeTransmissionMode.Byte,
PipeOptions.Asynchronous);
NamedPipeClientStream workerStream = new(
".",
pipeName,
PipeDirection.InOut,
PipeOptions.Asynchronous);
Task waitForConnectionTask = gatewayStream.WaitForConnectionAsync();
await workerStream.ConnectAsync();
await waitForConnectionTask;
return new PipePair(gatewayStream, workerStream);
}
public async ValueTask DisposeWorkerSideAsync()
{
if (_workerSideDisposed)
{
return;
}
await _workerStream.DisposeAsync();
_workerSideDisposed = true;
}
public async ValueTask DisposeAsync()
{
await DisposeWorkerSideAsync();
await GatewayStream.DisposeAsync();
}
}
}
@@ -1,6 +1,9 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Contracts;
using MxGateway.Worker.Bootstrap;
using MxGateway.Worker.Ipc;
namespace MxGateway.Worker.Tests.Bootstrap;
@@ -15,16 +18,19 @@ public sealed class WorkerApplicationTests
int exitCode = MxGateway.Worker.WorkerApplication.Run(
ValidArgs(),
environment,
logger);
logger,
new SucceedingPipeClient());
Assert.Equal((int)WorkerExitCode.Success, exitCode);
MemoryWorkerLogEntry entry = Assert.Single(logger.Entries);
Assert.Equal(2, logger.Entries.Count);
MemoryWorkerLogEntry entry = logger.Entries[0];
Assert.Equal("Information", entry.Level);
Assert.Equal("WorkerBootstrapSucceeded", entry.EventName);
Assert.Equal("session-1", entry.Fields["session_id"]);
Assert.Equal("mxaccess-gateway-123-session-1", entry.Fields["pipe_name"]);
Assert.Equal(GatewayContractInfo.WorkerProtocolVersion, entry.Fields["protocol_version"]);
Assert.Equal("[redacted]", entry.Fields["nonce"]);
Assert.Equal("WorkerPipeHandshakeSucceeded", logger.Entries[1].EventName);
}
[Fact]
@@ -73,6 +79,24 @@ public sealed class WorkerApplicationTests
Assert.Equal((int)WorkerExitCode.MissingNonce, exitCode);
}
[Fact]
public void Run_WithPipeProtocolFailure_ReturnsProtocolViolation()
{
MemoryWorkerEnvironment environment = CreateEnvironment("nonce-secret");
MemoryWorkerLogger logger = new();
int exitCode = MxGateway.Worker.WorkerApplication.Run(
ValidArgs(),
environment,
logger,
new ThrowingPipeClient(new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.NonceMismatch,
"Bad nonce.")));
Assert.Equal((int)WorkerExitCode.ProtocolViolation, exitCode);
Assert.Equal("WorkerPipeProtocolFailure", logger.Entries[1].EventName);
}
[Fact]
public void Run_WithUnexpectedBootstrapFailure_ReturnsUnexpectedFailure()
{
@@ -110,4 +134,31 @@ public sealed class WorkerApplicationTests
environment.Set(WorkerOptions.NonceEnvironmentVariableName, nonce);
return environment;
}
private sealed class SucceedingPipeClient : IWorkerPipeClient
{
public Task RunAsync(
WorkerOptions options,
CancellationToken cancellationToken = default)
{
return Task.CompletedTask;
}
}
private sealed class ThrowingPipeClient : IWorkerPipeClient
{
private readonly Exception _exception;
public ThrowingPipeClient(Exception exception)
{
_exception = exception;
}
public Task RunAsync(
WorkerOptions options,
CancellationToken cancellationToken = default)
{
throw _exception;
}
}
}
@@ -0,0 +1,163 @@
using System;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Google.Protobuf;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Ipc;
namespace MxGateway.Worker.Tests.Ipc;
public sealed class WorkerFrameProtocolTests
{
private const string SessionId = "session-1";
private const string Nonce = "nonce-secret";
[Fact]
public async Task WriteAndReadAsync_WithValidEnvelope_RoundTripsFrame()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream stream = new();
WorkerEnvelope original = CreateGatewayHelloEnvelope();
WorkerFrameWriter writer = new(stream, options);
await writer.WriteAsync(original);
stream.Position = 0;
WorkerFrameReader reader = new(stream, options);
WorkerEnvelope parsed = await reader.ReadAsync();
Assert.Equal(original, parsed);
}
[Fact]
public async Task ReadAsync_WithWrongProtocolVersion_ThrowsProtocolVersionMismatch()
{
WorkerFrameProtocolOptions options = CreateOptions();
WorkerEnvelope envelope = CreateGatewayHelloEnvelope();
envelope.ProtocolVersion++;
MemoryStream stream = new(CreateFrame(envelope));
WorkerFrameReader reader = new(stream, options);
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await reader.ReadAsync());
Assert.Equal(WorkerFrameProtocolErrorCode.ProtocolVersionMismatch, exception.ErrorCode);
}
[Fact]
public async Task ReadAsync_WithWrongSessionId_ThrowsSessionMismatch()
{
WorkerFrameProtocolOptions options = CreateOptions();
WorkerEnvelope envelope = CreateGatewayHelloEnvelope();
envelope.SessionId = "different-session";
MemoryStream stream = new(CreateFrame(envelope));
WorkerFrameReader reader = new(stream, options);
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await reader.ReadAsync());
Assert.Equal(WorkerFrameProtocolErrorCode.SessionMismatch, exception.ErrorCode);
}
[Fact]
public async Task ReadAsync_WithMalformedLength_ThrowsMalformedLength()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream stream = new(new byte[sizeof(uint)]);
WorkerFrameReader reader = new(stream, options);
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await reader.ReadAsync());
Assert.Equal(WorkerFrameProtocolErrorCode.MalformedLength, exception.ErrorCode);
}
[Fact]
public async Task ReadAsync_WithMalformedPayload_ThrowsInvalidEnvelope()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream stream = new(CreateFrame(new byte[] { 0x80 }));
WorkerFrameReader reader = new(stream, options);
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await reader.ReadAsync());
Assert.Equal(WorkerFrameProtocolErrorCode.InvalidEnvelope, exception.ErrorCode);
}
[Fact]
public async Task WriteAsync_WithConcurrentCalls_SerializesCompleteFrames()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream stream = new();
WorkerFrameWriter writer = new(stream, options);
await Task.WhenAll(
writer.WriteAsync(CreateGatewayHelloEnvelope(sequence: 1)),
writer.WriteAsync(CreateGatewayHelloEnvelope(sequence: 2)),
writer.WriteAsync(CreateGatewayHelloEnvelope(sequence: 3)));
stream.Position = 0;
WorkerFrameReader reader = new(stream, options);
WorkerEnvelope first = await reader.ReadAsync();
WorkerEnvelope second = await reader.ReadAsync();
WorkerEnvelope third = await reader.ReadAsync();
Assert.Equal(new ulong[] { 1, 2, 3 }, new[] { first.Sequence, second.Sequence, third.Sequence }.OrderBy(sequence => sequence));
}
private static WorkerFrameProtocolOptions CreateOptions()
{
return new WorkerFrameProtocolOptions(
SessionId,
GatewayContractInfo.WorkerProtocolVersion,
Nonce);
}
private static WorkerEnvelope CreateGatewayHelloEnvelope(ulong sequence = 1)
{
return new WorkerEnvelope
{
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
SessionId = SessionId,
Sequence = sequence,
GatewayHello = new GatewayHello
{
SupportedProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
Nonce = Nonce,
GatewayVersion = "test-gateway",
},
};
}
private static byte[] CreateFrame(IMessage message)
{
return CreateFrame(message.ToByteArray());
}
private static byte[] CreateFrame(byte[] payload)
{
byte[] frame = new byte[sizeof(uint) + payload.Length];
WriteUInt32LittleEndian(frame, (uint)payload.Length);
payload.CopyTo(frame, sizeof(uint));
return frame;
}
private static void WriteUInt32LittleEndian(
byte[] buffer,
uint value)
{
buffer[0] = (byte)value;
buffer[1] = (byte)(value >> 8);
buffer[2] = (byte)(value >> 16);
buffer[3] = (byte)(value >> 24);
}
}
@@ -0,0 +1,61 @@
using System;
using System.IO.Pipes;
using System.Threading.Tasks;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Bootstrap;
using MxGateway.Worker.Ipc;
namespace MxGateway.Worker.Tests.Ipc;
public sealed class WorkerPipeClientTests
{
[Fact]
public async Task RunAsync_ConnectsToPipeAndCompletesHandshake()
{
string pipeName = $"mxaccess-gateway-test-{Guid.NewGuid():N}";
WorkerOptions workerOptions = new(
"session-1",
pipeName,
GatewayContractInfo.WorkerProtocolVersion,
"nonce-secret");
WorkerFrameProtocolOptions frameOptions = new(workerOptions);
using NamedPipeServerStream server = new(
pipeName,
PipeDirection.InOut,
1,
PipeTransmissionMode.Byte,
PipeOptions.Asynchronous);
WorkerPipeClient client = new(connectTimeoutMilliseconds: 5000);
Task clientTask = client.RunAsync(workerOptions);
await Task.Factory.FromAsync(server.BeginWaitForConnection, server.EndWaitForConnection, null);
WorkerFrameReader reader = new(server, frameOptions);
WorkerFrameWriter writer = new(server, frameOptions);
await writer.WriteAsync(new WorkerEnvelope
{
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
SessionId = "session-1",
Sequence = 1,
GatewayHello = new GatewayHello
{
SupportedProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
Nonce = "nonce-secret",
GatewayVersion = "test-gateway",
},
});
WorkerEnvelope hello = await reader.ReadAsync();
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerHello, hello.BodyCase);
Assert.Equal("nonce-secret", hello.WorkerHello.Nonce);
WorkerEnvelope ready = await reader.ReadAsync();
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerReady, ready.BodyCase);
await clientTask;
}
}
@@ -0,0 +1,192 @@
using System.Collections.Generic;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Contracts;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.Ipc;
namespace MxGateway.Worker.Tests.Ipc;
public sealed class WorkerPipeSessionTests
{
private const string SessionId = "session-1";
private const string Nonce = "nonce-secret";
[Fact]
public async Task CompleteStartupHandshakeAsync_WithValidGatewayHello_SendsHelloThenReady()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream inbound = new();
await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope());
inbound.Position = 0;
MemoryStream outbound = new();
WorkerPipeSession session = CreateSession(inbound, outbound, options);
bool initialized = false;
await session.CompleteStartupHandshakeAsync(
_ =>
{
initialized = true;
return Task.CompletedTask;
});
Assert.True(initialized);
WorkerEnvelope[] written = ReadWrittenFrames(outbound, options);
Assert.Equal(2, written.Length);
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerHello, written[0].BodyCase);
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerReady, written[1].BodyCase);
Assert.Equal(Nonce, written[0].WorkerHello.Nonce);
}
[Fact]
public async Task CompleteStartupHandshakeAsync_WithWrongNonce_FaultsBeforeInitialization()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream inbound = new();
await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope(nonce: "wrong"));
inbound.Position = 0;
MemoryStream outbound = new();
WorkerPipeSession session = CreateSession(inbound, outbound, options);
bool initialized = false;
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await session.CompleteStartupHandshakeAsync(
_ =>
{
initialized = true;
return Task.CompletedTask;
}));
Assert.False(initialized);
Assert.Equal(WorkerFrameProtocolErrorCode.NonceMismatch, exception.ErrorCode);
WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options));
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerFault, fault.BodyCase);
Assert.Equal(WorkerFaultCategory.ProtocolViolation, fault.WorkerFault.Category);
}
[Fact]
public async Task CompleteStartupHandshakeAsync_WithWrongProtocol_FaultsBeforeInitialization()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream inbound = new();
await new WorkerFrameWriter(inbound, options).WriteAsync(CreateGatewayHelloEnvelope(supportedProtocolVersion: 999));
inbound.Position = 0;
MemoryStream outbound = new();
WorkerPipeSession session = CreateSession(inbound, outbound, options);
bool initialized = false;
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await session.CompleteStartupHandshakeAsync(
_ =>
{
initialized = true;
return Task.CompletedTask;
}));
Assert.False(initialized);
Assert.Equal(WorkerFrameProtocolErrorCode.ProtocolVersionMismatch, exception.ErrorCode);
WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options));
Assert.Equal(WorkerFaultCategory.ProtocolMismatch, fault.WorkerFault.Category);
}
[Fact]
public async Task CompleteStartupHandshakeAsync_WithMalformedFrame_WritesWorkerFault()
{
WorkerFrameProtocolOptions options = CreateOptions();
MemoryStream inbound = new(CreateFrame(new byte[] { 0x80 }));
MemoryStream outbound = new();
WorkerPipeSession session = CreateSession(inbound, outbound, options);
bool initialized = false;
WorkerFrameProtocolException exception =
await Assert.ThrowsAsync<WorkerFrameProtocolException>(
async () => await session.CompleteStartupHandshakeAsync(
_ =>
{
initialized = true;
return Task.CompletedTask;
}));
Assert.False(initialized);
Assert.Equal(WorkerFrameProtocolErrorCode.InvalidEnvelope, exception.ErrorCode);
WorkerEnvelope fault = Assert.Single(ReadWrittenFrames(outbound, options));
Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerFault, fault.BodyCase);
Assert.Equal(WorkerFaultCategory.ProtocolViolation, fault.WorkerFault.Category);
}
private static WorkerPipeSession CreateSession(
Stream inbound,
Stream outbound,
WorkerFrameProtocolOptions options)
{
return new WorkerPipeSession(
new WorkerFrameReader(inbound, options),
new WorkerFrameWriter(outbound, options),
options,
() => 1234);
}
private static WorkerFrameProtocolOptions CreateOptions()
{
return new WorkerFrameProtocolOptions(
SessionId,
GatewayContractInfo.WorkerProtocolVersion,
Nonce);
}
private static WorkerEnvelope CreateGatewayHelloEnvelope(
string nonce = Nonce,
uint supportedProtocolVersion = GatewayContractInfo.WorkerProtocolVersion)
{
return new WorkerEnvelope
{
ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion,
SessionId = SessionId,
Sequence = 1,
GatewayHello = new GatewayHello
{
SupportedProtocolVersion = supportedProtocolVersion,
Nonce = nonce,
GatewayVersion = "test-gateway",
},
};
}
private static WorkerEnvelope[] ReadWrittenFrames(
MemoryStream stream,
WorkerFrameProtocolOptions options)
{
stream.Position = 0;
WorkerFrameReader reader = new(stream, options);
List<WorkerEnvelope> envelopes = new();
while (stream.Position < stream.Length)
{
envelopes.Add(reader.ReadAsync(CancellationToken.None).GetAwaiter().GetResult());
}
return envelopes.ToArray();
}
private static byte[] CreateFrame(byte[] payload)
{
byte[] frame = new byte[sizeof(uint) + payload.Length];
WriteUInt32LittleEndian(frame, (uint)payload.Length);
payload.CopyTo(frame, sizeof(uint));
return frame;
}
private static void WriteUInt32LittleEndian(
byte[] buffer,
uint value)
{
buffer[0] = (byte)value;
buffer[1] = (byte)(value >> 8);
buffer[2] = (byte)(value >> 16);
buffer[3] = (byte)(value >> 24);
}
}
@@ -7,4 +7,6 @@ public enum WorkerExitCode
InvalidArguments = 2,
InvalidProtocolVersion = 3,
MissingNonce = 4,
PipeConnectionFailed = 5,
ProtocolViolation = 6,
}
@@ -0,0 +1,12 @@
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Worker.Bootstrap;
namespace MxGateway.Worker.Ipc;
public interface IWorkerPipeClient
{
Task RunAsync(
WorkerOptions options,
CancellationToken cancellationToken = default);
}
@@ -0,0 +1,33 @@
using System;
using MxGateway.Contracts.Proto;
namespace MxGateway.Worker.Ipc;
internal static class WorkerEnvelopeValidator
{
public static void Validate(
WorkerEnvelope envelope,
WorkerFrameProtocolOptions options)
{
if (envelope.ProtocolVersion != options.ProtocolVersion)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.ProtocolVersionMismatch,
$"Worker envelope protocol version {envelope.ProtocolVersion} does not match expected version {options.ProtocolVersion}.");
}
if (!string.Equals(envelope.SessionId, options.SessionId, StringComparison.Ordinal))
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.SessionMismatch,
"Worker envelope session id does not match the owning worker session.");
}
if (envelope.BodyCase == WorkerEnvelope.BodyOneofCase.None)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.InvalidEnvelope,
"Worker envelope must include a typed body.");
}
}
}
@@ -0,0 +1,15 @@
namespace MxGateway.Worker.Ipc;
public enum WorkerFrameProtocolErrorCode
{
Unknown = 0,
InvalidConfiguration = 1,
EndOfStream = 2,
MalformedLength = 3,
MessageTooLarge = 4,
InvalidEnvelope = 5,
ProtocolVersionMismatch = 6,
SessionMismatch = 7,
NonceMismatch = 8,
UnexpectedEnvelopeBody = 9,
}
@@ -0,0 +1,25 @@
using System;
namespace MxGateway.Worker.Ipc;
public sealed class WorkerFrameProtocolException : Exception
{
public WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode errorCode,
string message)
: base(message)
{
ErrorCode = errorCode;
}
public WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode errorCode,
string message,
Exception innerException)
: base(message, innerException)
{
ErrorCode = errorCode;
}
public WorkerFrameProtocolErrorCode ErrorCode { get; }
}
@@ -0,0 +1,86 @@
using System;
using MxGateway.Contracts;
using MxGateway.Worker.Bootstrap;
namespace MxGateway.Worker.Ipc;
public sealed class WorkerFrameProtocolOptions
{
public const int DefaultMaxMessageBytes = 16 * 1024 * 1024;
public WorkerFrameProtocolOptions(WorkerOptions options)
: this(
options?.SessionId ?? throw new ArgumentNullException(nameof(options)),
options.ProtocolVersion,
options.Nonce,
DefaultMaxMessageBytes)
{
}
public WorkerFrameProtocolOptions(
string sessionId,
uint protocolVersion,
string nonce)
: this(
sessionId,
protocolVersion,
nonce,
DefaultMaxMessageBytes)
{
}
public WorkerFrameProtocolOptions(
string sessionId,
uint protocolVersion,
string nonce,
int maxMessageBytes)
{
if (string.IsNullOrWhiteSpace(sessionId))
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.InvalidConfiguration,
"Worker frame protocol requires a session id.");
}
if (protocolVersion == 0)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.InvalidConfiguration,
"Worker frame protocol requires a non-zero protocol version.");
}
if (protocolVersion != GatewayContractInfo.WorkerProtocolVersion)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.ProtocolVersionMismatch,
$"Worker frame protocol version {protocolVersion} is not supported.");
}
if (string.IsNullOrWhiteSpace(nonce))
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.InvalidConfiguration,
"Worker frame protocol requires a nonce.");
}
if (maxMessageBytes <= 0)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.InvalidConfiguration,
"Worker frame protocol max message size must be greater than zero.");
}
SessionId = sessionId;
ProtocolVersion = protocolVersion;
Nonce = nonce;
MaxMessageBytes = maxMessageBytes;
}
public string SessionId { get; }
public uint ProtocolVersion { get; }
public string Nonce { get; }
public int MaxMessageBytes { get; }
}
@@ -0,0 +1,93 @@
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Google.Protobuf;
using MxGateway.Contracts.Proto;
namespace MxGateway.Worker.Ipc;
public sealed class WorkerFrameReader
{
private readonly WorkerFrameProtocolOptions _options;
private readonly Stream _stream;
public WorkerFrameReader(
Stream stream,
WorkerFrameProtocolOptions options)
{
_stream = stream ?? throw new ArgumentNullException(nameof(stream));
_options = options ?? throw new ArgumentNullException(nameof(options));
}
public async Task<WorkerEnvelope> ReadAsync(CancellationToken cancellationToken = default)
{
byte[] lengthPrefix = new byte[sizeof(uint)];
await ReadExactlyOrThrowAsync(lengthPrefix, cancellationToken).ConfigureAwait(false);
uint payloadLength = ReadUInt32LittleEndian(lengthPrefix);
if (payloadLength == 0)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.MalformedLength,
"Worker frame payload length must be greater than zero.");
}
if (payloadLength > _options.MaxMessageBytes)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.MessageTooLarge,
$"Worker frame payload length {payloadLength} exceeds the configured maximum of {_options.MaxMessageBytes} bytes.");
}
byte[] payload = new byte[payloadLength];
await ReadExactlyOrThrowAsync(payload, cancellationToken).ConfigureAwait(false);
WorkerEnvelope envelope;
try
{
envelope = WorkerEnvelope.Parser.ParseFrom(payload);
}
catch (InvalidProtocolBufferException exception)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.InvalidEnvelope,
"Worker frame payload is not a valid WorkerEnvelope protobuf message.",
exception);
}
WorkerEnvelopeValidator.Validate(envelope, _options);
return envelope;
}
private static uint ReadUInt32LittleEndian(byte[] buffer)
{
return (uint)buffer[0]
| ((uint)buffer[1] << 8)
| ((uint)buffer[2] << 16)
| ((uint)buffer[3] << 24);
}
private async Task ReadExactlyOrThrowAsync(
byte[] buffer,
CancellationToken cancellationToken)
{
int offset = 0;
while (offset < buffer.Length)
{
int bytesRead = await _stream
.ReadAsync(buffer, offset, buffer.Length - offset, cancellationToken)
.ConfigureAwait(false);
if (bytesRead == 0)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.EndOfStream,
"Worker frame ended before the expected number of bytes were read.");
}
offset += bytesRead;
}
}
}
@@ -0,0 +1,76 @@
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Google.Protobuf;
using MxGateway.Contracts.Proto;
namespace MxGateway.Worker.Ipc;
public sealed class WorkerFrameWriter
{
private readonly WorkerFrameProtocolOptions _options;
private readonly SemaphoreSlim _writeLock = new(1, 1);
private readonly Stream _stream;
public WorkerFrameWriter(
Stream stream,
WorkerFrameProtocolOptions options)
{
_stream = stream ?? throw new ArgumentNullException(nameof(stream));
_options = options ?? throw new ArgumentNullException(nameof(options));
}
public async Task WriteAsync(
WorkerEnvelope envelope,
CancellationToken cancellationToken = default)
{
if (envelope is null)
{
throw new ArgumentNullException(nameof(envelope));
}
WorkerEnvelopeValidator.Validate(envelope, _options);
int payloadLength = envelope.CalculateSize();
if (payloadLength == 0)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.InvalidEnvelope,
"Worker envelope cannot serialize to an empty payload.");
}
if (payloadLength > _options.MaxMessageBytes)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.MessageTooLarge,
$"Worker envelope payload length {payloadLength} exceeds the configured maximum of {_options.MaxMessageBytes} bytes.");
}
byte[] payload = envelope.ToByteArray();
byte[] lengthPrefix = new byte[sizeof(uint)];
WriteUInt32LittleEndian(lengthPrefix, (uint)payloadLength);
await _writeLock.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await _stream.WriteAsync(lengthPrefix, 0, lengthPrefix.Length, cancellationToken).ConfigureAwait(false);
await _stream.WriteAsync(payload, 0, payload.Length, cancellationToken).ConfigureAwait(false);
await _stream.FlushAsync(cancellationToken).ConfigureAwait(false);
}
finally
{
_writeLock.Release();
}
}
private static void WriteUInt32LittleEndian(
byte[] buffer,
uint value)
{
buffer[0] = (byte)value;
buffer[1] = (byte)(value >> 8);
buffer[2] = (byte)(value >> 16);
buffer[3] = (byte)(value >> 24);
}
}
@@ -0,0 +1,67 @@
using System;
using System.IO.Pipes;
using System.Threading;
using System.Threading.Tasks;
using MxGateway.Worker.Bootstrap;
namespace MxGateway.Worker.Ipc;
public sealed class WorkerPipeClient : IWorkerPipeClient
{
public const int DefaultConnectTimeoutMilliseconds = 30000;
private readonly int _connectTimeoutMilliseconds;
public WorkerPipeClient()
: this(DefaultConnectTimeoutMilliseconds)
{
}
public WorkerPipeClient(int connectTimeoutMilliseconds)
{
if (connectTimeoutMilliseconds <= 0)
{
throw new ArgumentOutOfRangeException(
nameof(connectTimeoutMilliseconds),
"Worker pipe connect timeout must be greater than zero.");
}
_connectTimeoutMilliseconds = connectTimeoutMilliseconds;
}
public async Task RunAsync(
WorkerOptions options,
CancellationToken cancellationToken = default)
{
if (options is null)
{
throw new ArgumentNullException(nameof(options));
}
WorkerFrameProtocolOptions frameOptions = new(options);
using NamedPipeClientStream pipe = new(
".",
options.PipeName,
PipeDirection.InOut,
PipeOptions.Asynchronous);
await ConnectAsync(pipe, cancellationToken).ConfigureAwait(false);
WorkerPipeSession session = new(pipe, frameOptions);
await session.CompleteStartupHandshakeAsync(cancellationToken).ConfigureAwait(false);
}
private Task ConnectAsync(
NamedPipeClientStream pipe,
CancellationToken cancellationToken)
{
return Task.Run(
() =>
{
cancellationToken.ThrowIfCancellationRequested();
pipe.Connect(_connectTimeoutMilliseconds);
},
cancellationToken);
}
}
@@ -0,0 +1,218 @@
using System;
using System.Diagnostics;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Google.Protobuf.WellKnownTypes;
using MxGateway.Contracts.Proto;
using MxGateway.Worker.MxAccess;
namespace MxGateway.Worker.Ipc;
public sealed class WorkerPipeSession
{
private readonly WorkerFrameProtocolOptions _options;
private readonly Func<int> _processIdProvider;
private readonly WorkerFrameReader _reader;
private readonly WorkerFrameWriter _writer;
private long _nextSequence;
public WorkerPipeSession(
Stream stream,
WorkerFrameProtocolOptions options)
: this(
new WorkerFrameReader(stream, options),
new WorkerFrameWriter(stream, options),
options,
() => Process.GetCurrentProcess().Id)
{
}
public WorkerPipeSession(
WorkerFrameReader reader,
WorkerFrameWriter writer,
WorkerFrameProtocolOptions options,
Func<int> processIdProvider)
{
_reader = reader ?? throw new ArgumentNullException(nameof(reader));
_writer = writer ?? throw new ArgumentNullException(nameof(writer));
_options = options ?? throw new ArgumentNullException(nameof(options));
_processIdProvider = processIdProvider ?? throw new ArgumentNullException(nameof(processIdProvider));
}
public Task CompleteStartupHandshakeAsync(CancellationToken cancellationToken = default)
{
return CompleteStartupHandshakeAsync(_ => Task.CompletedTask, cancellationToken);
}
public async Task CompleteStartupHandshakeAsync(
Func<CancellationToken, Task> initializeMxAccessAsync,
CancellationToken cancellationToken = default)
{
if (initializeMxAccessAsync is null)
{
throw new ArgumentNullException(nameof(initializeMxAccessAsync));
}
try
{
WorkerEnvelope envelope = await _reader.ReadAsync(cancellationToken).ConfigureAwait(false);
ValidateGatewayHello(envelope);
await WriteWorkerHelloAsync(cancellationToken).ConfigureAwait(false);
await initializeMxAccessAsync(cancellationToken).ConfigureAwait(false);
await WriteWorkerReadyAsync(cancellationToken).ConfigureAwait(false);
}
catch (WorkerFrameProtocolException exception)
{
await TryWriteFaultAsync(exception, cancellationToken).ConfigureAwait(false);
throw;
}
}
private void ValidateGatewayHello(WorkerEnvelope envelope)
{
if (envelope.BodyCase != WorkerEnvelope.BodyOneofCase.GatewayHello)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.UnexpectedEnvelopeBody,
"Worker expected GatewayHello during startup handshake.");
}
GatewayHello gatewayHello = envelope.GatewayHello;
if (gatewayHello.SupportedProtocolVersion != _options.ProtocolVersion)
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.ProtocolVersionMismatch,
$"GatewayHello supported protocol version {gatewayHello.SupportedProtocolVersion} does not match expected version {_options.ProtocolVersion}.");
}
if (!string.Equals(gatewayHello.Nonce, _options.Nonce, StringComparison.Ordinal))
{
throw new WorkerFrameProtocolException(
WorkerFrameProtocolErrorCode.NonceMismatch,
"GatewayHello nonce does not match the worker launch nonce.");
}
}
private Task WriteWorkerHelloAsync(CancellationToken cancellationToken)
{
return _writer.WriteAsync(
CreateEnvelope(new WorkerHello
{
ProtocolVersion = _options.ProtocolVersion,
Nonce = _options.Nonce,
WorkerProcessId = _processIdProvider(),
WorkerVersion = typeof(WorkerPipeSession).Assembly.GetName().Version?.ToString() ?? string.Empty,
}),
cancellationToken);
}
private Task WriteWorkerReadyAsync(CancellationToken cancellationToken)
{
return _writer.WriteAsync(
CreateEnvelope(new WorkerReady
{
WorkerProcessId = _processIdProvider(),
MxaccessProgid = MxAccessInteropInfo.ProgId,
MxaccessClsid = MxAccessInteropInfo.Clsid,
ReadyTimestamp = Timestamp.FromDateTime(DateTime.UtcNow),
}),
cancellationToken);
}
private async Task TryWriteFaultAsync(
WorkerFrameProtocolException exception,
CancellationToken cancellationToken)
{
try
{
await _writer
.WriteAsync(CreateEnvelope(CreateFault(exception)), cancellationToken)
.ConfigureAwait(false);
}
catch (Exception faultWriteException) when (
faultWriteException is IOException
|| faultWriteException is ObjectDisposedException
|| faultWriteException is WorkerFrameProtocolException)
{
// The original protocol failure is the actionable error.
}
}
private WorkerEnvelope CreateEnvelope(WorkerHello hello)
{
return CreateBaseEnvelope(hello);
}
private WorkerEnvelope CreateEnvelope(WorkerReady ready)
{
return CreateBaseEnvelope(ready);
}
private WorkerEnvelope CreateEnvelope(WorkerFault fault)
{
return CreateBaseEnvelope(fault);
}
private WorkerEnvelope CreateBaseEnvelope(WorkerHello body)
{
WorkerEnvelope envelope = CreateBaseEnvelope();
envelope.WorkerHello = body;
return envelope;
}
private WorkerEnvelope CreateBaseEnvelope(WorkerReady body)
{
WorkerEnvelope envelope = CreateBaseEnvelope();
envelope.WorkerReady = body;
return envelope;
}
private WorkerEnvelope CreateBaseEnvelope(WorkerFault body)
{
WorkerEnvelope envelope = CreateBaseEnvelope();
envelope.WorkerFault = body;
return envelope;
}
private WorkerEnvelope CreateBaseEnvelope()
{
return new WorkerEnvelope
{
ProtocolVersion = _options.ProtocolVersion,
SessionId = _options.SessionId,
Sequence = NextSequence(),
};
}
private ulong NextSequence()
{
return unchecked((ulong)Interlocked.Increment(ref _nextSequence));
}
private static WorkerFault CreateFault(WorkerFrameProtocolException exception)
{
return new WorkerFault
{
Category = MapFaultCategory(exception.ErrorCode),
ExceptionType = exception.GetType().FullName ?? string.Empty,
DiagnosticMessage = exception.Message,
ProtocolStatus = new ProtocolStatus
{
Code = ProtocolStatusCode.ProtocolViolation,
Message = exception.Message,
},
};
}
private static WorkerFaultCategory MapFaultCategory(WorkerFrameProtocolErrorCode errorCode)
{
return errorCode switch
{
WorkerFrameProtocolErrorCode.ProtocolVersionMismatch => WorkerFaultCategory.ProtocolMismatch,
WorkerFrameProtocolErrorCode.EndOfStream => WorkerFaultCategory.PipeDisconnected,
_ => WorkerFaultCategory.ProtocolViolation,
};
}
}
+52 -1
View File
@@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using System.IO;
using MxGateway.Worker.Bootstrap;
using MxGateway.Worker.Ipc;
namespace MxGateway.Worker;
@@ -11,13 +13,27 @@ public static class WorkerApplication
return Run(
args,
new EnvironmentVariableWorkerEnvironment(),
new WorkerConsoleLogger(Console.Error));
new WorkerConsoleLogger(Console.Error),
new WorkerPipeClient());
}
public static int Run(
string[] args,
IWorkerEnvironment environment,
IWorkerLogger logger)
{
return Run(
args,
environment,
logger,
new WorkerPipeClient());
}
public static int Run(
string[] args,
IWorkerEnvironment environment,
IWorkerLogger logger,
IWorkerPipeClient pipeClient)
{
if (args is null)
{
@@ -34,6 +50,11 @@ public static class WorkerApplication
throw new ArgumentNullException(nameof(logger));
}
if (pipeClient is null)
{
throw new ArgumentNullException(nameof(pipeClient));
}
try
{
WorkerOptionsParser parser = new(environment);
@@ -61,8 +82,38 @@ public static class WorkerApplication
["nonce"] = options.Nonce,
});
pipeClient.RunAsync(options).GetAwaiter().GetResult();
logger.Information("WorkerPipeHandshakeSucceeded", new Dictionary<string, object?>
{
["session_id"] = options.SessionId,
["pipe_name"] = options.PipeName,
["protocol_version"] = options.ProtocolVersion,
});
return (int)WorkerExitCode.Success;
}
catch (WorkerFrameProtocolException exception)
{
logger.Error("WorkerPipeProtocolFailure", new Dictionary<string, object?>
{
["exit_code"] = WorkerExitCode.ProtocolViolation,
["error_code"] = exception.ErrorCode,
["exception_type"] = exception.GetType().FullName,
});
return (int)WorkerExitCode.ProtocolViolation;
}
catch (Exception exception) when (exception is IOException or TimeoutException)
{
logger.Error("WorkerPipeConnectionFailed", new Dictionary<string, object?>
{
["exit_code"] = WorkerExitCode.PipeConnectionFailed,
["exception_type"] = exception.GetType().FullName,
});
return (int)WorkerExitCode.PipeConnectionFailed;
}
catch (Exception exception)
{
logger.Error("WorkerBootstrapUnexpectedFailure", new Dictionary<string, object?>