diff --git a/.gitignore b/.gitignore index 687ab3d..a12360c 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ data/ # Python .venv/ +.mlx-venv/ __pycache__/ *.pyc .pytest_cache/ diff --git a/chat/app.py b/chat/app.py index 7241cd0..3e7fcb9 100644 --- a/chat/app.py +++ b/chat/app.py @@ -72,17 +72,28 @@ async def lifespan(app: FastAPI): # (free / lower paid tiers cap at 2). Shared across all # FeatherlessClient instances in the process. from chat.llm.featherless import FeatherlessClient + from chat.llm.local_mlx import LocalMLXClient + from chat.llm.router import RoutedLLMClient FeatherlessClient.configure_concurrency(settings.featherless_max_concurrent) + LocalMLXClient.configure_concurrency(settings.local_mlx_max_concurrent) - # Background worker for the async significance pass (T22). Each job - # constructs a fresh FeatherlessClient via the factory; tests can - # disable enqueue by toggling ``app.state.background_worker.enabled``. + # Background workers (significance scoring, embedding indexer) + # construct a fresh client per job via the factory. Workers route + # through the same RoutedLLMClient as request-time traffic so the + # narrative model still goes to Featherless and the classifier + + # embeddings hit the local MLX server. def _factory(): - return FeatherlessClient( + narrative = FeatherlessClient( api_key=settings.featherless_api_key, base_url=settings.featherless_base_url, ) + local = LocalMLXClient(base_url=settings.local_mlx_base_url) + return RoutedLLMClient( + narrative=narrative, + local=local, + narrative_model=settings.narrative_model, + ) worker = BackgroundWorker(settings, llm_client_factory=_factory) await worker.start() diff --git a/chat/config.py b/chat/config.py index d10dea4..fa08803 100644 --- a/chat/config.py +++ b/chat/config.py @@ -39,13 +39,22 @@ class Settings(BaseModel): data_dir: Path = REPO_ROOT / "data" bind_host: str = "127.0.0.1" bind_port: int = 8000 + # Local MLX server (e.g. ``mlx-omni-server``) — serves any model + # whose id is NOT ``narrative_model``. The :class:`RoutedLLMClient` + # inspects the ``model`` kwarg at call time: ``model == narrative_model`` + # -> Featherless, otherwise local. ``embed()`` always routes local. + # If no MLX server is running and ``classifier_model`` is set to a + # local model id, classifier calls will fail; ``embed()`` will fall + # back to the zero-vector path with a T107 warning. To use the + # default Featherless-only deployment, leave ``classifier_model`` + # at its remote default and ``embedding_model`` at the pseudo. + local_mlx_base_url: str = "http://127.0.0.1:10240/v1" + local_mlx_max_concurrent: int = 1 # T112 (Phase 4.5): embedding model identifier. Default is the - # deterministic local pseudo (semantically meaningless but keeps the - # vector pipeline structurally valid). Swap to a real model name - # (e.g. "bge-small-en-v1.5") once the LLMClient implementation - # supports embed() — currently FeatherlessClient does NOT, so a - # non-default value will trigger the zero-vector fallback path - # plus a T107 warning until a different provider is wired in. + # deterministic local pseudo so fresh installs / tests don't need + # any external infra. Override via config.toml to a real model id + # (e.g. ``"mlx-community/bge-small-en-v1.5-bf16"``) once a local + # MLX server is running. embedding_model: str = "pseudo-sha256-384" def load_settings() -> Settings: diff --git a/chat/llm/local_mlx.py b/chat/llm/local_mlx.py new file mode 100644 index 0000000..ed22bb6 --- /dev/null +++ b/chat/llm/local_mlx.py @@ -0,0 +1,95 @@ +"""Local MLX OpenAI-compatible client. + +Talks to a locally-running MLX server (e.g., ``mlx-omni-server``) over +the same OpenAI surface that :class:`chat.llm.featherless.FeatherlessClient` +uses, via :class:`openai.AsyncOpenAI`. The underlying server runs MLX +models on Apple Silicon (M-series) for chat completions AND embeddings. + +Use cases (Phase 4.5+): +- Classifier traffic moved off Featherless to local MLX (cost + latency). +- Embeddings via ``client.embed`` actually work — Featherless's + ``/v1/embeddings`` always returns 500. + +Constructor takes a ``base_url`` (e.g., ``"http://127.0.0.1:10240/v1"``) +and an optional ``api_key`` (most local MLX servers don't authenticate; +the OpenAI SDK requires *some* string, so we default to a placeholder). +""" + +from __future__ import annotations +import asyncio +from typing import AsyncIterator, Sequence + +from openai import AsyncOpenAI + +from .client import Message + + +class LocalMLXClient: + """OpenAI-compatible client for a local MLX server. + + The server is single-process by default (``mlx-omni-server`` loads + one model at a time and swaps on demand). The class-level semaphore + serializes concurrent requests so we never queue more than + ``max_concurrent`` at a time — defaults to 1, since MLX inference + on a single M-series device is sequential anyway. + """ + + _semaphore: asyncio.Semaphore | None = None + + @classmethod + def configure_concurrency(cls, max_concurrent: int) -> None: + cls._semaphore = asyncio.Semaphore(max(1, int(max_concurrent))) + + @classmethod + def _sem(cls) -> asyncio.Semaphore: + if cls._semaphore is None: + cls._semaphore = asyncio.Semaphore(1) + return cls._semaphore + + def __init__( + self, + base_url: str = "http://127.0.0.1:10240/v1", + api_key: str = "not-needed", + ): + self._client = AsyncOpenAI(api_key=api_key, base_url=base_url) + + async def generate( + self, messages: Sequence[Message], *, model: str, **params + ) -> str: + async with self._sem(): + resp = await self._client.chat.completions.create( + model=model, + messages=[{"role": m.role, "content": m.content} for m in messages], + **params, + ) + return resp.choices[0].message.content or "" + + async def stream( + self, messages: Sequence[Message], *, model: str, **params + ) -> AsyncIterator[str]: + async with self._sem(): + stream = await self._client.chat.completions.create( + model=model, + messages=[{"role": m.role, "content": m.content} for m in messages], + stream=True, + **params, + ) + async for chunk in stream: + delta = chunk.choices[0].delta.content or "" + if delta: + yield delta + + async def embed(self, text: str, *, model: str) -> list[float]: + """Return an embedding vector for ``text`` using the named model. + + Targets ``/v1/embeddings`` on the local MLX server; the server + loads the model on first request and caches it. The embedding + model is independent of the chat model loaded for ``generate`` + / ``stream`` (the server can serve both). + """ + async with self._sem(): + resp = await self._client.embeddings.create( + model=model, + input=text, + ) + return list(resp.data[0].embedding) diff --git a/chat/llm/router.py b/chat/llm/router.py new file mode 100644 index 0000000..b113a66 --- /dev/null +++ b/chat/llm/router.py @@ -0,0 +1,59 @@ +"""Routed LLM client — splits traffic across multiple backends by model. + +Phase 4.5+ deployment: the 24B narrative model stays on Featherless, +the 8B classifier model moves to local MLX, and embeddings run on a +local BGE/MLX model. One :class:`LLMClient` interface, two underlying +backends, dispatched by the ``model`` argument at every call site. + +Routing rule: requests whose ``model`` argument matches the configured +``narrative_model`` go to the narrative backend; everything else +(classifier, embeddings, future locally-hosted models) goes to the +local backend. +""" + +from __future__ import annotations +from typing import AsyncIterator, Sequence + +from .client import LLMClient, Message + + +class RoutedLLMClient: + """Delegates to one of two underlying clients based on ``model``. + + The narrative client is exercised only when ``model`` exactly equals + ``narrative_model`` (the configured remote model id). Everything + else — classifier traffic, embeddings, any future locally-hosted + model — routes to the local client. ``embed`` always routes locally + (the remote provider doesn't support it; see + :class:`chat.llm.featherless.FeatherlessClient.embed`). + """ + + def __init__( + self, + *, + narrative: LLMClient, + local: LLMClient, + narrative_model: str, + ) -> None: + self._narrative = narrative + self._local = local + self._narrative_model = narrative_model + + def _pick(self, model: str) -> LLMClient: + return self._narrative if model == self._narrative_model else self._local + + async def generate( + self, messages: Sequence[Message], *, model: str, **params + ) -> str: + return await self._pick(model).generate(messages, model=model, **params) + + async def stream( + self, messages: Sequence[Message], *, model: str, **params + ) -> AsyncIterator[str]: + async for chunk in self._pick(model).stream(messages, model=model, **params): + yield chunk + + async def embed(self, text: str, *, model: str) -> list[float]: + # Embeddings always run on the local backend — the remote + # provider doesn't expose a working ``/v1/embeddings`` endpoint. + return await self._local.embed(text, model=model) diff --git a/chat/web/kickoff.py b/chat/web/kickoff.py index 320b771..b214147 100644 --- a/chat/web/kickoff.py +++ b/chat/web/kickoff.py @@ -32,14 +32,30 @@ router = APIRouter() def get_llm_client(request: Request) -> LLMClient: - """Production LLM client. Tests override this via ``app.dependency_overrides``.""" + """Production LLM client. Tests override this via ``app.dependency_overrides``. + + Returns a :class:`chat.llm.router.RoutedLLMClient` that splits + traffic: the narrative model goes to Featherless, the classifier + + embeddings go to the local MLX server (``mlx-omni-server``). + Both backends share the OpenAI-compatible surface, so the routing + is invisible to call sites — they just pass ``model=...`` and the + router picks the backend. + """ settings = request.app.state.settings from chat.llm.featherless import FeatherlessClient + from chat.llm.local_mlx import LocalMLXClient + from chat.llm.router import RoutedLLMClient - return FeatherlessClient( + narrative = FeatherlessClient( api_key=settings.featherless_api_key, base_url=settings.featherless_base_url, ) + local = LocalMLXClient(base_url=settings.local_mlx_base_url) + return RoutedLLMClient( + narrative=narrative, + local=local, + narrative_model=settings.narrative_model, + ) def _parse_holding(text: str) -> list[str]: diff --git a/scripts/start_mlx_server.sh b/scripts/start_mlx_server.sh new file mode 100755 index 0000000..3615c94 --- /dev/null +++ b/scripts/start_mlx_server.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash +# Start the local mlx-omni-server that serves the classifier + embedding +# models. The chat app's RoutedLLMClient routes everything except the +# narrative model to this server; with no MLX server running, classifier +# calls fail and embeddings degrade to the zero-vector fallback. +# +# Run in the foreground: +# ./scripts/start_mlx_server.sh +# Run as a background daemon (logs to data/mlx-server.log): +# ./scripts/start_mlx_server.sh --daemon +# +# Models are pulled from Hugging Face on first request; expect a delay +# the first time you exercise the classifier or embedding path. + +set -euo pipefail + +REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)" +VENV="${REPO_ROOT}/.mlx-venv" +LOG="${REPO_ROOT}/data/mlx-server.log" +PORT="${MLX_PORT:-10240}" +HOST="${MLX_HOST:-127.0.0.1}" + +if [ ! -x "${VENV}/bin/mlx-omni-server" ]; then + echo "error: mlx-omni-server not installed in ${VENV}" >&2 + echo "create the venv with:" >&2 + echo " python3.12 -m venv ${VENV} && ${VENV}/bin/pip install mlx-omni-server" >&2 + exit 1 +fi + +if [ "${1:-}" = "--daemon" ]; then + mkdir -p "$(dirname "${LOG}")" + nohup "${VENV}/bin/mlx-omni-server" --host "${HOST}" --port "${PORT}" \ + >>"${LOG}" 2>&1 & + echo "mlx-omni-server started in background (pid $!)" + echo "logs: ${LOG}" +else + exec "${VENV}/bin/mlx-omni-server" --host "${HOST}" --port "${PORT}" +fi diff --git a/tests/test_local_mlx_client.py b/tests/test_local_mlx_client.py new file mode 100644 index 0000000..7893875 --- /dev/null +++ b/tests/test_local_mlx_client.py @@ -0,0 +1,84 @@ +"""Tests for LocalMLXClient (Phase 4.5+). + +Talks to a local mlx-omni-server over the OpenAI-compatible surface. +We don't spin up a real server in tests — instead we monkey-patch the +underlying ``AsyncOpenAI`` instance to assert on the request shape and +return canned responses. The semaphore behavior is shared with +FeatherlessClient (same pattern), so we don't re-test that here. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from chat.llm.client import Message +from chat.llm.local_mlx import LocalMLXClient + + +class _FakeChatCompletions: + def __init__(self, response): + self.response = response + self.calls = [] + + async def create(self, **kw): + self.calls.append(kw) + return self.response + + +class _FakeEmbeddings: + def __init__(self, vector): + self.vector = vector + self.calls = [] + + async def create(self, **kw): + self.calls.append(kw) + return SimpleNamespace(data=[SimpleNamespace(embedding=self.vector)]) + + +@pytest.mark.asyncio +async def test_local_mlx_client_generate_calls_chat_completions(): + client = LocalMLXClient(base_url="http://localhost:10240/v1") + fake_response = SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content="hello"))] + ) + fake_chat = _FakeChatCompletions(fake_response) + client._client.chat = SimpleNamespace(completions=fake_chat) + + out = await client.generate( + [Message(role="user", content="hi")], + model="mlx-community/Hermes-3-Llama-3.1-8B-8bit", + ) + + assert out == "hello" + assert len(fake_chat.calls) == 1 + assert fake_chat.calls[0]["model"] == "mlx-community/Hermes-3-Llama-3.1-8B-8bit" + assert fake_chat.calls[0]["messages"] == [{"role": "user", "content": "hi"}] + + +@pytest.mark.asyncio +async def test_local_mlx_client_embed_returns_vector(): + """``embed()`` actually works on this client (unlike FeatherlessClient + which raises NotImplementedError) — the local MLX server has a real + ``/v1/embeddings`` endpoint backed by an MLX-quantized model. + """ + client = LocalMLXClient() + canned = [0.1, 0.2, 0.3, 0.4] + fake_embeddings = _FakeEmbeddings(canned) + client._client.embeddings = fake_embeddings + + out = await client.embed("hello", model="mlx-community/bge-small-en-v1.5-bf16") + + assert out == canned + assert fake_embeddings.calls[0]["model"] == "mlx-community/bge-small-en-v1.5-bf16" + assert fake_embeddings.calls[0]["input"] == "hello" + + +@pytest.mark.asyncio +async def test_local_mlx_client_default_base_url(): + """Default base_url targets ``mlx-omni-server`` on its standard port.""" + client = LocalMLXClient() + # AsyncOpenAI normalizes trailing-slash differences; just check the + # configured host:port appears in the underlying client config. + assert "127.0.0.1:10240" in str(client._client.base_url) diff --git a/tests/test_router_client.py b/tests/test_router_client.py new file mode 100644 index 0000000..3ff9c45 --- /dev/null +++ b/tests/test_router_client.py @@ -0,0 +1,108 @@ +"""Tests for RoutedLLMClient (Phase 4.5+). + +Splits traffic across two underlying clients based on the ``model`` +kwarg. We use simple stub clients to assert the router picks the +correct backend for each call. +""" + +from __future__ import annotations + +from typing import AsyncIterator, Sequence + +import pytest + +from chat.llm.client import Message +from chat.llm.router import RoutedLLMClient + + +class _StubClient: + def __init__(self, name: str): + self.name = name + self.generate_calls: list[str] = [] + self.stream_calls: list[str] = [] + self.embed_calls: list[str] = [] + + async def generate(self, messages, *, model, **params) -> str: + self.generate_calls.append(model) + return f"{self.name}:{model}" + + async def stream(self, messages, *, model, **params) -> AsyncIterator[str]: + self.stream_calls.append(model) + yield f"{self.name}:{model}" + + async def embed(self, text, *, model) -> list[float]: + self.embed_calls.append(model) + return [1.0, 2.0] + + +@pytest.mark.asyncio +async def test_router_generate_dispatches_narrative_to_narrative_backend(): + narrative = _StubClient("narrative") + local = _StubClient("local") + router = RoutedLLMClient( + narrative=narrative, + local=local, + narrative_model="big-model", + ) + + out = await router.generate([Message(role="user", content="hi")], model="big-model") + + assert out == "narrative:big-model" + assert narrative.generate_calls == ["big-model"] + assert local.generate_calls == [] + + +@pytest.mark.asyncio +async def test_router_generate_dispatches_classifier_to_local_backend(): + narrative = _StubClient("narrative") + local = _StubClient("local") + router = RoutedLLMClient( + narrative=narrative, + local=local, + narrative_model="big-model", + ) + + out = await router.generate( + [Message(role="user", content="hi")], model="small-model" + ) + + assert out == "local:small-model" + assert local.generate_calls == ["small-model"] + assert narrative.generate_calls == [] + + +@pytest.mark.asyncio +async def test_router_stream_dispatches_by_model(): + narrative = _StubClient("narrative") + local = _StubClient("local") + router = RoutedLLMClient( + narrative=narrative, local=local, narrative_model="big-model" + ) + + chunks_big = [c async for c in router.stream( + [Message(role="user", content="hi")], model="big-model" + )] + chunks_small = [c async for c in router.stream( + [Message(role="user", content="hi")], model="other-model" + )] + + assert chunks_big == ["narrative:big-model"] + assert chunks_small == ["local:other-model"] + + +@pytest.mark.asyncio +async def test_router_embed_always_routes_to_local(): + """Embeddings always run locally — the remote provider doesn't + expose a working ``/v1/embeddings``, so the router never sends + embed calls there even if the model name happens to look 'remote'.""" + narrative = _StubClient("narrative") + local = _StubClient("local") + router = RoutedLLMClient( + narrative=narrative, local=local, narrative_model="big-model" + ) + + out = await router.embed("hello", model="any-embedding-model") + + assert out == [1.0, 2.0] + assert local.embed_calls == ["any-embedding-model"] + assert narrative.embed_calls == []