Files
chat/tests/test_embeddings.py
Joseph Doherty e0a28abbcd feat: generate_embedding routes non-default models through client.embed (T112.3)
When model != DEFAULT_EMBEDDING_MODEL, generate_embedding now
calls client.embed(text, model=model) and wraps the returned
vector in an EmbeddingResult tagged with the requested model.
On any exception (NotImplementedError from providers without an
embeddings endpoint, transient network errors, etc.), the existing
T107 warning fires and the function falls back to the zero-vector
sentinel — callers detect model == 'fallback' and skip indexing.

Adds:
- MockLLMClient accepts a canned_embeddings queue mirroring
  the existing canned pattern. embed() pops from the front;
  empty queue raises IndexError so misconfigured tests fail
  loudly.
- Settings.embedding_model defaults to "pseudo-sha256-384"
  so existing zero-config installs keep Phase 4 behavior. The app
  lifespan now passes this through to EmbeddingWorker.model.

The public signature of generate_embedding is unchanged:
(client, *, text, model=DEFAULT_EMBEDDING_MODEL, dim=..., timeout_s=...).
2026-04-27 05:50:29 -04:00

171 lines
6.7 KiB
Python

"""Tests for the embedding generation service (T91, Phase 4).
Phase 4's first cut ships a deterministic local pseudo-embedding so the
vector retrieval pipeline can land without an external embeddings API
or a heavy local model dependency. These tests pin the contract:
* the result has the right shape (vector length, ``dim`` metadata),
* the default ``model`` string is reported back unchanged,
* output is byte-identical for the same input (deterministic),
* distinct inputs produce distinct vectors (so cosine actually
discriminates),
* empty / whitespace-only input collapses to the ``"fallback"`` sentinel
with a zero vector — callers detect this and skip indexing,
* the vector is unit-normalized so cosine similarity behaves.
The pseudo path doesn't touch the LLMClient, so we pass an empty
``MockLLMClient`` — any accidental call into it would raise
``IndexError`` and surface as a regression.
"""
from __future__ import annotations
import logging
import math
import pytest
from chat.llm.mock import MockLLMClient
from chat.services.embeddings import (
DEFAULT_EMBEDDING_DIM,
DEFAULT_EMBEDDING_MODEL,
FALLBACK_EMBEDDING_MODEL,
EmbeddingResult,
generate_embedding,
)
def _client() -> MockLLMClient:
# Pseudo path never calls the client — empty canned list ensures any
# accidental call raises and surfaces the regression loudly.
return MockLLMClient(canned=[])
@pytest.mark.asyncio
async def test_generate_embedding_returns_vector_of_correct_dim():
result = await generate_embedding(_client(), text="hello")
assert isinstance(result, EmbeddingResult)
assert isinstance(result.vector, list)
assert len(result.vector) == DEFAULT_EMBEDDING_DIM == 384
assert result.dim == 384
assert all(isinstance(x, float) for x in result.vector)
@pytest.mark.asyncio
async def test_generate_embedding_returns_correct_model_metadata():
result = await generate_embedding(_client(), text="hello")
assert result.model == DEFAULT_EMBEDDING_MODEL == "pseudo-sha256-384"
@pytest.mark.asyncio
async def test_generate_embedding_is_deterministic():
a = await generate_embedding(_client(), text="hello world")
b = await generate_embedding(_client(), text="hello world")
assert a.vector == b.vector
@pytest.mark.asyncio
async def test_generate_embedding_distinct_text_produces_distinct_vectors():
a = await generate_embedding(_client(), text="hello world")
b = await generate_embedding(_client(), text="totally different content")
assert a.vector != b.vector
# Sanity-check cosine similarity — both vectors are unit-normalized,
# so this reduces to a plain dot product.
cosine = sum(x * y for x, y in zip(a.vector, b.vector))
assert cosine < 0.99
@pytest.mark.asyncio
async def test_generate_embedding_empty_text_returns_fallback():
for empty in ("", " ", "\n\t"):
result = await generate_embedding(_client(), text=empty)
assert result.model == FALLBACK_EMBEDDING_MODEL == "fallback"
assert result.dim == DEFAULT_EMBEDDING_DIM
assert len(result.vector) == DEFAULT_EMBEDDING_DIM
assert all(x == 0.0 for x in result.vector)
@pytest.mark.asyncio
async def test_generate_embedding_unit_normalized():
result = await generate_embedding(_client(), text="some non-empty text")
norm_sq = sum(x * x for x in result.vector)
assert math.isclose(norm_sq, 1.0, abs_tol=1e-6)
@pytest.mark.asyncio
async def test_generate_embedding_non_default_model_logs_warning(caplog):
"""T107: non-default model falls through to fallback and must warn.
A Phase 4.5+ caller pointing at a real model that isn't yet wired
up would otherwise silently degrade (zero vector → useless cosine).
The warning surfaces the misconfiguration in logs.
"""
caplog.set_level(logging.WARNING, logger="chat.services.embeddings")
result = await generate_embedding(_client(), text="hello", model="real-model")
# Behavior unchanged: still returns the fallback sentinel.
assert result.model == FALLBACK_EMBEDDING_MODEL == "fallback"
assert all(x == 0.0 for x in result.vector)
# Warning fired and names the offending model.
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
assert any("non-default model" in r.getMessage() for r in warnings)
assert any("real-model" in r.getMessage() for r in warnings)
@pytest.mark.asyncio
async def test_generate_embedding_default_model_does_not_warn(caplog):
"""T107: the silent default path must stay silent."""
caplog.set_level(logging.WARNING, logger="chat.services.embeddings")
await generate_embedding(_client(), text="hello")
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
assert warnings == []
@pytest.mark.asyncio
async def test_embed_routes_to_client_when_non_default_model():
"""T112: when a non-default ``model`` is requested, generate_embedding
routes through ``client.embed(text, model=...)`` and wraps the
returned vector in an EmbeddingResult tagged with the requested
model (NOT the fallback sentinel)."""
canned = [0.1, 0.2, 0.3, 0.4]
client = MockLLMClient(canned=[], canned_embeddings=[canned])
result = await generate_embedding(
client, text="hello world", model="bge-small-en-v1.5"
)
assert result.vector == canned
assert result.model == "bge-small-en-v1.5"
assert result.dim == len(canned)
@pytest.mark.asyncio
async def test_embed_falls_back_on_client_failure(caplog):
"""T112: when ``client.embed`` raises (e.g. NotImplementedError on
Featherless, or a transient network error), generate_embedding logs
the existing T107 warning and returns the zero-vector fallback so
callers detect the sentinel and skip indexing."""
class _FailingClient:
async def generate(self, messages, *, model, **params): # pragma: no cover
raise AssertionError("generate must not be called")
def stream(self, messages, *, model, **params): # pragma: no cover
raise AssertionError("stream must not be called")
async def embed(self, text, *, model):
raise NotImplementedError("provider does not expose embeddings")
caplog.set_level(logging.WARNING, logger="chat.services.embeddings")
result = await generate_embedding(
_FailingClient(), text="hello", model="bge-small-en-v1.5"
)
assert result.model == FALLBACK_EMBEDDING_MODEL == "fallback"
assert len(result.vector) == DEFAULT_EMBEDDING_DIM
assert all(x == 0.0 for x in result.vector)
# Existing T107 warning fires (re-used from the new exception branch).
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
assert any("bge-small-en-v1.5" in r.getMessage() for r in warnings)