394 lines
16 KiB
Python
394 lines
16 KiB
Python
from __future__ import annotations
|
||
from sqlite3 import Connection
|
||
from chat.eventlog.projector import on
|
||
from chat.eventlog.log import Event
|
||
|
||
_VALID_WITNESS_ROLES = {"you", "host", "guest"}
|
||
|
||
|
||
def _row_to_dict(conn: Connection, row: tuple) -> dict:
|
||
cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()]
|
||
return dict(zip(cols, row))
|
||
|
||
|
||
@on("memory_written")
|
||
def _apply_memory_written(conn: Connection, e: Event) -> None:
|
||
p = e.payload
|
||
conn.execute(
|
||
"INSERT INTO memories ("
|
||
"owner_id, chat_id, scene_id, pov_summary, "
|
||
"witness_you, witness_host, witness_guest, "
|
||
"chat_clock_at, source, reliability, significance, pinned, auto_pinned"
|
||
") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||
(
|
||
p["owner_id"],
|
||
p["chat_id"],
|
||
p.get("scene_id"),
|
||
p["pov_summary"],
|
||
int(p["witness_you"]),
|
||
int(p["witness_host"]),
|
||
int(p["witness_guest"]),
|
||
p.get("chat_clock_at"),
|
||
p.get("source", "direct"),
|
||
float(p.get("reliability", 1.0)),
|
||
int(p.get("significance", 1)),
|
||
int(p.get("pinned", 0)),
|
||
int(p.get("auto_pinned", 0)),
|
||
),
|
||
)
|
||
|
||
|
||
@on("memory_significance_set")
|
||
def _apply_memory_significance_set(conn: Connection, e: Event) -> None:
|
||
"""Update an existing memory's significance score (T22).
|
||
|
||
Emitted by the async significance worker after it scores the turn.
|
||
"""
|
||
p = e.payload
|
||
conn.execute(
|
||
"UPDATE memories SET significance = ? WHERE id = ?",
|
||
(int(p["significance"]), int(p["memory_id"])),
|
||
)
|
||
|
||
|
||
@on("memory_pin_changed")
|
||
def _apply_memory_pin_changed(conn: Connection, e: Event) -> None:
|
||
"""Toggle a memory's pin state (T22, §8.5).
|
||
|
||
Used both for auto-pinning a pivotal turn and for evicting the oldest
|
||
auto-pin when the per-owner soft cap is exceeded. Manual pins use the
|
||
same handler; the ``auto_pinned`` flag distinguishes them so the
|
||
eviction query can leave manual pins alone.
|
||
"""
|
||
p = e.payload
|
||
conn.execute(
|
||
"UPDATE memories SET pinned = ?, auto_pinned = ? WHERE id = ?",
|
||
(int(p["pinned"]), int(p["auto_pinned"]), int(p["memory_id"])),
|
||
)
|
||
|
||
|
||
def get_memory(conn: Connection, memory_id: int) -> dict | None:
|
||
row = conn.execute(
|
||
"SELECT * FROM memories WHERE id = ?", (memory_id,)
|
||
).fetchone()
|
||
if not row:
|
||
return None
|
||
return _row_to_dict(conn, row)
|
||
|
||
|
||
def get_pinned(conn: Connection, owner_id: str) -> list[dict]:
|
||
cur = conn.execute(
|
||
"SELECT * FROM memories WHERE owner_id = ? AND pinned = 1 "
|
||
"ORDER BY created_at DESC, id DESC",
|
||
(owner_id,),
|
||
)
|
||
rows = cur.fetchall()
|
||
cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()]
|
||
return [dict(zip(cols, row)) for row in rows]
|
||
|
||
|
||
# Composite-score weights used by ``search_memories`` (T23, §8 retrieval).
|
||
# FTS5 BM25 ``rank`` is *more negative* for better matches, so subtracting a
|
||
# positive boost from it drives stronger candidates further down (i.e. earlier
|
||
# in an ascending sort). Hardcoded for v1 — tunable in a later pass.
|
||
_SIGNIFICANCE_WEIGHT = 0.3
|
||
_RECENCY_WEIGHT = 0.5
|
||
|
||
# T57 (Phase 3, §11.1): significance multiplier applied to the SQL ORDER BY in
|
||
# ``search_memories`` so that the FTS over-fetch already prefers
|
||
# higher-significance rows for tied / near-tied BM25 ranks. Module-level so it
|
||
# can be tuned without a code change. BM25 ``rank`` is lower-is-better, so the
|
||
# bias is *subtracted* from rank in the ASC ordering — equivalent to multiplying
|
||
# a higher-is-better score by a positive constant per the spec wording.
|
||
SIGNIFICANCE_RANK_BIAS = 0.5
|
||
|
||
# T96 (Phase 4): reciprocal-rank-fusion constant used when ``search_memories``
|
||
# is given a ``query_vector`` and must merge FTS + vector candidate lists. The
|
||
# value 60 is the canonical RRF constant from Cormack et al. ("Reciprocal Rank
|
||
# Fusion outperforms Condorcet and Individual Rank Learning Methods", SIGIR
|
||
# 2009): large enough to dampen the head of either ranking so that a strong
|
||
# top-1 in ranking A doesn't crowd out a moderate top-3 in ranking B, but
|
||
# small enough that the position-1/position-2 gap still matters.
|
||
RRF_CONST = 60
|
||
|
||
|
||
def _max_event_id(conn: Connection, owner_id: str) -> int:
|
||
"""Return the largest ``memories.id`` for ``owner_id`` (1 if none exist).
|
||
|
||
Used as the recency-boost denominator by both ``_composite_rerank`` and
|
||
``_rrf_fuse_and_rerank`` (T104). The row id is a monotonic recency proxy
|
||
— newer memories have larger ids — so dividing by the per-owner max keeps
|
||
the boost in [0, 1] regardless of how many memories the owner has.
|
||
|
||
Returns 1 (not 0) when the owner has no rows so callers can divide by
|
||
the result without a guard. The "no memories" case never actually hits
|
||
this helper because the FTS query above would have returned no rows,
|
||
but the safe default keeps the helper trivially reusable.
|
||
"""
|
||
row = conn.execute(
|
||
"SELECT MAX(id) FROM memories WHERE owner_id = ?", (owner_id,)
|
||
).fetchone()
|
||
return row[0] if row and row[0] else 1
|
||
|
||
|
||
def search_memories(
|
||
conn: Connection,
|
||
owner_id: str,
|
||
witness_role: str,
|
||
query: str,
|
||
k: int = 4,
|
||
*,
|
||
query_vector: list[float] | None = None,
|
||
) -> list[dict]:
|
||
"""FTS5 search over pov_summary, scoped by owner and witness role.
|
||
|
||
witness_role must be one of {"you", "host", "guest"} per the witness flags
|
||
on each memory row. Returns up to ``k`` rows ranked by a composite score
|
||
that combines the FTS5 BM25 rank with two boosts (§8 retrieval rules):
|
||
|
||
* **significance boost** — ``0.3 * significance`` (0..3 per §11.1).
|
||
* **recency boost** — ``0.5 * (id / max_id)``, using the row id as a
|
||
monotonic recency proxy. Newer memories therefore tilt above older ones
|
||
when the BM25 rank and significance are otherwise tied.
|
||
|
||
BM25 returns negative scores (lower = better). Both boosts are subtracted
|
||
so that stronger candidates yield smaller composite scores; the result is
|
||
sorted ascending and truncated to ``k``. The unmodified ``fts_rank`` and a
|
||
debug-friendly ``composite_score`` are kept on each returned dict.
|
||
|
||
The result ordering applies TWO independent significance boosts:
|
||
|
||
* **SQL-side** — ``ORDER BY (rank - significance * SIGNIFICANCE_RANK_BIAS)``
|
||
pushes higher-significance memories ahead in the FTS5 candidate set so
|
||
the over-fetch already prefers them for tied / near-tied BM25 ranks
|
||
(T57, §11.1).
|
||
* **Python-side** — a composite re-rank with ``_SIGNIFICANCE_WEIGHT``
|
||
reinforces the ordering after candidate retrieval, alongside the
|
||
recency boost above.
|
||
|
||
PHASE 4 EXTENSION (T96): when ``query_vector`` is provided, fuses FTS and
|
||
vector hits via reciprocal-rank fusion (RRF):
|
||
|
||
fusion_score = 1/(RRF_CONST + fts_rank) + 1/(RRF_CONST + vec_rank)
|
||
|
||
where ``fts_rank`` and ``vec_rank`` are the 0-indexed positions of the
|
||
memory in each candidate list. Each candidate gets the sum of its
|
||
reciprocal ranks across both rankings; memories appearing in only one
|
||
ranking still get a partial score (the other term is dropped). Both
|
||
candidate lists are over-fetched at ``k * 2`` so a memory dominant in
|
||
only one channel has a fair chance to surface. The Python-side
|
||
significance + recency re-rank is then applied as a final pass to
|
||
break ties in favour of more important / more recent memories.
|
||
|
||
When ``query_vector`` is None: FTS-only behaviour unchanged — all
|
||
Phase 1-3.5 callers see the same row shape and ordering as before.
|
||
|
||
**Row-shape contract (T104):** every returned dict carries an
|
||
``fts_rank`` key. For FTS hits this is the BM25 score (a negative float,
|
||
lower-is-better). For *vector-only* hits surfaced by the fused path —
|
||
rows that matched the query embedding but did NOT match FTS — the
|
||
``fts_rank`` value is ``None``. Downstream consumers must accept
|
||
``None`` here; do not assume ``fts_rank`` is always numeric. The
|
||
``composite_score`` is always a float on every returned row.
|
||
"""
|
||
if witness_role not in _VALID_WITNESS_ROLES:
|
||
raise ValueError(
|
||
f"witness_role must be one of {sorted(_VALID_WITNESS_ROLES)}, "
|
||
f"got {witness_role!r}"
|
||
)
|
||
if not query.strip():
|
||
return []
|
||
witness_col = f"witness_{witness_role}"
|
||
cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()]
|
||
select_list = ", ".join(f"m.{c}" for c in cols)
|
||
# Over-fetch from FTS so the Python-side re-rank has room to reorder
|
||
# results that BM25 alone would have demoted past the top-k boundary.
|
||
# When fusing with a vector ranking, we still over-fetch (k*2 from each
|
||
# channel) so memories that are weak in FTS but strong in vector — and
|
||
# vice versa — make it into the merge pool.
|
||
over_fetch = max(k * 2, 20) if query_vector is not None else max(k * 4, 20)
|
||
sql = (
|
||
f"SELECT {select_list}, memories_fts.rank AS fts_rank "
|
||
"FROM memories_fts "
|
||
"JOIN memories m ON m.id = memories_fts.rowid "
|
||
f"WHERE m.owner_id = ? AND m.{witness_col} = 1 "
|
||
"AND memories_fts MATCH ? "
|
||
# T57: significance multiplier biases the FTS over-fetch order. BM25
|
||
# ``rank`` is lower-is-better, so subtracting ``significance * BIAS``
|
||
# surfaces higher-significance rows above lower-significance rows with
|
||
# equal/near-equal match strength. Equivalent to ``score × constant``
|
||
# per §11.1 once the rank is inverted to a higher-is-better score.
|
||
"ORDER BY (memories_fts.rank - m.significance * ?) ASC "
|
||
"LIMIT ?"
|
||
)
|
||
cur = conn.execute(sql, (owner_id, query, SIGNIFICANCE_RANK_BIAS, over_fetch))
|
||
rows = cur.fetchall()
|
||
|
||
# FTS-only path: preserve pre-T96 behaviour exactly.
|
||
if query_vector is None:
|
||
if not rows:
|
||
return []
|
||
return _composite_rerank(conn, cols, rows, owner_id, k)
|
||
|
||
# Fused path: combine FTS candidates with vector candidates via RRF.
|
||
return _rrf_fuse_and_rerank(
|
||
conn,
|
||
cols=cols,
|
||
fts_rows=rows,
|
||
owner_id=owner_id,
|
||
witness_role=witness_role,
|
||
query_vector=query_vector,
|
||
k=k,
|
||
)
|
||
|
||
|
||
def _composite_rerank(
|
||
conn: Connection,
|
||
cols: list[str],
|
||
rows: list[tuple],
|
||
owner_id: str,
|
||
k: int,
|
||
) -> list[dict]:
|
||
"""Apply the significance + recency composite re-rank to FTS rows.
|
||
|
||
Extracted from ``search_memories`` so the no-vector path stays a single
|
||
call and the fused path can re-use the same boost formulae after RRF.
|
||
"""
|
||
max_id = _max_event_id(conn, owner_id)
|
||
|
||
result_cols = cols + ["fts_rank"]
|
||
enriched: list[dict] = []
|
||
for row in rows:
|
||
d = dict(zip(result_cols, row))
|
||
fts_rank = d.get("fts_rank") or 0.0
|
||
sig_boost = _SIGNIFICANCE_WEIGHT * (d.get("significance") or 0)
|
||
recency_boost = _RECENCY_WEIGHT * ((d.get("id") or 0) / max_id)
|
||
d["composite_score"] = fts_rank - sig_boost - recency_boost
|
||
enriched.append(d)
|
||
|
||
enriched.sort(key=lambda x: x["composite_score"])
|
||
return enriched[:k]
|
||
|
||
|
||
def _rrf_fuse_and_rerank(
|
||
conn: Connection,
|
||
*,
|
||
cols: list[str],
|
||
fts_rows: list[tuple],
|
||
owner_id: str,
|
||
witness_role: str,
|
||
query_vector: list[float],
|
||
k: int,
|
||
) -> list[dict]:
|
||
"""Merge FTS + vector candidates via reciprocal-rank fusion, then apply
|
||
the existing significance + recency boost as a final tie-breaker.
|
||
|
||
RRF formula (Cormack et al. 2009)::
|
||
|
||
fusion_score = sum over rankings r of 1 / (RRF_CONST + rank_r)
|
||
|
||
where ``rank_r`` is the 0-indexed position of the memory in ranking r.
|
||
"Missing from a ranking" is handled by SKIPPING the term for that
|
||
ranking — i.e. that channel contributes 0 to the sum, which preserves
|
||
the fairness property: a memory that only appears in one ranking is
|
||
not penalised relative to itself, just relative to memories that
|
||
appeared in both. This matches the canonical RRF presentation.
|
||
|
||
The final composite score subtracted from the *negated* fusion score
|
||
is::
|
||
|
||
composite = -fusion - sig_boost - recency_boost
|
||
|
||
Sorted ascending, smaller-is-better — the same ordering convention as
|
||
the FTS-only path so the Python-side significance + recency boosts
|
||
apply as a clean tie-breaker without inverting any sign.
|
||
"""
|
||
# Lazy import to avoid a hard module-level cycle: vector_search reads
|
||
# from chat.state.embeddings, which is itself a sibling of this module.
|
||
from chat.services.vector_search import vector_search
|
||
|
||
fts_rank_by_id: dict[int, int] = {}
|
||
fts_row_by_id: dict[int, tuple] = {}
|
||
id_idx = cols.index("id")
|
||
for rank, row in enumerate(fts_rows):
|
||
memory_id = row[id_idx]
|
||
fts_rank_by_id[memory_id] = rank
|
||
fts_row_by_id[memory_id] = row
|
||
|
||
# Over-fetch the vector channel symmetrically so each channel gets a
|
||
# fair shot at surfacing its strongest candidates.
|
||
vec_over_fetch = max(k * 2, 20)
|
||
vec_hits = vector_search(
|
||
conn,
|
||
owner_id=owner_id,
|
||
witness_role=witness_role,
|
||
query_vector=query_vector,
|
||
k=vec_over_fetch,
|
||
)
|
||
vec_rank_by_id: dict[int, int] = {
|
||
hit["memory_id"]: rank for rank, hit in enumerate(vec_hits)
|
||
}
|
||
|
||
# If the vector channel returned nothing (no embeddings indexed), the
|
||
# fused path collapses cleanly to the FTS-only path. No error, no
|
||
# surprise zero-hit return.
|
||
if not vec_rank_by_id and not fts_row_by_id:
|
||
return []
|
||
if not vec_rank_by_id:
|
||
return _composite_rerank(conn, cols, fts_rows, owner_id, k)
|
||
|
||
# For any vector-only hits we don't have a full memory row for yet,
|
||
# fetch them in a single round-trip. The FTS row carries an ``fts_rank``
|
||
# column at the end; vector-only rows get ``None`` there.
|
||
missing_ids = [mid for mid in vec_rank_by_id if mid not in fts_row_by_id]
|
||
select_list = ", ".join(cols)
|
||
if missing_ids:
|
||
placeholders = ",".join("?" * len(missing_ids))
|
||
cur = conn.execute(
|
||
f"SELECT {select_list} FROM memories WHERE id IN ({placeholders})",
|
||
missing_ids,
|
||
)
|
||
for row in cur.fetchall():
|
||
# Pad with a None for the trailing ``fts_rank`` slot so the row
|
||
# shape matches FTS rows downstream.
|
||
fts_row_by_id[row[id_idx]] = tuple(row) + (None,)
|
||
|
||
# Compute fusion score per candidate. Missing-from-ranking terms are
|
||
# simply omitted from the sum.
|
||
all_ids = set(fts_rank_by_id) | set(vec_rank_by_id)
|
||
fusion_by_id: dict[int, float] = {}
|
||
for mid in all_ids:
|
||
score = 0.0
|
||
if mid in fts_rank_by_id:
|
||
score += 1.0 / (RRF_CONST + fts_rank_by_id[mid])
|
||
if mid in vec_rank_by_id:
|
||
score += 1.0 / (RRF_CONST + vec_rank_by_id[mid])
|
||
fusion_by_id[mid] = score
|
||
|
||
# Final composite re-rank: significance + recency boosts on top of the
|
||
# negated fusion score so the sort direction matches the FTS-only path.
|
||
max_id = _max_event_id(conn, owner_id)
|
||
|
||
result_cols = cols + ["fts_rank"]
|
||
enriched: list[dict] = []
|
||
for mid in all_ids:
|
||
row = fts_row_by_id.get(mid)
|
||
if row is None:
|
||
# Defensive: a vector hit with no memory row would be a logic
|
||
# bug (vector_search joins memories), so just skip it rather
|
||
# than crash the whole search.
|
||
continue
|
||
d = dict(zip(result_cols, row))
|
||
sig_boost = _SIGNIFICANCE_WEIGHT * (d.get("significance") or 0)
|
||
recency_boost = _RECENCY_WEIGHT * ((d.get("id") or 0) / max_id)
|
||
fusion = fusion_by_id[mid]
|
||
# Sort ascending, smaller-is-better → negate fusion so a larger
|
||
# fusion score yields a smaller composite. Significance and recency
|
||
# boosts then act as tie-breakers exactly like the FTS-only path.
|
||
d["fusion_score"] = fusion
|
||
d["composite_score"] = -fusion - sig_boost - recency_boost
|
||
enriched.append(d)
|
||
|
||
enriched.sort(key=lambda x: x["composite_score"])
|
||
return enriched[:k]
|