"""Client connection options for the async Python wrapper.""" from __future__ import annotations import ssl from collections.abc import Sequence from dataclasses import dataclass, field from pathlib import Path import grpc from .auth import REDACTED, ApiKey from .errors import MxGatewayTransportError # Fallback bound for the TOFU certificate probe when no call_timeout is set, so a # black-holed host fails fast instead of hanging on the OS default connect timeout. _TOFU_PROBE_TIMEOUT_SECONDS = 10.0 @dataclass(frozen=True) class ClientOptions: """Connection settings for `GatewayClient.connect`.""" endpoint: str api_key: str | ApiKey | None = None plaintext: bool = False ca_file: str | None = None require_certificate_validation: bool = False server_name_override: str | None = None call_timeout: float | None = 30.0 stream_timeout: float | None = None max_grpc_message_bytes: int = 16 * 1024 * 1024 def __post_init__(self) -> None: """Validate options; raise `ValueError` for invalid combinations.""" if not self.endpoint: raise ValueError("endpoint must not be empty") if self.plaintext and self.ca_file: raise ValueError("ca_file cannot be used with plaintext connections") if self.call_timeout is not None and self.call_timeout <= 0: raise ValueError("call_timeout must be greater than zero") if self.stream_timeout is not None and self.stream_timeout <= 0: raise ValueError("stream_timeout must be greater than zero") if self.max_grpc_message_bytes <= 0: raise ValueError("max_grpc_message_bytes must be greater than zero") def __repr__(self) -> str: """Return a repr that redacts the API key value.""" api_key = REDACTED if self.api_key else None return ( f"{type(self).__name__}(endpoint={self.endpoint!r}, " f"api_key={api_key!r}, plaintext={self.plaintext!r}, " f"ca_file={self.ca_file!r}, " f"require_certificate_validation={self.require_certificate_validation!r}, " f"server_name_override={self.server_name_override!r}, " f"call_timeout={self.call_timeout!r}, " f"stream_timeout={self.stream_timeout!r}, " f"max_grpc_message_bytes={self.max_grpc_message_bytes!r})" ) @dataclass(frozen=True) class BrowseChildrenOptions: """Filters and shape options for ``GalaxyRepositoryClient.browse``. Mirrors the AND-combined filter set on ``BrowseChildrenRequest`` so a single instance can be re-used across an entire lazy browse session (the filter set is part of the page-token contract). """ category_ids: Sequence[int] = field(default_factory=tuple) template_chain_contains: Sequence[str] = field(default_factory=tuple) tag_name_glob: str | None = None include_attributes: bool | None = None alarm_bearing_only: bool = False historized_only: bool = False def _split_authority(endpoint: str) -> tuple[str, int]: """Split a gRPC target (optionally scheme-prefixed) into (host, port). Handles bracketed IPv6 literals (e.g. ``[::1]:5120`` or bare ``[::1]``), returning the host without brackets so it is safe to pass to ``ssl.get_server_certificate``. """ target = endpoint.split("://", 1)[-1] if target.startswith("["): # Bracketed IPv6: "[::1]:5120" or "[::1]" bracket_end = target.find("]") host = target[1:bracket_end] # strip surrounding brackets remainder = target[bracket_end + 1 :] # ":5120" or "" port_str = remainder.lstrip(":") return (host, int(port_str) if port_str else 443) host, sep, port = target.rpartition(":") if not sep: # No colon at all (e.g. a bare hostname "mygateway"): the whole target # is the host; default the port rather than raising on int("mygateway"). return (target or "localhost", 443) if not port.isdigit(): # A colon with a non-numeric / empty tail (e.g. a trailing ":") is not # an explicit port — keep the left side as the host and default the # port so a typo cannot raise an uncaught ValueError on the TOFU path. return (host or "localhost", 443) return (host or "localhost", int(port)) def create_channel(options: ClientOptions) -> grpc.aio.Channel: """Create a plaintext or TLS `grpc.aio` channel from client options. The TLS default is lenient: grpc-python has no per-channel skip-verify, so the server's presented certificate is fetched once (unverified) and pinned as the channel's only trust root (trust-on-first-use). Set `require_certificate_validation=True` to force system-trust verification, or pass `ca_file` to verify against a specific CA — both bypass the TOFU path. """ channel_options: list[tuple[str, str | int]] = [ ("grpc.max_receive_message_length", options.max_grpc_message_bytes), ("grpc.max_send_message_length", options.max_grpc_message_bytes), ] if options.server_name_override: channel_options.append(("grpc.ssl_target_name_override", options.server_name_override)) if options.plaintext: return grpc.aio.insecure_channel(options.endpoint, options=channel_options) if options.ca_file: root_certificates = Path(options.ca_file).read_bytes() credentials = grpc.ssl_channel_credentials(root_certificates=root_certificates) elif options.require_certificate_validation: credentials = grpc.ssl_channel_credentials() else: # Lenient default: grpc-python has no per-channel skip-verify, so fetch the # server's certificate (unverified) and pin it for this channel (TOFU). # The probe opens a real blocking TCP+TLS socket, so it MUST be bounded — # a black-holed / firewall-drop host would otherwise hang on the OS default # connect timeout (minutes). Bound it by call_timeout (or a short fixed # fallback) so the dial fails fast as a transport error. The async # `connect` classmethods run this off the event loop (asyncio.to_thread). host, port = _split_authority(options.endpoint) probe_timeout = options.call_timeout if options.call_timeout else _TOFU_PROBE_TIMEOUT_SECONDS try: presented = ssl.get_server_certificate((host, port), timeout=probe_timeout) except OSError as error: raise MxGatewayTransportError( f"failed to fetch TLS certificate from {options.endpoint}: {error}" ) from error credentials = grpc.ssl_channel_credentials(root_certificates=presented.encode("ascii")) # The gateway self-signed cert always carries a "localhost" SAN, so default # the SNI/target-name override to it when none was supplied, tolerating # dial-by-IP or hostname mismatch. if not options.server_name_override: channel_options.append(("grpc.ssl_target_name_override", "localhost")) return grpc.aio.secure_channel( options.endpoint, credentials, options=channel_options, )