diff --git a/chat/services/embedding_worker.py b/chat/services/embedding_worker.py new file mode 100644 index 0000000..80f87d8 --- /dev/null +++ b/chat/services/embedding_worker.py @@ -0,0 +1,137 @@ +"""Embedding worker (T97, Phase 4). + +Drains a queue of embedding jobs. Each job carries a memory id and the +narrative text to embed; the worker calls +:func:`chat.services.embeddings.generate_embedding` and emits an +``embedding_indexed`` event so the projector lands the vector in the +``embeddings`` table. + +Mirrors the :class:`chat.services.background.BackgroundWorker` pattern: +single asyncio task, sentinel-based shutdown, exceptions are caught and +logged so a flaky embedding call doesn't take down the worker. Each job +opens its own SQLite connection via ``conn_factory`` — the request path +and the worker do not share connections. + +Featherless concurrency (the 2-conn cap) is respected by virtue of the +single-task design: jobs run strictly serially. Phase 4's pseudo-embedding +path is local and synchronous so this is largely moot, but the pattern +is in place for the Phase 4.5+ real-embedding swap. +""" + +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass +from sqlite3 import Connection +from typing import Callable + +from chat.eventlog.log import append_and_apply +from chat.services.embeddings import ( + DEFAULT_EMBEDDING_DIM, + DEFAULT_EMBEDDING_MODEL, + FALLBACK_EMBEDDING_MODEL, + generate_embedding, +) + + +log = logging.getLogger(__name__) + + +@dataclass +class EmbeddingJob: + """One unit of work for the embedding worker. + + ``memory_id`` is the row to attach the vector to; ``text`` is the + narrative text to embed (typically ``memories.pov_summary``). + """ + + memory_id: int + text: str + + +class EmbeddingWorker: + """asyncio.Queue-backed single-worker task for embedding generation. + + Started on app startup; ``stop()`` enqueues a sentinel and awaits + the task so any in-flight job has a chance to finish. Pending jobs + after the sentinel are dropped on shutdown. + """ + + def __init__( + self, + *, + conn_factory: Callable[[], Connection], + client, # LLMClient | None — unused on the pseudo path. + model: str = DEFAULT_EMBEDDING_MODEL, + dim: int = DEFAULT_EMBEDDING_DIM, + enabled: bool = True, + ) -> None: + self._queue: asyncio.Queue[EmbeddingJob | None] = asyncio.Queue() + self._conn_factory = conn_factory + self._client = client + self._model = model + self._dim = dim + self._task: asyncio.Task | None = None + self.enabled = enabled + + def enqueue(self, job: EmbeddingJob) -> None: + if not self.enabled: + return + self._queue.put_nowait(job) + + async def start(self) -> None: + if self._task is None: + self._task = asyncio.create_task(self._run()) + + async def stop(self) -> None: + if self._task is None: + return + self._queue.put_nowait(None) # sentinel + await self._task + self._task = None + + async def _run(self) -> None: + while True: + job = await self._queue.get() + if job is None: + return + try: + await self._process(job) + except Exception as exc: # noqa: BLE001 — worker must not die + log.warning( + "embedding worker failed for memory_id=%s: %s", + job.memory_id, + exc, + exc_info=True, + ) + + async def _process(self, job: EmbeddingJob) -> None: + result = await generate_embedding( + self._client, + text=job.text, + model=self._model, + dim=self._dim, + ) + if result.model == FALLBACK_EMBEDDING_MODEL: + # Don't index a fallback (zero) vector — the backfill script + # can retry later once a real embedding is available. + log.debug( + "embedding worker skipping fallback result for memory_id=%s", + job.memory_id, + ) + return + with self._conn_factory() as conn: + append_and_apply( + conn, + kind="embedding_indexed", + payload={ + "memory_id": job.memory_id, + "model": result.model, + "dim": result.dim, + "vector": result.vector, + }, + ) + + +__all__ = ["EmbeddingJob", "EmbeddingWorker"] diff --git a/tests/test_embedding_worker.py b/tests/test_embedding_worker.py new file mode 100644 index 0000000..f7d9416 --- /dev/null +++ b/tests/test_embedding_worker.py @@ -0,0 +1,185 @@ +"""Embedding worker (T97, Phase 4). + +The worker drains a queue of EmbeddingJobs and emits ``embedding_indexed`` +events. Mirrors test_significance.py's BackgroundWorker tests in shape: +seed a memory, enqueue jobs, call ``stop()`` to drain via sentinel, then +assert on the projected ``embeddings`` table and the underlying event_log. +""" + +from __future__ import annotations + +from pathlib import Path + +from chat.db.connection import open_db +from chat.db.migrate import apply_migrations +from chat.eventlog.log import append_event +from chat.eventlog.projector import project +from chat.services.embedding_worker import EmbeddingJob, EmbeddingWorker +from chat.services.embeddings import ( + DEFAULT_EMBEDDING_MODEL, + EmbeddingResult, + FALLBACK_EMBEDDING_MODEL, +) + +# Trigger handler registration for projection. +import chat.state.embeddings # noqa: F401 +import chat.state.entities # noqa: F401 +import chat.state.memory # noqa: F401 +import chat.state.world # noqa: F401 + + +def _seed_memories(db_path: Path, count: int) -> list[int]: + """Seed ``count`` memory rows for ``bot_a`` and return their ids.""" + with open_db(db_path) as conn: + append_event( + conn, + kind="bot_authored", + payload={ + "id": "bot_a", + "name": "BotA", + "persona": "...", + "voice_samples": [], + "traits": [], + "backstory": "", + "initial_relationship_to_you": "", + "kickoff_prose": "", + }, + ) + append_event( + conn, + kind="chat_created", + payload={ + "id": "chat_bot_a", + "host_bot_id": "bot_a", + "initial_time": "2026-04-26T20:00:00+00:00", + "narrative_anchor": "Day 1", + "weather": "", + }, + ) + for i in range(count): + append_event( + conn, + kind="memory_written", + payload={ + "owner_id": "bot_a", + "chat_id": "chat_bot_a", + "pov_summary": f"memory text {i}", + "witness_you": 1, + "witness_host": 1, + "witness_guest": 0, + "source": "direct", + "reliability": 1.0, + "significance": 1, + "pinned": 0, + "auto_pinned": 0, + }, + ) + project(conn) + return [ + r[0] + for r in conn.execute( + "SELECT id FROM memories WHERE owner_id = 'bot_a' ORDER BY id" + ).fetchall() + ] + + +async def test_worker_drains_jobs_and_emits_indexed_events(tmp_path): + """Three jobs in -> three ``embedding_indexed`` events out, all + projected into the ``embeddings`` table.""" + db = tmp_path / "t.db" + apply_migrations(db) + memory_ids = _seed_memories(db, count=3) + + worker = EmbeddingWorker( + conn_factory=lambda: open_db(db), + client=None, # pseudo path — no client needed + ) + await worker.start() + for mid in memory_ids: + worker.enqueue(EmbeddingJob(memory_id=mid, text=f"text-{mid}")) + await worker.stop() + + with open_db(db) as conn: + # Three embedding_indexed events landed. + cur = conn.execute( + "SELECT COUNT(*) FROM event_log WHERE kind = 'embedding_indexed'" + ) + assert cur.fetchone()[0] == 3 + # Three rows in the embeddings table — one per memory. + cur = conn.execute( + "SELECT memory_id, model, dim FROM embeddings ORDER BY memory_id" + ) + rows = cur.fetchall() + assert len(rows) == 3 + for (mid, model, dim), expected_mid in zip(rows, memory_ids): + assert mid == expected_mid + assert model == DEFAULT_EMBEDDING_MODEL + assert dim > 0 + + +async def test_worker_skips_fallback_results(tmp_path, monkeypatch): + """A fallback EmbeddingResult must NOT produce an embedding_indexed + event — backfill can retry later when a real embedding is available.""" + db = tmp_path / "t.db" + apply_migrations(db) + memory_ids = _seed_memories(db, count=1) + + async def _fake_generate(client, *, text, model, dim, timeout_s=30.0): + return EmbeddingResult( + vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim + ) + + # Patch the symbol the worker resolved at import time. + import chat.services.embedding_worker as worker_mod + + monkeypatch.setattr(worker_mod, "generate_embedding", _fake_generate) + + worker = EmbeddingWorker( + conn_factory=lambda: open_db(db), + client=None, + ) + await worker.start() + worker.enqueue(EmbeddingJob(memory_id=memory_ids[0], text="anything")) + await worker.stop() + + with open_db(db) as conn: + cur = conn.execute( + "SELECT COUNT(*) FROM event_log WHERE kind = 'embedding_indexed'" + ) + assert cur.fetchone()[0] == 0 + cur = conn.execute("SELECT COUNT(*) FROM embeddings") + assert cur.fetchone()[0] == 0 + + +async def test_worker_handles_concurrent_jobs_serially(tmp_path): + """Five jobs queued back-to-back must process in FIFO order — the + single-task design respects the Featherless 2-conn cap (and keeps + event_log ordering deterministic).""" + db = tmp_path / "t.db" + apply_migrations(db) + memory_ids = _seed_memories(db, count=5) + + worker = EmbeddingWorker( + conn_factory=lambda: open_db(db), + client=None, + ) + await worker.start() + # Enqueue all five before yielding to the loop — exercises the queue + # rather than a one-at-a-time drain. + for mid in memory_ids: + worker.enqueue(EmbeddingJob(memory_id=mid, text=f"text-{mid}")) + await worker.stop() + + with open_db(db) as conn: + # Events landed in enqueue order (FIFO). + cur = conn.execute( + "SELECT json_extract(payload_json, '$.memory_id') " + "FROM event_log WHERE kind = 'embedding_indexed' " + "ORDER BY id" + ) + seen = [r[0] for r in cur.fetchall()] + assert seen == memory_ids + + # All five embeddings projected. + cur = conn.execute("SELECT COUNT(*) FROM embeddings") + assert cur.fetchone()[0] == 5