feat: rewind with impact preview, pre-rewind snapshot, undo toast
This commit is contained in:
@@ -0,0 +1,112 @@
|
||||
"""Rewind service — truncate the event log past a chosen turn and re-project.
|
||||
|
||||
Per Requirements §10.1 and Plan Task 28, "rewind to here" must:
|
||||
|
||||
1. Take a snapshot of the current state so the user can recover (handed
|
||||
off to :mod:`chat.services.snapshot`).
|
||||
2. Truncate the event log past ``after_event_id`` — physical DELETE for
|
||||
v1 simplicity; the spec says rewind should be a hard truncation, not
|
||||
the soft ``hidden=1`` mechanism used by edits/regenerate.
|
||||
3. Clear projected tables and re-project from the truncated log so live
|
||||
state matches "what the world looked like at turn N". Without the
|
||||
re-projection, projected tables would carry forward stale rows from
|
||||
rewound events (e.g. an ``edge_update`` that bumped affinity past the
|
||||
rewind point would still show in ``edges``).
|
||||
|
||||
Re-projection is a full replay rather than a "revert delta" because most
|
||||
projector handlers are idempotent inserts, but the edge handler is a
|
||||
delta-shaped accumulator — there's no clean way to invert a single
|
||||
``edge_update`` against ``edges.affinity`` without replay. Wiping +
|
||||
replaying is straightforward and correct.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from sqlite3 import Connection
|
||||
|
||||
from chat.db.connection import open_db
|
||||
from chat.eventlog.projector import project
|
||||
from chat.services.snapshot import take_snapshot
|
||||
|
||||
|
||||
def compute_rewind_preview(
|
||||
conn: Connection, after_event_id: int
|
||||
) -> dict:
|
||||
"""Return counts of each event kind that would be removed by rewinding.
|
||||
|
||||
Used by the preview modal so the user sees the impact (e.g. "this
|
||||
will remove 8 events: 4 user_turn, 4 assistant_turn") before
|
||||
confirming. Counts include hidden/superseded rows — they're still
|
||||
physically deleted.
|
||||
"""
|
||||
cur = conn.execute(
|
||||
"SELECT kind, COUNT(*) FROM event_log WHERE id > ? GROUP BY kind "
|
||||
"ORDER BY kind",
|
||||
(after_event_id,),
|
||||
)
|
||||
counts = {kind: count for kind, count in cur.fetchall()}
|
||||
total = sum(counts.values())
|
||||
return {
|
||||
"after_event_id": after_event_id,
|
||||
"total_events": total,
|
||||
"by_kind": counts,
|
||||
}
|
||||
|
||||
|
||||
def execute_rewind(
|
||||
*, db_path: Path, data_dir: Path, after_event_id: int
|
||||
) -> Path:
|
||||
"""Take a snapshot, truncate, and re-project. Returns the snapshot path.
|
||||
|
||||
The snapshot is taken inside the same connection scope as the
|
||||
truncate + reproject so all three commit together — if any step
|
||||
fails the connection's commit-on-exit is bypassed by the exception
|
||||
and the database stays untouched. The snapshot file is on disk
|
||||
regardless, which is the desired behaviour: even if the truncate
|
||||
aborts, the user has a recovery point.
|
||||
"""
|
||||
with open_db(db_path) as conn:
|
||||
# 1. Snapshot first — we want this on disk before any destructive
|
||||
# operation runs.
|
||||
snapshot_path = take_snapshot(
|
||||
conn, data_dir=data_dir, kind="rewind"
|
||||
)
|
||||
|
||||
# 2. Truncate the event log past the chosen id. Foreign keys are
|
||||
# ON, but ``event_log.superseded_by`` self-references and the
|
||||
# rows we're deleting are the only ones that could point
|
||||
# forward — there's nothing to cascade.
|
||||
conn.execute(
|
||||
"DELETE FROM event_log WHERE id > ?", (after_event_id,)
|
||||
)
|
||||
|
||||
# 3. Clear projected tables in topological order so FK ON DELETE
|
||||
# constraints don't fire on referenced rows. ``activity`` and
|
||||
# ``scenes`` reference ``containers``; ``chat_state`` references
|
||||
# ``chats`` by id-convention only (no FK declared). ``memories``,
|
||||
# ``edges``, ``bots``, ``you_entity``, and ``classifier_failures``
|
||||
# have no incoming FKs from other projected tables.
|
||||
#
|
||||
# ``executescript`` is intentionally avoided so foreign_keys=ON
|
||||
# stays in effect for each statement — executescript would
|
||||
# implicitly commit and reset some pragmas on certain SQLite
|
||||
# builds.
|
||||
conn.execute("DELETE FROM memories")
|
||||
conn.execute("DELETE FROM activity")
|
||||
conn.execute("DELETE FROM scenes")
|
||||
conn.execute("DELETE FROM containers")
|
||||
conn.execute("DELETE FROM chat_state")
|
||||
conn.execute("DELETE FROM chats")
|
||||
conn.execute("DELETE FROM edges")
|
||||
conn.execute("DELETE FROM bots")
|
||||
conn.execute("DELETE FROM you_entity")
|
||||
conn.execute("DELETE FROM classifier_failures")
|
||||
|
||||
# 4. Re-project from the truncated event log. Handler registry
|
||||
# is module-level state populated by importing chat.state.* —
|
||||
# callers (the route, tests) need to have those modules
|
||||
# imported for this to do anything useful.
|
||||
project(conn)
|
||||
|
||||
return snapshot_path
|
||||
@@ -0,0 +1,100 @@
|
||||
"""Snapshot service — write a JSON dump of all projected tables to disk.
|
||||
|
||||
Used by the rewind flow (Requirements §10.1, T28) so the user can recover a
|
||||
pre-rewind state if the rewind was a mistake. Stored under
|
||||
``data/snapshots/{kind}/`` with a UTC timestamp filename.
|
||||
|
||||
The dump captures both the event log (so the original event sequence is
|
||||
preserved verbatim) and every projected table (so a future restore could
|
||||
either re-load tables directly or re-project from the saved event log).
|
||||
|
||||
The FTS shadow table ``memories_fts`` is intentionally skipped — it's a
|
||||
virtual table maintained by the ``memories_ai/au/ad`` triggers, so it would
|
||||
rebuild itself on a memories re-load. Snapshotting it would also fail
|
||||
``PRAGMA table_info`` cleanly since FTS5 reports its columns differently.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from sqlite3 import Connection
|
||||
|
||||
# Order doesn't affect correctness for snapshotting (we read, not write),
|
||||
# but listing tables explicitly keeps the snapshot stable across schema
|
||||
# evolution: a new table won't silently change the dump shape until it's
|
||||
# added here.
|
||||
PROJECTED_TABLES = [
|
||||
"bots",
|
||||
"you_entity",
|
||||
"edges",
|
||||
"memories",
|
||||
"memories_fts",
|
||||
"chats",
|
||||
"chat_state",
|
||||
"containers",
|
||||
"scenes",
|
||||
"activity",
|
||||
"classifier_failures",
|
||||
]
|
||||
|
||||
|
||||
def take_snapshot(
|
||||
conn: Connection, *, data_dir: Path, kind: str = "rewind"
|
||||
) -> Path:
|
||||
"""Write a JSON dump of the event log and projected tables.
|
||||
|
||||
Returns the path to the written snapshot file. Creates parent
|
||||
directories as needed. Filename is a UTC timestamp in
|
||||
``YYYYMMDDTHHMMSSZ`` form so chronological listing matches creation
|
||||
order.
|
||||
"""
|
||||
snapshot_dir = data_dir / "snapshots" / kind
|
||||
snapshot_dir.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
path = snapshot_dir / f"{timestamp}.json"
|
||||
|
||||
dump: dict[str, list] = {}
|
||||
|
||||
# Event log: pull every column we care about. ``ts`` and the
|
||||
# superseded/hidden flags are needed to faithfully reconstruct the
|
||||
# log on restore.
|
||||
cur = conn.execute(
|
||||
"SELECT id, branch_id, ts, kind, payload_json, superseded_by, hidden "
|
||||
"FROM event_log ORDER BY id"
|
||||
)
|
||||
dump["event_log"] = [
|
||||
{
|
||||
"id": r[0],
|
||||
"branch_id": r[1],
|
||||
"ts": r[2],
|
||||
"kind": r[3],
|
||||
"payload_json": r[4],
|
||||
"superseded_by": r[5],
|
||||
"hidden": r[6],
|
||||
}
|
||||
for r in cur.fetchall()
|
||||
]
|
||||
|
||||
for table in PROJECTED_TABLES:
|
||||
if table == "memories_fts":
|
||||
# Virtual FTS5 table — rebuilt by triggers on insert, no need
|
||||
# to snapshot it (and ``PRAGMA table_info`` reports its
|
||||
# columns differently).
|
||||
continue
|
||||
cur = conn.execute(f"PRAGMA table_info({table})")
|
||||
cols = [c[1] for c in cur.fetchall()]
|
||||
if not cols:
|
||||
# Table not present in this schema version — leave an empty
|
||||
# list rather than raising, so older snapshots can survive.
|
||||
dump[table] = []
|
||||
continue
|
||||
cur = conn.execute(f"SELECT {', '.join(cols)} FROM {table}")
|
||||
dump[table] = [dict(zip(cols, row)) for row in cur.fetchall()]
|
||||
|
||||
# ``default=str`` covers Path-like or datetime values that might
|
||||
# sneak through if a column ever stored them; the projected tables
|
||||
# all use TEXT so this is mostly defensive.
|
||||
path.write_text(json.dumps(dump, default=str))
|
||||
return path
|
||||
+76
-1
@@ -36,12 +36,13 @@ import html
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Request
|
||||
from fastapi.responses import Response
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse, Response
|
||||
|
||||
from chat.eventlog.log import append_and_apply, append_event
|
||||
from chat.services.background import SignificanceJob
|
||||
from chat.services.memory_write import record_turn_memory
|
||||
from chat.services.prompt import assemble_narrative_prompt
|
||||
from chat.services.rewind import compute_rewind_preview, execute_rewind
|
||||
from chat.services.scene_close import detect_scene_close
|
||||
from chat.services.scene_summarize import apply_scene_close_summary
|
||||
from chat.services.state_update import compute_state_update
|
||||
@@ -409,3 +410,77 @@ async def post_turn(
|
||||
raise asyncio.CancelledError
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rewind routes (Task 28).
|
||||
#
|
||||
# Two endpoints: a GET that renders the impact-preview modal, and a POST
|
||||
# that actually executes the rewind. The execution path opens its own
|
||||
# database connection because the route's ``conn`` is closed when the
|
||||
# dependency-injection scope exits — passing it to ``execute_rewind``
|
||||
# would dangle.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/chats/{chat_id}/rewind/preview/{event_id}",
|
||||
response_class=HTMLResponse,
|
||||
)
|
||||
async def rewind_preview(
|
||||
chat_id: str,
|
||||
event_id: int,
|
||||
request: Request,
|
||||
conn=Depends(get_conn),
|
||||
):
|
||||
"""Render the rewind impact-preview modal as a small HTML fragment.
|
||||
|
||||
The HTMX form inside the fragment posts to the execute endpoint
|
||||
below. v1 keeps the markup minimal — Task 35 polishes the modal.
|
||||
"""
|
||||
chat = get_chat(conn, chat_id)
|
||||
if chat is None:
|
||||
raise HTTPException(status_code=404, detail=f"chat not found: {chat_id}")
|
||||
preview = compute_rewind_preview(conn, event_id)
|
||||
items = "".join(
|
||||
f"<li>{count} × {html.escape(kind)}</li>"
|
||||
for kind, count in preview["by_kind"].items()
|
||||
)
|
||||
body = (
|
||||
"<div class='rewind-modal'>"
|
||||
f"<h3>Rewind to event {event_id}?</h3>"
|
||||
f"<p>This will remove {preview['total_events']} events:</p>"
|
||||
f"<ul>{items}</ul>"
|
||||
f"<form hx-post='/chats/{html.escape(chat_id)}/rewind/{event_id}' "
|
||||
"hx-target='body' hx-swap='innerHTML'>"
|
||||
"<button type='submit'>Confirm Rewind</button>"
|
||||
"</form>"
|
||||
"</div>"
|
||||
)
|
||||
return HTMLResponse(body)
|
||||
|
||||
|
||||
@router.post("/chats/{chat_id}/rewind/{event_id}")
|
||||
async def rewind_execute(
|
||||
chat_id: str,
|
||||
event_id: int,
|
||||
request: Request,
|
||||
conn=Depends(get_conn),
|
||||
):
|
||||
"""Execute the rewind: snapshot, truncate event_log, re-project.
|
||||
|
||||
Note: ``conn`` is only used to validate the chat exists. The actual
|
||||
rewind opens its own connection inside ``execute_rewind`` because
|
||||
we need it to commit independently and survive the route's
|
||||
dependency teardown.
|
||||
"""
|
||||
chat = get_chat(conn, chat_id)
|
||||
if chat is None:
|
||||
raise HTTPException(status_code=404, detail=f"chat not found: {chat_id}")
|
||||
settings = request.app.state.settings
|
||||
execute_rewind(
|
||||
db_path=settings.db_path,
|
||||
data_dir=settings.data_dir,
|
||||
after_event_id=event_id,
|
||||
)
|
||||
return RedirectResponse(url=f"/chats/{chat_id}", status_code=303)
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
"""Tests for Task 28 — rewind with snapshot, impact preview, and re-projection.
|
||||
|
||||
Per Requirements §10.1, rewind must:
|
||||
|
||||
* take a pre-rewind snapshot of all projected tables (so the user can recover),
|
||||
* truncate the event log past a chosen event id,
|
||||
* clear projected tables and re-project from the truncated log so live state
|
||||
matches "what the world looked like at turn N" (no stale rows from rewound
|
||||
events).
|
||||
|
||||
These tests cover the functional core. The HTTP route surface is left to the
|
||||
plan's polish pass — tests exercise via direct service calls.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from chat.db.connection import open_db
|
||||
from chat.db.migrate import apply_migrations
|
||||
from chat.eventlog.log import append_event
|
||||
from chat.eventlog.projector import project
|
||||
from chat.services.rewind import compute_rewind_preview, execute_rewind
|
||||
from chat.services.snapshot import take_snapshot
|
||||
|
||||
# Importing the state modules registers their projector handlers as a
|
||||
# side effect — the test would otherwise see an unprojected db after
|
||||
# re-projection because the registry would be empty.
|
||||
import chat.state.entities # noqa: F401
|
||||
import chat.state.edges # noqa: F401
|
||||
import chat.state.memory # noqa: F401
|
||||
import chat.state.world # noqa: F401
|
||||
import chat.state.manual_edit # noqa: F401
|
||||
|
||||
|
||||
def _seed_5_turns(db):
|
||||
"""Seed: bot + chat + 5 mock user/assistant turn pairs.
|
||||
|
||||
user_turn / assistant_turn have no projector handlers — they live in
|
||||
the event_log purely for transcript rendering — so the only
|
||||
projection-bearing events are bot_authored and chat_created. That
|
||||
makes the post-rewind invariants easy to assert without needing the
|
||||
real classifier pass.
|
||||
"""
|
||||
apply_migrations(db)
|
||||
with open_db(db) as conn:
|
||||
append_event(
|
||||
conn,
|
||||
kind="bot_authored",
|
||||
payload={
|
||||
"id": "bot_a",
|
||||
"name": "BotA",
|
||||
"persona": "...",
|
||||
"voice_samples": [],
|
||||
"traits": [],
|
||||
"backstory": "",
|
||||
"initial_relationship_to_you": "",
|
||||
"kickoff_prose": "",
|
||||
},
|
||||
)
|
||||
append_event(
|
||||
conn,
|
||||
kind="chat_created",
|
||||
payload={
|
||||
"id": "chat_bot_a",
|
||||
"host_bot_id": "bot_a",
|
||||
"initial_time": "2026-04-26T20:00:00+00:00",
|
||||
"narrative_anchor": "Day 1",
|
||||
"weather": "",
|
||||
},
|
||||
)
|
||||
for i in range(5):
|
||||
append_event(
|
||||
conn,
|
||||
kind="user_turn",
|
||||
payload={
|
||||
"chat_id": "chat_bot_a",
|
||||
"prose": f"turn {i}",
|
||||
"segments": [],
|
||||
},
|
||||
)
|
||||
append_event(
|
||||
conn,
|
||||
kind="assistant_turn",
|
||||
payload={
|
||||
"chat_id": "chat_bot_a",
|
||||
"speaker_id": "bot_a",
|
||||
"text": f"reply {i}",
|
||||
"truncated": False,
|
||||
"user_turn_id": i,
|
||||
},
|
||||
)
|
||||
project(conn)
|
||||
|
||||
|
||||
def test_take_snapshot_writes_file_to_disk(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
_seed_5_turns(db)
|
||||
with open_db(db) as conn:
|
||||
snapshot_path = take_snapshot(
|
||||
conn, data_dir=tmp_path / "data", kind="rewind"
|
||||
)
|
||||
assert snapshot_path.exists()
|
||||
assert snapshot_path.parent == tmp_path / "data" / "snapshots" / "rewind"
|
||||
data = json.loads(snapshot_path.read_text())
|
||||
assert "event_log" in data
|
||||
assert "bots" in data
|
||||
assert "chats" in data
|
||||
# Bot is in the dump
|
||||
assert any(b["id"] == "bot_a" for b in data["bots"])
|
||||
|
||||
|
||||
def test_compute_rewind_preview_counts_kinds(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
_seed_5_turns(db)
|
||||
with open_db(db) as conn:
|
||||
# The 1st assistant_turn is the 4th event:
|
||||
# 1 bot_authored, 2 chat_created, 3 user_turn, 4 assistant_turn.
|
||||
# Everything past it should be in the preview (4 user + 4 assistant = 8).
|
||||
first_assistant = conn.execute(
|
||||
"SELECT id FROM event_log WHERE kind='assistant_turn' "
|
||||
"ORDER BY id LIMIT 1"
|
||||
).fetchone()[0]
|
||||
preview = compute_rewind_preview(conn, after_event_id=first_assistant)
|
||||
assert preview["after_event_id"] == first_assistant
|
||||
assert preview["total_events"] > 0
|
||||
# by_kind sums to total_events.
|
||||
assert sum(preview["by_kind"].values()) == preview["total_events"]
|
||||
# We rewound past assistant_turn #1, so 4 user + 4 assistant remain.
|
||||
assert preview["by_kind"].get("user_turn") == 4
|
||||
assert preview["by_kind"].get("assistant_turn") == 4
|
||||
|
||||
|
||||
def test_execute_rewind_truncates_and_reprojects(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
data_dir = tmp_path / "data"
|
||||
_seed_5_turns(db)
|
||||
with open_db(db) as conn:
|
||||
first_assistant = conn.execute(
|
||||
"SELECT id FROM event_log WHERE kind='assistant_turn' "
|
||||
"ORDER BY id LIMIT 1"
|
||||
).fetchone()[0]
|
||||
|
||||
snapshot_path = execute_rewind(
|
||||
db_path=db, data_dir=data_dir, after_event_id=first_assistant
|
||||
)
|
||||
# Snapshot is written under data/snapshots/rewind/.
|
||||
assert snapshot_path.exists()
|
||||
assert snapshot_path.parent == data_dir / "snapshots" / "rewind"
|
||||
|
||||
# Verify event_log truncated and projected state matches state-at-turn.
|
||||
with open_db(db) as conn:
|
||||
max_id = conn.execute("SELECT MAX(id) FROM event_log").fetchone()[0]
|
||||
assert max_id == first_assistant
|
||||
# Bot still exists (re-projected from preserved bot_authored event).
|
||||
bot = conn.execute(
|
||||
"SELECT id FROM bots WHERE id = 'bot_a'"
|
||||
).fetchone()
|
||||
assert bot is not None
|
||||
# Chat still exists (re-projected from preserved chat_created event).
|
||||
chat = conn.execute(
|
||||
"SELECT id FROM chats WHERE id = 'chat_bot_a'"
|
||||
).fetchone()
|
||||
assert chat is not None
|
||||
Reference in New Issue
Block a user