Files
mxaccessgw/clients/python/tests/test_auth_options.py

279 lines
9.3 KiB
Python

"""Tests for auth metadata and connection options."""
import pytest
from zb_mom_ww_mxgateway.auth import REDACTED, ApiKey, auth_metadata, redact_secret
from zb_mom_ww_mxgateway import options as options_module
from zb_mom_ww_mxgateway.options import ClientOptions, create_channel
def test_auth_metadata_adds_bearer_api_key() -> None:
assert auth_metadata("mxgw_test_secret") == (
("authorization", "Bearer mxgw_test_secret"),
)
def test_api_key_repr_is_redacted() -> None:
api_key = ApiKey("mxgw_test_secret")
assert "mxgw_test_secret" not in repr(api_key)
assert REDACTED in repr(api_key)
def test_redact_secret_replaces_known_values() -> None:
redacted = redact_secret(
"authorization failed for mxgw_test_secret",
["mxgw_test_secret"],
)
assert redacted == f"authorization failed for {REDACTED}"
def test_client_options_reject_plaintext_with_ca_file() -> None:
with pytest.raises(ValueError, match="ca_file"):
ClientOptions(
endpoint="localhost:5000",
plaintext=True,
ca_file="ca.pem",
)
def test_client_options_repr_redacts_api_key() -> None:
options = ClientOptions(endpoint="localhost:5000", api_key="mxgw_test_secret")
assert "mxgw_test_secret" not in repr(options)
assert REDACTED in repr(options)
def test_create_channel_uses_plaintext_channel(monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[tuple[str, object]] = []
def fake_insecure_channel(endpoint: str, *, options: object) -> str:
calls.append((endpoint, options))
return "plain-channel"
monkeypatch.setattr(
options_module.grpc.aio,
"insecure_channel",
fake_insecure_channel,
)
channel = create_channel(ClientOptions(endpoint="localhost:5000", plaintext=True))
assert channel == "plain-channel"
assert calls == [
(
"localhost:5000",
[
("grpc.max_receive_message_length", 16 * 1024 * 1024),
("grpc.max_send_message_length", 16 * 1024 * 1024),
],
),
]
def test_create_channel_uses_tls_channel_tofu_default(monkeypatch: pytest.MonkeyPatch) -> None:
"""Default TLS (no ca_file, no require_certificate_validation) uses TOFU:
fetches the server cert unverified, pins it as root_certificates, and adds
grpc.ssl_target_name_override = "localhost" automatically.
"""
_DUMMY_PEM = "-----BEGIN CERTIFICATE-----\nZmFrZQ==\n-----END CERTIFICATE-----\n"
get_cert_calls: list[tuple[str, int]] = []
def fake_get_server_certificate(addr: tuple[str, int]) -> str:
get_cert_calls.append(addr)
return _DUMMY_PEM
cred_calls: list[object] = []
def fake_credentials(*, root_certificates: object = None) -> str:
cred_calls.append(root_certificates)
return "creds"
channel_calls: list[tuple[str, object, object]] = []
def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str:
channel_calls.append((endpoint, credentials, options))
return "tls-channel"
monkeypatch.setattr(options_module.ssl, "get_server_certificate", fake_get_server_certificate)
monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials)
monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel)
channel = create_channel(
ClientOptions(endpoint="gateway.example:5001"),
)
assert channel == "tls-channel"
# TOFU: should have fetched the cert from the server (host, port)
assert get_cert_calls == [("gateway.example", 5001)]
# Pinned the fetched PEM bytes as root_certificates
assert cred_calls == [_DUMMY_PEM.encode("ascii")]
# Auto-injected localhost override (no server_name_override supplied)
assert channel_calls == [
(
"gateway.example:5001",
"creds",
[
("grpc.max_receive_message_length", 16 * 1024 * 1024),
("grpc.max_send_message_length", 16 * 1024 * 1024),
("grpc.ssl_target_name_override", "localhost"),
],
),
]
def test_create_channel_uses_tls_channel_tofu_respects_server_name_override(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When server_name_override is set, TOFU still runs but does NOT add the
auto-localhost override (the explicit override is already in channel_options).
"""
_DUMMY_PEM = "-----BEGIN CERTIFICATE-----\nZmFrZQ==\n-----END CERTIFICATE-----\n"
monkeypatch.setattr(
options_module.ssl,
"get_server_certificate",
lambda addr: _DUMMY_PEM,
)
cred_calls: list[object] = []
def fake_credentials(*, root_certificates: object = None) -> str:
cred_calls.append(root_certificates)
return "creds"
channel_calls: list[tuple[str, object, object]] = []
def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str:
channel_calls.append((endpoint, credentials, options))
return "tls-channel"
monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials)
monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel)
channel = create_channel(
ClientOptions(
endpoint="gateway.example:5001",
server_name_override="gateway.test",
),
)
assert channel == "tls-channel"
assert cred_calls == [_DUMMY_PEM.encode("ascii")]
assert channel_calls == [
(
"gateway.example:5001",
"creds",
[
("grpc.max_receive_message_length", 16 * 1024 * 1024),
("grpc.max_send_message_length", 16 * 1024 * 1024),
# Explicit override from ClientOptions — not the auto-localhost one
("grpc.ssl_target_name_override", "gateway.test"),
],
),
]
def test_create_channel_uses_tls_channel_require_cert_validation(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""require_certificate_validation=True uses system trust (no TOFU, no root_certificates)."""
get_cert_called = False
def fake_get_server_certificate(addr: object) -> str: # pragma: no cover
nonlocal get_cert_called
get_cert_called = True
return "SHOULD_NOT_BE_CALLED"
cred_calls: list[object] = []
def fake_credentials(**kwargs: object) -> str:
cred_calls.append(kwargs)
return "creds"
channel_calls: list[tuple[str, object, object]] = []
def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str:
channel_calls.append((endpoint, credentials, options))
return "tls-channel"
monkeypatch.setattr(options_module.ssl, "get_server_certificate", fake_get_server_certificate)
monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials)
monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel)
channel = create_channel(
ClientOptions(
endpoint="gateway.example:5001",
require_certificate_validation=True,
),
)
assert channel == "tls-channel"
# Must NOT call TOFU prefetch
assert not get_cert_called
# ssl_channel_credentials() called with NO keyword args (system trust)
assert cred_calls == [{}]
assert channel_calls == [
(
"gateway.example:5001",
"creds",
[
("grpc.max_receive_message_length", 16 * 1024 * 1024),
("grpc.max_send_message_length", 16 * 1024 * 1024),
],
),
]
def test_create_channel_uses_tls_channel_ca_file(
monkeypatch: pytest.MonkeyPatch,
tmp_path: pytest.TempPathFactory,
) -> None:
"""ca_file path: reads the PEM file, passes bytes as root_certificates, skips TOFU."""
ca_pem = b"-----BEGIN CERTIFICATE-----\nY2FkYXRh\n-----END CERTIFICATE-----\n"
ca_file = tmp_path / "ca.pem"
ca_file.write_bytes(ca_pem)
get_cert_called = False
def fake_get_server_certificate(addr: object) -> str: # pragma: no cover
nonlocal get_cert_called
get_cert_called = True
return "SHOULD_NOT_BE_CALLED"
cred_calls: list[object] = []
def fake_credentials(*, root_certificates: object = None) -> str:
cred_calls.append(root_certificates)
return "creds"
channel_calls: list[tuple[str, object, object]] = []
def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str:
channel_calls.append((endpoint, credentials, options))
return "tls-channel"
monkeypatch.setattr(options_module.ssl, "get_server_certificate", fake_get_server_certificate)
monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials)
monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel)
channel = create_channel(
ClientOptions(
endpoint="gateway.example:5001",
ca_file=str(ca_file),
),
)
assert channel == "tls-channel"
assert not get_cert_called
assert cred_calls == [ca_pem]
assert channel_calls == [
(
"gateway.example:5001",
"creds",
[
("grpc.max_receive_message_length", 16 * 1024 * 1024),
("grpc.max_send_message_length", 16 * 1024 * 1024),
],
),
]