feat: rewind with impact preview, pre-rewind snapshot, undo toast

This commit is contained in:
Joseph Doherty
2026-04-26 13:58:20 -04:00
parent b5175aefaa
commit aa0563b4fa
4 changed files with 452 additions and 1 deletions
+112
View File
@@ -0,0 +1,112 @@
"""Rewind service — truncate the event log past a chosen turn and re-project.
Per Requirements §10.1 and Plan Task 28, "rewind to here" must:
1. Take a snapshot of the current state so the user can recover (handed
off to :mod:`chat.services.snapshot`).
2. Truncate the event log past ``after_event_id`` — physical DELETE for
v1 simplicity; the spec says rewind should be a hard truncation, not
the soft ``hidden=1`` mechanism used by edits/regenerate.
3. Clear projected tables and re-project from the truncated log so live
state matches "what the world looked like at turn N". Without the
re-projection, projected tables would carry forward stale rows from
rewound events (e.g. an ``edge_update`` that bumped affinity past the
rewind point would still show in ``edges``).
Re-projection is a full replay rather than a "revert delta" because most
projector handlers are idempotent inserts, but the edge handler is a
delta-shaped accumulator — there's no clean way to invert a single
``edge_update`` against ``edges.affinity`` without replay. Wiping +
replaying is straightforward and correct.
"""
from __future__ import annotations
from pathlib import Path
from sqlite3 import Connection
from chat.db.connection import open_db
from chat.eventlog.projector import project
from chat.services.snapshot import take_snapshot
def compute_rewind_preview(
conn: Connection, after_event_id: int
) -> dict:
"""Return counts of each event kind that would be removed by rewinding.
Used by the preview modal so the user sees the impact (e.g. "this
will remove 8 events: 4 user_turn, 4 assistant_turn") before
confirming. Counts include hidden/superseded rows — they're still
physically deleted.
"""
cur = conn.execute(
"SELECT kind, COUNT(*) FROM event_log WHERE id > ? GROUP BY kind "
"ORDER BY kind",
(after_event_id,),
)
counts = {kind: count for kind, count in cur.fetchall()}
total = sum(counts.values())
return {
"after_event_id": after_event_id,
"total_events": total,
"by_kind": counts,
}
def execute_rewind(
*, db_path: Path, data_dir: Path, after_event_id: int
) -> Path:
"""Take a snapshot, truncate, and re-project. Returns the snapshot path.
The snapshot is taken inside the same connection scope as the
truncate + reproject so all three commit together — if any step
fails the connection's commit-on-exit is bypassed by the exception
and the database stays untouched. The snapshot file is on disk
regardless, which is the desired behaviour: even if the truncate
aborts, the user has a recovery point.
"""
with open_db(db_path) as conn:
# 1. Snapshot first — we want this on disk before any destructive
# operation runs.
snapshot_path = take_snapshot(
conn, data_dir=data_dir, kind="rewind"
)
# 2. Truncate the event log past the chosen id. Foreign keys are
# ON, but ``event_log.superseded_by`` self-references and the
# rows we're deleting are the only ones that could point
# forward — there's nothing to cascade.
conn.execute(
"DELETE FROM event_log WHERE id > ?", (after_event_id,)
)
# 3. Clear projected tables in topological order so FK ON DELETE
# constraints don't fire on referenced rows. ``activity`` and
# ``scenes`` reference ``containers``; ``chat_state`` references
# ``chats`` by id-convention only (no FK declared). ``memories``,
# ``edges``, ``bots``, ``you_entity``, and ``classifier_failures``
# have no incoming FKs from other projected tables.
#
# ``executescript`` is intentionally avoided so foreign_keys=ON
# stays in effect for each statement — executescript would
# implicitly commit and reset some pragmas on certain SQLite
# builds.
conn.execute("DELETE FROM memories")
conn.execute("DELETE FROM activity")
conn.execute("DELETE FROM scenes")
conn.execute("DELETE FROM containers")
conn.execute("DELETE FROM chat_state")
conn.execute("DELETE FROM chats")
conn.execute("DELETE FROM edges")
conn.execute("DELETE FROM bots")
conn.execute("DELETE FROM you_entity")
conn.execute("DELETE FROM classifier_failures")
# 4. Re-project from the truncated event log. Handler registry
# is module-level state populated by importing chat.state.* —
# callers (the route, tests) need to have those modules
# imported for this to do anything useful.
project(conn)
return snapshot_path
+100
View File
@@ -0,0 +1,100 @@
"""Snapshot service — write a JSON dump of all projected tables to disk.
Used by the rewind flow (Requirements §10.1, T28) so the user can recover a
pre-rewind state if the rewind was a mistake. Stored under
``data/snapshots/{kind}/`` with a UTC timestamp filename.
The dump captures both the event log (so the original event sequence is
preserved verbatim) and every projected table (so a future restore could
either re-load tables directly or re-project from the saved event log).
The FTS shadow table ``memories_fts`` is intentionally skipped — it's a
virtual table maintained by the ``memories_ai/au/ad`` triggers, so it would
rebuild itself on a memories re-load. Snapshotting it would also fail
``PRAGMA table_info`` cleanly since FTS5 reports its columns differently.
"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from pathlib import Path
from sqlite3 import Connection
# Order doesn't affect correctness for snapshotting (we read, not write),
# but listing tables explicitly keeps the snapshot stable across schema
# evolution: a new table won't silently change the dump shape until it's
# added here.
PROJECTED_TABLES = [
"bots",
"you_entity",
"edges",
"memories",
"memories_fts",
"chats",
"chat_state",
"containers",
"scenes",
"activity",
"classifier_failures",
]
def take_snapshot(
conn: Connection, *, data_dir: Path, kind: str = "rewind"
) -> Path:
"""Write a JSON dump of the event log and projected tables.
Returns the path to the written snapshot file. Creates parent
directories as needed. Filename is a UTC timestamp in
``YYYYMMDDTHHMMSSZ`` form so chronological listing matches creation
order.
"""
snapshot_dir = data_dir / "snapshots" / kind
snapshot_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
path = snapshot_dir / f"{timestamp}.json"
dump: dict[str, list] = {}
# Event log: pull every column we care about. ``ts`` and the
# superseded/hidden flags are needed to faithfully reconstruct the
# log on restore.
cur = conn.execute(
"SELECT id, branch_id, ts, kind, payload_json, superseded_by, hidden "
"FROM event_log ORDER BY id"
)
dump["event_log"] = [
{
"id": r[0],
"branch_id": r[1],
"ts": r[2],
"kind": r[3],
"payload_json": r[4],
"superseded_by": r[5],
"hidden": r[6],
}
for r in cur.fetchall()
]
for table in PROJECTED_TABLES:
if table == "memories_fts":
# Virtual FTS5 table — rebuilt by triggers on insert, no need
# to snapshot it (and ``PRAGMA table_info`` reports its
# columns differently).
continue
cur = conn.execute(f"PRAGMA table_info({table})")
cols = [c[1] for c in cur.fetchall()]
if not cols:
# Table not present in this schema version — leave an empty
# list rather than raising, so older snapshots can survive.
dump[table] = []
continue
cur = conn.execute(f"SELECT {', '.join(cols)} FROM {table}")
dump[table] = [dict(zip(cols, row)) for row in cur.fetchall()]
# ``default=str`` covers Path-like or datetime values that might
# sneak through if a column ever stored them; the projected tables
# all use TEXT so this is mostly defensive.
path.write_text(json.dumps(dump, default=str))
return path
+76 -1
View File
@@ -36,12 +36,13 @@ import html
import json
from fastapi import APIRouter, Depends, Form, HTTPException, Request
from fastapi.responses import Response
from fastapi.responses import HTMLResponse, RedirectResponse, Response
from chat.eventlog.log import append_and_apply, append_event
from chat.services.background import SignificanceJob
from chat.services.memory_write import record_turn_memory
from chat.services.prompt import assemble_narrative_prompt
from chat.services.rewind import compute_rewind_preview, execute_rewind
from chat.services.scene_close import detect_scene_close
from chat.services.scene_summarize import apply_scene_close_summary
from chat.services.state_update import compute_state_update
@@ -409,3 +410,77 @@ async def post_turn(
raise asyncio.CancelledError
return Response(status_code=204)
# ---------------------------------------------------------------------------
# Rewind routes (Task 28).
#
# Two endpoints: a GET that renders the impact-preview modal, and a POST
# that actually executes the rewind. The execution path opens its own
# database connection because the route's ``conn`` is closed when the
# dependency-injection scope exits — passing it to ``execute_rewind``
# would dangle.
# ---------------------------------------------------------------------------
@router.get(
"/chats/{chat_id}/rewind/preview/{event_id}",
response_class=HTMLResponse,
)
async def rewind_preview(
chat_id: str,
event_id: int,
request: Request,
conn=Depends(get_conn),
):
"""Render the rewind impact-preview modal as a small HTML fragment.
The HTMX form inside the fragment posts to the execute endpoint
below. v1 keeps the markup minimal — Task 35 polishes the modal.
"""
chat = get_chat(conn, chat_id)
if chat is None:
raise HTTPException(status_code=404, detail=f"chat not found: {chat_id}")
preview = compute_rewind_preview(conn, event_id)
items = "".join(
f"<li>{count} × {html.escape(kind)}</li>"
for kind, count in preview["by_kind"].items()
)
body = (
"<div class='rewind-modal'>"
f"<h3>Rewind to event {event_id}?</h3>"
f"<p>This will remove {preview['total_events']} events:</p>"
f"<ul>{items}</ul>"
f"<form hx-post='/chats/{html.escape(chat_id)}/rewind/{event_id}' "
"hx-target='body' hx-swap='innerHTML'>"
"<button type='submit'>Confirm Rewind</button>"
"</form>"
"</div>"
)
return HTMLResponse(body)
@router.post("/chats/{chat_id}/rewind/{event_id}")
async def rewind_execute(
chat_id: str,
event_id: int,
request: Request,
conn=Depends(get_conn),
):
"""Execute the rewind: snapshot, truncate event_log, re-project.
Note: ``conn`` is only used to validate the chat exists. The actual
rewind opens its own connection inside ``execute_rewind`` because
we need it to commit independently and survive the route's
dependency teardown.
"""
chat = get_chat(conn, chat_id)
if chat is None:
raise HTTPException(status_code=404, detail=f"chat not found: {chat_id}")
settings = request.app.state.settings
execute_rewind(
db_path=settings.db_path,
data_dir=settings.data_dir,
after_event_id=event_id,
)
return RedirectResponse(url=f"/chats/{chat_id}", status_code=303)
+164
View File
@@ -0,0 +1,164 @@
"""Tests for Task 28 — rewind with snapshot, impact preview, and re-projection.
Per Requirements §10.1, rewind must:
* take a pre-rewind snapshot of all projected tables (so the user can recover),
* truncate the event log past a chosen event id,
* clear projected tables and re-project from the truncated log so live state
matches "what the world looked like at turn N" (no stale rows from rewound
events).
These tests cover the functional core. The HTTP route surface is left to the
plan's polish pass — tests exercise via direct service calls.
"""
from __future__ import annotations
import json
from chat.db.connection import open_db
from chat.db.migrate import apply_migrations
from chat.eventlog.log import append_event
from chat.eventlog.projector import project
from chat.services.rewind import compute_rewind_preview, execute_rewind
from chat.services.snapshot import take_snapshot
# Importing the state modules registers their projector handlers as a
# side effect — the test would otherwise see an unprojected db after
# re-projection because the registry would be empty.
import chat.state.entities # noqa: F401
import chat.state.edges # noqa: F401
import chat.state.memory # noqa: F401
import chat.state.world # noqa: F401
import chat.state.manual_edit # noqa: F401
def _seed_5_turns(db):
"""Seed: bot + chat + 5 mock user/assistant turn pairs.
user_turn / assistant_turn have no projector handlers — they live in
the event_log purely for transcript rendering — so the only
projection-bearing events are bot_authored and chat_created. That
makes the post-rewind invariants easy to assert without needing the
real classifier pass.
"""
apply_migrations(db)
with open_db(db) as conn:
append_event(
conn,
kind="bot_authored",
payload={
"id": "bot_a",
"name": "BotA",
"persona": "...",
"voice_samples": [],
"traits": [],
"backstory": "",
"initial_relationship_to_you": "",
"kickoff_prose": "",
},
)
append_event(
conn,
kind="chat_created",
payload={
"id": "chat_bot_a",
"host_bot_id": "bot_a",
"initial_time": "2026-04-26T20:00:00+00:00",
"narrative_anchor": "Day 1",
"weather": "",
},
)
for i in range(5):
append_event(
conn,
kind="user_turn",
payload={
"chat_id": "chat_bot_a",
"prose": f"turn {i}",
"segments": [],
},
)
append_event(
conn,
kind="assistant_turn",
payload={
"chat_id": "chat_bot_a",
"speaker_id": "bot_a",
"text": f"reply {i}",
"truncated": False,
"user_turn_id": i,
},
)
project(conn)
def test_take_snapshot_writes_file_to_disk(tmp_path):
db = tmp_path / "t.db"
_seed_5_turns(db)
with open_db(db) as conn:
snapshot_path = take_snapshot(
conn, data_dir=tmp_path / "data", kind="rewind"
)
assert snapshot_path.exists()
assert snapshot_path.parent == tmp_path / "data" / "snapshots" / "rewind"
data = json.loads(snapshot_path.read_text())
assert "event_log" in data
assert "bots" in data
assert "chats" in data
# Bot is in the dump
assert any(b["id"] == "bot_a" for b in data["bots"])
def test_compute_rewind_preview_counts_kinds(tmp_path):
db = tmp_path / "t.db"
_seed_5_turns(db)
with open_db(db) as conn:
# The 1st assistant_turn is the 4th event:
# 1 bot_authored, 2 chat_created, 3 user_turn, 4 assistant_turn.
# Everything past it should be in the preview (4 user + 4 assistant = 8).
first_assistant = conn.execute(
"SELECT id FROM event_log WHERE kind='assistant_turn' "
"ORDER BY id LIMIT 1"
).fetchone()[0]
preview = compute_rewind_preview(conn, after_event_id=first_assistant)
assert preview["after_event_id"] == first_assistant
assert preview["total_events"] > 0
# by_kind sums to total_events.
assert sum(preview["by_kind"].values()) == preview["total_events"]
# We rewound past assistant_turn #1, so 4 user + 4 assistant remain.
assert preview["by_kind"].get("user_turn") == 4
assert preview["by_kind"].get("assistant_turn") == 4
def test_execute_rewind_truncates_and_reprojects(tmp_path):
db = tmp_path / "t.db"
data_dir = tmp_path / "data"
_seed_5_turns(db)
with open_db(db) as conn:
first_assistant = conn.execute(
"SELECT id FROM event_log WHERE kind='assistant_turn' "
"ORDER BY id LIMIT 1"
).fetchone()[0]
snapshot_path = execute_rewind(
db_path=db, data_dir=data_dir, after_event_id=first_assistant
)
# Snapshot is written under data/snapshots/rewind/.
assert snapshot_path.exists()
assert snapshot_path.parent == data_dir / "snapshots" / "rewind"
# Verify event_log truncated and projected state matches state-at-turn.
with open_db(db) as conn:
max_id = conn.execute("SELECT MAX(id) FROM event_log").fetchone()[0]
assert max_id == first_assistant
# Bot still exists (re-projected from preserved bot_authored event).
bot = conn.execute(
"SELECT id FROM bots WHERE id = 'bot_a'"
).fetchone()
assert bot is not None
# Chat still exists (re-projected from preserved chat_created event).
chat = conn.execute(
"SELECT id FROM chats WHERE id = 'chat_bot_a'"
).fetchone()
assert chat is not None