feat: split classifier + embeddings to local mlx-omni-server, narrative stays on Featherless
Adds RoutedLLMClient that dispatches by model name: requests matching Settings.narrative_model go to Featherless, everything else (classifier calls, embed) goes to a local MLX server. The local server is mlx-omni-server (separate venv at .mlx-venv) and exposes the standard OpenAI surface at http://127.0.0.1:10240/v1. LocalMLXClient mirrors FeatherlessClient (AsyncOpenAI under the hood) but with a working embed() — Featherless's /v1/embeddings always returns 500 with completions_error, so the router unconditionally sends embed traffic to the local backend. Production deployment overrides via data/config.toml: - classifier_model = mlx-community/Hermes-3-Llama-3.1-8B-8bit (~8 GB) - embedding_model = mlx-community/bge-small-en-v1.5-bf16 (~150 MB, 384 dim — matches existing schema, no migration) Defaults stay remote / pseudo so fresh installs and tests need no external infra. Smoke-tested live: classifier returns expected output, BGE produces correctly-clustering 384-dim vectors (cat-on-mat closer to cat-on-rug than to quantum-mechanics). scripts/start_mlx_server.sh starts the daemon (foreground or --daemon). .mlx-venv/ added to .gitignore. Suite: 464 passed (was 457 → +7 new across LocalMLXClient + Router).
This commit is contained in:
+15
-4
@@ -72,17 +72,28 @@ async def lifespan(app: FastAPI):
|
||||
# (free / lower paid tiers cap at 2). Shared across all
|
||||
# FeatherlessClient instances in the process.
|
||||
from chat.llm.featherless import FeatherlessClient
|
||||
from chat.llm.local_mlx import LocalMLXClient
|
||||
from chat.llm.router import RoutedLLMClient
|
||||
|
||||
FeatherlessClient.configure_concurrency(settings.featherless_max_concurrent)
|
||||
LocalMLXClient.configure_concurrency(settings.local_mlx_max_concurrent)
|
||||
|
||||
# Background worker for the async significance pass (T22). Each job
|
||||
# constructs a fresh FeatherlessClient via the factory; tests can
|
||||
# disable enqueue by toggling ``app.state.background_worker.enabled``.
|
||||
# Background workers (significance scoring, embedding indexer)
|
||||
# construct a fresh client per job via the factory. Workers route
|
||||
# through the same RoutedLLMClient as request-time traffic so the
|
||||
# narrative model still goes to Featherless and the classifier +
|
||||
# embeddings hit the local MLX server.
|
||||
def _factory():
|
||||
return FeatherlessClient(
|
||||
narrative = FeatherlessClient(
|
||||
api_key=settings.featherless_api_key,
|
||||
base_url=settings.featherless_base_url,
|
||||
)
|
||||
local = LocalMLXClient(base_url=settings.local_mlx_base_url)
|
||||
return RoutedLLMClient(
|
||||
narrative=narrative,
|
||||
local=local,
|
||||
narrative_model=settings.narrative_model,
|
||||
)
|
||||
|
||||
worker = BackgroundWorker(settings, llm_client_factory=_factory)
|
||||
await worker.start()
|
||||
|
||||
+15
-6
@@ -39,13 +39,22 @@ class Settings(BaseModel):
|
||||
data_dir: Path = REPO_ROOT / "data"
|
||||
bind_host: str = "127.0.0.1"
|
||||
bind_port: int = 8000
|
||||
# Local MLX server (e.g. ``mlx-omni-server``) — serves any model
|
||||
# whose id is NOT ``narrative_model``. The :class:`RoutedLLMClient`
|
||||
# inspects the ``model`` kwarg at call time: ``model == narrative_model``
|
||||
# -> Featherless, otherwise local. ``embed()`` always routes local.
|
||||
# If no MLX server is running and ``classifier_model`` is set to a
|
||||
# local model id, classifier calls will fail; ``embed()`` will fall
|
||||
# back to the zero-vector path with a T107 warning. To use the
|
||||
# default Featherless-only deployment, leave ``classifier_model``
|
||||
# at its remote default and ``embedding_model`` at the pseudo.
|
||||
local_mlx_base_url: str = "http://127.0.0.1:10240/v1"
|
||||
local_mlx_max_concurrent: int = 1
|
||||
# T112 (Phase 4.5): embedding model identifier. Default is the
|
||||
# deterministic local pseudo (semantically meaningless but keeps the
|
||||
# vector pipeline structurally valid). Swap to a real model name
|
||||
# (e.g. "bge-small-en-v1.5") once the LLMClient implementation
|
||||
# supports embed() — currently FeatherlessClient does NOT, so a
|
||||
# non-default value will trigger the zero-vector fallback path
|
||||
# plus a T107 warning until a different provider is wired in.
|
||||
# deterministic local pseudo so fresh installs / tests don't need
|
||||
# any external infra. Override via config.toml to a real model id
|
||||
# (e.g. ``"mlx-community/bge-small-en-v1.5-bf16"``) once a local
|
||||
# MLX server is running.
|
||||
embedding_model: str = "pseudo-sha256-384"
|
||||
|
||||
def load_settings() -> Settings:
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
"""Local MLX OpenAI-compatible client.
|
||||
|
||||
Talks to a locally-running MLX server (e.g., ``mlx-omni-server``) over
|
||||
the same OpenAI surface that :class:`chat.llm.featherless.FeatherlessClient`
|
||||
uses, via :class:`openai.AsyncOpenAI`. The underlying server runs MLX
|
||||
models on Apple Silicon (M-series) for chat completions AND embeddings.
|
||||
|
||||
Use cases (Phase 4.5+):
|
||||
- Classifier traffic moved off Featherless to local MLX (cost + latency).
|
||||
- Embeddings via ``client.embed`` actually work — Featherless's
|
||||
``/v1/embeddings`` always returns 500.
|
||||
|
||||
Constructor takes a ``base_url`` (e.g., ``"http://127.0.0.1:10240/v1"``)
|
||||
and an optional ``api_key`` (most local MLX servers don't authenticate;
|
||||
the OpenAI SDK requires *some* string, so we default to a placeholder).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
from typing import AsyncIterator, Sequence
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from .client import Message
|
||||
|
||||
|
||||
class LocalMLXClient:
|
||||
"""OpenAI-compatible client for a local MLX server.
|
||||
|
||||
The server is single-process by default (``mlx-omni-server`` loads
|
||||
one model at a time and swaps on demand). The class-level semaphore
|
||||
serializes concurrent requests so we never queue more than
|
||||
``max_concurrent`` at a time — defaults to 1, since MLX inference
|
||||
on a single M-series device is sequential anyway.
|
||||
"""
|
||||
|
||||
_semaphore: asyncio.Semaphore | None = None
|
||||
|
||||
@classmethod
|
||||
def configure_concurrency(cls, max_concurrent: int) -> None:
|
||||
cls._semaphore = asyncio.Semaphore(max(1, int(max_concurrent)))
|
||||
|
||||
@classmethod
|
||||
def _sem(cls) -> asyncio.Semaphore:
|
||||
if cls._semaphore is None:
|
||||
cls._semaphore = asyncio.Semaphore(1)
|
||||
return cls._semaphore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://127.0.0.1:10240/v1",
|
||||
api_key: str = "not-needed",
|
||||
):
|
||||
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
async def generate(
|
||||
self, messages: Sequence[Message], *, model: str, **params
|
||||
) -> str:
|
||||
async with self._sem():
|
||||
resp = await self._client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": m.role, "content": m.content} for m in messages],
|
||||
**params,
|
||||
)
|
||||
return resp.choices[0].message.content or ""
|
||||
|
||||
async def stream(
|
||||
self, messages: Sequence[Message], *, model: str, **params
|
||||
) -> AsyncIterator[str]:
|
||||
async with self._sem():
|
||||
stream = await self._client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": m.role, "content": m.content} for m in messages],
|
||||
stream=True,
|
||||
**params,
|
||||
)
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta.content or ""
|
||||
if delta:
|
||||
yield delta
|
||||
|
||||
async def embed(self, text: str, *, model: str) -> list[float]:
|
||||
"""Return an embedding vector for ``text`` using the named model.
|
||||
|
||||
Targets ``/v1/embeddings`` on the local MLX server; the server
|
||||
loads the model on first request and caches it. The embedding
|
||||
model is independent of the chat model loaded for ``generate``
|
||||
/ ``stream`` (the server can serve both).
|
||||
"""
|
||||
async with self._sem():
|
||||
resp = await self._client.embeddings.create(
|
||||
model=model,
|
||||
input=text,
|
||||
)
|
||||
return list(resp.data[0].embedding)
|
||||
@@ -0,0 +1,59 @@
|
||||
"""Routed LLM client — splits traffic across multiple backends by model.
|
||||
|
||||
Phase 4.5+ deployment: the 24B narrative model stays on Featherless,
|
||||
the 8B classifier model moves to local MLX, and embeddings run on a
|
||||
local BGE/MLX model. One :class:`LLMClient` interface, two underlying
|
||||
backends, dispatched by the ``model`` argument at every call site.
|
||||
|
||||
Routing rule: requests whose ``model`` argument matches the configured
|
||||
``narrative_model`` go to the narrative backend; everything else
|
||||
(classifier, embeddings, future locally-hosted models) goes to the
|
||||
local backend.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import AsyncIterator, Sequence
|
||||
|
||||
from .client import LLMClient, Message
|
||||
|
||||
|
||||
class RoutedLLMClient:
|
||||
"""Delegates to one of two underlying clients based on ``model``.
|
||||
|
||||
The narrative client is exercised only when ``model`` exactly equals
|
||||
``narrative_model`` (the configured remote model id). Everything
|
||||
else — classifier traffic, embeddings, any future locally-hosted
|
||||
model — routes to the local client. ``embed`` always routes locally
|
||||
(the remote provider doesn't support it; see
|
||||
:class:`chat.llm.featherless.FeatherlessClient.embed`).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
narrative: LLMClient,
|
||||
local: LLMClient,
|
||||
narrative_model: str,
|
||||
) -> None:
|
||||
self._narrative = narrative
|
||||
self._local = local
|
||||
self._narrative_model = narrative_model
|
||||
|
||||
def _pick(self, model: str) -> LLMClient:
|
||||
return self._narrative if model == self._narrative_model else self._local
|
||||
|
||||
async def generate(
|
||||
self, messages: Sequence[Message], *, model: str, **params
|
||||
) -> str:
|
||||
return await self._pick(model).generate(messages, model=model, **params)
|
||||
|
||||
async def stream(
|
||||
self, messages: Sequence[Message], *, model: str, **params
|
||||
) -> AsyncIterator[str]:
|
||||
async for chunk in self._pick(model).stream(messages, model=model, **params):
|
||||
yield chunk
|
||||
|
||||
async def embed(self, text: str, *, model: str) -> list[float]:
|
||||
# Embeddings always run on the local backend — the remote
|
||||
# provider doesn't expose a working ``/v1/embeddings`` endpoint.
|
||||
return await self._local.embed(text, model=model)
|
||||
+18
-2
@@ -32,14 +32,30 @@ router = APIRouter()
|
||||
|
||||
|
||||
def get_llm_client(request: Request) -> LLMClient:
|
||||
"""Production LLM client. Tests override this via ``app.dependency_overrides``."""
|
||||
"""Production LLM client. Tests override this via ``app.dependency_overrides``.
|
||||
|
||||
Returns a :class:`chat.llm.router.RoutedLLMClient` that splits
|
||||
traffic: the narrative model goes to Featherless, the classifier
|
||||
+ embeddings go to the local MLX server (``mlx-omni-server``).
|
||||
Both backends share the OpenAI-compatible surface, so the routing
|
||||
is invisible to call sites — they just pass ``model=...`` and the
|
||||
router picks the backend.
|
||||
"""
|
||||
settings = request.app.state.settings
|
||||
from chat.llm.featherless import FeatherlessClient
|
||||
from chat.llm.local_mlx import LocalMLXClient
|
||||
from chat.llm.router import RoutedLLMClient
|
||||
|
||||
return FeatherlessClient(
|
||||
narrative = FeatherlessClient(
|
||||
api_key=settings.featherless_api_key,
|
||||
base_url=settings.featherless_base_url,
|
||||
)
|
||||
local = LocalMLXClient(base_url=settings.local_mlx_base_url)
|
||||
return RoutedLLMClient(
|
||||
narrative=narrative,
|
||||
local=local,
|
||||
narrative_model=settings.narrative_model,
|
||||
)
|
||||
|
||||
|
||||
def _parse_holding(text: str) -> list[str]:
|
||||
|
||||
Reference in New Issue
Block a user