diff --git a/chat/services/vector_search.py b/chat/services/vector_search.py new file mode 100644 index 0000000..60b179d --- /dev/null +++ b/chat/services/vector_search.py @@ -0,0 +1,79 @@ +"""Vector search service (T92, Phase 4). + +Pure-Python cosine similarity over the embeddings table. Phase 4 ships +this without sqlite-vec because the host Python build doesn't support +loadable extensions. For single-user scale (< few thousand memories +per owner), iterating in Python is sub-millisecond. + +Phase 4.5+ may swap to sqlite-vec when the host Python supports +enable_load_extension; the public API stays stable. +""" + +from __future__ import annotations +import math +from sqlite3 import Connection + +from chat.state.embeddings import list_embeddings_for_owner + + +_VALID_WITNESS_ROLES = {"you", "host", "guest"} + + +def _cosine_similarity(a: list[float], b: list[float]) -> float: + """Cosine similarity. Assumes both vectors are non-zero.""" + if len(a) != len(b): + return 0.0 + dot = sum(x * y for x, y in zip(a, b)) + norm_a = math.sqrt(sum(x * x for x in a)) or 1.0 + norm_b = math.sqrt(sum(x * x for x in b)) or 1.0 + return dot / (norm_a * norm_b) + + +def vector_search( + conn: Connection, + *, + owner_id: str, + witness_role: str, # "you" | "host" | "guest" + query_vector: list[float], + k: int = 4, +) -> list[dict]: + """Return top-K memories by cosine similarity to query_vector, + witness-filtered for the viewer's POV. Returns rows with + {memory_id, pov_summary, significance, score} sorted by score + DESC. Empty list if no embeddings indexed for this owner. + """ + if witness_role not in _VALID_WITNESS_ROLES: + raise ValueError( + f"witness_role must be one of {_VALID_WITNESS_ROLES}, got {witness_role!r}" + ) + + rows = list_embeddings_for_owner(conn, owner_id) + if not rows: + return [] + + # Witness-filter by the requesting role. + witness_key = f"witness_{witness_role}" + filtered = [r for r in rows if r.get(witness_key) == 1] + if not filtered: + return [] + + scored: list[tuple[float, dict]] = [] + for row in filtered: + score = _cosine_similarity(query_vector, row["vector"]) + scored.append( + ( + score, + { + "memory_id": row["memory_id"], + "pov_summary": row["pov_summary"], + "significance": row["significance"], + "score": score, + }, + ) + ) + + scored.sort(key=lambda t: t[0], reverse=True) + return [item for _, item in scored[:k]] + + +__all__ = ["vector_search"] diff --git a/tests/test_vector_search.py b/tests/test_vector_search.py new file mode 100644 index 0000000..7801e80 --- /dev/null +++ b/tests/test_vector_search.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +import pytest + +from chat.db.connection import open_db +from chat.db.migrate import apply_migrations +from chat.eventlog.log import append_event +from chat.eventlog.projector import project +import chat.state.memory # registers memory_written handler +import chat.state.embeddings # registers embedding handlers +from chat.services.vector_search import vector_search + + +def _base_memory(**overrides): + payload = { + "owner_id": "bot_a", + "chat_id": "chat_bot_a", + "scene_id": 1, + "pov_summary": "She laughed at his joke about owls.", + "witness_you": 1, + "witness_host": 1, + "witness_guest": 0, + "chat_clock_at": "2026-04-26T10:00:00", + "source": "direct", + "reliability": 1.0, + "significance": 1, + "pinned": 0, + "auto_pinned": 0, + } + payload.update(overrides) + return payload + + +def _one_hot(dim: int, idx: int) -> list[float]: + """Return a one-hot vector of length ``dim`` with 1.0 at ``idx``.""" + v = [0.0] * dim + v[idx] = 1.0 + return v + + +def _seed_memory_with_embedding( + conn, + *, + owner_id: str, + pov_summary: str, + vector: list[float], + significance: int = 1, + witness_you: int = 1, + witness_host: int = 1, + witness_guest: int = 0, + model: str = "test-model", +) -> int: + append_event( + conn, + kind="memory_written", + payload=_base_memory( + owner_id=owner_id, + pov_summary=pov_summary, + significance=significance, + witness_you=witness_you, + witness_host=witness_host, + witness_guest=witness_guest, + ), + ) + project(conn) + memory_id = conn.execute( + "SELECT id FROM memories WHERE pov_summary = ?", (pov_summary,) + ).fetchone()[0] + append_event( + conn, + kind="embedding_indexed", + payload={ + "memory_id": memory_id, + "vector": vector, + "model": model, + "dim": len(vector), + }, + ) + project(conn) + return memory_id + + +def test_vector_search_returns_nearest_neighbors(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + dim = 8 + ids = [] + for i in range(5): + mid = _seed_memory_with_embedding( + conn, + owner_id="bot_a", + pov_summary=f"Memory {i}.", + vector=_one_hot(dim, i), + ) + ids.append(mid) + + # Query close to memory index 3 (one-hot at position 3, plus tiny noise). + query = _one_hot(dim, 3) + query[2] = 0.01 + + results = vector_search( + conn, + owner_id="bot_a", + witness_role="you", + query_vector=query, + k=3, + ) + assert len(results) == 3 + # Top-1 must be memory at index 3. + assert results[0]["memory_id"] == ids[3] + assert results[0]["pov_summary"] == "Memory 3." + # Score for the near-perfect match should be very close to 1.0. + assert results[0]["score"] > 0.99 + # Results sorted by score DESC. + scores = [r["score"] for r in results] + assert scores == sorted(scores, reverse=True) + # Second place should be memory index 2 (the small noise component). + assert results[1]["memory_id"] == ids[2] + + +def test_vector_search_respects_witness_filter(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + dim = 4 + # Memory visible to you=1, host=1, guest=0. + _seed_memory_with_embedding( + conn, + owner_id="bot_a", + pov_summary="Restricted.", + vector=_one_hot(dim, 0), + witness_you=1, + witness_host=1, + witness_guest=0, + ) + + # Guest sees nothing. + guest_results = vector_search( + conn, + owner_id="bot_a", + witness_role="guest", + query_vector=_one_hot(dim, 0), + k=4, + ) + assert guest_results == [] + + # Host sees the memory. + host_results = vector_search( + conn, + owner_id="bot_a", + witness_role="host", + query_vector=_one_hot(dim, 0), + k=4, + ) + assert len(host_results) == 1 + assert host_results[0]["pov_summary"] == "Restricted." + + # You also see it. + you_results = vector_search( + conn, + owner_id="bot_a", + witness_role="you", + query_vector=_one_hot(dim, 0), + k=4, + ) + assert len(you_results) == 1 + + +def test_vector_search_respects_owner_filter(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + dim = 4 + _seed_memory_with_embedding( + conn, + owner_id="bot_a", + pov_summary="Owner A memory.", + vector=_one_hot(dim, 0), + ) + _seed_memory_with_embedding( + conn, + owner_id="bot_b", + pov_summary="Owner B memory.", + vector=_one_hot(dim, 0), + ) + + a_results = vector_search( + conn, + owner_id="bot_a", + witness_role="you", + query_vector=_one_hot(dim, 0), + k=10, + ) + assert len(a_results) == 1 + assert a_results[0]["pov_summary"] == "Owner A memory." + + b_results = vector_search( + conn, + owner_id="bot_b", + witness_role="you", + query_vector=_one_hot(dim, 0), + k=10, + ) + assert len(b_results) == 1 + assert b_results[0]["pov_summary"] == "Owner B memory." + + +def test_vector_search_invalid_witness_role_raises(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + with pytest.raises(ValueError, match="witness_role"): + vector_search( + conn, + owner_id="bot_a", + witness_role="invalid", + query_vector=[1.0, 0.0, 0.0], + k=4, + ) + + +def test_vector_search_empty_when_no_embeddings_indexed(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + # Seed a memory but don't index an embedding for it. + append_event( + conn, + kind="memory_written", + payload=_base_memory(owner_id="bot_a", pov_summary="No embedding here."), + ) + project(conn) + + results = vector_search( + conn, + owner_id="bot_a", + witness_role="you", + query_vector=[1.0, 0.0, 0.0, 0.0], + k=4, + ) + assert results == []