From fce9e9955367a21e94b1a29b2fec144fe71d92ac Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 26 Apr 2026 17:09:39 -0400 Subject: [PATCH] Issue #11: implement gateway workerclient --- docs/gateway-process-design.md | 10 +- src/MxGateway.Server/Workers/IWorkerClient.cs | 27 + src/MxGateway.Server/Workers/WorkerClient.cs | 755 ++++++++++++++++++ .../Workers/WorkerClientConnection.cs | 38 + .../Workers/WorkerClientErrorCode.cs | 14 + .../Workers/WorkerClientException.cs | 23 + .../Workers/WorkerClientOptions.cs | 24 + .../Workers/WorkerClientState.cs | 11 + .../Gateway/Workers/WorkerClientTests.cs | 341 ++++++++ 9 files changed, 1241 insertions(+), 2 deletions(-) create mode 100644 src/MxGateway.Server/Workers/IWorkerClient.cs create mode 100644 src/MxGateway.Server/Workers/WorkerClient.cs create mode 100644 src/MxGateway.Server/Workers/WorkerClientConnection.cs create mode 100644 src/MxGateway.Server/Workers/WorkerClientErrorCode.cs create mode 100644 src/MxGateway.Server/Workers/WorkerClientException.cs create mode 100644 src/MxGateway.Server/Workers/WorkerClientOptions.cs create mode 100644 src/MxGateway.Server/Workers/WorkerClientState.cs create mode 100644 src/MxGateway.Tests/Gateway/Workers/WorkerClientTests.cs diff --git a/docs/gateway-process-design.md b/docs/gateway-process-design.md index 6d35349..a300590 100644 --- a/docs/gateway-process-design.md +++ b/docs/gateway-process-design.md @@ -411,7 +411,7 @@ session ids as protocol faults and close the session. `WorkerClient` is the gateway-side object that owns one worker connection. -Suggested public shape: +Current public shape: ```csharp public interface IWorkerClient : IAsyncDisposable @@ -419,6 +419,7 @@ public interface IWorkerClient : IAsyncDisposable string SessionId { get; } int? ProcessId { get; } WorkerClientState State { get; } + DateTimeOffset LastHeartbeatAt { get; } Task StartAsync(CancellationToken cancellationToken); Task InvokeAsync( @@ -438,12 +439,17 @@ Internally it owns: - pipe stream, - read loop, - write loop, -- bounded outbound command/control channel, +- outbound command/control channel serialized by the write loop, - bounded inbound event channel, - pending command dictionary keyed by correlation id, - heartbeat monitor, - terminal fault source. +`StartAsync` sends `GatewayHello`, verifies the `WorkerHello` protocol version +and nonce, waits for `WorkerReady`, and only then exposes `Ready` state. The +read loop starts after readiness so the handshake has a single owner for its +ordered frames. + ### Read Loop The read loop: diff --git a/src/MxGateway.Server/Workers/IWorkerClient.cs b/src/MxGateway.Server/Workers/IWorkerClient.cs new file mode 100644 index 0000000..7d8c69d --- /dev/null +++ b/src/MxGateway.Server/Workers/IWorkerClient.cs @@ -0,0 +1,27 @@ +using MxGateway.Contracts.Proto; + +namespace MxGateway.Server.Workers; + +public interface IWorkerClient : IAsyncDisposable +{ + string SessionId { get; } + + int? ProcessId { get; } + + WorkerClientState State { get; } + + DateTimeOffset LastHeartbeatAt { get; } + + Task StartAsync(CancellationToken cancellationToken); + + Task InvokeAsync( + WorkerCommand command, + TimeSpan timeout, + CancellationToken cancellationToken); + + IAsyncEnumerable ReadEventsAsync(CancellationToken cancellationToken); + + Task ShutdownAsync(TimeSpan timeout, CancellationToken cancellationToken); + + void Kill(string reason); +} diff --git a/src/MxGateway.Server/Workers/WorkerClient.cs b/src/MxGateway.Server/Workers/WorkerClient.cs new file mode 100644 index 0000000..9e0daa9 --- /dev/null +++ b/src/MxGateway.Server/Workers/WorkerClient.cs @@ -0,0 +1,755 @@ +using System.Collections.Concurrent; +using System.Runtime.CompilerServices; +using System.Threading.Channels; +using Google.Protobuf.WellKnownTypes; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using MxGateway.Contracts; +using MxGateway.Contracts.Proto; +using MxGateway.Server.Metrics; + +namespace MxGateway.Server.Workers; + +public sealed class WorkerClient : IWorkerClient +{ + private const string GatewayVersionFallback = "unknown"; + private readonly object _syncRoot = new(); + private readonly WorkerClientConnection _connection; + private readonly WorkerClientOptions _options; + private readonly GatewayMetrics? _metrics; + private readonly TimeProvider _timeProvider; + private readonly ILogger _logger; + private readonly WorkerFrameReader _reader; + private readonly WorkerFrameWriter _writer; + private readonly Channel _outboundEnvelopes; + private readonly Channel _events; + private readonly ConcurrentDictionary _pendingCommands = new(StringComparer.Ordinal); + private readonly CancellationTokenSource _stopCts = new(); + private long _nextSequence; + private WorkerClientState _state; + private DateTimeOffset _lastHeartbeatAt; + private int? _processId; + private Task? _readLoopTask; + private Task? _writeLoopTask; + private Task? _heartbeatLoopTask; + private bool _disposed; + + public WorkerClient( + WorkerClientConnection connection, + WorkerClientOptions? options = null, + GatewayMetrics? metrics = null, + TimeProvider? timeProvider = null, + ILogger? logger = null) + { + _connection = connection ?? throw new ArgumentNullException(nameof(connection)); + _options = options ?? new WorkerClientOptions(); + _metrics = metrics; + _timeProvider = timeProvider ?? TimeProvider.System; + _logger = logger ?? NullLogger.Instance; + _reader = new WorkerFrameReader(connection.Stream, connection.FrameOptions); + _writer = new WorkerFrameWriter(connection.Stream, connection.FrameOptions); + _outboundEnvelopes = Channel.CreateUnbounded( + new UnboundedChannelOptions + { + SingleReader = true, + SingleWriter = false, + AllowSynchronousContinuations = false, + }); + _events = Channel.CreateBounded( + new BoundedChannelOptions(_options.EventChannelCapacity) + { + SingleReader = false, + SingleWriter = true, + FullMode = BoundedChannelFullMode.Wait, + AllowSynchronousContinuations = false, + }); + _lastHeartbeatAt = _timeProvider.GetUtcNow(); + } + + public string SessionId => _connection.SessionId; + + public int? ProcessId + { + get + { + lock (_syncRoot) + { + return _processId; + } + } + } + + public WorkerClientState State + { + get + { + lock (_syncRoot) + { + return _state; + } + } + } + + public DateTimeOffset LastHeartbeatAt + { + get + { + lock (_syncRoot) + { + return _lastHeartbeatAt; + } + } + } + + public async Task StartAsync(CancellationToken cancellationToken) + { + ThrowIfDisposed(); + TransitionFromCreatedToHandshaking(); + + _writeLoopTask = Task.Run(WriteLoopAsync); + await EnqueueAsync(CreateGatewayHelloEnvelope(), cancellationToken).ConfigureAwait(false); + + WorkerEnvelope helloEnvelope = await ReadHandshakeEnvelopeAsync( + WorkerEnvelope.BodyOneofCase.WorkerHello, + cancellationToken).ConfigureAwait(false); + ValidateWorkerHello(helloEnvelope.WorkerHello); + + WorkerEnvelope readyEnvelope = await ReadHandshakeEnvelopeAsync( + WorkerEnvelope.BodyOneofCase.WorkerReady, + cancellationToken).ConfigureAwait(false); + MarkReady(readyEnvelope.WorkerReady); + + _readLoopTask = Task.Run(ReadLoopAsync); + _heartbeatLoopTask = Task.Run(HeartbeatLoopAsync); + } + + public async Task InvokeAsync( + WorkerCommand command, + TimeSpan timeout, + CancellationToken cancellationToken) + { + ArgumentNullException.ThrowIfNull(command); + ThrowIfDisposed(); + EnsureReady(); + + if (timeout <= TimeSpan.Zero) + { + throw new ArgumentOutOfRangeException(nameof(timeout), timeout, "Command timeout must be greater than zero."); + } + + string correlationId = Guid.NewGuid().ToString("N"); + string method = GetCommandMethod(command); + PendingCommand pendingCommand = new( + correlationId, + method, + _timeProvider.GetTimestamp()); + + if (!_pendingCommands.TryAdd(correlationId, pendingCommand)) + { + throw new InvalidOperationException("Generated a duplicate command correlation id."); + } + + _metrics?.CommandStarted(method); + + try + { + await EnqueueAsync(CreateCommandEnvelope(correlationId, command), cancellationToken).ConfigureAwait(false); + using CancellationTokenSource timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + Task timeoutTask = Task.Delay(timeout, timeoutCts.Token); + Task replyTask = pendingCommand.Task; + Task completedTask = await Task.WhenAny(replyTask, timeoutTask).ConfigureAwait(false); + + if (completedTask == replyTask) + { + await timeoutCts.CancelAsync().ConfigureAwait(false); + return await replyTask.ConfigureAwait(false); + } + + if (cancellationToken.IsCancellationRequested) + { + RemovePendingCommandAsFailed( + correlationId, + pendingCommand, + WorkerClientErrorCode.GatewayShutdown, + "Command wait was canceled."); + cancellationToken.ThrowIfCancellationRequested(); + } + + RemovePendingCommandAsFailed( + correlationId, + pendingCommand, + WorkerClientErrorCode.CommandTimeout, + $"Worker command {method} timed out after {timeout}."); + + throw new WorkerClientException( + WorkerClientErrorCode.CommandTimeout, + $"Worker command {method} timed out after {timeout}."); + } + catch + { + _pendingCommands.TryRemove(correlationId, out _); + throw; + } + } + + public async IAsyncEnumerable ReadEventsAsync( + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (WorkerEvent workerEvent in _events.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) + { + yield return workerEvent; + } + } + + public async Task ShutdownAsync(TimeSpan timeout, CancellationToken cancellationToken) + { + ThrowIfDisposed(); + if (timeout <= TimeSpan.Zero) + { + throw new ArgumentOutOfRangeException(nameof(timeout), timeout, "Shutdown timeout must be greater than zero."); + } + + WorkerClientState state = State; + if (state is WorkerClientState.Closed or WorkerClientState.Faulted) + { + return; + } + + MarkClosing(); + await EnqueueAsync(CreateShutdownEnvelope(timeout, "gateway-shutdown"), cancellationToken).ConfigureAwait(false); + _outboundEnvelopes.Writer.TryComplete(); + + using CancellationTokenSource timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + timeoutCts.CancelAfter(timeout); + try + { + await WaitForBackgroundTasksAsync(timeoutCts.Token).ConfigureAwait(false); + MarkClosed("shutdown"); + } + catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested) + { + SetFaulted( + WorkerClientErrorCode.ShutdownTimeout, + "Worker shutdown timed out.", + null); + throw new WorkerClientException( + WorkerClientErrorCode.ShutdownTimeout, + $"Worker shutdown timed out after {timeout}."); + } + } + + public void Kill(string reason) + { + ThrowIfDisposed(); + _connection.ProcessHandle?.Process.Kill(entireProcessTree: true); + _metrics?.WorkerKilled(reason); + SetFaulted( + WorkerClientErrorCode.WorkerFaulted, + $"Worker was killed by the gateway: {reason}.", + null); + } + + public async ValueTask DisposeAsync() + { + if (_disposed) + { + return; + } + + _disposed = true; + _stopCts.Cancel(); + _outboundEnvelopes.Writer.TryComplete(); + _events.Writer.TryComplete(); + CompletePendingCommands( + new WorkerClientException( + WorkerClientErrorCode.GatewayShutdown, + "Worker client was disposed.")); + + await WaitForBackgroundTasksAsync(CancellationToken.None).ConfigureAwait(false); + await _connection.Stream.DisposeAsync().ConfigureAwait(false); + _connection.ProcessHandle?.Dispose(); + _stopCts.Dispose(); + } + + private async Task WriteLoopAsync() + { + try + { + await foreach (WorkerEnvelope envelope in _outboundEnvelopes.Reader.ReadAllAsync(_stopCts.Token).ConfigureAwait(false)) + { + await _writer.WriteAsync(envelope, _stopCts.Token).ConfigureAwait(false); + } + } + catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState()) + { + } + catch (Exception exception) + { + SetFaulted( + WorkerClientErrorCode.WriteFailed, + "Worker pipe write failed.", + exception); + } + } + + private async Task ReadLoopAsync() + { + try + { + while (!_stopCts.IsCancellationRequested) + { + WorkerEnvelope envelope = await _reader.ReadAsync(_stopCts.Token).ConfigureAwait(false); + await DispatchEnvelopeAsync(envelope, _stopCts.Token).ConfigureAwait(false); + } + } + catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState()) + { + } + catch (WorkerFrameProtocolException exception) when (exception.ErrorCode == WorkerFrameProtocolErrorCode.EndOfStream) + { + SetFaulted( + WorkerClientErrorCode.PipeDisconnected, + "Worker pipe disconnected.", + exception); + } + catch (Exception exception) + { + SetFaulted( + WorkerClientErrorCode.ProtocolViolation, + "Worker read loop failed.", + exception); + } + } + + private async Task HeartbeatLoopAsync() + { + try + { + while (!_stopCts.IsCancellationRequested) + { + await Task.Delay(_options.HeartbeatCheckInterval, _stopCts.Token).ConfigureAwait(false); + if (State != WorkerClientState.Ready) + { + continue; + } + + DateTimeOffset lastHeartbeatAt = LastHeartbeatAt; + DateTimeOffset now = _timeProvider.GetUtcNow(); + if (now - lastHeartbeatAt <= _options.HeartbeatGrace) + { + continue; + } + + _metrics?.HeartbeatFailed(SessionId); + SetFaulted( + WorkerClientErrorCode.HeartbeatExpired, + $"Worker heartbeat expired. Last heartbeat was at {lastHeartbeatAt:O}.", + null); + } + } + catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState()) + { + } + } + + private async Task DispatchEnvelopeAsync( + WorkerEnvelope envelope, + CancellationToken cancellationToken) + { + switch (envelope.BodyCase) + { + case WorkerEnvelope.BodyOneofCase.WorkerCommandReply: + CompleteCommand(envelope); + break; + case WorkerEnvelope.BodyOneofCase.WorkerEvent: + await EnqueueWorkerEventAsync(envelope.WorkerEvent, cancellationToken).ConfigureAwait(false); + break; + case WorkerEnvelope.BodyOneofCase.WorkerHeartbeat: + MarkHeartbeat(envelope.WorkerHeartbeat); + break; + case WorkerEnvelope.BodyOneofCase.WorkerFault: + SetFaulted( + WorkerClientErrorCode.WorkerFaulted, + CreateWorkerFaultMessage(envelope.WorkerFault), + null); + break; + case WorkerEnvelope.BodyOneofCase.WorkerShutdownAck: + MarkClosed("worker-shutdown-ack"); + break; + default: + SetFaulted( + WorkerClientErrorCode.ProtocolViolation, + $"Worker sent unexpected envelope body {envelope.BodyCase}.", + null); + break; + } + } + + private async Task EnqueueWorkerEventAsync( + WorkerEvent workerEvent, + CancellationToken cancellationToken) + { + if (workerEvent.Event is not null) + { + _metrics?.EventReceived(SessionId, workerEvent.Event.Family.ToString()); + } + + if (!await _events.Writer.WaitToWriteAsync(cancellationToken).ConfigureAwait(false)) + { + return; + } + + if (!_events.Writer.TryWrite(workerEvent)) + { + _metrics?.QueueOverflow("worker-events"); + SetFaulted( + WorkerClientErrorCode.ProtocolViolation, + "Worker event channel rejected an event.", + null); + } + } + + private void CompleteCommand(WorkerEnvelope envelope) + { + string correlationId = envelope.CorrelationId; + if (string.IsNullOrWhiteSpace(correlationId)) + { + correlationId = envelope.WorkerCommandReply.Reply?.CorrelationId ?? string.Empty; + } + + if (!_pendingCommands.TryRemove(correlationId, out PendingCommand? pendingCommand)) + { + _logger.LogDebug( + "Ignoring late or unknown worker command reply for session {SessionId} and correlation {CorrelationId}.", + SessionId, + correlationId); + return; + } + + TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp); + _metrics?.CommandSucceeded(pendingCommand.Method, duration); + pendingCommand.SetResult(envelope.WorkerCommandReply); + } + + private void RemovePendingCommandAsFailed( + string correlationId, + PendingCommand pendingCommand, + WorkerClientErrorCode errorCode, + string message) + { + if (!_pendingCommands.TryRemove(correlationId, out _)) + { + return; + } + + TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp); + _metrics?.CommandFailed(pendingCommand.Method, errorCode.ToString(), duration); + pendingCommand.SetException(new WorkerClientException(errorCode, message)); + } + + private async Task ReadHandshakeEnvelopeAsync( + WorkerEnvelope.BodyOneofCase expectedBody, + CancellationToken cancellationToken) + { + WorkerEnvelope envelope = await _reader.ReadAsync(cancellationToken).ConfigureAwait(false); + if (envelope.BodyCase != expectedBody) + { + throw new WorkerClientException( + WorkerClientErrorCode.ProtocolViolation, + $"Worker handshake expected {expectedBody} but received {envelope.BodyCase}."); + } + + return envelope; + } + + private void ValidateWorkerHello(WorkerHello workerHello) + { + if (workerHello.ProtocolVersion != _connection.FrameOptions.ProtocolVersion) + { + throw new WorkerClientException( + WorkerClientErrorCode.ProtocolViolation, + "Worker hello protocol version does not match the gateway protocol version."); + } + + if (!string.Equals(workerHello.Nonce, _connection.Nonce, StringComparison.Ordinal)) + { + throw new WorkerClientException( + WorkerClientErrorCode.ProtocolViolation, + "Worker hello nonce does not match the gateway nonce."); + } + + lock (_syncRoot) + { + _processId = workerHello.WorkerProcessId == 0 + ? _connection.ProcessHandle?.ProcessId + : workerHello.WorkerProcessId; + } + } + + private void MarkReady(WorkerReady ready) + { + lock (_syncRoot) + { + _processId = ready.WorkerProcessId == 0 + ? _processId ?? _connection.ProcessHandle?.ProcessId + : ready.WorkerProcessId; + _lastHeartbeatAt = _timeProvider.GetUtcNow(); + _state = WorkerClientState.Ready; + } + + DateTimeOffset readyAt = _timeProvider.GetUtcNow(); + DateTimeOffset launchedAt = _connection.ProcessHandle?.LaunchedAt ?? readyAt; + _metrics?.WorkerStarted(readyAt - launchedAt); + } + + private void MarkHeartbeat(WorkerHeartbeat heartbeat) + { + lock (_syncRoot) + { + _lastHeartbeatAt = _timeProvider.GetUtcNow(); + if (heartbeat.WorkerProcessId != 0) + { + _processId = heartbeat.WorkerProcessId; + } + } + } + + private void MarkClosing() + { + lock (_syncRoot) + { + if (_state is WorkerClientState.Closed or WorkerClientState.Faulted) + { + return; + } + + _state = WorkerClientState.Closing; + } + } + + private void MarkClosed(string reason) + { + lock (_syncRoot) + { + if (_state == WorkerClientState.Closed) + { + return; + } + + _state = WorkerClientState.Closed; + } + + _stopCts.Cancel(); + _outboundEnvelopes.Writer.TryComplete(); + _events.Writer.TryComplete(); + CompletePendingCommands( + new WorkerClientException( + WorkerClientErrorCode.GatewayShutdown, + $"Worker client closed because {reason}.")); + _metrics?.WorkerStopped(reason); + } + + private void SetFaulted( + WorkerClientErrorCode errorCode, + string message, + Exception? exception) + { + WorkerClientException fault = exception is null + ? new WorkerClientException(errorCode, message) + : new WorkerClientException(errorCode, message, exception); + + lock (_syncRoot) + { + if (_state is WorkerClientState.Faulted or WorkerClientState.Closed) + { + return; + } + + _state = WorkerClientState.Faulted; + } + + _stopCts.Cancel(); + _outboundEnvelopes.Writer.TryComplete(fault); + _events.Writer.TryComplete(fault); + CompletePendingCommands(fault); + _metrics?.Fault(errorCode.ToString()); + _logger.LogWarning(exception, "Worker client faulted for session {SessionId}: {Message}", SessionId, message); + } + + private void CompletePendingCommands(Exception exception) + { + foreach (KeyValuePair item in _pendingCommands.ToArray()) + { + if (_pendingCommands.TryRemove(item.Key, out PendingCommand? pendingCommand)) + { + TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp); + _metrics?.CommandFailed(pendingCommand.Method, exception.GetType().Name, duration); + pendingCommand.SetException(exception); + } + } + } + + private void TransitionFromCreatedToHandshaking() + { + lock (_syncRoot) + { + if (_state != WorkerClientState.Created) + { + throw new WorkerClientException( + WorkerClientErrorCode.InvalidState, + $"Worker client cannot start from state {_state}."); + } + + _state = WorkerClientState.Handshaking; + } + } + + private void EnsureReady() + { + WorkerClientState state = State; + if (state != WorkerClientState.Ready) + { + throw new WorkerClientException( + WorkerClientErrorCode.InvalidState, + $"Worker client is not ready. Current state is {state}."); + } + } + + private bool IsTerminalState() + { + WorkerClientState state = State; + return state is WorkerClientState.Closing or WorkerClientState.Closed or WorkerClientState.Faulted; + } + + private async Task EnqueueAsync( + WorkerEnvelope envelope, + CancellationToken cancellationToken) + { + try + { + await _outboundEnvelopes.Writer.WriteAsync(envelope, cancellationToken).ConfigureAwait(false); + } + catch (ChannelClosedException exception) + { + throw new WorkerClientException( + WorkerClientErrorCode.WriteFailed, + "Worker outbound channel is closed.", + exception); + } + } + + private WorkerEnvelope CreateGatewayHelloEnvelope() + { + return CreateEnvelope( + correlationId: string.Empty, + envelope => envelope.GatewayHello = new GatewayHello + { + SupportedProtocolVersion = _connection.FrameOptions.ProtocolVersion, + Nonce = _connection.Nonce, + GatewayVersion = typeof(GatewayContractInfo).Assembly.GetName().Version?.ToString() ?? GatewayVersionFallback, + }); + } + + private WorkerEnvelope CreateCommandEnvelope( + string correlationId, + WorkerCommand command) + { + return CreateEnvelope( + correlationId, + envelope => envelope.WorkerCommand = command.Clone()); + } + + private WorkerEnvelope CreateShutdownEnvelope( + TimeSpan timeout, + string reason) + { + return CreateEnvelope( + correlationId: string.Empty, + envelope => envelope.WorkerShutdown = new WorkerShutdown + { + GracePeriod = Duration.FromTimeSpan(timeout), + Reason = reason, + }); + } + + private WorkerEnvelope CreateEnvelope( + string correlationId, + Action setBody) + { + WorkerEnvelope envelope = new() + { + ProtocolVersion = _connection.FrameOptions.ProtocolVersion, + SessionId = SessionId, + Sequence = (ulong)Interlocked.Increment(ref _nextSequence), + CorrelationId = correlationId, + }; + setBody(envelope); + + return envelope; + } + + private static string GetCommandMethod(WorkerCommand command) + { + return command.Command?.Kind.ToString() ?? MxCommandKind.Unspecified.ToString(); + } + + private static string CreateWorkerFaultMessage(WorkerFault fault) + { + return string.IsNullOrWhiteSpace(fault.DiagnosticMessage) + ? $"Worker faulted with category {fault.Category}." + : $"Worker faulted with category {fault.Category}: {fault.DiagnosticMessage}"; + } + + private async Task WaitForBackgroundTasksAsync(CancellationToken cancellationToken) + { + Task[] tasks = new[] { _readLoopTask, _writeLoopTask, _heartbeatLoopTask } + .Where(task => task is not null) + .Cast() + .ToArray(); + + if (tasks.Length == 0) + { + return; + } + + await Task.WhenAll(tasks).WaitAsync(cancellationToken).ConfigureAwait(false); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed, this); + } + + private sealed class PendingCommand + { + private readonly TaskCompletionSource _completion = new(TaskCreationOptions.RunContinuationsAsynchronously); + + public PendingCommand( + string correlationId, + string method, + long startTimestamp) + { + CorrelationId = correlationId; + Method = method; + StartTimestamp = startTimestamp; + } + + public string CorrelationId { get; } + + public string Method { get; } + + public long StartTimestamp { get; } + + public Task Task => _completion.Task; + + public void SetResult(WorkerCommandReply reply) + { + _completion.TrySetResult(reply); + } + + public void SetException(Exception exception) + { + _completion.TrySetException(exception); + } + } +} diff --git a/src/MxGateway.Server/Workers/WorkerClientConnection.cs b/src/MxGateway.Server/Workers/WorkerClientConnection.cs new file mode 100644 index 0000000..b4256d0 --- /dev/null +++ b/src/MxGateway.Server/Workers/WorkerClientConnection.cs @@ -0,0 +1,38 @@ +namespace MxGateway.Server.Workers; + +public sealed class WorkerClientConnection +{ + public WorkerClientConnection( + string sessionId, + string nonce, + Stream stream, + WorkerFrameProtocolOptions frameOptions, + WorkerProcessHandle? processHandle = null) + { + if (string.IsNullOrWhiteSpace(sessionId)) + { + throw new ArgumentException("Session id is required.", nameof(sessionId)); + } + + if (string.IsNullOrWhiteSpace(nonce)) + { + throw new ArgumentException("Worker nonce is required.", nameof(nonce)); + } + + SessionId = sessionId; + Nonce = nonce; + Stream = stream ?? throw new ArgumentNullException(nameof(stream)); + FrameOptions = frameOptions ?? throw new ArgumentNullException(nameof(frameOptions)); + ProcessHandle = processHandle; + } + + public string SessionId { get; } + + public string Nonce { get; } + + public Stream Stream { get; } + + public WorkerFrameProtocolOptions FrameOptions { get; } + + public WorkerProcessHandle? ProcessHandle { get; } +} diff --git a/src/MxGateway.Server/Workers/WorkerClientErrorCode.cs b/src/MxGateway.Server/Workers/WorkerClientErrorCode.cs new file mode 100644 index 0000000..452a6ed --- /dev/null +++ b/src/MxGateway.Server/Workers/WorkerClientErrorCode.cs @@ -0,0 +1,14 @@ +namespace MxGateway.Server.Workers; + +public enum WorkerClientErrorCode +{ + InvalidState, + ProtocolViolation, + PipeDisconnected, + CommandTimeout, + WorkerFaulted, + HeartbeatExpired, + ShutdownTimeout, + GatewayShutdown, + WriteFailed, +} diff --git a/src/MxGateway.Server/Workers/WorkerClientException.cs b/src/MxGateway.Server/Workers/WorkerClientException.cs new file mode 100644 index 0000000..8c98a6b --- /dev/null +++ b/src/MxGateway.Server/Workers/WorkerClientException.cs @@ -0,0 +1,23 @@ +namespace MxGateway.Server.Workers; + +public sealed class WorkerClientException : Exception +{ + public WorkerClientException( + WorkerClientErrorCode errorCode, + string message) + : base(message) + { + ErrorCode = errorCode; + } + + public WorkerClientException( + WorkerClientErrorCode errorCode, + string message, + Exception innerException) + : base(message, innerException) + { + ErrorCode = errorCode; + } + + public WorkerClientErrorCode ErrorCode { get; } +} diff --git a/src/MxGateway.Server/Workers/WorkerClientOptions.cs b/src/MxGateway.Server/Workers/WorkerClientOptions.cs new file mode 100644 index 0000000..8ef59fd --- /dev/null +++ b/src/MxGateway.Server/Workers/WorkerClientOptions.cs @@ -0,0 +1,24 @@ +namespace MxGateway.Server.Workers; + +public sealed class WorkerClientOptions +{ + public static readonly TimeSpan DefaultHeartbeatGrace = TimeSpan.FromSeconds(15); + public static readonly TimeSpan DefaultHeartbeatCheckInterval = TimeSpan.FromSeconds(1); + public static readonly TimeSpan DefaultEventChannelFullModeTimeout = TimeSpan.FromSeconds(5); + + public WorkerClientOptions() + { + HeartbeatGrace = DefaultHeartbeatGrace; + HeartbeatCheckInterval = DefaultHeartbeatCheckInterval; + EventChannelCapacity = 1_024; + EventChannelFullModeTimeout = DefaultEventChannelFullModeTimeout; + } + + public TimeSpan HeartbeatGrace { get; init; } + + public TimeSpan HeartbeatCheckInterval { get; init; } + + public int EventChannelCapacity { get; init; } + + public TimeSpan EventChannelFullModeTimeout { get; init; } +} diff --git a/src/MxGateway.Server/Workers/WorkerClientState.cs b/src/MxGateway.Server/Workers/WorkerClientState.cs new file mode 100644 index 0000000..9ba3be1 --- /dev/null +++ b/src/MxGateway.Server/Workers/WorkerClientState.cs @@ -0,0 +1,11 @@ +namespace MxGateway.Server.Workers; + +public enum WorkerClientState +{ + Created, + Handshaking, + Ready, + Closing, + Closed, + Faulted, +} diff --git a/src/MxGateway.Tests/Gateway/Workers/WorkerClientTests.cs b/src/MxGateway.Tests/Gateway/Workers/WorkerClientTests.cs new file mode 100644 index 0000000..c519c07 --- /dev/null +++ b/src/MxGateway.Tests/Gateway/Workers/WorkerClientTests.cs @@ -0,0 +1,341 @@ +using System.IO.Pipes; +using MxGateway.Contracts; +using MxGateway.Contracts.Proto; +using MxGateway.Server.Workers; + +namespace MxGateway.Tests.Gateway.Workers; + +public sealed class WorkerClientTests +{ + private const string SessionId = "session-worker-client"; + private const string Nonce = "nonce-worker-client"; + private const int WorkerProcessId = 4321; + private static readonly TimeSpan TestTimeout = TimeSpan.FromSeconds(5); + + [Fact] + public async Task StartAsync_WithWorkerHelloAndReady_EntersReadyState() + { + await using PipePair pipePair = await PipePair.CreateAsync(); + await using WorkerClient client = CreateClient(pipePair); + + await CompleteHandshakeAsync(client, pipePair); + + Assert.Equal(WorkerClientState.Ready, client.State); + Assert.Equal(WorkerProcessId, client.ProcessId); + } + + [Fact] + public async Task InvokeAsync_WithMatchingReply_CompletesPendingCommand() + { + await using PipePair pipePair = await PipePair.CreateAsync(); + await using WorkerClient client = CreateClient(pipePair); + await CompleteHandshakeAsync(client, pipePair); + + Task invokeTask = client.InvokeAsync( + CreateCommand(MxCommandKind.Ping), + TestTimeout, + CancellationToken.None); + + WorkerEnvelope commandEnvelope = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout); + Assert.Equal(WorkerEnvelope.BodyOneofCase.WorkerCommand, commandEnvelope.BodyCase); + Assert.False(string.IsNullOrWhiteSpace(commandEnvelope.CorrelationId)); + + await pipePair.WorkerWriter.WriteAsync( + CreateCommandReplyEnvelope(commandEnvelope.CorrelationId, MxCommandKind.Ping)); + + WorkerCommandReply reply = await invokeTask.WaitAsync(TestTimeout); + + Assert.Equal(commandEnvelope.CorrelationId, reply.Reply.CorrelationId); + Assert.Equal(MxCommandKind.Ping, reply.Reply.Kind); + } + + [Fact] + public async Task InvokeAsync_WithLateReply_IgnoresLateReplyAndKeepsClientReady() + { + await using PipePair pipePair = await PipePair.CreateAsync(); + await using WorkerClient client = CreateClient(pipePair); + await CompleteHandshakeAsync(client, pipePair); + + Task timedOutInvokeTask = client.InvokeAsync( + CreateCommand(MxCommandKind.Ping), + TimeSpan.FromMilliseconds(50), + CancellationToken.None); + WorkerEnvelope timedOutCommand = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout); + + WorkerClientException exception = await Assert.ThrowsAsync( + async () => await timedOutInvokeTask); + Assert.Equal(WorkerClientErrorCode.CommandTimeout, exception.ErrorCode); + + await pipePair.WorkerWriter.WriteAsync( + CreateCommandReplyEnvelope(timedOutCommand.CorrelationId, MxCommandKind.Ping)); + await Task.Delay(TimeSpan.FromMilliseconds(50)); + + Task secondInvokeTask = client.InvokeAsync( + CreateCommand(MxCommandKind.GetWorkerInfo), + TestTimeout, + CancellationToken.None); + WorkerEnvelope secondCommand = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout); + await pipePair.WorkerWriter.WriteAsync( + CreateCommandReplyEnvelope(secondCommand.CorrelationId, MxCommandKind.GetWorkerInfo)); + + WorkerCommandReply reply = await secondInvokeTask.WaitAsync(TestTimeout); + + Assert.Equal(WorkerClientState.Ready, client.State); + Assert.Equal(MxCommandKind.GetWorkerInfo, reply.Reply.Kind); + } + + [Fact] + public async Task ReadEventsAsync_WithWorkerEvents_YieldsEventsInPipeOrder() + { + await using PipePair pipePair = await PipePair.CreateAsync(); + await using WorkerClient client = CreateClient(pipePair); + await CompleteHandshakeAsync(client, pipePair); + using CancellationTokenSource cancellationTokenSource = new(TestTimeout); + + await using IAsyncEnumerator events = + client.ReadEventsAsync(cancellationTokenSource.Token).GetAsyncEnumerator(cancellationTokenSource.Token); + + await pipePair.WorkerWriter.WriteAsync( + CreateEventEnvelope(sequence: 11, MxEventFamily.OnDataChange)); + await pipePair.WorkerWriter.WriteAsync( + CreateEventEnvelope(sequence: 12, MxEventFamily.OperationComplete)); + + Assert.True(await events.MoveNextAsync()); + Assert.Equal((ulong)11, events.Current.Event.WorkerSequence); + Assert.Equal(MxEventFamily.OnDataChange, events.Current.Event.Family); + + Assert.True(await events.MoveNextAsync()); + Assert.Equal((ulong)12, events.Current.Event.WorkerSequence); + Assert.Equal(MxEventFamily.OperationComplete, events.Current.Event.Family); + } + + [Fact] + public async Task ReadLoop_WhenPipeDisconnects_FaultsClient() + { + await using PipePair pipePair = await PipePair.CreateAsync(); + await using WorkerClient client = CreateClient(pipePair); + await CompleteHandshakeAsync(client, pipePair); + + await pipePair.DisposeWorkerSideAsync(); + + await WaitUntilAsync( + () => client.State == WorkerClientState.Faulted, + TestTimeout); + + Assert.Equal(WorkerClientState.Faulted, client.State); + } + + [Fact] + public async Task HeartbeatMonitor_WhenHeartbeatExpires_FaultsClient() + { + await using PipePair pipePair = await PipePair.CreateAsync(); + await using WorkerClient client = CreateClient( + pipePair, + new WorkerClientOptions + { + HeartbeatGrace = TimeSpan.FromMilliseconds(80), + HeartbeatCheckInterval = TimeSpan.FromMilliseconds(20), + EventChannelCapacity = 8, + }); + await CompleteHandshakeAsync(client, pipePair); + + await WaitUntilAsync( + () => client.State == WorkerClientState.Faulted, + TestTimeout); + + Assert.Equal(WorkerClientState.Faulted, client.State); + } + + private static WorkerClient CreateClient( + PipePair pipePair, + WorkerClientOptions? options = null) + { + WorkerFrameProtocolOptions frameOptions = new(SessionId); + WorkerClientConnection connection = new( + SessionId, + Nonce, + pipePair.GatewayStream, + frameOptions); + + return new WorkerClient(connection, options); + } + + private static async Task CompleteHandshakeAsync( + WorkerClient client, + PipePair pipePair) + { + Task startTask = client.StartAsync(CancellationToken.None); + + WorkerEnvelope gatewayHello = await pipePair.WorkerReader.ReadAsync().AsTask().WaitAsync(TestTimeout); + Assert.Equal(WorkerEnvelope.BodyOneofCase.GatewayHello, gatewayHello.BodyCase); + Assert.Equal(Nonce, gatewayHello.GatewayHello.Nonce); + Assert.Equal(GatewayContractInfo.WorkerProtocolVersion, gatewayHello.GatewayHello.SupportedProtocolVersion); + + await pipePair.WorkerWriter.WriteAsync(CreateWorkerHelloEnvelope()); + await pipePair.WorkerWriter.WriteAsync(CreateWorkerReadyEnvelope()); + await startTask.WaitAsync(TestTimeout); + } + + private static WorkerCommand CreateCommand(MxCommandKind kind) + { + return new WorkerCommand + { + Command = new MxCommand + { + Kind = kind, + }, + }; + } + + private static WorkerEnvelope CreateWorkerHelloEnvelope() + { + return CreateWorkerEnvelope( + correlationId: string.Empty, + sequence: 1, + envelope => envelope.WorkerHello = new WorkerHello + { + ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion, + Nonce = Nonce, + WorkerProcessId = WorkerProcessId, + WorkerVersion = "fake-worker", + }); + } + + private static WorkerEnvelope CreateWorkerReadyEnvelope() + { + return CreateWorkerEnvelope( + correlationId: string.Empty, + sequence: 2, + envelope => envelope.WorkerReady = new WorkerReady + { + WorkerProcessId = WorkerProcessId, + MxaccessProgid = "LMXProxy.LMXProxyServer.1", + MxaccessClsid = "{C30B52F5-2CB5-4760-AF0A-3A344A7EB5DC}", + }); + } + + private static WorkerEnvelope CreateCommandReplyEnvelope( + string correlationId, + MxCommandKind kind) + { + return CreateWorkerEnvelope( + correlationId, + sequence: 10, + envelope => envelope.WorkerCommandReply = new WorkerCommandReply + { + Reply = new MxCommandReply + { + SessionId = SessionId, + CorrelationId = correlationId, + Kind = kind, + }, + }); + } + + private static WorkerEnvelope CreateEventEnvelope( + ulong sequence, + MxEventFamily family) + { + return CreateWorkerEnvelope( + correlationId: string.Empty, + sequence, + envelope => envelope.WorkerEvent = new WorkerEvent + { + Event = new MxEvent + { + SessionId = SessionId, + Family = family, + WorkerSequence = sequence, + }, + }); + } + + private static WorkerEnvelope CreateWorkerEnvelope( + string correlationId, + ulong sequence, + Action setBody) + { + WorkerEnvelope envelope = new() + { + ProtocolVersion = GatewayContractInfo.WorkerProtocolVersion, + SessionId = SessionId, + Sequence = sequence, + CorrelationId = correlationId, + }; + setBody(envelope); + + return envelope; + } + + private static async Task WaitUntilAsync( + Func predicate, + TimeSpan timeout) + { + using CancellationTokenSource cancellationTokenSource = new(timeout); + while (!predicate()) + { + await Task.Delay(TimeSpan.FromMilliseconds(10), cancellationTokenSource.Token); + } + } + + private sealed class PipePair : IAsyncDisposable + { + private readonly NamedPipeClientStream _workerStream; + private bool _workerSideDisposed; + + private PipePair( + NamedPipeServerStream gatewayStream, + NamedPipeClientStream workerStream) + { + GatewayStream = gatewayStream; + _workerStream = workerStream; + WorkerReader = new WorkerFrameReader(_workerStream, new WorkerFrameProtocolOptions(SessionId)); + WorkerWriter = new WorkerFrameWriter(_workerStream, new WorkerFrameProtocolOptions(SessionId)); + } + + public NamedPipeServerStream GatewayStream { get; } + + public WorkerFrameReader WorkerReader { get; } + + public WorkerFrameWriter WorkerWriter { get; } + + public static async Task CreateAsync() + { + string pipeName = $"mxaccessgw-workerclient-tests-{Guid.NewGuid():N}"; + NamedPipeServerStream gatewayStream = new( + pipeName, + PipeDirection.InOut, + maxNumberOfServerInstances: 1, + PipeTransmissionMode.Byte, + PipeOptions.Asynchronous); + NamedPipeClientStream workerStream = new( + ".", + pipeName, + PipeDirection.InOut, + PipeOptions.Asynchronous); + + Task waitForConnectionTask = gatewayStream.WaitForConnectionAsync(); + await workerStream.ConnectAsync(); + await waitForConnectionTask; + + return new PipePair(gatewayStream, workerStream); + } + + public async ValueTask DisposeWorkerSideAsync() + { + if (_workerSideDisposed) + { + return; + } + + await _workerStream.DisposeAsync(); + _workerSideDisposed = true; + } + + public async ValueTask DisposeAsync() + { + await DisposeWorkerSideAsync(); + await GatewayStream.DisposeAsync(); + } + } +} -- 2.52.0