feat(client-python): accept gateway cert by default via TOFU pre-fetch
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ssl
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -9,6 +10,7 @@ from pathlib import Path
|
|||||||
import grpc
|
import grpc
|
||||||
|
|
||||||
from .auth import REDACTED, ApiKey
|
from .auth import REDACTED, ApiKey
|
||||||
|
from .errors import MxGatewayTransportError
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -19,6 +21,7 @@ class ClientOptions:
|
|||||||
api_key: str | ApiKey | None = None
|
api_key: str | ApiKey | None = None
|
||||||
plaintext: bool = False
|
plaintext: bool = False
|
||||||
ca_file: str | None = None
|
ca_file: str | None = None
|
||||||
|
require_certificate_validation: bool = False
|
||||||
server_name_override: str | None = None
|
server_name_override: str | None = None
|
||||||
call_timeout: float | None = 30.0
|
call_timeout: float | None = 30.0
|
||||||
stream_timeout: float | None = None
|
stream_timeout: float | None = None
|
||||||
@@ -45,6 +48,7 @@ class ClientOptions:
|
|||||||
f"{type(self).__name__}(endpoint={self.endpoint!r}, "
|
f"{type(self).__name__}(endpoint={self.endpoint!r}, "
|
||||||
f"api_key={api_key!r}, plaintext={self.plaintext!r}, "
|
f"api_key={api_key!r}, plaintext={self.plaintext!r}, "
|
||||||
f"ca_file={self.ca_file!r}, "
|
f"ca_file={self.ca_file!r}, "
|
||||||
|
f"require_certificate_validation={self.require_certificate_validation!r}, "
|
||||||
f"server_name_override={self.server_name_override!r}, "
|
f"server_name_override={self.server_name_override!r}, "
|
||||||
f"call_timeout={self.call_timeout!r}, "
|
f"call_timeout={self.call_timeout!r}, "
|
||||||
f"stream_timeout={self.stream_timeout!r}, "
|
f"stream_timeout={self.stream_timeout!r}, "
|
||||||
@@ -69,8 +73,22 @@ class BrowseChildrenOptions:
|
|||||||
historized_only: bool = False
|
historized_only: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def _split_authority(endpoint: str) -> tuple[str, int]:
|
||||||
|
"""Split a gRPC target (optionally scheme-prefixed) into (host, port)."""
|
||||||
|
target = endpoint.split("://", 1)[-1]
|
||||||
|
host, _, port = target.rpartition(":")
|
||||||
|
return (host or "localhost", int(port) if port else 443)
|
||||||
|
|
||||||
|
|
||||||
def create_channel(options: ClientOptions) -> grpc.aio.Channel:
|
def create_channel(options: ClientOptions) -> grpc.aio.Channel:
|
||||||
"""Create a plaintext or TLS `grpc.aio` channel from client options."""
|
"""Create a plaintext or TLS `grpc.aio` channel from client options.
|
||||||
|
|
||||||
|
The TLS default is lenient: grpc-python has no per-channel skip-verify, so
|
||||||
|
the server's presented certificate is fetched once (unverified) and pinned
|
||||||
|
as the channel's only trust root (trust-on-first-use). Set
|
||||||
|
`require_certificate_validation=True` to force system-trust verification, or
|
||||||
|
pass `ca_file` to verify against a specific CA — both bypass the TOFU path.
|
||||||
|
"""
|
||||||
|
|
||||||
channel_options: list[tuple[str, str | int]] = [
|
channel_options: list[tuple[str, str | int]] = [
|
||||||
("grpc.max_receive_message_length", options.max_grpc_message_bytes),
|
("grpc.max_receive_message_length", options.max_grpc_message_bytes),
|
||||||
@@ -82,11 +100,28 @@ def create_channel(options: ClientOptions) -> grpc.aio.Channel:
|
|||||||
if options.plaintext:
|
if options.plaintext:
|
||||||
return grpc.aio.insecure_channel(options.endpoint, options=channel_options)
|
return grpc.aio.insecure_channel(options.endpoint, options=channel_options)
|
||||||
|
|
||||||
root_certificates = None
|
|
||||||
if options.ca_file:
|
if options.ca_file:
|
||||||
root_certificates = Path(options.ca_file).read_bytes()
|
root_certificates = Path(options.ca_file).read_bytes()
|
||||||
|
credentials = grpc.ssl_channel_credentials(root_certificates=root_certificates)
|
||||||
|
elif options.require_certificate_validation:
|
||||||
|
credentials = grpc.ssl_channel_credentials()
|
||||||
|
else:
|
||||||
|
# Lenient default: grpc-python has no per-channel skip-verify, so fetch the
|
||||||
|
# server's certificate (unverified) and pin it for this channel (TOFU).
|
||||||
|
host, port = _split_authority(options.endpoint)
|
||||||
|
try:
|
||||||
|
presented = ssl.get_server_certificate((host, port))
|
||||||
|
except OSError as error:
|
||||||
|
raise MxGatewayTransportError(
|
||||||
|
f"failed to fetch TLS certificate from {options.endpoint}: {error}"
|
||||||
|
) from error
|
||||||
|
credentials = grpc.ssl_channel_credentials(root_certificates=presented.encode("ascii"))
|
||||||
|
# The gateway self-signed cert always carries a "localhost" SAN, so default
|
||||||
|
# the SNI/target-name override to it when none was supplied, tolerating
|
||||||
|
# dial-by-IP or hostname mismatch.
|
||||||
|
if not options.server_name_override:
|
||||||
|
channel_options.append(("grpc.ssl_target_name_override", "localhost"))
|
||||||
|
|
||||||
credentials = grpc.ssl_channel_credentials(root_certificates=root_certificates)
|
|
||||||
return grpc.aio.secure_channel(
|
return grpc.aio.secure_channel(
|
||||||
options.endpoint,
|
options.endpoint,
|
||||||
credentials,
|
credentials,
|
||||||
|
|||||||
@@ -0,0 +1,154 @@
|
|||||||
|
"""TLS behaviour tests for ``create_channel``.
|
||||||
|
|
||||||
|
These spin up a real loopback ``grpc.aio`` server with a freshly generated
|
||||||
|
self-signed certificate (carrying a ``localhost`` SAN, mirroring the gateway's
|
||||||
|
auto-generated cert) and assert the lenient TOFU default lets a client connect
|
||||||
|
without any CA configured.
|
||||||
|
|
||||||
|
Marked ``tls`` and skipped unless ``MXGATEWAY_RUN_TLS_TESTS=1`` because loopback
|
||||||
|
TLS handshakes can be timing-flaky on shared CI runners. This mirrors how the
|
||||||
|
suite gates anything that depends on real sockets rather than fakes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import socket
|
||||||
|
import ssl
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from zb_mom_ww_mxgateway import ClientOptions
|
||||||
|
from zb_mom_ww_mxgateway.errors import MxGatewayTransportError
|
||||||
|
from zb_mom_ww_mxgateway.generated import mxaccess_gateway_pb2 as pb
|
||||||
|
from zb_mom_ww_mxgateway.generated import mxaccess_gateway_pb2_grpc as pb_grpc
|
||||||
|
from zb_mom_ww_mxgateway.options import create_channel
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.tls
|
||||||
|
|
||||||
|
_RUN_TLS_TESTS = os.environ.get("MXGATEWAY_RUN_TLS_TESTS") == "1"
|
||||||
|
_OPENSSL = shutil.which("openssl")
|
||||||
|
|
||||||
|
requires_tls = pytest.mark.skipif(
|
||||||
|
not _RUN_TLS_TESTS,
|
||||||
|
reason="set MXGATEWAY_RUN_TLS_TESTS=1 to run loopback TLS tests",
|
||||||
|
)
|
||||||
|
requires_openssl = pytest.mark.skipif(
|
||||||
|
_OPENSSL is None,
|
||||||
|
reason="openssl CLI is required to generate a self-signed test certificate",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_self_signed_cert(directory: Path) -> tuple[Path, Path]:
|
||||||
|
"""Generate a self-signed cert/key pair with a ``localhost`` SAN."""
|
||||||
|
key_path = directory / "server.key"
|
||||||
|
cert_path = directory / "server.crt"
|
||||||
|
subprocess.run(
|
||||||
|
[
|
||||||
|
str(_OPENSSL),
|
||||||
|
"req",
|
||||||
|
"-x509",
|
||||||
|
"-newkey",
|
||||||
|
"rsa:2048",
|
||||||
|
"-nodes",
|
||||||
|
"-keyout",
|
||||||
|
str(key_path),
|
||||||
|
"-out",
|
||||||
|
str(cert_path),
|
||||||
|
"-days",
|
||||||
|
"1",
|
||||||
|
"-subj",
|
||||||
|
"/CN=mxgateway-test",
|
||||||
|
"-addext",
|
||||||
|
"subjectAltName=DNS:localhost,IP:127.0.0.1",
|
||||||
|
],
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
)
|
||||||
|
return cert_path, key_path
|
||||||
|
|
||||||
|
|
||||||
|
def _free_port() -> int:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||||
|
sock.bind(("127.0.0.1", 0))
|
||||||
|
return int(sock.getsockname()[1])
|
||||||
|
|
||||||
|
|
||||||
|
class _StaticGatewayServicer(pb_grpc.MxAccessGatewayServicer):
|
||||||
|
"""Minimal servicer answering ``OpenSession`` with a fixed session id."""
|
||||||
|
|
||||||
|
async def OpenSession( # noqa: N802 - generated gRPC method name
|
||||||
|
self, request: pb.OpenSessionRequest, context: object
|
||||||
|
) -> pb.OpenSessionReply:
|
||||||
|
return pb.OpenSessionReply(session_id="tls-session-1")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def tls_server() -> AsyncIterator[int]:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
cert_path, key_path = _generate_self_signed_cert(Path(tmp))
|
||||||
|
credentials = grpc.ssl_server_credentials(
|
||||||
|
[(key_path.read_bytes(), cert_path.read_bytes())]
|
||||||
|
)
|
||||||
|
server = grpc.aio.server()
|
||||||
|
pb_grpc.add_MxAccessGatewayServicer_to_server(_StaticGatewayServicer(), server)
|
||||||
|
port = _free_port()
|
||||||
|
server.add_secure_port(f"127.0.0.1:{port}", credentials)
|
||||||
|
await server.start()
|
||||||
|
try:
|
||||||
|
yield port
|
||||||
|
finally:
|
||||||
|
await server.stop(grace=None)
|
||||||
|
|
||||||
|
|
||||||
|
@requires_tls
|
||||||
|
@requires_openssl
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_default_tls_connects_via_tofu(tls_server: int) -> None:
|
||||||
|
"""Default TLS options (no CA) connect by pinning the presented cert."""
|
||||||
|
options = ClientOptions(
|
||||||
|
endpoint=f"127.0.0.1:{tls_server}",
|
||||||
|
api_key="mxgw_test_secret",
|
||||||
|
)
|
||||||
|
channel = create_channel(options)
|
||||||
|
try:
|
||||||
|
stub = pb_grpc.MxAccessGatewayStub(channel)
|
||||||
|
reply = await stub.OpenSession(pb.OpenSessionRequest(), timeout=10)
|
||||||
|
assert reply.session_id == "tls-session-1"
|
||||||
|
finally:
|
||||||
|
await channel.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_authority_parses_host_and_port() -> None:
|
||||||
|
from zb_mom_ww_mxgateway.options import _split_authority
|
||||||
|
|
||||||
|
assert _split_authority("https://10.0.0.5:5120") == ("10.0.0.5", 5120)
|
||||||
|
assert _split_authority("localhost:5120") == ("localhost", 5120)
|
||||||
|
assert _split_authority(":5120") == ("localhost", 5120)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tofu_connect_failure_raises_transport_error() -> None:
|
||||||
|
"""A failed cert pre-fetch surfaces the client's transport error type."""
|
||||||
|
options = ClientOptions(endpoint=f"127.0.0.1:{_free_port()}")
|
||||||
|
with pytest.raises(MxGatewayTransportError) as excinfo:
|
||||||
|
create_channel(options)
|
||||||
|
assert options.endpoint in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_require_certificate_validation_uses_system_trust() -> None:
|
||||||
|
"""``require_certificate_validation`` must not attempt a TOFU pre-fetch."""
|
||||||
|
# Pointing at a closed port: with system-trust the channel is created lazily
|
||||||
|
# (no eager pre-fetch), so create_channel must succeed without connecting.
|
||||||
|
options = ClientOptions(
|
||||||
|
endpoint=f"127.0.0.1:{_free_port()}",
|
||||||
|
require_certificate_validation=True,
|
||||||
|
)
|
||||||
|
channel = create_channel(options)
|
||||||
|
assert isinstance(channel, grpc.aio.Channel)
|
||||||
Reference in New Issue
Block a user