Files
chat/chat/llm/router.py
T
Joseph Doherty de7f6624f0 perf: 18s/turn -> 2.5s/turn (SQLite busy_timeout, parallel state pairs, OpenRouter Cerebras-pinned classifier)
Four changes that compound:

1) **SQLite busy_timeout 5.0s -> 0.1s** in chat/db/connection.py. Root
   cause of the bulk of the slowness. The embedding worker contends
   for the WAL write lock while the request handler holds an open
   transaction; conn.execute's busy-wait does NOT release the GIL, so
   every state_update LLM call after the narrative was silently
   freezing the asyncio event loop for ~5s. With 0.1s the worker
   fails fast and logs (already handled), the chat keeps moving, and
   any missed embedding can be backfilled out of band. Also takes the
   test suite from ~290s -> 13s as a bonus.

2) **Parallel state-update pairs** in multi_state_update.py. Each
   directed (src, tgt) pair becomes a coroutine in asyncio.gather
   instead of a sequential for-loop. Returned order is preserved.

3) **Classifier on OpenRouter, provider-pinned to Cerebras**. New
   prefix-based router: model id with mlx-community/ -> local MLX,
   model == narrative_model -> narrative remote, else -> classifier
   remote. Settings.classifier_provider_order populates extra_body for
   the classifier client only (FeatherlessClient now accepts
   default_extra_body to merge into every chat.completions.create).
   Llama-3.1-8B on Cerebras runs at ~423 tok/s, ~10x the default
   provider. narrative still routes to mistral-nemo:nitro (Friendli).

4) **Cap classify max_tokens at 512**. A misbehaving classifier
   (response_format=json_object ignored) could otherwise generate
   thousands of tokens of prose before classify's JSON validation
   trips the retry. 512 is generous; usual completions are 50-150.

CHAT_LLM_TIMING=1 env var enables per-call timing logs on stderr;
zero overhead when unset. Useful for finding the slow link.

Suite: 464 passed in 13s (was 290s).
2026-04-27 13:51:27 -04:00

150 lines
5.7 KiB
Python

"""Routed LLM client — splits traffic across multiple backends by model.
Phase 4.5+ deployment: the 24B narrative model stays on Featherless,
the 8B classifier model moves to local MLX, and embeddings run on a
local BGE/MLX model. One :class:`LLMClient` interface, two underlying
backends, dispatched by the ``model`` argument at every call site.
Routing rule: requests whose ``model`` argument matches the configured
``narrative_model`` go to the narrative backend; everything else
(classifier, embeddings, future locally-hosted models) goes to the
local backend.
Set the env var ``CHAT_LLM_TIMING=1`` to log per-call timing at INFO
level. Useful for finding the slow link in a turn.
"""
from __future__ import annotations
import logging
import os
import time
from typing import AsyncIterator, Sequence
from .client import LLMClient, Message
_log = logging.getLogger(__name__)
_TIMING = os.environ.get("CHAT_LLM_TIMING") == "1"
if _TIMING and not _log.handlers:
# Wire a stderr handler when timing is enabled so the per-call
# logs show up under uvicorn (which doesn't configure non-uvicorn
# loggers by default).
_h = logging.StreamHandler()
_h.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s"))
_log.addHandler(_h)
_log.setLevel(logging.INFO)
_log.propagate = False
class RoutedLLMClient:
"""Delegates to one of two underlying clients based on ``model``.
Routing rule: any model id starting with one of ``local_prefixes``
goes to the local backend (e.g. ``"mlx-community/"`` for models
served by ``mlx-omni-server``). Everything else — narrative model,
remote classifiers, anything on a hosted provider — routes to the
remote backend.
``embed`` always routes locally (the remote provider doesn't
expose a working ``/v1/embeddings``; see
:class:`chat.llm.featherless.FeatherlessClient.embed`).
"""
def __init__(
self,
*,
narrative: LLMClient,
local: LLMClient,
narrative_model: str,
classifier: LLMClient | None = None,
local_prefixes: tuple[str, ...] = ("mlx-community/",),
) -> None:
# ``classifier`` is an optional separate backend for the
# classifier model. Useful when classifier and narrative both
# live on a remote OpenRouter-style provider but need different
# provider-pinning (e.g. Cerebras for the 8B classifier,
# default Friendli/etc. for the narrative). When ``classifier``
# is None, classifier traffic falls through to ``narrative``
# (the remote client) so old wiring keeps working.
self._narrative = narrative
self._classifier = classifier
self._local = local
self._narrative_model = narrative_model
self._local_prefixes = local_prefixes
def _pick(self, model: str) -> LLMClient:
if any(model.startswith(p) for p in self._local_prefixes):
return self._local
if model == self._narrative_model:
return self._narrative
# Anything else (most importantly, the classifier model) goes
# to the classifier client when configured, otherwise to the
# narrative remote client.
return self._classifier or self._narrative
async def generate(
self, messages: Sequence[Message], *, model: str, **params
) -> str:
client = self._pick(model)
backend = (
"narrative" if client is self._narrative else
"classifier" if client is self._classifier else
"local"
)
if not _TIMING:
return await client.generate(messages, model=model, **params)
in_chars = sum(len(m.content) for m in messages)
_log.info("LLM generate START [%s] %s in_chars=%d", backend, model, in_chars)
t0 = time.perf_counter()
try:
return await client.generate(messages, model=model, **params)
finally:
_log.info(
"LLM generate END [%s] %s in_chars=%d %.2fs",
backend, model, in_chars, time.perf_counter() - t0,
)
async def stream(
self, messages: Sequence[Message], *, model: str, **params
) -> AsyncIterator[str]:
client = self._pick(model)
backend = (
"narrative" if client is self._narrative else
"classifier" if client is self._classifier else
"local"
)
if not _TIMING:
async for chunk in client.stream(messages, model=model, **params):
yield chunk
return
t0 = time.perf_counter()
ttft = None
chars_out = 0
try:
async for chunk in client.stream(messages, model=model, **params):
if ttft is None:
ttft = time.perf_counter() - t0
chars_out += len(chunk)
yield chunk
finally:
dt = time.perf_counter() - t0
in_chars = sum(len(m.content) for m in messages)
_log.info(
"LLM stream [%s] %s in_chars=%d out_chars=%d ttft=%.2fs total=%.2fs",
backend, model, in_chars, chars_out, ttft or 0.0, dt,
)
async def embed(self, text: str, *, model: str) -> list[float]:
# Embeddings always run on the local backend — the remote
# provider doesn't expose a working ``/v1/embeddings`` endpoint.
if not _TIMING:
return await self._local.embed(text, model=model)
t0 = time.perf_counter()
try:
return await self._local.embed(text, model=model)
finally:
_log.info(
"LLM embed [local] %s in_chars=%d %.2fs",
model, len(text), time.perf_counter() - t0,
)