Files
chat/chat/state/memory.py
T
2026-04-26 20:15:19 -04:00

180 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from sqlite3 import Connection
from chat.eventlog.projector import on
from chat.eventlog.log import Event
_VALID_WITNESS_ROLES = {"you", "host", "guest"}
def _row_to_dict(conn: Connection, row: tuple) -> dict:
cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()]
return dict(zip(cols, row))
@on("memory_written")
def _apply_memory_written(conn: Connection, e: Event) -> None:
p = e.payload
conn.execute(
"INSERT INTO memories ("
"owner_id, chat_id, scene_id, pov_summary, "
"witness_you, witness_host, witness_guest, "
"chat_clock_at, source, reliability, significance, pinned, auto_pinned"
") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
p["owner_id"],
p["chat_id"],
p.get("scene_id"),
p["pov_summary"],
int(p["witness_you"]),
int(p["witness_host"]),
int(p["witness_guest"]),
p.get("chat_clock_at"),
p.get("source", "direct"),
float(p.get("reliability", 1.0)),
int(p.get("significance", 1)),
int(p.get("pinned", 0)),
int(p.get("auto_pinned", 0)),
),
)
@on("memory_significance_set")
def _apply_memory_significance_set(conn: Connection, e: Event) -> None:
"""Update an existing memory's significance score (T22).
Emitted by the async significance worker after it scores the turn.
"""
p = e.payload
conn.execute(
"UPDATE memories SET significance = ? WHERE id = ?",
(int(p["significance"]), int(p["memory_id"])),
)
@on("memory_pin_changed")
def _apply_memory_pin_changed(conn: Connection, e: Event) -> None:
"""Toggle a memory's pin state (T22, §8.5).
Used both for auto-pinning a pivotal turn and for evicting the oldest
auto-pin when the per-owner soft cap is exceeded. Manual pins use the
same handler; the ``auto_pinned`` flag distinguishes them so the
eviction query can leave manual pins alone.
"""
p = e.payload
conn.execute(
"UPDATE memories SET pinned = ?, auto_pinned = ? WHERE id = ?",
(int(p["pinned"]), int(p["auto_pinned"]), int(p["memory_id"])),
)
def get_memory(conn: Connection, memory_id: int) -> dict | None:
row = conn.execute(
"SELECT * FROM memories WHERE id = ?", (memory_id,)
).fetchone()
if not row:
return None
return _row_to_dict(conn, row)
def get_pinned(conn: Connection, owner_id: str) -> list[dict]:
cur = conn.execute(
"SELECT * FROM memories WHERE owner_id = ? AND pinned = 1 "
"ORDER BY created_at DESC, id DESC",
(owner_id,),
)
rows = cur.fetchall()
cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()]
return [dict(zip(cols, row)) for row in rows]
# Composite-score weights used by ``search_memories`` (T23, §8 retrieval).
# FTS5 BM25 ``rank`` is *more negative* for better matches, so subtracting a
# positive boost from it drives stronger candidates further down (i.e. earlier
# in an ascending sort). Hardcoded for v1 — tunable in a later pass.
_SIGNIFICANCE_WEIGHT = 0.3
_RECENCY_WEIGHT = 0.5
# T57 (Phase 3, §11.1): significance multiplier applied to the SQL ORDER BY in
# ``search_memories`` so that the FTS over-fetch already prefers
# higher-significance rows for tied / near-tied BM25 ranks. Module-level so it
# can be tuned without a code change. BM25 ``rank`` is lower-is-better, so the
# bias is *subtracted* from rank in the ASC ordering — equivalent to multiplying
# a higher-is-better score by a positive constant per the spec wording.
SIGNIFICANCE_RANK_BIAS = 0.5
def search_memories(
conn: Connection,
owner_id: str,
witness_role: str,
query: str,
k: int = 4,
) -> list[dict]:
"""FTS5 search over pov_summary, scoped by owner and witness role.
witness_role must be one of {"you", "host", "guest"} per the witness flags
on each memory row. Returns up to ``k`` rows ranked by a composite score
that combines the FTS5 BM25 rank with two boosts (§8 retrieval rules):
* **significance boost** — ``0.3 * significance`` (0..3 per §11.1).
* **recency boost** — ``0.5 * (id / max_id)``, using the row id as a
monotonic recency proxy. Newer memories therefore tilt above older ones
when the BM25 rank and significance are otherwise tied.
BM25 returns negative scores (lower = better). Both boosts are subtracted
so that stronger candidates yield smaller composite scores; the result is
sorted ascending and truncated to ``k``. The unmodified ``fts_rank`` and a
debug-friendly ``composite_score`` are kept on each returned dict.
"""
if witness_role not in _VALID_WITNESS_ROLES:
raise ValueError(
f"witness_role must be one of {sorted(_VALID_WITNESS_ROLES)}, "
f"got {witness_role!r}"
)
if not query.strip():
return []
witness_col = f"witness_{witness_role}"
cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()]
select_list = ", ".join(f"m.{c}" for c in cols)
# Over-fetch from FTS so the Python-side re-rank has room to reorder
# results that BM25 alone would have demoted past the top-k boundary.
over_fetch = max(k * 4, 20)
sql = (
f"SELECT {select_list}, memories_fts.rank AS fts_rank "
"FROM memories_fts "
"JOIN memories m ON m.id = memories_fts.rowid "
f"WHERE m.owner_id = ? AND m.{witness_col} = 1 "
"AND memories_fts MATCH ? "
# T57: significance multiplier biases the FTS over-fetch order. BM25
# ``rank`` is lower-is-better, so subtracting ``significance * BIAS``
# surfaces higher-significance rows above lower-significance rows with
# equal/near-equal match strength. Equivalent to ``score × constant``
# per §11.1 once the rank is inverted to a higher-is-better score.
"ORDER BY (memories_fts.rank - m.significance * ?) ASC "
"LIMIT ?"
)
cur = conn.execute(sql, (owner_id, query, SIGNIFICANCE_RANK_BIAS, over_fetch))
rows = cur.fetchall()
if not rows:
return []
# Recency normalises against the current max id for this owner so the
# boost magnitude is bounded regardless of dataset size.
max_id_row = conn.execute(
"SELECT MAX(id) FROM memories WHERE owner_id = ?", (owner_id,)
).fetchone()
max_id = max_id_row[0] if max_id_row and max_id_row[0] else 1
result_cols = cols + ["fts_rank"]
enriched: list[dict] = []
for row in rows:
d = dict(zip(result_cols, row))
fts_rank = d.get("fts_rank") or 0.0
sig_boost = _SIGNIFICANCE_WEIGHT * (d.get("significance") or 0)
recency_boost = _RECENCY_WEIGHT * ((d.get("id") or 0) / max_id)
d["composite_score"] = fts_rank - sig_boost - recency_boost
enriched.append(d)
enriched.sort(key=lambda x: x["composite_score"])
return enriched[:k]