138 lines
4.3 KiB
Python
138 lines
4.3 KiB
Python
"""Embedding worker (T97, Phase 4).
|
|
|
|
Drains a queue of embedding jobs. Each job carries a memory id and the
|
|
narrative text to embed; the worker calls
|
|
:func:`chat.services.embeddings.generate_embedding` and emits an
|
|
``embedding_indexed`` event so the projector lands the vector in the
|
|
``embeddings`` table.
|
|
|
|
Mirrors the :class:`chat.services.background.BackgroundWorker` pattern:
|
|
single asyncio task, sentinel-based shutdown, exceptions are caught and
|
|
logged so a flaky embedding call doesn't take down the worker. Each job
|
|
opens its own SQLite connection via ``conn_factory`` — the request path
|
|
and the worker do not share connections.
|
|
|
|
Featherless concurrency (the 2-conn cap) is respected by virtue of the
|
|
single-task design: jobs run strictly serially. Phase 4's pseudo-embedding
|
|
path is local and synchronous so this is largely moot, but the pattern
|
|
is in place for the Phase 4.5+ real-embedding swap.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from sqlite3 import Connection
|
|
from typing import Callable
|
|
|
|
from chat.eventlog.log import append_and_apply
|
|
from chat.services.embeddings import (
|
|
DEFAULT_EMBEDDING_DIM,
|
|
DEFAULT_EMBEDDING_MODEL,
|
|
FALLBACK_EMBEDDING_MODEL,
|
|
generate_embedding,
|
|
)
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class EmbeddingJob:
|
|
"""One unit of work for the embedding worker.
|
|
|
|
``memory_id`` is the row to attach the vector to; ``text`` is the
|
|
narrative text to embed (typically ``memories.pov_summary``).
|
|
"""
|
|
|
|
memory_id: int
|
|
text: str
|
|
|
|
|
|
class EmbeddingWorker:
|
|
"""asyncio.Queue-backed single-worker task for embedding generation.
|
|
|
|
Started on app startup; ``stop()`` enqueues a sentinel and awaits
|
|
the task so any in-flight job has a chance to finish. Pending jobs
|
|
after the sentinel are dropped on shutdown.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
conn_factory: Callable[[], Connection],
|
|
client, # LLMClient | None — unused on the pseudo path.
|
|
model: str = DEFAULT_EMBEDDING_MODEL,
|
|
dim: int = DEFAULT_EMBEDDING_DIM,
|
|
enabled: bool = True,
|
|
) -> None:
|
|
self._queue: asyncio.Queue[EmbeddingJob | None] = asyncio.Queue()
|
|
self._conn_factory = conn_factory
|
|
self._client = client
|
|
self._model = model
|
|
self._dim = dim
|
|
self._task: asyncio.Task | None = None
|
|
self.enabled = enabled
|
|
|
|
def enqueue(self, job: EmbeddingJob) -> None:
|
|
if not self.enabled:
|
|
return
|
|
self._queue.put_nowait(job)
|
|
|
|
async def start(self) -> None:
|
|
if self._task is None:
|
|
self._task = asyncio.create_task(self._run())
|
|
|
|
async def stop(self) -> None:
|
|
if self._task is None:
|
|
return
|
|
self._queue.put_nowait(None) # sentinel
|
|
await self._task
|
|
self._task = None
|
|
|
|
async def _run(self) -> None:
|
|
while True:
|
|
job = await self._queue.get()
|
|
if job is None:
|
|
return
|
|
try:
|
|
await self._process(job)
|
|
except Exception as exc: # noqa: BLE001 — worker must not die
|
|
log.warning(
|
|
"embedding worker failed for memory_id=%s: %s",
|
|
job.memory_id,
|
|
exc,
|
|
exc_info=True,
|
|
)
|
|
|
|
async def _process(self, job: EmbeddingJob) -> None:
|
|
result = await generate_embedding(
|
|
self._client,
|
|
text=job.text,
|
|
model=self._model,
|
|
dim=self._dim,
|
|
)
|
|
if result.model == FALLBACK_EMBEDDING_MODEL:
|
|
# Don't index a fallback (zero) vector — the backfill script
|
|
# can retry later once a real embedding is available.
|
|
log.debug(
|
|
"embedding worker skipping fallback result for memory_id=%s",
|
|
job.memory_id,
|
|
)
|
|
return
|
|
with self._conn_factory() as conn:
|
|
append_and_apply(
|
|
conn,
|
|
kind="embedding_indexed",
|
|
payload={
|
|
"memory_id": job.memory_id,
|
|
"model": result.model,
|
|
"dim": result.dim,
|
|
"vector": result.vector,
|
|
},
|
|
)
|
|
|
|
|
|
__all__ = ["EmbeddingJob", "EmbeddingWorker"]
|