feat: periodic snapshots with retention and cold-load fast-path

This commit is contained in:
Joseph Doherty
2026-04-26 14:15:17 -04:00
parent 82be8b3f51
commit b9644fad31
5 changed files with 355 additions and 9 deletions
+32
View File
@@ -33,6 +33,11 @@ from chat.db.connection import open_db
from chat.eventlog.log import append_and_apply
from chat.llm.client import LLMClient
from chat.services.significance import compute_significance
from chat.services.snapshot import (
prune_periodic_snapshots,
should_take_periodic_snapshot,
take_snapshot,
)
log = logging.getLogger(__name__)
@@ -123,6 +128,33 @@ class BackgroundWorker:
memory_id=job.memory_id,
)
# T31: piggy-back the periodic snapshot check on the background
# worker so we don't need a separate timer task. The classifier
# pass already runs out-of-band, so snapshot I/O on the same
# worker is a natural fit. Each snapshot opens its own
# connection so we don't conflate the snapshot's read-only view
# with the significance-write transaction above. Failures are
# caught and logged: a flaky disk shouldn't take down the
# significance pipeline.
try:
with open_db(self._settings.db_path) as conn:
if should_take_periodic_snapshot(
conn, self._settings.data_dir
):
snapshot_path = take_snapshot(
conn,
data_dir=self._settings.data_dir,
kind="periodic",
)
prune_periodic_snapshots(
self._settings.data_dir, keep=5
)
log.info(
"periodic snapshot taken: %s", snapshot_path
)
except Exception as exc: # noqa: BLE001 — never break the worker
log.exception("periodic snapshot failed: %s", exc)
def _auto_pin_with_cap(
conn,
+154 -9
View File
@@ -1,26 +1,42 @@
"""Snapshot service — write a JSON dump of all projected tables to disk.
Used by the rewind flow (Requirements §10.1, T28) so the user can recover a
pre-rewind state if the rewind was a mistake. Stored under
``data/snapshots/{kind}/`` with a UTC timestamp filename.
Two snapshot kinds, both covered by this module:
The dump captures both the event log (so the original event sequence is
preserved verbatim) and every projected table (so a future restore could
either re-load tables directly or re-project from the saved event log).
* ``rewind`` (T28, Requirements §10.1): pre-rewind safety snapshot so the
user can recover if a rewind was a mistake. Retention: 14 days.
* ``periodic`` (T31, Requirements §10.4): full-state checkpoint taken
every 100 events OR every 30 minutes since the last one. Retention:
the most recent 5 are kept; older ones are pruned on write.
Both kinds live under ``data/snapshots/{kind}/`` with a UTC timestamp
filename so chronological listing matches creation order.
The dump captures the event log (so the original event sequence is
preserved verbatim), every projected table, and a top-level
``last_event_id`` recording the highest ``event_log.id`` at snapshot
time. The ``last_event_id`` is what the cold-load fast-path uses to
replay only events past the snapshot rather than the entire log.
The FTS shadow table ``memories_fts`` is intentionally skipped — it's a
virtual table maintained by the ``memories_ai/au/ad`` triggers, so it would
rebuild itself on a memories re-load. Snapshotting it would also fail
virtual table maintained by the ``memories_ai/au/ad`` triggers, so it
rebuilds itself on a memories re-load. Snapshotting it would also fail
``PRAGMA table_info`` cleanly since FTS5 reports its columns differently.
"""
from __future__ import annotations
import json
import time
from datetime import datetime, timezone
from pathlib import Path
from sqlite3 import Connection
# Periodic snapshot triggers (Requirements §10.4): "every 100 events OR
# every 30 minutes since last snapshot". Module-level so tests can read
# them and so the values stay together with the policy that uses them.
EVENT_COUNT_THRESHOLD = 100
TIME_THRESHOLD_SECONDS = 30 * 60 # 30 minutes
# Order doesn't affect correctness for snapshotting (we read, not write),
# but listing tables explicitly keeps the snapshot stable across schema
# evolution: a new table won't silently change the dump shape until it's
@@ -49,13 +65,24 @@ def take_snapshot(
directories as needed. Filename is a UTC timestamp in
``YYYYMMDDTHHMMSSZ`` form so chronological listing matches creation
order.
The dump's top-level ``last_event_id`` is the highest ``event_log.id``
at snapshot time (0 if the log is empty). This is what the cold-load
fast-path uses to know which suffix of the log to replay.
"""
snapshot_dir = data_dir / "snapshots" / kind
snapshot_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
path = snapshot_dir / f"{timestamp}.json"
dump: dict[str, list] = {}
dump: dict = {}
# Record the high-water-mark id up front so cold-load can replay
# only events past it. ``MAX(id)`` is None on an empty log; treat
# that as 0 (i.e. "replay everything").
cur = conn.execute("SELECT MAX(id) FROM event_log")
max_id_row = cur.fetchone()
dump["last_event_id"] = max_id_row[0] if max_id_row[0] is not None else 0
# Event log: pull every column we care about. ``ts`` and the
# superseded/hidden flags are needed to faithfully reconstruct the
@@ -98,3 +125,121 @@ def take_snapshot(
# all use TEXT so this is mostly defensive.
path.write_text(json.dumps(dump, default=str))
return path
def latest_snapshot_path(data_dir: Path, kind: str = "periodic") -> Path | None:
"""Return the most recent snapshot file for ``kind``, or None if none exist.
Sorting by filename works because :func:`take_snapshot` uses a UTC
timestamp in ``YYYYMMDDTHHMMSSZ`` form — lexicographic order matches
chronological order.
"""
snapshot_dir = data_dir / "snapshots" / kind
if not snapshot_dir.exists():
return None
files = sorted(snapshot_dir.glob("*.json"))
return files[-1] if files else None
def should_take_periodic_snapshot(
conn: Connection, data_dir: Path
) -> bool:
"""Decide whether a periodic snapshot is due per Requirements §10.4.
The policy:
* No prior snapshot and at least one event in the log → take one.
* Time since last snapshot ≥ ``TIME_THRESHOLD_SECONDS`` → take one.
* New events since last snapshot's ``last_event_id`` ≥
``EVENT_COUNT_THRESHOLD`` → take one.
"Time since last snapshot" is measured by the file's mtime — we
don't trust the timestamp embedded in the filename for clock drift
reasons.
"""
latest = latest_snapshot_path(data_dir, kind="periodic")
if latest is None:
# No prior snapshot; take one if there are any events to capture.
cur = conn.execute("SELECT COUNT(*) FROM event_log")
return cur.fetchone()[0] > 0
age_seconds = time.time() - latest.stat().st_mtime
if age_seconds >= TIME_THRESHOLD_SECONDS:
return True
# Count events appended since the last snapshot was written. Reading
# ``last_event_id`` from the dump is cheap (a few KB at most for the
# header) but we still avoid loading the full file by parsing once.
last_dump = json.loads(latest.read_text())
last_event_id = last_dump.get("last_event_id", 0)
cur = conn.execute(
"SELECT COUNT(*) FROM event_log WHERE id > ?", (last_event_id,)
)
new_event_count = cur.fetchone()[0]
return new_event_count >= EVENT_COUNT_THRESHOLD
def prune_periodic_snapshots(data_dir: Path, keep: int = 5) -> int:
"""Delete all but the most recent ``keep`` periodic snapshots.
Returns the number of files removed. Safe to call when the directory
doesn't exist (returns 0). Sorting is by filename, which is the UTC
timestamp — same ordering :func:`latest_snapshot_path` uses.
"""
snapshot_dir = data_dir / "snapshots" / "periodic"
if not snapshot_dir.exists():
return 0
files = sorted(snapshot_dir.glob("*.json"))
to_remove = files[:-keep] if len(files) > keep else []
for f in to_remove:
f.unlink()
return len(to_remove)
def restore_from_snapshot(conn: Connection, snapshot_path: Path) -> int:
"""Restore projected tables from ``snapshot_path``.
Returns the snapshot's ``last_event_id`` so callers (the cold-load
fast-path in :func:`chat.app.lifespan`) know what suffix of the
event log still needs replaying.
Projected tables are cleared in the same FK-respecting order as
:func:`chat.services.rewind.execute_rewind`, then re-populated from
the dump. ``memories_fts`` is skipped — it's a virtual FTS5 table
that rebuilds itself when rows hit ``memories``. The event log
itself is *not* touched: cold-load assumes the on-disk log is the
source of truth and the snapshot is just a fast-forward to skip
re-projecting old events.
"""
dump = json.loads(snapshot_path.read_text())
# Same delete order as rewind: child tables before parents so FK
# ON DELETE doesn't fire on referenced rows.
conn.execute("DELETE FROM memories")
conn.execute("DELETE FROM activity")
conn.execute("DELETE FROM scenes")
conn.execute("DELETE FROM containers")
conn.execute("DELETE FROM chat_state")
conn.execute("DELETE FROM chats")
conn.execute("DELETE FROM edges")
conn.execute("DELETE FROM bots")
conn.execute("DELETE FROM you_entity")
conn.execute("DELETE FROM classifier_failures")
for table in PROJECTED_TABLES:
if table == "memories_fts":
# Rebuilt by triggers when memories rows are inserted below.
continue
rows = dump.get(table, [])
if not rows:
continue
cols = list(rows[0].keys())
placeholders = ", ".join("?" * len(cols))
col_list = ", ".join(cols)
for row in rows:
conn.execute(
f"INSERT INTO {table} ({col_list}) VALUES ({placeholders})",
tuple(row[c] for c in cols),
)
return dump.get("last_event_id", 0)