diff --git a/chat/services/regenerate.py b/chat/services/regenerate.py index cb6b23c..6565aea 100644 --- a/chat/services/regenerate.py +++ b/chat/services/regenerate.py @@ -80,6 +80,10 @@ from chat.services.interjection import detect_interjection from chat.services.memory_write import record_turn_memory_for_present from chat.services.multi_state_update import compute_state_updates_for_present from chat.services.prompt import assemble_narrative_prompt +from chat.services.turn_common import ( + gather_prior_edges, + read_recent_dialogue, +) from chat.state.edges import get_edge from chat.state.entities import get_bot, get_you from chat.state.events import list_active_events @@ -209,33 +213,30 @@ async def regenerate_assistant_turn( # assistant_turn explicitly (we haven't superseded it yet — that # update lands at the end so the new event_id is known) and use the # standard ``superseded_by IS NULL AND hidden = 0`` filter so any - # prior regenerates also drop out. + # prior regenerates also drop out. T83.2: shared helper handles the + # SQL + filtering; we post-process to map speaker ids to display + # names for the prompt. you_entity = get_you(conn) or {"name": "you", "persona": ""} you_name = you_entity.get("name", "you") - cur = conn.execute( - "SELECT id, kind, payload_json FROM event_log " - "WHERE kind IN ('user_turn', 'user_turn_edit', 'assistant_turn') " - " AND id != ? " - " AND superseded_by IS NULL AND hidden = 0 " - "ORDER BY id DESC LIMIT 20", - (original_assistant_event_id,), + raw_recent = read_recent_dialogue( + conn, + chat_id, + limit=20, + exclude_event_id=original_assistant_event_id, ) - rows = list(reversed(cur.fetchall())) recent: list[dict] = [] - for _eid, kind, payload_json in rows: - p = json.loads(payload_json) - if p.get("chat_id") != chat_id: + for entry in raw_recent: + spk = entry.get("speaker", "bot") + if spk == "you": + recent.append({"speaker": you_name, "text": entry.get("text", "")}) continue - if kind in ("user_turn", "user_turn_edit"): - recent.append({"speaker": you_name, "text": p.get("prose", "")}) - else: - spk = p.get("speaker_id", "bot") + if spk == host_bot_id: spk_name = host_bot.get("name", "bot") - if spk == host_bot_id: - spk_name = host_bot.get("name", "bot") - elif guest_bot is not None and spk == guest_bot.get("id"): - spk_name = guest_bot.get("name", "bot") - recent.append({"speaker": spk_name, "text": p.get("text", "")}) + elif guest_bot is not None and spk == guest_bot.get("id"): + spk_name = guest_bot.get("name", "bot") + else: + spk_name = host_bot.get("name", "bot") + recent.append({"speaker": spk_name, "text": entry.get("text", "")}) # 4. Assemble the narrative prompt. ``recent`` already excludes the # current user prose, which we pass through ``user_turn_prose``. @@ -373,17 +374,8 @@ async def regenerate_assistant_turn( present_names[guest_bot_id] = guest_bot.get("name", "bot") personas[guest_bot_id] = guest_bot.get("persona") or "" - prior_edges: dict[tuple[str, str], dict] = {} - for src in present_ids: - for tgt in present_ids: - if src == tgt: - continue - edge = get_edge(conn, src, tgt) or { - "affinity": 50, - "trust": 50, - "summary": "", - } - prior_edges[(src, tgt)] = edge + # T83.2: shared helper builds the directed-pair edge dict. + prior_edges = gather_prior_edges(conn, present_ids) state_updates = await compute_state_updates_for_present( client, @@ -472,34 +464,27 @@ async def regenerate_assistant_turn( ) if decision.should_interject: - # Re-read recent so the just-appended primary is in the prompt. - interject_cur = conn.execute( - "SELECT id, kind, payload_json FROM event_log " - "WHERE kind IN ('user_turn', 'user_turn_edit', 'assistant_turn') " - " AND superseded_by IS NULL AND hidden = 0 " - "ORDER BY id DESC LIMIT 20", - ) - interject_rows = list(reversed(interject_cur.fetchall())) + # Re-read recent so the just-appended primary is in the + # prompt. T83.2: shared helper + the same id->name mapping + # as the primary read above. + raw_interject = read_recent_dialogue(conn, chat_id, limit=20) interject_recent: list[dict] = [] - for _eid, kind, payload_json in interject_rows: - p = json.loads(payload_json) - if p.get("chat_id") != chat_id: + for entry in raw_interject: + spk = entry.get("speaker", "bot") + if spk == "you": + interject_recent.append( + {"speaker": you_name, "text": entry.get("text", "")} + ) continue - if kind in ("user_turn", "user_turn_edit"): - interject_recent.append( - {"speaker": you_name, "text": p.get("prose", "")} - ) + if spk == host_bot_id: + spk_name = host_bot.get("name", "bot") + elif spk == guest_bot.get("id"): + spk_name = guest_bot.get("name", "bot") else: - spk = p.get("speaker_id", "bot") - if spk == host_bot_id: - spk_name = host_bot.get("name", "bot") - elif spk == guest_bot.get("id"): - spk_name = guest_bot.get("name", "bot") - else: - spk_name = "bot" - interject_recent.append( - {"speaker": spk_name, "text": p.get("text", "")} - ) + spk_name = "bot" + interject_recent.append( + {"speaker": spk_name, "text": entry.get("text", "")} + ) if interject_recent and interject_recent[-1].get("speaker") == you_name: interject_recent = interject_recent[:-1] @@ -603,17 +588,8 @@ async def regenerate_assistant_turn( "text": interject_text, } ] - prior_edges_post: dict[tuple[str, str], dict] = {} - for src in present_ids: - for tgt in present_ids: - if src == tgt: - continue - edge = get_edge(conn, src, tgt) or { - "affinity": 50, - "trust": 50, - "summary": "", - } - prior_edges_post[(src, tgt)] = edge + # T83.2: shared helper handles the directed-pair edge dict. + prior_edges_post = gather_prior_edges(conn, present_ids) state_updates_post = await compute_state_updates_for_present( client, diff --git a/chat/services/turn_common.py b/chat/services/turn_common.py new file mode 100644 index 0000000..e4c5444 --- /dev/null +++ b/chat/services/turn_common.py @@ -0,0 +1,118 @@ +"""Shared helpers for turn flows (T83.2). + +Both ``chat.web.turns.post_turn`` and +``chat.services.regenerate.regenerate_assistant_turn`` need to: + +1. Pull a chronological tail of user-side and assistant_turn events for + prompt assembly + state-update inputs. +2. Build a directed-edge dict over a fixed set of "present" entity ids + for the multi-pair state-update pass (with the schema 50/50 default + filled in for missing rows). + +Before T83.2 each call site had its own copy of these blocks. The two +copies drifted on details (T73.1 added ``user_turn_edit`` handling to +turns.py; regenerate.py had a slightly different recent-window query). +This module is the single source so a future change to either lands in +both flows by construction. + +Note on overlap with ``chat.services.scene_summarize._read_recent_dialogue``: +that helper has a ``since_event_id`` clamp (T80.2 thread-detection +scope) and intentionally does NOT include ``user_turn_edit`` events — +its callers want the *original* prose, not edits. Deduplicating it +into here would either (a) require a new flag on the shared helper for +``user_turn_edit`` inclusion, or (b) silently change scene_summarize's +read shape. Both feel more invasive than the duplication is bad, so +that helper is left alone for now. +""" + +from __future__ import annotations + +import json +from sqlite3 import Connection + +from chat.state.edges import get_edge + + +def read_recent_dialogue( + conn: Connection, + chat_id: str, + *, + limit: int = 50, + exclude_event_id: int | None = None, +) -> list[dict]: + """Pull the last ``limit`` user-side / assistant_turn events for + ``chat_id`` as ``[{"speaker": , "text": }]``, + chronologically ordered (oldest first). + + Filters: ``superseded_by IS NULL AND hidden = 0`` — regenerated + rows drop out so the timeline reflects the current state. Includes + ``user_turn``, ``user_turn_edit`` (T29 edited prose substitutes for + the original — the original is marked superseded above), and + ``assistant_turn`` rows. + + ``exclude_event_id`` is an optional event_log id to skip — used by + regenerate to drop the original assistant_turn from its prompt + context window before that row has been marked superseded (the + supersede UPDATE lands at the end so the new event_id is known). + """ + if exclude_event_id is None: + cur = conn.execute( + "SELECT id, kind, payload_json FROM event_log " + "WHERE kind IN ('user_turn', 'user_turn_edit', 'assistant_turn') " + " AND superseded_by IS NULL AND hidden = 0 " + "ORDER BY id DESC LIMIT ?", + (limit,), + ) + else: + cur = conn.execute( + "SELECT id, kind, payload_json FROM event_log " + "WHERE kind IN ('user_turn', 'user_turn_edit', 'assistant_turn') " + " AND id != ? " + " AND superseded_by IS NULL AND hidden = 0 " + "ORDER BY id DESC LIMIT ?", + (exclude_event_id, limit), + ) + rows = list(reversed(cur.fetchall())) + out: list[dict] = [] + for _row_id, kind, payload_json in rows: + p = json.loads(payload_json) + if p.get("chat_id") != chat_id: + continue + if kind in ("user_turn", "user_turn_edit"): + out.append({"speaker": "you", "text": p.get("prose", "")}) + else: + out.append( + { + "speaker": p.get("speaker_id", "bot"), + "text": p.get("text", ""), + } + ) + return out + + +def gather_prior_edges( + conn: Connection, present_ids: list[str] +) -> dict[tuple[str, str], dict]: + """Build ``{(src, tgt): {affinity, trust, summary}}`` for every + directed pair where both ``src`` and ``tgt`` are in ``present_ids`` + and ``src != tgt``. + + Missing rows fall back to the schema default 50/50 baseline (mirrors + the Phase 1 single-pair flow). Used by post_turn and regenerate to + seed the multi-pair state-update classifier. + """ + prior_edges: dict[tuple[str, str], dict] = {} + for src in present_ids: + for tgt in present_ids: + if src == tgt: + continue + edge = get_edge(conn, src, tgt) or { + "affinity": 50, + "trust": 50, + "summary": "", + } + prior_edges[(src, tgt)] = edge + return prior_edges + + +__all__ = ["read_recent_dialogue", "gather_prior_edges"] diff --git a/chat/web/turns.py b/chat/web/turns.py index 0368d8b..3505c42 100644 --- a/chat/web/turns.py +++ b/chat/web/turns.py @@ -71,6 +71,10 @@ from chat.services.prompt import ( from chat.services.rewind import compute_rewind_preview, execute_rewind from chat.services.scene_close import detect_scene_close from chat.services.scene_summarize import apply_scene_close_summary +from chat.services.turn_common import ( + gather_prior_edges, + read_recent_dialogue, +) from chat.services.turn_parse import ParsedTurn, parse_turn from chat.state.edges import get_edge from chat.state.entities import get_bot, get_you @@ -113,38 +117,13 @@ def _strip_ooc_for_prompt(parsed: ParsedTurn) -> str: def _read_recent_dialogue(conn, chat_id: str, limit: int = 200) -> list[dict]: """Return user-side and assistant_turn events for ``chat_id``. - Includes ``user_turn``, ``user_turn_edit`` (T29 edited prose), and - ``assistant_turn``. Ordered oldest-first; superseded/hidden rows are - skipped so regenerated turns (T29) drop out of the rendered timeline. - Each entry is shaped ``{"speaker": , "text": }`` - for the prompt assembler and the chat-detail template. + T83.2: thin delegate over + :func:`chat.services.turn_common.read_recent_dialogue` so post_turn + and regenerate share one implementation. The wrapper survives so + the chat-detail template and other callers in this module don't all + have to update at once. """ - cur = conn.execute( - "SELECT id, kind, payload_json FROM event_log " - "WHERE kind IN ('user_turn', 'user_turn_edit', 'assistant_turn') " - " AND superseded_by IS NULL AND hidden = 0 " - "ORDER BY id DESC LIMIT ?", - (limit,), - ) - rows = cur.fetchall() - rows.reverse() # back to chronological order - out: list[dict] = [] - for _row_id, kind, payload_json in rows: - p = json.loads(payload_json) - if p.get("chat_id") != chat_id: - continue - if kind in ("user_turn", "user_turn_edit"): - # Edited prose substitutes for the original user_turn (the - # original is marked superseded_by and filtered above). - out.append({"speaker": "you", "text": p.get("prose", "")}) - else: - out.append( - { - "speaker": p.get("speaker_id", "bot"), - "text": p.get("text", ""), - } - ) - return out + return read_recent_dialogue(conn, chat_id, limit=limit) def _detect_addressee_id( @@ -211,17 +190,8 @@ def _gather_state_update_inputs( present_names[guest_bot["id"]] = guest_bot["name"] personas[guest_bot["id"]] = guest_bot.get("persona") or "" - prior_edges: dict[tuple[str, str], dict] = {} - for src in present_ids: - for tgt in present_ids: - if src == tgt: - continue - edge = get_edge(conn, src, tgt) or { - "affinity": 50, - "trust": 50, - "summary": "", - } - prior_edges[(src, tgt)] = edge + # T83.2: directed-edge gather is shared with regenerate.py. + prior_edges = gather_prior_edges(conn, present_ids) return present_ids, present_names, personas, prior_edges diff --git a/tests/test_turn_common.py b/tests/test_turn_common.py new file mode 100644 index 0000000..f4b4f9b --- /dev/null +++ b/tests/test_turn_common.py @@ -0,0 +1,215 @@ +"""Shared turn helpers (T83.2). + +``chat.services.turn_common`` extracts two snippets that were duplicated +between ``chat.web.turns`` and ``chat.services.regenerate``: the recent +user-side / assistant_turn read, and the directed-pair edge gather for +the multi-pair state-update pass. These tests pin the helpers' behavior +independently of either call site. +""" + +from __future__ import annotations + +from chat.db.connection import open_db +from chat.db.migrate import apply_migrations +from chat.eventlog.log import append_event +from chat.eventlog.projector import project +from chat.services.turn_common import gather_prior_edges, read_recent_dialogue + + +def _seed_basic_chat(db_path): + """Seed bot + chat + a couple of edges + one round of user/assistant + turns. Returns ``(user_turn_id, assistant_turn_id)``. + """ + apply_migrations(db_path) + with open_db(db_path) as conn: + append_event( + conn, + kind="bot_authored", + payload={ + "id": "bot_a", + "name": "BotA", + "persona": "thoughtful", + "voice_samples": [], + "traits": [], + "backstory": "", + "initial_relationship_to_you": "", + "kickoff_prose": "", + }, + ) + append_event( + conn, + kind="chat_created", + payload={ + "id": "chat_a", + "host_bot_id": "bot_a", + "initial_time": "2026-04-26T20:00:00+00:00", + "narrative_anchor": "Day 1", + "weather": "", + }, + ) + append_event( + conn, + kind="edge_update", + payload={ + "source_id": "bot_a", + "target_id": "you", + "chat_id": "chat_a", + "affinity_delta": 7, + "trust_delta": 3, + }, + ) + append_event( + conn, + kind="edge_update", + payload={ + "source_id": "you", + "target_id": "bot_a", + "chat_id": "chat_a", + "affinity_delta": 2, + "trust_delta": 1, + }, + ) + ut_id = append_event( + conn, + kind="user_turn", + payload={ + "chat_id": "chat_a", + "prose": "hello", + "segments": [], + }, + ) + at_id = append_event( + conn, + kind="assistant_turn", + payload={ + "chat_id": "chat_a", + "speaker_id": "bot_a", + "text": "Original.", + "truncated": False, + "user_turn_id": ut_id, + }, + ) + project(conn) + return ut_id, at_id + + +def test_read_recent_dialogue_returns_chronological_pairs(tmp_path): + """``read_recent_dialogue`` returns oldest-first ``{speaker, text}`` + entries scoped to the requested chat. Speaker is "you" for user-side + rows and the assistant_turn's ``speaker_id`` for bot rows. + """ + db = tmp_path / "test.db" + _seed_basic_chat(db) + + with open_db(db) as conn: + out = read_recent_dialogue(conn, "chat_a", limit=10) + + assert out == [ + {"speaker": "you", "text": "hello"}, + {"speaker": "bot_a", "text": "Original."}, + ] + + +def test_read_recent_dialogue_filters_superseded_and_other_chats(tmp_path): + """Superseded rows drop out (regenerate-aware). Rows scoped to a + different chat are also filtered. ``exclude_event_id`` excludes a + specific row even when it isn't superseded yet (regenerate uses this + to drop the original assistant_turn before the supersede UPDATE + lands). + """ + db = tmp_path / "test.db" + ut_id, at_id = _seed_basic_chat(db) + + with open_db(db) as conn: + # Append a second user/assistant pair. + ut_id2 = append_event( + conn, + kind="user_turn", + payload={ + "chat_id": "chat_a", + "prose": "how are you", + "segments": [], + }, + ) + at_id2 = append_event( + conn, + kind="assistant_turn", + payload={ + "chat_id": "chat_a", + "speaker_id": "bot_a", + "text": "Second.", + "truncated": False, + "user_turn_id": ut_id2, + }, + ) + # And a row scoped to a different chat — must NOT appear. + append_event( + conn, + kind="user_turn", + payload={ + "chat_id": "other_chat", + "prose": "should be filtered", + "segments": [], + }, + ) + # Mark the first assistant_turn as superseded — must drop out. + conn.execute( + "UPDATE event_log SET superseded_by = ? WHERE id = ?", + (at_id2, at_id), + ) + + out = read_recent_dialogue(conn, "chat_a", limit=10) + # First (superseded) assistant turn dropped; "other_chat" rows + # filtered; first user_turn still present. + speakers = [(e["speaker"], e["text"]) for e in out] + assert speakers == [ + ("you", "hello"), + ("you", "how are you"), + ("bot_a", "Second."), + ] + + # exclude_event_id drops at_id2 even though it's not superseded. + out2 = read_recent_dialogue( + conn, "chat_a", limit=10, exclude_event_id=at_id2 + ) + speakers2 = [(e["speaker"], e["text"]) for e in out2] + assert ("bot_a", "Second.") not in speakers2 + assert ("you", "how are you") in speakers2 + + # Ensure ut_id is still part of the dataset (sanity for the seed). + assert ut_id is not None + + +def test_gather_prior_edges_fills_missing_with_default(tmp_path): + """``gather_prior_edges`` returns one entry per directed pair across + ``present_ids``. Missing rows fall back to the schema default + 50/50 baseline; existing rows carry their stored values. + """ + db = tmp_path / "test.db" + _seed_basic_chat(db) + + with open_db(db) as conn: + out = gather_prior_edges(conn, ["bot_a", "you"]) + + # 2 entities -> 2 directed pairs (a->b and b->a, no self-pairs). + assert set(out.keys()) == {("bot_a", "you"), ("you", "bot_a")} + bot_to_you = out[("bot_a", "you")] + you_to_bot = out[("you", "bot_a")] + # Both edges seeded with deltas — they must reflect the projected + # affinity/trust (not the default 50/50). + assert bot_to_you["affinity"] == 57 # 50 + 7 + assert bot_to_you["trust"] == 53 # 50 + 3 + assert you_to_bot["affinity"] == 52 + assert you_to_bot["trust"] == 51 + + # A pair with no row yet falls back to 50/50. + with open_db(db) as conn: + out_with_missing = gather_prior_edges( + conn, ["bot_a", "you", "ghost_bot"] + ) + # 3 entities -> 6 directed pairs. + assert len(out_with_missing) == 6 + fallback = out_with_missing[("bot_a", "ghost_bot")] + assert fallback["affinity"] == 50 + assert fallback["trust"] == 50 + assert fallback["summary"] == ""