feat: embedding worker drains queue and emits embedding_indexed events (T97.1)
This commit is contained in:
@@ -0,0 +1,137 @@
|
||||
"""Embedding worker (T97, Phase 4).
|
||||
|
||||
Drains a queue of embedding jobs. Each job carries a memory id and the
|
||||
narrative text to embed; the worker calls
|
||||
:func:`chat.services.embeddings.generate_embedding` and emits an
|
||||
``embedding_indexed`` event so the projector lands the vector in the
|
||||
``embeddings`` table.
|
||||
|
||||
Mirrors the :class:`chat.services.background.BackgroundWorker` pattern:
|
||||
single asyncio task, sentinel-based shutdown, exceptions are caught and
|
||||
logged so a flaky embedding call doesn't take down the worker. Each job
|
||||
opens its own SQLite connection via ``conn_factory`` — the request path
|
||||
and the worker do not share connections.
|
||||
|
||||
Featherless concurrency (the 2-conn cap) is respected by virtue of the
|
||||
single-task design: jobs run strictly serially. Phase 4's pseudo-embedding
|
||||
path is local and synchronous so this is largely moot, but the pattern
|
||||
is in place for the Phase 4.5+ real-embedding swap.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from sqlite3 import Connection
|
||||
from typing import Callable
|
||||
|
||||
from chat.eventlog.log import append_and_apply
|
||||
from chat.services.embeddings import (
|
||||
DEFAULT_EMBEDDING_DIM,
|
||||
DEFAULT_EMBEDDING_MODEL,
|
||||
FALLBACK_EMBEDDING_MODEL,
|
||||
generate_embedding,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingJob:
|
||||
"""One unit of work for the embedding worker.
|
||||
|
||||
``memory_id`` is the row to attach the vector to; ``text`` is the
|
||||
narrative text to embed (typically ``memories.pov_summary``).
|
||||
"""
|
||||
|
||||
memory_id: int
|
||||
text: str
|
||||
|
||||
|
||||
class EmbeddingWorker:
|
||||
"""asyncio.Queue-backed single-worker task for embedding generation.
|
||||
|
||||
Started on app startup; ``stop()`` enqueues a sentinel and awaits
|
||||
the task so any in-flight job has a chance to finish. Pending jobs
|
||||
after the sentinel are dropped on shutdown.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
conn_factory: Callable[[], Connection],
|
||||
client, # LLMClient | None — unused on the pseudo path.
|
||||
model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
dim: int = DEFAULT_EMBEDDING_DIM,
|
||||
enabled: bool = True,
|
||||
) -> None:
|
||||
self._queue: asyncio.Queue[EmbeddingJob | None] = asyncio.Queue()
|
||||
self._conn_factory = conn_factory
|
||||
self._client = client
|
||||
self._model = model
|
||||
self._dim = dim
|
||||
self._task: asyncio.Task | None = None
|
||||
self.enabled = enabled
|
||||
|
||||
def enqueue(self, job: EmbeddingJob) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
self._queue.put_nowait(job)
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._task is None:
|
||||
self._task = asyncio.create_task(self._run())
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._task is None:
|
||||
return
|
||||
self._queue.put_nowait(None) # sentinel
|
||||
await self._task
|
||||
self._task = None
|
||||
|
||||
async def _run(self) -> None:
|
||||
while True:
|
||||
job = await self._queue.get()
|
||||
if job is None:
|
||||
return
|
||||
try:
|
||||
await self._process(job)
|
||||
except Exception as exc: # noqa: BLE001 — worker must not die
|
||||
log.warning(
|
||||
"embedding worker failed for memory_id=%s: %s",
|
||||
job.memory_id,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _process(self, job: EmbeddingJob) -> None:
|
||||
result = await generate_embedding(
|
||||
self._client,
|
||||
text=job.text,
|
||||
model=self._model,
|
||||
dim=self._dim,
|
||||
)
|
||||
if result.model == FALLBACK_EMBEDDING_MODEL:
|
||||
# Don't index a fallback (zero) vector — the backfill script
|
||||
# can retry later once a real embedding is available.
|
||||
log.debug(
|
||||
"embedding worker skipping fallback result for memory_id=%s",
|
||||
job.memory_id,
|
||||
)
|
||||
return
|
||||
with self._conn_factory() as conn:
|
||||
append_and_apply(
|
||||
conn,
|
||||
kind="embedding_indexed",
|
||||
payload={
|
||||
"memory_id": job.memory_id,
|
||||
"model": result.model,
|
||||
"dim": result.dim,
|
||||
"vector": result.vector,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["EmbeddingJob", "EmbeddingWorker"]
|
||||
@@ -0,0 +1,185 @@
|
||||
"""Embedding worker (T97, Phase 4).
|
||||
|
||||
The worker drains a queue of EmbeddingJobs and emits ``embedding_indexed``
|
||||
events. Mirrors test_significance.py's BackgroundWorker tests in shape:
|
||||
seed a memory, enqueue jobs, call ``stop()`` to drain via sentinel, then
|
||||
assert on the projected ``embeddings`` table and the underlying event_log.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from chat.db.connection import open_db
|
||||
from chat.db.migrate import apply_migrations
|
||||
from chat.eventlog.log import append_event
|
||||
from chat.eventlog.projector import project
|
||||
from chat.services.embedding_worker import EmbeddingJob, EmbeddingWorker
|
||||
from chat.services.embeddings import (
|
||||
DEFAULT_EMBEDDING_MODEL,
|
||||
EmbeddingResult,
|
||||
FALLBACK_EMBEDDING_MODEL,
|
||||
)
|
||||
|
||||
# Trigger handler registration for projection.
|
||||
import chat.state.embeddings # noqa: F401
|
||||
import chat.state.entities # noqa: F401
|
||||
import chat.state.memory # noqa: F401
|
||||
import chat.state.world # noqa: F401
|
||||
|
||||
|
||||
def _seed_memories(db_path: Path, count: int) -> list[int]:
|
||||
"""Seed ``count`` memory rows for ``bot_a`` and return their ids."""
|
||||
with open_db(db_path) as conn:
|
||||
append_event(
|
||||
conn,
|
||||
kind="bot_authored",
|
||||
payload={
|
||||
"id": "bot_a",
|
||||
"name": "BotA",
|
||||
"persona": "...",
|
||||
"voice_samples": [],
|
||||
"traits": [],
|
||||
"backstory": "",
|
||||
"initial_relationship_to_you": "",
|
||||
"kickoff_prose": "",
|
||||
},
|
||||
)
|
||||
append_event(
|
||||
conn,
|
||||
kind="chat_created",
|
||||
payload={
|
||||
"id": "chat_bot_a",
|
||||
"host_bot_id": "bot_a",
|
||||
"initial_time": "2026-04-26T20:00:00+00:00",
|
||||
"narrative_anchor": "Day 1",
|
||||
"weather": "",
|
||||
},
|
||||
)
|
||||
for i in range(count):
|
||||
append_event(
|
||||
conn,
|
||||
kind="memory_written",
|
||||
payload={
|
||||
"owner_id": "bot_a",
|
||||
"chat_id": "chat_bot_a",
|
||||
"pov_summary": f"memory text {i}",
|
||||
"witness_you": 1,
|
||||
"witness_host": 1,
|
||||
"witness_guest": 0,
|
||||
"source": "direct",
|
||||
"reliability": 1.0,
|
||||
"significance": 1,
|
||||
"pinned": 0,
|
||||
"auto_pinned": 0,
|
||||
},
|
||||
)
|
||||
project(conn)
|
||||
return [
|
||||
r[0]
|
||||
for r in conn.execute(
|
||||
"SELECT id FROM memories WHERE owner_id = 'bot_a' ORDER BY id"
|
||||
).fetchall()
|
||||
]
|
||||
|
||||
|
||||
async def test_worker_drains_jobs_and_emits_indexed_events(tmp_path):
|
||||
"""Three jobs in -> three ``embedding_indexed`` events out, all
|
||||
projected into the ``embeddings`` table."""
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
memory_ids = _seed_memories(db, count=3)
|
||||
|
||||
worker = EmbeddingWorker(
|
||||
conn_factory=lambda: open_db(db),
|
||||
client=None, # pseudo path — no client needed
|
||||
)
|
||||
await worker.start()
|
||||
for mid in memory_ids:
|
||||
worker.enqueue(EmbeddingJob(memory_id=mid, text=f"text-{mid}"))
|
||||
await worker.stop()
|
||||
|
||||
with open_db(db) as conn:
|
||||
# Three embedding_indexed events landed.
|
||||
cur = conn.execute(
|
||||
"SELECT COUNT(*) FROM event_log WHERE kind = 'embedding_indexed'"
|
||||
)
|
||||
assert cur.fetchone()[0] == 3
|
||||
# Three rows in the embeddings table — one per memory.
|
||||
cur = conn.execute(
|
||||
"SELECT memory_id, model, dim FROM embeddings ORDER BY memory_id"
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
assert len(rows) == 3
|
||||
for (mid, model, dim), expected_mid in zip(rows, memory_ids):
|
||||
assert mid == expected_mid
|
||||
assert model == DEFAULT_EMBEDDING_MODEL
|
||||
assert dim > 0
|
||||
|
||||
|
||||
async def test_worker_skips_fallback_results(tmp_path, monkeypatch):
|
||||
"""A fallback EmbeddingResult must NOT produce an embedding_indexed
|
||||
event — backfill can retry later when a real embedding is available."""
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
memory_ids = _seed_memories(db, count=1)
|
||||
|
||||
async def _fake_generate(client, *, text, model, dim, timeout_s=30.0):
|
||||
return EmbeddingResult(
|
||||
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim
|
||||
)
|
||||
|
||||
# Patch the symbol the worker resolved at import time.
|
||||
import chat.services.embedding_worker as worker_mod
|
||||
|
||||
monkeypatch.setattr(worker_mod, "generate_embedding", _fake_generate)
|
||||
|
||||
worker = EmbeddingWorker(
|
||||
conn_factory=lambda: open_db(db),
|
||||
client=None,
|
||||
)
|
||||
await worker.start()
|
||||
worker.enqueue(EmbeddingJob(memory_id=memory_ids[0], text="anything"))
|
||||
await worker.stop()
|
||||
|
||||
with open_db(db) as conn:
|
||||
cur = conn.execute(
|
||||
"SELECT COUNT(*) FROM event_log WHERE kind = 'embedding_indexed'"
|
||||
)
|
||||
assert cur.fetchone()[0] == 0
|
||||
cur = conn.execute("SELECT COUNT(*) FROM embeddings")
|
||||
assert cur.fetchone()[0] == 0
|
||||
|
||||
|
||||
async def test_worker_handles_concurrent_jobs_serially(tmp_path):
|
||||
"""Five jobs queued back-to-back must process in FIFO order — the
|
||||
single-task design respects the Featherless 2-conn cap (and keeps
|
||||
event_log ordering deterministic)."""
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
memory_ids = _seed_memories(db, count=5)
|
||||
|
||||
worker = EmbeddingWorker(
|
||||
conn_factory=lambda: open_db(db),
|
||||
client=None,
|
||||
)
|
||||
await worker.start()
|
||||
# Enqueue all five before yielding to the loop — exercises the queue
|
||||
# rather than a one-at-a-time drain.
|
||||
for mid in memory_ids:
|
||||
worker.enqueue(EmbeddingJob(memory_id=mid, text=f"text-{mid}"))
|
||||
await worker.stop()
|
||||
|
||||
with open_db(db) as conn:
|
||||
# Events landed in enqueue order (FIFO).
|
||||
cur = conn.execute(
|
||||
"SELECT json_extract(payload_json, '$.memory_id') "
|
||||
"FROM event_log WHERE kind = 'embedding_indexed' "
|
||||
"ORDER BY id"
|
||||
)
|
||||
seen = [r[0] for r in cur.fetchall()]
|
||||
assert seen == memory_ids
|
||||
|
||||
# All five embeddings projected.
|
||||
cur = conn.execute("SELECT COUNT(*) FROM embeddings")
|
||||
assert cur.fetchone()[0] == 5
|
||||
Reference in New Issue
Block a user