diff --git a/chat/services/regenerate.py b/chat/services/regenerate.py index 1317903..5e39b5c 100644 --- a/chat/services/regenerate.py +++ b/chat/services/regenerate.py @@ -68,7 +68,9 @@ Phase 2.5 changes: from __future__ import annotations +import asyncio import json +import logging from sqlite3 import Connection from chat.config import Settings @@ -79,6 +81,10 @@ from chat.services.interjection import detect_interjection from chat.services.memory_write import record_turn_memory_for_present from chat.services.multi_state_update import compute_state_updates_for_present from chat.services.prompt import assemble_narrative_prompt +from chat.services.turn_common import ( + gather_prior_edges, + read_recent_dialogue, +) from chat.state.edges import get_edge from chat.state.entities import get_bot, get_you from chat.state.events import list_active_events @@ -86,6 +92,8 @@ from chat.state.world import active_scene, get_chat from chat.web.pubsub import publish from chat.web.render import render_turn_html +_log = logging.getLogger(__name__) + async def regenerate_assistant_turn( conn: Connection, @@ -104,6 +112,19 @@ async def regenerate_assistant_turn( Raises :class:`ValueError` when the chat or the assistant_turn event cannot be found — the FastAPI route translates this to 404. + + .. note:: + **Lifecycle-rollback limitation (T83.4, Phase 4 follow-up).** + When the superseded turn already produced lifecycle transitions + (``event_started`` / ``event_completed`` / ``event_cancelled``), + this function does NOT roll those rows back before re-running + ``detect_event_transitions`` against the regenerated text. A + regenerate-after-completion can therefore double-emit promotion + artifacts if the new text re-completes the same event. Phase 3.5 + only documents the gap and emits a WARNING log naming the + affected event_log ids; the actual undo pass is invasive + (re-projection / inverse-handler dispatch) and is deferred to + Phase 4. See the ``# T83.4`` block below for the warning emit. """ chat = get_chat(conn, chat_id) if chat is None: @@ -136,6 +157,40 @@ async def regenerate_assistant_turn( original_assistant_payload = json.loads(row[0]) original_user_turn_id = original_assistant_payload.get("user_turn_id") + # T83.4: scan for downstream lifecycle transitions emitted by the + # superseded turn — they're not being rolled back (see method + # docstring). Heuristic: any ``event_started`` / ``event_completed`` + # / ``event_cancelled`` event_log row with id strictly greater than + # the original assistant_turn's id was emitted as part of (or after) + # that turn's processing. Lifecycle events don't carry ``chat_id`` + # in their payload (their payload references an ``event_id`` FK to + # the ``events`` table, which holds chat_id), so we join through + # ``events`` to scope to this chat. + # + # A WARNING log surfaces the affected event ids so operators can + # spot double-emit cases until the Phase 4 rollback pass lands. + unrolled_lifecycle = conn.execute( + "SELECT el.id, el.kind FROM event_log AS el " + "JOIN events AS ev " + " ON ev.event_id = json_extract(el.payload_json, '$.event_id') " + "WHERE el.kind IN (" + " 'event_started', 'event_completed', 'event_cancelled'" + " ) " + " AND ev.chat_id = ? " + " AND el.id > ? " + "ORDER BY el.id ASC", + (chat_id, original_assistant_event_id), + ).fetchall() + if unrolled_lifecycle: + _log.warning( + "regenerate_assistant_turn: %d lifecycle transition(s) from " + "superseded turn %s are NOT being rolled back (Phase 4 " + "follow-up). Affected event ids: %s", + len(unrolled_lifecycle), + original_assistant_event_id, + [r[0] for r in unrolled_lifecycle], + ) + # 1a. Look up any sibling interjection beat in the same turn group # (T73.2). The original group is (primary + optional interjection), # both pinned to the same ``user_turn_id``. The interjection has a @@ -143,6 +198,13 @@ async def regenerate_assistant_turn( # the silent witness (the bot that wasn't the primary addressee). # Filter on ``superseded_by IS NULL`` so prior regenerates of this # group don't reappear as siblings. + # + # T83.3: push the chat_id filter into SQL via ``json_extract`` so + # the query doesn't scan every assistant_turn row across the whole + # database. ``LIMIT 50`` bounds worst-case work even when chat_id + # isn't selective (e.g. a single chat with many turns) — we only + # need the one matching sibling. Mirrors the SQL pattern in + # ``chat.web.meanwhile._last_meanwhile_speaker``. original_interjection_event_id: int | None = None original_interjection_payload: dict | None = None if original_user_turn_id is not None: @@ -150,8 +212,11 @@ async def regenerate_assistant_turn( "SELECT id, payload_json FROM event_log " "WHERE kind = 'assistant_turn' " " AND id != ? " - " AND superseded_by IS NULL", - (original_assistant_event_id,), + " AND superseded_by IS NULL " + " AND json_extract(payload_json, '$.chat_id') = ? " + "ORDER BY id DESC " + "LIMIT 50", + (original_assistant_event_id, chat_id), ) for sib_id, sib_payload_json in sibling_cur.fetchall(): sib_payload = json.loads(sib_payload_json) @@ -208,33 +273,30 @@ async def regenerate_assistant_turn( # assistant_turn explicitly (we haven't superseded it yet — that # update lands at the end so the new event_id is known) and use the # standard ``superseded_by IS NULL AND hidden = 0`` filter so any - # prior regenerates also drop out. + # prior regenerates also drop out. T83.2: shared helper handles the + # SQL + filtering; we post-process to map speaker ids to display + # names for the prompt. you_entity = get_you(conn) or {"name": "you", "persona": ""} you_name = you_entity.get("name", "you") - cur = conn.execute( - "SELECT id, kind, payload_json FROM event_log " - "WHERE kind IN ('user_turn', 'user_turn_edit', 'assistant_turn') " - " AND id != ? " - " AND superseded_by IS NULL AND hidden = 0 " - "ORDER BY id DESC LIMIT 20", - (original_assistant_event_id,), + raw_recent = read_recent_dialogue( + conn, + chat_id, + limit=20, + exclude_event_id=original_assistant_event_id, ) - rows = list(reversed(cur.fetchall())) recent: list[dict] = [] - for _eid, kind, payload_json in rows: - p = json.loads(payload_json) - if p.get("chat_id") != chat_id: + for entry in raw_recent: + spk = entry.get("speaker", "bot") + if spk == "you": + recent.append({"speaker": you_name, "text": entry.get("text", "")}) continue - if kind in ("user_turn", "user_turn_edit"): - recent.append({"speaker": you_name, "text": p.get("prose", "")}) - else: - spk = p.get("speaker_id", "bot") + if spk == host_bot_id: spk_name = host_bot.get("name", "bot") - if spk == host_bot_id: - spk_name = host_bot.get("name", "bot") - elif guest_bot is not None and spk == guest_bot.get("id"): - spk_name = guest_bot.get("name", "bot") - recent.append({"speaker": spk_name, "text": p.get("text", "")}) + elif guest_bot is not None and spk == guest_bot.get("id"): + spk_name = guest_bot.get("name", "bot") + else: + spk_name = host_bot.get("name", "bot") + recent.append({"speaker": spk_name, "text": entry.get("text", "")}) # 4. Assemble the narrative prompt. ``recent`` already excludes the # current user prose, which we pass through ``user_turn_prose``. @@ -250,19 +312,37 @@ async def regenerate_assistant_turn( guest_id=guest_bot_id, ) - # 5. Stream the new narrative. + # 5. Stream the new narrative. T83.1: register the streaming Task in + # the chat-keyed in-flight registry so POST /chats//turns/cancel + # can call ``.cancel()`` on a mid-regenerate stream. We import the + # underscore name from turns.py deliberately — same single-process + # registry the cancel route reads, mirrors the meanwhile registration + # pattern in chat/web/meanwhile.py. + from chat.web.turns import _in_flight_tasks # noqa: PLC0415 + accumulated: list[str] = [] - async for chunk in client.stream( - messages, - model=settings.narrative_model, - max_tokens=settings.narrative_max_tokens, - temperature=settings.narrative_temperature, - ): - accumulated.append(chunk) - await publish( - chat_id, - {"event": "token", "text": chunk, "speaker_id": speaker_bot_id}, - ) + + async def _stream_primary() -> None: + async for chunk in client.stream( + messages, + model=settings.narrative_model, + max_tokens=settings.narrative_max_tokens, + temperature=settings.narrative_temperature, + ): + accumulated.append(chunk) + await publish( + chat_id, + {"event": "token", "text": chunk, "speaker_id": speaker_bot_id}, + ) + + stream_task = asyncio.create_task(_stream_primary()) + _in_flight_tasks[chat_id] = stream_task + try: + await stream_task + finally: + # Always unregister so a subsequent turn / regenerate can register + # a fresh task. Mirrors the cleanup in turns.py::post_turn. + _in_flight_tasks.pop(chat_id, None) new_text = "".join(accumulated) # 6. Append the new assistant_turn event. ``user_turn_id`` points at @@ -354,17 +434,8 @@ async def regenerate_assistant_turn( present_names[guest_bot_id] = guest_bot.get("name", "bot") personas[guest_bot_id] = guest_bot.get("persona") or "" - prior_edges: dict[tuple[str, str], dict] = {} - for src in present_ids: - for tgt in present_ids: - if src == tgt: - continue - edge = get_edge(conn, src, tgt) or { - "affinity": 50, - "trust": 50, - "summary": "", - } - prior_edges[(src, tgt)] = edge + # T83.2: shared helper builds the directed-pair edge dict. + prior_edges = gather_prior_edges(conn, present_ids) state_updates = await compute_state_updates_for_present( client, @@ -453,34 +524,27 @@ async def regenerate_assistant_turn( ) if decision.should_interject: - # Re-read recent so the just-appended primary is in the prompt. - interject_cur = conn.execute( - "SELECT id, kind, payload_json FROM event_log " - "WHERE kind IN ('user_turn', 'user_turn_edit', 'assistant_turn') " - " AND superseded_by IS NULL AND hidden = 0 " - "ORDER BY id DESC LIMIT 20", - ) - interject_rows = list(reversed(interject_cur.fetchall())) + # Re-read recent so the just-appended primary is in the + # prompt. T83.2: shared helper + the same id->name mapping + # as the primary read above. + raw_interject = read_recent_dialogue(conn, chat_id, limit=20) interject_recent: list[dict] = [] - for _eid, kind, payload_json in interject_rows: - p = json.loads(payload_json) - if p.get("chat_id") != chat_id: + for entry in raw_interject: + spk = entry.get("speaker", "bot") + if spk == "you": + interject_recent.append( + {"speaker": you_name, "text": entry.get("text", "")} + ) continue - if kind in ("user_turn", "user_turn_edit"): - interject_recent.append( - {"speaker": you_name, "text": p.get("prose", "")} - ) + if spk == host_bot_id: + spk_name = host_bot.get("name", "bot") + elif spk == guest_bot.get("id"): + spk_name = guest_bot.get("name", "bot") else: - spk = p.get("speaker_id", "bot") - if spk == host_bot_id: - spk_name = host_bot.get("name", "bot") - elif spk == guest_bot.get("id"): - spk_name = guest_bot.get("name", "bot") - else: - spk_name = "bot" - interject_recent.append( - {"speaker": spk_name, "text": p.get("text", "")} - ) + spk_name = "bot" + interject_recent.append( + {"speaker": spk_name, "text": entry.get("text", "")} + ) if interject_recent and interject_recent[-1].get("speaker") == you_name: interject_recent = interject_recent[:-1] @@ -497,21 +561,32 @@ async def regenerate_assistant_turn( ) interject_accumulated: list[str] = [] - async for chunk in client.stream( - interject_messages, - model=settings.narrative_model, - max_tokens=settings.narrative_max_tokens, - temperature=settings.narrative_temperature, - ): - interject_accumulated.append(chunk) - await publish( - chat_id, - { - "event": "token", - "text": chunk, - "speaker_id": silent_witness_id, - }, - ) + + async def _stream_interjection() -> None: + async for chunk in client.stream( + interject_messages, + model=settings.narrative_model, + max_tokens=settings.narrative_max_tokens, + temperature=settings.narrative_temperature, + ): + interject_accumulated.append(chunk) + await publish( + chat_id, + { + "event": "token", + "text": chunk, + "speaker_id": silent_witness_id, + }, + ) + + # T83.1: register the interjection sub-stream in the same + # in-flight registry so /turns/cancel collapses it too. + interject_task = asyncio.create_task(_stream_interjection()) + _in_flight_tasks[chat_id] = interject_task + try: + await interject_task + finally: + _in_flight_tasks.pop(chat_id, None) interject_text = "".join(interject_accumulated) new_interjection_event_id = append_event( @@ -573,17 +648,8 @@ async def regenerate_assistant_turn( "text": interject_text, } ] - prior_edges_post: dict[tuple[str, str], dict] = {} - for src in present_ids: - for tgt in present_ids: - if src == tgt: - continue - edge = get_edge(conn, src, tgt) or { - "affinity": 50, - "trust": 50, - "summary": "", - } - prior_edges_post[(src, tgt)] = edge + # T83.2: shared helper handles the directed-pair edge dict. + prior_edges_post = gather_prior_edges(conn, present_ids) state_updates_post = await compute_state_updates_for_present( client, @@ -620,23 +686,28 @@ async def regenerate_assistant_turn( (new_assistant_event_id, original_interjection_event_id), ) - # 10. Event-lifecycle detection (Phase 3, T61). Mirrors the post_turn - # block: classify whether any active events transitioned in the - # regenerated narrative and append the corresponding event_started / + # 9a. Event-lifecycle detection (Phase 3, T61). T83.5 cosmetic + # ordering: mirrors ``chat.web.turns.post_turn``'s 8a block — runs + # AFTER the interjection branch (and AFTER the post-interjection + # state-update + memory passes) so the classifier sees the same + # narrative-text input post_turn does. Numbering uses ``9a`` to + # match post_turn's ``8a`` shape (the interjection branch is step 9 + # in regenerate vs step 8 in post_turn; lifecycle is the immediate + # follow-on in both). Behaviour identical to the prior ``step 10`` + # placement — the block was already structurally last in regenerate + # because there's no scene-close pass here. + # + # Classify whether any active events transitioned in the regenerated + # narrative and append the corresponding event_started / # event_completed / event_cancelled. ``promote_completed_event`` # runs inline after a completion so promotion artifacts land in the # same regenerate path. # - # Phase 3.5 follow-up: when a regenerate replaces a turn that had - # already produced event transitions, those original transitions are - # NOT undone here. The superseded ``assistant_turn`` group keeps its - # prior ``event_started`` / ``event_completed`` events in the log - # (they remain projected onto the events table). Phase 3.5 will add - # an "undo lifecycle" step to roll back the prior transitions before - # re-classifying the regenerated text. For v3 we accept that a - # regenerate-after-completion will double-emit promotion artifacts - # if the new text re-completes the same event — narratively rare, - # and a true fix needs the lifecycle-undo pass. + # T83.4 follow-up: when a regenerate replaces a turn that had + # already produced event transitions, those original transitions + # are NOT undone here (Phase 4 work). A WARNING log earlier in this + # function names the affected event_log ids — see the T83.4 block + # near the function entry. new_active_events = list_active_events(conn, chat_id) if new_active_events: lifecycle_decision = await detect_event_transitions( diff --git a/chat/services/turn_common.py b/chat/services/turn_common.py new file mode 100644 index 0000000..e4c5444 --- /dev/null +++ b/chat/services/turn_common.py @@ -0,0 +1,118 @@ +"""Shared helpers for turn flows (T83.2). + +Both ``chat.web.turns.post_turn`` and +``chat.services.regenerate.regenerate_assistant_turn`` need to: + +1. Pull a chronological tail of user-side and assistant_turn events for + prompt assembly + state-update inputs. +2. Build a directed-edge dict over a fixed set of "present" entity ids + for the multi-pair state-update pass (with the schema 50/50 default + filled in for missing rows). + +Before T83.2 each call site had its own copy of these blocks. The two +copies drifted on details (T73.1 added ``user_turn_edit`` handling to +turns.py; regenerate.py had a slightly different recent-window query). +This module is the single source so a future change to either lands in +both flows by construction. + +Note on overlap with ``chat.services.scene_summarize._read_recent_dialogue``: +that helper has a ``since_event_id`` clamp (T80.2 thread-detection +scope) and intentionally does NOT include ``user_turn_edit`` events — +its callers want the *original* prose, not edits. Deduplicating it +into here would either (a) require a new flag on the shared helper for +``user_turn_edit`` inclusion, or (b) silently change scene_summarize's +read shape. Both feel more invasive than the duplication is bad, so +that helper is left alone for now. +""" + +from __future__ import annotations + +import json +from sqlite3 import Connection + +from chat.state.edges import get_edge + + +def read_recent_dialogue( + conn: Connection, + chat_id: str, + *, + limit: int = 50, + exclude_event_id: int | None = None, +) -> list[dict]: + """Pull the last ``limit`` user-side / assistant_turn events for + ``chat_id`` as ``[{"speaker": , "text": }]``, + chronologically ordered (oldest first). + + Filters: ``superseded_by IS NULL AND hidden = 0`` — regenerated + rows drop out so the timeline reflects the current state. Includes + ``user_turn``, ``user_turn_edit`` (T29 edited prose substitutes for + the original — the original is marked superseded above), and + ``assistant_turn`` rows. + + ``exclude_event_id`` is an optional event_log id to skip — used by + regenerate to drop the original assistant_turn from its prompt + context window before that row has been marked superseded (the + supersede UPDATE lands at the end so the new event_id is known). + """ + if exclude_event_id is None: + cur = conn.execute( + "SELECT id, kind, payload_json FROM event_log " + "WHERE kind IN ('user_turn', 'user_turn_edit', 'assistant_turn') " + " AND superseded_by IS NULL AND hidden = 0 " + "ORDER BY id DESC LIMIT ?", + (limit,), + ) + else: + cur = conn.execute( + "SELECT id, kind, payload_json FROM event_log " + "WHERE kind IN ('user_turn', 'user_turn_edit', 'assistant_turn') " + " AND id != ? " + " AND superseded_by IS NULL AND hidden = 0 " + "ORDER BY id DESC LIMIT ?", + (exclude_event_id, limit), + ) + rows = list(reversed(cur.fetchall())) + out: list[dict] = [] + for _row_id, kind, payload_json in rows: + p = json.loads(payload_json) + if p.get("chat_id") != chat_id: + continue + if kind in ("user_turn", "user_turn_edit"): + out.append({"speaker": "you", "text": p.get("prose", "")}) + else: + out.append( + { + "speaker": p.get("speaker_id", "bot"), + "text": p.get("text", ""), + } + ) + return out + + +def gather_prior_edges( + conn: Connection, present_ids: list[str] +) -> dict[tuple[str, str], dict]: + """Build ``{(src, tgt): {affinity, trust, summary}}`` for every + directed pair where both ``src`` and ``tgt`` are in ``present_ids`` + and ``src != tgt``. + + Missing rows fall back to the schema default 50/50 baseline (mirrors + the Phase 1 single-pair flow). Used by post_turn and regenerate to + seed the multi-pair state-update classifier. + """ + prior_edges: dict[tuple[str, str], dict] = {} + for src in present_ids: + for tgt in present_ids: + if src == tgt: + continue + edge = get_edge(conn, src, tgt) or { + "affinity": 50, + "trust": 50, + "summary": "", + } + prior_edges[(src, tgt)] = edge + return prior_edges + + +__all__ = ["read_recent_dialogue", "gather_prior_edges"] diff --git a/chat/web/turns.py b/chat/web/turns.py index 0368d8b..3505c42 100644 --- a/chat/web/turns.py +++ b/chat/web/turns.py @@ -71,6 +71,10 @@ from chat.services.prompt import ( from chat.services.rewind import compute_rewind_preview, execute_rewind from chat.services.scene_close import detect_scene_close from chat.services.scene_summarize import apply_scene_close_summary +from chat.services.turn_common import ( + gather_prior_edges, + read_recent_dialogue, +) from chat.services.turn_parse import ParsedTurn, parse_turn from chat.state.edges import get_edge from chat.state.entities import get_bot, get_you @@ -113,38 +117,13 @@ def _strip_ooc_for_prompt(parsed: ParsedTurn) -> str: def _read_recent_dialogue(conn, chat_id: str, limit: int = 200) -> list[dict]: """Return user-side and assistant_turn events for ``chat_id``. - Includes ``user_turn``, ``user_turn_edit`` (T29 edited prose), and - ``assistant_turn``. Ordered oldest-first; superseded/hidden rows are - skipped so regenerated turns (T29) drop out of the rendered timeline. - Each entry is shaped ``{"speaker": , "text": }`` - for the prompt assembler and the chat-detail template. + T83.2: thin delegate over + :func:`chat.services.turn_common.read_recent_dialogue` so post_turn + and regenerate share one implementation. The wrapper survives so + the chat-detail template and other callers in this module don't all + have to update at once. """ - cur = conn.execute( - "SELECT id, kind, payload_json FROM event_log " - "WHERE kind IN ('user_turn', 'user_turn_edit', 'assistant_turn') " - " AND superseded_by IS NULL AND hidden = 0 " - "ORDER BY id DESC LIMIT ?", - (limit,), - ) - rows = cur.fetchall() - rows.reverse() # back to chronological order - out: list[dict] = [] - for _row_id, kind, payload_json in rows: - p = json.loads(payload_json) - if p.get("chat_id") != chat_id: - continue - if kind in ("user_turn", "user_turn_edit"): - # Edited prose substitutes for the original user_turn (the - # original is marked superseded_by and filtered above). - out.append({"speaker": "you", "text": p.get("prose", "")}) - else: - out.append( - { - "speaker": p.get("speaker_id", "bot"), - "text": p.get("text", ""), - } - ) - return out + return read_recent_dialogue(conn, chat_id, limit=limit) def _detect_addressee_id( @@ -211,17 +190,8 @@ def _gather_state_update_inputs( present_names[guest_bot["id"]] = guest_bot["name"] personas[guest_bot["id"]] = guest_bot.get("persona") or "" - prior_edges: dict[tuple[str, str], dict] = {} - for src in present_ids: - for tgt in present_ids: - if src == tgt: - continue - edge = get_edge(conn, src, tgt) or { - "affinity": 50, - "trust": 50, - "summary": "", - } - prior_edges[(src, tgt)] = edge + # T83.2: directed-edge gather is shared with regenerate.py. + prior_edges = gather_prior_edges(conn, present_ids) return present_ids, present_names, personas, prior_edges diff --git a/tests/test_regenerate.py b/tests/test_regenerate.py index 7fa22bc..d8a2d65 100644 --- a/tests/test_regenerate.py +++ b/tests/test_regenerate.py @@ -662,3 +662,356 @@ def test_regenerate_drops_interjection_when_classifier_returns_false( new_primary_payload = json.loads(cur[0][0]) assert new_primary_payload["text"] == "New primary text." assert "interjection_of" not in new_primary_payload + + +def test_regenerate_with_prior_lifecycle_logs_warning(tmp_path, monkeypatch, caplog): + """T83.4: when the superseded assistant_turn already produced + lifecycle transitions (event_started / event_completed / + event_cancelled), regenerate emits a WARNING naming the un-rolled- + back transitions. Phase 3.5 documents the gap; the actual rollback + is Phase 4 work. + """ + import asyncio + import logging + + from chat.config import Settings + from chat.db.migrate import apply_migrations + from chat.eventlog.log import append_and_apply + from chat.services.regenerate import regenerate_assistant_turn + + db_path = tmp_path / "test.db" + cfg = tmp_path / "config.toml" + cfg.write_text('featherless_api_key = "test"\n') + monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg)) + monkeypatch.setenv("CHAT_DB_PATH", str(db_path)) + apply_migrations(db_path) + + _ut_id, at_id = _seed_with_one_turn(db_path) + + # After the assistant_turn lands, simulate that the turn flow + # produced an event_completed transition. ``append_and_apply`` is + # the standard path so the events projection updates. + with open_db(db_path) as conn: + append_and_apply( + conn, + kind="event_planned", + payload={ + "event_id": "evt_x", + "chat_id": "chat_bot_a", + "kind": "story_event", + "props": {}, + "planned_for": "2026-04-30T18:00:00+00:00", + }, + ) + append_and_apply( + conn, + kind="event_started", + payload={ + "event_id": "evt_x", + "started_at": "2026-04-30T19:00:00+00:00", + }, + ) + completed_id = append_and_apply( + conn, + kind="event_completed", + payload={ + "event_id": "evt_x", + "completed_at": "2026-04-30T19:30:00+00:00", + }, + ) + assert completed_id is not None + + state_canned = json.dumps( + {"affinity_delta": 0, "trust_delta": 0, "knowledge_facts": []} + ) + mock_client = MockLLMClient( + canned=["Refreshed reply.", state_canned, state_canned] + ) + settings = Settings(featherless_api_key="test") + + caplog.set_level(logging.WARNING, logger="chat.services.regenerate") + + with open_db(db_path) as conn: + asyncio.run( + regenerate_assistant_turn( + conn, + mock_client, + settings=settings, + chat_id="chat_bot_a", + original_assistant_event_id=at_id, + ) + ) + + # The warning records the count and at least one of the affected + # event_log ids (event_started + event_completed = at minimum 2). + warnings = [ + r for r in caplog.records if r.levelname == "WARNING" + ] + matching = [w for w in warnings if "lifecycle transition" in w.getMessage()] + assert matching, ( + "expected a WARNING about un-rolled-back lifecycle transitions; " + f"got: {[w.getMessage() for w in warnings]}" + ) + msg = matching[0].getMessage() + # Reference the original superseded turn's id and the event_completed + # row's id. + assert str(at_id) in msg + assert str(completed_id) in msg + + +def test_regenerate_sibling_lookup_scoped_to_chat(tmp_path, monkeypatch): + """T83.3: regenerate's sibling-interjection lookup is scoped to the + chat being regenerated. + + Setup: TWO chats, each with a primary + interjection turn group whose + rows happen to share the same ``user_turn_id`` value (the projector + assigns event_log ids monotonically across the whole database, so + when each chat is seeded back-to-back the chat A primary lands on a + different ``user_turn_id`` than chat B's — but in older versions the + sibling query had no chat predicate, so it could in principle latch + onto a row from a different chat if ids collided in some unusual + flow). We construct the seeding so chat B's interjection has the + SAME ``interjection_of`` value as the chat A primary's speaker_id — + pre-T83.3 the global query could have picked it up. + + Assert: regenerating the chat A primary leaves chat B's rows + untouched (no supersede), and the regenerated chat A turn group's + interjection (the only one regenerate should regenerate) has its + ``regenerated_from`` pointing at the chat A original interjection, + not chat B's. + """ + import asyncio + + from chat.config import Settings + from chat.db.migrate import apply_migrations + from chat.services import regenerate as regenerate_module + from chat.services.interjection import InterjectionDecision + from chat.services.regenerate import regenerate_assistant_turn + + db_path = tmp_path / "test.db" + cfg = tmp_path / "config.toml" + cfg.write_text('featherless_api_key = "test"\n') + monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg)) + monkeypatch.setenv("CHAT_DB_PATH", str(db_path)) + apply_migrations(db_path) + + # Seed chat A's interjection group. + a_ut_id, a_primary_id, a_interjection_id = _seed_with_interjection_group( + db_path + ) + + # Seed chat B with the same shape but a different chat_id and bot + # ids, then add an interjection group whose ``interjection_of`` + # points at "bot_a" so a global (unscoped) query could collide. + with open_db(db_path) as conn: + for bot_id, name in (("bot_c", "BotC"), ("bot_d", "BotD")): + append_event( + conn, + kind="bot_authored", + payload={ + "id": bot_id, + "name": name, + "persona": "", + "voice_samples": [], + "traits": [], + "backstory": "", + "initial_relationship_to_you": "", + "kickoff_prose": "", + }, + ) + append_event( + conn, + kind="chat_created", + payload={ + "id": "chat_other", + "host_bot_id": "bot_c", + "guest_bot_id": "bot_d", + "initial_time": "2026-04-26T20:00:00+00:00", + "narrative_anchor": "Day 1", + "weather": "", + }, + ) + b_ut_id = append_event( + conn, + kind="user_turn", + payload={ + "chat_id": "chat_other", + "prose": "different chat", + "segments": [], + }, + ) + b_primary_id = append_event( + conn, + kind="assistant_turn", + payload={ + "chat_id": "chat_other", + "speaker_id": "bot_c", + "text": "Other primary.", + "truncated": False, + "user_turn_id": b_ut_id, + }, + ) + # The chat B interjection's ``interjection_of`` references + # "bot_a" — the chat A primary's speaker. Pre-T83.3 the global + # sibling query could mis-match this row. + b_interjection_id = append_event( + conn, + kind="assistant_turn", + payload={ + "chat_id": "chat_other", + "speaker_id": "bot_d", + "text": "Cross-chat noise.", + "truncated": False, + "user_turn_id": b_ut_id, + "interjection_of": "bot_a", + }, + ) + + # Stub the interjection classifier to return True so the regenerate + # actively walks the sibling-discovery path. + async def _stub_should_interject(*_args, **_kwargs): + return InterjectionDecision(should_interject=True, reason="fired") + + monkeypatch.setattr( + regenerate_module, "detect_interjection", _stub_should_interject + ) + + state_canned = json.dumps( + {"affinity_delta": 0, "trust_delta": 0, "knowledge_facts": []} + ) + canned: list[str] = ( + ["New chat A primary."] + + [state_canned] * 6 + + ["New chat A interjection."] + + [state_canned] * 6 + ) + mock_client = MockLLMClient(canned=list(canned)) + settings = Settings(featherless_api_key="test") + + with open_db(db_path) as conn: + new_text = asyncio.run( + regenerate_assistant_turn( + conn, + mock_client, + settings=settings, + chat_id="chat_multi", + original_assistant_event_id=a_primary_id, + ) + ) + assert new_text == "New chat A primary." + + # Chat B rows are untouched — neither superseded nor referenced. + b_primary_super = conn.execute( + "SELECT superseded_by FROM event_log WHERE id = ?", + (b_primary_id,), + ).fetchone()[0] + b_interjection_super = conn.execute( + "SELECT superseded_by FROM event_log WHERE id = ?", + (b_interjection_id,), + ).fetchone()[0] + assert b_primary_super is None + assert b_interjection_super is None + + # Chat A's regenerated interjection has its ``regenerated_from`` + # pointing at chat A's original interjection — NOT chat B's. + cur = conn.execute( + "SELECT payload_json FROM event_log " + "WHERE kind = 'assistant_turn' " + " AND id NOT IN (?, ?, ?, ?) " + " AND superseded_by IS NULL", + (a_primary_id, a_interjection_id, b_primary_id, b_interjection_id), + ).fetchall() + # Two new rows: regenerated primary + regenerated interjection. + assert len(cur) == 2 + payloads = [json.loads(row[0]) for row in cur] + # Find the regenerated interjection (carries interjection_of). + new_interject_payloads = [ + p for p in payloads if p.get("interjection_of") + ] + assert len(new_interject_payloads) == 1 + assert new_interject_payloads[0]["regenerated_from"] == a_interjection_id + # Pin chat scope on every new row. + for p in payloads: + assert p["chat_id"] == "chat_multi" + + +def test_regenerate_registers_task_in_in_flight_tasks(tmp_path, monkeypatch): + """T83.1: regenerate's streaming Task is registered in the chat-keyed + ``_in_flight_tasks`` dict so the /turns/cancel route can cancel a + mid-regenerate stream. Mirrors the meanwhile registration pattern + pinned by tests/test_meanwhile_turn_flow.py. + + Snapshot pattern: a custom MockLLMClient subclass captures the + presence of the chat_id in ``_in_flight_tasks`` at the first stream + yield (when the regenerate coroutine is awaiting our generator and + the task is alive). Post-flight, the entry must be cleaned up so the + next regenerate / turn registers a fresh task. + """ + import asyncio + from typing import AsyncIterator, Sequence + + from chat.config import Settings + from chat.db.migrate import apply_migrations + from chat.llm.client import Message + from chat.services.regenerate import regenerate_assistant_turn + from chat.web.turns import _in_flight_tasks + + db_path = tmp_path / "test.db" + cfg = tmp_path / "config.toml" + cfg.write_text('featherless_api_key = "test"\n') + monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg)) + monkeypatch.setenv("CHAT_DB_PATH", str(db_path)) + apply_migrations(db_path) + + _ut_id, at_id = _seed_with_one_turn(db_path) + + in_flight_snapshot: dict = {} + + class _SnapshotMock(MockLLMClient): + async def stream( + self, messages: Sequence[Message], *, model: str, **params + ) -> AsyncIterator[str]: + text = self._canned.pop(0) + for i, ch in enumerate(text): + if i == 0: + in_flight_snapshot["present"] = ( + "chat_bot_a" in _in_flight_tasks + ) + in_flight_snapshot["task"] = _in_flight_tasks.get( + "chat_bot_a" + ) + yield ch + + state_canned = json.dumps( + {"affinity_delta": 0, "trust_delta": 0, "knowledge_facts": []} + ) + mock_client = _SnapshotMock( + canned=["Refreshed reply.", state_canned, state_canned] + ) + + settings = Settings(featherless_api_key="test") + + # Pre-condition: registry empty for this chat. + assert "chat_bot_a" not in _in_flight_tasks + + with open_db(db_path) as conn: + new_text = asyncio.run( + regenerate_assistant_turn( + conn, + mock_client, + settings=settings, + chat_id="chat_bot_a", + original_assistant_event_id=at_id, + ) + ) + assert new_text == "Refreshed reply." + + # Mid-flight: the streaming task was present in the registry, and + # the captured value was an asyncio.Task. + assert in_flight_snapshot.get("present") is True, ( + "_in_flight_tasks was empty at first yield — regenerate stream " + "isn't registering its task" + ) + assert isinstance(in_flight_snapshot.get("task"), asyncio.Task) + # Post-flight: the entry has been cleaned up. + assert "chat_bot_a" not in _in_flight_tasks diff --git a/tests/test_turn_common.py b/tests/test_turn_common.py new file mode 100644 index 0000000..f4b4f9b --- /dev/null +++ b/tests/test_turn_common.py @@ -0,0 +1,215 @@ +"""Shared turn helpers (T83.2). + +``chat.services.turn_common`` extracts two snippets that were duplicated +between ``chat.web.turns`` and ``chat.services.regenerate``: the recent +user-side / assistant_turn read, and the directed-pair edge gather for +the multi-pair state-update pass. These tests pin the helpers' behavior +independently of either call site. +""" + +from __future__ import annotations + +from chat.db.connection import open_db +from chat.db.migrate import apply_migrations +from chat.eventlog.log import append_event +from chat.eventlog.projector import project +from chat.services.turn_common import gather_prior_edges, read_recent_dialogue + + +def _seed_basic_chat(db_path): + """Seed bot + chat + a couple of edges + one round of user/assistant + turns. Returns ``(user_turn_id, assistant_turn_id)``. + """ + apply_migrations(db_path) + with open_db(db_path) as conn: + append_event( + conn, + kind="bot_authored", + payload={ + "id": "bot_a", + "name": "BotA", + "persona": "thoughtful", + "voice_samples": [], + "traits": [], + "backstory": "", + "initial_relationship_to_you": "", + "kickoff_prose": "", + }, + ) + append_event( + conn, + kind="chat_created", + payload={ + "id": "chat_a", + "host_bot_id": "bot_a", + "initial_time": "2026-04-26T20:00:00+00:00", + "narrative_anchor": "Day 1", + "weather": "", + }, + ) + append_event( + conn, + kind="edge_update", + payload={ + "source_id": "bot_a", + "target_id": "you", + "chat_id": "chat_a", + "affinity_delta": 7, + "trust_delta": 3, + }, + ) + append_event( + conn, + kind="edge_update", + payload={ + "source_id": "you", + "target_id": "bot_a", + "chat_id": "chat_a", + "affinity_delta": 2, + "trust_delta": 1, + }, + ) + ut_id = append_event( + conn, + kind="user_turn", + payload={ + "chat_id": "chat_a", + "prose": "hello", + "segments": [], + }, + ) + at_id = append_event( + conn, + kind="assistant_turn", + payload={ + "chat_id": "chat_a", + "speaker_id": "bot_a", + "text": "Original.", + "truncated": False, + "user_turn_id": ut_id, + }, + ) + project(conn) + return ut_id, at_id + + +def test_read_recent_dialogue_returns_chronological_pairs(tmp_path): + """``read_recent_dialogue`` returns oldest-first ``{speaker, text}`` + entries scoped to the requested chat. Speaker is "you" for user-side + rows and the assistant_turn's ``speaker_id`` for bot rows. + """ + db = tmp_path / "test.db" + _seed_basic_chat(db) + + with open_db(db) as conn: + out = read_recent_dialogue(conn, "chat_a", limit=10) + + assert out == [ + {"speaker": "you", "text": "hello"}, + {"speaker": "bot_a", "text": "Original."}, + ] + + +def test_read_recent_dialogue_filters_superseded_and_other_chats(tmp_path): + """Superseded rows drop out (regenerate-aware). Rows scoped to a + different chat are also filtered. ``exclude_event_id`` excludes a + specific row even when it isn't superseded yet (regenerate uses this + to drop the original assistant_turn before the supersede UPDATE + lands). + """ + db = tmp_path / "test.db" + ut_id, at_id = _seed_basic_chat(db) + + with open_db(db) as conn: + # Append a second user/assistant pair. + ut_id2 = append_event( + conn, + kind="user_turn", + payload={ + "chat_id": "chat_a", + "prose": "how are you", + "segments": [], + }, + ) + at_id2 = append_event( + conn, + kind="assistant_turn", + payload={ + "chat_id": "chat_a", + "speaker_id": "bot_a", + "text": "Second.", + "truncated": False, + "user_turn_id": ut_id2, + }, + ) + # And a row scoped to a different chat — must NOT appear. + append_event( + conn, + kind="user_turn", + payload={ + "chat_id": "other_chat", + "prose": "should be filtered", + "segments": [], + }, + ) + # Mark the first assistant_turn as superseded — must drop out. + conn.execute( + "UPDATE event_log SET superseded_by = ? WHERE id = ?", + (at_id2, at_id), + ) + + out = read_recent_dialogue(conn, "chat_a", limit=10) + # First (superseded) assistant turn dropped; "other_chat" rows + # filtered; first user_turn still present. + speakers = [(e["speaker"], e["text"]) for e in out] + assert speakers == [ + ("you", "hello"), + ("you", "how are you"), + ("bot_a", "Second."), + ] + + # exclude_event_id drops at_id2 even though it's not superseded. + out2 = read_recent_dialogue( + conn, "chat_a", limit=10, exclude_event_id=at_id2 + ) + speakers2 = [(e["speaker"], e["text"]) for e in out2] + assert ("bot_a", "Second.") not in speakers2 + assert ("you", "how are you") in speakers2 + + # Ensure ut_id is still part of the dataset (sanity for the seed). + assert ut_id is not None + + +def test_gather_prior_edges_fills_missing_with_default(tmp_path): + """``gather_prior_edges`` returns one entry per directed pair across + ``present_ids``. Missing rows fall back to the schema default + 50/50 baseline; existing rows carry their stored values. + """ + db = tmp_path / "test.db" + _seed_basic_chat(db) + + with open_db(db) as conn: + out = gather_prior_edges(conn, ["bot_a", "you"]) + + # 2 entities -> 2 directed pairs (a->b and b->a, no self-pairs). + assert set(out.keys()) == {("bot_a", "you"), ("you", "bot_a")} + bot_to_you = out[("bot_a", "you")] + you_to_bot = out[("you", "bot_a")] + # Both edges seeded with deltas — they must reflect the projected + # affinity/trust (not the default 50/50). + assert bot_to_you["affinity"] == 57 # 50 + 7 + assert bot_to_you["trust"] == 53 # 50 + 3 + assert you_to_bot["affinity"] == 52 + assert you_to_bot["trust"] == 51 + + # A pair with no row yet falls back to 50/50. + with open_db(db) as conn: + out_with_missing = gather_prior_edges( + conn, ["bot_a", "you", "ghost_bot"] + ) + # 3 entities -> 6 directed pairs. + assert len(out_with_missing) == 6 + fallback = out_with_missing[("bot_a", "ghost_bot")] + assert fallback["affinity"] == 50 + assert fallback["trust"] == 50 + assert fallback["summary"] == ""