fe9c497038
Adds RoutedLLMClient that dispatches by model name: requests matching Settings.narrative_model go to Featherless, everything else (classifier calls, embed) goes to a local MLX server. The local server is mlx-omni-server (separate venv at .mlx-venv) and exposes the standard OpenAI surface at http://127.0.0.1:10240/v1. LocalMLXClient mirrors FeatherlessClient (AsyncOpenAI under the hood) but with a working embed() — Featherless's /v1/embeddings always returns 500 with completions_error, so the router unconditionally sends embed traffic to the local backend. Production deployment overrides via data/config.toml: - classifier_model = mlx-community/Hermes-3-Llama-3.1-8B-8bit (~8 GB) - embedding_model = mlx-community/bge-small-en-v1.5-bf16 (~150 MB, 384 dim — matches existing schema, no migration) Defaults stay remote / pseudo so fresh installs and tests need no external infra. Smoke-tested live: classifier returns expected output, BGE produces correctly-clustering 384-dim vectors (cat-on-mat closer to cat-on-rug than to quantum-mechanics). scripts/start_mlx_server.sh starts the daemon (foreground or --daemon). .mlx-venv/ added to .gitignore. Suite: 464 passed (was 457 → +7 new across LocalMLXClient + Router).
109 lines
3.3 KiB
Python
109 lines
3.3 KiB
Python
"""Tests for RoutedLLMClient (Phase 4.5+).
|
|
|
|
Splits traffic across two underlying clients based on the ``model``
|
|
kwarg. We use simple stub clients to assert the router picks the
|
|
correct backend for each call.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import AsyncIterator, Sequence
|
|
|
|
import pytest
|
|
|
|
from chat.llm.client import Message
|
|
from chat.llm.router import RoutedLLMClient
|
|
|
|
|
|
class _StubClient:
|
|
def __init__(self, name: str):
|
|
self.name = name
|
|
self.generate_calls: list[str] = []
|
|
self.stream_calls: list[str] = []
|
|
self.embed_calls: list[str] = []
|
|
|
|
async def generate(self, messages, *, model, **params) -> str:
|
|
self.generate_calls.append(model)
|
|
return f"{self.name}:{model}"
|
|
|
|
async def stream(self, messages, *, model, **params) -> AsyncIterator[str]:
|
|
self.stream_calls.append(model)
|
|
yield f"{self.name}:{model}"
|
|
|
|
async def embed(self, text, *, model) -> list[float]:
|
|
self.embed_calls.append(model)
|
|
return [1.0, 2.0]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_router_generate_dispatches_narrative_to_narrative_backend():
|
|
narrative = _StubClient("narrative")
|
|
local = _StubClient("local")
|
|
router = RoutedLLMClient(
|
|
narrative=narrative,
|
|
local=local,
|
|
narrative_model="big-model",
|
|
)
|
|
|
|
out = await router.generate([Message(role="user", content="hi")], model="big-model")
|
|
|
|
assert out == "narrative:big-model"
|
|
assert narrative.generate_calls == ["big-model"]
|
|
assert local.generate_calls == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_router_generate_dispatches_classifier_to_local_backend():
|
|
narrative = _StubClient("narrative")
|
|
local = _StubClient("local")
|
|
router = RoutedLLMClient(
|
|
narrative=narrative,
|
|
local=local,
|
|
narrative_model="big-model",
|
|
)
|
|
|
|
out = await router.generate(
|
|
[Message(role="user", content="hi")], model="small-model"
|
|
)
|
|
|
|
assert out == "local:small-model"
|
|
assert local.generate_calls == ["small-model"]
|
|
assert narrative.generate_calls == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_router_stream_dispatches_by_model():
|
|
narrative = _StubClient("narrative")
|
|
local = _StubClient("local")
|
|
router = RoutedLLMClient(
|
|
narrative=narrative, local=local, narrative_model="big-model"
|
|
)
|
|
|
|
chunks_big = [c async for c in router.stream(
|
|
[Message(role="user", content="hi")], model="big-model"
|
|
)]
|
|
chunks_small = [c async for c in router.stream(
|
|
[Message(role="user", content="hi")], model="other-model"
|
|
)]
|
|
|
|
assert chunks_big == ["narrative:big-model"]
|
|
assert chunks_small == ["local:other-model"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_router_embed_always_routes_to_local():
|
|
"""Embeddings always run locally — the remote provider doesn't
|
|
expose a working ``/v1/embeddings``, so the router never sends
|
|
embed calls there even if the model name happens to look 'remote'."""
|
|
narrative = _StubClient("narrative")
|
|
local = _StubClient("local")
|
|
router = RoutedLLMClient(
|
|
narrative=narrative, local=local, narrative_model="big-model"
|
|
)
|
|
|
|
out = await router.embed("hello", model="any-embedding-model")
|
|
|
|
assert out == [1.0, 2.0]
|
|
assert local.embed_calls == ["any-embedding-model"]
|
|
assert narrative.embed_calls == []
|