Files
chat/tests/test_local_mlx_client.py
Joseph Doherty fe9c497038 feat: split classifier + embeddings to local mlx-omni-server, narrative stays on Featherless
Adds RoutedLLMClient that dispatches by model name: requests matching
Settings.narrative_model go to Featherless, everything else (classifier
calls, embed) goes to a local MLX server. The local server is
mlx-omni-server (separate venv at .mlx-venv) and exposes the standard
OpenAI surface at http://127.0.0.1:10240/v1.

LocalMLXClient mirrors FeatherlessClient (AsyncOpenAI under the hood)
but with a working embed() — Featherless's /v1/embeddings always
returns 500 with completions_error, so the router unconditionally
sends embed traffic to the local backend.

Production deployment overrides via data/config.toml:
- classifier_model = mlx-community/Hermes-3-Llama-3.1-8B-8bit (~8 GB)
- embedding_model = mlx-community/bge-small-en-v1.5-bf16 (~150 MB,
  384 dim — matches existing schema, no migration)

Defaults stay remote / pseudo so fresh installs and tests need no
external infra. Smoke-tested live: classifier returns expected output,
BGE produces correctly-clustering 384-dim vectors (cat-on-mat closer
to cat-on-rug than to quantum-mechanics).

scripts/start_mlx_server.sh starts the daemon (foreground or --daemon).
.mlx-venv/ added to .gitignore.

Suite: 464 passed (was 457 → +7 new across LocalMLXClient + Router).
2026-04-27 12:05:41 -04:00

85 lines
2.8 KiB
Python

"""Tests for LocalMLXClient (Phase 4.5+).
Talks to a local mlx-omni-server over the OpenAI-compatible surface.
We don't spin up a real server in tests — instead we monkey-patch the
underlying ``AsyncOpenAI`` instance to assert on the request shape and
return canned responses. The semaphore behavior is shared with
FeatherlessClient (same pattern), so we don't re-test that here.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
from chat.llm.client import Message
from chat.llm.local_mlx import LocalMLXClient
class _FakeChatCompletions:
def __init__(self, response):
self.response = response
self.calls = []
async def create(self, **kw):
self.calls.append(kw)
return self.response
class _FakeEmbeddings:
def __init__(self, vector):
self.vector = vector
self.calls = []
async def create(self, **kw):
self.calls.append(kw)
return SimpleNamespace(data=[SimpleNamespace(embedding=self.vector)])
@pytest.mark.asyncio
async def test_local_mlx_client_generate_calls_chat_completions():
client = LocalMLXClient(base_url="http://localhost:10240/v1")
fake_response = SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content="hello"))]
)
fake_chat = _FakeChatCompletions(fake_response)
client._client.chat = SimpleNamespace(completions=fake_chat)
out = await client.generate(
[Message(role="user", content="hi")],
model="mlx-community/Hermes-3-Llama-3.1-8B-8bit",
)
assert out == "hello"
assert len(fake_chat.calls) == 1
assert fake_chat.calls[0]["model"] == "mlx-community/Hermes-3-Llama-3.1-8B-8bit"
assert fake_chat.calls[0]["messages"] == [{"role": "user", "content": "hi"}]
@pytest.mark.asyncio
async def test_local_mlx_client_embed_returns_vector():
"""``embed()`` actually works on this client (unlike FeatherlessClient
which raises NotImplementedError) — the local MLX server has a real
``/v1/embeddings`` endpoint backed by an MLX-quantized model.
"""
client = LocalMLXClient()
canned = [0.1, 0.2, 0.3, 0.4]
fake_embeddings = _FakeEmbeddings(canned)
client._client.embeddings = fake_embeddings
out = await client.embed("hello", model="mlx-community/bge-small-en-v1.5-bf16")
assert out == canned
assert fake_embeddings.calls[0]["model"] == "mlx-community/bge-small-en-v1.5-bf16"
assert fake_embeddings.calls[0]["input"] == "hello"
@pytest.mark.asyncio
async def test_local_mlx_client_default_base_url():
"""Default base_url targets ``mlx-omni-server`` on its standard port."""
client = LocalMLXClient()
# AsyncOpenAI normalizes trailing-slash differences; just check the
# configured host:port appears in the underlying client config.
assert "127.0.0.1:10240" in str(client._client.base_url)