feat: rewind with impact preview, pre-rewind snapshot, undo toast

This commit is contained in:
Joseph Doherty
2026-04-26 13:58:20 -04:00
parent b5175aefaa
commit aa0563b4fa
4 changed files with 452 additions and 1 deletions
+112
View File
@@ -0,0 +1,112 @@
"""Rewind service — truncate the event log past a chosen turn and re-project.
Per Requirements §10.1 and Plan Task 28, "rewind to here" must:
1. Take a snapshot of the current state so the user can recover (handed
off to :mod:`chat.services.snapshot`).
2. Truncate the event log past ``after_event_id`` — physical DELETE for
v1 simplicity; the spec says rewind should be a hard truncation, not
the soft ``hidden=1`` mechanism used by edits/regenerate.
3. Clear projected tables and re-project from the truncated log so live
state matches "what the world looked like at turn N". Without the
re-projection, projected tables would carry forward stale rows from
rewound events (e.g. an ``edge_update`` that bumped affinity past the
rewind point would still show in ``edges``).
Re-projection is a full replay rather than a "revert delta" because most
projector handlers are idempotent inserts, but the edge handler is a
delta-shaped accumulator — there's no clean way to invert a single
``edge_update`` against ``edges.affinity`` without replay. Wiping +
replaying is straightforward and correct.
"""
from __future__ import annotations
from pathlib import Path
from sqlite3 import Connection
from chat.db.connection import open_db
from chat.eventlog.projector import project
from chat.services.snapshot import take_snapshot
def compute_rewind_preview(
conn: Connection, after_event_id: int
) -> dict:
"""Return counts of each event kind that would be removed by rewinding.
Used by the preview modal so the user sees the impact (e.g. "this
will remove 8 events: 4 user_turn, 4 assistant_turn") before
confirming. Counts include hidden/superseded rows — they're still
physically deleted.
"""
cur = conn.execute(
"SELECT kind, COUNT(*) FROM event_log WHERE id > ? GROUP BY kind "
"ORDER BY kind",
(after_event_id,),
)
counts = {kind: count for kind, count in cur.fetchall()}
total = sum(counts.values())
return {
"after_event_id": after_event_id,
"total_events": total,
"by_kind": counts,
}
def execute_rewind(
*, db_path: Path, data_dir: Path, after_event_id: int
) -> Path:
"""Take a snapshot, truncate, and re-project. Returns the snapshot path.
The snapshot is taken inside the same connection scope as the
truncate + reproject so all three commit together — if any step
fails the connection's commit-on-exit is bypassed by the exception
and the database stays untouched. The snapshot file is on disk
regardless, which is the desired behaviour: even if the truncate
aborts, the user has a recovery point.
"""
with open_db(db_path) as conn:
# 1. Snapshot first — we want this on disk before any destructive
# operation runs.
snapshot_path = take_snapshot(
conn, data_dir=data_dir, kind="rewind"
)
# 2. Truncate the event log past the chosen id. Foreign keys are
# ON, but ``event_log.superseded_by`` self-references and the
# rows we're deleting are the only ones that could point
# forward — there's nothing to cascade.
conn.execute(
"DELETE FROM event_log WHERE id > ?", (after_event_id,)
)
# 3. Clear projected tables in topological order so FK ON DELETE
# constraints don't fire on referenced rows. ``activity`` and
# ``scenes`` reference ``containers``; ``chat_state`` references
# ``chats`` by id-convention only (no FK declared). ``memories``,
# ``edges``, ``bots``, ``you_entity``, and ``classifier_failures``
# have no incoming FKs from other projected tables.
#
# ``executescript`` is intentionally avoided so foreign_keys=ON
# stays in effect for each statement — executescript would
# implicitly commit and reset some pragmas on certain SQLite
# builds.
conn.execute("DELETE FROM memories")
conn.execute("DELETE FROM activity")
conn.execute("DELETE FROM scenes")
conn.execute("DELETE FROM containers")
conn.execute("DELETE FROM chat_state")
conn.execute("DELETE FROM chats")
conn.execute("DELETE FROM edges")
conn.execute("DELETE FROM bots")
conn.execute("DELETE FROM you_entity")
conn.execute("DELETE FROM classifier_failures")
# 4. Re-project from the truncated event log. Handler registry
# is module-level state populated by importing chat.state.* —
# callers (the route, tests) need to have those modules
# imported for this to do anything useful.
project(conn)
return snapshot_path
+100
View File
@@ -0,0 +1,100 @@
"""Snapshot service — write a JSON dump of all projected tables to disk.
Used by the rewind flow (Requirements §10.1, T28) so the user can recover a
pre-rewind state if the rewind was a mistake. Stored under
``data/snapshots/{kind}/`` with a UTC timestamp filename.
The dump captures both the event log (so the original event sequence is
preserved verbatim) and every projected table (so a future restore could
either re-load tables directly or re-project from the saved event log).
The FTS shadow table ``memories_fts`` is intentionally skipped — it's a
virtual table maintained by the ``memories_ai/au/ad`` triggers, so it would
rebuild itself on a memories re-load. Snapshotting it would also fail
``PRAGMA table_info`` cleanly since FTS5 reports its columns differently.
"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from pathlib import Path
from sqlite3 import Connection
# Order doesn't affect correctness for snapshotting (we read, not write),
# but listing tables explicitly keeps the snapshot stable across schema
# evolution: a new table won't silently change the dump shape until it's
# added here.
PROJECTED_TABLES = [
"bots",
"you_entity",
"edges",
"memories",
"memories_fts",
"chats",
"chat_state",
"containers",
"scenes",
"activity",
"classifier_failures",
]
def take_snapshot(
conn: Connection, *, data_dir: Path, kind: str = "rewind"
) -> Path:
"""Write a JSON dump of the event log and projected tables.
Returns the path to the written snapshot file. Creates parent
directories as needed. Filename is a UTC timestamp in
``YYYYMMDDTHHMMSSZ`` form so chronological listing matches creation
order.
"""
snapshot_dir = data_dir / "snapshots" / kind
snapshot_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
path = snapshot_dir / f"{timestamp}.json"
dump: dict[str, list] = {}
# Event log: pull every column we care about. ``ts`` and the
# superseded/hidden flags are needed to faithfully reconstruct the
# log on restore.
cur = conn.execute(
"SELECT id, branch_id, ts, kind, payload_json, superseded_by, hidden "
"FROM event_log ORDER BY id"
)
dump["event_log"] = [
{
"id": r[0],
"branch_id": r[1],
"ts": r[2],
"kind": r[3],
"payload_json": r[4],
"superseded_by": r[5],
"hidden": r[6],
}
for r in cur.fetchall()
]
for table in PROJECTED_TABLES:
if table == "memories_fts":
# Virtual FTS5 table — rebuilt by triggers on insert, no need
# to snapshot it (and ``PRAGMA table_info`` reports its
# columns differently).
continue
cur = conn.execute(f"PRAGMA table_info({table})")
cols = [c[1] for c in cur.fetchall()]
if not cols:
# Table not present in this schema version — leave an empty
# list rather than raising, so older snapshots can survive.
dump[table] = []
continue
cur = conn.execute(f"SELECT {', '.join(cols)} FROM {table}")
dump[table] = [dict(zip(cols, row)) for row in cur.fetchall()]
# ``default=str`` covers Path-like or datetime values that might
# sneak through if a column ever stored them; the projected tables
# all use TEXT so this is mostly defensive.
path.write_text(json.dumps(dump, default=str))
return path