diff --git a/chat/services/thread_detection.py b/chat/services/thread_detection.py new file mode 100644 index 0000000..ca1c0d8 --- /dev/null +++ b/chat/services/thread_detection.py @@ -0,0 +1,89 @@ +"""Thread-detection service (T55). + +On scene close, classify the transcript into thread open/update/close +candidates. Returns ThreadCandidate list; caller (T58 scene compression) +emits one thread_opened/thread_updated/thread_closed event per candidate. +""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + +from chat.llm.classify import classify +from chat.llm.client import LLMClient + + +class ThreadCandidate(BaseModel): + action: str # "open" | "update" | "close" + title: str = "" # required for "open"; ignored otherwise + summary: str = "" + existing_thread_id: str | None = None # required for "update" / "close" + + +class ThreadDetectionResult(BaseModel): + candidates: list[ThreadCandidate] = Field(default_factory=list) + + +_SYSTEM = ( + "You analyze a closed scene's transcript to identify narrative " + "threads (unresolved arcs, dangling questions, promises made, " + "open obligations). Choose actions:\n" + "- 'open': a NEW thread the scene introduced. Provide title (short " + "noun phrase) + summary (one sentence).\n" + "- 'update': an EXISTING open thread that the scene developed. " + "Provide existing_thread_id + new summary.\n" + "- 'close': an EXISTING open thread that the scene resolved. " + "Provide existing_thread_id; summary may capture the resolution.\n" + "Conservative bias: most scenes do NOT open new threads. Only " + "produce candidates when the transcript clearly justifies them. " + "Output strict JSON matching the schema." +) + + +async def detect_threads( + client: LLMClient, + *, + classifier_model: str, + scene_transcript: list[dict], # [{speaker, text}, ...] + open_threads: list[dict], # [{thread_id, title, summary}, ...] + timeout_s: float = 30.0, +) -> ThreadDetectionResult: + """Classify scene close into thread open/update/close candidates.""" + if not scene_transcript: + return ThreadDetectionResult() + + transcript_lines = [ + f"{turn.get('speaker', 'unknown')}: {turn.get('text', '')}" + for turn in scene_transcript + ] + threads_lines = [] + if open_threads: + threads_lines.append("Currently open threads:") + for t in open_threads: + threads_lines.append( + f"- thread_id={t['thread_id']} " + f"title={t.get('title', '')} " + f"summary={t.get('summary', '')}" + ) + else: + threads_lines.append("No currently open threads.") + + user = ( + "Scene transcript:\n" + + "\n".join(transcript_lines) + + "\n\n" + + "\n".join(threads_lines) + ) + + return await classify( + client, + model=classifier_model, + system=_SYSTEM, + user=user, + schema=ThreadDetectionResult, + default=ThreadDetectionResult(), + timeout_s=timeout_s, + ) + + +__all__ = ["ThreadCandidate", "ThreadDetectionResult", "detect_threads"] diff --git a/tests/test_thread_detection.py b/tests/test_thread_detection.py new file mode 100644 index 0000000..0249407 --- /dev/null +++ b/tests/test_thread_detection.py @@ -0,0 +1,128 @@ +"""Tests for the thread-detection service (T55). + +On scene close, the transcript is classified to detect open threads +(unresolved arcs, dangling questions, promises made). The service can +also signal **update** to an existing thread when the scene developed +it, or **close** when the scene resolved it. + +These tests cover: + +* A new thread the scene introduced — action="open" with a fresh title. +* An update to an existing thread — action="update" with + ``existing_thread_id`` referencing the prior thread. +* Classifier failure — three bad responses degrade to an empty + candidates list (graceful degradation, §3.3). +* Empty transcript short-circuits before any classifier call. +""" + +from __future__ import annotations + +import json + +import pytest + +from chat.llm.mock import MockLLMClient +from chat.services.thread_detection import ( + ThreadCandidate, + ThreadDetectionResult, + detect_threads, +) + + +@pytest.mark.asyncio +async def test_detects_new_thread_open(): + canned = json.dumps( + { + "candidates": [ + { + "action": "open", + "title": "Maya's job hunt", + "summary": "Maya is looking for a new job", + "existing_thread_id": None, + } + ] + } + ) + mock = MockLLMClient(canned=[canned]) + result = await detect_threads( + mock, + classifier_model="x", + scene_transcript=[ + {"speaker": "Maya", "text": "I need to find a new job soon."}, + {"speaker": "Sam", "text": "What kind of role are you looking for?"}, + ], + open_threads=[], + ) + assert isinstance(result, ThreadDetectionResult) + assert len(result.candidates) == 1 + cand = result.candidates[0] + assert isinstance(cand, ThreadCandidate) + assert cand.action == "open" + assert cand.title == "Maya's job hunt" + assert cand.summary == "Maya is looking for a new job" + assert cand.existing_thread_id is None + + +@pytest.mark.asyncio +async def test_detects_update_to_existing_thread(): + canned = json.dumps( + { + "candidates": [ + { + "action": "update", + "title": "", + "summary": "Maya interviewed at Acme today", + "existing_thread_id": "thr_jobhunt", + } + ] + } + ) + mock = MockLLMClient(canned=[canned]) + result = await detect_threads( + mock, + classifier_model="x", + scene_transcript=[ + {"speaker": "Maya", "text": "I had the Acme interview today."}, + {"speaker": "Sam", "text": "How did it go?"}, + ], + open_threads=[ + { + "thread_id": "thr_jobhunt", + "title": "Maya's job hunt", + "summary": "Maya is looking for a new job", + } + ], + ) + assert len(result.candidates) == 1 + cand = result.candidates[0] + assert cand.action == "update" + assert cand.existing_thread_id == "thr_jobhunt" + assert cand.summary == "Maya interviewed at Acme today" + + +@pytest.mark.asyncio +async def test_classifier_failure_returns_empty(): + """Three malformed classifier responses → empty candidates list.""" + mock = MockLLMClient(canned=["not json", "still not json", "{bad"]) + result = await detect_threads( + mock, + classifier_model="x", + scene_transcript=[ + {"speaker": "Maya", "text": "Anything could happen here."}, + ], + open_threads=[], + ) + assert result.candidates == [] + + +@pytest.mark.asyncio +async def test_empty_transcript_short_circuits(): + """Empty transcript short-circuits — classifier must not be called.""" + mock = MockLLMClient(canned=[]) + result = await detect_threads( + mock, + classifier_model="x", + scene_transcript=[], + open_threads=[], + ) + assert result.candidates == []