243 lines
6.6 KiB
Python
243 lines
6.6 KiB
Python
from __future__ import annotations
|
|
|
|
import pytest
|
|
|
|
from chat.db.connection import open_db
|
|
from chat.db.migrate import apply_migrations
|
|
from chat.eventlog.log import append_event
|
|
from chat.eventlog.projector import project
|
|
import chat.state.memory # registers memory_written handler
|
|
import chat.state.embeddings # registers embedding handlers
|
|
from chat.services.vector_search import vector_search
|
|
|
|
|
|
def _base_memory(**overrides):
|
|
payload = {
|
|
"owner_id": "bot_a",
|
|
"chat_id": "chat_bot_a",
|
|
"scene_id": 1,
|
|
"pov_summary": "She laughed at his joke about owls.",
|
|
"witness_you": 1,
|
|
"witness_host": 1,
|
|
"witness_guest": 0,
|
|
"chat_clock_at": "2026-04-26T10:00:00",
|
|
"source": "direct",
|
|
"reliability": 1.0,
|
|
"significance": 1,
|
|
"pinned": 0,
|
|
"auto_pinned": 0,
|
|
}
|
|
payload.update(overrides)
|
|
return payload
|
|
|
|
|
|
def _one_hot(dim: int, idx: int) -> list[float]:
|
|
"""Return a one-hot vector of length ``dim`` with 1.0 at ``idx``."""
|
|
v = [0.0] * dim
|
|
v[idx] = 1.0
|
|
return v
|
|
|
|
|
|
def _seed_memory_with_embedding(
|
|
conn,
|
|
*,
|
|
owner_id: str,
|
|
pov_summary: str,
|
|
vector: list[float],
|
|
significance: int = 1,
|
|
witness_you: int = 1,
|
|
witness_host: int = 1,
|
|
witness_guest: int = 0,
|
|
model: str = "test-model",
|
|
) -> int:
|
|
append_event(
|
|
conn,
|
|
kind="memory_written",
|
|
payload=_base_memory(
|
|
owner_id=owner_id,
|
|
pov_summary=pov_summary,
|
|
significance=significance,
|
|
witness_you=witness_you,
|
|
witness_host=witness_host,
|
|
witness_guest=witness_guest,
|
|
),
|
|
)
|
|
project(conn)
|
|
memory_id = conn.execute(
|
|
"SELECT id FROM memories WHERE pov_summary = ?", (pov_summary,)
|
|
).fetchone()[0]
|
|
append_event(
|
|
conn,
|
|
kind="embedding_indexed",
|
|
payload={
|
|
"memory_id": memory_id,
|
|
"vector": vector,
|
|
"model": model,
|
|
"dim": len(vector),
|
|
},
|
|
)
|
|
project(conn)
|
|
return memory_id
|
|
|
|
|
|
def test_vector_search_returns_nearest_neighbors(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
with open_db(db) as conn:
|
|
dim = 8
|
|
ids = []
|
|
for i in range(5):
|
|
mid = _seed_memory_with_embedding(
|
|
conn,
|
|
owner_id="bot_a",
|
|
pov_summary=f"Memory {i}.",
|
|
vector=_one_hot(dim, i),
|
|
)
|
|
ids.append(mid)
|
|
|
|
# Query close to memory index 3 (one-hot at position 3, plus tiny noise).
|
|
query = _one_hot(dim, 3)
|
|
query[2] = 0.01
|
|
|
|
results = vector_search(
|
|
conn,
|
|
owner_id="bot_a",
|
|
witness_role="you",
|
|
query_vector=query,
|
|
k=3,
|
|
)
|
|
assert len(results) == 3
|
|
# Top-1 must be memory at index 3.
|
|
assert results[0]["memory_id"] == ids[3]
|
|
assert results[0]["pov_summary"] == "Memory 3."
|
|
# Score for the near-perfect match should be very close to 1.0.
|
|
assert results[0]["score"] > 0.99
|
|
# Results sorted by score DESC.
|
|
scores = [r["score"] for r in results]
|
|
assert scores == sorted(scores, reverse=True)
|
|
# Second place should be memory index 2 (the small noise component).
|
|
assert results[1]["memory_id"] == ids[2]
|
|
|
|
|
|
def test_vector_search_respects_witness_filter(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
with open_db(db) as conn:
|
|
dim = 4
|
|
# Memory visible to you=1, host=1, guest=0.
|
|
_seed_memory_with_embedding(
|
|
conn,
|
|
owner_id="bot_a",
|
|
pov_summary="Restricted.",
|
|
vector=_one_hot(dim, 0),
|
|
witness_you=1,
|
|
witness_host=1,
|
|
witness_guest=0,
|
|
)
|
|
|
|
# Guest sees nothing.
|
|
guest_results = vector_search(
|
|
conn,
|
|
owner_id="bot_a",
|
|
witness_role="guest",
|
|
query_vector=_one_hot(dim, 0),
|
|
k=4,
|
|
)
|
|
assert guest_results == []
|
|
|
|
# Host sees the memory.
|
|
host_results = vector_search(
|
|
conn,
|
|
owner_id="bot_a",
|
|
witness_role="host",
|
|
query_vector=_one_hot(dim, 0),
|
|
k=4,
|
|
)
|
|
assert len(host_results) == 1
|
|
assert host_results[0]["pov_summary"] == "Restricted."
|
|
|
|
# You also see it.
|
|
you_results = vector_search(
|
|
conn,
|
|
owner_id="bot_a",
|
|
witness_role="you",
|
|
query_vector=_one_hot(dim, 0),
|
|
k=4,
|
|
)
|
|
assert len(you_results) == 1
|
|
|
|
|
|
def test_vector_search_respects_owner_filter(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
with open_db(db) as conn:
|
|
dim = 4
|
|
_seed_memory_with_embedding(
|
|
conn,
|
|
owner_id="bot_a",
|
|
pov_summary="Owner A memory.",
|
|
vector=_one_hot(dim, 0),
|
|
)
|
|
_seed_memory_with_embedding(
|
|
conn,
|
|
owner_id="bot_b",
|
|
pov_summary="Owner B memory.",
|
|
vector=_one_hot(dim, 0),
|
|
)
|
|
|
|
a_results = vector_search(
|
|
conn,
|
|
owner_id="bot_a",
|
|
witness_role="you",
|
|
query_vector=_one_hot(dim, 0),
|
|
k=10,
|
|
)
|
|
assert len(a_results) == 1
|
|
assert a_results[0]["pov_summary"] == "Owner A memory."
|
|
|
|
b_results = vector_search(
|
|
conn,
|
|
owner_id="bot_b",
|
|
witness_role="you",
|
|
query_vector=_one_hot(dim, 0),
|
|
k=10,
|
|
)
|
|
assert len(b_results) == 1
|
|
assert b_results[0]["pov_summary"] == "Owner B memory."
|
|
|
|
|
|
def test_vector_search_invalid_witness_role_raises(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
with open_db(db) as conn:
|
|
with pytest.raises(ValueError, match="witness_role"):
|
|
vector_search(
|
|
conn,
|
|
owner_id="bot_a",
|
|
witness_role="invalid",
|
|
query_vector=[1.0, 0.0, 0.0],
|
|
k=4,
|
|
)
|
|
|
|
|
|
def test_vector_search_empty_when_no_embeddings_indexed(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
with open_db(db) as conn:
|
|
# Seed a memory but don't index an embedding for it.
|
|
append_event(
|
|
conn,
|
|
kind="memory_written",
|
|
payload=_base_memory(owner_id="bot_a", pov_summary="No embedding here."),
|
|
)
|
|
project(conn)
|
|
|
|
results = vector_search(
|
|
conn,
|
|
owner_id="bot_a",
|
|
witness_role="you",
|
|
query_vector=[1.0, 0.0, 0.0, 0.0],
|
|
k=4,
|
|
)
|
|
assert results == []
|