feat: backfill_embeddings --re-embed-all flag for model swaps (T112.4)
Adds two new flags to the backfill script: * --re-embed-all walks **every** memory (not just those without an existing embeddings row) and re-emits embedding_indexed events. The projector is INSERT OR REPLACE, so re-emitting an event for an existing memory replaces the prior vector. Use this when swapping embedding models — the default mode still keeps the Phase 4 gap-fill behavior. * --model M overrides Settings.embedding_model for this run. The script also gains a small _build_client helper that returns None for the pseudo path (no client needed) and a FeatherlessClient otherwise; tests monkeypatch this to inject a Mock with canned embeddings. Adds tests/test_backfill_embeddings.py with three integration tests: re-embed-all walks every memory, default mode skips existing rows, and --model overrides the configured model end-to-end.
This commit is contained in:
@@ -8,8 +8,21 @@ Phase 4 ships the deterministic local pseudo-embedding so this script
|
||||
runs synchronously without a network round-trip — the LLMClient argument
|
||||
is not needed on the pseudo path. Phase 4.5+ will need a real client.
|
||||
|
||||
T112 (Phase 4.5) adds two flags:
|
||||
|
||||
* ``--re-embed-all`` walks **every** memory regardless of whether it
|
||||
already has an ``embeddings`` row. Useful when swapping embedding
|
||||
models — the projector is INSERT OR REPLACE, so re-emitting an event
|
||||
for an existing memory replaces the prior vector. Without this flag,
|
||||
the script keeps the Phase 4 behavior of only filling in gaps.
|
||||
* ``--model M`` overrides ``Settings.embedding_model`` for this run.
|
||||
Defaults to the configured model (which itself defaults to
|
||||
``"pseudo-sha256-384"``).
|
||||
|
||||
Run from the repo root:
|
||||
.venv/bin/python scripts/backfill_embeddings.py [--limit N] [--dry-run]
|
||||
.venv/bin/python scripts/backfill_embeddings.py --re-embed-all
|
||||
.venv/bin/python scripts/backfill_embeddings.py --re-embed-all --model bge-small-en-v1.5
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -17,11 +30,12 @@ from __future__ import annotations
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from chat.config import load_settings
|
||||
from chat.config import Settings, load_settings
|
||||
from chat.db.connection import open_db
|
||||
from chat.db.migrate import apply_migrations
|
||||
from chat.eventlog.log import append_and_apply
|
||||
from chat.services.embeddings import (
|
||||
DEFAULT_EMBEDDING_MODEL,
|
||||
FALLBACK_EMBEDDING_MODEL,
|
||||
generate_embedding,
|
||||
)
|
||||
@@ -34,6 +48,24 @@ import chat.state.memory # noqa: F401
|
||||
import chat.state.world # noqa: F401
|
||||
|
||||
|
||||
def _build_client(settings: Settings):
|
||||
"""Construct an LLMClient for the backfill run.
|
||||
|
||||
Default-model runs (the pseudo path) don't need a client, so we
|
||||
return ``None`` and ``generate_embedding`` skips the call. Non-default
|
||||
models route through the real client; injectable via monkeypatch in
|
||||
tests.
|
||||
"""
|
||||
if settings.embedding_model == DEFAULT_EMBEDDING_MODEL:
|
||||
return None
|
||||
from chat.llm.featherless import FeatherlessClient
|
||||
|
||||
return FeatherlessClient(
|
||||
api_key=settings.featherless_api_key,
|
||||
base_url=settings.featherless_base_url,
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
@@ -47,23 +79,51 @@ async def main() -> None:
|
||||
action="store_true",
|
||||
help="Print the count of memories needing embeddings, then exit.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--re-embed-all",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Walk every memory (not just those without an embeddings row) "
|
||||
"and re-emit embedding_indexed events. Use this when swapping "
|
||||
"embedding models so the existing rows get replaced."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Embedding model identifier. Overrides Settings.embedding_model "
|
||||
"for this run; default uses the configured model."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
settings = load_settings()
|
||||
settings.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
apply_migrations(settings.db_path)
|
||||
|
||||
model = args.model or settings.embedding_model
|
||||
# Override the settings instance so ``_build_client`` sees the
|
||||
# effective model when deciding whether to construct a real client.
|
||||
settings = settings.model_copy(update={"embedding_model": model})
|
||||
client = _build_client(settings)
|
||||
|
||||
with open_db(settings.db_path) as conn:
|
||||
sql = (
|
||||
"SELECT m.id, m.pov_summary FROM memories m "
|
||||
"LEFT JOIN embeddings e ON e.memory_id = m.id "
|
||||
"WHERE e.memory_id IS NULL "
|
||||
"ORDER BY m.id"
|
||||
)
|
||||
if args.re_embed_all:
|
||||
sql = "SELECT m.id, m.pov_summary FROM memories m ORDER BY m.id"
|
||||
else:
|
||||
sql = (
|
||||
"SELECT m.id, m.pov_summary FROM memories m "
|
||||
"LEFT JOIN embeddings e ON e.memory_id = m.id "
|
||||
"WHERE e.memory_id IS NULL "
|
||||
"ORDER BY m.id"
|
||||
)
|
||||
if args.limit is not None:
|
||||
sql += f" LIMIT {int(args.limit)}"
|
||||
rows = conn.execute(sql).fetchall()
|
||||
print(f"Found {len(rows)} memories needing embeddings.")
|
||||
mode = "re-embedding" if args.re_embed_all else "needing embeddings"
|
||||
print(f"Found {len(rows)} memories {mode} (model={model}).")
|
||||
if args.dry_run:
|
||||
return
|
||||
|
||||
@@ -71,11 +131,12 @@ async def main() -> None:
|
||||
skipped = 0
|
||||
for memory_id, text in rows:
|
||||
result = await generate_embedding(
|
||||
client=None, # pseudo path: no client needed
|
||||
client=client,
|
||||
text=text or "",
|
||||
model=model,
|
||||
)
|
||||
if result.model == FALLBACK_EMBEDDING_MODEL:
|
||||
print(f" Skipping memory_id={memory_id} (empty text)")
|
||||
print(f" Skipping memory_id={memory_id} (empty text or fallback)")
|
||||
skipped += 1
|
||||
continue
|
||||
append_and_apply(
|
||||
|
||||
@@ -0,0 +1,231 @@
|
||||
"""Tests for the backfill_embeddings script (T112, Phase 4.5).
|
||||
|
||||
Phase 4 shipped a backfill that walked memories *without* an embedding
|
||||
row and produced a vector for each (deterministic pseudo path). T112
|
||||
adds a ``--re-embed-all`` flag that walks **every** memory regardless
|
||||
of whether it already has an embeddings row, so operators can swap
|
||||
embedding models and have the existing rows replaced (the
|
||||
``embedding_indexed`` projector is INSERT OR REPLACE).
|
||||
|
||||
These tests exercise the script's ``main()`` directly via asyncio —
|
||||
shell-out via subprocess would also work but importing keeps the
|
||||
fixture surface small and the failure mode clearer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from chat.db.connection import open_db
|
||||
from chat.db.migrate import apply_migrations
|
||||
from chat.eventlog.log import append_and_apply, append_event
|
||||
from chat.eventlog.projector import project
|
||||
from chat.services.embeddings import DEFAULT_EMBEDDING_MODEL
|
||||
|
||||
# Trigger handler registration for projection.
|
||||
import chat.state.embeddings # noqa: F401
|
||||
import chat.state.entities # noqa: F401
|
||||
import chat.state.memory # noqa: F401
|
||||
import chat.state.world # noqa: F401
|
||||
|
||||
import scripts.backfill_embeddings as backfill
|
||||
|
||||
|
||||
def _seed(db_path: Path, count: int) -> list[int]:
|
||||
"""Seed ``count`` memory rows for ``bot_a``; return their ids."""
|
||||
with open_db(db_path) as conn:
|
||||
append_event(
|
||||
conn,
|
||||
kind="bot_authored",
|
||||
payload={
|
||||
"id": "bot_a",
|
||||
"name": "BotA",
|
||||
"persona": "...",
|
||||
"voice_samples": [],
|
||||
"traits": [],
|
||||
"backstory": "",
|
||||
"initial_relationship_to_you": "",
|
||||
"kickoff_prose": "",
|
||||
},
|
||||
)
|
||||
append_event(
|
||||
conn,
|
||||
kind="chat_created",
|
||||
payload={
|
||||
"id": "chat_bot_a",
|
||||
"host_bot_id": "bot_a",
|
||||
"initial_time": "2026-04-26T20:00:00+00:00",
|
||||
"narrative_anchor": "Day 1",
|
||||
"weather": "",
|
||||
},
|
||||
)
|
||||
for i in range(count):
|
||||
append_event(
|
||||
conn,
|
||||
kind="memory_written",
|
||||
payload={
|
||||
"owner_id": "bot_a",
|
||||
"chat_id": "chat_bot_a",
|
||||
"pov_summary": f"memory text {i}",
|
||||
"witness_you": 1,
|
||||
"witness_host": 1,
|
||||
"witness_guest": 0,
|
||||
"source": "direct",
|
||||
"reliability": 1.0,
|
||||
"significance": 1,
|
||||
"pinned": 0,
|
||||
"auto_pinned": 0,
|
||||
},
|
||||
)
|
||||
project(conn)
|
||||
return [
|
||||
r[0]
|
||||
for r in conn.execute(
|
||||
"SELECT id FROM memories WHERE owner_id = 'bot_a' ORDER BY id"
|
||||
).fetchall()
|
||||
]
|
||||
|
||||
|
||||
def _seed_embedding(db_path: Path, memory_id: int, model: str = "stale-model") -> None:
|
||||
"""Insert a stale ``embedding_indexed`` event so the row already
|
||||
exists in ``embeddings`` (and the default backfill would skip it)."""
|
||||
with open_db(db_path) as conn:
|
||||
append_and_apply(
|
||||
conn,
|
||||
kind="embedding_indexed",
|
||||
payload={
|
||||
"memory_id": memory_id,
|
||||
"model": model,
|
||||
"dim": 3,
|
||||
"vector": [0.0, 0.0, 0.0],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_re_embed_all_walks_every_memory(tmp_path, monkeypatch, capsys):
|
||||
"""``--re-embed-all`` re-embeds memories that already have rows in
|
||||
``embeddings`` (default mode skips them). After the run, every
|
||||
memory should have an updated embedding tagged with the configured
|
||||
model (the projector replaces stale rows in place)."""
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
memory_ids = _seed(db, count=3)
|
||||
# Pre-seed stale embeddings on two of the three memories so the
|
||||
# default path would skip them and only ``--re-embed-all`` covers
|
||||
# everything.
|
||||
_seed_embedding(db, memory_ids[0])
|
||||
_seed_embedding(db, memory_ids[1])
|
||||
|
||||
cfg = tmp_path / "config.toml"
|
||||
cfg.write_text(
|
||||
f'featherless_api_key = "x"\n'
|
||||
f'db_path = "{db}"\n'
|
||||
f'data_dir = "{tmp_path}"\n'
|
||||
)
|
||||
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
|
||||
monkeypatch.setenv("CHAT_DB_PATH", str(db))
|
||||
|
||||
with patch("sys.argv", ["backfill_embeddings.py", "--re-embed-all"]):
|
||||
await backfill.main()
|
||||
|
||||
# All three memories now have a fresh embedding tagged with the
|
||||
# default pseudo model (replacing the stale rows).
|
||||
with open_db(db) as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT memory_id, model FROM embeddings ORDER BY memory_id"
|
||||
).fetchall()
|
||||
assert len(rows) == 3
|
||||
for mid, model in rows:
|
||||
assert mid in memory_ids
|
||||
assert model == DEFAULT_EMBEDDING_MODEL
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_backfill_only_walks_missing(tmp_path, monkeypatch):
|
||||
"""Without ``--re-embed-all``, the script keeps the Phase 4
|
||||
behavior — memories with an existing embedding row are left
|
||||
alone (their stale-model tag survives)."""
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
memory_ids = _seed(db, count=2)
|
||||
_seed_embedding(db, memory_ids[0], model="stale-model")
|
||||
# memory_ids[1] has no embedding yet.
|
||||
|
||||
cfg = tmp_path / "config.toml"
|
||||
cfg.write_text(
|
||||
f'featherless_api_key = "x"\n'
|
||||
f'db_path = "{db}"\n'
|
||||
f'data_dir = "{tmp_path}"\n'
|
||||
)
|
||||
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
|
||||
monkeypatch.setenv("CHAT_DB_PATH", str(db))
|
||||
|
||||
with patch("sys.argv", ["backfill_embeddings.py"]):
|
||||
await backfill.main()
|
||||
|
||||
with open_db(db) as conn:
|
||||
rows = dict(
|
||||
conn.execute(
|
||||
"SELECT memory_id, model FROM embeddings ORDER BY memory_id"
|
||||
).fetchall()
|
||||
)
|
||||
# Stale row preserved; only the missing one was filled.
|
||||
assert rows[memory_ids[0]] == "stale-model"
|
||||
assert rows[memory_ids[1]] == DEFAULT_EMBEDDING_MODEL
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_re_embed_all_respects_model_arg(tmp_path, monkeypatch):
|
||||
"""The ``--model`` flag overrides ``Settings.embedding_model``.
|
||||
With a non-default model and a client that returns canned vectors,
|
||||
every memory is re-embedded with the supplied model tag."""
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
memory_ids = _seed(db, count=2)
|
||||
_seed_embedding(db, memory_ids[0])
|
||||
|
||||
cfg = tmp_path / "config.toml"
|
||||
cfg.write_text(
|
||||
f'featherless_api_key = "x"\n'
|
||||
f'db_path = "{db}"\n'
|
||||
f'data_dir = "{tmp_path}"\n'
|
||||
)
|
||||
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
|
||||
monkeypatch.setenv("CHAT_DB_PATH", str(db))
|
||||
|
||||
# Patch the client factory the script uses to produce a Mock with
|
||||
# canned embeddings — one per memory.
|
||||
from chat.llm.mock import MockLLMClient
|
||||
|
||||
canned_vec = [0.1] * 384
|
||||
|
||||
def _factory(_settings):
|
||||
return MockLLMClient(
|
||||
canned=[],
|
||||
canned_embeddings=[list(canned_vec) for _ in memory_ids],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(backfill, "_build_client", _factory)
|
||||
|
||||
with patch(
|
||||
"sys.argv",
|
||||
[
|
||||
"backfill_embeddings.py",
|
||||
"--re-embed-all",
|
||||
"--model",
|
||||
"bge-small-en-v1.5",
|
||||
],
|
||||
):
|
||||
await backfill.main()
|
||||
|
||||
with open_db(db) as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT memory_id, model FROM embeddings ORDER BY memory_id"
|
||||
).fetchall()
|
||||
assert len(rows) == 2
|
||||
for _, model in rows:
|
||||
assert model == "bge-small-en-v1.5"
|
||||
Reference in New Issue
Block a user