Compare commits

...

2 Commits

Author SHA1 Message Date
Joseph Doherty b57662aae7 Issue #46: implement Python async client values errors and CLI 2026-04-26 20:46:18 -04:00
dohertj2 14afb325c3 Merge pull request #96 from agent-1/issue-47-scaffold-java-gradle-build
Issue #47: scaffold Java Gradle build
2026-04-26 20:42:39 -04:00
14 changed files with 1883 additions and 10 deletions
+56 -6
View File
@@ -1,8 +1,8 @@
# Python Client
The Python client package contains generated MXAccess Gateway protobuf
bindings, the `mxgateway` package scaffold, and the `mxgw-py` test CLI
scaffold. The package uses the shared proto inputs documented in
bindings, the async `mxgateway` package, and the `mxgw-py` test CLI. The
package uses the shared proto inputs documented in
`../../docs/client-proto-generation.md` so gateway and client contracts stay in
sync.
@@ -43,15 +43,65 @@ python -m pytest
python -m pip wheel . --no-deps --wheel-dir "$env:TEMP\mxgateway-python-wheel"
```
The scaffold tests import the generated gateway and worker stubs and exercise
the deterministic CLI version output.
The tests import the generated gateway and worker stubs, run fake async gateway
stubs, verify API key metadata, exercise stream cancellation, load shared value
and command fixtures, and check deterministic CLI output.
## Library Usage
The library is async-first:
```python
from mxgateway import GatewayClient
async with await GatewayClient.connect(
endpoint="localhost:5000",
api_key="mxgw_example",
plaintext=True,
) as client:
session = await client.open_session(client_session_name="python-client")
try:
server_handle = await session.register("python-client")
item_handle = await session.add_item(server_handle, "Object.Attribute")
await session.advise(server_handle, item_handle)
finally:
await session.close()
```
`GatewayClient.open_session_raw`, `GatewayClient.invoke_raw`, and
`GatewayClient.stream_events_raw` keep the generated protobuf replies and
events available for parity tests. `Session` helpers call the method-specific
MXAccess commands and preserve raw replies on typed command exceptions.
Canceling a Python task cancels the client-side gRPC call or stream wait. It
does not abort an in-flight MXAccess COM call inside the worker process.
## Authentication And TLS
`ClientOptions.api_key` adds this metadata to unary calls and streams:
```text
authorization: Bearer <api-key>
```
The client supports plaintext channels for local development, TLS with system
roots, TLS with a custom `ca_file`, and an optional test server name override.
API keys are redacted from option repr output and CLI error output.
## CLI
The scaffold CLI exposes version information:
The CLI emits deterministic JSON for automation:
```powershell
mxgw-py version --json
mxgw-py open-session --endpoint localhost:5000 --plaintext --json
mxgw-py register --session-id <id> --client-name python-client --json
mxgw-py add-item --session-id <id> --server-handle 1 --item Object.Attribute --json
mxgw-py advise --session-id <id> --server-handle 1 --item-handle 2 --json
mxgw-py stream-events --session-id <id> --max-events 1 --json
mxgw-py write --session-id <id> --server-handle 1 --item-handle 2 --type int32 --value 123 --json
```
Additional commands are implemented with the async client/session wrapper work.
Use `--api-key` or `--api-key-env MXGATEWAY_API_KEY` to attach API key
metadata. `smoke` opens a session, registers, adds an item, advises, streams a
bounded event count, and closes the session in a `finally` block.
+34 -1
View File
@@ -1,5 +1,38 @@
"""MXAccess Gateway Python client package."""
from .auth import ApiKey, auth_metadata
from .client import GatewayClient
from .errors import (
MxAccessError,
MxGatewayAuthenticationError,
MxGatewayAuthorizationError,
MxGatewayCommandError,
MxGatewayError,
MxGatewaySessionError,
MxGatewayTransportError,
MxGatewayWorkerError,
)
from .options import ClientOptions
from .session import Session
from .values import MxValueView, from_mx_value, to_mx_value
from .version import __version__
__all__ = ["__version__"]
__all__ = [
"ApiKey",
"ClientOptions",
"GatewayClient",
"MxAccessError",
"MxGatewayAuthenticationError",
"MxGatewayAuthorizationError",
"MxGatewayCommandError",
"MxGatewayError",
"MxGatewaySessionError",
"MxGatewayTransportError",
"MxGatewayWorkerError",
"MxValueView",
"Session",
"__version__",
"auth_metadata",
"from_mx_value",
"to_mx_value",
]
+58
View File
@@ -0,0 +1,58 @@
"""Authentication metadata helpers for MXAccess Gateway clients."""
from collections.abc import Sequence
from dataclasses import dataclass
AUTHORIZATION_HEADER = "authorization"
REDACTED = "[redacted]"
@dataclass(frozen=True)
class ApiKey:
"""API key wrapper that avoids leaking the secret through repr output."""
value: str
def __post_init__(self) -> None:
if not self.value:
raise ValueError("api_key must not be empty")
def __repr__(self) -> str:
return f"{type(self).__name__}({REDACTED!r})"
def bearer_value(self) -> str:
return f"Bearer {self.value}"
def auth_metadata(api_key: str | ApiKey | None) -> tuple[tuple[str, str], ...]:
"""Return gRPC metadata for API key auth."""
if api_key is None:
return ()
key = api_key.value if isinstance(api_key, ApiKey) else api_key
if not key:
return ()
return ((AUTHORIZATION_HEADER, f"Bearer {key}"),)
def merge_metadata(
api_key: str | ApiKey | None,
metadata: Sequence[tuple[str, str]] | None = None,
) -> tuple[tuple[str, str], ...]:
"""Merge caller metadata with API key metadata."""
merged = list(metadata or ())
merged.extend(auth_metadata(api_key))
return tuple(merged)
def redact_secret(text: str, secrets: Sequence[str | None]) -> str:
"""Replace known secret values with a stable redaction marker."""
redacted = text
for secret in secrets:
if secret:
redacted = redacted.replace(secret, REDACTED)
return redacted
+165
View File
@@ -0,0 +1,165 @@
"""Async MXAccess Gateway client wrapper."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator, Sequence
from typing import Any
import grpc
from .auth import merge_metadata
from .errors import ensure_protocol_success, map_rpc_error
from .generated import mxaccess_gateway_pb2 as pb
from .generated import mxaccess_gateway_pb2_grpc as pb_grpc
from .options import ClientOptions, create_channel
class GatewayClient:
"""Async client for the public MXAccess Gateway gRPC API."""
def __init__(
self,
*,
options: ClientOptions,
stub: Any,
channel: grpc.aio.Channel | None = None,
) -> None:
self.options = options
self.raw_stub = stub
self._channel = channel
self._closed = False
@classmethod
async def connect(
cls,
options: ClientOptions | None = None,
*,
endpoint: str | None = None,
api_key: str | None = None,
plaintext: bool = False,
ca_file: str | None = None,
server_name_override: str | None = None,
stub: Any | None = None,
) -> "GatewayClient":
"""Create a client with either a real async channel or a supplied fake stub."""
resolved = options or ClientOptions(
endpoint=endpoint or "",
api_key=api_key,
plaintext=plaintext,
ca_file=ca_file,
server_name_override=server_name_override,
)
if stub is not None:
return cls(options=resolved, stub=stub)
channel = create_channel(resolved)
return cls(
options=resolved,
stub=pb_grpc.MxAccessGatewayStub(channel),
channel=channel,
)
async def __aenter__(self) -> "GatewayClient":
return self
async def __aexit__(self, *_exc_info: object) -> None:
await self.close()
async def close(self) -> None:
"""Close the owned gRPC channel."""
if self._closed:
return
self._closed = True
if self._channel is not None:
await self._channel.close()
async def open_session(
self,
request: pb.OpenSessionRequest | None = None,
*,
requested_backend: str = "",
client_session_name: str = "",
client_correlation_id: str = "",
) -> "Session":
"""Open a gateway session and return a high-level session wrapper."""
from .session import Session
raw_request = request or pb.OpenSessionRequest(
requested_backend=requested_backend,
client_session_name=client_session_name,
client_correlation_id=client_correlation_id,
)
reply = await self.open_session_raw(raw_request)
return Session(client=self, session_id=reply.session_id, open_reply=reply)
async def open_session_raw(self, request: pb.OpenSessionRequest) -> pb.OpenSessionReply:
reply = await self._unary("open session", self.raw_stub.OpenSession, request)
ensure_protocol_success("open session", reply.protocol_status, reply)
return reply
async def close_session_raw(
self,
request: pb.CloseSessionRequest,
) -> pb.CloseSessionReply:
reply = await self._unary("close session", self.raw_stub.CloseSession, request)
ensure_protocol_success("close session", reply.protocol_status, reply)
return reply
async def invoke_raw(self, request: pb.MxCommandRequest) -> pb.MxCommandReply:
reply = await self._unary("invoke", self.raw_stub.Invoke, request)
ensure_protocol_success("invoke", reply.protocol_status, reply)
return reply
def stream_events_raw(
self,
request: pb.StreamEventsRequest,
*,
metadata: Sequence[tuple[str, str]] | None = None,
) -> AsyncIterator[pb.MxEvent]:
"""Return an async event iterator and cancel the stream when iteration stops."""
call = self.raw_stub.StreamEvents(
request,
metadata=merge_metadata(self.options.api_key, metadata),
)
return _canceling_iterator(call)
async def _unary(
self,
operation: str,
method: Any,
request: Any,
*,
metadata: Sequence[tuple[str, str]] | None = None,
) -> Any:
call = method(
request,
metadata=merge_metadata(self.options.api_key, metadata),
)
try:
return await call
except asyncio.CancelledError:
cancel = getattr(call, "cancel", None)
if cancel is not None:
cancel()
raise
except grpc.RpcError as error:
raise map_rpc_error(operation, error) from error
async def _canceling_iterator(call: Any) -> AsyncIterator[pb.MxEvent]:
try:
async for event in call:
yield event
except grpc.RpcError as error:
raise map_rpc_error("stream events", error) from error
finally:
cancel = getattr(call, "cancel", None)
if cancel is not None:
cancel()
+157
View File
@@ -0,0 +1,157 @@
"""Typed exception model for MXAccess Gateway Python clients."""
from __future__ import annotations
from typing import Any
import grpc
from .generated import mxaccess_gateway_pb2 as pb
class MxGatewayError(Exception):
"""Base class for client wrapper errors."""
def __init__(
self,
message: str,
*,
protocol_status: pb.ProtocolStatus | None = None,
raw_reply: Any | None = None,
) -> None:
super().__init__(message)
self.protocol_status = protocol_status
self.raw_reply = raw_reply
class MxGatewayTransportError(MxGatewayError):
"""Transport-level gRPC failure."""
class MxGatewayAuthenticationError(MxGatewayTransportError):
"""Authentication failure reported by gRPC."""
class MxGatewayAuthorizationError(MxGatewayTransportError):
"""Authorization failure reported by gRPC."""
class MxGatewaySessionError(MxGatewayError):
"""Gateway session failure."""
class MxGatewayWorkerError(MxGatewayError):
"""Gateway worker process or protocol failure."""
class MxGatewayCommandError(MxGatewayError):
"""Command failure that preserves the raw protobuf reply."""
class MxAccessError(MxGatewayCommandError):
"""MXAccess HRESULT or status failure."""
def map_rpc_error(operation: str, error: grpc.RpcError) -> MxGatewayTransportError:
"""Map a generated gRPC exception to the client exception hierarchy."""
code = error.code() if hasattr(error, "code") else None
details = error.details() if hasattr(error, "details") else str(error)
message = f"{operation} failed: {details}"
if code == grpc.StatusCode.UNAUTHENTICATED:
return MxGatewayAuthenticationError(message)
if code == grpc.StatusCode.PERMISSION_DENIED:
return MxGatewayAuthorizationError(message)
return MxGatewayTransportError(message)
def ensure_protocol_success(
operation: str,
protocol_status: pb.ProtocolStatus | None,
raw_reply: Any | None = None,
) -> Any | None:
"""Raise typed gateway errors for non-OK protocol statuses."""
code = (
protocol_status.code
if protocol_status is not None
else pb.PROTOCOL_STATUS_CODE_UNSPECIFIED
)
if code in (
pb.PROTOCOL_STATUS_CODE_OK,
pb.PROTOCOL_STATUS_CODE_MXACCESS_FAILURE,
):
return raw_reply
message_text = protocol_status.message if protocol_status else ""
message = f"{operation} failed: {message_text or pb.ProtocolStatusCode.Name(code)}"
if code in (
pb.PROTOCOL_STATUS_CODE_SESSION_NOT_FOUND,
pb.PROTOCOL_STATUS_CODE_SESSION_NOT_READY,
):
raise MxGatewaySessionError(
message,
protocol_status=protocol_status,
raw_reply=raw_reply,
)
if code in (
pb.PROTOCOL_STATUS_CODE_WORKER_UNAVAILABLE,
pb.PROTOCOL_STATUS_CODE_TIMEOUT,
pb.PROTOCOL_STATUS_CODE_CANCELED,
pb.PROTOCOL_STATUS_CODE_PROTOCOL_VIOLATION,
):
raise MxGatewayWorkerError(
message,
protocol_status=protocol_status,
raw_reply=raw_reply,
)
raise MxGatewayCommandError(
message,
protocol_status=protocol_status,
raw_reply=raw_reply,
)
def ensure_mxaccess_success(operation: str, reply: pb.MxCommandReply) -> pb.MxCommandReply:
"""Raise `MxAccessError` when MXAccess returned HRESULT or status failure."""
status = reply.protocol_status
if status.code == pb.PROTOCOL_STATUS_CODE_MXACCESS_FAILURE:
raise MxAccessError(
_mxaccess_message(operation, reply),
protocol_status=status,
raw_reply=reply,
)
if reply.HasField("hresult") and reply.hresult < 0:
raise MxAccessError(
_mxaccess_message(operation, reply),
protocol_status=status,
raw_reply=reply,
)
for mx_status in reply.statuses:
if mx_status.success == 0:
raise MxAccessError(
_mxaccess_message(operation, reply),
protocol_status=status,
raw_reply=reply,
)
return reply
def _mxaccess_message(operation: str, reply: pb.MxCommandReply) -> str:
status_text = reply.protocol_status.message or "MXAccess command failed"
hresult = reply.hresult if reply.HasField("hresult") else None
return (
f"{operation} failed: {status_text}; "
f"session={reply.session_id}; correlation={reply.correlation_id}; "
f"hresult={hresult}; statuses={len(reply.statuses)}"
)
+59
View File
@@ -0,0 +1,59 @@
"""Client connection options for the async Python wrapper."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import grpc
from .auth import REDACTED, ApiKey
@dataclass(frozen=True)
class ClientOptions:
"""Connection settings for `GatewayClient.connect`."""
endpoint: str
api_key: str | ApiKey | None = None
plaintext: bool = False
ca_file: str | None = None
server_name_override: str | None = None
def __post_init__(self) -> None:
if not self.endpoint:
raise ValueError("endpoint must not be empty")
if self.plaintext and self.ca_file:
raise ValueError("ca_file cannot be used with plaintext connections")
def __repr__(self) -> str:
api_key = REDACTED if self.api_key else None
return (
f"{type(self).__name__}(endpoint={self.endpoint!r}, "
f"api_key={api_key!r}, plaintext={self.plaintext!r}, "
f"ca_file={self.ca_file!r}, "
f"server_name_override={self.server_name_override!r})"
)
def create_channel(options: ClientOptions) -> grpc.aio.Channel:
"""Create a plaintext or TLS `grpc.aio` channel from client options."""
channel_options: list[tuple[str, str]] = []
if options.server_name_override:
channel_options.append(("grpc.ssl_target_name_override", options.server_name_override))
if options.plaintext:
return grpc.aio.insecure_channel(options.endpoint, options=channel_options)
root_certificates = None
if options.ca_file:
root_certificates = Path(options.ca_file).read_bytes()
credentials = grpc.ssl_channel_credentials(root_certificates=root_certificates)
return grpc.aio.secure_channel(
options.endpoint,
credentials,
options=channel_options,
)
+209
View File
@@ -0,0 +1,209 @@
"""Async session wrapper for MXAccess Gateway commands."""
from __future__ import annotations
from collections.abc import AsyncIterator
from .errors import ensure_mxaccess_success
from .generated import mxaccess_gateway_pb2 as pb
from .values import MxValueInput, to_mx_value
class Session:
"""A single gateway-backed MXAccess session."""
def __init__(
self,
*,
client: "GatewayClient",
session_id: str,
open_reply: pb.OpenSessionReply | None = None,
) -> None:
self.client = client
self.session_id = session_id
self.open_reply = open_reply
self._closed = False
async def __aenter__(self) -> "Session":
return self
async def __aexit__(self, *_exc_info: object) -> None:
await self.close()
async def close(self, *, client_correlation_id: str = "") -> pb.CloseSessionReply:
"""Close the gateway session. Repeated calls return a local closed reply."""
if self._closed:
return pb.CloseSessionReply(
session_id=self.session_id,
final_state=pb.SESSION_STATE_CLOSED,
protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK),
)
self._closed = True
return await self.client.close_session_raw(
pb.CloseSessionRequest(
session_id=self.session_id,
client_correlation_id=client_correlation_id,
),
)
async def invoke(self, command: pb.MxCommand, *, correlation_id: str = "") -> pb.MxCommandReply:
"""Invoke a raw command and enforce gateway and MXAccess success."""
reply = await self.invoke_raw(command, correlation_id=correlation_id)
return ensure_mxaccess_success("invoke", reply)
async def invoke_raw(
self,
command: pb.MxCommand,
*,
correlation_id: str = "",
) -> pb.MxCommandReply:
"""Invoke a raw command and preserve the raw reply."""
return await self.client.invoke_raw(
pb.MxCommandRequest(
session_id=self.session_id,
client_correlation_id=correlation_id,
command=command,
),
)
async def register(self, client_name: str, *, correlation_id: str = "") -> int:
reply = await self.invoke(
pb.MxCommand(
kind=pb.MX_COMMAND_KIND_REGISTER,
register=pb.RegisterCommand(client_name=client_name),
),
correlation_id=correlation_id,
)
return reply.register.server_handle
async def unregister(self, server_handle: int, *, correlation_id: str = "") -> None:
await self.invoke(
pb.MxCommand(
kind=pb.MX_COMMAND_KIND_UNREGISTER,
unregister=pb.UnregisterCommand(server_handle=server_handle),
),
correlation_id=correlation_id,
)
async def add_item(
self,
server_handle: int,
item_definition: str,
*,
correlation_id: str = "",
) -> int:
reply = await self.invoke(
pb.MxCommand(
kind=pb.MX_COMMAND_KIND_ADD_ITEM,
add_item=pb.AddItemCommand(
server_handle=server_handle,
item_definition=item_definition,
),
),
correlation_id=correlation_id,
)
return reply.add_item.item_handle
async def add_item2(
self,
server_handle: int,
item_definition: str,
item_context: str,
*,
correlation_id: str = "",
) -> int:
reply = await self.invoke(
pb.MxCommand(
kind=pb.MX_COMMAND_KIND_ADD_ITEM2,
add_item2=pb.AddItem2Command(
server_handle=server_handle,
item_definition=item_definition,
item_context=item_context,
),
),
correlation_id=correlation_id,
)
return reply.add_item2.item_handle
async def advise(
self,
server_handle: int,
item_handle: int,
*,
correlation_id: str = "",
) -> None:
await self.invoke(
pb.MxCommand(
kind=pb.MX_COMMAND_KIND_ADVISE,
advise=pb.AdviseCommand(
server_handle=server_handle,
item_handle=item_handle,
),
),
correlation_id=correlation_id,
)
async def write(
self,
server_handle: int,
item_handle: int,
value: MxValueInput,
*,
user_id: int = 0,
correlation_id: str = "",
) -> None:
await self.invoke(
pb.MxCommand(
kind=pb.MX_COMMAND_KIND_WRITE,
write=pb.WriteCommand(
server_handle=server_handle,
item_handle=item_handle,
value=to_mx_value(value),
user_id=user_id,
),
),
correlation_id=correlation_id,
)
async def write2(
self,
server_handle: int,
item_handle: int,
value: MxValueInput,
timestamp_value: MxValueInput,
*,
user_id: int = 0,
correlation_id: str = "",
) -> None:
await self.invoke(
pb.MxCommand(
kind=pb.MX_COMMAND_KIND_WRITE2,
write2=pb.Write2Command(
server_handle=server_handle,
item_handle=item_handle,
value=to_mx_value(value),
timestamp_value=to_mx_value(timestamp_value),
user_id=user_id,
),
),
correlation_id=correlation_id,
)
def stream_events(
self,
*,
after_worker_sequence: int = 0,
) -> AsyncIterator[pb.MxEvent]:
return self.client.stream_events_raw(
pb.StreamEventsRequest(
session_id=self.session_id,
after_worker_sequence=after_worker_sequence,
),
)
from .client import GatewayClient # noqa: E402
+234
View File
@@ -0,0 +1,234 @@
"""MXAccess value conversion helpers."""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any
from google.protobuf.timestamp_pb2 import Timestamp
from .generated import mxaccess_gateway_pb2 as pb
MxValueInput = bool | int | float | str | datetime | bytes | None | Sequence[Any]
@dataclass(frozen=True)
class MxValueView:
"""Typed projection of a raw `MxValue` protobuf message."""
value: Any
kind: str
raw: pb.MxValue
def to_mx_value(value: MxValueInput, *, data_type: str | None = None) -> pb.MxValue:
"""Convert a Python value into the public protobuf `MxValue` union."""
if isinstance(value, pb.MxValue):
return value
if value is None:
return pb.MxValue(
data_type=pb.MX_DATA_TYPE_NO_DATA,
variant_type="VT_EMPTY",
is_null=True,
raw_data_type=pb.MX_DATA_TYPE_NO_DATA,
)
if isinstance(value, bool):
return pb.MxValue(
data_type=_data_type(data_type, pb.MX_DATA_TYPE_BOOLEAN),
variant_type="VT_BOOL",
bool_value=value,
)
if isinstance(value, int):
if -(2**31) <= value <= (2**31 - 1):
return pb.MxValue(
data_type=_data_type(data_type, pb.MX_DATA_TYPE_INTEGER),
variant_type="VT_I4",
int32_value=value,
)
return pb.MxValue(
data_type=_data_type(data_type, pb.MX_DATA_TYPE_INTEGER),
variant_type="VT_I8",
int64_value=value,
)
if isinstance(value, float):
return pb.MxValue(
data_type=_data_type(data_type, pb.MX_DATA_TYPE_DOUBLE),
variant_type="VT_R8",
double_value=value,
)
if isinstance(value, str):
return pb.MxValue(
data_type=_data_type(data_type, pb.MX_DATA_TYPE_STRING),
variant_type="VT_BSTR",
string_value=value,
)
if isinstance(value, datetime):
return pb.MxValue(
data_type=_data_type(data_type, pb.MX_DATA_TYPE_TIME),
variant_type="VT_DATE",
timestamp_value=_timestamp_from_datetime(value),
)
if isinstance(value, bytes):
return pb.MxValue(
data_type=_data_type(data_type, pb.MX_DATA_TYPE_UNKNOWN),
variant_type="VT_RECORD",
raw_value=value,
)
if isinstance(value, Sequence):
return _sequence_to_mx_value(value, data_type=data_type)
raise TypeError(f"unsupported MxValue input type: {type(value).__name__}")
def from_mx_value(value: pb.MxValue) -> MxValueView:
"""Project a protobuf `MxValue` into an idiomatic Python value."""
kind = value.WhichOneof("kind")
if kind is None:
return MxValueView(None, "none", value)
if kind == "timestamp_value":
return MxValueView(
value.timestamp_value.ToDatetime().replace(tzinfo=timezone.utc),
kind,
value,
)
if kind == "array_value":
return MxValueView(from_mx_array(value.array_value), kind, value)
return MxValueView(getattr(value, kind), kind, value)
def from_mx_array(array: pb.MxArray) -> list[Any]:
"""Project a protobuf `MxArray` into a Python list."""
kind = array.WhichOneof("values")
if kind is None:
return []
values = list(getattr(array, kind).values)
if kind == "timestamp_values":
return [
timestamp.ToDatetime().replace(tzinfo=timezone.utc)
for timestamp in values
]
return values
def _sequence_to_mx_value(
values: Sequence[Any],
*,
data_type: str | None,
) -> pb.MxValue:
sequence = list(values)
if not sequence:
return pb.MxValue(
data_type=_data_type(data_type, pb.MX_DATA_TYPE_UNKNOWN),
array_value=pb.MxArray(
element_data_type=pb.MX_DATA_TYPE_UNKNOWN,
dimensions=[0],
),
)
first = sequence[0]
dimensions = [len(sequence)]
if all(isinstance(item, bool) for item in sequence):
array = pb.MxArray(
element_data_type=pb.MX_DATA_TYPE_BOOLEAN,
variant_type="VT_ARRAY|VT_BOOL",
dimensions=dimensions,
bool_values=pb.BoolArray(values=sequence),
)
return pb.MxValue(data_type=pb.MX_DATA_TYPE_BOOLEAN, array_value=array)
if all(isinstance(item, int) and not isinstance(item, bool) for item in sequence):
use_int32 = all(-(2**31) <= item <= (2**31 - 1) for item in sequence)
if use_int32:
array = pb.MxArray(
element_data_type=pb.MX_DATA_TYPE_INTEGER,
variant_type="VT_ARRAY|VT_I4",
dimensions=dimensions,
int32_values=pb.Int32Array(values=sequence),
)
else:
array = pb.MxArray(
element_data_type=pb.MX_DATA_TYPE_INTEGER,
variant_type="VT_ARRAY|VT_I8",
dimensions=dimensions,
int64_values=pb.Int64Array(values=sequence),
)
return pb.MxValue(data_type=pb.MX_DATA_TYPE_INTEGER, array_value=array)
if all(isinstance(item, float) for item in sequence):
array = pb.MxArray(
element_data_type=pb.MX_DATA_TYPE_DOUBLE,
variant_type="VT_ARRAY|VT_R8",
dimensions=dimensions,
double_values=pb.DoubleArray(values=sequence),
)
return pb.MxValue(data_type=pb.MX_DATA_TYPE_DOUBLE, array_value=array)
if all(isinstance(item, str) for item in sequence):
array = pb.MxArray(
element_data_type=pb.MX_DATA_TYPE_STRING,
variant_type="VT_ARRAY|VT_BSTR",
dimensions=dimensions,
string_values=pb.StringArray(values=sequence),
)
return pb.MxValue(data_type=pb.MX_DATA_TYPE_STRING, array_value=array)
if all(isinstance(item, datetime) for item in sequence):
array = pb.MxArray(
element_data_type=pb.MX_DATA_TYPE_TIME,
variant_type="VT_ARRAY|VT_DATE",
dimensions=dimensions,
timestamp_values=pb.TimestampArray(
values=[_timestamp_from_datetime(item) for item in sequence],
),
)
return pb.MxValue(data_type=pb.MX_DATA_TYPE_TIME, array_value=array)
if all(isinstance(item, bytes) for item in sequence):
array = pb.MxArray(
element_data_type=pb.MX_DATA_TYPE_UNKNOWN,
variant_type="VT_ARRAY|VT_VARIANT",
dimensions=dimensions,
raw_values=pb.RawArray(values=sequence),
)
return pb.MxValue(data_type=pb.MX_DATA_TYPE_UNKNOWN, array_value=array)
raise TypeError(
"MxValue array inputs must use one supported element type; "
f"first element was {type(first).__name__}"
)
def _timestamp_from_datetime(value: datetime) -> Timestamp:
timestamp = Timestamp()
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
timestamp.FromDatetime(value.astimezone(timezone.utc))
return timestamp
def _data_type(name: str | None, default: int) -> int:
if name is None:
return default
return pb.MxDataType.Value(name)
+437 -2
View File
@@ -1,10 +1,24 @@
"""CLI scaffold for the MXAccess Gateway Python client."""
"""Command line interface for the MXAccess Gateway Python client."""
from __future__ import annotations
import asyncio
import json
import os
from collections.abc import Awaitable, Callable
from datetime import datetime, timezone
from typing import Any
import click
from google.protobuf.json_format import MessageToDict
from mxgateway import __version__
from mxgateway.auth import redact_secret
from mxgateway.client import GatewayClient
from mxgateway.errors import MxGatewayError
from mxgateway.generated import mxaccess_gateway_pb2 as pb
from mxgateway.options import ClientOptions
from mxgateway.values import MxValueInput
@click.group()
@@ -16,14 +30,435 @@ def main() -> None:
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
def version(output_json: bool) -> None:
"""Print client package version information."""
payload = {
"client": "mxgw-py",
"package": "mxaccess-gateway-client",
"version": __version__,
}
_emit(payload, output_json=output_json, text=f"mxgw-py {__version__}")
def gateway_options(command: Callable[..., Any]) -> Callable[..., Any]:
command = click.option("--endpoint", default="localhost:5000", show_default=True)(command)
command = click.option("--api-key", default=None, help="Gateway API key.")(command)
command = click.option(
"--api-key-env",
default=None,
help="Environment variable containing the gateway API key.",
)(command)
command = click.option("--plaintext", is_flag=True, help="Use plaintext gRPC.")(command)
command = click.option("--tls", "use_tls", is_flag=True, help="Use TLS gRPC.")(command)
command = click.option("--ca-file", default=None, help="Custom root certificate file.")(command)
command = click.option(
"--server-name-override",
default=None,
help="TLS server name override for test environments.",
)(command)
return command
@main.command("open-session")
@gateway_options
@click.option("--client-name", default="", help="Client session name.")
@click.option("--requested-backend", default="", help="Requested backend name.")
@click.option("--correlation-id", default="", help="Client correlation id.")
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
def open_session(**kwargs: Any) -> None:
"""Open a gateway session."""
_run(
_open_session(**kwargs),
output_json=kwargs["output_json"],
secrets=_secrets(kwargs),
)
@main.command("close-session")
@gateway_options
@click.option("--session-id", required=True, help="Gateway session id.")
@click.option("--correlation-id", default="", help="Client correlation id.")
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
def close_session(**kwargs: Any) -> None:
"""Close a gateway session."""
_run(
_close_session(**kwargs),
output_json=kwargs["output_json"],
secrets=_secrets(kwargs),
)
@main.command()
@gateway_options
@click.option("--session-id", required=True, help="Gateway session id.")
@click.option("--message", default="ping", show_default=True, help="Ping payload.")
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
def ping(**kwargs: Any) -> None:
"""Send a diagnostic ping command."""
_run(_ping(**kwargs), output_json=kwargs["output_json"], secrets=_secrets(kwargs))
@main.command()
@gateway_options
@click.option("--session-id", required=True, help="Gateway session id.")
@click.option("--client-name", required=True, help="MXAccess client name.")
@click.option("--correlation-id", default="", help="Client correlation id.")
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
def register(**kwargs: Any) -> None:
"""Invoke MXAccess Register."""
_run(
_register(**kwargs),
output_json=kwargs["output_json"],
secrets=_secrets(kwargs),
)
@main.command("add-item")
@gateway_options
@click.option("--session-id", required=True, help="Gateway session id.")
@click.option("--server-handle", required=True, type=int, help="MXAccess server handle.")
@click.option("--item", required=True, help="MXAccess item definition.")
@click.option("--correlation-id", default="", help="Client correlation id.")
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
def add_item(**kwargs: Any) -> None:
"""Invoke MXAccess AddItem."""
_run(
_add_item(**kwargs),
output_json=kwargs["output_json"],
secrets=_secrets(kwargs),
)
@main.command()
@gateway_options
@click.option("--session-id", required=True, help="Gateway session id.")
@click.option("--server-handle", required=True, type=int, help="MXAccess server handle.")
@click.option("--item-handle", required=True, type=int, help="MXAccess item handle.")
@click.option("--correlation-id", default="", help="Client correlation id.")
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
def advise(**kwargs: Any) -> None:
"""Invoke MXAccess Advise."""
_run(_advise(**kwargs), output_json=kwargs["output_json"], secrets=_secrets(kwargs))
@main.command("stream-events")
@gateway_options
@click.option("--session-id", required=True, help="Gateway session id.")
@click.option("--after-worker-sequence", default=0, type=int, show_default=True)
@click.option("--max-events", default=1, type=int, show_default=True)
@click.option("--timeout", default=5.0, type=float, show_default=True)
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
def stream_events(**kwargs: Any) -> None:
"""Stream a bounded number of events."""
_run(
_stream_events(**kwargs),
output_json=kwargs["output_json"],
secrets=_secrets(kwargs),
)
@main.command()
@gateway_options
@click.option("--session-id", required=True, help="Gateway session id.")
@click.option("--server-handle", required=True, type=int, help="MXAccess server handle.")
@click.option("--item-handle", required=True, type=int, help="MXAccess item handle.")
@click.option("--type", "value_type", default="string", show_default=True)
@click.option("--value", required=True, help="Value to write.")
@click.option("--user-id", default=0, type=int, show_default=True)
@click.option("--correlation-id", default="", help="Client correlation id.")
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
def write(**kwargs: Any) -> None:
"""Invoke MXAccess Write."""
_run(_write(**kwargs), output_json=kwargs["output_json"], secrets=_secrets(kwargs))
@main.command()
@gateway_options
@click.option("--session-id", required=True, help="Gateway session id.")
@click.option("--server-handle", required=True, type=int, help="MXAccess server handle.")
@click.option("--item-handle", required=True, type=int, help="MXAccess item handle.")
@click.option("--type", "value_type", default="string", show_default=True)
@click.option("--value", required=True, help="Value to write.")
@click.option("--timestamp", required=True, help="ISO-8601 timestamp value.")
@click.option("--user-id", default=0, type=int, show_default=True)
@click.option("--correlation-id", default="", help="Client correlation id.")
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
def write2(**kwargs: Any) -> None:
"""Invoke MXAccess Write2."""
_run(_write2(**kwargs), output_json=kwargs["output_json"], secrets=_secrets(kwargs))
@main.command()
@gateway_options
@click.option("--client-name", default="mxgw-py-smoke", show_default=True)
@click.option("--item", required=True, help="MXAccess item definition.")
@click.option("--max-events", default=1, type=int, show_default=True)
@click.option("--timeout", default=5.0, type=float, show_default=True)
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
def smoke(**kwargs: Any) -> None:
"""Run a bounded open/register/add/advise/stream/close smoke flow."""
_run(_smoke(**kwargs), output_json=kwargs["output_json"], secrets=_secrets(kwargs))
async def _open_session(**kwargs: Any) -> dict[str, Any]:
async with await _connect(kwargs) as client:
reply = await client.open_session_raw(
pb.OpenSessionRequest(
requested_backend=kwargs["requested_backend"],
client_session_name=kwargs["client_name"],
client_correlation_id=kwargs["correlation_id"],
),
)
return {"sessionId": reply.session_id, "rawReply": _message_dict(reply)}
async def _close_session(**kwargs: Any) -> dict[str, Any]:
async with await _connect(kwargs) as client:
reply = await client.close_session_raw(
pb.CloseSessionRequest(
session_id=kwargs["session_id"],
client_correlation_id=kwargs["correlation_id"],
),
)
return {"sessionId": reply.session_id, "rawReply": _message_dict(reply)}
async def _ping(**kwargs: Any) -> dict[str, Any]:
async with await _connect(kwargs) as client:
reply = await client.invoke_raw(
pb.MxCommandRequest(
session_id=kwargs["session_id"],
command=pb.MxCommand(
kind=pb.MX_COMMAND_KIND_PING,
ping=pb.PingCommand(message=kwargs["message"]),
),
),
)
return {"kind": "ping", "rawReply": _message_dict(reply)}
async def _register(**kwargs: Any) -> dict[str, Any]:
async with await _connect(kwargs) as client:
session = _session(client, kwargs["session_id"])
server_handle = await session.register(
kwargs["client_name"],
correlation_id=kwargs["correlation_id"],
)
return {"serverHandle": server_handle}
async def _add_item(**kwargs: Any) -> dict[str, Any]:
async with await _connect(kwargs) as client:
session = _session(client, kwargs["session_id"])
item_handle = await session.add_item(
kwargs["server_handle"],
kwargs["item"],
correlation_id=kwargs["correlation_id"],
)
return {"itemHandle": item_handle}
async def _advise(**kwargs: Any) -> dict[str, Any]:
async with await _connect(kwargs) as client:
session = _session(client, kwargs["session_id"])
await session.advise(
kwargs["server_handle"],
kwargs["item_handle"],
correlation_id=kwargs["correlation_id"],
)
return {"ok": True}
async def _stream_events(**kwargs: Any) -> dict[str, Any]:
async with await _connect(kwargs) as client:
session = _session(client, kwargs["session_id"])
events = await _collect_events(
session.stream_events(after_worker_sequence=kwargs["after_worker_sequence"]),
max_events=kwargs["max_events"],
timeout=kwargs["timeout"],
)
return {"events": [_message_dict(event) for event in events]}
async def _write(**kwargs: Any) -> dict[str, Any]:
value = _parse_value(kwargs["value"], kwargs["value_type"])
async with await _connect(kwargs) as client:
session = _session(client, kwargs["session_id"])
await session.write(
kwargs["server_handle"],
kwargs["item_handle"],
value,
user_id=kwargs["user_id"],
correlation_id=kwargs["correlation_id"],
)
return {"ok": True}
async def _write2(**kwargs: Any) -> dict[str, Any]:
value = _parse_value(kwargs["value"], kwargs["value_type"])
timestamp = _parse_datetime(kwargs["timestamp"])
async with await _connect(kwargs) as client:
session = _session(client, kwargs["session_id"])
await session.write2(
kwargs["server_handle"],
kwargs["item_handle"],
value,
timestamp,
user_id=kwargs["user_id"],
correlation_id=kwargs["correlation_id"],
)
return {"ok": True}
async def _smoke(**kwargs: Any) -> dict[str, Any]:
async with await _connect(kwargs) as client:
session = await client.open_session(client_session_name=kwargs["client_name"])
closed = False
try:
server_handle = await session.register(kwargs["client_name"])
item_handle = await session.add_item(server_handle, kwargs["item"])
await session.advise(server_handle, item_handle)
events = await _collect_events(
session.stream_events(),
max_events=kwargs["max_events"],
timeout=kwargs["timeout"],
)
return {
"sessionId": session.session_id,
"serverHandle": server_handle,
"itemHandle": item_handle,
"events": [_message_dict(event) for event in events],
}
finally:
if not closed:
await session.close()
async def _connect(kwargs: dict[str, Any]) -> GatewayClient:
api_key = kwargs.get("api_key") or _api_key_from_env(kwargs.get("api_key_env"))
return await GatewayClient.connect(
ClientOptions(
endpoint=kwargs["endpoint"],
api_key=api_key,
plaintext=_use_plaintext(kwargs),
ca_file=kwargs.get("ca_file"),
server_name_override=kwargs.get("server_name_override"),
),
)
def _session(client: GatewayClient, session_id: str):
from mxgateway.session import Session
return Session(client=client, session_id=session_id)
def _use_plaintext(kwargs: dict[str, Any]) -> bool:
if kwargs.get("use_tls"):
return False
if kwargs.get("plaintext"):
return True
return kwargs["endpoint"].startswith("localhost:") or kwargs["endpoint"].startswith("127.0.0.1:")
def _api_key_from_env(name: str | None) -> str | None:
if not name:
return None
return os.environ.get(name)
def _secrets(kwargs: dict[str, Any]) -> list[str | None]:
return [
kwargs.get("api_key"),
_api_key_from_env(kwargs.get("api_key_env")),
]
def _run(
awaitable: Awaitable[dict[str, Any]],
*,
output_json: bool,
secrets: list[str | None],
) -> None:
try:
payload = asyncio.run(awaitable)
except MxGatewayError as error:
raise click.ClickException(redact_secret(str(error), secrets)) from error
_emit(payload, output_json=output_json)
def _emit(
payload: dict[str, Any],
*,
output_json: bool,
text: str | None = None,
) -> None:
if output_json:
click.echo(json.dumps(payload, sort_keys=True))
return
click.echo(f"mxgw-py {__version__}")
click.echo(text or json.dumps(payload, sort_keys=True))
async def _collect_events(
events: Any,
*,
max_events: int,
timeout: float,
) -> list[pb.MxEvent]:
collected: list[pb.MxEvent] = []
iterator = events.__aiter__()
try:
while len(collected) < max_events:
collected.append(await asyncio.wait_for(iterator.__anext__(), timeout=timeout))
except StopAsyncIteration:
pass
finally:
close = getattr(iterator, "aclose", None)
if close is not None:
await close()
return collected
def _parse_value(raw_value: str, value_type: str) -> MxValueInput:
normalized = value_type.lower()
if normalized == "bool":
return raw_value.lower() in ("1", "true", "yes", "on")
if normalized in ("int", "int32", "int64"):
return int(raw_value)
if normalized in ("float", "double"):
return float(raw_value)
if normalized in ("time", "timestamp"):
return _parse_datetime(raw_value)
if normalized == "raw":
return raw_value.encode("utf-8")
if normalized == "string":
return raw_value
raise click.BadParameter(f"unsupported value type: {value_type}", param_hint="--type")
def _parse_datetime(raw_value: str) -> datetime:
if raw_value.endswith("Z"):
raw_value = raw_value[:-1] + "+00:00"
parsed = datetime.fromisoformat(raw_value)
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed
def _message_dict(message: Any) -> dict[str, Any]:
return MessageToDict(
message,
preserving_proto_field_name=False,
use_integers_for_enums=False,
)
+103
View File
@@ -0,0 +1,103 @@
"""Tests for auth metadata and connection options."""
import pytest
from mxgateway.auth import REDACTED, ApiKey, auth_metadata, redact_secret
from mxgateway import options as options_module
from mxgateway.options import ClientOptions, create_channel
def test_auth_metadata_adds_bearer_api_key() -> None:
assert auth_metadata("mxgw_test_secret") == (
("authorization", "Bearer mxgw_test_secret"),
)
def test_api_key_repr_is_redacted() -> None:
api_key = ApiKey("mxgw_test_secret")
assert "mxgw_test_secret" not in repr(api_key)
assert REDACTED in repr(api_key)
def test_redact_secret_replaces_known_values() -> None:
redacted = redact_secret(
"authorization failed for mxgw_test_secret",
["mxgw_test_secret"],
)
assert redacted == f"authorization failed for {REDACTED}"
def test_client_options_reject_plaintext_with_ca_file() -> None:
with pytest.raises(ValueError, match="ca_file"):
ClientOptions(
endpoint="localhost:5000",
plaintext=True,
ca_file="ca.pem",
)
def test_client_options_repr_redacts_api_key() -> None:
options = ClientOptions(endpoint="localhost:5000", api_key="mxgw_test_secret")
assert "mxgw_test_secret" not in repr(options)
assert REDACTED in repr(options)
def test_create_channel_uses_plaintext_channel(monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[tuple[str, object]] = []
def fake_insecure_channel(endpoint: str, *, options: object) -> str:
calls.append((endpoint, options))
return "plain-channel"
monkeypatch.setattr(
options_module.grpc.aio,
"insecure_channel",
fake_insecure_channel,
)
channel = create_channel(ClientOptions(endpoint="localhost:5000", plaintext=True))
assert channel == "plain-channel"
assert calls == [("localhost:5000", [])]
def test_create_channel_uses_tls_channel(monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[tuple[str, object, object]] = []
def fake_credentials(*, root_certificates: object) -> str:
assert root_certificates is None
return "creds"
def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str:
calls.append((endpoint, credentials, options))
return "tls-channel"
monkeypatch.setattr(
options_module.grpc,
"ssl_channel_credentials",
fake_credentials,
)
monkeypatch.setattr(
options_module.grpc.aio,
"secure_channel",
fake_secure_channel,
)
channel = create_channel(
ClientOptions(
endpoint="gateway.example:5001",
server_name_override="gateway.test",
),
)
assert channel == "tls-channel"
assert calls == [
(
"gateway.example:5001",
"creds",
[("grpc.ssl_target_name_override", "gateway.test")],
),
]
+48 -1
View File
@@ -1,4 +1,4 @@
"""Tests for the Python CLI scaffold."""
"""Tests for the Python CLI."""
import json
@@ -19,3 +19,50 @@ def test_version_json_is_deterministic() -> None:
"package": "mxaccess-gateway-client",
"version": __version__,
}
def test_write_parser_rejects_unknown_value_type() -> None:
runner = CliRunner()
result = runner.invoke(
main,
[
"write",
"--session-id",
"session-1",
"--server-handle",
"12",
"--item-handle",
"34",
"--type",
"unsupported",
"--value",
"123",
"--api-key",
"mxgw_test_secret",
"--json",
],
)
assert result.exit_code != 0
assert "unsupported value type" in result.output
def test_cli_error_output_redacts_api_key() -> None:
runner = CliRunner()
result = runner.invoke(
main,
[
"open-session",
"--endpoint",
"127.0.0.1:1",
"--api-key",
"mxgw_test_secret",
"--plaintext",
"--json",
],
)
assert result.exit_code != 0
assert "mxgw_test_secret" not in result.output
+225
View File
@@ -0,0 +1,225 @@
"""Tests for the async client and session wrappers."""
from __future__ import annotations
import asyncio
from typing import Any
import pytest
from mxgateway import ClientOptions, GatewayClient, MxAccessError
from mxgateway.generated import mxaccess_gateway_pb2 as pb
@pytest.mark.asyncio
async def test_session_helpers_send_auth_metadata_and_preserve_raw_replies() -> None:
stub = FakeGatewayStub()
client = await GatewayClient.connect(
ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True),
stub=stub,
)
session = await client.open_session(client_session_name="pytest")
server_handle = await session.register("pytest-client")
item_handle = await session.add_item(server_handle, "Object.Attribute")
await session.advise(server_handle, item_handle)
assert session.session_id == "session-1"
assert server_handle == 12
assert item_handle == 34
assert stub.open_session.metadata == (("authorization", "Bearer mxgw_test_secret"),)
assert stub.invoke.requests[0].command.register.client_name == "pytest-client"
assert stub.invoke.requests[1].command.add_item.item_definition == "Object.Attribute"
assert stub.invoke.requests[2].command.advise.item_handle == 34
@pytest.mark.asyncio
async def test_mxaccess_error_preserves_raw_reply() -> None:
stub = FakeGatewayStub()
failure_reply = pb.MxCommandReply(
session_id="session-1",
kind=pb.MX_COMMAND_KIND_WRITE,
protocol_status=pb.ProtocolStatus(
code=pb.PROTOCOL_STATUS_CODE_MXACCESS_FAILURE,
message="MXAccess rejected write.",
),
hresult=-1,
)
stub.invoke.replies = [failure_reply]
client = await GatewayClient.connect(
ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True),
stub=stub,
)
session = await client.open_session()
with pytest.raises(MxAccessError) as captured:
await session.write(12, 34, 123)
assert captured.value.raw_reply is failure_reply
@pytest.mark.asyncio
async def test_stream_events_cancels_underlying_call_when_closed() -> None:
stream = FakeStream(
[
pb.MxEvent(
session_id="session-1",
worker_sequence=1,
family=pb.MX_EVENT_FAMILY_ON_DATA_CHANGE,
),
],
)
stub = FakeGatewayStub(stream=stream)
client = await GatewayClient.connect(
ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True),
stub=stub,
)
session = await client.open_session()
events = session.stream_events()
first = await anext(events)
await events.aclose()
assert first.worker_sequence == 1
assert stream.cancelled
assert stub.stream_metadata == (("authorization", "Bearer mxgw_test_secret"),)
@pytest.mark.asyncio
async def test_unary_task_cancellation_reaches_fake_call() -> None:
blocking = BlockingCancellableUnary()
stub = FakeGatewayStub()
stub.OpenSession = blocking
client = await GatewayClient.connect(
ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True),
stub=stub,
)
task = asyncio.create_task(client.open_session())
await blocking.started.wait()
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
assert blocking.call is not None
assert blocking.call.cancelled
class FakeGatewayStub:
def __init__(self, stream: "FakeStream | None" = None) -> None:
self.open_session = FakeUnary(
[
pb.OpenSessionReply(
session_id="session-1",
protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK),
),
],
)
self.close_session = FakeUnary(
[
pb.CloseSessionReply(
session_id="session-1",
protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK),
),
],
)
self.invoke = FakeUnary(
[
pb.MxCommandReply(
session_id="session-1",
kind=pb.MX_COMMAND_KIND_REGISTER,
protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK),
register=pb.RegisterReply(server_handle=12),
),
pb.MxCommandReply(
session_id="session-1",
kind=pb.MX_COMMAND_KIND_ADD_ITEM,
protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK),
add_item=pb.AddItemReply(item_handle=34),
),
pb.MxCommandReply(
session_id="session-1",
kind=pb.MX_COMMAND_KIND_ADVISE,
protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK),
),
],
)
self.OpenSession = self.open_session
self.CloseSession = self.close_session
self.Invoke = self.invoke
self._stream = stream or FakeStream([])
self.stream_metadata: tuple[tuple[str, str], ...] | None = None
def StreamEvents(
self,
request: pb.StreamEventsRequest,
*,
metadata: tuple[tuple[str, str], ...],
) -> "FakeStream":
self.stream_request = request
self.stream_metadata = metadata
return self._stream
class FakeUnary:
def __init__(self, replies: list[Any]) -> None:
self.replies = replies
self.requests: list[Any] = []
self.metadata: tuple[tuple[str, str], ...] | None = None
async def __call__(
self,
request: Any,
*,
metadata: tuple[tuple[str, str], ...],
) -> Any:
self.requests.append(request)
self.metadata = metadata
return self.replies.pop(0)
class BlockingCancellableUnary:
def __init__(self) -> None:
self.started = asyncio.Event()
self.call: BlockingCall | None = None
def __call__(self, *_args: Any, **_kwargs: Any) -> "BlockingCall":
self.call = BlockingCall(self.started)
return self.call
class BlockingCall:
def __init__(self, started: asyncio.Event) -> None:
self.started = started
self.cancelled = False
def __await__(self):
return self._wait().__await__()
async def _wait(self) -> Any:
self.started.set()
try:
await asyncio.Event().wait()
except asyncio.CancelledError:
raise
def cancel(self) -> None:
self.cancelled = True
class FakeStream:
def __init__(self, events: list[pb.MxEvent]) -> None:
self._events = events
self.cancelled = False
def __aiter__(self) -> "FakeStream":
return self
async def __anext__(self) -> pb.MxEvent:
if not self._events:
await asyncio.sleep(3600)
return self._events.pop(0)
def cancel(self) -> None:
self.cancelled = True
+49
View File
@@ -0,0 +1,49 @@
"""Tests for typed command error mapping."""
import json
from pathlib import Path
import pytest
from google.protobuf.json_format import ParseDict
from mxgateway.errors import ensure_mxaccess_success, ensure_protocol_success
from mxgateway import MxAccessError, MxGatewaySessionError
from mxgateway.generated import mxaccess_gateway_pb2 as pb
FIXTURE_ROOT = Path(__file__).resolve().parents[2] / "proto" / "fixtures" / "behavior"
def test_register_fixture_is_protocol_and_mxaccess_success() -> None:
reply = _load_reply("command-replies/register.ok.reply.json")
assert ensure_protocol_success("register", reply.protocol_status, reply) is reply
assert ensure_mxaccess_success("register", reply) is reply
def test_write_failure_fixture_preserves_raw_reply() -> None:
reply = _load_reply("command-replies/write.mxaccess-failure.reply.json")
assert ensure_protocol_success("write", reply.protocol_status, reply) is reply
with pytest.raises(MxAccessError) as captured:
ensure_mxaccess_success("write", reply)
assert captured.value.raw_reply is reply
assert captured.value.raw_reply.hresult == -2147220992
assert len(captured.value.raw_reply.statuses) == 2
def test_session_status_maps_to_session_error() -> None:
status = pb.ProtocolStatus(
code=pb.PROTOCOL_STATUS_CODE_SESSION_NOT_FOUND,
message="session missing",
)
with pytest.raises(MxGatewaySessionError) as captured:
ensure_protocol_success("invoke", status)
assert captured.value.protocol_status is status
def _load_reply(name: str) -> pb.MxCommandReply:
payload = json.loads((FIXTURE_ROOT / name).read_text(encoding="utf-8"))
return ParseDict(payload, pb.MxCommandReply())
+49
View File
@@ -0,0 +1,49 @@
"""Tests for MXAccess value conversion helpers."""
import json
import re
from datetime import datetime, timezone
from pathlib import Path
from google.protobuf.json_format import ParseDict
from mxgateway.generated import mxaccess_gateway_pb2 as pb
from mxgateway.values import from_mx_value, to_mx_value
FIXTURE_ROOT = Path(__file__).resolve().parents[2] / "proto" / "fixtures" / "behavior"
def test_value_conversion_fixtures_project_expected_oneof_kind() -> None:
payload = json.loads(
(FIXTURE_ROOT / "values" / "value-conversion-cases.json").read_text(
encoding="utf-8",
),
)
for case in payload["cases"]:
value = ParseDict(case["value"], pb.MxValue())
projection = from_mx_value(value)
assert projection.kind == _snake_case(case["expectedKind"])
assert projection.raw is value
def test_to_mx_value_supports_scalar_and_array_inputs() -> None:
assert to_mx_value(True).WhichOneof("kind") == "bool_value"
assert to_mx_value(12).int32_value == 12
assert to_mx_value(2**40).int64_value == 2**40
assert to_mx_value(12.5).double_value == 12.5
assert to_mx_value("abc").string_value == "abc"
assert to_mx_value([1, 2]).array_value.int32_values.values == [1, 2]
assert to_mx_value(["a", "b"]).array_value.string_values.values == ["a", "b"]
def test_to_mx_value_uses_utc_timestamps() -> None:
value = to_mx_value(datetime(2026, 1, 1, 0, 0, 4, tzinfo=timezone.utc))
assert value.data_type == pb.MX_DATA_TYPE_TIME
assert value.timestamp_value.seconds == 1767225604
def _snake_case(value: str) -> str:
return re.sub(r"(?<!^)(?=[A-Z])", "_", value).lower()