feat: generate_embedding routes non-default models through client.embed (T112.3)
When model != DEFAULT_EMBEDDING_MODEL, generate_embedding now calls client.embed(text, model=model) and wraps the returned vector in an EmbeddingResult tagged with the requested model. On any exception (NotImplementedError from providers without an embeddings endpoint, transient network errors, etc.), the existing T107 warning fires and the function falls back to the zero-vector sentinel — callers detect model == 'fallback' and skip indexing. Adds: - MockLLMClient accepts a canned_embeddings queue mirroring the existing canned pattern. embed() pops from the front; empty queue raises IndexError so misconfigured tests fail loudly. - Settings.embedding_model defaults to "pseudo-sha256-384" so existing zero-config installs keep Phase 4 behavior. The app lifespan now passes this through to EmbeddingWorker.model. The public signature of generate_embedding is unchanged: (client, *, text, model=DEFAULT_EMBEDDING_MODEL, dim=..., timeout_s=...).
This commit is contained in:
@@ -94,9 +94,15 @@ async def lifespan(app: FastAPI):
|
||||
# Phase 4's pseudo-embedding path is local so the worker doesn't need
|
||||
# an LLM client; we still pass one so the Phase 4.5 swap to a real
|
||||
# model is a one-line change.
|
||||
# T112 (Phase 4.5): the embedding model is now configurable via
|
||||
# ``Settings.embedding_model``. Default ``"pseudo-sha256-384"``
|
||||
# keeps the local-only path; swapping to a real model routes
|
||||
# through ``client.embed(...)`` and falls back to a zero vector
|
||||
# plus warning if the provider doesn't support embeddings.
|
||||
embedding_worker = EmbeddingWorker(
|
||||
conn_factory=lambda: open_db(settings.db_path),
|
||||
client=_factory(),
|
||||
model=settings.embedding_model,
|
||||
)
|
||||
await embedding_worker.start()
|
||||
app.state.embedding_worker = embedding_worker
|
||||
|
||||
@@ -39,6 +39,14 @@ class Settings(BaseModel):
|
||||
data_dir: Path = REPO_ROOT / "data"
|
||||
bind_host: str = "127.0.0.1"
|
||||
bind_port: int = 8000
|
||||
# T112 (Phase 4.5): embedding model identifier. Default is the
|
||||
# deterministic local pseudo (semantically meaningless but keeps the
|
||||
# vector pipeline structurally valid). Swap to a real model name
|
||||
# (e.g. "bge-small-en-v1.5") once the LLMClient implementation
|
||||
# supports embed() — currently FeatherlessClient does NOT, so a
|
||||
# non-default value will trigger the zero-vector fallback path
|
||||
# plus a T107 warning until a different provider is wired in.
|
||||
embedding_model: str = "pseudo-sha256-384"
|
||||
|
||||
def load_settings() -> Settings:
|
||||
config_path = Path(os.environ.get("CHAT_CONFIG_PATH", DEFAULT_CONFIG))
|
||||
|
||||
+21
-1
@@ -4,8 +4,23 @@ from .client import Message
|
||||
|
||||
|
||||
class MockLLMClient:
|
||||
def __init__(self, canned: list[str]):
|
||||
"""In-memory LLMClient for tests.
|
||||
|
||||
``canned`` feeds ``generate``/``stream`` (one entry per call, popped
|
||||
from the front). ``canned_embeddings`` (T112, Phase 4.5) feeds
|
||||
``embed`` the same way — each call pops the next vector. An empty
|
||||
queue raises ``IndexError`` so misconfigured tests fail loudly
|
||||
rather than returning ``None`` or hanging.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
canned: list[str],
|
||||
*,
|
||||
canned_embeddings: list[list[float]] | None = None,
|
||||
):
|
||||
self._canned = list(canned)
|
||||
self._canned_embeddings: list[list[float]] = list(canned_embeddings or [])
|
||||
|
||||
async def generate(self, messages: Sequence[Message], *, model: str, **params) -> str:
|
||||
return self._canned.pop(0)
|
||||
@@ -14,3 +29,8 @@ class MockLLMClient:
|
||||
text = self._canned.pop(0)
|
||||
for ch in text:
|
||||
yield ch
|
||||
|
||||
async def embed(self, text: str, *, model: str) -> list[float]:
|
||||
# Mirrors the canned-queue pattern; empty queue raises so
|
||||
# misconfigured tests surface clearly instead of returning None.
|
||||
return self._canned_embeddings.pop(0)
|
||||
|
||||
+21
-13
@@ -95,19 +95,27 @@ async def generate_embedding(
|
||||
# Pure-local pseudo path — no LLMClient call.
|
||||
return EmbeddingResult(vector=_pseudo_embed(text, dim), model=model, dim=dim)
|
||||
|
||||
# Future: real embedding via client.embed(...). Phase 4.5 work.
|
||||
# For Phase 4, any non-default model falls through to fallback —
|
||||
# warn so misconfigured callers (e.g., a real-model swap that isn't
|
||||
# wired up yet) don't silently degrade to a zero vector.
|
||||
_log.warning(
|
||||
"generate_embedding: non-default model %r returned fallback "
|
||||
"(model client.embed() not yet implemented in Phase 4.5+); "
|
||||
"downstream search will degrade silently. Configure a supported model.",
|
||||
model,
|
||||
)
|
||||
return EmbeddingResult(
|
||||
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim
|
||||
)
|
||||
# T112 (Phase 4.5): non-default model — route through the client's
|
||||
# ``embed()`` method. On any failure (including ``NotImplementedError``
|
||||
# from providers that don't expose embeddings, e.g. Featherless today),
|
||||
# fall back to the zero vector and re-fire the T107 warning so
|
||||
# misconfigured callers see the issue in logs rather than silently
|
||||
# producing useless cosine results.
|
||||
try:
|
||||
vector = await client.embed(text, model=model)
|
||||
return EmbeddingResult(vector=list(vector), model=model, dim=len(vector))
|
||||
except Exception as exc: # noqa: BLE001 — any failure must degrade gracefully
|
||||
_log.warning(
|
||||
"generate_embedding: non-default model %r returned fallback "
|
||||
"(client.embed() raised %s: %s); "
|
||||
"downstream search will degrade silently. Configure a supported model.",
|
||||
model,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
)
|
||||
return EmbeddingResult(
|
||||
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
Reference in New Issue
Block a user