279 lines
9.3 KiB
Python
279 lines
9.3 KiB
Python
"""Tests for auth metadata and connection options."""
|
|
|
|
import pytest
|
|
|
|
from zb_mom_ww_mxgateway.auth import REDACTED, ApiKey, auth_metadata, redact_secret
|
|
from zb_mom_ww_mxgateway import options as options_module
|
|
from zb_mom_ww_mxgateway.options import ClientOptions, create_channel
|
|
|
|
|
|
def test_auth_metadata_adds_bearer_api_key() -> None:
|
|
assert auth_metadata("mxgw_test_secret") == (
|
|
("authorization", "Bearer mxgw_test_secret"),
|
|
)
|
|
|
|
|
|
def test_api_key_repr_is_redacted() -> None:
|
|
api_key = ApiKey("mxgw_test_secret")
|
|
|
|
assert "mxgw_test_secret" not in repr(api_key)
|
|
assert REDACTED in repr(api_key)
|
|
|
|
|
|
def test_redact_secret_replaces_known_values() -> None:
|
|
redacted = redact_secret(
|
|
"authorization failed for mxgw_test_secret",
|
|
["mxgw_test_secret"],
|
|
)
|
|
|
|
assert redacted == f"authorization failed for {REDACTED}"
|
|
|
|
|
|
def test_client_options_reject_plaintext_with_ca_file() -> None:
|
|
with pytest.raises(ValueError, match="ca_file"):
|
|
ClientOptions(
|
|
endpoint="localhost:5000",
|
|
plaintext=True,
|
|
ca_file="ca.pem",
|
|
)
|
|
|
|
|
|
def test_client_options_repr_redacts_api_key() -> None:
|
|
options = ClientOptions(endpoint="localhost:5000", api_key="mxgw_test_secret")
|
|
|
|
assert "mxgw_test_secret" not in repr(options)
|
|
assert REDACTED in repr(options)
|
|
|
|
|
|
def test_create_channel_uses_plaintext_channel(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
calls: list[tuple[str, object]] = []
|
|
|
|
def fake_insecure_channel(endpoint: str, *, options: object) -> str:
|
|
calls.append((endpoint, options))
|
|
return "plain-channel"
|
|
|
|
monkeypatch.setattr(
|
|
options_module.grpc.aio,
|
|
"insecure_channel",
|
|
fake_insecure_channel,
|
|
)
|
|
|
|
channel = create_channel(ClientOptions(endpoint="localhost:5000", plaintext=True))
|
|
|
|
assert channel == "plain-channel"
|
|
assert calls == [
|
|
(
|
|
"localhost:5000",
|
|
[
|
|
("grpc.max_receive_message_length", 16 * 1024 * 1024),
|
|
("grpc.max_send_message_length", 16 * 1024 * 1024),
|
|
],
|
|
),
|
|
]
|
|
|
|
|
|
def test_create_channel_uses_tls_channel_tofu_default(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Default TLS (no ca_file, no require_certificate_validation) uses TOFU:
|
|
fetches the server cert unverified, pins it as root_certificates, and adds
|
|
grpc.ssl_target_name_override = "localhost" automatically.
|
|
"""
|
|
_DUMMY_PEM = "-----BEGIN CERTIFICATE-----\nZmFrZQ==\n-----END CERTIFICATE-----\n"
|
|
get_cert_calls: list[tuple[str, int]] = []
|
|
|
|
def fake_get_server_certificate(addr: tuple[str, int]) -> str:
|
|
get_cert_calls.append(addr)
|
|
return _DUMMY_PEM
|
|
|
|
cred_calls: list[object] = []
|
|
|
|
def fake_credentials(*, root_certificates: object = None) -> str:
|
|
cred_calls.append(root_certificates)
|
|
return "creds"
|
|
|
|
channel_calls: list[tuple[str, object, object]] = []
|
|
|
|
def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str:
|
|
channel_calls.append((endpoint, credentials, options))
|
|
return "tls-channel"
|
|
|
|
monkeypatch.setattr(options_module.ssl, "get_server_certificate", fake_get_server_certificate)
|
|
monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials)
|
|
monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel)
|
|
|
|
channel = create_channel(
|
|
ClientOptions(endpoint="gateway.example:5001"),
|
|
)
|
|
|
|
assert channel == "tls-channel"
|
|
# TOFU: should have fetched the cert from the server (host, port)
|
|
assert get_cert_calls == [("gateway.example", 5001)]
|
|
# Pinned the fetched PEM bytes as root_certificates
|
|
assert cred_calls == [_DUMMY_PEM.encode("ascii")]
|
|
# Auto-injected localhost override (no server_name_override supplied)
|
|
assert channel_calls == [
|
|
(
|
|
"gateway.example:5001",
|
|
"creds",
|
|
[
|
|
("grpc.max_receive_message_length", 16 * 1024 * 1024),
|
|
("grpc.max_send_message_length", 16 * 1024 * 1024),
|
|
("grpc.ssl_target_name_override", "localhost"),
|
|
],
|
|
),
|
|
]
|
|
|
|
|
|
def test_create_channel_uses_tls_channel_tofu_respects_server_name_override(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
"""When server_name_override is set, TOFU still runs but does NOT add the
|
|
auto-localhost override (the explicit override is already in channel_options).
|
|
"""
|
|
_DUMMY_PEM = "-----BEGIN CERTIFICATE-----\nZmFrZQ==\n-----END CERTIFICATE-----\n"
|
|
monkeypatch.setattr(
|
|
options_module.ssl,
|
|
"get_server_certificate",
|
|
lambda addr: _DUMMY_PEM,
|
|
)
|
|
cred_calls: list[object] = []
|
|
|
|
def fake_credentials(*, root_certificates: object = None) -> str:
|
|
cred_calls.append(root_certificates)
|
|
return "creds"
|
|
|
|
channel_calls: list[tuple[str, object, object]] = []
|
|
|
|
def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str:
|
|
channel_calls.append((endpoint, credentials, options))
|
|
return "tls-channel"
|
|
|
|
monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials)
|
|
monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel)
|
|
|
|
channel = create_channel(
|
|
ClientOptions(
|
|
endpoint="gateway.example:5001",
|
|
server_name_override="gateway.test",
|
|
),
|
|
)
|
|
|
|
assert channel == "tls-channel"
|
|
assert cred_calls == [_DUMMY_PEM.encode("ascii")]
|
|
assert channel_calls == [
|
|
(
|
|
"gateway.example:5001",
|
|
"creds",
|
|
[
|
|
("grpc.max_receive_message_length", 16 * 1024 * 1024),
|
|
("grpc.max_send_message_length", 16 * 1024 * 1024),
|
|
# Explicit override from ClientOptions — not the auto-localhost one
|
|
("grpc.ssl_target_name_override", "gateway.test"),
|
|
],
|
|
),
|
|
]
|
|
|
|
|
|
def test_create_channel_uses_tls_channel_require_cert_validation(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
"""require_certificate_validation=True uses system trust (no TOFU, no root_certificates)."""
|
|
get_cert_called = False
|
|
|
|
def fake_get_server_certificate(addr: object) -> str: # pragma: no cover
|
|
nonlocal get_cert_called
|
|
get_cert_called = True
|
|
return "SHOULD_NOT_BE_CALLED"
|
|
|
|
cred_calls: list[object] = []
|
|
|
|
def fake_credentials(**kwargs: object) -> str:
|
|
cred_calls.append(kwargs)
|
|
return "creds"
|
|
|
|
channel_calls: list[tuple[str, object, object]] = []
|
|
|
|
def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str:
|
|
channel_calls.append((endpoint, credentials, options))
|
|
return "tls-channel"
|
|
|
|
monkeypatch.setattr(options_module.ssl, "get_server_certificate", fake_get_server_certificate)
|
|
monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials)
|
|
monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel)
|
|
|
|
channel = create_channel(
|
|
ClientOptions(
|
|
endpoint="gateway.example:5001",
|
|
require_certificate_validation=True,
|
|
),
|
|
)
|
|
|
|
assert channel == "tls-channel"
|
|
# Must NOT call TOFU prefetch
|
|
assert not get_cert_called
|
|
# ssl_channel_credentials() called with NO keyword args (system trust)
|
|
assert cred_calls == [{}]
|
|
assert channel_calls == [
|
|
(
|
|
"gateway.example:5001",
|
|
"creds",
|
|
[
|
|
("grpc.max_receive_message_length", 16 * 1024 * 1024),
|
|
("grpc.max_send_message_length", 16 * 1024 * 1024),
|
|
],
|
|
),
|
|
]
|
|
|
|
|
|
def test_create_channel_uses_tls_channel_ca_file(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
tmp_path: pytest.TempPathFactory,
|
|
) -> None:
|
|
"""ca_file path: reads the PEM file, passes bytes as root_certificates, skips TOFU."""
|
|
ca_pem = b"-----BEGIN CERTIFICATE-----\nY2FkYXRh\n-----END CERTIFICATE-----\n"
|
|
ca_file = tmp_path / "ca.pem"
|
|
ca_file.write_bytes(ca_pem)
|
|
|
|
get_cert_called = False
|
|
|
|
def fake_get_server_certificate(addr: object) -> str: # pragma: no cover
|
|
nonlocal get_cert_called
|
|
get_cert_called = True
|
|
return "SHOULD_NOT_BE_CALLED"
|
|
|
|
cred_calls: list[object] = []
|
|
|
|
def fake_credentials(*, root_certificates: object = None) -> str:
|
|
cred_calls.append(root_certificates)
|
|
return "creds"
|
|
|
|
channel_calls: list[tuple[str, object, object]] = []
|
|
|
|
def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str:
|
|
channel_calls.append((endpoint, credentials, options))
|
|
return "tls-channel"
|
|
|
|
monkeypatch.setattr(options_module.ssl, "get_server_certificate", fake_get_server_certificate)
|
|
monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials)
|
|
monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel)
|
|
|
|
channel = create_channel(
|
|
ClientOptions(
|
|
endpoint="gateway.example:5001",
|
|
ca_file=str(ca_file),
|
|
),
|
|
)
|
|
|
|
assert channel == "tls-channel"
|
|
assert not get_cert_called
|
|
assert cred_calls == [ca_pem]
|
|
assert channel_calls == [
|
|
(
|
|
"gateway.example:5001",
|
|
"creds",
|
|
[
|
|
("grpc.max_receive_message_length", 16 * 1024 * 1024),
|
|
("grpc.max_send_message_length", 16 * 1024 * 1024),
|
|
],
|
|
),
|
|
]
|