"""Snapshot service — write a JSON dump of all projected tables to disk. Two snapshot kinds, both covered by this module: * ``rewind`` (T28, Requirements §10.1): pre-rewind safety snapshot so the user can recover if a rewind was a mistake. Retention: 14 days. * ``periodic`` (T31, Requirements §10.4): full-state checkpoint taken every 100 events OR every 30 minutes since the last one. Retention: the most recent 5 are kept; older ones are pruned on write. Both kinds live under ``data/snapshots/{kind}/`` with a UTC timestamp filename so chronological listing matches creation order. The dump captures the event log (so the original event sequence is preserved verbatim), every projected table, and a top-level ``last_event_id`` recording the highest ``event_log.id`` at snapshot time. The ``last_event_id`` is what the cold-load fast-path uses to replay only events past the snapshot rather than the entire log. The FTS shadow table ``memories_fts`` is intentionally skipped — it's a virtual table maintained by the ``memories_ai/au/ad`` triggers, so it rebuilds itself on a memories re-load. Snapshotting it would also fail ``PRAGMA table_info`` cleanly since FTS5 reports its columns differently. """ from __future__ import annotations import json import time from datetime import datetime, timezone from pathlib import Path from sqlite3 import Connection # Periodic snapshot triggers (Requirements §10.4): "every 100 events OR # every 30 minutes since last snapshot". Module-level so tests can read # them and so the values stay together with the policy that uses them. EVENT_COUNT_THRESHOLD = 100 TIME_THRESHOLD_SECONDS = 30 * 60 # 30 minutes # Order doesn't affect correctness for snapshotting (we read, not write), # but listing tables explicitly keeps the snapshot stable across schema # evolution: a new table won't silently change the dump shape until it's # added here. PROJECTED_TABLES = [ "bots", "you_entity", "edges", "memories", "memories_fts", "chats", "chat_state", "containers", "scenes", "activity", "classifier_failures", ] def take_snapshot( conn: Connection, *, data_dir: Path, kind: str = "rewind" ) -> Path: """Write a JSON dump of the event log and projected tables. Returns the path to the written snapshot file. Creates parent directories as needed. Filename is a UTC timestamp in ``YYYYMMDDTHHMMSSZ`` form so chronological listing matches creation order. The dump's top-level ``last_event_id`` is the highest ``event_log.id`` at snapshot time (0 if the log is empty). This is what the cold-load fast-path uses to know which suffix of the log to replay. """ snapshot_dir = data_dir / "snapshots" / kind snapshot_dir.mkdir(parents=True, exist_ok=True) timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") path = snapshot_dir / f"{timestamp}.json" dump: dict = {} # Record the high-water-mark id up front so cold-load can replay # only events past it. ``MAX(id)`` is None on an empty log; treat # that as 0 (i.e. "replay everything"). cur = conn.execute("SELECT MAX(id) FROM event_log") max_id_row = cur.fetchone() dump["last_event_id"] = max_id_row[0] if max_id_row[0] is not None else 0 # Event log: pull every column we care about. ``ts`` and the # superseded/hidden flags are needed to faithfully reconstruct the # log on restore. cur = conn.execute( "SELECT id, branch_id, ts, kind, payload_json, superseded_by, hidden " "FROM event_log ORDER BY id" ) dump["event_log"] = [ { "id": r[0], "branch_id": r[1], "ts": r[2], "kind": r[3], "payload_json": r[4], "superseded_by": r[5], "hidden": r[6], } for r in cur.fetchall() ] for table in PROJECTED_TABLES: if table == "memories_fts": # Virtual FTS5 table — rebuilt by triggers on insert, no need # to snapshot it (and ``PRAGMA table_info`` reports its # columns differently). continue cur = conn.execute(f"PRAGMA table_info({table})") cols = [c[1] for c in cur.fetchall()] if not cols: # Table not present in this schema version — leave an empty # list rather than raising, so older snapshots can survive. dump[table] = [] continue cur = conn.execute(f"SELECT {', '.join(cols)} FROM {table}") dump[table] = [dict(zip(cols, row)) for row in cur.fetchall()] # ``default=str`` covers Path-like or datetime values that might # sneak through if a column ever stored them; the projected tables # all use TEXT so this is mostly defensive. path.write_text(json.dumps(dump, default=str)) return path def latest_snapshot_path(data_dir: Path, kind: str = "periodic") -> Path | None: """Return the most recent snapshot file for ``kind``, or None if none exist. Sorting by filename works because :func:`take_snapshot` uses a UTC timestamp in ``YYYYMMDDTHHMMSSZ`` form — lexicographic order matches chronological order. """ snapshot_dir = data_dir / "snapshots" / kind if not snapshot_dir.exists(): return None files = sorted(snapshot_dir.glob("*.json")) return files[-1] if files else None def should_take_periodic_snapshot( conn: Connection, data_dir: Path ) -> bool: """Decide whether a periodic snapshot is due per Requirements §10.4. The policy: * No prior snapshot and at least one event in the log → take one. * Time since last snapshot ≥ ``TIME_THRESHOLD_SECONDS`` → take one. * New events since last snapshot's ``last_event_id`` ≥ ``EVENT_COUNT_THRESHOLD`` → take one. "Time since last snapshot" is measured by the file's mtime — we don't trust the timestamp embedded in the filename for clock drift reasons. """ latest = latest_snapshot_path(data_dir, kind="periodic") if latest is None: # No prior snapshot; take one if there are any events to capture. cur = conn.execute("SELECT COUNT(*) FROM event_log") return cur.fetchone()[0] > 0 age_seconds = time.time() - latest.stat().st_mtime if age_seconds >= TIME_THRESHOLD_SECONDS: return True # Count events appended since the last snapshot was written. Reading # ``last_event_id`` from the dump is cheap (a few KB at most for the # header) but we still avoid loading the full file by parsing once. last_dump = json.loads(latest.read_text()) last_event_id = last_dump.get("last_event_id", 0) cur = conn.execute( "SELECT COUNT(*) FROM event_log WHERE id > ?", (last_event_id,) ) new_event_count = cur.fetchone()[0] return new_event_count >= EVENT_COUNT_THRESHOLD def prune_periodic_snapshots(data_dir: Path, keep: int = 5) -> int: """Delete all but the most recent ``keep`` periodic snapshots. Returns the number of files removed. Safe to call when the directory doesn't exist (returns 0). Sorting is by filename, which is the UTC timestamp — same ordering :func:`latest_snapshot_path` uses. """ snapshot_dir = data_dir / "snapshots" / "periodic" if not snapshot_dir.exists(): return 0 files = sorted(snapshot_dir.glob("*.json")) to_remove = files[:-keep] if len(files) > keep else [] for f in to_remove: f.unlink() return len(to_remove) def restore_from_snapshot(conn: Connection, snapshot_path: Path) -> int: """Restore projected tables from ``snapshot_path``. Returns the snapshot's ``last_event_id`` so callers (the cold-load fast-path in :func:`chat.app.lifespan`) know what suffix of the event log still needs replaying. Projected tables are cleared in the same FK-respecting order as :func:`chat.services.rewind.execute_rewind`, then re-populated from the dump. ``memories_fts`` is skipped — it's a virtual FTS5 table that rebuilds itself when rows hit ``memories``. The event log itself is *not* touched: cold-load assumes the on-disk log is the source of truth and the snapshot is just a fast-forward to skip re-projecting old events. """ dump = json.loads(snapshot_path.read_text()) # Same delete order as rewind: child tables before parents so FK # ON DELETE doesn't fire on referenced rows. conn.execute("DELETE FROM memories") conn.execute("DELETE FROM activity") conn.execute("DELETE FROM scenes") conn.execute("DELETE FROM containers") conn.execute("DELETE FROM chat_state") conn.execute("DELETE FROM chats") conn.execute("DELETE FROM edges") conn.execute("DELETE FROM bots") conn.execute("DELETE FROM you_entity") conn.execute("DELETE FROM classifier_failures") for table in PROJECTED_TABLES: if table == "memories_fts": # Rebuilt by triggers when memories rows are inserted below. continue rows = dump.get(table, []) if not rows: continue cols = list(rows[0].keys()) placeholders = ", ".join("?" * len(cols)) col_list = ", ".join(cols) for row in rows: conn.execute( f"INSERT INTO {table} ({col_list}) VALUES ({placeholders})", tuple(row[c] for c in cols), ) return dump.get("last_event_id", 0)