Files
chat/chat/state/memory.py
T

394 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from sqlite3 import Connection
from chat.eventlog.projector import on
from chat.eventlog.log import Event
_VALID_WITNESS_ROLES = {"you", "host", "guest"}
def _row_to_dict(conn: Connection, row: tuple) -> dict:
cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()]
return dict(zip(cols, row))
@on("memory_written")
def _apply_memory_written(conn: Connection, e: Event) -> None:
p = e.payload
conn.execute(
"INSERT INTO memories ("
"owner_id, chat_id, scene_id, pov_summary, "
"witness_you, witness_host, witness_guest, "
"chat_clock_at, source, reliability, significance, pinned, auto_pinned"
") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
p["owner_id"],
p["chat_id"],
p.get("scene_id"),
p["pov_summary"],
int(p["witness_you"]),
int(p["witness_host"]),
int(p["witness_guest"]),
p.get("chat_clock_at"),
p.get("source", "direct"),
float(p.get("reliability", 1.0)),
int(p.get("significance", 1)),
int(p.get("pinned", 0)),
int(p.get("auto_pinned", 0)),
),
)
@on("memory_significance_set")
def _apply_memory_significance_set(conn: Connection, e: Event) -> None:
"""Update an existing memory's significance score (T22).
Emitted by the async significance worker after it scores the turn.
"""
p = e.payload
conn.execute(
"UPDATE memories SET significance = ? WHERE id = ?",
(int(p["significance"]), int(p["memory_id"])),
)
@on("memory_pin_changed")
def _apply_memory_pin_changed(conn: Connection, e: Event) -> None:
"""Toggle a memory's pin state (T22, §8.5).
Used both for auto-pinning a pivotal turn and for evicting the oldest
auto-pin when the per-owner soft cap is exceeded. Manual pins use the
same handler; the ``auto_pinned`` flag distinguishes them so the
eviction query can leave manual pins alone.
"""
p = e.payload
conn.execute(
"UPDATE memories SET pinned = ?, auto_pinned = ? WHERE id = ?",
(int(p["pinned"]), int(p["auto_pinned"]), int(p["memory_id"])),
)
def get_memory(conn: Connection, memory_id: int) -> dict | None:
row = conn.execute(
"SELECT * FROM memories WHERE id = ?", (memory_id,)
).fetchone()
if not row:
return None
return _row_to_dict(conn, row)
def get_pinned(conn: Connection, owner_id: str) -> list[dict]:
cur = conn.execute(
"SELECT * FROM memories WHERE owner_id = ? AND pinned = 1 "
"ORDER BY created_at DESC, id DESC",
(owner_id,),
)
rows = cur.fetchall()
cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()]
return [dict(zip(cols, row)) for row in rows]
# Composite-score weights used by ``search_memories`` (T23, §8 retrieval).
# FTS5 BM25 ``rank`` is *more negative* for better matches, so subtracting a
# positive boost from it drives stronger candidates further down (i.e. earlier
# in an ascending sort). Hardcoded for v1 — tunable in a later pass.
_SIGNIFICANCE_WEIGHT = 0.3
_RECENCY_WEIGHT = 0.5
# T57 (Phase 3, §11.1): significance multiplier applied to the SQL ORDER BY in
# ``search_memories`` so that the FTS over-fetch already prefers
# higher-significance rows for tied / near-tied BM25 ranks. Module-level so it
# can be tuned without a code change. BM25 ``rank`` is lower-is-better, so the
# bias is *subtracted* from rank in the ASC ordering — equivalent to multiplying
# a higher-is-better score by a positive constant per the spec wording.
SIGNIFICANCE_RANK_BIAS = 0.5
# T96 (Phase 4): reciprocal-rank-fusion constant used when ``search_memories``
# is given a ``query_vector`` and must merge FTS + vector candidate lists. The
# value 60 is the canonical RRF constant from Cormack et al. ("Reciprocal Rank
# Fusion outperforms Condorcet and Individual Rank Learning Methods", SIGIR
# 2009): large enough to dampen the head of either ranking so that a strong
# top-1 in ranking A doesn't crowd out a moderate top-3 in ranking B, but
# small enough that the position-1/position-2 gap still matters.
RRF_CONST = 60
def _max_event_id(conn: Connection, owner_id: str) -> int:
"""Return the largest ``memories.id`` for ``owner_id`` (1 if none exist).
Used as the recency-boost denominator by both ``_composite_rerank`` and
``_rrf_fuse_and_rerank`` (T104). The row id is a monotonic recency proxy
— newer memories have larger ids — so dividing by the per-owner max keeps
the boost in [0, 1] regardless of how many memories the owner has.
Returns 1 (not 0) when the owner has no rows so callers can divide by
the result without a guard. The "no memories" case never actually hits
this helper because the FTS query above would have returned no rows,
but the safe default keeps the helper trivially reusable.
"""
row = conn.execute(
"SELECT MAX(id) FROM memories WHERE owner_id = ?", (owner_id,)
).fetchone()
return row[0] if row and row[0] else 1
def search_memories(
conn: Connection,
owner_id: str,
witness_role: str,
query: str,
k: int = 4,
*,
query_vector: list[float] | None = None,
) -> list[dict]:
"""FTS5 search over pov_summary, scoped by owner and witness role.
witness_role must be one of {"you", "host", "guest"} per the witness flags
on each memory row. Returns up to ``k`` rows ranked by a composite score
that combines the FTS5 BM25 rank with two boosts (§8 retrieval rules):
* **significance boost** — ``0.3 * significance`` (0..3 per §11.1).
* **recency boost** — ``0.5 * (id / max_id)``, using the row id as a
monotonic recency proxy. Newer memories therefore tilt above older ones
when the BM25 rank and significance are otherwise tied.
BM25 returns negative scores (lower = better). Both boosts are subtracted
so that stronger candidates yield smaller composite scores; the result is
sorted ascending and truncated to ``k``. The unmodified ``fts_rank`` and a
debug-friendly ``composite_score`` are kept on each returned dict.
The result ordering applies TWO independent significance boosts:
* **SQL-side** — ``ORDER BY (rank - significance * SIGNIFICANCE_RANK_BIAS)``
pushes higher-significance memories ahead in the FTS5 candidate set so
the over-fetch already prefers them for tied / near-tied BM25 ranks
(T57, §11.1).
* **Python-side** — a composite re-rank with ``_SIGNIFICANCE_WEIGHT``
reinforces the ordering after candidate retrieval, alongside the
recency boost above.
PHASE 4 EXTENSION (T96): when ``query_vector`` is provided, fuses FTS and
vector hits via reciprocal-rank fusion (RRF):
fusion_score = 1/(RRF_CONST + fts_rank) + 1/(RRF_CONST + vec_rank)
where ``fts_rank`` and ``vec_rank`` are the 0-indexed positions of the
memory in each candidate list. Each candidate gets the sum of its
reciprocal ranks across both rankings; memories appearing in only one
ranking still get a partial score (the other term is dropped). Both
candidate lists are over-fetched at ``k * 2`` so a memory dominant in
only one channel has a fair chance to surface. The Python-side
significance + recency re-rank is then applied as a final pass to
break ties in favour of more important / more recent memories.
When ``query_vector`` is None: FTS-only behaviour unchanged — all
Phase 1-3.5 callers see the same row shape and ordering as before.
**Row-shape contract (T104):** every returned dict carries an
``fts_rank`` key. For FTS hits this is the BM25 score (a negative float,
lower-is-better). For *vector-only* hits surfaced by the fused path —
rows that matched the query embedding but did NOT match FTS — the
``fts_rank`` value is ``None``. Downstream consumers must accept
``None`` here; do not assume ``fts_rank`` is always numeric. The
``composite_score`` is always a float on every returned row.
"""
if witness_role not in _VALID_WITNESS_ROLES:
raise ValueError(
f"witness_role must be one of {sorted(_VALID_WITNESS_ROLES)}, "
f"got {witness_role!r}"
)
if not query.strip():
return []
witness_col = f"witness_{witness_role}"
cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()]
select_list = ", ".join(f"m.{c}" for c in cols)
# Over-fetch from FTS so the Python-side re-rank has room to reorder
# results that BM25 alone would have demoted past the top-k boundary.
# When fusing with a vector ranking, we still over-fetch (k*2 from each
# channel) so memories that are weak in FTS but strong in vector — and
# vice versa — make it into the merge pool.
over_fetch = max(k * 2, 20) if query_vector is not None else max(k * 4, 20)
sql = (
f"SELECT {select_list}, memories_fts.rank AS fts_rank "
"FROM memories_fts "
"JOIN memories m ON m.id = memories_fts.rowid "
f"WHERE m.owner_id = ? AND m.{witness_col} = 1 "
"AND memories_fts MATCH ? "
# T57: significance multiplier biases the FTS over-fetch order. BM25
# ``rank`` is lower-is-better, so subtracting ``significance * BIAS``
# surfaces higher-significance rows above lower-significance rows with
# equal/near-equal match strength. Equivalent to ``score × constant``
# per §11.1 once the rank is inverted to a higher-is-better score.
"ORDER BY (memories_fts.rank - m.significance * ?) ASC "
"LIMIT ?"
)
cur = conn.execute(sql, (owner_id, query, SIGNIFICANCE_RANK_BIAS, over_fetch))
rows = cur.fetchall()
# FTS-only path: preserve pre-T96 behaviour exactly.
if query_vector is None:
if not rows:
return []
return _composite_rerank(conn, cols, rows, owner_id, k)
# Fused path: combine FTS candidates with vector candidates via RRF.
return _rrf_fuse_and_rerank(
conn,
cols=cols,
fts_rows=rows,
owner_id=owner_id,
witness_role=witness_role,
query_vector=query_vector,
k=k,
)
def _composite_rerank(
conn: Connection,
cols: list[str],
rows: list[tuple],
owner_id: str,
k: int,
) -> list[dict]:
"""Apply the significance + recency composite re-rank to FTS rows.
Extracted from ``search_memories`` so the no-vector path stays a single
call and the fused path can re-use the same boost formulae after RRF.
"""
max_id = _max_event_id(conn, owner_id)
result_cols = cols + ["fts_rank"]
enriched: list[dict] = []
for row in rows:
d = dict(zip(result_cols, row))
fts_rank = d.get("fts_rank") or 0.0
sig_boost = _SIGNIFICANCE_WEIGHT * (d.get("significance") or 0)
recency_boost = _RECENCY_WEIGHT * ((d.get("id") or 0) / max_id)
d["composite_score"] = fts_rank - sig_boost - recency_boost
enriched.append(d)
enriched.sort(key=lambda x: x["composite_score"])
return enriched[:k]
def _rrf_fuse_and_rerank(
conn: Connection,
*,
cols: list[str],
fts_rows: list[tuple],
owner_id: str,
witness_role: str,
query_vector: list[float],
k: int,
) -> list[dict]:
"""Merge FTS + vector candidates via reciprocal-rank fusion, then apply
the existing significance + recency boost as a final tie-breaker.
RRF formula (Cormack et al. 2009)::
fusion_score = sum over rankings r of 1 / (RRF_CONST + rank_r)
where ``rank_r`` is the 0-indexed position of the memory in ranking r.
"Missing from a ranking" is handled by SKIPPING the term for that
ranking — i.e. that channel contributes 0 to the sum, which preserves
the fairness property: a memory that only appears in one ranking is
not penalised relative to itself, just relative to memories that
appeared in both. This matches the canonical RRF presentation.
The final composite score subtracted from the *negated* fusion score
is::
composite = -fusion - sig_boost - recency_boost
Sorted ascending, smaller-is-better — the same ordering convention as
the FTS-only path so the Python-side significance + recency boosts
apply as a clean tie-breaker without inverting any sign.
"""
# Lazy import to avoid a hard module-level cycle: vector_search reads
# from chat.state.embeddings, which is itself a sibling of this module.
from chat.services.vector_search import vector_search
fts_rank_by_id: dict[int, int] = {}
fts_row_by_id: dict[int, tuple] = {}
id_idx = cols.index("id")
for rank, row in enumerate(fts_rows):
memory_id = row[id_idx]
fts_rank_by_id[memory_id] = rank
fts_row_by_id[memory_id] = row
# Over-fetch the vector channel symmetrically so each channel gets a
# fair shot at surfacing its strongest candidates.
vec_over_fetch = max(k * 2, 20)
vec_hits = vector_search(
conn,
owner_id=owner_id,
witness_role=witness_role,
query_vector=query_vector,
k=vec_over_fetch,
)
vec_rank_by_id: dict[int, int] = {
hit["memory_id"]: rank for rank, hit in enumerate(vec_hits)
}
# If the vector channel returned nothing (no embeddings indexed), the
# fused path collapses cleanly to the FTS-only path. No error, no
# surprise zero-hit return.
if not vec_rank_by_id and not fts_row_by_id:
return []
if not vec_rank_by_id:
return _composite_rerank(conn, cols, fts_rows, owner_id, k)
# For any vector-only hits we don't have a full memory row for yet,
# fetch them in a single round-trip. The FTS row carries an ``fts_rank``
# column at the end; vector-only rows get ``None`` there.
missing_ids = [mid for mid in vec_rank_by_id if mid not in fts_row_by_id]
select_list = ", ".join(cols)
if missing_ids:
placeholders = ",".join("?" * len(missing_ids))
cur = conn.execute(
f"SELECT {select_list} FROM memories WHERE id IN ({placeholders})",
missing_ids,
)
for row in cur.fetchall():
# Pad with a None for the trailing ``fts_rank`` slot so the row
# shape matches FTS rows downstream.
fts_row_by_id[row[id_idx]] = tuple(row) + (None,)
# Compute fusion score per candidate. Missing-from-ranking terms are
# simply omitted from the sum.
all_ids = set(fts_rank_by_id) | set(vec_rank_by_id)
fusion_by_id: dict[int, float] = {}
for mid in all_ids:
score = 0.0
if mid in fts_rank_by_id:
score += 1.0 / (RRF_CONST + fts_rank_by_id[mid])
if mid in vec_rank_by_id:
score += 1.0 / (RRF_CONST + vec_rank_by_id[mid])
fusion_by_id[mid] = score
# Final composite re-rank: significance + recency boosts on top of the
# negated fusion score so the sort direction matches the FTS-only path.
max_id = _max_event_id(conn, owner_id)
result_cols = cols + ["fts_rank"]
enriched: list[dict] = []
for mid in all_ids:
row = fts_row_by_id.get(mid)
if row is None:
# Defensive: a vector hit with no memory row would be a logic
# bug (vector_search joins memories), so just skip it rather
# than crash the whole search.
continue
d = dict(zip(result_cols, row))
sig_boost = _SIGNIFICANCE_WEIGHT * (d.get("significance") or 0)
recency_boost = _RECENCY_WEIGHT * ((d.get("id") or 0) / max_id)
fusion = fusion_by_id[mid]
# Sort ascending, smaller-is-better → negate fusion so a larger
# fusion score yields a smaller composite. Significance and recency
# boosts then act as tie-breakers exactly like the FTS-only path.
d["fusion_score"] = fusion
d["composite_score"] = -fusion - sig_boost - recency_boost
enriched.append(d)
enriched.sort(key=lambda x: x["composite_score"])
return enriched[:k]