feat: embedding worker drains queue and emits embedding_indexed events (T97.1)

This commit is contained in:
Joseph Doherty
2026-04-27 02:51:36 -04:00
parent 50448b72f8
commit 6674f9475c
2 changed files with 322 additions and 0 deletions
+137
View File
@@ -0,0 +1,137 @@
"""Embedding worker (T97, Phase 4).
Drains a queue of embedding jobs. Each job carries a memory id and the
narrative text to embed; the worker calls
:func:`chat.services.embeddings.generate_embedding` and emits an
``embedding_indexed`` event so the projector lands the vector in the
``embeddings`` table.
Mirrors the :class:`chat.services.background.BackgroundWorker` pattern:
single asyncio task, sentinel-based shutdown, exceptions are caught and
logged so a flaky embedding call doesn't take down the worker. Each job
opens its own SQLite connection via ``conn_factory`` — the request path
and the worker do not share connections.
Featherless concurrency (the 2-conn cap) is respected by virtue of the
single-task design: jobs run strictly serially. Phase 4's pseudo-embedding
path is local and synchronous so this is largely moot, but the pattern
is in place for the Phase 4.5+ real-embedding swap.
"""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from sqlite3 import Connection
from typing import Callable
from chat.eventlog.log import append_and_apply
from chat.services.embeddings import (
DEFAULT_EMBEDDING_DIM,
DEFAULT_EMBEDDING_MODEL,
FALLBACK_EMBEDDING_MODEL,
generate_embedding,
)
log = logging.getLogger(__name__)
@dataclass
class EmbeddingJob:
"""One unit of work for the embedding worker.
``memory_id`` is the row to attach the vector to; ``text`` is the
narrative text to embed (typically ``memories.pov_summary``).
"""
memory_id: int
text: str
class EmbeddingWorker:
"""asyncio.Queue-backed single-worker task for embedding generation.
Started on app startup; ``stop()`` enqueues a sentinel and awaits
the task so any in-flight job has a chance to finish. Pending jobs
after the sentinel are dropped on shutdown.
"""
def __init__(
self,
*,
conn_factory: Callable[[], Connection],
client, # LLMClient | None — unused on the pseudo path.
model: str = DEFAULT_EMBEDDING_MODEL,
dim: int = DEFAULT_EMBEDDING_DIM,
enabled: bool = True,
) -> None:
self._queue: asyncio.Queue[EmbeddingJob | None] = asyncio.Queue()
self._conn_factory = conn_factory
self._client = client
self._model = model
self._dim = dim
self._task: asyncio.Task | None = None
self.enabled = enabled
def enqueue(self, job: EmbeddingJob) -> None:
if not self.enabled:
return
self._queue.put_nowait(job)
async def start(self) -> None:
if self._task is None:
self._task = asyncio.create_task(self._run())
async def stop(self) -> None:
if self._task is None:
return
self._queue.put_nowait(None) # sentinel
await self._task
self._task = None
async def _run(self) -> None:
while True:
job = await self._queue.get()
if job is None:
return
try:
await self._process(job)
except Exception as exc: # noqa: BLE001 — worker must not die
log.warning(
"embedding worker failed for memory_id=%s: %s",
job.memory_id,
exc,
exc_info=True,
)
async def _process(self, job: EmbeddingJob) -> None:
result = await generate_embedding(
self._client,
text=job.text,
model=self._model,
dim=self._dim,
)
if result.model == FALLBACK_EMBEDDING_MODEL:
# Don't index a fallback (zero) vector — the backfill script
# can retry later once a real embedding is available.
log.debug(
"embedding worker skipping fallback result for memory_id=%s",
job.memory_id,
)
return
with self._conn_factory() as conn:
append_and_apply(
conn,
kind="embedding_indexed",
payload={
"memory_id": job.memory_id,
"model": result.model,
"dim": result.dim,
"vector": result.vector,
},
)
__all__ = ["EmbeddingJob", "EmbeddingWorker"]
+185
View File
@@ -0,0 +1,185 @@
"""Embedding worker (T97, Phase 4).
The worker drains a queue of EmbeddingJobs and emits ``embedding_indexed``
events. Mirrors test_significance.py's BackgroundWorker tests in shape:
seed a memory, enqueue jobs, call ``stop()`` to drain via sentinel, then
assert on the projected ``embeddings`` table and the underlying event_log.
"""
from __future__ import annotations
from pathlib import Path
from chat.db.connection import open_db
from chat.db.migrate import apply_migrations
from chat.eventlog.log import append_event
from chat.eventlog.projector import project
from chat.services.embedding_worker import EmbeddingJob, EmbeddingWorker
from chat.services.embeddings import (
DEFAULT_EMBEDDING_MODEL,
EmbeddingResult,
FALLBACK_EMBEDDING_MODEL,
)
# Trigger handler registration for projection.
import chat.state.embeddings # noqa: F401
import chat.state.entities # noqa: F401
import chat.state.memory # noqa: F401
import chat.state.world # noqa: F401
def _seed_memories(db_path: Path, count: int) -> list[int]:
"""Seed ``count`` memory rows for ``bot_a`` and return their ids."""
with open_db(db_path) as conn:
append_event(
conn,
kind="bot_authored",
payload={
"id": "bot_a",
"name": "BotA",
"persona": "...",
"voice_samples": [],
"traits": [],
"backstory": "",
"initial_relationship_to_you": "",
"kickoff_prose": "",
},
)
append_event(
conn,
kind="chat_created",
payload={
"id": "chat_bot_a",
"host_bot_id": "bot_a",
"initial_time": "2026-04-26T20:00:00+00:00",
"narrative_anchor": "Day 1",
"weather": "",
},
)
for i in range(count):
append_event(
conn,
kind="memory_written",
payload={
"owner_id": "bot_a",
"chat_id": "chat_bot_a",
"pov_summary": f"memory text {i}",
"witness_you": 1,
"witness_host": 1,
"witness_guest": 0,
"source": "direct",
"reliability": 1.0,
"significance": 1,
"pinned": 0,
"auto_pinned": 0,
},
)
project(conn)
return [
r[0]
for r in conn.execute(
"SELECT id FROM memories WHERE owner_id = 'bot_a' ORDER BY id"
).fetchall()
]
async def test_worker_drains_jobs_and_emits_indexed_events(tmp_path):
"""Three jobs in -> three ``embedding_indexed`` events out, all
projected into the ``embeddings`` table."""
db = tmp_path / "t.db"
apply_migrations(db)
memory_ids = _seed_memories(db, count=3)
worker = EmbeddingWorker(
conn_factory=lambda: open_db(db),
client=None, # pseudo path — no client needed
)
await worker.start()
for mid in memory_ids:
worker.enqueue(EmbeddingJob(memory_id=mid, text=f"text-{mid}"))
await worker.stop()
with open_db(db) as conn:
# Three embedding_indexed events landed.
cur = conn.execute(
"SELECT COUNT(*) FROM event_log WHERE kind = 'embedding_indexed'"
)
assert cur.fetchone()[0] == 3
# Three rows in the embeddings table — one per memory.
cur = conn.execute(
"SELECT memory_id, model, dim FROM embeddings ORDER BY memory_id"
)
rows = cur.fetchall()
assert len(rows) == 3
for (mid, model, dim), expected_mid in zip(rows, memory_ids):
assert mid == expected_mid
assert model == DEFAULT_EMBEDDING_MODEL
assert dim > 0
async def test_worker_skips_fallback_results(tmp_path, monkeypatch):
"""A fallback EmbeddingResult must NOT produce an embedding_indexed
event — backfill can retry later when a real embedding is available."""
db = tmp_path / "t.db"
apply_migrations(db)
memory_ids = _seed_memories(db, count=1)
async def _fake_generate(client, *, text, model, dim, timeout_s=30.0):
return EmbeddingResult(
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim
)
# Patch the symbol the worker resolved at import time.
import chat.services.embedding_worker as worker_mod
monkeypatch.setattr(worker_mod, "generate_embedding", _fake_generate)
worker = EmbeddingWorker(
conn_factory=lambda: open_db(db),
client=None,
)
await worker.start()
worker.enqueue(EmbeddingJob(memory_id=memory_ids[0], text="anything"))
await worker.stop()
with open_db(db) as conn:
cur = conn.execute(
"SELECT COUNT(*) FROM event_log WHERE kind = 'embedding_indexed'"
)
assert cur.fetchone()[0] == 0
cur = conn.execute("SELECT COUNT(*) FROM embeddings")
assert cur.fetchone()[0] == 0
async def test_worker_handles_concurrent_jobs_serially(tmp_path):
"""Five jobs queued back-to-back must process in FIFO order — the
single-task design respects the Featherless 2-conn cap (and keeps
event_log ordering deterministic)."""
db = tmp_path / "t.db"
apply_migrations(db)
memory_ids = _seed_memories(db, count=5)
worker = EmbeddingWorker(
conn_factory=lambda: open_db(db),
client=None,
)
await worker.start()
# Enqueue all five before yielding to the loop — exercises the queue
# rather than a one-at-a-time drain.
for mid in memory_ids:
worker.enqueue(EmbeddingJob(memory_id=mid, text=f"text-{mid}"))
await worker.stop()
with open_db(db) as conn:
# Events landed in enqueue order (FIFO).
cur = conn.execute(
"SELECT json_extract(payload_json, '$.memory_id') "
"FROM event_log WHERE kind = 'embedding_indexed' "
"ORDER BY id"
)
seen = [r[0] for r in cur.fetchall()]
assert seen == memory_ids
# All five embeddings projected.
cur = conn.execute("SELECT COUNT(*) FROM embeddings")
assert cur.fetchone()[0] == 5