e0a28abbcd
When model != DEFAULT_EMBEDDING_MODEL, generate_embedding now calls client.embed(text, model=model) and wraps the returned vector in an EmbeddingResult tagged with the requested model. On any exception (NotImplementedError from providers without an embeddings endpoint, transient network errors, etc.), the existing T107 warning fires and the function falls back to the zero-vector sentinel — callers detect model == 'fallback' and skip indexing. Adds: - MockLLMClient accepts a canned_embeddings queue mirroring the existing canned pattern. embed() pops from the front; empty queue raises IndexError so misconfigured tests fail loudly. - Settings.embedding_model defaults to "pseudo-sha256-384" so existing zero-config installs keep Phase 4 behavior. The app lifespan now passes this through to EmbeddingWorker.model. The public signature of generate_embedding is unchanged: (client, *, text, model=DEFAULT_EMBEDDING_MODEL, dim=..., timeout_s=...).
171 lines
6.7 KiB
Python
171 lines
6.7 KiB
Python
"""Tests for the embedding generation service (T91, Phase 4).
|
|
|
|
Phase 4's first cut ships a deterministic local pseudo-embedding so the
|
|
vector retrieval pipeline can land without an external embeddings API
|
|
or a heavy local model dependency. These tests pin the contract:
|
|
|
|
* the result has the right shape (vector length, ``dim`` metadata),
|
|
* the default ``model`` string is reported back unchanged,
|
|
* output is byte-identical for the same input (deterministic),
|
|
* distinct inputs produce distinct vectors (so cosine actually
|
|
discriminates),
|
|
* empty / whitespace-only input collapses to the ``"fallback"`` sentinel
|
|
with a zero vector — callers detect this and skip indexing,
|
|
* the vector is unit-normalized so cosine similarity behaves.
|
|
|
|
The pseudo path doesn't touch the LLMClient, so we pass an empty
|
|
``MockLLMClient`` — any accidental call into it would raise
|
|
``IndexError`` and surface as a regression.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import math
|
|
|
|
import pytest
|
|
|
|
from chat.llm.mock import MockLLMClient
|
|
from chat.services.embeddings import (
|
|
DEFAULT_EMBEDDING_DIM,
|
|
DEFAULT_EMBEDDING_MODEL,
|
|
FALLBACK_EMBEDDING_MODEL,
|
|
EmbeddingResult,
|
|
generate_embedding,
|
|
)
|
|
|
|
|
|
def _client() -> MockLLMClient:
|
|
# Pseudo path never calls the client — empty canned list ensures any
|
|
# accidental call raises and surfaces the regression loudly.
|
|
return MockLLMClient(canned=[])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_embedding_returns_vector_of_correct_dim():
|
|
result = await generate_embedding(_client(), text="hello")
|
|
assert isinstance(result, EmbeddingResult)
|
|
assert isinstance(result.vector, list)
|
|
assert len(result.vector) == DEFAULT_EMBEDDING_DIM == 384
|
|
assert result.dim == 384
|
|
assert all(isinstance(x, float) for x in result.vector)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_embedding_returns_correct_model_metadata():
|
|
result = await generate_embedding(_client(), text="hello")
|
|
assert result.model == DEFAULT_EMBEDDING_MODEL == "pseudo-sha256-384"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_embedding_is_deterministic():
|
|
a = await generate_embedding(_client(), text="hello world")
|
|
b = await generate_embedding(_client(), text="hello world")
|
|
assert a.vector == b.vector
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_embedding_distinct_text_produces_distinct_vectors():
|
|
a = await generate_embedding(_client(), text="hello world")
|
|
b = await generate_embedding(_client(), text="totally different content")
|
|
assert a.vector != b.vector
|
|
# Sanity-check cosine similarity — both vectors are unit-normalized,
|
|
# so this reduces to a plain dot product.
|
|
cosine = sum(x * y for x, y in zip(a.vector, b.vector))
|
|
assert cosine < 0.99
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_embedding_empty_text_returns_fallback():
|
|
for empty in ("", " ", "\n\t"):
|
|
result = await generate_embedding(_client(), text=empty)
|
|
assert result.model == FALLBACK_EMBEDDING_MODEL == "fallback"
|
|
assert result.dim == DEFAULT_EMBEDDING_DIM
|
|
assert len(result.vector) == DEFAULT_EMBEDDING_DIM
|
|
assert all(x == 0.0 for x in result.vector)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_embedding_unit_normalized():
|
|
result = await generate_embedding(_client(), text="some non-empty text")
|
|
norm_sq = sum(x * x for x in result.vector)
|
|
assert math.isclose(norm_sq, 1.0, abs_tol=1e-6)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_embedding_non_default_model_logs_warning(caplog):
|
|
"""T107: non-default model falls through to fallback and must warn.
|
|
|
|
A Phase 4.5+ caller pointing at a real model that isn't yet wired
|
|
up would otherwise silently degrade (zero vector → useless cosine).
|
|
The warning surfaces the misconfiguration in logs.
|
|
"""
|
|
caplog.set_level(logging.WARNING, logger="chat.services.embeddings")
|
|
result = await generate_embedding(_client(), text="hello", model="real-model")
|
|
|
|
# Behavior unchanged: still returns the fallback sentinel.
|
|
assert result.model == FALLBACK_EMBEDDING_MODEL == "fallback"
|
|
assert all(x == 0.0 for x in result.vector)
|
|
|
|
# Warning fired and names the offending model.
|
|
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
|
|
assert any("non-default model" in r.getMessage() for r in warnings)
|
|
assert any("real-model" in r.getMessage() for r in warnings)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_embedding_default_model_does_not_warn(caplog):
|
|
"""T107: the silent default path must stay silent."""
|
|
caplog.set_level(logging.WARNING, logger="chat.services.embeddings")
|
|
await generate_embedding(_client(), text="hello")
|
|
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
|
|
assert warnings == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_embed_routes_to_client_when_non_default_model():
|
|
"""T112: when a non-default ``model`` is requested, generate_embedding
|
|
routes through ``client.embed(text, model=...)`` and wraps the
|
|
returned vector in an EmbeddingResult tagged with the requested
|
|
model (NOT the fallback sentinel)."""
|
|
canned = [0.1, 0.2, 0.3, 0.4]
|
|
client = MockLLMClient(canned=[], canned_embeddings=[canned])
|
|
|
|
result = await generate_embedding(
|
|
client, text="hello world", model="bge-small-en-v1.5"
|
|
)
|
|
assert result.vector == canned
|
|
assert result.model == "bge-small-en-v1.5"
|
|
assert result.dim == len(canned)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_embed_falls_back_on_client_failure(caplog):
|
|
"""T112: when ``client.embed`` raises (e.g. NotImplementedError on
|
|
Featherless, or a transient network error), generate_embedding logs
|
|
the existing T107 warning and returns the zero-vector fallback so
|
|
callers detect the sentinel and skip indexing."""
|
|
|
|
class _FailingClient:
|
|
async def generate(self, messages, *, model, **params): # pragma: no cover
|
|
raise AssertionError("generate must not be called")
|
|
|
|
def stream(self, messages, *, model, **params): # pragma: no cover
|
|
raise AssertionError("stream must not be called")
|
|
|
|
async def embed(self, text, *, model):
|
|
raise NotImplementedError("provider does not expose embeddings")
|
|
|
|
caplog.set_level(logging.WARNING, logger="chat.services.embeddings")
|
|
result = await generate_embedding(
|
|
_FailingClient(), text="hello", model="bge-small-en-v1.5"
|
|
)
|
|
|
|
assert result.model == FALLBACK_EMBEDDING_MODEL == "fallback"
|
|
assert len(result.vector) == DEFAULT_EMBEDDING_DIM
|
|
assert all(x == 0.0 for x in result.vector)
|
|
|
|
# Existing T107 warning fires (re-used from the new exception branch).
|
|
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
|
|
assert any("bge-small-en-v1.5" in r.getMessage() for r in warnings)
|