Improve gateway reliability and dashboard docs
This commit is contained in:
@@ -74,9 +74,9 @@ class GatewayClient:
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
if self._channel is not None:
|
||||
await self._channel.close()
|
||||
self._closed = True
|
||||
|
||||
async def open_session(
|
||||
self,
|
||||
@@ -124,10 +124,10 @@ class GatewayClient:
|
||||
) -> AsyncIterator[pb.MxEvent]:
|
||||
"""Return an async event iterator and cancel the stream when iteration stops."""
|
||||
|
||||
call = self.raw_stub.StreamEvents(
|
||||
request,
|
||||
metadata=merge_metadata(self.options.api_key, metadata),
|
||||
)
|
||||
kwargs: dict[str, Any] = {"metadata": merge_metadata(self.options.api_key, metadata)}
|
||||
if self.options.stream_timeout is not None:
|
||||
kwargs["timeout"] = self.options.stream_timeout
|
||||
call = self.raw_stub.StreamEvents(request, **kwargs)
|
||||
return _canceling_iterator(call)
|
||||
|
||||
async def _unary(
|
||||
@@ -138,10 +138,16 @@ class GatewayClient:
|
||||
*,
|
||||
metadata: Sequence[tuple[str, str]] | None = None,
|
||||
) -> Any:
|
||||
call = method(
|
||||
request,
|
||||
metadata=merge_metadata(self.options.api_key, metadata),
|
||||
)
|
||||
kwargs: dict[str, Any] = {"metadata": merge_metadata(self.options.api_key, metadata)}
|
||||
if self.options.call_timeout is not None:
|
||||
kwargs["timeout"] = self.options.call_timeout
|
||||
try:
|
||||
call = method(request, **kwargs)
|
||||
except TypeError as error:
|
||||
if "timeout" not in kwargs or "unexpected keyword argument 'timeout'" not in str(error):
|
||||
raise
|
||||
kwargs.pop("timeout")
|
||||
call = method(request, **kwargs)
|
||||
try:
|
||||
return await call
|
||||
except asyncio.CancelledError:
|
||||
|
||||
@@ -19,6 +19,8 @@ class ClientOptions:
|
||||
plaintext: bool = False
|
||||
ca_file: str | None = None
|
||||
server_name_override: str | None = None
|
||||
call_timeout: float | None = 30.0
|
||||
stream_timeout: float | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.endpoint:
|
||||
@@ -26,6 +28,10 @@ class ClientOptions:
|
||||
|
||||
if self.plaintext and self.ca_file:
|
||||
raise ValueError("ca_file cannot be used with plaintext connections")
|
||||
if self.call_timeout is not None and self.call_timeout <= 0:
|
||||
raise ValueError("call_timeout must be greater than zero")
|
||||
if self.stream_timeout is not None and self.stream_timeout <= 0:
|
||||
raise ValueError("stream_timeout must be greater than zero")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
api_key = REDACTED if self.api_key else None
|
||||
@@ -33,7 +39,9 @@ class ClientOptions:
|
||||
f"{type(self).__name__}(endpoint={self.endpoint!r}, "
|
||||
f"api_key={api_key!r}, plaintext={self.plaintext!r}, "
|
||||
f"ca_file={self.ca_file!r}, "
|
||||
f"server_name_override={self.server_name_override!r})"
|
||||
f"server_name_override={self.server_name_override!r}, "
|
||||
f"call_timeout={self.call_timeout!r}, "
|
||||
f"stream_timeout={self.stream_timeout!r})"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ from .errors import ensure_mxaccess_success
|
||||
from .generated import mxaccess_gateway_pb2 as pb
|
||||
from .values import MxValueInput, to_mx_value
|
||||
|
||||
MAX_BULK_ITEMS = 1000
|
||||
|
||||
|
||||
class Session:
|
||||
"""A single gateway-backed MXAccess session."""
|
||||
@@ -40,13 +42,14 @@ class Session:
|
||||
protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK),
|
||||
)
|
||||
|
||||
self._closed = True
|
||||
return await self.client.close_session_raw(
|
||||
reply = await self.client.close_session_raw(
|
||||
pb.CloseSessionRequest(
|
||||
session_id=self.session_id,
|
||||
client_correlation_id=client_correlation_id,
|
||||
),
|
||||
)
|
||||
self._closed = True
|
||||
return reply
|
||||
|
||||
async def invoke(self, command: pb.MxCommand, *, correlation_id: str = "") -> pb.MxCommandReply:
|
||||
"""Invoke a raw command and enforce gateway and MXAccess success."""
|
||||
@@ -192,6 +195,7 @@ class Session:
|
||||
) -> list[pb.SubscribeResult]:
|
||||
if tag_addresses is None:
|
||||
raise TypeError("tag_addresses is required")
|
||||
_ensure_bulk_size("tag_addresses", len(tag_addresses))
|
||||
reply = await self.invoke(
|
||||
pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_ADD_ITEM_BULK,
|
||||
@@ -213,6 +217,7 @@ class Session:
|
||||
) -> list[pb.SubscribeResult]:
|
||||
if item_handles is None:
|
||||
raise TypeError("item_handles is required")
|
||||
_ensure_bulk_size("item_handles", len(item_handles))
|
||||
reply = await self.invoke(
|
||||
pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_ADVISE_ITEM_BULK,
|
||||
@@ -234,6 +239,7 @@ class Session:
|
||||
) -> list[pb.SubscribeResult]:
|
||||
if item_handles is None:
|
||||
raise TypeError("item_handles is required")
|
||||
_ensure_bulk_size("item_handles", len(item_handles))
|
||||
reply = await self.invoke(
|
||||
pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_REMOVE_ITEM_BULK,
|
||||
@@ -255,6 +261,7 @@ class Session:
|
||||
) -> list[pb.SubscribeResult]:
|
||||
if item_handles is None:
|
||||
raise TypeError("item_handles is required")
|
||||
_ensure_bulk_size("item_handles", len(item_handles))
|
||||
reply = await self.invoke(
|
||||
pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_UN_ADVISE_ITEM_BULK,
|
||||
@@ -276,6 +283,7 @@ class Session:
|
||||
) -> list[pb.SubscribeResult]:
|
||||
if tag_addresses is None:
|
||||
raise TypeError("tag_addresses is required")
|
||||
_ensure_bulk_size("tag_addresses", len(tag_addresses))
|
||||
reply = await self.invoke(
|
||||
pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_SUBSCRIBE_BULK,
|
||||
@@ -297,6 +305,7 @@ class Session:
|
||||
) -> list[pb.SubscribeResult]:
|
||||
if item_handles is None:
|
||||
raise TypeError("item_handles is required")
|
||||
_ensure_bulk_size("item_handles", len(item_handles))
|
||||
reply = await self.invoke(
|
||||
pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_UNSUBSCRIBE_BULK,
|
||||
@@ -368,4 +377,9 @@ class Session:
|
||||
)
|
||||
|
||||
|
||||
def _ensure_bulk_size(name: str, count: int) -> None:
|
||||
if count > MAX_BULK_ITEMS:
|
||||
raise ValueError(f"{name} bulk commands are limited to {MAX_BULK_ITEMS} item(s)")
|
||||
|
||||
|
||||
from .client import GatewayClient # noqa: E402
|
||||
|
||||
@@ -20,6 +20,8 @@ from mxgateway.generated import mxaccess_gateway_pb2 as pb
|
||||
from mxgateway.options import ClientOptions
|
||||
from mxgateway.values import MxValueInput
|
||||
|
||||
MAX_AGGREGATE_EVENTS = 10_000
|
||||
|
||||
|
||||
@click.group()
|
||||
def main() -> None:
|
||||
@@ -55,6 +57,8 @@ def gateway_options(command: Callable[..., Any]) -> Callable[..., Any]:
|
||||
default=None,
|
||||
help="TLS server name override for test environments.",
|
||||
)(command)
|
||||
command = click.option("--call-timeout", default=30.0, type=float, show_default=True)(command)
|
||||
command = click.option("--stream-timeout", default=None, type=float)(command)
|
||||
return command
|
||||
|
||||
|
||||
@@ -352,6 +356,8 @@ async def _connect(kwargs: dict[str, Any]) -> GatewayClient:
|
||||
plaintext=_use_plaintext(kwargs),
|
||||
ca_file=kwargs.get("ca_file"),
|
||||
server_name_override=kwargs.get("server_name_override"),
|
||||
call_timeout=kwargs.get("call_timeout"),
|
||||
stream_timeout=kwargs.get("stream_timeout"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -416,6 +422,12 @@ async def _collect_events(
|
||||
max_events: int,
|
||||
timeout: float,
|
||||
) -> list[pb.MxEvent]:
|
||||
if max_events > MAX_AGGREGATE_EVENTS:
|
||||
raise click.BadParameter(
|
||||
f"must be less than or equal to {MAX_AGGREGATE_EVENTS}",
|
||||
param_hint="--max-events",
|
||||
)
|
||||
|
||||
collected: list[pb.MxEvent] = []
|
||||
iterator = events.__aiter__()
|
||||
try:
|
||||
@@ -423,6 +435,8 @@ async def _collect_events(
|
||||
collected.append(await asyncio.wait_for(iterator.__anext__(), timeout=timeout))
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
close = getattr(iterator, "aclose", None)
|
||||
if close is not None:
|
||||
|
||||
Reference in New Issue
Block a user