From 5f16bb575a8cefcc40a46c78203a998217e0acfe Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Mon, 27 Apr 2026 05:47:55 -0400 Subject: [PATCH 1/4] feat: LLMClient Protocol gains embed() method (T112.1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds async def embed(self, text: str, *, model: str) -> list[float] to the LLMClient Protocol so Phase 4.5 can wire a real-embedding swap without changing call sites. Protocol is structural — existing implementations that don't use it remain compatible; downstream implementations (FeatherlessClient, MockLLMClient) ship in T112.2 and T112.3. --- chat/llm/client.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/chat/llm/client.py b/chat/llm/client.py index ca34a2d..5c079e1 100644 --- a/chat/llm/client.py +++ b/chat/llm/client.py @@ -12,3 +12,11 @@ class Message: class LLMClient(Protocol): async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str: ... def stream(self, messages: Sequence[Message], *, model: str, **params) -> AsyncIterator[str]: ... + # T112 (Phase 4.5): real-embedding seam. Implementations either call a + # provider's ``/v1/embeddings`` endpoint or, when the provider doesn't + # expose embeddings (e.g. Featherless today), raise ``NotImplementedError`` + # so ``generate_embedding`` can catch it and degrade to the zero-vector + # fallback. The Protocol is structural, so this method only needs to + # exist on implementations; existing callers that don't use it are + # unaffected. + async def embed(self, text: str, *, model: str) -> list[float]: ... From ac6e74ab4c90d223c6ca140d9c99c5db109c2052 Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Mon, 27 Apr 2026 05:48:34 -0400 Subject: [PATCH 2/4] feat: FeatherlessClient.embed() against /v1/embeddings (T112.2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements embed() on FeatherlessClient. Featherless's OpenAI- compatible surface does NOT expose /v1/embeddings at the time of writing, so this implementation raises NotImplementedError rather than issuing a request that would 404. The chat.services.embeddings.generate_embedding wrapper (T112.3) catches the exception and degrades to the zero-vector fallback path (plus the existing T107 warning) — misconfigured callers fail loudly in logs while the request path keeps working. If/when Featherless ships embeddings, swap the body for self._client.embeddings.create(model=..., input=...) guarded by the existing 2-conn semaphore (mirrors generate/stream). The Protocol seam in T112.1 is already wired so no other code needs to change. Adds tests/test_featherless.py pinning the NotImplementedError contract. --- chat/llm/featherless.py | 23 +++++++++++++++++++++++ tests/test_featherless.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) create mode 100644 tests/test_featherless.py diff --git a/chat/llm/featherless.py b/chat/llm/featherless.py index cf1138b..2eff3de 100644 --- a/chat/llm/featherless.py +++ b/chat/llm/featherless.py @@ -53,3 +53,26 @@ class FeatherlessClient: delta = chunk.choices[0].delta.content or "" if delta: yield delta + + async def embed(self, text: str, *, model: str) -> list[float]: + """Embeddings via Featherless — currently unsupported. + + T112 (Phase 4.5) extends the LLMClient Protocol with ``embed()`` + for a future real-embedding swap. Featherless's OpenAI-compatible + surface does NOT expose ``/v1/embeddings`` at the time of writing, + so this implementation raises ``NotImplementedError`` rather than + attempting a request that would 404. The + :func:`chat.services.embeddings.generate_embedding` wrapper + catches this and degrades to the existing zero-vector fallback + (with the T107 warning), so misconfigured callers fail loudly in + logs but the request path keeps working. + + If Featherless ships embeddings, swap the body for an + ``self._client.embeddings.create(model=..., input=...)`` call + guarded by ``self._sem()`` (mirrors ``generate``/``stream``). + """ + raise NotImplementedError( + "Featherless does not expose /v1/embeddings; " + "configure a different embedding provider or stick with " + "the default pseudo-sha256-384 model." + ) diff --git a/tests/test_featherless.py b/tests/test_featherless.py new file mode 100644 index 0000000..bfea4d6 --- /dev/null +++ b/tests/test_featherless.py @@ -0,0 +1,32 @@ +"""Tests for FeatherlessClient (Phase 4.5+). + +Phase 4.5 adds an ``embed()`` method to the LLMClient Protocol (T112). +Featherless does not expose an OpenAI-compatible ``/v1/embeddings`` +endpoint, so its implementation deliberately raises +``NotImplementedError`` to surface the gap clearly. The +``generate_embedding`` wrapper catches this and degrades to the +zero-vector fallback (the existing T107 warning path). + +If/when Featherless ships embeddings, swap the body for a real call to +``/v1/embeddings`` and update this test to mock the HTTP layer. +""" + +from __future__ import annotations + +import pytest + +from chat.llm.featherless import FeatherlessClient + + +@pytest.mark.asyncio +async def test_featherless_embed_raises_not_implemented(): + """Featherless does not expose ``/v1/embeddings`` — embed() must + raise ``NotImplementedError`` so callers (``generate_embedding``) + can degrade to the fallback zero vector + warning rather than + silently producing useless output.""" + client = FeatherlessClient(api_key="test-key") + with pytest.raises(NotImplementedError) as excinfo: + await client.embed("hello world", model="bge-small-en-v1.5") + # Message should hint at the cause so operators see why their + # real-model swap fell back. + assert "embeddings" in str(excinfo.value).lower() From e0a28abbcd778106fd875d7ab7dc752bf4bf59db Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Mon, 27 Apr 2026 05:50:29 -0400 Subject: [PATCH 3/4] feat: generate_embedding routes non-default models through client.embed (T112.3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When model != DEFAULT_EMBEDDING_MODEL, generate_embedding now calls client.embed(text, model=model) and wraps the returned vector in an EmbeddingResult tagged with the requested model. On any exception (NotImplementedError from providers without an embeddings endpoint, transient network errors, etc.), the existing T107 warning fires and the function falls back to the zero-vector sentinel — callers detect model == 'fallback' and skip indexing. Adds: - MockLLMClient accepts a canned_embeddings queue mirroring the existing canned pattern. embed() pops from the front; empty queue raises IndexError so misconfigured tests fail loudly. - Settings.embedding_model defaults to "pseudo-sha256-384" so existing zero-config installs keep Phase 4 behavior. The app lifespan now passes this through to EmbeddingWorker.model. The public signature of generate_embedding is unchanged: (client, *, text, model=DEFAULT_EMBEDDING_MODEL, dim=..., timeout_s=...). --- chat/app.py | 6 +++++ chat/config.py | 8 +++++++ chat/llm/mock.py | 22 ++++++++++++++++- chat/services/embeddings.py | 34 ++++++++++++++++---------- tests/test_config.py | 22 +++++++++++++++++ tests/test_embeddings.py | 48 +++++++++++++++++++++++++++++++++++++ tests/test_llm_mock.py | 25 +++++++++++++++++++ 7 files changed, 151 insertions(+), 14 deletions(-) diff --git a/chat/app.py b/chat/app.py index 80b0553..7241cd0 100644 --- a/chat/app.py +++ b/chat/app.py @@ -94,9 +94,15 @@ async def lifespan(app: FastAPI): # Phase 4's pseudo-embedding path is local so the worker doesn't need # an LLM client; we still pass one so the Phase 4.5 swap to a real # model is a one-line change. + # T112 (Phase 4.5): the embedding model is now configurable via + # ``Settings.embedding_model``. Default ``"pseudo-sha256-384"`` + # keeps the local-only path; swapping to a real model routes + # through ``client.embed(...)`` and falls back to a zero vector + # plus warning if the provider doesn't support embeddings. embedding_worker = EmbeddingWorker( conn_factory=lambda: open_db(settings.db_path), client=_factory(), + model=settings.embedding_model, ) await embedding_worker.start() app.state.embedding_worker = embedding_worker diff --git a/chat/config.py b/chat/config.py index 8eb19b6..d10dea4 100644 --- a/chat/config.py +++ b/chat/config.py @@ -39,6 +39,14 @@ class Settings(BaseModel): data_dir: Path = REPO_ROOT / "data" bind_host: str = "127.0.0.1" bind_port: int = 8000 + # T112 (Phase 4.5): embedding model identifier. Default is the + # deterministic local pseudo (semantically meaningless but keeps the + # vector pipeline structurally valid). Swap to a real model name + # (e.g. "bge-small-en-v1.5") once the LLMClient implementation + # supports embed() — currently FeatherlessClient does NOT, so a + # non-default value will trigger the zero-vector fallback path + # plus a T107 warning until a different provider is wired in. + embedding_model: str = "pseudo-sha256-384" def load_settings() -> Settings: config_path = Path(os.environ.get("CHAT_CONFIG_PATH", DEFAULT_CONFIG)) diff --git a/chat/llm/mock.py b/chat/llm/mock.py index 75ab786..5afc1ef 100644 --- a/chat/llm/mock.py +++ b/chat/llm/mock.py @@ -4,8 +4,23 @@ from .client import Message class MockLLMClient: - def __init__(self, canned: list[str]): + """In-memory LLMClient for tests. + + ``canned`` feeds ``generate``/``stream`` (one entry per call, popped + from the front). ``canned_embeddings`` (T112, Phase 4.5) feeds + ``embed`` the same way — each call pops the next vector. An empty + queue raises ``IndexError`` so misconfigured tests fail loudly + rather than returning ``None`` or hanging. + """ + + def __init__( + self, + canned: list[str], + *, + canned_embeddings: list[list[float]] | None = None, + ): self._canned = list(canned) + self._canned_embeddings: list[list[float]] = list(canned_embeddings or []) async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str: return self._canned.pop(0) @@ -14,3 +29,8 @@ class MockLLMClient: text = self._canned.pop(0) for ch in text: yield ch + + async def embed(self, text: str, *, model: str) -> list[float]: + # Mirrors the canned-queue pattern; empty queue raises so + # misconfigured tests surface clearly instead of returning None. + return self._canned_embeddings.pop(0) diff --git a/chat/services/embeddings.py b/chat/services/embeddings.py index 44002ea..e38fde4 100644 --- a/chat/services/embeddings.py +++ b/chat/services/embeddings.py @@ -95,19 +95,27 @@ async def generate_embedding( # Pure-local pseudo path — no LLMClient call. return EmbeddingResult(vector=_pseudo_embed(text, dim), model=model, dim=dim) - # Future: real embedding via client.embed(...). Phase 4.5 work. - # For Phase 4, any non-default model falls through to fallback — - # warn so misconfigured callers (e.g., a real-model swap that isn't - # wired up yet) don't silently degrade to a zero vector. - _log.warning( - "generate_embedding: non-default model %r returned fallback " - "(model client.embed() not yet implemented in Phase 4.5+); " - "downstream search will degrade silently. Configure a supported model.", - model, - ) - return EmbeddingResult( - vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim - ) + # T112 (Phase 4.5): non-default model — route through the client's + # ``embed()`` method. On any failure (including ``NotImplementedError`` + # from providers that don't expose embeddings, e.g. Featherless today), + # fall back to the zero vector and re-fire the T107 warning so + # misconfigured callers see the issue in logs rather than silently + # producing useless cosine results. + try: + vector = await client.embed(text, model=model) + return EmbeddingResult(vector=list(vector), model=model, dim=len(vector)) + except Exception as exc: # noqa: BLE001 — any failure must degrade gracefully + _log.warning( + "generate_embedding: non-default model %r returned fallback " + "(client.embed() raised %s: %s); " + "downstream search will degrade silently. Configure a supported model.", + model, + type(exc).__name__, + exc, + ) + return EmbeddingResult( + vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim + ) __all__ = [ diff --git a/tests/test_config.py b/tests/test_config.py index abffd57..bb723bd 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -24,3 +24,25 @@ def test_chat_db_path_env_overrides_default(tmp_path, monkeypatch): (tmp_path / "config.toml").write_text('featherless_api_key = "x"\n') s = load_settings() assert s.db_path == tmp_path / "alt.db" + + +def test_embedding_model_defaults_to_pseudo(tmp_path, monkeypatch): + """T112: ``embedding_model`` defaults to the deterministic pseudo + so existing zero-config installs keep the Phase 4 behavior.""" + monkeypatch.setenv("CHAT_CONFIG_PATH", str(tmp_path / "config.toml")) + (tmp_path / "config.toml").write_text('featherless_api_key = "x"\n') + s = load_settings() + assert s.embedding_model == "pseudo-sha256-384" + + +def test_embedding_model_overridable_via_toml(tmp_path, monkeypatch): + """T112: operators swap the embedding model by editing config.toml. + The new value flows through to the embedding worker at startup.""" + cfg = tmp_path / "config.toml" + cfg.write_text( + 'featherless_api_key = "x"\n' + 'embedding_model = "bge-small-en-v1.5"\n' + ) + monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg)) + s = load_settings() + assert s.embedding_model == "bge-small-en-v1.5" diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index 4d1dc4b..9b0084a 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -120,3 +120,51 @@ async def test_generate_embedding_default_model_does_not_warn(caplog): await generate_embedding(_client(), text="hello") warnings = [r for r in caplog.records if r.levelno == logging.WARNING] assert warnings == [] + + +@pytest.mark.asyncio +async def test_embed_routes_to_client_when_non_default_model(): + """T112: when a non-default ``model`` is requested, generate_embedding + routes through ``client.embed(text, model=...)`` and wraps the + returned vector in an EmbeddingResult tagged with the requested + model (NOT the fallback sentinel).""" + canned = [0.1, 0.2, 0.3, 0.4] + client = MockLLMClient(canned=[], canned_embeddings=[canned]) + + result = await generate_embedding( + client, text="hello world", model="bge-small-en-v1.5" + ) + assert result.vector == canned + assert result.model == "bge-small-en-v1.5" + assert result.dim == len(canned) + + +@pytest.mark.asyncio +async def test_embed_falls_back_on_client_failure(caplog): + """T112: when ``client.embed`` raises (e.g. NotImplementedError on + Featherless, or a transient network error), generate_embedding logs + the existing T107 warning and returns the zero-vector fallback so + callers detect the sentinel and skip indexing.""" + + class _FailingClient: + async def generate(self, messages, *, model, **params): # pragma: no cover + raise AssertionError("generate must not be called") + + def stream(self, messages, *, model, **params): # pragma: no cover + raise AssertionError("stream must not be called") + + async def embed(self, text, *, model): + raise NotImplementedError("provider does not expose embeddings") + + caplog.set_level(logging.WARNING, logger="chat.services.embeddings") + result = await generate_embedding( + _FailingClient(), text="hello", model="bge-small-en-v1.5" + ) + + assert result.model == FALLBACK_EMBEDDING_MODEL == "fallback" + assert len(result.vector) == DEFAULT_EMBEDDING_DIM + assert all(x == 0.0 for x in result.vector) + + # Existing T107 warning fires (re-used from the new exception branch). + warnings = [r for r in caplog.records if r.levelno == logging.WARNING] + assert any("bge-small-en-v1.5" in r.getMessage() for r in warnings) diff --git a/tests/test_llm_mock.py b/tests/test_llm_mock.py index d56a783..556e6cd 100644 --- a/tests/test_llm_mock.py +++ b/tests/test_llm_mock.py @@ -19,3 +19,28 @@ async def test_mock_streams_tokens(): async for chunk in client.stream(msgs, model="any"): chunks.append(chunk) assert "".join(chunks) == "abcd" + + +@pytest.mark.asyncio +async def test_mock_llm_client_embed_pops_canned(): + """T112: MockLLMClient.embed() pops a canned vector from the front + of ``canned_embeddings`` (mirrors the existing ``canned`` queue + pattern for generate/stream).""" + v1 = [0.1, 0.2, 0.3] + v2 = [0.4, 0.5, 0.6] + client = MockLLMClient(canned=[], canned_embeddings=[v1, v2]) + + out1 = await client.embed("first", model="bge-small-en-v1.5") + out2 = await client.embed("second", model="bge-small-en-v1.5") + assert out1 == v1 + assert out2 == v2 + + +@pytest.mark.asyncio +async def test_mock_llm_client_embed_empty_queue_raises(): + """When the canned_embeddings queue is empty, ``embed`` must raise + a clear failure (IndexError) so misconfigured tests don't silently + return None or hang.""" + client = MockLLMClient(canned=[]) + with pytest.raises(IndexError): + await client.embed("text", model="any") From 9b7a6d459f168dd1863f7d6482802af3829adbae Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Mon, 27 Apr 2026 06:02:23 -0400 Subject: [PATCH 4/4] feat: backfill_embeddings --re-embed-all flag for model swaps (T112.4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds two new flags to the backfill script: * --re-embed-all walks **every** memory (not just those without an existing embeddings row) and re-emits embedding_indexed events. The projector is INSERT OR REPLACE, so re-emitting an event for an existing memory replaces the prior vector. Use this when swapping embedding models — the default mode still keeps the Phase 4 gap-fill behavior. * --model M overrides Settings.embedding_model for this run. The script also gains a small _build_client helper that returns None for the pseudo path (no client needed) and a FeatherlessClient otherwise; tests monkeypatch this to inject a Mock with canned embeddings. Adds tests/test_backfill_embeddings.py with three integration tests: re-embed-all walks every memory, default mode skips existing rows, and --model overrides the configured model end-to-end. --- scripts/backfill_embeddings.py | 81 +++++++++-- tests/test_backfill_embeddings.py | 231 ++++++++++++++++++++++++++++++ 2 files changed, 302 insertions(+), 10 deletions(-) create mode 100644 tests/test_backfill_embeddings.py diff --git a/scripts/backfill_embeddings.py b/scripts/backfill_embeddings.py index f5c15bb..e823d2b 100644 --- a/scripts/backfill_embeddings.py +++ b/scripts/backfill_embeddings.py @@ -8,8 +8,21 @@ Phase 4 ships the deterministic local pseudo-embedding so this script runs synchronously without a network round-trip — the LLMClient argument is not needed on the pseudo path. Phase 4.5+ will need a real client. +T112 (Phase 4.5) adds two flags: + +* ``--re-embed-all`` walks **every** memory regardless of whether it + already has an ``embeddings`` row. Useful when swapping embedding + models — the projector is INSERT OR REPLACE, so re-emitting an event + for an existing memory replaces the prior vector. Without this flag, + the script keeps the Phase 4 behavior of only filling in gaps. +* ``--model M`` overrides ``Settings.embedding_model`` for this run. + Defaults to the configured model (which itself defaults to + ``"pseudo-sha256-384"``). + Run from the repo root: .venv/bin/python scripts/backfill_embeddings.py [--limit N] [--dry-run] + .venv/bin/python scripts/backfill_embeddings.py --re-embed-all + .venv/bin/python scripts/backfill_embeddings.py --re-embed-all --model bge-small-en-v1.5 """ from __future__ import annotations @@ -17,11 +30,12 @@ from __future__ import annotations import argparse import asyncio -from chat.config import load_settings +from chat.config import Settings, load_settings from chat.db.connection import open_db from chat.db.migrate import apply_migrations from chat.eventlog.log import append_and_apply from chat.services.embeddings import ( + DEFAULT_EMBEDDING_MODEL, FALLBACK_EMBEDDING_MODEL, generate_embedding, ) @@ -34,6 +48,24 @@ import chat.state.memory # noqa: F401 import chat.state.world # noqa: F401 +def _build_client(settings: Settings): + """Construct an LLMClient for the backfill run. + + Default-model runs (the pseudo path) don't need a client, so we + return ``None`` and ``generate_embedding`` skips the call. Non-default + models route through the real client; injectable via monkeypatch in + tests. + """ + if settings.embedding_model == DEFAULT_EMBEDDING_MODEL: + return None + from chat.llm.featherless import FeatherlessClient + + return FeatherlessClient( + api_key=settings.featherless_api_key, + base_url=settings.featherless_base_url, + ) + + async def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( @@ -47,23 +79,51 @@ async def main() -> None: action="store_true", help="Print the count of memories needing embeddings, then exit.", ) + parser.add_argument( + "--re-embed-all", + action="store_true", + help=( + "Walk every memory (not just those without an embeddings row) " + "and re-emit embedding_indexed events. Use this when swapping " + "embedding models so the existing rows get replaced." + ), + ) + parser.add_argument( + "--model", + type=str, + default=None, + help=( + "Embedding model identifier. Overrides Settings.embedding_model " + "for this run; default uses the configured model." + ), + ) args = parser.parse_args() settings = load_settings() settings.db_path.parent.mkdir(parents=True, exist_ok=True) apply_migrations(settings.db_path) + model = args.model or settings.embedding_model + # Override the settings instance so ``_build_client`` sees the + # effective model when deciding whether to construct a real client. + settings = settings.model_copy(update={"embedding_model": model}) + client = _build_client(settings) + with open_db(settings.db_path) as conn: - sql = ( - "SELECT m.id, m.pov_summary FROM memories m " - "LEFT JOIN embeddings e ON e.memory_id = m.id " - "WHERE e.memory_id IS NULL " - "ORDER BY m.id" - ) + if args.re_embed_all: + sql = "SELECT m.id, m.pov_summary FROM memories m ORDER BY m.id" + else: + sql = ( + "SELECT m.id, m.pov_summary FROM memories m " + "LEFT JOIN embeddings e ON e.memory_id = m.id " + "WHERE e.memory_id IS NULL " + "ORDER BY m.id" + ) if args.limit is not None: sql += f" LIMIT {int(args.limit)}" rows = conn.execute(sql).fetchall() - print(f"Found {len(rows)} memories needing embeddings.") + mode = "re-embedding" if args.re_embed_all else "needing embeddings" + print(f"Found {len(rows)} memories {mode} (model={model}).") if args.dry_run: return @@ -71,11 +131,12 @@ async def main() -> None: skipped = 0 for memory_id, text in rows: result = await generate_embedding( - client=None, # pseudo path: no client needed + client=client, text=text or "", + model=model, ) if result.model == FALLBACK_EMBEDDING_MODEL: - print(f" Skipping memory_id={memory_id} (empty text)") + print(f" Skipping memory_id={memory_id} (empty text or fallback)") skipped += 1 continue append_and_apply( diff --git a/tests/test_backfill_embeddings.py b/tests/test_backfill_embeddings.py new file mode 100644 index 0000000..d0f33b3 --- /dev/null +++ b/tests/test_backfill_embeddings.py @@ -0,0 +1,231 @@ +"""Tests for the backfill_embeddings script (T112, Phase 4.5). + +Phase 4 shipped a backfill that walked memories *without* an embedding +row and produced a vector for each (deterministic pseudo path). T112 +adds a ``--re-embed-all`` flag that walks **every** memory regardless +of whether it already has an embeddings row, so operators can swap +embedding models and have the existing rows replaced (the +``embedding_indexed`` projector is INSERT OR REPLACE). + +These tests exercise the script's ``main()`` directly via asyncio — +shell-out via subprocess would also work but importing keeps the +fixture surface small and the failure mode clearer. +""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import patch + +import pytest + +from chat.db.connection import open_db +from chat.db.migrate import apply_migrations +from chat.eventlog.log import append_and_apply, append_event +from chat.eventlog.projector import project +from chat.services.embeddings import DEFAULT_EMBEDDING_MODEL + +# Trigger handler registration for projection. +import chat.state.embeddings # noqa: F401 +import chat.state.entities # noqa: F401 +import chat.state.memory # noqa: F401 +import chat.state.world # noqa: F401 + +import scripts.backfill_embeddings as backfill + + +def _seed(db_path: Path, count: int) -> list[int]: + """Seed ``count`` memory rows for ``bot_a``; return their ids.""" + with open_db(db_path) as conn: + append_event( + conn, + kind="bot_authored", + payload={ + "id": "bot_a", + "name": "BotA", + "persona": "...", + "voice_samples": [], + "traits": [], + "backstory": "", + "initial_relationship_to_you": "", + "kickoff_prose": "", + }, + ) + append_event( + conn, + kind="chat_created", + payload={ + "id": "chat_bot_a", + "host_bot_id": "bot_a", + "initial_time": "2026-04-26T20:00:00+00:00", + "narrative_anchor": "Day 1", + "weather": "", + }, + ) + for i in range(count): + append_event( + conn, + kind="memory_written", + payload={ + "owner_id": "bot_a", + "chat_id": "chat_bot_a", + "pov_summary": f"memory text {i}", + "witness_you": 1, + "witness_host": 1, + "witness_guest": 0, + "source": "direct", + "reliability": 1.0, + "significance": 1, + "pinned": 0, + "auto_pinned": 0, + }, + ) + project(conn) + return [ + r[0] + for r in conn.execute( + "SELECT id FROM memories WHERE owner_id = 'bot_a' ORDER BY id" + ).fetchall() + ] + + +def _seed_embedding(db_path: Path, memory_id: int, model: str = "stale-model") -> None: + """Insert a stale ``embedding_indexed`` event so the row already + exists in ``embeddings`` (and the default backfill would skip it).""" + with open_db(db_path) as conn: + append_and_apply( + conn, + kind="embedding_indexed", + payload={ + "memory_id": memory_id, + "model": model, + "dim": 3, + "vector": [0.0, 0.0, 0.0], + }, + ) + + +@pytest.mark.asyncio +async def test_re_embed_all_walks_every_memory(tmp_path, monkeypatch, capsys): + """``--re-embed-all`` re-embeds memories that already have rows in + ``embeddings`` (default mode skips them). After the run, every + memory should have an updated embedding tagged with the configured + model (the projector replaces stale rows in place).""" + db = tmp_path / "t.db" + apply_migrations(db) + memory_ids = _seed(db, count=3) + # Pre-seed stale embeddings on two of the three memories so the + # default path would skip them and only ``--re-embed-all`` covers + # everything. + _seed_embedding(db, memory_ids[0]) + _seed_embedding(db, memory_ids[1]) + + cfg = tmp_path / "config.toml" + cfg.write_text( + f'featherless_api_key = "x"\n' + f'db_path = "{db}"\n' + f'data_dir = "{tmp_path}"\n' + ) + monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg)) + monkeypatch.setenv("CHAT_DB_PATH", str(db)) + + with patch("sys.argv", ["backfill_embeddings.py", "--re-embed-all"]): + await backfill.main() + + # All three memories now have a fresh embedding tagged with the + # default pseudo model (replacing the stale rows). + with open_db(db) as conn: + rows = conn.execute( + "SELECT memory_id, model FROM embeddings ORDER BY memory_id" + ).fetchall() + assert len(rows) == 3 + for mid, model in rows: + assert mid in memory_ids + assert model == DEFAULT_EMBEDDING_MODEL + + +@pytest.mark.asyncio +async def test_default_backfill_only_walks_missing(tmp_path, monkeypatch): + """Without ``--re-embed-all``, the script keeps the Phase 4 + behavior — memories with an existing embedding row are left + alone (their stale-model tag survives).""" + db = tmp_path / "t.db" + apply_migrations(db) + memory_ids = _seed(db, count=2) + _seed_embedding(db, memory_ids[0], model="stale-model") + # memory_ids[1] has no embedding yet. + + cfg = tmp_path / "config.toml" + cfg.write_text( + f'featherless_api_key = "x"\n' + f'db_path = "{db}"\n' + f'data_dir = "{tmp_path}"\n' + ) + monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg)) + monkeypatch.setenv("CHAT_DB_PATH", str(db)) + + with patch("sys.argv", ["backfill_embeddings.py"]): + await backfill.main() + + with open_db(db) as conn: + rows = dict( + conn.execute( + "SELECT memory_id, model FROM embeddings ORDER BY memory_id" + ).fetchall() + ) + # Stale row preserved; only the missing one was filled. + assert rows[memory_ids[0]] == "stale-model" + assert rows[memory_ids[1]] == DEFAULT_EMBEDDING_MODEL + + +@pytest.mark.asyncio +async def test_re_embed_all_respects_model_arg(tmp_path, monkeypatch): + """The ``--model`` flag overrides ``Settings.embedding_model``. + With a non-default model and a client that returns canned vectors, + every memory is re-embedded with the supplied model tag.""" + db = tmp_path / "t.db" + apply_migrations(db) + memory_ids = _seed(db, count=2) + _seed_embedding(db, memory_ids[0]) + + cfg = tmp_path / "config.toml" + cfg.write_text( + f'featherless_api_key = "x"\n' + f'db_path = "{db}"\n' + f'data_dir = "{tmp_path}"\n' + ) + monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg)) + monkeypatch.setenv("CHAT_DB_PATH", str(db)) + + # Patch the client factory the script uses to produce a Mock with + # canned embeddings — one per memory. + from chat.llm.mock import MockLLMClient + + canned_vec = [0.1] * 384 + + def _factory(_settings): + return MockLLMClient( + canned=[], + canned_embeddings=[list(canned_vec) for _ in memory_ids], + ) + + monkeypatch.setattr(backfill, "_build_client", _factory) + + with patch( + "sys.argv", + [ + "backfill_embeddings.py", + "--re-embed-all", + "--model", + "bge-small-en-v1.5", + ], + ): + await backfill.main() + + with open_db(db) as conn: + rows = conn.execute( + "SELECT memory_id, model FROM embeddings ORDER BY memory_id" + ).fetchall() + assert len(rows) == 2 + for _, model in rows: + assert model == "bge-small-en-v1.5"