226 lines
6.9 KiB
Python
226 lines
6.9 KiB
Python
"""Tests for the async client and session wrappers."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from typing import Any
|
|
|
|
import pytest
|
|
|
|
from mxgateway import ClientOptions, GatewayClient, MxAccessError
|
|
from mxgateway.generated import mxaccess_gateway_pb2 as pb
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_helpers_send_auth_metadata_and_preserve_raw_replies() -> None:
|
|
stub = FakeGatewayStub()
|
|
client = await GatewayClient.connect(
|
|
ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True),
|
|
stub=stub,
|
|
)
|
|
|
|
session = await client.open_session(client_session_name="pytest")
|
|
server_handle = await session.register("pytest-client")
|
|
item_handle = await session.add_item(server_handle, "Object.Attribute")
|
|
await session.advise(server_handle, item_handle)
|
|
|
|
assert session.session_id == "session-1"
|
|
assert server_handle == 12
|
|
assert item_handle == 34
|
|
assert stub.open_session.metadata == (("authorization", "Bearer mxgw_test_secret"),)
|
|
assert stub.invoke.requests[0].command.register.client_name == "pytest-client"
|
|
assert stub.invoke.requests[1].command.add_item.item_definition == "Object.Attribute"
|
|
assert stub.invoke.requests[2].command.advise.item_handle == 34
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mxaccess_error_preserves_raw_reply() -> None:
|
|
stub = FakeGatewayStub()
|
|
failure_reply = pb.MxCommandReply(
|
|
session_id="session-1",
|
|
kind=pb.MX_COMMAND_KIND_WRITE,
|
|
protocol_status=pb.ProtocolStatus(
|
|
code=pb.PROTOCOL_STATUS_CODE_MXACCESS_FAILURE,
|
|
message="MXAccess rejected write.",
|
|
),
|
|
hresult=-1,
|
|
)
|
|
stub.invoke.replies = [failure_reply]
|
|
client = await GatewayClient.connect(
|
|
ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True),
|
|
stub=stub,
|
|
)
|
|
session = await client.open_session()
|
|
|
|
with pytest.raises(MxAccessError) as captured:
|
|
await session.write(12, 34, 123)
|
|
|
|
assert captured.value.raw_reply is failure_reply
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_events_cancels_underlying_call_when_closed() -> None:
|
|
stream = FakeStream(
|
|
[
|
|
pb.MxEvent(
|
|
session_id="session-1",
|
|
worker_sequence=1,
|
|
family=pb.MX_EVENT_FAMILY_ON_DATA_CHANGE,
|
|
),
|
|
],
|
|
)
|
|
stub = FakeGatewayStub(stream=stream)
|
|
client = await GatewayClient.connect(
|
|
ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True),
|
|
stub=stub,
|
|
)
|
|
session = await client.open_session()
|
|
|
|
events = session.stream_events()
|
|
first = await anext(events)
|
|
await events.aclose()
|
|
|
|
assert first.worker_sequence == 1
|
|
assert stream.cancelled
|
|
assert stub.stream_metadata == (("authorization", "Bearer mxgw_test_secret"),)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unary_task_cancellation_reaches_fake_call() -> None:
|
|
blocking = BlockingCancellableUnary()
|
|
stub = FakeGatewayStub()
|
|
stub.OpenSession = blocking
|
|
client = await GatewayClient.connect(
|
|
ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True),
|
|
stub=stub,
|
|
)
|
|
|
|
task = asyncio.create_task(client.open_session())
|
|
await blocking.started.wait()
|
|
task.cancel()
|
|
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await task
|
|
|
|
assert blocking.call is not None
|
|
assert blocking.call.cancelled
|
|
|
|
|
|
class FakeGatewayStub:
|
|
def __init__(self, stream: "FakeStream | None" = None) -> None:
|
|
self.open_session = FakeUnary(
|
|
[
|
|
pb.OpenSessionReply(
|
|
session_id="session-1",
|
|
protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK),
|
|
),
|
|
],
|
|
)
|
|
self.close_session = FakeUnary(
|
|
[
|
|
pb.CloseSessionReply(
|
|
session_id="session-1",
|
|
protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK),
|
|
),
|
|
],
|
|
)
|
|
self.invoke = FakeUnary(
|
|
[
|
|
pb.MxCommandReply(
|
|
session_id="session-1",
|
|
kind=pb.MX_COMMAND_KIND_REGISTER,
|
|
protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK),
|
|
register=pb.RegisterReply(server_handle=12),
|
|
),
|
|
pb.MxCommandReply(
|
|
session_id="session-1",
|
|
kind=pb.MX_COMMAND_KIND_ADD_ITEM,
|
|
protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK),
|
|
add_item=pb.AddItemReply(item_handle=34),
|
|
),
|
|
pb.MxCommandReply(
|
|
session_id="session-1",
|
|
kind=pb.MX_COMMAND_KIND_ADVISE,
|
|
protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK),
|
|
),
|
|
],
|
|
)
|
|
self.OpenSession = self.open_session
|
|
self.CloseSession = self.close_session
|
|
self.Invoke = self.invoke
|
|
self._stream = stream or FakeStream([])
|
|
self.stream_metadata: tuple[tuple[str, str], ...] | None = None
|
|
|
|
def StreamEvents(
|
|
self,
|
|
request: pb.StreamEventsRequest,
|
|
*,
|
|
metadata: tuple[tuple[str, str], ...],
|
|
) -> "FakeStream":
|
|
self.stream_request = request
|
|
self.stream_metadata = metadata
|
|
return self._stream
|
|
|
|
|
|
class FakeUnary:
|
|
def __init__(self, replies: list[Any]) -> None:
|
|
self.replies = replies
|
|
self.requests: list[Any] = []
|
|
self.metadata: tuple[tuple[str, str], ...] | None = None
|
|
|
|
async def __call__(
|
|
self,
|
|
request: Any,
|
|
*,
|
|
metadata: tuple[tuple[str, str], ...],
|
|
) -> Any:
|
|
self.requests.append(request)
|
|
self.metadata = metadata
|
|
return self.replies.pop(0)
|
|
|
|
|
|
class BlockingCancellableUnary:
|
|
def __init__(self) -> None:
|
|
self.started = asyncio.Event()
|
|
self.call: BlockingCall | None = None
|
|
|
|
def __call__(self, *_args: Any, **_kwargs: Any) -> "BlockingCall":
|
|
self.call = BlockingCall(self.started)
|
|
return self.call
|
|
|
|
|
|
class BlockingCall:
|
|
def __init__(self, started: asyncio.Event) -> None:
|
|
self.started = started
|
|
self.cancelled = False
|
|
|
|
def __await__(self):
|
|
return self._wait().__await__()
|
|
|
|
async def _wait(self) -> Any:
|
|
self.started.set()
|
|
try:
|
|
await asyncio.Event().wait()
|
|
except asyncio.CancelledError:
|
|
raise
|
|
|
|
def cancel(self) -> None:
|
|
self.cancelled = True
|
|
|
|
|
|
class FakeStream:
|
|
def __init__(self, events: list[pb.MxEvent]) -> None:
|
|
self._events = events
|
|
self.cancelled = False
|
|
|
|
def __aiter__(self) -> "FakeStream":
|
|
return self
|
|
|
|
async def __anext__(self) -> pb.MxEvent:
|
|
if not self._events:
|
|
await asyncio.sleep(3600)
|
|
return self._events.pop(0)
|
|
|
|
def cancel(self) -> None:
|
|
self.cancelled = True
|