Files
chat/tests/test_router_client.py
Joseph Doherty de7f6624f0 perf: 18s/turn -> 2.5s/turn (SQLite busy_timeout, parallel state pairs, OpenRouter Cerebras-pinned classifier)
Four changes that compound:

1) **SQLite busy_timeout 5.0s -> 0.1s** in chat/db/connection.py. Root
   cause of the bulk of the slowness. The embedding worker contends
   for the WAL write lock while the request handler holds an open
   transaction; conn.execute's busy-wait does NOT release the GIL, so
   every state_update LLM call after the narrative was silently
   freezing the asyncio event loop for ~5s. With 0.1s the worker
   fails fast and logs (already handled), the chat keeps moving, and
   any missed embedding can be backfilled out of band. Also takes the
   test suite from ~290s -> 13s as a bonus.

2) **Parallel state-update pairs** in multi_state_update.py. Each
   directed (src, tgt) pair becomes a coroutine in asyncio.gather
   instead of a sequential for-loop. Returned order is preserved.

3) **Classifier on OpenRouter, provider-pinned to Cerebras**. New
   prefix-based router: model id with mlx-community/ -> local MLX,
   model == narrative_model -> narrative remote, else -> classifier
   remote. Settings.classifier_provider_order populates extra_body for
   the classifier client only (FeatherlessClient now accepts
   default_extra_body to merge into every chat.completions.create).
   Llama-3.1-8B on Cerebras runs at ~423 tok/s, ~10x the default
   provider. narrative still routes to mistral-nemo:nitro (Friendli).

4) **Cap classify max_tokens at 512**. A misbehaving classifier
   (response_format=json_object ignored) could otherwise generate
   thousands of tokens of prose before classify's JSON validation
   trips the retry. 512 is generous; usual completions are 50-150.

CHAT_LLM_TIMING=1 env var enables per-call timing logs on stderr;
zero overhead when unset. Useful for finding the slow link.

Suite: 464 passed in 13s (was 290s).
2026-04-27 13:51:27 -04:00

123 lines
4.0 KiB
Python

"""Tests for RoutedLLMClient (Phase 4.5+).
Splits traffic across two underlying clients based on the ``model``
kwarg. We use simple stub clients to assert the router picks the
correct backend for each call.
"""
from __future__ import annotations
from typing import AsyncIterator, Sequence
import pytest
from chat.llm.client import Message
from chat.llm.router import RoutedLLMClient
class _StubClient:
def __init__(self, name: str):
self.name = name
self.generate_calls: list[str] = []
self.stream_calls: list[str] = []
self.embed_calls: list[str] = []
async def generate(self, messages, *, model, **params) -> str:
self.generate_calls.append(model)
return f"{self.name}:{model}"
async def stream(self, messages, *, model, **params) -> AsyncIterator[str]:
self.stream_calls.append(model)
yield f"{self.name}:{model}"
async def embed(self, text, *, model) -> list[float]:
self.embed_calls.append(model)
return [1.0, 2.0]
@pytest.mark.asyncio
async def test_router_generate_routes_remote_model_to_remote_backend():
"""Any model id NOT starting with a local prefix goes to the remote
backend — narrative model, remote classifiers, anything else."""
narrative = _StubClient("narrative")
local = _StubClient("local")
router = RoutedLLMClient(
narrative=narrative,
local=local,
narrative_model="provider/big-model",
local_prefixes=("mlx-community/",),
)
out = await router.generate(
[Message(role="user", content="hi")], model="provider/big-model"
)
assert out == "narrative:provider/big-model"
assert narrative.generate_calls == ["provider/big-model"]
assert local.generate_calls == []
@pytest.mark.asyncio
async def test_router_generate_routes_local_prefix_to_local_backend():
"""Models prefixed with a local prefix (e.g. ``mlx-community/``)
go to the local MLX backend regardless of whether the rest of the
path looks like a remote provider id."""
narrative = _StubClient("narrative")
local = _StubClient("local")
router = RoutedLLMClient(
narrative=narrative,
local=local,
narrative_model="provider/big-model",
local_prefixes=("mlx-community/",),
)
out = await router.generate(
[Message(role="user", content="hi")],
model="mlx-community/Hermes-3-Llama-3.1-8B-8bit",
)
assert out == "local:mlx-community/Hermes-3-Llama-3.1-8B-8bit"
assert local.generate_calls == ["mlx-community/Hermes-3-Llama-3.1-8B-8bit"]
assert narrative.generate_calls == []
@pytest.mark.asyncio
async def test_router_stream_dispatches_by_prefix():
narrative = _StubClient("narrative")
local = _StubClient("local")
router = RoutedLLMClient(
narrative=narrative,
local=local,
narrative_model="provider/big-model",
local_prefixes=("mlx-community/",),
)
chunks_remote = [c async for c in router.stream(
[Message(role="user", content="hi")], model="provider/big-model"
)]
chunks_local = [c async for c in router.stream(
[Message(role="user", content="hi")],
model="mlx-community/Hermes-3-Llama-3.1-8B-8bit",
)]
assert chunks_remote == ["narrative:provider/big-model"]
assert chunks_local == ["local:mlx-community/Hermes-3-Llama-3.1-8B-8bit"]
@pytest.mark.asyncio
async def test_router_embed_always_routes_to_local():
"""Embeddings always run locally — the remote provider doesn't
expose a working ``/v1/embeddings``, so the router never sends
embed calls there even if the model name happens to look 'remote'."""
narrative = _StubClient("narrative")
local = _StubClient("local")
router = RoutedLLMClient(
narrative=narrative, local=local, narrative_model="big-model"
)
out = await router.embed("hello", model="any-embedding-model")
assert out == [1.0, 2.0]
assert local.embed_calls == ["any-embedding-model"]
assert narrative.embed_calls == []