diff --git a/chat/llm/__init__.py b/chat/llm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chat/llm/client.py b/chat/llm/client.py new file mode 100644 index 0000000..ca34a2d --- /dev/null +++ b/chat/llm/client.py @@ -0,0 +1,14 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Protocol, AsyncIterator, Sequence + + +@dataclass +class Message: + role: str # "system" | "user" | "assistant" + content: str + + +class LLMClient(Protocol): + async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str: ... + def stream(self, messages: Sequence[Message], *, model: str, **params) -> AsyncIterator[str]: ... diff --git a/chat/llm/featherless.py b/chat/llm/featherless.py new file mode 100644 index 0000000..3e1fbcf --- /dev/null +++ b/chat/llm/featherless.py @@ -0,0 +1,29 @@ +from __future__ import annotations +from typing import AsyncIterator, Sequence +from openai import AsyncOpenAI +from .client import Message + + +class FeatherlessClient: + def __init__(self, api_key: str, base_url: str = "https://api.featherless.ai/v1"): + self._client = AsyncOpenAI(api_key=api_key, base_url=base_url) + + async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str: + resp = await self._client.chat.completions.create( + model=model, + messages=[{"role": m.role, "content": m.content} for m in messages], + **params, + ) + return resp.choices[0].message.content or "" + + async def stream(self, messages: Sequence[Message], *, model: str, **params) -> AsyncIterator[str]: + stream = await self._client.chat.completions.create( + model=model, + messages=[{"role": m.role, "content": m.content} for m in messages], + stream=True, + **params, + ) + async for chunk in stream: + delta = chunk.choices[0].delta.content or "" + if delta: + yield delta diff --git a/chat/llm/mock.py b/chat/llm/mock.py new file mode 100644 index 0000000..75ab786 --- /dev/null +++ b/chat/llm/mock.py @@ -0,0 +1,16 @@ +from __future__ import annotations +from typing import AsyncIterator, Sequence +from .client import Message + + +class MockLLMClient: + def __init__(self, canned: list[str]): + self._canned = list(canned) + + async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str: + return self._canned.pop(0) + + async def stream(self, messages: Sequence[Message], *, model: str, **params) -> AsyncIterator[str]: + text = self._canned.pop(0) + for ch in text: + yield ch diff --git a/tests/test_llm_mock.py b/tests/test_llm_mock.py new file mode 100644 index 0000000..d56a783 --- /dev/null +++ b/tests/test_llm_mock.py @@ -0,0 +1,21 @@ +import pytest +from chat.llm.mock import MockLLMClient +from chat.llm.client import Message + + +@pytest.mark.asyncio +async def test_mock_returns_canned_response(): + client = MockLLMClient(canned=["Hello, world."]) + msgs = [Message(role="user", content="hi")] + out = await client.generate(msgs, model="any") + assert out == "Hello, world." + + +@pytest.mark.asyncio +async def test_mock_streams_tokens(): + client = MockLLMClient(canned=["abcd"]) + msgs = [Message(role="user", content="hi")] + chunks = [] + async for chunk in client.stream(msgs, model="any"): + chunks.append(chunk) + assert "".join(chunks) == "abcd"