de7f6624f0
Four changes that compound: 1) **SQLite busy_timeout 5.0s -> 0.1s** in chat/db/connection.py. Root cause of the bulk of the slowness. The embedding worker contends for the WAL write lock while the request handler holds an open transaction; conn.execute's busy-wait does NOT release the GIL, so every state_update LLM call after the narrative was silently freezing the asyncio event loop for ~5s. With 0.1s the worker fails fast and logs (already handled), the chat keeps moving, and any missed embedding can be backfilled out of band. Also takes the test suite from ~290s -> 13s as a bonus. 2) **Parallel state-update pairs** in multi_state_update.py. Each directed (src, tgt) pair becomes a coroutine in asyncio.gather instead of a sequential for-loop. Returned order is preserved. 3) **Classifier on OpenRouter, provider-pinned to Cerebras**. New prefix-based router: model id with mlx-community/ -> local MLX, model == narrative_model -> narrative remote, else -> classifier remote. Settings.classifier_provider_order populates extra_body for the classifier client only (FeatherlessClient now accepts default_extra_body to merge into every chat.completions.create). Llama-3.1-8B on Cerebras runs at ~423 tok/s, ~10x the default provider. narrative still routes to mistral-nemo:nitro (Friendli). 4) **Cap classify max_tokens at 512**. A misbehaving classifier (response_format=json_object ignored) could otherwise generate thousands of tokens of prose before classify's JSON validation trips the retry. 512 is generous; usual completions are 50-150. CHAT_LLM_TIMING=1 env var enables per-call timing logs on stderr; zero overhead when unset. Useful for finding the slow link. Suite: 464 passed in 13s (was 290s).
150 lines
5.7 KiB
Python
150 lines
5.7 KiB
Python
"""Routed LLM client — splits traffic across multiple backends by model.
|
|
|
|
Phase 4.5+ deployment: the 24B narrative model stays on Featherless,
|
|
the 8B classifier model moves to local MLX, and embeddings run on a
|
|
local BGE/MLX model. One :class:`LLMClient` interface, two underlying
|
|
backends, dispatched by the ``model`` argument at every call site.
|
|
|
|
Routing rule: requests whose ``model`` argument matches the configured
|
|
``narrative_model`` go to the narrative backend; everything else
|
|
(classifier, embeddings, future locally-hosted models) goes to the
|
|
local backend.
|
|
|
|
Set the env var ``CHAT_LLM_TIMING=1`` to log per-call timing at INFO
|
|
level. Useful for finding the slow link in a turn.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
import logging
|
|
import os
|
|
import time
|
|
from typing import AsyncIterator, Sequence
|
|
|
|
from .client import LLMClient, Message
|
|
|
|
|
|
_log = logging.getLogger(__name__)
|
|
_TIMING = os.environ.get("CHAT_LLM_TIMING") == "1"
|
|
if _TIMING and not _log.handlers:
|
|
# Wire a stderr handler when timing is enabled so the per-call
|
|
# logs show up under uvicorn (which doesn't configure non-uvicorn
|
|
# loggers by default).
|
|
_h = logging.StreamHandler()
|
|
_h.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s"))
|
|
_log.addHandler(_h)
|
|
_log.setLevel(logging.INFO)
|
|
_log.propagate = False
|
|
|
|
|
|
class RoutedLLMClient:
|
|
"""Delegates to one of two underlying clients based on ``model``.
|
|
|
|
Routing rule: any model id starting with one of ``local_prefixes``
|
|
goes to the local backend (e.g. ``"mlx-community/"`` for models
|
|
served by ``mlx-omni-server``). Everything else — narrative model,
|
|
remote classifiers, anything on a hosted provider — routes to the
|
|
remote backend.
|
|
|
|
``embed`` always routes locally (the remote provider doesn't
|
|
expose a working ``/v1/embeddings``; see
|
|
:class:`chat.llm.featherless.FeatherlessClient.embed`).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
narrative: LLMClient,
|
|
local: LLMClient,
|
|
narrative_model: str,
|
|
classifier: LLMClient | None = None,
|
|
local_prefixes: tuple[str, ...] = ("mlx-community/",),
|
|
) -> None:
|
|
# ``classifier`` is an optional separate backend for the
|
|
# classifier model. Useful when classifier and narrative both
|
|
# live on a remote OpenRouter-style provider but need different
|
|
# provider-pinning (e.g. Cerebras for the 8B classifier,
|
|
# default Friendli/etc. for the narrative). When ``classifier``
|
|
# is None, classifier traffic falls through to ``narrative``
|
|
# (the remote client) so old wiring keeps working.
|
|
self._narrative = narrative
|
|
self._classifier = classifier
|
|
self._local = local
|
|
self._narrative_model = narrative_model
|
|
self._local_prefixes = local_prefixes
|
|
|
|
def _pick(self, model: str) -> LLMClient:
|
|
if any(model.startswith(p) for p in self._local_prefixes):
|
|
return self._local
|
|
if model == self._narrative_model:
|
|
return self._narrative
|
|
# Anything else (most importantly, the classifier model) goes
|
|
# to the classifier client when configured, otherwise to the
|
|
# narrative remote client.
|
|
return self._classifier or self._narrative
|
|
|
|
async def generate(
|
|
self, messages: Sequence[Message], *, model: str, **params
|
|
) -> str:
|
|
client = self._pick(model)
|
|
backend = (
|
|
"narrative" if client is self._narrative else
|
|
"classifier" if client is self._classifier else
|
|
"local"
|
|
)
|
|
if not _TIMING:
|
|
return await client.generate(messages, model=model, **params)
|
|
in_chars = sum(len(m.content) for m in messages)
|
|
_log.info("LLM generate START [%s] %s in_chars=%d", backend, model, in_chars)
|
|
t0 = time.perf_counter()
|
|
try:
|
|
return await client.generate(messages, model=model, **params)
|
|
finally:
|
|
_log.info(
|
|
"LLM generate END [%s] %s in_chars=%d %.2fs",
|
|
backend, model, in_chars, time.perf_counter() - t0,
|
|
)
|
|
|
|
async def stream(
|
|
self, messages: Sequence[Message], *, model: str, **params
|
|
) -> AsyncIterator[str]:
|
|
client = self._pick(model)
|
|
backend = (
|
|
"narrative" if client is self._narrative else
|
|
"classifier" if client is self._classifier else
|
|
"local"
|
|
)
|
|
if not _TIMING:
|
|
async for chunk in client.stream(messages, model=model, **params):
|
|
yield chunk
|
|
return
|
|
t0 = time.perf_counter()
|
|
ttft = None
|
|
chars_out = 0
|
|
try:
|
|
async for chunk in client.stream(messages, model=model, **params):
|
|
if ttft is None:
|
|
ttft = time.perf_counter() - t0
|
|
chars_out += len(chunk)
|
|
yield chunk
|
|
finally:
|
|
dt = time.perf_counter() - t0
|
|
in_chars = sum(len(m.content) for m in messages)
|
|
_log.info(
|
|
"LLM stream [%s] %s in_chars=%d out_chars=%d ttft=%.2fs total=%.2fs",
|
|
backend, model, in_chars, chars_out, ttft or 0.0, dt,
|
|
)
|
|
|
|
async def embed(self, text: str, *, model: str) -> list[float]:
|
|
# Embeddings always run on the local backend — the remote
|
|
# provider doesn't expose a working ``/v1/embeddings`` endpoint.
|
|
if not _TIMING:
|
|
return await self._local.embed(text, model=model)
|
|
t0 = time.perf_counter()
|
|
try:
|
|
return await self._local.embed(text, model=model)
|
|
finally:
|
|
_log.info(
|
|
"LLM embed [local] %s in_chars=%d %.2fs",
|
|
model, len(text), time.perf_counter() - t0,
|
|
)
|