perf: 18s/turn -> 2.5s/turn (SQLite busy_timeout, parallel state pairs, OpenRouter Cerebras-pinned classifier)
Four changes that compound: 1) **SQLite busy_timeout 5.0s -> 0.1s** in chat/db/connection.py. Root cause of the bulk of the slowness. The embedding worker contends for the WAL write lock while the request handler holds an open transaction; conn.execute's busy-wait does NOT release the GIL, so every state_update LLM call after the narrative was silently freezing the asyncio event loop for ~5s. With 0.1s the worker fails fast and logs (already handled), the chat keeps moving, and any missed embedding can be backfilled out of band. Also takes the test suite from ~290s -> 13s as a bonus. 2) **Parallel state-update pairs** in multi_state_update.py. Each directed (src, tgt) pair becomes a coroutine in asyncio.gather instead of a sequential for-loop. Returned order is preserved. 3) **Classifier on OpenRouter, provider-pinned to Cerebras**. New prefix-based router: model id with mlx-community/ -> local MLX, model == narrative_model -> narrative remote, else -> classifier remote. Settings.classifier_provider_order populates extra_body for the classifier client only (FeatherlessClient now accepts default_extra_body to merge into every chat.completions.create). Llama-3.1-8B on Cerebras runs at ~423 tok/s, ~10x the default provider. narrative still routes to mistral-nemo:nitro (Friendli). 4) **Cap classify max_tokens at 512**. A misbehaving classifier (response_format=json_object ignored) could otherwise generate thousands of tokens of prose before classify's JSON validation trips the retry. 512 is generous; usual completions are 50-150. CHAT_LLM_TIMING=1 env var enables per-call timing logs on stderr; zero overhead when unset. Useful for finding the slow link. Suite: 464 passed in 13s (was 290s).
This commit is contained in:
+12
@@ -88,9 +88,21 @@ async def lifespan(app: FastAPI):
|
|||||||
api_key=settings.featherless_api_key,
|
api_key=settings.featherless_api_key,
|
||||||
base_url=settings.featherless_base_url,
|
base_url=settings.featherless_base_url,
|
||||||
)
|
)
|
||||||
|
classifier = None
|
||||||
|
if settings.classifier_provider_order:
|
||||||
|
classifier = FeatherlessClient(
|
||||||
|
api_key=settings.featherless_api_key,
|
||||||
|
base_url=settings.featherless_base_url,
|
||||||
|
default_extra_body={
|
||||||
|
"provider": {
|
||||||
|
"order": list(settings.classifier_provider_order)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
local = LocalMLXClient(base_url=settings.local_mlx_base_url)
|
local = LocalMLXClient(base_url=settings.local_mlx_base_url)
|
||||||
return RoutedLLMClient(
|
return RoutedLLMClient(
|
||||||
narrative=narrative,
|
narrative=narrative,
|
||||||
|
classifier=classifier,
|
||||||
local=local,
|
local=local,
|
||||||
narrative_model=settings.narrative_model,
|
narrative_model=settings.narrative_model,
|
||||||
)
|
)
|
||||||
|
|||||||
+12
-8
@@ -44,16 +44,20 @@ class Settings(BaseModel):
|
|||||||
bind_host: str = "127.0.0.1"
|
bind_host: str = "127.0.0.1"
|
||||||
bind_port: int = 8000
|
bind_port: int = 8000
|
||||||
# Local MLX server (e.g. ``mlx-omni-server``) — serves any model
|
# Local MLX server (e.g. ``mlx-omni-server``) — serves any model
|
||||||
# whose id is NOT ``narrative_model``. The :class:`RoutedLLMClient`
|
# whose id starts with one of ``local_prefixes`` (default
|
||||||
# inspects the ``model`` kwarg at call time: ``model == narrative_model``
|
# ``"mlx-community/"``). The :class:`RoutedLLMClient` inspects the
|
||||||
# -> Featherless, otherwise local. ``embed()`` always routes local.
|
# ``model`` kwarg at call time: local-prefix -> local, else -> remote.
|
||||||
# If no MLX server is running and ``classifier_model`` is set to a
|
# ``embed()`` always routes local.
|
||||||
# local model id, classifier calls will fail; ``embed()`` will fall
|
|
||||||
# back to the zero-vector path with a T107 warning. To use the
|
|
||||||
# default Featherless-only deployment, leave ``classifier_model``
|
|
||||||
# at its remote default and ``embedding_model`` at the pseudo.
|
|
||||||
local_mlx_base_url: str = "http://127.0.0.1:10240/v1"
|
local_mlx_base_url: str = "http://127.0.0.1:10240/v1"
|
||||||
local_mlx_max_concurrent: int = 1
|
local_mlx_max_concurrent: int = 1
|
||||||
|
# Optional OpenRouter-style provider pinning for the classifier
|
||||||
|
# client. Maps to the ``provider`` field on chat.completions.create
|
||||||
|
# via ``extra_body``; the FeatherlessClient (which is just an
|
||||||
|
# AsyncOpenAI wrapper) merges it into every call. Useful for forcing
|
||||||
|
# Llama-3.1-8B classifier traffic onto Cerebras (~423 tok/s, 10x
|
||||||
|
# the default Nebius). Empty list = no pin (provider is
|
||||||
|
# OpenRouter's choice).
|
||||||
|
classifier_provider_order: list[str] = Field(default_factory=list)
|
||||||
# T112 (Phase 4.5): embedding model identifier. Default is the
|
# T112 (Phase 4.5): embedding model identifier. Default is the
|
||||||
# deterministic local pseudo so fresh installs / tests don't need
|
# deterministic local pseudo so fresh installs / tests don't need
|
||||||
# any external infra. Override via config.toml to a real model id
|
# any external infra. Override via config.toml to a real model id
|
||||||
|
|||||||
+14
-1
@@ -7,7 +7,20 @@ from pathlib import Path
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def open_db(path: Path, *, check_same_thread: bool = True):
|
def open_db(path: Path, *, check_same_thread: bool = True):
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
conn = sqlite3.connect(path, check_same_thread=check_same_thread)
|
# ``timeout`` here sets SQLite's busy_timeout, in seconds: how long
|
||||||
|
# ``conn.execute`` blocks when another connection holds the WAL
|
||||||
|
# write lock. The Python default is 5.0, which is fatal for the
|
||||||
|
# async chat app: ``conn.execute``'s busy-wait does NOT release the
|
||||||
|
# GIL, so a contending background worker (e.g. the embedding worker
|
||||||
|
# writing ``embedding_indexed`` while the request handler holds an
|
||||||
|
# open transaction) freezes the whole asyncio event loop for up to
|
||||||
|
# 5 seconds — silently turning every concurrent LLM call into a 5s
|
||||||
|
# wall-clock hit. 0.1s lets contending writers fail fast; callers
|
||||||
|
# that need durability should retry, and the embedding worker
|
||||||
|
# already logs failures so a missed embedding can be backfilled.
|
||||||
|
conn = sqlite3.connect(
|
||||||
|
path, check_same_thread=check_same_thread, timeout=0.1
|
||||||
|
)
|
||||||
conn.execute("PRAGMA journal_mode=WAL")
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
conn.execute("PRAGMA foreign_keys=ON")
|
conn.execute("PRAGMA foreign_keys=ON")
|
||||||
try:
|
try:
|
||||||
|
|||||||
+12
-1
@@ -31,6 +31,7 @@ async def classify(
|
|||||||
schema: type[T],
|
schema: type[T],
|
||||||
default: T | None = None,
|
default: T | None = None,
|
||||||
timeout_s: float = 10.0,
|
timeout_s: float = 10.0,
|
||||||
|
max_tokens: int = 512,
|
||||||
) -> T:
|
) -> T:
|
||||||
schema_json = json.dumps(schema.model_json_schema(), indent=2)
|
schema_json = json.dumps(schema.model_json_schema(), indent=2)
|
||||||
schema_block = (
|
schema_block = (
|
||||||
@@ -41,10 +42,20 @@ async def classify(
|
|||||||
Message(role="system", content=system + schema_block),
|
Message(role="system", content=system + schema_block),
|
||||||
Message(role="user", content=user),
|
Message(role="user", content=user),
|
||||||
]
|
]
|
||||||
|
# Cap output length so a misbehaving model (e.g. one that ignores
|
||||||
|
# ``response_format=json_object`` and generates prose) can't burn
|
||||||
|
# several seconds on tokens we'll never use. Classifier responses
|
||||||
|
# are small JSON objects — 512 tokens is generous; usual completions
|
||||||
|
# are 50-150.
|
||||||
for attempt in range(3):
|
for attempt in range(3):
|
||||||
try:
|
try:
|
||||||
text = await asyncio.wait_for(
|
text = await asyncio.wait_for(
|
||||||
client.generate(msgs, model=model, response_format={"type": "json_object"}),
|
client.generate(
|
||||||
|
msgs,
|
||||||
|
model=model,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
),
|
||||||
timeout=timeout_s,
|
timeout=timeout_s,
|
||||||
)
|
)
|
||||||
cleaned = _strip_json_fences(text)
|
cleaned = _strip_json_fences(text)
|
||||||
|
|||||||
+42
-1
@@ -29,19 +29,60 @@ class FeatherlessClient:
|
|||||||
cls._semaphore = asyncio.Semaphore(2)
|
cls._semaphore = asyncio.Semaphore(2)
|
||||||
return cls._semaphore
|
return cls._semaphore
|
||||||
|
|
||||||
def __init__(self, api_key: str, base_url: str = "https://api.featherless.ai/v1"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str = "https://api.featherless.ai/v1",
|
||||||
|
*,
|
||||||
|
default_extra_body: dict | None = None,
|
||||||
|
):
|
||||||
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
# ``default_extra_body`` is merged into every chat.completions.create
|
||||||
|
# call's ``extra_body``. Useful with OpenRouter to pin specific
|
||||||
|
# upstream providers (e.g. ``{"provider": {"order": ["Cerebras"]}}``
|
||||||
|
# for 10x throughput on Llama-3.1-8B). Featherless ignores the
|
||||||
|
# field, so it's safe to leave set even when ``base_url`` points
|
||||||
|
# back at Featherless.
|
||||||
|
self._default_extra_body = default_extra_body or {}
|
||||||
|
|
||||||
|
def _merge_extra_body(self, params: dict) -> dict:
|
||||||
|
if not self._default_extra_body:
|
||||||
|
return params
|
||||||
|
eb = dict(self._default_extra_body)
|
||||||
|
eb.update(params.pop("extra_body", {}) or {})
|
||||||
|
params["extra_body"] = eb
|
||||||
|
return params
|
||||||
|
|
||||||
async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str:
|
async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str:
|
||||||
|
params = self._merge_extra_body(dict(params))
|
||||||
async with self._sem():
|
async with self._sem():
|
||||||
resp = await self._client.chat.completions.create(
|
resp = await self._client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[{"role": m.role, "content": m.content} for m in messages],
|
messages=[{"role": m.role, "content": m.content} for m in messages],
|
||||||
**params,
|
**params,
|
||||||
)
|
)
|
||||||
|
# Diagnostic: stash provider+usage on a side-channel for the
|
||||||
|
# router timing log to pick up. OpenRouter sticks a 'provider'
|
||||||
|
# field on the response (not part of the OAI spec, but the
|
||||||
|
# SDK passes it through on its model dict).
|
||||||
|
try: # pragma: no cover — diagnostic only
|
||||||
|
import os as _os
|
||||||
|
if _os.environ.get("CHAT_LLM_TIMING") == "1":
|
||||||
|
prov = getattr(resp, "provider", None)
|
||||||
|
usage = getattr(resp, "usage", None)
|
||||||
|
ct = getattr(usage, "completion_tokens", "?") if usage else "?"
|
||||||
|
pt = getattr(usage, "prompt_tokens", "?") if usage else "?"
|
||||||
|
import logging as _logging
|
||||||
|
_logging.getLogger("chat.llm.router").info(
|
||||||
|
" ↪ provider=%s prompt_toks=%s completion_toks=%s",
|
||||||
|
prov, pt, ct,
|
||||||
|
)
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
pass
|
||||||
return resp.choices[0].message.content or ""
|
return resp.choices[0].message.content or ""
|
||||||
|
|
||||||
async def stream(self, messages: Sequence[Message], *, model: str, **params) -> AsyncIterator[str]:
|
async def stream(self, messages: Sequence[Message], *, model: str, **params) -> AsyncIterator[str]:
|
||||||
|
params = self._merge_extra_body(dict(params))
|
||||||
async with self._sem():
|
async with self._sem():
|
||||||
stream = await self._client.chat.completions.create(
|
stream = await self._client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
|
|||||||
+100
-10
@@ -9,22 +9,44 @@ Routing rule: requests whose ``model`` argument matches the configured
|
|||||||
``narrative_model`` go to the narrative backend; everything else
|
``narrative_model`` go to the narrative backend; everything else
|
||||||
(classifier, embeddings, future locally-hosted models) goes to the
|
(classifier, embeddings, future locally-hosted models) goes to the
|
||||||
local backend.
|
local backend.
|
||||||
|
|
||||||
|
Set the env var ``CHAT_LLM_TIMING=1`` to log per-call timing at INFO
|
||||||
|
level. Useful for finding the slow link in a turn.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
from typing import AsyncIterator, Sequence
|
from typing import AsyncIterator, Sequence
|
||||||
|
|
||||||
from .client import LLMClient, Message
|
from .client import LLMClient, Message
|
||||||
|
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
_TIMING = os.environ.get("CHAT_LLM_TIMING") == "1"
|
||||||
|
if _TIMING and not _log.handlers:
|
||||||
|
# Wire a stderr handler when timing is enabled so the per-call
|
||||||
|
# logs show up under uvicorn (which doesn't configure non-uvicorn
|
||||||
|
# loggers by default).
|
||||||
|
_h = logging.StreamHandler()
|
||||||
|
_h.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s"))
|
||||||
|
_log.addHandler(_h)
|
||||||
|
_log.setLevel(logging.INFO)
|
||||||
|
_log.propagate = False
|
||||||
|
|
||||||
|
|
||||||
class RoutedLLMClient:
|
class RoutedLLMClient:
|
||||||
"""Delegates to one of two underlying clients based on ``model``.
|
"""Delegates to one of two underlying clients based on ``model``.
|
||||||
|
|
||||||
The narrative client is exercised only when ``model`` exactly equals
|
Routing rule: any model id starting with one of ``local_prefixes``
|
||||||
``narrative_model`` (the configured remote model id). Everything
|
goes to the local backend (e.g. ``"mlx-community/"`` for models
|
||||||
else — classifier traffic, embeddings, any future locally-hosted
|
served by ``mlx-omni-server``). Everything else — narrative model,
|
||||||
model — routes to the local client. ``embed`` always routes locally
|
remote classifiers, anything on a hosted provider — routes to the
|
||||||
(the remote provider doesn't support it; see
|
remote backend.
|
||||||
|
|
||||||
|
``embed`` always routes locally (the remote provider doesn't
|
||||||
|
expose a working ``/v1/embeddings``; see
|
||||||
:class:`chat.llm.featherless.FeatherlessClient.embed`).
|
:class:`chat.llm.featherless.FeatherlessClient.embed`).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -34,26 +56,94 @@ class RoutedLLMClient:
|
|||||||
narrative: LLMClient,
|
narrative: LLMClient,
|
||||||
local: LLMClient,
|
local: LLMClient,
|
||||||
narrative_model: str,
|
narrative_model: str,
|
||||||
|
classifier: LLMClient | None = None,
|
||||||
|
local_prefixes: tuple[str, ...] = ("mlx-community/",),
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# ``classifier`` is an optional separate backend for the
|
||||||
|
# classifier model. Useful when classifier and narrative both
|
||||||
|
# live on a remote OpenRouter-style provider but need different
|
||||||
|
# provider-pinning (e.g. Cerebras for the 8B classifier,
|
||||||
|
# default Friendli/etc. for the narrative). When ``classifier``
|
||||||
|
# is None, classifier traffic falls through to ``narrative``
|
||||||
|
# (the remote client) so old wiring keeps working.
|
||||||
self._narrative = narrative
|
self._narrative = narrative
|
||||||
|
self._classifier = classifier
|
||||||
self._local = local
|
self._local = local
|
||||||
self._narrative_model = narrative_model
|
self._narrative_model = narrative_model
|
||||||
|
self._local_prefixes = local_prefixes
|
||||||
|
|
||||||
def _pick(self, model: str) -> LLMClient:
|
def _pick(self, model: str) -> LLMClient:
|
||||||
return self._narrative if model == self._narrative_model else self._local
|
if any(model.startswith(p) for p in self._local_prefixes):
|
||||||
|
return self._local
|
||||||
|
if model == self._narrative_model:
|
||||||
|
return self._narrative
|
||||||
|
# Anything else (most importantly, the classifier model) goes
|
||||||
|
# to the classifier client when configured, otherwise to the
|
||||||
|
# narrative remote client.
|
||||||
|
return self._classifier or self._narrative
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self, messages: Sequence[Message], *, model: str, **params
|
self, messages: Sequence[Message], *, model: str, **params
|
||||||
) -> str:
|
) -> str:
|
||||||
return await self._pick(model).generate(messages, model=model, **params)
|
client = self._pick(model)
|
||||||
|
backend = (
|
||||||
|
"narrative" if client is self._narrative else
|
||||||
|
"classifier" if client is self._classifier else
|
||||||
|
"local"
|
||||||
|
)
|
||||||
|
if not _TIMING:
|
||||||
|
return await client.generate(messages, model=model, **params)
|
||||||
|
in_chars = sum(len(m.content) for m in messages)
|
||||||
|
_log.info("LLM generate START [%s] %s in_chars=%d", backend, model, in_chars)
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
try:
|
||||||
|
return await client.generate(messages, model=model, **params)
|
||||||
|
finally:
|
||||||
|
_log.info(
|
||||||
|
"LLM generate END [%s] %s in_chars=%d %.2fs",
|
||||||
|
backend, model, in_chars, time.perf_counter() - t0,
|
||||||
|
)
|
||||||
|
|
||||||
async def stream(
|
async def stream(
|
||||||
self, messages: Sequence[Message], *, model: str, **params
|
self, messages: Sequence[Message], *, model: str, **params
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
async for chunk in self._pick(model).stream(messages, model=model, **params):
|
client = self._pick(model)
|
||||||
yield chunk
|
backend = (
|
||||||
|
"narrative" if client is self._narrative else
|
||||||
|
"classifier" if client is self._classifier else
|
||||||
|
"local"
|
||||||
|
)
|
||||||
|
if not _TIMING:
|
||||||
|
async for chunk in client.stream(messages, model=model, **params):
|
||||||
|
yield chunk
|
||||||
|
return
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
ttft = None
|
||||||
|
chars_out = 0
|
||||||
|
try:
|
||||||
|
async for chunk in client.stream(messages, model=model, **params):
|
||||||
|
if ttft is None:
|
||||||
|
ttft = time.perf_counter() - t0
|
||||||
|
chars_out += len(chunk)
|
||||||
|
yield chunk
|
||||||
|
finally:
|
||||||
|
dt = time.perf_counter() - t0
|
||||||
|
in_chars = sum(len(m.content) for m in messages)
|
||||||
|
_log.info(
|
||||||
|
"LLM stream [%s] %s in_chars=%d out_chars=%d ttft=%.2fs total=%.2fs",
|
||||||
|
backend, model, in_chars, chars_out, ttft or 0.0, dt,
|
||||||
|
)
|
||||||
|
|
||||||
async def embed(self, text: str, *, model: str) -> list[float]:
|
async def embed(self, text: str, *, model: str) -> list[float]:
|
||||||
# Embeddings always run on the local backend — the remote
|
# Embeddings always run on the local backend — the remote
|
||||||
# provider doesn't expose a working ``/v1/embeddings`` endpoint.
|
# provider doesn't expose a working ``/v1/embeddings`` endpoint.
|
||||||
return await self._local.embed(text, model=model)
|
if not _TIMING:
|
||||||
|
return await self._local.embed(text, model=model)
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
try:
|
||||||
|
return await self._local.embed(text, model=model)
|
||||||
|
finally:
|
||||||
|
_log.info(
|
||||||
|
"LLM embed [local] %s in_chars=%d %.2fs",
|
||||||
|
model, len(text), time.perf_counter() - t0,
|
||||||
|
)
|
||||||
|
|||||||
@@ -4,13 +4,15 @@ Wraps single-pair compute_state_update to run state updates for ALL
|
|||||||
directed pairs of present entities. With 3 present entities (you, host,
|
directed pairs of present entities. With 3 present entities (you, host,
|
||||||
guest) that's 6 directed pairs. With 2 present (you, host) it's 2 pairs.
|
guest) that's 6 directed pairs. With 2 present (you, host) it's 2 pairs.
|
||||||
|
|
||||||
Calls run sequentially to respect Featherless's 2-connection cap (the
|
Pairs run concurrently via :func:`asyncio.gather`; the underlying
|
||||||
client-level semaphore would serialize them anyway, but doing it here
|
client should impose its own concurrency cap if the upstream provider
|
||||||
keeps the failure surface clean — a hung pair doesn't queue behind
|
needs it (e.g., Featherless's 2-conn semaphore). Returning order is
|
||||||
itself).
|
preserved (natural iteration over ``present_ids x present_ids``,
|
||||||
|
src != tgt) so downstream event-append order stays deterministic.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from chat.llm.client import LLMClient
|
from chat.llm.client import LLMClient
|
||||||
from chat.services.state_update import StateUpdate, compute_state_update
|
from chat.services.state_update import StateUpdate, compute_state_update
|
||||||
@@ -28,35 +30,44 @@ async def compute_state_updates_for_present(
|
|||||||
timeout_s: float = 30.0,
|
timeout_s: float = 30.0,
|
||||||
) -> list[tuple[str, str, StateUpdate]]:
|
) -> list[tuple[str, str, StateUpdate]]:
|
||||||
"""Run compute_state_update for every directed pair (src != tgt) over
|
"""Run compute_state_update for every directed pair (src != tgt) over
|
||||||
``present_ids``. Returns list of ``(source_id, target_id, update)``
|
``present_ids``, concurrently. Returns list of
|
||||||
tuples in the natural iteration order over ``present_ids x present_ids``.
|
``(source_id, target_id, update)`` tuples in the natural iteration
|
||||||
|
order over ``present_ids x present_ids`` — concurrent dispatch does
|
||||||
|
not change the returned order.
|
||||||
|
|
||||||
A single failing pair falls back to the schema-default StateUpdate
|
A single failing pair falls back to the schema-default StateUpdate
|
||||||
(zero deltas, empty facts) inside ``compute_state_update``; the batch
|
(zero deltas, empty facts) inside ``compute_state_update``; sibling
|
||||||
keeps going.
|
pairs continue independently because each call is wrapped in its
|
||||||
|
own try/except inside ``compute_state_update``.
|
||||||
"""
|
"""
|
||||||
out: list[tuple[str, str, StateUpdate]] = []
|
pair_keys: list[tuple[str, str]] = [
|
||||||
for src in present_ids:
|
(src, tgt)
|
||||||
for tgt in present_ids:
|
for src in present_ids
|
||||||
if src == tgt:
|
for tgt in present_ids
|
||||||
continue
|
if src != tgt
|
||||||
edge = prior_edges.get((src, tgt), {})
|
]
|
||||||
update = await compute_state_update(
|
if not pair_keys:
|
||||||
client,
|
return []
|
||||||
model=classifier_model,
|
|
||||||
source_id=src,
|
async def _one(src: str, tgt: str) -> StateUpdate:
|
||||||
target_id=tgt,
|
edge = prior_edges.get((src, tgt), {})
|
||||||
source_name=present_names.get(src, src),
|
return await compute_state_update(
|
||||||
source_persona=personas.get(src, "") or "",
|
client,
|
||||||
target_name=present_names.get(tgt, tgt),
|
model=classifier_model,
|
||||||
prior_affinity=int(edge.get("affinity", 50)),
|
source_id=src,
|
||||||
prior_trust=int(edge.get("trust", 50)),
|
target_id=tgt,
|
||||||
prior_summary=edge.get("summary", "") or "",
|
source_name=present_names.get(src, src),
|
||||||
recent_dialogue=recent_dialogue,
|
source_persona=personas.get(src, "") or "",
|
||||||
timeout_s=timeout_s,
|
target_name=present_names.get(tgt, tgt),
|
||||||
)
|
prior_affinity=int(edge.get("affinity", 50)),
|
||||||
out.append((src, tgt, update))
|
prior_trust=int(edge.get("trust", 50)),
|
||||||
return out
|
prior_summary=edge.get("summary", "") or "",
|
||||||
|
recent_dialogue=recent_dialogue,
|
||||||
|
timeout_s=timeout_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
updates = await asyncio.gather(*(_one(src, tgt) for src, tgt in pair_keys))
|
||||||
|
return [(src, tgt, upd) for (src, tgt), upd in zip(pair_keys, updates)]
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["compute_state_updates_for_present"]
|
__all__ = ["compute_state_updates_for_present"]
|
||||||
|
|||||||
@@ -50,9 +50,24 @@ def get_llm_client(request: Request) -> LLMClient:
|
|||||||
api_key=settings.featherless_api_key,
|
api_key=settings.featherless_api_key,
|
||||||
base_url=settings.featherless_base_url,
|
base_url=settings.featherless_base_url,
|
||||||
)
|
)
|
||||||
|
# Dedicated classifier client when a provider pin is configured —
|
||||||
|
# routes Llama-3.1-8B (or whatever ``classifier_model`` is) onto a
|
||||||
|
# specific upstream like Cerebras for ~10x throughput. When the
|
||||||
|
# pin is empty, ``classifier`` is None and the router falls back
|
||||||
|
# to the narrative client for classifier traffic.
|
||||||
|
classifier = None
|
||||||
|
if settings.classifier_provider_order:
|
||||||
|
classifier = FeatherlessClient(
|
||||||
|
api_key=settings.featherless_api_key,
|
||||||
|
base_url=settings.featherless_base_url,
|
||||||
|
default_extra_body={
|
||||||
|
"provider": {"order": list(settings.classifier_provider_order)}
|
||||||
|
},
|
||||||
|
)
|
||||||
local = LocalMLXClient(base_url=settings.local_mlx_base_url)
|
local = LocalMLXClient(base_url=settings.local_mlx_base_url)
|
||||||
return RoutedLLMClient(
|
return RoutedLLMClient(
|
||||||
narrative=narrative,
|
narrative=narrative,
|
||||||
|
classifier=classifier,
|
||||||
local=local,
|
local=local,
|
||||||
narrative_model=settings.narrative_model,
|
narrative_model=settings.narrative_model,
|
||||||
)
|
)
|
||||||
|
|||||||
+32
-18
@@ -36,58 +36,72 @@ class _StubClient:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_router_generate_dispatches_narrative_to_narrative_backend():
|
async def test_router_generate_routes_remote_model_to_remote_backend():
|
||||||
|
"""Any model id NOT starting with a local prefix goes to the remote
|
||||||
|
backend — narrative model, remote classifiers, anything else."""
|
||||||
narrative = _StubClient("narrative")
|
narrative = _StubClient("narrative")
|
||||||
local = _StubClient("local")
|
local = _StubClient("local")
|
||||||
router = RoutedLLMClient(
|
router = RoutedLLMClient(
|
||||||
narrative=narrative,
|
narrative=narrative,
|
||||||
local=local,
|
local=local,
|
||||||
narrative_model="big-model",
|
narrative_model="provider/big-model",
|
||||||
|
local_prefixes=("mlx-community/",),
|
||||||
)
|
)
|
||||||
|
|
||||||
out = await router.generate([Message(role="user", content="hi")], model="big-model")
|
out = await router.generate(
|
||||||
|
[Message(role="user", content="hi")], model="provider/big-model"
|
||||||
|
)
|
||||||
|
|
||||||
assert out == "narrative:big-model"
|
assert out == "narrative:provider/big-model"
|
||||||
assert narrative.generate_calls == ["big-model"]
|
assert narrative.generate_calls == ["provider/big-model"]
|
||||||
assert local.generate_calls == []
|
assert local.generate_calls == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_router_generate_dispatches_classifier_to_local_backend():
|
async def test_router_generate_routes_local_prefix_to_local_backend():
|
||||||
|
"""Models prefixed with a local prefix (e.g. ``mlx-community/``)
|
||||||
|
go to the local MLX backend regardless of whether the rest of the
|
||||||
|
path looks like a remote provider id."""
|
||||||
narrative = _StubClient("narrative")
|
narrative = _StubClient("narrative")
|
||||||
local = _StubClient("local")
|
local = _StubClient("local")
|
||||||
router = RoutedLLMClient(
|
router = RoutedLLMClient(
|
||||||
narrative=narrative,
|
narrative=narrative,
|
||||||
local=local,
|
local=local,
|
||||||
narrative_model="big-model",
|
narrative_model="provider/big-model",
|
||||||
|
local_prefixes=("mlx-community/",),
|
||||||
)
|
)
|
||||||
|
|
||||||
out = await router.generate(
|
out = await router.generate(
|
||||||
[Message(role="user", content="hi")], model="small-model"
|
[Message(role="user", content="hi")],
|
||||||
|
model="mlx-community/Hermes-3-Llama-3.1-8B-8bit",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert out == "local:small-model"
|
assert out == "local:mlx-community/Hermes-3-Llama-3.1-8B-8bit"
|
||||||
assert local.generate_calls == ["small-model"]
|
assert local.generate_calls == ["mlx-community/Hermes-3-Llama-3.1-8B-8bit"]
|
||||||
assert narrative.generate_calls == []
|
assert narrative.generate_calls == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_router_stream_dispatches_by_model():
|
async def test_router_stream_dispatches_by_prefix():
|
||||||
narrative = _StubClient("narrative")
|
narrative = _StubClient("narrative")
|
||||||
local = _StubClient("local")
|
local = _StubClient("local")
|
||||||
router = RoutedLLMClient(
|
router = RoutedLLMClient(
|
||||||
narrative=narrative, local=local, narrative_model="big-model"
|
narrative=narrative,
|
||||||
|
local=local,
|
||||||
|
narrative_model="provider/big-model",
|
||||||
|
local_prefixes=("mlx-community/",),
|
||||||
)
|
)
|
||||||
|
|
||||||
chunks_big = [c async for c in router.stream(
|
chunks_remote = [c async for c in router.stream(
|
||||||
[Message(role="user", content="hi")], model="big-model"
|
[Message(role="user", content="hi")], model="provider/big-model"
|
||||||
)]
|
)]
|
||||||
chunks_small = [c async for c in router.stream(
|
chunks_local = [c async for c in router.stream(
|
||||||
[Message(role="user", content="hi")], model="other-model"
|
[Message(role="user", content="hi")],
|
||||||
|
model="mlx-community/Hermes-3-Llama-3.1-8B-8bit",
|
||||||
)]
|
)]
|
||||||
|
|
||||||
assert chunks_big == ["narrative:big-model"]
|
assert chunks_remote == ["narrative:provider/big-model"]
|
||||||
assert chunks_small == ["local:other-model"]
|
assert chunks_local == ["local:mlx-community/Hermes-3-Llama-3.1-8B-8bit"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Reference in New Issue
Block a user