Files
chat/tests/test_router_client.py
T
Joseph Doherty fe9c497038 feat: split classifier + embeddings to local mlx-omni-server, narrative stays on Featherless
Adds RoutedLLMClient that dispatches by model name: requests matching
Settings.narrative_model go to Featherless, everything else (classifier
calls, embed) goes to a local MLX server. The local server is
mlx-omni-server (separate venv at .mlx-venv) and exposes the standard
OpenAI surface at http://127.0.0.1:10240/v1.

LocalMLXClient mirrors FeatherlessClient (AsyncOpenAI under the hood)
but with a working embed() — Featherless's /v1/embeddings always
returns 500 with completions_error, so the router unconditionally
sends embed traffic to the local backend.

Production deployment overrides via data/config.toml:
- classifier_model = mlx-community/Hermes-3-Llama-3.1-8B-8bit (~8 GB)
- embedding_model = mlx-community/bge-small-en-v1.5-bf16 (~150 MB,
  384 dim — matches existing schema, no migration)

Defaults stay remote / pseudo so fresh installs and tests need no
external infra. Smoke-tested live: classifier returns expected output,
BGE produces correctly-clustering 384-dim vectors (cat-on-mat closer
to cat-on-rug than to quantum-mechanics).

scripts/start_mlx_server.sh starts the daemon (foreground or --daemon).
.mlx-venv/ added to .gitignore.

Suite: 464 passed (was 457 → +7 new across LocalMLXClient + Router).
2026-04-27 12:05:41 -04:00

109 lines
3.3 KiB
Python

"""Tests for RoutedLLMClient (Phase 4.5+).
Splits traffic across two underlying clients based on the ``model``
kwarg. We use simple stub clients to assert the router picks the
correct backend for each call.
"""
from __future__ import annotations
from typing import AsyncIterator, Sequence
import pytest
from chat.llm.client import Message
from chat.llm.router import RoutedLLMClient
class _StubClient:
def __init__(self, name: str):
self.name = name
self.generate_calls: list[str] = []
self.stream_calls: list[str] = []
self.embed_calls: list[str] = []
async def generate(self, messages, *, model, **params) -> str:
self.generate_calls.append(model)
return f"{self.name}:{model}"
async def stream(self, messages, *, model, **params) -> AsyncIterator[str]:
self.stream_calls.append(model)
yield f"{self.name}:{model}"
async def embed(self, text, *, model) -> list[float]:
self.embed_calls.append(model)
return [1.0, 2.0]
@pytest.mark.asyncio
async def test_router_generate_dispatches_narrative_to_narrative_backend():
narrative = _StubClient("narrative")
local = _StubClient("local")
router = RoutedLLMClient(
narrative=narrative,
local=local,
narrative_model="big-model",
)
out = await router.generate([Message(role="user", content="hi")], model="big-model")
assert out == "narrative:big-model"
assert narrative.generate_calls == ["big-model"]
assert local.generate_calls == []
@pytest.mark.asyncio
async def test_router_generate_dispatches_classifier_to_local_backend():
narrative = _StubClient("narrative")
local = _StubClient("local")
router = RoutedLLMClient(
narrative=narrative,
local=local,
narrative_model="big-model",
)
out = await router.generate(
[Message(role="user", content="hi")], model="small-model"
)
assert out == "local:small-model"
assert local.generate_calls == ["small-model"]
assert narrative.generate_calls == []
@pytest.mark.asyncio
async def test_router_stream_dispatches_by_model():
narrative = _StubClient("narrative")
local = _StubClient("local")
router = RoutedLLMClient(
narrative=narrative, local=local, narrative_model="big-model"
)
chunks_big = [c async for c in router.stream(
[Message(role="user", content="hi")], model="big-model"
)]
chunks_small = [c async for c in router.stream(
[Message(role="user", content="hi")], model="other-model"
)]
assert chunks_big == ["narrative:big-model"]
assert chunks_small == ["local:other-model"]
@pytest.mark.asyncio
async def test_router_embed_always_routes_to_local():
"""Embeddings always run locally — the remote provider doesn't
expose a working ``/v1/embeddings``, so the router never sends
embed calls there even if the model name happens to look 'remote'."""
narrative = _StubClient("narrative")
local = _StubClient("local")
router = RoutedLLMClient(
narrative=narrative, local=local, narrative_model="big-model"
)
out = await router.embed("hello", model="any-embedding-model")
assert out == [1.0, 2.0]
assert local.embed_calls == ["any-embedding-model"]
assert narrative.embed_calls == []