diff --git a/chat/app.py b/chat/app.py index c9daf90..9e2c74b 100644 --- a/chat/app.py +++ b/chat/app.py @@ -16,6 +16,7 @@ from chat.db.migrate import apply_migrations from chat.eventlog.log import read_events from chat.eventlog.projector import apply_event from chat.services.background import BackgroundWorker +from chat.services.embedding_worker import EmbeddingWorker from chat.services.snapshot import latest_snapshot_path, restore_from_snapshot # Trigger handler registration: @@ -85,9 +86,23 @@ async def lifespan(app: FastAPI): await worker.start() app.state.background_worker = worker + # T97: separate worker for the async embedding pass. Each + # ``memory_written`` enqueues an EmbeddingJob; the worker drains the + # queue, calls ``generate_embedding``, and emits ``embedding_indexed``. + # Phase 4's pseudo-embedding path is local so the worker doesn't need + # an LLM client; we still pass one so the Phase 4.5 swap to a real + # model is a one-line change. + embedding_worker = EmbeddingWorker( + conn_factory=lambda: open_db(settings.db_path), + client=_factory(), + ) + await embedding_worker.start() + app.state.embedding_worker = embedding_worker + try: yield finally: + await embedding_worker.stop() await worker.stop() diff --git a/chat/services/embedding_worker.py b/chat/services/embedding_worker.py new file mode 100644 index 0000000..80f87d8 --- /dev/null +++ b/chat/services/embedding_worker.py @@ -0,0 +1,137 @@ +"""Embedding worker (T97, Phase 4). + +Drains a queue of embedding jobs. Each job carries a memory id and the +narrative text to embed; the worker calls +:func:`chat.services.embeddings.generate_embedding` and emits an +``embedding_indexed`` event so the projector lands the vector in the +``embeddings`` table. + +Mirrors the :class:`chat.services.background.BackgroundWorker` pattern: +single asyncio task, sentinel-based shutdown, exceptions are caught and +logged so a flaky embedding call doesn't take down the worker. Each job +opens its own SQLite connection via ``conn_factory`` — the request path +and the worker do not share connections. + +Featherless concurrency (the 2-conn cap) is respected by virtue of the +single-task design: jobs run strictly serially. Phase 4's pseudo-embedding +path is local and synchronous so this is largely moot, but the pattern +is in place for the Phase 4.5+ real-embedding swap. +""" + +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass +from sqlite3 import Connection +from typing import Callable + +from chat.eventlog.log import append_and_apply +from chat.services.embeddings import ( + DEFAULT_EMBEDDING_DIM, + DEFAULT_EMBEDDING_MODEL, + FALLBACK_EMBEDDING_MODEL, + generate_embedding, +) + + +log = logging.getLogger(__name__) + + +@dataclass +class EmbeddingJob: + """One unit of work for the embedding worker. + + ``memory_id`` is the row to attach the vector to; ``text`` is the + narrative text to embed (typically ``memories.pov_summary``). + """ + + memory_id: int + text: str + + +class EmbeddingWorker: + """asyncio.Queue-backed single-worker task for embedding generation. + + Started on app startup; ``stop()`` enqueues a sentinel and awaits + the task so any in-flight job has a chance to finish. Pending jobs + after the sentinel are dropped on shutdown. + """ + + def __init__( + self, + *, + conn_factory: Callable[[], Connection], + client, # LLMClient | None — unused on the pseudo path. + model: str = DEFAULT_EMBEDDING_MODEL, + dim: int = DEFAULT_EMBEDDING_DIM, + enabled: bool = True, + ) -> None: + self._queue: asyncio.Queue[EmbeddingJob | None] = asyncio.Queue() + self._conn_factory = conn_factory + self._client = client + self._model = model + self._dim = dim + self._task: asyncio.Task | None = None + self.enabled = enabled + + def enqueue(self, job: EmbeddingJob) -> None: + if not self.enabled: + return + self._queue.put_nowait(job) + + async def start(self) -> None: + if self._task is None: + self._task = asyncio.create_task(self._run()) + + async def stop(self) -> None: + if self._task is None: + return + self._queue.put_nowait(None) # sentinel + await self._task + self._task = None + + async def _run(self) -> None: + while True: + job = await self._queue.get() + if job is None: + return + try: + await self._process(job) + except Exception as exc: # noqa: BLE001 — worker must not die + log.warning( + "embedding worker failed for memory_id=%s: %s", + job.memory_id, + exc, + exc_info=True, + ) + + async def _process(self, job: EmbeddingJob) -> None: + result = await generate_embedding( + self._client, + text=job.text, + model=self._model, + dim=self._dim, + ) + if result.model == FALLBACK_EMBEDDING_MODEL: + # Don't index a fallback (zero) vector — the backfill script + # can retry later once a real embedding is available. + log.debug( + "embedding worker skipping fallback result for memory_id=%s", + job.memory_id, + ) + return + with self._conn_factory() as conn: + append_and_apply( + conn, + kind="embedding_indexed", + payload={ + "memory_id": job.memory_id, + "model": result.model, + "dim": result.dim, + "vector": result.vector, + }, + ) + + +__all__ = ["EmbeddingJob", "EmbeddingWorker"] diff --git a/chat/services/memory_write.py b/chat/services/memory_write.py index d60c3d9..5e89eb3 100644 --- a/chat/services/memory_write.py +++ b/chat/services/memory_write.py @@ -13,6 +13,14 @@ Phase 1 simplifications (per plan §11.1, T27 will refine): pass overwrites via a follow-up event. - Witness flags are hard-coded ``[you=1, host=1, guest=0]``. Phase 2 will derive them from ``chat.guest_bot_id`` once a guest can be present. + +T97 (Phase 4): each successful memory write also enqueues an +:class:`~chat.services.embedding_worker.EmbeddingJob` on the +lifespan-managed embedding worker, so the just-written memory gets a +vector indexed out-of-band. The hook is opt-in via the ``app`` kwarg — +callers without a FastAPI app handle (e.g. one-off scripts, isolated +unit tests) simply don't enqueue, and the backfill script can pick up +those rows later. """ from __future__ import annotations @@ -20,6 +28,7 @@ from __future__ import annotations from sqlite3 import Connection from chat.eventlog.log import append_and_apply +from chat.services.embedding_worker import EmbeddingJob def _write_one_memory( @@ -35,9 +44,16 @@ def _write_one_memory( chat_clock_at: str | None, source: str, significance: int, + app=None, ) -> tuple[int, int | None]: """Append a single ``memory_written`` event for ``owner_id`` and return - ``(event_id, memory_id)`` for the projected row.""" + ``(event_id, memory_id)`` for the projected row. + + When ``app`` is provided and ``app.state.embedding_worker`` exists, + enqueue an :class:`EmbeddingJob` for the freshly-projected memory id + (T97). Skipped silently if the worker is absent or the projected row + can't be located — the backfill script handles missing-vector rows. + """ payload: dict = { "owner_id": owner_id, "chat_id": chat_id, @@ -64,6 +80,23 @@ def _write_one_memory( (owner_id, chat_id), ).fetchone() memory_id = row[0] if row else None + + # T97: enqueue an embedding job for the just-written memory. The + # worker drains the queue out-of-band and emits an + # ``embedding_indexed`` event when the vector is ready. ``getattr`` + # keeps this a no-op for callers without a wired-up app (scripts, + # tests) — the backfill script handles those rows. + if memory_id is not None and narrative_text and narrative_text.strip(): + worker = ( + getattr(app.state, "embedding_worker", None) + if app is not None + else None + ) + if worker is not None: + worker.enqueue( + EmbeddingJob(memory_id=memory_id, text=narrative_text) + ) + return event_id, memory_id @@ -79,6 +112,7 @@ def record_turn_memory_for_present( source: str = "direct", significance: int = 1, you_present: bool = True, + app=None, ) -> dict[str, tuple[int, int | None]]: """Single entry-point for per-turn memory writes (T84). @@ -97,6 +131,9 @@ def record_turn_memory_for_present( with ``you_present=False`` is a programming error and raises :class:`ValueError`. + When ``app`` is provided, each per-witness write also enqueues an + :class:`EmbeddingJob` on ``app.state.embedding_worker`` (T97). + Returns a mapping ``{bot_id: (event_id, memory_id)}`` so callers can look up the freshly-projected memory id per owner without re-querying the database. @@ -121,6 +158,7 @@ def record_turn_memory_for_present( chat_clock_at=chat_clock_at, source=source, significance=significance, + app=app, ) if guest_bot_id is not None: result[guest_bot_id] = _write_one_memory( @@ -135,6 +173,7 @@ def record_turn_memory_for_present( chat_clock_at=chat_clock_at, source=source, significance=significance, + app=app, ) return result @@ -150,6 +189,7 @@ def record_meanwhile_memory( chat_clock_at: str | None = None, source: str = "direct", significance: int = 1, + app=None, ) -> dict[str, tuple[int, int | None]]: """Backward-compat thin wrapper for meanwhile memory writes (T64, T84). @@ -169,4 +209,5 @@ def record_meanwhile_memory( source=source, significance=significance, you_present=False, + app=app, ) diff --git a/chat/services/regenerate.py b/chat/services/regenerate.py index 0678a76..6442bb2 100644 --- a/chat/services/regenerate.py +++ b/chat/services/regenerate.py @@ -103,6 +103,7 @@ async def regenerate_assistant_turn( chat_id: str, original_assistant_event_id: int, edited_user_prose: str | None = None, + app=None, ) -> str: """Regenerate the assistant turn linked to ``original_assistant_event_id``. @@ -414,6 +415,7 @@ async def regenerate_assistant_turn( narrative_text=new_text, scene_id=scene["id"] if scene else None, chat_clock_at=chat.get("time"), + app=app, ) last_at = chat.get("time") @@ -648,6 +650,7 @@ async def regenerate_assistant_turn( narrative_text=interject_text, scene_id=scene["id"] if scene else None, chat_clock_at=chat.get("time"), + app=app, ) # Re-run the multi-pair state-update with the post-interjection diff --git a/chat/web/drawer.py b/chat/web/drawer.py index bcfdc0d..97f03cf 100644 --- a/chat/web/drawer.py +++ b/chat/web/drawer.py @@ -993,6 +993,7 @@ async def skip_elision( chat_id=chat_id, new_time=new_time, landing_state_hint=landing_state_hint, + app=request.app, ) except ChatNotFoundError as exc: # Missing chat row: typed exception (T81) replaces the prior @@ -1036,6 +1037,7 @@ async def skip_jump( new_time=new_time, notable_prose=notable_prose, reset_activity=reset_flag, + app=request.app, ) except ChatNotFoundError as exc: # Missing chat row: typed exception (T81) replaces the prior diff --git a/chat/web/meanwhile.py b/chat/web/meanwhile.py index 5c46b3e..52a91bc 100644 --- a/chat/web/meanwhile.py +++ b/chat/web/meanwhile.py @@ -131,6 +131,7 @@ async def process_meanwhile_turn( *, chat_id: str, prose: str, + app=None, ) -> dict: """Run one meanwhile turn end-to-end. @@ -314,6 +315,7 @@ async def process_meanwhile_turn( narrative_text=text, scene_id=scene_id, chat_clock_at=chat.get("time"), + app=app, ) # 9. Post-turn state-update — exactly 2 directed pairs over the diff --git a/chat/web/skip.py b/chat/web/skip.py index b6aa179..fd241df 100644 --- a/chat/web/skip.py +++ b/chat/web/skip.py @@ -91,6 +91,7 @@ async def process_elision_skip( chat_id: str, new_time: str, landing_state_hint: str = "", + app=None, ) -> dict: """Run an elision skip end-to-end. @@ -175,6 +176,7 @@ async def process_jump_skip( new_time: str, notable_prose: str = "", reset_activity: bool = False, + app=None, ) -> dict: """Run a jump skip end-to-end. @@ -254,6 +256,7 @@ async def process_jump_skip( chat_clock_at=new_time, source="synthesized", significance=mem.significance, + app=app, ) narration = await narrate_skip( diff --git a/chat/web/turns.py b/chat/web/turns.py index 94f46d4..97ef4a6 100644 --- a/chat/web/turns.py +++ b/chat/web/turns.py @@ -248,6 +248,7 @@ async def post_turn( settings, chat_id=chat_id, prose=prose, + app=request.app, ) except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) @@ -352,6 +353,7 @@ async def post_turn( new_time=new_time, landing_state_hint=getattr(parsed, "landing_state_hint", "") or "", + app=request.app, ) except ChatNotFoundError as exc: # Defensive: chat existence is checked above, so this only @@ -512,6 +514,7 @@ async def post_turn( narrative_text=primary_text, scene_id=scene["id"] if scene else None, chat_clock_at=chat.get("time"), + app=request.app, ) # 7b. Post-turn state-update pass (Requirements §3.4 / T40). All @@ -746,6 +749,7 @@ async def post_turn( narrative_text=interjection_text, scene_id=scene["id"] if scene else None, chat_clock_at=chat.get("time"), + app=request.app, ) # T74.2: enqueue a significance pass for the interjection @@ -1092,6 +1096,7 @@ async def regenerate_turn( chat_id=chat_id, original_assistant_event_id=event_id, edited_user_prose=edited_prose, + app=request.app, ) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) diff --git a/scripts/backfill_embeddings.py b/scripts/backfill_embeddings.py new file mode 100644 index 0000000..f5c15bb --- /dev/null +++ b/scripts/backfill_embeddings.py @@ -0,0 +1,97 @@ +"""Backfill embeddings for memories that lack them (T97, Phase 4). + +Walks all memories where no row exists in the ``embeddings`` table. For +each, calls :func:`chat.services.embeddings.generate_embedding` and emits +an ``embedding_indexed`` event so the projector lands the vector. + +Phase 4 ships the deterministic local pseudo-embedding so this script +runs synchronously without a network round-trip — the LLMClient argument +is not needed on the pseudo path. Phase 4.5+ will need a real client. + +Run from the repo root: + .venv/bin/python scripts/backfill_embeddings.py [--limit N] [--dry-run] +""" + +from __future__ import annotations + +import argparse +import asyncio + +from chat.config import load_settings +from chat.db.connection import open_db +from chat.db.migrate import apply_migrations +from chat.eventlog.log import append_and_apply +from chat.services.embeddings import ( + FALLBACK_EMBEDDING_MODEL, + generate_embedding, +) + +# Trigger projector handler registration so ``append_and_apply`` lands +# the embedding rows correctly. +import chat.state.embeddings # noqa: F401 +import chat.state.entities # noqa: F401 +import chat.state.memory # noqa: F401 +import chat.state.world # noqa: F401 + + +async def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--limit", + type=int, + default=None, + help="Cap the number of memories backfilled in this run.", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print the count of memories needing embeddings, then exit.", + ) + args = parser.parse_args() + + settings = load_settings() + settings.db_path.parent.mkdir(parents=True, exist_ok=True) + apply_migrations(settings.db_path) + + with open_db(settings.db_path) as conn: + sql = ( + "SELECT m.id, m.pov_summary FROM memories m " + "LEFT JOIN embeddings e ON e.memory_id = m.id " + "WHERE e.memory_id IS NULL " + "ORDER BY m.id" + ) + if args.limit is not None: + sql += f" LIMIT {int(args.limit)}" + rows = conn.execute(sql).fetchall() + print(f"Found {len(rows)} memories needing embeddings.") + if args.dry_run: + return + + indexed = 0 + skipped = 0 + for memory_id, text in rows: + result = await generate_embedding( + client=None, # pseudo path: no client needed + text=text or "", + ) + if result.model == FALLBACK_EMBEDDING_MODEL: + print(f" Skipping memory_id={memory_id} (empty text)") + skipped += 1 + continue + append_and_apply( + conn, + kind="embedding_indexed", + payload={ + "memory_id": memory_id, + "model": result.model, + "dim": result.dim, + "vector": result.vector, + }, + ) + indexed += 1 + print(f" Indexed memory_id={memory_id}") + print(f"Done. Indexed {indexed}, skipped {skipped}.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_embedding_worker.py b/tests/test_embedding_worker.py new file mode 100644 index 0000000..f7d9416 --- /dev/null +++ b/tests/test_embedding_worker.py @@ -0,0 +1,185 @@ +"""Embedding worker (T97, Phase 4). + +The worker drains a queue of EmbeddingJobs and emits ``embedding_indexed`` +events. Mirrors test_significance.py's BackgroundWorker tests in shape: +seed a memory, enqueue jobs, call ``stop()`` to drain via sentinel, then +assert on the projected ``embeddings`` table and the underlying event_log. +""" + +from __future__ import annotations + +from pathlib import Path + +from chat.db.connection import open_db +from chat.db.migrate import apply_migrations +from chat.eventlog.log import append_event +from chat.eventlog.projector import project +from chat.services.embedding_worker import EmbeddingJob, EmbeddingWorker +from chat.services.embeddings import ( + DEFAULT_EMBEDDING_MODEL, + EmbeddingResult, + FALLBACK_EMBEDDING_MODEL, +) + +# Trigger handler registration for projection. +import chat.state.embeddings # noqa: F401 +import chat.state.entities # noqa: F401 +import chat.state.memory # noqa: F401 +import chat.state.world # noqa: F401 + + +def _seed_memories(db_path: Path, count: int) -> list[int]: + """Seed ``count`` memory rows for ``bot_a`` and return their ids.""" + with open_db(db_path) as conn: + append_event( + conn, + kind="bot_authored", + payload={ + "id": "bot_a", + "name": "BotA", + "persona": "...", + "voice_samples": [], + "traits": [], + "backstory": "", + "initial_relationship_to_you": "", + "kickoff_prose": "", + }, + ) + append_event( + conn, + kind="chat_created", + payload={ + "id": "chat_bot_a", + "host_bot_id": "bot_a", + "initial_time": "2026-04-26T20:00:00+00:00", + "narrative_anchor": "Day 1", + "weather": "", + }, + ) + for i in range(count): + append_event( + conn, + kind="memory_written", + payload={ + "owner_id": "bot_a", + "chat_id": "chat_bot_a", + "pov_summary": f"memory text {i}", + "witness_you": 1, + "witness_host": 1, + "witness_guest": 0, + "source": "direct", + "reliability": 1.0, + "significance": 1, + "pinned": 0, + "auto_pinned": 0, + }, + ) + project(conn) + return [ + r[0] + for r in conn.execute( + "SELECT id FROM memories WHERE owner_id = 'bot_a' ORDER BY id" + ).fetchall() + ] + + +async def test_worker_drains_jobs_and_emits_indexed_events(tmp_path): + """Three jobs in -> three ``embedding_indexed`` events out, all + projected into the ``embeddings`` table.""" + db = tmp_path / "t.db" + apply_migrations(db) + memory_ids = _seed_memories(db, count=3) + + worker = EmbeddingWorker( + conn_factory=lambda: open_db(db), + client=None, # pseudo path — no client needed + ) + await worker.start() + for mid in memory_ids: + worker.enqueue(EmbeddingJob(memory_id=mid, text=f"text-{mid}")) + await worker.stop() + + with open_db(db) as conn: + # Three embedding_indexed events landed. + cur = conn.execute( + "SELECT COUNT(*) FROM event_log WHERE kind = 'embedding_indexed'" + ) + assert cur.fetchone()[0] == 3 + # Three rows in the embeddings table — one per memory. + cur = conn.execute( + "SELECT memory_id, model, dim FROM embeddings ORDER BY memory_id" + ) + rows = cur.fetchall() + assert len(rows) == 3 + for (mid, model, dim), expected_mid in zip(rows, memory_ids): + assert mid == expected_mid + assert model == DEFAULT_EMBEDDING_MODEL + assert dim > 0 + + +async def test_worker_skips_fallback_results(tmp_path, monkeypatch): + """A fallback EmbeddingResult must NOT produce an embedding_indexed + event — backfill can retry later when a real embedding is available.""" + db = tmp_path / "t.db" + apply_migrations(db) + memory_ids = _seed_memories(db, count=1) + + async def _fake_generate(client, *, text, model, dim, timeout_s=30.0): + return EmbeddingResult( + vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim + ) + + # Patch the symbol the worker resolved at import time. + import chat.services.embedding_worker as worker_mod + + monkeypatch.setattr(worker_mod, "generate_embedding", _fake_generate) + + worker = EmbeddingWorker( + conn_factory=lambda: open_db(db), + client=None, + ) + await worker.start() + worker.enqueue(EmbeddingJob(memory_id=memory_ids[0], text="anything")) + await worker.stop() + + with open_db(db) as conn: + cur = conn.execute( + "SELECT COUNT(*) FROM event_log WHERE kind = 'embedding_indexed'" + ) + assert cur.fetchone()[0] == 0 + cur = conn.execute("SELECT COUNT(*) FROM embeddings") + assert cur.fetchone()[0] == 0 + + +async def test_worker_handles_concurrent_jobs_serially(tmp_path): + """Five jobs queued back-to-back must process in FIFO order — the + single-task design respects the Featherless 2-conn cap (and keeps + event_log ordering deterministic).""" + db = tmp_path / "t.db" + apply_migrations(db) + memory_ids = _seed_memories(db, count=5) + + worker = EmbeddingWorker( + conn_factory=lambda: open_db(db), + client=None, + ) + await worker.start() + # Enqueue all five before yielding to the loop — exercises the queue + # rather than a one-at-a-time drain. + for mid in memory_ids: + worker.enqueue(EmbeddingJob(memory_id=mid, text=f"text-{mid}")) + await worker.stop() + + with open_db(db) as conn: + # Events landed in enqueue order (FIFO). + cur = conn.execute( + "SELECT json_extract(payload_json, '$.memory_id') " + "FROM event_log WHERE kind = 'embedding_indexed' " + "ORDER BY id" + ) + seen = [r[0] for r in cur.fetchall()] + assert seen == memory_ids + + # All five embeddings projected. + cur = conn.execute("SELECT COUNT(*) FROM embeddings") + assert cur.fetchone()[0] == 5 diff --git a/tests/test_memory_write.py b/tests/test_memory_write.py index 8c5253a..3c135a5 100644 --- a/tests/test_memory_write.py +++ b/tests/test_memory_write.py @@ -540,3 +540,49 @@ def test_record_turn_memory_you_present_false_requires_guest(tmp_path): narrative_text="invalid", you_present=False, ) + + +# --------------------------------------------------------------------------- +# T97: embedding-worker enqueue hook. +# --------------------------------------------------------------------------- + + +def test_record_turn_memory_enqueues_embedding_job(tmp_path): + """When ``app.state.embedding_worker`` is wired, every per-witness + write enqueues an :class:`EmbeddingJob` carrying the freshly-projected + memory id and the narrative text. Two-bot turn -> two jobs.""" + from types import SimpleNamespace + + from chat.services.embedding_worker import EmbeddingJob + + db = tmp_path / "t.db" + apply_migrations(db) + _seed_two_bots(db) + + captured: list[EmbeddingJob] = [] + + class _StubWorker: + def enqueue(self, job: EmbeddingJob) -> None: + captured.append(job) + + fake_app = SimpleNamespace( + state=SimpleNamespace(embedding_worker=_StubWorker()) + ) + + with open_db(db) as conn: + result = record_turn_memory_for_present( + conn, + chat_id="chat_ab", + host_bot_id="bot_a", + guest_bot_id="bot_b", + narrative_text="Both bots witness this beat.", + app=fake_app, + ) + + # One job per witness — host first, then guest (matches result dict + # insertion order in record_turn_memory_for_present). + assert len(captured) == 2 + expected_ids = {result["bot_a"][1], result["bot_b"][1]} + assert {job.memory_id for job in captured} == expected_ids + for job in captured: + assert job.text == "Both bots witness this beat." diff --git a/tests/test_phase4_integration.py b/tests/test_phase4_integration.py new file mode 100644 index 0000000..ee30f07 --- /dev/null +++ b/tests/test_phase4_integration.py @@ -0,0 +1,180 @@ +"""Phase 4 cross-feature integration tests (T97 follow-up). + +Wave 8 / T101 will populate this file with the full Phase 4 retrieval + +embedding integration suite. For now this houses a single test pinning +the T97.5 wiring: the production turn route plumbs ``app=request.app`` +all the way through ``record_turn_memory_for_present`` so the embedding +worker actually receives jobs in production. Without this fix-up the +plumbing added in T97 was dormant — every per-witness write took the +no-app branch and silently dropped the embed enqueue. + +The test monkeypatches ``app.state.embedding_worker.enqueue`` to record +jobs (rather than draining the worker mid-test) so the assertion is +deterministic and free of asyncio-timing flakiness inside FastAPI's +TestClient. The bug we're guarding against is "did the call site pass +``app`` at all" — the worker's drain path is exercised in +:mod:`tests.test_embedding_worker`, so duplicating that here would add +no coverage. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest +from fastapi.testclient import TestClient + +from chat.app import app +from chat.db.connection import open_db +from chat.eventlog.log import append_event +from chat.eventlog.projector import project +from chat.llm.mock import MockLLMClient + + +def _zero_state() -> str: + return json.dumps( + {"affinity_delta": 0, "trust_delta": 0, "knowledge_facts": []} + ) + + +def _override_llm(canned: list[str]) -> MockLLMClient: + from chat.web.kickoff import get_llm_client + + mock = MockLLMClient(canned=list(canned)) + app.dependency_overrides[get_llm_client] = lambda: mock + return mock + + +@pytest.fixture +def app_state_setup(tmp_path, monkeypatch): + cfg = tmp_path / "config.toml" + cfg.write_text('featherless_api_key = "test"\n') + monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg)) + db = tmp_path / "test.db" + monkeypatch.setenv("CHAT_DB_PATH", str(db)) + with TestClient(app) as c: + # The background worker is disabled so the canned-response queue + # is consumed only by the request path. The embedding worker + # stays "started" but its loop won't observe the captured + # enqueues — we replace ``enqueue`` on the worker instance below. + app.state.background_worker.enabled = False + yield c + app.dependency_overrides.clear() + + +def _seed(db_path: Path) -> None: + """Mirror of ``tests/test_turn_flow.py::_seed`` — single bot + chat + + edge + activities so the prompt assembler has something to render. + """ + with open_db(db_path) as conn: + append_event( + conn, + kind="bot_authored", + payload={ + "id": "bot_a", + "name": "BotA", + "persona": "thoughtful, observant", + "voice_samples": [], + "traits": [], + "backstory": "", + "initial_relationship_to_you": "", + "kickoff_prose": "...", + }, + ) + append_event( + conn, + kind="chat_created", + payload={ + "id": "chat_bot_a", + "host_bot_id": "bot_a", + "initial_time": "2026-04-26T20:00:00+00:00", + "narrative_anchor": "Day 1", + "weather": "", + }, + ) + append_event( + conn, + kind="edge_update", + payload={ + "source_id": "bot_a", + "target_id": "you", + "chat_id": "chat_bot_a", + "knowledge_facts": ["coworker"], + }, + ) + for entity_id, verb in [("you", "talking"), ("bot_a", "listening")]: + append_event( + conn, + kind="activity_change", + payload={ + "entity_id": entity_id, + "posture": "sitting", + "action": { + "verb": verb, + "interruptible": True, + "required_attention": "low", + "expected_duration": "ongoing", + }, + "attention": "", + "holding": [], + "status": {}, + }, + ) + project(conn) + + +def test_post_turn_embeddings_indexed_via_worker_hook( + app_state_setup, tmp_path +): + """POST a turn; the route must pass ``app=request.app`` into + ``record_turn_memory_for_present`` so the per-witness write enqueues + an :class:`EmbeddingJob` on ``app.state.embedding_worker``. + + Without the T97.5 wiring this test fails: the call site previously + omitted ``app=`` and the helper's ``app is None`` branch silently + skipped every enqueue. We monkeypatch ``enqueue`` on the live + embedding worker (rather than draining the queue mid-request) so the + assertion does not depend on asyncio scheduling inside the + TestClient — the bug is in the wiring, and the wiring is what we + pin. The drain path is covered separately in + :mod:`tests.test_embedding_worker`. + """ + _seed(tmp_path / "test.db") + + canned_parse = json.dumps( + {"segments": [{"kind": "dialogue", "text": "hello"}]} + ) + _override_llm( + [canned_parse, "Hi there.", _zero_state(), _zero_state()] + ) + + captured: list = [] + worker = app.state.embedding_worker + original_enqueue = worker.enqueue + worker.enqueue = captured.append # type: ignore[assignment] + try: + response = app_state_setup.post( + "/chats/chat_bot_a/turns", data={"prose": "hello"} + ) + assert response.status_code == 204 + finally: + worker.enqueue = original_enqueue # type: ignore[assignment] + app.dependency_overrides.clear() + + # Single-bot turn -> one ``memory_written`` -> one EmbeddingJob. + # The job's ``memory_id`` should match the freshly-projected memory + # row, and its ``text`` should carry the assistant's narrative text. + assert len(captured) == 1 + job = captured[0] + assert job.text == "Hi there." + + with open_db(tmp_path / "test.db") as conn: + memory_ids = [ + r[0] + for r in conn.execute( + "SELECT id FROM memories WHERE owner_id = ?", + ("bot_a",), + ).fetchall() + ] + assert job.memory_id in memory_ids