From 30e664812216c688e8a6fd80e8299abd161fd653 Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 26 Apr 2026 11:56:32 -0400 Subject: [PATCH] feat: memory schema with witness flags and FTS5 index --- chat/db/migrations/0006_memories.sql | 35 ++++ chat/state/memory.py | 96 +++++++++++ tests/test_memory.py | 229 +++++++++++++++++++++++++++ 3 files changed, 360 insertions(+) create mode 100644 chat/db/migrations/0006_memories.sql create mode 100644 chat/state/memory.py create mode 100644 tests/test_memory.py diff --git a/chat/db/migrations/0006_memories.sql b/chat/db/migrations/0006_memories.sql new file mode 100644 index 0000000..a9c9fa4 --- /dev/null +++ b/chat/db/migrations/0006_memories.sql @@ -0,0 +1,35 @@ +CREATE TABLE memories ( + id INTEGER PRIMARY KEY, + owner_id TEXT NOT NULL, + chat_id TEXT NOT NULL, + scene_id INTEGER, + pov_summary TEXT NOT NULL, + witness_you INTEGER NOT NULL, + witness_host INTEGER NOT NULL, + witness_guest INTEGER NOT NULL, + chat_clock_at TEXT, + source TEXT, + reliability REAL NOT NULL DEFAULT 1.0, + significance INTEGER NOT NULL DEFAULT 1, + pinned INTEGER NOT NULL DEFAULT 0, + auto_pinned INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL DEFAULT (datetime('now')) +); +CREATE INDEX idx_memories_owner ON memories(owner_id); + +CREATE VIRTUAL TABLE memories_fts USING fts5( + pov_summary, content='memories', content_rowid='id' +); + +CREATE TRIGGER memories_ai AFTER INSERT ON memories BEGIN + INSERT INTO memories_fts(rowid, pov_summary) VALUES (new.id, new.pov_summary); +END; +CREATE TRIGGER memories_au AFTER UPDATE ON memories BEGIN + INSERT INTO memories_fts(memories_fts, rowid, pov_summary) + VALUES('delete', old.id, old.pov_summary); + INSERT INTO memories_fts(rowid, pov_summary) VALUES (new.id, new.pov_summary); +END; +CREATE TRIGGER memories_ad AFTER DELETE ON memories BEGIN + INSERT INTO memories_fts(memories_fts, rowid, pov_summary) + VALUES('delete', old.id, old.pov_summary); +END; diff --git a/chat/state/memory.py b/chat/state/memory.py new file mode 100644 index 0000000..b2655b0 --- /dev/null +++ b/chat/state/memory.py @@ -0,0 +1,96 @@ +from __future__ import annotations +from sqlite3 import Connection +from chat.eventlog.projector import on +from chat.eventlog.log import Event + +_VALID_WITNESS_ROLES = {"you", "host", "guest"} + + +def _row_to_dict(conn: Connection, row: tuple) -> dict: + cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()] + return dict(zip(cols, row)) + + +@on("memory_written") +def _apply_memory_written(conn: Connection, e: Event) -> None: + p = e.payload + conn.execute( + "INSERT INTO memories (" + "owner_id, chat_id, scene_id, pov_summary, " + "witness_you, witness_host, witness_guest, " + "chat_clock_at, source, reliability, significance, pinned, auto_pinned" + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + p["owner_id"], + p["chat_id"], + p.get("scene_id"), + p["pov_summary"], + int(p["witness_you"]), + int(p["witness_host"]), + int(p["witness_guest"]), + p.get("chat_clock_at"), + p.get("source", "direct"), + float(p.get("reliability", 1.0)), + int(p.get("significance", 1)), + int(p.get("pinned", 0)), + int(p.get("auto_pinned", 0)), + ), + ) + + +def get_memory(conn: Connection, memory_id: int) -> dict | None: + row = conn.execute( + "SELECT * FROM memories WHERE id = ?", (memory_id,) + ).fetchone() + if not row: + return None + return _row_to_dict(conn, row) + + +def get_pinned(conn: Connection, owner_id: str) -> list[dict]: + cur = conn.execute( + "SELECT * FROM memories WHERE owner_id = ? AND pinned = 1 " + "ORDER BY created_at DESC, id DESC", + (owner_id,), + ) + rows = cur.fetchall() + cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()] + return [dict(zip(cols, row)) for row in rows] + + +def search_memories( + conn: Connection, + owner_id: str, + witness_role: str, + query: str, + k: int = 4, +) -> list[dict]: + """FTS5 search over pov_summary, scoped by owner and witness role. + + witness_role must be one of {"you", "host", "guest"} per the witness flags + on each memory row. Returns up to k rows ordered by FTS5 bm25 rank. + """ + if witness_role not in _VALID_WITNESS_ROLES: + raise ValueError( + f"witness_role must be one of {sorted(_VALID_WITNESS_ROLES)}, " + f"got {witness_role!r}" + ) + witness_col = f"witness_{witness_role}" + cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()] + select_list = ", ".join(f"m.{c}" for c in cols) + sql = ( + f"SELECT {select_list}, memories_fts.rank AS fts_rank " + "FROM memories_fts " + "JOIN memories m ON m.id = memories_fts.rowid " + f"WHERE m.owner_id = ? AND m.{witness_col} = 1 " + "AND memories_fts MATCH ? " + "ORDER BY memories_fts.rank " + "LIMIT ?" + ) + cur = conn.execute(sql, (owner_id, query, k)) + rows = cur.fetchall() + out: list[dict] = [] + result_cols = cols + ["fts_rank"] + for row in rows: + out.append(dict(zip(result_cols, row))) + return out diff --git a/tests/test_memory.py b/tests/test_memory.py new file mode 100644 index 0000000..3d174b0 --- /dev/null +++ b/tests/test_memory.py @@ -0,0 +1,229 @@ +from __future__ import annotations +import pytest + +from chat.db.connection import open_db +from chat.db.migrate import apply_migrations +from chat.eventlog.log import append_event +from chat.eventlog.projector import project +import chat.state.memory # registers memory_written handler +from chat.state.memory import get_memory, get_pinned, search_memories + + +def _base_memory(**overrides): + payload = { + "owner_id": "bot_a", + "chat_id": "chat_bot_a", + "scene_id": 1, + "pov_summary": "She laughed at his joke about owls.", + "witness_you": 1, + "witness_host": 1, + "witness_guest": 0, + "chat_clock_at": "2026-04-26T10:00:00", + "source": "direct", + "reliability": 1.0, + "significance": 1, + "pinned": 0, + "auto_pinned": 0, + } + payload.update(overrides) + return payload + + +def test_memory_written_is_projected_and_readable(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + append_event(conn, kind="memory_written", payload=_base_memory()) + project(conn) + row = conn.execute("SELECT id FROM memories").fetchone() + assert row is not None + mem = get_memory(conn, row[0]) + assert mem is not None + assert mem["owner_id"] == "bot_a" + assert mem["chat_id"] == "chat_bot_a" + assert mem["scene_id"] == 1 + assert mem["pov_summary"] == "She laughed at his joke about owls." + assert mem["witness_you"] == 1 + assert mem["witness_host"] == 1 + assert mem["witness_guest"] == 0 + assert mem["chat_clock_at"] == "2026-04-26T10:00:00" + assert mem["source"] == "direct" + assert mem["reliability"] == 1.0 + assert mem["significance"] == 1 + assert mem["pinned"] == 0 + assert mem["auto_pinned"] == 0 + + +def test_get_memory_returns_none_for_missing_id(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + assert get_memory(conn, 9999) is None + + +def test_search_memories_filters_out_non_witness(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + append_event(conn, kind="memory_written", payload=_base_memory( + pov_summary="The cat sat on the mat.", + witness_you=1, witness_host=1, witness_guest=0, + )) + project(conn) + # guest did not witness => excluded + results = search_memories(conn, "bot_a", "guest", "cat") + assert results == [] + + +def test_search_memories_includes_witnesses(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + append_event(conn, kind="memory_written", payload=_base_memory( + pov_summary="The cat sat on the mat.", + witness_you=1, witness_host=1, witness_guest=0, + )) + project(conn) + results = search_memories(conn, "bot_a", "host", "cat") + assert len(results) == 1 + assert results[0]["pov_summary"] == "The cat sat on the mat." + assert "fts_rank" in results[0] + + +def test_search_memories_fts_matches_only_relevant_text(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + append_event(conn, kind="memory_written", payload=_base_memory( + pov_summary="She loves owls and stars.", + )) + append_event(conn, kind="memory_written", payload=_base_memory( + pov_summary="He fixed the broken kettle.", + )) + project(conn) + results = search_memories(conn, "bot_a", "you", "owls") + assert len(results) == 1 + assert results[0]["pov_summary"] == "She loves owls and stars." + + +def test_search_memories_filters_by_owner(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + append_event(conn, kind="memory_written", payload=_base_memory( + owner_id="bot_a", + pov_summary="Owls hooted at midnight.", + )) + append_event(conn, kind="memory_written", payload=_base_memory( + owner_id="bot_b", + pov_summary="Owls hooted at midnight.", + )) + project(conn) + results_a = search_memories(conn, "bot_a", "you", "owls") + results_b = search_memories(conn, "bot_b", "you", "owls") + assert len(results_a) == 1 + assert results_a[0]["owner_id"] == "bot_a" + assert len(results_b) == 1 + assert results_b[0]["owner_id"] == "bot_b" + + +def test_search_memories_returns_empty_on_no_match(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + append_event(conn, kind="memory_written", payload=_base_memory( + pov_summary="The cat sat on the mat.", + )) + project(conn) + results = search_memories(conn, "bot_a", "you", "spaceship") + assert results == [] + + +def test_search_memories_invalid_witness_role_raises(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + with pytest.raises(ValueError): + search_memories(conn, "bot_a", "everyone", "cat") + + +def test_search_memories_respects_k_limit(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + for i in range(6): + append_event(conn, kind="memory_written", payload=_base_memory( + pov_summary=f"Owls hooted at midnight number {i}.", + )) + project(conn) + results = search_memories(conn, "bot_a", "you", "owls", k=4) + assert len(results) == 4 + + +def test_get_pinned_returns_only_pinned(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + append_event(conn, kind="memory_written", payload=_base_memory( + pov_summary="Pinned moment.", + pinned=1, + )) + append_event(conn, kind="memory_written", payload=_base_memory( + pov_summary="Unpinned moment.", + pinned=0, + )) + project(conn) + pinned = get_pinned(conn, "bot_a") + assert len(pinned) == 1 + assert pinned[0]["pov_summary"] == "Pinned moment." + assert pinned[0]["pinned"] == 1 + + +def test_get_pinned_filters_by_owner(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + append_event(conn, kind="memory_written", payload=_base_memory( + owner_id="bot_a", pov_summary="A's pin.", pinned=1, + )) + append_event(conn, kind="memory_written", payload=_base_memory( + owner_id="bot_b", pov_summary="B's pin.", pinned=1, + )) + project(conn) + pinned_a = get_pinned(conn, "bot_a") + assert len(pinned_a) == 1 + assert pinned_a[0]["owner_id"] == "bot_a" + + +def test_memory_payload_defaults_when_optional_missing(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + append_event(conn, kind="memory_written", payload={ + "owner_id": "bot_a", + "chat_id": "chat_bot_a", + "pov_summary": "Bare minimum memory.", + "witness_you": 1, + "witness_host": 1, + "witness_guest": 1, + }) + project(conn) + row = conn.execute("SELECT id FROM memories").fetchone() + mem = get_memory(conn, row[0]) + assert mem["scene_id"] is None + assert mem["chat_clock_at"] is None + assert mem["source"] == "direct" + assert mem["reliability"] == 1.0 + assert mem["significance"] == 1 + assert mem["pinned"] == 0 + assert mem["auto_pinned"] == 0 + + +def test_schema_version_after_migration_is_6(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + row = conn.execute( + "SELECT value FROM meta WHERE key = 'schema_version'" + ).fetchone() + assert int(row[0]) == 6