"""Tests for auth metadata and connection options.""" import pytest from zb_mom_ww_mxgateway.auth import REDACTED, ApiKey, auth_metadata, redact_secret from zb_mom_ww_mxgateway import options as options_module from zb_mom_ww_mxgateway.options import ClientOptions, create_channel def test_auth_metadata_adds_bearer_api_key() -> None: assert auth_metadata("mxgw_test_secret") == ( ("authorization", "Bearer mxgw_test_secret"), ) def test_api_key_repr_is_redacted() -> None: api_key = ApiKey("mxgw_test_secret") assert "mxgw_test_secret" not in repr(api_key) assert REDACTED in repr(api_key) def test_redact_secret_replaces_known_values() -> None: redacted = redact_secret( "authorization failed for mxgw_test_secret", ["mxgw_test_secret"], ) assert redacted == f"authorization failed for {REDACTED}" def test_client_options_reject_plaintext_with_ca_file() -> None: with pytest.raises(ValueError, match="ca_file"): ClientOptions( endpoint="localhost:5000", plaintext=True, ca_file="ca.pem", ) def test_client_options_repr_redacts_api_key() -> None: options = ClientOptions(endpoint="localhost:5000", api_key="mxgw_test_secret") assert "mxgw_test_secret" not in repr(options) assert REDACTED in repr(options) def test_create_channel_uses_plaintext_channel(monkeypatch: pytest.MonkeyPatch) -> None: calls: list[tuple[str, object]] = [] def fake_insecure_channel(endpoint: str, *, options: object) -> str: calls.append((endpoint, options)) return "plain-channel" monkeypatch.setattr( options_module.grpc.aio, "insecure_channel", fake_insecure_channel, ) channel = create_channel(ClientOptions(endpoint="localhost:5000", plaintext=True)) assert channel == "plain-channel" assert calls == [ ( "localhost:5000", [ ("grpc.max_receive_message_length", 16 * 1024 * 1024), ("grpc.max_send_message_length", 16 * 1024 * 1024), ], ), ] def test_create_channel_uses_tls_channel_tofu_default(monkeypatch: pytest.MonkeyPatch) -> None: """Default TLS (no ca_file, no require_certificate_validation) uses TOFU: fetches the server cert unverified, pins it as root_certificates, and adds grpc.ssl_target_name_override = "localhost" automatically. """ _DUMMY_PEM = "-----BEGIN CERTIFICATE-----\nZmFrZQ==\n-----END CERTIFICATE-----\n" get_cert_calls: list[tuple[str, int]] = [] def fake_get_server_certificate(addr: tuple[str, int]) -> str: get_cert_calls.append(addr) return _DUMMY_PEM cred_calls: list[object] = [] def fake_credentials(*, root_certificates: object = None) -> str: cred_calls.append(root_certificates) return "creds" channel_calls: list[tuple[str, object, object]] = [] def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str: channel_calls.append((endpoint, credentials, options)) return "tls-channel" monkeypatch.setattr(options_module.ssl, "get_server_certificate", fake_get_server_certificate) monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials) monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel) channel = create_channel( ClientOptions(endpoint="gateway.example:5001"), ) assert channel == "tls-channel" # TOFU: should have fetched the cert from the server (host, port) assert get_cert_calls == [("gateway.example", 5001)] # Pinned the fetched PEM bytes as root_certificates assert cred_calls == [_DUMMY_PEM.encode("ascii")] # Auto-injected localhost override (no server_name_override supplied) assert channel_calls == [ ( "gateway.example:5001", "creds", [ ("grpc.max_receive_message_length", 16 * 1024 * 1024), ("grpc.max_send_message_length", 16 * 1024 * 1024), ("grpc.ssl_target_name_override", "localhost"), ], ), ] def test_create_channel_uses_tls_channel_tofu_respects_server_name_override( monkeypatch: pytest.MonkeyPatch, ) -> None: """When server_name_override is set, TOFU still runs but does NOT add the auto-localhost override (the explicit override is already in channel_options). """ _DUMMY_PEM = "-----BEGIN CERTIFICATE-----\nZmFrZQ==\n-----END CERTIFICATE-----\n" monkeypatch.setattr( options_module.ssl, "get_server_certificate", lambda addr: _DUMMY_PEM, ) cred_calls: list[object] = [] def fake_credentials(*, root_certificates: object = None) -> str: cred_calls.append(root_certificates) return "creds" channel_calls: list[tuple[str, object, object]] = [] def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str: channel_calls.append((endpoint, credentials, options)) return "tls-channel" monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials) monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel) channel = create_channel( ClientOptions( endpoint="gateway.example:5001", server_name_override="gateway.test", ), ) assert channel == "tls-channel" assert cred_calls == [_DUMMY_PEM.encode("ascii")] assert channel_calls == [ ( "gateway.example:5001", "creds", [ ("grpc.max_receive_message_length", 16 * 1024 * 1024), ("grpc.max_send_message_length", 16 * 1024 * 1024), # Explicit override from ClientOptions — not the auto-localhost one ("grpc.ssl_target_name_override", "gateway.test"), ], ), ] def test_create_channel_uses_tls_channel_require_cert_validation( monkeypatch: pytest.MonkeyPatch, ) -> None: """require_certificate_validation=True uses system trust (no TOFU, no root_certificates).""" get_cert_called = False def fake_get_server_certificate(addr: object) -> str: # pragma: no cover nonlocal get_cert_called get_cert_called = True return "SHOULD_NOT_BE_CALLED" cred_calls: list[object] = [] def fake_credentials(**kwargs: object) -> str: cred_calls.append(kwargs) return "creds" channel_calls: list[tuple[str, object, object]] = [] def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str: channel_calls.append((endpoint, credentials, options)) return "tls-channel" monkeypatch.setattr(options_module.ssl, "get_server_certificate", fake_get_server_certificate) monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials) monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel) channel = create_channel( ClientOptions( endpoint="gateway.example:5001", require_certificate_validation=True, ), ) assert channel == "tls-channel" # Must NOT call TOFU prefetch assert not get_cert_called # ssl_channel_credentials() called with NO keyword args (system trust) assert cred_calls == [{}] assert channel_calls == [ ( "gateway.example:5001", "creds", [ ("grpc.max_receive_message_length", 16 * 1024 * 1024), ("grpc.max_send_message_length", 16 * 1024 * 1024), ], ), ] def test_create_channel_uses_tls_channel_ca_file( monkeypatch: pytest.MonkeyPatch, tmp_path: pytest.TempPathFactory, ) -> None: """ca_file path: reads the PEM file, passes bytes as root_certificates, skips TOFU.""" ca_pem = b"-----BEGIN CERTIFICATE-----\nY2FkYXRh\n-----END CERTIFICATE-----\n" ca_file = tmp_path / "ca.pem" ca_file.write_bytes(ca_pem) get_cert_called = False def fake_get_server_certificate(addr: object) -> str: # pragma: no cover nonlocal get_cert_called get_cert_called = True return "SHOULD_NOT_BE_CALLED" cred_calls: list[object] = [] def fake_credentials(*, root_certificates: object = None) -> str: cred_calls.append(root_certificates) return "creds" channel_calls: list[tuple[str, object, object]] = [] def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str: channel_calls.append((endpoint, credentials, options)) return "tls-channel" monkeypatch.setattr(options_module.ssl, "get_server_certificate", fake_get_server_certificate) monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials) monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel) channel = create_channel( ClientOptions( endpoint="gateway.example:5001", ca_file=str(ca_file), ), ) assert channel == "tls-channel" assert not get_cert_called assert cred_calls == [ca_pem] assert channel_calls == [ ( "gateway.example:5001", "creds", [ ("grpc.max_receive_message_length", 16 * 1024 * 1024), ("grpc.max_send_message_length", 16 * 1024 * 1024), ], ), ]