diff --git a/clients/python/tests/test_auth_options.py b/clients/python/tests/test_auth_options.py index d8242ce..fa6c36c 100644 --- a/clients/python/tests/test_auth_options.py +++ b/clients/python/tests/test_auth_options.py @@ -72,27 +72,83 @@ def test_create_channel_uses_plaintext_channel(monkeypatch: pytest.MonkeyPatch) ] -def test_create_channel_uses_tls_channel(monkeypatch: pytest.MonkeyPatch) -> None: - calls: list[tuple[str, object, object]] = [] +def test_create_channel_uses_tls_channel_tofu_default(monkeypatch: pytest.MonkeyPatch) -> None: + """Default TLS (no ca_file, no require_certificate_validation) uses TOFU: + fetches the server cert unverified, pins it as root_certificates, and adds + grpc.ssl_target_name_override = "localhost" automatically. + """ + _DUMMY_PEM = "-----BEGIN CERTIFICATE-----\nZmFrZQ==\n-----END CERTIFICATE-----\n" + get_cert_calls: list[tuple[str, int]] = [] - def fake_credentials(*, root_certificates: object) -> str: - assert root_certificates is None + def fake_get_server_certificate(addr: tuple[str, int]) -> str: + get_cert_calls.append(addr) + return _DUMMY_PEM + + cred_calls: list[object] = [] + + def fake_credentials(*, root_certificates: object = None) -> str: + cred_calls.append(root_certificates) return "creds" + channel_calls: list[tuple[str, object, object]] = [] + def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str: - calls.append((endpoint, credentials, options)) + channel_calls.append((endpoint, credentials, options)) return "tls-channel" - monkeypatch.setattr( - options_module.grpc, - "ssl_channel_credentials", - fake_credentials, + monkeypatch.setattr(options_module.ssl, "get_server_certificate", fake_get_server_certificate) + monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials) + monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel) + + channel = create_channel( + ClientOptions(endpoint="gateway.example:5001"), ) + + assert channel == "tls-channel" + # TOFU: should have fetched the cert from the server (host, port) + assert get_cert_calls == [("gateway.example", 5001)] + # Pinned the fetched PEM bytes as root_certificates + assert cred_calls == [_DUMMY_PEM.encode("ascii")] + # Auto-injected localhost override (no server_name_override supplied) + assert channel_calls == [ + ( + "gateway.example:5001", + "creds", + [ + ("grpc.max_receive_message_length", 16 * 1024 * 1024), + ("grpc.max_send_message_length", 16 * 1024 * 1024), + ("grpc.ssl_target_name_override", "localhost"), + ], + ), + ] + + +def test_create_channel_uses_tls_channel_tofu_respects_server_name_override( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """When server_name_override is set, TOFU still runs but does NOT add the + auto-localhost override (the explicit override is already in channel_options). + """ + _DUMMY_PEM = "-----BEGIN CERTIFICATE-----\nZmFrZQ==\n-----END CERTIFICATE-----\n" monkeypatch.setattr( - options_module.grpc.aio, - "secure_channel", - fake_secure_channel, + options_module.ssl, + "get_server_certificate", + lambda addr: _DUMMY_PEM, ) + cred_calls: list[object] = [] + + def fake_credentials(*, root_certificates: object = None) -> str: + cred_calls.append(root_certificates) + return "creds" + + channel_calls: list[tuple[str, object, object]] = [] + + def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str: + channel_calls.append((endpoint, credentials, options)) + return "tls-channel" + + monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials) + monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel) channel = create_channel( ClientOptions( @@ -102,14 +158,121 @@ def test_create_channel_uses_tls_channel(monkeypatch: pytest.MonkeyPatch) -> Non ) assert channel == "tls-channel" - assert calls == [ - ( - "gateway.example:5001", - "creds", - [ - ("grpc.max_receive_message_length", 16 * 1024 * 1024), - ("grpc.max_send_message_length", 16 * 1024 * 1024), - ("grpc.ssl_target_name_override", "gateway.test"), - ], - ), - ] + assert cred_calls == [_DUMMY_PEM.encode("ascii")] + assert channel_calls == [ + ( + "gateway.example:5001", + "creds", + [ + ("grpc.max_receive_message_length", 16 * 1024 * 1024), + ("grpc.max_send_message_length", 16 * 1024 * 1024), + # Explicit override from ClientOptions — not the auto-localhost one + ("grpc.ssl_target_name_override", "gateway.test"), + ], + ), + ] + + +def test_create_channel_uses_tls_channel_require_cert_validation( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """require_certificate_validation=True uses system trust (no TOFU, no root_certificates).""" + get_cert_called = False + + def fake_get_server_certificate(addr: object) -> str: # pragma: no cover + nonlocal get_cert_called + get_cert_called = True + return "SHOULD_NOT_BE_CALLED" + + cred_calls: list[object] = [] + + def fake_credentials(**kwargs: object) -> str: + cred_calls.append(kwargs) + return "creds" + + channel_calls: list[tuple[str, object, object]] = [] + + def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str: + channel_calls.append((endpoint, credentials, options)) + return "tls-channel" + + monkeypatch.setattr(options_module.ssl, "get_server_certificate", fake_get_server_certificate) + monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials) + monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel) + + channel = create_channel( + ClientOptions( + endpoint="gateway.example:5001", + require_certificate_validation=True, + ), + ) + + assert channel == "tls-channel" + # Must NOT call TOFU prefetch + assert not get_cert_called + # ssl_channel_credentials() called with NO keyword args (system trust) + assert cred_calls == [{}] + assert channel_calls == [ + ( + "gateway.example:5001", + "creds", + [ + ("grpc.max_receive_message_length", 16 * 1024 * 1024), + ("grpc.max_send_message_length", 16 * 1024 * 1024), + ], + ), + ] + + +def test_create_channel_uses_tls_channel_ca_file( + monkeypatch: pytest.MonkeyPatch, + tmp_path: pytest.TempPathFactory, +) -> None: + """ca_file path: reads the PEM file, passes bytes as root_certificates, skips TOFU.""" + ca_pem = b"-----BEGIN CERTIFICATE-----\nY2FkYXRh\n-----END CERTIFICATE-----\n" + ca_file = tmp_path / "ca.pem" + ca_file.write_bytes(ca_pem) + + get_cert_called = False + + def fake_get_server_certificate(addr: object) -> str: # pragma: no cover + nonlocal get_cert_called + get_cert_called = True + return "SHOULD_NOT_BE_CALLED" + + cred_calls: list[object] = [] + + def fake_credentials(*, root_certificates: object = None) -> str: + cred_calls.append(root_certificates) + return "creds" + + channel_calls: list[tuple[str, object, object]] = [] + + def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str: + channel_calls.append((endpoint, credentials, options)) + return "tls-channel" + + monkeypatch.setattr(options_module.ssl, "get_server_certificate", fake_get_server_certificate) + monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials) + monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel) + + channel = create_channel( + ClientOptions( + endpoint="gateway.example:5001", + ca_file=str(ca_file), + ), + ) + + assert channel == "tls-channel" + assert not get_cert_called + assert cred_calls == [ca_pem] + assert channel_calls == [ + ( + "gateway.example:5001", + "creds", + [ + ("grpc.max_receive_message_length", 16 * 1024 * 1024), + ("grpc.max_send_message_length", 16 * 1024 * 1024), + ], + ), + ]