142 lines
5.8 KiB
Python
142 lines
5.8 KiB
Python
"""Client connection options for the async Python wrapper."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import ssl
|
|
from collections.abc import Sequence
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
|
|
import grpc
|
|
|
|
from .auth import REDACTED, ApiKey
|
|
from .errors import MxGatewayTransportError
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ClientOptions:
|
|
"""Connection settings for `GatewayClient.connect`."""
|
|
|
|
endpoint: str
|
|
api_key: str | ApiKey | None = None
|
|
plaintext: bool = False
|
|
ca_file: str | None = None
|
|
require_certificate_validation: bool = False
|
|
server_name_override: str | None = None
|
|
call_timeout: float | None = 30.0
|
|
stream_timeout: float | None = None
|
|
max_grpc_message_bytes: int = 16 * 1024 * 1024
|
|
|
|
def __post_init__(self) -> None:
|
|
"""Validate options; raise `ValueError` for invalid combinations."""
|
|
if not self.endpoint:
|
|
raise ValueError("endpoint must not be empty")
|
|
|
|
if self.plaintext and self.ca_file:
|
|
raise ValueError("ca_file cannot be used with plaintext connections")
|
|
if self.call_timeout is not None and self.call_timeout <= 0:
|
|
raise ValueError("call_timeout must be greater than zero")
|
|
if self.stream_timeout is not None and self.stream_timeout <= 0:
|
|
raise ValueError("stream_timeout must be greater than zero")
|
|
if self.max_grpc_message_bytes <= 0:
|
|
raise ValueError("max_grpc_message_bytes must be greater than zero")
|
|
|
|
def __repr__(self) -> str:
|
|
"""Return a repr that redacts the API key value."""
|
|
api_key = REDACTED if self.api_key else None
|
|
return (
|
|
f"{type(self).__name__}(endpoint={self.endpoint!r}, "
|
|
f"api_key={api_key!r}, plaintext={self.plaintext!r}, "
|
|
f"ca_file={self.ca_file!r}, "
|
|
f"require_certificate_validation={self.require_certificate_validation!r}, "
|
|
f"server_name_override={self.server_name_override!r}, "
|
|
f"call_timeout={self.call_timeout!r}, "
|
|
f"stream_timeout={self.stream_timeout!r}, "
|
|
f"max_grpc_message_bytes={self.max_grpc_message_bytes!r})"
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class BrowseChildrenOptions:
|
|
"""Filters and shape options for ``GalaxyRepositoryClient.browse``.
|
|
|
|
Mirrors the AND-combined filter set on ``BrowseChildrenRequest`` so a
|
|
single instance can be re-used across an entire lazy browse session
|
|
(the filter set is part of the page-token contract).
|
|
"""
|
|
|
|
category_ids: Sequence[int] = field(default_factory=tuple)
|
|
template_chain_contains: Sequence[str] = field(default_factory=tuple)
|
|
tag_name_glob: str | None = None
|
|
include_attributes: bool | None = None
|
|
alarm_bearing_only: bool = False
|
|
historized_only: bool = False
|
|
|
|
|
|
def _split_authority(endpoint: str) -> tuple[str, int]:
|
|
"""Split a gRPC target (optionally scheme-prefixed) into (host, port).
|
|
|
|
Handles bracketed IPv6 literals (e.g. ``[::1]:5120`` or bare ``[::1]``),
|
|
returning the host without brackets so it is safe to pass to
|
|
``ssl.get_server_certificate``.
|
|
"""
|
|
target = endpoint.split("://", 1)[-1]
|
|
if target.startswith("["):
|
|
# Bracketed IPv6: "[::1]:5120" or "[::1]"
|
|
bracket_end = target.find("]")
|
|
host = target[1:bracket_end] # strip surrounding brackets
|
|
remainder = target[bracket_end + 1 :] # ":5120" or ""
|
|
port_str = remainder.lstrip(":")
|
|
return (host, int(port_str) if port_str else 443)
|
|
host, _, port = target.rpartition(":")
|
|
return (host or "localhost", int(port) if port else 443)
|
|
|
|
|
|
def create_channel(options: ClientOptions) -> grpc.aio.Channel:
|
|
"""Create a plaintext or TLS `grpc.aio` channel from client options.
|
|
|
|
The TLS default is lenient: grpc-python has no per-channel skip-verify, so
|
|
the server's presented certificate is fetched once (unverified) and pinned
|
|
as the channel's only trust root (trust-on-first-use). Set
|
|
`require_certificate_validation=True` to force system-trust verification, or
|
|
pass `ca_file` to verify against a specific CA — both bypass the TOFU path.
|
|
"""
|
|
|
|
channel_options: list[tuple[str, str | int]] = [
|
|
("grpc.max_receive_message_length", options.max_grpc_message_bytes),
|
|
("grpc.max_send_message_length", options.max_grpc_message_bytes),
|
|
]
|
|
if options.server_name_override:
|
|
channel_options.append(("grpc.ssl_target_name_override", options.server_name_override))
|
|
|
|
if options.plaintext:
|
|
return grpc.aio.insecure_channel(options.endpoint, options=channel_options)
|
|
|
|
if options.ca_file:
|
|
root_certificates = Path(options.ca_file).read_bytes()
|
|
credentials = grpc.ssl_channel_credentials(root_certificates=root_certificates)
|
|
elif options.require_certificate_validation:
|
|
credentials = grpc.ssl_channel_credentials()
|
|
else:
|
|
# Lenient default: grpc-python has no per-channel skip-verify, so fetch the
|
|
# server's certificate (unverified) and pin it for this channel (TOFU).
|
|
host, port = _split_authority(options.endpoint)
|
|
try:
|
|
presented = ssl.get_server_certificate((host, port))
|
|
except OSError as error:
|
|
raise MxGatewayTransportError(
|
|
f"failed to fetch TLS certificate from {options.endpoint}: {error}"
|
|
) from error
|
|
credentials = grpc.ssl_channel_credentials(root_certificates=presented.encode("ascii"))
|
|
# The gateway self-signed cert always carries a "localhost" SAN, so default
|
|
# the SNI/target-name override to it when none was supplied, tolerating
|
|
# dial-by-IP or hostname mismatch.
|
|
if not options.server_name_override:
|
|
channel_options.append(("grpc.ssl_target_name_override", "localhost"))
|
|
|
|
return grpc.aio.secure_channel(
|
|
options.endpoint,
|
|
credentials,
|
|
options=channel_options,
|
|
)
|