feat(client-python): accept gateway cert by default via TOFU pre-fetch

This commit is contained in:
Joseph Doherty
2026-06-01 07:10:55 -04:00
parent f47bbaea95
commit 4c093a64fa
2 changed files with 192 additions and 3 deletions
@@ -2,6 +2,7 @@
from __future__ import annotations
import ssl
from collections.abc import Sequence
from dataclasses import dataclass, field
from pathlib import Path
@@ -9,6 +10,7 @@ from pathlib import Path
import grpc
from .auth import REDACTED, ApiKey
from .errors import MxGatewayTransportError
@dataclass(frozen=True)
@@ -19,6 +21,7 @@ class ClientOptions:
api_key: str | ApiKey | None = None
plaintext: bool = False
ca_file: str | None = None
require_certificate_validation: bool = False
server_name_override: str | None = None
call_timeout: float | None = 30.0
stream_timeout: float | None = None
@@ -45,6 +48,7 @@ class ClientOptions:
f"{type(self).__name__}(endpoint={self.endpoint!r}, "
f"api_key={api_key!r}, plaintext={self.plaintext!r}, "
f"ca_file={self.ca_file!r}, "
f"require_certificate_validation={self.require_certificate_validation!r}, "
f"server_name_override={self.server_name_override!r}, "
f"call_timeout={self.call_timeout!r}, "
f"stream_timeout={self.stream_timeout!r}, "
@@ -69,8 +73,22 @@ class BrowseChildrenOptions:
historized_only: bool = False
def _split_authority(endpoint: str) -> tuple[str, int]:
"""Split a gRPC target (optionally scheme-prefixed) into (host, port)."""
target = endpoint.split("://", 1)[-1]
host, _, port = target.rpartition(":")
return (host or "localhost", int(port) if port else 443)
def create_channel(options: ClientOptions) -> grpc.aio.Channel:
"""Create a plaintext or TLS `grpc.aio` channel from client options."""
"""Create a plaintext or TLS `grpc.aio` channel from client options.
The TLS default is lenient: grpc-python has no per-channel skip-verify, so
the server's presented certificate is fetched once (unverified) and pinned
as the channel's only trust root (trust-on-first-use). Set
`require_certificate_validation=True` to force system-trust verification, or
pass `ca_file` to verify against a specific CA — both bypass the TOFU path.
"""
channel_options: list[tuple[str, str | int]] = [
("grpc.max_receive_message_length", options.max_grpc_message_bytes),
@@ -82,11 +100,28 @@ def create_channel(options: ClientOptions) -> grpc.aio.Channel:
if options.plaintext:
return grpc.aio.insecure_channel(options.endpoint, options=channel_options)
root_certificates = None
if options.ca_file:
root_certificates = Path(options.ca_file).read_bytes()
credentials = grpc.ssl_channel_credentials(root_certificates=root_certificates)
elif options.require_certificate_validation:
credentials = grpc.ssl_channel_credentials()
else:
# Lenient default: grpc-python has no per-channel skip-verify, so fetch the
# server's certificate (unverified) and pin it for this channel (TOFU).
host, port = _split_authority(options.endpoint)
try:
presented = ssl.get_server_certificate((host, port))
except OSError as error:
raise MxGatewayTransportError(
f"failed to fetch TLS certificate from {options.endpoint}: {error}"
) from error
credentials = grpc.ssl_channel_credentials(root_certificates=presented.encode("ascii"))
# The gateway self-signed cert always carries a "localhost" SAN, so default
# the SNI/target-name override to it when none was supplied, tolerating
# dial-by-IP or hostname mismatch.
if not options.server_name_override:
channel_options.append(("grpc.ssl_target_name_override", "localhost"))
credentials = grpc.ssl_channel_credentials(root_certificates=root_certificates)
return grpc.aio.secure_channel(
options.endpoint,
credentials,
+154
View File
@@ -0,0 +1,154 @@
"""TLS behaviour tests for ``create_channel``.
These spin up a real loopback ``grpc.aio`` server with a freshly generated
self-signed certificate (carrying a ``localhost`` SAN, mirroring the gateway's
auto-generated cert) and assert the lenient TOFU default lets a client connect
without any CA configured.
Marked ``tls`` and skipped unless ``MXGATEWAY_RUN_TLS_TESTS=1`` because loopback
TLS handshakes can be timing-flaky on shared CI runners. This mirrors how the
suite gates anything that depends on real sockets rather than fakes.
"""
from __future__ import annotations
import os
import shutil
import socket
import ssl
import subprocess
import tempfile
from collections.abc import AsyncIterator
from pathlib import Path
import grpc
import pytest
import pytest_asyncio
from zb_mom_ww_mxgateway import ClientOptions
from zb_mom_ww_mxgateway.errors import MxGatewayTransportError
from zb_mom_ww_mxgateway.generated import mxaccess_gateway_pb2 as pb
from zb_mom_ww_mxgateway.generated import mxaccess_gateway_pb2_grpc as pb_grpc
from zb_mom_ww_mxgateway.options import create_channel
pytestmark = pytest.mark.tls
_RUN_TLS_TESTS = os.environ.get("MXGATEWAY_RUN_TLS_TESTS") == "1"
_OPENSSL = shutil.which("openssl")
requires_tls = pytest.mark.skipif(
not _RUN_TLS_TESTS,
reason="set MXGATEWAY_RUN_TLS_TESTS=1 to run loopback TLS tests",
)
requires_openssl = pytest.mark.skipif(
_OPENSSL is None,
reason="openssl CLI is required to generate a self-signed test certificate",
)
def _generate_self_signed_cert(directory: Path) -> tuple[Path, Path]:
"""Generate a self-signed cert/key pair with a ``localhost`` SAN."""
key_path = directory / "server.key"
cert_path = directory / "server.crt"
subprocess.run(
[
str(_OPENSSL),
"req",
"-x509",
"-newkey",
"rsa:2048",
"-nodes",
"-keyout",
str(key_path),
"-out",
str(cert_path),
"-days",
"1",
"-subj",
"/CN=mxgateway-test",
"-addext",
"subjectAltName=DNS:localhost,IP:127.0.0.1",
],
check=True,
capture_output=True,
)
return cert_path, key_path
def _free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
return int(sock.getsockname()[1])
class _StaticGatewayServicer(pb_grpc.MxAccessGatewayServicer):
"""Minimal servicer answering ``OpenSession`` with a fixed session id."""
async def OpenSession( # noqa: N802 - generated gRPC method name
self, request: pb.OpenSessionRequest, context: object
) -> pb.OpenSessionReply:
return pb.OpenSessionReply(session_id="tls-session-1")
@pytest_asyncio.fixture
async def tls_server() -> AsyncIterator[int]:
with tempfile.TemporaryDirectory() as tmp:
cert_path, key_path = _generate_self_signed_cert(Path(tmp))
credentials = grpc.ssl_server_credentials(
[(key_path.read_bytes(), cert_path.read_bytes())]
)
server = grpc.aio.server()
pb_grpc.add_MxAccessGatewayServicer_to_server(_StaticGatewayServicer(), server)
port = _free_port()
server.add_secure_port(f"127.0.0.1:{port}", credentials)
await server.start()
try:
yield port
finally:
await server.stop(grace=None)
@requires_tls
@requires_openssl
@pytest.mark.asyncio
async def test_default_tls_connects_via_tofu(tls_server: int) -> None:
"""Default TLS options (no CA) connect by pinning the presented cert."""
options = ClientOptions(
endpoint=f"127.0.0.1:{tls_server}",
api_key="mxgw_test_secret",
)
channel = create_channel(options)
try:
stub = pb_grpc.MxAccessGatewayStub(channel)
reply = await stub.OpenSession(pb.OpenSessionRequest(), timeout=10)
assert reply.session_id == "tls-session-1"
finally:
await channel.close()
def test_split_authority_parses_host_and_port() -> None:
from zb_mom_ww_mxgateway.options import _split_authority
assert _split_authority("https://10.0.0.5:5120") == ("10.0.0.5", 5120)
assert _split_authority("localhost:5120") == ("localhost", 5120)
assert _split_authority(":5120") == ("localhost", 5120)
def test_tofu_connect_failure_raises_transport_error() -> None:
"""A failed cert pre-fetch surfaces the client's transport error type."""
options = ClientOptions(endpoint=f"127.0.0.1:{_free_port()}")
with pytest.raises(MxGatewayTransportError) as excinfo:
create_channel(options)
assert options.endpoint in str(excinfo.value)
def test_require_certificate_validation_uses_system_trust() -> None:
"""``require_certificate_validation`` must not attempt a TOFU pre-fetch."""
# Pointing at a closed port: with system-trust the channel is created lazily
# (no eager pre-fetch), so create_channel must succeed without connecting.
options = ClientOptions(
endpoint=f"127.0.0.1:{_free_port()}",
require_certificate_validation=True,
)
channel = create_channel(options)
assert isinstance(channel, grpc.aio.Channel)