Files
chat/chat/services/embeddings.py
T

109 lines
3.7 KiB
Python

"""Embedding generation service (T91, Phase 4).
Wraps the embedding API call. For Phase 4's first cut we ship a
deterministic local pseudo-embedding (hash-derived) so the vector
retrieval pipeline can land without an external embedding endpoint
or heavy local dependency. Phase 4.5+ swaps to a real model — the
EmbeddingResult shape stays the same, only the generator changes.
"""
from __future__ import annotations
import hashlib
import math
import struct
from pydantic import BaseModel
from chat.llm.client import LLMClient
DEFAULT_EMBEDDING_DIM = 384
DEFAULT_EMBEDDING_MODEL = "pseudo-sha256-384"
FALLBACK_EMBEDDING_MODEL = "fallback"
class EmbeddingResult(BaseModel):
vector: list[float]
model: str
dim: int
def _pseudo_embed(text: str, dim: int = DEFAULT_EMBEDDING_DIM) -> list[float]:
"""Deterministic pseudo-embedding for Phase 4 first cut.
Hashes the text with SHA-256, then expands by re-hashing each
successive block with the previous block + a counter — this gives
``dim * 4`` bytes of fresh entropy per input rather than naively
repeating the 32-byte digest (which would collapse the vector onto
only 8 unique floats and make distinct inputs cosine-similar).
Bytes are unpacked as little-endian int32s and rescaled to [-1, 1]
so we sidestep the float32 NaN/denormal values that ``struct.unpack
'f'`` would otherwise produce on raw hash bytes. The result is
unit-normalized so cosine similarity reduces to a dot product.
NOT semantically meaningful — just consistent for testing the
pipeline. Phase 4.5 should swap to a real embedding model.
"""
needed = dim * 4 # 4 bytes per int32
seed = text.encode("utf-8")
chunks: list[bytes] = []
counter = 0
while sum(len(c) for c in chunks) < needed:
block = hashlib.sha256(seed + counter.to_bytes(4, "big")).digest()
chunks.append(block)
counter += 1
full = b"".join(chunks)[:needed]
ints = struct.unpack(f"<{dim}i", full)
# Map int32 to roughly [-1, 1] — exact bound doesn't matter since we
# normalize, but keeps values numerically tame.
raw = [x / 2147483648.0 for x in ints]
norm = math.sqrt(sum(x * x for x in raw)) or 1.0
return [x / norm for x in raw]
async def generate_embedding(
client: LLMClient,
*,
text: str,
model: str = DEFAULT_EMBEDDING_MODEL,
dim: int = DEFAULT_EMBEDDING_DIM,
timeout_s: float = 30.0,
) -> EmbeddingResult:
"""Generate an embedding for the given text.
Phase 4 default uses a deterministic local pseudo-embedding. If
the LLMClient grows an ``embed(...)`` method in Phase 4.5, this
wrapper will route to it when ``model != "pseudo-sha256-384"``.
Falls back to a zero vector with ``model="fallback"`` on any
failure (callers detect the sentinel and skip indexing). For the
pseudo path, failure is structurally impossible — it's pure local
computation.
"""
if not text or not text.strip():
# Empty input — return fallback so caller doesn't index empty rows.
return EmbeddingResult(
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim
)
if model == DEFAULT_EMBEDDING_MODEL:
# Pure-local pseudo path — no LLMClient call.
return EmbeddingResult(vector=_pseudo_embed(text, dim), model=model, dim=dim)
# Future: real embedding via client.embed(...). Phase 4.5 work.
# For Phase 4, any non-default model falls through to fallback.
return EmbeddingResult(
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim
)
__all__ = [
"DEFAULT_EMBEDDING_DIM",
"DEFAULT_EMBEDDING_MODEL",
"FALLBACK_EMBEDDING_MODEL",
"EmbeddingResult",
"generate_embedding",
]