"""Embedding worker (T97, Phase 4). The worker drains a queue of EmbeddingJobs and emits ``embedding_indexed`` events. Mirrors test_significance.py's BackgroundWorker tests in shape: seed a memory, enqueue jobs, call ``stop()`` to drain via sentinel, then assert on the projected ``embeddings`` table and the underlying event_log. """ from __future__ import annotations from pathlib import Path from chat.db.connection import open_db from chat.db.migrate import apply_migrations from chat.eventlog.log import append_event from chat.eventlog.projector import project from chat.services.embedding_worker import EmbeddingJob, EmbeddingWorker from chat.services.embeddings import ( DEFAULT_EMBEDDING_MODEL, EmbeddingResult, FALLBACK_EMBEDDING_MODEL, ) # Trigger handler registration for projection. import chat.state.embeddings # noqa: F401 import chat.state.entities # noqa: F401 import chat.state.memory # noqa: F401 import chat.state.world # noqa: F401 def _seed_memories(db_path: Path, count: int) -> list[int]: """Seed ``count`` memory rows for ``bot_a`` and return their ids.""" with open_db(db_path) as conn: append_event( conn, kind="bot_authored", payload={ "id": "bot_a", "name": "BotA", "persona": "...", "voice_samples": [], "traits": [], "backstory": "", "initial_relationship_to_you": "", "kickoff_prose": "", }, ) append_event( conn, kind="chat_created", payload={ "id": "chat_bot_a", "host_bot_id": "bot_a", "initial_time": "2026-04-26T20:00:00+00:00", "narrative_anchor": "Day 1", "weather": "", }, ) for i in range(count): append_event( conn, kind="memory_written", payload={ "owner_id": "bot_a", "chat_id": "chat_bot_a", "pov_summary": f"memory text {i}", "witness_you": 1, "witness_host": 1, "witness_guest": 0, "source": "direct", "reliability": 1.0, "significance": 1, "pinned": 0, "auto_pinned": 0, }, ) project(conn) return [ r[0] for r in conn.execute( "SELECT id FROM memories WHERE owner_id = 'bot_a' ORDER BY id" ).fetchall() ] async def test_worker_drains_jobs_and_emits_indexed_events(tmp_path): """Three jobs in -> three ``embedding_indexed`` events out, all projected into the ``embeddings`` table.""" db = tmp_path / "t.db" apply_migrations(db) memory_ids = _seed_memories(db, count=3) worker = EmbeddingWorker( conn_factory=lambda: open_db(db), client=None, # pseudo path — no client needed ) await worker.start() for mid in memory_ids: worker.enqueue(EmbeddingJob(memory_id=mid, text=f"text-{mid}")) await worker.stop() with open_db(db) as conn: # Three embedding_indexed events landed. cur = conn.execute( "SELECT COUNT(*) FROM event_log WHERE kind = 'embedding_indexed'" ) assert cur.fetchone()[0] == 3 # Three rows in the embeddings table — one per memory. cur = conn.execute( "SELECT memory_id, model, dim FROM embeddings ORDER BY memory_id" ) rows = cur.fetchall() assert len(rows) == 3 for (mid, model, dim), expected_mid in zip(rows, memory_ids): assert mid == expected_mid assert model == DEFAULT_EMBEDDING_MODEL assert dim > 0 async def test_worker_skips_fallback_results(tmp_path, monkeypatch): """A fallback EmbeddingResult must NOT produce an embedding_indexed event — backfill can retry later when a real embedding is available.""" db = tmp_path / "t.db" apply_migrations(db) memory_ids = _seed_memories(db, count=1) async def _fake_generate(client, *, text, model, dim, timeout_s=30.0): return EmbeddingResult( vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim ) # Patch the symbol the worker resolved at import time. import chat.services.embedding_worker as worker_mod monkeypatch.setattr(worker_mod, "generate_embedding", _fake_generate) worker = EmbeddingWorker( conn_factory=lambda: open_db(db), client=None, ) await worker.start() worker.enqueue(EmbeddingJob(memory_id=memory_ids[0], text="anything")) await worker.stop() with open_db(db) as conn: cur = conn.execute( "SELECT COUNT(*) FROM event_log WHERE kind = 'embedding_indexed'" ) assert cur.fetchone()[0] == 0 cur = conn.execute("SELECT COUNT(*) FROM embeddings") assert cur.fetchone()[0] == 0 async def test_worker_handles_concurrent_jobs_serially(tmp_path): """Five jobs queued back-to-back must process in FIFO order — the single-task design respects the Featherless 2-conn cap (and keeps event_log ordering deterministic).""" db = tmp_path / "t.db" apply_migrations(db) memory_ids = _seed_memories(db, count=5) worker = EmbeddingWorker( conn_factory=lambda: open_db(db), client=None, ) await worker.start() # Enqueue all five before yielding to the loop — exercises the queue # rather than a one-at-a-time drain. for mid in memory_ids: worker.enqueue(EmbeddingJob(memory_id=mid, text=f"text-{mid}")) await worker.stop() with open_db(db) as conn: # Events landed in enqueue order (FIFO). cur = conn.execute( "SELECT json_extract(payload_json, '$.memory_id') " "FROM event_log WHERE kind = 'embedding_indexed' " "ORDER BY id" ) seen = [r[0] for r in cur.fetchall()] assert seen == memory_ids # All five embeddings projected. cur = conn.execute("SELECT COUNT(*) FROM embeddings") assert cur.fetchone()[0] == 5