feat: generate_embedding routes non-default models through client.embed (T112.3)

When model != DEFAULT_EMBEDDING_MODEL, generate_embedding now
calls client.embed(text, model=model) and wraps the returned
vector in an EmbeddingResult tagged with the requested model.
On any exception (NotImplementedError from providers without an
embeddings endpoint, transient network errors, etc.), the existing
T107 warning fires and the function falls back to the zero-vector
sentinel — callers detect model == 'fallback' and skip indexing.

Adds:
- MockLLMClient accepts a canned_embeddings queue mirroring
  the existing canned pattern. embed() pops from the front;
  empty queue raises IndexError so misconfigured tests fail
  loudly.
- Settings.embedding_model defaults to "pseudo-sha256-384"
  so existing zero-config installs keep Phase 4 behavior. The app
  lifespan now passes this through to EmbeddingWorker.model.

The public signature of generate_embedding is unchanged:
(client, *, text, model=DEFAULT_EMBEDDING_MODEL, dim=..., timeout_s=...).
This commit is contained in:
Joseph Doherty
2026-04-27 05:50:29 -04:00
parent ac6e74ab4c
commit e0a28abbcd
7 changed files with 151 additions and 14 deletions
+6
View File
@@ -94,9 +94,15 @@ async def lifespan(app: FastAPI):
# Phase 4's pseudo-embedding path is local so the worker doesn't need # Phase 4's pseudo-embedding path is local so the worker doesn't need
# an LLM client; we still pass one so the Phase 4.5 swap to a real # an LLM client; we still pass one so the Phase 4.5 swap to a real
# model is a one-line change. # model is a one-line change.
# T112 (Phase 4.5): the embedding model is now configurable via
# ``Settings.embedding_model``. Default ``"pseudo-sha256-384"``
# keeps the local-only path; swapping to a real model routes
# through ``client.embed(...)`` and falls back to a zero vector
# plus warning if the provider doesn't support embeddings.
embedding_worker = EmbeddingWorker( embedding_worker = EmbeddingWorker(
conn_factory=lambda: open_db(settings.db_path), conn_factory=lambda: open_db(settings.db_path),
client=_factory(), client=_factory(),
model=settings.embedding_model,
) )
await embedding_worker.start() await embedding_worker.start()
app.state.embedding_worker = embedding_worker app.state.embedding_worker = embedding_worker
+8
View File
@@ -39,6 +39,14 @@ class Settings(BaseModel):
data_dir: Path = REPO_ROOT / "data" data_dir: Path = REPO_ROOT / "data"
bind_host: str = "127.0.0.1" bind_host: str = "127.0.0.1"
bind_port: int = 8000 bind_port: int = 8000
# T112 (Phase 4.5): embedding model identifier. Default is the
# deterministic local pseudo (semantically meaningless but keeps the
# vector pipeline structurally valid). Swap to a real model name
# (e.g. "bge-small-en-v1.5") once the LLMClient implementation
# supports embed() — currently FeatherlessClient does NOT, so a
# non-default value will trigger the zero-vector fallback path
# plus a T107 warning until a different provider is wired in.
embedding_model: str = "pseudo-sha256-384"
def load_settings() -> Settings: def load_settings() -> Settings:
config_path = Path(os.environ.get("CHAT_CONFIG_PATH", DEFAULT_CONFIG)) config_path = Path(os.environ.get("CHAT_CONFIG_PATH", DEFAULT_CONFIG))
+21 -1
View File
@@ -4,8 +4,23 @@ from .client import Message
class MockLLMClient: class MockLLMClient:
def __init__(self, canned: list[str]): """In-memory LLMClient for tests.
``canned`` feeds ``generate``/``stream`` (one entry per call, popped
from the front). ``canned_embeddings`` (T112, Phase 4.5) feeds
``embed`` the same way — each call pops the next vector. An empty
queue raises ``IndexError`` so misconfigured tests fail loudly
rather than returning ``None`` or hanging.
"""
def __init__(
self,
canned: list[str],
*,
canned_embeddings: list[list[float]] | None = None,
):
self._canned = list(canned) self._canned = list(canned)
self._canned_embeddings: list[list[float]] = list(canned_embeddings or [])
async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str: async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str:
return self._canned.pop(0) return self._canned.pop(0)
@@ -14,3 +29,8 @@ class MockLLMClient:
text = self._canned.pop(0) text = self._canned.pop(0)
for ch in text: for ch in text:
yield ch yield ch
async def embed(self, text: str, *, model: str) -> list[float]:
# Mirrors the canned-queue pattern; empty queue raises so
# misconfigured tests surface clearly instead of returning None.
return self._canned_embeddings.pop(0)
+21 -13
View File
@@ -95,19 +95,27 @@ async def generate_embedding(
# Pure-local pseudo path — no LLMClient call. # Pure-local pseudo path — no LLMClient call.
return EmbeddingResult(vector=_pseudo_embed(text, dim), model=model, dim=dim) return EmbeddingResult(vector=_pseudo_embed(text, dim), model=model, dim=dim)
# Future: real embedding via client.embed(...). Phase 4.5 work. # T112 (Phase 4.5): non-default model — route through the client's
# For Phase 4, any non-default model falls through to fallback — # ``embed()`` method. On any failure (including ``NotImplementedError``
# warn so misconfigured callers (e.g., a real-model swap that isn't # from providers that don't expose embeddings, e.g. Featherless today),
# wired up yet) don't silently degrade to a zero vector. # fall back to the zero vector and re-fire the T107 warning so
_log.warning( # misconfigured callers see the issue in logs rather than silently
"generate_embedding: non-default model %r returned fallback " # producing useless cosine results.
"(model client.embed() not yet implemented in Phase 4.5+); " try:
"downstream search will degrade silently. Configure a supported model.", vector = await client.embed(text, model=model)
model, return EmbeddingResult(vector=list(vector), model=model, dim=len(vector))
) except Exception as exc: # noqa: BLE001 — any failure must degrade gracefully
return EmbeddingResult( _log.warning(
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim "generate_embedding: non-default model %r returned fallback "
) "(client.embed() raised %s: %s); "
"downstream search will degrade silently. Configure a supported model.",
model,
type(exc).__name__,
exc,
)
return EmbeddingResult(
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim
)
__all__ = [ __all__ = [
+22
View File
@@ -24,3 +24,25 @@ def test_chat_db_path_env_overrides_default(tmp_path, monkeypatch):
(tmp_path / "config.toml").write_text('featherless_api_key = "x"\n') (tmp_path / "config.toml").write_text('featherless_api_key = "x"\n')
s = load_settings() s = load_settings()
assert s.db_path == tmp_path / "alt.db" assert s.db_path == tmp_path / "alt.db"
def test_embedding_model_defaults_to_pseudo(tmp_path, monkeypatch):
"""T112: ``embedding_model`` defaults to the deterministic pseudo
so existing zero-config installs keep the Phase 4 behavior."""
monkeypatch.setenv("CHAT_CONFIG_PATH", str(tmp_path / "config.toml"))
(tmp_path / "config.toml").write_text('featherless_api_key = "x"\n')
s = load_settings()
assert s.embedding_model == "pseudo-sha256-384"
def test_embedding_model_overridable_via_toml(tmp_path, monkeypatch):
"""T112: operators swap the embedding model by editing config.toml.
The new value flows through to the embedding worker at startup."""
cfg = tmp_path / "config.toml"
cfg.write_text(
'featherless_api_key = "x"\n'
'embedding_model = "bge-small-en-v1.5"\n'
)
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
s = load_settings()
assert s.embedding_model == "bge-small-en-v1.5"
+48
View File
@@ -120,3 +120,51 @@ async def test_generate_embedding_default_model_does_not_warn(caplog):
await generate_embedding(_client(), text="hello") await generate_embedding(_client(), text="hello")
warnings = [r for r in caplog.records if r.levelno == logging.WARNING] warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
assert warnings == [] assert warnings == []
@pytest.mark.asyncio
async def test_embed_routes_to_client_when_non_default_model():
"""T112: when a non-default ``model`` is requested, generate_embedding
routes through ``client.embed(text, model=...)`` and wraps the
returned vector in an EmbeddingResult tagged with the requested
model (NOT the fallback sentinel)."""
canned = [0.1, 0.2, 0.3, 0.4]
client = MockLLMClient(canned=[], canned_embeddings=[canned])
result = await generate_embedding(
client, text="hello world", model="bge-small-en-v1.5"
)
assert result.vector == canned
assert result.model == "bge-small-en-v1.5"
assert result.dim == len(canned)
@pytest.mark.asyncio
async def test_embed_falls_back_on_client_failure(caplog):
"""T112: when ``client.embed`` raises (e.g. NotImplementedError on
Featherless, or a transient network error), generate_embedding logs
the existing T107 warning and returns the zero-vector fallback so
callers detect the sentinel and skip indexing."""
class _FailingClient:
async def generate(self, messages, *, model, **params): # pragma: no cover
raise AssertionError("generate must not be called")
def stream(self, messages, *, model, **params): # pragma: no cover
raise AssertionError("stream must not be called")
async def embed(self, text, *, model):
raise NotImplementedError("provider does not expose embeddings")
caplog.set_level(logging.WARNING, logger="chat.services.embeddings")
result = await generate_embedding(
_FailingClient(), text="hello", model="bge-small-en-v1.5"
)
assert result.model == FALLBACK_EMBEDDING_MODEL == "fallback"
assert len(result.vector) == DEFAULT_EMBEDDING_DIM
assert all(x == 0.0 for x in result.vector)
# Existing T107 warning fires (re-used from the new exception branch).
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
assert any("bge-small-en-v1.5" in r.getMessage() for r in warnings)
+25
View File
@@ -19,3 +19,28 @@ async def test_mock_streams_tokens():
async for chunk in client.stream(msgs, model="any"): async for chunk in client.stream(msgs, model="any"):
chunks.append(chunk) chunks.append(chunk)
assert "".join(chunks) == "abcd" assert "".join(chunks) == "abcd"
@pytest.mark.asyncio
async def test_mock_llm_client_embed_pops_canned():
"""T112: MockLLMClient.embed() pops a canned vector from the front
of ``canned_embeddings`` (mirrors the existing ``canned`` queue
pattern for generate/stream)."""
v1 = [0.1, 0.2, 0.3]
v2 = [0.4, 0.5, 0.6]
client = MockLLMClient(canned=[], canned_embeddings=[v1, v2])
out1 = await client.embed("first", model="bge-small-en-v1.5")
out2 = await client.embed("second", model="bge-small-en-v1.5")
assert out1 == v1
assert out2 == v2
@pytest.mark.asyncio
async def test_mock_llm_client_embed_empty_queue_raises():
"""When the canned_embeddings queue is empty, ``embed`` must raise
a clear failure (IndexError) so misconfigured tests don't silently
return None or hang."""
client = MockLLMClient(canned=[])
with pytest.raises(IndexError):
await client.embed("text", model="any")