"""Tests for the AcknowledgeAlarm + QueryActiveAlarms client surface (PR E.3).""" from __future__ import annotations import asyncio from typing import Any import grpc import pytest from zb_mom_ww_mxgateway import ClientOptions, GatewayClient from zb_mom_ww_mxgateway.errors import MxGatewayAuthenticationError, MxGatewayAuthorizationError from zb_mom_ww_mxgateway.generated import mxaccess_gateway_pb2 as pb @pytest.mark.asyncio async def test_acknowledge_alarm_sends_request_and_returns_reply() -> None: stub = FakeGatewayStub() stub.acknowledge_alarm.replies = [ pb.AcknowledgeAlarmReply( correlation_id="corr-7", protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), status=pb.MxStatusProxy(success=1, category=pb.MX_STATUS_CATEGORY_OK), ), ] client = await GatewayClient.connect( ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), stub=stub, ) reply = await client.acknowledge_alarm( pb.AcknowledgeAlarmRequest( client_correlation_id="corr-7", alarm_full_reference="Tank01.Level.HiHi", comment="investigating", operator_user="alice", ), ) assert reply.protocol_status.code == pb.PROTOCOL_STATUS_CODE_OK assert reply.status.category == pb.MX_STATUS_CATEGORY_OK captured_request = stub.acknowledge_alarm.requests[0] assert captured_request.alarm_full_reference == "Tank01.Level.HiHi" assert captured_request.comment == "investigating" assert captured_request.operator_user == "alice" assert stub.acknowledge_alarm.metadata == (("authorization", "Bearer mxgw_test_secret"),) @pytest.mark.asyncio async def test_acknowledge_alarm_unauthenticated_raises_typed_error() -> None: stub = FakeGatewayStub() stub.acknowledge_alarm.exception = FakeRpcError(grpc.StatusCode.UNAUTHENTICATED, "expired key") client = await GatewayClient.connect( ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), stub=stub, ) with pytest.raises(MxGatewayAuthenticationError): await client.acknowledge_alarm( pb.AcknowledgeAlarmRequest( alarm_full_reference="Tank01.Level.HiHi", comment="", operator_user="alice", ), ) @pytest.mark.asyncio async def test_acknowledge_alarm_permission_denied_raises_typed_error() -> None: stub = FakeGatewayStub() stub.acknowledge_alarm.exception = FakeRpcError(grpc.StatusCode.PERMISSION_DENIED, "missing scope") client = await GatewayClient.connect( ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), stub=stub, ) with pytest.raises(MxGatewayAuthorizationError): await client.acknowledge_alarm( pb.AcknowledgeAlarmRequest( alarm_full_reference="Tank01.Level.HiHi", comment="", operator_user="alice", ), ) @pytest.mark.asyncio async def test_query_active_alarms_streams_snapshots() -> None: snapshots = [ pb.ActiveAlarmSnapshot( alarm_full_reference="Tank01.Level.HiHi", current_state=pb.ALARM_CONDITION_STATE_ACTIVE, severity=750, ), pb.ActiveAlarmSnapshot( alarm_full_reference="Tank02.Level.HiHi", current_state=pb.ALARM_CONDITION_STATE_ACTIVE_ACKED, severity=750, ), ] stream = FakeSnapshotStream(snapshots) stub = FakeGatewayStub(snapshot_stream=stream) client = await GatewayClient.connect( ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), stub=stub, ) received: list[pb.ActiveAlarmSnapshot] = [] async for snapshot in client.query_active_alarms( pb.QueryActiveAlarmsRequest(session_id="session-1"), ): received.append(snapshot) assert len(received) == 2 assert received[0].alarm_full_reference == "Tank01.Level.HiHi" assert received[0].current_state == pb.ALARM_CONDITION_STATE_ACTIVE assert received[1].current_state == pb.ALARM_CONDITION_STATE_ACTIVE_ACKED assert stub.query_metadata == (("authorization", "Bearer mxgw_test_secret"),) @pytest.mark.asyncio async def test_query_active_alarms_passes_filter_prefix() -> None: stream = FakeSnapshotStream([]) stub = FakeGatewayStub(snapshot_stream=stream) client = await GatewayClient.connect( ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), stub=stub, ) iterator = client.query_active_alarms( pb.QueryActiveAlarmsRequest(session_id="session-1", alarm_filter_prefix="Tank01."), ) # Drain to trigger the stub call. async for _ in iterator: pass assert stub.query_request is not None assert stub.query_request.alarm_filter_prefix == "Tank01." @pytest.mark.asyncio async def test_query_active_alarms_cancels_underlying_stream_on_close() -> None: snapshots = [ pb.ActiveAlarmSnapshot( alarm_full_reference="Tank01.Level.HiHi", current_state=pb.ALARM_CONDITION_STATE_ACTIVE, ), ] stream = FakeSnapshotStream(snapshots) stub = FakeGatewayStub(snapshot_stream=stream) client = await GatewayClient.connect( ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), stub=stub, ) iterator = client.query_active_alarms(pb.QueryActiveAlarmsRequest(session_id="session-1")) first = await anext(iterator) await iterator.aclose() assert first.alarm_full_reference == "Tank01.Level.HiHi" assert stream.cancelled class FakeGatewayStub: def __init__(self, snapshot_stream: "FakeSnapshotStream | None" = None) -> None: self.open_session = FakeUnary( [ pb.OpenSessionReply( session_id="session-1", protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), ), ], ) self.acknowledge_alarm = FakeUnary([]) self.OpenSession = self.open_session self.AcknowledgeAlarm = self.acknowledge_alarm self._snapshot_stream = snapshot_stream or FakeSnapshotStream([]) self.query_request: pb.QueryActiveAlarmsRequest | None = None self.query_metadata: tuple[tuple[str, str], ...] | None = None def QueryActiveAlarms( self, request: pb.QueryActiveAlarmsRequest, *, metadata: tuple[tuple[str, str], ...], ) -> "FakeSnapshotStream": self.query_request = request self.query_metadata = metadata return self._snapshot_stream class FakeUnary: def __init__(self, replies: list[Any]) -> None: self.replies = replies self.requests: list[Any] = [] self.metadata: tuple[tuple[str, str], ...] | None = None self.exception: Exception | None = None async def __call__( self, request: Any, *, metadata: tuple[tuple[str, str], ...], ) -> Any: self.requests.append(request) self.metadata = metadata if self.exception is not None: raise self.exception return self.replies.pop(0) class FakeSnapshotStream: def __init__(self, snapshots: list[pb.ActiveAlarmSnapshot]) -> None: self._snapshots = list(snapshots) self.cancelled = False def __aiter__(self) -> "FakeSnapshotStream": return self async def __anext__(self) -> pb.ActiveAlarmSnapshot: if not self._snapshots: raise StopAsyncIteration return self._snapshots.pop(0) def cancel(self) -> None: self.cancelled = True class FakeRpcError(grpc.RpcError): def __init__(self, code: grpc.StatusCode, details: str) -> None: self._code = code self._details = details def code(self) -> grpc.StatusCode: # noqa: D401 return self._code def details(self) -> str: # noqa: D401 return self._details