feat: turn input parser via classifier
This commit is contained in:
@@ -0,0 +1,94 @@
|
|||||||
|
"""Turn input parser.
|
||||||
|
|
||||||
|
Service-layer function that splits a user's authored turn into typed
|
||||||
|
segments — ``dialogue``, ``action``, or ``ooc`` (out-of-character).
|
||||||
|
|
||||||
|
Per Requirements §6.1 a turn is mixed prose with three conventions:
|
||||||
|
|
||||||
|
- ``*action*`` (single asterisks around prose) → action segment.
|
||||||
|
- Quoted text, or bare prose between the conventions → dialogue.
|
||||||
|
- ``((double parens))`` → OOC, the author talking to the system rather
|
||||||
|
than the bot. Downstream (T19) strips OOC from the prompt sent to the
|
||||||
|
bot but keeps it in the transcript display.
|
||||||
|
|
||||||
|
A regex-based splitter would brittle on edge cases (unclosed asterisks,
|
||||||
|
nested quotes, mixed punctuation), so v1 delegates the segmentation to
|
||||||
|
the classifier. The configurable ``Settings.ooc_marker`` is *not* read
|
||||||
|
here: the classifier figures OOC out from ``((`` ``))`` regardless of
|
||||||
|
config-time choice; marker-based stripping is a downstream concern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from chat.llm.classify import classify
|
||||||
|
from chat.llm.client import LLMClient
|
||||||
|
|
||||||
|
|
||||||
|
class TurnSegment(BaseModel):
|
||||||
|
"""One classified piece of a turn.
|
||||||
|
|
||||||
|
``kind`` is kept as a plain ``str`` (not a ``Literal``) so an
|
||||||
|
unexpected classifier output doesn't crash parsing — callers that
|
||||||
|
care about specific values can check defensively.
|
||||||
|
"""
|
||||||
|
|
||||||
|
kind: str # "dialogue" | "action" | "ooc"
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class ParsedTurn(BaseModel):
|
||||||
|
"""A turn split into ordered, typed segments."""
|
||||||
|
|
||||||
|
segments: list[TurnSegment]
|
||||||
|
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT = (
|
||||||
|
"You are splitting a roleplay turn into typed segments. The input "
|
||||||
|
"is mixed prose with three conventions:\n"
|
||||||
|
"- *text in single asterisks* is an ACTION segment.\n"
|
||||||
|
"- \"quoted text\" or bare prose between conventions is a DIALOGUE segment.\n"
|
||||||
|
"- ((text in double parens)) is an OOC (out-of-character) segment — "
|
||||||
|
"the author talking to the system, not the in-fiction bot.\n\n"
|
||||||
|
"Output a JSON object with shape "
|
||||||
|
'{"segments": [{"kind": "...", "text": "..."}, ...]} '
|
||||||
|
"where each ``kind`` is exactly one of: dialogue, action, ooc. "
|
||||||
|
"Preserve the original substring text as ``text``: do not rewrite, "
|
||||||
|
"translate, or normalize punctuation — strip only the marker "
|
||||||
|
"characters (asterisks, surrounding quotes, double parens) so "
|
||||||
|
"``text`` is the inner content. Emit segments in the order they "
|
||||||
|
"appear in the input."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def parse_turn(
|
||||||
|
client: LLMClient,
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
prose: str,
|
||||||
|
timeout_s: float = 10.0,
|
||||||
|
) -> ParsedTurn:
|
||||||
|
"""Parse a user turn into typed segments.
|
||||||
|
|
||||||
|
Calls :func:`chat.llm.classify.classify` under the hood. Empty or
|
||||||
|
whitespace-only prose short-circuits to an empty ``ParsedTurn``
|
||||||
|
without an LLM call (the classifier would error on empty input
|
||||||
|
anyway, and the result is unambiguous).
|
||||||
|
|
||||||
|
Raises ``RuntimeError`` if the classifier fails twice — no default
|
||||||
|
is supplied, since the caller (T19's turn flow) is responsible for
|
||||||
|
surfacing the error to the user.
|
||||||
|
"""
|
||||||
|
if not prose.strip():
|
||||||
|
return ParsedTurn(segments=[])
|
||||||
|
|
||||||
|
user_prompt = f"INPUT:\n{prose}"
|
||||||
|
return await classify(
|
||||||
|
client,
|
||||||
|
model=model,
|
||||||
|
system=_SYSTEM_PROMPT,
|
||||||
|
user=user_prompt,
|
||||||
|
schema=ParsedTurn,
|
||||||
|
timeout_s=timeout_s,
|
||||||
|
)
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from chat.llm.mock import MockLLMClient
|
||||||
|
from chat.services.turn_parse import (
|
||||||
|
ParsedTurn,
|
||||||
|
TurnSegment,
|
||||||
|
parse_turn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_turn_three_segment_happy_path():
|
||||||
|
canned = json.dumps(
|
||||||
|
{
|
||||||
|
"segments": [
|
||||||
|
{"kind": "action", "text": "walks over"},
|
||||||
|
{"kind": "dialogue", "text": "Hey."},
|
||||||
|
{"kind": "ooc", "text": "player note"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock = MockLLMClient(canned=[canned])
|
||||||
|
result = await parse_turn(
|
||||||
|
mock,
|
||||||
|
model="m",
|
||||||
|
prose='*walks over* "Hey." ((player note))',
|
||||||
|
)
|
||||||
|
assert isinstance(result, ParsedTurn)
|
||||||
|
assert len(result.segments) == 3
|
||||||
|
kinds = [s.kind for s in result.segments]
|
||||||
|
assert kinds == ["action", "dialogue", "ooc"]
|
||||||
|
texts = [s.text for s in result.segments]
|
||||||
|
assert texts == ["walks over", "Hey.", "player note"]
|
||||||
|
assert all(isinstance(s, TurnSegment) for s in result.segments)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_turn_pure_dialogue_single_segment():
|
||||||
|
canned = json.dumps(
|
||||||
|
{
|
||||||
|
"segments": [
|
||||||
|
{"kind": "dialogue", "text": "Hello there"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock = MockLLMClient(canned=[canned])
|
||||||
|
result = await parse_turn(
|
||||||
|
mock,
|
||||||
|
model="m",
|
||||||
|
prose='"Hello there"',
|
||||||
|
)
|
||||||
|
assert isinstance(result, ParsedTurn)
|
||||||
|
assert len(result.segments) == 1
|
||||||
|
assert result.segments[0].kind == "dialogue"
|
||||||
|
assert result.segments[0].text == "Hello there"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_turn_empty_prose_short_circuits_without_classifier_call():
|
||||||
|
# No canned responses provided — if classify() is invoked it will raise
|
||||||
|
# IndexError on the empty list. The short-circuit must prevent that.
|
||||||
|
mock = MockLLMClient(canned=[])
|
||||||
|
result = await parse_turn(mock, model="m", prose="")
|
||||||
|
assert isinstance(result, ParsedTurn)
|
||||||
|
assert result.segments == []
|
||||||
|
|
||||||
|
# Whitespace-only prose must also short-circuit.
|
||||||
|
result_ws = await parse_turn(mock, model="m", prose=" \n\t ")
|
||||||
|
assert isinstance(result_ws, ParsedTurn)
|
||||||
|
assert result_ws.segments == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_turn_raises_when_classifier_fails_twice():
|
||||||
|
mock = MockLLMClient(canned=["nope", "still nope"])
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await parse_turn(
|
||||||
|
mock,
|
||||||
|
model="m",
|
||||||
|
prose='*shrugs* "whatever"',
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user