fe9c497038
Adds RoutedLLMClient that dispatches by model name: requests matching Settings.narrative_model go to Featherless, everything else (classifier calls, embed) goes to a local MLX server. The local server is mlx-omni-server (separate venv at .mlx-venv) and exposes the standard OpenAI surface at http://127.0.0.1:10240/v1. LocalMLXClient mirrors FeatherlessClient (AsyncOpenAI under the hood) but with a working embed() — Featherless's /v1/embeddings always returns 500 with completions_error, so the router unconditionally sends embed traffic to the local backend. Production deployment overrides via data/config.toml: - classifier_model = mlx-community/Hermes-3-Llama-3.1-8B-8bit (~8 GB) - embedding_model = mlx-community/bge-small-en-v1.5-bf16 (~150 MB, 384 dim — matches existing schema, no migration) Defaults stay remote / pseudo so fresh installs and tests need no external infra. Smoke-tested live: classifier returns expected output, BGE produces correctly-clustering 384-dim vectors (cat-on-mat closer to cat-on-rug than to quantum-mechanics). scripts/start_mlx_server.sh starts the daemon (foreground or --daemon). .mlx-venv/ added to .gitignore. Suite: 464 passed (was 457 → +7 new across LocalMLXClient + Router).
85 lines
2.8 KiB
Python
85 lines
2.8 KiB
Python
"""Tests for LocalMLXClient (Phase 4.5+).
|
|
|
|
Talks to a local mlx-omni-server over the OpenAI-compatible surface.
|
|
We don't spin up a real server in tests — instead we monkey-patch the
|
|
underlying ``AsyncOpenAI`` instance to assert on the request shape and
|
|
return canned responses. The semaphore behavior is shared with
|
|
FeatherlessClient (same pattern), so we don't re-test that here.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from types import SimpleNamespace
|
|
|
|
import pytest
|
|
|
|
from chat.llm.client import Message
|
|
from chat.llm.local_mlx import LocalMLXClient
|
|
|
|
|
|
class _FakeChatCompletions:
|
|
def __init__(self, response):
|
|
self.response = response
|
|
self.calls = []
|
|
|
|
async def create(self, **kw):
|
|
self.calls.append(kw)
|
|
return self.response
|
|
|
|
|
|
class _FakeEmbeddings:
|
|
def __init__(self, vector):
|
|
self.vector = vector
|
|
self.calls = []
|
|
|
|
async def create(self, **kw):
|
|
self.calls.append(kw)
|
|
return SimpleNamespace(data=[SimpleNamespace(embedding=self.vector)])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_local_mlx_client_generate_calls_chat_completions():
|
|
client = LocalMLXClient(base_url="http://localhost:10240/v1")
|
|
fake_response = SimpleNamespace(
|
|
choices=[SimpleNamespace(message=SimpleNamespace(content="hello"))]
|
|
)
|
|
fake_chat = _FakeChatCompletions(fake_response)
|
|
client._client.chat = SimpleNamespace(completions=fake_chat)
|
|
|
|
out = await client.generate(
|
|
[Message(role="user", content="hi")],
|
|
model="mlx-community/Hermes-3-Llama-3.1-8B-8bit",
|
|
)
|
|
|
|
assert out == "hello"
|
|
assert len(fake_chat.calls) == 1
|
|
assert fake_chat.calls[0]["model"] == "mlx-community/Hermes-3-Llama-3.1-8B-8bit"
|
|
assert fake_chat.calls[0]["messages"] == [{"role": "user", "content": "hi"}]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_local_mlx_client_embed_returns_vector():
|
|
"""``embed()`` actually works on this client (unlike FeatherlessClient
|
|
which raises NotImplementedError) — the local MLX server has a real
|
|
``/v1/embeddings`` endpoint backed by an MLX-quantized model.
|
|
"""
|
|
client = LocalMLXClient()
|
|
canned = [0.1, 0.2, 0.3, 0.4]
|
|
fake_embeddings = _FakeEmbeddings(canned)
|
|
client._client.embeddings = fake_embeddings
|
|
|
|
out = await client.embed("hello", model="mlx-community/bge-small-en-v1.5-bf16")
|
|
|
|
assert out == canned
|
|
assert fake_embeddings.calls[0]["model"] == "mlx-community/bge-small-en-v1.5-bf16"
|
|
assert fake_embeddings.calls[0]["input"] == "hello"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_local_mlx_client_default_base_url():
|
|
"""Default base_url targets ``mlx-omni-server`` on its standard port."""
|
|
client = LocalMLXClient()
|
|
# AsyncOpenAI normalizes trailing-slash differences; just check the
|
|
# configured host:port appears in the underlying client config.
|
|
assert "127.0.0.1:10240" in str(client._client.base_url)
|