feat: memory schema with witness flags and FTS5 index
This commit is contained in:
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
from sqlite3 import Connection
|
||||
from chat.eventlog.projector import on
|
||||
from chat.eventlog.log import Event
|
||||
|
||||
_VALID_WITNESS_ROLES = {"you", "host", "guest"}
|
||||
|
||||
|
||||
def _row_to_dict(conn: Connection, row: tuple) -> dict:
|
||||
cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()]
|
||||
return dict(zip(cols, row))
|
||||
|
||||
|
||||
@on("memory_written")
|
||||
def _apply_memory_written(conn: Connection, e: Event) -> None:
|
||||
p = e.payload
|
||||
conn.execute(
|
||||
"INSERT INTO memories ("
|
||||
"owner_id, chat_id, scene_id, pov_summary, "
|
||||
"witness_you, witness_host, witness_guest, "
|
||||
"chat_clock_at, source, reliability, significance, pinned, auto_pinned"
|
||||
") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
p["owner_id"],
|
||||
p["chat_id"],
|
||||
p.get("scene_id"),
|
||||
p["pov_summary"],
|
||||
int(p["witness_you"]),
|
||||
int(p["witness_host"]),
|
||||
int(p["witness_guest"]),
|
||||
p.get("chat_clock_at"),
|
||||
p.get("source", "direct"),
|
||||
float(p.get("reliability", 1.0)),
|
||||
int(p.get("significance", 1)),
|
||||
int(p.get("pinned", 0)),
|
||||
int(p.get("auto_pinned", 0)),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_memory(conn: Connection, memory_id: int) -> dict | None:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM memories WHERE id = ?", (memory_id,)
|
||||
).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return _row_to_dict(conn, row)
|
||||
|
||||
|
||||
def get_pinned(conn: Connection, owner_id: str) -> list[dict]:
|
||||
cur = conn.execute(
|
||||
"SELECT * FROM memories WHERE owner_id = ? AND pinned = 1 "
|
||||
"ORDER BY created_at DESC, id DESC",
|
||||
(owner_id,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()]
|
||||
return [dict(zip(cols, row)) for row in rows]
|
||||
|
||||
|
||||
def search_memories(
|
||||
conn: Connection,
|
||||
owner_id: str,
|
||||
witness_role: str,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
) -> list[dict]:
|
||||
"""FTS5 search over pov_summary, scoped by owner and witness role.
|
||||
|
||||
witness_role must be one of {"you", "host", "guest"} per the witness flags
|
||||
on each memory row. Returns up to k rows ordered by FTS5 bm25 rank.
|
||||
"""
|
||||
if witness_role not in _VALID_WITNESS_ROLES:
|
||||
raise ValueError(
|
||||
f"witness_role must be one of {sorted(_VALID_WITNESS_ROLES)}, "
|
||||
f"got {witness_role!r}"
|
||||
)
|
||||
witness_col = f"witness_{witness_role}"
|
||||
cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()]
|
||||
select_list = ", ".join(f"m.{c}" for c in cols)
|
||||
sql = (
|
||||
f"SELECT {select_list}, memories_fts.rank AS fts_rank "
|
||||
"FROM memories_fts "
|
||||
"JOIN memories m ON m.id = memories_fts.rowid "
|
||||
f"WHERE m.owner_id = ? AND m.{witness_col} = 1 "
|
||||
"AND memories_fts MATCH ? "
|
||||
"ORDER BY memories_fts.rank "
|
||||
"LIMIT ?"
|
||||
)
|
||||
cur = conn.execute(sql, (owner_id, query, k))
|
||||
rows = cur.fetchall()
|
||||
out: list[dict] = []
|
||||
result_cols = cols + ["fts_rank"]
|
||||
for row in rows:
|
||||
out.append(dict(zip(result_cols, row)))
|
||||
return out
|
||||
Reference in New Issue
Block a user