"""Async session wrapper for MXAccess Gateway commands.""" from __future__ import annotations from collections.abc import AsyncIterator from .errors import ensure_mxaccess_success from .generated import mxaccess_gateway_pb2 as pb from .values import MxValueInput, to_mx_value class Session: """A single gateway-backed MXAccess session.""" def __init__( self, *, client: "GatewayClient", session_id: str, open_reply: pb.OpenSessionReply | None = None, ) -> None: self.client = client self.session_id = session_id self.open_reply = open_reply self._closed = False async def __aenter__(self) -> "Session": return self async def __aexit__(self, *_exc_info: object) -> None: await self.close() async def close(self, *, client_correlation_id: str = "") -> pb.CloseSessionReply: """Close the gateway session. Repeated calls return a local closed reply.""" if self._closed: return pb.CloseSessionReply( session_id=self.session_id, final_state=pb.SESSION_STATE_CLOSED, protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), ) self._closed = True return await self.client.close_session_raw( pb.CloseSessionRequest( session_id=self.session_id, client_correlation_id=client_correlation_id, ), ) async def invoke(self, command: pb.MxCommand, *, correlation_id: str = "") -> pb.MxCommandReply: """Invoke a raw command and enforce gateway and MXAccess success.""" reply = await self.invoke_raw(command, correlation_id=correlation_id) return ensure_mxaccess_success("invoke", reply) async def invoke_raw( self, command: pb.MxCommand, *, correlation_id: str = "", ) -> pb.MxCommandReply: """Invoke a raw command and preserve the raw reply.""" return await self.client.invoke_raw( pb.MxCommandRequest( session_id=self.session_id, client_correlation_id=correlation_id, command=command, ), ) async def register(self, client_name: str, *, correlation_id: str = "") -> int: reply = await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_REGISTER, register=pb.RegisterCommand(client_name=client_name), ), correlation_id=correlation_id, ) return reply.register.server_handle async def unregister(self, server_handle: int, *, correlation_id: str = "") -> None: await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_UNREGISTER, unregister=pb.UnregisterCommand(server_handle=server_handle), ), correlation_id=correlation_id, ) async def add_item( self, server_handle: int, item_definition: str, *, correlation_id: str = "", ) -> int: reply = await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_ADD_ITEM, add_item=pb.AddItemCommand( server_handle=server_handle, item_definition=item_definition, ), ), correlation_id=correlation_id, ) return reply.add_item.item_handle async def add_item2( self, server_handle: int, item_definition: str, item_context: str, *, correlation_id: str = "", ) -> int: reply = await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_ADD_ITEM2, add_item2=pb.AddItem2Command( server_handle=server_handle, item_definition=item_definition, item_context=item_context, ), ), correlation_id=correlation_id, ) return reply.add_item2.item_handle async def advise( self, server_handle: int, item_handle: int, *, correlation_id: str = "", ) -> None: await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_ADVISE, advise=pb.AdviseCommand( server_handle=server_handle, item_handle=item_handle, ), ), correlation_id=correlation_id, ) async def write( self, server_handle: int, item_handle: int, value: MxValueInput, *, user_id: int = 0, correlation_id: str = "", ) -> None: await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_WRITE, write=pb.WriteCommand( server_handle=server_handle, item_handle=item_handle, value=to_mx_value(value), user_id=user_id, ), ), correlation_id=correlation_id, ) async def write2( self, server_handle: int, item_handle: int, value: MxValueInput, timestamp_value: MxValueInput, *, user_id: int = 0, correlation_id: str = "", ) -> None: await self.invoke( pb.MxCommand( kind=pb.MX_COMMAND_KIND_WRITE2, write2=pb.Write2Command( server_handle=server_handle, item_handle=item_handle, value=to_mx_value(value), timestamp_value=to_mx_value(timestamp_value), user_id=user_id, ), ), correlation_id=correlation_id, ) def stream_events( self, *, after_worker_sequence: int = 0, ) -> AsyncIterator[pb.MxEvent]: return self.client.stream_events_raw( pb.StreamEventsRequest( session_id=self.session_id, after_worker_sequence=after_worker_sequence, ), ) from .client import GatewayClient # noqa: E402