feat: split classifier + embeddings to local mlx-omni-server, narrative stays on Featherless
Adds RoutedLLMClient that dispatches by model name: requests matching Settings.narrative_model go to Featherless, everything else (classifier calls, embed) goes to a local MLX server. The local server is mlx-omni-server (separate venv at .mlx-venv) and exposes the standard OpenAI surface at http://127.0.0.1:10240/v1. LocalMLXClient mirrors FeatherlessClient (AsyncOpenAI under the hood) but with a working embed() — Featherless's /v1/embeddings always returns 500 with completions_error, so the router unconditionally sends embed traffic to the local backend. Production deployment overrides via data/config.toml: - classifier_model = mlx-community/Hermes-3-Llama-3.1-8B-8bit (~8 GB) - embedding_model = mlx-community/bge-small-en-v1.5-bf16 (~150 MB, 384 dim — matches existing schema, no migration) Defaults stay remote / pseudo so fresh installs and tests need no external infra. Smoke-tested live: classifier returns expected output, BGE produces correctly-clustering 384-dim vectors (cat-on-mat closer to cat-on-rug than to quantum-mechanics). scripts/start_mlx_server.sh starts the daemon (foreground or --daemon). .mlx-venv/ added to .gitignore. Suite: 464 passed (was 457 → +7 new across LocalMLXClient + Router).
This commit is contained in:
@@ -5,6 +5,7 @@ data/
|
|||||||
|
|
||||||
# Python
|
# Python
|
||||||
.venv/
|
.venv/
|
||||||
|
.mlx-venv/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.pyc
|
*.pyc
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
|
|||||||
+15
-4
@@ -72,17 +72,28 @@ async def lifespan(app: FastAPI):
|
|||||||
# (free / lower paid tiers cap at 2). Shared across all
|
# (free / lower paid tiers cap at 2). Shared across all
|
||||||
# FeatherlessClient instances in the process.
|
# FeatherlessClient instances in the process.
|
||||||
from chat.llm.featherless import FeatherlessClient
|
from chat.llm.featherless import FeatherlessClient
|
||||||
|
from chat.llm.local_mlx import LocalMLXClient
|
||||||
|
from chat.llm.router import RoutedLLMClient
|
||||||
|
|
||||||
FeatherlessClient.configure_concurrency(settings.featherless_max_concurrent)
|
FeatherlessClient.configure_concurrency(settings.featherless_max_concurrent)
|
||||||
|
LocalMLXClient.configure_concurrency(settings.local_mlx_max_concurrent)
|
||||||
|
|
||||||
# Background worker for the async significance pass (T22). Each job
|
# Background workers (significance scoring, embedding indexer)
|
||||||
# constructs a fresh FeatherlessClient via the factory; tests can
|
# construct a fresh client per job via the factory. Workers route
|
||||||
# disable enqueue by toggling ``app.state.background_worker.enabled``.
|
# through the same RoutedLLMClient as request-time traffic so the
|
||||||
|
# narrative model still goes to Featherless and the classifier +
|
||||||
|
# embeddings hit the local MLX server.
|
||||||
def _factory():
|
def _factory():
|
||||||
return FeatherlessClient(
|
narrative = FeatherlessClient(
|
||||||
api_key=settings.featherless_api_key,
|
api_key=settings.featherless_api_key,
|
||||||
base_url=settings.featherless_base_url,
|
base_url=settings.featherless_base_url,
|
||||||
)
|
)
|
||||||
|
local = LocalMLXClient(base_url=settings.local_mlx_base_url)
|
||||||
|
return RoutedLLMClient(
|
||||||
|
narrative=narrative,
|
||||||
|
local=local,
|
||||||
|
narrative_model=settings.narrative_model,
|
||||||
|
)
|
||||||
|
|
||||||
worker = BackgroundWorker(settings, llm_client_factory=_factory)
|
worker = BackgroundWorker(settings, llm_client_factory=_factory)
|
||||||
await worker.start()
|
await worker.start()
|
||||||
|
|||||||
+15
-6
@@ -39,13 +39,22 @@ class Settings(BaseModel):
|
|||||||
data_dir: Path = REPO_ROOT / "data"
|
data_dir: Path = REPO_ROOT / "data"
|
||||||
bind_host: str = "127.0.0.1"
|
bind_host: str = "127.0.0.1"
|
||||||
bind_port: int = 8000
|
bind_port: int = 8000
|
||||||
|
# Local MLX server (e.g. ``mlx-omni-server``) — serves any model
|
||||||
|
# whose id is NOT ``narrative_model``. The :class:`RoutedLLMClient`
|
||||||
|
# inspects the ``model`` kwarg at call time: ``model == narrative_model``
|
||||||
|
# -> Featherless, otherwise local. ``embed()`` always routes local.
|
||||||
|
# If no MLX server is running and ``classifier_model`` is set to a
|
||||||
|
# local model id, classifier calls will fail; ``embed()`` will fall
|
||||||
|
# back to the zero-vector path with a T107 warning. To use the
|
||||||
|
# default Featherless-only deployment, leave ``classifier_model``
|
||||||
|
# at its remote default and ``embedding_model`` at the pseudo.
|
||||||
|
local_mlx_base_url: str = "http://127.0.0.1:10240/v1"
|
||||||
|
local_mlx_max_concurrent: int = 1
|
||||||
# T112 (Phase 4.5): embedding model identifier. Default is the
|
# T112 (Phase 4.5): embedding model identifier. Default is the
|
||||||
# deterministic local pseudo (semantically meaningless but keeps the
|
# deterministic local pseudo so fresh installs / tests don't need
|
||||||
# vector pipeline structurally valid). Swap to a real model name
|
# any external infra. Override via config.toml to a real model id
|
||||||
# (e.g. "bge-small-en-v1.5") once the LLMClient implementation
|
# (e.g. ``"mlx-community/bge-small-en-v1.5-bf16"``) once a local
|
||||||
# supports embed() — currently FeatherlessClient does NOT, so a
|
# MLX server is running.
|
||||||
# non-default value will trigger the zero-vector fallback path
|
|
||||||
# plus a T107 warning until a different provider is wired in.
|
|
||||||
embedding_model: str = "pseudo-sha256-384"
|
embedding_model: str = "pseudo-sha256-384"
|
||||||
|
|
||||||
def load_settings() -> Settings:
|
def load_settings() -> Settings:
|
||||||
|
|||||||
@@ -0,0 +1,95 @@
|
|||||||
|
"""Local MLX OpenAI-compatible client.
|
||||||
|
|
||||||
|
Talks to a locally-running MLX server (e.g., ``mlx-omni-server``) over
|
||||||
|
the same OpenAI surface that :class:`chat.llm.featherless.FeatherlessClient`
|
||||||
|
uses, via :class:`openai.AsyncOpenAI`. The underlying server runs MLX
|
||||||
|
models on Apple Silicon (M-series) for chat completions AND embeddings.
|
||||||
|
|
||||||
|
Use cases (Phase 4.5+):
|
||||||
|
- Classifier traffic moved off Featherless to local MLX (cost + latency).
|
||||||
|
- Embeddings via ``client.embed`` actually work — Featherless's
|
||||||
|
``/v1/embeddings`` always returns 500.
|
||||||
|
|
||||||
|
Constructor takes a ``base_url`` (e.g., ``"http://127.0.0.1:10240/v1"``)
|
||||||
|
and an optional ``api_key`` (most local MLX servers don't authenticate;
|
||||||
|
the OpenAI SDK requires *some* string, so we default to a placeholder).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import asyncio
|
||||||
|
from typing import AsyncIterator, Sequence
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from .client import Message
|
||||||
|
|
||||||
|
|
||||||
|
class LocalMLXClient:
|
||||||
|
"""OpenAI-compatible client for a local MLX server.
|
||||||
|
|
||||||
|
The server is single-process by default (``mlx-omni-server`` loads
|
||||||
|
one model at a time and swaps on demand). The class-level semaphore
|
||||||
|
serializes concurrent requests so we never queue more than
|
||||||
|
``max_concurrent`` at a time — defaults to 1, since MLX inference
|
||||||
|
on a single M-series device is sequential anyway.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_semaphore: asyncio.Semaphore | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def configure_concurrency(cls, max_concurrent: int) -> None:
|
||||||
|
cls._semaphore = asyncio.Semaphore(max(1, int(max_concurrent)))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _sem(cls) -> asyncio.Semaphore:
|
||||||
|
if cls._semaphore is None:
|
||||||
|
cls._semaphore = asyncio.Semaphore(1)
|
||||||
|
return cls._semaphore
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str = "http://127.0.0.1:10240/v1",
|
||||||
|
api_key: str = "not-needed",
|
||||||
|
):
|
||||||
|
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self, messages: Sequence[Message], *, model: str, **params
|
||||||
|
) -> str:
|
||||||
|
async with self._sem():
|
||||||
|
resp = await self._client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[{"role": m.role, "content": m.content} for m in messages],
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
return resp.choices[0].message.content or ""
|
||||||
|
|
||||||
|
async def stream(
|
||||||
|
self, messages: Sequence[Message], *, model: str, **params
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
async with self._sem():
|
||||||
|
stream = await self._client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[{"role": m.role, "content": m.content} for m in messages],
|
||||||
|
stream=True,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
async for chunk in stream:
|
||||||
|
delta = chunk.choices[0].delta.content or ""
|
||||||
|
if delta:
|
||||||
|
yield delta
|
||||||
|
|
||||||
|
async def embed(self, text: str, *, model: str) -> list[float]:
|
||||||
|
"""Return an embedding vector for ``text`` using the named model.
|
||||||
|
|
||||||
|
Targets ``/v1/embeddings`` on the local MLX server; the server
|
||||||
|
loads the model on first request and caches it. The embedding
|
||||||
|
model is independent of the chat model loaded for ``generate``
|
||||||
|
/ ``stream`` (the server can serve both).
|
||||||
|
"""
|
||||||
|
async with self._sem():
|
||||||
|
resp = await self._client.embeddings.create(
|
||||||
|
model=model,
|
||||||
|
input=text,
|
||||||
|
)
|
||||||
|
return list(resp.data[0].embedding)
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
"""Routed LLM client — splits traffic across multiple backends by model.
|
||||||
|
|
||||||
|
Phase 4.5+ deployment: the 24B narrative model stays on Featherless,
|
||||||
|
the 8B classifier model moves to local MLX, and embeddings run on a
|
||||||
|
local BGE/MLX model. One :class:`LLMClient` interface, two underlying
|
||||||
|
backends, dispatched by the ``model`` argument at every call site.
|
||||||
|
|
||||||
|
Routing rule: requests whose ``model`` argument matches the configured
|
||||||
|
``narrative_model`` go to the narrative backend; everything else
|
||||||
|
(classifier, embeddings, future locally-hosted models) goes to the
|
||||||
|
local backend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
from typing import AsyncIterator, Sequence
|
||||||
|
|
||||||
|
from .client import LLMClient, Message
|
||||||
|
|
||||||
|
|
||||||
|
class RoutedLLMClient:
|
||||||
|
"""Delegates to one of two underlying clients based on ``model``.
|
||||||
|
|
||||||
|
The narrative client is exercised only when ``model`` exactly equals
|
||||||
|
``narrative_model`` (the configured remote model id). Everything
|
||||||
|
else — classifier traffic, embeddings, any future locally-hosted
|
||||||
|
model — routes to the local client. ``embed`` always routes locally
|
||||||
|
(the remote provider doesn't support it; see
|
||||||
|
:class:`chat.llm.featherless.FeatherlessClient.embed`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
narrative: LLMClient,
|
||||||
|
local: LLMClient,
|
||||||
|
narrative_model: str,
|
||||||
|
) -> None:
|
||||||
|
self._narrative = narrative
|
||||||
|
self._local = local
|
||||||
|
self._narrative_model = narrative_model
|
||||||
|
|
||||||
|
def _pick(self, model: str) -> LLMClient:
|
||||||
|
return self._narrative if model == self._narrative_model else self._local
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self, messages: Sequence[Message], *, model: str, **params
|
||||||
|
) -> str:
|
||||||
|
return await self._pick(model).generate(messages, model=model, **params)
|
||||||
|
|
||||||
|
async def stream(
|
||||||
|
self, messages: Sequence[Message], *, model: str, **params
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
async for chunk in self._pick(model).stream(messages, model=model, **params):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def embed(self, text: str, *, model: str) -> list[float]:
|
||||||
|
# Embeddings always run on the local backend — the remote
|
||||||
|
# provider doesn't expose a working ``/v1/embeddings`` endpoint.
|
||||||
|
return await self._local.embed(text, model=model)
|
||||||
+18
-2
@@ -32,14 +32,30 @@ router = APIRouter()
|
|||||||
|
|
||||||
|
|
||||||
def get_llm_client(request: Request) -> LLMClient:
|
def get_llm_client(request: Request) -> LLMClient:
|
||||||
"""Production LLM client. Tests override this via ``app.dependency_overrides``."""
|
"""Production LLM client. Tests override this via ``app.dependency_overrides``.
|
||||||
|
|
||||||
|
Returns a :class:`chat.llm.router.RoutedLLMClient` that splits
|
||||||
|
traffic: the narrative model goes to Featherless, the classifier
|
||||||
|
+ embeddings go to the local MLX server (``mlx-omni-server``).
|
||||||
|
Both backends share the OpenAI-compatible surface, so the routing
|
||||||
|
is invisible to call sites — they just pass ``model=...`` and the
|
||||||
|
router picks the backend.
|
||||||
|
"""
|
||||||
settings = request.app.state.settings
|
settings = request.app.state.settings
|
||||||
from chat.llm.featherless import FeatherlessClient
|
from chat.llm.featherless import FeatherlessClient
|
||||||
|
from chat.llm.local_mlx import LocalMLXClient
|
||||||
|
from chat.llm.router import RoutedLLMClient
|
||||||
|
|
||||||
return FeatherlessClient(
|
narrative = FeatherlessClient(
|
||||||
api_key=settings.featherless_api_key,
|
api_key=settings.featherless_api_key,
|
||||||
base_url=settings.featherless_base_url,
|
base_url=settings.featherless_base_url,
|
||||||
)
|
)
|
||||||
|
local = LocalMLXClient(base_url=settings.local_mlx_base_url)
|
||||||
|
return RoutedLLMClient(
|
||||||
|
narrative=narrative,
|
||||||
|
local=local,
|
||||||
|
narrative_model=settings.narrative_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _parse_holding(text: str) -> list[str]:
|
def _parse_holding(text: str) -> list[str]:
|
||||||
|
|||||||
Executable
+38
@@ -0,0 +1,38 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Start the local mlx-omni-server that serves the classifier + embedding
|
||||||
|
# models. The chat app's RoutedLLMClient routes everything except the
|
||||||
|
# narrative model to this server; with no MLX server running, classifier
|
||||||
|
# calls fail and embeddings degrade to the zero-vector fallback.
|
||||||
|
#
|
||||||
|
# Run in the foreground:
|
||||||
|
# ./scripts/start_mlx_server.sh
|
||||||
|
# Run as a background daemon (logs to data/mlx-server.log):
|
||||||
|
# ./scripts/start_mlx_server.sh --daemon
|
||||||
|
#
|
||||||
|
# Models are pulled from Hugging Face on first request; expect a delay
|
||||||
|
# the first time you exercise the classifier or embedding path.
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
||||||
|
VENV="${REPO_ROOT}/.mlx-venv"
|
||||||
|
LOG="${REPO_ROOT}/data/mlx-server.log"
|
||||||
|
PORT="${MLX_PORT:-10240}"
|
||||||
|
HOST="${MLX_HOST:-127.0.0.1}"
|
||||||
|
|
||||||
|
if [ ! -x "${VENV}/bin/mlx-omni-server" ]; then
|
||||||
|
echo "error: mlx-omni-server not installed in ${VENV}" >&2
|
||||||
|
echo "create the venv with:" >&2
|
||||||
|
echo " python3.12 -m venv ${VENV} && ${VENV}/bin/pip install mlx-omni-server" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "${1:-}" = "--daemon" ]; then
|
||||||
|
mkdir -p "$(dirname "${LOG}")"
|
||||||
|
nohup "${VENV}/bin/mlx-omni-server" --host "${HOST}" --port "${PORT}" \
|
||||||
|
>>"${LOG}" 2>&1 &
|
||||||
|
echo "mlx-omni-server started in background (pid $!)"
|
||||||
|
echo "logs: ${LOG}"
|
||||||
|
else
|
||||||
|
exec "${VENV}/bin/mlx-omni-server" --host "${HOST}" --port "${PORT}"
|
||||||
|
fi
|
||||||
@@ -0,0 +1,84 @@
|
|||||||
|
"""Tests for LocalMLXClient (Phase 4.5+).
|
||||||
|
|
||||||
|
Talks to a local mlx-omni-server over the OpenAI-compatible surface.
|
||||||
|
We don't spin up a real server in tests — instead we monkey-patch the
|
||||||
|
underlying ``AsyncOpenAI`` instance to assert on the request shape and
|
||||||
|
return canned responses. The semaphore behavior is shared with
|
||||||
|
FeatherlessClient (same pattern), so we don't re-test that here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from chat.llm.client import Message
|
||||||
|
from chat.llm.local_mlx import LocalMLXClient
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeChatCompletions:
|
||||||
|
def __init__(self, response):
|
||||||
|
self.response = response
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
async def create(self, **kw):
|
||||||
|
self.calls.append(kw)
|
||||||
|
return self.response
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeEmbeddings:
|
||||||
|
def __init__(self, vector):
|
||||||
|
self.vector = vector
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
async def create(self, **kw):
|
||||||
|
self.calls.append(kw)
|
||||||
|
return SimpleNamespace(data=[SimpleNamespace(embedding=self.vector)])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_local_mlx_client_generate_calls_chat_completions():
|
||||||
|
client = LocalMLXClient(base_url="http://localhost:10240/v1")
|
||||||
|
fake_response = SimpleNamespace(
|
||||||
|
choices=[SimpleNamespace(message=SimpleNamespace(content="hello"))]
|
||||||
|
)
|
||||||
|
fake_chat = _FakeChatCompletions(fake_response)
|
||||||
|
client._client.chat = SimpleNamespace(completions=fake_chat)
|
||||||
|
|
||||||
|
out = await client.generate(
|
||||||
|
[Message(role="user", content="hi")],
|
||||||
|
model="mlx-community/Hermes-3-Llama-3.1-8B-8bit",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert out == "hello"
|
||||||
|
assert len(fake_chat.calls) == 1
|
||||||
|
assert fake_chat.calls[0]["model"] == "mlx-community/Hermes-3-Llama-3.1-8B-8bit"
|
||||||
|
assert fake_chat.calls[0]["messages"] == [{"role": "user", "content": "hi"}]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_local_mlx_client_embed_returns_vector():
|
||||||
|
"""``embed()`` actually works on this client (unlike FeatherlessClient
|
||||||
|
which raises NotImplementedError) — the local MLX server has a real
|
||||||
|
``/v1/embeddings`` endpoint backed by an MLX-quantized model.
|
||||||
|
"""
|
||||||
|
client = LocalMLXClient()
|
||||||
|
canned = [0.1, 0.2, 0.3, 0.4]
|
||||||
|
fake_embeddings = _FakeEmbeddings(canned)
|
||||||
|
client._client.embeddings = fake_embeddings
|
||||||
|
|
||||||
|
out = await client.embed("hello", model="mlx-community/bge-small-en-v1.5-bf16")
|
||||||
|
|
||||||
|
assert out == canned
|
||||||
|
assert fake_embeddings.calls[0]["model"] == "mlx-community/bge-small-en-v1.5-bf16"
|
||||||
|
assert fake_embeddings.calls[0]["input"] == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_local_mlx_client_default_base_url():
|
||||||
|
"""Default base_url targets ``mlx-omni-server`` on its standard port."""
|
||||||
|
client = LocalMLXClient()
|
||||||
|
# AsyncOpenAI normalizes trailing-slash differences; just check the
|
||||||
|
# configured host:port appears in the underlying client config.
|
||||||
|
assert "127.0.0.1:10240" in str(client._client.base_url)
|
||||||
@@ -0,0 +1,108 @@
|
|||||||
|
"""Tests for RoutedLLMClient (Phase 4.5+).
|
||||||
|
|
||||||
|
Splits traffic across two underlying clients based on the ``model``
|
||||||
|
kwarg. We use simple stub clients to assert the router picks the
|
||||||
|
correct backend for each call.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import AsyncIterator, Sequence
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from chat.llm.client import Message
|
||||||
|
from chat.llm.router import RoutedLLMClient
|
||||||
|
|
||||||
|
|
||||||
|
class _StubClient:
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name = name
|
||||||
|
self.generate_calls: list[str] = []
|
||||||
|
self.stream_calls: list[str] = []
|
||||||
|
self.embed_calls: list[str] = []
|
||||||
|
|
||||||
|
async def generate(self, messages, *, model, **params) -> str:
|
||||||
|
self.generate_calls.append(model)
|
||||||
|
return f"{self.name}:{model}"
|
||||||
|
|
||||||
|
async def stream(self, messages, *, model, **params) -> AsyncIterator[str]:
|
||||||
|
self.stream_calls.append(model)
|
||||||
|
yield f"{self.name}:{model}"
|
||||||
|
|
||||||
|
async def embed(self, text, *, model) -> list[float]:
|
||||||
|
self.embed_calls.append(model)
|
||||||
|
return [1.0, 2.0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_generate_dispatches_narrative_to_narrative_backend():
|
||||||
|
narrative = _StubClient("narrative")
|
||||||
|
local = _StubClient("local")
|
||||||
|
router = RoutedLLMClient(
|
||||||
|
narrative=narrative,
|
||||||
|
local=local,
|
||||||
|
narrative_model="big-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
out = await router.generate([Message(role="user", content="hi")], model="big-model")
|
||||||
|
|
||||||
|
assert out == "narrative:big-model"
|
||||||
|
assert narrative.generate_calls == ["big-model"]
|
||||||
|
assert local.generate_calls == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_generate_dispatches_classifier_to_local_backend():
|
||||||
|
narrative = _StubClient("narrative")
|
||||||
|
local = _StubClient("local")
|
||||||
|
router = RoutedLLMClient(
|
||||||
|
narrative=narrative,
|
||||||
|
local=local,
|
||||||
|
narrative_model="big-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
out = await router.generate(
|
||||||
|
[Message(role="user", content="hi")], model="small-model"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert out == "local:small-model"
|
||||||
|
assert local.generate_calls == ["small-model"]
|
||||||
|
assert narrative.generate_calls == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_stream_dispatches_by_model():
|
||||||
|
narrative = _StubClient("narrative")
|
||||||
|
local = _StubClient("local")
|
||||||
|
router = RoutedLLMClient(
|
||||||
|
narrative=narrative, local=local, narrative_model="big-model"
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks_big = [c async for c in router.stream(
|
||||||
|
[Message(role="user", content="hi")], model="big-model"
|
||||||
|
)]
|
||||||
|
chunks_small = [c async for c in router.stream(
|
||||||
|
[Message(role="user", content="hi")], model="other-model"
|
||||||
|
)]
|
||||||
|
|
||||||
|
assert chunks_big == ["narrative:big-model"]
|
||||||
|
assert chunks_small == ["local:other-model"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_embed_always_routes_to_local():
|
||||||
|
"""Embeddings always run locally — the remote provider doesn't
|
||||||
|
expose a working ``/v1/embeddings``, so the router never sends
|
||||||
|
embed calls there even if the model name happens to look 'remote'."""
|
||||||
|
narrative = _StubClient("narrative")
|
||||||
|
local = _StubClient("local")
|
||||||
|
router = RoutedLLMClient(
|
||||||
|
narrative=narrative, local=local, narrative_model="big-model"
|
||||||
|
)
|
||||||
|
|
||||||
|
out = await router.embed("hello", model="any-embedding-model")
|
||||||
|
|
||||||
|
assert out == [1.0, 2.0]
|
||||||
|
assert local.embed_calls == ["any-embedding-model"]
|
||||||
|
assert narrative.embed_calls == []
|
||||||
Reference in New Issue
Block a user