"""Tests for the backfill_embeddings script (T112, Phase 4.5). Phase 4 shipped a backfill that walked memories *without* an embedding row and produced a vector for each (deterministic pseudo path). T112 adds a ``--re-embed-all`` flag that walks **every** memory regardless of whether it already has an embeddings row, so operators can swap embedding models and have the existing rows replaced (the ``embedding_indexed`` projector is INSERT OR REPLACE). These tests exercise the script's ``main()`` directly via asyncio — shell-out via subprocess would also work but importing keeps the fixture surface small and the failure mode clearer. """ from __future__ import annotations from pathlib import Path from unittest.mock import patch import pytest from chat.db.connection import open_db from chat.db.migrate import apply_migrations from chat.eventlog.log import append_and_apply, append_event from chat.eventlog.projector import project from chat.services.embeddings import DEFAULT_EMBEDDING_MODEL # Trigger handler registration for projection. import chat.state.embeddings # noqa: F401 import chat.state.entities # noqa: F401 import chat.state.memory # noqa: F401 import chat.state.world # noqa: F401 import scripts.backfill_embeddings as backfill def _seed(db_path: Path, count: int) -> list[int]: """Seed ``count`` memory rows for ``bot_a``; return their ids.""" with open_db(db_path) as conn: append_event( conn, kind="bot_authored", payload={ "id": "bot_a", "name": "BotA", "persona": "...", "voice_samples": [], "traits": [], "backstory": "", "initial_relationship_to_you": "", "kickoff_prose": "", }, ) append_event( conn, kind="chat_created", payload={ "id": "chat_bot_a", "host_bot_id": "bot_a", "initial_time": "2026-04-26T20:00:00+00:00", "narrative_anchor": "Day 1", "weather": "", }, ) for i in range(count): append_event( conn, kind="memory_written", payload={ "owner_id": "bot_a", "chat_id": "chat_bot_a", "pov_summary": f"memory text {i}", "witness_you": 1, "witness_host": 1, "witness_guest": 0, "source": "direct", "reliability": 1.0, "significance": 1, "pinned": 0, "auto_pinned": 0, }, ) project(conn) return [ r[0] for r in conn.execute( "SELECT id FROM memories WHERE owner_id = 'bot_a' ORDER BY id" ).fetchall() ] def _seed_embedding(db_path: Path, memory_id: int, model: str = "stale-model") -> None: """Insert a stale ``embedding_indexed`` event so the row already exists in ``embeddings`` (and the default backfill would skip it).""" with open_db(db_path) as conn: append_and_apply( conn, kind="embedding_indexed", payload={ "memory_id": memory_id, "model": model, "dim": 3, "vector": [0.0, 0.0, 0.0], }, ) @pytest.mark.asyncio async def test_re_embed_all_walks_every_memory(tmp_path, monkeypatch, capsys): """``--re-embed-all`` re-embeds memories that already have rows in ``embeddings`` (default mode skips them). After the run, every memory should have an updated embedding tagged with the configured model (the projector replaces stale rows in place).""" db = tmp_path / "t.db" apply_migrations(db) memory_ids = _seed(db, count=3) # Pre-seed stale embeddings on two of the three memories so the # default path would skip them and only ``--re-embed-all`` covers # everything. _seed_embedding(db, memory_ids[0]) _seed_embedding(db, memory_ids[1]) cfg = tmp_path / "config.toml" cfg.write_text( f'featherless_api_key = "x"\n' f'db_path = "{db}"\n' f'data_dir = "{tmp_path}"\n' ) monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg)) monkeypatch.setenv("CHAT_DB_PATH", str(db)) with patch("sys.argv", ["backfill_embeddings.py", "--re-embed-all"]): await backfill.main() # All three memories now have a fresh embedding tagged with the # default pseudo model (replacing the stale rows). with open_db(db) as conn: rows = conn.execute( "SELECT memory_id, model FROM embeddings ORDER BY memory_id" ).fetchall() assert len(rows) == 3 for mid, model in rows: assert mid in memory_ids assert model == DEFAULT_EMBEDDING_MODEL @pytest.mark.asyncio async def test_default_backfill_only_walks_missing(tmp_path, monkeypatch): """Without ``--re-embed-all``, the script keeps the Phase 4 behavior — memories with an existing embedding row are left alone (their stale-model tag survives).""" db = tmp_path / "t.db" apply_migrations(db) memory_ids = _seed(db, count=2) _seed_embedding(db, memory_ids[0], model="stale-model") # memory_ids[1] has no embedding yet. cfg = tmp_path / "config.toml" cfg.write_text( f'featherless_api_key = "x"\n' f'db_path = "{db}"\n' f'data_dir = "{tmp_path}"\n' ) monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg)) monkeypatch.setenv("CHAT_DB_PATH", str(db)) with patch("sys.argv", ["backfill_embeddings.py"]): await backfill.main() with open_db(db) as conn: rows = dict( conn.execute( "SELECT memory_id, model FROM embeddings ORDER BY memory_id" ).fetchall() ) # Stale row preserved; only the missing one was filled. assert rows[memory_ids[0]] == "stale-model" assert rows[memory_ids[1]] == DEFAULT_EMBEDDING_MODEL @pytest.mark.asyncio async def test_re_embed_all_respects_model_arg(tmp_path, monkeypatch): """The ``--model`` flag overrides ``Settings.embedding_model``. With a non-default model and a client that returns canned vectors, every memory is re-embedded with the supplied model tag.""" db = tmp_path / "t.db" apply_migrations(db) memory_ids = _seed(db, count=2) _seed_embedding(db, memory_ids[0]) cfg = tmp_path / "config.toml" cfg.write_text( f'featherless_api_key = "x"\n' f'db_path = "{db}"\n' f'data_dir = "{tmp_path}"\n' ) monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg)) monkeypatch.setenv("CHAT_DB_PATH", str(db)) # Patch the client factory the script uses to produce a Mock with # canned embeddings — one per memory. from chat.llm.mock import MockLLMClient canned_vec = [0.1] * 384 def _factory(_settings): return MockLLMClient( canned=[], canned_embeddings=[list(canned_vec) for _ in memory_ids], ) monkeypatch.setattr(backfill, "_build_client", _factory) with patch( "sys.argv", [ "backfill_embeddings.py", "--re-embed-all", "--model", "bge-small-en-v1.5", ], ): await backfill.main() with open_db(db) as conn: rows = conn.execute( "SELECT memory_id, model FROM embeddings ORDER BY memory_id" ).fetchall() assert len(rows) == 2 for _, model in rows: assert model == "bge-small-en-v1.5"