From b9644fad312f141a53cc095155416320e971bcde Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 26 Apr 2026 14:15:17 -0400 Subject: [PATCH] feat: periodic snapshots with retention and cold-load fast-path --- chat/app.py | 28 +++++++ chat/config.py | 8 ++ chat/services/background.py | 32 +++++++ chat/services/snapshot.py | 163 ++++++++++++++++++++++++++++++++++-- tests/test_snapshot.py | 133 +++++++++++++++++++++++++++++ 5 files changed, 355 insertions(+), 9 deletions(-) create mode 100644 tests/test_snapshot.py diff --git a/chat/app.py b/chat/app.py index 0acbf69..709eeea 100644 --- a/chat/app.py +++ b/chat/app.py @@ -1,4 +1,6 @@ from __future__ import annotations + +import logging from contextlib import asynccontextmanager from pathlib import Path @@ -6,8 +8,12 @@ from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from chat.config import load_settings +from chat.db.connection import open_db from chat.db.migrate import apply_migrations +from chat.eventlog.log import read_events +from chat.eventlog.projector import apply_event from chat.services.background import BackgroundWorker +from chat.services.snapshot import latest_snapshot_path, restore_from_snapshot # Trigger handler registration: import chat.state.entities # noqa: F401 @@ -25,12 +31,34 @@ from chat.web.settings import router as settings_router from chat.web.sse import router as sse_router from chat.web.turns import router as turns_router +log = logging.getLogger(__name__) + @asynccontextmanager async def lifespan(app: FastAPI): settings = load_settings() settings.db_path.parent.mkdir(parents=True, exist_ok=True) apply_migrations(settings.db_path) + + # T31 cold-load fast-path: if a periodic snapshot exists, restore + # projected tables from it and replay only events past its + # ``last_event_id``. Migrations already ran above, so any new tables + # introduced after the snapshot was taken are present and empty — + # the replay-forward step refills them from the event log. + snapshot_path = latest_snapshot_path(settings.data_dir, kind="periodic") + if snapshot_path is not None: + with open_db(settings.db_path) as conn: + last_event_id = restore_from_snapshot(conn, snapshot_path) + for event in read_events( + conn, branch_id=1, after_id=last_event_id + ): + apply_event(conn, event) + log.info( + "cold-load restored from %s, replayed events past id %d", + snapshot_path, + last_event_id, + ) + app.state.settings = settings # Background worker for the async significance pass (T22). Each job diff --git a/chat/config.py b/chat/config.py index 90d6175..b2aed74 100644 --- a/chat/config.py +++ b/chat/config.py @@ -37,4 +37,12 @@ def load_settings() -> Settings: raw = tomllib.loads(config_path.read_text()) if "CHAT_DB_PATH" in os.environ: raw["db_path"] = Path(os.environ["CHAT_DB_PATH"]) + if "CHAT_DATA_DIR" in os.environ: + raw["data_dir"] = Path(os.environ["CHAT_DATA_DIR"]) + elif "data_dir" not in raw and "db_path" in raw: + # T31: when ``CHAT_DB_PATH`` is overridden (typical in tests) but + # ``data_dir`` isn't, derive ``data_dir`` from the db's parent so + # snapshot/auxiliary files stay alongside the test db rather than + # leaking into the real repo data dir. + raw["data_dir"] = Path(raw["db_path"]).parent return Settings(**raw) diff --git a/chat/services/background.py b/chat/services/background.py index 0ad7f68..f02285c 100644 --- a/chat/services/background.py +++ b/chat/services/background.py @@ -33,6 +33,11 @@ from chat.db.connection import open_db from chat.eventlog.log import append_and_apply from chat.llm.client import LLMClient from chat.services.significance import compute_significance +from chat.services.snapshot import ( + prune_periodic_snapshots, + should_take_periodic_snapshot, + take_snapshot, +) log = logging.getLogger(__name__) @@ -123,6 +128,33 @@ class BackgroundWorker: memory_id=job.memory_id, ) + # T31: piggy-back the periodic snapshot check on the background + # worker so we don't need a separate timer task. The classifier + # pass already runs out-of-band, so snapshot I/O on the same + # worker is a natural fit. Each snapshot opens its own + # connection so we don't conflate the snapshot's read-only view + # with the significance-write transaction above. Failures are + # caught and logged: a flaky disk shouldn't take down the + # significance pipeline. + try: + with open_db(self._settings.db_path) as conn: + if should_take_periodic_snapshot( + conn, self._settings.data_dir + ): + snapshot_path = take_snapshot( + conn, + data_dir=self._settings.data_dir, + kind="periodic", + ) + prune_periodic_snapshots( + self._settings.data_dir, keep=5 + ) + log.info( + "periodic snapshot taken: %s", snapshot_path + ) + except Exception as exc: # noqa: BLE001 — never break the worker + log.exception("periodic snapshot failed: %s", exc) + def _auto_pin_with_cap( conn, diff --git a/chat/services/snapshot.py b/chat/services/snapshot.py index 2cfea68..e1679b8 100644 --- a/chat/services/snapshot.py +++ b/chat/services/snapshot.py @@ -1,26 +1,42 @@ """Snapshot service — write a JSON dump of all projected tables to disk. -Used by the rewind flow (Requirements §10.1, T28) so the user can recover a -pre-rewind state if the rewind was a mistake. Stored under -``data/snapshots/{kind}/`` with a UTC timestamp filename. +Two snapshot kinds, both covered by this module: -The dump captures both the event log (so the original event sequence is -preserved verbatim) and every projected table (so a future restore could -either re-load tables directly or re-project from the saved event log). +* ``rewind`` (T28, Requirements §10.1): pre-rewind safety snapshot so the + user can recover if a rewind was a mistake. Retention: 14 days. +* ``periodic`` (T31, Requirements §10.4): full-state checkpoint taken + every 100 events OR every 30 minutes since the last one. Retention: + the most recent 5 are kept; older ones are pruned on write. + +Both kinds live under ``data/snapshots/{kind}/`` with a UTC timestamp +filename so chronological listing matches creation order. + +The dump captures the event log (so the original event sequence is +preserved verbatim), every projected table, and a top-level +``last_event_id`` recording the highest ``event_log.id`` at snapshot +time. The ``last_event_id`` is what the cold-load fast-path uses to +replay only events past the snapshot rather than the entire log. The FTS shadow table ``memories_fts`` is intentionally skipped — it's a -virtual table maintained by the ``memories_ai/au/ad`` triggers, so it would -rebuild itself on a memories re-load. Snapshotting it would also fail +virtual table maintained by the ``memories_ai/au/ad`` triggers, so it +rebuilds itself on a memories re-load. Snapshotting it would also fail ``PRAGMA table_info`` cleanly since FTS5 reports its columns differently. """ from __future__ import annotations import json +import time from datetime import datetime, timezone from pathlib import Path from sqlite3 import Connection +# Periodic snapshot triggers (Requirements §10.4): "every 100 events OR +# every 30 minutes since last snapshot". Module-level so tests can read +# them and so the values stay together with the policy that uses them. +EVENT_COUNT_THRESHOLD = 100 +TIME_THRESHOLD_SECONDS = 30 * 60 # 30 minutes + # Order doesn't affect correctness for snapshotting (we read, not write), # but listing tables explicitly keeps the snapshot stable across schema # evolution: a new table won't silently change the dump shape until it's @@ -49,13 +65,24 @@ def take_snapshot( directories as needed. Filename is a UTC timestamp in ``YYYYMMDDTHHMMSSZ`` form so chronological listing matches creation order. + + The dump's top-level ``last_event_id`` is the highest ``event_log.id`` + at snapshot time (0 if the log is empty). This is what the cold-load + fast-path uses to know which suffix of the log to replay. """ snapshot_dir = data_dir / "snapshots" / kind snapshot_dir.mkdir(parents=True, exist_ok=True) timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") path = snapshot_dir / f"{timestamp}.json" - dump: dict[str, list] = {} + dump: dict = {} + + # Record the high-water-mark id up front so cold-load can replay + # only events past it. ``MAX(id)`` is None on an empty log; treat + # that as 0 (i.e. "replay everything"). + cur = conn.execute("SELECT MAX(id) FROM event_log") + max_id_row = cur.fetchone() + dump["last_event_id"] = max_id_row[0] if max_id_row[0] is not None else 0 # Event log: pull every column we care about. ``ts`` and the # superseded/hidden flags are needed to faithfully reconstruct the @@ -98,3 +125,121 @@ def take_snapshot( # all use TEXT so this is mostly defensive. path.write_text(json.dumps(dump, default=str)) return path + + +def latest_snapshot_path(data_dir: Path, kind: str = "periodic") -> Path | None: + """Return the most recent snapshot file for ``kind``, or None if none exist. + + Sorting by filename works because :func:`take_snapshot` uses a UTC + timestamp in ``YYYYMMDDTHHMMSSZ`` form — lexicographic order matches + chronological order. + """ + snapshot_dir = data_dir / "snapshots" / kind + if not snapshot_dir.exists(): + return None + files = sorted(snapshot_dir.glob("*.json")) + return files[-1] if files else None + + +def should_take_periodic_snapshot( + conn: Connection, data_dir: Path +) -> bool: + """Decide whether a periodic snapshot is due per Requirements §10.4. + + The policy: + + * No prior snapshot and at least one event in the log → take one. + * Time since last snapshot ≥ ``TIME_THRESHOLD_SECONDS`` → take one. + * New events since last snapshot's ``last_event_id`` ≥ + ``EVENT_COUNT_THRESHOLD`` → take one. + + "Time since last snapshot" is measured by the file's mtime — we + don't trust the timestamp embedded in the filename for clock drift + reasons. + """ + latest = latest_snapshot_path(data_dir, kind="periodic") + if latest is None: + # No prior snapshot; take one if there are any events to capture. + cur = conn.execute("SELECT COUNT(*) FROM event_log") + return cur.fetchone()[0] > 0 + + age_seconds = time.time() - latest.stat().st_mtime + if age_seconds >= TIME_THRESHOLD_SECONDS: + return True + + # Count events appended since the last snapshot was written. Reading + # ``last_event_id`` from the dump is cheap (a few KB at most for the + # header) but we still avoid loading the full file by parsing once. + last_dump = json.loads(latest.read_text()) + last_event_id = last_dump.get("last_event_id", 0) + cur = conn.execute( + "SELECT COUNT(*) FROM event_log WHERE id > ?", (last_event_id,) + ) + new_event_count = cur.fetchone()[0] + return new_event_count >= EVENT_COUNT_THRESHOLD + + +def prune_periodic_snapshots(data_dir: Path, keep: int = 5) -> int: + """Delete all but the most recent ``keep`` periodic snapshots. + + Returns the number of files removed. Safe to call when the directory + doesn't exist (returns 0). Sorting is by filename, which is the UTC + timestamp — same ordering :func:`latest_snapshot_path` uses. + """ + snapshot_dir = data_dir / "snapshots" / "periodic" + if not snapshot_dir.exists(): + return 0 + files = sorted(snapshot_dir.glob("*.json")) + to_remove = files[:-keep] if len(files) > keep else [] + for f in to_remove: + f.unlink() + return len(to_remove) + + +def restore_from_snapshot(conn: Connection, snapshot_path: Path) -> int: + """Restore projected tables from ``snapshot_path``. + + Returns the snapshot's ``last_event_id`` so callers (the cold-load + fast-path in :func:`chat.app.lifespan`) know what suffix of the + event log still needs replaying. + + Projected tables are cleared in the same FK-respecting order as + :func:`chat.services.rewind.execute_rewind`, then re-populated from + the dump. ``memories_fts`` is skipped — it's a virtual FTS5 table + that rebuilds itself when rows hit ``memories``. The event log + itself is *not* touched: cold-load assumes the on-disk log is the + source of truth and the snapshot is just a fast-forward to skip + re-projecting old events. + """ + dump = json.loads(snapshot_path.read_text()) + + # Same delete order as rewind: child tables before parents so FK + # ON DELETE doesn't fire on referenced rows. + conn.execute("DELETE FROM memories") + conn.execute("DELETE FROM activity") + conn.execute("DELETE FROM scenes") + conn.execute("DELETE FROM containers") + conn.execute("DELETE FROM chat_state") + conn.execute("DELETE FROM chats") + conn.execute("DELETE FROM edges") + conn.execute("DELETE FROM bots") + conn.execute("DELETE FROM you_entity") + conn.execute("DELETE FROM classifier_failures") + + for table in PROJECTED_TABLES: + if table == "memories_fts": + # Rebuilt by triggers when memories rows are inserted below. + continue + rows = dump.get(table, []) + if not rows: + continue + cols = list(rows[0].keys()) + placeholders = ", ".join("?" * len(cols)) + col_list = ", ".join(cols) + for row in rows: + conn.execute( + f"INSERT INTO {table} ({col_list}) VALUES ({placeholders})", + tuple(row[c] for c in cols), + ) + + return dump.get("last_event_id", 0) diff --git a/tests/test_snapshot.py b/tests/test_snapshot.py new file mode 100644 index 0000000..3a418b1 --- /dev/null +++ b/tests/test_snapshot.py @@ -0,0 +1,133 @@ +"""Tests for Task 31 — periodic snapshots with retention and cold-load fast-path. + +Per Requirements §10.4 the periodic snapshot policy is: + +* Take a snapshot every 100 events OR every 30 minutes since the last one, + whichever comes first. +* Store under ``data/snapshots/periodic/`` with a UTC timestamp filename. +* Retain only the last 5 periodic snapshots; prune older ones on write. +* On cold load, restore from the most recent snapshot and replay events + past the snapshot's ``last_event_id`` to bring projected state forward. + +These tests cover the functional core (snapshot timing, pruning, restore). +Worker- and lifespan-level wiring is covered by the integration tests in +``test_turn_flow`` and the existing app boot tests. +""" + +from __future__ import annotations + +import json + +from chat.db.connection import open_db +from chat.db.migrate import apply_migrations +from chat.eventlog.log import append_event +from chat.eventlog.projector import project +from chat.services.snapshot import ( + latest_snapshot_path, + prune_periodic_snapshots, + restore_from_snapshot, + should_take_periodic_snapshot, + take_snapshot, +) + +# Importing the state modules registers their projector handlers as a +# side effect — restoring + replaying needs them present. +import chat.state.entities # noqa: F401 +import chat.state.edges # noqa: F401 +import chat.state.manual_edit # noqa: F401 +import chat.state.memory # noqa: F401 +import chat.state.world # noqa: F401 + + +def _bot_payload(bot_id: str, name: str) -> dict: + return { + "id": bot_id, + "name": name, + "persona": "fancy", + "voice_samples": ["sample"], + "traits": ["shy"], + "backstory": "", + "initial_relationship_to_you": "coworker", + "kickoff_prose": "", + } + + +def test_take_snapshot_includes_last_event_id(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + append_event(conn, kind="bot_authored", payload=_bot_payload("bot_a", "BotA")) + project(conn) + path = take_snapshot(conn, data_dir=tmp_path / "data", kind="periodic") + dump = json.loads(path.read_text()) + assert "last_event_id" in dump + assert dump["last_event_id"] >= 1 + + +def test_should_take_periodic_when_no_prior_and_events_exist(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + append_event(conn, kind="bot_authored", payload=_bot_payload("bot_a", "BotA")) + project(conn) + assert should_take_periodic_snapshot(conn, tmp_path / "data") is True + + +def test_should_not_take_when_recent_and_few_events(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + append_event(conn, kind="bot_authored", payload=_bot_payload("bot_a", "BotA")) + project(conn) + # Take a snapshot to establish a recent baseline. + take_snapshot(conn, data_dir=tmp_path / "data", kind="periodic") + # Right after — should be False (within time threshold and < 100 new events). + assert should_take_periodic_snapshot(conn, tmp_path / "data") is False + + +def test_prune_keeps_last_5(tmp_path): + snapshot_dir = tmp_path / "data" / "snapshots" / "periodic" + snapshot_dir.mkdir(parents=True) + # Create 8 dummy snapshot files with sortable names. + for i in range(8): + p = snapshot_dir / f"2026010{i}T000000Z.json" + p.write_text(json.dumps({"last_event_id": i})) + removed = prune_periodic_snapshots(tmp_path / "data", keep=5) + assert removed == 3 + remaining = sorted(snapshot_dir.glob("*.json")) + assert len(remaining) == 5 + # The 5 most recent (highest names) should remain. + assert remaining[0].name == "20260103T000000Z.json" + assert remaining[-1].name == "20260107T000000Z.json" + + +def test_latest_snapshot_path_returns_none_when_missing(tmp_path): + # No directory yet. + assert latest_snapshot_path(tmp_path / "data", kind="periodic") is None + # Empty directory. + (tmp_path / "data" / "snapshots" / "periodic").mkdir(parents=True) + assert latest_snapshot_path(tmp_path / "data", kind="periodic") is None + + +def test_restore_from_snapshot_repopulates_tables(tmp_path): + # Source DB: seed a bot, snapshot it. + db1 = tmp_path / "t1.db" + apply_migrations(db1) + with open_db(db1) as conn: + append_event(conn, kind="bot_authored", payload=_bot_payload("bot_a", "BotA")) + project(conn) + snapshot_path = take_snapshot( + conn, data_dir=tmp_path / "data", kind="periodic" + ) + + # Fresh DB — restore from the snapshot, no event-log replay needed. + db2 = tmp_path / "t2.db" + apply_migrations(db2) + with open_db(db2) as conn: + last_id = restore_from_snapshot(conn, snapshot_path) + assert last_id >= 1 + bot = conn.execute( + "SELECT name FROM bots WHERE id = 'bot_a'" + ).fetchone() + assert bot is not None + assert bot[0] == "BotA"