fix(client/python): reachable cert-validation flag; bounded off-loop TOFU probe; license/marker fixes (Client.Python-027..031)
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
"""Tests for auth metadata and connection options."""
|
||||
|
||||
import socket
|
||||
|
||||
import pytest
|
||||
|
||||
from zb_mom_ww_mxgateway.auth import REDACTED, ApiKey, auth_metadata, redact_secret
|
||||
from zb_mom_ww_mxgateway import options as options_module
|
||||
from zb_mom_ww_mxgateway.errors import MxGatewayTransportError
|
||||
from zb_mom_ww_mxgateway.options import ClientOptions, create_channel
|
||||
|
||||
|
||||
@@ -80,7 +83,9 @@ def test_create_channel_uses_tls_channel_tofu_default(monkeypatch: pytest.Monkey
|
||||
_DUMMY_PEM = "-----BEGIN CERTIFICATE-----\nZmFrZQ==\n-----END CERTIFICATE-----\n"
|
||||
get_cert_calls: list[tuple[str, int]] = []
|
||||
|
||||
def fake_get_server_certificate(addr: tuple[str, int]) -> str:
|
||||
def fake_get_server_certificate(
|
||||
addr: tuple[str, int], *, timeout: float | None = None
|
||||
) -> str:
|
||||
get_cert_calls.append(addr)
|
||||
return _DUMMY_PEM
|
||||
|
||||
@@ -133,7 +138,7 @@ def test_create_channel_uses_tls_channel_tofu_respects_server_name_override(
|
||||
monkeypatch.setattr(
|
||||
options_module.ssl,
|
||||
"get_server_certificate",
|
||||
lambda addr: _DUMMY_PEM,
|
||||
lambda addr, *, timeout=None: _DUMMY_PEM,
|
||||
)
|
||||
cred_calls: list[object] = []
|
||||
|
||||
@@ -276,3 +281,46 @@ def test_create_channel_uses_tls_channel_ca_file(
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def test_tofu_probe_passes_a_bounded_timeout(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""The TOFU cert pre-fetch must be bounded so a black-holed host fails fast."""
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def fake_get_server_certificate(addr: object, *, timeout: float | None = None) -> str:
|
||||
captured["timeout"] = timeout
|
||||
return "-----BEGIN CERTIFICATE-----\nZmFrZQ==\n-----END CERTIFICATE-----\n"
|
||||
|
||||
monkeypatch.setattr(options_module.ssl, "get_server_certificate", fake_get_server_certificate)
|
||||
monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", lambda **_: "creds")
|
||||
monkeypatch.setattr(
|
||||
options_module.grpc.aio,
|
||||
"secure_channel",
|
||||
lambda endpoint, credentials, *, options: "tls-channel",
|
||||
)
|
||||
|
||||
create_channel(ClientOptions(endpoint="gateway.example:5001", call_timeout=7.5))
|
||||
|
||||
# A finite, positive timeout must be supplied (bounded by call_timeout here).
|
||||
assert isinstance(captured["timeout"], (int, float))
|
||||
assert 0 < captured["timeout"] <= 7.5
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raised",
|
||||
[socket.timeout("timed out"), TimeoutError("timed out"), OSError("connection refused")],
|
||||
)
|
||||
def test_tofu_probe_timeout_raises_transport_error(
|
||||
monkeypatch: pytest.MonkeyPatch, raised: Exception
|
||||
) -> None:
|
||||
"""A timed-out / failed probe surfaces as MxGatewayTransportError, not a raw error."""
|
||||
|
||||
def fake_get_server_certificate(addr: object, *, timeout: float | None = None) -> str:
|
||||
raise raised
|
||||
|
||||
monkeypatch.setattr(options_module.ssl, "get_server_certificate", fake_get_server_certificate)
|
||||
|
||||
options = ClientOptions(endpoint="gateway.example:5001")
|
||||
with pytest.raises(MxGatewayTransportError) as excinfo:
|
||||
create_channel(options)
|
||||
assert options.endpoint in str(excinfo.value)
|
||||
|
||||
@@ -2,14 +2,79 @@
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from zb_mom_ww_mxgateway import __version__
|
||||
from zb_mom_ww_mxgateway_cli import commands as commands_module
|
||||
from zb_mom_ww_mxgateway_cli.commands import main
|
||||
|
||||
_BATCH_EOR = "__MXGW_BATCH_EOR__"
|
||||
|
||||
|
||||
def test_require_certificate_validation_flag_flows_through_connect(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""The --require-certificate-validation CLI flag must reach ClientOptions (Client.Python-027)."""
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
async def fake_connect(options, **_kwargs):
|
||||
captured["options"] = options
|
||||
# Return a minimal object that supports the async context-manager protocol
|
||||
# used by every CLI command body (async with await _connect(...) as client).
|
||||
return _FakeAsyncClient()
|
||||
|
||||
monkeypatch.setattr(commands_module.GatewayClient, "connect", fake_connect)
|
||||
|
||||
result = CliRunner().invoke(
|
||||
main,
|
||||
[
|
||||
"open-session",
|
||||
"--endpoint",
|
||||
"gateway.example:5001",
|
||||
"--require-certificate-validation",
|
||||
"--json",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert captured["options"].require_certificate_validation is True
|
||||
|
||||
|
||||
def test_require_certificate_validation_defaults_off(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Without the flag the strict-validation posture stays off (TOFU default)."""
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
async def fake_connect(options, **_kwargs):
|
||||
captured["options"] = options
|
||||
return _FakeAsyncClient()
|
||||
|
||||
monkeypatch.setattr(commands_module.GatewayClient, "connect", fake_connect)
|
||||
|
||||
result = CliRunner().invoke(
|
||||
main,
|
||||
["open-session", "--endpoint", "gateway.example:5001", "--plaintext", "--json"],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert captured["options"].require_certificate_validation is False
|
||||
|
||||
|
||||
class _FakeAsyncClient:
|
||||
"""Minimal async-context-manager fake satisfying the open-session command body."""
|
||||
|
||||
async def __aenter__(self) -> "_FakeAsyncClient":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *_exc: object) -> None:
|
||||
return None
|
||||
|
||||
async def open_session_raw(self, *_args, **_kwargs):
|
||||
from zb_mom_ww_mxgateway.generated import mxaccess_gateway_pb2 as pb
|
||||
|
||||
return pb.OpenSessionReply(session_id="cli-test-session")
|
||||
|
||||
|
||||
def test_version_json_is_deterministic() -> None:
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
@@ -8,9 +8,107 @@ from typing import Any
|
||||
import pytest
|
||||
|
||||
from zb_mom_ww_mxgateway import ClientOptions, GatewayClient, MxAccessError
|
||||
from zb_mom_ww_mxgateway import client as client_module
|
||||
from zb_mom_ww_mxgateway import galaxy as galaxy_module
|
||||
from zb_mom_ww_mxgateway.galaxy import GalaxyRepositoryClient
|
||||
from zb_mom_ww_mxgateway.generated import mxaccess_gateway_pb2 as pb
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_connect_forwards_require_certificate_validation(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""The connect convenience kwarg must reach ClientOptions (Client.Python-027)."""
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
def fake_create_channel(options: ClientOptions) -> object:
|
||||
captured["options"] = options
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(client_module, "create_channel", fake_create_channel)
|
||||
monkeypatch.setattr(client_module.pb_grpc, "MxAccessGatewayStub", lambda channel: object())
|
||||
|
||||
await GatewayClient.connect(
|
||||
endpoint="gateway.example:5001",
|
||||
require_certificate_validation=True,
|
||||
)
|
||||
|
||||
assert captured["options"].require_certificate_validation is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_galaxy_connect_forwards_require_certificate_validation(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""GalaxyRepositoryClient.connect must thread the flag too (Client.Python-027)."""
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
def fake_create_channel(options: ClientOptions) -> object:
|
||||
captured["options"] = options
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(galaxy_module, "create_channel", fake_create_channel)
|
||||
monkeypatch.setattr(
|
||||
galaxy_module.galaxy_pb_grpc, "GalaxyRepositoryStub", lambda channel: object()
|
||||
)
|
||||
|
||||
await GalaxyRepositoryClient.connect(
|
||||
endpoint="gateway.example:5001",
|
||||
require_certificate_validation=True,
|
||||
)
|
||||
|
||||
assert captured["options"].require_certificate_validation is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_connect_runs_create_channel_off_the_event_loop(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""connect must run the blocking channel factory off the loop (Client.Python-028)."""
|
||||
ran_in_thread: dict[str, bool] = {}
|
||||
|
||||
def fake_create_channel(options: ClientOptions) -> object:
|
||||
# If this runs on the event loop thread, get_running_loop() succeeds.
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
ran_in_thread["off_loop"] = False
|
||||
except RuntimeError:
|
||||
ran_in_thread["off_loop"] = True
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(client_module, "create_channel", fake_create_channel)
|
||||
monkeypatch.setattr(client_module.pb_grpc, "MxAccessGatewayStub", lambda channel: object())
|
||||
|
||||
await GatewayClient.connect(endpoint="gateway.example:5001")
|
||||
|
||||
assert ran_in_thread["off_loop"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_galaxy_connect_runs_create_channel_off_the_event_loop(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""GalaxyRepositoryClient.connect must also run the probe off the loop (Client.Python-028)."""
|
||||
ran_in_thread: dict[str, bool] = {}
|
||||
|
||||
def fake_create_channel(options: ClientOptions) -> object:
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
ran_in_thread["off_loop"] = False
|
||||
except RuntimeError:
|
||||
ran_in_thread["off_loop"] = True
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(galaxy_module, "create_channel", fake_create_channel)
|
||||
monkeypatch.setattr(
|
||||
galaxy_module.galaxy_pb_grpc, "GalaxyRepositoryStub", lambda channel: object()
|
||||
)
|
||||
|
||||
await GalaxyRepositoryClient.connect(endpoint="gateway.example:5001")
|
||||
|
||||
assert ran_in_thread["off_loop"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_helpers_send_auth_metadata_and_preserve_raw_replies() -> None:
|
||||
stub = FakeGatewayStub()
|
||||
|
||||
@@ -134,6 +134,17 @@ def test_split_authority_parses_host_and_port() -> None:
|
||||
assert _split_authority(":5120") == ("localhost", 5120)
|
||||
|
||||
|
||||
def test_split_authority_defaults_port_for_portless_endpoint() -> None:
|
||||
from zb_mom_ww_mxgateway.options import _split_authority
|
||||
|
||||
# A bare hostname (no ":port") must default to 443, not crash on int("mygateway").
|
||||
assert _split_authority("mygateway") == ("mygateway", 443)
|
||||
# Scheme-prefixed bare hostname behaves the same.
|
||||
assert _split_authority("https://mygateway") == ("mygateway", 443)
|
||||
# A non-numeric tail after a colon is treated as no explicit port.
|
||||
assert _split_authority("mygateway:") == ("mygateway", 443)
|
||||
|
||||
|
||||
def test_split_authority_strips_ipv6_brackets() -> None:
|
||||
from zb_mom_ww_mxgateway.options import _split_authority
|
||||
|
||||
|
||||
Reference in New Issue
Block a user