feat: split classifier + embeddings to local mlx-omni-server, narrative stays on Featherless
Adds RoutedLLMClient that dispatches by model name: requests matching Settings.narrative_model go to Featherless, everything else (classifier calls, embed) goes to a local MLX server. The local server is mlx-omni-server (separate venv at .mlx-venv) and exposes the standard OpenAI surface at http://127.0.0.1:10240/v1. LocalMLXClient mirrors FeatherlessClient (AsyncOpenAI under the hood) but with a working embed() — Featherless's /v1/embeddings always returns 500 with completions_error, so the router unconditionally sends embed traffic to the local backend. Production deployment overrides via data/config.toml: - classifier_model = mlx-community/Hermes-3-Llama-3.1-8B-8bit (~8 GB) - embedding_model = mlx-community/bge-small-en-v1.5-bf16 (~150 MB, 384 dim — matches existing schema, no migration) Defaults stay remote / pseudo so fresh installs and tests need no external infra. Smoke-tested live: classifier returns expected output, BGE produces correctly-clustering 384-dim vectors (cat-on-mat closer to cat-on-rug than to quantum-mechanics). scripts/start_mlx_server.sh starts the daemon (foreground or --daemon). .mlx-venv/ added to .gitignore. Suite: 464 passed (was 457 → +7 new across LocalMLXClient + Router).
This commit is contained in:
@@ -5,6 +5,7 @@ data/
|
||||
|
||||
# Python
|
||||
.venv/
|
||||
.mlx-venv/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.pytest_cache/
|
||||
|
||||
+15
-4
@@ -72,17 +72,28 @@ async def lifespan(app: FastAPI):
|
||||
# (free / lower paid tiers cap at 2). Shared across all
|
||||
# FeatherlessClient instances in the process.
|
||||
from chat.llm.featherless import FeatherlessClient
|
||||
from chat.llm.local_mlx import LocalMLXClient
|
||||
from chat.llm.router import RoutedLLMClient
|
||||
|
||||
FeatherlessClient.configure_concurrency(settings.featherless_max_concurrent)
|
||||
LocalMLXClient.configure_concurrency(settings.local_mlx_max_concurrent)
|
||||
|
||||
# Background worker for the async significance pass (T22). Each job
|
||||
# constructs a fresh FeatherlessClient via the factory; tests can
|
||||
# disable enqueue by toggling ``app.state.background_worker.enabled``.
|
||||
# Background workers (significance scoring, embedding indexer)
|
||||
# construct a fresh client per job via the factory. Workers route
|
||||
# through the same RoutedLLMClient as request-time traffic so the
|
||||
# narrative model still goes to Featherless and the classifier +
|
||||
# embeddings hit the local MLX server.
|
||||
def _factory():
|
||||
return FeatherlessClient(
|
||||
narrative = FeatherlessClient(
|
||||
api_key=settings.featherless_api_key,
|
||||
base_url=settings.featherless_base_url,
|
||||
)
|
||||
local = LocalMLXClient(base_url=settings.local_mlx_base_url)
|
||||
return RoutedLLMClient(
|
||||
narrative=narrative,
|
||||
local=local,
|
||||
narrative_model=settings.narrative_model,
|
||||
)
|
||||
|
||||
worker = BackgroundWorker(settings, llm_client_factory=_factory)
|
||||
await worker.start()
|
||||
|
||||
+15
-6
@@ -39,13 +39,22 @@ class Settings(BaseModel):
|
||||
data_dir: Path = REPO_ROOT / "data"
|
||||
bind_host: str = "127.0.0.1"
|
||||
bind_port: int = 8000
|
||||
# Local MLX server (e.g. ``mlx-omni-server``) — serves any model
|
||||
# whose id is NOT ``narrative_model``. The :class:`RoutedLLMClient`
|
||||
# inspects the ``model`` kwarg at call time: ``model == narrative_model``
|
||||
# -> Featherless, otherwise local. ``embed()`` always routes local.
|
||||
# If no MLX server is running and ``classifier_model`` is set to a
|
||||
# local model id, classifier calls will fail; ``embed()`` will fall
|
||||
# back to the zero-vector path with a T107 warning. To use the
|
||||
# default Featherless-only deployment, leave ``classifier_model``
|
||||
# at its remote default and ``embedding_model`` at the pseudo.
|
||||
local_mlx_base_url: str = "http://127.0.0.1:10240/v1"
|
||||
local_mlx_max_concurrent: int = 1
|
||||
# T112 (Phase 4.5): embedding model identifier. Default is the
|
||||
# deterministic local pseudo (semantically meaningless but keeps the
|
||||
# vector pipeline structurally valid). Swap to a real model name
|
||||
# (e.g. "bge-small-en-v1.5") once the LLMClient implementation
|
||||
# supports embed() — currently FeatherlessClient does NOT, so a
|
||||
# non-default value will trigger the zero-vector fallback path
|
||||
# plus a T107 warning until a different provider is wired in.
|
||||
# deterministic local pseudo so fresh installs / tests don't need
|
||||
# any external infra. Override via config.toml to a real model id
|
||||
# (e.g. ``"mlx-community/bge-small-en-v1.5-bf16"``) once a local
|
||||
# MLX server is running.
|
||||
embedding_model: str = "pseudo-sha256-384"
|
||||
|
||||
def load_settings() -> Settings:
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
"""Local MLX OpenAI-compatible client.
|
||||
|
||||
Talks to a locally-running MLX server (e.g., ``mlx-omni-server``) over
|
||||
the same OpenAI surface that :class:`chat.llm.featherless.FeatherlessClient`
|
||||
uses, via :class:`openai.AsyncOpenAI`. The underlying server runs MLX
|
||||
models on Apple Silicon (M-series) for chat completions AND embeddings.
|
||||
|
||||
Use cases (Phase 4.5+):
|
||||
- Classifier traffic moved off Featherless to local MLX (cost + latency).
|
||||
- Embeddings via ``client.embed`` actually work — Featherless's
|
||||
``/v1/embeddings`` always returns 500.
|
||||
|
||||
Constructor takes a ``base_url`` (e.g., ``"http://127.0.0.1:10240/v1"``)
|
||||
and an optional ``api_key`` (most local MLX servers don't authenticate;
|
||||
the OpenAI SDK requires *some* string, so we default to a placeholder).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
from typing import AsyncIterator, Sequence
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from .client import Message
|
||||
|
||||
|
||||
class LocalMLXClient:
|
||||
"""OpenAI-compatible client for a local MLX server.
|
||||
|
||||
The server is single-process by default (``mlx-omni-server`` loads
|
||||
one model at a time and swaps on demand). The class-level semaphore
|
||||
serializes concurrent requests so we never queue more than
|
||||
``max_concurrent`` at a time — defaults to 1, since MLX inference
|
||||
on a single M-series device is sequential anyway.
|
||||
"""
|
||||
|
||||
_semaphore: asyncio.Semaphore | None = None
|
||||
|
||||
@classmethod
|
||||
def configure_concurrency(cls, max_concurrent: int) -> None:
|
||||
cls._semaphore = asyncio.Semaphore(max(1, int(max_concurrent)))
|
||||
|
||||
@classmethod
|
||||
def _sem(cls) -> asyncio.Semaphore:
|
||||
if cls._semaphore is None:
|
||||
cls._semaphore = asyncio.Semaphore(1)
|
||||
return cls._semaphore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://127.0.0.1:10240/v1",
|
||||
api_key: str = "not-needed",
|
||||
):
|
||||
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
async def generate(
|
||||
self, messages: Sequence[Message], *, model: str, **params
|
||||
) -> str:
|
||||
async with self._sem():
|
||||
resp = await self._client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": m.role, "content": m.content} for m in messages],
|
||||
**params,
|
||||
)
|
||||
return resp.choices[0].message.content or ""
|
||||
|
||||
async def stream(
|
||||
self, messages: Sequence[Message], *, model: str, **params
|
||||
) -> AsyncIterator[str]:
|
||||
async with self._sem():
|
||||
stream = await self._client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": m.role, "content": m.content} for m in messages],
|
||||
stream=True,
|
||||
**params,
|
||||
)
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta.content or ""
|
||||
if delta:
|
||||
yield delta
|
||||
|
||||
async def embed(self, text: str, *, model: str) -> list[float]:
|
||||
"""Return an embedding vector for ``text`` using the named model.
|
||||
|
||||
Targets ``/v1/embeddings`` on the local MLX server; the server
|
||||
loads the model on first request and caches it. The embedding
|
||||
model is independent of the chat model loaded for ``generate``
|
||||
/ ``stream`` (the server can serve both).
|
||||
"""
|
||||
async with self._sem():
|
||||
resp = await self._client.embeddings.create(
|
||||
model=model,
|
||||
input=text,
|
||||
)
|
||||
return list(resp.data[0].embedding)
|
||||
@@ -0,0 +1,59 @@
|
||||
"""Routed LLM client — splits traffic across multiple backends by model.
|
||||
|
||||
Phase 4.5+ deployment: the 24B narrative model stays on Featherless,
|
||||
the 8B classifier model moves to local MLX, and embeddings run on a
|
||||
local BGE/MLX model. One :class:`LLMClient` interface, two underlying
|
||||
backends, dispatched by the ``model`` argument at every call site.
|
||||
|
||||
Routing rule: requests whose ``model`` argument matches the configured
|
||||
``narrative_model`` go to the narrative backend; everything else
|
||||
(classifier, embeddings, future locally-hosted models) goes to the
|
||||
local backend.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import AsyncIterator, Sequence
|
||||
|
||||
from .client import LLMClient, Message
|
||||
|
||||
|
||||
class RoutedLLMClient:
|
||||
"""Delegates to one of two underlying clients based on ``model``.
|
||||
|
||||
The narrative client is exercised only when ``model`` exactly equals
|
||||
``narrative_model`` (the configured remote model id). Everything
|
||||
else — classifier traffic, embeddings, any future locally-hosted
|
||||
model — routes to the local client. ``embed`` always routes locally
|
||||
(the remote provider doesn't support it; see
|
||||
:class:`chat.llm.featherless.FeatherlessClient.embed`).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
narrative: LLMClient,
|
||||
local: LLMClient,
|
||||
narrative_model: str,
|
||||
) -> None:
|
||||
self._narrative = narrative
|
||||
self._local = local
|
||||
self._narrative_model = narrative_model
|
||||
|
||||
def _pick(self, model: str) -> LLMClient:
|
||||
return self._narrative if model == self._narrative_model else self._local
|
||||
|
||||
async def generate(
|
||||
self, messages: Sequence[Message], *, model: str, **params
|
||||
) -> str:
|
||||
return await self._pick(model).generate(messages, model=model, **params)
|
||||
|
||||
async def stream(
|
||||
self, messages: Sequence[Message], *, model: str, **params
|
||||
) -> AsyncIterator[str]:
|
||||
async for chunk in self._pick(model).stream(messages, model=model, **params):
|
||||
yield chunk
|
||||
|
||||
async def embed(self, text: str, *, model: str) -> list[float]:
|
||||
# Embeddings always run on the local backend — the remote
|
||||
# provider doesn't expose a working ``/v1/embeddings`` endpoint.
|
||||
return await self._local.embed(text, model=model)
|
||||
+18
-2
@@ -32,14 +32,30 @@ router = APIRouter()
|
||||
|
||||
|
||||
def get_llm_client(request: Request) -> LLMClient:
|
||||
"""Production LLM client. Tests override this via ``app.dependency_overrides``."""
|
||||
"""Production LLM client. Tests override this via ``app.dependency_overrides``.
|
||||
|
||||
Returns a :class:`chat.llm.router.RoutedLLMClient` that splits
|
||||
traffic: the narrative model goes to Featherless, the classifier
|
||||
+ embeddings go to the local MLX server (``mlx-omni-server``).
|
||||
Both backends share the OpenAI-compatible surface, so the routing
|
||||
is invisible to call sites — they just pass ``model=...`` and the
|
||||
router picks the backend.
|
||||
"""
|
||||
settings = request.app.state.settings
|
||||
from chat.llm.featherless import FeatherlessClient
|
||||
from chat.llm.local_mlx import LocalMLXClient
|
||||
from chat.llm.router import RoutedLLMClient
|
||||
|
||||
return FeatherlessClient(
|
||||
narrative = FeatherlessClient(
|
||||
api_key=settings.featherless_api_key,
|
||||
base_url=settings.featherless_base_url,
|
||||
)
|
||||
local = LocalMLXClient(base_url=settings.local_mlx_base_url)
|
||||
return RoutedLLMClient(
|
||||
narrative=narrative,
|
||||
local=local,
|
||||
narrative_model=settings.narrative_model,
|
||||
)
|
||||
|
||||
|
||||
def _parse_holding(text: str) -> list[str]:
|
||||
|
||||
Executable
+38
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env bash
|
||||
# Start the local mlx-omni-server that serves the classifier + embedding
|
||||
# models. The chat app's RoutedLLMClient routes everything except the
|
||||
# narrative model to this server; with no MLX server running, classifier
|
||||
# calls fail and embeddings degrade to the zero-vector fallback.
|
||||
#
|
||||
# Run in the foreground:
|
||||
# ./scripts/start_mlx_server.sh
|
||||
# Run as a background daemon (logs to data/mlx-server.log):
|
||||
# ./scripts/start_mlx_server.sh --daemon
|
||||
#
|
||||
# Models are pulled from Hugging Face on first request; expect a delay
|
||||
# the first time you exercise the classifier or embedding path.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
||||
VENV="${REPO_ROOT}/.mlx-venv"
|
||||
LOG="${REPO_ROOT}/data/mlx-server.log"
|
||||
PORT="${MLX_PORT:-10240}"
|
||||
HOST="${MLX_HOST:-127.0.0.1}"
|
||||
|
||||
if [ ! -x "${VENV}/bin/mlx-omni-server" ]; then
|
||||
echo "error: mlx-omni-server not installed in ${VENV}" >&2
|
||||
echo "create the venv with:" >&2
|
||||
echo " python3.12 -m venv ${VENV} && ${VENV}/bin/pip install mlx-omni-server" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ "${1:-}" = "--daemon" ]; then
|
||||
mkdir -p "$(dirname "${LOG}")"
|
||||
nohup "${VENV}/bin/mlx-omni-server" --host "${HOST}" --port "${PORT}" \
|
||||
>>"${LOG}" 2>&1 &
|
||||
echo "mlx-omni-server started in background (pid $!)"
|
||||
echo "logs: ${LOG}"
|
||||
else
|
||||
exec "${VENV}/bin/mlx-omni-server" --host "${HOST}" --port "${PORT}"
|
||||
fi
|
||||
@@ -0,0 +1,84 @@
|
||||
"""Tests for LocalMLXClient (Phase 4.5+).
|
||||
|
||||
Talks to a local mlx-omni-server over the OpenAI-compatible surface.
|
||||
We don't spin up a real server in tests — instead we monkey-patch the
|
||||
underlying ``AsyncOpenAI`` instance to assert on the request shape and
|
||||
return canned responses. The semaphore behavior is shared with
|
||||
FeatherlessClient (same pattern), so we don't re-test that here.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from chat.llm.client import Message
|
||||
from chat.llm.local_mlx import LocalMLXClient
|
||||
|
||||
|
||||
class _FakeChatCompletions:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
self.calls = []
|
||||
|
||||
async def create(self, **kw):
|
||||
self.calls.append(kw)
|
||||
return self.response
|
||||
|
||||
|
||||
class _FakeEmbeddings:
|
||||
def __init__(self, vector):
|
||||
self.vector = vector
|
||||
self.calls = []
|
||||
|
||||
async def create(self, **kw):
|
||||
self.calls.append(kw)
|
||||
return SimpleNamespace(data=[SimpleNamespace(embedding=self.vector)])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_mlx_client_generate_calls_chat_completions():
|
||||
client = LocalMLXClient(base_url="http://localhost:10240/v1")
|
||||
fake_response = SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content="hello"))]
|
||||
)
|
||||
fake_chat = _FakeChatCompletions(fake_response)
|
||||
client._client.chat = SimpleNamespace(completions=fake_chat)
|
||||
|
||||
out = await client.generate(
|
||||
[Message(role="user", content="hi")],
|
||||
model="mlx-community/Hermes-3-Llama-3.1-8B-8bit",
|
||||
)
|
||||
|
||||
assert out == "hello"
|
||||
assert len(fake_chat.calls) == 1
|
||||
assert fake_chat.calls[0]["model"] == "mlx-community/Hermes-3-Llama-3.1-8B-8bit"
|
||||
assert fake_chat.calls[0]["messages"] == [{"role": "user", "content": "hi"}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_mlx_client_embed_returns_vector():
|
||||
"""``embed()`` actually works on this client (unlike FeatherlessClient
|
||||
which raises NotImplementedError) — the local MLX server has a real
|
||||
``/v1/embeddings`` endpoint backed by an MLX-quantized model.
|
||||
"""
|
||||
client = LocalMLXClient()
|
||||
canned = [0.1, 0.2, 0.3, 0.4]
|
||||
fake_embeddings = _FakeEmbeddings(canned)
|
||||
client._client.embeddings = fake_embeddings
|
||||
|
||||
out = await client.embed("hello", model="mlx-community/bge-small-en-v1.5-bf16")
|
||||
|
||||
assert out == canned
|
||||
assert fake_embeddings.calls[0]["model"] == "mlx-community/bge-small-en-v1.5-bf16"
|
||||
assert fake_embeddings.calls[0]["input"] == "hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_mlx_client_default_base_url():
|
||||
"""Default base_url targets ``mlx-omni-server`` on its standard port."""
|
||||
client = LocalMLXClient()
|
||||
# AsyncOpenAI normalizes trailing-slash differences; just check the
|
||||
# configured host:port appears in the underlying client config.
|
||||
assert "127.0.0.1:10240" in str(client._client.base_url)
|
||||
@@ -0,0 +1,108 @@
|
||||
"""Tests for RoutedLLMClient (Phase 4.5+).
|
||||
|
||||
Splits traffic across two underlying clients based on the ``model``
|
||||
kwarg. We use simple stub clients to assert the router picks the
|
||||
correct backend for each call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncIterator, Sequence
|
||||
|
||||
import pytest
|
||||
|
||||
from chat.llm.client import Message
|
||||
from chat.llm.router import RoutedLLMClient
|
||||
|
||||
|
||||
class _StubClient:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.generate_calls: list[str] = []
|
||||
self.stream_calls: list[str] = []
|
||||
self.embed_calls: list[str] = []
|
||||
|
||||
async def generate(self, messages, *, model, **params) -> str:
|
||||
self.generate_calls.append(model)
|
||||
return f"{self.name}:{model}"
|
||||
|
||||
async def stream(self, messages, *, model, **params) -> AsyncIterator[str]:
|
||||
self.stream_calls.append(model)
|
||||
yield f"{self.name}:{model}"
|
||||
|
||||
async def embed(self, text, *, model) -> list[float]:
|
||||
self.embed_calls.append(model)
|
||||
return [1.0, 2.0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_generate_dispatches_narrative_to_narrative_backend():
|
||||
narrative = _StubClient("narrative")
|
||||
local = _StubClient("local")
|
||||
router = RoutedLLMClient(
|
||||
narrative=narrative,
|
||||
local=local,
|
||||
narrative_model="big-model",
|
||||
)
|
||||
|
||||
out = await router.generate([Message(role="user", content="hi")], model="big-model")
|
||||
|
||||
assert out == "narrative:big-model"
|
||||
assert narrative.generate_calls == ["big-model"]
|
||||
assert local.generate_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_generate_dispatches_classifier_to_local_backend():
|
||||
narrative = _StubClient("narrative")
|
||||
local = _StubClient("local")
|
||||
router = RoutedLLMClient(
|
||||
narrative=narrative,
|
||||
local=local,
|
||||
narrative_model="big-model",
|
||||
)
|
||||
|
||||
out = await router.generate(
|
||||
[Message(role="user", content="hi")], model="small-model"
|
||||
)
|
||||
|
||||
assert out == "local:small-model"
|
||||
assert local.generate_calls == ["small-model"]
|
||||
assert narrative.generate_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_stream_dispatches_by_model():
|
||||
narrative = _StubClient("narrative")
|
||||
local = _StubClient("local")
|
||||
router = RoutedLLMClient(
|
||||
narrative=narrative, local=local, narrative_model="big-model"
|
||||
)
|
||||
|
||||
chunks_big = [c async for c in router.stream(
|
||||
[Message(role="user", content="hi")], model="big-model"
|
||||
)]
|
||||
chunks_small = [c async for c in router.stream(
|
||||
[Message(role="user", content="hi")], model="other-model"
|
||||
)]
|
||||
|
||||
assert chunks_big == ["narrative:big-model"]
|
||||
assert chunks_small == ["local:other-model"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_embed_always_routes_to_local():
|
||||
"""Embeddings always run locally — the remote provider doesn't
|
||||
expose a working ``/v1/embeddings``, so the router never sends
|
||||
embed calls there even if the model name happens to look 'remote'."""
|
||||
narrative = _StubClient("narrative")
|
||||
local = _StubClient("local")
|
||||
router = RoutedLLMClient(
|
||||
narrative=narrative, local=local, narrative_model="big-model"
|
||||
)
|
||||
|
||||
out = await router.embed("hello", model="any-embedding-model")
|
||||
|
||||
assert out == [1.0, 2.0]
|
||||
assert local.embed_calls == ["any-embedding-model"]
|
||||
assert narrative.embed_calls == []
|
||||
Reference in New Issue
Block a user