feat: FTS5 memory retrieval with witness filter and ranking boosts
This commit is contained in:
+46
-5
@@ -87,6 +87,14 @@ def get_pinned(conn: Connection, owner_id: str) -> list[dict]:
|
||||
return [dict(zip(cols, row)) for row in rows]
|
||||
|
||||
|
||||
# Composite-score weights used by ``search_memories`` (T23, §8 retrieval).
|
||||
# FTS5 BM25 ``rank`` is *more negative* for better matches, so subtracting a
|
||||
# positive boost from it drives stronger candidates further down (i.e. earlier
|
||||
# in an ascending sort). Hardcoded for v1 — tunable in a later pass.
|
||||
_SIGNIFICANCE_WEIGHT = 0.3
|
||||
_RECENCY_WEIGHT = 0.5
|
||||
|
||||
|
||||
def search_memories(
|
||||
conn: Connection,
|
||||
owner_id: str,
|
||||
@@ -97,16 +105,32 @@ def search_memories(
|
||||
"""FTS5 search over pov_summary, scoped by owner and witness role.
|
||||
|
||||
witness_role must be one of {"you", "host", "guest"} per the witness flags
|
||||
on each memory row. Returns up to k rows ordered by FTS5 bm25 rank.
|
||||
on each memory row. Returns up to ``k`` rows ranked by a composite score
|
||||
that combines the FTS5 BM25 rank with two boosts (§8 retrieval rules):
|
||||
|
||||
* **significance boost** — ``0.3 * significance`` (0..3 per §11.1).
|
||||
* **recency boost** — ``0.5 * (id / max_id)``, using the row id as a
|
||||
monotonic recency proxy. Newer memories therefore tilt above older ones
|
||||
when the BM25 rank and significance are otherwise tied.
|
||||
|
||||
BM25 returns negative scores (lower = better). Both boosts are subtracted
|
||||
so that stronger candidates yield smaller composite scores; the result is
|
||||
sorted ascending and truncated to ``k``. The unmodified ``fts_rank`` and a
|
||||
debug-friendly ``composite_score`` are kept on each returned dict.
|
||||
"""
|
||||
if witness_role not in _VALID_WITNESS_ROLES:
|
||||
raise ValueError(
|
||||
f"witness_role must be one of {sorted(_VALID_WITNESS_ROLES)}, "
|
||||
f"got {witness_role!r}"
|
||||
)
|
||||
if not query.strip():
|
||||
return []
|
||||
witness_col = f"witness_{witness_role}"
|
||||
cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()]
|
||||
select_list = ", ".join(f"m.{c}" for c in cols)
|
||||
# Over-fetch from FTS so the Python-side re-rank has room to reorder
|
||||
# results that BM25 alone would have demoted past the top-k boundary.
|
||||
over_fetch = max(k * 4, 20)
|
||||
sql = (
|
||||
f"SELECT {select_list}, memories_fts.rank AS fts_rank "
|
||||
"FROM memories_fts "
|
||||
@@ -116,10 +140,27 @@ def search_memories(
|
||||
"ORDER BY memories_fts.rank "
|
||||
"LIMIT ?"
|
||||
)
|
||||
cur = conn.execute(sql, (owner_id, query, k))
|
||||
cur = conn.execute(sql, (owner_id, query, over_fetch))
|
||||
rows = cur.fetchall()
|
||||
out: list[dict] = []
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
# Recency normalises against the current max id for this owner so the
|
||||
# boost magnitude is bounded regardless of dataset size.
|
||||
max_id_row = conn.execute(
|
||||
"SELECT MAX(id) FROM memories WHERE owner_id = ?", (owner_id,)
|
||||
).fetchone()
|
||||
max_id = max_id_row[0] if max_id_row and max_id_row[0] else 1
|
||||
|
||||
result_cols = cols + ["fts_rank"]
|
||||
enriched: list[dict] = []
|
||||
for row in rows:
|
||||
out.append(dict(zip(result_cols, row)))
|
||||
return out
|
||||
d = dict(zip(result_cols, row))
|
||||
fts_rank = d.get("fts_rank") or 0.0
|
||||
sig_boost = _SIGNIFICANCE_WEIGHT * (d.get("significance") or 0)
|
||||
recency_boost = _RECENCY_WEIGHT * ((d.get("id") or 0) / max_id)
|
||||
d["composite_score"] = fts_rank - sig_boost - recency_boost
|
||||
enriched.append(d)
|
||||
|
||||
enriched.sort(key=lambda x: x["composite_score"])
|
||||
return enriched[:k]
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
"""Task 23: FTS5 memory retrieval with witness filter and ranking boosts.
|
||||
|
||||
Verifies that ``search_memories`` applies recency + significance boosts on top
|
||||
of the FTS5 BM25 rank so that newer / more significant memories surface above
|
||||
older / less significant ones for the same match. Existing T8 behaviour
|
||||
(witness filter, k limit, FTS match, role validation) is exercised again here
|
||||
to lock the contract.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import pytest
|
||||
|
||||
from chat.db.connection import open_db
|
||||
from chat.db.migrate import apply_migrations
|
||||
from chat.eventlog.log import append_event
|
||||
from chat.eventlog.projector import project
|
||||
from chat.state.memory import search_memories
|
||||
import chat.state.memory # noqa: F401 (registers memory_written handler)
|
||||
|
||||
|
||||
def _seed(db, *, memory_specs):
|
||||
"""Apply migrations + project a list of memory_written events.
|
||||
|
||||
memory_specs: list of dicts. Required key: ``pov_summary``. Optional keys
|
||||
override the defaults below.
|
||||
"""
|
||||
apply_migrations(db)
|
||||
with open_db(db) as conn:
|
||||
for spec in memory_specs:
|
||||
payload = {
|
||||
"owner_id": spec.get("owner_id", "bot_a"),
|
||||
"chat_id": spec.get("chat_id", "chat_bot_a"),
|
||||
"pov_summary": spec["pov_summary"],
|
||||
"witness_you": spec.get("witness_you", 1),
|
||||
"witness_host": spec.get("witness_host", 1),
|
||||
"witness_guest": spec.get("witness_guest", 0),
|
||||
"source": "direct",
|
||||
"reliability": 1.0,
|
||||
"significance": spec.get("significance", 1),
|
||||
"pinned": 0,
|
||||
"auto_pinned": 0,
|
||||
}
|
||||
append_event(conn, kind="memory_written", payload=payload)
|
||||
project(conn)
|
||||
|
||||
|
||||
def test_search_filters_by_witness_bit(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
_seed(
|
||||
db,
|
||||
memory_specs=[
|
||||
{
|
||||
"pov_summary": "BotA mentioned her sister",
|
||||
"witness_you": 1,
|
||||
"witness_host": 1,
|
||||
"witness_guest": 0,
|
||||
},
|
||||
],
|
||||
)
|
||||
with open_db(db) as conn:
|
||||
# Witnessed by host -> returned.
|
||||
out = search_memories(conn, "bot_a", "host", "sister", k=4)
|
||||
assert len(out) == 1
|
||||
# NOT witnessed by guest -> filtered out.
|
||||
out = search_memories(conn, "bot_a", "guest", "sister", k=4)
|
||||
assert out == []
|
||||
|
||||
|
||||
def test_search_higher_significance_ranks_above_lower(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
_seed(
|
||||
db,
|
||||
memory_specs=[
|
||||
# Both match "promise"; the third row carries significance 3 and
|
||||
# should outrank the first two, which carry the default of 1.
|
||||
{"pov_summary": "small promise"},
|
||||
{"pov_summary": "huge promise"},
|
||||
{"pov_summary": "tiny promise", "significance": 3},
|
||||
],
|
||||
)
|
||||
with open_db(db) as conn:
|
||||
out = search_memories(conn, "bot_a", "host", "promise", k=3)
|
||||
assert len(out) == 3
|
||||
assert out[0]["pov_summary"] == "tiny promise"
|
||||
assert out[0]["significance"] == 3
|
||||
|
||||
|
||||
def test_search_newer_memory_ranks_above_older_when_same_match(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
_seed(
|
||||
db,
|
||||
memory_specs=[
|
||||
{"pov_summary": "BotA said hello"},
|
||||
{"pov_summary": "BotA said hello again"},
|
||||
],
|
||||
)
|
||||
with open_db(db) as conn:
|
||||
out = search_memories(conn, "bot_a", "host", "hello", k=2)
|
||||
assert len(out) == 2
|
||||
# Newer (higher id, "again") wins on the recency boost when the BM25
|
||||
# rank and significance are otherwise comparable.
|
||||
assert out[0]["pov_summary"] == "BotA said hello again"
|
||||
|
||||
|
||||
def test_search_respects_k_limit(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
_seed(
|
||||
db,
|
||||
memory_specs=[
|
||||
{"pov_summary": "the cat sat"},
|
||||
{"pov_summary": "the cat ran"},
|
||||
{"pov_summary": "the cat slept"},
|
||||
{"pov_summary": "the cat ate"},
|
||||
{"pov_summary": "the cat purred"},
|
||||
],
|
||||
)
|
||||
with open_db(db) as conn:
|
||||
out = search_memories(conn, "bot_a", "host", "cat", k=2)
|
||||
assert len(out) == 2
|
||||
|
||||
|
||||
def test_search_invalid_witness_role_raises(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
with open_db(db) as conn:
|
||||
with pytest.raises(ValueError):
|
||||
search_memories(conn, "bot_a", "invalid_role", "anything", k=4)
|
||||
Reference in New Issue
Block a user