Files
chat/tests/test_cross_chat_search.py
T
2026-04-27 02:31:31 -04:00

156 lines
5.3 KiB
Python

"""T93 (Phase 4): cross-chat FTS5 search across all owners and chats.
Verifies that ``chat.services.cross_chat_search.search_all_memories``:
* surfaces matches across multiple owner_ids (the per-owner restriction
used by ``state.memory.search_memories`` is intentionally absent),
* applies no witness filter (admin/power-user surface),
* orders results by FTS5 BM25 rank (lower = stronger match, surfaced
first), and
* honours the ``k`` LIMIT and the empty-query fast-path.
"""
from __future__ import annotations
from chat.db.connection import open_db
from chat.db.migrate import apply_migrations
from chat.eventlog.log import append_event
from chat.eventlog.projector import project
from chat.services.cross_chat_search import search_all_memories
import chat.state.memory # noqa: F401 (registers memory_written handler)
def _seed(db, *, memory_specs):
"""Apply migrations + project a list of memory_written events."""
apply_migrations(db)
with open_db(db) as conn:
for spec in memory_specs:
payload = {
"owner_id": spec.get("owner_id", "bot_a"),
"chat_id": spec.get("chat_id", "chat_bot_a"),
"pov_summary": spec["pov_summary"],
"witness_you": spec.get("witness_you", 1),
"witness_host": spec.get("witness_host", 1),
"witness_guest": spec.get("witness_guest", 0),
"source": "direct",
"reliability": 1.0,
"significance": spec.get("significance", 1),
"pinned": 0,
"auto_pinned": 0,
}
append_event(conn, kind="memory_written", payload=payload)
project(conn)
def test_search_all_memories_returns_matches_across_owners(tmp_path):
"""Cross-owner: a single query must surface memories from every owner.
The per-owner ``owner_id = ?`` predicate that ``search_memories`` uses
is intentionally absent here, so a "rabbit" memory under ``bot_a`` and
one under ``bot_b`` should both come back from a single call.
"""
db = tmp_path / "t.db"
_seed(
db,
memory_specs=[
{
"owner_id": "bot_a",
"chat_id": "chat_bot_a",
"pov_summary": "the rabbit darted into the brambles",
},
{
"owner_id": "bot_b",
"chat_id": "chat_bot_b",
"pov_summary": "a white rabbit watched from the hedge",
},
# Distractor: must not appear for "rabbit".
{
"owner_id": "bot_a",
"chat_id": "chat_bot_a",
"pov_summary": "the kettle whistled",
},
],
)
with open_db(db) as conn:
out = search_all_memories(conn, query="rabbit")
owners = {row["owner_id"] for row in out}
assert owners == {"bot_a", "bot_b"}
assert len(out) == 2
# Returned shape contract.
for row in out:
assert set(row.keys()) >= {
"memory_id",
"owner_id",
"chat_id",
"scene_id",
"pov_summary",
"significance",
"ts",
"fts_rank",
}
def test_search_all_memories_orders_by_fts_rank(tmp_path):
"""Stronger BM25 match must come first (rank ASC = lower is better)."""
db = tmp_path / "t.db"
_seed(
db,
memory_specs=[
# Single occurrence -> weaker BM25 score.
{
"owner_id": "bot_a",
"chat_id": "chat_bot_a",
"pov_summary": "a rabbit appeared",
},
# Triple occurrence in a short row -> stronger BM25 score.
{
"owner_id": "bot_b",
"chat_id": "chat_bot_b",
"pov_summary": "rabbit rabbit rabbit",
},
],
)
with open_db(db) as conn:
out = search_all_memories(conn, query="rabbit", k=5)
assert len(out) == 2
# Stronger match first; fts_rank monotonically non-decreasing
# (lower-is-better, so ASC).
assert out[0]["pov_summary"] == "rabbit rabbit rabbit"
assert out[0]["fts_rank"] <= out[1]["fts_rank"]
def test_search_all_memories_respects_k_limit(tmp_path):
"""LIMIT ? must cap result count even when more matches exist."""
db = tmp_path / "t.db"
_seed(
db,
memory_specs=[
{
"owner_id": f"bot_{i}",
"chat_id": f"chat_{i}",
"pov_summary": f"rabbit sighting number {i}",
}
for i in range(10)
],
)
with open_db(db) as conn:
out = search_all_memories(conn, query="rabbit", k=3)
assert len(out) == 3
def test_search_all_memories_empty_query_returns_empty(tmp_path):
"""Empty / whitespace-only query must short-circuit to []."""
db = tmp_path / "t.db"
_seed(
db,
memory_specs=[
{
"owner_id": "bot_a",
"chat_id": "chat_bot_a",
"pov_summary": "the rabbit darted into the brambles",
},
],
)
with open_db(db) as conn:
assert search_all_memories(conn, query="") == []
assert search_all_memories(conn, query=" ") == []