feat: generate_embedding routes non-default models through client.embed (T112.3)
When model != DEFAULT_EMBEDDING_MODEL, generate_embedding now calls client.embed(text, model=model) and wraps the returned vector in an EmbeddingResult tagged with the requested model. On any exception (NotImplementedError from providers without an embeddings endpoint, transient network errors, etc.), the existing T107 warning fires and the function falls back to the zero-vector sentinel — callers detect model == 'fallback' and skip indexing. Adds: - MockLLMClient accepts a canned_embeddings queue mirroring the existing canned pattern. embed() pops from the front; empty queue raises IndexError so misconfigured tests fail loudly. - Settings.embedding_model defaults to "pseudo-sha256-384" so existing zero-config installs keep Phase 4 behavior. The app lifespan now passes this through to EmbeddingWorker.model. The public signature of generate_embedding is unchanged: (client, *, text, model=DEFAULT_EMBEDDING_MODEL, dim=..., timeout_s=...).
This commit is contained in:
@@ -94,9 +94,15 @@ async def lifespan(app: FastAPI):
|
||||
# Phase 4's pseudo-embedding path is local so the worker doesn't need
|
||||
# an LLM client; we still pass one so the Phase 4.5 swap to a real
|
||||
# model is a one-line change.
|
||||
# T112 (Phase 4.5): the embedding model is now configurable via
|
||||
# ``Settings.embedding_model``. Default ``"pseudo-sha256-384"``
|
||||
# keeps the local-only path; swapping to a real model routes
|
||||
# through ``client.embed(...)`` and falls back to a zero vector
|
||||
# plus warning if the provider doesn't support embeddings.
|
||||
embedding_worker = EmbeddingWorker(
|
||||
conn_factory=lambda: open_db(settings.db_path),
|
||||
client=_factory(),
|
||||
model=settings.embedding_model,
|
||||
)
|
||||
await embedding_worker.start()
|
||||
app.state.embedding_worker = embedding_worker
|
||||
|
||||
@@ -39,6 +39,14 @@ class Settings(BaseModel):
|
||||
data_dir: Path = REPO_ROOT / "data"
|
||||
bind_host: str = "127.0.0.1"
|
||||
bind_port: int = 8000
|
||||
# T112 (Phase 4.5): embedding model identifier. Default is the
|
||||
# deterministic local pseudo (semantically meaningless but keeps the
|
||||
# vector pipeline structurally valid). Swap to a real model name
|
||||
# (e.g. "bge-small-en-v1.5") once the LLMClient implementation
|
||||
# supports embed() — currently FeatherlessClient does NOT, so a
|
||||
# non-default value will trigger the zero-vector fallback path
|
||||
# plus a T107 warning until a different provider is wired in.
|
||||
embedding_model: str = "pseudo-sha256-384"
|
||||
|
||||
def load_settings() -> Settings:
|
||||
config_path = Path(os.environ.get("CHAT_CONFIG_PATH", DEFAULT_CONFIG))
|
||||
|
||||
+21
-1
@@ -4,8 +4,23 @@ from .client import Message
|
||||
|
||||
|
||||
class MockLLMClient:
|
||||
def __init__(self, canned: list[str]):
|
||||
"""In-memory LLMClient for tests.
|
||||
|
||||
``canned`` feeds ``generate``/``stream`` (one entry per call, popped
|
||||
from the front). ``canned_embeddings`` (T112, Phase 4.5) feeds
|
||||
``embed`` the same way — each call pops the next vector. An empty
|
||||
queue raises ``IndexError`` so misconfigured tests fail loudly
|
||||
rather than returning ``None`` or hanging.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
canned: list[str],
|
||||
*,
|
||||
canned_embeddings: list[list[float]] | None = None,
|
||||
):
|
||||
self._canned = list(canned)
|
||||
self._canned_embeddings: list[list[float]] = list(canned_embeddings or [])
|
||||
|
||||
async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str:
|
||||
return self._canned.pop(0)
|
||||
@@ -14,3 +29,8 @@ class MockLLMClient:
|
||||
text = self._canned.pop(0)
|
||||
for ch in text:
|
||||
yield ch
|
||||
|
||||
async def embed(self, text: str, *, model: str) -> list[float]:
|
||||
# Mirrors the canned-queue pattern; empty queue raises so
|
||||
# misconfigured tests surface clearly instead of returning None.
|
||||
return self._canned_embeddings.pop(0)
|
||||
|
||||
+21
-13
@@ -95,19 +95,27 @@ async def generate_embedding(
|
||||
# Pure-local pseudo path — no LLMClient call.
|
||||
return EmbeddingResult(vector=_pseudo_embed(text, dim), model=model, dim=dim)
|
||||
|
||||
# Future: real embedding via client.embed(...). Phase 4.5 work.
|
||||
# For Phase 4, any non-default model falls through to fallback —
|
||||
# warn so misconfigured callers (e.g., a real-model swap that isn't
|
||||
# wired up yet) don't silently degrade to a zero vector.
|
||||
_log.warning(
|
||||
"generate_embedding: non-default model %r returned fallback "
|
||||
"(model client.embed() not yet implemented in Phase 4.5+); "
|
||||
"downstream search will degrade silently. Configure a supported model.",
|
||||
model,
|
||||
)
|
||||
return EmbeddingResult(
|
||||
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim
|
||||
)
|
||||
# T112 (Phase 4.5): non-default model — route through the client's
|
||||
# ``embed()`` method. On any failure (including ``NotImplementedError``
|
||||
# from providers that don't expose embeddings, e.g. Featherless today),
|
||||
# fall back to the zero vector and re-fire the T107 warning so
|
||||
# misconfigured callers see the issue in logs rather than silently
|
||||
# producing useless cosine results.
|
||||
try:
|
||||
vector = await client.embed(text, model=model)
|
||||
return EmbeddingResult(vector=list(vector), model=model, dim=len(vector))
|
||||
except Exception as exc: # noqa: BLE001 — any failure must degrade gracefully
|
||||
_log.warning(
|
||||
"generate_embedding: non-default model %r returned fallback "
|
||||
"(client.embed() raised %s: %s); "
|
||||
"downstream search will degrade silently. Configure a supported model.",
|
||||
model,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
)
|
||||
return EmbeddingResult(
|
||||
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -24,3 +24,25 @@ def test_chat_db_path_env_overrides_default(tmp_path, monkeypatch):
|
||||
(tmp_path / "config.toml").write_text('featherless_api_key = "x"\n')
|
||||
s = load_settings()
|
||||
assert s.db_path == tmp_path / "alt.db"
|
||||
|
||||
|
||||
def test_embedding_model_defaults_to_pseudo(tmp_path, monkeypatch):
|
||||
"""T112: ``embedding_model`` defaults to the deterministic pseudo
|
||||
so existing zero-config installs keep the Phase 4 behavior."""
|
||||
monkeypatch.setenv("CHAT_CONFIG_PATH", str(tmp_path / "config.toml"))
|
||||
(tmp_path / "config.toml").write_text('featherless_api_key = "x"\n')
|
||||
s = load_settings()
|
||||
assert s.embedding_model == "pseudo-sha256-384"
|
||||
|
||||
|
||||
def test_embedding_model_overridable_via_toml(tmp_path, monkeypatch):
|
||||
"""T112: operators swap the embedding model by editing config.toml.
|
||||
The new value flows through to the embedding worker at startup."""
|
||||
cfg = tmp_path / "config.toml"
|
||||
cfg.write_text(
|
||||
'featherless_api_key = "x"\n'
|
||||
'embedding_model = "bge-small-en-v1.5"\n'
|
||||
)
|
||||
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
|
||||
s = load_settings()
|
||||
assert s.embedding_model == "bge-small-en-v1.5"
|
||||
|
||||
@@ -120,3 +120,51 @@ async def test_generate_embedding_default_model_does_not_warn(caplog):
|
||||
await generate_embedding(_client(), text="hello")
|
||||
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
|
||||
assert warnings == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_routes_to_client_when_non_default_model():
|
||||
"""T112: when a non-default ``model`` is requested, generate_embedding
|
||||
routes through ``client.embed(text, model=...)`` and wraps the
|
||||
returned vector in an EmbeddingResult tagged with the requested
|
||||
model (NOT the fallback sentinel)."""
|
||||
canned = [0.1, 0.2, 0.3, 0.4]
|
||||
client = MockLLMClient(canned=[], canned_embeddings=[canned])
|
||||
|
||||
result = await generate_embedding(
|
||||
client, text="hello world", model="bge-small-en-v1.5"
|
||||
)
|
||||
assert result.vector == canned
|
||||
assert result.model == "bge-small-en-v1.5"
|
||||
assert result.dim == len(canned)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_falls_back_on_client_failure(caplog):
|
||||
"""T112: when ``client.embed`` raises (e.g. NotImplementedError on
|
||||
Featherless, or a transient network error), generate_embedding logs
|
||||
the existing T107 warning and returns the zero-vector fallback so
|
||||
callers detect the sentinel and skip indexing."""
|
||||
|
||||
class _FailingClient:
|
||||
async def generate(self, messages, *, model, **params): # pragma: no cover
|
||||
raise AssertionError("generate must not be called")
|
||||
|
||||
def stream(self, messages, *, model, **params): # pragma: no cover
|
||||
raise AssertionError("stream must not be called")
|
||||
|
||||
async def embed(self, text, *, model):
|
||||
raise NotImplementedError("provider does not expose embeddings")
|
||||
|
||||
caplog.set_level(logging.WARNING, logger="chat.services.embeddings")
|
||||
result = await generate_embedding(
|
||||
_FailingClient(), text="hello", model="bge-small-en-v1.5"
|
||||
)
|
||||
|
||||
assert result.model == FALLBACK_EMBEDDING_MODEL == "fallback"
|
||||
assert len(result.vector) == DEFAULT_EMBEDDING_DIM
|
||||
assert all(x == 0.0 for x in result.vector)
|
||||
|
||||
# Existing T107 warning fires (re-used from the new exception branch).
|
||||
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
|
||||
assert any("bge-small-en-v1.5" in r.getMessage() for r in warnings)
|
||||
|
||||
@@ -19,3 +19,28 @@ async def test_mock_streams_tokens():
|
||||
async for chunk in client.stream(msgs, model="any"):
|
||||
chunks.append(chunk)
|
||||
assert "".join(chunks) == "abcd"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_llm_client_embed_pops_canned():
|
||||
"""T112: MockLLMClient.embed() pops a canned vector from the front
|
||||
of ``canned_embeddings`` (mirrors the existing ``canned`` queue
|
||||
pattern for generate/stream)."""
|
||||
v1 = [0.1, 0.2, 0.3]
|
||||
v2 = [0.4, 0.5, 0.6]
|
||||
client = MockLLMClient(canned=[], canned_embeddings=[v1, v2])
|
||||
|
||||
out1 = await client.embed("first", model="bge-small-en-v1.5")
|
||||
out2 = await client.embed("second", model="bge-small-en-v1.5")
|
||||
assert out1 == v1
|
||||
assert out2 == v2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_llm_client_embed_empty_queue_raises():
|
||||
"""When the canned_embeddings queue is empty, ``embed`` must raise
|
||||
a clear failure (IndexError) so misconfigured tests don't silently
|
||||
return None or hang."""
|
||||
client = MockLLMClient(canned=[])
|
||||
with pytest.raises(IndexError):
|
||||
await client.embed("text", model="any")
|
||||
|
||||
Reference in New Issue
Block a user