"""Tests for the async client and session wrappers.""" from __future__ import annotations import asyncio from typing import Any import pytest from mxgateway import ClientOptions, GatewayClient, MxAccessError from mxgateway.generated import mxaccess_gateway_pb2 as pb @pytest.mark.asyncio async def test_session_helpers_send_auth_metadata_and_preserve_raw_replies() -> None: stub = FakeGatewayStub() client = await GatewayClient.connect( ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), stub=stub, ) session = await client.open_session(client_session_name="pytest") server_handle = await session.register("pytest-client") item_handle = await session.add_item(server_handle, "Object.Attribute") await session.advise(server_handle, item_handle) assert session.session_id == "session-1" assert server_handle == 12 assert item_handle == 34 assert stub.open_session.metadata == (("authorization", "Bearer mxgw_test_secret"),) assert stub.invoke.requests[0].command.register.client_name == "pytest-client" assert stub.invoke.requests[1].command.add_item.item_definition == "Object.Attribute" assert stub.invoke.requests[2].command.advise.item_handle == 34 @pytest.mark.asyncio async def test_mxaccess_error_preserves_raw_reply() -> None: stub = FakeGatewayStub() failure_reply = pb.MxCommandReply( session_id="session-1", kind=pb.MX_COMMAND_KIND_WRITE, protocol_status=pb.ProtocolStatus( code=pb.PROTOCOL_STATUS_CODE_MXACCESS_FAILURE, message="MXAccess rejected write.", ), hresult=-1, ) stub.invoke.replies = [failure_reply] client = await GatewayClient.connect( ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), stub=stub, ) session = await client.open_session() with pytest.raises(MxAccessError) as captured: await session.write(12, 34, 123) assert captured.value.raw_reply is failure_reply @pytest.mark.asyncio async def test_subscribe_bulk_sends_one_bulk_command_and_returns_results() -> None: stub = FakeGatewayStub() bulk_reply = pb.MxCommandReply( session_id="session-1", kind=pb.MX_COMMAND_KIND_SUBSCRIBE_BULK, protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), subscribe_bulk=pb.BulkSubscribeReply( results=[ pb.SubscribeResult( server_handle=12, tag_address="Area001.Pump001.Speed", item_handle=34, was_successful=True, ), ], ), ) stub.invoke.replies = [bulk_reply] client = await GatewayClient.connect( ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), stub=stub, ) session = await client.open_session() results = await session.subscribe_bulk(12, ["Area001.Pump001.Speed"]) assert results[0].item_handle == 34 assert len(stub.invoke.requests) == 1 assert stub.invoke.requests[0].command.kind == pb.MX_COMMAND_KIND_SUBSCRIBE_BULK assert list(stub.invoke.requests[0].command.subscribe_bulk.tag_addresses) == [ "Area001.Pump001.Speed", ] @pytest.mark.asyncio async def test_write_bulk_sends_one_bulk_command_and_returns_per_entry_results() -> None: stub = FakeGatewayStub() bulk_reply = pb.MxCommandReply( session_id="session-1", kind=pb.MX_COMMAND_KIND_WRITE_BULK, protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), write_bulk=pb.BulkWriteReply( results=[ pb.BulkWriteResult(server_handle=12, item_handle=901, was_successful=True), pb.BulkWriteResult(server_handle=12, item_handle=902, was_successful=False, error_message="invalid handle"), ], ), ) stub.invoke.replies = [bulk_reply] client = await GatewayClient.connect( ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), stub=stub, ) session = await client.open_session() entries = [ pb.WriteBulkEntry(item_handle=901, user_id=5, value=pb.MxValue(data_type=pb.MX_DATA_TYPE_INTEGER, int32_value=11)), pb.WriteBulkEntry(item_handle=902, user_id=5, value=pb.MxValue(data_type=pb.MX_DATA_TYPE_INTEGER, int32_value=22)), ] results = await session.write_bulk(12, entries) assert len(results) == 2 assert results[0].was_successful is True assert results[1].was_successful is False sent = stub.invoke.requests[0].command assert sent.kind == pb.MX_COMMAND_KIND_WRITE_BULK assert len(sent.write_bulk.entries) == 2 @pytest.mark.asyncio async def test_read_bulk_forwards_timeout_and_unpacks_cached_flag() -> None: stub = FakeGatewayStub() bulk_reply = pb.MxCommandReply( session_id="session-1", kind=pb.MX_COMMAND_KIND_READ_BULK, protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), read_bulk=pb.BulkReadReply( results=[ pb.BulkReadResult( server_handle=12, tag_address="Area001.Pump001.Speed", item_handle=34, was_successful=True, was_cached=True, value=pb.MxValue(data_type=pb.MX_DATA_TYPE_INTEGER, int32_value=99), ), ], ), ) stub.invoke.replies = [bulk_reply] client = await GatewayClient.connect( ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), stub=stub, ) session = await client.open_session() results = await session.read_bulk(12, ["Area001.Pump001.Speed"], timeout_ms=750) assert len(results) == 1 assert results[0].was_cached is True assert results[0].value.int32_value == 99 sent = stub.invoke.requests[0].command assert sent.kind == pb.MX_COMMAND_KIND_READ_BULK assert list(sent.read_bulk.tag_addresses) == ["Area001.Pump001.Speed"] assert sent.read_bulk.timeout_ms == 750 @pytest.mark.asyncio async def test_stream_events_cancels_underlying_call_when_closed() -> None: stream = FakeStream( [ pb.MxEvent( session_id="session-1", worker_sequence=1, family=pb.MX_EVENT_FAMILY_ON_DATA_CHANGE, ), ], ) stub = FakeGatewayStub(stream=stream) client = await GatewayClient.connect( ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), stub=stub, ) session = await client.open_session() events = session.stream_events() first = await anext(events) await events.aclose() assert first.worker_sequence == 1 assert stream.cancelled assert stub.stream_metadata == (("authorization", "Bearer mxgw_test_secret"),) @pytest.mark.asyncio async def test_unary_task_cancellation_reaches_fake_call() -> None: blocking = BlockingCancellableUnary() stub = FakeGatewayStub() stub.OpenSession = blocking client = await GatewayClient.connect( ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), stub=stub, ) task = asyncio.create_task(client.open_session()) await blocking.started.wait() task.cancel() with pytest.raises(asyncio.CancelledError): await task assert blocking.call is not None assert blocking.call.cancelled class FakeGatewayStub: def __init__(self, stream: "FakeStream | None" = None) -> None: self.open_session = FakeUnary( [ pb.OpenSessionReply( session_id="session-1", protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), ), ], ) self.close_session = FakeUnary( [ pb.CloseSessionReply( session_id="session-1", protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), ), ], ) self.invoke = FakeUnary( [ pb.MxCommandReply( session_id="session-1", kind=pb.MX_COMMAND_KIND_REGISTER, protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), register=pb.RegisterReply(server_handle=12), ), pb.MxCommandReply( session_id="session-1", kind=pb.MX_COMMAND_KIND_ADD_ITEM, protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), add_item=pb.AddItemReply(item_handle=34), ), pb.MxCommandReply( session_id="session-1", kind=pb.MX_COMMAND_KIND_ADVISE, protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), ), ], ) self.OpenSession = self.open_session self.CloseSession = self.close_session self.Invoke = self.invoke self._stream = stream or FakeStream([]) self.stream_metadata: tuple[tuple[str, str], ...] | None = None def StreamEvents( self, request: pb.StreamEventsRequest, *, metadata: tuple[tuple[str, str], ...], ) -> "FakeStream": self.stream_request = request self.stream_metadata = metadata return self._stream class FakeUnary: def __init__(self, replies: list[Any]) -> None: self.replies = replies self.requests: list[Any] = [] self.metadata: tuple[tuple[str, str], ...] | None = None async def __call__( self, request: Any, *, metadata: tuple[tuple[str, str], ...], ) -> Any: self.requests.append(request) self.metadata = metadata return self.replies.pop(0) class BlockingCancellableUnary: def __init__(self) -> None: self.started = asyncio.Event() self.call: BlockingCall | None = None def __call__(self, *_args: Any, **_kwargs: Any) -> "BlockingCall": self.call = BlockingCall(self.started) return self.call class BlockingCall: def __init__(self, started: asyncio.Event) -> None: self.started = started self.cancelled = False def __await__(self): return self._wait().__await__() async def _wait(self) -> Any: self.started.set() try: await asyncio.Event().wait() except asyncio.CancelledError: raise def cancel(self) -> None: self.cancelled = True class FakeStream: def __init__(self, events: list[pb.MxEvent]) -> None: self._events = events self.cancelled = False def __aiter__(self) -> "FakeStream": return self async def __anext__(self) -> pb.MxEvent: if not self._events: await asyncio.sleep(3600) return self._events.pop(0) def cancel(self) -> None: self.cancelled = True