diff --git a/chat/db/migrations/0002_classifier_failures.sql b/chat/db/migrations/0002_classifier_failures.sql new file mode 100644 index 0000000..7bfa23a --- /dev/null +++ b/chat/db/migrations/0002_classifier_failures.sql @@ -0,0 +1,8 @@ +CREATE TABLE classifier_failures ( + id INTEGER PRIMARY KEY, + kind TEXT NOT NULL, + model TEXT NOT NULL, + raw_text TEXT, + attempt_count INTEGER NOT NULL, + created_at TEXT NOT NULL DEFAULT (datetime('now')) +); diff --git a/chat/llm/classify.py b/chat/llm/classify.py new file mode 100644 index 0000000..66517c6 --- /dev/null +++ b/chat/llm/classify.py @@ -0,0 +1,41 @@ +from __future__ import annotations +import json +import asyncio +from typing import TypeVar +from pydantic import BaseModel, ValidationError +from .client import LLMClient, Message + +T = TypeVar("T", bound=BaseModel) + +REFUSAL_PATTERNS = ("i can't", "i cannot", "i'm sorry, but", "as an ai") + + +async def classify( + client: LLMClient, + *, + model: str, + system: str, + user: str, + schema: type[T], + default: T | None = None, + timeout_s: float = 10.0, +) -> T: + msgs = [ + Message(role="system", content=system + "\n\nRespond with JSON only matching the schema."), + Message(role="user", content=user), + ] + for attempt in range(2): + try: + text = await asyncio.wait_for( + client.generate(msgs, model=model, response_format={"type": "json_object"}), + timeout=timeout_s, + ) + if any(p in text.lower()[:80] for p in REFUSAL_PATTERNS) and not text.strip().startswith("{"): + raise ValueError("refusal-shaped response") + return schema.model_validate_json(text) + except (ValidationError, ValueError, json.JSONDecodeError, asyncio.TimeoutError): + msgs[0] = Message(role="system", content=system + "\n\nRespond with valid JSON ONLY. No prose.") + continue + if default is None: + raise RuntimeError(f"classify failed for schema {schema.__name__} with no default") + return default diff --git a/tests/test_classify.py b/tests/test_classify.py new file mode 100644 index 0000000..d059eaa --- /dev/null +++ b/tests/test_classify.py @@ -0,0 +1,24 @@ +import pytest +from pydantic import BaseModel +from chat.llm.mock import MockLLMClient +from chat.llm.classify import classify + + +class Verdict(BaseModel): + score: int + reason: str + + +@pytest.mark.asyncio +async def test_classify_parses_valid_json(): + mock = MockLLMClient(canned=['{"score": 2, "reason": "notable"}']) + result = await classify(mock, model="m", system="x", user="y", schema=Verdict) + assert result.score == 2 + + +@pytest.mark.asyncio +async def test_classify_falls_back_on_unparseable_after_retry(): + mock = MockLLMClient(canned=["nope", "still nope"]) + default = Verdict(score=1, reason="fallback") + result = await classify(mock, model="m", system="x", user="y", schema=Verdict, default=default) + assert result.reason == "fallback"