feat: generate_embedding routes non-default models through client.embed (T112.3)
When model != DEFAULT_EMBEDDING_MODEL, generate_embedding now calls client.embed(text, model=model) and wraps the returned vector in an EmbeddingResult tagged with the requested model. On any exception (NotImplementedError from providers without an embeddings endpoint, transient network errors, etc.), the existing T107 warning fires and the function falls back to the zero-vector sentinel — callers detect model == 'fallback' and skip indexing. Adds: - MockLLMClient accepts a canned_embeddings queue mirroring the existing canned pattern. embed() pops from the front; empty queue raises IndexError so misconfigured tests fail loudly. - Settings.embedding_model defaults to "pseudo-sha256-384" so existing zero-config installs keep Phase 4 behavior. The app lifespan now passes this through to EmbeddingWorker.model. The public signature of generate_embedding is unchanged: (client, *, text, model=DEFAULT_EMBEDDING_MODEL, dim=..., timeout_s=...).
This commit is contained in:
@@ -120,3 +120,51 @@ async def test_generate_embedding_default_model_does_not_warn(caplog):
|
||||
await generate_embedding(_client(), text="hello")
|
||||
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
|
||||
assert warnings == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_routes_to_client_when_non_default_model():
|
||||
"""T112: when a non-default ``model`` is requested, generate_embedding
|
||||
routes through ``client.embed(text, model=...)`` and wraps the
|
||||
returned vector in an EmbeddingResult tagged with the requested
|
||||
model (NOT the fallback sentinel)."""
|
||||
canned = [0.1, 0.2, 0.3, 0.4]
|
||||
client = MockLLMClient(canned=[], canned_embeddings=[canned])
|
||||
|
||||
result = await generate_embedding(
|
||||
client, text="hello world", model="bge-small-en-v1.5"
|
||||
)
|
||||
assert result.vector == canned
|
||||
assert result.model == "bge-small-en-v1.5"
|
||||
assert result.dim == len(canned)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_falls_back_on_client_failure(caplog):
|
||||
"""T112: when ``client.embed`` raises (e.g. NotImplementedError on
|
||||
Featherless, or a transient network error), generate_embedding logs
|
||||
the existing T107 warning and returns the zero-vector fallback so
|
||||
callers detect the sentinel and skip indexing."""
|
||||
|
||||
class _FailingClient:
|
||||
async def generate(self, messages, *, model, **params): # pragma: no cover
|
||||
raise AssertionError("generate must not be called")
|
||||
|
||||
def stream(self, messages, *, model, **params): # pragma: no cover
|
||||
raise AssertionError("stream must not be called")
|
||||
|
||||
async def embed(self, text, *, model):
|
||||
raise NotImplementedError("provider does not expose embeddings")
|
||||
|
||||
caplog.set_level(logging.WARNING, logger="chat.services.embeddings")
|
||||
result = await generate_embedding(
|
||||
_FailingClient(), text="hello", model="bge-small-en-v1.5"
|
||||
)
|
||||
|
||||
assert result.model == FALLBACK_EMBEDDING_MODEL == "fallback"
|
||||
assert len(result.vector) == DEFAULT_EMBEDDING_DIM
|
||||
assert all(x == 0.0 for x in result.vector)
|
||||
|
||||
# Existing T107 warning fires (re-used from the new exception branch).
|
||||
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
|
||||
assert any("bge-small-en-v1.5" in r.getMessage() for r in warnings)
|
||||
|
||||
Reference in New Issue
Block a user