"""Tests for LocalMLXClient (Phase 4.5+). Talks to a local mlx-omni-server over the OpenAI-compatible surface. We don't spin up a real server in tests — instead we monkey-patch the underlying ``AsyncOpenAI`` instance to assert on the request shape and return canned responses. The semaphore behavior is shared with FeatherlessClient (same pattern), so we don't re-test that here. """ from __future__ import annotations from types import SimpleNamespace import pytest from chat.llm.client import Message from chat.llm.local_mlx import LocalMLXClient class _FakeChatCompletions: def __init__(self, response): self.response = response self.calls = [] async def create(self, **kw): self.calls.append(kw) return self.response class _FakeEmbeddings: def __init__(self, vector): self.vector = vector self.calls = [] async def create(self, **kw): self.calls.append(kw) return SimpleNamespace(data=[SimpleNamespace(embedding=self.vector)]) @pytest.mark.asyncio async def test_local_mlx_client_generate_calls_chat_completions(): client = LocalMLXClient(base_url="http://localhost:10240/v1") fake_response = SimpleNamespace( choices=[SimpleNamespace(message=SimpleNamespace(content="hello"))] ) fake_chat = _FakeChatCompletions(fake_response) client._client.chat = SimpleNamespace(completions=fake_chat) out = await client.generate( [Message(role="user", content="hi")], model="mlx-community/Hermes-3-Llama-3.1-8B-8bit", ) assert out == "hello" assert len(fake_chat.calls) == 1 assert fake_chat.calls[0]["model"] == "mlx-community/Hermes-3-Llama-3.1-8B-8bit" assert fake_chat.calls[0]["messages"] == [{"role": "user", "content": "hi"}] @pytest.mark.asyncio async def test_local_mlx_client_embed_returns_vector(): """``embed()`` actually works on this client (unlike FeatherlessClient which raises NotImplementedError) — the local MLX server has a real ``/v1/embeddings`` endpoint backed by an MLX-quantized model. """ client = LocalMLXClient() canned = [0.1, 0.2, 0.3, 0.4] fake_embeddings = _FakeEmbeddings(canned) client._client.embeddings = fake_embeddings out = await client.embed("hello", model="mlx-community/bge-small-en-v1.5-bf16") assert out == canned assert fake_embeddings.calls[0]["model"] == "mlx-community/bge-small-en-v1.5-bf16" assert fake_embeddings.calls[0]["input"] == "hello" @pytest.mark.asyncio async def test_local_mlx_client_default_base_url(): """Default base_url targets ``mlx-omni-server`` on its standard port.""" client = LocalMLXClient() # AsyncOpenAI normalizes trailing-slash differences; just check the # configured host:port appears in the underlying client config. assert "127.0.0.1:10240" in str(client._client.base_url)