feat: thread-detection service (T55)
This commit is contained in:
@@ -0,0 +1,89 @@
|
||||
"""Thread-detection service (T55).
|
||||
|
||||
On scene close, classify the transcript into thread open/update/close
|
||||
candidates. Returns ThreadCandidate list; caller (T58 scene compression)
|
||||
emits one thread_opened/thread_updated/thread_closed event per candidate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from chat.llm.classify import classify
|
||||
from chat.llm.client import LLMClient
|
||||
|
||||
|
||||
class ThreadCandidate(BaseModel):
|
||||
action: str # "open" | "update" | "close"
|
||||
title: str = "" # required for "open"; ignored otherwise
|
||||
summary: str = ""
|
||||
existing_thread_id: str | None = None # required for "update" / "close"
|
||||
|
||||
|
||||
class ThreadDetectionResult(BaseModel):
|
||||
candidates: list[ThreadCandidate] = Field(default_factory=list)
|
||||
|
||||
|
||||
_SYSTEM = (
|
||||
"You analyze a closed scene's transcript to identify narrative "
|
||||
"threads (unresolved arcs, dangling questions, promises made, "
|
||||
"open obligations). Choose actions:\n"
|
||||
"- 'open': a NEW thread the scene introduced. Provide title (short "
|
||||
"noun phrase) + summary (one sentence).\n"
|
||||
"- 'update': an EXISTING open thread that the scene developed. "
|
||||
"Provide existing_thread_id + new summary.\n"
|
||||
"- 'close': an EXISTING open thread that the scene resolved. "
|
||||
"Provide existing_thread_id; summary may capture the resolution.\n"
|
||||
"Conservative bias: most scenes do NOT open new threads. Only "
|
||||
"produce candidates when the transcript clearly justifies them. "
|
||||
"Output strict JSON matching the schema."
|
||||
)
|
||||
|
||||
|
||||
async def detect_threads(
|
||||
client: LLMClient,
|
||||
*,
|
||||
classifier_model: str,
|
||||
scene_transcript: list[dict], # [{speaker, text}, ...]
|
||||
open_threads: list[dict], # [{thread_id, title, summary}, ...]
|
||||
timeout_s: float = 30.0,
|
||||
) -> ThreadDetectionResult:
|
||||
"""Classify scene close into thread open/update/close candidates."""
|
||||
if not scene_transcript:
|
||||
return ThreadDetectionResult()
|
||||
|
||||
transcript_lines = [
|
||||
f"{turn.get('speaker', 'unknown')}: {turn.get('text', '')}"
|
||||
for turn in scene_transcript
|
||||
]
|
||||
threads_lines = []
|
||||
if open_threads:
|
||||
threads_lines.append("Currently open threads:")
|
||||
for t in open_threads:
|
||||
threads_lines.append(
|
||||
f"- thread_id={t['thread_id']} "
|
||||
f"title={t.get('title', '')} "
|
||||
f"summary={t.get('summary', '')}"
|
||||
)
|
||||
else:
|
||||
threads_lines.append("No currently open threads.")
|
||||
|
||||
user = (
|
||||
"Scene transcript:\n"
|
||||
+ "\n".join(transcript_lines)
|
||||
+ "\n\n"
|
||||
+ "\n".join(threads_lines)
|
||||
)
|
||||
|
||||
return await classify(
|
||||
client,
|
||||
model=classifier_model,
|
||||
system=_SYSTEM,
|
||||
user=user,
|
||||
schema=ThreadDetectionResult,
|
||||
default=ThreadDetectionResult(),
|
||||
timeout_s=timeout_s,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["ThreadCandidate", "ThreadDetectionResult", "detect_threads"]
|
||||
Reference in New Issue
Block a user