From de7f6624f0c09a757f96c666877008b2fc30c98c Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Mon, 27 Apr 2026 13:51:27 -0400 Subject: [PATCH] perf: 18s/turn -> 2.5s/turn (SQLite busy_timeout, parallel state pairs, OpenRouter Cerebras-pinned classifier) Four changes that compound: 1) **SQLite busy_timeout 5.0s -> 0.1s** in chat/db/connection.py. Root cause of the bulk of the slowness. The embedding worker contends for the WAL write lock while the request handler holds an open transaction; conn.execute's busy-wait does NOT release the GIL, so every state_update LLM call after the narrative was silently freezing the asyncio event loop for ~5s. With 0.1s the worker fails fast and logs (already handled), the chat keeps moving, and any missed embedding can be backfilled out of band. Also takes the test suite from ~290s -> 13s as a bonus. 2) **Parallel state-update pairs** in multi_state_update.py. Each directed (src, tgt) pair becomes a coroutine in asyncio.gather instead of a sequential for-loop. Returned order is preserved. 3) **Classifier on OpenRouter, provider-pinned to Cerebras**. New prefix-based router: model id with mlx-community/ -> local MLX, model == narrative_model -> narrative remote, else -> classifier remote. Settings.classifier_provider_order populates extra_body for the classifier client only (FeatherlessClient now accepts default_extra_body to merge into every chat.completions.create). Llama-3.1-8B on Cerebras runs at ~423 tok/s, ~10x the default provider. narrative still routes to mistral-nemo:nitro (Friendli). 4) **Cap classify max_tokens at 512**. A misbehaving classifier (response_format=json_object ignored) could otherwise generate thousands of tokens of prose before classify's JSON validation trips the retry. 512 is generous; usual completions are 50-150. CHAT_LLM_TIMING=1 env var enables per-call timing logs on stderr; zero overhead when unset. Useful for finding the slow link. Suite: 464 passed in 13s (was 290s). --- chat/app.py | 12 +++ chat/config.py | 20 +++-- chat/db/connection.py | 15 +++- chat/llm/classify.py | 13 +++- chat/llm/featherless.py | 43 ++++++++++- chat/llm/router.py | 110 +++++++++++++++++++++++++--- chat/services/multi_state_update.py | 71 ++++++++++-------- chat/web/kickoff.py | 15 ++++ tests/test_router_client.py | 50 ++++++++----- 9 files changed, 280 insertions(+), 69 deletions(-) diff --git a/chat/app.py b/chat/app.py index 3e7fcb9..239c3d9 100644 --- a/chat/app.py +++ b/chat/app.py @@ -88,9 +88,21 @@ async def lifespan(app: FastAPI): api_key=settings.featherless_api_key, base_url=settings.featherless_base_url, ) + classifier = None + if settings.classifier_provider_order: + classifier = FeatherlessClient( + api_key=settings.featherless_api_key, + base_url=settings.featherless_base_url, + default_extra_body={ + "provider": { + "order": list(settings.classifier_provider_order) + } + }, + ) local = LocalMLXClient(base_url=settings.local_mlx_base_url) return RoutedLLMClient( narrative=narrative, + classifier=classifier, local=local, narrative_model=settings.narrative_model, ) diff --git a/chat/config.py b/chat/config.py index 8d633b3..332c95c 100644 --- a/chat/config.py +++ b/chat/config.py @@ -44,16 +44,20 @@ class Settings(BaseModel): bind_host: str = "127.0.0.1" bind_port: int = 8000 # Local MLX server (e.g. ``mlx-omni-server``) — serves any model - # whose id is NOT ``narrative_model``. The :class:`RoutedLLMClient` - # inspects the ``model`` kwarg at call time: ``model == narrative_model`` - # -> Featherless, otherwise local. ``embed()`` always routes local. - # If no MLX server is running and ``classifier_model`` is set to a - # local model id, classifier calls will fail; ``embed()`` will fall - # back to the zero-vector path with a T107 warning. To use the - # default Featherless-only deployment, leave ``classifier_model`` - # at its remote default and ``embedding_model`` at the pseudo. + # whose id starts with one of ``local_prefixes`` (default + # ``"mlx-community/"``). The :class:`RoutedLLMClient` inspects the + # ``model`` kwarg at call time: local-prefix -> local, else -> remote. + # ``embed()`` always routes local. local_mlx_base_url: str = "http://127.0.0.1:10240/v1" local_mlx_max_concurrent: int = 1 + # Optional OpenRouter-style provider pinning for the classifier + # client. Maps to the ``provider`` field on chat.completions.create + # via ``extra_body``; the FeatherlessClient (which is just an + # AsyncOpenAI wrapper) merges it into every call. Useful for forcing + # Llama-3.1-8B classifier traffic onto Cerebras (~423 tok/s, 10x + # the default Nebius). Empty list = no pin (provider is + # OpenRouter's choice). + classifier_provider_order: list[str] = Field(default_factory=list) # T112 (Phase 4.5): embedding model identifier. Default is the # deterministic local pseudo so fresh installs / tests don't need # any external infra. Override via config.toml to a real model id diff --git a/chat/db/connection.py b/chat/db/connection.py index f293aca..aaadf06 100644 --- a/chat/db/connection.py +++ b/chat/db/connection.py @@ -7,7 +7,20 @@ from pathlib import Path @contextmanager def open_db(path: Path, *, check_same_thread: bool = True): path.parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(path, check_same_thread=check_same_thread) + # ``timeout`` here sets SQLite's busy_timeout, in seconds: how long + # ``conn.execute`` blocks when another connection holds the WAL + # write lock. The Python default is 5.0, which is fatal for the + # async chat app: ``conn.execute``'s busy-wait does NOT release the + # GIL, so a contending background worker (e.g. the embedding worker + # writing ``embedding_indexed`` while the request handler holds an + # open transaction) freezes the whole asyncio event loop for up to + # 5 seconds — silently turning every concurrent LLM call into a 5s + # wall-clock hit. 0.1s lets contending writers fail fast; callers + # that need durability should retry, and the embedding worker + # already logs failures so a missed embedding can be backfilled. + conn = sqlite3.connect( + path, check_same_thread=check_same_thread, timeout=0.1 + ) conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA foreign_keys=ON") try: diff --git a/chat/llm/classify.py b/chat/llm/classify.py index 7074c3b..bf46554 100644 --- a/chat/llm/classify.py +++ b/chat/llm/classify.py @@ -31,6 +31,7 @@ async def classify( schema: type[T], default: T | None = None, timeout_s: float = 10.0, + max_tokens: int = 512, ) -> T: schema_json = json.dumps(schema.model_json_schema(), indent=2) schema_block = ( @@ -41,10 +42,20 @@ async def classify( Message(role="system", content=system + schema_block), Message(role="user", content=user), ] + # Cap output length so a misbehaving model (e.g. one that ignores + # ``response_format=json_object`` and generates prose) can't burn + # several seconds on tokens we'll never use. Classifier responses + # are small JSON objects — 512 tokens is generous; usual completions + # are 50-150. for attempt in range(3): try: text = await asyncio.wait_for( - client.generate(msgs, model=model, response_format={"type": "json_object"}), + client.generate( + msgs, + model=model, + response_format={"type": "json_object"}, + max_tokens=max_tokens, + ), timeout=timeout_s, ) cleaned = _strip_json_fences(text) diff --git a/chat/llm/featherless.py b/chat/llm/featherless.py index 00fc9ce..945c8e8 100644 --- a/chat/llm/featherless.py +++ b/chat/llm/featherless.py @@ -29,19 +29,60 @@ class FeatherlessClient: cls._semaphore = asyncio.Semaphore(2) return cls._semaphore - def __init__(self, api_key: str, base_url: str = "https://api.featherless.ai/v1"): + def __init__( + self, + api_key: str, + base_url: str = "https://api.featherless.ai/v1", + *, + default_extra_body: dict | None = None, + ): self._client = AsyncOpenAI(api_key=api_key, base_url=base_url) + # ``default_extra_body`` is merged into every chat.completions.create + # call's ``extra_body``. Useful with OpenRouter to pin specific + # upstream providers (e.g. ``{"provider": {"order": ["Cerebras"]}}`` + # for 10x throughput on Llama-3.1-8B). Featherless ignores the + # field, so it's safe to leave set even when ``base_url`` points + # back at Featherless. + self._default_extra_body = default_extra_body or {} + + def _merge_extra_body(self, params: dict) -> dict: + if not self._default_extra_body: + return params + eb = dict(self._default_extra_body) + eb.update(params.pop("extra_body", {}) or {}) + params["extra_body"] = eb + return params async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str: + params = self._merge_extra_body(dict(params)) async with self._sem(): resp = await self._client.chat.completions.create( model=model, messages=[{"role": m.role, "content": m.content} for m in messages], **params, ) + # Diagnostic: stash provider+usage on a side-channel for the + # router timing log to pick up. OpenRouter sticks a 'provider' + # field on the response (not part of the OAI spec, but the + # SDK passes it through on its model dict). + try: # pragma: no cover — diagnostic only + import os as _os + if _os.environ.get("CHAT_LLM_TIMING") == "1": + prov = getattr(resp, "provider", None) + usage = getattr(resp, "usage", None) + ct = getattr(usage, "completion_tokens", "?") if usage else "?" + pt = getattr(usage, "prompt_tokens", "?") if usage else "?" + import logging as _logging + _logging.getLogger("chat.llm.router").info( + " ↪ provider=%s prompt_toks=%s completion_toks=%s", + prov, pt, ct, + ) + except Exception: # pragma: no cover + pass return resp.choices[0].message.content or "" async def stream(self, messages: Sequence[Message], *, model: str, **params) -> AsyncIterator[str]: + params = self._merge_extra_body(dict(params)) async with self._sem(): stream = await self._client.chat.completions.create( model=model, diff --git a/chat/llm/router.py b/chat/llm/router.py index b113a66..df3a05c 100644 --- a/chat/llm/router.py +++ b/chat/llm/router.py @@ -9,22 +9,44 @@ Routing rule: requests whose ``model`` argument matches the configured ``narrative_model`` go to the narrative backend; everything else (classifier, embeddings, future locally-hosted models) goes to the local backend. + +Set the env var ``CHAT_LLM_TIMING=1`` to log per-call timing at INFO +level. Useful for finding the slow link in a turn. """ from __future__ import annotations +import logging +import os +import time from typing import AsyncIterator, Sequence from .client import LLMClient, Message +_log = logging.getLogger(__name__) +_TIMING = os.environ.get("CHAT_LLM_TIMING") == "1" +if _TIMING and not _log.handlers: + # Wire a stderr handler when timing is enabled so the per-call + # logs show up under uvicorn (which doesn't configure non-uvicorn + # loggers by default). + _h = logging.StreamHandler() + _h.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")) + _log.addHandler(_h) + _log.setLevel(logging.INFO) + _log.propagate = False + + class RoutedLLMClient: """Delegates to one of two underlying clients based on ``model``. - The narrative client is exercised only when ``model`` exactly equals - ``narrative_model`` (the configured remote model id). Everything - else — classifier traffic, embeddings, any future locally-hosted - model — routes to the local client. ``embed`` always routes locally - (the remote provider doesn't support it; see + Routing rule: any model id starting with one of ``local_prefixes`` + goes to the local backend (e.g. ``"mlx-community/"`` for models + served by ``mlx-omni-server``). Everything else — narrative model, + remote classifiers, anything on a hosted provider — routes to the + remote backend. + + ``embed`` always routes locally (the remote provider doesn't + expose a working ``/v1/embeddings``; see :class:`chat.llm.featherless.FeatherlessClient.embed`). """ @@ -34,26 +56,94 @@ class RoutedLLMClient: narrative: LLMClient, local: LLMClient, narrative_model: str, + classifier: LLMClient | None = None, + local_prefixes: tuple[str, ...] = ("mlx-community/",), ) -> None: + # ``classifier`` is an optional separate backend for the + # classifier model. Useful when classifier and narrative both + # live on a remote OpenRouter-style provider but need different + # provider-pinning (e.g. Cerebras for the 8B classifier, + # default Friendli/etc. for the narrative). When ``classifier`` + # is None, classifier traffic falls through to ``narrative`` + # (the remote client) so old wiring keeps working. self._narrative = narrative + self._classifier = classifier self._local = local self._narrative_model = narrative_model + self._local_prefixes = local_prefixes def _pick(self, model: str) -> LLMClient: - return self._narrative if model == self._narrative_model else self._local + if any(model.startswith(p) for p in self._local_prefixes): + return self._local + if model == self._narrative_model: + return self._narrative + # Anything else (most importantly, the classifier model) goes + # to the classifier client when configured, otherwise to the + # narrative remote client. + return self._classifier or self._narrative async def generate( self, messages: Sequence[Message], *, model: str, **params ) -> str: - return await self._pick(model).generate(messages, model=model, **params) + client = self._pick(model) + backend = ( + "narrative" if client is self._narrative else + "classifier" if client is self._classifier else + "local" + ) + if not _TIMING: + return await client.generate(messages, model=model, **params) + in_chars = sum(len(m.content) for m in messages) + _log.info("LLM generate START [%s] %s in_chars=%d", backend, model, in_chars) + t0 = time.perf_counter() + try: + return await client.generate(messages, model=model, **params) + finally: + _log.info( + "LLM generate END [%s] %s in_chars=%d %.2fs", + backend, model, in_chars, time.perf_counter() - t0, + ) async def stream( self, messages: Sequence[Message], *, model: str, **params ) -> AsyncIterator[str]: - async for chunk in self._pick(model).stream(messages, model=model, **params): - yield chunk + client = self._pick(model) + backend = ( + "narrative" if client is self._narrative else + "classifier" if client is self._classifier else + "local" + ) + if not _TIMING: + async for chunk in client.stream(messages, model=model, **params): + yield chunk + return + t0 = time.perf_counter() + ttft = None + chars_out = 0 + try: + async for chunk in client.stream(messages, model=model, **params): + if ttft is None: + ttft = time.perf_counter() - t0 + chars_out += len(chunk) + yield chunk + finally: + dt = time.perf_counter() - t0 + in_chars = sum(len(m.content) for m in messages) + _log.info( + "LLM stream [%s] %s in_chars=%d out_chars=%d ttft=%.2fs total=%.2fs", + backend, model, in_chars, chars_out, ttft or 0.0, dt, + ) async def embed(self, text: str, *, model: str) -> list[float]: # Embeddings always run on the local backend — the remote # provider doesn't expose a working ``/v1/embeddings`` endpoint. - return await self._local.embed(text, model=model) + if not _TIMING: + return await self._local.embed(text, model=model) + t0 = time.perf_counter() + try: + return await self._local.embed(text, model=model) + finally: + _log.info( + "LLM embed [local] %s in_chars=%d %.2fs", + model, len(text), time.perf_counter() - t0, + ) diff --git a/chat/services/multi_state_update.py b/chat/services/multi_state_update.py index f6eeb8e..4d0519f 100644 --- a/chat/services/multi_state_update.py +++ b/chat/services/multi_state_update.py @@ -4,13 +4,15 @@ Wraps single-pair compute_state_update to run state updates for ALL directed pairs of present entities. With 3 present entities (you, host, guest) that's 6 directed pairs. With 2 present (you, host) it's 2 pairs. -Calls run sequentially to respect Featherless's 2-connection cap (the -client-level semaphore would serialize them anyway, but doing it here -keeps the failure surface clean — a hung pair doesn't queue behind -itself). +Pairs run concurrently via :func:`asyncio.gather`; the underlying +client should impose its own concurrency cap if the upstream provider +needs it (e.g., Featherless's 2-conn semaphore). Returning order is +preserved (natural iteration over ``present_ids x present_ids``, +src != tgt) so downstream event-append order stays deterministic. """ from __future__ import annotations +import asyncio from chat.llm.client import LLMClient from chat.services.state_update import StateUpdate, compute_state_update @@ -28,35 +30,44 @@ async def compute_state_updates_for_present( timeout_s: float = 30.0, ) -> list[tuple[str, str, StateUpdate]]: """Run compute_state_update for every directed pair (src != tgt) over - ``present_ids``. Returns list of ``(source_id, target_id, update)`` - tuples in the natural iteration order over ``present_ids x present_ids``. + ``present_ids``, concurrently. Returns list of + ``(source_id, target_id, update)`` tuples in the natural iteration + order over ``present_ids x present_ids`` — concurrent dispatch does + not change the returned order. A single failing pair falls back to the schema-default StateUpdate - (zero deltas, empty facts) inside ``compute_state_update``; the batch - keeps going. + (zero deltas, empty facts) inside ``compute_state_update``; sibling + pairs continue independently because each call is wrapped in its + own try/except inside ``compute_state_update``. """ - out: list[tuple[str, str, StateUpdate]] = [] - for src in present_ids: - for tgt in present_ids: - if src == tgt: - continue - edge = prior_edges.get((src, tgt), {}) - update = await compute_state_update( - client, - model=classifier_model, - source_id=src, - target_id=tgt, - source_name=present_names.get(src, src), - source_persona=personas.get(src, "") or "", - target_name=present_names.get(tgt, tgt), - prior_affinity=int(edge.get("affinity", 50)), - prior_trust=int(edge.get("trust", 50)), - prior_summary=edge.get("summary", "") or "", - recent_dialogue=recent_dialogue, - timeout_s=timeout_s, - ) - out.append((src, tgt, update)) - return out + pair_keys: list[tuple[str, str]] = [ + (src, tgt) + for src in present_ids + for tgt in present_ids + if src != tgt + ] + if not pair_keys: + return [] + + async def _one(src: str, tgt: str) -> StateUpdate: + edge = prior_edges.get((src, tgt), {}) + return await compute_state_update( + client, + model=classifier_model, + source_id=src, + target_id=tgt, + source_name=present_names.get(src, src), + source_persona=personas.get(src, "") or "", + target_name=present_names.get(tgt, tgt), + prior_affinity=int(edge.get("affinity", 50)), + prior_trust=int(edge.get("trust", 50)), + prior_summary=edge.get("summary", "") or "", + recent_dialogue=recent_dialogue, + timeout_s=timeout_s, + ) + + updates = await asyncio.gather(*(_one(src, tgt) for src, tgt in pair_keys)) + return [(src, tgt, upd) for (src, tgt), upd in zip(pair_keys, updates)] __all__ = ["compute_state_updates_for_present"] diff --git a/chat/web/kickoff.py b/chat/web/kickoff.py index b214147..9e85944 100644 --- a/chat/web/kickoff.py +++ b/chat/web/kickoff.py @@ -50,9 +50,24 @@ def get_llm_client(request: Request) -> LLMClient: api_key=settings.featherless_api_key, base_url=settings.featherless_base_url, ) + # Dedicated classifier client when a provider pin is configured — + # routes Llama-3.1-8B (or whatever ``classifier_model`` is) onto a + # specific upstream like Cerebras for ~10x throughput. When the + # pin is empty, ``classifier`` is None and the router falls back + # to the narrative client for classifier traffic. + classifier = None + if settings.classifier_provider_order: + classifier = FeatherlessClient( + api_key=settings.featherless_api_key, + base_url=settings.featherless_base_url, + default_extra_body={ + "provider": {"order": list(settings.classifier_provider_order)} + }, + ) local = LocalMLXClient(base_url=settings.local_mlx_base_url) return RoutedLLMClient( narrative=narrative, + classifier=classifier, local=local, narrative_model=settings.narrative_model, ) diff --git a/tests/test_router_client.py b/tests/test_router_client.py index 3ff9c45..3ecf0d7 100644 --- a/tests/test_router_client.py +++ b/tests/test_router_client.py @@ -36,58 +36,72 @@ class _StubClient: @pytest.mark.asyncio -async def test_router_generate_dispatches_narrative_to_narrative_backend(): +async def test_router_generate_routes_remote_model_to_remote_backend(): + """Any model id NOT starting with a local prefix goes to the remote + backend — narrative model, remote classifiers, anything else.""" narrative = _StubClient("narrative") local = _StubClient("local") router = RoutedLLMClient( narrative=narrative, local=local, - narrative_model="big-model", + narrative_model="provider/big-model", + local_prefixes=("mlx-community/",), ) - out = await router.generate([Message(role="user", content="hi")], model="big-model") + out = await router.generate( + [Message(role="user", content="hi")], model="provider/big-model" + ) - assert out == "narrative:big-model" - assert narrative.generate_calls == ["big-model"] + assert out == "narrative:provider/big-model" + assert narrative.generate_calls == ["provider/big-model"] assert local.generate_calls == [] @pytest.mark.asyncio -async def test_router_generate_dispatches_classifier_to_local_backend(): +async def test_router_generate_routes_local_prefix_to_local_backend(): + """Models prefixed with a local prefix (e.g. ``mlx-community/``) + go to the local MLX backend regardless of whether the rest of the + path looks like a remote provider id.""" narrative = _StubClient("narrative") local = _StubClient("local") router = RoutedLLMClient( narrative=narrative, local=local, - narrative_model="big-model", + narrative_model="provider/big-model", + local_prefixes=("mlx-community/",), ) out = await router.generate( - [Message(role="user", content="hi")], model="small-model" + [Message(role="user", content="hi")], + model="mlx-community/Hermes-3-Llama-3.1-8B-8bit", ) - assert out == "local:small-model" - assert local.generate_calls == ["small-model"] + assert out == "local:mlx-community/Hermes-3-Llama-3.1-8B-8bit" + assert local.generate_calls == ["mlx-community/Hermes-3-Llama-3.1-8B-8bit"] assert narrative.generate_calls == [] @pytest.mark.asyncio -async def test_router_stream_dispatches_by_model(): +async def test_router_stream_dispatches_by_prefix(): narrative = _StubClient("narrative") local = _StubClient("local") router = RoutedLLMClient( - narrative=narrative, local=local, narrative_model="big-model" + narrative=narrative, + local=local, + narrative_model="provider/big-model", + local_prefixes=("mlx-community/",), ) - chunks_big = [c async for c in router.stream( - [Message(role="user", content="hi")], model="big-model" + chunks_remote = [c async for c in router.stream( + [Message(role="user", content="hi")], model="provider/big-model" )] - chunks_small = [c async for c in router.stream( - [Message(role="user", content="hi")], model="other-model" + chunks_local = [c async for c in router.stream( + [Message(role="user", content="hi")], + model="mlx-community/Hermes-3-Llama-3.1-8B-8bit", )] - assert chunks_big == ["narrative:big-model"] - assert chunks_small == ["local:other-model"] + assert chunks_remote == ["narrative:provider/big-model"] + assert chunks_local == ["local:mlx-community/Hermes-3-Llama-3.1-8B-8bit"] @pytest.mark.asyncio