merge: T92 pure-Python cosine vector search service
This commit is contained in:
@@ -0,0 +1,79 @@
|
||||
"""Vector search service (T92, Phase 4).
|
||||
|
||||
Pure-Python cosine similarity over the embeddings table. Phase 4 ships
|
||||
this without sqlite-vec because the host Python build doesn't support
|
||||
loadable extensions. For single-user scale (< few thousand memories
|
||||
per owner), iterating in Python is sub-millisecond.
|
||||
|
||||
Phase 4.5+ may swap to sqlite-vec when the host Python supports
|
||||
enable_load_extension; the public API stays stable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import math
|
||||
from sqlite3 import Connection
|
||||
|
||||
from chat.state.embeddings import list_embeddings_for_owner
|
||||
|
||||
|
||||
_VALID_WITNESS_ROLES = {"you", "host", "guest"}
|
||||
|
||||
|
||||
def _cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||
"""Cosine similarity. Assumes both vectors are non-zero."""
|
||||
if len(a) != len(b):
|
||||
return 0.0
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = math.sqrt(sum(x * x for x in a)) or 1.0
|
||||
norm_b = math.sqrt(sum(x * x for x in b)) or 1.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
|
||||
def vector_search(
|
||||
conn: Connection,
|
||||
*,
|
||||
owner_id: str,
|
||||
witness_role: str, # "you" | "host" | "guest"
|
||||
query_vector: list[float],
|
||||
k: int = 4,
|
||||
) -> list[dict]:
|
||||
"""Return top-K memories by cosine similarity to query_vector,
|
||||
witness-filtered for the viewer's POV. Returns rows with
|
||||
{memory_id, pov_summary, significance, score} sorted by score
|
||||
DESC. Empty list if no embeddings indexed for this owner.
|
||||
"""
|
||||
if witness_role not in _VALID_WITNESS_ROLES:
|
||||
raise ValueError(
|
||||
f"witness_role must be one of {_VALID_WITNESS_ROLES}, got {witness_role!r}"
|
||||
)
|
||||
|
||||
rows = list_embeddings_for_owner(conn, owner_id)
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
# Witness-filter by the requesting role.
|
||||
witness_key = f"witness_{witness_role}"
|
||||
filtered = [r for r in rows if r.get(witness_key) == 1]
|
||||
if not filtered:
|
||||
return []
|
||||
|
||||
scored: list[tuple[float, dict]] = []
|
||||
for row in filtered:
|
||||
score = _cosine_similarity(query_vector, row["vector"])
|
||||
scored.append(
|
||||
(
|
||||
score,
|
||||
{
|
||||
"memory_id": row["memory_id"],
|
||||
"pov_summary": row["pov_summary"],
|
||||
"significance": row["significance"],
|
||||
"score": score,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
scored.sort(key=lambda t: t[0], reverse=True)
|
||||
return [item for _, item in scored[:k]]
|
||||
|
||||
|
||||
__all__ = ["vector_search"]
|
||||
@@ -0,0 +1,242 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from chat.db.connection import open_db
|
||||
from chat.db.migrate import apply_migrations
|
||||
from chat.eventlog.log import append_event
|
||||
from chat.eventlog.projector import project
|
||||
import chat.state.memory # registers memory_written handler
|
||||
import chat.state.embeddings # registers embedding handlers
|
||||
from chat.services.vector_search import vector_search
|
||||
|
||||
|
||||
def _base_memory(**overrides):
|
||||
payload = {
|
||||
"owner_id": "bot_a",
|
||||
"chat_id": "chat_bot_a",
|
||||
"scene_id": 1,
|
||||
"pov_summary": "She laughed at his joke about owls.",
|
||||
"witness_you": 1,
|
||||
"witness_host": 1,
|
||||
"witness_guest": 0,
|
||||
"chat_clock_at": "2026-04-26T10:00:00",
|
||||
"source": "direct",
|
||||
"reliability": 1.0,
|
||||
"significance": 1,
|
||||
"pinned": 0,
|
||||
"auto_pinned": 0,
|
||||
}
|
||||
payload.update(overrides)
|
||||
return payload
|
||||
|
||||
|
||||
def _one_hot(dim: int, idx: int) -> list[float]:
|
||||
"""Return a one-hot vector of length ``dim`` with 1.0 at ``idx``."""
|
||||
v = [0.0] * dim
|
||||
v[idx] = 1.0
|
||||
return v
|
||||
|
||||
|
||||
def _seed_memory_with_embedding(
|
||||
conn,
|
||||
*,
|
||||
owner_id: str,
|
||||
pov_summary: str,
|
||||
vector: list[float],
|
||||
significance: int = 1,
|
||||
witness_you: int = 1,
|
||||
witness_host: int = 1,
|
||||
witness_guest: int = 0,
|
||||
model: str = "test-model",
|
||||
) -> int:
|
||||
append_event(
|
||||
conn,
|
||||
kind="memory_written",
|
||||
payload=_base_memory(
|
||||
owner_id=owner_id,
|
||||
pov_summary=pov_summary,
|
||||
significance=significance,
|
||||
witness_you=witness_you,
|
||||
witness_host=witness_host,
|
||||
witness_guest=witness_guest,
|
||||
),
|
||||
)
|
||||
project(conn)
|
||||
memory_id = conn.execute(
|
||||
"SELECT id FROM memories WHERE pov_summary = ?", (pov_summary,)
|
||||
).fetchone()[0]
|
||||
append_event(
|
||||
conn,
|
||||
kind="embedding_indexed",
|
||||
payload={
|
||||
"memory_id": memory_id,
|
||||
"vector": vector,
|
||||
"model": model,
|
||||
"dim": len(vector),
|
||||
},
|
||||
)
|
||||
project(conn)
|
||||
return memory_id
|
||||
|
||||
|
||||
def test_vector_search_returns_nearest_neighbors(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
with open_db(db) as conn:
|
||||
dim = 8
|
||||
ids = []
|
||||
for i in range(5):
|
||||
mid = _seed_memory_with_embedding(
|
||||
conn,
|
||||
owner_id="bot_a",
|
||||
pov_summary=f"Memory {i}.",
|
||||
vector=_one_hot(dim, i),
|
||||
)
|
||||
ids.append(mid)
|
||||
|
||||
# Query close to memory index 3 (one-hot at position 3, plus tiny noise).
|
||||
query = _one_hot(dim, 3)
|
||||
query[2] = 0.01
|
||||
|
||||
results = vector_search(
|
||||
conn,
|
||||
owner_id="bot_a",
|
||||
witness_role="you",
|
||||
query_vector=query,
|
||||
k=3,
|
||||
)
|
||||
assert len(results) == 3
|
||||
# Top-1 must be memory at index 3.
|
||||
assert results[0]["memory_id"] == ids[3]
|
||||
assert results[0]["pov_summary"] == "Memory 3."
|
||||
# Score for the near-perfect match should be very close to 1.0.
|
||||
assert results[0]["score"] > 0.99
|
||||
# Results sorted by score DESC.
|
||||
scores = [r["score"] for r in results]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
# Second place should be memory index 2 (the small noise component).
|
||||
assert results[1]["memory_id"] == ids[2]
|
||||
|
||||
|
||||
def test_vector_search_respects_witness_filter(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
with open_db(db) as conn:
|
||||
dim = 4
|
||||
# Memory visible to you=1, host=1, guest=0.
|
||||
_seed_memory_with_embedding(
|
||||
conn,
|
||||
owner_id="bot_a",
|
||||
pov_summary="Restricted.",
|
||||
vector=_one_hot(dim, 0),
|
||||
witness_you=1,
|
||||
witness_host=1,
|
||||
witness_guest=0,
|
||||
)
|
||||
|
||||
# Guest sees nothing.
|
||||
guest_results = vector_search(
|
||||
conn,
|
||||
owner_id="bot_a",
|
||||
witness_role="guest",
|
||||
query_vector=_one_hot(dim, 0),
|
||||
k=4,
|
||||
)
|
||||
assert guest_results == []
|
||||
|
||||
# Host sees the memory.
|
||||
host_results = vector_search(
|
||||
conn,
|
||||
owner_id="bot_a",
|
||||
witness_role="host",
|
||||
query_vector=_one_hot(dim, 0),
|
||||
k=4,
|
||||
)
|
||||
assert len(host_results) == 1
|
||||
assert host_results[0]["pov_summary"] == "Restricted."
|
||||
|
||||
# You also see it.
|
||||
you_results = vector_search(
|
||||
conn,
|
||||
owner_id="bot_a",
|
||||
witness_role="you",
|
||||
query_vector=_one_hot(dim, 0),
|
||||
k=4,
|
||||
)
|
||||
assert len(you_results) == 1
|
||||
|
||||
|
||||
def test_vector_search_respects_owner_filter(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
with open_db(db) as conn:
|
||||
dim = 4
|
||||
_seed_memory_with_embedding(
|
||||
conn,
|
||||
owner_id="bot_a",
|
||||
pov_summary="Owner A memory.",
|
||||
vector=_one_hot(dim, 0),
|
||||
)
|
||||
_seed_memory_with_embedding(
|
||||
conn,
|
||||
owner_id="bot_b",
|
||||
pov_summary="Owner B memory.",
|
||||
vector=_one_hot(dim, 0),
|
||||
)
|
||||
|
||||
a_results = vector_search(
|
||||
conn,
|
||||
owner_id="bot_a",
|
||||
witness_role="you",
|
||||
query_vector=_one_hot(dim, 0),
|
||||
k=10,
|
||||
)
|
||||
assert len(a_results) == 1
|
||||
assert a_results[0]["pov_summary"] == "Owner A memory."
|
||||
|
||||
b_results = vector_search(
|
||||
conn,
|
||||
owner_id="bot_b",
|
||||
witness_role="you",
|
||||
query_vector=_one_hot(dim, 0),
|
||||
k=10,
|
||||
)
|
||||
assert len(b_results) == 1
|
||||
assert b_results[0]["pov_summary"] == "Owner B memory."
|
||||
|
||||
|
||||
def test_vector_search_invalid_witness_role_raises(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
with open_db(db) as conn:
|
||||
with pytest.raises(ValueError, match="witness_role"):
|
||||
vector_search(
|
||||
conn,
|
||||
owner_id="bot_a",
|
||||
witness_role="invalid",
|
||||
query_vector=[1.0, 0.0, 0.0],
|
||||
k=4,
|
||||
)
|
||||
|
||||
|
||||
def test_vector_search_empty_when_no_embeddings_indexed(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
with open_db(db) as conn:
|
||||
# Seed a memory but don't index an embedding for it.
|
||||
append_event(
|
||||
conn,
|
||||
kind="memory_written",
|
||||
payload=_base_memory(owner_id="bot_a", pov_summary="No embedding here."),
|
||||
)
|
||||
project(conn)
|
||||
|
||||
results = vector_search(
|
||||
conn,
|
||||
owner_id="bot_a",
|
||||
witness_role="you",
|
||||
query_vector=[1.0, 0.0, 0.0, 0.0],
|
||||
k=4,
|
||||
)
|
||||
assert results == []
|
||||
Reference in New Issue
Block a user