feat: periodic snapshots with retention and cold-load fast-path
This commit is contained in:
+28
@@ -1,4 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
@@ -6,8 +8,12 @@ from fastapi import FastAPI
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from chat.config import load_settings
|
||||
from chat.db.connection import open_db
|
||||
from chat.db.migrate import apply_migrations
|
||||
from chat.eventlog.log import read_events
|
||||
from chat.eventlog.projector import apply_event
|
||||
from chat.services.background import BackgroundWorker
|
||||
from chat.services.snapshot import latest_snapshot_path, restore_from_snapshot
|
||||
|
||||
# Trigger handler registration:
|
||||
import chat.state.entities # noqa: F401
|
||||
@@ -25,12 +31,34 @@ from chat.web.settings import router as settings_router
|
||||
from chat.web.sse import router as sse_router
|
||||
from chat.web.turns import router as turns_router
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
settings = load_settings()
|
||||
settings.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
apply_migrations(settings.db_path)
|
||||
|
||||
# T31 cold-load fast-path: if a periodic snapshot exists, restore
|
||||
# projected tables from it and replay only events past its
|
||||
# ``last_event_id``. Migrations already ran above, so any new tables
|
||||
# introduced after the snapshot was taken are present and empty —
|
||||
# the replay-forward step refills them from the event log.
|
||||
snapshot_path = latest_snapshot_path(settings.data_dir, kind="periodic")
|
||||
if snapshot_path is not None:
|
||||
with open_db(settings.db_path) as conn:
|
||||
last_event_id = restore_from_snapshot(conn, snapshot_path)
|
||||
for event in read_events(
|
||||
conn, branch_id=1, after_id=last_event_id
|
||||
):
|
||||
apply_event(conn, event)
|
||||
log.info(
|
||||
"cold-load restored from %s, replayed events past id %d",
|
||||
snapshot_path,
|
||||
last_event_id,
|
||||
)
|
||||
|
||||
app.state.settings = settings
|
||||
|
||||
# Background worker for the async significance pass (T22). Each job
|
||||
|
||||
@@ -37,4 +37,12 @@ def load_settings() -> Settings:
|
||||
raw = tomllib.loads(config_path.read_text())
|
||||
if "CHAT_DB_PATH" in os.environ:
|
||||
raw["db_path"] = Path(os.environ["CHAT_DB_PATH"])
|
||||
if "CHAT_DATA_DIR" in os.environ:
|
||||
raw["data_dir"] = Path(os.environ["CHAT_DATA_DIR"])
|
||||
elif "data_dir" not in raw and "db_path" in raw:
|
||||
# T31: when ``CHAT_DB_PATH`` is overridden (typical in tests) but
|
||||
# ``data_dir`` isn't, derive ``data_dir`` from the db's parent so
|
||||
# snapshot/auxiliary files stay alongside the test db rather than
|
||||
# leaking into the real repo data dir.
|
||||
raw["data_dir"] = Path(raw["db_path"]).parent
|
||||
return Settings(**raw)
|
||||
|
||||
@@ -33,6 +33,11 @@ from chat.db.connection import open_db
|
||||
from chat.eventlog.log import append_and_apply
|
||||
from chat.llm.client import LLMClient
|
||||
from chat.services.significance import compute_significance
|
||||
from chat.services.snapshot import (
|
||||
prune_periodic_snapshots,
|
||||
should_take_periodic_snapshot,
|
||||
take_snapshot,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -123,6 +128,33 @@ class BackgroundWorker:
|
||||
memory_id=job.memory_id,
|
||||
)
|
||||
|
||||
# T31: piggy-back the periodic snapshot check on the background
|
||||
# worker so we don't need a separate timer task. The classifier
|
||||
# pass already runs out-of-band, so snapshot I/O on the same
|
||||
# worker is a natural fit. Each snapshot opens its own
|
||||
# connection so we don't conflate the snapshot's read-only view
|
||||
# with the significance-write transaction above. Failures are
|
||||
# caught and logged: a flaky disk shouldn't take down the
|
||||
# significance pipeline.
|
||||
try:
|
||||
with open_db(self._settings.db_path) as conn:
|
||||
if should_take_periodic_snapshot(
|
||||
conn, self._settings.data_dir
|
||||
):
|
||||
snapshot_path = take_snapshot(
|
||||
conn,
|
||||
data_dir=self._settings.data_dir,
|
||||
kind="periodic",
|
||||
)
|
||||
prune_periodic_snapshots(
|
||||
self._settings.data_dir, keep=5
|
||||
)
|
||||
log.info(
|
||||
"periodic snapshot taken: %s", snapshot_path
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001 — never break the worker
|
||||
log.exception("periodic snapshot failed: %s", exc)
|
||||
|
||||
|
||||
def _auto_pin_with_cap(
|
||||
conn,
|
||||
|
||||
+154
-9
@@ -1,26 +1,42 @@
|
||||
"""Snapshot service — write a JSON dump of all projected tables to disk.
|
||||
|
||||
Used by the rewind flow (Requirements §10.1, T28) so the user can recover a
|
||||
pre-rewind state if the rewind was a mistake. Stored under
|
||||
``data/snapshots/{kind}/`` with a UTC timestamp filename.
|
||||
Two snapshot kinds, both covered by this module:
|
||||
|
||||
The dump captures both the event log (so the original event sequence is
|
||||
preserved verbatim) and every projected table (so a future restore could
|
||||
either re-load tables directly or re-project from the saved event log).
|
||||
* ``rewind`` (T28, Requirements §10.1): pre-rewind safety snapshot so the
|
||||
user can recover if a rewind was a mistake. Retention: 14 days.
|
||||
* ``periodic`` (T31, Requirements §10.4): full-state checkpoint taken
|
||||
every 100 events OR every 30 minutes since the last one. Retention:
|
||||
the most recent 5 are kept; older ones are pruned on write.
|
||||
|
||||
Both kinds live under ``data/snapshots/{kind}/`` with a UTC timestamp
|
||||
filename so chronological listing matches creation order.
|
||||
|
||||
The dump captures the event log (so the original event sequence is
|
||||
preserved verbatim), every projected table, and a top-level
|
||||
``last_event_id`` recording the highest ``event_log.id`` at snapshot
|
||||
time. The ``last_event_id`` is what the cold-load fast-path uses to
|
||||
replay only events past the snapshot rather than the entire log.
|
||||
|
||||
The FTS shadow table ``memories_fts`` is intentionally skipped — it's a
|
||||
virtual table maintained by the ``memories_ai/au/ad`` triggers, so it would
|
||||
rebuild itself on a memories re-load. Snapshotting it would also fail
|
||||
virtual table maintained by the ``memories_ai/au/ad`` triggers, so it
|
||||
rebuilds itself on a memories re-load. Snapshotting it would also fail
|
||||
``PRAGMA table_info`` cleanly since FTS5 reports its columns differently.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from sqlite3 import Connection
|
||||
|
||||
# Periodic snapshot triggers (Requirements §10.4): "every 100 events OR
|
||||
# every 30 minutes since last snapshot". Module-level so tests can read
|
||||
# them and so the values stay together with the policy that uses them.
|
||||
EVENT_COUNT_THRESHOLD = 100
|
||||
TIME_THRESHOLD_SECONDS = 30 * 60 # 30 minutes
|
||||
|
||||
# Order doesn't affect correctness for snapshotting (we read, not write),
|
||||
# but listing tables explicitly keeps the snapshot stable across schema
|
||||
# evolution: a new table won't silently change the dump shape until it's
|
||||
@@ -49,13 +65,24 @@ def take_snapshot(
|
||||
directories as needed. Filename is a UTC timestamp in
|
||||
``YYYYMMDDTHHMMSSZ`` form so chronological listing matches creation
|
||||
order.
|
||||
|
||||
The dump's top-level ``last_event_id`` is the highest ``event_log.id``
|
||||
at snapshot time (0 if the log is empty). This is what the cold-load
|
||||
fast-path uses to know which suffix of the log to replay.
|
||||
"""
|
||||
snapshot_dir = data_dir / "snapshots" / kind
|
||||
snapshot_dir.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
path = snapshot_dir / f"{timestamp}.json"
|
||||
|
||||
dump: dict[str, list] = {}
|
||||
dump: dict = {}
|
||||
|
||||
# Record the high-water-mark id up front so cold-load can replay
|
||||
# only events past it. ``MAX(id)`` is None on an empty log; treat
|
||||
# that as 0 (i.e. "replay everything").
|
||||
cur = conn.execute("SELECT MAX(id) FROM event_log")
|
||||
max_id_row = cur.fetchone()
|
||||
dump["last_event_id"] = max_id_row[0] if max_id_row[0] is not None else 0
|
||||
|
||||
# Event log: pull every column we care about. ``ts`` and the
|
||||
# superseded/hidden flags are needed to faithfully reconstruct the
|
||||
@@ -98,3 +125,121 @@ def take_snapshot(
|
||||
# all use TEXT so this is mostly defensive.
|
||||
path.write_text(json.dumps(dump, default=str))
|
||||
return path
|
||||
|
||||
|
||||
def latest_snapshot_path(data_dir: Path, kind: str = "periodic") -> Path | None:
|
||||
"""Return the most recent snapshot file for ``kind``, or None if none exist.
|
||||
|
||||
Sorting by filename works because :func:`take_snapshot` uses a UTC
|
||||
timestamp in ``YYYYMMDDTHHMMSSZ`` form — lexicographic order matches
|
||||
chronological order.
|
||||
"""
|
||||
snapshot_dir = data_dir / "snapshots" / kind
|
||||
if not snapshot_dir.exists():
|
||||
return None
|
||||
files = sorted(snapshot_dir.glob("*.json"))
|
||||
return files[-1] if files else None
|
||||
|
||||
|
||||
def should_take_periodic_snapshot(
|
||||
conn: Connection, data_dir: Path
|
||||
) -> bool:
|
||||
"""Decide whether a periodic snapshot is due per Requirements §10.4.
|
||||
|
||||
The policy:
|
||||
|
||||
* No prior snapshot and at least one event in the log → take one.
|
||||
* Time since last snapshot ≥ ``TIME_THRESHOLD_SECONDS`` → take one.
|
||||
* New events since last snapshot's ``last_event_id`` ≥
|
||||
``EVENT_COUNT_THRESHOLD`` → take one.
|
||||
|
||||
"Time since last snapshot" is measured by the file's mtime — we
|
||||
don't trust the timestamp embedded in the filename for clock drift
|
||||
reasons.
|
||||
"""
|
||||
latest = latest_snapshot_path(data_dir, kind="periodic")
|
||||
if latest is None:
|
||||
# No prior snapshot; take one if there are any events to capture.
|
||||
cur = conn.execute("SELECT COUNT(*) FROM event_log")
|
||||
return cur.fetchone()[0] > 0
|
||||
|
||||
age_seconds = time.time() - latest.stat().st_mtime
|
||||
if age_seconds >= TIME_THRESHOLD_SECONDS:
|
||||
return True
|
||||
|
||||
# Count events appended since the last snapshot was written. Reading
|
||||
# ``last_event_id`` from the dump is cheap (a few KB at most for the
|
||||
# header) but we still avoid loading the full file by parsing once.
|
||||
last_dump = json.loads(latest.read_text())
|
||||
last_event_id = last_dump.get("last_event_id", 0)
|
||||
cur = conn.execute(
|
||||
"SELECT COUNT(*) FROM event_log WHERE id > ?", (last_event_id,)
|
||||
)
|
||||
new_event_count = cur.fetchone()[0]
|
||||
return new_event_count >= EVENT_COUNT_THRESHOLD
|
||||
|
||||
|
||||
def prune_periodic_snapshots(data_dir: Path, keep: int = 5) -> int:
|
||||
"""Delete all but the most recent ``keep`` periodic snapshots.
|
||||
|
||||
Returns the number of files removed. Safe to call when the directory
|
||||
doesn't exist (returns 0). Sorting is by filename, which is the UTC
|
||||
timestamp — same ordering :func:`latest_snapshot_path` uses.
|
||||
"""
|
||||
snapshot_dir = data_dir / "snapshots" / "periodic"
|
||||
if not snapshot_dir.exists():
|
||||
return 0
|
||||
files = sorted(snapshot_dir.glob("*.json"))
|
||||
to_remove = files[:-keep] if len(files) > keep else []
|
||||
for f in to_remove:
|
||||
f.unlink()
|
||||
return len(to_remove)
|
||||
|
||||
|
||||
def restore_from_snapshot(conn: Connection, snapshot_path: Path) -> int:
|
||||
"""Restore projected tables from ``snapshot_path``.
|
||||
|
||||
Returns the snapshot's ``last_event_id`` so callers (the cold-load
|
||||
fast-path in :func:`chat.app.lifespan`) know what suffix of the
|
||||
event log still needs replaying.
|
||||
|
||||
Projected tables are cleared in the same FK-respecting order as
|
||||
:func:`chat.services.rewind.execute_rewind`, then re-populated from
|
||||
the dump. ``memories_fts`` is skipped — it's a virtual FTS5 table
|
||||
that rebuilds itself when rows hit ``memories``. The event log
|
||||
itself is *not* touched: cold-load assumes the on-disk log is the
|
||||
source of truth and the snapshot is just a fast-forward to skip
|
||||
re-projecting old events.
|
||||
"""
|
||||
dump = json.loads(snapshot_path.read_text())
|
||||
|
||||
# Same delete order as rewind: child tables before parents so FK
|
||||
# ON DELETE doesn't fire on referenced rows.
|
||||
conn.execute("DELETE FROM memories")
|
||||
conn.execute("DELETE FROM activity")
|
||||
conn.execute("DELETE FROM scenes")
|
||||
conn.execute("DELETE FROM containers")
|
||||
conn.execute("DELETE FROM chat_state")
|
||||
conn.execute("DELETE FROM chats")
|
||||
conn.execute("DELETE FROM edges")
|
||||
conn.execute("DELETE FROM bots")
|
||||
conn.execute("DELETE FROM you_entity")
|
||||
conn.execute("DELETE FROM classifier_failures")
|
||||
|
||||
for table in PROJECTED_TABLES:
|
||||
if table == "memories_fts":
|
||||
# Rebuilt by triggers when memories rows are inserted below.
|
||||
continue
|
||||
rows = dump.get(table, [])
|
||||
if not rows:
|
||||
continue
|
||||
cols = list(rows[0].keys())
|
||||
placeholders = ", ".join("?" * len(cols))
|
||||
col_list = ", ".join(cols)
|
||||
for row in rows:
|
||||
conn.execute(
|
||||
f"INSERT INTO {table} ({col_list}) VALUES ({placeholders})",
|
||||
tuple(row[c] for c in cols),
|
||||
)
|
||||
|
||||
return dump.get("last_event_id", 0)
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
"""Tests for Task 31 — periodic snapshots with retention and cold-load fast-path.
|
||||
|
||||
Per Requirements §10.4 the periodic snapshot policy is:
|
||||
|
||||
* Take a snapshot every 100 events OR every 30 minutes since the last one,
|
||||
whichever comes first.
|
||||
* Store under ``data/snapshots/periodic/`` with a UTC timestamp filename.
|
||||
* Retain only the last 5 periodic snapshots; prune older ones on write.
|
||||
* On cold load, restore from the most recent snapshot and replay events
|
||||
past the snapshot's ``last_event_id`` to bring projected state forward.
|
||||
|
||||
These tests cover the functional core (snapshot timing, pruning, restore).
|
||||
Worker- and lifespan-level wiring is covered by the integration tests in
|
||||
``test_turn_flow`` and the existing app boot tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from chat.db.connection import open_db
|
||||
from chat.db.migrate import apply_migrations
|
||||
from chat.eventlog.log import append_event
|
||||
from chat.eventlog.projector import project
|
||||
from chat.services.snapshot import (
|
||||
latest_snapshot_path,
|
||||
prune_periodic_snapshots,
|
||||
restore_from_snapshot,
|
||||
should_take_periodic_snapshot,
|
||||
take_snapshot,
|
||||
)
|
||||
|
||||
# Importing the state modules registers their projector handlers as a
|
||||
# side effect — restoring + replaying needs them present.
|
||||
import chat.state.entities # noqa: F401
|
||||
import chat.state.edges # noqa: F401
|
||||
import chat.state.manual_edit # noqa: F401
|
||||
import chat.state.memory # noqa: F401
|
||||
import chat.state.world # noqa: F401
|
||||
|
||||
|
||||
def _bot_payload(bot_id: str, name: str) -> dict:
|
||||
return {
|
||||
"id": bot_id,
|
||||
"name": name,
|
||||
"persona": "fancy",
|
||||
"voice_samples": ["sample"],
|
||||
"traits": ["shy"],
|
||||
"backstory": "",
|
||||
"initial_relationship_to_you": "coworker",
|
||||
"kickoff_prose": "",
|
||||
}
|
||||
|
||||
|
||||
def test_take_snapshot_includes_last_event_id(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
with open_db(db) as conn:
|
||||
append_event(conn, kind="bot_authored", payload=_bot_payload("bot_a", "BotA"))
|
||||
project(conn)
|
||||
path = take_snapshot(conn, data_dir=tmp_path / "data", kind="periodic")
|
||||
dump = json.loads(path.read_text())
|
||||
assert "last_event_id" in dump
|
||||
assert dump["last_event_id"] >= 1
|
||||
|
||||
|
||||
def test_should_take_periodic_when_no_prior_and_events_exist(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
with open_db(db) as conn:
|
||||
append_event(conn, kind="bot_authored", payload=_bot_payload("bot_a", "BotA"))
|
||||
project(conn)
|
||||
assert should_take_periodic_snapshot(conn, tmp_path / "data") is True
|
||||
|
||||
|
||||
def test_should_not_take_when_recent_and_few_events(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
with open_db(db) as conn:
|
||||
append_event(conn, kind="bot_authored", payload=_bot_payload("bot_a", "BotA"))
|
||||
project(conn)
|
||||
# Take a snapshot to establish a recent baseline.
|
||||
take_snapshot(conn, data_dir=tmp_path / "data", kind="periodic")
|
||||
# Right after — should be False (within time threshold and < 100 new events).
|
||||
assert should_take_periodic_snapshot(conn, tmp_path / "data") is False
|
||||
|
||||
|
||||
def test_prune_keeps_last_5(tmp_path):
|
||||
snapshot_dir = tmp_path / "data" / "snapshots" / "periodic"
|
||||
snapshot_dir.mkdir(parents=True)
|
||||
# Create 8 dummy snapshot files with sortable names.
|
||||
for i in range(8):
|
||||
p = snapshot_dir / f"2026010{i}T000000Z.json"
|
||||
p.write_text(json.dumps({"last_event_id": i}))
|
||||
removed = prune_periodic_snapshots(tmp_path / "data", keep=5)
|
||||
assert removed == 3
|
||||
remaining = sorted(snapshot_dir.glob("*.json"))
|
||||
assert len(remaining) == 5
|
||||
# The 5 most recent (highest names) should remain.
|
||||
assert remaining[0].name == "20260103T000000Z.json"
|
||||
assert remaining[-1].name == "20260107T000000Z.json"
|
||||
|
||||
|
||||
def test_latest_snapshot_path_returns_none_when_missing(tmp_path):
|
||||
# No directory yet.
|
||||
assert latest_snapshot_path(tmp_path / "data", kind="periodic") is None
|
||||
# Empty directory.
|
||||
(tmp_path / "data" / "snapshots" / "periodic").mkdir(parents=True)
|
||||
assert latest_snapshot_path(tmp_path / "data", kind="periodic") is None
|
||||
|
||||
|
||||
def test_restore_from_snapshot_repopulates_tables(tmp_path):
|
||||
# Source DB: seed a bot, snapshot it.
|
||||
db1 = tmp_path / "t1.db"
|
||||
apply_migrations(db1)
|
||||
with open_db(db1) as conn:
|
||||
append_event(conn, kind="bot_authored", payload=_bot_payload("bot_a", "BotA"))
|
||||
project(conn)
|
||||
snapshot_path = take_snapshot(
|
||||
conn, data_dir=tmp_path / "data", kind="periodic"
|
||||
)
|
||||
|
||||
# Fresh DB — restore from the snapshot, no event-log replay needed.
|
||||
db2 = tmp_path / "t2.db"
|
||||
apply_migrations(db2)
|
||||
with open_db(db2) as conn:
|
||||
last_id = restore_from_snapshot(conn, snapshot_path)
|
||||
assert last_id >= 1
|
||||
bot = conn.execute(
|
||||
"SELECT name FROM bots WHERE id = 'bot_a'"
|
||||
).fetchone()
|
||||
assert bot is not None
|
||||
assert bot[0] == "BotA"
|
||||
Reference in New Issue
Block a user