using System.Collections.Concurrent; using System.Runtime.CompilerServices; using System.Threading.Channels; using Google.Protobuf.WellKnownTypes; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using MxGateway.Contracts; using MxGateway.Contracts.Proto; using MxGateway.Server.Metrics; namespace MxGateway.Server.Workers; public sealed class WorkerClient : IWorkerClient { private const string GatewayVersionFallback = "unknown"; private readonly object _syncRoot = new(); private readonly WorkerClientConnection _connection; private readonly WorkerClientOptions _options; private readonly GatewayMetrics? _metrics; private readonly TimeProvider _timeProvider; private readonly ILogger _logger; private readonly WorkerFrameReader _reader; private readonly WorkerFrameWriter _writer; private readonly Channel _outboundEnvelopes; private readonly Channel _events; private readonly ConcurrentDictionary _pendingCommands = new(StringComparer.Ordinal); private readonly CancellationTokenSource _stopCts = new(); private long _nextSequence; private WorkerClientState _state; private DateTimeOffset _lastHeartbeatAt; private int? _processId; private int _eventQueueDepth; private Task? _readLoopTask; private Task? _writeLoopTask; private Task? _heartbeatLoopTask; private bool _disposed; public WorkerClient( WorkerClientConnection connection, WorkerClientOptions? options = null, GatewayMetrics? metrics = null, TimeProvider? timeProvider = null, ILogger? logger = null) { _connection = connection ?? throw new ArgumentNullException(nameof(connection)); _options = options ?? new WorkerClientOptions(); _metrics = metrics; _timeProvider = timeProvider ?? TimeProvider.System; _logger = logger ?? NullLogger.Instance; _reader = new WorkerFrameReader(connection.Stream, connection.FrameOptions); _writer = new WorkerFrameWriter(connection.Stream, connection.FrameOptions); _outboundEnvelopes = Channel.CreateUnbounded( new UnboundedChannelOptions { SingleReader = true, SingleWriter = false, AllowSynchronousContinuations = false, }); _events = Channel.CreateBounded( new BoundedChannelOptions(_options.EventChannelCapacity) { SingleReader = false, SingleWriter = true, FullMode = BoundedChannelFullMode.Wait, AllowSynchronousContinuations = false, }); _lastHeartbeatAt = _timeProvider.GetUtcNow(); } public string SessionId => _connection.SessionId; public int? ProcessId { get { lock (_syncRoot) { return _processId; } } } public WorkerClientState State { get { lock (_syncRoot) { return _state; } } } public DateTimeOffset LastHeartbeatAt { get { lock (_syncRoot) { return _lastHeartbeatAt; } } } public async Task StartAsync(CancellationToken cancellationToken) { ThrowIfDisposed(); TransitionFromCreatedToHandshaking(); _writeLoopTask = Task.Run(WriteLoopAsync); await EnqueueAsync(CreateGatewayHelloEnvelope(), cancellationToken).ConfigureAwait(false); WorkerEnvelope helloEnvelope = await ReadHandshakeEnvelopeAsync( WorkerEnvelope.BodyOneofCase.WorkerHello, cancellationToken).ConfigureAwait(false); ValidateWorkerHello(helloEnvelope.WorkerHello); WorkerEnvelope readyEnvelope = await ReadHandshakeEnvelopeAsync( WorkerEnvelope.BodyOneofCase.WorkerReady, cancellationToken).ConfigureAwait(false); MarkReady(readyEnvelope.WorkerReady); _readLoopTask = Task.Run(ReadLoopAsync); _heartbeatLoopTask = Task.Run(HeartbeatLoopAsync); } public async Task InvokeAsync( WorkerCommand command, TimeSpan timeout, CancellationToken cancellationToken) { ArgumentNullException.ThrowIfNull(command); ThrowIfDisposed(); EnsureReady(); if (timeout <= TimeSpan.Zero) { throw new ArgumentOutOfRangeException(nameof(timeout), timeout, "Command timeout must be greater than zero."); } string correlationId = Guid.NewGuid().ToString("N"); string method = GetCommandMethod(command); PendingCommand pendingCommand = new( correlationId, method, _timeProvider.GetTimestamp()); if (!_pendingCommands.TryAdd(correlationId, pendingCommand)) { throw new InvalidOperationException("Generated a duplicate command correlation id."); } _metrics?.CommandStarted(method); try { await EnqueueAsync(CreateCommandEnvelope(correlationId, command), cancellationToken).ConfigureAwait(false); using CancellationTokenSource timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); Task timeoutTask = Task.Delay(timeout, timeoutCts.Token); Task replyTask = pendingCommand.Task; Task completedTask = await Task.WhenAny(replyTask, timeoutTask).ConfigureAwait(false); if (completedTask == replyTask) { await timeoutCts.CancelAsync().ConfigureAwait(false); return await replyTask.ConfigureAwait(false); } if (cancellationToken.IsCancellationRequested) { RemovePendingCommandAsFailed( correlationId, pendingCommand, WorkerClientErrorCode.GatewayShutdown, "Command wait was canceled."); cancellationToken.ThrowIfCancellationRequested(); } RemovePendingCommandAsFailed( correlationId, pendingCommand, WorkerClientErrorCode.CommandTimeout, $"Worker command {method} timed out after {timeout}."); throw new WorkerClientException( WorkerClientErrorCode.CommandTimeout, $"Worker command {method} timed out after {timeout}."); } catch { _pendingCommands.TryRemove(correlationId, out _); throw; } } public async IAsyncEnumerable ReadEventsAsync( [EnumeratorCancellation] CancellationToken cancellationToken) { await foreach (WorkerEvent workerEvent in _events.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) { int queueDepth = Math.Max(0, Interlocked.Decrement(ref _eventQueueDepth)); _metrics?.SetEventQueueDepth(queueDepth); yield return workerEvent; } } public async Task ShutdownAsync(TimeSpan timeout, CancellationToken cancellationToken) { ThrowIfDisposed(); if (timeout <= TimeSpan.Zero) { throw new ArgumentOutOfRangeException(nameof(timeout), timeout, "Shutdown timeout must be greater than zero."); } WorkerClientState state = State; if (state is WorkerClientState.Closed or WorkerClientState.Faulted) { return; } MarkClosing(); await EnqueueAsync(CreateShutdownEnvelope(timeout, "gateway-shutdown"), cancellationToken).ConfigureAwait(false); _outboundEnvelopes.Writer.TryComplete(); using CancellationTokenSource timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); timeoutCts.CancelAfter(timeout); try { await WaitForBackgroundTasksAsync(timeoutCts.Token).ConfigureAwait(false); MarkClosed("shutdown"); } catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested) { SetFaulted( WorkerClientErrorCode.ShutdownTimeout, "Worker shutdown timed out.", null); throw new WorkerClientException( WorkerClientErrorCode.ShutdownTimeout, $"Worker shutdown timed out after {timeout}."); } } public void Kill(string reason) { ThrowIfDisposed(); _connection.ProcessHandle?.Process.Kill(entireProcessTree: true); _metrics?.WorkerKilled(reason); SetFaulted( WorkerClientErrorCode.WorkerFaulted, $"Worker was killed by the gateway: {reason}.", null); } public async ValueTask DisposeAsync() { if (_disposed) { return; } _disposed = true; _stopCts.Cancel(); _outboundEnvelopes.Writer.TryComplete(); _events.Writer.TryComplete(); CompletePendingCommands( new WorkerClientException( WorkerClientErrorCode.GatewayShutdown, "Worker client was disposed.")); await WaitForBackgroundTasksAsync(CancellationToken.None).ConfigureAwait(false); await _connection.Stream.DisposeAsync().ConfigureAwait(false); _connection.ProcessHandle?.Dispose(); _stopCts.Dispose(); } private async Task WriteLoopAsync() { try { await foreach (WorkerEnvelope envelope in _outboundEnvelopes.Reader.ReadAllAsync(_stopCts.Token).ConfigureAwait(false)) { await _writer.WriteAsync(envelope, _stopCts.Token).ConfigureAwait(false); } } catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState()) { } catch (Exception exception) { SetFaulted( WorkerClientErrorCode.WriteFailed, "Worker pipe write failed.", exception); } } private async Task ReadLoopAsync() { try { while (!_stopCts.IsCancellationRequested) { WorkerEnvelope envelope = await _reader.ReadAsync(_stopCts.Token).ConfigureAwait(false); await DispatchEnvelopeAsync(envelope, _stopCts.Token).ConfigureAwait(false); } } catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState()) { } catch (WorkerFrameProtocolException exception) when (exception.ErrorCode == WorkerFrameProtocolErrorCode.EndOfStream) { SetFaulted( WorkerClientErrorCode.PipeDisconnected, "Worker pipe disconnected.", exception); } catch (Exception exception) { SetFaulted( WorkerClientErrorCode.ProtocolViolation, "Worker read loop failed.", exception); } } private async Task HeartbeatLoopAsync() { try { while (!_stopCts.IsCancellationRequested) { await Task.Delay(_options.HeartbeatCheckInterval, _stopCts.Token).ConfigureAwait(false); if (State != WorkerClientState.Ready) { continue; } DateTimeOffset lastHeartbeatAt = LastHeartbeatAt; DateTimeOffset now = _timeProvider.GetUtcNow(); if (now - lastHeartbeatAt <= _options.HeartbeatGrace) { continue; } _metrics?.HeartbeatFailed(SessionId); SetFaulted( WorkerClientErrorCode.HeartbeatExpired, $"Worker heartbeat expired. Last heartbeat was at {lastHeartbeatAt:O}.", null); } } catch (OperationCanceledException) when (_stopCts.IsCancellationRequested || IsTerminalState()) { } } private async Task DispatchEnvelopeAsync( WorkerEnvelope envelope, CancellationToken cancellationToken) { switch (envelope.BodyCase) { case WorkerEnvelope.BodyOneofCase.WorkerCommandReply: CompleteCommand(envelope); break; case WorkerEnvelope.BodyOneofCase.WorkerEvent: await EnqueueWorkerEventAsync(envelope.WorkerEvent, cancellationToken).ConfigureAwait(false); break; case WorkerEnvelope.BodyOneofCase.WorkerHeartbeat: MarkHeartbeat(envelope.WorkerHeartbeat); break; case WorkerEnvelope.BodyOneofCase.WorkerFault: SetFaulted( WorkerClientErrorCode.WorkerFaulted, CreateWorkerFaultMessage(envelope.WorkerFault), null); break; case WorkerEnvelope.BodyOneofCase.WorkerShutdownAck: MarkClosed("worker-shutdown-ack"); break; default: SetFaulted( WorkerClientErrorCode.ProtocolViolation, $"Worker sent unexpected envelope body {envelope.BodyCase}.", null); break; } } private async Task EnqueueWorkerEventAsync( WorkerEvent workerEvent, CancellationToken cancellationToken) { if (workerEvent.Event is not null) { _metrics?.EventReceived(SessionId, workerEvent.Event.Family.ToString()); } if (!_events.Writer.TryWrite(workerEvent)) { _metrics?.QueueOverflow("worker-events"); SetFaulted( WorkerClientErrorCode.ProtocolViolation, "Worker event channel rejected an event.", null); return; } int queueDepth = Interlocked.Increment(ref _eventQueueDepth); _metrics?.SetEventQueueDepth(queueDepth); } private void CompleteCommand(WorkerEnvelope envelope) { string correlationId = envelope.CorrelationId; if (string.IsNullOrWhiteSpace(correlationId)) { correlationId = envelope.WorkerCommandReply.Reply?.CorrelationId ?? string.Empty; } if (!_pendingCommands.TryRemove(correlationId, out PendingCommand? pendingCommand)) { _logger.LogDebug( "Ignoring late or unknown worker command reply for session {SessionId} and correlation {CorrelationId}.", SessionId, correlationId); return; } TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp); _metrics?.CommandSucceeded(pendingCommand.Method, duration); pendingCommand.SetResult(envelope.WorkerCommandReply); } private void RemovePendingCommandAsFailed( string correlationId, PendingCommand pendingCommand, WorkerClientErrorCode errorCode, string message) { if (!_pendingCommands.TryRemove(correlationId, out _)) { return; } TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp); _metrics?.CommandFailed(pendingCommand.Method, errorCode.ToString(), duration); pendingCommand.SetException(new WorkerClientException(errorCode, message)); } private async Task ReadHandshakeEnvelopeAsync( WorkerEnvelope.BodyOneofCase expectedBody, CancellationToken cancellationToken) { WorkerEnvelope envelope = await _reader.ReadAsync(cancellationToken).ConfigureAwait(false); if (envelope.BodyCase != expectedBody) { throw new WorkerClientException( WorkerClientErrorCode.ProtocolViolation, $"Worker handshake expected {expectedBody} but received {envelope.BodyCase}."); } return envelope; } private void ValidateWorkerHello(WorkerHello workerHello) { if (workerHello.ProtocolVersion != _connection.FrameOptions.ProtocolVersion) { throw new WorkerClientException( WorkerClientErrorCode.ProtocolViolation, "Worker hello protocol version does not match the gateway protocol version."); } if (!string.Equals(workerHello.Nonce, _connection.Nonce, StringComparison.Ordinal)) { throw new WorkerClientException( WorkerClientErrorCode.ProtocolViolation, "Worker hello nonce does not match the gateway nonce."); } lock (_syncRoot) { _processId = workerHello.WorkerProcessId == 0 ? _connection.ProcessHandle?.ProcessId : workerHello.WorkerProcessId; } } private void MarkReady(WorkerReady ready) { lock (_syncRoot) { _processId = ready.WorkerProcessId == 0 ? _processId ?? _connection.ProcessHandle?.ProcessId : ready.WorkerProcessId; _lastHeartbeatAt = _timeProvider.GetUtcNow(); _state = WorkerClientState.Ready; } DateTimeOffset readyAt = _timeProvider.GetUtcNow(); DateTimeOffset launchedAt = _connection.ProcessHandle?.LaunchedAt ?? readyAt; _metrics?.WorkerStarted(readyAt - launchedAt); } private void MarkHeartbeat(WorkerHeartbeat heartbeat) { lock (_syncRoot) { _lastHeartbeatAt = _timeProvider.GetUtcNow(); if (heartbeat.WorkerProcessId != 0) { _processId = heartbeat.WorkerProcessId; } } } private void MarkClosing() { lock (_syncRoot) { if (_state is WorkerClientState.Closed or WorkerClientState.Faulted) { return; } _state = WorkerClientState.Closing; } } private void MarkClosed(string reason) { lock (_syncRoot) { if (_state == WorkerClientState.Closed) { return; } _state = WorkerClientState.Closed; } _stopCts.Cancel(); _outboundEnvelopes.Writer.TryComplete(); _events.Writer.TryComplete(); CompletePendingCommands( new WorkerClientException( WorkerClientErrorCode.GatewayShutdown, $"Worker client closed because {reason}.")); _metrics?.WorkerStopped(reason); } private void SetFaulted( WorkerClientErrorCode errorCode, string message, Exception? exception) { WorkerClientException fault = exception is null ? new WorkerClientException(errorCode, message) : new WorkerClientException(errorCode, message, exception); lock (_syncRoot) { if (_state is WorkerClientState.Faulted or WorkerClientState.Closed) { return; } _state = WorkerClientState.Faulted; } _stopCts.Cancel(); _outboundEnvelopes.Writer.TryComplete(fault); _events.Writer.TryComplete(fault); CompletePendingCommands(fault); _metrics?.Fault(errorCode.ToString()); _logger.LogWarning(exception, "Worker client faulted for session {SessionId}: {Message}", SessionId, message); } private void CompletePendingCommands(Exception exception) { foreach (KeyValuePair item in _pendingCommands.ToArray()) { if (_pendingCommands.TryRemove(item.Key, out PendingCommand? pendingCommand)) { TimeSpan duration = _timeProvider.GetElapsedTime(pendingCommand.StartTimestamp); _metrics?.CommandFailed(pendingCommand.Method, exception.GetType().Name, duration); pendingCommand.SetException(exception); } } } private void TransitionFromCreatedToHandshaking() { lock (_syncRoot) { if (_state != WorkerClientState.Created) { throw new WorkerClientException( WorkerClientErrorCode.InvalidState, $"Worker client cannot start from state {_state}."); } _state = WorkerClientState.Handshaking; } } private void EnsureReady() { WorkerClientState state = State; if (state != WorkerClientState.Ready) { throw new WorkerClientException( WorkerClientErrorCode.InvalidState, $"Worker client is not ready. Current state is {state}."); } } private bool IsTerminalState() { WorkerClientState state = State; return state is WorkerClientState.Closing or WorkerClientState.Closed or WorkerClientState.Faulted; } private async Task EnqueueAsync( WorkerEnvelope envelope, CancellationToken cancellationToken) { try { await _outboundEnvelopes.Writer.WriteAsync(envelope, cancellationToken).ConfigureAwait(false); } catch (ChannelClosedException exception) { throw new WorkerClientException( WorkerClientErrorCode.WriteFailed, "Worker outbound channel is closed.", exception); } } private WorkerEnvelope CreateGatewayHelloEnvelope() { return CreateEnvelope( correlationId: string.Empty, envelope => envelope.GatewayHello = new GatewayHello { SupportedProtocolVersion = _connection.FrameOptions.ProtocolVersion, Nonce = _connection.Nonce, GatewayVersion = typeof(GatewayContractInfo).Assembly.GetName().Version?.ToString() ?? GatewayVersionFallback, }); } private WorkerEnvelope CreateCommandEnvelope( string correlationId, WorkerCommand command) { return CreateEnvelope( correlationId, envelope => envelope.WorkerCommand = command.Clone()); } private WorkerEnvelope CreateShutdownEnvelope( TimeSpan timeout, string reason) { return CreateEnvelope( correlationId: string.Empty, envelope => envelope.WorkerShutdown = new WorkerShutdown { GracePeriod = Duration.FromTimeSpan(timeout), Reason = reason, }); } private WorkerEnvelope CreateEnvelope( string correlationId, Action setBody) { WorkerEnvelope envelope = new() { ProtocolVersion = _connection.FrameOptions.ProtocolVersion, SessionId = SessionId, Sequence = (ulong)Interlocked.Increment(ref _nextSequence), CorrelationId = correlationId, }; setBody(envelope); return envelope; } private static string GetCommandMethod(WorkerCommand command) { return command.Command?.Kind.ToString() ?? MxCommandKind.Unspecified.ToString(); } private static string CreateWorkerFaultMessage(WorkerFault fault) { return string.IsNullOrWhiteSpace(fault.DiagnosticMessage) ? $"Worker faulted with category {fault.Category}." : $"Worker faulted with category {fault.Category}: {fault.DiagnosticMessage}"; } private async Task WaitForBackgroundTasksAsync(CancellationToken cancellationToken) { Task[] tasks = new[] { _readLoopTask, _writeLoopTask, _heartbeatLoopTask } .Where(task => task is not null) .Cast() .ToArray(); if (tasks.Length == 0) { return; } await Task.WhenAll(tasks).WaitAsync(cancellationToken).ConfigureAwait(false); } private void ThrowIfDisposed() { ObjectDisposedException.ThrowIf(_disposed, this); } private sealed class PendingCommand { private readonly TaskCompletionSource _completion = new(TaskCreationOptions.RunContinuationsAsynchronously); public PendingCommand( string correlationId, string method, long startTimestamp) { CorrelationId = correlationId; Method = method; StartTimestamp = startTimestamp; } public string CorrelationId { get; } public string Method { get; } public long StartTimestamp { get; } public Task Task => _completion.Task; public void SetResult(WorkerCommandReply reply) { _completion.TrySetResult(reply); } public void SetException(Exception exception) { _completion.TrySetException(exception); } } }