From b57662aae7dc97f55b36ccc34befdfd3cd87a83b Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 26 Apr 2026 20:46:18 -0400 Subject: [PATCH] Issue #46: implement Python async client values errors and CLI --- clients/python/README.md | 62 ++- clients/python/src/mxgateway/__init__.py | 35 +- clients/python/src/mxgateway/auth.py | 58 +++ clients/python/src/mxgateway/client.py | 165 +++++++ clients/python/src/mxgateway/errors.py | 157 +++++++ clients/python/src/mxgateway/options.py | 59 +++ clients/python/src/mxgateway/session.py | 209 +++++++++ clients/python/src/mxgateway/values.py | 234 ++++++++++ clients/python/src/mxgateway_cli/commands.py | 439 ++++++++++++++++++- clients/python/tests/test_auth_options.py | 103 +++++ clients/python/tests/test_cli.py | 49 ++- clients/python/tests/test_client_session.py | 225 ++++++++++ clients/python/tests/test_errors.py | 49 +++ clients/python/tests/test_values.py | 49 +++ 14 files changed, 1883 insertions(+), 10 deletions(-) create mode 100644 clients/python/src/mxgateway/auth.py create mode 100644 clients/python/src/mxgateway/client.py create mode 100644 clients/python/src/mxgateway/errors.py create mode 100644 clients/python/src/mxgateway/options.py create mode 100644 clients/python/src/mxgateway/session.py create mode 100644 clients/python/src/mxgateway/values.py create mode 100644 clients/python/tests/test_auth_options.py create mode 100644 clients/python/tests/test_client_session.py create mode 100644 clients/python/tests/test_errors.py create mode 100644 clients/python/tests/test_values.py diff --git a/clients/python/README.md b/clients/python/README.md index ba8cd5e..2169299 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -1,8 +1,8 @@ # Python Client The Python client package contains generated MXAccess Gateway protobuf -bindings, the `mxgateway` package scaffold, and the `mxgw-py` test CLI -scaffold. The package uses the shared proto inputs documented in +bindings, the async `mxgateway` package, and the `mxgw-py` test CLI. The +package uses the shared proto inputs documented in `../../docs/client-proto-generation.md` so gateway and client contracts stay in sync. @@ -43,15 +43,65 @@ python -m pytest python -m pip wheel . --no-deps --wheel-dir "$env:TEMP\mxgateway-python-wheel" ``` -The scaffold tests import the generated gateway and worker stubs and exercise -the deterministic CLI version output. +The tests import the generated gateway and worker stubs, run fake async gateway +stubs, verify API key metadata, exercise stream cancellation, load shared value +and command fixtures, and check deterministic CLI output. + +## Library Usage + +The library is async-first: + +```python +from mxgateway import GatewayClient + +async with await GatewayClient.connect( + endpoint="localhost:5000", + api_key="mxgw_example", + plaintext=True, +) as client: + session = await client.open_session(client_session_name="python-client") + try: + server_handle = await session.register("python-client") + item_handle = await session.add_item(server_handle, "Object.Attribute") + await session.advise(server_handle, item_handle) + finally: + await session.close() +``` + +`GatewayClient.open_session_raw`, `GatewayClient.invoke_raw`, and +`GatewayClient.stream_events_raw` keep the generated protobuf replies and +events available for parity tests. `Session` helpers call the method-specific +MXAccess commands and preserve raw replies on typed command exceptions. + +Canceling a Python task cancels the client-side gRPC call or stream wait. It +does not abort an in-flight MXAccess COM call inside the worker process. + +## Authentication And TLS + +`ClientOptions.api_key` adds this metadata to unary calls and streams: + +```text +authorization: Bearer +``` + +The client supports plaintext channels for local development, TLS with system +roots, TLS with a custom `ca_file`, and an optional test server name override. +API keys are redacted from option repr output and CLI error output. ## CLI -The scaffold CLI exposes version information: +The CLI emits deterministic JSON for automation: ```powershell mxgw-py version --json +mxgw-py open-session --endpoint localhost:5000 --plaintext --json +mxgw-py register --session-id --client-name python-client --json +mxgw-py add-item --session-id --server-handle 1 --item Object.Attribute --json +mxgw-py advise --session-id --server-handle 1 --item-handle 2 --json +mxgw-py stream-events --session-id --max-events 1 --json +mxgw-py write --session-id --server-handle 1 --item-handle 2 --type int32 --value 123 --json ``` -Additional commands are implemented with the async client/session wrapper work. +Use `--api-key` or `--api-key-env MXGATEWAY_API_KEY` to attach API key +metadata. `smoke` opens a session, registers, adds an item, advises, streams a +bounded event count, and closes the session in a `finally` block. diff --git a/clients/python/src/mxgateway/__init__.py b/clients/python/src/mxgateway/__init__.py index 44c6096..7e248b1 100644 --- a/clients/python/src/mxgateway/__init__.py +++ b/clients/python/src/mxgateway/__init__.py @@ -1,5 +1,38 @@ """MXAccess Gateway Python client package.""" +from .auth import ApiKey, auth_metadata +from .client import GatewayClient +from .errors import ( + MxAccessError, + MxGatewayAuthenticationError, + MxGatewayAuthorizationError, + MxGatewayCommandError, + MxGatewayError, + MxGatewaySessionError, + MxGatewayTransportError, + MxGatewayWorkerError, +) +from .options import ClientOptions +from .session import Session +from .values import MxValueView, from_mx_value, to_mx_value from .version import __version__ -__all__ = ["__version__"] +__all__ = [ + "ApiKey", + "ClientOptions", + "GatewayClient", + "MxAccessError", + "MxGatewayAuthenticationError", + "MxGatewayAuthorizationError", + "MxGatewayCommandError", + "MxGatewayError", + "MxGatewaySessionError", + "MxGatewayTransportError", + "MxGatewayWorkerError", + "MxValueView", + "Session", + "__version__", + "auth_metadata", + "from_mx_value", + "to_mx_value", +] diff --git a/clients/python/src/mxgateway/auth.py b/clients/python/src/mxgateway/auth.py new file mode 100644 index 0000000..3c5d041 --- /dev/null +++ b/clients/python/src/mxgateway/auth.py @@ -0,0 +1,58 @@ +"""Authentication metadata helpers for MXAccess Gateway clients.""" + +from collections.abc import Sequence +from dataclasses import dataclass + +AUTHORIZATION_HEADER = "authorization" +REDACTED = "[redacted]" + + +@dataclass(frozen=True) +class ApiKey: + """API key wrapper that avoids leaking the secret through repr output.""" + + value: str + + def __post_init__(self) -> None: + if not self.value: + raise ValueError("api_key must not be empty") + + def __repr__(self) -> str: + return f"{type(self).__name__}({REDACTED!r})" + + def bearer_value(self) -> str: + return f"Bearer {self.value}" + + +def auth_metadata(api_key: str | ApiKey | None) -> tuple[tuple[str, str], ...]: + """Return gRPC metadata for API key auth.""" + + if api_key is None: + return () + + key = api_key.value if isinstance(api_key, ApiKey) else api_key + if not key: + return () + + return ((AUTHORIZATION_HEADER, f"Bearer {key}"),) + + +def merge_metadata( + api_key: str | ApiKey | None, + metadata: Sequence[tuple[str, str]] | None = None, +) -> tuple[tuple[str, str], ...]: + """Merge caller metadata with API key metadata.""" + + merged = list(metadata or ()) + merged.extend(auth_metadata(api_key)) + return tuple(merged) + + +def redact_secret(text: str, secrets: Sequence[str | None]) -> str: + """Replace known secret values with a stable redaction marker.""" + + redacted = text + for secret in secrets: + if secret: + redacted = redacted.replace(secret, REDACTED) + return redacted diff --git a/clients/python/src/mxgateway/client.py b/clients/python/src/mxgateway/client.py new file mode 100644 index 0000000..bd2719a --- /dev/null +++ b/clients/python/src/mxgateway/client.py @@ -0,0 +1,165 @@ +"""Async MXAccess Gateway client wrapper.""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator, Sequence +from typing import Any + +import grpc + +from .auth import merge_metadata +from .errors import ensure_protocol_success, map_rpc_error +from .generated import mxaccess_gateway_pb2 as pb +from .generated import mxaccess_gateway_pb2_grpc as pb_grpc +from .options import ClientOptions, create_channel + + +class GatewayClient: + """Async client for the public MXAccess Gateway gRPC API.""" + + def __init__( + self, + *, + options: ClientOptions, + stub: Any, + channel: grpc.aio.Channel | None = None, + ) -> None: + self.options = options + self.raw_stub = stub + self._channel = channel + self._closed = False + + @classmethod + async def connect( + cls, + options: ClientOptions | None = None, + *, + endpoint: str | None = None, + api_key: str | None = None, + plaintext: bool = False, + ca_file: str | None = None, + server_name_override: str | None = None, + stub: Any | None = None, + ) -> "GatewayClient": + """Create a client with either a real async channel or a supplied fake stub.""" + + resolved = options or ClientOptions( + endpoint=endpoint or "", + api_key=api_key, + plaintext=plaintext, + ca_file=ca_file, + server_name_override=server_name_override, + ) + + if stub is not None: + return cls(options=resolved, stub=stub) + + channel = create_channel(resolved) + return cls( + options=resolved, + stub=pb_grpc.MxAccessGatewayStub(channel), + channel=channel, + ) + + async def __aenter__(self) -> "GatewayClient": + return self + + async def __aexit__(self, *_exc_info: object) -> None: + await self.close() + + async def close(self) -> None: + """Close the owned gRPC channel.""" + + if self._closed: + return + + self._closed = True + if self._channel is not None: + await self._channel.close() + + async def open_session( + self, + request: pb.OpenSessionRequest | None = None, + *, + requested_backend: str = "", + client_session_name: str = "", + client_correlation_id: str = "", + ) -> "Session": + """Open a gateway session and return a high-level session wrapper.""" + + from .session import Session + + raw_request = request or pb.OpenSessionRequest( + requested_backend=requested_backend, + client_session_name=client_session_name, + client_correlation_id=client_correlation_id, + ) + reply = await self.open_session_raw(raw_request) + return Session(client=self, session_id=reply.session_id, open_reply=reply) + + async def open_session_raw(self, request: pb.OpenSessionRequest) -> pb.OpenSessionReply: + reply = await self._unary("open session", self.raw_stub.OpenSession, request) + ensure_protocol_success("open session", reply.protocol_status, reply) + return reply + + async def close_session_raw( + self, + request: pb.CloseSessionRequest, + ) -> pb.CloseSessionReply: + reply = await self._unary("close session", self.raw_stub.CloseSession, request) + ensure_protocol_success("close session", reply.protocol_status, reply) + return reply + + async def invoke_raw(self, request: pb.MxCommandRequest) -> pb.MxCommandReply: + reply = await self._unary("invoke", self.raw_stub.Invoke, request) + ensure_protocol_success("invoke", reply.protocol_status, reply) + return reply + + def stream_events_raw( + self, + request: pb.StreamEventsRequest, + *, + metadata: Sequence[tuple[str, str]] | None = None, + ) -> AsyncIterator[pb.MxEvent]: + """Return an async event iterator and cancel the stream when iteration stops.""" + + call = self.raw_stub.StreamEvents( + request, + metadata=merge_metadata(self.options.api_key, metadata), + ) + return _canceling_iterator(call) + + async def _unary( + self, + operation: str, + method: Any, + request: Any, + *, + metadata: Sequence[tuple[str, str]] | None = None, + ) -> Any: + call = method( + request, + metadata=merge_metadata(self.options.api_key, metadata), + ) + try: + return await call + except asyncio.CancelledError: + cancel = getattr(call, "cancel", None) + if cancel is not None: + cancel() + raise + except grpc.RpcError as error: + raise map_rpc_error(operation, error) from error + + +async def _canceling_iterator(call: Any) -> AsyncIterator[pb.MxEvent]: + try: + async for event in call: + yield event + except grpc.RpcError as error: + raise map_rpc_error("stream events", error) from error + finally: + cancel = getattr(call, "cancel", None) + if cancel is not None: + cancel() diff --git a/clients/python/src/mxgateway/errors.py b/clients/python/src/mxgateway/errors.py new file mode 100644 index 0000000..9939d21 --- /dev/null +++ b/clients/python/src/mxgateway/errors.py @@ -0,0 +1,157 @@ +"""Typed exception model for MXAccess Gateway Python clients.""" + +from __future__ import annotations + +from typing import Any + +import grpc + +from .generated import mxaccess_gateway_pb2 as pb + + +class MxGatewayError(Exception): + """Base class for client wrapper errors.""" + + def __init__( + self, + message: str, + *, + protocol_status: pb.ProtocolStatus | None = None, + raw_reply: Any | None = None, + ) -> None: + super().__init__(message) + self.protocol_status = protocol_status + self.raw_reply = raw_reply + + +class MxGatewayTransportError(MxGatewayError): + """Transport-level gRPC failure.""" + + +class MxGatewayAuthenticationError(MxGatewayTransportError): + """Authentication failure reported by gRPC.""" + + +class MxGatewayAuthorizationError(MxGatewayTransportError): + """Authorization failure reported by gRPC.""" + + +class MxGatewaySessionError(MxGatewayError): + """Gateway session failure.""" + + +class MxGatewayWorkerError(MxGatewayError): + """Gateway worker process or protocol failure.""" + + +class MxGatewayCommandError(MxGatewayError): + """Command failure that preserves the raw protobuf reply.""" + + +class MxAccessError(MxGatewayCommandError): + """MXAccess HRESULT or status failure.""" + + +def map_rpc_error(operation: str, error: grpc.RpcError) -> MxGatewayTransportError: + """Map a generated gRPC exception to the client exception hierarchy.""" + + code = error.code() if hasattr(error, "code") else None + details = error.details() if hasattr(error, "details") else str(error) + message = f"{operation} failed: {details}" + + if code == grpc.StatusCode.UNAUTHENTICATED: + return MxGatewayAuthenticationError(message) + if code == grpc.StatusCode.PERMISSION_DENIED: + return MxGatewayAuthorizationError(message) + + return MxGatewayTransportError(message) + + +def ensure_protocol_success( + operation: str, + protocol_status: pb.ProtocolStatus | None, + raw_reply: Any | None = None, +) -> Any | None: + """Raise typed gateway errors for non-OK protocol statuses.""" + + code = ( + protocol_status.code + if protocol_status is not None + else pb.PROTOCOL_STATUS_CODE_UNSPECIFIED + ) + + if code in ( + pb.PROTOCOL_STATUS_CODE_OK, + pb.PROTOCOL_STATUS_CODE_MXACCESS_FAILURE, + ): + return raw_reply + + message_text = protocol_status.message if protocol_status else "" + message = f"{operation} failed: {message_text or pb.ProtocolStatusCode.Name(code)}" + + if code in ( + pb.PROTOCOL_STATUS_CODE_SESSION_NOT_FOUND, + pb.PROTOCOL_STATUS_CODE_SESSION_NOT_READY, + ): + raise MxGatewaySessionError( + message, + protocol_status=protocol_status, + raw_reply=raw_reply, + ) + + if code in ( + pb.PROTOCOL_STATUS_CODE_WORKER_UNAVAILABLE, + pb.PROTOCOL_STATUS_CODE_TIMEOUT, + pb.PROTOCOL_STATUS_CODE_CANCELED, + pb.PROTOCOL_STATUS_CODE_PROTOCOL_VIOLATION, + ): + raise MxGatewayWorkerError( + message, + protocol_status=protocol_status, + raw_reply=raw_reply, + ) + + raise MxGatewayCommandError( + message, + protocol_status=protocol_status, + raw_reply=raw_reply, + ) + + +def ensure_mxaccess_success(operation: str, reply: pb.MxCommandReply) -> pb.MxCommandReply: + """Raise `MxAccessError` when MXAccess returned HRESULT or status failure.""" + + status = reply.protocol_status + if status.code == pb.PROTOCOL_STATUS_CODE_MXACCESS_FAILURE: + raise MxAccessError( + _mxaccess_message(operation, reply), + protocol_status=status, + raw_reply=reply, + ) + + if reply.HasField("hresult") and reply.hresult < 0: + raise MxAccessError( + _mxaccess_message(operation, reply), + protocol_status=status, + raw_reply=reply, + ) + + for mx_status in reply.statuses: + if mx_status.success == 0: + raise MxAccessError( + _mxaccess_message(operation, reply), + protocol_status=status, + raw_reply=reply, + ) + + return reply + + +def _mxaccess_message(operation: str, reply: pb.MxCommandReply) -> str: + status_text = reply.protocol_status.message or "MXAccess command failed" + hresult = reply.hresult if reply.HasField("hresult") else None + return ( + f"{operation} failed: {status_text}; " + f"session={reply.session_id}; correlation={reply.correlation_id}; " + f"hresult={hresult}; statuses={len(reply.statuses)}" + ) diff --git a/clients/python/src/mxgateway/options.py b/clients/python/src/mxgateway/options.py new file mode 100644 index 0000000..845e544 --- /dev/null +++ b/clients/python/src/mxgateway/options.py @@ -0,0 +1,59 @@ +"""Client connection options for the async Python wrapper.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + +import grpc + +from .auth import REDACTED, ApiKey + + +@dataclass(frozen=True) +class ClientOptions: + """Connection settings for `GatewayClient.connect`.""" + + endpoint: str + api_key: str | ApiKey | None = None + plaintext: bool = False + ca_file: str | None = None + server_name_override: str | None = None + + def __post_init__(self) -> None: + if not self.endpoint: + raise ValueError("endpoint must not be empty") + + if self.plaintext and self.ca_file: + raise ValueError("ca_file cannot be used with plaintext connections") + + def __repr__(self) -> str: + api_key = REDACTED if self.api_key else None + return ( + f"{type(self).__name__}(endpoint={self.endpoint!r}, " + f"api_key={api_key!r}, plaintext={self.plaintext!r}, " + f"ca_file={self.ca_file!r}, " + f"server_name_override={self.server_name_override!r})" + ) + + +def create_channel(options: ClientOptions) -> grpc.aio.Channel: + """Create a plaintext or TLS `grpc.aio` channel from client options.""" + + channel_options: list[tuple[str, str]] = [] + if options.server_name_override: + channel_options.append(("grpc.ssl_target_name_override", options.server_name_override)) + + if options.plaintext: + return grpc.aio.insecure_channel(options.endpoint, options=channel_options) + + root_certificates = None + if options.ca_file: + root_certificates = Path(options.ca_file).read_bytes() + + credentials = grpc.ssl_channel_credentials(root_certificates=root_certificates) + return grpc.aio.secure_channel( + options.endpoint, + credentials, + options=channel_options, + ) diff --git a/clients/python/src/mxgateway/session.py b/clients/python/src/mxgateway/session.py new file mode 100644 index 0000000..13aa90b --- /dev/null +++ b/clients/python/src/mxgateway/session.py @@ -0,0 +1,209 @@ +"""Async session wrapper for MXAccess Gateway commands.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator + +from .errors import ensure_mxaccess_success +from .generated import mxaccess_gateway_pb2 as pb +from .values import MxValueInput, to_mx_value + + +class Session: + """A single gateway-backed MXAccess session.""" + + def __init__( + self, + *, + client: "GatewayClient", + session_id: str, + open_reply: pb.OpenSessionReply | None = None, + ) -> None: + self.client = client + self.session_id = session_id + self.open_reply = open_reply + self._closed = False + + async def __aenter__(self) -> "Session": + return self + + async def __aexit__(self, *_exc_info: object) -> None: + await self.close() + + async def close(self, *, client_correlation_id: str = "") -> pb.CloseSessionReply: + """Close the gateway session. Repeated calls return a local closed reply.""" + + if self._closed: + return pb.CloseSessionReply( + session_id=self.session_id, + final_state=pb.SESSION_STATE_CLOSED, + protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), + ) + + self._closed = True + return await self.client.close_session_raw( + pb.CloseSessionRequest( + session_id=self.session_id, + client_correlation_id=client_correlation_id, + ), + ) + + async def invoke(self, command: pb.MxCommand, *, correlation_id: str = "") -> pb.MxCommandReply: + """Invoke a raw command and enforce gateway and MXAccess success.""" + + reply = await self.invoke_raw(command, correlation_id=correlation_id) + return ensure_mxaccess_success("invoke", reply) + + async def invoke_raw( + self, + command: pb.MxCommand, + *, + correlation_id: str = "", + ) -> pb.MxCommandReply: + """Invoke a raw command and preserve the raw reply.""" + + return await self.client.invoke_raw( + pb.MxCommandRequest( + session_id=self.session_id, + client_correlation_id=correlation_id, + command=command, + ), + ) + + async def register(self, client_name: str, *, correlation_id: str = "") -> int: + reply = await self.invoke( + pb.MxCommand( + kind=pb.MX_COMMAND_KIND_REGISTER, + register=pb.RegisterCommand(client_name=client_name), + ), + correlation_id=correlation_id, + ) + return reply.register.server_handle + + async def unregister(self, server_handle: int, *, correlation_id: str = "") -> None: + await self.invoke( + pb.MxCommand( + kind=pb.MX_COMMAND_KIND_UNREGISTER, + unregister=pb.UnregisterCommand(server_handle=server_handle), + ), + correlation_id=correlation_id, + ) + + async def add_item( + self, + server_handle: int, + item_definition: str, + *, + correlation_id: str = "", + ) -> int: + reply = await self.invoke( + pb.MxCommand( + kind=pb.MX_COMMAND_KIND_ADD_ITEM, + add_item=pb.AddItemCommand( + server_handle=server_handle, + item_definition=item_definition, + ), + ), + correlation_id=correlation_id, + ) + return reply.add_item.item_handle + + async def add_item2( + self, + server_handle: int, + item_definition: str, + item_context: str, + *, + correlation_id: str = "", + ) -> int: + reply = await self.invoke( + pb.MxCommand( + kind=pb.MX_COMMAND_KIND_ADD_ITEM2, + add_item2=pb.AddItem2Command( + server_handle=server_handle, + item_definition=item_definition, + item_context=item_context, + ), + ), + correlation_id=correlation_id, + ) + return reply.add_item2.item_handle + + async def advise( + self, + server_handle: int, + item_handle: int, + *, + correlation_id: str = "", + ) -> None: + await self.invoke( + pb.MxCommand( + kind=pb.MX_COMMAND_KIND_ADVISE, + advise=pb.AdviseCommand( + server_handle=server_handle, + item_handle=item_handle, + ), + ), + correlation_id=correlation_id, + ) + + async def write( + self, + server_handle: int, + item_handle: int, + value: MxValueInput, + *, + user_id: int = 0, + correlation_id: str = "", + ) -> None: + await self.invoke( + pb.MxCommand( + kind=pb.MX_COMMAND_KIND_WRITE, + write=pb.WriteCommand( + server_handle=server_handle, + item_handle=item_handle, + value=to_mx_value(value), + user_id=user_id, + ), + ), + correlation_id=correlation_id, + ) + + async def write2( + self, + server_handle: int, + item_handle: int, + value: MxValueInput, + timestamp_value: MxValueInput, + *, + user_id: int = 0, + correlation_id: str = "", + ) -> None: + await self.invoke( + pb.MxCommand( + kind=pb.MX_COMMAND_KIND_WRITE2, + write2=pb.Write2Command( + server_handle=server_handle, + item_handle=item_handle, + value=to_mx_value(value), + timestamp_value=to_mx_value(timestamp_value), + user_id=user_id, + ), + ), + correlation_id=correlation_id, + ) + + def stream_events( + self, + *, + after_worker_sequence: int = 0, + ) -> AsyncIterator[pb.MxEvent]: + return self.client.stream_events_raw( + pb.StreamEventsRequest( + session_id=self.session_id, + after_worker_sequence=after_worker_sequence, + ), + ) + + +from .client import GatewayClient # noqa: E402 diff --git a/clients/python/src/mxgateway/values.py b/clients/python/src/mxgateway/values.py new file mode 100644 index 0000000..e9251bf --- /dev/null +++ b/clients/python/src/mxgateway/values.py @@ -0,0 +1,234 @@ +"""MXAccess value conversion helpers.""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any + +from google.protobuf.timestamp_pb2 import Timestamp + +from .generated import mxaccess_gateway_pb2 as pb + + +MxValueInput = bool | int | float | str | datetime | bytes | None | Sequence[Any] + + +@dataclass(frozen=True) +class MxValueView: + """Typed projection of a raw `MxValue` protobuf message.""" + + value: Any + kind: str + raw: pb.MxValue + + +def to_mx_value(value: MxValueInput, *, data_type: str | None = None) -> pb.MxValue: + """Convert a Python value into the public protobuf `MxValue` union.""" + + if isinstance(value, pb.MxValue): + return value + + if value is None: + return pb.MxValue( + data_type=pb.MX_DATA_TYPE_NO_DATA, + variant_type="VT_EMPTY", + is_null=True, + raw_data_type=pb.MX_DATA_TYPE_NO_DATA, + ) + + if isinstance(value, bool): + return pb.MxValue( + data_type=_data_type(data_type, pb.MX_DATA_TYPE_BOOLEAN), + variant_type="VT_BOOL", + bool_value=value, + ) + + if isinstance(value, int): + if -(2**31) <= value <= (2**31 - 1): + return pb.MxValue( + data_type=_data_type(data_type, pb.MX_DATA_TYPE_INTEGER), + variant_type="VT_I4", + int32_value=value, + ) + + return pb.MxValue( + data_type=_data_type(data_type, pb.MX_DATA_TYPE_INTEGER), + variant_type="VT_I8", + int64_value=value, + ) + + if isinstance(value, float): + return pb.MxValue( + data_type=_data_type(data_type, pb.MX_DATA_TYPE_DOUBLE), + variant_type="VT_R8", + double_value=value, + ) + + if isinstance(value, str): + return pb.MxValue( + data_type=_data_type(data_type, pb.MX_DATA_TYPE_STRING), + variant_type="VT_BSTR", + string_value=value, + ) + + if isinstance(value, datetime): + return pb.MxValue( + data_type=_data_type(data_type, pb.MX_DATA_TYPE_TIME), + variant_type="VT_DATE", + timestamp_value=_timestamp_from_datetime(value), + ) + + if isinstance(value, bytes): + return pb.MxValue( + data_type=_data_type(data_type, pb.MX_DATA_TYPE_UNKNOWN), + variant_type="VT_RECORD", + raw_value=value, + ) + + if isinstance(value, Sequence): + return _sequence_to_mx_value(value, data_type=data_type) + + raise TypeError(f"unsupported MxValue input type: {type(value).__name__}") + + +def from_mx_value(value: pb.MxValue) -> MxValueView: + """Project a protobuf `MxValue` into an idiomatic Python value.""" + + kind = value.WhichOneof("kind") + if kind is None: + return MxValueView(None, "none", value) + + if kind == "timestamp_value": + return MxValueView( + value.timestamp_value.ToDatetime().replace(tzinfo=timezone.utc), + kind, + value, + ) + + if kind == "array_value": + return MxValueView(from_mx_array(value.array_value), kind, value) + + return MxValueView(getattr(value, kind), kind, value) + + +def from_mx_array(array: pb.MxArray) -> list[Any]: + """Project a protobuf `MxArray` into a Python list.""" + + kind = array.WhichOneof("values") + if kind is None: + return [] + + values = list(getattr(array, kind).values) + if kind == "timestamp_values": + return [ + timestamp.ToDatetime().replace(tzinfo=timezone.utc) + for timestamp in values + ] + + return values + + +def _sequence_to_mx_value( + values: Sequence[Any], + *, + data_type: str | None, +) -> pb.MxValue: + sequence = list(values) + if not sequence: + return pb.MxValue( + data_type=_data_type(data_type, pb.MX_DATA_TYPE_UNKNOWN), + array_value=pb.MxArray( + element_data_type=pb.MX_DATA_TYPE_UNKNOWN, + dimensions=[0], + ), + ) + + first = sequence[0] + dimensions = [len(sequence)] + + if all(isinstance(item, bool) for item in sequence): + array = pb.MxArray( + element_data_type=pb.MX_DATA_TYPE_BOOLEAN, + variant_type="VT_ARRAY|VT_BOOL", + dimensions=dimensions, + bool_values=pb.BoolArray(values=sequence), + ) + return pb.MxValue(data_type=pb.MX_DATA_TYPE_BOOLEAN, array_value=array) + + if all(isinstance(item, int) and not isinstance(item, bool) for item in sequence): + use_int32 = all(-(2**31) <= item <= (2**31 - 1) for item in sequence) + if use_int32: + array = pb.MxArray( + element_data_type=pb.MX_DATA_TYPE_INTEGER, + variant_type="VT_ARRAY|VT_I4", + dimensions=dimensions, + int32_values=pb.Int32Array(values=sequence), + ) + else: + array = pb.MxArray( + element_data_type=pb.MX_DATA_TYPE_INTEGER, + variant_type="VT_ARRAY|VT_I8", + dimensions=dimensions, + int64_values=pb.Int64Array(values=sequence), + ) + + return pb.MxValue(data_type=pb.MX_DATA_TYPE_INTEGER, array_value=array) + + if all(isinstance(item, float) for item in sequence): + array = pb.MxArray( + element_data_type=pb.MX_DATA_TYPE_DOUBLE, + variant_type="VT_ARRAY|VT_R8", + dimensions=dimensions, + double_values=pb.DoubleArray(values=sequence), + ) + return pb.MxValue(data_type=pb.MX_DATA_TYPE_DOUBLE, array_value=array) + + if all(isinstance(item, str) for item in sequence): + array = pb.MxArray( + element_data_type=pb.MX_DATA_TYPE_STRING, + variant_type="VT_ARRAY|VT_BSTR", + dimensions=dimensions, + string_values=pb.StringArray(values=sequence), + ) + return pb.MxValue(data_type=pb.MX_DATA_TYPE_STRING, array_value=array) + + if all(isinstance(item, datetime) for item in sequence): + array = pb.MxArray( + element_data_type=pb.MX_DATA_TYPE_TIME, + variant_type="VT_ARRAY|VT_DATE", + dimensions=dimensions, + timestamp_values=pb.TimestampArray( + values=[_timestamp_from_datetime(item) for item in sequence], + ), + ) + return pb.MxValue(data_type=pb.MX_DATA_TYPE_TIME, array_value=array) + + if all(isinstance(item, bytes) for item in sequence): + array = pb.MxArray( + element_data_type=pb.MX_DATA_TYPE_UNKNOWN, + variant_type="VT_ARRAY|VT_VARIANT", + dimensions=dimensions, + raw_values=pb.RawArray(values=sequence), + ) + return pb.MxValue(data_type=pb.MX_DATA_TYPE_UNKNOWN, array_value=array) + + raise TypeError( + "MxValue array inputs must use one supported element type; " + f"first element was {type(first).__name__}" + ) + + +def _timestamp_from_datetime(value: datetime) -> Timestamp: + timestamp = Timestamp() + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + timestamp.FromDatetime(value.astimezone(timezone.utc)) + return timestamp + + +def _data_type(name: str | None, default: int) -> int: + if name is None: + return default + return pb.MxDataType.Value(name) diff --git a/clients/python/src/mxgateway_cli/commands.py b/clients/python/src/mxgateway_cli/commands.py index dbd70af..c72b56e 100644 --- a/clients/python/src/mxgateway_cli/commands.py +++ b/clients/python/src/mxgateway_cli/commands.py @@ -1,10 +1,24 @@ -"""CLI scaffold for the MXAccess Gateway Python client.""" +"""Command line interface for the MXAccess Gateway Python client.""" +from __future__ import annotations + +import asyncio import json +import os +from collections.abc import Awaitable, Callable +from datetime import datetime, timezone +from typing import Any import click +from google.protobuf.json_format import MessageToDict from mxgateway import __version__ +from mxgateway.auth import redact_secret +from mxgateway.client import GatewayClient +from mxgateway.errors import MxGatewayError +from mxgateway.generated import mxaccess_gateway_pb2 as pb +from mxgateway.options import ClientOptions +from mxgateway.values import MxValueInput @click.group() @@ -16,14 +30,435 @@ def main() -> None: @click.option("--json", "output_json", is_flag=True, help="Emit JSON output.") def version(output_json: bool) -> None: """Print client package version information.""" + payload = { "client": "mxgw-py", "package": "mxaccess-gateway-client", "version": __version__, } + _emit(payload, output_json=output_json, text=f"mxgw-py {__version__}") + +def gateway_options(command: Callable[..., Any]) -> Callable[..., Any]: + command = click.option("--endpoint", default="localhost:5000", show_default=True)(command) + command = click.option("--api-key", default=None, help="Gateway API key.")(command) + command = click.option( + "--api-key-env", + default=None, + help="Environment variable containing the gateway API key.", + )(command) + command = click.option("--plaintext", is_flag=True, help="Use plaintext gRPC.")(command) + command = click.option("--tls", "use_tls", is_flag=True, help="Use TLS gRPC.")(command) + command = click.option("--ca-file", default=None, help="Custom root certificate file.")(command) + command = click.option( + "--server-name-override", + default=None, + help="TLS server name override for test environments.", + )(command) + return command + + +@main.command("open-session") +@gateway_options +@click.option("--client-name", default="", help="Client session name.") +@click.option("--requested-backend", default="", help="Requested backend name.") +@click.option("--correlation-id", default="", help="Client correlation id.") +@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.") +def open_session(**kwargs: Any) -> None: + """Open a gateway session.""" + + _run( + _open_session(**kwargs), + output_json=kwargs["output_json"], + secrets=_secrets(kwargs), + ) + + +@main.command("close-session") +@gateway_options +@click.option("--session-id", required=True, help="Gateway session id.") +@click.option("--correlation-id", default="", help="Client correlation id.") +@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.") +def close_session(**kwargs: Any) -> None: + """Close a gateway session.""" + + _run( + _close_session(**kwargs), + output_json=kwargs["output_json"], + secrets=_secrets(kwargs), + ) + + +@main.command() +@gateway_options +@click.option("--session-id", required=True, help="Gateway session id.") +@click.option("--message", default="ping", show_default=True, help="Ping payload.") +@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.") +def ping(**kwargs: Any) -> None: + """Send a diagnostic ping command.""" + + _run(_ping(**kwargs), output_json=kwargs["output_json"], secrets=_secrets(kwargs)) + + +@main.command() +@gateway_options +@click.option("--session-id", required=True, help="Gateway session id.") +@click.option("--client-name", required=True, help="MXAccess client name.") +@click.option("--correlation-id", default="", help="Client correlation id.") +@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.") +def register(**kwargs: Any) -> None: + """Invoke MXAccess Register.""" + + _run( + _register(**kwargs), + output_json=kwargs["output_json"], + secrets=_secrets(kwargs), + ) + + +@main.command("add-item") +@gateway_options +@click.option("--session-id", required=True, help="Gateway session id.") +@click.option("--server-handle", required=True, type=int, help="MXAccess server handle.") +@click.option("--item", required=True, help="MXAccess item definition.") +@click.option("--correlation-id", default="", help="Client correlation id.") +@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.") +def add_item(**kwargs: Any) -> None: + """Invoke MXAccess AddItem.""" + + _run( + _add_item(**kwargs), + output_json=kwargs["output_json"], + secrets=_secrets(kwargs), + ) + + +@main.command() +@gateway_options +@click.option("--session-id", required=True, help="Gateway session id.") +@click.option("--server-handle", required=True, type=int, help="MXAccess server handle.") +@click.option("--item-handle", required=True, type=int, help="MXAccess item handle.") +@click.option("--correlation-id", default="", help="Client correlation id.") +@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.") +def advise(**kwargs: Any) -> None: + """Invoke MXAccess Advise.""" + + _run(_advise(**kwargs), output_json=kwargs["output_json"], secrets=_secrets(kwargs)) + + +@main.command("stream-events") +@gateway_options +@click.option("--session-id", required=True, help="Gateway session id.") +@click.option("--after-worker-sequence", default=0, type=int, show_default=True) +@click.option("--max-events", default=1, type=int, show_default=True) +@click.option("--timeout", default=5.0, type=float, show_default=True) +@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.") +def stream_events(**kwargs: Any) -> None: + """Stream a bounded number of events.""" + + _run( + _stream_events(**kwargs), + output_json=kwargs["output_json"], + secrets=_secrets(kwargs), + ) + + +@main.command() +@gateway_options +@click.option("--session-id", required=True, help="Gateway session id.") +@click.option("--server-handle", required=True, type=int, help="MXAccess server handle.") +@click.option("--item-handle", required=True, type=int, help="MXAccess item handle.") +@click.option("--type", "value_type", default="string", show_default=True) +@click.option("--value", required=True, help="Value to write.") +@click.option("--user-id", default=0, type=int, show_default=True) +@click.option("--correlation-id", default="", help="Client correlation id.") +@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.") +def write(**kwargs: Any) -> None: + """Invoke MXAccess Write.""" + + _run(_write(**kwargs), output_json=kwargs["output_json"], secrets=_secrets(kwargs)) + + +@main.command() +@gateway_options +@click.option("--session-id", required=True, help="Gateway session id.") +@click.option("--server-handle", required=True, type=int, help="MXAccess server handle.") +@click.option("--item-handle", required=True, type=int, help="MXAccess item handle.") +@click.option("--type", "value_type", default="string", show_default=True) +@click.option("--value", required=True, help="Value to write.") +@click.option("--timestamp", required=True, help="ISO-8601 timestamp value.") +@click.option("--user-id", default=0, type=int, show_default=True) +@click.option("--correlation-id", default="", help="Client correlation id.") +@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.") +def write2(**kwargs: Any) -> None: + """Invoke MXAccess Write2.""" + + _run(_write2(**kwargs), output_json=kwargs["output_json"], secrets=_secrets(kwargs)) + + +@main.command() +@gateway_options +@click.option("--client-name", default="mxgw-py-smoke", show_default=True) +@click.option("--item", required=True, help="MXAccess item definition.") +@click.option("--max-events", default=1, type=int, show_default=True) +@click.option("--timeout", default=5.0, type=float, show_default=True) +@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.") +def smoke(**kwargs: Any) -> None: + """Run a bounded open/register/add/advise/stream/close smoke flow.""" + + _run(_smoke(**kwargs), output_json=kwargs["output_json"], secrets=_secrets(kwargs)) + + +async def _open_session(**kwargs: Any) -> dict[str, Any]: + async with await _connect(kwargs) as client: + reply = await client.open_session_raw( + pb.OpenSessionRequest( + requested_backend=kwargs["requested_backend"], + client_session_name=kwargs["client_name"], + client_correlation_id=kwargs["correlation_id"], + ), + ) + return {"sessionId": reply.session_id, "rawReply": _message_dict(reply)} + + +async def _close_session(**kwargs: Any) -> dict[str, Any]: + async with await _connect(kwargs) as client: + reply = await client.close_session_raw( + pb.CloseSessionRequest( + session_id=kwargs["session_id"], + client_correlation_id=kwargs["correlation_id"], + ), + ) + return {"sessionId": reply.session_id, "rawReply": _message_dict(reply)} + + +async def _ping(**kwargs: Any) -> dict[str, Any]: + async with await _connect(kwargs) as client: + reply = await client.invoke_raw( + pb.MxCommandRequest( + session_id=kwargs["session_id"], + command=pb.MxCommand( + kind=pb.MX_COMMAND_KIND_PING, + ping=pb.PingCommand(message=kwargs["message"]), + ), + ), + ) + return {"kind": "ping", "rawReply": _message_dict(reply)} + + +async def _register(**kwargs: Any) -> dict[str, Any]: + async with await _connect(kwargs) as client: + session = _session(client, kwargs["session_id"]) + server_handle = await session.register( + kwargs["client_name"], + correlation_id=kwargs["correlation_id"], + ) + return {"serverHandle": server_handle} + + +async def _add_item(**kwargs: Any) -> dict[str, Any]: + async with await _connect(kwargs) as client: + session = _session(client, kwargs["session_id"]) + item_handle = await session.add_item( + kwargs["server_handle"], + kwargs["item"], + correlation_id=kwargs["correlation_id"], + ) + return {"itemHandle": item_handle} + + +async def _advise(**kwargs: Any) -> dict[str, Any]: + async with await _connect(kwargs) as client: + session = _session(client, kwargs["session_id"]) + await session.advise( + kwargs["server_handle"], + kwargs["item_handle"], + correlation_id=kwargs["correlation_id"], + ) + return {"ok": True} + + +async def _stream_events(**kwargs: Any) -> dict[str, Any]: + async with await _connect(kwargs) as client: + session = _session(client, kwargs["session_id"]) + events = await _collect_events( + session.stream_events(after_worker_sequence=kwargs["after_worker_sequence"]), + max_events=kwargs["max_events"], + timeout=kwargs["timeout"], + ) + return {"events": [_message_dict(event) for event in events]} + + +async def _write(**kwargs: Any) -> dict[str, Any]: + value = _parse_value(kwargs["value"], kwargs["value_type"]) + async with await _connect(kwargs) as client: + session = _session(client, kwargs["session_id"]) + await session.write( + kwargs["server_handle"], + kwargs["item_handle"], + value, + user_id=kwargs["user_id"], + correlation_id=kwargs["correlation_id"], + ) + return {"ok": True} + + +async def _write2(**kwargs: Any) -> dict[str, Any]: + value = _parse_value(kwargs["value"], kwargs["value_type"]) + timestamp = _parse_datetime(kwargs["timestamp"]) + async with await _connect(kwargs) as client: + session = _session(client, kwargs["session_id"]) + await session.write2( + kwargs["server_handle"], + kwargs["item_handle"], + value, + timestamp, + user_id=kwargs["user_id"], + correlation_id=kwargs["correlation_id"], + ) + return {"ok": True} + + +async def _smoke(**kwargs: Any) -> dict[str, Any]: + async with await _connect(kwargs) as client: + session = await client.open_session(client_session_name=kwargs["client_name"]) + closed = False + try: + server_handle = await session.register(kwargs["client_name"]) + item_handle = await session.add_item(server_handle, kwargs["item"]) + await session.advise(server_handle, item_handle) + events = await _collect_events( + session.stream_events(), + max_events=kwargs["max_events"], + timeout=kwargs["timeout"], + ) + return { + "sessionId": session.session_id, + "serverHandle": server_handle, + "itemHandle": item_handle, + "events": [_message_dict(event) for event in events], + } + finally: + if not closed: + await session.close() + + +async def _connect(kwargs: dict[str, Any]) -> GatewayClient: + api_key = kwargs.get("api_key") or _api_key_from_env(kwargs.get("api_key_env")) + return await GatewayClient.connect( + ClientOptions( + endpoint=kwargs["endpoint"], + api_key=api_key, + plaintext=_use_plaintext(kwargs), + ca_file=kwargs.get("ca_file"), + server_name_override=kwargs.get("server_name_override"), + ), + ) + + +def _session(client: GatewayClient, session_id: str): + from mxgateway.session import Session + + return Session(client=client, session_id=session_id) + + +def _use_plaintext(kwargs: dict[str, Any]) -> bool: + if kwargs.get("use_tls"): + return False + if kwargs.get("plaintext"): + return True + return kwargs["endpoint"].startswith("localhost:") or kwargs["endpoint"].startswith("127.0.0.1:") + + +def _api_key_from_env(name: str | None) -> str | None: + if not name: + return None + return os.environ.get(name) + + +def _secrets(kwargs: dict[str, Any]) -> list[str | None]: + return [ + kwargs.get("api_key"), + _api_key_from_env(kwargs.get("api_key_env")), + ] + + +def _run( + awaitable: Awaitable[dict[str, Any]], + *, + output_json: bool, + secrets: list[str | None], +) -> None: + try: + payload = asyncio.run(awaitable) + except MxGatewayError as error: + raise click.ClickException(redact_secret(str(error), secrets)) from error + + _emit(payload, output_json=output_json) + + +def _emit( + payload: dict[str, Any], + *, + output_json: bool, + text: str | None = None, +) -> None: if output_json: click.echo(json.dumps(payload, sort_keys=True)) return - click.echo(f"mxgw-py {__version__}") + click.echo(text or json.dumps(payload, sort_keys=True)) + + +async def _collect_events( + events: Any, + *, + max_events: int, + timeout: float, +) -> list[pb.MxEvent]: + collected: list[pb.MxEvent] = [] + iterator = events.__aiter__() + try: + while len(collected) < max_events: + collected.append(await asyncio.wait_for(iterator.__anext__(), timeout=timeout)) + except StopAsyncIteration: + pass + finally: + close = getattr(iterator, "aclose", None) + if close is not None: + await close() + return collected + + +def _parse_value(raw_value: str, value_type: str) -> MxValueInput: + normalized = value_type.lower() + if normalized == "bool": + return raw_value.lower() in ("1", "true", "yes", "on") + if normalized in ("int", "int32", "int64"): + return int(raw_value) + if normalized in ("float", "double"): + return float(raw_value) + if normalized in ("time", "timestamp"): + return _parse_datetime(raw_value) + if normalized == "raw": + return raw_value.encode("utf-8") + if normalized == "string": + return raw_value + raise click.BadParameter(f"unsupported value type: {value_type}", param_hint="--type") + + +def _parse_datetime(raw_value: str) -> datetime: + if raw_value.endswith("Z"): + raw_value = raw_value[:-1] + "+00:00" + parsed = datetime.fromisoformat(raw_value) + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + return parsed + + +def _message_dict(message: Any) -> dict[str, Any]: + return MessageToDict( + message, + preserving_proto_field_name=False, + use_integers_for_enums=False, + ) diff --git a/clients/python/tests/test_auth_options.py b/clients/python/tests/test_auth_options.py new file mode 100644 index 0000000..522e017 --- /dev/null +++ b/clients/python/tests/test_auth_options.py @@ -0,0 +1,103 @@ +"""Tests for auth metadata and connection options.""" + +import pytest + +from mxgateway.auth import REDACTED, ApiKey, auth_metadata, redact_secret +from mxgateway import options as options_module +from mxgateway.options import ClientOptions, create_channel + + +def test_auth_metadata_adds_bearer_api_key() -> None: + assert auth_metadata("mxgw_test_secret") == ( + ("authorization", "Bearer mxgw_test_secret"), + ) + + +def test_api_key_repr_is_redacted() -> None: + api_key = ApiKey("mxgw_test_secret") + + assert "mxgw_test_secret" not in repr(api_key) + assert REDACTED in repr(api_key) + + +def test_redact_secret_replaces_known_values() -> None: + redacted = redact_secret( + "authorization failed for mxgw_test_secret", + ["mxgw_test_secret"], + ) + + assert redacted == f"authorization failed for {REDACTED}" + + +def test_client_options_reject_plaintext_with_ca_file() -> None: + with pytest.raises(ValueError, match="ca_file"): + ClientOptions( + endpoint="localhost:5000", + plaintext=True, + ca_file="ca.pem", + ) + + +def test_client_options_repr_redacts_api_key() -> None: + options = ClientOptions(endpoint="localhost:5000", api_key="mxgw_test_secret") + + assert "mxgw_test_secret" not in repr(options) + assert REDACTED in repr(options) + + +def test_create_channel_uses_plaintext_channel(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[tuple[str, object]] = [] + + def fake_insecure_channel(endpoint: str, *, options: object) -> str: + calls.append((endpoint, options)) + return "plain-channel" + + monkeypatch.setattr( + options_module.grpc.aio, + "insecure_channel", + fake_insecure_channel, + ) + + channel = create_channel(ClientOptions(endpoint="localhost:5000", plaintext=True)) + + assert channel == "plain-channel" + assert calls == [("localhost:5000", [])] + + +def test_create_channel_uses_tls_channel(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[tuple[str, object, object]] = [] + + def fake_credentials(*, root_certificates: object) -> str: + assert root_certificates is None + return "creds" + + def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str: + calls.append((endpoint, credentials, options)) + return "tls-channel" + + monkeypatch.setattr( + options_module.grpc, + "ssl_channel_credentials", + fake_credentials, + ) + monkeypatch.setattr( + options_module.grpc.aio, + "secure_channel", + fake_secure_channel, + ) + + channel = create_channel( + ClientOptions( + endpoint="gateway.example:5001", + server_name_override="gateway.test", + ), + ) + + assert channel == "tls-channel" + assert calls == [ + ( + "gateway.example:5001", + "creds", + [("grpc.ssl_target_name_override", "gateway.test")], + ), + ] diff --git a/clients/python/tests/test_cli.py b/clients/python/tests/test_cli.py index 9af78bf..a2ff19e 100644 --- a/clients/python/tests/test_cli.py +++ b/clients/python/tests/test_cli.py @@ -1,4 +1,4 @@ -"""Tests for the Python CLI scaffold.""" +"""Tests for the Python CLI.""" import json @@ -19,3 +19,50 @@ def test_version_json_is_deterministic() -> None: "package": "mxaccess-gateway-client", "version": __version__, } + + +def test_write_parser_rejects_unknown_value_type() -> None: + runner = CliRunner() + + result = runner.invoke( + main, + [ + "write", + "--session-id", + "session-1", + "--server-handle", + "12", + "--item-handle", + "34", + "--type", + "unsupported", + "--value", + "123", + "--api-key", + "mxgw_test_secret", + "--json", + ], + ) + + assert result.exit_code != 0 + assert "unsupported value type" in result.output + + +def test_cli_error_output_redacts_api_key() -> None: + runner = CliRunner() + + result = runner.invoke( + main, + [ + "open-session", + "--endpoint", + "127.0.0.1:1", + "--api-key", + "mxgw_test_secret", + "--plaintext", + "--json", + ], + ) + + assert result.exit_code != 0 + assert "mxgw_test_secret" not in result.output diff --git a/clients/python/tests/test_client_session.py b/clients/python/tests/test_client_session.py new file mode 100644 index 0000000..7e278da --- /dev/null +++ b/clients/python/tests/test_client_session.py @@ -0,0 +1,225 @@ +"""Tests for the async client and session wrappers.""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import pytest + +from mxgateway import ClientOptions, GatewayClient, MxAccessError +from mxgateway.generated import mxaccess_gateway_pb2 as pb + + +@pytest.mark.asyncio +async def test_session_helpers_send_auth_metadata_and_preserve_raw_replies() -> None: + stub = FakeGatewayStub() + client = await GatewayClient.connect( + ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), + stub=stub, + ) + + session = await client.open_session(client_session_name="pytest") + server_handle = await session.register("pytest-client") + item_handle = await session.add_item(server_handle, "Object.Attribute") + await session.advise(server_handle, item_handle) + + assert session.session_id == "session-1" + assert server_handle == 12 + assert item_handle == 34 + assert stub.open_session.metadata == (("authorization", "Bearer mxgw_test_secret"),) + assert stub.invoke.requests[0].command.register.client_name == "pytest-client" + assert stub.invoke.requests[1].command.add_item.item_definition == "Object.Attribute" + assert stub.invoke.requests[2].command.advise.item_handle == 34 + + +@pytest.mark.asyncio +async def test_mxaccess_error_preserves_raw_reply() -> None: + stub = FakeGatewayStub() + failure_reply = pb.MxCommandReply( + session_id="session-1", + kind=pb.MX_COMMAND_KIND_WRITE, + protocol_status=pb.ProtocolStatus( + code=pb.PROTOCOL_STATUS_CODE_MXACCESS_FAILURE, + message="MXAccess rejected write.", + ), + hresult=-1, + ) + stub.invoke.replies = [failure_reply] + client = await GatewayClient.connect( + ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), + stub=stub, + ) + session = await client.open_session() + + with pytest.raises(MxAccessError) as captured: + await session.write(12, 34, 123) + + assert captured.value.raw_reply is failure_reply + + +@pytest.mark.asyncio +async def test_stream_events_cancels_underlying_call_when_closed() -> None: + stream = FakeStream( + [ + pb.MxEvent( + session_id="session-1", + worker_sequence=1, + family=pb.MX_EVENT_FAMILY_ON_DATA_CHANGE, + ), + ], + ) + stub = FakeGatewayStub(stream=stream) + client = await GatewayClient.connect( + ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), + stub=stub, + ) + session = await client.open_session() + + events = session.stream_events() + first = await anext(events) + await events.aclose() + + assert first.worker_sequence == 1 + assert stream.cancelled + assert stub.stream_metadata == (("authorization", "Bearer mxgw_test_secret"),) + + +@pytest.mark.asyncio +async def test_unary_task_cancellation_reaches_fake_call() -> None: + blocking = BlockingCancellableUnary() + stub = FakeGatewayStub() + stub.OpenSession = blocking + client = await GatewayClient.connect( + ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True), + stub=stub, + ) + + task = asyncio.create_task(client.open_session()) + await blocking.started.wait() + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + assert blocking.call is not None + assert blocking.call.cancelled + + +class FakeGatewayStub: + def __init__(self, stream: "FakeStream | None" = None) -> None: + self.open_session = FakeUnary( + [ + pb.OpenSessionReply( + session_id="session-1", + protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), + ), + ], + ) + self.close_session = FakeUnary( + [ + pb.CloseSessionReply( + session_id="session-1", + protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), + ), + ], + ) + self.invoke = FakeUnary( + [ + pb.MxCommandReply( + session_id="session-1", + kind=pb.MX_COMMAND_KIND_REGISTER, + protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), + register=pb.RegisterReply(server_handle=12), + ), + pb.MxCommandReply( + session_id="session-1", + kind=pb.MX_COMMAND_KIND_ADD_ITEM, + protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), + add_item=pb.AddItemReply(item_handle=34), + ), + pb.MxCommandReply( + session_id="session-1", + kind=pb.MX_COMMAND_KIND_ADVISE, + protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK), + ), + ], + ) + self.OpenSession = self.open_session + self.CloseSession = self.close_session + self.Invoke = self.invoke + self._stream = stream or FakeStream([]) + self.stream_metadata: tuple[tuple[str, str], ...] | None = None + + def StreamEvents( + self, + request: pb.StreamEventsRequest, + *, + metadata: tuple[tuple[str, str], ...], + ) -> "FakeStream": + self.stream_request = request + self.stream_metadata = metadata + return self._stream + + +class FakeUnary: + def __init__(self, replies: list[Any]) -> None: + self.replies = replies + self.requests: list[Any] = [] + self.metadata: tuple[tuple[str, str], ...] | None = None + + async def __call__( + self, + request: Any, + *, + metadata: tuple[tuple[str, str], ...], + ) -> Any: + self.requests.append(request) + self.metadata = metadata + return self.replies.pop(0) + + +class BlockingCancellableUnary: + def __init__(self) -> None: + self.started = asyncio.Event() + self.call: BlockingCall | None = None + + def __call__(self, *_args: Any, **_kwargs: Any) -> "BlockingCall": + self.call = BlockingCall(self.started) + return self.call + + +class BlockingCall: + def __init__(self, started: asyncio.Event) -> None: + self.started = started + self.cancelled = False + + def __await__(self): + return self._wait().__await__() + + async def _wait(self) -> Any: + self.started.set() + try: + await asyncio.Event().wait() + except asyncio.CancelledError: + raise + + def cancel(self) -> None: + self.cancelled = True + + +class FakeStream: + def __init__(self, events: list[pb.MxEvent]) -> None: + self._events = events + self.cancelled = False + + def __aiter__(self) -> "FakeStream": + return self + + async def __anext__(self) -> pb.MxEvent: + if not self._events: + await asyncio.sleep(3600) + return self._events.pop(0) + + def cancel(self) -> None: + self.cancelled = True diff --git a/clients/python/tests/test_errors.py b/clients/python/tests/test_errors.py new file mode 100644 index 0000000..7720ed2 --- /dev/null +++ b/clients/python/tests/test_errors.py @@ -0,0 +1,49 @@ +"""Tests for typed command error mapping.""" + +import json +from pathlib import Path + +import pytest +from google.protobuf.json_format import ParseDict + +from mxgateway.errors import ensure_mxaccess_success, ensure_protocol_success +from mxgateway import MxAccessError, MxGatewaySessionError +from mxgateway.generated import mxaccess_gateway_pb2 as pb + +FIXTURE_ROOT = Path(__file__).resolve().parents[2] / "proto" / "fixtures" / "behavior" + + +def test_register_fixture_is_protocol_and_mxaccess_success() -> None: + reply = _load_reply("command-replies/register.ok.reply.json") + + assert ensure_protocol_success("register", reply.protocol_status, reply) is reply + assert ensure_mxaccess_success("register", reply) is reply + + +def test_write_failure_fixture_preserves_raw_reply() -> None: + reply = _load_reply("command-replies/write.mxaccess-failure.reply.json") + + assert ensure_protocol_success("write", reply.protocol_status, reply) is reply + with pytest.raises(MxAccessError) as captured: + ensure_mxaccess_success("write", reply) + + assert captured.value.raw_reply is reply + assert captured.value.raw_reply.hresult == -2147220992 + assert len(captured.value.raw_reply.statuses) == 2 + + +def test_session_status_maps_to_session_error() -> None: + status = pb.ProtocolStatus( + code=pb.PROTOCOL_STATUS_CODE_SESSION_NOT_FOUND, + message="session missing", + ) + + with pytest.raises(MxGatewaySessionError) as captured: + ensure_protocol_success("invoke", status) + + assert captured.value.protocol_status is status + + +def _load_reply(name: str) -> pb.MxCommandReply: + payload = json.loads((FIXTURE_ROOT / name).read_text(encoding="utf-8")) + return ParseDict(payload, pb.MxCommandReply()) diff --git a/clients/python/tests/test_values.py b/clients/python/tests/test_values.py new file mode 100644 index 0000000..c6d1c69 --- /dev/null +++ b/clients/python/tests/test_values.py @@ -0,0 +1,49 @@ +"""Tests for MXAccess value conversion helpers.""" + +import json +import re +from datetime import datetime, timezone +from pathlib import Path + +from google.protobuf.json_format import ParseDict + +from mxgateway.generated import mxaccess_gateway_pb2 as pb +from mxgateway.values import from_mx_value, to_mx_value + +FIXTURE_ROOT = Path(__file__).resolve().parents[2] / "proto" / "fixtures" / "behavior" + + +def test_value_conversion_fixtures_project_expected_oneof_kind() -> None: + payload = json.loads( + (FIXTURE_ROOT / "values" / "value-conversion-cases.json").read_text( + encoding="utf-8", + ), + ) + + for case in payload["cases"]: + value = ParseDict(case["value"], pb.MxValue()) + projection = from_mx_value(value) + + assert projection.kind == _snake_case(case["expectedKind"]) + assert projection.raw is value + + +def test_to_mx_value_supports_scalar_and_array_inputs() -> None: + assert to_mx_value(True).WhichOneof("kind") == "bool_value" + assert to_mx_value(12).int32_value == 12 + assert to_mx_value(2**40).int64_value == 2**40 + assert to_mx_value(12.5).double_value == 12.5 + assert to_mx_value("abc").string_value == "abc" + assert to_mx_value([1, 2]).array_value.int32_values.values == [1, 2] + assert to_mx_value(["a", "b"]).array_value.string_values.values == ["a", "b"] + + +def test_to_mx_value_uses_utc_timestamps() -> None: + value = to_mx_value(datetime(2026, 1, 1, 0, 0, 4, tzinfo=timezone.utc)) + + assert value.data_type == pb.MX_DATA_TYPE_TIME + assert value.timestamp_value.seconds == 1767225604 + + +def _snake_case(value: str) -> str: + return re.sub(r"(?