using System; using System.Diagnostics; using System.Linq; using System.Threading; using System.Threading.Tasks; using Grpc.Core; using GrpcStatus = Grpc.Core.Status; using Serilog; using ZB.MOM.WW.LmxProxy.Host.Domain; using ZB.MOM.WW.LmxProxy.Host.Metrics; using ZB.MOM.WW.LmxProxy.Host.Sessions; using ZB.MOM.WW.LmxProxy.Host.Security; using ZB.MOM.WW.LmxProxy.Host.Subscriptions; namespace ZB.MOM.WW.LmxProxy.Host.Grpc.Services { /// /// gRPC service implementation for all 10 SCADA RPCs. /// Inherits from proto-generated ScadaService.ScadaServiceBase. /// public class ScadaGrpcService : Scada.ScadaService.ScadaServiceBase { private static readonly ILogger Log = Serilog.Log.ForContext(); private readonly IScadaClient _scadaClient; private readonly SessionManager _sessionManager; private readonly SubscriptionManager _subscriptionManager; private readonly PerformanceMetrics? _performanceMetrics; private readonly ApiKeyService? _apiKeyService; public ScadaGrpcService( IScadaClient scadaClient, SessionManager sessionManager, SubscriptionManager subscriptionManager, PerformanceMetrics? performanceMetrics = null, ApiKeyService? apiKeyService = null) { _scadaClient = scadaClient; _sessionManager = sessionManager; _subscriptionManager = subscriptionManager; _performanceMetrics = performanceMetrics; _apiKeyService = apiKeyService; } // -- Connection Management ------------------------------------ public override Task Connect( Scada.ConnectRequest request, ServerCallContext context) { try { if (!_scadaClient.IsConnected) { return Task.FromResult(new Scada.ConnectResponse { Success = false, Message = "MxAccess is not connected" }); } var sessionId = _sessionManager.CreateSession(request.ClientId, request.ApiKey); return Task.FromResult(new Scada.ConnectResponse { Success = true, Message = "Connected", SessionId = sessionId }); } catch (Exception ex) { Log.Error(ex, "Connect failed for client {ClientId}", request.ClientId); return Task.FromResult(new Scada.ConnectResponse { Success = false, Message = ex.Message }); } } public override Task Disconnect( Scada.DisconnectRequest request, ServerCallContext context) { try { // Terminate session first — prevents new Subscribe RPCs from passing // session validation while we clean up subscriptions var terminated = _sessionManager.TerminateSession(request.SessionId); // Then clean up all subscriptions for this session _subscriptionManager.UnsubscribeSession(request.SessionId); return Task.FromResult(new Scada.DisconnectResponse { Success = terminated, Message = terminated ? "Disconnected" : "Session not found" }); } catch (Exception ex) { Log.Error(ex, "Disconnect failed for session {SessionId}", request.SessionId); return Task.FromResult(new Scada.DisconnectResponse { Success = false, Message = ex.Message }); } } public override Task GetConnectionState( Scada.GetConnectionStateRequest request, ServerCallContext context) { var session = _sessionManager.GetSession(request.SessionId); return Task.FromResult(new Scada.GetConnectionStateResponse { IsConnected = _scadaClient.IsConnected, ClientId = session?.ClientId ?? "", ConnectedSinceUtcTicks = session?.ConnectedSinceUtcTicks ?? 0 }); } // -- Read Operations ------------------------------------------ public override async Task Read( Scada.ReadRequest request, ServerCallContext context) { if (!_sessionManager.ValidateSession(request.SessionId)) { return new Scada.ReadResponse { Success = false, Message = "Invalid session", Vtq = CreateBadVtq(request.Tag, QualityCodeMapper.Bad()) }; } using var timing = _performanceMetrics?.BeginOperation("Read"); try { var vtq = await _scadaClient.ReadAsync(request.Tag, context.CancellationToken); return new Scada.ReadResponse { Success = true, Message = "", Vtq = ConvertToProtoVtq(request.Tag, vtq) }; } catch (Exception ex) { timing?.SetSuccess(false); Log.Error(ex, "Read failed for tag {Tag}", request.Tag); return new Scada.ReadResponse { Success = false, Message = ex.Message, Vtq = CreateBadVtq(request.Tag, QualityCodeMapper.BadCommunicationFailure()) }; } } public override async Task ReadBatch( Scada.ReadBatchRequest request, ServerCallContext context) { if (!_sessionManager.ValidateSession(request.SessionId)) { return new Scada.ReadBatchResponse { Success = false, Message = "Invalid session" }; } using var timing = _performanceMetrics?.BeginOperation("ReadBatch"); try { var results = await _scadaClient.ReadBatchAsync(request.Tags, context.CancellationToken); var response = new Scada.ReadBatchResponse { Success = true, Message = "" }; // Return results in request order foreach (var tag in request.Tags) { if (results.TryGetValue(tag, out var vtq)) { response.Vtqs.Add(ConvertToProtoVtq(tag, vtq)); } else { response.Vtqs.Add(CreateBadVtq(tag, QualityCodeMapper.BadConfigurationError())); } } return response; } catch (Exception ex) { timing?.SetSuccess(false); Log.Error(ex, "ReadBatch failed"); return new Scada.ReadBatchResponse { Success = false, Message = ex.Message }; } } // -- Write Operations ----------------------------------------- public override async Task Write( Scada.WriteRequest request, ServerCallContext context) { if (!_sessionManager.ValidateSession(request.SessionId)) { return new Scada.WriteResponse { Success = false, Message = "Invalid session" }; } using var timing = _performanceMetrics?.BeginOperation("Write"); try { var value = TypedValueConverter.FromTypedValue(request.Value); await _scadaClient.WriteAsync(request.Tag, value!, context.CancellationToken); return new Scada.WriteResponse { Success = true, Message = "" }; } catch (Exception ex) { timing?.SetSuccess(false); Log.Error(ex, "Write failed for tag {Tag}", request.Tag); return new Scada.WriteResponse { Success = false, Message = ex.Message }; } } public override async Task WriteBatch( Scada.WriteBatchRequest request, ServerCallContext context) { if (!_sessionManager.ValidateSession(request.SessionId)) { return new Scada.WriteBatchResponse { Success = false, Message = "Invalid session" }; } using var timing = _performanceMetrics?.BeginOperation("WriteBatch"); var response = new Scada.WriteBatchResponse { Success = true, Message = "" }; foreach (var item in request.Items) { try { var value = TypedValueConverter.FromTypedValue(item.Value); await _scadaClient.WriteAsync(item.Tag, value!, context.CancellationToken); response.Results.Add(new Scada.WriteResult { Tag = item.Tag, Success = true, Message = "" }); } catch (Exception ex) { response.Success = false; response.Results.Add(new Scada.WriteResult { Tag = item.Tag, Success = false, Message = ex.Message }); } } if (!response.Success) { timing?.SetSuccess(false); } return response; } public override async Task WriteBatchAndWait( Scada.WriteBatchAndWaitRequest request, ServerCallContext context) { if (!_sessionManager.ValidateSession(request.SessionId)) { return new Scada.WriteBatchAndWaitResponse { Success = false, Message = "Invalid session" }; } var response = new Scada.WriteBatchAndWaitResponse { Success = true }; try { // Execute writes and collect results foreach (var item in request.Items) { try { var value = TypedValueConverter.FromTypedValue(item.Value); await _scadaClient.WriteAsync(item.Tag, value!, context.CancellationToken); response.WriteResults.Add(new Scada.WriteResult { Tag = item.Tag, Success = true, Message = "" }); } catch (Exception ex) { response.Success = false; response.Message = "One or more writes failed"; response.WriteResults.Add(new Scada.WriteResult { Tag = item.Tag, Success = false, Message = ex.Message }); } } // If any write failed, return immediately if (!response.Success) return response; // Poll flag tag var flagValue = TypedValueConverter.FromTypedValue(request.FlagValue); var timeoutMs = request.TimeoutMs > 0 ? request.TimeoutMs : 5000; var pollIntervalMs = request.PollIntervalMs > 0 ? request.PollIntervalMs : 100; var sw = Stopwatch.StartNew(); while (sw.ElapsedMilliseconds < timeoutMs) { context.CancellationToken.ThrowIfCancellationRequested(); var vtq = await _scadaClient.ReadAsync(request.FlagTag, context.CancellationToken); if (vtq.Quality.IsGood() && TypedValueComparer.Equals(vtq.Value, flagValue)) { response.FlagReached = true; response.ElapsedMs = (int)sw.ElapsedMilliseconds; return response; } await Task.Delay(pollIntervalMs, context.CancellationToken); } // Timeout -- not an error response.FlagReached = false; response.ElapsedMs = (int)sw.ElapsedMilliseconds; return response; } catch (OperationCanceledException) { throw; } catch (Exception ex) { Log.Error(ex, "WriteBatchAndWait failed"); return new Scada.WriteBatchAndWaitResponse { Success = false, Message = ex.Message }; } } // -- Subscription --------------------------------------------- public override async Task Subscribe( Scada.SubscribeRequest request, IServerStreamWriter responseStream, ServerCallContext context) { if (!_sessionManager.ValidateSession(request.SessionId)) { throw new RpcException(new GrpcStatus(StatusCode.Unauthenticated, "Invalid session")); } var (reader, subscriptionId) = await _subscriptionManager.SubscribeAsync( request.SessionId, request.Tags, context.CancellationToken); try { // Use a combined approach: check both the gRPC cancellation token AND // periodic session validity. This works around Grpc.Core not reliably // firing CancellationToken on client disconnect. while (true) { // Wait for data with a timeout so we can periodically check session validity using (var timeoutCts = new CancellationTokenSource(TimeSpan.FromSeconds(30))) using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource( context.CancellationToken, timeoutCts.Token)) { bool hasData; try { hasData = await reader.WaitToReadAsync(linkedCts.Token); } catch (OperationCanceledException) when (timeoutCts.IsCancellationRequested && !context.CancellationToken.IsCancellationRequested) { // Timeout expired, not a client disconnect — check if session is still valid if (!_sessionManager.ValidateSession(request.SessionId)) { Log.Information("Subscribe stream ending — session {SessionId} no longer valid", request.SessionId); break; } continue; // Session still valid, keep waiting } if (!hasData) break; // Channel completed while (reader.TryRead(out var item)) { var protoVtq = ConvertToProtoVtq(item.address, item.vtq); await responseStream.WriteAsync(protoVtq); } } } } catch (OperationCanceledException) { // Client disconnected -- normal } catch (Exception ex) { Log.Error(ex, "Subscribe stream error for session {SessionId} subscription {SubscriptionId}", request.SessionId, subscriptionId); throw new RpcException(new GrpcStatus(StatusCode.Internal, ex.Message)); } finally { // Clean up THIS subscription only, not the entire session _subscriptionManager.UnsubscribeSubscription(subscriptionId); } } // -- API Key Check -------------------------------------------- public override Task CheckApiKey( Scada.CheckApiKeyRequest request, ServerCallContext context) { // Check the API key from the request body against the key store. var isValid = _apiKeyService != null && _apiKeyService.ValidateApiKey(request.ApiKey) != null; return Task.FromResult(new Scada.CheckApiKeyResponse { IsValid = isValid, Message = isValid ? "Valid" : "Invalid" }); } // -- Helpers -------------------------------------------------- /// Converts a domain Vtq to a proto VtqMessage. private static Scada.VtqMessage ConvertToProtoVtq(string tag, Vtq vtq) { return new Scada.VtqMessage { Tag = tag, Value = TypedValueConverter.ToTypedValue(vtq.Value), TimestampUtcTicks = vtq.Timestamp.Ticks, Quality = QualityCodeMapper.ToQualityCode(vtq.Quality) }; } /// Creates a VtqMessage with bad quality for error responses. private static Scada.VtqMessage CreateBadVtq(string tag, Scada.QualityCode quality) { return new Scada.VtqMessage { Tag = tag, TimestampUtcTicks = DateTime.UtcNow.Ticks, Quality = quality }; } } }