feat: periodic snapshots with retention and cold-load fast-path

This commit is contained in:
Joseph Doherty
2026-04-26 14:15:17 -04:00
parent 82be8b3f51
commit b9644fad31
5 changed files with 355 additions and 9 deletions
+28
View File
@@ -1,4 +1,6 @@
from __future__ import annotations from __future__ import annotations
import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
@@ -6,8 +8,12 @@ from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from chat.config import load_settings from chat.config import load_settings
from chat.db.connection import open_db
from chat.db.migrate import apply_migrations from chat.db.migrate import apply_migrations
from chat.eventlog.log import read_events
from chat.eventlog.projector import apply_event
from chat.services.background import BackgroundWorker from chat.services.background import BackgroundWorker
from chat.services.snapshot import latest_snapshot_path, restore_from_snapshot
# Trigger handler registration: # Trigger handler registration:
import chat.state.entities # noqa: F401 import chat.state.entities # noqa: F401
@@ -25,12 +31,34 @@ from chat.web.settings import router as settings_router
from chat.web.sse import router as sse_router from chat.web.sse import router as sse_router
from chat.web.turns import router as turns_router from chat.web.turns import router as turns_router
log = logging.getLogger(__name__)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
settings = load_settings() settings = load_settings()
settings.db_path.parent.mkdir(parents=True, exist_ok=True) settings.db_path.parent.mkdir(parents=True, exist_ok=True)
apply_migrations(settings.db_path) apply_migrations(settings.db_path)
# T31 cold-load fast-path: if a periodic snapshot exists, restore
# projected tables from it and replay only events past its
# ``last_event_id``. Migrations already ran above, so any new tables
# introduced after the snapshot was taken are present and empty —
# the replay-forward step refills them from the event log.
snapshot_path = latest_snapshot_path(settings.data_dir, kind="periodic")
if snapshot_path is not None:
with open_db(settings.db_path) as conn:
last_event_id = restore_from_snapshot(conn, snapshot_path)
for event in read_events(
conn, branch_id=1, after_id=last_event_id
):
apply_event(conn, event)
log.info(
"cold-load restored from %s, replayed events past id %d",
snapshot_path,
last_event_id,
)
app.state.settings = settings app.state.settings = settings
# Background worker for the async significance pass (T22). Each job # Background worker for the async significance pass (T22). Each job
+8
View File
@@ -37,4 +37,12 @@ def load_settings() -> Settings:
raw = tomllib.loads(config_path.read_text()) raw = tomllib.loads(config_path.read_text())
if "CHAT_DB_PATH" in os.environ: if "CHAT_DB_PATH" in os.environ:
raw["db_path"] = Path(os.environ["CHAT_DB_PATH"]) raw["db_path"] = Path(os.environ["CHAT_DB_PATH"])
if "CHAT_DATA_DIR" in os.environ:
raw["data_dir"] = Path(os.environ["CHAT_DATA_DIR"])
elif "data_dir" not in raw and "db_path" in raw:
# T31: when ``CHAT_DB_PATH`` is overridden (typical in tests) but
# ``data_dir`` isn't, derive ``data_dir`` from the db's parent so
# snapshot/auxiliary files stay alongside the test db rather than
# leaking into the real repo data dir.
raw["data_dir"] = Path(raw["db_path"]).parent
return Settings(**raw) return Settings(**raw)
+32
View File
@@ -33,6 +33,11 @@ from chat.db.connection import open_db
from chat.eventlog.log import append_and_apply from chat.eventlog.log import append_and_apply
from chat.llm.client import LLMClient from chat.llm.client import LLMClient
from chat.services.significance import compute_significance from chat.services.significance import compute_significance
from chat.services.snapshot import (
prune_periodic_snapshots,
should_take_periodic_snapshot,
take_snapshot,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -123,6 +128,33 @@ class BackgroundWorker:
memory_id=job.memory_id, memory_id=job.memory_id,
) )
# T31: piggy-back the periodic snapshot check on the background
# worker so we don't need a separate timer task. The classifier
# pass already runs out-of-band, so snapshot I/O on the same
# worker is a natural fit. Each snapshot opens its own
# connection so we don't conflate the snapshot's read-only view
# with the significance-write transaction above. Failures are
# caught and logged: a flaky disk shouldn't take down the
# significance pipeline.
try:
with open_db(self._settings.db_path) as conn:
if should_take_periodic_snapshot(
conn, self._settings.data_dir
):
snapshot_path = take_snapshot(
conn,
data_dir=self._settings.data_dir,
kind="periodic",
)
prune_periodic_snapshots(
self._settings.data_dir, keep=5
)
log.info(
"periodic snapshot taken: %s", snapshot_path
)
except Exception as exc: # noqa: BLE001 — never break the worker
log.exception("periodic snapshot failed: %s", exc)
def _auto_pin_with_cap( def _auto_pin_with_cap(
conn, conn,
+154 -9
View File
@@ -1,26 +1,42 @@
"""Snapshot service — write a JSON dump of all projected tables to disk. """Snapshot service — write a JSON dump of all projected tables to disk.
Used by the rewind flow (Requirements §10.1, T28) so the user can recover a Two snapshot kinds, both covered by this module:
pre-rewind state if the rewind was a mistake. Stored under
``data/snapshots/{kind}/`` with a UTC timestamp filename.
The dump captures both the event log (so the original event sequence is * ``rewind`` (T28, Requirements §10.1): pre-rewind safety snapshot so the
preserved verbatim) and every projected table (so a future restore could user can recover if a rewind was a mistake. Retention: 14 days.
either re-load tables directly or re-project from the saved event log). * ``periodic`` (T31, Requirements §10.4): full-state checkpoint taken
every 100 events OR every 30 minutes since the last one. Retention:
the most recent 5 are kept; older ones are pruned on write.
Both kinds live under ``data/snapshots/{kind}/`` with a UTC timestamp
filename so chronological listing matches creation order.
The dump captures the event log (so the original event sequence is
preserved verbatim), every projected table, and a top-level
``last_event_id`` recording the highest ``event_log.id`` at snapshot
time. The ``last_event_id`` is what the cold-load fast-path uses to
replay only events past the snapshot rather than the entire log.
The FTS shadow table ``memories_fts`` is intentionally skipped — it's a The FTS shadow table ``memories_fts`` is intentionally skipped — it's a
virtual table maintained by the ``memories_ai/au/ad`` triggers, so it would virtual table maintained by the ``memories_ai/au/ad`` triggers, so it
rebuild itself on a memories re-load. Snapshotting it would also fail rebuilds itself on a memories re-load. Snapshotting it would also fail
``PRAGMA table_info`` cleanly since FTS5 reports its columns differently. ``PRAGMA table_info`` cleanly since FTS5 reports its columns differently.
""" """
from __future__ import annotations from __future__ import annotations
import json import json
import time
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from sqlite3 import Connection from sqlite3 import Connection
# Periodic snapshot triggers (Requirements §10.4): "every 100 events OR
# every 30 minutes since last snapshot". Module-level so tests can read
# them and so the values stay together with the policy that uses them.
EVENT_COUNT_THRESHOLD = 100
TIME_THRESHOLD_SECONDS = 30 * 60 # 30 minutes
# Order doesn't affect correctness for snapshotting (we read, not write), # Order doesn't affect correctness for snapshotting (we read, not write),
# but listing tables explicitly keeps the snapshot stable across schema # but listing tables explicitly keeps the snapshot stable across schema
# evolution: a new table won't silently change the dump shape until it's # evolution: a new table won't silently change the dump shape until it's
@@ -49,13 +65,24 @@ def take_snapshot(
directories as needed. Filename is a UTC timestamp in directories as needed. Filename is a UTC timestamp in
``YYYYMMDDTHHMMSSZ`` form so chronological listing matches creation ``YYYYMMDDTHHMMSSZ`` form so chronological listing matches creation
order. order.
The dump's top-level ``last_event_id`` is the highest ``event_log.id``
at snapshot time (0 if the log is empty). This is what the cold-load
fast-path uses to know which suffix of the log to replay.
""" """
snapshot_dir = data_dir / "snapshots" / kind snapshot_dir = data_dir / "snapshots" / kind
snapshot_dir.mkdir(parents=True, exist_ok=True) snapshot_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
path = snapshot_dir / f"{timestamp}.json" path = snapshot_dir / f"{timestamp}.json"
dump: dict[str, list] = {} dump: dict = {}
# Record the high-water-mark id up front so cold-load can replay
# only events past it. ``MAX(id)`` is None on an empty log; treat
# that as 0 (i.e. "replay everything").
cur = conn.execute("SELECT MAX(id) FROM event_log")
max_id_row = cur.fetchone()
dump["last_event_id"] = max_id_row[0] if max_id_row[0] is not None else 0
# Event log: pull every column we care about. ``ts`` and the # Event log: pull every column we care about. ``ts`` and the
# superseded/hidden flags are needed to faithfully reconstruct the # superseded/hidden flags are needed to faithfully reconstruct the
@@ -98,3 +125,121 @@ def take_snapshot(
# all use TEXT so this is mostly defensive. # all use TEXT so this is mostly defensive.
path.write_text(json.dumps(dump, default=str)) path.write_text(json.dumps(dump, default=str))
return path return path
def latest_snapshot_path(data_dir: Path, kind: str = "periodic") -> Path | None:
"""Return the most recent snapshot file for ``kind``, or None if none exist.
Sorting by filename works because :func:`take_snapshot` uses a UTC
timestamp in ``YYYYMMDDTHHMMSSZ`` form — lexicographic order matches
chronological order.
"""
snapshot_dir = data_dir / "snapshots" / kind
if not snapshot_dir.exists():
return None
files = sorted(snapshot_dir.glob("*.json"))
return files[-1] if files else None
def should_take_periodic_snapshot(
conn: Connection, data_dir: Path
) -> bool:
"""Decide whether a periodic snapshot is due per Requirements §10.4.
The policy:
* No prior snapshot and at least one event in the log → take one.
* Time since last snapshot ≥ ``TIME_THRESHOLD_SECONDS`` → take one.
* New events since last snapshot's ``last_event_id`` ≥
``EVENT_COUNT_THRESHOLD`` → take one.
"Time since last snapshot" is measured by the file's mtime — we
don't trust the timestamp embedded in the filename for clock drift
reasons.
"""
latest = latest_snapshot_path(data_dir, kind="periodic")
if latest is None:
# No prior snapshot; take one if there are any events to capture.
cur = conn.execute("SELECT COUNT(*) FROM event_log")
return cur.fetchone()[0] > 0
age_seconds = time.time() - latest.stat().st_mtime
if age_seconds >= TIME_THRESHOLD_SECONDS:
return True
# Count events appended since the last snapshot was written. Reading
# ``last_event_id`` from the dump is cheap (a few KB at most for the
# header) but we still avoid loading the full file by parsing once.
last_dump = json.loads(latest.read_text())
last_event_id = last_dump.get("last_event_id", 0)
cur = conn.execute(
"SELECT COUNT(*) FROM event_log WHERE id > ?", (last_event_id,)
)
new_event_count = cur.fetchone()[0]
return new_event_count >= EVENT_COUNT_THRESHOLD
def prune_periodic_snapshots(data_dir: Path, keep: int = 5) -> int:
"""Delete all but the most recent ``keep`` periodic snapshots.
Returns the number of files removed. Safe to call when the directory
doesn't exist (returns 0). Sorting is by filename, which is the UTC
timestamp — same ordering :func:`latest_snapshot_path` uses.
"""
snapshot_dir = data_dir / "snapshots" / "periodic"
if not snapshot_dir.exists():
return 0
files = sorted(snapshot_dir.glob("*.json"))
to_remove = files[:-keep] if len(files) > keep else []
for f in to_remove:
f.unlink()
return len(to_remove)
def restore_from_snapshot(conn: Connection, snapshot_path: Path) -> int:
"""Restore projected tables from ``snapshot_path``.
Returns the snapshot's ``last_event_id`` so callers (the cold-load
fast-path in :func:`chat.app.lifespan`) know what suffix of the
event log still needs replaying.
Projected tables are cleared in the same FK-respecting order as
:func:`chat.services.rewind.execute_rewind`, then re-populated from
the dump. ``memories_fts`` is skipped — it's a virtual FTS5 table
that rebuilds itself when rows hit ``memories``. The event log
itself is *not* touched: cold-load assumes the on-disk log is the
source of truth and the snapshot is just a fast-forward to skip
re-projecting old events.
"""
dump = json.loads(snapshot_path.read_text())
# Same delete order as rewind: child tables before parents so FK
# ON DELETE doesn't fire on referenced rows.
conn.execute("DELETE FROM memories")
conn.execute("DELETE FROM activity")
conn.execute("DELETE FROM scenes")
conn.execute("DELETE FROM containers")
conn.execute("DELETE FROM chat_state")
conn.execute("DELETE FROM chats")
conn.execute("DELETE FROM edges")
conn.execute("DELETE FROM bots")
conn.execute("DELETE FROM you_entity")
conn.execute("DELETE FROM classifier_failures")
for table in PROJECTED_TABLES:
if table == "memories_fts":
# Rebuilt by triggers when memories rows are inserted below.
continue
rows = dump.get(table, [])
if not rows:
continue
cols = list(rows[0].keys())
placeholders = ", ".join("?" * len(cols))
col_list = ", ".join(cols)
for row in rows:
conn.execute(
f"INSERT INTO {table} ({col_list}) VALUES ({placeholders})",
tuple(row[c] for c in cols),
)
return dump.get("last_event_id", 0)
+133
View File
@@ -0,0 +1,133 @@
"""Tests for Task 31 — periodic snapshots with retention and cold-load fast-path.
Per Requirements §10.4 the periodic snapshot policy is:
* Take a snapshot every 100 events OR every 30 minutes since the last one,
whichever comes first.
* Store under ``data/snapshots/periodic/`` with a UTC timestamp filename.
* Retain only the last 5 periodic snapshots; prune older ones on write.
* On cold load, restore from the most recent snapshot and replay events
past the snapshot's ``last_event_id`` to bring projected state forward.
These tests cover the functional core (snapshot timing, pruning, restore).
Worker- and lifespan-level wiring is covered by the integration tests in
``test_turn_flow`` and the existing app boot tests.
"""
from __future__ import annotations
import json
from chat.db.connection import open_db
from chat.db.migrate import apply_migrations
from chat.eventlog.log import append_event
from chat.eventlog.projector import project
from chat.services.snapshot import (
latest_snapshot_path,
prune_periodic_snapshots,
restore_from_snapshot,
should_take_periodic_snapshot,
take_snapshot,
)
# Importing the state modules registers their projector handlers as a
# side effect — restoring + replaying needs them present.
import chat.state.entities # noqa: F401
import chat.state.edges # noqa: F401
import chat.state.manual_edit # noqa: F401
import chat.state.memory # noqa: F401
import chat.state.world # noqa: F401
def _bot_payload(bot_id: str, name: str) -> dict:
return {
"id": bot_id,
"name": name,
"persona": "fancy",
"voice_samples": ["sample"],
"traits": ["shy"],
"backstory": "",
"initial_relationship_to_you": "coworker",
"kickoff_prose": "",
}
def test_take_snapshot_includes_last_event_id(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
append_event(conn, kind="bot_authored", payload=_bot_payload("bot_a", "BotA"))
project(conn)
path = take_snapshot(conn, data_dir=tmp_path / "data", kind="periodic")
dump = json.loads(path.read_text())
assert "last_event_id" in dump
assert dump["last_event_id"] >= 1
def test_should_take_periodic_when_no_prior_and_events_exist(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
append_event(conn, kind="bot_authored", payload=_bot_payload("bot_a", "BotA"))
project(conn)
assert should_take_periodic_snapshot(conn, tmp_path / "data") is True
def test_should_not_take_when_recent_and_few_events(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
append_event(conn, kind="bot_authored", payload=_bot_payload("bot_a", "BotA"))
project(conn)
# Take a snapshot to establish a recent baseline.
take_snapshot(conn, data_dir=tmp_path / "data", kind="periodic")
# Right after — should be False (within time threshold and < 100 new events).
assert should_take_periodic_snapshot(conn, tmp_path / "data") is False
def test_prune_keeps_last_5(tmp_path):
snapshot_dir = tmp_path / "data" / "snapshots" / "periodic"
snapshot_dir.mkdir(parents=True)
# Create 8 dummy snapshot files with sortable names.
for i in range(8):
p = snapshot_dir / f"2026010{i}T000000Z.json"
p.write_text(json.dumps({"last_event_id": i}))
removed = prune_periodic_snapshots(tmp_path / "data", keep=5)
assert removed == 3
remaining = sorted(snapshot_dir.glob("*.json"))
assert len(remaining) == 5
# The 5 most recent (highest names) should remain.
assert remaining[0].name == "20260103T000000Z.json"
assert remaining[-1].name == "20260107T000000Z.json"
def test_latest_snapshot_path_returns_none_when_missing(tmp_path):
# No directory yet.
assert latest_snapshot_path(tmp_path / "data", kind="periodic") is None
# Empty directory.
(tmp_path / "data" / "snapshots" / "periodic").mkdir(parents=True)
assert latest_snapshot_path(tmp_path / "data", kind="periodic") is None
def test_restore_from_snapshot_repopulates_tables(tmp_path):
# Source DB: seed a bot, snapshot it.
db1 = tmp_path / "t1.db"
apply_migrations(db1)
with open_db(db1) as conn:
append_event(conn, kind="bot_authored", payload=_bot_payload("bot_a", "BotA"))
project(conn)
snapshot_path = take_snapshot(
conn, data_dir=tmp_path / "data", kind="periodic"
)
# Fresh DB — restore from the snapshot, no event-log replay needed.
db2 = tmp_path / "t2.db"
apply_migrations(db2)
with open_db(db2) as conn:
last_id = restore_from_snapshot(conn, snapshot_path)
assert last_id >= 1
bot = conn.execute(
"SELECT name FROM bots WHERE id = 'bot_a'"
).fetchone()
assert bot is not None
assert bot[0] == "BotA"