feat: embedding worker drains queue and emits embedding_indexed events (T97.1)
This commit is contained in:
@@ -0,0 +1,185 @@
|
||||
"""Embedding worker (T97, Phase 4).
|
||||
|
||||
The worker drains a queue of EmbeddingJobs and emits ``embedding_indexed``
|
||||
events. Mirrors test_significance.py's BackgroundWorker tests in shape:
|
||||
seed a memory, enqueue jobs, call ``stop()`` to drain via sentinel, then
|
||||
assert on the projected ``embeddings`` table and the underlying event_log.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from chat.db.connection import open_db
|
||||
from chat.db.migrate import apply_migrations
|
||||
from chat.eventlog.log import append_event
|
||||
from chat.eventlog.projector import project
|
||||
from chat.services.embedding_worker import EmbeddingJob, EmbeddingWorker
|
||||
from chat.services.embeddings import (
|
||||
DEFAULT_EMBEDDING_MODEL,
|
||||
EmbeddingResult,
|
||||
FALLBACK_EMBEDDING_MODEL,
|
||||
)
|
||||
|
||||
# Trigger handler registration for projection.
|
||||
import chat.state.embeddings # noqa: F401
|
||||
import chat.state.entities # noqa: F401
|
||||
import chat.state.memory # noqa: F401
|
||||
import chat.state.world # noqa: F401
|
||||
|
||||
|
||||
def _seed_memories(db_path: Path, count: int) -> list[int]:
|
||||
"""Seed ``count`` memory rows for ``bot_a`` and return their ids."""
|
||||
with open_db(db_path) as conn:
|
||||
append_event(
|
||||
conn,
|
||||
kind="bot_authored",
|
||||
payload={
|
||||
"id": "bot_a",
|
||||
"name": "BotA",
|
||||
"persona": "...",
|
||||
"voice_samples": [],
|
||||
"traits": [],
|
||||
"backstory": "",
|
||||
"initial_relationship_to_you": "",
|
||||
"kickoff_prose": "",
|
||||
},
|
||||
)
|
||||
append_event(
|
||||
conn,
|
||||
kind="chat_created",
|
||||
payload={
|
||||
"id": "chat_bot_a",
|
||||
"host_bot_id": "bot_a",
|
||||
"initial_time": "2026-04-26T20:00:00+00:00",
|
||||
"narrative_anchor": "Day 1",
|
||||
"weather": "",
|
||||
},
|
||||
)
|
||||
for i in range(count):
|
||||
append_event(
|
||||
conn,
|
||||
kind="memory_written",
|
||||
payload={
|
||||
"owner_id": "bot_a",
|
||||
"chat_id": "chat_bot_a",
|
||||
"pov_summary": f"memory text {i}",
|
||||
"witness_you": 1,
|
||||
"witness_host": 1,
|
||||
"witness_guest": 0,
|
||||
"source": "direct",
|
||||
"reliability": 1.0,
|
||||
"significance": 1,
|
||||
"pinned": 0,
|
||||
"auto_pinned": 0,
|
||||
},
|
||||
)
|
||||
project(conn)
|
||||
return [
|
||||
r[0]
|
||||
for r in conn.execute(
|
||||
"SELECT id FROM memories WHERE owner_id = 'bot_a' ORDER BY id"
|
||||
).fetchall()
|
||||
]
|
||||
|
||||
|
||||
async def test_worker_drains_jobs_and_emits_indexed_events(tmp_path):
|
||||
"""Three jobs in -> three ``embedding_indexed`` events out, all
|
||||
projected into the ``embeddings`` table."""
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
memory_ids = _seed_memories(db, count=3)
|
||||
|
||||
worker = EmbeddingWorker(
|
||||
conn_factory=lambda: open_db(db),
|
||||
client=None, # pseudo path — no client needed
|
||||
)
|
||||
await worker.start()
|
||||
for mid in memory_ids:
|
||||
worker.enqueue(EmbeddingJob(memory_id=mid, text=f"text-{mid}"))
|
||||
await worker.stop()
|
||||
|
||||
with open_db(db) as conn:
|
||||
# Three embedding_indexed events landed.
|
||||
cur = conn.execute(
|
||||
"SELECT COUNT(*) FROM event_log WHERE kind = 'embedding_indexed'"
|
||||
)
|
||||
assert cur.fetchone()[0] == 3
|
||||
# Three rows in the embeddings table — one per memory.
|
||||
cur = conn.execute(
|
||||
"SELECT memory_id, model, dim FROM embeddings ORDER BY memory_id"
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
assert len(rows) == 3
|
||||
for (mid, model, dim), expected_mid in zip(rows, memory_ids):
|
||||
assert mid == expected_mid
|
||||
assert model == DEFAULT_EMBEDDING_MODEL
|
||||
assert dim > 0
|
||||
|
||||
|
||||
async def test_worker_skips_fallback_results(tmp_path, monkeypatch):
|
||||
"""A fallback EmbeddingResult must NOT produce an embedding_indexed
|
||||
event — backfill can retry later when a real embedding is available."""
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
memory_ids = _seed_memories(db, count=1)
|
||||
|
||||
async def _fake_generate(client, *, text, model, dim, timeout_s=30.0):
|
||||
return EmbeddingResult(
|
||||
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim
|
||||
)
|
||||
|
||||
# Patch the symbol the worker resolved at import time.
|
||||
import chat.services.embedding_worker as worker_mod
|
||||
|
||||
monkeypatch.setattr(worker_mod, "generate_embedding", _fake_generate)
|
||||
|
||||
worker = EmbeddingWorker(
|
||||
conn_factory=lambda: open_db(db),
|
||||
client=None,
|
||||
)
|
||||
await worker.start()
|
||||
worker.enqueue(EmbeddingJob(memory_id=memory_ids[0], text="anything"))
|
||||
await worker.stop()
|
||||
|
||||
with open_db(db) as conn:
|
||||
cur = conn.execute(
|
||||
"SELECT COUNT(*) FROM event_log WHERE kind = 'embedding_indexed'"
|
||||
)
|
||||
assert cur.fetchone()[0] == 0
|
||||
cur = conn.execute("SELECT COUNT(*) FROM embeddings")
|
||||
assert cur.fetchone()[0] == 0
|
||||
|
||||
|
||||
async def test_worker_handles_concurrent_jobs_serially(tmp_path):
|
||||
"""Five jobs queued back-to-back must process in FIFO order — the
|
||||
single-task design respects the Featherless 2-conn cap (and keeps
|
||||
event_log ordering deterministic)."""
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
memory_ids = _seed_memories(db, count=5)
|
||||
|
||||
worker = EmbeddingWorker(
|
||||
conn_factory=lambda: open_db(db),
|
||||
client=None,
|
||||
)
|
||||
await worker.start()
|
||||
# Enqueue all five before yielding to the loop — exercises the queue
|
||||
# rather than a one-at-a-time drain.
|
||||
for mid in memory_ids:
|
||||
worker.enqueue(EmbeddingJob(memory_id=mid, text=f"text-{mid}"))
|
||||
await worker.stop()
|
||||
|
||||
with open_db(db) as conn:
|
||||
# Events landed in enqueue order (FIFO).
|
||||
cur = conn.execute(
|
||||
"SELECT json_extract(payload_json, '$.memory_id') "
|
||||
"FROM event_log WHERE kind = 'embedding_indexed' "
|
||||
"ORDER BY id"
|
||||
)
|
||||
seen = [r[0] for r in cur.fetchall()]
|
||||
assert seen == memory_ids
|
||||
|
||||
# All five embeddings projected.
|
||||
cur = conn.execute("SELECT COUNT(*) FROM embeddings")
|
||||
assert cur.fetchone()[0] == 5
|
||||
Reference in New Issue
Block a user