de7f6624f0
Four changes that compound: 1) **SQLite busy_timeout 5.0s -> 0.1s** in chat/db/connection.py. Root cause of the bulk of the slowness. The embedding worker contends for the WAL write lock while the request handler holds an open transaction; conn.execute's busy-wait does NOT release the GIL, so every state_update LLM call after the narrative was silently freezing the asyncio event loop for ~5s. With 0.1s the worker fails fast and logs (already handled), the chat keeps moving, and any missed embedding can be backfilled out of band. Also takes the test suite from ~290s -> 13s as a bonus. 2) **Parallel state-update pairs** in multi_state_update.py. Each directed (src, tgt) pair becomes a coroutine in asyncio.gather instead of a sequential for-loop. Returned order is preserved. 3) **Classifier on OpenRouter, provider-pinned to Cerebras**. New prefix-based router: model id with mlx-community/ -> local MLX, model == narrative_model -> narrative remote, else -> classifier remote. Settings.classifier_provider_order populates extra_body for the classifier client only (FeatherlessClient now accepts default_extra_body to merge into every chat.completions.create). Llama-3.1-8B on Cerebras runs at ~423 tok/s, ~10x the default provider. narrative still routes to mistral-nemo:nitro (Friendli). 4) **Cap classify max_tokens at 512**. A misbehaving classifier (response_format=json_object ignored) could otherwise generate thousands of tokens of prose before classify's JSON validation trips the retry. 512 is generous; usual completions are 50-150. CHAT_LLM_TIMING=1 env var enables per-call timing logs on stderr; zero overhead when unset. Useful for finding the slow link. Suite: 464 passed in 13s (was 290s).
123 lines
4.0 KiB
Python
123 lines
4.0 KiB
Python
"""Tests for RoutedLLMClient (Phase 4.5+).
|
|
|
|
Splits traffic across two underlying clients based on the ``model``
|
|
kwarg. We use simple stub clients to assert the router picks the
|
|
correct backend for each call.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import AsyncIterator, Sequence
|
|
|
|
import pytest
|
|
|
|
from chat.llm.client import Message
|
|
from chat.llm.router import RoutedLLMClient
|
|
|
|
|
|
class _StubClient:
|
|
def __init__(self, name: str):
|
|
self.name = name
|
|
self.generate_calls: list[str] = []
|
|
self.stream_calls: list[str] = []
|
|
self.embed_calls: list[str] = []
|
|
|
|
async def generate(self, messages, *, model, **params) -> str:
|
|
self.generate_calls.append(model)
|
|
return f"{self.name}:{model}"
|
|
|
|
async def stream(self, messages, *, model, **params) -> AsyncIterator[str]:
|
|
self.stream_calls.append(model)
|
|
yield f"{self.name}:{model}"
|
|
|
|
async def embed(self, text, *, model) -> list[float]:
|
|
self.embed_calls.append(model)
|
|
return [1.0, 2.0]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_router_generate_routes_remote_model_to_remote_backend():
|
|
"""Any model id NOT starting with a local prefix goes to the remote
|
|
backend — narrative model, remote classifiers, anything else."""
|
|
narrative = _StubClient("narrative")
|
|
local = _StubClient("local")
|
|
router = RoutedLLMClient(
|
|
narrative=narrative,
|
|
local=local,
|
|
narrative_model="provider/big-model",
|
|
local_prefixes=("mlx-community/",),
|
|
)
|
|
|
|
out = await router.generate(
|
|
[Message(role="user", content="hi")], model="provider/big-model"
|
|
)
|
|
|
|
assert out == "narrative:provider/big-model"
|
|
assert narrative.generate_calls == ["provider/big-model"]
|
|
assert local.generate_calls == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_router_generate_routes_local_prefix_to_local_backend():
|
|
"""Models prefixed with a local prefix (e.g. ``mlx-community/``)
|
|
go to the local MLX backend regardless of whether the rest of the
|
|
path looks like a remote provider id."""
|
|
narrative = _StubClient("narrative")
|
|
local = _StubClient("local")
|
|
router = RoutedLLMClient(
|
|
narrative=narrative,
|
|
local=local,
|
|
narrative_model="provider/big-model",
|
|
local_prefixes=("mlx-community/",),
|
|
)
|
|
|
|
out = await router.generate(
|
|
[Message(role="user", content="hi")],
|
|
model="mlx-community/Hermes-3-Llama-3.1-8B-8bit",
|
|
)
|
|
|
|
assert out == "local:mlx-community/Hermes-3-Llama-3.1-8B-8bit"
|
|
assert local.generate_calls == ["mlx-community/Hermes-3-Llama-3.1-8B-8bit"]
|
|
assert narrative.generate_calls == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_router_stream_dispatches_by_prefix():
|
|
narrative = _StubClient("narrative")
|
|
local = _StubClient("local")
|
|
router = RoutedLLMClient(
|
|
narrative=narrative,
|
|
local=local,
|
|
narrative_model="provider/big-model",
|
|
local_prefixes=("mlx-community/",),
|
|
)
|
|
|
|
chunks_remote = [c async for c in router.stream(
|
|
[Message(role="user", content="hi")], model="provider/big-model"
|
|
)]
|
|
chunks_local = [c async for c in router.stream(
|
|
[Message(role="user", content="hi")],
|
|
model="mlx-community/Hermes-3-Llama-3.1-8B-8bit",
|
|
)]
|
|
|
|
assert chunks_remote == ["narrative:provider/big-model"]
|
|
assert chunks_local == ["local:mlx-community/Hermes-3-Llama-3.1-8B-8bit"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_router_embed_always_routes_to_local():
|
|
"""Embeddings always run locally — the remote provider doesn't
|
|
expose a working ``/v1/embeddings``, so the router never sends
|
|
embed calls there even if the model name happens to look 'remote'."""
|
|
narrative = _StubClient("narrative")
|
|
local = _StubClient("local")
|
|
router = RoutedLLMClient(
|
|
narrative=narrative, local=local, narrative_model="big-model"
|
|
)
|
|
|
|
out = await router.embed("hello", model="any-embedding-model")
|
|
|
|
assert out == [1.0, 2.0]
|
|
assert local.embed_calls == ["any-embedding-model"]
|
|
assert narrative.embed_calls == []
|