feat: split classifier + embeddings to local mlx-omni-server, narrative stays on Featherless

Adds RoutedLLMClient that dispatches by model name: requests matching
Settings.narrative_model go to Featherless, everything else (classifier
calls, embed) goes to a local MLX server. The local server is
mlx-omni-server (separate venv at .mlx-venv) and exposes the standard
OpenAI surface at http://127.0.0.1:10240/v1.

LocalMLXClient mirrors FeatherlessClient (AsyncOpenAI under the hood)
but with a working embed() — Featherless's /v1/embeddings always
returns 500 with completions_error, so the router unconditionally
sends embed traffic to the local backend.

Production deployment overrides via data/config.toml:
- classifier_model = mlx-community/Hermes-3-Llama-3.1-8B-8bit (~8 GB)
- embedding_model = mlx-community/bge-small-en-v1.5-bf16 (~150 MB,
  384 dim — matches existing schema, no migration)

Defaults stay remote / pseudo so fresh installs and tests need no
external infra. Smoke-tested live: classifier returns expected output,
BGE produces correctly-clustering 384-dim vectors (cat-on-mat closer
to cat-on-rug than to quantum-mechanics).

scripts/start_mlx_server.sh starts the daemon (foreground or --daemon).
.mlx-venv/ added to .gitignore.

Suite: 464 passed (was 457 → +7 new across LocalMLXClient + Router).
This commit is contained in:
Joseph Doherty
2026-04-27 12:05:41 -04:00
parent b3d78c1603
commit fe9c497038
9 changed files with 433 additions and 12 deletions
+84
View File
@@ -0,0 +1,84 @@
"""Tests for LocalMLXClient (Phase 4.5+).
Talks to a local mlx-omni-server over the OpenAI-compatible surface.
We don't spin up a real server in tests — instead we monkey-patch the
underlying ``AsyncOpenAI`` instance to assert on the request shape and
return canned responses. The semaphore behavior is shared with
FeatherlessClient (same pattern), so we don't re-test that here.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
from chat.llm.client import Message
from chat.llm.local_mlx import LocalMLXClient
class _FakeChatCompletions:
def __init__(self, response):
self.response = response
self.calls = []
async def create(self, **kw):
self.calls.append(kw)
return self.response
class _FakeEmbeddings:
def __init__(self, vector):
self.vector = vector
self.calls = []
async def create(self, **kw):
self.calls.append(kw)
return SimpleNamespace(data=[SimpleNamespace(embedding=self.vector)])
@pytest.mark.asyncio
async def test_local_mlx_client_generate_calls_chat_completions():
client = LocalMLXClient(base_url="http://localhost:10240/v1")
fake_response = SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content="hello"))]
)
fake_chat = _FakeChatCompletions(fake_response)
client._client.chat = SimpleNamespace(completions=fake_chat)
out = await client.generate(
[Message(role="user", content="hi")],
model="mlx-community/Hermes-3-Llama-3.1-8B-8bit",
)
assert out == "hello"
assert len(fake_chat.calls) == 1
assert fake_chat.calls[0]["model"] == "mlx-community/Hermes-3-Llama-3.1-8B-8bit"
assert fake_chat.calls[0]["messages"] == [{"role": "user", "content": "hi"}]
@pytest.mark.asyncio
async def test_local_mlx_client_embed_returns_vector():
"""``embed()`` actually works on this client (unlike FeatherlessClient
which raises NotImplementedError) — the local MLX server has a real
``/v1/embeddings`` endpoint backed by an MLX-quantized model.
"""
client = LocalMLXClient()
canned = [0.1, 0.2, 0.3, 0.4]
fake_embeddings = _FakeEmbeddings(canned)
client._client.embeddings = fake_embeddings
out = await client.embed("hello", model="mlx-community/bge-small-en-v1.5-bf16")
assert out == canned
assert fake_embeddings.calls[0]["model"] == "mlx-community/bge-small-en-v1.5-bf16"
assert fake_embeddings.calls[0]["input"] == "hello"
@pytest.mark.asyncio
async def test_local_mlx_client_default_base_url():
"""Default base_url targets ``mlx-omni-server`` on its standard port."""
client = LocalMLXClient()
# AsyncOpenAI normalizes trailing-slash differences; just check the
# configured host:port appears in the underlying client config.
assert "127.0.0.1:10240" in str(client._client.base_url)
+108
View File
@@ -0,0 +1,108 @@
"""Tests for RoutedLLMClient (Phase 4.5+).
Splits traffic across two underlying clients based on the ``model``
kwarg. We use simple stub clients to assert the router picks the
correct backend for each call.
"""
from __future__ import annotations
from typing import AsyncIterator, Sequence
import pytest
from chat.llm.client import Message
from chat.llm.router import RoutedLLMClient
class _StubClient:
def __init__(self, name: str):
self.name = name
self.generate_calls: list[str] = []
self.stream_calls: list[str] = []
self.embed_calls: list[str] = []
async def generate(self, messages, *, model, **params) -> str:
self.generate_calls.append(model)
return f"{self.name}:{model}"
async def stream(self, messages, *, model, **params) -> AsyncIterator[str]:
self.stream_calls.append(model)
yield f"{self.name}:{model}"
async def embed(self, text, *, model) -> list[float]:
self.embed_calls.append(model)
return [1.0, 2.0]
@pytest.mark.asyncio
async def test_router_generate_dispatches_narrative_to_narrative_backend():
narrative = _StubClient("narrative")
local = _StubClient("local")
router = RoutedLLMClient(
narrative=narrative,
local=local,
narrative_model="big-model",
)
out = await router.generate([Message(role="user", content="hi")], model="big-model")
assert out == "narrative:big-model"
assert narrative.generate_calls == ["big-model"]
assert local.generate_calls == []
@pytest.mark.asyncio
async def test_router_generate_dispatches_classifier_to_local_backend():
narrative = _StubClient("narrative")
local = _StubClient("local")
router = RoutedLLMClient(
narrative=narrative,
local=local,
narrative_model="big-model",
)
out = await router.generate(
[Message(role="user", content="hi")], model="small-model"
)
assert out == "local:small-model"
assert local.generate_calls == ["small-model"]
assert narrative.generate_calls == []
@pytest.mark.asyncio
async def test_router_stream_dispatches_by_model():
narrative = _StubClient("narrative")
local = _StubClient("local")
router = RoutedLLMClient(
narrative=narrative, local=local, narrative_model="big-model"
)
chunks_big = [c async for c in router.stream(
[Message(role="user", content="hi")], model="big-model"
)]
chunks_small = [c async for c in router.stream(
[Message(role="user", content="hi")], model="other-model"
)]
assert chunks_big == ["narrative:big-model"]
assert chunks_small == ["local:other-model"]
@pytest.mark.asyncio
async def test_router_embed_always_routes_to_local():
"""Embeddings always run locally — the remote provider doesn't
expose a working ``/v1/embeddings``, so the router never sends
embed calls there even if the model name happens to look 'remote'."""
narrative = _StubClient("narrative")
local = _StubClient("local")
router = RoutedLLMClient(
narrative=narrative, local=local, narrative_model="big-model"
)
out = await router.embed("hello", model="any-embedding-model")
assert out == [1.0, 2.0]
assert local.embed_calls == ["any-embedding-model"]
assert narrative.embed_calls == []