"""Tests for the async client and session wrappers.""" from __future__ import annotations import asyncio from typing import Any import pytest from zb_mom_ww_mxgateway import ClientOptions, GatewayClient, MxAccessError from zb_mom_ww_mxgateway import client as client_module from zb_mom_ww_mxgateway import galaxy as galaxy_module from zb_mom_ww_mxgateway.galaxy import GalaxyRepositoryClient from zb_mom_ww_mxgateway.generated import mxaccess_gateway_pb2 as pb @pytest.mark.asyncio async def test_gateway_connect_forwards_require_certificate_validation( monkeypatch: pytest.MonkeyPatch, ) -> None: """The connect convenience kwarg must reach ClientOptions (Client.Python-027).""" captured: dict[str, Any] = {} def fake_create_channel(options: ClientOptions) -> object: captured["options"] = options return object() monkeypatch.setattr(client_module, "create_channel", fake_create_channel) monkeypatch.setattr(client_module.pb_grpc, "MxAccessGatewayStub", lambda channel: object()) await GatewayClient.connect( endpoint="gateway.example:5001", require_certificate_validation=True, ) assert captured["options"].require_certificate_validation is True @pytest.mark.asyncio async def test_galaxy_connect_forwards_require_certificate_validation( monkeypatch: pytest.MonkeyPatch, ) -> None: """GalaxyRepositoryClient.connect must thread the flag too (Client.Python-027).""" captured: dict[str, Any] = {} def fake_create_channel(options: ClientOptions) -> object: captured["options"] = options return object() monkeypatch.setattr(galaxy_module, "create_channel", fake_create_channel) monkeypatch.setattr( galaxy_module.galaxy_pb_grpc, "GalaxyRepositoryStub", lambda channel: object() ) await GalaxyRepositoryClient.connect( endpoint="gateway.example:5001", require_certificate_validation=True, ) assert captured["options"].require_certificate_validation is True @pytest.mark.asyncio async def test_gateway_connect_runs_create_channel_off_the_event_loop( monkeypatch: pytest.MonkeyPatch, ) -> None: """connect must run the blocking channel factory off the loop (Client.Python-028).""" ran_in_thread: dict[str, bool] = {} def fake_create_channel(options: ClientOptions) -> object: # If this runs on the event loop thread, get_running_loop() succeeds. try: asyncio.get_running_loop() ran_in_thread["off_loop"] = False except RuntimeError: ran_in_thread["off_loop"] = True return object() monkeypatch.setattr(client_module, "create_channel", fake_create_channel) monkeypatch.setattr(client_module.pb_grpc, "MxAccessGatewayStub", lambda channel: object()) await GatewayClient.connect(endpoint="gateway.example:5001") assert ran_in_thread["off_loop"] is True @pytest.mark.asyncio async def test_galaxy_connect_runs_create_channel_off_the_event_loop( monkeypatch: pytest.MonkeyPatch, ) -> None: """GalaxyRepositoryClient.connect must also run the probe off the loop (Client.Python-028).""" ran_in_thread: dict[str, bool] = {} def fake_create_channel(options: ClientOptions) -> object: try: asyncio.get_running_loop() ran_in_thread["off_loop"] = False except RuntimeError: ran_in_thread["off_loop"] = True return object() monkeypatch.setattr(galaxy_module, "create_channel", fake_create_channel) monkeypatch.setattr( galaxy_module.galaxy_pb_grpc, "GalaxyRepositoryStub", lambda channel: object() ) await GalaxyRepositoryClient.connect(endpoint="gateway.example:5001") assert ran_in_thread["off_loop"] is True @pytest.mark.asyncio async def test_session_helpers_send_auth_metadata_and_preserve_raw_replies() -> None: stub = FakeGatewayStub() client = await GatewayClient.connect( ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), stub=stub, ) session = await client.open_session(client_session_name="pytest") server_handle = await session.register("pytest-client") item_handle = await session.add_item(server_handle, "Object.Attribute") await session.advise(server_handle, item_handle) assert session.session_id == "session-1" assert server_handle == 12 assert item_handle == 34 assert stub.open_session.metadata == (("authorization", "Bearer mxgw_test_secret"),) assert stub.invoke.requests[0].command.register.client_name == "pytest-client" assert stub.invoke.requests[1].command.add_item.item_definition == "Object.Attribute" assert stub.invoke.requests[2].command.advise.item_handle == 34 @pytest.mark.asyncio async def test_mxaccess_error_preserves_raw_reply() -> None: stub = FakeGatewayStub() failure_reply = pb.MxCommandReply( session_id="session-1", kind=pb.MX_COMMAND_KIND_WRITE, protocol_status=pb.ProtocolStatus( code=pb.PROTOCOL_STATUS_CODE_MXACCESS_FAILURE, message="MXAccess rejected write.", ), hresult=-1, ) stub.invoke.replies = [failure_reply] client = await GatewayClient.connect( ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), stub=stub, ) session = await client.open_session() with pytest.raises(MxAccessError) as captured: await session.write(12, 34, 123) assert captured.value.raw_reply is failure_reply @pytest.mark.asyncio async def test_subscribe_bulk_sends_one_bulk_command_and_returns_results() -> None: stub = FakeGatewayStub() bulk_reply = pb.MxCommandReply( session_id="session-1", kind=pb.MX_COMMAND_KIND_SUBSCRIBE_BULK, protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), subscribe_bulk=pb.BulkSubscribeReply( results=[ pb.SubscribeResult( server_handle=12, tag_address="Area001.Pump001.Speed", item_handle=34, was_successful=True, ), ], ), ) stub.invoke.replies = [bulk_reply] client = await GatewayClient.connect( ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), stub=stub, ) session = await client.open_session() results = await session.subscribe_bulk(12, ["Area001.Pump001.Speed"]) assert results[0].item_handle == 34 assert len(stub.invoke.requests) == 1 assert stub.invoke.requests[0].command.kind == pb.MX_COMMAND_KIND_SUBSCRIBE_BULK assert list(stub.invoke.requests[0].command.subscribe_bulk.tag_addresses) == [ "Area001.Pump001.Speed", ] @pytest.mark.asyncio async def test_stream_events_cancels_underlying_call_when_closed() -> None: stream = FakeStream( [ pb.MxEvent( session_id="session-1", worker_sequence=1, family=pb.MX_EVENT_FAMILY_ON_DATA_CHANGE, ), ], ) stub = FakeGatewayStub(stream=stream) client = await GatewayClient.connect( ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), stub=stub, ) session = await client.open_session() events = session.stream_events() first = await anext(events) await events.aclose() assert first.worker_sequence == 1 assert stream.cancelled assert stub.stream_metadata == (("authorization", "Bearer mxgw_test_secret"),) @pytest.mark.asyncio async def test_unary_task_cancellation_reaches_fake_call() -> None: blocking = BlockingCancellableUnary() stub = FakeGatewayStub() stub.OpenSession = blocking client = await GatewayClient.connect( ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), stub=stub, ) task = asyncio.create_task(client.open_session()) await blocking.started.wait() task.cancel() with pytest.raises(asyncio.CancelledError): await task assert blocking.call is not None assert blocking.call.cancelled class FakeGatewayStub: def __init__(self, stream: "FakeStream | None" = None) -> None: self.open_session = FakeUnary( [ pb.OpenSessionReply( session_id="session-1", protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), ), ], ) self.close_session = FakeUnary( [ pb.CloseSessionReply( session_id="session-1", protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), ), ], ) self.invoke = FakeUnary( [ pb.MxCommandReply( session_id="session-1", kind=pb.MX_COMMAND_KIND_REGISTER, protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), register=pb.RegisterReply(server_handle=12), ), pb.MxCommandReply( session_id="session-1", kind=pb.MX_COMMAND_KIND_ADD_ITEM, protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), add_item=pb.AddItemReply(item_handle=34), ), pb.MxCommandReply( session_id="session-1", kind=pb.MX_COMMAND_KIND_ADVISE, protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), ), ], ) self.OpenSession = self.open_session self.CloseSession = self.close_session self.Invoke = self.invoke self._stream = stream or FakeStream([]) self.stream_metadata: tuple[tuple[str, str], ...] | None = None def StreamEvents( self, request: pb.StreamEventsRequest, *, metadata: tuple[tuple[str, str], ...], ) -> "FakeStream": self.stream_request = request self.stream_metadata = metadata return self._stream class FakeUnary: def __init__(self, replies: list[Any]) -> None: self.replies = replies self.requests: list[Any] = [] self.metadata: tuple[tuple[str, str], ...] | None = None async def __call__( self, request: Any, *, metadata: tuple[tuple[str, str], ...], ) -> Any: self.requests.append(request) self.metadata = metadata return self.replies.pop(0) class BlockingCancellableUnary: def __init__(self) -> None: self.started = asyncio.Event() self.call: BlockingCall | None = None def __call__(self, *_args: Any, **_kwargs: Any) -> "BlockingCall": self.call = BlockingCall(self.started) return self.call class BlockingCall: def __init__(self, started: asyncio.Event) -> None: self.started = started self.cancelled = False def __await__(self): return self._wait().__await__() async def _wait(self) -> Any: self.started.set() try: await asyncio.Event().wait() except asyncio.CancelledError: raise def cancel(self) -> None: self.cancelled = True class FakeStream: def __init__(self, events: list[pb.MxEvent]) -> None: self._events = events self.cancelled = False def __aiter__(self) -> "FakeStream": return self async def __anext__(self) -> pb.MxEvent: if not self._events: await asyncio.sleep(3600) return self._events.pop(0) def cancel(self) -> None: self.cancelled = True