merge: T112 real embedding model swap (Protocol + Mock + routing + backfill)
This commit is contained in:
@@ -94,9 +94,15 @@ async def lifespan(app: FastAPI):
|
|||||||
# Phase 4's pseudo-embedding path is local so the worker doesn't need
|
# Phase 4's pseudo-embedding path is local so the worker doesn't need
|
||||||
# an LLM client; we still pass one so the Phase 4.5 swap to a real
|
# an LLM client; we still pass one so the Phase 4.5 swap to a real
|
||||||
# model is a one-line change.
|
# model is a one-line change.
|
||||||
|
# T112 (Phase 4.5): the embedding model is now configurable via
|
||||||
|
# ``Settings.embedding_model``. Default ``"pseudo-sha256-384"``
|
||||||
|
# keeps the local-only path; swapping to a real model routes
|
||||||
|
# through ``client.embed(...)`` and falls back to a zero vector
|
||||||
|
# plus warning if the provider doesn't support embeddings.
|
||||||
embedding_worker = EmbeddingWorker(
|
embedding_worker = EmbeddingWorker(
|
||||||
conn_factory=lambda: open_db(settings.db_path),
|
conn_factory=lambda: open_db(settings.db_path),
|
||||||
client=_factory(),
|
client=_factory(),
|
||||||
|
model=settings.embedding_model,
|
||||||
)
|
)
|
||||||
await embedding_worker.start()
|
await embedding_worker.start()
|
||||||
app.state.embedding_worker = embedding_worker
|
app.state.embedding_worker = embedding_worker
|
||||||
|
|||||||
@@ -39,6 +39,14 @@ class Settings(BaseModel):
|
|||||||
data_dir: Path = REPO_ROOT / "data"
|
data_dir: Path = REPO_ROOT / "data"
|
||||||
bind_host: str = "127.0.0.1"
|
bind_host: str = "127.0.0.1"
|
||||||
bind_port: int = 8000
|
bind_port: int = 8000
|
||||||
|
# T112 (Phase 4.5): embedding model identifier. Default is the
|
||||||
|
# deterministic local pseudo (semantically meaningless but keeps the
|
||||||
|
# vector pipeline structurally valid). Swap to a real model name
|
||||||
|
# (e.g. "bge-small-en-v1.5") once the LLMClient implementation
|
||||||
|
# supports embed() — currently FeatherlessClient does NOT, so a
|
||||||
|
# non-default value will trigger the zero-vector fallback path
|
||||||
|
# plus a T107 warning until a different provider is wired in.
|
||||||
|
embedding_model: str = "pseudo-sha256-384"
|
||||||
|
|
||||||
def load_settings() -> Settings:
|
def load_settings() -> Settings:
|
||||||
config_path = Path(os.environ.get("CHAT_CONFIG_PATH", DEFAULT_CONFIG))
|
config_path = Path(os.environ.get("CHAT_CONFIG_PATH", DEFAULT_CONFIG))
|
||||||
|
|||||||
@@ -12,3 +12,11 @@ class Message:
|
|||||||
class LLMClient(Protocol):
|
class LLMClient(Protocol):
|
||||||
async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str: ...
|
async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str: ...
|
||||||
def stream(self, messages: Sequence[Message], *, model: str, **params) -> AsyncIterator[str]: ...
|
def stream(self, messages: Sequence[Message], *, model: str, **params) -> AsyncIterator[str]: ...
|
||||||
|
# T112 (Phase 4.5): real-embedding seam. Implementations either call a
|
||||||
|
# provider's ``/v1/embeddings`` endpoint or, when the provider doesn't
|
||||||
|
# expose embeddings (e.g. Featherless today), raise ``NotImplementedError``
|
||||||
|
# so ``generate_embedding`` can catch it and degrade to the zero-vector
|
||||||
|
# fallback. The Protocol is structural, so this method only needs to
|
||||||
|
# exist on implementations; existing callers that don't use it are
|
||||||
|
# unaffected.
|
||||||
|
async def embed(self, text: str, *, model: str) -> list[float]: ...
|
||||||
|
|||||||
@@ -53,3 +53,26 @@ class FeatherlessClient:
|
|||||||
delta = chunk.choices[0].delta.content or ""
|
delta = chunk.choices[0].delta.content or ""
|
||||||
if delta:
|
if delta:
|
||||||
yield delta
|
yield delta
|
||||||
|
|
||||||
|
async def embed(self, text: str, *, model: str) -> list[float]:
|
||||||
|
"""Embeddings via Featherless — currently unsupported.
|
||||||
|
|
||||||
|
T112 (Phase 4.5) extends the LLMClient Protocol with ``embed()``
|
||||||
|
for a future real-embedding swap. Featherless's OpenAI-compatible
|
||||||
|
surface does NOT expose ``/v1/embeddings`` at the time of writing,
|
||||||
|
so this implementation raises ``NotImplementedError`` rather than
|
||||||
|
attempting a request that would 404. The
|
||||||
|
:func:`chat.services.embeddings.generate_embedding` wrapper
|
||||||
|
catches this and degrades to the existing zero-vector fallback
|
||||||
|
(with the T107 warning), so misconfigured callers fail loudly in
|
||||||
|
logs but the request path keeps working.
|
||||||
|
|
||||||
|
If Featherless ships embeddings, swap the body for an
|
||||||
|
``self._client.embeddings.create(model=..., input=...)`` call
|
||||||
|
guarded by ``self._sem()`` (mirrors ``generate``/``stream``).
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Featherless does not expose /v1/embeddings; "
|
||||||
|
"configure a different embedding provider or stick with "
|
||||||
|
"the default pseudo-sha256-384 model."
|
||||||
|
)
|
||||||
|
|||||||
+21
-1
@@ -4,8 +4,23 @@ from .client import Message
|
|||||||
|
|
||||||
|
|
||||||
class MockLLMClient:
|
class MockLLMClient:
|
||||||
def __init__(self, canned: list[str]):
|
"""In-memory LLMClient for tests.
|
||||||
|
|
||||||
|
``canned`` feeds ``generate``/``stream`` (one entry per call, popped
|
||||||
|
from the front). ``canned_embeddings`` (T112, Phase 4.5) feeds
|
||||||
|
``embed`` the same way — each call pops the next vector. An empty
|
||||||
|
queue raises ``IndexError`` so misconfigured tests fail loudly
|
||||||
|
rather than returning ``None`` or hanging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
canned: list[str],
|
||||||
|
*,
|
||||||
|
canned_embeddings: list[list[float]] | None = None,
|
||||||
|
):
|
||||||
self._canned = list(canned)
|
self._canned = list(canned)
|
||||||
|
self._canned_embeddings: list[list[float]] = list(canned_embeddings or [])
|
||||||
|
|
||||||
async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str:
|
async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str:
|
||||||
return self._canned.pop(0)
|
return self._canned.pop(0)
|
||||||
@@ -14,3 +29,8 @@ class MockLLMClient:
|
|||||||
text = self._canned.pop(0)
|
text = self._canned.pop(0)
|
||||||
for ch in text:
|
for ch in text:
|
||||||
yield ch
|
yield ch
|
||||||
|
|
||||||
|
async def embed(self, text: str, *, model: str) -> list[float]:
|
||||||
|
# Mirrors the canned-queue pattern; empty queue raises so
|
||||||
|
# misconfigured tests surface clearly instead of returning None.
|
||||||
|
return self._canned_embeddings.pop(0)
|
||||||
|
|||||||
+21
-13
@@ -95,19 +95,27 @@ async def generate_embedding(
|
|||||||
# Pure-local pseudo path — no LLMClient call.
|
# Pure-local pseudo path — no LLMClient call.
|
||||||
return EmbeddingResult(vector=_pseudo_embed(text, dim), model=model, dim=dim)
|
return EmbeddingResult(vector=_pseudo_embed(text, dim), model=model, dim=dim)
|
||||||
|
|
||||||
# Future: real embedding via client.embed(...). Phase 4.5 work.
|
# T112 (Phase 4.5): non-default model — route through the client's
|
||||||
# For Phase 4, any non-default model falls through to fallback —
|
# ``embed()`` method. On any failure (including ``NotImplementedError``
|
||||||
# warn so misconfigured callers (e.g., a real-model swap that isn't
|
# from providers that don't expose embeddings, e.g. Featherless today),
|
||||||
# wired up yet) don't silently degrade to a zero vector.
|
# fall back to the zero vector and re-fire the T107 warning so
|
||||||
_log.warning(
|
# misconfigured callers see the issue in logs rather than silently
|
||||||
"generate_embedding: non-default model %r returned fallback "
|
# producing useless cosine results.
|
||||||
"(model client.embed() not yet implemented in Phase 4.5+); "
|
try:
|
||||||
"downstream search will degrade silently. Configure a supported model.",
|
vector = await client.embed(text, model=model)
|
||||||
model,
|
return EmbeddingResult(vector=list(vector), model=model, dim=len(vector))
|
||||||
)
|
except Exception as exc: # noqa: BLE001 — any failure must degrade gracefully
|
||||||
return EmbeddingResult(
|
_log.warning(
|
||||||
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim
|
"generate_embedding: non-default model %r returned fallback "
|
||||||
)
|
"(client.embed() raised %s: %s); "
|
||||||
|
"downstream search will degrade silently. Configure a supported model.",
|
||||||
|
model,
|
||||||
|
type(exc).__name__,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return EmbeddingResult(
|
||||||
|
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@@ -8,8 +8,21 @@ Phase 4 ships the deterministic local pseudo-embedding so this script
|
|||||||
runs synchronously without a network round-trip — the LLMClient argument
|
runs synchronously without a network round-trip — the LLMClient argument
|
||||||
is not needed on the pseudo path. Phase 4.5+ will need a real client.
|
is not needed on the pseudo path. Phase 4.5+ will need a real client.
|
||||||
|
|
||||||
|
T112 (Phase 4.5) adds two flags:
|
||||||
|
|
||||||
|
* ``--re-embed-all`` walks **every** memory regardless of whether it
|
||||||
|
already has an ``embeddings`` row. Useful when swapping embedding
|
||||||
|
models — the projector is INSERT OR REPLACE, so re-emitting an event
|
||||||
|
for an existing memory replaces the prior vector. Without this flag,
|
||||||
|
the script keeps the Phase 4 behavior of only filling in gaps.
|
||||||
|
* ``--model M`` overrides ``Settings.embedding_model`` for this run.
|
||||||
|
Defaults to the configured model (which itself defaults to
|
||||||
|
``"pseudo-sha256-384"``).
|
||||||
|
|
||||||
Run from the repo root:
|
Run from the repo root:
|
||||||
.venv/bin/python scripts/backfill_embeddings.py [--limit N] [--dry-run]
|
.venv/bin/python scripts/backfill_embeddings.py [--limit N] [--dry-run]
|
||||||
|
.venv/bin/python scripts/backfill_embeddings.py --re-embed-all
|
||||||
|
.venv/bin/python scripts/backfill_embeddings.py --re-embed-all --model bge-small-en-v1.5
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -17,11 +30,12 @@ from __future__ import annotations
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from chat.config import load_settings
|
from chat.config import Settings, load_settings
|
||||||
from chat.db.connection import open_db
|
from chat.db.connection import open_db
|
||||||
from chat.db.migrate import apply_migrations
|
from chat.db.migrate import apply_migrations
|
||||||
from chat.eventlog.log import append_and_apply
|
from chat.eventlog.log import append_and_apply
|
||||||
from chat.services.embeddings import (
|
from chat.services.embeddings import (
|
||||||
|
DEFAULT_EMBEDDING_MODEL,
|
||||||
FALLBACK_EMBEDDING_MODEL,
|
FALLBACK_EMBEDDING_MODEL,
|
||||||
generate_embedding,
|
generate_embedding,
|
||||||
)
|
)
|
||||||
@@ -34,6 +48,24 @@ import chat.state.memory # noqa: F401
|
|||||||
import chat.state.world # noqa: F401
|
import chat.state.world # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
def _build_client(settings: Settings):
|
||||||
|
"""Construct an LLMClient for the backfill run.
|
||||||
|
|
||||||
|
Default-model runs (the pseudo path) don't need a client, so we
|
||||||
|
return ``None`` and ``generate_embedding`` skips the call. Non-default
|
||||||
|
models route through the real client; injectable via monkeypatch in
|
||||||
|
tests.
|
||||||
|
"""
|
||||||
|
if settings.embedding_model == DEFAULT_EMBEDDING_MODEL:
|
||||||
|
return None
|
||||||
|
from chat.llm.featherless import FeatherlessClient
|
||||||
|
|
||||||
|
return FeatherlessClient(
|
||||||
|
api_key=settings.featherless_api_key,
|
||||||
|
base_url=settings.featherless_base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def main() -> None:
|
async def main() -> None:
|
||||||
parser = argparse.ArgumentParser(description=__doc__)
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -47,23 +79,51 @@ async def main() -> None:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Print the count of memories needing embeddings, then exit.",
|
help="Print the count of memories needing embeddings, then exit.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--re-embed-all",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"Walk every memory (not just those without an embeddings row) "
|
||||||
|
"and re-emit embedding_indexed events. Use this when swapping "
|
||||||
|
"embedding models so the existing rows get replaced."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Embedding model identifier. Overrides Settings.embedding_model "
|
||||||
|
"for this run; default uses the configured model."
|
||||||
|
),
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
settings = load_settings()
|
settings = load_settings()
|
||||||
settings.db_path.parent.mkdir(parents=True, exist_ok=True)
|
settings.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
apply_migrations(settings.db_path)
|
apply_migrations(settings.db_path)
|
||||||
|
|
||||||
|
model = args.model or settings.embedding_model
|
||||||
|
# Override the settings instance so ``_build_client`` sees the
|
||||||
|
# effective model when deciding whether to construct a real client.
|
||||||
|
settings = settings.model_copy(update={"embedding_model": model})
|
||||||
|
client = _build_client(settings)
|
||||||
|
|
||||||
with open_db(settings.db_path) as conn:
|
with open_db(settings.db_path) as conn:
|
||||||
sql = (
|
if args.re_embed_all:
|
||||||
"SELECT m.id, m.pov_summary FROM memories m "
|
sql = "SELECT m.id, m.pov_summary FROM memories m ORDER BY m.id"
|
||||||
"LEFT JOIN embeddings e ON e.memory_id = m.id "
|
else:
|
||||||
"WHERE e.memory_id IS NULL "
|
sql = (
|
||||||
"ORDER BY m.id"
|
"SELECT m.id, m.pov_summary FROM memories m "
|
||||||
)
|
"LEFT JOIN embeddings e ON e.memory_id = m.id "
|
||||||
|
"WHERE e.memory_id IS NULL "
|
||||||
|
"ORDER BY m.id"
|
||||||
|
)
|
||||||
if args.limit is not None:
|
if args.limit is not None:
|
||||||
sql += f" LIMIT {int(args.limit)}"
|
sql += f" LIMIT {int(args.limit)}"
|
||||||
rows = conn.execute(sql).fetchall()
|
rows = conn.execute(sql).fetchall()
|
||||||
print(f"Found {len(rows)} memories needing embeddings.")
|
mode = "re-embedding" if args.re_embed_all else "needing embeddings"
|
||||||
|
print(f"Found {len(rows)} memories {mode} (model={model}).")
|
||||||
if args.dry_run:
|
if args.dry_run:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -71,11 +131,12 @@ async def main() -> None:
|
|||||||
skipped = 0
|
skipped = 0
|
||||||
for memory_id, text in rows:
|
for memory_id, text in rows:
|
||||||
result = await generate_embedding(
|
result = await generate_embedding(
|
||||||
client=None, # pseudo path: no client needed
|
client=client,
|
||||||
text=text or "",
|
text=text or "",
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
if result.model == FALLBACK_EMBEDDING_MODEL:
|
if result.model == FALLBACK_EMBEDDING_MODEL:
|
||||||
print(f" Skipping memory_id={memory_id} (empty text)")
|
print(f" Skipping memory_id={memory_id} (empty text or fallback)")
|
||||||
skipped += 1
|
skipped += 1
|
||||||
continue
|
continue
|
||||||
append_and_apply(
|
append_and_apply(
|
||||||
|
|||||||
@@ -0,0 +1,231 @@
|
|||||||
|
"""Tests for the backfill_embeddings script (T112, Phase 4.5).
|
||||||
|
|
||||||
|
Phase 4 shipped a backfill that walked memories *without* an embedding
|
||||||
|
row and produced a vector for each (deterministic pseudo path). T112
|
||||||
|
adds a ``--re-embed-all`` flag that walks **every** memory regardless
|
||||||
|
of whether it already has an embeddings row, so operators can swap
|
||||||
|
embedding models and have the existing rows replaced (the
|
||||||
|
``embedding_indexed`` projector is INSERT OR REPLACE).
|
||||||
|
|
||||||
|
These tests exercise the script's ``main()`` directly via asyncio —
|
||||||
|
shell-out via subprocess would also work but importing keeps the
|
||||||
|
fixture surface small and the failure mode clearer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from chat.db.connection import open_db
|
||||||
|
from chat.db.migrate import apply_migrations
|
||||||
|
from chat.eventlog.log import append_and_apply, append_event
|
||||||
|
from chat.eventlog.projector import project
|
||||||
|
from chat.services.embeddings import DEFAULT_EMBEDDING_MODEL
|
||||||
|
|
||||||
|
# Trigger handler registration for projection.
|
||||||
|
import chat.state.embeddings # noqa: F401
|
||||||
|
import chat.state.entities # noqa: F401
|
||||||
|
import chat.state.memory # noqa: F401
|
||||||
|
import chat.state.world # noqa: F401
|
||||||
|
|
||||||
|
import scripts.backfill_embeddings as backfill
|
||||||
|
|
||||||
|
|
||||||
|
def _seed(db_path: Path, count: int) -> list[int]:
|
||||||
|
"""Seed ``count`` memory rows for ``bot_a``; return their ids."""
|
||||||
|
with open_db(db_path) as conn:
|
||||||
|
append_event(
|
||||||
|
conn,
|
||||||
|
kind="bot_authored",
|
||||||
|
payload={
|
||||||
|
"id": "bot_a",
|
||||||
|
"name": "BotA",
|
||||||
|
"persona": "...",
|
||||||
|
"voice_samples": [],
|
||||||
|
"traits": [],
|
||||||
|
"backstory": "",
|
||||||
|
"initial_relationship_to_you": "",
|
||||||
|
"kickoff_prose": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
append_event(
|
||||||
|
conn,
|
||||||
|
kind="chat_created",
|
||||||
|
payload={
|
||||||
|
"id": "chat_bot_a",
|
||||||
|
"host_bot_id": "bot_a",
|
||||||
|
"initial_time": "2026-04-26T20:00:00+00:00",
|
||||||
|
"narrative_anchor": "Day 1",
|
||||||
|
"weather": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for i in range(count):
|
||||||
|
append_event(
|
||||||
|
conn,
|
||||||
|
kind="memory_written",
|
||||||
|
payload={
|
||||||
|
"owner_id": "bot_a",
|
||||||
|
"chat_id": "chat_bot_a",
|
||||||
|
"pov_summary": f"memory text {i}",
|
||||||
|
"witness_you": 1,
|
||||||
|
"witness_host": 1,
|
||||||
|
"witness_guest": 0,
|
||||||
|
"source": "direct",
|
||||||
|
"reliability": 1.0,
|
||||||
|
"significance": 1,
|
||||||
|
"pinned": 0,
|
||||||
|
"auto_pinned": 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
project(conn)
|
||||||
|
return [
|
||||||
|
r[0]
|
||||||
|
for r in conn.execute(
|
||||||
|
"SELECT id FROM memories WHERE owner_id = 'bot_a' ORDER BY id"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _seed_embedding(db_path: Path, memory_id: int, model: str = "stale-model") -> None:
|
||||||
|
"""Insert a stale ``embedding_indexed`` event so the row already
|
||||||
|
exists in ``embeddings`` (and the default backfill would skip it)."""
|
||||||
|
with open_db(db_path) as conn:
|
||||||
|
append_and_apply(
|
||||||
|
conn,
|
||||||
|
kind="embedding_indexed",
|
||||||
|
payload={
|
||||||
|
"memory_id": memory_id,
|
||||||
|
"model": model,
|
||||||
|
"dim": 3,
|
||||||
|
"vector": [0.0, 0.0, 0.0],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_re_embed_all_walks_every_memory(tmp_path, monkeypatch, capsys):
|
||||||
|
"""``--re-embed-all`` re-embeds memories that already have rows in
|
||||||
|
``embeddings`` (default mode skips them). After the run, every
|
||||||
|
memory should have an updated embedding tagged with the configured
|
||||||
|
model (the projector replaces stale rows in place)."""
|
||||||
|
db = tmp_path / "t.db"
|
||||||
|
apply_migrations(db)
|
||||||
|
memory_ids = _seed(db, count=3)
|
||||||
|
# Pre-seed stale embeddings on two of the three memories so the
|
||||||
|
# default path would skip them and only ``--re-embed-all`` covers
|
||||||
|
# everything.
|
||||||
|
_seed_embedding(db, memory_ids[0])
|
||||||
|
_seed_embedding(db, memory_ids[1])
|
||||||
|
|
||||||
|
cfg = tmp_path / "config.toml"
|
||||||
|
cfg.write_text(
|
||||||
|
f'featherless_api_key = "x"\n'
|
||||||
|
f'db_path = "{db}"\n'
|
||||||
|
f'data_dir = "{tmp_path}"\n'
|
||||||
|
)
|
||||||
|
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
|
||||||
|
monkeypatch.setenv("CHAT_DB_PATH", str(db))
|
||||||
|
|
||||||
|
with patch("sys.argv", ["backfill_embeddings.py", "--re-embed-all"]):
|
||||||
|
await backfill.main()
|
||||||
|
|
||||||
|
# All three memories now have a fresh embedding tagged with the
|
||||||
|
# default pseudo model (replacing the stale rows).
|
||||||
|
with open_db(db) as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT memory_id, model FROM embeddings ORDER BY memory_id"
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 3
|
||||||
|
for mid, model in rows:
|
||||||
|
assert mid in memory_ids
|
||||||
|
assert model == DEFAULT_EMBEDDING_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_default_backfill_only_walks_missing(tmp_path, monkeypatch):
|
||||||
|
"""Without ``--re-embed-all``, the script keeps the Phase 4
|
||||||
|
behavior — memories with an existing embedding row are left
|
||||||
|
alone (their stale-model tag survives)."""
|
||||||
|
db = tmp_path / "t.db"
|
||||||
|
apply_migrations(db)
|
||||||
|
memory_ids = _seed(db, count=2)
|
||||||
|
_seed_embedding(db, memory_ids[0], model="stale-model")
|
||||||
|
# memory_ids[1] has no embedding yet.
|
||||||
|
|
||||||
|
cfg = tmp_path / "config.toml"
|
||||||
|
cfg.write_text(
|
||||||
|
f'featherless_api_key = "x"\n'
|
||||||
|
f'db_path = "{db}"\n'
|
||||||
|
f'data_dir = "{tmp_path}"\n'
|
||||||
|
)
|
||||||
|
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
|
||||||
|
monkeypatch.setenv("CHAT_DB_PATH", str(db))
|
||||||
|
|
||||||
|
with patch("sys.argv", ["backfill_embeddings.py"]):
|
||||||
|
await backfill.main()
|
||||||
|
|
||||||
|
with open_db(db) as conn:
|
||||||
|
rows = dict(
|
||||||
|
conn.execute(
|
||||||
|
"SELECT memory_id, model FROM embeddings ORDER BY memory_id"
|
||||||
|
).fetchall()
|
||||||
|
)
|
||||||
|
# Stale row preserved; only the missing one was filled.
|
||||||
|
assert rows[memory_ids[0]] == "stale-model"
|
||||||
|
assert rows[memory_ids[1]] == DEFAULT_EMBEDDING_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_re_embed_all_respects_model_arg(tmp_path, monkeypatch):
|
||||||
|
"""The ``--model`` flag overrides ``Settings.embedding_model``.
|
||||||
|
With a non-default model and a client that returns canned vectors,
|
||||||
|
every memory is re-embedded with the supplied model tag."""
|
||||||
|
db = tmp_path / "t.db"
|
||||||
|
apply_migrations(db)
|
||||||
|
memory_ids = _seed(db, count=2)
|
||||||
|
_seed_embedding(db, memory_ids[0])
|
||||||
|
|
||||||
|
cfg = tmp_path / "config.toml"
|
||||||
|
cfg.write_text(
|
||||||
|
f'featherless_api_key = "x"\n'
|
||||||
|
f'db_path = "{db}"\n'
|
||||||
|
f'data_dir = "{tmp_path}"\n'
|
||||||
|
)
|
||||||
|
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
|
||||||
|
monkeypatch.setenv("CHAT_DB_PATH", str(db))
|
||||||
|
|
||||||
|
# Patch the client factory the script uses to produce a Mock with
|
||||||
|
# canned embeddings — one per memory.
|
||||||
|
from chat.llm.mock import MockLLMClient
|
||||||
|
|
||||||
|
canned_vec = [0.1] * 384
|
||||||
|
|
||||||
|
def _factory(_settings):
|
||||||
|
return MockLLMClient(
|
||||||
|
canned=[],
|
||||||
|
canned_embeddings=[list(canned_vec) for _ in memory_ids],
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(backfill, "_build_client", _factory)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"sys.argv",
|
||||||
|
[
|
||||||
|
"backfill_embeddings.py",
|
||||||
|
"--re-embed-all",
|
||||||
|
"--model",
|
||||||
|
"bge-small-en-v1.5",
|
||||||
|
],
|
||||||
|
):
|
||||||
|
await backfill.main()
|
||||||
|
|
||||||
|
with open_db(db) as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT memory_id, model FROM embeddings ORDER BY memory_id"
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 2
|
||||||
|
for _, model in rows:
|
||||||
|
assert model == "bge-small-en-v1.5"
|
||||||
@@ -24,3 +24,25 @@ def test_chat_db_path_env_overrides_default(tmp_path, monkeypatch):
|
|||||||
(tmp_path / "config.toml").write_text('featherless_api_key = "x"\n')
|
(tmp_path / "config.toml").write_text('featherless_api_key = "x"\n')
|
||||||
s = load_settings()
|
s = load_settings()
|
||||||
assert s.db_path == tmp_path / "alt.db"
|
assert s.db_path == tmp_path / "alt.db"
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_model_defaults_to_pseudo(tmp_path, monkeypatch):
|
||||||
|
"""T112: ``embedding_model`` defaults to the deterministic pseudo
|
||||||
|
so existing zero-config installs keep the Phase 4 behavior."""
|
||||||
|
monkeypatch.setenv("CHAT_CONFIG_PATH", str(tmp_path / "config.toml"))
|
||||||
|
(tmp_path / "config.toml").write_text('featherless_api_key = "x"\n')
|
||||||
|
s = load_settings()
|
||||||
|
assert s.embedding_model == "pseudo-sha256-384"
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_model_overridable_via_toml(tmp_path, monkeypatch):
|
||||||
|
"""T112: operators swap the embedding model by editing config.toml.
|
||||||
|
The new value flows through to the embedding worker at startup."""
|
||||||
|
cfg = tmp_path / "config.toml"
|
||||||
|
cfg.write_text(
|
||||||
|
'featherless_api_key = "x"\n'
|
||||||
|
'embedding_model = "bge-small-en-v1.5"\n'
|
||||||
|
)
|
||||||
|
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
|
||||||
|
s = load_settings()
|
||||||
|
assert s.embedding_model == "bge-small-en-v1.5"
|
||||||
|
|||||||
@@ -120,3 +120,51 @@ async def test_generate_embedding_default_model_does_not_warn(caplog):
|
|||||||
await generate_embedding(_client(), text="hello")
|
await generate_embedding(_client(), text="hello")
|
||||||
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
|
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
|
||||||
assert warnings == []
|
assert warnings == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embed_routes_to_client_when_non_default_model():
|
||||||
|
"""T112: when a non-default ``model`` is requested, generate_embedding
|
||||||
|
routes through ``client.embed(text, model=...)`` and wraps the
|
||||||
|
returned vector in an EmbeddingResult tagged with the requested
|
||||||
|
model (NOT the fallback sentinel)."""
|
||||||
|
canned = [0.1, 0.2, 0.3, 0.4]
|
||||||
|
client = MockLLMClient(canned=[], canned_embeddings=[canned])
|
||||||
|
|
||||||
|
result = await generate_embedding(
|
||||||
|
client, text="hello world", model="bge-small-en-v1.5"
|
||||||
|
)
|
||||||
|
assert result.vector == canned
|
||||||
|
assert result.model == "bge-small-en-v1.5"
|
||||||
|
assert result.dim == len(canned)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embed_falls_back_on_client_failure(caplog):
|
||||||
|
"""T112: when ``client.embed`` raises (e.g. NotImplementedError on
|
||||||
|
Featherless, or a transient network error), generate_embedding logs
|
||||||
|
the existing T107 warning and returns the zero-vector fallback so
|
||||||
|
callers detect the sentinel and skip indexing."""
|
||||||
|
|
||||||
|
class _FailingClient:
|
||||||
|
async def generate(self, messages, *, model, **params): # pragma: no cover
|
||||||
|
raise AssertionError("generate must not be called")
|
||||||
|
|
||||||
|
def stream(self, messages, *, model, **params): # pragma: no cover
|
||||||
|
raise AssertionError("stream must not be called")
|
||||||
|
|
||||||
|
async def embed(self, text, *, model):
|
||||||
|
raise NotImplementedError("provider does not expose embeddings")
|
||||||
|
|
||||||
|
caplog.set_level(logging.WARNING, logger="chat.services.embeddings")
|
||||||
|
result = await generate_embedding(
|
||||||
|
_FailingClient(), text="hello", model="bge-small-en-v1.5"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.model == FALLBACK_EMBEDDING_MODEL == "fallback"
|
||||||
|
assert len(result.vector) == DEFAULT_EMBEDDING_DIM
|
||||||
|
assert all(x == 0.0 for x in result.vector)
|
||||||
|
|
||||||
|
# Existing T107 warning fires (re-used from the new exception branch).
|
||||||
|
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
|
||||||
|
assert any("bge-small-en-v1.5" in r.getMessage() for r in warnings)
|
||||||
|
|||||||
@@ -0,0 +1,32 @@
|
|||||||
|
"""Tests for FeatherlessClient (Phase 4.5+).
|
||||||
|
|
||||||
|
Phase 4.5 adds an ``embed()`` method to the LLMClient Protocol (T112).
|
||||||
|
Featherless does not expose an OpenAI-compatible ``/v1/embeddings``
|
||||||
|
endpoint, so its implementation deliberately raises
|
||||||
|
``NotImplementedError`` to surface the gap clearly. The
|
||||||
|
``generate_embedding`` wrapper catches this and degrades to the
|
||||||
|
zero-vector fallback (the existing T107 warning path).
|
||||||
|
|
||||||
|
If/when Featherless ships embeddings, swap the body for a real call to
|
||||||
|
``/v1/embeddings`` and update this test to mock the HTTP layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from chat.llm.featherless import FeatherlessClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_featherless_embed_raises_not_implemented():
|
||||||
|
"""Featherless does not expose ``/v1/embeddings`` — embed() must
|
||||||
|
raise ``NotImplementedError`` so callers (``generate_embedding``)
|
||||||
|
can degrade to the fallback zero vector + warning rather than
|
||||||
|
silently producing useless output."""
|
||||||
|
client = FeatherlessClient(api_key="test-key")
|
||||||
|
with pytest.raises(NotImplementedError) as excinfo:
|
||||||
|
await client.embed("hello world", model="bge-small-en-v1.5")
|
||||||
|
# Message should hint at the cause so operators see why their
|
||||||
|
# real-model swap fell back.
|
||||||
|
assert "embeddings" in str(excinfo.value).lower()
|
||||||
@@ -19,3 +19,28 @@ async def test_mock_streams_tokens():
|
|||||||
async for chunk in client.stream(msgs, model="any"):
|
async for chunk in client.stream(msgs, model="any"):
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
assert "".join(chunks) == "abcd"
|
assert "".join(chunks) == "abcd"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mock_llm_client_embed_pops_canned():
|
||||||
|
"""T112: MockLLMClient.embed() pops a canned vector from the front
|
||||||
|
of ``canned_embeddings`` (mirrors the existing ``canned`` queue
|
||||||
|
pattern for generate/stream)."""
|
||||||
|
v1 = [0.1, 0.2, 0.3]
|
||||||
|
v2 = [0.4, 0.5, 0.6]
|
||||||
|
client = MockLLMClient(canned=[], canned_embeddings=[v1, v2])
|
||||||
|
|
||||||
|
out1 = await client.embed("first", model="bge-small-en-v1.5")
|
||||||
|
out2 = await client.embed("second", model="bge-small-en-v1.5")
|
||||||
|
assert out1 == v1
|
||||||
|
assert out2 == v2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mock_llm_client_embed_empty_queue_raises():
|
||||||
|
"""When the canned_embeddings queue is empty, ``embed`` must raise
|
||||||
|
a clear failure (IndexError) so misconfigured tests don't silently
|
||||||
|
return None or hang."""
|
||||||
|
client = MockLLMClient(canned=[])
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
await client.embed("text", model="any")
|
||||||
|
|||||||
Reference in New Issue
Block a user