"""Async session wrapper for MXAccess Gateway commands.""" from __future__ import annotations from collections.abc import AsyncIterator, Sequence from typing import TYPE_CHECKING from .errors import ensure_mxaccess_success from .generated import mxaccess_gateway_pb2 as pb from .values import MxValueInput, to_mx_value if TYPE_CHECKING: from .client import GatewayClient MAX_BULK_ITEMS = 1000 class Session: """A single gateway-backed MXAccess session.""" def __init__( self, *, client: "GatewayClient", session_id: str, open_reply: pb.OpenSessionReply | None = None, ) -> None: """Initialize a session bound to a client and gateway session id.""" self.client = client self.session_id = session_id self.open_reply = open_reply self._closed = False async def __aenter__(self) -> "Session": """Return self to support ``async with`` usage.""" return self async def __aexit__(self, *_exc_info: object) -> None: """Close the session when leaving an ``async with`` block.""" await self.close() async def close(self, *, client_correlation_id: str = "") -> pb.CloseSessionReply: """Close the gateway session. Repeated calls return a local closed reply. Idempotent, including under concurrent calls: ``_closed`` is set before the ``CloseSession`` RPC is awaited so a second coroutine entering ``close()`` while the first RPC is in flight returns the local closed reply instead of issuing a second ``CloseSession``. """ if self._closed: return pb.CloseSessionReply( session_id=self.session_id, final_state=pb.SESSION_STATE_CLOSED, protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), ) self._closed = True return await self.client.close_session_raw( pb.CloseSessionRequest( session_id=self.session_id, client_correlation_id=client_correlation_id, ), ) async def invoke(self, command: pb.MxCommand, *, correlation_id: str = "") -> pb.MxCommandReply: """Invoke a raw command and enforce gateway and MXAccess success.""" reply = await self.invoke_raw(command, correlation_id=correlation_id) return ensure_mxaccess_success("invoke", reply) async def invoke_raw( self, command: pb.MxCommand, *, correlation_id: str = "", ) -> pb.MxCommandReply: """Invoke a raw command and preserve the raw reply. Enforces gateway protocol success only — unlike :meth:`invoke`, it does not run MXAccess-failure detection. An MXAccess HRESULT or ``MxStatusProxy`` status failure is left embedded in the returned reply and no ``MxAccessError`` is raised. Parity-test callers must inspect ``protocol_status``, ``hresult``, and ``statuses`` on the reply themselves. """ return await self.client.invoke_raw( pb.MxCommandRequest( session_id=self.session_id, client_correlation_id=correlation_id, command=command, ), ) async def register(self, client_name: str, *, correlation_id: str = "") -> int: """Invoke MXAccess `Register` and return the new `ServerHandle`.""" reply = await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_REGISTER, register=pb.RegisterCommand(client_name=client_name), ), correlation_id=correlation_id, ) return reply.register.server_handle async def unregister(self, server_handle: int, *, correlation_id: str = "") -> None: """Invoke MXAccess `Unregister` for a previously registered `ServerHandle`.""" await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_UNREGISTER, unregister=pb.UnregisterCommand(server_handle=server_handle), ), correlation_id=correlation_id, ) async def remove_item( self, server_handle: int, item_handle: int, *, correlation_id: str = "", ) -> None: """Invoke MXAccess `RemoveItem` for the given `ItemHandle`.""" await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_REMOVE_ITEM, remove_item=pb.RemoveItemCommand( server_handle=server_handle, item_handle=item_handle, ), ), correlation_id=correlation_id, ) async def add_item( self, server_handle: int, item_definition: str, *, correlation_id: str = "", ) -> int: """Invoke MXAccess `AddItem` and return the new `ItemHandle`.""" reply = await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_ADD_ITEM, add_item=pb.AddItemCommand( server_handle=server_handle, item_definition=item_definition, ), ), correlation_id=correlation_id, ) return reply.add_item.item_handle async def add_item2( self, server_handle: int, item_definition: str, item_context: str, *, correlation_id: str = "", ) -> int: """Invoke MXAccess `AddItem2` with item context and return the new `ItemHandle`.""" reply = await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_ADD_ITEM2, add_item2=pb.AddItem2Command( server_handle=server_handle, item_definition=item_definition, item_context=item_context, ), ), correlation_id=correlation_id, ) return reply.add_item2.item_handle async def advise( self, server_handle: int, item_handle: int, *, correlation_id: str = "", ) -> None: """Invoke MXAccess `Advise` to subscribe an existing `ItemHandle` to events.""" await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_ADVISE, advise=pb.AdviseCommand( server_handle=server_handle, item_handle=item_handle, ), ), correlation_id=correlation_id, ) async def unadvise( self, server_handle: int, item_handle: int, *, correlation_id: str = "", ) -> None: """Invoke MXAccess `UnAdvise` to stop event delivery for an `ItemHandle`.""" await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_UN_ADVISE, un_advise=pb.UnAdviseCommand( server_handle=server_handle, item_handle=item_handle, ), ), correlation_id=correlation_id, ) async def add_item_bulk( self, server_handle: int, tag_addresses: Sequence[str], *, correlation_id: str = "", ) -> list[pb.SubscribeResult]: """Invoke MXAccess `AddItemBulk` and return one result per tag address.""" if tag_addresses is None: raise TypeError("tag_addresses is required") _ensure_bulk_size("tag_addresses", len(tag_addresses)) reply = await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_ADD_ITEM_BULK, add_item_bulk=pb.AddItemBulkCommand( server_handle=server_handle, tag_addresses=tag_addresses, ), ), correlation_id=correlation_id, ) return list(reply.add_item_bulk.results) async def advise_item_bulk( self, server_handle: int, item_handles: Sequence[int], *, correlation_id: str = "", ) -> list[pb.SubscribeResult]: """Invoke MXAccess `AdviseItemBulk` and return one result per item handle.""" if item_handles is None: raise TypeError("item_handles is required") _ensure_bulk_size("item_handles", len(item_handles)) reply = await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_ADVISE_ITEM_BULK, advise_item_bulk=pb.AdviseItemBulkCommand( server_handle=server_handle, item_handles=item_handles, ), ), correlation_id=correlation_id, ) return list(reply.advise_item_bulk.results) async def remove_item_bulk( self, server_handle: int, item_handles: Sequence[int], *, correlation_id: str = "", ) -> list[pb.SubscribeResult]: """Invoke MXAccess `RemoveItemBulk` and return one result per item handle.""" if item_handles is None: raise TypeError("item_handles is required") _ensure_bulk_size("item_handles", len(item_handles)) reply = await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_REMOVE_ITEM_BULK, remove_item_bulk=pb.RemoveItemBulkCommand( server_handle=server_handle, item_handles=item_handles, ), ), correlation_id=correlation_id, ) return list(reply.remove_item_bulk.results) async def unadvise_item_bulk( self, server_handle: int, item_handles: Sequence[int], *, correlation_id: str = "", ) -> list[pb.SubscribeResult]: """Invoke MXAccess `UnAdviseItemBulk` and return one result per item handle.""" if item_handles is None: raise TypeError("item_handles is required") _ensure_bulk_size("item_handles", len(item_handles)) reply = await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_UN_ADVISE_ITEM_BULK, un_advise_item_bulk=pb.UnAdviseItemBulkCommand( server_handle=server_handle, item_handles=item_handles, ), ), correlation_id=correlation_id, ) return list(reply.un_advise_item_bulk.results) async def subscribe_bulk( self, server_handle: int, tag_addresses: Sequence[str], *, correlation_id: str = "", ) -> list[pb.SubscribeResult]: """Invoke MXAccess `SubscribeBulk` and return one result per tag address.""" if tag_addresses is None: raise TypeError("tag_addresses is required") _ensure_bulk_size("tag_addresses", len(tag_addresses)) reply = await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_SUBSCRIBE_BULK, subscribe_bulk=pb.SubscribeBulkCommand( server_handle=server_handle, tag_addresses=tag_addresses, ), ), correlation_id=correlation_id, ) return list(reply.subscribe_bulk.results) async def unsubscribe_bulk( self, server_handle: int, item_handles: Sequence[int], *, correlation_id: str = "", ) -> list[pb.SubscribeResult]: """Invoke MXAccess `UnsubscribeBulk` and return one result per item handle.""" if item_handles is None: raise TypeError("item_handles is required") _ensure_bulk_size("item_handles", len(item_handles)) reply = await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_UNSUBSCRIBE_BULK, unsubscribe_bulk=pb.UnsubscribeBulkCommand( server_handle=server_handle, item_handles=item_handles, ), ), correlation_id=correlation_id, ) return list(reply.unsubscribe_bulk.results) async def write( self, server_handle: int, item_handle: int, value: MxValueInput, *, user_id: int = 0, correlation_id: str = "", ) -> None: """Invoke MXAccess `Write` for an `ItemHandle` with the converted value.""" await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_WRITE, write=pb.WriteCommand( server_handle=server_handle, item_handle=item_handle, value=to_mx_value(value), user_id=user_id, ), ), correlation_id=correlation_id, ) async def write2( self, server_handle: int, item_handle: int, value: MxValueInput, timestamp_value: MxValueInput, *, user_id: int = 0, correlation_id: str = "", ) -> None: """Invoke MXAccess `Write2` with both a value and a client-supplied timestamp.""" await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_WRITE2, write2=pb.Write2Command( server_handle=server_handle, item_handle=item_handle, value=to_mx_value(value), timestamp_value=to_mx_value(timestamp_value), user_id=user_id, ), ), correlation_id=correlation_id, ) def stream_events( self, *, after_worker_sequence: int = 0, ) -> AsyncIterator[pb.MxEvent]: """Return an async iterator of `MxEvent` messages for this session.""" return self.client.stream_events_raw( pb.StreamEventsRequest( session_id=self.session_id, after_worker_sequence=after_worker_sequence, ), ) def _ensure_bulk_size(name: str, count: int) -> None: if count > MAX_BULK_ITEMS: raise ValueError(f"{name} bulk commands are limited to {MAX_BULK_ITEMS} item(s)")