Files
chat/tests/test_vector_search.py
T
2026-04-27 02:31:06 -04:00

243 lines
6.6 KiB
Python

from __future__ import annotations
import pytest
from chat.db.connection import open_db
from chat.db.migrate import apply_migrations
from chat.eventlog.log import append_event
from chat.eventlog.projector import project
import chat.state.memory # registers memory_written handler
import chat.state.embeddings # registers embedding handlers
from chat.services.vector_search import vector_search
def _base_memory(**overrides):
payload = {
"owner_id": "bot_a",
"chat_id": "chat_bot_a",
"scene_id": 1,
"pov_summary": "She laughed at his joke about owls.",
"witness_you": 1,
"witness_host": 1,
"witness_guest": 0,
"chat_clock_at": "2026-04-26T10:00:00",
"source": "direct",
"reliability": 1.0,
"significance": 1,
"pinned": 0,
"auto_pinned": 0,
}
payload.update(overrides)
return payload
def _one_hot(dim: int, idx: int) -> list[float]:
"""Return a one-hot vector of length ``dim`` with 1.0 at ``idx``."""
v = [0.0] * dim
v[idx] = 1.0
return v
def _seed_memory_with_embedding(
conn,
*,
owner_id: str,
pov_summary: str,
vector: list[float],
significance: int = 1,
witness_you: int = 1,
witness_host: int = 1,
witness_guest: int = 0,
model: str = "test-model",
) -> int:
append_event(
conn,
kind="memory_written",
payload=_base_memory(
owner_id=owner_id,
pov_summary=pov_summary,
significance=significance,
witness_you=witness_you,
witness_host=witness_host,
witness_guest=witness_guest,
),
)
project(conn)
memory_id = conn.execute(
"SELECT id FROM memories WHERE pov_summary = ?", (pov_summary,)
).fetchone()[0]
append_event(
conn,
kind="embedding_indexed",
payload={
"memory_id": memory_id,
"vector": vector,
"model": model,
"dim": len(vector),
},
)
project(conn)
return memory_id
def test_vector_search_returns_nearest_neighbors(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
dim = 8
ids = []
for i in range(5):
mid = _seed_memory_with_embedding(
conn,
owner_id="bot_a",
pov_summary=f"Memory {i}.",
vector=_one_hot(dim, i),
)
ids.append(mid)
# Query close to memory index 3 (one-hot at position 3, plus tiny noise).
query = _one_hot(dim, 3)
query[2] = 0.01
results = vector_search(
conn,
owner_id="bot_a",
witness_role="you",
query_vector=query,
k=3,
)
assert len(results) == 3
# Top-1 must be memory at index 3.
assert results[0]["memory_id"] == ids[3]
assert results[0]["pov_summary"] == "Memory 3."
# Score for the near-perfect match should be very close to 1.0.
assert results[0]["score"] > 0.99
# Results sorted by score DESC.
scores = [r["score"] for r in results]
assert scores == sorted(scores, reverse=True)
# Second place should be memory index 2 (the small noise component).
assert results[1]["memory_id"] == ids[2]
def test_vector_search_respects_witness_filter(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
dim = 4
# Memory visible to you=1, host=1, guest=0.
_seed_memory_with_embedding(
conn,
owner_id="bot_a",
pov_summary="Restricted.",
vector=_one_hot(dim, 0),
witness_you=1,
witness_host=1,
witness_guest=0,
)
# Guest sees nothing.
guest_results = vector_search(
conn,
owner_id="bot_a",
witness_role="guest",
query_vector=_one_hot(dim, 0),
k=4,
)
assert guest_results == []
# Host sees the memory.
host_results = vector_search(
conn,
owner_id="bot_a",
witness_role="host",
query_vector=_one_hot(dim, 0),
k=4,
)
assert len(host_results) == 1
assert host_results[0]["pov_summary"] == "Restricted."
# You also see it.
you_results = vector_search(
conn,
owner_id="bot_a",
witness_role="you",
query_vector=_one_hot(dim, 0),
k=4,
)
assert len(you_results) == 1
def test_vector_search_respects_owner_filter(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
dim = 4
_seed_memory_with_embedding(
conn,
owner_id="bot_a",
pov_summary="Owner A memory.",
vector=_one_hot(dim, 0),
)
_seed_memory_with_embedding(
conn,
owner_id="bot_b",
pov_summary="Owner B memory.",
vector=_one_hot(dim, 0),
)
a_results = vector_search(
conn,
owner_id="bot_a",
witness_role="you",
query_vector=_one_hot(dim, 0),
k=10,
)
assert len(a_results) == 1
assert a_results[0]["pov_summary"] == "Owner A memory."
b_results = vector_search(
conn,
owner_id="bot_b",
witness_role="you",
query_vector=_one_hot(dim, 0),
k=10,
)
assert len(b_results) == 1
assert b_results[0]["pov_summary"] == "Owner B memory."
def test_vector_search_invalid_witness_role_raises(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
with pytest.raises(ValueError, match="witness_role"):
vector_search(
conn,
owner_id="bot_a",
witness_role="invalid",
query_vector=[1.0, 0.0, 0.0],
k=4,
)
def test_vector_search_empty_when_no_embeddings_indexed(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
# Seed a memory but don't index an embedding for it.
append_event(
conn,
kind="memory_written",
payload=_base_memory(owner_id="bot_a", pov_summary="No embedding here."),
)
project(conn)
results = vector_search(
conn,
owner_id="bot_a",
witness_role="you",
query_vector=[1.0, 0.0, 0.0, 0.0],
k=4,
)
assert results == []