Files
chat/chat/services/embeddings.py
T
Joseph Doherty e0a28abbcd feat: generate_embedding routes non-default models through client.embed (T112.3)
When model != DEFAULT_EMBEDDING_MODEL, generate_embedding now
calls client.embed(text, model=model) and wraps the returned
vector in an EmbeddingResult tagged with the requested model.
On any exception (NotImplementedError from providers without an
embeddings endpoint, transient network errors, etc.), the existing
T107 warning fires and the function falls back to the zero-vector
sentinel — callers detect model == 'fallback' and skip indexing.

Adds:
- MockLLMClient accepts a canned_embeddings queue mirroring
  the existing canned pattern. embed() pops from the front;
  empty queue raises IndexError so misconfigured tests fail
  loudly.
- Settings.embedding_model defaults to "pseudo-sha256-384"
  so existing zero-config installs keep Phase 4 behavior. The app
  lifespan now passes this through to EmbeddingWorker.model.

The public signature of generate_embedding is unchanged:
(client, *, text, model=DEFAULT_EMBEDDING_MODEL, dim=..., timeout_s=...).
2026-04-27 05:50:29 -04:00

128 lines
4.5 KiB
Python

"""Embedding generation service (T91, Phase 4).
Wraps the embedding API call. For Phase 4's first cut we ship a
deterministic local pseudo-embedding (hash-derived) so the vector
retrieval pipeline can land without an external embedding endpoint
or heavy local dependency. Phase 4.5+ swaps to a real model — the
EmbeddingResult shape stays the same, only the generator changes.
"""
from __future__ import annotations
import hashlib
import logging
import math
import struct
from pydantic import BaseModel
from chat.llm.client import LLMClient
_log = logging.getLogger(__name__)
DEFAULT_EMBEDDING_DIM = 384
DEFAULT_EMBEDDING_MODEL = "pseudo-sha256-384"
FALLBACK_EMBEDDING_MODEL = "fallback"
class EmbeddingResult(BaseModel):
vector: list[float]
model: str
dim: int
def _pseudo_embed(text: str, dim: int = DEFAULT_EMBEDDING_DIM) -> list[float]:
"""Deterministic pseudo-embedding for Phase 4 first cut.
Hashes the text with SHA-256, then expands by re-hashing each
successive block with the previous block + a counter — this gives
``dim * 4`` bytes of fresh entropy per input rather than naively
repeating the 32-byte digest (which would collapse the vector onto
only 8 unique floats and make distinct inputs cosine-similar).
Bytes are unpacked as little-endian int32s and rescaled to [-1, 1]
so we sidestep the float32 NaN/denormal values that ``struct.unpack
'f'`` would otherwise produce on raw hash bytes. The result is
unit-normalized so cosine similarity reduces to a dot product.
NOT semantically meaningful — just consistent for testing the
pipeline. Phase 4.5 should swap to a real embedding model.
"""
needed = dim * 4 # 4 bytes per int32
seed = text.encode("utf-8")
chunks: list[bytes] = []
counter = 0
while sum(len(c) for c in chunks) < needed:
block = hashlib.sha256(seed + counter.to_bytes(4, "big")).digest()
chunks.append(block)
counter += 1
full = b"".join(chunks)[:needed]
ints = struct.unpack(f"<{dim}i", full)
# Map int32 to roughly [-1, 1] — exact bound doesn't matter since we
# normalize, but keeps values numerically tame.
raw = [x / 2147483648.0 for x in ints]
norm = math.sqrt(sum(x * x for x in raw)) or 1.0
return [x / norm for x in raw]
async def generate_embedding(
client: LLMClient,
*,
text: str,
model: str = DEFAULT_EMBEDDING_MODEL,
dim: int = DEFAULT_EMBEDDING_DIM,
timeout_s: float = 30.0,
) -> EmbeddingResult:
"""Generate an embedding for the given text.
Phase 4 default uses a deterministic local pseudo-embedding. If
the LLMClient grows an ``embed(...)`` method in Phase 4.5, this
wrapper will route to it when ``model != "pseudo-sha256-384"``.
Falls back to a zero vector with ``model="fallback"`` on any
failure (callers detect the sentinel and skip indexing). For the
pseudo path, failure is structurally impossible — it's pure local
computation.
"""
if not text or not text.strip():
# Empty input — return fallback so caller doesn't index empty rows.
return EmbeddingResult(
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim
)
if model == DEFAULT_EMBEDDING_MODEL:
# Pure-local pseudo path — no LLMClient call.
return EmbeddingResult(vector=_pseudo_embed(text, dim), model=model, dim=dim)
# T112 (Phase 4.5): non-default model — route through the client's
# ``embed()`` method. On any failure (including ``NotImplementedError``
# from providers that don't expose embeddings, e.g. Featherless today),
# fall back to the zero vector and re-fire the T107 warning so
# misconfigured callers see the issue in logs rather than silently
# producing useless cosine results.
try:
vector = await client.embed(text, model=model)
return EmbeddingResult(vector=list(vector), model=model, dim=len(vector))
except Exception as exc: # noqa: BLE001 — any failure must degrade gracefully
_log.warning(
"generate_embedding: non-default model %r returned fallback "
"(client.embed() raised %s: %s); "
"downstream search will degrade silently. Configure a supported model.",
model,
type(exc).__name__,
exc,
)
return EmbeddingResult(
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim
)
__all__ = [
"DEFAULT_EMBEDDING_DIM",
"DEFAULT_EMBEDDING_MODEL",
"FALLBACK_EMBEDDING_MODEL",
"EmbeddingResult",
"generate_embedding",
]