feat(client-python): accept gateway cert by default via TOFU pre-fetch

This commit is contained in:
Joseph Doherty
2026-06-01 07:10:55 -04:00
parent f47bbaea95
commit 4c093a64fa
2 changed files with 192 additions and 3 deletions
@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import ssl
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
@@ -9,6 +10,7 @@ from pathlib import Path
import grpc import grpc
from .auth import REDACTED, ApiKey from .auth import REDACTED, ApiKey
from .errors import MxGatewayTransportError
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -19,6 +21,7 @@ class ClientOptions:
api_key: str | ApiKey | None = None api_key: str | ApiKey | None = None
plaintext: bool = False plaintext: bool = False
ca_file: str | None = None ca_file: str | None = None
require_certificate_validation: bool = False
server_name_override: str | None = None server_name_override: str | None = None
call_timeout: float | None = 30.0 call_timeout: float | None = 30.0
stream_timeout: float | None = None stream_timeout: float | None = None
@@ -45,6 +48,7 @@ class ClientOptions:
f"{type(self).__name__}(endpoint={self.endpoint!r}, " f"{type(self).__name__}(endpoint={self.endpoint!r}, "
f"api_key={api_key!r}, plaintext={self.plaintext!r}, " f"api_key={api_key!r}, plaintext={self.plaintext!r}, "
f"ca_file={self.ca_file!r}, " f"ca_file={self.ca_file!r}, "
f"require_certificate_validation={self.require_certificate_validation!r}, "
f"server_name_override={self.server_name_override!r}, " f"server_name_override={self.server_name_override!r}, "
f"call_timeout={self.call_timeout!r}, " f"call_timeout={self.call_timeout!r}, "
f"stream_timeout={self.stream_timeout!r}, " f"stream_timeout={self.stream_timeout!r}, "
@@ -69,8 +73,22 @@ class BrowseChildrenOptions:
historized_only: bool = False historized_only: bool = False
def _split_authority(endpoint: str) -> tuple[str, int]:
"""Split a gRPC target (optionally scheme-prefixed) into (host, port)."""
target = endpoint.split("://", 1)[-1]
host, _, port = target.rpartition(":")
return (host or "localhost", int(port) if port else 443)
def create_channel(options: ClientOptions) -> grpc.aio.Channel: def create_channel(options: ClientOptions) -> grpc.aio.Channel:
"""Create a plaintext or TLS `grpc.aio` channel from client options.""" """Create a plaintext or TLS `grpc.aio` channel from client options.
The TLS default is lenient: grpc-python has no per-channel skip-verify, so
the server's presented certificate is fetched once (unverified) and pinned
as the channel's only trust root (trust-on-first-use). Set
`require_certificate_validation=True` to force system-trust verification, or
pass `ca_file` to verify against a specific CA — both bypass the TOFU path.
"""
channel_options: list[tuple[str, str | int]] = [ channel_options: list[tuple[str, str | int]] = [
("grpc.max_receive_message_length", options.max_grpc_message_bytes), ("grpc.max_receive_message_length", options.max_grpc_message_bytes),
@@ -82,11 +100,28 @@ def create_channel(options: ClientOptions) -> grpc.aio.Channel:
if options.plaintext: if options.plaintext:
return grpc.aio.insecure_channel(options.endpoint, options=channel_options) return grpc.aio.insecure_channel(options.endpoint, options=channel_options)
root_certificates = None
if options.ca_file: if options.ca_file:
root_certificates = Path(options.ca_file).read_bytes() root_certificates = Path(options.ca_file).read_bytes()
credentials = grpc.ssl_channel_credentials(root_certificates=root_certificates)
elif options.require_certificate_validation:
credentials = grpc.ssl_channel_credentials()
else:
# Lenient default: grpc-python has no per-channel skip-verify, so fetch the
# server's certificate (unverified) and pin it for this channel (TOFU).
host, port = _split_authority(options.endpoint)
try:
presented = ssl.get_server_certificate((host, port))
except OSError as error:
raise MxGatewayTransportError(
f"failed to fetch TLS certificate from {options.endpoint}: {error}"
) from error
credentials = grpc.ssl_channel_credentials(root_certificates=presented.encode("ascii"))
# The gateway self-signed cert always carries a "localhost" SAN, so default
# the SNI/target-name override to it when none was supplied, tolerating
# dial-by-IP or hostname mismatch.
if not options.server_name_override:
channel_options.append(("grpc.ssl_target_name_override", "localhost"))
credentials = grpc.ssl_channel_credentials(root_certificates=root_certificates)
return grpc.aio.secure_channel( return grpc.aio.secure_channel(
options.endpoint, options.endpoint,
credentials, credentials,
+154
View File
@@ -0,0 +1,154 @@
"""TLS behaviour tests for ``create_channel``.
These spin up a real loopback ``grpc.aio`` server with a freshly generated
self-signed certificate (carrying a ``localhost`` SAN, mirroring the gateway's
auto-generated cert) and assert the lenient TOFU default lets a client connect
without any CA configured.
Marked ``tls`` and skipped unless ``MXGATEWAY_RUN_TLS_TESTS=1`` because loopback
TLS handshakes can be timing-flaky on shared CI runners. This mirrors how the
suite gates anything that depends on real sockets rather than fakes.
"""
from __future__ import annotations
import os
import shutil
import socket
import ssl
import subprocess
import tempfile
from collections.abc import AsyncIterator
from pathlib import Path
import grpc
import pytest
import pytest_asyncio
from zb_mom_ww_mxgateway import ClientOptions
from zb_mom_ww_mxgateway.errors import MxGatewayTransportError
from zb_mom_ww_mxgateway.generated import mxaccess_gateway_pb2 as pb
from zb_mom_ww_mxgateway.generated import mxaccess_gateway_pb2_grpc as pb_grpc
from zb_mom_ww_mxgateway.options import create_channel
pytestmark = pytest.mark.tls
_RUN_TLS_TESTS = os.environ.get("MXGATEWAY_RUN_TLS_TESTS") == "1"
_OPENSSL = shutil.which("openssl")
requires_tls = pytest.mark.skipif(
not _RUN_TLS_TESTS,
reason="set MXGATEWAY_RUN_TLS_TESTS=1 to run loopback TLS tests",
)
requires_openssl = pytest.mark.skipif(
_OPENSSL is None,
reason="openssl CLI is required to generate a self-signed test certificate",
)
def _generate_self_signed_cert(directory: Path) -> tuple[Path, Path]:
"""Generate a self-signed cert/key pair with a ``localhost`` SAN."""
key_path = directory / "server.key"
cert_path = directory / "server.crt"
subprocess.run(
[
str(_OPENSSL),
"req",
"-x509",
"-newkey",
"rsa:2048",
"-nodes",
"-keyout",
str(key_path),
"-out",
str(cert_path),
"-days",
"1",
"-subj",
"/CN=mxgateway-test",
"-addext",
"subjectAltName=DNS:localhost,IP:127.0.0.1",
],
check=True,
capture_output=True,
)
return cert_path, key_path
def _free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
return int(sock.getsockname()[1])
class _StaticGatewayServicer(pb_grpc.MxAccessGatewayServicer):
"""Minimal servicer answering ``OpenSession`` with a fixed session id."""
async def OpenSession( # noqa: N802 - generated gRPC method name
self, request: pb.OpenSessionRequest, context: object
) -> pb.OpenSessionReply:
return pb.OpenSessionReply(session_id="tls-session-1")
@pytest_asyncio.fixture
async def tls_server() -> AsyncIterator[int]:
with tempfile.TemporaryDirectory() as tmp:
cert_path, key_path = _generate_self_signed_cert(Path(tmp))
credentials = grpc.ssl_server_credentials(
[(key_path.read_bytes(), cert_path.read_bytes())]
)
server = grpc.aio.server()
pb_grpc.add_MxAccessGatewayServicer_to_server(_StaticGatewayServicer(), server)
port = _free_port()
server.add_secure_port(f"127.0.0.1:{port}", credentials)
await server.start()
try:
yield port
finally:
await server.stop(grace=None)
@requires_tls
@requires_openssl
@pytest.mark.asyncio
async def test_default_tls_connects_via_tofu(tls_server: int) -> None:
"""Default TLS options (no CA) connect by pinning the presented cert."""
options = ClientOptions(
endpoint=f"127.0.0.1:{tls_server}",
api_key="mxgw_test_secret",
)
channel = create_channel(options)
try:
stub = pb_grpc.MxAccessGatewayStub(channel)
reply = await stub.OpenSession(pb.OpenSessionRequest(), timeout=10)
assert reply.session_id == "tls-session-1"
finally:
await channel.close()
def test_split_authority_parses_host_and_port() -> None:
from zb_mom_ww_mxgateway.options import _split_authority
assert _split_authority("https://10.0.0.5:5120") == ("10.0.0.5", 5120)
assert _split_authority("localhost:5120") == ("localhost", 5120)
assert _split_authority(":5120") == ("localhost", 5120)
def test_tofu_connect_failure_raises_transport_error() -> None:
"""A failed cert pre-fetch surfaces the client's transport error type."""
options = ClientOptions(endpoint=f"127.0.0.1:{_free_port()}")
with pytest.raises(MxGatewayTransportError) as excinfo:
create_channel(options)
assert options.endpoint in str(excinfo.value)
def test_require_certificate_validation_uses_system_trust() -> None:
"""``require_certificate_validation`` must not attempt a TOFU pre-fetch."""
# Pointing at a closed port: with system-trust the channel is created lazily
# (no eager pre-fetch), so create_channel must succeed without connecting.
options = ClientOptions(
endpoint=f"127.0.0.1:{_free_port()}",
require_certificate_validation=True,
)
channel = create_channel(options)
assert isinstance(channel, grpc.aio.Channel)