162 lines
5.6 KiB
Python
162 lines
5.6 KiB
Python
"""Task 23: FTS5 memory retrieval with witness filter and ranking boosts.
|
|
|
|
Verifies that ``search_memories`` applies recency + significance boosts on top
|
|
of the FTS5 BM25 rank so that newer / more significant memories surface above
|
|
older / less significant ones for the same match. Existing T8 behaviour
|
|
(witness filter, k limit, FTS match, role validation) is exercised again here
|
|
to lock the contract.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
import pytest
|
|
|
|
from chat.db.connection import open_db
|
|
from chat.db.migrate import apply_migrations
|
|
from chat.eventlog.log import append_event
|
|
from chat.eventlog.projector import project
|
|
from chat.state.memory import search_memories
|
|
import chat.state.memory # noqa: F401 (registers memory_written handler)
|
|
|
|
|
|
def _seed(db, *, memory_specs):
|
|
"""Apply migrations + project a list of memory_written events.
|
|
|
|
memory_specs: list of dicts. Required key: ``pov_summary``. Optional keys
|
|
override the defaults below.
|
|
"""
|
|
apply_migrations(db)
|
|
with open_db(db) as conn:
|
|
for spec in memory_specs:
|
|
payload = {
|
|
"owner_id": spec.get("owner_id", "bot_a"),
|
|
"chat_id": spec.get("chat_id", "chat_bot_a"),
|
|
"pov_summary": spec["pov_summary"],
|
|
"witness_you": spec.get("witness_you", 1),
|
|
"witness_host": spec.get("witness_host", 1),
|
|
"witness_guest": spec.get("witness_guest", 0),
|
|
"source": "direct",
|
|
"reliability": 1.0,
|
|
"significance": spec.get("significance", 1),
|
|
"pinned": 0,
|
|
"auto_pinned": 0,
|
|
}
|
|
append_event(conn, kind="memory_written", payload=payload)
|
|
project(conn)
|
|
|
|
|
|
def test_search_filters_by_witness_bit(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
_seed(
|
|
db,
|
|
memory_specs=[
|
|
{
|
|
"pov_summary": "BotA mentioned her sister",
|
|
"witness_you": 1,
|
|
"witness_host": 1,
|
|
"witness_guest": 0,
|
|
},
|
|
],
|
|
)
|
|
with open_db(db) as conn:
|
|
# Witnessed by host -> returned.
|
|
out = search_memories(conn, "bot_a", "host", "sister", k=4)
|
|
assert len(out) == 1
|
|
# NOT witnessed by guest -> filtered out.
|
|
out = search_memories(conn, "bot_a", "guest", "sister", k=4)
|
|
assert out == []
|
|
|
|
|
|
def test_search_higher_significance_ranks_above_lower(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
_seed(
|
|
db,
|
|
memory_specs=[
|
|
# Both match "promise"; the third row carries significance 3 and
|
|
# should outrank the first two, which carry the default of 1.
|
|
{"pov_summary": "small promise"},
|
|
{"pov_summary": "huge promise"},
|
|
{"pov_summary": "tiny promise", "significance": 3},
|
|
],
|
|
)
|
|
with open_db(db) as conn:
|
|
out = search_memories(conn, "bot_a", "host", "promise", k=3)
|
|
assert len(out) == 3
|
|
assert out[0]["pov_summary"] == "tiny promise"
|
|
assert out[0]["significance"] == 3
|
|
|
|
|
|
def test_search_newer_memory_ranks_above_older_when_same_match(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
_seed(
|
|
db,
|
|
memory_specs=[
|
|
{"pov_summary": "BotA said hello"},
|
|
{"pov_summary": "BotA said hello again"},
|
|
],
|
|
)
|
|
with open_db(db) as conn:
|
|
out = search_memories(conn, "bot_a", "host", "hello", k=2)
|
|
assert len(out) == 2
|
|
# Newer (higher id, "again") wins on the recency boost when the BM25
|
|
# rank and significance are otherwise comparable.
|
|
assert out[0]["pov_summary"] == "BotA said hello again"
|
|
|
|
|
|
def test_search_respects_k_limit(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
_seed(
|
|
db,
|
|
memory_specs=[
|
|
{"pov_summary": "the cat sat"},
|
|
{"pov_summary": "the cat ran"},
|
|
{"pov_summary": "the cat slept"},
|
|
{"pov_summary": "the cat ate"},
|
|
{"pov_summary": "the cat purred"},
|
|
],
|
|
)
|
|
with open_db(db) as conn:
|
|
out = search_memories(conn, "bot_a", "host", "cat", k=2)
|
|
assert len(out) == 2
|
|
|
|
|
|
def test_search_invalid_witness_role_raises(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
with open_db(db) as conn:
|
|
with pytest.raises(ValueError):
|
|
search_memories(conn, "bot_a", "invalid_role", "anything", k=4)
|
|
|
|
|
|
def test_higher_significance_outranks_equal_rank(tmp_path):
|
|
"""T57: significance multiplier biases the SQL ORDER BY.
|
|
|
|
Two memories with IDENTICAL FTS-matching text yield (effectively) equal
|
|
BM25 ranks. The significance bias applied in the SQL ORDER BY must
|
|
surface the higher-significance row first.
|
|
"""
|
|
db = tmp_path / "t.db"
|
|
_seed(
|
|
db,
|
|
memory_specs=[
|
|
# Identical pov_summary text -> FTS BM25 rank is the same for both.
|
|
{"pov_summary": "she swore an oath", "significance": 0},
|
|
{"pov_summary": "she swore an oath", "significance": 3},
|
|
],
|
|
)
|
|
with open_db(db) as conn:
|
|
out = search_memories(conn, "bot_a", "host", "oath", k=5)
|
|
assert len(out) == 2
|
|
# Higher significance wins despite tied FTS rank.
|
|
assert out[0]["significance"] == 3
|
|
assert out[1]["significance"] == 0
|
|
|
|
|
|
def test_significance_bias_is_constant_module_level():
|
|
"""T57: pin ``SIGNIFICANCE_RANK_BIAS`` as a tunable module-level numeric."""
|
|
from chat.state.memory import SIGNIFICANCE_RANK_BIAS
|
|
|
|
assert isinstance(SIGNIFICANCE_RANK_BIAS, (int, float))
|
|
# Must be non-negative -- a negative bias would invert the desired
|
|
# "higher significance ranks higher" semantics.
|
|
assert SIGNIFICANCE_RANK_BIAS >= 0
|