feat: split classifier + embeddings to local mlx-omni-server, narrative stays on Featherless

Adds RoutedLLMClient that dispatches by model name: requests matching
Settings.narrative_model go to Featherless, everything else (classifier
calls, embed) goes to a local MLX server. The local server is
mlx-omni-server (separate venv at .mlx-venv) and exposes the standard
OpenAI surface at http://127.0.0.1:10240/v1.

LocalMLXClient mirrors FeatherlessClient (AsyncOpenAI under the hood)
but with a working embed() — Featherless's /v1/embeddings always
returns 500 with completions_error, so the router unconditionally
sends embed traffic to the local backend.

Production deployment overrides via data/config.toml:
- classifier_model = mlx-community/Hermes-3-Llama-3.1-8B-8bit (~8 GB)
- embedding_model = mlx-community/bge-small-en-v1.5-bf16 (~150 MB,
  384 dim — matches existing schema, no migration)

Defaults stay remote / pseudo so fresh installs and tests need no
external infra. Smoke-tested live: classifier returns expected output,
BGE produces correctly-clustering 384-dim vectors (cat-on-mat closer
to cat-on-rug than to quantum-mechanics).

scripts/start_mlx_server.sh starts the daemon (foreground or --daemon).
.mlx-venv/ added to .gitignore.

Suite: 464 passed (was 457 → +7 new across LocalMLXClient + Router).
This commit is contained in:
Joseph Doherty
2026-04-27 12:05:41 -04:00
parent b3d78c1603
commit fe9c497038
9 changed files with 433 additions and 12 deletions
+15 -4
View File
@@ -72,17 +72,28 @@ async def lifespan(app: FastAPI):
# (free / lower paid tiers cap at 2). Shared across all
# FeatherlessClient instances in the process.
from chat.llm.featherless import FeatherlessClient
from chat.llm.local_mlx import LocalMLXClient
from chat.llm.router import RoutedLLMClient
FeatherlessClient.configure_concurrency(settings.featherless_max_concurrent)
LocalMLXClient.configure_concurrency(settings.local_mlx_max_concurrent)
# Background worker for the async significance pass (T22). Each job
# constructs a fresh FeatherlessClient via the factory; tests can
# disable enqueue by toggling ``app.state.background_worker.enabled``.
# Background workers (significance scoring, embedding indexer)
# construct a fresh client per job via the factory. Workers route
# through the same RoutedLLMClient as request-time traffic so the
# narrative model still goes to Featherless and the classifier +
# embeddings hit the local MLX server.
def _factory():
return FeatherlessClient(
narrative = FeatherlessClient(
api_key=settings.featherless_api_key,
base_url=settings.featherless_base_url,
)
local = LocalMLXClient(base_url=settings.local_mlx_base_url)
return RoutedLLMClient(
narrative=narrative,
local=local,
narrative_model=settings.narrative_model,
)
worker = BackgroundWorker(settings, llm_client_factory=_factory)
await worker.start()
+15 -6
View File
@@ -39,13 +39,22 @@ class Settings(BaseModel):
data_dir: Path = REPO_ROOT / "data"
bind_host: str = "127.0.0.1"
bind_port: int = 8000
# Local MLX server (e.g. ``mlx-omni-server``) — serves any model
# whose id is NOT ``narrative_model``. The :class:`RoutedLLMClient`
# inspects the ``model`` kwarg at call time: ``model == narrative_model``
# -> Featherless, otherwise local. ``embed()`` always routes local.
# If no MLX server is running and ``classifier_model`` is set to a
# local model id, classifier calls will fail; ``embed()`` will fall
# back to the zero-vector path with a T107 warning. To use the
# default Featherless-only deployment, leave ``classifier_model``
# at its remote default and ``embedding_model`` at the pseudo.
local_mlx_base_url: str = "http://127.0.0.1:10240/v1"
local_mlx_max_concurrent: int = 1
# T112 (Phase 4.5): embedding model identifier. Default is the
# deterministic local pseudo (semantically meaningless but keeps the
# vector pipeline structurally valid). Swap to a real model name
# (e.g. "bge-small-en-v1.5") once the LLMClient implementation
# supports embed() — currently FeatherlessClient does NOT, so a
# non-default value will trigger the zero-vector fallback path
# plus a T107 warning until a different provider is wired in.
# deterministic local pseudo so fresh installs / tests don't need
# any external infra. Override via config.toml to a real model id
# (e.g. ``"mlx-community/bge-small-en-v1.5-bf16"``) once a local
# MLX server is running.
embedding_model: str = "pseudo-sha256-384"
def load_settings() -> Settings:
+95
View File
@@ -0,0 +1,95 @@
"""Local MLX OpenAI-compatible client.
Talks to a locally-running MLX server (e.g., ``mlx-omni-server``) over
the same OpenAI surface that :class:`chat.llm.featherless.FeatherlessClient`
uses, via :class:`openai.AsyncOpenAI`. The underlying server runs MLX
models on Apple Silicon (M-series) for chat completions AND embeddings.
Use cases (Phase 4.5+):
- Classifier traffic moved off Featherless to local MLX (cost + latency).
- Embeddings via ``client.embed`` actually work — Featherless's
``/v1/embeddings`` always returns 500.
Constructor takes a ``base_url`` (e.g., ``"http://127.0.0.1:10240/v1"``)
and an optional ``api_key`` (most local MLX servers don't authenticate;
the OpenAI SDK requires *some* string, so we default to a placeholder).
"""
from __future__ import annotations
import asyncio
from typing import AsyncIterator, Sequence
from openai import AsyncOpenAI
from .client import Message
class LocalMLXClient:
"""OpenAI-compatible client for a local MLX server.
The server is single-process by default (``mlx-omni-server`` loads
one model at a time and swaps on demand). The class-level semaphore
serializes concurrent requests so we never queue more than
``max_concurrent`` at a time — defaults to 1, since MLX inference
on a single M-series device is sequential anyway.
"""
_semaphore: asyncio.Semaphore | None = None
@classmethod
def configure_concurrency(cls, max_concurrent: int) -> None:
cls._semaphore = asyncio.Semaphore(max(1, int(max_concurrent)))
@classmethod
def _sem(cls) -> asyncio.Semaphore:
if cls._semaphore is None:
cls._semaphore = asyncio.Semaphore(1)
return cls._semaphore
def __init__(
self,
base_url: str = "http://127.0.0.1:10240/v1",
api_key: str = "not-needed",
):
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
async def generate(
self, messages: Sequence[Message], *, model: str, **params
) -> str:
async with self._sem():
resp = await self._client.chat.completions.create(
model=model,
messages=[{"role": m.role, "content": m.content} for m in messages],
**params,
)
return resp.choices[0].message.content or ""
async def stream(
self, messages: Sequence[Message], *, model: str, **params
) -> AsyncIterator[str]:
async with self._sem():
stream = await self._client.chat.completions.create(
model=model,
messages=[{"role": m.role, "content": m.content} for m in messages],
stream=True,
**params,
)
async for chunk in stream:
delta = chunk.choices[0].delta.content or ""
if delta:
yield delta
async def embed(self, text: str, *, model: str) -> list[float]:
"""Return an embedding vector for ``text`` using the named model.
Targets ``/v1/embeddings`` on the local MLX server; the server
loads the model on first request and caches it. The embedding
model is independent of the chat model loaded for ``generate``
/ ``stream`` (the server can serve both).
"""
async with self._sem():
resp = await self._client.embeddings.create(
model=model,
input=text,
)
return list(resp.data[0].embedding)
+59
View File
@@ -0,0 +1,59 @@
"""Routed LLM client — splits traffic across multiple backends by model.
Phase 4.5+ deployment: the 24B narrative model stays on Featherless,
the 8B classifier model moves to local MLX, and embeddings run on a
local BGE/MLX model. One :class:`LLMClient` interface, two underlying
backends, dispatched by the ``model`` argument at every call site.
Routing rule: requests whose ``model`` argument matches the configured
``narrative_model`` go to the narrative backend; everything else
(classifier, embeddings, future locally-hosted models) goes to the
local backend.
"""
from __future__ import annotations
from typing import AsyncIterator, Sequence
from .client import LLMClient, Message
class RoutedLLMClient:
"""Delegates to one of two underlying clients based on ``model``.
The narrative client is exercised only when ``model`` exactly equals
``narrative_model`` (the configured remote model id). Everything
else — classifier traffic, embeddings, any future locally-hosted
model — routes to the local client. ``embed`` always routes locally
(the remote provider doesn't support it; see
:class:`chat.llm.featherless.FeatherlessClient.embed`).
"""
def __init__(
self,
*,
narrative: LLMClient,
local: LLMClient,
narrative_model: str,
) -> None:
self._narrative = narrative
self._local = local
self._narrative_model = narrative_model
def _pick(self, model: str) -> LLMClient:
return self._narrative if model == self._narrative_model else self._local
async def generate(
self, messages: Sequence[Message], *, model: str, **params
) -> str:
return await self._pick(model).generate(messages, model=model, **params)
async def stream(
self, messages: Sequence[Message], *, model: str, **params
) -> AsyncIterator[str]:
async for chunk in self._pick(model).stream(messages, model=model, **params):
yield chunk
async def embed(self, text: str, *, model: str) -> list[float]:
# Embeddings always run on the local backend — the remote
# provider doesn't expose a working ``/v1/embeddings`` endpoint.
return await self._local.embed(text, model=model)
+18 -2
View File
@@ -32,14 +32,30 @@ router = APIRouter()
def get_llm_client(request: Request) -> LLMClient:
"""Production LLM client. Tests override this via ``app.dependency_overrides``."""
"""Production LLM client. Tests override this via ``app.dependency_overrides``.
Returns a :class:`chat.llm.router.RoutedLLMClient` that splits
traffic: the narrative model goes to Featherless, the classifier
+ embeddings go to the local MLX server (``mlx-omni-server``).
Both backends share the OpenAI-compatible surface, so the routing
is invisible to call sites — they just pass ``model=...`` and the
router picks the backend.
"""
settings = request.app.state.settings
from chat.llm.featherless import FeatherlessClient
from chat.llm.local_mlx import LocalMLXClient
from chat.llm.router import RoutedLLMClient
return FeatherlessClient(
narrative = FeatherlessClient(
api_key=settings.featherless_api_key,
base_url=settings.featherless_base_url,
)
local = LocalMLXClient(base_url=settings.local_mlx_base_url)
return RoutedLLMClient(
narrative=narrative,
local=local,
narrative_model=settings.narrative_model,
)
def _parse_holding(text: str) -> list[str]: