Issue #46: implement Python async client values errors and CLI
This commit is contained in:
@@ -1,5 +1,38 @@
|
||||
"""MXAccess Gateway Python client package."""
|
||||
|
||||
from .auth import ApiKey, auth_metadata
|
||||
from .client import GatewayClient
|
||||
from .errors import (
|
||||
MxAccessError,
|
||||
MxGatewayAuthenticationError,
|
||||
MxGatewayAuthorizationError,
|
||||
MxGatewayCommandError,
|
||||
MxGatewayError,
|
||||
MxGatewaySessionError,
|
||||
MxGatewayTransportError,
|
||||
MxGatewayWorkerError,
|
||||
)
|
||||
from .options import ClientOptions
|
||||
from .session import Session
|
||||
from .values import MxValueView, from_mx_value, to_mx_value
|
||||
from .version import __version__
|
||||
|
||||
__all__ = ["__version__"]
|
||||
__all__ = [
|
||||
"ApiKey",
|
||||
"ClientOptions",
|
||||
"GatewayClient",
|
||||
"MxAccessError",
|
||||
"MxGatewayAuthenticationError",
|
||||
"MxGatewayAuthorizationError",
|
||||
"MxGatewayCommandError",
|
||||
"MxGatewayError",
|
||||
"MxGatewaySessionError",
|
||||
"MxGatewayTransportError",
|
||||
"MxGatewayWorkerError",
|
||||
"MxValueView",
|
||||
"Session",
|
||||
"__version__",
|
||||
"auth_metadata",
|
||||
"from_mx_value",
|
||||
"to_mx_value",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
"""Authentication metadata helpers for MXAccess Gateway clients."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
AUTHORIZATION_HEADER = "authorization"
|
||||
REDACTED = "[redacted]"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ApiKey:
|
||||
"""API key wrapper that avoids leaking the secret through repr output."""
|
||||
|
||||
value: str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.value:
|
||||
raise ValueError("api_key must not be empty")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{type(self).__name__}({REDACTED!r})"
|
||||
|
||||
def bearer_value(self) -> str:
|
||||
return f"Bearer {self.value}"
|
||||
|
||||
|
||||
def auth_metadata(api_key: str | ApiKey | None) -> tuple[tuple[str, str], ...]:
|
||||
"""Return gRPC metadata for API key auth."""
|
||||
|
||||
if api_key is None:
|
||||
return ()
|
||||
|
||||
key = api_key.value if isinstance(api_key, ApiKey) else api_key
|
||||
if not key:
|
||||
return ()
|
||||
|
||||
return ((AUTHORIZATION_HEADER, f"Bearer {key}"),)
|
||||
|
||||
|
||||
def merge_metadata(
|
||||
api_key: str | ApiKey | None,
|
||||
metadata: Sequence[tuple[str, str]] | None = None,
|
||||
) -> tuple[tuple[str, str], ...]:
|
||||
"""Merge caller metadata with API key metadata."""
|
||||
|
||||
merged = list(metadata or ())
|
||||
merged.extend(auth_metadata(api_key))
|
||||
return tuple(merged)
|
||||
|
||||
|
||||
def redact_secret(text: str, secrets: Sequence[str | None]) -> str:
|
||||
"""Replace known secret values with a stable redaction marker."""
|
||||
|
||||
redacted = text
|
||||
for secret in secrets:
|
||||
if secret:
|
||||
redacted = redacted.replace(secret, REDACTED)
|
||||
return redacted
|
||||
@@ -0,0 +1,165 @@
|
||||
"""Async MXAccess Gateway client wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
from typing import Any
|
||||
|
||||
import grpc
|
||||
|
||||
from .auth import merge_metadata
|
||||
from .errors import ensure_protocol_success, map_rpc_error
|
||||
from .generated import mxaccess_gateway_pb2 as pb
|
||||
from .generated import mxaccess_gateway_pb2_grpc as pb_grpc
|
||||
from .options import ClientOptions, create_channel
|
||||
|
||||
|
||||
class GatewayClient:
|
||||
"""Async client for the public MXAccess Gateway gRPC API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
options: ClientOptions,
|
||||
stub: Any,
|
||||
channel: grpc.aio.Channel | None = None,
|
||||
) -> None:
|
||||
self.options = options
|
||||
self.raw_stub = stub
|
||||
self._channel = channel
|
||||
self._closed = False
|
||||
|
||||
@classmethod
|
||||
async def connect(
|
||||
cls,
|
||||
options: ClientOptions | None = None,
|
||||
*,
|
||||
endpoint: str | None = None,
|
||||
api_key: str | None = None,
|
||||
plaintext: bool = False,
|
||||
ca_file: str | None = None,
|
||||
server_name_override: str | None = None,
|
||||
stub: Any | None = None,
|
||||
) -> "GatewayClient":
|
||||
"""Create a client with either a real async channel or a supplied fake stub."""
|
||||
|
||||
resolved = options or ClientOptions(
|
||||
endpoint=endpoint or "",
|
||||
api_key=api_key,
|
||||
plaintext=plaintext,
|
||||
ca_file=ca_file,
|
||||
server_name_override=server_name_override,
|
||||
)
|
||||
|
||||
if stub is not None:
|
||||
return cls(options=resolved, stub=stub)
|
||||
|
||||
channel = create_channel(resolved)
|
||||
return cls(
|
||||
options=resolved,
|
||||
stub=pb_grpc.MxAccessGatewayStub(channel),
|
||||
channel=channel,
|
||||
)
|
||||
|
||||
async def __aenter__(self) -> "GatewayClient":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *_exc_info: object) -> None:
|
||||
await self.close()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the owned gRPC channel."""
|
||||
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
if self._channel is not None:
|
||||
await self._channel.close()
|
||||
|
||||
async def open_session(
|
||||
self,
|
||||
request: pb.OpenSessionRequest | None = None,
|
||||
*,
|
||||
requested_backend: str = "",
|
||||
client_session_name: str = "",
|
||||
client_correlation_id: str = "",
|
||||
) -> "Session":
|
||||
"""Open a gateway session and return a high-level session wrapper."""
|
||||
|
||||
from .session import Session
|
||||
|
||||
raw_request = request or pb.OpenSessionRequest(
|
||||
requested_backend=requested_backend,
|
||||
client_session_name=client_session_name,
|
||||
client_correlation_id=client_correlation_id,
|
||||
)
|
||||
reply = await self.open_session_raw(raw_request)
|
||||
return Session(client=self, session_id=reply.session_id, open_reply=reply)
|
||||
|
||||
async def open_session_raw(self, request: pb.OpenSessionRequest) -> pb.OpenSessionReply:
|
||||
reply = await self._unary("open session", self.raw_stub.OpenSession, request)
|
||||
ensure_protocol_success("open session", reply.protocol_status, reply)
|
||||
return reply
|
||||
|
||||
async def close_session_raw(
|
||||
self,
|
||||
request: pb.CloseSessionRequest,
|
||||
) -> pb.CloseSessionReply:
|
||||
reply = await self._unary("close session", self.raw_stub.CloseSession, request)
|
||||
ensure_protocol_success("close session", reply.protocol_status, reply)
|
||||
return reply
|
||||
|
||||
async def invoke_raw(self, request: pb.MxCommandRequest) -> pb.MxCommandReply:
|
||||
reply = await self._unary("invoke", self.raw_stub.Invoke, request)
|
||||
ensure_protocol_success("invoke", reply.protocol_status, reply)
|
||||
return reply
|
||||
|
||||
def stream_events_raw(
|
||||
self,
|
||||
request: pb.StreamEventsRequest,
|
||||
*,
|
||||
metadata: Sequence[tuple[str, str]] | None = None,
|
||||
) -> AsyncIterator[pb.MxEvent]:
|
||||
"""Return an async event iterator and cancel the stream when iteration stops."""
|
||||
|
||||
call = self.raw_stub.StreamEvents(
|
||||
request,
|
||||
metadata=merge_metadata(self.options.api_key, metadata),
|
||||
)
|
||||
return _canceling_iterator(call)
|
||||
|
||||
async def _unary(
|
||||
self,
|
||||
operation: str,
|
||||
method: Any,
|
||||
request: Any,
|
||||
*,
|
||||
metadata: Sequence[tuple[str, str]] | None = None,
|
||||
) -> Any:
|
||||
call = method(
|
||||
request,
|
||||
metadata=merge_metadata(self.options.api_key, metadata),
|
||||
)
|
||||
try:
|
||||
return await call
|
||||
except asyncio.CancelledError:
|
||||
cancel = getattr(call, "cancel", None)
|
||||
if cancel is not None:
|
||||
cancel()
|
||||
raise
|
||||
except grpc.RpcError as error:
|
||||
raise map_rpc_error(operation, error) from error
|
||||
|
||||
|
||||
async def _canceling_iterator(call: Any) -> AsyncIterator[pb.MxEvent]:
|
||||
try:
|
||||
async for event in call:
|
||||
yield event
|
||||
except grpc.RpcError as error:
|
||||
raise map_rpc_error("stream events", error) from error
|
||||
finally:
|
||||
cancel = getattr(call, "cancel", None)
|
||||
if cancel is not None:
|
||||
cancel()
|
||||
@@ -0,0 +1,157 @@
|
||||
"""Typed exception model for MXAccess Gateway Python clients."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import grpc
|
||||
|
||||
from .generated import mxaccess_gateway_pb2 as pb
|
||||
|
||||
|
||||
class MxGatewayError(Exception):
|
||||
"""Base class for client wrapper errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
protocol_status: pb.ProtocolStatus | None = None,
|
||||
raw_reply: Any | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.protocol_status = protocol_status
|
||||
self.raw_reply = raw_reply
|
||||
|
||||
|
||||
class MxGatewayTransportError(MxGatewayError):
|
||||
"""Transport-level gRPC failure."""
|
||||
|
||||
|
||||
class MxGatewayAuthenticationError(MxGatewayTransportError):
|
||||
"""Authentication failure reported by gRPC."""
|
||||
|
||||
|
||||
class MxGatewayAuthorizationError(MxGatewayTransportError):
|
||||
"""Authorization failure reported by gRPC."""
|
||||
|
||||
|
||||
class MxGatewaySessionError(MxGatewayError):
|
||||
"""Gateway session failure."""
|
||||
|
||||
|
||||
class MxGatewayWorkerError(MxGatewayError):
|
||||
"""Gateway worker process or protocol failure."""
|
||||
|
||||
|
||||
class MxGatewayCommandError(MxGatewayError):
|
||||
"""Command failure that preserves the raw protobuf reply."""
|
||||
|
||||
|
||||
class MxAccessError(MxGatewayCommandError):
|
||||
"""MXAccess HRESULT or status failure."""
|
||||
|
||||
|
||||
def map_rpc_error(operation: str, error: grpc.RpcError) -> MxGatewayTransportError:
|
||||
"""Map a generated gRPC exception to the client exception hierarchy."""
|
||||
|
||||
code = error.code() if hasattr(error, "code") else None
|
||||
details = error.details() if hasattr(error, "details") else str(error)
|
||||
message = f"{operation} failed: {details}"
|
||||
|
||||
if code == grpc.StatusCode.UNAUTHENTICATED:
|
||||
return MxGatewayAuthenticationError(message)
|
||||
if code == grpc.StatusCode.PERMISSION_DENIED:
|
||||
return MxGatewayAuthorizationError(message)
|
||||
|
||||
return MxGatewayTransportError(message)
|
||||
|
||||
|
||||
def ensure_protocol_success(
|
||||
operation: str,
|
||||
protocol_status: pb.ProtocolStatus | None,
|
||||
raw_reply: Any | None = None,
|
||||
) -> Any | None:
|
||||
"""Raise typed gateway errors for non-OK protocol statuses."""
|
||||
|
||||
code = (
|
||||
protocol_status.code
|
||||
if protocol_status is not None
|
||||
else pb.PROTOCOL_STATUS_CODE_UNSPECIFIED
|
||||
)
|
||||
|
||||
if code in (
|
||||
pb.PROTOCOL_STATUS_CODE_OK,
|
||||
pb.PROTOCOL_STATUS_CODE_MXACCESS_FAILURE,
|
||||
):
|
||||
return raw_reply
|
||||
|
||||
message_text = protocol_status.message if protocol_status else ""
|
||||
message = f"{operation} failed: {message_text or pb.ProtocolStatusCode.Name(code)}"
|
||||
|
||||
if code in (
|
||||
pb.PROTOCOL_STATUS_CODE_SESSION_NOT_FOUND,
|
||||
pb.PROTOCOL_STATUS_CODE_SESSION_NOT_READY,
|
||||
):
|
||||
raise MxGatewaySessionError(
|
||||
message,
|
||||
protocol_status=protocol_status,
|
||||
raw_reply=raw_reply,
|
||||
)
|
||||
|
||||
if code in (
|
||||
pb.PROTOCOL_STATUS_CODE_WORKER_UNAVAILABLE,
|
||||
pb.PROTOCOL_STATUS_CODE_TIMEOUT,
|
||||
pb.PROTOCOL_STATUS_CODE_CANCELED,
|
||||
pb.PROTOCOL_STATUS_CODE_PROTOCOL_VIOLATION,
|
||||
):
|
||||
raise MxGatewayWorkerError(
|
||||
message,
|
||||
protocol_status=protocol_status,
|
||||
raw_reply=raw_reply,
|
||||
)
|
||||
|
||||
raise MxGatewayCommandError(
|
||||
message,
|
||||
protocol_status=protocol_status,
|
||||
raw_reply=raw_reply,
|
||||
)
|
||||
|
||||
|
||||
def ensure_mxaccess_success(operation: str, reply: pb.MxCommandReply) -> pb.MxCommandReply:
|
||||
"""Raise `MxAccessError` when MXAccess returned HRESULT or status failure."""
|
||||
|
||||
status = reply.protocol_status
|
||||
if status.code == pb.PROTOCOL_STATUS_CODE_MXACCESS_FAILURE:
|
||||
raise MxAccessError(
|
||||
_mxaccess_message(operation, reply),
|
||||
protocol_status=status,
|
||||
raw_reply=reply,
|
||||
)
|
||||
|
||||
if reply.HasField("hresult") and reply.hresult < 0:
|
||||
raise MxAccessError(
|
||||
_mxaccess_message(operation, reply),
|
||||
protocol_status=status,
|
||||
raw_reply=reply,
|
||||
)
|
||||
|
||||
for mx_status in reply.statuses:
|
||||
if mx_status.success == 0:
|
||||
raise MxAccessError(
|
||||
_mxaccess_message(operation, reply),
|
||||
protocol_status=status,
|
||||
raw_reply=reply,
|
||||
)
|
||||
|
||||
return reply
|
||||
|
||||
|
||||
def _mxaccess_message(operation: str, reply: pb.MxCommandReply) -> str:
|
||||
status_text = reply.protocol_status.message or "MXAccess command failed"
|
||||
hresult = reply.hresult if reply.HasField("hresult") else None
|
||||
return (
|
||||
f"{operation} failed: {status_text}; "
|
||||
f"session={reply.session_id}; correlation={reply.correlation_id}; "
|
||||
f"hresult={hresult}; statuses={len(reply.statuses)}"
|
||||
)
|
||||
@@ -0,0 +1,59 @@
|
||||
"""Client connection options for the async Python wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import grpc
|
||||
|
||||
from .auth import REDACTED, ApiKey
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ClientOptions:
|
||||
"""Connection settings for `GatewayClient.connect`."""
|
||||
|
||||
endpoint: str
|
||||
api_key: str | ApiKey | None = None
|
||||
plaintext: bool = False
|
||||
ca_file: str | None = None
|
||||
server_name_override: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.endpoint:
|
||||
raise ValueError("endpoint must not be empty")
|
||||
|
||||
if self.plaintext and self.ca_file:
|
||||
raise ValueError("ca_file cannot be used with plaintext connections")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
api_key = REDACTED if self.api_key else None
|
||||
return (
|
||||
f"{type(self).__name__}(endpoint={self.endpoint!r}, "
|
||||
f"api_key={api_key!r}, plaintext={self.plaintext!r}, "
|
||||
f"ca_file={self.ca_file!r}, "
|
||||
f"server_name_override={self.server_name_override!r})"
|
||||
)
|
||||
|
||||
|
||||
def create_channel(options: ClientOptions) -> grpc.aio.Channel:
|
||||
"""Create a plaintext or TLS `grpc.aio` channel from client options."""
|
||||
|
||||
channel_options: list[tuple[str, str]] = []
|
||||
if options.server_name_override:
|
||||
channel_options.append(("grpc.ssl_target_name_override", options.server_name_override))
|
||||
|
||||
if options.plaintext:
|
||||
return grpc.aio.insecure_channel(options.endpoint, options=channel_options)
|
||||
|
||||
root_certificates = None
|
||||
if options.ca_file:
|
||||
root_certificates = Path(options.ca_file).read_bytes()
|
||||
|
||||
credentials = grpc.ssl_channel_credentials(root_certificates=root_certificates)
|
||||
return grpc.aio.secure_channel(
|
||||
options.endpoint,
|
||||
credentials,
|
||||
options=channel_options,
|
||||
)
|
||||
@@ -0,0 +1,209 @@
|
||||
"""Async session wrapper for MXAccess Gateway commands."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from .errors import ensure_mxaccess_success
|
||||
from .generated import mxaccess_gateway_pb2 as pb
|
||||
from .values import MxValueInput, to_mx_value
|
||||
|
||||
|
||||
class Session:
|
||||
"""A single gateway-backed MXAccess session."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
client: "GatewayClient",
|
||||
session_id: str,
|
||||
open_reply: pb.OpenSessionReply | None = None,
|
||||
) -> None:
|
||||
self.client = client
|
||||
self.session_id = session_id
|
||||
self.open_reply = open_reply
|
||||
self._closed = False
|
||||
|
||||
async def __aenter__(self) -> "Session":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *_exc_info: object) -> None:
|
||||
await self.close()
|
||||
|
||||
async def close(self, *, client_correlation_id: str = "") -> pb.CloseSessionReply:
|
||||
"""Close the gateway session. Repeated calls return a local closed reply."""
|
||||
|
||||
if self._closed:
|
||||
return pb.CloseSessionReply(
|
||||
session_id=self.session_id,
|
||||
final_state=pb.SESSION_STATE_CLOSED,
|
||||
protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK),
|
||||
)
|
||||
|
||||
self._closed = True
|
||||
return await self.client.close_session_raw(
|
||||
pb.CloseSessionRequest(
|
||||
session_id=self.session_id,
|
||||
client_correlation_id=client_correlation_id,
|
||||
),
|
||||
)
|
||||
|
||||
async def invoke(self, command: pb.MxCommand, *, correlation_id: str = "") -> pb.MxCommandReply:
|
||||
"""Invoke a raw command and enforce gateway and MXAccess success."""
|
||||
|
||||
reply = await self.invoke_raw(command, correlation_id=correlation_id)
|
||||
return ensure_mxaccess_success("invoke", reply)
|
||||
|
||||
async def invoke_raw(
|
||||
self,
|
||||
command: pb.MxCommand,
|
||||
*,
|
||||
correlation_id: str = "",
|
||||
) -> pb.MxCommandReply:
|
||||
"""Invoke a raw command and preserve the raw reply."""
|
||||
|
||||
return await self.client.invoke_raw(
|
||||
pb.MxCommandRequest(
|
||||
session_id=self.session_id,
|
||||
client_correlation_id=correlation_id,
|
||||
command=command,
|
||||
),
|
||||
)
|
||||
|
||||
async def register(self, client_name: str, *, correlation_id: str = "") -> int:
|
||||
reply = await self.invoke(
|
||||
pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_REGISTER,
|
||||
register=pb.RegisterCommand(client_name=client_name),
|
||||
),
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
return reply.register.server_handle
|
||||
|
||||
async def unregister(self, server_handle: int, *, correlation_id: str = "") -> None:
|
||||
await self.invoke(
|
||||
pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_UNREGISTER,
|
||||
unregister=pb.UnregisterCommand(server_handle=server_handle),
|
||||
),
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
async def add_item(
|
||||
self,
|
||||
server_handle: int,
|
||||
item_definition: str,
|
||||
*,
|
||||
correlation_id: str = "",
|
||||
) -> int:
|
||||
reply = await self.invoke(
|
||||
pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_ADD_ITEM,
|
||||
add_item=pb.AddItemCommand(
|
||||
server_handle=server_handle,
|
||||
item_definition=item_definition,
|
||||
),
|
||||
),
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
return reply.add_item.item_handle
|
||||
|
||||
async def add_item2(
|
||||
self,
|
||||
server_handle: int,
|
||||
item_definition: str,
|
||||
item_context: str,
|
||||
*,
|
||||
correlation_id: str = "",
|
||||
) -> int:
|
||||
reply = await self.invoke(
|
||||
pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_ADD_ITEM2,
|
||||
add_item2=pb.AddItem2Command(
|
||||
server_handle=server_handle,
|
||||
item_definition=item_definition,
|
||||
item_context=item_context,
|
||||
),
|
||||
),
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
return reply.add_item2.item_handle
|
||||
|
||||
async def advise(
|
||||
self,
|
||||
server_handle: int,
|
||||
item_handle: int,
|
||||
*,
|
||||
correlation_id: str = "",
|
||||
) -> None:
|
||||
await self.invoke(
|
||||
pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_ADVISE,
|
||||
advise=pb.AdviseCommand(
|
||||
server_handle=server_handle,
|
||||
item_handle=item_handle,
|
||||
),
|
||||
),
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
async def write(
|
||||
self,
|
||||
server_handle: int,
|
||||
item_handle: int,
|
||||
value: MxValueInput,
|
||||
*,
|
||||
user_id: int = 0,
|
||||
correlation_id: str = "",
|
||||
) -> None:
|
||||
await self.invoke(
|
||||
pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_WRITE,
|
||||
write=pb.WriteCommand(
|
||||
server_handle=server_handle,
|
||||
item_handle=item_handle,
|
||||
value=to_mx_value(value),
|
||||
user_id=user_id,
|
||||
),
|
||||
),
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
async def write2(
|
||||
self,
|
||||
server_handle: int,
|
||||
item_handle: int,
|
||||
value: MxValueInput,
|
||||
timestamp_value: MxValueInput,
|
||||
*,
|
||||
user_id: int = 0,
|
||||
correlation_id: str = "",
|
||||
) -> None:
|
||||
await self.invoke(
|
||||
pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_WRITE2,
|
||||
write2=pb.Write2Command(
|
||||
server_handle=server_handle,
|
||||
item_handle=item_handle,
|
||||
value=to_mx_value(value),
|
||||
timestamp_value=to_mx_value(timestamp_value),
|
||||
user_id=user_id,
|
||||
),
|
||||
),
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
def stream_events(
|
||||
self,
|
||||
*,
|
||||
after_worker_sequence: int = 0,
|
||||
) -> AsyncIterator[pb.MxEvent]:
|
||||
return self.client.stream_events_raw(
|
||||
pb.StreamEventsRequest(
|
||||
session_id=self.session_id,
|
||||
after_worker_sequence=after_worker_sequence,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
from .client import GatewayClient # noqa: E402
|
||||
@@ -0,0 +1,234 @@
|
||||
"""MXAccess value conversion helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from google.protobuf.timestamp_pb2 import Timestamp
|
||||
|
||||
from .generated import mxaccess_gateway_pb2 as pb
|
||||
|
||||
|
||||
MxValueInput = bool | int | float | str | datetime | bytes | None | Sequence[Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MxValueView:
|
||||
"""Typed projection of a raw `MxValue` protobuf message."""
|
||||
|
||||
value: Any
|
||||
kind: str
|
||||
raw: pb.MxValue
|
||||
|
||||
|
||||
def to_mx_value(value: MxValueInput, *, data_type: str | None = None) -> pb.MxValue:
|
||||
"""Convert a Python value into the public protobuf `MxValue` union."""
|
||||
|
||||
if isinstance(value, pb.MxValue):
|
||||
return value
|
||||
|
||||
if value is None:
|
||||
return pb.MxValue(
|
||||
data_type=pb.MX_DATA_TYPE_NO_DATA,
|
||||
variant_type="VT_EMPTY",
|
||||
is_null=True,
|
||||
raw_data_type=pb.MX_DATA_TYPE_NO_DATA,
|
||||
)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return pb.MxValue(
|
||||
data_type=_data_type(data_type, pb.MX_DATA_TYPE_BOOLEAN),
|
||||
variant_type="VT_BOOL",
|
||||
bool_value=value,
|
||||
)
|
||||
|
||||
if isinstance(value, int):
|
||||
if -(2**31) <= value <= (2**31 - 1):
|
||||
return pb.MxValue(
|
||||
data_type=_data_type(data_type, pb.MX_DATA_TYPE_INTEGER),
|
||||
variant_type="VT_I4",
|
||||
int32_value=value,
|
||||
)
|
||||
|
||||
return pb.MxValue(
|
||||
data_type=_data_type(data_type, pb.MX_DATA_TYPE_INTEGER),
|
||||
variant_type="VT_I8",
|
||||
int64_value=value,
|
||||
)
|
||||
|
||||
if isinstance(value, float):
|
||||
return pb.MxValue(
|
||||
data_type=_data_type(data_type, pb.MX_DATA_TYPE_DOUBLE),
|
||||
variant_type="VT_R8",
|
||||
double_value=value,
|
||||
)
|
||||
|
||||
if isinstance(value, str):
|
||||
return pb.MxValue(
|
||||
data_type=_data_type(data_type, pb.MX_DATA_TYPE_STRING),
|
||||
variant_type="VT_BSTR",
|
||||
string_value=value,
|
||||
)
|
||||
|
||||
if isinstance(value, datetime):
|
||||
return pb.MxValue(
|
||||
data_type=_data_type(data_type, pb.MX_DATA_TYPE_TIME),
|
||||
variant_type="VT_DATE",
|
||||
timestamp_value=_timestamp_from_datetime(value),
|
||||
)
|
||||
|
||||
if isinstance(value, bytes):
|
||||
return pb.MxValue(
|
||||
data_type=_data_type(data_type, pb.MX_DATA_TYPE_UNKNOWN),
|
||||
variant_type="VT_RECORD",
|
||||
raw_value=value,
|
||||
)
|
||||
|
||||
if isinstance(value, Sequence):
|
||||
return _sequence_to_mx_value(value, data_type=data_type)
|
||||
|
||||
raise TypeError(f"unsupported MxValue input type: {type(value).__name__}")
|
||||
|
||||
|
||||
def from_mx_value(value: pb.MxValue) -> MxValueView:
|
||||
"""Project a protobuf `MxValue` into an idiomatic Python value."""
|
||||
|
||||
kind = value.WhichOneof("kind")
|
||||
if kind is None:
|
||||
return MxValueView(None, "none", value)
|
||||
|
||||
if kind == "timestamp_value":
|
||||
return MxValueView(
|
||||
value.timestamp_value.ToDatetime().replace(tzinfo=timezone.utc),
|
||||
kind,
|
||||
value,
|
||||
)
|
||||
|
||||
if kind == "array_value":
|
||||
return MxValueView(from_mx_array(value.array_value), kind, value)
|
||||
|
||||
return MxValueView(getattr(value, kind), kind, value)
|
||||
|
||||
|
||||
def from_mx_array(array: pb.MxArray) -> list[Any]:
|
||||
"""Project a protobuf `MxArray` into a Python list."""
|
||||
|
||||
kind = array.WhichOneof("values")
|
||||
if kind is None:
|
||||
return []
|
||||
|
||||
values = list(getattr(array, kind).values)
|
||||
if kind == "timestamp_values":
|
||||
return [
|
||||
timestamp.ToDatetime().replace(tzinfo=timezone.utc)
|
||||
for timestamp in values
|
||||
]
|
||||
|
||||
return values
|
||||
|
||||
|
||||
def _sequence_to_mx_value(
|
||||
values: Sequence[Any],
|
||||
*,
|
||||
data_type: str | None,
|
||||
) -> pb.MxValue:
|
||||
sequence = list(values)
|
||||
if not sequence:
|
||||
return pb.MxValue(
|
||||
data_type=_data_type(data_type, pb.MX_DATA_TYPE_UNKNOWN),
|
||||
array_value=pb.MxArray(
|
||||
element_data_type=pb.MX_DATA_TYPE_UNKNOWN,
|
||||
dimensions=[0],
|
||||
),
|
||||
)
|
||||
|
||||
first = sequence[0]
|
||||
dimensions = [len(sequence)]
|
||||
|
||||
if all(isinstance(item, bool) for item in sequence):
|
||||
array = pb.MxArray(
|
||||
element_data_type=pb.MX_DATA_TYPE_BOOLEAN,
|
||||
variant_type="VT_ARRAY|VT_BOOL",
|
||||
dimensions=dimensions,
|
||||
bool_values=pb.BoolArray(values=sequence),
|
||||
)
|
||||
return pb.MxValue(data_type=pb.MX_DATA_TYPE_BOOLEAN, array_value=array)
|
||||
|
||||
if all(isinstance(item, int) and not isinstance(item, bool) for item in sequence):
|
||||
use_int32 = all(-(2**31) <= item <= (2**31 - 1) for item in sequence)
|
||||
if use_int32:
|
||||
array = pb.MxArray(
|
||||
element_data_type=pb.MX_DATA_TYPE_INTEGER,
|
||||
variant_type="VT_ARRAY|VT_I4",
|
||||
dimensions=dimensions,
|
||||
int32_values=pb.Int32Array(values=sequence),
|
||||
)
|
||||
else:
|
||||
array = pb.MxArray(
|
||||
element_data_type=pb.MX_DATA_TYPE_INTEGER,
|
||||
variant_type="VT_ARRAY|VT_I8",
|
||||
dimensions=dimensions,
|
||||
int64_values=pb.Int64Array(values=sequence),
|
||||
)
|
||||
|
||||
return pb.MxValue(data_type=pb.MX_DATA_TYPE_INTEGER, array_value=array)
|
||||
|
||||
if all(isinstance(item, float) for item in sequence):
|
||||
array = pb.MxArray(
|
||||
element_data_type=pb.MX_DATA_TYPE_DOUBLE,
|
||||
variant_type="VT_ARRAY|VT_R8",
|
||||
dimensions=dimensions,
|
||||
double_values=pb.DoubleArray(values=sequence),
|
||||
)
|
||||
return pb.MxValue(data_type=pb.MX_DATA_TYPE_DOUBLE, array_value=array)
|
||||
|
||||
if all(isinstance(item, str) for item in sequence):
|
||||
array = pb.MxArray(
|
||||
element_data_type=pb.MX_DATA_TYPE_STRING,
|
||||
variant_type="VT_ARRAY|VT_BSTR",
|
||||
dimensions=dimensions,
|
||||
string_values=pb.StringArray(values=sequence),
|
||||
)
|
||||
return pb.MxValue(data_type=pb.MX_DATA_TYPE_STRING, array_value=array)
|
||||
|
||||
if all(isinstance(item, datetime) for item in sequence):
|
||||
array = pb.MxArray(
|
||||
element_data_type=pb.MX_DATA_TYPE_TIME,
|
||||
variant_type="VT_ARRAY|VT_DATE",
|
||||
dimensions=dimensions,
|
||||
timestamp_values=pb.TimestampArray(
|
||||
values=[_timestamp_from_datetime(item) for item in sequence],
|
||||
),
|
||||
)
|
||||
return pb.MxValue(data_type=pb.MX_DATA_TYPE_TIME, array_value=array)
|
||||
|
||||
if all(isinstance(item, bytes) for item in sequence):
|
||||
array = pb.MxArray(
|
||||
element_data_type=pb.MX_DATA_TYPE_UNKNOWN,
|
||||
variant_type="VT_ARRAY|VT_VARIANT",
|
||||
dimensions=dimensions,
|
||||
raw_values=pb.RawArray(values=sequence),
|
||||
)
|
||||
return pb.MxValue(data_type=pb.MX_DATA_TYPE_UNKNOWN, array_value=array)
|
||||
|
||||
raise TypeError(
|
||||
"MxValue array inputs must use one supported element type; "
|
||||
f"first element was {type(first).__name__}"
|
||||
)
|
||||
|
||||
|
||||
def _timestamp_from_datetime(value: datetime) -> Timestamp:
|
||||
timestamp = Timestamp()
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
timestamp.FromDatetime(value.astimezone(timezone.utc))
|
||||
return timestamp
|
||||
|
||||
|
||||
def _data_type(name: str | None, default: int) -> int:
|
||||
if name is None:
|
||||
return default
|
||||
return pb.MxDataType.Value(name)
|
||||
@@ -1,10 +1,24 @@
|
||||
"""CLI scaffold for the MXAccess Gateway Python client."""
|
||||
"""Command line interface for the MXAccess Gateway Python client."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
from mxgateway import __version__
|
||||
from mxgateway.auth import redact_secret
|
||||
from mxgateway.client import GatewayClient
|
||||
from mxgateway.errors import MxGatewayError
|
||||
from mxgateway.generated import mxaccess_gateway_pb2 as pb
|
||||
from mxgateway.options import ClientOptions
|
||||
from mxgateway.values import MxValueInput
|
||||
|
||||
|
||||
@click.group()
|
||||
@@ -16,14 +30,435 @@ def main() -> None:
|
||||
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
|
||||
def version(output_json: bool) -> None:
|
||||
"""Print client package version information."""
|
||||
|
||||
payload = {
|
||||
"client": "mxgw-py",
|
||||
"package": "mxaccess-gateway-client",
|
||||
"version": __version__,
|
||||
}
|
||||
_emit(payload, output_json=output_json, text=f"mxgw-py {__version__}")
|
||||
|
||||
|
||||
def gateway_options(command: Callable[..., Any]) -> Callable[..., Any]:
|
||||
command = click.option("--endpoint", default="localhost:5000", show_default=True)(command)
|
||||
command = click.option("--api-key", default=None, help="Gateway API key.")(command)
|
||||
command = click.option(
|
||||
"--api-key-env",
|
||||
default=None,
|
||||
help="Environment variable containing the gateway API key.",
|
||||
)(command)
|
||||
command = click.option("--plaintext", is_flag=True, help="Use plaintext gRPC.")(command)
|
||||
command = click.option("--tls", "use_tls", is_flag=True, help="Use TLS gRPC.")(command)
|
||||
command = click.option("--ca-file", default=None, help="Custom root certificate file.")(command)
|
||||
command = click.option(
|
||||
"--server-name-override",
|
||||
default=None,
|
||||
help="TLS server name override for test environments.",
|
||||
)(command)
|
||||
return command
|
||||
|
||||
|
||||
@main.command("open-session")
|
||||
@gateway_options
|
||||
@click.option("--client-name", default="", help="Client session name.")
|
||||
@click.option("--requested-backend", default="", help="Requested backend name.")
|
||||
@click.option("--correlation-id", default="", help="Client correlation id.")
|
||||
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
|
||||
def open_session(**kwargs: Any) -> None:
|
||||
"""Open a gateway session."""
|
||||
|
||||
_run(
|
||||
_open_session(**kwargs),
|
||||
output_json=kwargs["output_json"],
|
||||
secrets=_secrets(kwargs),
|
||||
)
|
||||
|
||||
|
||||
@main.command("close-session")
|
||||
@gateway_options
|
||||
@click.option("--session-id", required=True, help="Gateway session id.")
|
||||
@click.option("--correlation-id", default="", help="Client correlation id.")
|
||||
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
|
||||
def close_session(**kwargs: Any) -> None:
|
||||
"""Close a gateway session."""
|
||||
|
||||
_run(
|
||||
_close_session(**kwargs),
|
||||
output_json=kwargs["output_json"],
|
||||
secrets=_secrets(kwargs),
|
||||
)
|
||||
|
||||
|
||||
@main.command()
|
||||
@gateway_options
|
||||
@click.option("--session-id", required=True, help="Gateway session id.")
|
||||
@click.option("--message", default="ping", show_default=True, help="Ping payload.")
|
||||
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
|
||||
def ping(**kwargs: Any) -> None:
|
||||
"""Send a diagnostic ping command."""
|
||||
|
||||
_run(_ping(**kwargs), output_json=kwargs["output_json"], secrets=_secrets(kwargs))
|
||||
|
||||
|
||||
@main.command()
|
||||
@gateway_options
|
||||
@click.option("--session-id", required=True, help="Gateway session id.")
|
||||
@click.option("--client-name", required=True, help="MXAccess client name.")
|
||||
@click.option("--correlation-id", default="", help="Client correlation id.")
|
||||
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
|
||||
def register(**kwargs: Any) -> None:
|
||||
"""Invoke MXAccess Register."""
|
||||
|
||||
_run(
|
||||
_register(**kwargs),
|
||||
output_json=kwargs["output_json"],
|
||||
secrets=_secrets(kwargs),
|
||||
)
|
||||
|
||||
|
||||
@main.command("add-item")
|
||||
@gateway_options
|
||||
@click.option("--session-id", required=True, help="Gateway session id.")
|
||||
@click.option("--server-handle", required=True, type=int, help="MXAccess server handle.")
|
||||
@click.option("--item", required=True, help="MXAccess item definition.")
|
||||
@click.option("--correlation-id", default="", help="Client correlation id.")
|
||||
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
|
||||
def add_item(**kwargs: Any) -> None:
|
||||
"""Invoke MXAccess AddItem."""
|
||||
|
||||
_run(
|
||||
_add_item(**kwargs),
|
||||
output_json=kwargs["output_json"],
|
||||
secrets=_secrets(kwargs),
|
||||
)
|
||||
|
||||
|
||||
@main.command()
|
||||
@gateway_options
|
||||
@click.option("--session-id", required=True, help="Gateway session id.")
|
||||
@click.option("--server-handle", required=True, type=int, help="MXAccess server handle.")
|
||||
@click.option("--item-handle", required=True, type=int, help="MXAccess item handle.")
|
||||
@click.option("--correlation-id", default="", help="Client correlation id.")
|
||||
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
|
||||
def advise(**kwargs: Any) -> None:
|
||||
"""Invoke MXAccess Advise."""
|
||||
|
||||
_run(_advise(**kwargs), output_json=kwargs["output_json"], secrets=_secrets(kwargs))
|
||||
|
||||
|
||||
@main.command("stream-events")
|
||||
@gateway_options
|
||||
@click.option("--session-id", required=True, help="Gateway session id.")
|
||||
@click.option("--after-worker-sequence", default=0, type=int, show_default=True)
|
||||
@click.option("--max-events", default=1, type=int, show_default=True)
|
||||
@click.option("--timeout", default=5.0, type=float, show_default=True)
|
||||
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
|
||||
def stream_events(**kwargs: Any) -> None:
|
||||
"""Stream a bounded number of events."""
|
||||
|
||||
_run(
|
||||
_stream_events(**kwargs),
|
||||
output_json=kwargs["output_json"],
|
||||
secrets=_secrets(kwargs),
|
||||
)
|
||||
|
||||
|
||||
@main.command()
|
||||
@gateway_options
|
||||
@click.option("--session-id", required=True, help="Gateway session id.")
|
||||
@click.option("--server-handle", required=True, type=int, help="MXAccess server handle.")
|
||||
@click.option("--item-handle", required=True, type=int, help="MXAccess item handle.")
|
||||
@click.option("--type", "value_type", default="string", show_default=True)
|
||||
@click.option("--value", required=True, help="Value to write.")
|
||||
@click.option("--user-id", default=0, type=int, show_default=True)
|
||||
@click.option("--correlation-id", default="", help="Client correlation id.")
|
||||
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
|
||||
def write(**kwargs: Any) -> None:
|
||||
"""Invoke MXAccess Write."""
|
||||
|
||||
_run(_write(**kwargs), output_json=kwargs["output_json"], secrets=_secrets(kwargs))
|
||||
|
||||
|
||||
@main.command()
|
||||
@gateway_options
|
||||
@click.option("--session-id", required=True, help="Gateway session id.")
|
||||
@click.option("--server-handle", required=True, type=int, help="MXAccess server handle.")
|
||||
@click.option("--item-handle", required=True, type=int, help="MXAccess item handle.")
|
||||
@click.option("--type", "value_type", default="string", show_default=True)
|
||||
@click.option("--value", required=True, help="Value to write.")
|
||||
@click.option("--timestamp", required=True, help="ISO-8601 timestamp value.")
|
||||
@click.option("--user-id", default=0, type=int, show_default=True)
|
||||
@click.option("--correlation-id", default="", help="Client correlation id.")
|
||||
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
|
||||
def write2(**kwargs: Any) -> None:
|
||||
"""Invoke MXAccess Write2."""
|
||||
|
||||
_run(_write2(**kwargs), output_json=kwargs["output_json"], secrets=_secrets(kwargs))
|
||||
|
||||
|
||||
@main.command()
|
||||
@gateway_options
|
||||
@click.option("--client-name", default="mxgw-py-smoke", show_default=True)
|
||||
@click.option("--item", required=True, help="MXAccess item definition.")
|
||||
@click.option("--max-events", default=1, type=int, show_default=True)
|
||||
@click.option("--timeout", default=5.0, type=float, show_default=True)
|
||||
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
|
||||
def smoke(**kwargs: Any) -> None:
|
||||
"""Run a bounded open/register/add/advise/stream/close smoke flow."""
|
||||
|
||||
_run(_smoke(**kwargs), output_json=kwargs["output_json"], secrets=_secrets(kwargs))
|
||||
|
||||
|
||||
async def _open_session(**kwargs: Any) -> dict[str, Any]:
|
||||
async with await _connect(kwargs) as client:
|
||||
reply = await client.open_session_raw(
|
||||
pb.OpenSessionRequest(
|
||||
requested_backend=kwargs["requested_backend"],
|
||||
client_session_name=kwargs["client_name"],
|
||||
client_correlation_id=kwargs["correlation_id"],
|
||||
),
|
||||
)
|
||||
return {"sessionId": reply.session_id, "rawReply": _message_dict(reply)}
|
||||
|
||||
|
||||
async def _close_session(**kwargs: Any) -> dict[str, Any]:
|
||||
async with await _connect(kwargs) as client:
|
||||
reply = await client.close_session_raw(
|
||||
pb.CloseSessionRequest(
|
||||
session_id=kwargs["session_id"],
|
||||
client_correlation_id=kwargs["correlation_id"],
|
||||
),
|
||||
)
|
||||
return {"sessionId": reply.session_id, "rawReply": _message_dict(reply)}
|
||||
|
||||
|
||||
async def _ping(**kwargs: Any) -> dict[str, Any]:
|
||||
async with await _connect(kwargs) as client:
|
||||
reply = await client.invoke_raw(
|
||||
pb.MxCommandRequest(
|
||||
session_id=kwargs["session_id"],
|
||||
command=pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_PING,
|
||||
ping=pb.PingCommand(message=kwargs["message"]),
|
||||
),
|
||||
),
|
||||
)
|
||||
return {"kind": "ping", "rawReply": _message_dict(reply)}
|
||||
|
||||
|
||||
async def _register(**kwargs: Any) -> dict[str, Any]:
|
||||
async with await _connect(kwargs) as client:
|
||||
session = _session(client, kwargs["session_id"])
|
||||
server_handle = await session.register(
|
||||
kwargs["client_name"],
|
||||
correlation_id=kwargs["correlation_id"],
|
||||
)
|
||||
return {"serverHandle": server_handle}
|
||||
|
||||
|
||||
async def _add_item(**kwargs: Any) -> dict[str, Any]:
|
||||
async with await _connect(kwargs) as client:
|
||||
session = _session(client, kwargs["session_id"])
|
||||
item_handle = await session.add_item(
|
||||
kwargs["server_handle"],
|
||||
kwargs["item"],
|
||||
correlation_id=kwargs["correlation_id"],
|
||||
)
|
||||
return {"itemHandle": item_handle}
|
||||
|
||||
|
||||
async def _advise(**kwargs: Any) -> dict[str, Any]:
|
||||
async with await _connect(kwargs) as client:
|
||||
session = _session(client, kwargs["session_id"])
|
||||
await session.advise(
|
||||
kwargs["server_handle"],
|
||||
kwargs["item_handle"],
|
||||
correlation_id=kwargs["correlation_id"],
|
||||
)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
async def _stream_events(**kwargs: Any) -> dict[str, Any]:
|
||||
async with await _connect(kwargs) as client:
|
||||
session = _session(client, kwargs["session_id"])
|
||||
events = await _collect_events(
|
||||
session.stream_events(after_worker_sequence=kwargs["after_worker_sequence"]),
|
||||
max_events=kwargs["max_events"],
|
||||
timeout=kwargs["timeout"],
|
||||
)
|
||||
return {"events": [_message_dict(event) for event in events]}
|
||||
|
||||
|
||||
async def _write(**kwargs: Any) -> dict[str, Any]:
|
||||
value = _parse_value(kwargs["value"], kwargs["value_type"])
|
||||
async with await _connect(kwargs) as client:
|
||||
session = _session(client, kwargs["session_id"])
|
||||
await session.write(
|
||||
kwargs["server_handle"],
|
||||
kwargs["item_handle"],
|
||||
value,
|
||||
user_id=kwargs["user_id"],
|
||||
correlation_id=kwargs["correlation_id"],
|
||||
)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
async def _write2(**kwargs: Any) -> dict[str, Any]:
|
||||
value = _parse_value(kwargs["value"], kwargs["value_type"])
|
||||
timestamp = _parse_datetime(kwargs["timestamp"])
|
||||
async with await _connect(kwargs) as client:
|
||||
session = _session(client, kwargs["session_id"])
|
||||
await session.write2(
|
||||
kwargs["server_handle"],
|
||||
kwargs["item_handle"],
|
||||
value,
|
||||
timestamp,
|
||||
user_id=kwargs["user_id"],
|
||||
correlation_id=kwargs["correlation_id"],
|
||||
)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
async def _smoke(**kwargs: Any) -> dict[str, Any]:
|
||||
async with await _connect(kwargs) as client:
|
||||
session = await client.open_session(client_session_name=kwargs["client_name"])
|
||||
closed = False
|
||||
try:
|
||||
server_handle = await session.register(kwargs["client_name"])
|
||||
item_handle = await session.add_item(server_handle, kwargs["item"])
|
||||
await session.advise(server_handle, item_handle)
|
||||
events = await _collect_events(
|
||||
session.stream_events(),
|
||||
max_events=kwargs["max_events"],
|
||||
timeout=kwargs["timeout"],
|
||||
)
|
||||
return {
|
||||
"sessionId": session.session_id,
|
||||
"serverHandle": server_handle,
|
||||
"itemHandle": item_handle,
|
||||
"events": [_message_dict(event) for event in events],
|
||||
}
|
||||
finally:
|
||||
if not closed:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def _connect(kwargs: dict[str, Any]) -> GatewayClient:
|
||||
api_key = kwargs.get("api_key") or _api_key_from_env(kwargs.get("api_key_env"))
|
||||
return await GatewayClient.connect(
|
||||
ClientOptions(
|
||||
endpoint=kwargs["endpoint"],
|
||||
api_key=api_key,
|
||||
plaintext=_use_plaintext(kwargs),
|
||||
ca_file=kwargs.get("ca_file"),
|
||||
server_name_override=kwargs.get("server_name_override"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _session(client: GatewayClient, session_id: str):
|
||||
from mxgateway.session import Session
|
||||
|
||||
return Session(client=client, session_id=session_id)
|
||||
|
||||
|
||||
def _use_plaintext(kwargs: dict[str, Any]) -> bool:
|
||||
if kwargs.get("use_tls"):
|
||||
return False
|
||||
if kwargs.get("plaintext"):
|
||||
return True
|
||||
return kwargs["endpoint"].startswith("localhost:") or kwargs["endpoint"].startswith("127.0.0.1:")
|
||||
|
||||
|
||||
def _api_key_from_env(name: str | None) -> str | None:
|
||||
if not name:
|
||||
return None
|
||||
return os.environ.get(name)
|
||||
|
||||
|
||||
def _secrets(kwargs: dict[str, Any]) -> list[str | None]:
|
||||
return [
|
||||
kwargs.get("api_key"),
|
||||
_api_key_from_env(kwargs.get("api_key_env")),
|
||||
]
|
||||
|
||||
|
||||
def _run(
|
||||
awaitable: Awaitable[dict[str, Any]],
|
||||
*,
|
||||
output_json: bool,
|
||||
secrets: list[str | None],
|
||||
) -> None:
|
||||
try:
|
||||
payload = asyncio.run(awaitable)
|
||||
except MxGatewayError as error:
|
||||
raise click.ClickException(redact_secret(str(error), secrets)) from error
|
||||
|
||||
_emit(payload, output_json=output_json)
|
||||
|
||||
|
||||
def _emit(
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
output_json: bool,
|
||||
text: str | None = None,
|
||||
) -> None:
|
||||
if output_json:
|
||||
click.echo(json.dumps(payload, sort_keys=True))
|
||||
return
|
||||
|
||||
click.echo(f"mxgw-py {__version__}")
|
||||
click.echo(text or json.dumps(payload, sort_keys=True))
|
||||
|
||||
|
||||
async def _collect_events(
|
||||
events: Any,
|
||||
*,
|
||||
max_events: int,
|
||||
timeout: float,
|
||||
) -> list[pb.MxEvent]:
|
||||
collected: list[pb.MxEvent] = []
|
||||
iterator = events.__aiter__()
|
||||
try:
|
||||
while len(collected) < max_events:
|
||||
collected.append(await asyncio.wait_for(iterator.__anext__(), timeout=timeout))
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
finally:
|
||||
close = getattr(iterator, "aclose", None)
|
||||
if close is not None:
|
||||
await close()
|
||||
return collected
|
||||
|
||||
|
||||
def _parse_value(raw_value: str, value_type: str) -> MxValueInput:
|
||||
normalized = value_type.lower()
|
||||
if normalized == "bool":
|
||||
return raw_value.lower() in ("1", "true", "yes", "on")
|
||||
if normalized in ("int", "int32", "int64"):
|
||||
return int(raw_value)
|
||||
if normalized in ("float", "double"):
|
||||
return float(raw_value)
|
||||
if normalized in ("time", "timestamp"):
|
||||
return _parse_datetime(raw_value)
|
||||
if normalized == "raw":
|
||||
return raw_value.encode("utf-8")
|
||||
if normalized == "string":
|
||||
return raw_value
|
||||
raise click.BadParameter(f"unsupported value type: {value_type}", param_hint="--type")
|
||||
|
||||
|
||||
def _parse_datetime(raw_value: str) -> datetime:
|
||||
if raw_value.endswith("Z"):
|
||||
raw_value = raw_value[:-1] + "+00:00"
|
||||
parsed = datetime.fromisoformat(raw_value)
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed
|
||||
|
||||
|
||||
def _message_dict(message: Any) -> dict[str, Any]:
|
||||
return MessageToDict(
|
||||
message,
|
||||
preserving_proto_field_name=False,
|
||||
use_integers_for_enums=False,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user