"""Async MXAccess Gateway client wrapper.""" from __future__ import annotations import asyncio from collections.abc import AsyncIterator, Sequence from typing import Any import grpc from .auth import merge_metadata from .errors import ensure_protocol_success, map_rpc_error from .generated import mxaccess_gateway_pb2 as pb from .generated import mxaccess_gateway_pb2_grpc as pb_grpc from .options import ClientOptions, create_channel class GatewayClient: """Async client for the public MXAccess Gateway gRPC API.""" def __init__( self, *, options: ClientOptions, stub: Any, channel: grpc.aio.Channel | None = None, ) -> None: """Initialize the client with resolved options and a gRPC stub.""" self.options = options self.raw_stub = stub self._channel = channel self._closed = False @classmethod async def connect( cls, options: ClientOptions | None = None, *, endpoint: str | None = None, api_key: str | None = None, plaintext: bool = False, ca_file: str | None = None, server_name_override: str | None = None, stub: Any | None = None, ) -> "GatewayClient": """Create a client with either a real async channel or a supplied fake stub.""" resolved = options or ClientOptions( endpoint=endpoint or "", api_key=api_key, plaintext=plaintext, ca_file=ca_file, server_name_override=server_name_override, ) if stub is not None: return cls(options=resolved, stub=stub) channel = create_channel(resolved) return cls( options=resolved, stub=pb_grpc.MxAccessGatewayStub(channel), channel=channel, ) async def __aenter__(self) -> "GatewayClient": """Return self to support ``async with`` usage.""" return self async def __aexit__(self, *_exc_info: object) -> None: """Close the client when leaving an ``async with`` block.""" await self.close() async def close(self) -> None: """Close the owned gRPC channel.""" if self._closed: return if self._channel is not None: await self._channel.close() self._closed = True async def open_session( self, request: pb.OpenSessionRequest | None = None, *, requested_backend: str = "", client_session_name: str = "", client_correlation_id: str = "", ) -> "Session": """Open a gateway session and return a high-level session wrapper.""" from .session import Session raw_request = request or pb.OpenSessionRequest( requested_backend=requested_backend, client_session_name=client_session_name, client_correlation_id=client_correlation_id, ) reply = await self.open_session_raw(raw_request) return Session(client=self, session_id=reply.session_id, open_reply=reply) async def open_session_raw(self, request: pb.OpenSessionRequest) -> pb.OpenSessionReply: """Send an `OpenSession` RPC and return the raw reply.""" reply = await self._unary("open session", self.raw_stub.OpenSession, request) ensure_protocol_success("open session", reply.protocol_status, reply) return reply async def close_session_raw( self, request: pb.CloseSessionRequest, ) -> pb.CloseSessionReply: """Send a `CloseSession` RPC and return the raw reply.""" reply = await self._unary("close session", self.raw_stub.CloseSession, request) ensure_protocol_success("close session", reply.protocol_status, reply) return reply async def invoke_raw(self, request: pb.MxCommandRequest) -> pb.MxCommandReply: """Send an `Invoke` RPC and return the raw reply.""" reply = await self._unary("invoke", self.raw_stub.Invoke, request) ensure_protocol_success("invoke", reply.protocol_status, reply) return reply def stream_events_raw( self, request: pb.StreamEventsRequest, *, metadata: Sequence[tuple[str, str]] | None = None, ) -> AsyncIterator[pb.MxEvent]: """Return an async event iterator and cancel the stream when iteration stops.""" kwargs: dict[str, Any] = {"metadata": merge_metadata(self.options.api_key, metadata)} if self.options.stream_timeout is not None: kwargs["timeout"] = self.options.stream_timeout call = self.raw_stub.StreamEvents(request, **kwargs) return _canceling_iterator(call) async def acknowledge_alarm( self, request: pb.AcknowledgeAlarmRequest, ) -> pb.AcknowledgeAlarmReply: """Acknowledge an active MXAccess alarm condition through the gateway. The gateway authenticates the request against the API key's ``invoke:alarm-ack`` scope and forwards the acknowledge to the worker's MXAccess session; the resulting native ``MxStatus`` is returned in the reply. Acks are idempotent — re-acking an already-acked condition is a no-op at the MxAccess layer. """ reply = await self._unary("acknowledge alarm", self.raw_stub.AcknowledgeAlarm, request) ensure_protocol_success("acknowledge alarm", reply.protocol_status, reply) return reply def query_active_alarms( self, request: pb.QueryActiveAlarmsRequest, *, metadata: Sequence[tuple[str, str]] | None = None, ) -> AsyncIterator[pb.ActiveAlarmSnapshot]: """Stream a snapshot of all alarms currently Active or ActiveAcked. The gateway's ConditionRefresh equivalent. Use after reconnect to seed local Part 9 state, or to reconcile alarms that may have been missed during a transport blip. Optionally scoped by alarm-reference prefix (``request.alarm_filter_prefix``) so a partial refresh can target an equipment sub-tree. """ kwargs: dict[str, Any] = {"metadata": merge_metadata(self.options.api_key, metadata)} if self.options.stream_timeout is not None: kwargs["timeout"] = self.options.stream_timeout call = self.raw_stub.QueryActiveAlarms(request, **kwargs) return _canceling_active_alarms_iterator(call) def stream_alarms( self, request: pb.StreamAlarmsRequest, *, metadata: Sequence[tuple[str, str]] | None = None, ) -> AsyncIterator[pb.AlarmFeedMessage]: """Attach to the gateway's central alarm feed. The stream opens with one ``AlarmFeedMessage`` per currently-active alarm (the ConditionRefresh snapshot), then a single ``snapshot_complete``, then a ``transition`` for every subsequent raise / acknowledge / clear. Served by the gateway's always-on alarm monitor — no worker session is opened — so any number of clients may attach. Optionally scoped by alarm-reference prefix (``request.alarm_filter_prefix``). """ kwargs: dict[str, Any] = {"metadata": merge_metadata(self.options.api_key, metadata)} if self.options.stream_timeout is not None: kwargs["timeout"] = self.options.stream_timeout call = self.raw_stub.StreamAlarms(request, **kwargs) return _canceling_alarm_feed_iterator(call) async def _unary( self, operation: str, method: Any, request: Any, *, metadata: Sequence[tuple[str, str]] | None = None, ) -> Any: kwargs: dict[str, Any] = {"metadata": merge_metadata(self.options.api_key, metadata)} if self.options.call_timeout is not None: kwargs["timeout"] = self.options.call_timeout try: call = method(request, **kwargs) except TypeError as error: if "timeout" not in kwargs or "unexpected keyword argument 'timeout'" not in str(error): raise kwargs.pop("timeout") call = method(request, **kwargs) try: return await call except asyncio.CancelledError: cancel = getattr(call, "cancel", None) if cancel is not None: cancel() raise except grpc.RpcError as error: raise map_rpc_error(operation, error) from error async def _canceling_iterator(call: Any) -> AsyncIterator[pb.MxEvent]: try: async for event in call: yield event except grpc.RpcError as error: raise map_rpc_error("stream events", error) from error finally: cancel = getattr(call, "cancel", None) if cancel is not None: cancel() async def _canceling_active_alarms_iterator(call: Any) -> AsyncIterator[pb.ActiveAlarmSnapshot]: try: async for snapshot in call: yield snapshot except grpc.RpcError as error: raise map_rpc_error("query active alarms", error) from error finally: cancel = getattr(call, "cancel", None) if cancel is not None: cancel() async def _canceling_alarm_feed_iterator(call: Any) -> AsyncIterator[pb.AlarmFeedMessage]: try: async for message in call: yield message except grpc.RpcError as error: raise map_rpc_error("stream alarms", error) from error finally: cancel = getattr(call, "cancel", None) if cancel is not None: cancel()