fix(client/python): reachable cert-validation flag; bounded off-loop TOFU probe; license/marker fixes (Client.Python-027..031)

This commit is contained in:
Joseph Doherty
2026-06-15 02:39:11 -04:00
parent d0d1dcef15
commit 47062c1a6e
11 changed files with 550 additions and 13 deletions
@@ -40,6 +40,7 @@ class GatewayClient:
api_key: str | None = None,
plaintext: bool = False,
ca_file: str | None = None,
require_certificate_validation: bool = False,
server_name_override: str | None = None,
stub: Any | None = None,
) -> "GatewayClient":
@@ -50,13 +51,16 @@ class GatewayClient:
api_key=api_key,
plaintext=plaintext,
ca_file=ca_file,
require_certificate_validation=require_certificate_validation,
server_name_override=server_name_override,
)
if stub is not None:
return cls(options=resolved, stub=stub)
channel = create_channel(resolved)
# create_channel may perform a blocking TLS certificate probe (TOFU
# default); run it off the event loop so connect never freezes the loop.
channel = await asyncio.to_thread(create_channel, resolved)
return cls(
options=resolved,
stub=pb_grpc.MxAccessGatewayStub(channel),
@@ -52,6 +52,7 @@ class GalaxyRepositoryClient:
api_key: str | None = None,
plaintext: bool = False,
ca_file: str | None = None,
require_certificate_validation: bool = False,
server_name_override: str | None = None,
stub: Any | None = None,
) -> "GalaxyRepositoryClient":
@@ -62,13 +63,16 @@ class GalaxyRepositoryClient:
api_key=api_key,
plaintext=plaintext,
ca_file=ca_file,
require_certificate_validation=require_certificate_validation,
server_name_override=server_name_override,
)
if stub is not None:
return cls(options=resolved, stub=stub)
channel = create_channel(resolved)
# create_channel may perform a blocking TLS certificate probe (TOFU
# default); run it off the event loop so connect never freezes the loop.
channel = await asyncio.to_thread(create_channel, resolved)
return cls(
options=resolved,
stub=galaxy_pb_grpc.GalaxyRepositoryStub(channel),
@@ -12,6 +12,10 @@ import grpc
from .auth import REDACTED, ApiKey
from .errors import MxGatewayTransportError
# Fallback bound for the TOFU certificate probe when no call_timeout is set, so a
# black-holed host fails fast instead of hanging on the OS default connect timeout.
_TOFU_PROBE_TIMEOUT_SECONDS = 10.0
@dataclass(frozen=True)
class ClientOptions:
@@ -88,8 +92,17 @@ def _split_authority(endpoint: str) -> tuple[str, int]:
remainder = target[bracket_end + 1 :] # ":5120" or ""
port_str = remainder.lstrip(":")
return (host, int(port_str) if port_str else 443)
host, _, port = target.rpartition(":")
return (host or "localhost", int(port) if port else 443)
host, sep, port = target.rpartition(":")
if not sep:
# No colon at all (e.g. a bare hostname "mygateway"): the whole target
# is the host; default the port rather than raising on int("mygateway").
return (target or "localhost", 443)
if not port.isdigit():
# A colon with a non-numeric / empty tail (e.g. a trailing ":") is not
# an explicit port — keep the left side as the host and default the
# port so a typo cannot raise an uncaught ValueError on the TOFU path.
return (host or "localhost", 443)
return (host or "localhost", int(port))
def create_channel(options: ClientOptions) -> grpc.aio.Channel:
@@ -120,9 +133,15 @@ def create_channel(options: ClientOptions) -> grpc.aio.Channel:
else:
# Lenient default: grpc-python has no per-channel skip-verify, so fetch the
# server's certificate (unverified) and pin it for this channel (TOFU).
# The probe opens a real blocking TCP+TLS socket, so it MUST be bounded —
# a black-holed / firewall-drop host would otherwise hang on the OS default
# connect timeout (minutes). Bound it by call_timeout (or a short fixed
# fallback) so the dial fails fast as a transport error. The async
# `connect` classmethods run this off the event loop (asyncio.to_thread).
host, port = _split_authority(options.endpoint)
probe_timeout = options.call_timeout if options.call_timeout else _TOFU_PROBE_TIMEOUT_SECONDS
try:
presented = ssl.get_server_certificate((host, port))
presented = ssl.get_server_certificate((host, port), timeout=probe_timeout)
except OSError as error:
raise MxGatewayTransportError(
f"failed to fetch TLS certificate from {options.endpoint}: {error}"
@@ -170,6 +170,13 @@ def gateway_options(command: Callable[..., Any]) -> Callable[..., Any]:
command = click.option("--plaintext", is_flag=True, help="Use plaintext gRPC.")(command)
command = click.option("--tls", "use_tls", is_flag=True, help="Use TLS gRPC.")(command)
command = click.option("--ca-file", default=None, help="Custom root certificate file.")(command)
command = click.option(
"--require-certificate-validation",
"require_certificate_validation",
is_flag=True,
help="Verify the TLS certificate against the system trust store "
"instead of the lenient trust-on-first-use default.",
)(command)
command = click.option(
"--server-name-override",
default=None,
@@ -923,6 +930,7 @@ async def _connect(kwargs: dict[str, Any]) -> GatewayClient:
api_key=api_key,
plaintext=_use_plaintext(kwargs),
ca_file=kwargs.get("ca_file"),
require_certificate_validation=bool(kwargs.get("require_certificate_validation")),
server_name_override=kwargs.get("server_name_override"),
call_timeout=kwargs.get("call_timeout"),
stream_timeout=kwargs.get("stream_timeout"),