perf: 18s/turn -> 2.5s/turn (SQLite busy_timeout, parallel state pairs, OpenRouter Cerebras-pinned classifier)

Four changes that compound:

1) **SQLite busy_timeout 5.0s -> 0.1s** in chat/db/connection.py. Root
   cause of the bulk of the slowness. The embedding worker contends
   for the WAL write lock while the request handler holds an open
   transaction; conn.execute's busy-wait does NOT release the GIL, so
   every state_update LLM call after the narrative was silently
   freezing the asyncio event loop for ~5s. With 0.1s the worker
   fails fast and logs (already handled), the chat keeps moving, and
   any missed embedding can be backfilled out of band. Also takes the
   test suite from ~290s -> 13s as a bonus.

2) **Parallel state-update pairs** in multi_state_update.py. Each
   directed (src, tgt) pair becomes a coroutine in asyncio.gather
   instead of a sequential for-loop. Returned order is preserved.

3) **Classifier on OpenRouter, provider-pinned to Cerebras**. New
   prefix-based router: model id with mlx-community/ -> local MLX,
   model == narrative_model -> narrative remote, else -> classifier
   remote. Settings.classifier_provider_order populates extra_body for
   the classifier client only (FeatherlessClient now accepts
   default_extra_body to merge into every chat.completions.create).
   Llama-3.1-8B on Cerebras runs at ~423 tok/s, ~10x the default
   provider. narrative still routes to mistral-nemo:nitro (Friendli).

4) **Cap classify max_tokens at 512**. A misbehaving classifier
   (response_format=json_object ignored) could otherwise generate
   thousands of tokens of prose before classify's JSON validation
   trips the retry. 512 is generous; usual completions are 50-150.

CHAT_LLM_TIMING=1 env var enables per-call timing logs on stderr;
zero overhead when unset. Useful for finding the slow link.

Suite: 464 passed in 13s (was 290s).
This commit is contained in:
Joseph Doherty
2026-04-27 13:51:27 -04:00
parent d656ee8805
commit de7f6624f0
9 changed files with 280 additions and 69 deletions
+12
View File
@@ -88,9 +88,21 @@ async def lifespan(app: FastAPI):
api_key=settings.featherless_api_key,
base_url=settings.featherless_base_url,
)
classifier = None
if settings.classifier_provider_order:
classifier = FeatherlessClient(
api_key=settings.featherless_api_key,
base_url=settings.featherless_base_url,
default_extra_body={
"provider": {
"order": list(settings.classifier_provider_order)
}
},
)
local = LocalMLXClient(base_url=settings.local_mlx_base_url)
return RoutedLLMClient(
narrative=narrative,
classifier=classifier,
local=local,
narrative_model=settings.narrative_model,
)
+12 -8
View File
@@ -44,16 +44,20 @@ class Settings(BaseModel):
bind_host: str = "127.0.0.1"
bind_port: int = 8000
# Local MLX server (e.g. ``mlx-omni-server``) — serves any model
# whose id is NOT ``narrative_model``. The :class:`RoutedLLMClient`
# inspects the ``model`` kwarg at call time: ``model == narrative_model``
# -> Featherless, otherwise local. ``embed()`` always routes local.
# If no MLX server is running and ``classifier_model`` is set to a
# local model id, classifier calls will fail; ``embed()`` will fall
# back to the zero-vector path with a T107 warning. To use the
# default Featherless-only deployment, leave ``classifier_model``
# at its remote default and ``embedding_model`` at the pseudo.
# whose id starts with one of ``local_prefixes`` (default
# ``"mlx-community/"``). The :class:`RoutedLLMClient` inspects the
# ``model`` kwarg at call time: local-prefix -> local, else -> remote.
# ``embed()`` always routes local.
local_mlx_base_url: str = "http://127.0.0.1:10240/v1"
local_mlx_max_concurrent: int = 1
# Optional OpenRouter-style provider pinning for the classifier
# client. Maps to the ``provider`` field on chat.completions.create
# via ``extra_body``; the FeatherlessClient (which is just an
# AsyncOpenAI wrapper) merges it into every call. Useful for forcing
# Llama-3.1-8B classifier traffic onto Cerebras (~423 tok/s, 10x
# the default Nebius). Empty list = no pin (provider is
# OpenRouter's choice).
classifier_provider_order: list[str] = Field(default_factory=list)
# T112 (Phase 4.5): embedding model identifier. Default is the
# deterministic local pseudo so fresh installs / tests don't need
# any external infra. Override via config.toml to a real model id
+14 -1
View File
@@ -7,7 +7,20 @@ from pathlib import Path
@contextmanager
def open_db(path: Path, *, check_same_thread: bool = True):
path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(path, check_same_thread=check_same_thread)
# ``timeout`` here sets SQLite's busy_timeout, in seconds: how long
# ``conn.execute`` blocks when another connection holds the WAL
# write lock. The Python default is 5.0, which is fatal for the
# async chat app: ``conn.execute``'s busy-wait does NOT release the
# GIL, so a contending background worker (e.g. the embedding worker
# writing ``embedding_indexed`` while the request handler holds an
# open transaction) freezes the whole asyncio event loop for up to
# 5 seconds — silently turning every concurrent LLM call into a 5s
# wall-clock hit. 0.1s lets contending writers fail fast; callers
# that need durability should retry, and the embedding worker
# already logs failures so a missed embedding can be backfilled.
conn = sqlite3.connect(
path, check_same_thread=check_same_thread, timeout=0.1
)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
try:
+12 -1
View File
@@ -31,6 +31,7 @@ async def classify(
schema: type[T],
default: T | None = None,
timeout_s: float = 10.0,
max_tokens: int = 512,
) -> T:
schema_json = json.dumps(schema.model_json_schema(), indent=2)
schema_block = (
@@ -41,10 +42,20 @@ async def classify(
Message(role="system", content=system + schema_block),
Message(role="user", content=user),
]
# Cap output length so a misbehaving model (e.g. one that ignores
# ``response_format=json_object`` and generates prose) can't burn
# several seconds on tokens we'll never use. Classifier responses
# are small JSON objects — 512 tokens is generous; usual completions
# are 50-150.
for attempt in range(3):
try:
text = await asyncio.wait_for(
client.generate(msgs, model=model, response_format={"type": "json_object"}),
client.generate(
msgs,
model=model,
response_format={"type": "json_object"},
max_tokens=max_tokens,
),
timeout=timeout_s,
)
cleaned = _strip_json_fences(text)
+42 -1
View File
@@ -29,19 +29,60 @@ class FeatherlessClient:
cls._semaphore = asyncio.Semaphore(2)
return cls._semaphore
def __init__(self, api_key: str, base_url: str = "https://api.featherless.ai/v1"):
def __init__(
self,
api_key: str,
base_url: str = "https://api.featherless.ai/v1",
*,
default_extra_body: dict | None = None,
):
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
# ``default_extra_body`` is merged into every chat.completions.create
# call's ``extra_body``. Useful with OpenRouter to pin specific
# upstream providers (e.g. ``{"provider": {"order": ["Cerebras"]}}``
# for 10x throughput on Llama-3.1-8B). Featherless ignores the
# field, so it's safe to leave set even when ``base_url`` points
# back at Featherless.
self._default_extra_body = default_extra_body or {}
def _merge_extra_body(self, params: dict) -> dict:
if not self._default_extra_body:
return params
eb = dict(self._default_extra_body)
eb.update(params.pop("extra_body", {}) or {})
params["extra_body"] = eb
return params
async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str:
params = self._merge_extra_body(dict(params))
async with self._sem():
resp = await self._client.chat.completions.create(
model=model,
messages=[{"role": m.role, "content": m.content} for m in messages],
**params,
)
# Diagnostic: stash provider+usage on a side-channel for the
# router timing log to pick up. OpenRouter sticks a 'provider'
# field on the response (not part of the OAI spec, but the
# SDK passes it through on its model dict).
try: # pragma: no cover — diagnostic only
import os as _os
if _os.environ.get("CHAT_LLM_TIMING") == "1":
prov = getattr(resp, "provider", None)
usage = getattr(resp, "usage", None)
ct = getattr(usage, "completion_tokens", "?") if usage else "?"
pt = getattr(usage, "prompt_tokens", "?") if usage else "?"
import logging as _logging
_logging.getLogger("chat.llm.router").info(
" ↪ provider=%s prompt_toks=%s completion_toks=%s",
prov, pt, ct,
)
except Exception: # pragma: no cover
pass
return resp.choices[0].message.content or ""
async def stream(self, messages: Sequence[Message], *, model: str, **params) -> AsyncIterator[str]:
params = self._merge_extra_body(dict(params))
async with self._sem():
stream = await self._client.chat.completions.create(
model=model,
+100 -10
View File
@@ -9,22 +9,44 @@ Routing rule: requests whose ``model`` argument matches the configured
``narrative_model`` go to the narrative backend; everything else
(classifier, embeddings, future locally-hosted models) goes to the
local backend.
Set the env var ``CHAT_LLM_TIMING=1`` to log per-call timing at INFO
level. Useful for finding the slow link in a turn.
"""
from __future__ import annotations
import logging
import os
import time
from typing import AsyncIterator, Sequence
from .client import LLMClient, Message
_log = logging.getLogger(__name__)
_TIMING = os.environ.get("CHAT_LLM_TIMING") == "1"
if _TIMING and not _log.handlers:
# Wire a stderr handler when timing is enabled so the per-call
# logs show up under uvicorn (which doesn't configure non-uvicorn
# loggers by default).
_h = logging.StreamHandler()
_h.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s"))
_log.addHandler(_h)
_log.setLevel(logging.INFO)
_log.propagate = False
class RoutedLLMClient:
"""Delegates to one of two underlying clients based on ``model``.
The narrative client is exercised only when ``model`` exactly equals
``narrative_model`` (the configured remote model id). Everything
else — classifier traffic, embeddings, any future locally-hosted
model — routes to the local client. ``embed`` always routes locally
(the remote provider doesn't support it; see
Routing rule: any model id starting with one of ``local_prefixes``
goes to the local backend (e.g. ``"mlx-community/"`` for models
served by ``mlx-omni-server``). Everything else — narrative model,
remote classifiers, anything on a hosted provider — routes to the
remote backend.
``embed`` always routes locally (the remote provider doesn't
expose a working ``/v1/embeddings``; see
:class:`chat.llm.featherless.FeatherlessClient.embed`).
"""
@@ -34,26 +56,94 @@ class RoutedLLMClient:
narrative: LLMClient,
local: LLMClient,
narrative_model: str,
classifier: LLMClient | None = None,
local_prefixes: tuple[str, ...] = ("mlx-community/",),
) -> None:
# ``classifier`` is an optional separate backend for the
# classifier model. Useful when classifier and narrative both
# live on a remote OpenRouter-style provider but need different
# provider-pinning (e.g. Cerebras for the 8B classifier,
# default Friendli/etc. for the narrative). When ``classifier``
# is None, classifier traffic falls through to ``narrative``
# (the remote client) so old wiring keeps working.
self._narrative = narrative
self._classifier = classifier
self._local = local
self._narrative_model = narrative_model
self._local_prefixes = local_prefixes
def _pick(self, model: str) -> LLMClient:
return self._narrative if model == self._narrative_model else self._local
if any(model.startswith(p) for p in self._local_prefixes):
return self._local
if model == self._narrative_model:
return self._narrative
# Anything else (most importantly, the classifier model) goes
# to the classifier client when configured, otherwise to the
# narrative remote client.
return self._classifier or self._narrative
async def generate(
self, messages: Sequence[Message], *, model: str, **params
) -> str:
return await self._pick(model).generate(messages, model=model, **params)
client = self._pick(model)
backend = (
"narrative" if client is self._narrative else
"classifier" if client is self._classifier else
"local"
)
if not _TIMING:
return await client.generate(messages, model=model, **params)
in_chars = sum(len(m.content) for m in messages)
_log.info("LLM generate START [%s] %s in_chars=%d", backend, model, in_chars)
t0 = time.perf_counter()
try:
return await client.generate(messages, model=model, **params)
finally:
_log.info(
"LLM generate END [%s] %s in_chars=%d %.2fs",
backend, model, in_chars, time.perf_counter() - t0,
)
async def stream(
self, messages: Sequence[Message], *, model: str, **params
) -> AsyncIterator[str]:
async for chunk in self._pick(model).stream(messages, model=model, **params):
yield chunk
client = self._pick(model)
backend = (
"narrative" if client is self._narrative else
"classifier" if client is self._classifier else
"local"
)
if not _TIMING:
async for chunk in client.stream(messages, model=model, **params):
yield chunk
return
t0 = time.perf_counter()
ttft = None
chars_out = 0
try:
async for chunk in client.stream(messages, model=model, **params):
if ttft is None:
ttft = time.perf_counter() - t0
chars_out += len(chunk)
yield chunk
finally:
dt = time.perf_counter() - t0
in_chars = sum(len(m.content) for m in messages)
_log.info(
"LLM stream [%s] %s in_chars=%d out_chars=%d ttft=%.2fs total=%.2fs",
backend, model, in_chars, chars_out, ttft or 0.0, dt,
)
async def embed(self, text: str, *, model: str) -> list[float]:
# Embeddings always run on the local backend — the remote
# provider doesn't expose a working ``/v1/embeddings`` endpoint.
return await self._local.embed(text, model=model)
if not _TIMING:
return await self._local.embed(text, model=model)
t0 = time.perf_counter()
try:
return await self._local.embed(text, model=model)
finally:
_log.info(
"LLM embed [local] %s in_chars=%d %.2fs",
model, len(text), time.perf_counter() - t0,
)
+41 -30
View File
@@ -4,13 +4,15 @@ Wraps single-pair compute_state_update to run state updates for ALL
directed pairs of present entities. With 3 present entities (you, host,
guest) that's 6 directed pairs. With 2 present (you, host) it's 2 pairs.
Calls run sequentially to respect Featherless's 2-connection cap (the
client-level semaphore would serialize them anyway, but doing it here
keeps the failure surface clean — a hung pair doesn't queue behind
itself).
Pairs run concurrently via :func:`asyncio.gather`; the underlying
client should impose its own concurrency cap if the upstream provider
needs it (e.g., Featherless's 2-conn semaphore). Returning order is
preserved (natural iteration over ``present_ids x present_ids``,
src != tgt) so downstream event-append order stays deterministic.
"""
from __future__ import annotations
import asyncio
from chat.llm.client import LLMClient
from chat.services.state_update import StateUpdate, compute_state_update
@@ -28,35 +30,44 @@ async def compute_state_updates_for_present(
timeout_s: float = 30.0,
) -> list[tuple[str, str, StateUpdate]]:
"""Run compute_state_update for every directed pair (src != tgt) over
``present_ids``. Returns list of ``(source_id, target_id, update)``
tuples in the natural iteration order over ``present_ids x present_ids``.
``present_ids``, concurrently. Returns list of
``(source_id, target_id, update)`` tuples in the natural iteration
order over ``present_ids x present_ids`` — concurrent dispatch does
not change the returned order.
A single failing pair falls back to the schema-default StateUpdate
(zero deltas, empty facts) inside ``compute_state_update``; the batch
keeps going.
(zero deltas, empty facts) inside ``compute_state_update``; sibling
pairs continue independently because each call is wrapped in its
own try/except inside ``compute_state_update``.
"""
out: list[tuple[str, str, StateUpdate]] = []
for src in present_ids:
for tgt in present_ids:
if src == tgt:
continue
edge = prior_edges.get((src, tgt), {})
update = await compute_state_update(
client,
model=classifier_model,
source_id=src,
target_id=tgt,
source_name=present_names.get(src, src),
source_persona=personas.get(src, "") or "",
target_name=present_names.get(tgt, tgt),
prior_affinity=int(edge.get("affinity", 50)),
prior_trust=int(edge.get("trust", 50)),
prior_summary=edge.get("summary", "") or "",
recent_dialogue=recent_dialogue,
timeout_s=timeout_s,
)
out.append((src, tgt, update))
return out
pair_keys: list[tuple[str, str]] = [
(src, tgt)
for src in present_ids
for tgt in present_ids
if src != tgt
]
if not pair_keys:
return []
async def _one(src: str, tgt: str) -> StateUpdate:
edge = prior_edges.get((src, tgt), {})
return await compute_state_update(
client,
model=classifier_model,
source_id=src,
target_id=tgt,
source_name=present_names.get(src, src),
source_persona=personas.get(src, "") or "",
target_name=present_names.get(tgt, tgt),
prior_affinity=int(edge.get("affinity", 50)),
prior_trust=int(edge.get("trust", 50)),
prior_summary=edge.get("summary", "") or "",
recent_dialogue=recent_dialogue,
timeout_s=timeout_s,
)
updates = await asyncio.gather(*(_one(src, tgt) for src, tgt in pair_keys))
return [(src, tgt, upd) for (src, tgt), upd in zip(pair_keys, updates)]
__all__ = ["compute_state_updates_for_present"]
+15
View File
@@ -50,9 +50,24 @@ def get_llm_client(request: Request) -> LLMClient:
api_key=settings.featherless_api_key,
base_url=settings.featherless_base_url,
)
# Dedicated classifier client when a provider pin is configured —
# routes Llama-3.1-8B (or whatever ``classifier_model`` is) onto a
# specific upstream like Cerebras for ~10x throughput. When the
# pin is empty, ``classifier`` is None and the router falls back
# to the narrative client for classifier traffic.
classifier = None
if settings.classifier_provider_order:
classifier = FeatherlessClient(
api_key=settings.featherless_api_key,
base_url=settings.featherless_base_url,
default_extra_body={
"provider": {"order": list(settings.classifier_provider_order)}
},
)
local = LocalMLXClient(base_url=settings.local_mlx_base_url)
return RoutedLLMClient(
narrative=narrative,
classifier=classifier,
local=local,
narrative_model=settings.narrative_model,
)
+32 -18
View File
@@ -36,58 +36,72 @@ class _StubClient:
@pytest.mark.asyncio
async def test_router_generate_dispatches_narrative_to_narrative_backend():
async def test_router_generate_routes_remote_model_to_remote_backend():
"""Any model id NOT starting with a local prefix goes to the remote
backend — narrative model, remote classifiers, anything else."""
narrative = _StubClient("narrative")
local = _StubClient("local")
router = RoutedLLMClient(
narrative=narrative,
local=local,
narrative_model="big-model",
narrative_model="provider/big-model",
local_prefixes=("mlx-community/",),
)
out = await router.generate([Message(role="user", content="hi")], model="big-model")
out = await router.generate(
[Message(role="user", content="hi")], model="provider/big-model"
)
assert out == "narrative:big-model"
assert narrative.generate_calls == ["big-model"]
assert out == "narrative:provider/big-model"
assert narrative.generate_calls == ["provider/big-model"]
assert local.generate_calls == []
@pytest.mark.asyncio
async def test_router_generate_dispatches_classifier_to_local_backend():
async def test_router_generate_routes_local_prefix_to_local_backend():
"""Models prefixed with a local prefix (e.g. ``mlx-community/``)
go to the local MLX backend regardless of whether the rest of the
path looks like a remote provider id."""
narrative = _StubClient("narrative")
local = _StubClient("local")
router = RoutedLLMClient(
narrative=narrative,
local=local,
narrative_model="big-model",
narrative_model="provider/big-model",
local_prefixes=("mlx-community/",),
)
out = await router.generate(
[Message(role="user", content="hi")], model="small-model"
[Message(role="user", content="hi")],
model="mlx-community/Hermes-3-Llama-3.1-8B-8bit",
)
assert out == "local:small-model"
assert local.generate_calls == ["small-model"]
assert out == "local:mlx-community/Hermes-3-Llama-3.1-8B-8bit"
assert local.generate_calls == ["mlx-community/Hermes-3-Llama-3.1-8B-8bit"]
assert narrative.generate_calls == []
@pytest.mark.asyncio
async def test_router_stream_dispatches_by_model():
async def test_router_stream_dispatches_by_prefix():
narrative = _StubClient("narrative")
local = _StubClient("local")
router = RoutedLLMClient(
narrative=narrative, local=local, narrative_model="big-model"
narrative=narrative,
local=local,
narrative_model="provider/big-model",
local_prefixes=("mlx-community/",),
)
chunks_big = [c async for c in router.stream(
[Message(role="user", content="hi")], model="big-model"
chunks_remote = [c async for c in router.stream(
[Message(role="user", content="hi")], model="provider/big-model"
)]
chunks_small = [c async for c in router.stream(
[Message(role="user", content="hi")], model="other-model"
chunks_local = [c async for c in router.stream(
[Message(role="user", content="hi")],
model="mlx-community/Hermes-3-Llama-3.1-8B-8bit",
)]
assert chunks_big == ["narrative:big-model"]
assert chunks_small == ["local:other-model"]
assert chunks_remote == ["narrative:provider/big-model"]
assert chunks_local == ["local:mlx-community/Hermes-3-Llama-3.1-8B-8bit"]
@pytest.mark.asyncio