feat: LLMClient protocol with Featherless and mock implementations
This commit is contained in:
@@ -0,0 +1,14 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Protocol, AsyncIterator, Sequence
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Message:
|
||||||
|
role: str # "system" | "user" | "assistant"
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class LLMClient(Protocol):
|
||||||
|
async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str: ...
|
||||||
|
def stream(self, messages: Sequence[Message], *, model: str, **params) -> AsyncIterator[str]: ...
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import AsyncIterator, Sequence
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
from .client import Message
|
||||||
|
|
||||||
|
|
||||||
|
class FeatherlessClient:
|
||||||
|
def __init__(self, api_key: str, base_url: str = "https://api.featherless.ai/v1"):
|
||||||
|
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
|
async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str:
|
||||||
|
resp = await self._client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[{"role": m.role, "content": m.content} for m in messages],
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
return resp.choices[0].message.content or ""
|
||||||
|
|
||||||
|
async def stream(self, messages: Sequence[Message], *, model: str, **params) -> AsyncIterator[str]:
|
||||||
|
stream = await self._client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[{"role": m.role, "content": m.content} for m in messages],
|
||||||
|
stream=True,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
async for chunk in stream:
|
||||||
|
delta = chunk.choices[0].delta.content or ""
|
||||||
|
if delta:
|
||||||
|
yield delta
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import AsyncIterator, Sequence
|
||||||
|
from .client import Message
|
||||||
|
|
||||||
|
|
||||||
|
class MockLLMClient:
|
||||||
|
def __init__(self, canned: list[str]):
|
||||||
|
self._canned = list(canned)
|
||||||
|
|
||||||
|
async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str:
|
||||||
|
return self._canned.pop(0)
|
||||||
|
|
||||||
|
async def stream(self, messages: Sequence[Message], *, model: str, **params) -> AsyncIterator[str]:
|
||||||
|
text = self._canned.pop(0)
|
||||||
|
for ch in text:
|
||||||
|
yield ch
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
import pytest
|
||||||
|
from chat.llm.mock import MockLLMClient
|
||||||
|
from chat.llm.client import Message
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mock_returns_canned_response():
|
||||||
|
client = MockLLMClient(canned=["Hello, world."])
|
||||||
|
msgs = [Message(role="user", content="hi")]
|
||||||
|
out = await client.generate(msgs, model="any")
|
||||||
|
assert out == "Hello, world."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mock_streams_tokens():
|
||||||
|
client = MockLLMClient(canned=["abcd"])
|
||||||
|
msgs = [Message(role="user", content="hi")]
|
||||||
|
chunks = []
|
||||||
|
async for chunk in client.stream(msgs, model="any"):
|
||||||
|
chunks.append(chunk)
|
||||||
|
assert "".join(chunks) == "abcd"
|
||||||
Reference in New Issue
Block a user