merge: T96 combined FTS + vector retrieval ranking via RRF
This commit is contained in:
+188
-5
@@ -102,6 +102,15 @@ _RECENCY_WEIGHT = 0.5
|
||||
# a higher-is-better score by a positive constant per the spec wording.
|
||||
SIGNIFICANCE_RANK_BIAS = 0.5
|
||||
|
||||
# T96 (Phase 4): reciprocal-rank-fusion constant used when ``search_memories``
|
||||
# is given a ``query_vector`` and must merge FTS + vector candidate lists. The
|
||||
# value 60 is the canonical RRF constant from Cormack et al. ("Reciprocal Rank
|
||||
# Fusion outperforms Condorcet and Individual Rank Learning Methods", SIGIR
|
||||
# 2009): large enough to dampen the head of either ranking so that a strong
|
||||
# top-1 in ranking A doesn't crowd out a moderate top-3 in ranking B, but
|
||||
# small enough that the position-1/position-2 gap still matters.
|
||||
RRF_CONST = 60
|
||||
|
||||
|
||||
def search_memories(
|
||||
conn: Connection,
|
||||
@@ -109,6 +118,8 @@ def search_memories(
|
||||
witness_role: str,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
*,
|
||||
query_vector: list[float] | None = None,
|
||||
) -> list[dict]:
|
||||
"""FTS5 search over pov_summary, scoped by owner and witness role.
|
||||
|
||||
@@ -135,6 +146,23 @@ def search_memories(
|
||||
* **Python-side** — a composite re-rank with ``_SIGNIFICANCE_WEIGHT``
|
||||
reinforces the ordering after candidate retrieval, alongside the
|
||||
recency boost above.
|
||||
|
||||
PHASE 4 EXTENSION (T96): when ``query_vector`` is provided, fuses FTS and
|
||||
vector hits via reciprocal-rank fusion (RRF):
|
||||
|
||||
fusion_score = 1/(RRF_CONST + fts_rank) + 1/(RRF_CONST + vec_rank)
|
||||
|
||||
where ``fts_rank`` and ``vec_rank`` are the 0-indexed positions of the
|
||||
memory in each candidate list. Each candidate gets the sum of its
|
||||
reciprocal ranks across both rankings; memories appearing in only one
|
||||
ranking still get a partial score (the other term is dropped). Both
|
||||
candidate lists are over-fetched at ``k * 2`` so a memory dominant in
|
||||
only one channel has a fair chance to surface. The Python-side
|
||||
significance + recency re-rank is then applied as a final pass to
|
||||
break ties in favour of more important / more recent memories.
|
||||
|
||||
When ``query_vector`` is None: FTS-only behaviour unchanged — all
|
||||
Phase 1-3.5 callers see the same row shape and ordering as before.
|
||||
"""
|
||||
if witness_role not in _VALID_WITNESS_ROLES:
|
||||
raise ValueError(
|
||||
@@ -148,7 +176,10 @@ def search_memories(
|
||||
select_list = ", ".join(f"m.{c}" for c in cols)
|
||||
# Over-fetch from FTS so the Python-side re-rank has room to reorder
|
||||
# results that BM25 alone would have demoted past the top-k boundary.
|
||||
over_fetch = max(k * 4, 20)
|
||||
# When fusing with a vector ranking, we still over-fetch (k*2 from each
|
||||
# channel) so memories that are weak in FTS but strong in vector — and
|
||||
# vice versa — make it into the merge pool.
|
||||
over_fetch = max(k * 2, 20) if query_vector is not None else max(k * 4, 20)
|
||||
sql = (
|
||||
f"SELECT {select_list}, memories_fts.rank AS fts_rank "
|
||||
"FROM memories_fts "
|
||||
@@ -165,11 +196,37 @@ def search_memories(
|
||||
)
|
||||
cur = conn.execute(sql, (owner_id, query, SIGNIFICANCE_RANK_BIAS, over_fetch))
|
||||
rows = cur.fetchall()
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
# Recency normalises against the current max id for this owner so the
|
||||
# boost magnitude is bounded regardless of dataset size.
|
||||
# FTS-only path: preserve pre-T96 behaviour exactly.
|
||||
if query_vector is None:
|
||||
if not rows:
|
||||
return []
|
||||
return _composite_rerank(conn, cols, rows, owner_id, k)
|
||||
|
||||
# Fused path: combine FTS candidates with vector candidates via RRF.
|
||||
return _rrf_fuse_and_rerank(
|
||||
conn,
|
||||
cols=cols,
|
||||
fts_rows=rows,
|
||||
owner_id=owner_id,
|
||||
witness_role=witness_role,
|
||||
query_vector=query_vector,
|
||||
k=k,
|
||||
)
|
||||
|
||||
|
||||
def _composite_rerank(
|
||||
conn: Connection,
|
||||
cols: list[str],
|
||||
rows: list[tuple],
|
||||
owner_id: str,
|
||||
k: int,
|
||||
) -> list[dict]:
|
||||
"""Apply the significance + recency composite re-rank to FTS rows.
|
||||
|
||||
Extracted from ``search_memories`` so the no-vector path stays a single
|
||||
call and the fused path can re-use the same boost formulae after RRF.
|
||||
"""
|
||||
max_id_row = conn.execute(
|
||||
"SELECT MAX(id) FROM memories WHERE owner_id = ?", (owner_id,)
|
||||
).fetchone()
|
||||
@@ -187,3 +244,129 @@ def search_memories(
|
||||
|
||||
enriched.sort(key=lambda x: x["composite_score"])
|
||||
return enriched[:k]
|
||||
|
||||
|
||||
def _rrf_fuse_and_rerank(
|
||||
conn: Connection,
|
||||
*,
|
||||
cols: list[str],
|
||||
fts_rows: list[tuple],
|
||||
owner_id: str,
|
||||
witness_role: str,
|
||||
query_vector: list[float],
|
||||
k: int,
|
||||
) -> list[dict]:
|
||||
"""Merge FTS + vector candidates via reciprocal-rank fusion, then apply
|
||||
the existing significance + recency boost as a final tie-breaker.
|
||||
|
||||
RRF formula (Cormack et al. 2009)::
|
||||
|
||||
fusion_score = sum over rankings r of 1 / (RRF_CONST + rank_r)
|
||||
|
||||
where ``rank_r`` is the 0-indexed position of the memory in ranking r.
|
||||
"Missing from a ranking" is handled by SKIPPING the term for that
|
||||
ranking — i.e. that channel contributes 0 to the sum, which preserves
|
||||
the fairness property: a memory that only appears in one ranking is
|
||||
not penalised relative to itself, just relative to memories that
|
||||
appeared in both. This matches the canonical RRF presentation.
|
||||
|
||||
The final composite score subtracted from the *negated* fusion score
|
||||
is::
|
||||
|
||||
composite = -fusion - sig_boost - recency_boost
|
||||
|
||||
Sorted ascending, smaller-is-better — the same ordering convention as
|
||||
the FTS-only path so the Python-side significance + recency boosts
|
||||
apply as a clean tie-breaker without inverting any sign.
|
||||
"""
|
||||
# Lazy import to avoid a hard module-level cycle: vector_search reads
|
||||
# from chat.state.embeddings, which is itself a sibling of this module.
|
||||
from chat.services.vector_search import vector_search
|
||||
|
||||
fts_rank_by_id: dict[int, int] = {}
|
||||
fts_row_by_id: dict[int, tuple] = {}
|
||||
id_idx = cols.index("id")
|
||||
for rank, row in enumerate(fts_rows):
|
||||
memory_id = row[id_idx]
|
||||
fts_rank_by_id[memory_id] = rank
|
||||
fts_row_by_id[memory_id] = row
|
||||
|
||||
# Over-fetch the vector channel symmetrically so each channel gets a
|
||||
# fair shot at surfacing its strongest candidates.
|
||||
vec_over_fetch = max(k * 2, 20)
|
||||
vec_hits = vector_search(
|
||||
conn,
|
||||
owner_id=owner_id,
|
||||
witness_role=witness_role,
|
||||
query_vector=query_vector,
|
||||
k=vec_over_fetch,
|
||||
)
|
||||
vec_rank_by_id: dict[int, int] = {
|
||||
hit["memory_id"]: rank for rank, hit in enumerate(vec_hits)
|
||||
}
|
||||
|
||||
# If the vector channel returned nothing (no embeddings indexed), the
|
||||
# fused path collapses cleanly to the FTS-only path. No error, no
|
||||
# surprise zero-hit return.
|
||||
if not vec_rank_by_id and not fts_row_by_id:
|
||||
return []
|
||||
if not vec_rank_by_id:
|
||||
return _composite_rerank(conn, cols, fts_rows, owner_id, k)
|
||||
|
||||
# For any vector-only hits we don't have a full memory row for yet,
|
||||
# fetch them in a single round-trip. The FTS row carries an ``fts_rank``
|
||||
# column at the end; vector-only rows get ``None`` there.
|
||||
missing_ids = [mid for mid in vec_rank_by_id if mid not in fts_row_by_id]
|
||||
select_list = ", ".join(cols)
|
||||
if missing_ids:
|
||||
placeholders = ",".join("?" * len(missing_ids))
|
||||
cur = conn.execute(
|
||||
f"SELECT {select_list} FROM memories WHERE id IN ({placeholders})",
|
||||
missing_ids,
|
||||
)
|
||||
for row in cur.fetchall():
|
||||
# Pad with a None for the trailing ``fts_rank`` slot so the row
|
||||
# shape matches FTS rows downstream.
|
||||
fts_row_by_id[row[id_idx]] = tuple(row) + (None,)
|
||||
|
||||
# Compute fusion score per candidate. Missing-from-ranking terms are
|
||||
# simply omitted from the sum.
|
||||
all_ids = set(fts_rank_by_id) | set(vec_rank_by_id)
|
||||
fusion_by_id: dict[int, float] = {}
|
||||
for mid in all_ids:
|
||||
score = 0.0
|
||||
if mid in fts_rank_by_id:
|
||||
score += 1.0 / (RRF_CONST + fts_rank_by_id[mid])
|
||||
if mid in vec_rank_by_id:
|
||||
score += 1.0 / (RRF_CONST + vec_rank_by_id[mid])
|
||||
fusion_by_id[mid] = score
|
||||
|
||||
# Final composite re-rank: significance + recency boosts on top of the
|
||||
# negated fusion score so the sort direction matches the FTS-only path.
|
||||
max_id_row = conn.execute(
|
||||
"SELECT MAX(id) FROM memories WHERE owner_id = ?", (owner_id,)
|
||||
).fetchone()
|
||||
max_id = max_id_row[0] if max_id_row and max_id_row[0] else 1
|
||||
|
||||
result_cols = cols + ["fts_rank"]
|
||||
enriched: list[dict] = []
|
||||
for mid in all_ids:
|
||||
row = fts_row_by_id.get(mid)
|
||||
if row is None:
|
||||
# Defensive: a vector hit with no memory row would be a logic
|
||||
# bug (vector_search joins memories), so just skip it rather
|
||||
# than crash the whole search.
|
||||
continue
|
||||
d = dict(zip(result_cols, row))
|
||||
sig_boost = _SIGNIFICANCE_WEIGHT * (d.get("significance") or 0)
|
||||
recency_boost = _RECENCY_WEIGHT * ((d.get("id") or 0) / max_id)
|
||||
fusion = fusion_by_id[mid]
|
||||
# Sort ascending, smaller-is-better → negate fusion so a larger
|
||||
# fusion score yields a smaller composite. Significance and recency
|
||||
# boosts then act as tie-breakers exactly like the FTS-only path.
|
||||
d["fusion_score"] = fusion
|
||||
d["composite_score"] = -fusion - sig_boost - recency_boost
|
||||
enriched.append(d)
|
||||
|
||||
enriched.sort(key=lambda x: x["composite_score"])
|
||||
return enriched[:k]
|
||||
|
||||
@@ -16,6 +16,7 @@ from chat.eventlog.log import append_event
|
||||
from chat.eventlog.projector import project
|
||||
from chat.state.memory import search_memories
|
||||
import chat.state.memory # noqa: F401 (registers memory_written handler)
|
||||
import chat.state.embeddings # noqa: F401 (registers embedding_indexed handler)
|
||||
|
||||
|
||||
def _seed(db, *, memory_specs):
|
||||
@@ -159,3 +160,216 @@ def test_significance_bias_is_constant_module_level():
|
||||
# Must be non-negative -- a negative bias would invert the desired
|
||||
# "higher significance ranks higher" semantics.
|
||||
assert SIGNIFICANCE_RANK_BIAS >= 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# T96 (Phase 4): combined FTS + vector retrieval ranking via reciprocal-rank
|
||||
# fusion. The fused path activates only when ``query_vector`` is provided —
|
||||
# the no-vector path (above) is unchanged.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _one_hot(dim: int, idx: int) -> list[float]:
|
||||
v = [0.0] * dim
|
||||
v[idx] = 1.0
|
||||
return v
|
||||
|
||||
|
||||
def _seed_memories_with_optional_embeddings(db, *, memory_specs):
|
||||
"""Like ``_seed`` but also projects ``embedding_indexed`` events for any
|
||||
spec carrying a ``vector`` key.
|
||||
|
||||
Memory rows are assigned ids in the order their ``memory_written`` events
|
||||
were appended (the ``memories.id`` column is an autoincrementing primary
|
||||
key), so we predict ``memory_id = i + 1`` per spec and append both kinds
|
||||
of events back-to-back BEFORE projecting. Projecting only once keeps the
|
||||
INSERT-based ``memory_written`` handler from duplicating rows.
|
||||
"""
|
||||
apply_migrations(db)
|
||||
with open_db(db) as conn:
|
||||
# First pass: append every memory_written event in order. The DB
|
||||
# assigns autoincrementing ids 1..N matching the order of these
|
||||
# events, so we can pair vectors to memories by index.
|
||||
for spec in memory_specs:
|
||||
payload = {
|
||||
"owner_id": spec.get("owner_id", "bot_a"),
|
||||
"chat_id": spec.get("chat_id", "chat_bot_a"),
|
||||
"pov_summary": spec["pov_summary"],
|
||||
"witness_you": spec.get("witness_you", 1),
|
||||
"witness_host": spec.get("witness_host", 1),
|
||||
"witness_guest": spec.get("witness_guest", 0),
|
||||
"source": "direct",
|
||||
"reliability": 1.0,
|
||||
"significance": spec.get("significance", 1),
|
||||
"pinned": 0,
|
||||
"auto_pinned": 0,
|
||||
}
|
||||
append_event(conn, kind="memory_written", payload=payload)
|
||||
# Second pass: append embedding_indexed events for any spec that
|
||||
# supplied a vector, using the predicted memory id.
|
||||
for i, spec in enumerate(memory_specs, start=1):
|
||||
if "vector" not in spec:
|
||||
continue
|
||||
vec = spec["vector"]
|
||||
append_event(
|
||||
conn,
|
||||
kind="embedding_indexed",
|
||||
payload={
|
||||
"memory_id": i,
|
||||
"vector": list(vec),
|
||||
"model": "test-model",
|
||||
"dim": len(vec),
|
||||
},
|
||||
)
|
||||
# Single projection — avoids the memory_written handler INSERTing
|
||||
# the same row twice on a re-projection.
|
||||
project(conn)
|
||||
|
||||
|
||||
def test_search_memories_without_query_vector_uses_fts_only(tmp_path):
|
||||
"""Regression: omitting ``query_vector`` keeps the existing FTS-only path.
|
||||
|
||||
Identical seed to ``test_search_higher_significance_ranks_above_lower``
|
||||
but pinned to the no-vector code path explicitly (no kwarg passed).
|
||||
"""
|
||||
db = tmp_path / "t.db"
|
||||
_seed(
|
||||
db,
|
||||
memory_specs=[
|
||||
{"pov_summary": "small promise"},
|
||||
{"pov_summary": "huge promise"},
|
||||
{"pov_summary": "tiny promise", "significance": 3},
|
||||
],
|
||||
)
|
||||
with open_db(db) as conn:
|
||||
out = search_memories(conn, "bot_a", "host", "promise", k=3)
|
||||
assert len(out) == 3
|
||||
# The composite re-rank surfaces the high-significance row first.
|
||||
assert out[0]["pov_summary"] == "tiny promise"
|
||||
# Sanity: the row shape still carries ``fts_rank`` + ``composite_score``
|
||||
# like the FTS-only path always has.
|
||||
assert "fts_rank" in out[0]
|
||||
assert "composite_score" in out[0]
|
||||
|
||||
|
||||
def test_search_memories_with_query_vector_includes_vector_hits(tmp_path):
|
||||
"""RRF fuses FTS hits with vector hits — both kinds surface in the result.
|
||||
|
||||
Memory 1 only matches FTS (keyword "rabbit", embedding far from query).
|
||||
Memory 2 only matches the vector (embedding identical to query, no
|
||||
keyword overlap). Memories 3-5 are unrelated. The fused top-K must
|
||||
contain BOTH memory 1 and memory 2.
|
||||
"""
|
||||
db = tmp_path / "t.db"
|
||||
dim = 8
|
||||
# Query vector = one-hot at index 0. Memory 2 mirrors it exactly. The
|
||||
# FTS-only memory (memory 1) has NO embedding so it cannot leak into
|
||||
# the vector ranking; the filler memories (3-5) likewise have no
|
||||
# embeddings, so the vector ranking returns memory 2 alone.
|
||||
query_vec = _one_hot(dim, 0)
|
||||
_seed_memories_with_optional_embeddings(
|
||||
db,
|
||||
memory_specs=[
|
||||
# Memory 1: FTS-only match. No embedding indexed.
|
||||
{"pov_summary": "rabbit hopped over the fence"},
|
||||
# Memory 2: vector-only match. No keyword overlap with "rabbit".
|
||||
{
|
||||
"pov_summary": "completely unrelated narrative line",
|
||||
"vector": _one_hot(dim, 0),
|
||||
},
|
||||
# Memories 3-5: filler, irrelevant to both channels.
|
||||
{"pov_summary": "lighthouse keeper polished the lens"},
|
||||
{"pov_summary": "they discussed cartography for hours"},
|
||||
{"pov_summary": "she taught him semaphore signals"},
|
||||
],
|
||||
)
|
||||
with open_db(db) as conn:
|
||||
out = search_memories(
|
||||
conn,
|
||||
"bot_a",
|
||||
"host",
|
||||
"rabbit",
|
||||
k=4,
|
||||
query_vector=query_vec,
|
||||
)
|
||||
summaries = [r["pov_summary"] for r in out]
|
||||
# FTS-only candidate (memory 1) made it through.
|
||||
assert "rabbit hopped over the fence" in summaries
|
||||
# Vector-only candidate (memory 2) also made it through despite
|
||||
# having no keyword overlap with the query string.
|
||||
assert "completely unrelated narrative line" in summaries
|
||||
|
||||
|
||||
def test_search_memories_fusion_significance_bias_still_applies(tmp_path):
|
||||
"""With two RRF-tied candidates, the higher-significance one ranks first.
|
||||
|
||||
Two memories share the keyword "promise" AND share an identical
|
||||
embedding to the query — so their FTS rank and vector rank are both
|
||||
ties. RRF gives them the same fusion score. The Python-side
|
||||
significance + recency boost must break the tie in favour of the
|
||||
higher-significance memory.
|
||||
"""
|
||||
db = tmp_path / "t.db"
|
||||
dim = 4
|
||||
shared_vec = _one_hot(dim, 0)
|
||||
_seed_memories_with_optional_embeddings(
|
||||
db,
|
||||
memory_specs=[
|
||||
{
|
||||
"pov_summary": "she made a promise",
|
||||
"significance": 0,
|
||||
"vector": list(shared_vec),
|
||||
},
|
||||
{
|
||||
"pov_summary": "she made a promise",
|
||||
"significance": 3,
|
||||
"vector": list(shared_vec),
|
||||
},
|
||||
],
|
||||
)
|
||||
with open_db(db) as conn:
|
||||
out = search_memories(
|
||||
conn,
|
||||
"bot_a",
|
||||
"host",
|
||||
"promise",
|
||||
k=2,
|
||||
query_vector=list(shared_vec),
|
||||
)
|
||||
assert len(out) == 2
|
||||
# Higher significance breaks the RRF tie.
|
||||
assert out[0]["significance"] == 3
|
||||
assert out[1]["significance"] == 0
|
||||
|
||||
|
||||
def test_search_memories_fusion_handles_empty_vector_results(tmp_path):
|
||||
"""Vector path returning [] (no embeddings indexed) must not break FTS.
|
||||
|
||||
No ``embedding_indexed`` events are projected, so ``vector_search``
|
||||
returns an empty list. The function should still return the FTS hits
|
||||
as if ``query_vector`` had not been supplied.
|
||||
"""
|
||||
db = tmp_path / "t.db"
|
||||
_seed(
|
||||
db,
|
||||
memory_specs=[
|
||||
{"pov_summary": "the vault held an old promise"},
|
||||
{"pov_summary": "another promise was kept that night"},
|
||||
],
|
||||
)
|
||||
with open_db(db) as conn:
|
||||
out = search_memories(
|
||||
conn,
|
||||
"bot_a",
|
||||
"host",
|
||||
"promise",
|
||||
k=4,
|
||||
query_vector=[0.0] * 384, # No embeddings exist for this owner.
|
||||
)
|
||||
# Both FTS hits still come back — no error from the empty vector path.
|
||||
assert len(out) == 2
|
||||
summaries = {r["pov_summary"] for r in out}
|
||||
assert summaries == {
|
||||
"the vault held an old promise",
|
||||
"another promise was kept that night",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user