diff --git a/chat/state/memory.py b/chat/state/memory.py index f0420ca..0426067 100644 --- a/chat/state/memory.py +++ b/chat/state/memory.py @@ -87,6 +87,14 @@ def get_pinned(conn: Connection, owner_id: str) -> list[dict]: return [dict(zip(cols, row)) for row in rows] +# Composite-score weights used by ``search_memories`` (T23, §8 retrieval). +# FTS5 BM25 ``rank`` is *more negative* for better matches, so subtracting a +# positive boost from it drives stronger candidates further down (i.e. earlier +# in an ascending sort). Hardcoded for v1 — tunable in a later pass. +_SIGNIFICANCE_WEIGHT = 0.3 +_RECENCY_WEIGHT = 0.5 + + def search_memories( conn: Connection, owner_id: str, @@ -97,16 +105,32 @@ def search_memories( """FTS5 search over pov_summary, scoped by owner and witness role. witness_role must be one of {"you", "host", "guest"} per the witness flags - on each memory row. Returns up to k rows ordered by FTS5 bm25 rank. + on each memory row. Returns up to ``k`` rows ranked by a composite score + that combines the FTS5 BM25 rank with two boosts (§8 retrieval rules): + + * **significance boost** — ``0.3 * significance`` (0..3 per §11.1). + * **recency boost** — ``0.5 * (id / max_id)``, using the row id as a + monotonic recency proxy. Newer memories therefore tilt above older ones + when the BM25 rank and significance are otherwise tied. + + BM25 returns negative scores (lower = better). Both boosts are subtracted + so that stronger candidates yield smaller composite scores; the result is + sorted ascending and truncated to ``k``. The unmodified ``fts_rank`` and a + debug-friendly ``composite_score`` are kept on each returned dict. """ if witness_role not in _VALID_WITNESS_ROLES: raise ValueError( f"witness_role must be one of {sorted(_VALID_WITNESS_ROLES)}, " f"got {witness_role!r}" ) + if not query.strip(): + return [] witness_col = f"witness_{witness_role}" cols = [c[1] for c in conn.execute("PRAGMA table_info(memories)").fetchall()] select_list = ", ".join(f"m.{c}" for c in cols) + # Over-fetch from FTS so the Python-side re-rank has room to reorder + # results that BM25 alone would have demoted past the top-k boundary. + over_fetch = max(k * 4, 20) sql = ( f"SELECT {select_list}, memories_fts.rank AS fts_rank " "FROM memories_fts " @@ -116,10 +140,27 @@ def search_memories( "ORDER BY memories_fts.rank " "LIMIT ?" ) - cur = conn.execute(sql, (owner_id, query, k)) + cur = conn.execute(sql, (owner_id, query, over_fetch)) rows = cur.fetchall() - out: list[dict] = [] + if not rows: + return [] + + # Recency normalises against the current max id for this owner so the + # boost magnitude is bounded regardless of dataset size. + max_id_row = conn.execute( + "SELECT MAX(id) FROM memories WHERE owner_id = ?", (owner_id,) + ).fetchone() + max_id = max_id_row[0] if max_id_row and max_id_row[0] else 1 + result_cols = cols + ["fts_rank"] + enriched: list[dict] = [] for row in rows: - out.append(dict(zip(result_cols, row))) - return out + d = dict(zip(result_cols, row)) + fts_rank = d.get("fts_rank") or 0.0 + sig_boost = _SIGNIFICANCE_WEIGHT * (d.get("significance") or 0) + recency_boost = _RECENCY_WEIGHT * ((d.get("id") or 0) / max_id) + d["composite_score"] = fts_rank - sig_boost - recency_boost + enriched.append(d) + + enriched.sort(key=lambda x: x["composite_score"]) + return enriched[:k] diff --git a/tests/test_memory_search.py b/tests/test_memory_search.py new file mode 100644 index 0000000..dad7e84 --- /dev/null +++ b/tests/test_memory_search.py @@ -0,0 +1,127 @@ +"""Task 23: FTS5 memory retrieval with witness filter and ranking boosts. + +Verifies that ``search_memories`` applies recency + significance boosts on top +of the FTS5 BM25 rank so that newer / more significant memories surface above +older / less significant ones for the same match. Existing T8 behaviour +(witness filter, k limit, FTS match, role validation) is exercised again here +to lock the contract. +""" + +from __future__ import annotations +import pytest + +from chat.db.connection import open_db +from chat.db.migrate import apply_migrations +from chat.eventlog.log import append_event +from chat.eventlog.projector import project +from chat.state.memory import search_memories +import chat.state.memory # noqa: F401 (registers memory_written handler) + + +def _seed(db, *, memory_specs): + """Apply migrations + project a list of memory_written events. + + memory_specs: list of dicts. Required key: ``pov_summary``. Optional keys + override the defaults below. + """ + apply_migrations(db) + with open_db(db) as conn: + for spec in memory_specs: + payload = { + "owner_id": spec.get("owner_id", "bot_a"), + "chat_id": spec.get("chat_id", "chat_bot_a"), + "pov_summary": spec["pov_summary"], + "witness_you": spec.get("witness_you", 1), + "witness_host": spec.get("witness_host", 1), + "witness_guest": spec.get("witness_guest", 0), + "source": "direct", + "reliability": 1.0, + "significance": spec.get("significance", 1), + "pinned": 0, + "auto_pinned": 0, + } + append_event(conn, kind="memory_written", payload=payload) + project(conn) + + +def test_search_filters_by_witness_bit(tmp_path): + db = tmp_path / "t.db" + _seed( + db, + memory_specs=[ + { + "pov_summary": "BotA mentioned her sister", + "witness_you": 1, + "witness_host": 1, + "witness_guest": 0, + }, + ], + ) + with open_db(db) as conn: + # Witnessed by host -> returned. + out = search_memories(conn, "bot_a", "host", "sister", k=4) + assert len(out) == 1 + # NOT witnessed by guest -> filtered out. + out = search_memories(conn, "bot_a", "guest", "sister", k=4) + assert out == [] + + +def test_search_higher_significance_ranks_above_lower(tmp_path): + db = tmp_path / "t.db" + _seed( + db, + memory_specs=[ + # Both match "promise"; the third row carries significance 3 and + # should outrank the first two, which carry the default of 1. + {"pov_summary": "small promise"}, + {"pov_summary": "huge promise"}, + {"pov_summary": "tiny promise", "significance": 3}, + ], + ) + with open_db(db) as conn: + out = search_memories(conn, "bot_a", "host", "promise", k=3) + assert len(out) == 3 + assert out[0]["pov_summary"] == "tiny promise" + assert out[0]["significance"] == 3 + + +def test_search_newer_memory_ranks_above_older_when_same_match(tmp_path): + db = tmp_path / "t.db" + _seed( + db, + memory_specs=[ + {"pov_summary": "BotA said hello"}, + {"pov_summary": "BotA said hello again"}, + ], + ) + with open_db(db) as conn: + out = search_memories(conn, "bot_a", "host", "hello", k=2) + assert len(out) == 2 + # Newer (higher id, "again") wins on the recency boost when the BM25 + # rank and significance are otherwise comparable. + assert out[0]["pov_summary"] == "BotA said hello again" + + +def test_search_respects_k_limit(tmp_path): + db = tmp_path / "t.db" + _seed( + db, + memory_specs=[ + {"pov_summary": "the cat sat"}, + {"pov_summary": "the cat ran"}, + {"pov_summary": "the cat slept"}, + {"pov_summary": "the cat ate"}, + {"pov_summary": "the cat purred"}, + ], + ) + with open_db(db) as conn: + out = search_memories(conn, "bot_a", "host", "cat", k=2) + assert len(out) == 2 + + +def test_search_invalid_witness_role_raises(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + with pytest.raises(ValueError): + search_memories(conn, "bot_a", "invalid_role", "anything", k=4)