feat: embeddings table + projector handlers (pure-Python cosine, T88)
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
-- Embeddings stored as JSON arrays (pure-Python cosine at query time).
|
||||
-- Phase 4.5+ may swap to sqlite-vec when the host Python supports
|
||||
-- loadable extensions; the schema is intentionally simple to make that
|
||||
-- migration straightforward.
|
||||
CREATE TABLE embeddings (
|
||||
memory_id INTEGER PRIMARY KEY,
|
||||
vector_json TEXT NOT NULL, -- JSON array of floats, length = dim
|
||||
model TEXT NOT NULL,
|
||||
dim INTEGER NOT NULL,
|
||||
indexed_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
FOREIGN KEY (memory_id) REFERENCES memories(id)
|
||||
);
|
||||
|
||||
CREATE INDEX embeddings_model_idx ON embeddings(model);
|
||||
@@ -0,0 +1,105 @@
|
||||
"""Embeddings projector + readers (T88, Phase 4).
|
||||
|
||||
Embeddings are stored as JSON-serialized float arrays in a regular
|
||||
SQLite table. Cosine similarity is computed in Python at query time
|
||||
(see chat/services/vector_search.py / T92). This deliberately avoids
|
||||
the sqlite-vec extension dependency — the host Python build doesn't
|
||||
support enable_load_extension. Phase 4.5+ may revisit if memory counts
|
||||
grow beyond pure-Python feasibility (~few thousand per query).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import json
|
||||
from sqlite3 import Connection
|
||||
|
||||
from chat.eventlog.projector import on
|
||||
from chat.eventlog.log import Event
|
||||
|
||||
|
||||
@on("embedding_indexed")
|
||||
def _apply_embedding_indexed(conn: Connection, e: Event) -> None:
|
||||
"""Insert or replace the embedding for a memory.
|
||||
|
||||
Idempotent: re-projection or re-indexing replaces the prior vector.
|
||||
"""
|
||||
p = e.payload
|
||||
vector = p["vector"]
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO embeddings "
|
||||
"(memory_id, vector_json, model, dim, indexed_at) "
|
||||
"VALUES (?, ?, ?, ?, datetime('now'))",
|
||||
(
|
||||
int(p["memory_id"]),
|
||||
json.dumps(list(vector)),
|
||||
p["model"],
|
||||
int(p.get("dim") or len(vector)),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@on("embedding_deindexed")
|
||||
def _apply_embedding_deindexed(conn: Connection, e: Event) -> None:
|
||||
"""Remove the embedding for a memory (used by reset cascade)."""
|
||||
p = e.payload
|
||||
conn.execute(
|
||||
"DELETE FROM embeddings WHERE memory_id = ?",
|
||||
(int(p["memory_id"]),),
|
||||
)
|
||||
|
||||
|
||||
def get_embedding(conn: Connection, memory_id: int) -> dict | None:
|
||||
row = conn.execute(
|
||||
"SELECT memory_id, vector_json, model, dim, indexed_at "
|
||||
"FROM embeddings WHERE memory_id = ?",
|
||||
(memory_id,),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return {
|
||||
"memory_id": row[0],
|
||||
"vector": json.loads(row[1]),
|
||||
"model": row[2],
|
||||
"dim": row[3],
|
||||
"indexed_at": row[4],
|
||||
}
|
||||
|
||||
|
||||
def list_embeddings_for_owner(conn: Connection, owner_id: str) -> list[dict]:
|
||||
"""Return all embeddings for memories owned by ``owner_id``.
|
||||
|
||||
Used by vector search at query time (T92). The join carries the
|
||||
fields the cosine ranker needs to assemble result rows without a
|
||||
second round-trip: the POV summary text, significance, and witness
|
||||
flags. The ``memories`` table has no separate ``text`` column —
|
||||
``pov_summary`` is the canonical narrative text per
|
||||
``chat/services/memory_write.py``.
|
||||
"""
|
||||
rows = conn.execute(
|
||||
"SELECT e.memory_id, e.vector_json, e.model, e.dim, "
|
||||
" m.pov_summary, m.significance, "
|
||||
" m.witness_you, m.witness_host, m.witness_guest "
|
||||
"FROM embeddings e "
|
||||
"JOIN memories m ON m.id = e.memory_id "
|
||||
"WHERE m.owner_id = ?",
|
||||
(owner_id,),
|
||||
).fetchall()
|
||||
return [
|
||||
{
|
||||
"memory_id": r[0],
|
||||
"vector": json.loads(r[1]),
|
||||
"model": r[2],
|
||||
"dim": r[3],
|
||||
"pov_summary": r[4],
|
||||
"significance": r[5],
|
||||
"witness_you": r[6],
|
||||
"witness_host": r[7],
|
||||
"witness_guest": r[8],
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_embedding",
|
||||
"list_embeddings_for_owner",
|
||||
]
|
||||
@@ -0,0 +1,218 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from chat.db.connection import open_db
|
||||
from chat.db.migrate import apply_migrations
|
||||
from chat.eventlog.log import append_event
|
||||
from chat.eventlog.projector import project
|
||||
import chat.state.memory # registers memory_written handler
|
||||
import chat.state.embeddings # registers embedding handlers
|
||||
from chat.state.embeddings import get_embedding, list_embeddings_for_owner
|
||||
|
||||
|
||||
def _base_memory(**overrides):
|
||||
payload = {
|
||||
"owner_id": "bot_a",
|
||||
"chat_id": "chat_bot_a",
|
||||
"scene_id": 1,
|
||||
"pov_summary": "She laughed at his joke about owls.",
|
||||
"witness_you": 1,
|
||||
"witness_host": 1,
|
||||
"witness_guest": 0,
|
||||
"chat_clock_at": "2026-04-26T10:00:00",
|
||||
"source": "direct",
|
||||
"reliability": 1.0,
|
||||
"significance": 1,
|
||||
"pinned": 0,
|
||||
"auto_pinned": 0,
|
||||
}
|
||||
payload.update(overrides)
|
||||
return payload
|
||||
|
||||
|
||||
def _vec(n: int = 384, base: float = 0.1) -> list[float]:
|
||||
"""Return a length-n float vector with predictable values for assertions."""
|
||||
return [round(base + i * 0.001, 6) for i in range(n)]
|
||||
|
||||
|
||||
def test_embedding_indexed_inserts_row(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
with open_db(db) as conn:
|
||||
append_event(conn, kind="memory_written", payload=_base_memory())
|
||||
project(conn)
|
||||
memory_id = conn.execute("SELECT id FROM memories").fetchone()[0]
|
||||
|
||||
vector = _vec(384, base=0.1)
|
||||
append_event(
|
||||
conn,
|
||||
kind="embedding_indexed",
|
||||
payload={
|
||||
"memory_id": memory_id,
|
||||
"vector": vector,
|
||||
"model": "test-model",
|
||||
"dim": 384,
|
||||
},
|
||||
)
|
||||
project(conn)
|
||||
|
||||
emb = get_embedding(conn, memory_id)
|
||||
assert emb is not None
|
||||
assert emb["memory_id"] == memory_id
|
||||
assert emb["vector"] == vector
|
||||
assert emb["model"] == "test-model"
|
||||
assert emb["dim"] == 384
|
||||
assert emb["indexed_at"] is not None
|
||||
|
||||
|
||||
def test_embedding_deindexed_removes_row(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
with open_db(db) as conn:
|
||||
append_event(conn, kind="memory_written", payload=_base_memory())
|
||||
project(conn)
|
||||
memory_id = conn.execute("SELECT id FROM memories").fetchone()[0]
|
||||
|
||||
append_event(
|
||||
conn,
|
||||
kind="embedding_indexed",
|
||||
payload={
|
||||
"memory_id": memory_id,
|
||||
"vector": _vec(),
|
||||
"model": "test-model",
|
||||
"dim": 384,
|
||||
},
|
||||
)
|
||||
project(conn)
|
||||
assert get_embedding(conn, memory_id) is not None
|
||||
|
||||
append_event(
|
||||
conn,
|
||||
kind="embedding_deindexed",
|
||||
payload={"memory_id": memory_id},
|
||||
)
|
||||
project(conn)
|
||||
assert get_embedding(conn, memory_id) is None
|
||||
|
||||
|
||||
def test_embedding_indexed_replaces_existing(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
with open_db(db) as conn:
|
||||
append_event(conn, kind="memory_written", payload=_base_memory())
|
||||
project(conn)
|
||||
memory_id = conn.execute("SELECT id FROM memories").fetchone()[0]
|
||||
|
||||
vec_a = _vec(384, base=0.1)
|
||||
vec_b = _vec(384, base=0.5)
|
||||
append_event(
|
||||
conn,
|
||||
kind="embedding_indexed",
|
||||
payload={
|
||||
"memory_id": memory_id,
|
||||
"vector": vec_a,
|
||||
"model": "test-model",
|
||||
"dim": 384,
|
||||
},
|
||||
)
|
||||
project(conn)
|
||||
first = get_embedding(conn, memory_id)
|
||||
assert first is not None
|
||||
assert first["vector"] == vec_a
|
||||
|
||||
append_event(
|
||||
conn,
|
||||
kind="embedding_indexed",
|
||||
payload={
|
||||
"memory_id": memory_id,
|
||||
"vector": vec_b,
|
||||
"model": "test-model",
|
||||
"dim": 384,
|
||||
},
|
||||
)
|
||||
project(conn)
|
||||
second = get_embedding(conn, memory_id)
|
||||
assert second is not None
|
||||
assert second["vector"] == vec_b
|
||||
# Still exactly one row for this memory.
|
||||
count = conn.execute(
|
||||
"SELECT COUNT(*) FROM embeddings WHERE memory_id = ?", (memory_id,)
|
||||
).fetchone()[0]
|
||||
assert count == 1
|
||||
|
||||
|
||||
def test_list_embeddings_for_owner_returns_joined_rows(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
with open_db(db) as conn:
|
||||
# Two memories for bot_a, one for bot_b.
|
||||
append_event(
|
||||
conn,
|
||||
kind="memory_written",
|
||||
payload=_base_memory(
|
||||
owner_id="bot_a",
|
||||
pov_summary="Alpha memory.",
|
||||
significance=2,
|
||||
),
|
||||
)
|
||||
append_event(
|
||||
conn,
|
||||
kind="memory_written",
|
||||
payload=_base_memory(
|
||||
owner_id="bot_a",
|
||||
pov_summary="Beta memory.",
|
||||
significance=3,
|
||||
),
|
||||
)
|
||||
append_event(
|
||||
conn,
|
||||
kind="memory_written",
|
||||
payload=_base_memory(
|
||||
owner_id="bot_b",
|
||||
pov_summary="Gamma memory.",
|
||||
significance=1,
|
||||
),
|
||||
)
|
||||
project(conn)
|
||||
|
||||
rows = conn.execute(
|
||||
"SELECT id, owner_id FROM memories ORDER BY id"
|
||||
).fetchall()
|
||||
# Index every memory with a distinct vector so we can check ordering.
|
||||
for i, (mid, _owner) in enumerate(rows):
|
||||
append_event(
|
||||
conn,
|
||||
kind="embedding_indexed",
|
||||
payload={
|
||||
"memory_id": mid,
|
||||
"vector": _vec(384, base=0.1 * (i + 1)),
|
||||
"model": "test-model",
|
||||
"dim": 384,
|
||||
},
|
||||
)
|
||||
project(conn)
|
||||
|
||||
a_rows = list_embeddings_for_owner(conn, "bot_a")
|
||||
assert len(a_rows) == 2
|
||||
summaries = {r["pov_summary"] for r in a_rows}
|
||||
assert summaries == {"Alpha memory.", "Beta memory."}
|
||||
sigs = {r["significance"] for r in a_rows}
|
||||
assert sigs == {2, 3}
|
||||
for r in a_rows:
|
||||
assert r["model"] == "test-model"
|
||||
assert r["dim"] == 384
|
||||
assert isinstance(r["vector"], list)
|
||||
assert len(r["vector"]) == 384
|
||||
assert r["witness_you"] == 1
|
||||
assert r["witness_host"] == 1
|
||||
assert r["witness_guest"] == 0
|
||||
|
||||
b_rows = list_embeddings_for_owner(conn, "bot_b")
|
||||
assert len(b_rows) == 1
|
||||
assert b_rows[0]["pov_summary"] == "Gamma memory."
|
||||
|
||||
|
||||
def test_get_embedding_returns_none_when_missing(tmp_path):
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
with open_db(db) as conn:
|
||||
assert get_embedding(conn, 999) is None
|
||||
Reference in New Issue
Block a user