Files
chat/chat/llm/classify.py
T
Joseph Doherty 49be3cf4b9 fix: parse_turn falls back gracefully + classify logs flapping classifiers
The turn endpoint was 500ing in multi-bot scenes whenever the
classifier provider hiccuped on parse_turn — particularly visible
after a guest was added and bots started exchanging turns. The
traceback was 'classify failed for schema ParsedTurn with no default'
because parse_turn was the only classify caller without a default.

Two changes:

- chat/services/turn_parse.py: parse_turn now passes a default that
  wraps the whole prose as one 'dialogue' segment. The narrative
  still fires on the prose; we lose finer-grained segment kinds
  (action vs dialogue vs ooc) on this turn, but the request returns
  cleanly. Updated the existing test that pinned the old
  RuntimeError contract.

- chat/llm/classify.py: when retries are exhausted, log a WARNING
  with the schema name, last error type, and a snippet of the last
  raw text the model returned. Surfaces flapping classifiers in the
  uvicorn log for diagnosis without taking down the request.

Suite: 471 passed in 11.7s.
2026-04-27 15:07:39 -04:00

91 lines
3.3 KiB
Python

from __future__ import annotations
import json
import asyncio
import logging
from typing import TypeVar
from pydantic import BaseModel, ValidationError
from .client import LLMClient, Message
T = TypeVar("T", bound=BaseModel)
_log = logging.getLogger(__name__)
REFUSAL_PATTERNS = ("i can't", "i cannot", "i'm sorry, but", "as an ai")
def _strip_json_fences(text: str) -> str:
"""Strip ```json ... ``` markdown fences if the model wraps its JSON output."""
s = text.strip()
if s.startswith("```"):
# Drop the first fence line (which may be ``` or ```json)
s = s.split("\n", 1)[1] if "\n" in s else s[3:]
# Drop the trailing fence
if s.rstrip().endswith("```"):
s = s.rstrip()[:-3]
return s.strip()
async def classify(
client: LLMClient,
*,
model: str,
system: str,
user: str,
schema: type[T],
default: T | None = None,
timeout_s: float = 10.0,
max_tokens: int = 512,
) -> T:
schema_json = json.dumps(schema.model_json_schema(), indent=2)
schema_block = (
f"\n\nRespond with a single JSON object matching this exact schema. "
f"Use these field names exactly; do not invent your own keys:\n```json\n{schema_json}\n```"
)
msgs = [
Message(role="system", content=system + schema_block),
Message(role="user", content=user),
]
# Cap output length so a misbehaving model (e.g. one that ignores
# ``response_format=json_object`` and generates prose) can't burn
# several seconds on tokens we'll never use. Classifier responses
# are small JSON objects — 512 tokens is generous; usual completions
# are 50-150.
last_text = None
last_error: BaseException | None = None
for attempt in range(3):
try:
text = await asyncio.wait_for(
client.generate(
msgs,
model=model,
response_format={"type": "json_object"},
max_tokens=max_tokens,
),
timeout=timeout_s,
)
last_text = text
cleaned = _strip_json_fences(text)
if any(p in cleaned.lower()[:80] for p in REFUSAL_PATTERNS) and not cleaned.lstrip().startswith("{"):
raise ValueError("refusal-shaped response")
return schema.model_validate_json(cleaned)
except (ValidationError, ValueError, json.JSONDecodeError, asyncio.TimeoutError) as exc:
last_error = exc
msgs[0] = Message(
role="system",
content=system + schema_block + "\n\nRespond with valid JSON ONLY. No prose, no markdown fences.",
)
continue
# Log when we're falling back so flapping classifiers are
# diagnosable without taking down the request.
snippet = (last_text or "")[:200].replace("\n", " ")
_log.warning(
"classify(%s) exhausted 3 attempts; last_error=%s last_text=%r; "
"falling back to %s",
schema.__name__,
type(last_error).__name__ if last_error else "?",
snippet,
"default" if default is not None else "RuntimeError (no default)",
)
if default is None:
raise RuntimeError(f"classify failed for schema {schema.__name__} with no default")
return default