feat: memory schema with witness flags and FTS5 index

This commit is contained in:
Joseph Doherty
2026-04-26 11:56:32 -04:00
parent bc97d425ef
commit 30e6648122
3 changed files with 360 additions and 0 deletions
+35
View File
@@ -0,0 +1,35 @@
CREATE TABLE memories (
id INTEGER PRIMARY KEY,
owner_id TEXT NOT NULL,
chat_id TEXT NOT NULL,
scene_id INTEGER,
pov_summary TEXT NOT NULL,
witness_you INTEGER NOT NULL,
witness_host INTEGER NOT NULL,
witness_guest INTEGER NOT NULL,
chat_clock_at TEXT,
source TEXT,
reliability REAL NOT NULL DEFAULT 1.0,
significance INTEGER NOT NULL DEFAULT 1,
pinned INTEGER NOT NULL DEFAULT 0,
auto_pinned INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE INDEX idx_memories_owner ON memories(owner_id);
CREATE VIRTUAL TABLE memories_fts USING fts5(
pov_summary, content='memories', content_rowid='id'
);
CREATE TRIGGER memories_ai AFTER INSERT ON memories BEGIN
INSERT INTO memories_fts(rowid, pov_summary) VALUES (new.id, new.pov_summary);
END;
CREATE TRIGGER memories_au AFTER UPDATE ON memories BEGIN
INSERT INTO memories_fts(memories_fts, rowid, pov_summary)
VALUES('delete', old.id, old.pov_summary);
INSERT INTO memories_fts(rowid, pov_summary) VALUES (new.id, new.pov_summary);
END;
CREATE TRIGGER memories_ad AFTER DELETE ON memories BEGIN
INSERT INTO memories_fts(memories_fts, rowid, pov_summary)
VALUES('delete', old.id, old.pov_summary);
END;
+96
View File
@@ -0,0 +1,96 @@
from __future__ import annotations
from sqlite3 import Connection
from chat.eventlog.projector import on
from chat.eventlog.log import Event
_VALID_WITNESS_ROLES = {"you", "host", "guest"}
def _row_to_dict(conn: Connection, row: tuple) -> dict:
cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()]
return dict(zip(cols, row))
@on("memory_written")
def _apply_memory_written(conn: Connection, e: Event) -> None:
p = e.payload
conn.execute(
"INSERT INTO memories ("
"owner_id, chat_id, scene_id, pov_summary, "
"witness_you, witness_host, witness_guest, "
"chat_clock_at, source, reliability, significance, pinned, auto_pinned"
") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
p["owner_id"],
p["chat_id"],
p.get("scene_id"),
p["pov_summary"],
int(p["witness_you"]),
int(p["witness_host"]),
int(p["witness_guest"]),
p.get("chat_clock_at"),
p.get("source", "direct"),
float(p.get("reliability", 1.0)),
int(p.get("significance", 1)),
int(p.get("pinned", 0)),
int(p.get("auto_pinned", 0)),
),
)
def get_memory(conn: Connection, memory_id: int) -> dict | None:
row = conn.execute(
"SELECT * FROM memories WHERE id = ?", (memory_id,)
).fetchone()
if not row:
return None
return _row_to_dict(conn, row)
def get_pinned(conn: Connection, owner_id: str) -> list[dict]:
cur = conn.execute(
"SELECT * FROM memories WHERE owner_id = ? AND pinned = 1 "
"ORDER BY created_at DESC, id DESC",
(owner_id,),
)
rows = cur.fetchall()
cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()]
return [dict(zip(cols, row)) for row in rows]
def search_memories(
conn: Connection,
owner_id: str,
witness_role: str,
query: str,
k: int = 4,
) -> list[dict]:
"""FTS5 search over pov_summary, scoped by owner and witness role.
witness_role must be one of {"you", "host", "guest"} per the witness flags
on each memory row. Returns up to k rows ordered by FTS5 bm25 rank.
"""
if witness_role not in _VALID_WITNESS_ROLES:
raise ValueError(
f"witness_role must be one of {sorted(_VALID_WITNESS_ROLES)}, "
f"got {witness_role!r}"
)
witness_col = f"witness_{witness_role}"
cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()]
select_list = ", ".join(f"m.{c}" for c in cols)
sql = (
f"SELECT {select_list}, memories_fts.rank AS fts_rank "
"FROM memories_fts "
"JOIN memories m ON m.id = memories_fts.rowid "
f"WHERE m.owner_id = ? AND m.{witness_col} = 1 "
"AND memories_fts MATCH ? "
"ORDER BY memories_fts.rank "
"LIMIT ?"
)
cur = conn.execute(sql, (owner_id, query, k))
rows = cur.fetchall()
out: list[dict] = []
result_cols = cols + ["fts_rank"]
for row in rows:
out.append(dict(zip(result_cols, row)))
return out
+229
View File
@@ -0,0 +1,229 @@
from __future__ import annotations
import pytest
from chat.db.connection import open_db
from chat.db.migrate import apply_migrations
from chat.eventlog.log import append_event
from chat.eventlog.projector import project
import chat.state.memory # registers memory_written handler
from chat.state.memory import get_memory, get_pinned, search_memories
def _base_memory(**overrides):
payload = {
"owner_id": "bot_a",
"chat_id": "chat_bot_a",
"scene_id": 1,
"pov_summary": "She laughed at his joke about owls.",
"witness_you": 1,
"witness_host": 1,
"witness_guest": 0,
"chat_clock_at": "2026-04-26T10:00:00",
"source": "direct",
"reliability": 1.0,
"significance": 1,
"pinned": 0,
"auto_pinned": 0,
}
payload.update(overrides)
return payload
def test_memory_written_is_projected_and_readable(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
append_event(conn, kind="memory_written", payload=_base_memory())
project(conn)
row = conn.execute("SELECT id FROM memories").fetchone()
assert row is not None
mem = get_memory(conn, row[0])
assert mem is not None
assert mem["owner_id"] == "bot_a"
assert mem["chat_id"] == "chat_bot_a"
assert mem["scene_id"] == 1
assert mem["pov_summary"] == "She laughed at his joke about owls."
assert mem["witness_you"] == 1
assert mem["witness_host"] == 1
assert mem["witness_guest"] == 0
assert mem["chat_clock_at"] == "2026-04-26T10:00:00"
assert mem["source"] == "direct"
assert mem["reliability"] == 1.0
assert mem["significance"] == 1
assert mem["pinned"] == 0
assert mem["auto_pinned"] == 0
def test_get_memory_returns_none_for_missing_id(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
assert get_memory(conn, 9999) is None
def test_search_memories_filters_out_non_witness(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
append_event(conn, kind="memory_written", payload=_base_memory(
pov_summary="The cat sat on the mat.",
witness_you=1, witness_host=1, witness_guest=0,
))
project(conn)
# guest did not witness => excluded
results = search_memories(conn, "bot_a", "guest", "cat")
assert results == []
def test_search_memories_includes_witnesses(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
append_event(conn, kind="memory_written", payload=_base_memory(
pov_summary="The cat sat on the mat.",
witness_you=1, witness_host=1, witness_guest=0,
))
project(conn)
results = search_memories(conn, "bot_a", "host", "cat")
assert len(results) == 1
assert results[0]["pov_summary"] == "The cat sat on the mat."
assert "fts_rank" in results[0]
def test_search_memories_fts_matches_only_relevant_text(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
append_event(conn, kind="memory_written", payload=_base_memory(
pov_summary="She loves owls and stars.",
))
append_event(conn, kind="memory_written", payload=_base_memory(
pov_summary="He fixed the broken kettle.",
))
project(conn)
results = search_memories(conn, "bot_a", "you", "owls")
assert len(results) == 1
assert results[0]["pov_summary"] == "She loves owls and stars."
def test_search_memories_filters_by_owner(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
append_event(conn, kind="memory_written", payload=_base_memory(
owner_id="bot_a",
pov_summary="Owls hooted at midnight.",
))
append_event(conn, kind="memory_written", payload=_base_memory(
owner_id="bot_b",
pov_summary="Owls hooted at midnight.",
))
project(conn)
results_a = search_memories(conn, "bot_a", "you", "owls")
results_b = search_memories(conn, "bot_b", "you", "owls")
assert len(results_a) == 1
assert results_a[0]["owner_id"] == "bot_a"
assert len(results_b) == 1
assert results_b[0]["owner_id"] == "bot_b"
def test_search_memories_returns_empty_on_no_match(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
append_event(conn, kind="memory_written", payload=_base_memory(
pov_summary="The cat sat on the mat.",
))
project(conn)
results = search_memories(conn, "bot_a", "you", "spaceship")
assert results == []
def test_search_memories_invalid_witness_role_raises(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
with pytest.raises(ValueError):
search_memories(conn, "bot_a", "everyone", "cat")
def test_search_memories_respects_k_limit(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
for i in range(6):
append_event(conn, kind="memory_written", payload=_base_memory(
pov_summary=f"Owls hooted at midnight number {i}.",
))
project(conn)
results = search_memories(conn, "bot_a", "you", "owls", k=4)
assert len(results) == 4
def test_get_pinned_returns_only_pinned(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
append_event(conn, kind="memory_written", payload=_base_memory(
pov_summary="Pinned moment.",
pinned=1,
))
append_event(conn, kind="memory_written", payload=_base_memory(
pov_summary="Unpinned moment.",
pinned=0,
))
project(conn)
pinned = get_pinned(conn, "bot_a")
assert len(pinned) == 1
assert pinned[0]["pov_summary"] == "Pinned moment."
assert pinned[0]["pinned"] == 1
def test_get_pinned_filters_by_owner(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
append_event(conn, kind="memory_written", payload=_base_memory(
owner_id="bot_a", pov_summary="A's pin.", pinned=1,
))
append_event(conn, kind="memory_written", payload=_base_memory(
owner_id="bot_b", pov_summary="B's pin.", pinned=1,
))
project(conn)
pinned_a = get_pinned(conn, "bot_a")
assert len(pinned_a) == 1
assert pinned_a[0]["owner_id"] == "bot_a"
def test_memory_payload_defaults_when_optional_missing(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
append_event(conn, kind="memory_written", payload={
"owner_id": "bot_a",
"chat_id": "chat_bot_a",
"pov_summary": "Bare minimum memory.",
"witness_you": 1,
"witness_host": 1,
"witness_guest": 1,
})
project(conn)
row = conn.execute("SELECT id FROM memories").fetchone()
mem = get_memory(conn, row[0])
assert mem["scene_id"] is None
assert mem["chat_clock_at"] is None
assert mem["source"] == "direct"
assert mem["reliability"] == 1.0
assert mem["significance"] == 1
assert mem["pinned"] == 0
assert mem["auto_pinned"] == 0
def test_schema_version_after_migration_is_6(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
row = conn.execute(
"SELECT value FROM meta WHERE key = 'schema_version'"
).fetchone()
assert int(row[0]) == 6