test(client-python): update TLS default-channel test for TOFU behavior
This commit is contained in:
@@ -72,27 +72,83 @@ def test_create_channel_uses_plaintext_channel(monkeypatch: pytest.MonkeyPatch)
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_create_channel_uses_tls_channel(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_create_channel_uses_tls_channel_tofu_default(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
calls: list[tuple[str, object, object]] = []
|
"""Default TLS (no ca_file, no require_certificate_validation) uses TOFU:
|
||||||
|
fetches the server cert unverified, pins it as root_certificates, and adds
|
||||||
|
grpc.ssl_target_name_override = "localhost" automatically.
|
||||||
|
"""
|
||||||
|
_DUMMY_PEM = "-----BEGIN CERTIFICATE-----\nZmFrZQ==\n-----END CERTIFICATE-----\n"
|
||||||
|
get_cert_calls: list[tuple[str, int]] = []
|
||||||
|
|
||||||
def fake_credentials(*, root_certificates: object) -> str:
|
def fake_get_server_certificate(addr: tuple[str, int]) -> str:
|
||||||
assert root_certificates is None
|
get_cert_calls.append(addr)
|
||||||
|
return _DUMMY_PEM
|
||||||
|
|
||||||
|
cred_calls: list[object] = []
|
||||||
|
|
||||||
|
def fake_credentials(*, root_certificates: object = None) -> str:
|
||||||
|
cred_calls.append(root_certificates)
|
||||||
return "creds"
|
return "creds"
|
||||||
|
|
||||||
|
channel_calls: list[tuple[str, object, object]] = []
|
||||||
|
|
||||||
def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str:
|
def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str:
|
||||||
calls.append((endpoint, credentials, options))
|
channel_calls.append((endpoint, credentials, options))
|
||||||
return "tls-channel"
|
return "tls-channel"
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(options_module.ssl, "get_server_certificate", fake_get_server_certificate)
|
||||||
options_module.grpc,
|
monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials)
|
||||||
"ssl_channel_credentials",
|
monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel)
|
||||||
fake_credentials,
|
|
||||||
|
channel = create_channel(
|
||||||
|
ClientOptions(endpoint="gateway.example:5001"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert channel == "tls-channel"
|
||||||
|
# TOFU: should have fetched the cert from the server (host, port)
|
||||||
|
assert get_cert_calls == [("gateway.example", 5001)]
|
||||||
|
# Pinned the fetched PEM bytes as root_certificates
|
||||||
|
assert cred_calls == [_DUMMY_PEM.encode("ascii")]
|
||||||
|
# Auto-injected localhost override (no server_name_override supplied)
|
||||||
|
assert channel_calls == [
|
||||||
|
(
|
||||||
|
"gateway.example:5001",
|
||||||
|
"creds",
|
||||||
|
[
|
||||||
|
("grpc.max_receive_message_length", 16 * 1024 * 1024),
|
||||||
|
("grpc.max_send_message_length", 16 * 1024 * 1024),
|
||||||
|
("grpc.ssl_target_name_override", "localhost"),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_channel_uses_tls_channel_tofu_respects_server_name_override(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""When server_name_override is set, TOFU still runs but does NOT add the
|
||||||
|
auto-localhost override (the explicit override is already in channel_options).
|
||||||
|
"""
|
||||||
|
_DUMMY_PEM = "-----BEGIN CERTIFICATE-----\nZmFrZQ==\n-----END CERTIFICATE-----\n"
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
options_module.grpc.aio,
|
options_module.ssl,
|
||||||
"secure_channel",
|
"get_server_certificate",
|
||||||
fake_secure_channel,
|
lambda addr: _DUMMY_PEM,
|
||||||
)
|
)
|
||||||
|
cred_calls: list[object] = []
|
||||||
|
|
||||||
|
def fake_credentials(*, root_certificates: object = None) -> str:
|
||||||
|
cred_calls.append(root_certificates)
|
||||||
|
return "creds"
|
||||||
|
|
||||||
|
channel_calls: list[tuple[str, object, object]] = []
|
||||||
|
|
||||||
|
def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str:
|
||||||
|
channel_calls.append((endpoint, credentials, options))
|
||||||
|
return "tls-channel"
|
||||||
|
|
||||||
|
monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials)
|
||||||
|
monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel)
|
||||||
|
|
||||||
channel = create_channel(
|
channel = create_channel(
|
||||||
ClientOptions(
|
ClientOptions(
|
||||||
@@ -102,14 +158,121 @@ def test_create_channel_uses_tls_channel(monkeypatch: pytest.MonkeyPatch) -> Non
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert channel == "tls-channel"
|
assert channel == "tls-channel"
|
||||||
assert calls == [
|
assert cred_calls == [_DUMMY_PEM.encode("ascii")]
|
||||||
(
|
assert channel_calls == [
|
||||||
"gateway.example:5001",
|
(
|
||||||
"creds",
|
"gateway.example:5001",
|
||||||
[
|
"creds",
|
||||||
("grpc.max_receive_message_length", 16 * 1024 * 1024),
|
[
|
||||||
("grpc.max_send_message_length", 16 * 1024 * 1024),
|
("grpc.max_receive_message_length", 16 * 1024 * 1024),
|
||||||
("grpc.ssl_target_name_override", "gateway.test"),
|
("grpc.max_send_message_length", 16 * 1024 * 1024),
|
||||||
],
|
# Explicit override from ClientOptions — not the auto-localhost one
|
||||||
),
|
("grpc.ssl_target_name_override", "gateway.test"),
|
||||||
]
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_channel_uses_tls_channel_require_cert_validation(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""require_certificate_validation=True uses system trust (no TOFU, no root_certificates)."""
|
||||||
|
get_cert_called = False
|
||||||
|
|
||||||
|
def fake_get_server_certificate(addr: object) -> str: # pragma: no cover
|
||||||
|
nonlocal get_cert_called
|
||||||
|
get_cert_called = True
|
||||||
|
return "SHOULD_NOT_BE_CALLED"
|
||||||
|
|
||||||
|
cred_calls: list[object] = []
|
||||||
|
|
||||||
|
def fake_credentials(**kwargs: object) -> str:
|
||||||
|
cred_calls.append(kwargs)
|
||||||
|
return "creds"
|
||||||
|
|
||||||
|
channel_calls: list[tuple[str, object, object]] = []
|
||||||
|
|
||||||
|
def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str:
|
||||||
|
channel_calls.append((endpoint, credentials, options))
|
||||||
|
return "tls-channel"
|
||||||
|
|
||||||
|
monkeypatch.setattr(options_module.ssl, "get_server_certificate", fake_get_server_certificate)
|
||||||
|
monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials)
|
||||||
|
monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel)
|
||||||
|
|
||||||
|
channel = create_channel(
|
||||||
|
ClientOptions(
|
||||||
|
endpoint="gateway.example:5001",
|
||||||
|
require_certificate_validation=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert channel == "tls-channel"
|
||||||
|
# Must NOT call TOFU prefetch
|
||||||
|
assert not get_cert_called
|
||||||
|
# ssl_channel_credentials() called with NO keyword args (system trust)
|
||||||
|
assert cred_calls == [{}]
|
||||||
|
assert channel_calls == [
|
||||||
|
(
|
||||||
|
"gateway.example:5001",
|
||||||
|
"creds",
|
||||||
|
[
|
||||||
|
("grpc.max_receive_message_length", 16 * 1024 * 1024),
|
||||||
|
("grpc.max_send_message_length", 16 * 1024 * 1024),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_channel_uses_tls_channel_ca_file(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path: pytest.TempPathFactory,
|
||||||
|
) -> None:
|
||||||
|
"""ca_file path: reads the PEM file, passes bytes as root_certificates, skips TOFU."""
|
||||||
|
ca_pem = b"-----BEGIN CERTIFICATE-----\nY2FkYXRh\n-----END CERTIFICATE-----\n"
|
||||||
|
ca_file = tmp_path / "ca.pem"
|
||||||
|
ca_file.write_bytes(ca_pem)
|
||||||
|
|
||||||
|
get_cert_called = False
|
||||||
|
|
||||||
|
def fake_get_server_certificate(addr: object) -> str: # pragma: no cover
|
||||||
|
nonlocal get_cert_called
|
||||||
|
get_cert_called = True
|
||||||
|
return "SHOULD_NOT_BE_CALLED"
|
||||||
|
|
||||||
|
cred_calls: list[object] = []
|
||||||
|
|
||||||
|
def fake_credentials(*, root_certificates: object = None) -> str:
|
||||||
|
cred_calls.append(root_certificates)
|
||||||
|
return "creds"
|
||||||
|
|
||||||
|
channel_calls: list[tuple[str, object, object]] = []
|
||||||
|
|
||||||
|
def fake_secure_channel(endpoint: str, credentials: object, *, options: object) -> str:
|
||||||
|
channel_calls.append((endpoint, credentials, options))
|
||||||
|
return "tls-channel"
|
||||||
|
|
||||||
|
monkeypatch.setattr(options_module.ssl, "get_server_certificate", fake_get_server_certificate)
|
||||||
|
monkeypatch.setattr(options_module.grpc, "ssl_channel_credentials", fake_credentials)
|
||||||
|
monkeypatch.setattr(options_module.grpc.aio, "secure_channel", fake_secure_channel)
|
||||||
|
|
||||||
|
channel = create_channel(
|
||||||
|
ClientOptions(
|
||||||
|
endpoint="gateway.example:5001",
|
||||||
|
ca_file=str(ca_file),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert channel == "tls-channel"
|
||||||
|
assert not get_cert_called
|
||||||
|
assert cred_calls == [ca_pem]
|
||||||
|
assert channel_calls == [
|
||||||
|
(
|
||||||
|
"gateway.example:5001",
|
||||||
|
"creds",
|
||||||
|
[
|
||||||
|
("grpc.max_receive_message_length", 16 * 1024 * 1024),
|
||||||
|
("grpc.max_send_message_length", 16 * 1024 * 1024),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user