merge: T91 embedding generation service (pseudo-embedding)
This commit is contained in:
@@ -0,0 +1,108 @@
|
||||
"""Embedding generation service (T91, Phase 4).
|
||||
|
||||
Wraps the embedding API call. For Phase 4's first cut we ship a
|
||||
deterministic local pseudo-embedding (hash-derived) so the vector
|
||||
retrieval pipeline can land without an external embedding endpoint
|
||||
or heavy local dependency. Phase 4.5+ swaps to a real model — the
|
||||
EmbeddingResult shape stays the same, only the generator changes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import math
|
||||
import struct
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from chat.llm.client import LLMClient
|
||||
|
||||
|
||||
DEFAULT_EMBEDDING_DIM = 384
|
||||
DEFAULT_EMBEDDING_MODEL = "pseudo-sha256-384"
|
||||
FALLBACK_EMBEDDING_MODEL = "fallback"
|
||||
|
||||
|
||||
class EmbeddingResult(BaseModel):
|
||||
vector: list[float]
|
||||
model: str
|
||||
dim: int
|
||||
|
||||
|
||||
def _pseudo_embed(text: str, dim: int = DEFAULT_EMBEDDING_DIM) -> list[float]:
|
||||
"""Deterministic pseudo-embedding for Phase 4 first cut.
|
||||
|
||||
Hashes the text with SHA-256, then expands by re-hashing each
|
||||
successive block with the previous block + a counter — this gives
|
||||
``dim * 4`` bytes of fresh entropy per input rather than naively
|
||||
repeating the 32-byte digest (which would collapse the vector onto
|
||||
only 8 unique floats and make distinct inputs cosine-similar).
|
||||
|
||||
Bytes are unpacked as little-endian int32s and rescaled to [-1, 1]
|
||||
so we sidestep the float32 NaN/denormal values that ``struct.unpack
|
||||
'f'`` would otherwise produce on raw hash bytes. The result is
|
||||
unit-normalized so cosine similarity reduces to a dot product.
|
||||
|
||||
NOT semantically meaningful — just consistent for testing the
|
||||
pipeline. Phase 4.5 should swap to a real embedding model.
|
||||
"""
|
||||
needed = dim * 4 # 4 bytes per int32
|
||||
seed = text.encode("utf-8")
|
||||
chunks: list[bytes] = []
|
||||
counter = 0
|
||||
while sum(len(c) for c in chunks) < needed:
|
||||
block = hashlib.sha256(seed + counter.to_bytes(4, "big")).digest()
|
||||
chunks.append(block)
|
||||
counter += 1
|
||||
full = b"".join(chunks)[:needed]
|
||||
ints = struct.unpack(f"<{dim}i", full)
|
||||
# Map int32 to roughly [-1, 1] — exact bound doesn't matter since we
|
||||
# normalize, but keeps values numerically tame.
|
||||
raw = [x / 2147483648.0 for x in ints]
|
||||
norm = math.sqrt(sum(x * x for x in raw)) or 1.0
|
||||
return [x / norm for x in raw]
|
||||
|
||||
|
||||
async def generate_embedding(
|
||||
client: LLMClient,
|
||||
*,
|
||||
text: str,
|
||||
model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
dim: int = DEFAULT_EMBEDDING_DIM,
|
||||
timeout_s: float = 30.0,
|
||||
) -> EmbeddingResult:
|
||||
"""Generate an embedding for the given text.
|
||||
|
||||
Phase 4 default uses a deterministic local pseudo-embedding. If
|
||||
the LLMClient grows an ``embed(...)`` method in Phase 4.5, this
|
||||
wrapper will route to it when ``model != "pseudo-sha256-384"``.
|
||||
|
||||
Falls back to a zero vector with ``model="fallback"`` on any
|
||||
failure (callers detect the sentinel and skip indexing). For the
|
||||
pseudo path, failure is structurally impossible — it's pure local
|
||||
computation.
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
# Empty input — return fallback so caller doesn't index empty rows.
|
||||
return EmbeddingResult(
|
||||
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim
|
||||
)
|
||||
|
||||
if model == DEFAULT_EMBEDDING_MODEL:
|
||||
# Pure-local pseudo path — no LLMClient call.
|
||||
return EmbeddingResult(vector=_pseudo_embed(text, dim), model=model, dim=dim)
|
||||
|
||||
# Future: real embedding via client.embed(...). Phase 4.5 work.
|
||||
# For Phase 4, any non-default model falls through to fallback.
|
||||
return EmbeddingResult(
|
||||
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_EMBEDDING_DIM",
|
||||
"DEFAULT_EMBEDDING_MODEL",
|
||||
"FALLBACK_EMBEDDING_MODEL",
|
||||
"EmbeddingResult",
|
||||
"generate_embedding",
|
||||
]
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Tests for the embedding generation service (T91, Phase 4).
|
||||
|
||||
Phase 4's first cut ships a deterministic local pseudo-embedding so the
|
||||
vector retrieval pipeline can land without an external embeddings API
|
||||
or a heavy local model dependency. These tests pin the contract:
|
||||
|
||||
* the result has the right shape (vector length, ``dim`` metadata),
|
||||
* the default ``model`` string is reported back unchanged,
|
||||
* output is byte-identical for the same input (deterministic),
|
||||
* distinct inputs produce distinct vectors (so cosine actually
|
||||
discriminates),
|
||||
* empty / whitespace-only input collapses to the ``"fallback"`` sentinel
|
||||
with a zero vector — callers detect this and skip indexing,
|
||||
* the vector is unit-normalized so cosine similarity behaves.
|
||||
|
||||
The pseudo path doesn't touch the LLMClient, so we pass an empty
|
||||
``MockLLMClient`` — any accidental call into it would raise
|
||||
``IndexError`` and surface as a regression.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
|
||||
from chat.llm.mock import MockLLMClient
|
||||
from chat.services.embeddings import (
|
||||
DEFAULT_EMBEDDING_DIM,
|
||||
DEFAULT_EMBEDDING_MODEL,
|
||||
FALLBACK_EMBEDDING_MODEL,
|
||||
EmbeddingResult,
|
||||
generate_embedding,
|
||||
)
|
||||
|
||||
|
||||
def _client() -> MockLLMClient:
|
||||
# Pseudo path never calls the client — empty canned list ensures any
|
||||
# accidental call raises and surfaces the regression loudly.
|
||||
return MockLLMClient(canned=[])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_embedding_returns_vector_of_correct_dim():
|
||||
result = await generate_embedding(_client(), text="hello")
|
||||
assert isinstance(result, EmbeddingResult)
|
||||
assert isinstance(result.vector, list)
|
||||
assert len(result.vector) == DEFAULT_EMBEDDING_DIM == 384
|
||||
assert result.dim == 384
|
||||
assert all(isinstance(x, float) for x in result.vector)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_embedding_returns_correct_model_metadata():
|
||||
result = await generate_embedding(_client(), text="hello")
|
||||
assert result.model == DEFAULT_EMBEDDING_MODEL == "pseudo-sha256-384"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_embedding_is_deterministic():
|
||||
a = await generate_embedding(_client(), text="hello world")
|
||||
b = await generate_embedding(_client(), text="hello world")
|
||||
assert a.vector == b.vector
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_embedding_distinct_text_produces_distinct_vectors():
|
||||
a = await generate_embedding(_client(), text="hello world")
|
||||
b = await generate_embedding(_client(), text="totally different content")
|
||||
assert a.vector != b.vector
|
||||
# Sanity-check cosine similarity — both vectors are unit-normalized,
|
||||
# so this reduces to a plain dot product.
|
||||
cosine = sum(x * y for x, y in zip(a.vector, b.vector))
|
||||
assert cosine < 0.99
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_embedding_empty_text_returns_fallback():
|
||||
for empty in ("", " ", "\n\t"):
|
||||
result = await generate_embedding(_client(), text=empty)
|
||||
assert result.model == FALLBACK_EMBEDDING_MODEL == "fallback"
|
||||
assert result.dim == DEFAULT_EMBEDDING_DIM
|
||||
assert len(result.vector) == DEFAULT_EMBEDDING_DIM
|
||||
assert all(x == 0.0 for x in result.vector)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_embedding_unit_normalized():
|
||||
result = await generate_embedding(_client(), text="some non-empty text")
|
||||
norm_sq = sum(x * x for x in result.vector)
|
||||
assert math.isclose(norm_sq, 1.0, abs_tol=1e-6)
|
||||
Reference in New Issue
Block a user