186 lines
6.2 KiB
Python
186 lines
6.2 KiB
Python
"""Embedding worker (T97, Phase 4).
|
|
|
|
The worker drains a queue of EmbeddingJobs and emits ``embedding_indexed``
|
|
events. Mirrors test_significance.py's BackgroundWorker tests in shape:
|
|
seed a memory, enqueue jobs, call ``stop()`` to drain via sentinel, then
|
|
assert on the projected ``embeddings`` table and the underlying event_log.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
|
|
from chat.db.connection import open_db
|
|
from chat.db.migrate import apply_migrations
|
|
from chat.eventlog.log import append_event
|
|
from chat.eventlog.projector import project
|
|
from chat.services.embedding_worker import EmbeddingJob, EmbeddingWorker
|
|
from chat.services.embeddings import (
|
|
DEFAULT_EMBEDDING_MODEL,
|
|
EmbeddingResult,
|
|
FALLBACK_EMBEDDING_MODEL,
|
|
)
|
|
|
|
# Trigger handler registration for projection.
|
|
import chat.state.embeddings # noqa: F401
|
|
import chat.state.entities # noqa: F401
|
|
import chat.state.memory # noqa: F401
|
|
import chat.state.world # noqa: F401
|
|
|
|
|
|
def _seed_memories(db_path: Path, count: int) -> list[int]:
|
|
"""Seed ``count`` memory rows for ``bot_a`` and return their ids."""
|
|
with open_db(db_path) as conn:
|
|
append_event(
|
|
conn,
|
|
kind="bot_authored",
|
|
payload={
|
|
"id": "bot_a",
|
|
"name": "BotA",
|
|
"persona": "...",
|
|
"voice_samples": [],
|
|
"traits": [],
|
|
"backstory": "",
|
|
"initial_relationship_to_you": "",
|
|
"kickoff_prose": "",
|
|
},
|
|
)
|
|
append_event(
|
|
conn,
|
|
kind="chat_created",
|
|
payload={
|
|
"id": "chat_bot_a",
|
|
"host_bot_id": "bot_a",
|
|
"initial_time": "2026-04-26T20:00:00+00:00",
|
|
"narrative_anchor": "Day 1",
|
|
"weather": "",
|
|
},
|
|
)
|
|
for i in range(count):
|
|
append_event(
|
|
conn,
|
|
kind="memory_written",
|
|
payload={
|
|
"owner_id": "bot_a",
|
|
"chat_id": "chat_bot_a",
|
|
"pov_summary": f"memory text {i}",
|
|
"witness_you": 1,
|
|
"witness_host": 1,
|
|
"witness_guest": 0,
|
|
"source": "direct",
|
|
"reliability": 1.0,
|
|
"significance": 1,
|
|
"pinned": 0,
|
|
"auto_pinned": 0,
|
|
},
|
|
)
|
|
project(conn)
|
|
return [
|
|
r[0]
|
|
for r in conn.execute(
|
|
"SELECT id FROM memories WHERE owner_id = 'bot_a' ORDER BY id"
|
|
).fetchall()
|
|
]
|
|
|
|
|
|
async def test_worker_drains_jobs_and_emits_indexed_events(tmp_path):
|
|
"""Three jobs in -> three ``embedding_indexed`` events out, all
|
|
projected into the ``embeddings`` table."""
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
memory_ids = _seed_memories(db, count=3)
|
|
|
|
worker = EmbeddingWorker(
|
|
conn_factory=lambda: open_db(db),
|
|
client=None, # pseudo path — no client needed
|
|
)
|
|
await worker.start()
|
|
for mid in memory_ids:
|
|
worker.enqueue(EmbeddingJob(memory_id=mid, text=f"text-{mid}"))
|
|
await worker.stop()
|
|
|
|
with open_db(db) as conn:
|
|
# Three embedding_indexed events landed.
|
|
cur = conn.execute(
|
|
"SELECT COUNT(*) FROM event_log WHERE kind = 'embedding_indexed'"
|
|
)
|
|
assert cur.fetchone()[0] == 3
|
|
# Three rows in the embeddings table — one per memory.
|
|
cur = conn.execute(
|
|
"SELECT memory_id, model, dim FROM embeddings ORDER BY memory_id"
|
|
)
|
|
rows = cur.fetchall()
|
|
assert len(rows) == 3
|
|
for (mid, model, dim), expected_mid in zip(rows, memory_ids):
|
|
assert mid == expected_mid
|
|
assert model == DEFAULT_EMBEDDING_MODEL
|
|
assert dim > 0
|
|
|
|
|
|
async def test_worker_skips_fallback_results(tmp_path, monkeypatch):
|
|
"""A fallback EmbeddingResult must NOT produce an embedding_indexed
|
|
event — backfill can retry later when a real embedding is available."""
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
memory_ids = _seed_memories(db, count=1)
|
|
|
|
async def _fake_generate(client, *, text, model, dim, timeout_s=30.0):
|
|
return EmbeddingResult(
|
|
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim
|
|
)
|
|
|
|
# Patch the symbol the worker resolved at import time.
|
|
import chat.services.embedding_worker as worker_mod
|
|
|
|
monkeypatch.setattr(worker_mod, "generate_embedding", _fake_generate)
|
|
|
|
worker = EmbeddingWorker(
|
|
conn_factory=lambda: open_db(db),
|
|
client=None,
|
|
)
|
|
await worker.start()
|
|
worker.enqueue(EmbeddingJob(memory_id=memory_ids[0], text="anything"))
|
|
await worker.stop()
|
|
|
|
with open_db(db) as conn:
|
|
cur = conn.execute(
|
|
"SELECT COUNT(*) FROM event_log WHERE kind = 'embedding_indexed'"
|
|
)
|
|
assert cur.fetchone()[0] == 0
|
|
cur = conn.execute("SELECT COUNT(*) FROM embeddings")
|
|
assert cur.fetchone()[0] == 0
|
|
|
|
|
|
async def test_worker_handles_concurrent_jobs_serially(tmp_path):
|
|
"""Five jobs queued back-to-back must process in FIFO order — the
|
|
single-task design respects the Featherless 2-conn cap (and keeps
|
|
event_log ordering deterministic)."""
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
memory_ids = _seed_memories(db, count=5)
|
|
|
|
worker = EmbeddingWorker(
|
|
conn_factory=lambda: open_db(db),
|
|
client=None,
|
|
)
|
|
await worker.start()
|
|
# Enqueue all five before yielding to the loop — exercises the queue
|
|
# rather than a one-at-a-time drain.
|
|
for mid in memory_ids:
|
|
worker.enqueue(EmbeddingJob(memory_id=mid, text=f"text-{mid}"))
|
|
await worker.stop()
|
|
|
|
with open_db(db) as conn:
|
|
# Events landed in enqueue order (FIFO).
|
|
cur = conn.execute(
|
|
"SELECT json_extract(payload_json, '$.memory_id') "
|
|
"FROM event_log WHERE kind = 'embedding_indexed' "
|
|
"ORDER BY id"
|
|
)
|
|
seen = [r[0] for r in cur.fetchall()]
|
|
assert seen == memory_ids
|
|
|
|
# All five embeddings projected.
|
|
cur = conn.execute("SELECT COUNT(*) FROM embeddings")
|
|
assert cur.fetchone()[0] == 5
|