feat: split classifier + embeddings to local mlx-omni-server, narrative stays on Featherless

Adds RoutedLLMClient that dispatches by model name: requests matching
Settings.narrative_model go to Featherless, everything else (classifier
calls, embed) goes to a local MLX server. The local server is
mlx-omni-server (separate venv at .mlx-venv) and exposes the standard
OpenAI surface at http://127.0.0.1:10240/v1.

LocalMLXClient mirrors FeatherlessClient (AsyncOpenAI under the hood)
but with a working embed() — Featherless's /v1/embeddings always
returns 500 with completions_error, so the router unconditionally
sends embed traffic to the local backend.

Production deployment overrides via data/config.toml:
- classifier_model = mlx-community/Hermes-3-Llama-3.1-8B-8bit (~8 GB)
- embedding_model = mlx-community/bge-small-en-v1.5-bf16 (~150 MB,
  384 dim — matches existing schema, no migration)

Defaults stay remote / pseudo so fresh installs and tests need no
external infra. Smoke-tested live: classifier returns expected output,
BGE produces correctly-clustering 384-dim vectors (cat-on-mat closer
to cat-on-rug than to quantum-mechanics).

scripts/start_mlx_server.sh starts the daemon (foreground or --daemon).
.mlx-venv/ added to .gitignore.

Suite: 464 passed (was 457 → +7 new across LocalMLXClient + Router).
This commit is contained in:
Joseph Doherty
2026-04-27 12:05:41 -04:00
parent b3d78c1603
commit fe9c497038
9 changed files with 433 additions and 12 deletions
+1
View File
@@ -5,6 +5,7 @@ data/
# Python
.venv/
.mlx-venv/
__pycache__/
*.pyc
.pytest_cache/
+15 -4
View File
@@ -72,17 +72,28 @@ async def lifespan(app: FastAPI):
# (free / lower paid tiers cap at 2). Shared across all
# FeatherlessClient instances in the process.
from chat.llm.featherless import FeatherlessClient
from chat.llm.local_mlx import LocalMLXClient
from chat.llm.router import RoutedLLMClient
FeatherlessClient.configure_concurrency(settings.featherless_max_concurrent)
LocalMLXClient.configure_concurrency(settings.local_mlx_max_concurrent)
# Background worker for the async significance pass (T22). Each job
# constructs a fresh FeatherlessClient via the factory; tests can
# disable enqueue by toggling ``app.state.background_worker.enabled``.
# Background workers (significance scoring, embedding indexer)
# construct a fresh client per job via the factory. Workers route
# through the same RoutedLLMClient as request-time traffic so the
# narrative model still goes to Featherless and the classifier +
# embeddings hit the local MLX server.
def _factory():
return FeatherlessClient(
narrative = FeatherlessClient(
api_key=settings.featherless_api_key,
base_url=settings.featherless_base_url,
)
local = LocalMLXClient(base_url=settings.local_mlx_base_url)
return RoutedLLMClient(
narrative=narrative,
local=local,
narrative_model=settings.narrative_model,
)
worker = BackgroundWorker(settings, llm_client_factory=_factory)
await worker.start()
+15 -6
View File
@@ -39,13 +39,22 @@ class Settings(BaseModel):
data_dir: Path = REPO_ROOT / "data"
bind_host: str = "127.0.0.1"
bind_port: int = 8000
# Local MLX server (e.g. ``mlx-omni-server``) — serves any model
# whose id is NOT ``narrative_model``. The :class:`RoutedLLMClient`
# inspects the ``model`` kwarg at call time: ``model == narrative_model``
# -> Featherless, otherwise local. ``embed()`` always routes local.
# If no MLX server is running and ``classifier_model`` is set to a
# local model id, classifier calls will fail; ``embed()`` will fall
# back to the zero-vector path with a T107 warning. To use the
# default Featherless-only deployment, leave ``classifier_model``
# at its remote default and ``embedding_model`` at the pseudo.
local_mlx_base_url: str = "http://127.0.0.1:10240/v1"
local_mlx_max_concurrent: int = 1
# T112 (Phase 4.5): embedding model identifier. Default is the
# deterministic local pseudo (semantically meaningless but keeps the
# vector pipeline structurally valid). Swap to a real model name
# (e.g. "bge-small-en-v1.5") once the LLMClient implementation
# supports embed() — currently FeatherlessClient does NOT, so a
# non-default value will trigger the zero-vector fallback path
# plus a T107 warning until a different provider is wired in.
# deterministic local pseudo so fresh installs / tests don't need
# any external infra. Override via config.toml to a real model id
# (e.g. ``"mlx-community/bge-small-en-v1.5-bf16"``) once a local
# MLX server is running.
embedding_model: str = "pseudo-sha256-384"
def load_settings() -> Settings:
+95
View File
@@ -0,0 +1,95 @@
"""Local MLX OpenAI-compatible client.
Talks to a locally-running MLX server (e.g., ``mlx-omni-server``) over
the same OpenAI surface that :class:`chat.llm.featherless.FeatherlessClient`
uses, via :class:`openai.AsyncOpenAI`. The underlying server runs MLX
models on Apple Silicon (M-series) for chat completions AND embeddings.
Use cases (Phase 4.5+):
- Classifier traffic moved off Featherless to local MLX (cost + latency).
- Embeddings via ``client.embed`` actually work — Featherless's
``/v1/embeddings`` always returns 500.
Constructor takes a ``base_url`` (e.g., ``"http://127.0.0.1:10240/v1"``)
and an optional ``api_key`` (most local MLX servers don't authenticate;
the OpenAI SDK requires *some* string, so we default to a placeholder).
"""
from __future__ import annotations
import asyncio
from typing import AsyncIterator, Sequence
from openai import AsyncOpenAI
from .client import Message
class LocalMLXClient:
"""OpenAI-compatible client for a local MLX server.
The server is single-process by default (``mlx-omni-server`` loads
one model at a time and swaps on demand). The class-level semaphore
serializes concurrent requests so we never queue more than
``max_concurrent`` at a time — defaults to 1, since MLX inference
on a single M-series device is sequential anyway.
"""
_semaphore: asyncio.Semaphore | None = None
@classmethod
def configure_concurrency(cls, max_concurrent: int) -> None:
cls._semaphore = asyncio.Semaphore(max(1, int(max_concurrent)))
@classmethod
def _sem(cls) -> asyncio.Semaphore:
if cls._semaphore is None:
cls._semaphore = asyncio.Semaphore(1)
return cls._semaphore
def __init__(
self,
base_url: str = "http://127.0.0.1:10240/v1",
api_key: str = "not-needed",
):
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
async def generate(
self, messages: Sequence[Message], *, model: str, **params
) -> str:
async with self._sem():
resp = await self._client.chat.completions.create(
model=model,
messages=[{"role": m.role, "content": m.content} for m in messages],
**params,
)
return resp.choices[0].message.content or ""
async def stream(
self, messages: Sequence[Message], *, model: str, **params
) -> AsyncIterator[str]:
async with self._sem():
stream = await self._client.chat.completions.create(
model=model,
messages=[{"role": m.role, "content": m.content} for m in messages],
stream=True,
**params,
)
async for chunk in stream:
delta = chunk.choices[0].delta.content or ""
if delta:
yield delta
async def embed(self, text: str, *, model: str) -> list[float]:
"""Return an embedding vector for ``text`` using the named model.
Targets ``/v1/embeddings`` on the local MLX server; the server
loads the model on first request and caches it. The embedding
model is independent of the chat model loaded for ``generate``
/ ``stream`` (the server can serve both).
"""
async with self._sem():
resp = await self._client.embeddings.create(
model=model,
input=text,
)
return list(resp.data[0].embedding)
+59
View File
@@ -0,0 +1,59 @@
"""Routed LLM client — splits traffic across multiple backends by model.
Phase 4.5+ deployment: the 24B narrative model stays on Featherless,
the 8B classifier model moves to local MLX, and embeddings run on a
local BGE/MLX model. One :class:`LLMClient` interface, two underlying
backends, dispatched by the ``model`` argument at every call site.
Routing rule: requests whose ``model`` argument matches the configured
``narrative_model`` go to the narrative backend; everything else
(classifier, embeddings, future locally-hosted models) goes to the
local backend.
"""
from __future__ import annotations
from typing import AsyncIterator, Sequence
from .client import LLMClient, Message
class RoutedLLMClient:
"""Delegates to one of two underlying clients based on ``model``.
The narrative client is exercised only when ``model`` exactly equals
``narrative_model`` (the configured remote model id). Everything
else — classifier traffic, embeddings, any future locally-hosted
model — routes to the local client. ``embed`` always routes locally
(the remote provider doesn't support it; see
:class:`chat.llm.featherless.FeatherlessClient.embed`).
"""
def __init__(
self,
*,
narrative: LLMClient,
local: LLMClient,
narrative_model: str,
) -> None:
self._narrative = narrative
self._local = local
self._narrative_model = narrative_model
def _pick(self, model: str) -> LLMClient:
return self._narrative if model == self._narrative_model else self._local
async def generate(
self, messages: Sequence[Message], *, model: str, **params
) -> str:
return await self._pick(model).generate(messages, model=model, **params)
async def stream(
self, messages: Sequence[Message], *, model: str, **params
) -> AsyncIterator[str]:
async for chunk in self._pick(model).stream(messages, model=model, **params):
yield chunk
async def embed(self, text: str, *, model: str) -> list[float]:
# Embeddings always run on the local backend — the remote
# provider doesn't expose a working ``/v1/embeddings`` endpoint.
return await self._local.embed(text, model=model)
+18 -2
View File
@@ -32,14 +32,30 @@ router = APIRouter()
def get_llm_client(request: Request) -> LLMClient:
"""Production LLM client. Tests override this via ``app.dependency_overrides``."""
"""Production LLM client. Tests override this via ``app.dependency_overrides``.
Returns a :class:`chat.llm.router.RoutedLLMClient` that splits
traffic: the narrative model goes to Featherless, the classifier
+ embeddings go to the local MLX server (``mlx-omni-server``).
Both backends share the OpenAI-compatible surface, so the routing
is invisible to call sites — they just pass ``model=...`` and the
router picks the backend.
"""
settings = request.app.state.settings
from chat.llm.featherless import FeatherlessClient
from chat.llm.local_mlx import LocalMLXClient
from chat.llm.router import RoutedLLMClient
return FeatherlessClient(
narrative = FeatherlessClient(
api_key=settings.featherless_api_key,
base_url=settings.featherless_base_url,
)
local = LocalMLXClient(base_url=settings.local_mlx_base_url)
return RoutedLLMClient(
narrative=narrative,
local=local,
narrative_model=settings.narrative_model,
)
def _parse_holding(text: str) -> list[str]:
+38
View File
@@ -0,0 +1,38 @@
#!/usr/bin/env bash
# Start the local mlx-omni-server that serves the classifier + embedding
# models. The chat app's RoutedLLMClient routes everything except the
# narrative model to this server; with no MLX server running, classifier
# calls fail and embeddings degrade to the zero-vector fallback.
#
# Run in the foreground:
# ./scripts/start_mlx_server.sh
# Run as a background daemon (logs to data/mlx-server.log):
# ./scripts/start_mlx_server.sh --daemon
#
# Models are pulled from Hugging Face on first request; expect a delay
# the first time you exercise the classifier or embedding path.
set -euo pipefail
REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
VENV="${REPO_ROOT}/.mlx-venv"
LOG="${REPO_ROOT}/data/mlx-server.log"
PORT="${MLX_PORT:-10240}"
HOST="${MLX_HOST:-127.0.0.1}"
if [ ! -x "${VENV}/bin/mlx-omni-server" ]; then
echo "error: mlx-omni-server not installed in ${VENV}" >&2
echo "create the venv with:" >&2
echo " python3.12 -m venv ${VENV} && ${VENV}/bin/pip install mlx-omni-server" >&2
exit 1
fi
if [ "${1:-}" = "--daemon" ]; then
mkdir -p "$(dirname "${LOG}")"
nohup "${VENV}/bin/mlx-omni-server" --host "${HOST}" --port "${PORT}" \
>>"${LOG}" 2>&1 &
echo "mlx-omni-server started in background (pid $!)"
echo "logs: ${LOG}"
else
exec "${VENV}/bin/mlx-omni-server" --host "${HOST}" --port "${PORT}"
fi
+84
View File
@@ -0,0 +1,84 @@
"""Tests for LocalMLXClient (Phase 4.5+).
Talks to a local mlx-omni-server over the OpenAI-compatible surface.
We don't spin up a real server in tests — instead we monkey-patch the
underlying ``AsyncOpenAI`` instance to assert on the request shape and
return canned responses. The semaphore behavior is shared with
FeatherlessClient (same pattern), so we don't re-test that here.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
from chat.llm.client import Message
from chat.llm.local_mlx import LocalMLXClient
class _FakeChatCompletions:
def __init__(self, response):
self.response = response
self.calls = []
async def create(self, **kw):
self.calls.append(kw)
return self.response
class _FakeEmbeddings:
def __init__(self, vector):
self.vector = vector
self.calls = []
async def create(self, **kw):
self.calls.append(kw)
return SimpleNamespace(data=[SimpleNamespace(embedding=self.vector)])
@pytest.mark.asyncio
async def test_local_mlx_client_generate_calls_chat_completions():
client = LocalMLXClient(base_url="http://localhost:10240/v1")
fake_response = SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content="hello"))]
)
fake_chat = _FakeChatCompletions(fake_response)
client._client.chat = SimpleNamespace(completions=fake_chat)
out = await client.generate(
[Message(role="user", content="hi")],
model="mlx-community/Hermes-3-Llama-3.1-8B-8bit",
)
assert out == "hello"
assert len(fake_chat.calls) == 1
assert fake_chat.calls[0]["model"] == "mlx-community/Hermes-3-Llama-3.1-8B-8bit"
assert fake_chat.calls[0]["messages"] == [{"role": "user", "content": "hi"}]
@pytest.mark.asyncio
async def test_local_mlx_client_embed_returns_vector():
"""``embed()`` actually works on this client (unlike FeatherlessClient
which raises NotImplementedError) — the local MLX server has a real
``/v1/embeddings`` endpoint backed by an MLX-quantized model.
"""
client = LocalMLXClient()
canned = [0.1, 0.2, 0.3, 0.4]
fake_embeddings = _FakeEmbeddings(canned)
client._client.embeddings = fake_embeddings
out = await client.embed("hello", model="mlx-community/bge-small-en-v1.5-bf16")
assert out == canned
assert fake_embeddings.calls[0]["model"] == "mlx-community/bge-small-en-v1.5-bf16"
assert fake_embeddings.calls[0]["input"] == "hello"
@pytest.mark.asyncio
async def test_local_mlx_client_default_base_url():
"""Default base_url targets ``mlx-omni-server`` on its standard port."""
client = LocalMLXClient()
# AsyncOpenAI normalizes trailing-slash differences; just check the
# configured host:port appears in the underlying client config.
assert "127.0.0.1:10240" in str(client._client.base_url)
+108
View File
@@ -0,0 +1,108 @@
"""Tests for RoutedLLMClient (Phase 4.5+).
Splits traffic across two underlying clients based on the ``model``
kwarg. We use simple stub clients to assert the router picks the
correct backend for each call.
"""
from __future__ import annotations
from typing import AsyncIterator, Sequence
import pytest
from chat.llm.client import Message
from chat.llm.router import RoutedLLMClient
class _StubClient:
def __init__(self, name: str):
self.name = name
self.generate_calls: list[str] = []
self.stream_calls: list[str] = []
self.embed_calls: list[str] = []
async def generate(self, messages, *, model, **params) -> str:
self.generate_calls.append(model)
return f"{self.name}:{model}"
async def stream(self, messages, *, model, **params) -> AsyncIterator[str]:
self.stream_calls.append(model)
yield f"{self.name}:{model}"
async def embed(self, text, *, model) -> list[float]:
self.embed_calls.append(model)
return [1.0, 2.0]
@pytest.mark.asyncio
async def test_router_generate_dispatches_narrative_to_narrative_backend():
narrative = _StubClient("narrative")
local = _StubClient("local")
router = RoutedLLMClient(
narrative=narrative,
local=local,
narrative_model="big-model",
)
out = await router.generate([Message(role="user", content="hi")], model="big-model")
assert out == "narrative:big-model"
assert narrative.generate_calls == ["big-model"]
assert local.generate_calls == []
@pytest.mark.asyncio
async def test_router_generate_dispatches_classifier_to_local_backend():
narrative = _StubClient("narrative")
local = _StubClient("local")
router = RoutedLLMClient(
narrative=narrative,
local=local,
narrative_model="big-model",
)
out = await router.generate(
[Message(role="user", content="hi")], model="small-model"
)
assert out == "local:small-model"
assert local.generate_calls == ["small-model"]
assert narrative.generate_calls == []
@pytest.mark.asyncio
async def test_router_stream_dispatches_by_model():
narrative = _StubClient("narrative")
local = _StubClient("local")
router = RoutedLLMClient(
narrative=narrative, local=local, narrative_model="big-model"
)
chunks_big = [c async for c in router.stream(
[Message(role="user", content="hi")], model="big-model"
)]
chunks_small = [c async for c in router.stream(
[Message(role="user", content="hi")], model="other-model"
)]
assert chunks_big == ["narrative:big-model"]
assert chunks_small == ["local:other-model"]
@pytest.mark.asyncio
async def test_router_embed_always_routes_to_local():
"""Embeddings always run locally — the remote provider doesn't
expose a working ``/v1/embeddings``, so the router never sends
embed calls there even if the model name happens to look 'remote'."""
narrative = _StubClient("narrative")
local = _StubClient("local")
router = RoutedLLMClient(
narrative=narrative, local=local, narrative_model="big-model"
)
out = await router.embed("hello", model="any-embedding-model")
assert out == [1.0, 2.0]
assert local.embed_calls == ["any-embedding-model"]
assert narrative.embed_calls == []