feat: LLMClient protocol with Featherless and mock implementations

This commit is contained in:
Joseph Doherty
2026-04-26 11:35:57 -04:00
parent 67517926aa
commit e627356168
5 changed files with 80 additions and 0 deletions
View File
+14
View File
@@ -0,0 +1,14 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol, AsyncIterator, Sequence
@dataclass
class Message:
role: str # "system" | "user" | "assistant"
content: str
class LLMClient(Protocol):
async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str: ...
def stream(self, messages: Sequence[Message], *, model: str, **params) -> AsyncIterator[str]: ...
+29
View File
@@ -0,0 +1,29 @@
from __future__ import annotations
from typing import AsyncIterator, Sequence
from openai import AsyncOpenAI
from .client import Message
class FeatherlessClient:
def __init__(self, api_key: str, base_url: str = "https://api.featherless.ai/v1"):
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str:
resp = await self._client.chat.completions.create(
model=model,
messages=[{"role": m.role, "content": m.content} for m in messages],
**params,
)
return resp.choices[0].message.content or ""
async def stream(self, messages: Sequence[Message], *, model: str, **params) -> AsyncIterator[str]:
stream = await self._client.chat.completions.create(
model=model,
messages=[{"role": m.role, "content": m.content} for m in messages],
stream=True,
**params,
)
async for chunk in stream:
delta = chunk.choices[0].delta.content or ""
if delta:
yield delta
+16
View File
@@ -0,0 +1,16 @@
from __future__ import annotations
from typing import AsyncIterator, Sequence
from .client import Message
class MockLLMClient:
def __init__(self, canned: list[str]):
self._canned = list(canned)
async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str:
return self._canned.pop(0)
async def stream(self, messages: Sequence[Message], *, model: str, **params) -> AsyncIterator[str]:
text = self._canned.pop(0)
for ch in text:
yield ch
+21
View File
@@ -0,0 +1,21 @@
import pytest
from chat.llm.mock import MockLLMClient
from chat.llm.client import Message
@pytest.mark.asyncio
async def test_mock_returns_canned_response():
client = MockLLMClient(canned=["Hello, world."])
msgs = [Message(role="user", content="hi")]
out = await client.generate(msgs, model="any")
assert out == "Hello, world."
@pytest.mark.asyncio
async def test_mock_streams_tokens():
client = MockLLMClient(canned=["abcd"])
msgs = [Message(role="user", content="hi")]
chunks = []
async for chunk in client.stream(msgs, model="any"):
chunks.append(chunk)
assert "".join(chunks) == "abcd"