From b8b4aed6d974789054b9a662443a07195eb7bc3b Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Mon, 27 Apr 2026 02:42:38 -0400 Subject: [PATCH] feat: combined FTS + vector retrieval ranking via RRF (T96) --- chat/state/memory.py | 193 +++++++++++++++++++++++++++++++- tests/test_memory_search.py | 214 ++++++++++++++++++++++++++++++++++++ 2 files changed, 402 insertions(+), 5 deletions(-) diff --git a/chat/state/memory.py b/chat/state/memory.py index 5310965..42a7e95 100644 --- a/chat/state/memory.py +++ b/chat/state/memory.py @@ -102,6 +102,15 @@ _RECENCY_WEIGHT = 0.5 # a higher-is-better score by a positive constant per the spec wording. SIGNIFICANCE_RANK_BIAS = 0.5 +# T96 (Phase 4): reciprocal-rank-fusion constant used when ``search_memories`` +# is given a ``query_vector`` and must merge FTS + vector candidate lists. The +# value 60 is the canonical RRF constant from Cormack et al. ("Reciprocal Rank +# Fusion outperforms Condorcet and Individual Rank Learning Methods", SIGIR +# 2009): large enough to dampen the head of either ranking so that a strong +# top-1 in ranking A doesn't crowd out a moderate top-3 in ranking B, but +# small enough that the position-1/position-2 gap still matters. +RRF_CONST = 60 + def search_memories( conn: Connection, @@ -109,6 +118,8 @@ def search_memories( witness_role: str, query: str, k: int = 4, + *, + query_vector: list[float] | None = None, ) -> list[dict]: """FTS5 search over pov_summary, scoped by owner and witness role. @@ -135,6 +146,23 @@ def search_memories( * **Python-side** — a composite re-rank with ``_SIGNIFICANCE_WEIGHT`` reinforces the ordering after candidate retrieval, alongside the recency boost above. + + PHASE 4 EXTENSION (T96): when ``query_vector`` is provided, fuses FTS and + vector hits via reciprocal-rank fusion (RRF): + + fusion_score = 1/(RRF_CONST + fts_rank) + 1/(RRF_CONST + vec_rank) + + where ``fts_rank`` and ``vec_rank`` are the 0-indexed positions of the + memory in each candidate list. Each candidate gets the sum of its + reciprocal ranks across both rankings; memories appearing in only one + ranking still get a partial score (the other term is dropped). Both + candidate lists are over-fetched at ``k * 2`` so a memory dominant in + only one channel has a fair chance to surface. The Python-side + significance + recency re-rank is then applied as a final pass to + break ties in favour of more important / more recent memories. + + When ``query_vector`` is None: FTS-only behaviour unchanged — all + Phase 1-3.5 callers see the same row shape and ordering as before. """ if witness_role not in _VALID_WITNESS_ROLES: raise ValueError( @@ -148,7 +176,10 @@ def search_memories( select_list = ", ".join(f"m.{c}" for c in cols) # Over-fetch from FTS so the Python-side re-rank has room to reorder # results that BM25 alone would have demoted past the top-k boundary. - over_fetch = max(k * 4, 20) + # When fusing with a vector ranking, we still over-fetch (k*2 from each + # channel) so memories that are weak in FTS but strong in vector — and + # vice versa — make it into the merge pool. + over_fetch = max(k * 2, 20) if query_vector is not None else max(k * 4, 20) sql = ( f"SELECT {select_list}, memories_fts.rank AS fts_rank " "FROM memories_fts " @@ -165,11 +196,37 @@ def search_memories( ) cur = conn.execute(sql, (owner_id, query, SIGNIFICANCE_RANK_BIAS, over_fetch)) rows = cur.fetchall() - if not rows: - return [] - # Recency normalises against the current max id for this owner so the - # boost magnitude is bounded regardless of dataset size. + # FTS-only path: preserve pre-T96 behaviour exactly. + if query_vector is None: + if not rows: + return [] + return _composite_rerank(conn, cols, rows, owner_id, k) + + # Fused path: combine FTS candidates with vector candidates via RRF. + return _rrf_fuse_and_rerank( + conn, + cols=cols, + fts_rows=rows, + owner_id=owner_id, + witness_role=witness_role, + query_vector=query_vector, + k=k, + ) + + +def _composite_rerank( + conn: Connection, + cols: list[str], + rows: list[tuple], + owner_id: str, + k: int, +) -> list[dict]: + """Apply the significance + recency composite re-rank to FTS rows. + + Extracted from ``search_memories`` so the no-vector path stays a single + call and the fused path can re-use the same boost formulae after RRF. + """ max_id_row = conn.execute( "SELECT MAX(id) FROM memories WHERE owner_id = ?", (owner_id,) ).fetchone() @@ -187,3 +244,129 @@ def search_memories( enriched.sort(key=lambda x: x["composite_score"]) return enriched[:k] + + +def _rrf_fuse_and_rerank( + conn: Connection, + *, + cols: list[str], + fts_rows: list[tuple], + owner_id: str, + witness_role: str, + query_vector: list[float], + k: int, +) -> list[dict]: + """Merge FTS + vector candidates via reciprocal-rank fusion, then apply + the existing significance + recency boost as a final tie-breaker. + + RRF formula (Cormack et al. 2009):: + + fusion_score = sum over rankings r of 1 / (RRF_CONST + rank_r) + + where ``rank_r`` is the 0-indexed position of the memory in ranking r. + "Missing from a ranking" is handled by SKIPPING the term for that + ranking — i.e. that channel contributes 0 to the sum, which preserves + the fairness property: a memory that only appears in one ranking is + not penalised relative to itself, just relative to memories that + appeared in both. This matches the canonical RRF presentation. + + The final composite score subtracted from the *negated* fusion score + is:: + + composite = -fusion - sig_boost - recency_boost + + Sorted ascending, smaller-is-better — the same ordering convention as + the FTS-only path so the Python-side significance + recency boosts + apply as a clean tie-breaker without inverting any sign. + """ + # Lazy import to avoid a hard module-level cycle: vector_search reads + # from chat.state.embeddings, which is itself a sibling of this module. + from chat.services.vector_search import vector_search + + fts_rank_by_id: dict[int, int] = {} + fts_row_by_id: dict[int, tuple] = {} + id_idx = cols.index("id") + for rank, row in enumerate(fts_rows): + memory_id = row[id_idx] + fts_rank_by_id[memory_id] = rank + fts_row_by_id[memory_id] = row + + # Over-fetch the vector channel symmetrically so each channel gets a + # fair shot at surfacing its strongest candidates. + vec_over_fetch = max(k * 2, 20) + vec_hits = vector_search( + conn, + owner_id=owner_id, + witness_role=witness_role, + query_vector=query_vector, + k=vec_over_fetch, + ) + vec_rank_by_id: dict[int, int] = { + hit["memory_id"]: rank for rank, hit in enumerate(vec_hits) + } + + # If the vector channel returned nothing (no embeddings indexed), the + # fused path collapses cleanly to the FTS-only path. No error, no + # surprise zero-hit return. + if not vec_rank_by_id and not fts_row_by_id: + return [] + if not vec_rank_by_id: + return _composite_rerank(conn, cols, fts_rows, owner_id, k) + + # For any vector-only hits we don't have a full memory row for yet, + # fetch them in a single round-trip. The FTS row carries an ``fts_rank`` + # column at the end; vector-only rows get ``None`` there. + missing_ids = [mid for mid in vec_rank_by_id if mid not in fts_row_by_id] + select_list = ", ".join(cols) + if missing_ids: + placeholders = ",".join("?" * len(missing_ids)) + cur = conn.execute( + f"SELECT {select_list} FROM memories WHERE id IN ({placeholders})", + missing_ids, + ) + for row in cur.fetchall(): + # Pad with a None for the trailing ``fts_rank`` slot so the row + # shape matches FTS rows downstream. + fts_row_by_id[row[id_idx]] = tuple(row) + (None,) + + # Compute fusion score per candidate. Missing-from-ranking terms are + # simply omitted from the sum. + all_ids = set(fts_rank_by_id) | set(vec_rank_by_id) + fusion_by_id: dict[int, float] = {} + for mid in all_ids: + score = 0.0 + if mid in fts_rank_by_id: + score += 1.0 / (RRF_CONST + fts_rank_by_id[mid]) + if mid in vec_rank_by_id: + score += 1.0 / (RRF_CONST + vec_rank_by_id[mid]) + fusion_by_id[mid] = score + + # Final composite re-rank: significance + recency boosts on top of the + # negated fusion score so the sort direction matches the FTS-only path. + max_id_row = conn.execute( + "SELECT MAX(id) FROM memories WHERE owner_id = ?", (owner_id,) + ).fetchone() + max_id = max_id_row[0] if max_id_row and max_id_row[0] else 1 + + result_cols = cols + ["fts_rank"] + enriched: list[dict] = [] + for mid in all_ids: + row = fts_row_by_id.get(mid) + if row is None: + # Defensive: a vector hit with no memory row would be a logic + # bug (vector_search joins memories), so just skip it rather + # than crash the whole search. + continue + d = dict(zip(result_cols, row)) + sig_boost = _SIGNIFICANCE_WEIGHT * (d.get("significance") or 0) + recency_boost = _RECENCY_WEIGHT * ((d.get("id") or 0) / max_id) + fusion = fusion_by_id[mid] + # Sort ascending, smaller-is-better → negate fusion so a larger + # fusion score yields a smaller composite. Significance and recency + # boosts then act as tie-breakers exactly like the FTS-only path. + d["fusion_score"] = fusion + d["composite_score"] = -fusion - sig_boost - recency_boost + enriched.append(d) + + enriched.sort(key=lambda x: x["composite_score"]) + return enriched[:k] diff --git a/tests/test_memory_search.py b/tests/test_memory_search.py index 76f0ee1..c62c1bf 100644 --- a/tests/test_memory_search.py +++ b/tests/test_memory_search.py @@ -16,6 +16,7 @@ from chat.eventlog.log import append_event from chat.eventlog.projector import project from chat.state.memory import search_memories import chat.state.memory # noqa: F401 (registers memory_written handler) +import chat.state.embeddings # noqa: F401 (registers embedding_indexed handler) def _seed(db, *, memory_specs): @@ -159,3 +160,216 @@ def test_significance_bias_is_constant_module_level(): # Must be non-negative -- a negative bias would invert the desired # "higher significance ranks higher" semantics. assert SIGNIFICANCE_RANK_BIAS >= 0 + + +# --------------------------------------------------------------------------- +# T96 (Phase 4): combined FTS + vector retrieval ranking via reciprocal-rank +# fusion. The fused path activates only when ``query_vector`` is provided — +# the no-vector path (above) is unchanged. +# --------------------------------------------------------------------------- + + +def _one_hot(dim: int, idx: int) -> list[float]: + v = [0.0] * dim + v[idx] = 1.0 + return v + + +def _seed_memories_with_optional_embeddings(db, *, memory_specs): + """Like ``_seed`` but also projects ``embedding_indexed`` events for any + spec carrying a ``vector`` key. + + Memory rows are assigned ids in the order their ``memory_written`` events + were appended (the ``memories.id`` column is an autoincrementing primary + key), so we predict ``memory_id = i + 1`` per spec and append both kinds + of events back-to-back BEFORE projecting. Projecting only once keeps the + INSERT-based ``memory_written`` handler from duplicating rows. + """ + apply_migrations(db) + with open_db(db) as conn: + # First pass: append every memory_written event in order. The DB + # assigns autoincrementing ids 1..N matching the order of these + # events, so we can pair vectors to memories by index. + for spec in memory_specs: + payload = { + "owner_id": spec.get("owner_id", "bot_a"), + "chat_id": spec.get("chat_id", "chat_bot_a"), + "pov_summary": spec["pov_summary"], + "witness_you": spec.get("witness_you", 1), + "witness_host": spec.get("witness_host", 1), + "witness_guest": spec.get("witness_guest", 0), + "source": "direct", + "reliability": 1.0, + "significance": spec.get("significance", 1), + "pinned": 0, + "auto_pinned": 0, + } + append_event(conn, kind="memory_written", payload=payload) + # Second pass: append embedding_indexed events for any spec that + # supplied a vector, using the predicted memory id. + for i, spec in enumerate(memory_specs, start=1): + if "vector" not in spec: + continue + vec = spec["vector"] + append_event( + conn, + kind="embedding_indexed", + payload={ + "memory_id": i, + "vector": list(vec), + "model": "test-model", + "dim": len(vec), + }, + ) + # Single projection — avoids the memory_written handler INSERTing + # the same row twice on a re-projection. + project(conn) + + +def test_search_memories_without_query_vector_uses_fts_only(tmp_path): + """Regression: omitting ``query_vector`` keeps the existing FTS-only path. + + Identical seed to ``test_search_higher_significance_ranks_above_lower`` + but pinned to the no-vector code path explicitly (no kwarg passed). + """ + db = tmp_path / "t.db" + _seed( + db, + memory_specs=[ + {"pov_summary": "small promise"}, + {"pov_summary": "huge promise"}, + {"pov_summary": "tiny promise", "significance": 3}, + ], + ) + with open_db(db) as conn: + out = search_memories(conn, "bot_a", "host", "promise", k=3) + assert len(out) == 3 + # The composite re-rank surfaces the high-significance row first. + assert out[0]["pov_summary"] == "tiny promise" + # Sanity: the row shape still carries ``fts_rank`` + ``composite_score`` + # like the FTS-only path always has. + assert "fts_rank" in out[0] + assert "composite_score" in out[0] + + +def test_search_memories_with_query_vector_includes_vector_hits(tmp_path): + """RRF fuses FTS hits with vector hits — both kinds surface in the result. + + Memory 1 only matches FTS (keyword "rabbit", embedding far from query). + Memory 2 only matches the vector (embedding identical to query, no + keyword overlap). Memories 3-5 are unrelated. The fused top-K must + contain BOTH memory 1 and memory 2. + """ + db = tmp_path / "t.db" + dim = 8 + # Query vector = one-hot at index 0. Memory 2 mirrors it exactly. The + # FTS-only memory (memory 1) has NO embedding so it cannot leak into + # the vector ranking; the filler memories (3-5) likewise have no + # embeddings, so the vector ranking returns memory 2 alone. + query_vec = _one_hot(dim, 0) + _seed_memories_with_optional_embeddings( + db, + memory_specs=[ + # Memory 1: FTS-only match. No embedding indexed. + {"pov_summary": "rabbit hopped over the fence"}, + # Memory 2: vector-only match. No keyword overlap with "rabbit". + { + "pov_summary": "completely unrelated narrative line", + "vector": _one_hot(dim, 0), + }, + # Memories 3-5: filler, irrelevant to both channels. + {"pov_summary": "lighthouse keeper polished the lens"}, + {"pov_summary": "they discussed cartography for hours"}, + {"pov_summary": "she taught him semaphore signals"}, + ], + ) + with open_db(db) as conn: + out = search_memories( + conn, + "bot_a", + "host", + "rabbit", + k=4, + query_vector=query_vec, + ) + summaries = [r["pov_summary"] for r in out] + # FTS-only candidate (memory 1) made it through. + assert "rabbit hopped over the fence" in summaries + # Vector-only candidate (memory 2) also made it through despite + # having no keyword overlap with the query string. + assert "completely unrelated narrative line" in summaries + + +def test_search_memories_fusion_significance_bias_still_applies(tmp_path): + """With two RRF-tied candidates, the higher-significance one ranks first. + + Two memories share the keyword "promise" AND share an identical + embedding to the query — so their FTS rank and vector rank are both + ties. RRF gives them the same fusion score. The Python-side + significance + recency boost must break the tie in favour of the + higher-significance memory. + """ + db = tmp_path / "t.db" + dim = 4 + shared_vec = _one_hot(dim, 0) + _seed_memories_with_optional_embeddings( + db, + memory_specs=[ + { + "pov_summary": "she made a promise", + "significance": 0, + "vector": list(shared_vec), + }, + { + "pov_summary": "she made a promise", + "significance": 3, + "vector": list(shared_vec), + }, + ], + ) + with open_db(db) as conn: + out = search_memories( + conn, + "bot_a", + "host", + "promise", + k=2, + query_vector=list(shared_vec), + ) + assert len(out) == 2 + # Higher significance breaks the RRF tie. + assert out[0]["significance"] == 3 + assert out[1]["significance"] == 0 + + +def test_search_memories_fusion_handles_empty_vector_results(tmp_path): + """Vector path returning [] (no embeddings indexed) must not break FTS. + + No ``embedding_indexed`` events are projected, so ``vector_search`` + returns an empty list. The function should still return the FTS hits + as if ``query_vector`` had not been supplied. + """ + db = tmp_path / "t.db" + _seed( + db, + memory_specs=[ + {"pov_summary": "the vault held an old promise"}, + {"pov_summary": "another promise was kept that night"}, + ], + ) + with open_db(db) as conn: + out = search_memories( + conn, + "bot_a", + "host", + "promise", + k=4, + query_vector=[0.0] * 384, # No embeddings exist for this owner. + ) + # Both FTS hits still come back — no error from the empty vector path. + assert len(out) == 2 + summaries = {r["pov_summary"] for r in out} + assert summaries == { + "the vault held an old promise", + "another promise was kept that night", + }