perf: 18s/turn -> 2.5s/turn (SQLite busy_timeout, parallel state pairs, OpenRouter Cerebras-pinned classifier)
Four changes that compound: 1) **SQLite busy_timeout 5.0s -> 0.1s** in chat/db/connection.py. Root cause of the bulk of the slowness. The embedding worker contends for the WAL write lock while the request handler holds an open transaction; conn.execute's busy-wait does NOT release the GIL, so every state_update LLM call after the narrative was silently freezing the asyncio event loop for ~5s. With 0.1s the worker fails fast and logs (already handled), the chat keeps moving, and any missed embedding can be backfilled out of band. Also takes the test suite from ~290s -> 13s as a bonus. 2) **Parallel state-update pairs** in multi_state_update.py. Each directed (src, tgt) pair becomes a coroutine in asyncio.gather instead of a sequential for-loop. Returned order is preserved. 3) **Classifier on OpenRouter, provider-pinned to Cerebras**. New prefix-based router: model id with mlx-community/ -> local MLX, model == narrative_model -> narrative remote, else -> classifier remote. Settings.classifier_provider_order populates extra_body for the classifier client only (FeatherlessClient now accepts default_extra_body to merge into every chat.completions.create). Llama-3.1-8B on Cerebras runs at ~423 tok/s, ~10x the default provider. narrative still routes to mistral-nemo:nitro (Friendli). 4) **Cap classify max_tokens at 512**. A misbehaving classifier (response_format=json_object ignored) could otherwise generate thousands of tokens of prose before classify's JSON validation trips the retry. 512 is generous; usual completions are 50-150. CHAT_LLM_TIMING=1 env var enables per-call timing logs on stderr; zero overhead when unset. Useful for finding the slow link. Suite: 464 passed in 13s (was 290s).
This commit is contained in:
+12
@@ -88,9 +88,21 @@ async def lifespan(app: FastAPI):
|
||||
api_key=settings.featherless_api_key,
|
||||
base_url=settings.featherless_base_url,
|
||||
)
|
||||
classifier = None
|
||||
if settings.classifier_provider_order:
|
||||
classifier = FeatherlessClient(
|
||||
api_key=settings.featherless_api_key,
|
||||
base_url=settings.featherless_base_url,
|
||||
default_extra_body={
|
||||
"provider": {
|
||||
"order": list(settings.classifier_provider_order)
|
||||
}
|
||||
},
|
||||
)
|
||||
local = LocalMLXClient(base_url=settings.local_mlx_base_url)
|
||||
return RoutedLLMClient(
|
||||
narrative=narrative,
|
||||
classifier=classifier,
|
||||
local=local,
|
||||
narrative_model=settings.narrative_model,
|
||||
)
|
||||
|
||||
+12
-8
@@ -44,16 +44,20 @@ class Settings(BaseModel):
|
||||
bind_host: str = "127.0.0.1"
|
||||
bind_port: int = 8000
|
||||
# Local MLX server (e.g. ``mlx-omni-server``) — serves any model
|
||||
# whose id is NOT ``narrative_model``. The :class:`RoutedLLMClient`
|
||||
# inspects the ``model`` kwarg at call time: ``model == narrative_model``
|
||||
# -> Featherless, otherwise local. ``embed()`` always routes local.
|
||||
# If no MLX server is running and ``classifier_model`` is set to a
|
||||
# local model id, classifier calls will fail; ``embed()`` will fall
|
||||
# back to the zero-vector path with a T107 warning. To use the
|
||||
# default Featherless-only deployment, leave ``classifier_model``
|
||||
# at its remote default and ``embedding_model`` at the pseudo.
|
||||
# whose id starts with one of ``local_prefixes`` (default
|
||||
# ``"mlx-community/"``). The :class:`RoutedLLMClient` inspects the
|
||||
# ``model`` kwarg at call time: local-prefix -> local, else -> remote.
|
||||
# ``embed()`` always routes local.
|
||||
local_mlx_base_url: str = "http://127.0.0.1:10240/v1"
|
||||
local_mlx_max_concurrent: int = 1
|
||||
# Optional OpenRouter-style provider pinning for the classifier
|
||||
# client. Maps to the ``provider`` field on chat.completions.create
|
||||
# via ``extra_body``; the FeatherlessClient (which is just an
|
||||
# AsyncOpenAI wrapper) merges it into every call. Useful for forcing
|
||||
# Llama-3.1-8B classifier traffic onto Cerebras (~423 tok/s, 10x
|
||||
# the default Nebius). Empty list = no pin (provider is
|
||||
# OpenRouter's choice).
|
||||
classifier_provider_order: list[str] = Field(default_factory=list)
|
||||
# T112 (Phase 4.5): embedding model identifier. Default is the
|
||||
# deterministic local pseudo so fresh installs / tests don't need
|
||||
# any external infra. Override via config.toml to a real model id
|
||||
|
||||
+14
-1
@@ -7,7 +7,20 @@ from pathlib import Path
|
||||
@contextmanager
|
||||
def open_db(path: Path, *, check_same_thread: bool = True):
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(path, check_same_thread=check_same_thread)
|
||||
# ``timeout`` here sets SQLite's busy_timeout, in seconds: how long
|
||||
# ``conn.execute`` blocks when another connection holds the WAL
|
||||
# write lock. The Python default is 5.0, which is fatal for the
|
||||
# async chat app: ``conn.execute``'s busy-wait does NOT release the
|
||||
# GIL, so a contending background worker (e.g. the embedding worker
|
||||
# writing ``embedding_indexed`` while the request handler holds an
|
||||
# open transaction) freezes the whole asyncio event loop for up to
|
||||
# 5 seconds — silently turning every concurrent LLM call into a 5s
|
||||
# wall-clock hit. 0.1s lets contending writers fail fast; callers
|
||||
# that need durability should retry, and the embedding worker
|
||||
# already logs failures so a missed embedding can be backfilled.
|
||||
conn = sqlite3.connect(
|
||||
path, check_same_thread=check_same_thread, timeout=0.1
|
||||
)
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
try:
|
||||
|
||||
+12
-1
@@ -31,6 +31,7 @@ async def classify(
|
||||
schema: type[T],
|
||||
default: T | None = None,
|
||||
timeout_s: float = 10.0,
|
||||
max_tokens: int = 512,
|
||||
) -> T:
|
||||
schema_json = json.dumps(schema.model_json_schema(), indent=2)
|
||||
schema_block = (
|
||||
@@ -41,10 +42,20 @@ async def classify(
|
||||
Message(role="system", content=system + schema_block),
|
||||
Message(role="user", content=user),
|
||||
]
|
||||
# Cap output length so a misbehaving model (e.g. one that ignores
|
||||
# ``response_format=json_object`` and generates prose) can't burn
|
||||
# several seconds on tokens we'll never use. Classifier responses
|
||||
# are small JSON objects — 512 tokens is generous; usual completions
|
||||
# are 50-150.
|
||||
for attempt in range(3):
|
||||
try:
|
||||
text = await asyncio.wait_for(
|
||||
client.generate(msgs, model=model, response_format={"type": "json_object"}),
|
||||
client.generate(
|
||||
msgs,
|
||||
model=model,
|
||||
response_format={"type": "json_object"},
|
||||
max_tokens=max_tokens,
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
cleaned = _strip_json_fences(text)
|
||||
|
||||
+42
-1
@@ -29,19 +29,60 @@ class FeatherlessClient:
|
||||
cls._semaphore = asyncio.Semaphore(2)
|
||||
return cls._semaphore
|
||||
|
||||
def __init__(self, api_key: str, base_url: str = "https://api.featherless.ai/v1"):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "https://api.featherless.ai/v1",
|
||||
*,
|
||||
default_extra_body: dict | None = None,
|
||||
):
|
||||
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
# ``default_extra_body`` is merged into every chat.completions.create
|
||||
# call's ``extra_body``. Useful with OpenRouter to pin specific
|
||||
# upstream providers (e.g. ``{"provider": {"order": ["Cerebras"]}}``
|
||||
# for 10x throughput on Llama-3.1-8B). Featherless ignores the
|
||||
# field, so it's safe to leave set even when ``base_url`` points
|
||||
# back at Featherless.
|
||||
self._default_extra_body = default_extra_body or {}
|
||||
|
||||
def _merge_extra_body(self, params: dict) -> dict:
|
||||
if not self._default_extra_body:
|
||||
return params
|
||||
eb = dict(self._default_extra_body)
|
||||
eb.update(params.pop("extra_body", {}) or {})
|
||||
params["extra_body"] = eb
|
||||
return params
|
||||
|
||||
async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str:
|
||||
params = self._merge_extra_body(dict(params))
|
||||
async with self._sem():
|
||||
resp = await self._client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": m.role, "content": m.content} for m in messages],
|
||||
**params,
|
||||
)
|
||||
# Diagnostic: stash provider+usage on a side-channel for the
|
||||
# router timing log to pick up. OpenRouter sticks a 'provider'
|
||||
# field on the response (not part of the OAI spec, but the
|
||||
# SDK passes it through on its model dict).
|
||||
try: # pragma: no cover — diagnostic only
|
||||
import os as _os
|
||||
if _os.environ.get("CHAT_LLM_TIMING") == "1":
|
||||
prov = getattr(resp, "provider", None)
|
||||
usage = getattr(resp, "usage", None)
|
||||
ct = getattr(usage, "completion_tokens", "?") if usage else "?"
|
||||
pt = getattr(usage, "prompt_tokens", "?") if usage else "?"
|
||||
import logging as _logging
|
||||
_logging.getLogger("chat.llm.router").info(
|
||||
" ↪ provider=%s prompt_toks=%s completion_toks=%s",
|
||||
prov, pt, ct,
|
||||
)
|
||||
except Exception: # pragma: no cover
|
||||
pass
|
||||
return resp.choices[0].message.content or ""
|
||||
|
||||
async def stream(self, messages: Sequence[Message], *, model: str, **params) -> AsyncIterator[str]:
|
||||
params = self._merge_extra_body(dict(params))
|
||||
async with self._sem():
|
||||
stream = await self._client.chat.completions.create(
|
||||
model=model,
|
||||
|
||||
+100
-10
@@ -9,22 +9,44 @@ Routing rule: requests whose ``model`` argument matches the configured
|
||||
``narrative_model`` go to the narrative backend; everything else
|
||||
(classifier, embeddings, future locally-hosted models) goes to the
|
||||
local backend.
|
||||
|
||||
Set the env var ``CHAT_LLM_TIMING=1`` to log per-call timing at INFO
|
||||
level. Useful for finding the slow link in a turn.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import AsyncIterator, Sequence
|
||||
|
||||
from .client import LLMClient, Message
|
||||
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
_TIMING = os.environ.get("CHAT_LLM_TIMING") == "1"
|
||||
if _TIMING and not _log.handlers:
|
||||
# Wire a stderr handler when timing is enabled so the per-call
|
||||
# logs show up under uvicorn (which doesn't configure non-uvicorn
|
||||
# loggers by default).
|
||||
_h = logging.StreamHandler()
|
||||
_h.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s"))
|
||||
_log.addHandler(_h)
|
||||
_log.setLevel(logging.INFO)
|
||||
_log.propagate = False
|
||||
|
||||
|
||||
class RoutedLLMClient:
|
||||
"""Delegates to one of two underlying clients based on ``model``.
|
||||
|
||||
The narrative client is exercised only when ``model`` exactly equals
|
||||
``narrative_model`` (the configured remote model id). Everything
|
||||
else — classifier traffic, embeddings, any future locally-hosted
|
||||
model — routes to the local client. ``embed`` always routes locally
|
||||
(the remote provider doesn't support it; see
|
||||
Routing rule: any model id starting with one of ``local_prefixes``
|
||||
goes to the local backend (e.g. ``"mlx-community/"`` for models
|
||||
served by ``mlx-omni-server``). Everything else — narrative model,
|
||||
remote classifiers, anything on a hosted provider — routes to the
|
||||
remote backend.
|
||||
|
||||
``embed`` always routes locally (the remote provider doesn't
|
||||
expose a working ``/v1/embeddings``; see
|
||||
:class:`chat.llm.featherless.FeatherlessClient.embed`).
|
||||
"""
|
||||
|
||||
@@ -34,26 +56,94 @@ class RoutedLLMClient:
|
||||
narrative: LLMClient,
|
||||
local: LLMClient,
|
||||
narrative_model: str,
|
||||
classifier: LLMClient | None = None,
|
||||
local_prefixes: tuple[str, ...] = ("mlx-community/",),
|
||||
) -> None:
|
||||
# ``classifier`` is an optional separate backend for the
|
||||
# classifier model. Useful when classifier and narrative both
|
||||
# live on a remote OpenRouter-style provider but need different
|
||||
# provider-pinning (e.g. Cerebras for the 8B classifier,
|
||||
# default Friendli/etc. for the narrative). When ``classifier``
|
||||
# is None, classifier traffic falls through to ``narrative``
|
||||
# (the remote client) so old wiring keeps working.
|
||||
self._narrative = narrative
|
||||
self._classifier = classifier
|
||||
self._local = local
|
||||
self._narrative_model = narrative_model
|
||||
self._local_prefixes = local_prefixes
|
||||
|
||||
def _pick(self, model: str) -> LLMClient:
|
||||
return self._narrative if model == self._narrative_model else self._local
|
||||
if any(model.startswith(p) for p in self._local_prefixes):
|
||||
return self._local
|
||||
if model == self._narrative_model:
|
||||
return self._narrative
|
||||
# Anything else (most importantly, the classifier model) goes
|
||||
# to the classifier client when configured, otherwise to the
|
||||
# narrative remote client.
|
||||
return self._classifier or self._narrative
|
||||
|
||||
async def generate(
|
||||
self, messages: Sequence[Message], *, model: str, **params
|
||||
) -> str:
|
||||
return await self._pick(model).generate(messages, model=model, **params)
|
||||
client = self._pick(model)
|
||||
backend = (
|
||||
"narrative" if client is self._narrative else
|
||||
"classifier" if client is self._classifier else
|
||||
"local"
|
||||
)
|
||||
if not _TIMING:
|
||||
return await client.generate(messages, model=model, **params)
|
||||
in_chars = sum(len(m.content) for m in messages)
|
||||
_log.info("LLM generate START [%s] %s in_chars=%d", backend, model, in_chars)
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
return await client.generate(messages, model=model, **params)
|
||||
finally:
|
||||
_log.info(
|
||||
"LLM generate END [%s] %s in_chars=%d %.2fs",
|
||||
backend, model, in_chars, time.perf_counter() - t0,
|
||||
)
|
||||
|
||||
async def stream(
|
||||
self, messages: Sequence[Message], *, model: str, **params
|
||||
) -> AsyncIterator[str]:
|
||||
async for chunk in self._pick(model).stream(messages, model=model, **params):
|
||||
yield chunk
|
||||
client = self._pick(model)
|
||||
backend = (
|
||||
"narrative" if client is self._narrative else
|
||||
"classifier" if client is self._classifier else
|
||||
"local"
|
||||
)
|
||||
if not _TIMING:
|
||||
async for chunk in client.stream(messages, model=model, **params):
|
||||
yield chunk
|
||||
return
|
||||
t0 = time.perf_counter()
|
||||
ttft = None
|
||||
chars_out = 0
|
||||
try:
|
||||
async for chunk in client.stream(messages, model=model, **params):
|
||||
if ttft is None:
|
||||
ttft = time.perf_counter() - t0
|
||||
chars_out += len(chunk)
|
||||
yield chunk
|
||||
finally:
|
||||
dt = time.perf_counter() - t0
|
||||
in_chars = sum(len(m.content) for m in messages)
|
||||
_log.info(
|
||||
"LLM stream [%s] %s in_chars=%d out_chars=%d ttft=%.2fs total=%.2fs",
|
||||
backend, model, in_chars, chars_out, ttft or 0.0, dt,
|
||||
)
|
||||
|
||||
async def embed(self, text: str, *, model: str) -> list[float]:
|
||||
# Embeddings always run on the local backend — the remote
|
||||
# provider doesn't expose a working ``/v1/embeddings`` endpoint.
|
||||
return await self._local.embed(text, model=model)
|
||||
if not _TIMING:
|
||||
return await self._local.embed(text, model=model)
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
return await self._local.embed(text, model=model)
|
||||
finally:
|
||||
_log.info(
|
||||
"LLM embed [local] %s in_chars=%d %.2fs",
|
||||
model, len(text), time.perf_counter() - t0,
|
||||
)
|
||||
|
||||
@@ -4,13 +4,15 @@ Wraps single-pair compute_state_update to run state updates for ALL
|
||||
directed pairs of present entities. With 3 present entities (you, host,
|
||||
guest) that's 6 directed pairs. With 2 present (you, host) it's 2 pairs.
|
||||
|
||||
Calls run sequentially to respect Featherless's 2-connection cap (the
|
||||
client-level semaphore would serialize them anyway, but doing it here
|
||||
keeps the failure surface clean — a hung pair doesn't queue behind
|
||||
itself).
|
||||
Pairs run concurrently via :func:`asyncio.gather`; the underlying
|
||||
client should impose its own concurrency cap if the upstream provider
|
||||
needs it (e.g., Featherless's 2-conn semaphore). Returning order is
|
||||
preserved (natural iteration over ``present_ids x present_ids``,
|
||||
src != tgt) so downstream event-append order stays deterministic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
|
||||
from chat.llm.client import LLMClient
|
||||
from chat.services.state_update import StateUpdate, compute_state_update
|
||||
@@ -28,35 +30,44 @@ async def compute_state_updates_for_present(
|
||||
timeout_s: float = 30.0,
|
||||
) -> list[tuple[str, str, StateUpdate]]:
|
||||
"""Run compute_state_update for every directed pair (src != tgt) over
|
||||
``present_ids``. Returns list of ``(source_id, target_id, update)``
|
||||
tuples in the natural iteration order over ``present_ids x present_ids``.
|
||||
``present_ids``, concurrently. Returns list of
|
||||
``(source_id, target_id, update)`` tuples in the natural iteration
|
||||
order over ``present_ids x present_ids`` — concurrent dispatch does
|
||||
not change the returned order.
|
||||
|
||||
A single failing pair falls back to the schema-default StateUpdate
|
||||
(zero deltas, empty facts) inside ``compute_state_update``; the batch
|
||||
keeps going.
|
||||
(zero deltas, empty facts) inside ``compute_state_update``; sibling
|
||||
pairs continue independently because each call is wrapped in its
|
||||
own try/except inside ``compute_state_update``.
|
||||
"""
|
||||
out: list[tuple[str, str, StateUpdate]] = []
|
||||
for src in present_ids:
|
||||
for tgt in present_ids:
|
||||
if src == tgt:
|
||||
continue
|
||||
edge = prior_edges.get((src, tgt), {})
|
||||
update = await compute_state_update(
|
||||
client,
|
||||
model=classifier_model,
|
||||
source_id=src,
|
||||
target_id=tgt,
|
||||
source_name=present_names.get(src, src),
|
||||
source_persona=personas.get(src, "") or "",
|
||||
target_name=present_names.get(tgt, tgt),
|
||||
prior_affinity=int(edge.get("affinity", 50)),
|
||||
prior_trust=int(edge.get("trust", 50)),
|
||||
prior_summary=edge.get("summary", "") or "",
|
||||
recent_dialogue=recent_dialogue,
|
||||
timeout_s=timeout_s,
|
||||
)
|
||||
out.append((src, tgt, update))
|
||||
return out
|
||||
pair_keys: list[tuple[str, str]] = [
|
||||
(src, tgt)
|
||||
for src in present_ids
|
||||
for tgt in present_ids
|
||||
if src != tgt
|
||||
]
|
||||
if not pair_keys:
|
||||
return []
|
||||
|
||||
async def _one(src: str, tgt: str) -> StateUpdate:
|
||||
edge = prior_edges.get((src, tgt), {})
|
||||
return await compute_state_update(
|
||||
client,
|
||||
model=classifier_model,
|
||||
source_id=src,
|
||||
target_id=tgt,
|
||||
source_name=present_names.get(src, src),
|
||||
source_persona=personas.get(src, "") or "",
|
||||
target_name=present_names.get(tgt, tgt),
|
||||
prior_affinity=int(edge.get("affinity", 50)),
|
||||
prior_trust=int(edge.get("trust", 50)),
|
||||
prior_summary=edge.get("summary", "") or "",
|
||||
recent_dialogue=recent_dialogue,
|
||||
timeout_s=timeout_s,
|
||||
)
|
||||
|
||||
updates = await asyncio.gather(*(_one(src, tgt) for src, tgt in pair_keys))
|
||||
return [(src, tgt, upd) for (src, tgt), upd in zip(pair_keys, updates)]
|
||||
|
||||
|
||||
__all__ = ["compute_state_updates_for_present"]
|
||||
|
||||
@@ -50,9 +50,24 @@ def get_llm_client(request: Request) -> LLMClient:
|
||||
api_key=settings.featherless_api_key,
|
||||
base_url=settings.featherless_base_url,
|
||||
)
|
||||
# Dedicated classifier client when a provider pin is configured —
|
||||
# routes Llama-3.1-8B (or whatever ``classifier_model`` is) onto a
|
||||
# specific upstream like Cerebras for ~10x throughput. When the
|
||||
# pin is empty, ``classifier`` is None and the router falls back
|
||||
# to the narrative client for classifier traffic.
|
||||
classifier = None
|
||||
if settings.classifier_provider_order:
|
||||
classifier = FeatherlessClient(
|
||||
api_key=settings.featherless_api_key,
|
||||
base_url=settings.featherless_base_url,
|
||||
default_extra_body={
|
||||
"provider": {"order": list(settings.classifier_provider_order)}
|
||||
},
|
||||
)
|
||||
local = LocalMLXClient(base_url=settings.local_mlx_base_url)
|
||||
return RoutedLLMClient(
|
||||
narrative=narrative,
|
||||
classifier=classifier,
|
||||
local=local,
|
||||
narrative_model=settings.narrative_model,
|
||||
)
|
||||
|
||||
+32
-18
@@ -36,58 +36,72 @@ class _StubClient:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_generate_dispatches_narrative_to_narrative_backend():
|
||||
async def test_router_generate_routes_remote_model_to_remote_backend():
|
||||
"""Any model id NOT starting with a local prefix goes to the remote
|
||||
backend — narrative model, remote classifiers, anything else."""
|
||||
narrative = _StubClient("narrative")
|
||||
local = _StubClient("local")
|
||||
router = RoutedLLMClient(
|
||||
narrative=narrative,
|
||||
local=local,
|
||||
narrative_model="big-model",
|
||||
narrative_model="provider/big-model",
|
||||
local_prefixes=("mlx-community/",),
|
||||
)
|
||||
|
||||
out = await router.generate([Message(role="user", content="hi")], model="big-model")
|
||||
out = await router.generate(
|
||||
[Message(role="user", content="hi")], model="provider/big-model"
|
||||
)
|
||||
|
||||
assert out == "narrative:big-model"
|
||||
assert narrative.generate_calls == ["big-model"]
|
||||
assert out == "narrative:provider/big-model"
|
||||
assert narrative.generate_calls == ["provider/big-model"]
|
||||
assert local.generate_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_generate_dispatches_classifier_to_local_backend():
|
||||
async def test_router_generate_routes_local_prefix_to_local_backend():
|
||||
"""Models prefixed with a local prefix (e.g. ``mlx-community/``)
|
||||
go to the local MLX backend regardless of whether the rest of the
|
||||
path looks like a remote provider id."""
|
||||
narrative = _StubClient("narrative")
|
||||
local = _StubClient("local")
|
||||
router = RoutedLLMClient(
|
||||
narrative=narrative,
|
||||
local=local,
|
||||
narrative_model="big-model",
|
||||
narrative_model="provider/big-model",
|
||||
local_prefixes=("mlx-community/",),
|
||||
)
|
||||
|
||||
out = await router.generate(
|
||||
[Message(role="user", content="hi")], model="small-model"
|
||||
[Message(role="user", content="hi")],
|
||||
model="mlx-community/Hermes-3-Llama-3.1-8B-8bit",
|
||||
)
|
||||
|
||||
assert out == "local:small-model"
|
||||
assert local.generate_calls == ["small-model"]
|
||||
assert out == "local:mlx-community/Hermes-3-Llama-3.1-8B-8bit"
|
||||
assert local.generate_calls == ["mlx-community/Hermes-3-Llama-3.1-8B-8bit"]
|
||||
assert narrative.generate_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_stream_dispatches_by_model():
|
||||
async def test_router_stream_dispatches_by_prefix():
|
||||
narrative = _StubClient("narrative")
|
||||
local = _StubClient("local")
|
||||
router = RoutedLLMClient(
|
||||
narrative=narrative, local=local, narrative_model="big-model"
|
||||
narrative=narrative,
|
||||
local=local,
|
||||
narrative_model="provider/big-model",
|
||||
local_prefixes=("mlx-community/",),
|
||||
)
|
||||
|
||||
chunks_big = [c async for c in router.stream(
|
||||
[Message(role="user", content="hi")], model="big-model"
|
||||
chunks_remote = [c async for c in router.stream(
|
||||
[Message(role="user", content="hi")], model="provider/big-model"
|
||||
)]
|
||||
chunks_small = [c async for c in router.stream(
|
||||
[Message(role="user", content="hi")], model="other-model"
|
||||
chunks_local = [c async for c in router.stream(
|
||||
[Message(role="user", content="hi")],
|
||||
model="mlx-community/Hermes-3-Llama-3.1-8B-8bit",
|
||||
)]
|
||||
|
||||
assert chunks_big == ["narrative:big-model"]
|
||||
assert chunks_small == ["local:other-model"]
|
||||
assert chunks_remote == ["narrative:provider/big-model"]
|
||||
assert chunks_local == ["local:mlx-community/Hermes-3-Llama-3.1-8B-8bit"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user