feat: pure-Python cosine vector search service (T92)
This commit is contained in:
@@ -0,0 +1,79 @@
|
|||||||
|
"""Vector search service (T92, Phase 4).
|
||||||
|
|
||||||
|
Pure-Python cosine similarity over the embeddings table. Phase 4 ships
|
||||||
|
this without sqlite-vec because the host Python build doesn't support
|
||||||
|
loadable extensions. For single-user scale (< few thousand memories
|
||||||
|
per owner), iterating in Python is sub-millisecond.
|
||||||
|
|
||||||
|
Phase 4.5+ may swap to sqlite-vec when the host Python supports
|
||||||
|
enable_load_extension; the public API stays stable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import math
|
||||||
|
from sqlite3 import Connection
|
||||||
|
|
||||||
|
from chat.state.embeddings import list_embeddings_for_owner
|
||||||
|
|
||||||
|
|
||||||
|
_VALID_WITNESS_ROLES = {"you", "host", "guest"}
|
||||||
|
|
||||||
|
|
||||||
|
def _cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||||
|
"""Cosine similarity. Assumes both vectors are non-zero."""
|
||||||
|
if len(a) != len(b):
|
||||||
|
return 0.0
|
||||||
|
dot = sum(x * y for x, y in zip(a, b))
|
||||||
|
norm_a = math.sqrt(sum(x * x for x in a)) or 1.0
|
||||||
|
norm_b = math.sqrt(sum(x * x for x in b)) or 1.0
|
||||||
|
return dot / (norm_a * norm_b)
|
||||||
|
|
||||||
|
|
||||||
|
def vector_search(
|
||||||
|
conn: Connection,
|
||||||
|
*,
|
||||||
|
owner_id: str,
|
||||||
|
witness_role: str, # "you" | "host" | "guest"
|
||||||
|
query_vector: list[float],
|
||||||
|
k: int = 4,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Return top-K memories by cosine similarity to query_vector,
|
||||||
|
witness-filtered for the viewer's POV. Returns rows with
|
||||||
|
{memory_id, pov_summary, significance, score} sorted by score
|
||||||
|
DESC. Empty list if no embeddings indexed for this owner.
|
||||||
|
"""
|
||||||
|
if witness_role not in _VALID_WITNESS_ROLES:
|
||||||
|
raise ValueError(
|
||||||
|
f"witness_role must be one of {_VALID_WITNESS_ROLES}, got {witness_role!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = list_embeddings_for_owner(conn, owner_id)
|
||||||
|
if not rows:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Witness-filter by the requesting role.
|
||||||
|
witness_key = f"witness_{witness_role}"
|
||||||
|
filtered = [r for r in rows if r.get(witness_key) == 1]
|
||||||
|
if not filtered:
|
||||||
|
return []
|
||||||
|
|
||||||
|
scored: list[tuple[float, dict]] = []
|
||||||
|
for row in filtered:
|
||||||
|
score = _cosine_similarity(query_vector, row["vector"])
|
||||||
|
scored.append(
|
||||||
|
(
|
||||||
|
score,
|
||||||
|
{
|
||||||
|
"memory_id": row["memory_id"],
|
||||||
|
"pov_summary": row["pov_summary"],
|
||||||
|
"significance": row["significance"],
|
||||||
|
"score": score,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
scored.sort(key=lambda t: t[0], reverse=True)
|
||||||
|
return [item for _, item in scored[:k]]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["vector_search"]
|
||||||
@@ -0,0 +1,242 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from chat.db.connection import open_db
|
||||||
|
from chat.db.migrate import apply_migrations
|
||||||
|
from chat.eventlog.log import append_event
|
||||||
|
from chat.eventlog.projector import project
|
||||||
|
import chat.state.memory # registers memory_written handler
|
||||||
|
import chat.state.embeddings # registers embedding handlers
|
||||||
|
from chat.services.vector_search import vector_search
|
||||||
|
|
||||||
|
|
||||||
|
def _base_memory(**overrides):
|
||||||
|
payload = {
|
||||||
|
"owner_id": "bot_a",
|
||||||
|
"chat_id": "chat_bot_a",
|
||||||
|
"scene_id": 1,
|
||||||
|
"pov_summary": "She laughed at his joke about owls.",
|
||||||
|
"witness_you": 1,
|
||||||
|
"witness_host": 1,
|
||||||
|
"witness_guest": 0,
|
||||||
|
"chat_clock_at": "2026-04-26T10:00:00",
|
||||||
|
"source": "direct",
|
||||||
|
"reliability": 1.0,
|
||||||
|
"significance": 1,
|
||||||
|
"pinned": 0,
|
||||||
|
"auto_pinned": 0,
|
||||||
|
}
|
||||||
|
payload.update(overrides)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def _one_hot(dim: int, idx: int) -> list[float]:
|
||||||
|
"""Return a one-hot vector of length ``dim`` with 1.0 at ``idx``."""
|
||||||
|
v = [0.0] * dim
|
||||||
|
v[idx] = 1.0
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def _seed_memory_with_embedding(
|
||||||
|
conn,
|
||||||
|
*,
|
||||||
|
owner_id: str,
|
||||||
|
pov_summary: str,
|
||||||
|
vector: list[float],
|
||||||
|
significance: int = 1,
|
||||||
|
witness_you: int = 1,
|
||||||
|
witness_host: int = 1,
|
||||||
|
witness_guest: int = 0,
|
||||||
|
model: str = "test-model",
|
||||||
|
) -> int:
|
||||||
|
append_event(
|
||||||
|
conn,
|
||||||
|
kind="memory_written",
|
||||||
|
payload=_base_memory(
|
||||||
|
owner_id=owner_id,
|
||||||
|
pov_summary=pov_summary,
|
||||||
|
significance=significance,
|
||||||
|
witness_you=witness_you,
|
||||||
|
witness_host=witness_host,
|
||||||
|
witness_guest=witness_guest,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
project(conn)
|
||||||
|
memory_id = conn.execute(
|
||||||
|
"SELECT id FROM memories WHERE pov_summary = ?", (pov_summary,)
|
||||||
|
).fetchone()[0]
|
||||||
|
append_event(
|
||||||
|
conn,
|
||||||
|
kind="embedding_indexed",
|
||||||
|
payload={
|
||||||
|
"memory_id": memory_id,
|
||||||
|
"vector": vector,
|
||||||
|
"model": model,
|
||||||
|
"dim": len(vector),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
project(conn)
|
||||||
|
return memory_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_vector_search_returns_nearest_neighbors(tmp_path):
|
||||||
|
db = tmp_path / "t.db"
|
||||||
|
apply_migrations(db)
|
||||||
|
with open_db(db) as conn:
|
||||||
|
dim = 8
|
||||||
|
ids = []
|
||||||
|
for i in range(5):
|
||||||
|
mid = _seed_memory_with_embedding(
|
||||||
|
conn,
|
||||||
|
owner_id="bot_a",
|
||||||
|
pov_summary=f"Memory {i}.",
|
||||||
|
vector=_one_hot(dim, i),
|
||||||
|
)
|
||||||
|
ids.append(mid)
|
||||||
|
|
||||||
|
# Query close to memory index 3 (one-hot at position 3, plus tiny noise).
|
||||||
|
query = _one_hot(dim, 3)
|
||||||
|
query[2] = 0.01
|
||||||
|
|
||||||
|
results = vector_search(
|
||||||
|
conn,
|
||||||
|
owner_id="bot_a",
|
||||||
|
witness_role="you",
|
||||||
|
query_vector=query,
|
||||||
|
k=3,
|
||||||
|
)
|
||||||
|
assert len(results) == 3
|
||||||
|
# Top-1 must be memory at index 3.
|
||||||
|
assert results[0]["memory_id"] == ids[3]
|
||||||
|
assert results[0]["pov_summary"] == "Memory 3."
|
||||||
|
# Score for the near-perfect match should be very close to 1.0.
|
||||||
|
assert results[0]["score"] > 0.99
|
||||||
|
# Results sorted by score DESC.
|
||||||
|
scores = [r["score"] for r in results]
|
||||||
|
assert scores == sorted(scores, reverse=True)
|
||||||
|
# Second place should be memory index 2 (the small noise component).
|
||||||
|
assert results[1]["memory_id"] == ids[2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_vector_search_respects_witness_filter(tmp_path):
|
||||||
|
db = tmp_path / "t.db"
|
||||||
|
apply_migrations(db)
|
||||||
|
with open_db(db) as conn:
|
||||||
|
dim = 4
|
||||||
|
# Memory visible to you=1, host=1, guest=0.
|
||||||
|
_seed_memory_with_embedding(
|
||||||
|
conn,
|
||||||
|
owner_id="bot_a",
|
||||||
|
pov_summary="Restricted.",
|
||||||
|
vector=_one_hot(dim, 0),
|
||||||
|
witness_you=1,
|
||||||
|
witness_host=1,
|
||||||
|
witness_guest=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Guest sees nothing.
|
||||||
|
guest_results = vector_search(
|
||||||
|
conn,
|
||||||
|
owner_id="bot_a",
|
||||||
|
witness_role="guest",
|
||||||
|
query_vector=_one_hot(dim, 0),
|
||||||
|
k=4,
|
||||||
|
)
|
||||||
|
assert guest_results == []
|
||||||
|
|
||||||
|
# Host sees the memory.
|
||||||
|
host_results = vector_search(
|
||||||
|
conn,
|
||||||
|
owner_id="bot_a",
|
||||||
|
witness_role="host",
|
||||||
|
query_vector=_one_hot(dim, 0),
|
||||||
|
k=4,
|
||||||
|
)
|
||||||
|
assert len(host_results) == 1
|
||||||
|
assert host_results[0]["pov_summary"] == "Restricted."
|
||||||
|
|
||||||
|
# You also see it.
|
||||||
|
you_results = vector_search(
|
||||||
|
conn,
|
||||||
|
owner_id="bot_a",
|
||||||
|
witness_role="you",
|
||||||
|
query_vector=_one_hot(dim, 0),
|
||||||
|
k=4,
|
||||||
|
)
|
||||||
|
assert len(you_results) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_vector_search_respects_owner_filter(tmp_path):
|
||||||
|
db = tmp_path / "t.db"
|
||||||
|
apply_migrations(db)
|
||||||
|
with open_db(db) as conn:
|
||||||
|
dim = 4
|
||||||
|
_seed_memory_with_embedding(
|
||||||
|
conn,
|
||||||
|
owner_id="bot_a",
|
||||||
|
pov_summary="Owner A memory.",
|
||||||
|
vector=_one_hot(dim, 0),
|
||||||
|
)
|
||||||
|
_seed_memory_with_embedding(
|
||||||
|
conn,
|
||||||
|
owner_id="bot_b",
|
||||||
|
pov_summary="Owner B memory.",
|
||||||
|
vector=_one_hot(dim, 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
a_results = vector_search(
|
||||||
|
conn,
|
||||||
|
owner_id="bot_a",
|
||||||
|
witness_role="you",
|
||||||
|
query_vector=_one_hot(dim, 0),
|
||||||
|
k=10,
|
||||||
|
)
|
||||||
|
assert len(a_results) == 1
|
||||||
|
assert a_results[0]["pov_summary"] == "Owner A memory."
|
||||||
|
|
||||||
|
b_results = vector_search(
|
||||||
|
conn,
|
||||||
|
owner_id="bot_b",
|
||||||
|
witness_role="you",
|
||||||
|
query_vector=_one_hot(dim, 0),
|
||||||
|
k=10,
|
||||||
|
)
|
||||||
|
assert len(b_results) == 1
|
||||||
|
assert b_results[0]["pov_summary"] == "Owner B memory."
|
||||||
|
|
||||||
|
|
||||||
|
def test_vector_search_invalid_witness_role_raises(tmp_path):
|
||||||
|
db = tmp_path / "t.db"
|
||||||
|
apply_migrations(db)
|
||||||
|
with open_db(db) as conn:
|
||||||
|
with pytest.raises(ValueError, match="witness_role"):
|
||||||
|
vector_search(
|
||||||
|
conn,
|
||||||
|
owner_id="bot_a",
|
||||||
|
witness_role="invalid",
|
||||||
|
query_vector=[1.0, 0.0, 0.0],
|
||||||
|
k=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vector_search_empty_when_no_embeddings_indexed(tmp_path):
|
||||||
|
db = tmp_path / "t.db"
|
||||||
|
apply_migrations(db)
|
||||||
|
with open_db(db) as conn:
|
||||||
|
# Seed a memory but don't index an embedding for it.
|
||||||
|
append_event(
|
||||||
|
conn,
|
||||||
|
kind="memory_written",
|
||||||
|
payload=_base_memory(owner_id="bot_a", pov_summary="No embedding here."),
|
||||||
|
)
|
||||||
|
project(conn)
|
||||||
|
|
||||||
|
results = vector_search(
|
||||||
|
conn,
|
||||||
|
owner_id="bot_a",
|
||||||
|
witness_role="you",
|
||||||
|
query_vector=[1.0, 0.0, 0.0, 0.0],
|
||||||
|
k=4,
|
||||||
|
)
|
||||||
|
assert results == []
|
||||||
Reference in New Issue
Block a user