diff --git a/clients/python/src/zb_mom_ww_mxgateway/options.py b/clients/python/src/zb_mom_ww_mxgateway/options.py index 060d50c..b4e1645 100644 --- a/clients/python/src/zb_mom_ww_mxgateway/options.py +++ b/clients/python/src/zb_mom_ww_mxgateway/options.py @@ -2,6 +2,7 @@ from __future__ import annotations +import ssl from collections.abc import Sequence from dataclasses import dataclass, field from pathlib import Path @@ -9,6 +10,7 @@ from pathlib import Path import grpc from .auth import REDACTED, ApiKey +from .errors import MxGatewayTransportError @dataclass(frozen=True) @@ -19,6 +21,7 @@ class ClientOptions: api_key: str | ApiKey | None = None plaintext: bool = False ca_file: str | None = None + require_certificate_validation: bool = False server_name_override: str | None = None call_timeout: float | None = 30.0 stream_timeout: float | None = None @@ -45,6 +48,7 @@ class ClientOptions: f"{type(self).__name__}(endpoint={self.endpoint!r}, " f"api_key={api_key!r}, plaintext={self.plaintext!r}, " f"ca_file={self.ca_file!r}, " + f"require_certificate_validation={self.require_certificate_validation!r}, " f"server_name_override={self.server_name_override!r}, " f"call_timeout={self.call_timeout!r}, " f"stream_timeout={self.stream_timeout!r}, " @@ -69,8 +73,22 @@ class BrowseChildrenOptions: historized_only: bool = False +def _split_authority(endpoint: str) -> tuple[str, int]: + """Split a gRPC target (optionally scheme-prefixed) into (host, port).""" + target = endpoint.split("://", 1)[-1] + host, _, port = target.rpartition(":") + return (host or "localhost", int(port) if port else 443) + + def create_channel(options: ClientOptions) -> grpc.aio.Channel: - """Create a plaintext or TLS `grpc.aio` channel from client options.""" + """Create a plaintext or TLS `grpc.aio` channel from client options. + + The TLS default is lenient: grpc-python has no per-channel skip-verify, so + the server's presented certificate is fetched once (unverified) and pinned + as the channel's only trust root (trust-on-first-use). Set + `require_certificate_validation=True` to force system-trust verification, or + pass `ca_file` to verify against a specific CA — both bypass the TOFU path. + """ channel_options: list[tuple[str, str | int]] = [ ("grpc.max_receive_message_length", options.max_grpc_message_bytes), @@ -82,11 +100,28 @@ def create_channel(options: ClientOptions) -> grpc.aio.Channel: if options.plaintext: return grpc.aio.insecure_channel(options.endpoint, options=channel_options) - root_certificates = None if options.ca_file: root_certificates = Path(options.ca_file).read_bytes() + credentials = grpc.ssl_channel_credentials(root_certificates=root_certificates) + elif options.require_certificate_validation: + credentials = grpc.ssl_channel_credentials() + else: + # Lenient default: grpc-python has no per-channel skip-verify, so fetch the + # server's certificate (unverified) and pin it for this channel (TOFU). + host, port = _split_authority(options.endpoint) + try: + presented = ssl.get_server_certificate((host, port)) + except OSError as error: + raise MxGatewayTransportError( + f"failed to fetch TLS certificate from {options.endpoint}: {error}" + ) from error + credentials = grpc.ssl_channel_credentials(root_certificates=presented.encode("ascii")) + # The gateway self-signed cert always carries a "localhost" SAN, so default + # the SNI/target-name override to it when none was supplied, tolerating + # dial-by-IP or hostname mismatch. + if not options.server_name_override: + channel_options.append(("grpc.ssl_target_name_override", "localhost")) - credentials = grpc.ssl_channel_credentials(root_certificates=root_certificates) return grpc.aio.secure_channel( options.endpoint, credentials, diff --git a/clients/python/tests/test_tls.py b/clients/python/tests/test_tls.py new file mode 100644 index 0000000..2547011 --- /dev/null +++ b/clients/python/tests/test_tls.py @@ -0,0 +1,154 @@ +"""TLS behaviour tests for ``create_channel``. + +These spin up a real loopback ``grpc.aio`` server with a freshly generated +self-signed certificate (carrying a ``localhost`` SAN, mirroring the gateway's +auto-generated cert) and assert the lenient TOFU default lets a client connect +without any CA configured. + +Marked ``tls`` and skipped unless ``MXGATEWAY_RUN_TLS_TESTS=1`` because loopback +TLS handshakes can be timing-flaky on shared CI runners. This mirrors how the +suite gates anything that depends on real sockets rather than fakes. +""" + +from __future__ import annotations + +import os +import shutil +import socket +import ssl +import subprocess +import tempfile +from collections.abc import AsyncIterator +from pathlib import Path + +import grpc +import pytest +import pytest_asyncio + +from zb_mom_ww_mxgateway import ClientOptions +from zb_mom_ww_mxgateway.errors import MxGatewayTransportError +from zb_mom_ww_mxgateway.generated import mxaccess_gateway_pb2 as pb +from zb_mom_ww_mxgateway.generated import mxaccess_gateway_pb2_grpc as pb_grpc +from zb_mom_ww_mxgateway.options import create_channel + +pytestmark = pytest.mark.tls + +_RUN_TLS_TESTS = os.environ.get("MXGATEWAY_RUN_TLS_TESTS") == "1" +_OPENSSL = shutil.which("openssl") + +requires_tls = pytest.mark.skipif( + not _RUN_TLS_TESTS, + reason="set MXGATEWAY_RUN_TLS_TESTS=1 to run loopback TLS tests", +) +requires_openssl = pytest.mark.skipif( + _OPENSSL is None, + reason="openssl CLI is required to generate a self-signed test certificate", +) + + +def _generate_self_signed_cert(directory: Path) -> tuple[Path, Path]: + """Generate a self-signed cert/key pair with a ``localhost`` SAN.""" + key_path = directory / "server.key" + cert_path = directory / "server.crt" + subprocess.run( + [ + str(_OPENSSL), + "req", + "-x509", + "-newkey", + "rsa:2048", + "-nodes", + "-keyout", + str(key_path), + "-out", + str(cert_path), + "-days", + "1", + "-subj", + "/CN=mxgateway-test", + "-addext", + "subjectAltName=DNS:localhost,IP:127.0.0.1", + ], + check=True, + capture_output=True, + ) + return cert_path, key_path + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +class _StaticGatewayServicer(pb_grpc.MxAccessGatewayServicer): + """Minimal servicer answering ``OpenSession`` with a fixed session id.""" + + async def OpenSession( # noqa: N802 - generated gRPC method name + self, request: pb.OpenSessionRequest, context: object + ) -> pb.OpenSessionReply: + return pb.OpenSessionReply(session_id="tls-session-1") + + +@pytest_asyncio.fixture +async def tls_server() -> AsyncIterator[int]: + with tempfile.TemporaryDirectory() as tmp: + cert_path, key_path = _generate_self_signed_cert(Path(tmp)) + credentials = grpc.ssl_server_credentials( + [(key_path.read_bytes(), cert_path.read_bytes())] + ) + server = grpc.aio.server() + pb_grpc.add_MxAccessGatewayServicer_to_server(_StaticGatewayServicer(), server) + port = _free_port() + server.add_secure_port(f"127.0.0.1:{port}", credentials) + await server.start() + try: + yield port + finally: + await server.stop(grace=None) + + +@requires_tls +@requires_openssl +@pytest.mark.asyncio +async def test_default_tls_connects_via_tofu(tls_server: int) -> None: + """Default TLS options (no CA) connect by pinning the presented cert.""" + options = ClientOptions( + endpoint=f"127.0.0.1:{tls_server}", + api_key="mxgw_test_secret", + ) + channel = create_channel(options) + try: + stub = pb_grpc.MxAccessGatewayStub(channel) + reply = await stub.OpenSession(pb.OpenSessionRequest(), timeout=10) + assert reply.session_id == "tls-session-1" + finally: + await channel.close() + + +def test_split_authority_parses_host_and_port() -> None: + from zb_mom_ww_mxgateway.options import _split_authority + + assert _split_authority("https://10.0.0.5:5120") == ("10.0.0.5", 5120) + assert _split_authority("localhost:5120") == ("localhost", 5120) + assert _split_authority(":5120") == ("localhost", 5120) + + +def test_tofu_connect_failure_raises_transport_error() -> None: + """A failed cert pre-fetch surfaces the client's transport error type.""" + options = ClientOptions(endpoint=f"127.0.0.1:{_free_port()}") + with pytest.raises(MxGatewayTransportError) as excinfo: + create_channel(options) + assert options.endpoint in str(excinfo.value) + + +def test_require_certificate_validation_uses_system_trust() -> None: + """``require_certificate_validation`` must not attempt a TOFU pre-fetch.""" + # Pointing at a closed port: with system-trust the channel is created lazily + # (no eager pre-fetch), so create_channel must succeed without connecting. + options = ClientOptions( + endpoint=f"127.0.0.1:{_free_port()}", + require_certificate_validation=True, + ) + channel = create_channel(options) + assert isinstance(channel, grpc.aio.Channel)