"""Tests for RoutedLLMClient (Phase 4.5+). Splits traffic across two underlying clients based on the ``model`` kwarg. We use simple stub clients to assert the router picks the correct backend for each call. """ from __future__ import annotations from typing import AsyncIterator, Sequence import pytest from chat.llm.client import Message from chat.llm.router import RoutedLLMClient class _StubClient: def __init__(self, name: str): self.name = name self.generate_calls: list[str] = [] self.stream_calls: list[str] = [] self.embed_calls: list[str] = [] async def generate(self, messages, *, model, **params) -> str: self.generate_calls.append(model) return f"{self.name}:{model}" async def stream(self, messages, *, model, **params) -> AsyncIterator[str]: self.stream_calls.append(model) yield f"{self.name}:{model}" async def embed(self, text, *, model) -> list[float]: self.embed_calls.append(model) return [1.0, 2.0] @pytest.mark.asyncio async def test_router_generate_routes_remote_model_to_remote_backend(): """Any model id NOT starting with a local prefix goes to the remote backend — narrative model, remote classifiers, anything else.""" narrative = _StubClient("narrative") local = _StubClient("local") router = RoutedLLMClient( narrative=narrative, local=local, narrative_model="provider/big-model", local_prefixes=("mlx-community/",), ) out = await router.generate( [Message(role="user", content="hi")], model="provider/big-model" ) assert out == "narrative:provider/big-model" assert narrative.generate_calls == ["provider/big-model"] assert local.generate_calls == [] @pytest.mark.asyncio async def test_router_generate_routes_local_prefix_to_local_backend(): """Models prefixed with a local prefix (e.g. ``mlx-community/``) go to the local MLX backend regardless of whether the rest of the path looks like a remote provider id.""" narrative = _StubClient("narrative") local = _StubClient("local") router = RoutedLLMClient( narrative=narrative, local=local, narrative_model="provider/big-model", local_prefixes=("mlx-community/",), ) out = await router.generate( [Message(role="user", content="hi")], model="mlx-community/Hermes-3-Llama-3.1-8B-8bit", ) assert out == "local:mlx-community/Hermes-3-Llama-3.1-8B-8bit" assert local.generate_calls == ["mlx-community/Hermes-3-Llama-3.1-8B-8bit"] assert narrative.generate_calls == [] @pytest.mark.asyncio async def test_router_stream_dispatches_by_prefix(): narrative = _StubClient("narrative") local = _StubClient("local") router = RoutedLLMClient( narrative=narrative, local=local, narrative_model="provider/big-model", local_prefixes=("mlx-community/",), ) chunks_remote = [c async for c in router.stream( [Message(role="user", content="hi")], model="provider/big-model" )] chunks_local = [c async for c in router.stream( [Message(role="user", content="hi")], model="mlx-community/Hermes-3-Llama-3.1-8B-8bit", )] assert chunks_remote == ["narrative:provider/big-model"] assert chunks_local == ["local:mlx-community/Hermes-3-Llama-3.1-8B-8bit"] @pytest.mark.asyncio async def test_router_embed_always_routes_to_local(): """Embeddings always run locally — the remote provider doesn't expose a working ``/v1/embeddings``, so the router never sends embed calls there even if the model name happens to look 'remote'.""" narrative = _StubClient("narrative") local = _StubClient("local") router = RoutedLLMClient( narrative=narrative, local=local, narrative_model="big-model" ) out = await router.embed("hello", model="any-embedding-model") assert out == [1.0, 2.0] assert local.embed_calls == ["any-embedding-model"] assert narrative.embed_calls == []