Files
chat/tests/test_backfill_embeddings.py
Joseph Doherty 9b7a6d459f feat: backfill_embeddings --re-embed-all flag for model swaps (T112.4)
Adds two new flags to the backfill script:

* --re-embed-all walks **every** memory (not just those without
  an existing embeddings row) and re-emits embedding_indexed
  events. The projector is INSERT OR REPLACE, so re-emitting an event
  for an existing memory replaces the prior vector. Use this when
  swapping embedding models — the default mode still keeps the Phase
  4 gap-fill behavior.
* --model M overrides Settings.embedding_model for this run.

The script also gains a small _build_client helper that returns
None for the pseudo path (no client needed) and a FeatherlessClient
otherwise; tests monkeypatch this to inject a Mock with canned
embeddings.

Adds tests/test_backfill_embeddings.py with three integration
tests: re-embed-all walks every memory, default mode skips existing
rows, and --model overrides the configured model end-to-end.
2026-04-27 06:02:23 -04:00

232 lines
7.6 KiB
Python

"""Tests for the backfill_embeddings script (T112, Phase 4.5).
Phase 4 shipped a backfill that walked memories *without* an embedding
row and produced a vector for each (deterministic pseudo path). T112
adds a ``--re-embed-all`` flag that walks **every** memory regardless
of whether it already has an embeddings row, so operators can swap
embedding models and have the existing rows replaced (the
``embedding_indexed`` projector is INSERT OR REPLACE).
These tests exercise the script's ``main()`` directly via asyncio —
shell-out via subprocess would also work but importing keeps the
fixture surface small and the failure mode clearer.
"""
from __future__ import annotations
from pathlib import Path
from unittest.mock import patch
import pytest
from chat.db.connection import open_db
from chat.db.migrate import apply_migrations
from chat.eventlog.log import append_and_apply, append_event
from chat.eventlog.projector import project
from chat.services.embeddings import DEFAULT_EMBEDDING_MODEL
# Trigger handler registration for projection.
import chat.state.embeddings # noqa: F401
import chat.state.entities # noqa: F401
import chat.state.memory # noqa: F401
import chat.state.world # noqa: F401
import scripts.backfill_embeddings as backfill
def _seed(db_path: Path, count: int) -> list[int]:
"""Seed ``count`` memory rows for ``bot_a``; return their ids."""
with open_db(db_path) as conn:
append_event(
conn,
kind="bot_authored",
payload={
"id": "bot_a",
"name": "BotA",
"persona": "...",
"voice_samples": [],
"traits": [],
"backstory": "",
"initial_relationship_to_you": "",
"kickoff_prose": "",
},
)
append_event(
conn,
kind="chat_created",
payload={
"id": "chat_bot_a",
"host_bot_id": "bot_a",
"initial_time": "2026-04-26T20:00:00+00:00",
"narrative_anchor": "Day 1",
"weather": "",
},
)
for i in range(count):
append_event(
conn,
kind="memory_written",
payload={
"owner_id": "bot_a",
"chat_id": "chat_bot_a",
"pov_summary": f"memory text {i}",
"witness_you": 1,
"witness_host": 1,
"witness_guest": 0,
"source": "direct",
"reliability": 1.0,
"significance": 1,
"pinned": 0,
"auto_pinned": 0,
},
)
project(conn)
return [
r[0]
for r in conn.execute(
"SELECT id FROM memories WHERE owner_id = 'bot_a' ORDER BY id"
).fetchall()
]
def _seed_embedding(db_path: Path, memory_id: int, model: str = "stale-model") -> None:
"""Insert a stale ``embedding_indexed`` event so the row already
exists in ``embeddings`` (and the default backfill would skip it)."""
with open_db(db_path) as conn:
append_and_apply(
conn,
kind="embedding_indexed",
payload={
"memory_id": memory_id,
"model": model,
"dim": 3,
"vector": [0.0, 0.0, 0.0],
},
)
@pytest.mark.asyncio
async def test_re_embed_all_walks_every_memory(tmp_path, monkeypatch, capsys):
"""``--re-embed-all`` re-embeds memories that already have rows in
``embeddings`` (default mode skips them). After the run, every
memory should have an updated embedding tagged with the configured
model (the projector replaces stale rows in place)."""
db = tmp_path / "t.db"
apply_migrations(db)
memory_ids = _seed(db, count=3)
# Pre-seed stale embeddings on two of the three memories so the
# default path would skip them and only ``--re-embed-all`` covers
# everything.
_seed_embedding(db, memory_ids[0])
_seed_embedding(db, memory_ids[1])
cfg = tmp_path / "config.toml"
cfg.write_text(
f'featherless_api_key = "x"\n'
f'db_path = "{db}"\n'
f'data_dir = "{tmp_path}"\n'
)
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
monkeypatch.setenv("CHAT_DB_PATH", str(db))
with patch("sys.argv", ["backfill_embeddings.py", "--re-embed-all"]):
await backfill.main()
# All three memories now have a fresh embedding tagged with the
# default pseudo model (replacing the stale rows).
with open_db(db) as conn:
rows = conn.execute(
"SELECT memory_id, model FROM embeddings ORDER BY memory_id"
).fetchall()
assert len(rows) == 3
for mid, model in rows:
assert mid in memory_ids
assert model == DEFAULT_EMBEDDING_MODEL
@pytest.mark.asyncio
async def test_default_backfill_only_walks_missing(tmp_path, monkeypatch):
"""Without ``--re-embed-all``, the script keeps the Phase 4
behavior — memories with an existing embedding row are left
alone (their stale-model tag survives)."""
db = tmp_path / "t.db"
apply_migrations(db)
memory_ids = _seed(db, count=2)
_seed_embedding(db, memory_ids[0], model="stale-model")
# memory_ids[1] has no embedding yet.
cfg = tmp_path / "config.toml"
cfg.write_text(
f'featherless_api_key = "x"\n'
f'db_path = "{db}"\n'
f'data_dir = "{tmp_path}"\n'
)
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
monkeypatch.setenv("CHAT_DB_PATH", str(db))
with patch("sys.argv", ["backfill_embeddings.py"]):
await backfill.main()
with open_db(db) as conn:
rows = dict(
conn.execute(
"SELECT memory_id, model FROM embeddings ORDER BY memory_id"
).fetchall()
)
# Stale row preserved; only the missing one was filled.
assert rows[memory_ids[0]] == "stale-model"
assert rows[memory_ids[1]] == DEFAULT_EMBEDDING_MODEL
@pytest.mark.asyncio
async def test_re_embed_all_respects_model_arg(tmp_path, monkeypatch):
"""The ``--model`` flag overrides ``Settings.embedding_model``.
With a non-default model and a client that returns canned vectors,
every memory is re-embedded with the supplied model tag."""
db = tmp_path / "t.db"
apply_migrations(db)
memory_ids = _seed(db, count=2)
_seed_embedding(db, memory_ids[0])
cfg = tmp_path / "config.toml"
cfg.write_text(
f'featherless_api_key = "x"\n'
f'db_path = "{db}"\n'
f'data_dir = "{tmp_path}"\n'
)
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
monkeypatch.setenv("CHAT_DB_PATH", str(db))
# Patch the client factory the script uses to produce a Mock with
# canned embeddings — one per memory.
from chat.llm.mock import MockLLMClient
canned_vec = [0.1] * 384
def _factory(_settings):
return MockLLMClient(
canned=[],
canned_embeddings=[list(canned_vec) for _ in memory_ids],
)
monkeypatch.setattr(backfill, "_build_client", _factory)
with patch(
"sys.argv",
[
"backfill_embeddings.py",
"--re-embed-all",
"--model",
"bge-small-en-v1.5",
],
):
await backfill.main()
with open_db(db) as conn:
rows = conn.execute(
"SELECT memory_id, model FROM embeddings ORDER BY memory_id"
).fetchall()
assert len(rows) == 2
for _, model in rows:
assert model == "bge-small-en-v1.5"