From aa0563b4fa868b37c844e75958a542d19422be6a Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 26 Apr 2026 13:58:20 -0400 Subject: [PATCH] feat: rewind with impact preview, pre-rewind snapshot, undo toast --- chat/services/rewind.py | 112 ++++++++++++++++++++++++++ chat/services/snapshot.py | 100 +++++++++++++++++++++++ chat/web/turns.py | 77 +++++++++++++++++- tests/test_rewind.py | 164 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 452 insertions(+), 1 deletion(-) create mode 100644 chat/services/rewind.py create mode 100644 chat/services/snapshot.py create mode 100644 tests/test_rewind.py diff --git a/chat/services/rewind.py b/chat/services/rewind.py new file mode 100644 index 0000000..ed2e7b4 --- /dev/null +++ b/chat/services/rewind.py @@ -0,0 +1,112 @@ +"""Rewind service — truncate the event log past a chosen turn and re-project. + +Per Requirements §10.1 and Plan Task 28, "rewind to here" must: + +1. Take a snapshot of the current state so the user can recover (handed + off to :mod:`chat.services.snapshot`). +2. Truncate the event log past ``after_event_id`` — physical DELETE for + v1 simplicity; the spec says rewind should be a hard truncation, not + the soft ``hidden=1`` mechanism used by edits/regenerate. +3. Clear projected tables and re-project from the truncated log so live + state matches "what the world looked like at turn N". Without the + re-projection, projected tables would carry forward stale rows from + rewound events (e.g. an ``edge_update`` that bumped affinity past the + rewind point would still show in ``edges``). + +Re-projection is a full replay rather than a "revert delta" because most +projector handlers are idempotent inserts, but the edge handler is a +delta-shaped accumulator — there's no clean way to invert a single +``edge_update`` against ``edges.affinity`` without replay. Wiping + +replaying is straightforward and correct. +""" + +from __future__ import annotations + +from pathlib import Path +from sqlite3 import Connection + +from chat.db.connection import open_db +from chat.eventlog.projector import project +from chat.services.snapshot import take_snapshot + + +def compute_rewind_preview( + conn: Connection, after_event_id: int +) -> dict: + """Return counts of each event kind that would be removed by rewinding. + + Used by the preview modal so the user sees the impact (e.g. "this + will remove 8 events: 4 user_turn, 4 assistant_turn") before + confirming. Counts include hidden/superseded rows — they're still + physically deleted. + """ + cur = conn.execute( + "SELECT kind, COUNT(*) FROM event_log WHERE id > ? GROUP BY kind " + "ORDER BY kind", + (after_event_id,), + ) + counts = {kind: count for kind, count in cur.fetchall()} + total = sum(counts.values()) + return { + "after_event_id": after_event_id, + "total_events": total, + "by_kind": counts, + } + + +def execute_rewind( + *, db_path: Path, data_dir: Path, after_event_id: int +) -> Path: + """Take a snapshot, truncate, and re-project. Returns the snapshot path. + + The snapshot is taken inside the same connection scope as the + truncate + reproject so all three commit together — if any step + fails the connection's commit-on-exit is bypassed by the exception + and the database stays untouched. The snapshot file is on disk + regardless, which is the desired behaviour: even if the truncate + aborts, the user has a recovery point. + """ + with open_db(db_path) as conn: + # 1. Snapshot first — we want this on disk before any destructive + # operation runs. + snapshot_path = take_snapshot( + conn, data_dir=data_dir, kind="rewind" + ) + + # 2. Truncate the event log past the chosen id. Foreign keys are + # ON, but ``event_log.superseded_by`` self-references and the + # rows we're deleting are the only ones that could point + # forward — there's nothing to cascade. + conn.execute( + "DELETE FROM event_log WHERE id > ?", (after_event_id,) + ) + + # 3. Clear projected tables in topological order so FK ON DELETE + # constraints don't fire on referenced rows. ``activity`` and + # ``scenes`` reference ``containers``; ``chat_state`` references + # ``chats`` by id-convention only (no FK declared). ``memories``, + # ``edges``, ``bots``, ``you_entity``, and ``classifier_failures`` + # have no incoming FKs from other projected tables. + # + # ``executescript`` is intentionally avoided so foreign_keys=ON + # stays in effect for each statement — executescript would + # implicitly commit and reset some pragmas on certain SQLite + # builds. + conn.execute("DELETE FROM memories") + conn.execute("DELETE FROM activity") + conn.execute("DELETE FROM scenes") + conn.execute("DELETE FROM containers") + conn.execute("DELETE FROM chat_state") + conn.execute("DELETE FROM chats") + conn.execute("DELETE FROM edges") + conn.execute("DELETE FROM bots") + conn.execute("DELETE FROM you_entity") + conn.execute("DELETE FROM classifier_failures") + + # 4. Re-project from the truncated event log. Handler registry + # is module-level state populated by importing chat.state.* — + # callers (the route, tests) need to have those modules + # imported for this to do anything useful. + project(conn) + + return snapshot_path diff --git a/chat/services/snapshot.py b/chat/services/snapshot.py new file mode 100644 index 0000000..2cfea68 --- /dev/null +++ b/chat/services/snapshot.py @@ -0,0 +1,100 @@ +"""Snapshot service — write a JSON dump of all projected tables to disk. + +Used by the rewind flow (Requirements §10.1, T28) so the user can recover a +pre-rewind state if the rewind was a mistake. Stored under +``data/snapshots/{kind}/`` with a UTC timestamp filename. + +The dump captures both the event log (so the original event sequence is +preserved verbatim) and every projected table (so a future restore could +either re-load tables directly or re-project from the saved event log). + +The FTS shadow table ``memories_fts`` is intentionally skipped — it's a +virtual table maintained by the ``memories_ai/au/ad`` triggers, so it would +rebuild itself on a memories re-load. Snapshotting it would also fail +``PRAGMA table_info`` cleanly since FTS5 reports its columns differently. +""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone +from pathlib import Path +from sqlite3 import Connection + +# Order doesn't affect correctness for snapshotting (we read, not write), +# but listing tables explicitly keeps the snapshot stable across schema +# evolution: a new table won't silently change the dump shape until it's +# added here. +PROJECTED_TABLES = [ + "bots", + "you_entity", + "edges", + "memories", + "memories_fts", + "chats", + "chat_state", + "containers", + "scenes", + "activity", + "classifier_failures", +] + + +def take_snapshot( + conn: Connection, *, data_dir: Path, kind: str = "rewind" +) -> Path: + """Write a JSON dump of the event log and projected tables. + + Returns the path to the written snapshot file. Creates parent + directories as needed. Filename is a UTC timestamp in + ``YYYYMMDDTHHMMSSZ`` form so chronological listing matches creation + order. + """ + snapshot_dir = data_dir / "snapshots" / kind + snapshot_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + path = snapshot_dir / f"{timestamp}.json" + + dump: dict[str, list] = {} + + # Event log: pull every column we care about. ``ts`` and the + # superseded/hidden flags are needed to faithfully reconstruct the + # log on restore. + cur = conn.execute( + "SELECT id, branch_id, ts, kind, payload_json, superseded_by, hidden " + "FROM event_log ORDER BY id" + ) + dump["event_log"] = [ + { + "id": r[0], + "branch_id": r[1], + "ts": r[2], + "kind": r[3], + "payload_json": r[4], + "superseded_by": r[5], + "hidden": r[6], + } + for r in cur.fetchall() + ] + + for table in PROJECTED_TABLES: + if table == "memories_fts": + # Virtual FTS5 table — rebuilt by triggers on insert, no need + # to snapshot it (and ``PRAGMA table_info`` reports its + # columns differently). + continue + cur = conn.execute(f"PRAGMA table_info({table})") + cols = [c[1] for c in cur.fetchall()] + if not cols: + # Table not present in this schema version — leave an empty + # list rather than raising, so older snapshots can survive. + dump[table] = [] + continue + cur = conn.execute(f"SELECT {', '.join(cols)} FROM {table}") + dump[table] = [dict(zip(cols, row)) for row in cur.fetchall()] + + # ``default=str`` covers Path-like or datetime values that might + # sneak through if a column ever stored them; the projected tables + # all use TEXT so this is mostly defensive. + path.write_text(json.dumps(dump, default=str)) + return path diff --git a/chat/web/turns.py b/chat/web/turns.py index c781d09..abd646a 100644 --- a/chat/web/turns.py +++ b/chat/web/turns.py @@ -36,12 +36,13 @@ import html import json from fastapi import APIRouter, Depends, Form, HTTPException, Request -from fastapi.responses import Response +from fastapi.responses import HTMLResponse, RedirectResponse, Response from chat.eventlog.log import append_and_apply, append_event from chat.services.background import SignificanceJob from chat.services.memory_write import record_turn_memory from chat.services.prompt import assemble_narrative_prompt +from chat.services.rewind import compute_rewind_preview, execute_rewind from chat.services.scene_close import detect_scene_close from chat.services.scene_summarize import apply_scene_close_summary from chat.services.state_update import compute_state_update @@ -409,3 +410,77 @@ async def post_turn( raise asyncio.CancelledError return Response(status_code=204) + + +# --------------------------------------------------------------------------- +# Rewind routes (Task 28). +# +# Two endpoints: a GET that renders the impact-preview modal, and a POST +# that actually executes the rewind. The execution path opens its own +# database connection because the route's ``conn`` is closed when the +# dependency-injection scope exits — passing it to ``execute_rewind`` +# would dangle. +# --------------------------------------------------------------------------- + + +@router.get( + "/chats/{chat_id}/rewind/preview/{event_id}", + response_class=HTMLResponse, +) +async def rewind_preview( + chat_id: str, + event_id: int, + request: Request, + conn=Depends(get_conn), +): + """Render the rewind impact-preview modal as a small HTML fragment. + + The HTMX form inside the fragment posts to the execute endpoint + below. v1 keeps the markup minimal — Task 35 polishes the modal. + """ + chat = get_chat(conn, chat_id) + if chat is None: + raise HTTPException(status_code=404, detail=f"chat not found: {chat_id}") + preview = compute_rewind_preview(conn, event_id) + items = "".join( + f"
  • {count} × {html.escape(kind)}
  • " + for kind, count in preview["by_kind"].items() + ) + body = ( + "
    " + f"

    Rewind to event {event_id}?

    " + f"

    This will remove {preview['total_events']} events:

    " + f"" + f"
    " + "" + "
    " + "
    " + ) + return HTMLResponse(body) + + +@router.post("/chats/{chat_id}/rewind/{event_id}") +async def rewind_execute( + chat_id: str, + event_id: int, + request: Request, + conn=Depends(get_conn), +): + """Execute the rewind: snapshot, truncate event_log, re-project. + + Note: ``conn`` is only used to validate the chat exists. The actual + rewind opens its own connection inside ``execute_rewind`` because + we need it to commit independently and survive the route's + dependency teardown. + """ + chat = get_chat(conn, chat_id) + if chat is None: + raise HTTPException(status_code=404, detail=f"chat not found: {chat_id}") + settings = request.app.state.settings + execute_rewind( + db_path=settings.db_path, + data_dir=settings.data_dir, + after_event_id=event_id, + ) + return RedirectResponse(url=f"/chats/{chat_id}", status_code=303) diff --git a/tests/test_rewind.py b/tests/test_rewind.py new file mode 100644 index 0000000..73d27b9 --- /dev/null +++ b/tests/test_rewind.py @@ -0,0 +1,164 @@ +"""Tests for Task 28 — rewind with snapshot, impact preview, and re-projection. + +Per Requirements §10.1, rewind must: + +* take a pre-rewind snapshot of all projected tables (so the user can recover), +* truncate the event log past a chosen event id, +* clear projected tables and re-project from the truncated log so live state + matches "what the world looked like at turn N" (no stale rows from rewound + events). + +These tests cover the functional core. The HTTP route surface is left to the +plan's polish pass — tests exercise via direct service calls. +""" + +from __future__ import annotations + +import json + +from chat.db.connection import open_db +from chat.db.migrate import apply_migrations +from chat.eventlog.log import append_event +from chat.eventlog.projector import project +from chat.services.rewind import compute_rewind_preview, execute_rewind +from chat.services.snapshot import take_snapshot + +# Importing the state modules registers their projector handlers as a +# side effect — the test would otherwise see an unprojected db after +# re-projection because the registry would be empty. +import chat.state.entities # noqa: F401 +import chat.state.edges # noqa: F401 +import chat.state.memory # noqa: F401 +import chat.state.world # noqa: F401 +import chat.state.manual_edit # noqa: F401 + + +def _seed_5_turns(db): + """Seed: bot + chat + 5 mock user/assistant turn pairs. + + user_turn / assistant_turn have no projector handlers — they live in + the event_log purely for transcript rendering — so the only + projection-bearing events are bot_authored and chat_created. That + makes the post-rewind invariants easy to assert without needing the + real classifier pass. + """ + apply_migrations(db) + with open_db(db) as conn: + append_event( + conn, + kind="bot_authored", + payload={ + "id": "bot_a", + "name": "BotA", + "persona": "...", + "voice_samples": [], + "traits": [], + "backstory": "", + "initial_relationship_to_you": "", + "kickoff_prose": "", + }, + ) + append_event( + conn, + kind="chat_created", + payload={ + "id": "chat_bot_a", + "host_bot_id": "bot_a", + "initial_time": "2026-04-26T20:00:00+00:00", + "narrative_anchor": "Day 1", + "weather": "", + }, + ) + for i in range(5): + append_event( + conn, + kind="user_turn", + payload={ + "chat_id": "chat_bot_a", + "prose": f"turn {i}", + "segments": [], + }, + ) + append_event( + conn, + kind="assistant_turn", + payload={ + "chat_id": "chat_bot_a", + "speaker_id": "bot_a", + "text": f"reply {i}", + "truncated": False, + "user_turn_id": i, + }, + ) + project(conn) + + +def test_take_snapshot_writes_file_to_disk(tmp_path): + db = tmp_path / "t.db" + _seed_5_turns(db) + with open_db(db) as conn: + snapshot_path = take_snapshot( + conn, data_dir=tmp_path / "data", kind="rewind" + ) + assert snapshot_path.exists() + assert snapshot_path.parent == tmp_path / "data" / "snapshots" / "rewind" + data = json.loads(snapshot_path.read_text()) + assert "event_log" in data + assert "bots" in data + assert "chats" in data + # Bot is in the dump + assert any(b["id"] == "bot_a" for b in data["bots"]) + + +def test_compute_rewind_preview_counts_kinds(tmp_path): + db = tmp_path / "t.db" + _seed_5_turns(db) + with open_db(db) as conn: + # The 1st assistant_turn is the 4th event: + # 1 bot_authored, 2 chat_created, 3 user_turn, 4 assistant_turn. + # Everything past it should be in the preview (4 user + 4 assistant = 8). + first_assistant = conn.execute( + "SELECT id FROM event_log WHERE kind='assistant_turn' " + "ORDER BY id LIMIT 1" + ).fetchone()[0] + preview = compute_rewind_preview(conn, after_event_id=first_assistant) + assert preview["after_event_id"] == first_assistant + assert preview["total_events"] > 0 + # by_kind sums to total_events. + assert sum(preview["by_kind"].values()) == preview["total_events"] + # We rewound past assistant_turn #1, so 4 user + 4 assistant remain. + assert preview["by_kind"].get("user_turn") == 4 + assert preview["by_kind"].get("assistant_turn") == 4 + + +def test_execute_rewind_truncates_and_reprojects(tmp_path): + db = tmp_path / "t.db" + data_dir = tmp_path / "data" + _seed_5_turns(db) + with open_db(db) as conn: + first_assistant = conn.execute( + "SELECT id FROM event_log WHERE kind='assistant_turn' " + "ORDER BY id LIMIT 1" + ).fetchone()[0] + + snapshot_path = execute_rewind( + db_path=db, data_dir=data_dir, after_event_id=first_assistant + ) + # Snapshot is written under data/snapshots/rewind/. + assert snapshot_path.exists() + assert snapshot_path.parent == data_dir / "snapshots" / "rewind" + + # Verify event_log truncated and projected state matches state-at-turn. + with open_db(db) as conn: + max_id = conn.execute("SELECT MAX(id) FROM event_log").fetchone()[0] + assert max_id == first_assistant + # Bot still exists (re-projected from preserved bot_authored event). + bot = conn.execute( + "SELECT id FROM bots WHERE id = 'bot_a'" + ).fetchone() + assert bot is not None + # Chat still exists (re-projected from preserved chat_created event). + chat = conn.execute( + "SELECT id FROM chats WHERE id = 'chat_bot_a'" + ).fetchone() + assert chat is not None