feat: combined FTS + vector retrieval ranking via RRF (T96)
This commit is contained in:
+188
-5
@@ -102,6 +102,15 @@ _RECENCY_WEIGHT = 0.5
|
||||
# a higher-is-better score by a positive constant per the spec wording.
|
||||
SIGNIFICANCE_RANK_BIAS = 0.5
|
||||
|
||||
# T96 (Phase 4): reciprocal-rank-fusion constant used when ``search_memories``
|
||||
# is given a ``query_vector`` and must merge FTS + vector candidate lists. The
|
||||
# value 60 is the canonical RRF constant from Cormack et al. ("Reciprocal Rank
|
||||
# Fusion outperforms Condorcet and Individual Rank Learning Methods", SIGIR
|
||||
# 2009): large enough to dampen the head of either ranking so that a strong
|
||||
# top-1 in ranking A doesn't crowd out a moderate top-3 in ranking B, but
|
||||
# small enough that the position-1/position-2 gap still matters.
|
||||
RRF_CONST = 60
|
||||
|
||||
|
||||
def search_memories(
|
||||
conn: Connection,
|
||||
@@ -109,6 +118,8 @@ def search_memories(
|
||||
witness_role: str,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
*,
|
||||
query_vector: list[float] | None = None,
|
||||
) -> list[dict]:
|
||||
"""FTS5 search over pov_summary, scoped by owner and witness role.
|
||||
|
||||
@@ -135,6 +146,23 @@ def search_memories(
|
||||
* **Python-side** — a composite re-rank with ``_SIGNIFICANCE_WEIGHT``
|
||||
reinforces the ordering after candidate retrieval, alongside the
|
||||
recency boost above.
|
||||
|
||||
PHASE 4 EXTENSION (T96): when ``query_vector`` is provided, fuses FTS and
|
||||
vector hits via reciprocal-rank fusion (RRF):
|
||||
|
||||
fusion_score = 1/(RRF_CONST + fts_rank) + 1/(RRF_CONST + vec_rank)
|
||||
|
||||
where ``fts_rank`` and ``vec_rank`` are the 0-indexed positions of the
|
||||
memory in each candidate list. Each candidate gets the sum of its
|
||||
reciprocal ranks across both rankings; memories appearing in only one
|
||||
ranking still get a partial score (the other term is dropped). Both
|
||||
candidate lists are over-fetched at ``k * 2`` so a memory dominant in
|
||||
only one channel has a fair chance to surface. The Python-side
|
||||
significance + recency re-rank is then applied as a final pass to
|
||||
break ties in favour of more important / more recent memories.
|
||||
|
||||
When ``query_vector`` is None: FTS-only behaviour unchanged — all
|
||||
Phase 1-3.5 callers see the same row shape and ordering as before.
|
||||
"""
|
||||
if witness_role not in _VALID_WITNESS_ROLES:
|
||||
raise ValueError(
|
||||
@@ -148,7 +176,10 @@ def search_memories(
|
||||
select_list = ", ".join(f"m.{c}" for c in cols)
|
||||
# Over-fetch from FTS so the Python-side re-rank has room to reorder
|
||||
# results that BM25 alone would have demoted past the top-k boundary.
|
||||
over_fetch = max(k * 4, 20)
|
||||
# When fusing with a vector ranking, we still over-fetch (k*2 from each
|
||||
# channel) so memories that are weak in FTS but strong in vector — and
|
||||
# vice versa — make it into the merge pool.
|
||||
over_fetch = max(k * 2, 20) if query_vector is not None else max(k * 4, 20)
|
||||
sql = (
|
||||
f"SELECT {select_list}, memories_fts.rank AS fts_rank "
|
||||
"FROM memories_fts "
|
||||
@@ -165,11 +196,37 @@ def search_memories(
|
||||
)
|
||||
cur = conn.execute(sql, (owner_id, query, SIGNIFICANCE_RANK_BIAS, over_fetch))
|
||||
rows = cur.fetchall()
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
# Recency normalises against the current max id for this owner so the
|
||||
# boost magnitude is bounded regardless of dataset size.
|
||||
# FTS-only path: preserve pre-T96 behaviour exactly.
|
||||
if query_vector is None:
|
||||
if not rows:
|
||||
return []
|
||||
return _composite_rerank(conn, cols, rows, owner_id, k)
|
||||
|
||||
# Fused path: combine FTS candidates with vector candidates via RRF.
|
||||
return _rrf_fuse_and_rerank(
|
||||
conn,
|
||||
cols=cols,
|
||||
fts_rows=rows,
|
||||
owner_id=owner_id,
|
||||
witness_role=witness_role,
|
||||
query_vector=query_vector,
|
||||
k=k,
|
||||
)
|
||||
|
||||
|
||||
def _composite_rerank(
|
||||
conn: Connection,
|
||||
cols: list[str],
|
||||
rows: list[tuple],
|
||||
owner_id: str,
|
||||
k: int,
|
||||
) -> list[dict]:
|
||||
"""Apply the significance + recency composite re-rank to FTS rows.
|
||||
|
||||
Extracted from ``search_memories`` so the no-vector path stays a single
|
||||
call and the fused path can re-use the same boost formulae after RRF.
|
||||
"""
|
||||
max_id_row = conn.execute(
|
||||
"SELECT MAX(id) FROM memories WHERE owner_id = ?", (owner_id,)
|
||||
).fetchone()
|
||||
@@ -187,3 +244,129 @@ def search_memories(
|
||||
|
||||
enriched.sort(key=lambda x: x["composite_score"])
|
||||
return enriched[:k]
|
||||
|
||||
|
||||
def _rrf_fuse_and_rerank(
|
||||
conn: Connection,
|
||||
*,
|
||||
cols: list[str],
|
||||
fts_rows: list[tuple],
|
||||
owner_id: str,
|
||||
witness_role: str,
|
||||
query_vector: list[float],
|
||||
k: int,
|
||||
) -> list[dict]:
|
||||
"""Merge FTS + vector candidates via reciprocal-rank fusion, then apply
|
||||
the existing significance + recency boost as a final tie-breaker.
|
||||
|
||||
RRF formula (Cormack et al. 2009)::
|
||||
|
||||
fusion_score = sum over rankings r of 1 / (RRF_CONST + rank_r)
|
||||
|
||||
where ``rank_r`` is the 0-indexed position of the memory in ranking r.
|
||||
"Missing from a ranking" is handled by SKIPPING the term for that
|
||||
ranking — i.e. that channel contributes 0 to the sum, which preserves
|
||||
the fairness property: a memory that only appears in one ranking is
|
||||
not penalised relative to itself, just relative to memories that
|
||||
appeared in both. This matches the canonical RRF presentation.
|
||||
|
||||
The final composite score subtracted from the *negated* fusion score
|
||||
is::
|
||||
|
||||
composite = -fusion - sig_boost - recency_boost
|
||||
|
||||
Sorted ascending, smaller-is-better — the same ordering convention as
|
||||
the FTS-only path so the Python-side significance + recency boosts
|
||||
apply as a clean tie-breaker without inverting any sign.
|
||||
"""
|
||||
# Lazy import to avoid a hard module-level cycle: vector_search reads
|
||||
# from chat.state.embeddings, which is itself a sibling of this module.
|
||||
from chat.services.vector_search import vector_search
|
||||
|
||||
fts_rank_by_id: dict[int, int] = {}
|
||||
fts_row_by_id: dict[int, tuple] = {}
|
||||
id_idx = cols.index("id")
|
||||
for rank, row in enumerate(fts_rows):
|
||||
memory_id = row[id_idx]
|
||||
fts_rank_by_id[memory_id] = rank
|
||||
fts_row_by_id[memory_id] = row
|
||||
|
||||
# Over-fetch the vector channel symmetrically so each channel gets a
|
||||
# fair shot at surfacing its strongest candidates.
|
||||
vec_over_fetch = max(k * 2, 20)
|
||||
vec_hits = vector_search(
|
||||
conn,
|
||||
owner_id=owner_id,
|
||||
witness_role=witness_role,
|
||||
query_vector=query_vector,
|
||||
k=vec_over_fetch,
|
||||
)
|
||||
vec_rank_by_id: dict[int, int] = {
|
||||
hit["memory_id"]: rank for rank, hit in enumerate(vec_hits)
|
||||
}
|
||||
|
||||
# If the vector channel returned nothing (no embeddings indexed), the
|
||||
# fused path collapses cleanly to the FTS-only path. No error, no
|
||||
# surprise zero-hit return.
|
||||
if not vec_rank_by_id and not fts_row_by_id:
|
||||
return []
|
||||
if not vec_rank_by_id:
|
||||
return _composite_rerank(conn, cols, fts_rows, owner_id, k)
|
||||
|
||||
# For any vector-only hits we don't have a full memory row for yet,
|
||||
# fetch them in a single round-trip. The FTS row carries an ``fts_rank``
|
||||
# column at the end; vector-only rows get ``None`` there.
|
||||
missing_ids = [mid for mid in vec_rank_by_id if mid not in fts_row_by_id]
|
||||
select_list = ", ".join(cols)
|
||||
if missing_ids:
|
||||
placeholders = ",".join("?" * len(missing_ids))
|
||||
cur = conn.execute(
|
||||
f"SELECT {select_list} FROM memories WHERE id IN ({placeholders})",
|
||||
missing_ids,
|
||||
)
|
||||
for row in cur.fetchall():
|
||||
# Pad with a None for the trailing ``fts_rank`` slot so the row
|
||||
# shape matches FTS rows downstream.
|
||||
fts_row_by_id[row[id_idx]] = tuple(row) + (None,)
|
||||
|
||||
# Compute fusion score per candidate. Missing-from-ranking terms are
|
||||
# simply omitted from the sum.
|
||||
all_ids = set(fts_rank_by_id) | set(vec_rank_by_id)
|
||||
fusion_by_id: dict[int, float] = {}
|
||||
for mid in all_ids:
|
||||
score = 0.0
|
||||
if mid in fts_rank_by_id:
|
||||
score += 1.0 / (RRF_CONST + fts_rank_by_id[mid])
|
||||
if mid in vec_rank_by_id:
|
||||
score += 1.0 / (RRF_CONST + vec_rank_by_id[mid])
|
||||
fusion_by_id[mid] = score
|
||||
|
||||
# Final composite re-rank: significance + recency boosts on top of the
|
||||
# negated fusion score so the sort direction matches the FTS-only path.
|
||||
max_id_row = conn.execute(
|
||||
"SELECT MAX(id) FROM memories WHERE owner_id = ?", (owner_id,)
|
||||
).fetchone()
|
||||
max_id = max_id_row[0] if max_id_row and max_id_row[0] else 1
|
||||
|
||||
result_cols = cols + ["fts_rank"]
|
||||
enriched: list[dict] = []
|
||||
for mid in all_ids:
|
||||
row = fts_row_by_id.get(mid)
|
||||
if row is None:
|
||||
# Defensive: a vector hit with no memory row would be a logic
|
||||
# bug (vector_search joins memories), so just skip it rather
|
||||
# than crash the whole search.
|
||||
continue
|
||||
d = dict(zip(result_cols, row))
|
||||
sig_boost = _SIGNIFICANCE_WEIGHT * (d.get("significance") or 0)
|
||||
recency_boost = _RECENCY_WEIGHT * ((d.get("id") or 0) / max_id)
|
||||
fusion = fusion_by_id[mid]
|
||||
# Sort ascending, smaller-is-better → negate fusion so a larger
|
||||
# fusion score yields a smaller composite. Significance and recency
|
||||
# boosts then act as tie-breakers exactly like the FTS-only path.
|
||||
d["fusion_score"] = fusion
|
||||
d["composite_score"] = -fusion - sig_boost - recency_boost
|
||||
enriched.append(d)
|
||||
|
||||
enriched.sort(key=lambda x: x["composite_score"])
|
||||
return enriched[:k]
|
||||
|
||||
Reference in New Issue
Block a user