180 lines
6.9 KiB
Python
180 lines
6.9 KiB
Python
from __future__ import annotations
|
||
from sqlite3 import Connection
|
||
from chat.eventlog.projector import on
|
||
from chat.eventlog.log import Event
|
||
|
||
_VALID_WITNESS_ROLES = {"you", "host", "guest"}
|
||
|
||
|
||
def _row_to_dict(conn: Connection, row: tuple) -> dict:
|
||
cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()]
|
||
return dict(zip(cols, row))
|
||
|
||
|
||
@on("memory_written")
|
||
def _apply_memory_written(conn: Connection, e: Event) -> None:
|
||
p = e.payload
|
||
conn.execute(
|
||
"INSERT INTO memories ("
|
||
"owner_id, chat_id, scene_id, pov_summary, "
|
||
"witness_you, witness_host, witness_guest, "
|
||
"chat_clock_at, source, reliability, significance, pinned, auto_pinned"
|
||
") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||
(
|
||
p["owner_id"],
|
||
p["chat_id"],
|
||
p.get("scene_id"),
|
||
p["pov_summary"],
|
||
int(p["witness_you"]),
|
||
int(p["witness_host"]),
|
||
int(p["witness_guest"]),
|
||
p.get("chat_clock_at"),
|
||
p.get("source", "direct"),
|
||
float(p.get("reliability", 1.0)),
|
||
int(p.get("significance", 1)),
|
||
int(p.get("pinned", 0)),
|
||
int(p.get("auto_pinned", 0)),
|
||
),
|
||
)
|
||
|
||
|
||
@on("memory_significance_set")
|
||
def _apply_memory_significance_set(conn: Connection, e: Event) -> None:
|
||
"""Update an existing memory's significance score (T22).
|
||
|
||
Emitted by the async significance worker after it scores the turn.
|
||
"""
|
||
p = e.payload
|
||
conn.execute(
|
||
"UPDATE memories SET significance = ? WHERE id = ?",
|
||
(int(p["significance"]), int(p["memory_id"])),
|
||
)
|
||
|
||
|
||
@on("memory_pin_changed")
|
||
def _apply_memory_pin_changed(conn: Connection, e: Event) -> None:
|
||
"""Toggle a memory's pin state (T22, §8.5).
|
||
|
||
Used both for auto-pinning a pivotal turn and for evicting the oldest
|
||
auto-pin when the per-owner soft cap is exceeded. Manual pins use the
|
||
same handler; the ``auto_pinned`` flag distinguishes them so the
|
||
eviction query can leave manual pins alone.
|
||
"""
|
||
p = e.payload
|
||
conn.execute(
|
||
"UPDATE memories SET pinned = ?, auto_pinned = ? WHERE id = ?",
|
||
(int(p["pinned"]), int(p["auto_pinned"]), int(p["memory_id"])),
|
||
)
|
||
|
||
|
||
def get_memory(conn: Connection, memory_id: int) -> dict | None:
|
||
row = conn.execute(
|
||
"SELECT * FROM memories WHERE id = ?", (memory_id,)
|
||
).fetchone()
|
||
if not row:
|
||
return None
|
||
return _row_to_dict(conn, row)
|
||
|
||
|
||
def get_pinned(conn: Connection, owner_id: str) -> list[dict]:
|
||
cur = conn.execute(
|
||
"SELECT * FROM memories WHERE owner_id = ? AND pinned = 1 "
|
||
"ORDER BY created_at DESC, id DESC",
|
||
(owner_id,),
|
||
)
|
||
rows = cur.fetchall()
|
||
cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()]
|
||
return [dict(zip(cols, row)) for row in rows]
|
||
|
||
|
||
# Composite-score weights used by ``search_memories`` (T23, §8 retrieval).
|
||
# FTS5 BM25 ``rank`` is *more negative* for better matches, so subtracting a
|
||
# positive boost from it drives stronger candidates further down (i.e. earlier
|
||
# in an ascending sort). Hardcoded for v1 — tunable in a later pass.
|
||
_SIGNIFICANCE_WEIGHT = 0.3
|
||
_RECENCY_WEIGHT = 0.5
|
||
|
||
# T57 (Phase 3, §11.1): significance multiplier applied to the SQL ORDER BY in
|
||
# ``search_memories`` so that the FTS over-fetch already prefers
|
||
# higher-significance rows for tied / near-tied BM25 ranks. Module-level so it
|
||
# can be tuned without a code change. BM25 ``rank`` is lower-is-better, so the
|
||
# bias is *subtracted* from rank in the ASC ordering — equivalent to multiplying
|
||
# a higher-is-better score by a positive constant per the spec wording.
|
||
SIGNIFICANCE_RANK_BIAS = 0.5
|
||
|
||
|
||
def search_memories(
|
||
conn: Connection,
|
||
owner_id: str,
|
||
witness_role: str,
|
||
query: str,
|
||
k: int = 4,
|
||
) -> list[dict]:
|
||
"""FTS5 search over pov_summary, scoped by owner and witness role.
|
||
|
||
witness_role must be one of {"you", "host", "guest"} per the witness flags
|
||
on each memory row. Returns up to ``k`` rows ranked by a composite score
|
||
that combines the FTS5 BM25 rank with two boosts (§8 retrieval rules):
|
||
|
||
* **significance boost** — ``0.3 * significance`` (0..3 per §11.1).
|
||
* **recency boost** — ``0.5 * (id / max_id)``, using the row id as a
|
||
monotonic recency proxy. Newer memories therefore tilt above older ones
|
||
when the BM25 rank and significance are otherwise tied.
|
||
|
||
BM25 returns negative scores (lower = better). Both boosts are subtracted
|
||
so that stronger candidates yield smaller composite scores; the result is
|
||
sorted ascending and truncated to ``k``. The unmodified ``fts_rank`` and a
|
||
debug-friendly ``composite_score`` are kept on each returned dict.
|
||
"""
|
||
if witness_role not in _VALID_WITNESS_ROLES:
|
||
raise ValueError(
|
||
f"witness_role must be one of {sorted(_VALID_WITNESS_ROLES)}, "
|
||
f"got {witness_role!r}"
|
||
)
|
||
if not query.strip():
|
||
return []
|
||
witness_col = f"witness_{witness_role}"
|
||
cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()]
|
||
select_list = ", ".join(f"m.{c}" for c in cols)
|
||
# Over-fetch from FTS so the Python-side re-rank has room to reorder
|
||
# results that BM25 alone would have demoted past the top-k boundary.
|
||
over_fetch = max(k * 4, 20)
|
||
sql = (
|
||
f"SELECT {select_list}, memories_fts.rank AS fts_rank "
|
||
"FROM memories_fts "
|
||
"JOIN memories m ON m.id = memories_fts.rowid "
|
||
f"WHERE m.owner_id = ? AND m.{witness_col} = 1 "
|
||
"AND memories_fts MATCH ? "
|
||
# T57: significance multiplier biases the FTS over-fetch order. BM25
|
||
# ``rank`` is lower-is-better, so subtracting ``significance * BIAS``
|
||
# surfaces higher-significance rows above lower-significance rows with
|
||
# equal/near-equal match strength. Equivalent to ``score × constant``
|
||
# per §11.1 once the rank is inverted to a higher-is-better score.
|
||
"ORDER BY (memories_fts.rank - m.significance * ?) ASC "
|
||
"LIMIT ?"
|
||
)
|
||
cur = conn.execute(sql, (owner_id, query, SIGNIFICANCE_RANK_BIAS, over_fetch))
|
||
rows = cur.fetchall()
|
||
if not rows:
|
||
return []
|
||
|
||
# Recency normalises against the current max id for this owner so the
|
||
# boost magnitude is bounded regardless of dataset size.
|
||
max_id_row = conn.execute(
|
||
"SELECT MAX(id) FROM memories WHERE owner_id = ?", (owner_id,)
|
||
).fetchone()
|
||
max_id = max_id_row[0] if max_id_row and max_id_row[0] else 1
|
||
|
||
result_cols = cols + ["fts_rank"]
|
||
enriched: list[dict] = []
|
||
for row in rows:
|
||
d = dict(zip(result_cols, row))
|
||
fts_rank = d.get("fts_rank") or 0.0
|
||
sig_boost = _SIGNIFICANCE_WEIGHT * (d.get("significance") or 0)
|
||
recency_boost = _RECENCY_WEIGHT * ((d.get("id") or 0) / max_id)
|
||
d["composite_score"] = fts_rank - sig_boost - recency_boost
|
||
enriched.append(d)
|
||
|
||
enriched.sort(key=lambda x: x["composite_score"])
|
||
return enriched[:k]
|