Files
chat/tests/test_embeddings_state.py
T

219 lines
6.6 KiB
Python

from __future__ import annotations
from chat.db.connection import open_db
from chat.db.migrate import apply_migrations
from chat.eventlog.log import append_event
from chat.eventlog.projector import project
import chat.state.memory # registers memory_written handler
import chat.state.embeddings # registers embedding handlers
from chat.state.embeddings import get_embedding, list_embeddings_for_owner
def _base_memory(**overrides):
payload = {
"owner_id": "bot_a",
"chat_id": "chat_bot_a",
"scene_id": 1,
"pov_summary": "She laughed at his joke about owls.",
"witness_you": 1,
"witness_host": 1,
"witness_guest": 0,
"chat_clock_at": "2026-04-26T10:00:00",
"source": "direct",
"reliability": 1.0,
"significance": 1,
"pinned": 0,
"auto_pinned": 0,
}
payload.update(overrides)
return payload
def _vec(n: int = 384, base: float = 0.1) -> list[float]:
"""Return a length-n float vector with predictable values for assertions."""
return [round(base + i * 0.001, 6) for i in range(n)]
def test_embedding_indexed_inserts_row(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
append_event(conn, kind="memory_written", payload=_base_memory())
project(conn)
memory_id = conn.execute("SELECT id FROM memories").fetchone()[0]
vector = _vec(384, base=0.1)
append_event(
conn,
kind="embedding_indexed",
payload={
"memory_id": memory_id,
"vector": vector,
"model": "test-model",
"dim": 384,
},
)
project(conn)
emb = get_embedding(conn, memory_id)
assert emb is not None
assert emb["memory_id"] == memory_id
assert emb["vector"] == vector
assert emb["model"] == "test-model"
assert emb["dim"] == 384
assert emb["indexed_at"] is not None
def test_embedding_deindexed_removes_row(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
append_event(conn, kind="memory_written", payload=_base_memory())
project(conn)
memory_id = conn.execute("SELECT id FROM memories").fetchone()[0]
append_event(
conn,
kind="embedding_indexed",
payload={
"memory_id": memory_id,
"vector": _vec(),
"model": "test-model",
"dim": 384,
},
)
project(conn)
assert get_embedding(conn, memory_id) is not None
append_event(
conn,
kind="embedding_deindexed",
payload={"memory_id": memory_id},
)
project(conn)
assert get_embedding(conn, memory_id) is None
def test_embedding_indexed_replaces_existing(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
append_event(conn, kind="memory_written", payload=_base_memory())
project(conn)
memory_id = conn.execute("SELECT id FROM memories").fetchone()[0]
vec_a = _vec(384, base=0.1)
vec_b = _vec(384, base=0.5)
append_event(
conn,
kind="embedding_indexed",
payload={
"memory_id": memory_id,
"vector": vec_a,
"model": "test-model",
"dim": 384,
},
)
project(conn)
first = get_embedding(conn, memory_id)
assert first is not None
assert first["vector"] == vec_a
append_event(
conn,
kind="embedding_indexed",
payload={
"memory_id": memory_id,
"vector": vec_b,
"model": "test-model",
"dim": 384,
},
)
project(conn)
second = get_embedding(conn, memory_id)
assert second is not None
assert second["vector"] == vec_b
# Still exactly one row for this memory.
count = conn.execute(
"SELECT COUNT(*) FROM embeddings WHERE memory_id = ?", (memory_id,)
).fetchone()[0]
assert count == 1
def test_list_embeddings_for_owner_returns_joined_rows(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
# Two memories for bot_a, one for bot_b.
append_event(
conn,
kind="memory_written",
payload=_base_memory(
owner_id="bot_a",
pov_summary="Alpha memory.",
significance=2,
),
)
append_event(
conn,
kind="memory_written",
payload=_base_memory(
owner_id="bot_a",
pov_summary="Beta memory.",
significance=3,
),
)
append_event(
conn,
kind="memory_written",
payload=_base_memory(
owner_id="bot_b",
pov_summary="Gamma memory.",
significance=1,
),
)
project(conn)
rows = conn.execute(
"SELECT id, owner_id FROM memories ORDER BY id"
).fetchall()
# Index every memory with a distinct vector so we can check ordering.
for i, (mid, _owner) in enumerate(rows):
append_event(
conn,
kind="embedding_indexed",
payload={
"memory_id": mid,
"vector": _vec(384, base=0.1 * (i + 1)),
"model": "test-model",
"dim": 384,
},
)
project(conn)
a_rows = list_embeddings_for_owner(conn, "bot_a")
assert len(a_rows) == 2
summaries = {r["pov_summary"] for r in a_rows}
assert summaries == {"Alpha memory.", "Beta memory."}
sigs = {r["significance"] for r in a_rows}
assert sigs == {2, 3}
for r in a_rows:
assert r["model"] == "test-model"
assert r["dim"] == 384
assert isinstance(r["vector"], list)
assert len(r["vector"]) == 384
assert r["witness_you"] == 1
assert r["witness_host"] == 1
assert r["witness_guest"] == 0
b_rows = list_embeddings_for_owner(conn, "bot_b")
assert len(b_rows) == 1
assert b_rows[0]["pov_summary"] == "Gamma memory."
def test_get_embedding_returns_none_when_missing(tmp_path):
db = tmp_path / "t.db"
apply_migrations(db)
with open_db(db) as conn:
assert get_embedding(conn, 999) is None