219 lines
6.6 KiB
Python
219 lines
6.6 KiB
Python
from __future__ import annotations
|
|
|
|
from chat.db.connection import open_db
|
|
from chat.db.migrate import apply_migrations
|
|
from chat.eventlog.log import append_event
|
|
from chat.eventlog.projector import project
|
|
import chat.state.memory # registers memory_written handler
|
|
import chat.state.embeddings # registers embedding handlers
|
|
from chat.state.embeddings import get_embedding, list_embeddings_for_owner
|
|
|
|
|
|
def _base_memory(**overrides):
|
|
payload = {
|
|
"owner_id": "bot_a",
|
|
"chat_id": "chat_bot_a",
|
|
"scene_id": 1,
|
|
"pov_summary": "She laughed at his joke about owls.",
|
|
"witness_you": 1,
|
|
"witness_host": 1,
|
|
"witness_guest": 0,
|
|
"chat_clock_at": "2026-04-26T10:00:00",
|
|
"source": "direct",
|
|
"reliability": 1.0,
|
|
"significance": 1,
|
|
"pinned": 0,
|
|
"auto_pinned": 0,
|
|
}
|
|
payload.update(overrides)
|
|
return payload
|
|
|
|
|
|
def _vec(n: int = 384, base: float = 0.1) -> list[float]:
|
|
"""Return a length-n float vector with predictable values for assertions."""
|
|
return [round(base + i * 0.001, 6) for i in range(n)]
|
|
|
|
|
|
def test_embedding_indexed_inserts_row(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
with open_db(db) as conn:
|
|
append_event(conn, kind="memory_written", payload=_base_memory())
|
|
project(conn)
|
|
memory_id = conn.execute("SELECT id FROM memories").fetchone()[0]
|
|
|
|
vector = _vec(384, base=0.1)
|
|
append_event(
|
|
conn,
|
|
kind="embedding_indexed",
|
|
payload={
|
|
"memory_id": memory_id,
|
|
"vector": vector,
|
|
"model": "test-model",
|
|
"dim": 384,
|
|
},
|
|
)
|
|
project(conn)
|
|
|
|
emb = get_embedding(conn, memory_id)
|
|
assert emb is not None
|
|
assert emb["memory_id"] == memory_id
|
|
assert emb["vector"] == vector
|
|
assert emb["model"] == "test-model"
|
|
assert emb["dim"] == 384
|
|
assert emb["indexed_at"] is not None
|
|
|
|
|
|
def test_embedding_deindexed_removes_row(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
with open_db(db) as conn:
|
|
append_event(conn, kind="memory_written", payload=_base_memory())
|
|
project(conn)
|
|
memory_id = conn.execute("SELECT id FROM memories").fetchone()[0]
|
|
|
|
append_event(
|
|
conn,
|
|
kind="embedding_indexed",
|
|
payload={
|
|
"memory_id": memory_id,
|
|
"vector": _vec(),
|
|
"model": "test-model",
|
|
"dim": 384,
|
|
},
|
|
)
|
|
project(conn)
|
|
assert get_embedding(conn, memory_id) is not None
|
|
|
|
append_event(
|
|
conn,
|
|
kind="embedding_deindexed",
|
|
payload={"memory_id": memory_id},
|
|
)
|
|
project(conn)
|
|
assert get_embedding(conn, memory_id) is None
|
|
|
|
|
|
def test_embedding_indexed_replaces_existing(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
with open_db(db) as conn:
|
|
append_event(conn, kind="memory_written", payload=_base_memory())
|
|
project(conn)
|
|
memory_id = conn.execute("SELECT id FROM memories").fetchone()[0]
|
|
|
|
vec_a = _vec(384, base=0.1)
|
|
vec_b = _vec(384, base=0.5)
|
|
append_event(
|
|
conn,
|
|
kind="embedding_indexed",
|
|
payload={
|
|
"memory_id": memory_id,
|
|
"vector": vec_a,
|
|
"model": "test-model",
|
|
"dim": 384,
|
|
},
|
|
)
|
|
project(conn)
|
|
first = get_embedding(conn, memory_id)
|
|
assert first is not None
|
|
assert first["vector"] == vec_a
|
|
|
|
append_event(
|
|
conn,
|
|
kind="embedding_indexed",
|
|
payload={
|
|
"memory_id": memory_id,
|
|
"vector": vec_b,
|
|
"model": "test-model",
|
|
"dim": 384,
|
|
},
|
|
)
|
|
project(conn)
|
|
second = get_embedding(conn, memory_id)
|
|
assert second is not None
|
|
assert second["vector"] == vec_b
|
|
# Still exactly one row for this memory.
|
|
count = conn.execute(
|
|
"SELECT COUNT(*) FROM embeddings WHERE memory_id = ?", (memory_id,)
|
|
).fetchone()[0]
|
|
assert count == 1
|
|
|
|
|
|
def test_list_embeddings_for_owner_returns_joined_rows(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
with open_db(db) as conn:
|
|
# Two memories for bot_a, one for bot_b.
|
|
append_event(
|
|
conn,
|
|
kind="memory_written",
|
|
payload=_base_memory(
|
|
owner_id="bot_a",
|
|
pov_summary="Alpha memory.",
|
|
significance=2,
|
|
),
|
|
)
|
|
append_event(
|
|
conn,
|
|
kind="memory_written",
|
|
payload=_base_memory(
|
|
owner_id="bot_a",
|
|
pov_summary="Beta memory.",
|
|
significance=3,
|
|
),
|
|
)
|
|
append_event(
|
|
conn,
|
|
kind="memory_written",
|
|
payload=_base_memory(
|
|
owner_id="bot_b",
|
|
pov_summary="Gamma memory.",
|
|
significance=1,
|
|
),
|
|
)
|
|
project(conn)
|
|
|
|
rows = conn.execute(
|
|
"SELECT id, owner_id FROM memories ORDER BY id"
|
|
).fetchall()
|
|
# Index every memory with a distinct vector so we can check ordering.
|
|
for i, (mid, _owner) in enumerate(rows):
|
|
append_event(
|
|
conn,
|
|
kind="embedding_indexed",
|
|
payload={
|
|
"memory_id": mid,
|
|
"vector": _vec(384, base=0.1 * (i + 1)),
|
|
"model": "test-model",
|
|
"dim": 384,
|
|
},
|
|
)
|
|
project(conn)
|
|
|
|
a_rows = list_embeddings_for_owner(conn, "bot_a")
|
|
assert len(a_rows) == 2
|
|
summaries = {r["pov_summary"] for r in a_rows}
|
|
assert summaries == {"Alpha memory.", "Beta memory."}
|
|
sigs = {r["significance"] for r in a_rows}
|
|
assert sigs == {2, 3}
|
|
for r in a_rows:
|
|
assert r["model"] == "test-model"
|
|
assert r["dim"] == 384
|
|
assert isinstance(r["vector"], list)
|
|
assert len(r["vector"]) == 384
|
|
assert r["witness_you"] == 1
|
|
assert r["witness_host"] == 1
|
|
assert r["witness_guest"] == 0
|
|
|
|
b_rows = list_embeddings_for_owner(conn, "bot_b")
|
|
assert len(b_rows) == 1
|
|
assert b_rows[0]["pov_summary"] == "Gamma memory."
|
|
|
|
|
|
def test_get_embedding_returns_none_when_missing(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
with open_db(db) as conn:
|
|
assert get_embedding(conn, 999) is None
|