9b7a6d459f
Adds two new flags to the backfill script: * --re-embed-all walks **every** memory (not just those without an existing embeddings row) and re-emits embedding_indexed events. The projector is INSERT OR REPLACE, so re-emitting an event for an existing memory replaces the prior vector. Use this when swapping embedding models — the default mode still keeps the Phase 4 gap-fill behavior. * --model M overrides Settings.embedding_model for this run. The script also gains a small _build_client helper that returns None for the pseudo path (no client needed) and a FeatherlessClient otherwise; tests monkeypatch this to inject a Mock with canned embeddings. Adds tests/test_backfill_embeddings.py with three integration tests: re-embed-all walks every memory, default mode skips existing rows, and --model overrides the configured model end-to-end.
159 lines
5.5 KiB
Python
159 lines
5.5 KiB
Python
"""Backfill embeddings for memories that lack them (T97, Phase 4).
|
|
|
|
Walks all memories where no row exists in the ``embeddings`` table. For
|
|
each, calls :func:`chat.services.embeddings.generate_embedding` and emits
|
|
an ``embedding_indexed`` event so the projector lands the vector.
|
|
|
|
Phase 4 ships the deterministic local pseudo-embedding so this script
|
|
runs synchronously without a network round-trip — the LLMClient argument
|
|
is not needed on the pseudo path. Phase 4.5+ will need a real client.
|
|
|
|
T112 (Phase 4.5) adds two flags:
|
|
|
|
* ``--re-embed-all`` walks **every** memory regardless of whether it
|
|
already has an ``embeddings`` row. Useful when swapping embedding
|
|
models — the projector is INSERT OR REPLACE, so re-emitting an event
|
|
for an existing memory replaces the prior vector. Without this flag,
|
|
the script keeps the Phase 4 behavior of only filling in gaps.
|
|
* ``--model M`` overrides ``Settings.embedding_model`` for this run.
|
|
Defaults to the configured model (which itself defaults to
|
|
``"pseudo-sha256-384"``).
|
|
|
|
Run from the repo root:
|
|
.venv/bin/python scripts/backfill_embeddings.py [--limit N] [--dry-run]
|
|
.venv/bin/python scripts/backfill_embeddings.py --re-embed-all
|
|
.venv/bin/python scripts/backfill_embeddings.py --re-embed-all --model bge-small-en-v1.5
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
|
|
from chat.config import Settings, load_settings
|
|
from chat.db.connection import open_db
|
|
from chat.db.migrate import apply_migrations
|
|
from chat.eventlog.log import append_and_apply
|
|
from chat.services.embeddings import (
|
|
DEFAULT_EMBEDDING_MODEL,
|
|
FALLBACK_EMBEDDING_MODEL,
|
|
generate_embedding,
|
|
)
|
|
|
|
# Trigger projector handler registration so ``append_and_apply`` lands
|
|
# the embedding rows correctly.
|
|
import chat.state.embeddings # noqa: F401
|
|
import chat.state.entities # noqa: F401
|
|
import chat.state.memory # noqa: F401
|
|
import chat.state.world # noqa: F401
|
|
|
|
|
|
def _build_client(settings: Settings):
|
|
"""Construct an LLMClient for the backfill run.
|
|
|
|
Default-model runs (the pseudo path) don't need a client, so we
|
|
return ``None`` and ``generate_embedding`` skips the call. Non-default
|
|
models route through the real client; injectable via monkeypatch in
|
|
tests.
|
|
"""
|
|
if settings.embedding_model == DEFAULT_EMBEDDING_MODEL:
|
|
return None
|
|
from chat.llm.featherless import FeatherlessClient
|
|
|
|
return FeatherlessClient(
|
|
api_key=settings.featherless_api_key,
|
|
base_url=settings.featherless_base_url,
|
|
)
|
|
|
|
|
|
async def main() -> None:
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument(
|
|
"--limit",
|
|
type=int,
|
|
default=None,
|
|
help="Cap the number of memories backfilled in this run.",
|
|
)
|
|
parser.add_argument(
|
|
"--dry-run",
|
|
action="store_true",
|
|
help="Print the count of memories needing embeddings, then exit.",
|
|
)
|
|
parser.add_argument(
|
|
"--re-embed-all",
|
|
action="store_true",
|
|
help=(
|
|
"Walk every memory (not just those without an embeddings row) "
|
|
"and re-emit embedding_indexed events. Use this when swapping "
|
|
"embedding models so the existing rows get replaced."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
default=None,
|
|
help=(
|
|
"Embedding model identifier. Overrides Settings.embedding_model "
|
|
"for this run; default uses the configured model."
|
|
),
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
settings = load_settings()
|
|
settings.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
apply_migrations(settings.db_path)
|
|
|
|
model = args.model or settings.embedding_model
|
|
# Override the settings instance so ``_build_client`` sees the
|
|
# effective model when deciding whether to construct a real client.
|
|
settings = settings.model_copy(update={"embedding_model": model})
|
|
client = _build_client(settings)
|
|
|
|
with open_db(settings.db_path) as conn:
|
|
if args.re_embed_all:
|
|
sql = "SELECT m.id, m.pov_summary FROM memories m ORDER BY m.id"
|
|
else:
|
|
sql = (
|
|
"SELECT m.id, m.pov_summary FROM memories m "
|
|
"LEFT JOIN embeddings e ON e.memory_id = m.id "
|
|
"WHERE e.memory_id IS NULL "
|
|
"ORDER BY m.id"
|
|
)
|
|
if args.limit is not None:
|
|
sql += f" LIMIT {int(args.limit)}"
|
|
rows = conn.execute(sql).fetchall()
|
|
mode = "re-embedding" if args.re_embed_all else "needing embeddings"
|
|
print(f"Found {len(rows)} memories {mode} (model={model}).")
|
|
if args.dry_run:
|
|
return
|
|
|
|
indexed = 0
|
|
skipped = 0
|
|
for memory_id, text in rows:
|
|
result = await generate_embedding(
|
|
client=client,
|
|
text=text or "",
|
|
model=model,
|
|
)
|
|
if result.model == FALLBACK_EMBEDDING_MODEL:
|
|
print(f" Skipping memory_id={memory_id} (empty text or fallback)")
|
|
skipped += 1
|
|
continue
|
|
append_and_apply(
|
|
conn,
|
|
kind="embedding_indexed",
|
|
payload={
|
|
"memory_id": memory_id,
|
|
"model": result.model,
|
|
"dim": result.dim,
|
|
"vector": result.vector,
|
|
},
|
|
)
|
|
indexed += 1
|
|
print(f" Indexed memory_id={memory_id}")
|
|
print(f"Done. Indexed {indexed}, skipped {skipped}.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|