Phase 4: vector retrieval, branching, drawer polish #6

Merged
dohertj2 merged 41 commits from phase-4 into main 2026-04-27 04:10:26 -04:00
12 changed files with 717 additions and 1 deletions
Showing only changes of commit 7899c50b6c - Show all commits
+15
View File
@@ -16,6 +16,7 @@ from chat.db.migrate import apply_migrations
from chat.eventlog.log import read_events from chat.eventlog.log import read_events
from chat.eventlog.projector import apply_event from chat.eventlog.projector import apply_event
from chat.services.background import BackgroundWorker from chat.services.background import BackgroundWorker
from chat.services.embedding_worker import EmbeddingWorker
from chat.services.snapshot import latest_snapshot_path, restore_from_snapshot from chat.services.snapshot import latest_snapshot_path, restore_from_snapshot
# Trigger handler registration: # Trigger handler registration:
@@ -85,9 +86,23 @@ async def lifespan(app: FastAPI):
await worker.start() await worker.start()
app.state.background_worker = worker app.state.background_worker = worker
# T97: separate worker for the async embedding pass. Each
# ``memory_written`` enqueues an EmbeddingJob; the worker drains the
# queue, calls ``generate_embedding``, and emits ``embedding_indexed``.
# Phase 4's pseudo-embedding path is local so the worker doesn't need
# an LLM client; we still pass one so the Phase 4.5 swap to a real
# model is a one-line change.
embedding_worker = EmbeddingWorker(
conn_factory=lambda: open_db(settings.db_path),
client=_factory(),
)
await embedding_worker.start()
app.state.embedding_worker = embedding_worker
try: try:
yield yield
finally: finally:
await embedding_worker.stop()
await worker.stop() await worker.stop()
+137
View File
@@ -0,0 +1,137 @@
"""Embedding worker (T97, Phase 4).
Drains a queue of embedding jobs. Each job carries a memory id and the
narrative text to embed; the worker calls
:func:`chat.services.embeddings.generate_embedding` and emits an
``embedding_indexed`` event so the projector lands the vector in the
``embeddings`` table.
Mirrors the :class:`chat.services.background.BackgroundWorker` pattern:
single asyncio task, sentinel-based shutdown, exceptions are caught and
logged so a flaky embedding call doesn't take down the worker. Each job
opens its own SQLite connection via ``conn_factory`` — the request path
and the worker do not share connections.
Featherless concurrency (the 2-conn cap) is respected by virtue of the
single-task design: jobs run strictly serially. Phase 4's pseudo-embedding
path is local and synchronous so this is largely moot, but the pattern
is in place for the Phase 4.5+ real-embedding swap.
"""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from sqlite3 import Connection
from typing import Callable
from chat.eventlog.log import append_and_apply
from chat.services.embeddings import (
DEFAULT_EMBEDDING_DIM,
DEFAULT_EMBEDDING_MODEL,
FALLBACK_EMBEDDING_MODEL,
generate_embedding,
)
log = logging.getLogger(__name__)
@dataclass
class EmbeddingJob:
"""One unit of work for the embedding worker.
``memory_id`` is the row to attach the vector to; ``text`` is the
narrative text to embed (typically ``memories.pov_summary``).
"""
memory_id: int
text: str
class EmbeddingWorker:
"""asyncio.Queue-backed single-worker task for embedding generation.
Started on app startup; ``stop()`` enqueues a sentinel and awaits
the task so any in-flight job has a chance to finish. Pending jobs
after the sentinel are dropped on shutdown.
"""
def __init__(
self,
*,
conn_factory: Callable[[], Connection],
client, # LLMClient | None — unused on the pseudo path.
model: str = DEFAULT_EMBEDDING_MODEL,
dim: int = DEFAULT_EMBEDDING_DIM,
enabled: bool = True,
) -> None:
self._queue: asyncio.Queue[EmbeddingJob | None] = asyncio.Queue()
self._conn_factory = conn_factory
self._client = client
self._model = model
self._dim = dim
self._task: asyncio.Task | None = None
self.enabled = enabled
def enqueue(self, job: EmbeddingJob) -> None:
if not self.enabled:
return
self._queue.put_nowait(job)
async def start(self) -> None:
if self._task is None:
self._task = asyncio.create_task(self._run())
async def stop(self) -> None:
if self._task is None:
return
self._queue.put_nowait(None) # sentinel
await self._task
self._task = None
async def _run(self) -> None:
while True:
job = await self._queue.get()
if job is None:
return
try:
await self._process(job)
except Exception as exc: # noqa: BLE001 — worker must not die
log.warning(
"embedding worker failed for memory_id=%s: %s",
job.memory_id,
exc,
exc_info=True,
)
async def _process(self, job: EmbeddingJob) -> None:
result = await generate_embedding(
self._client,
text=job.text,
model=self._model,
dim=self._dim,
)
if result.model == FALLBACK_EMBEDDING_MODEL:
# Don't index a fallback (zero) vector — the backfill script
# can retry later once a real embedding is available.
log.debug(
"embedding worker skipping fallback result for memory_id=%s",
job.memory_id,
)
return
with self._conn_factory() as conn:
append_and_apply(
conn,
kind="embedding_indexed",
payload={
"memory_id": job.memory_id,
"model": result.model,
"dim": result.dim,
"vector": result.vector,
},
)
__all__ = ["EmbeddingJob", "EmbeddingWorker"]
+42 -1
View File
@@ -13,6 +13,14 @@ Phase 1 simplifications (per plan §11.1, T27 will refine):
pass overwrites via a follow-up event. pass overwrites via a follow-up event.
- Witness flags are hard-coded ``[you=1, host=1, guest=0]``. Phase 2 will - Witness flags are hard-coded ``[you=1, host=1, guest=0]``. Phase 2 will
derive them from ``chat.guest_bot_id`` once a guest can be present. derive them from ``chat.guest_bot_id`` once a guest can be present.
T97 (Phase 4): each successful memory write also enqueues an
:class:`~chat.services.embedding_worker.EmbeddingJob` on the
lifespan-managed embedding worker, so the just-written memory gets a
vector indexed out-of-band. The hook is opt-in via the ``app`` kwarg —
callers without a FastAPI app handle (e.g. one-off scripts, isolated
unit tests) simply don't enqueue, and the backfill script can pick up
those rows later.
""" """
from __future__ import annotations from __future__ import annotations
@@ -20,6 +28,7 @@ from __future__ import annotations
from sqlite3 import Connection from sqlite3 import Connection
from chat.eventlog.log import append_and_apply from chat.eventlog.log import append_and_apply
from chat.services.embedding_worker import EmbeddingJob
def _write_one_memory( def _write_one_memory(
@@ -35,9 +44,16 @@ def _write_one_memory(
chat_clock_at: str | None, chat_clock_at: str | None,
source: str, source: str,
significance: int, significance: int,
app=None,
) -> tuple[int, int | None]: ) -> tuple[int, int | None]:
"""Append a single ``memory_written`` event for ``owner_id`` and return """Append a single ``memory_written`` event for ``owner_id`` and return
``(event_id, memory_id)`` for the projected row.""" ``(event_id, memory_id)`` for the projected row.
When ``app`` is provided and ``app.state.embedding_worker`` exists,
enqueue an :class:`EmbeddingJob` for the freshly-projected memory id
(T97). Skipped silently if the worker is absent or the projected row
can't be located — the backfill script handles missing-vector rows.
"""
payload: dict = { payload: dict = {
"owner_id": owner_id, "owner_id": owner_id,
"chat_id": chat_id, "chat_id": chat_id,
@@ -64,6 +80,23 @@ def _write_one_memory(
(owner_id, chat_id), (owner_id, chat_id),
).fetchone() ).fetchone()
memory_id = row[0] if row else None memory_id = row[0] if row else None
# T97: enqueue an embedding job for the just-written memory. The
# worker drains the queue out-of-band and emits an
# ``embedding_indexed`` event when the vector is ready. ``getattr``
# keeps this a no-op for callers without a wired-up app (scripts,
# tests) — the backfill script handles those rows.
if memory_id is not None and narrative_text and narrative_text.strip():
worker = (
getattr(app.state, "embedding_worker", None)
if app is not None
else None
)
if worker is not None:
worker.enqueue(
EmbeddingJob(memory_id=memory_id, text=narrative_text)
)
return event_id, memory_id return event_id, memory_id
@@ -79,6 +112,7 @@ def record_turn_memory_for_present(
source: str = "direct", source: str = "direct",
significance: int = 1, significance: int = 1,
you_present: bool = True, you_present: bool = True,
app=None,
) -> dict[str, tuple[int, int | None]]: ) -> dict[str, tuple[int, int | None]]:
"""Single entry-point for per-turn memory writes (T84). """Single entry-point for per-turn memory writes (T84).
@@ -97,6 +131,9 @@ def record_turn_memory_for_present(
with ``you_present=False`` is a programming error and raises with ``you_present=False`` is a programming error and raises
:class:`ValueError`. :class:`ValueError`.
When ``app`` is provided, each per-witness write also enqueues an
:class:`EmbeddingJob` on ``app.state.embedding_worker`` (T97).
Returns a mapping ``{bot_id: (event_id, memory_id)}`` so callers can Returns a mapping ``{bot_id: (event_id, memory_id)}`` so callers can
look up the freshly-projected memory id per owner without re-querying look up the freshly-projected memory id per owner without re-querying
the database. the database.
@@ -121,6 +158,7 @@ def record_turn_memory_for_present(
chat_clock_at=chat_clock_at, chat_clock_at=chat_clock_at,
source=source, source=source,
significance=significance, significance=significance,
app=app,
) )
if guest_bot_id is not None: if guest_bot_id is not None:
result[guest_bot_id] = _write_one_memory( result[guest_bot_id] = _write_one_memory(
@@ -135,6 +173,7 @@ def record_turn_memory_for_present(
chat_clock_at=chat_clock_at, chat_clock_at=chat_clock_at,
source=source, source=source,
significance=significance, significance=significance,
app=app,
) )
return result return result
@@ -150,6 +189,7 @@ def record_meanwhile_memory(
chat_clock_at: str | None = None, chat_clock_at: str | None = None,
source: str = "direct", source: str = "direct",
significance: int = 1, significance: int = 1,
app=None,
) -> dict[str, tuple[int, int | None]]: ) -> dict[str, tuple[int, int | None]]:
"""Backward-compat thin wrapper for meanwhile memory writes (T64, T84). """Backward-compat thin wrapper for meanwhile memory writes (T64, T84).
@@ -169,4 +209,5 @@ def record_meanwhile_memory(
source=source, source=source,
significance=significance, significance=significance,
you_present=False, you_present=False,
app=app,
) )
+3
View File
@@ -103,6 +103,7 @@ async def regenerate_assistant_turn(
chat_id: str, chat_id: str,
original_assistant_event_id: int, original_assistant_event_id: int,
edited_user_prose: str | None = None, edited_user_prose: str | None = None,
app=None,
) -> str: ) -> str:
"""Regenerate the assistant turn linked to ``original_assistant_event_id``. """Regenerate the assistant turn linked to ``original_assistant_event_id``.
@@ -414,6 +415,7 @@ async def regenerate_assistant_turn(
narrative_text=new_text, narrative_text=new_text,
scene_id=scene["id"] if scene else None, scene_id=scene["id"] if scene else None,
chat_clock_at=chat.get("time"), chat_clock_at=chat.get("time"),
app=app,
) )
last_at = chat.get("time") last_at = chat.get("time")
@@ -648,6 +650,7 @@ async def regenerate_assistant_turn(
narrative_text=interject_text, narrative_text=interject_text,
scene_id=scene["id"] if scene else None, scene_id=scene["id"] if scene else None,
chat_clock_at=chat.get("time"), chat_clock_at=chat.get("time"),
app=app,
) )
# Re-run the multi-pair state-update with the post-interjection # Re-run the multi-pair state-update with the post-interjection
+2
View File
@@ -993,6 +993,7 @@ async def skip_elision(
chat_id=chat_id, chat_id=chat_id,
new_time=new_time, new_time=new_time,
landing_state_hint=landing_state_hint, landing_state_hint=landing_state_hint,
app=request.app,
) )
except ChatNotFoundError as exc: except ChatNotFoundError as exc:
# Missing chat row: typed exception (T81) replaces the prior # Missing chat row: typed exception (T81) replaces the prior
@@ -1036,6 +1037,7 @@ async def skip_jump(
new_time=new_time, new_time=new_time,
notable_prose=notable_prose, notable_prose=notable_prose,
reset_activity=reset_flag, reset_activity=reset_flag,
app=request.app,
) )
except ChatNotFoundError as exc: except ChatNotFoundError as exc:
# Missing chat row: typed exception (T81) replaces the prior # Missing chat row: typed exception (T81) replaces the prior
+2
View File
@@ -131,6 +131,7 @@ async def process_meanwhile_turn(
*, *,
chat_id: str, chat_id: str,
prose: str, prose: str,
app=None,
) -> dict: ) -> dict:
"""Run one meanwhile turn end-to-end. """Run one meanwhile turn end-to-end.
@@ -314,6 +315,7 @@ async def process_meanwhile_turn(
narrative_text=text, narrative_text=text,
scene_id=scene_id, scene_id=scene_id,
chat_clock_at=chat.get("time"), chat_clock_at=chat.get("time"),
app=app,
) )
# 9. Post-turn state-update — exactly 2 directed pairs over the # 9. Post-turn state-update — exactly 2 directed pairs over the
+3
View File
@@ -91,6 +91,7 @@ async def process_elision_skip(
chat_id: str, chat_id: str,
new_time: str, new_time: str,
landing_state_hint: str = "", landing_state_hint: str = "",
app=None,
) -> dict: ) -> dict:
"""Run an elision skip end-to-end. """Run an elision skip end-to-end.
@@ -175,6 +176,7 @@ async def process_jump_skip(
new_time: str, new_time: str,
notable_prose: str = "", notable_prose: str = "",
reset_activity: bool = False, reset_activity: bool = False,
app=None,
) -> dict: ) -> dict:
"""Run a jump skip end-to-end. """Run a jump skip end-to-end.
@@ -254,6 +256,7 @@ async def process_jump_skip(
chat_clock_at=new_time, chat_clock_at=new_time,
source="synthesized", source="synthesized",
significance=mem.significance, significance=mem.significance,
app=app,
) )
narration = await narrate_skip( narration = await narrate_skip(
+5
View File
@@ -248,6 +248,7 @@ async def post_turn(
settings, settings,
chat_id=chat_id, chat_id=chat_id,
prose=prose, prose=prose,
app=request.app,
) )
except ValueError as exc: except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) raise HTTPException(status_code=400, detail=str(exc))
@@ -352,6 +353,7 @@ async def post_turn(
new_time=new_time, new_time=new_time,
landing_state_hint=getattr(parsed, "landing_state_hint", "") landing_state_hint=getattr(parsed, "landing_state_hint", "")
or "", or "",
app=request.app,
) )
except ChatNotFoundError as exc: except ChatNotFoundError as exc:
# Defensive: chat existence is checked above, so this only # Defensive: chat existence is checked above, so this only
@@ -512,6 +514,7 @@ async def post_turn(
narrative_text=primary_text, narrative_text=primary_text,
scene_id=scene["id"] if scene else None, scene_id=scene["id"] if scene else None,
chat_clock_at=chat.get("time"), chat_clock_at=chat.get("time"),
app=request.app,
) )
# 7b. Post-turn state-update pass (Requirements §3.4 / T40). All # 7b. Post-turn state-update pass (Requirements §3.4 / T40). All
@@ -746,6 +749,7 @@ async def post_turn(
narrative_text=interjection_text, narrative_text=interjection_text,
scene_id=scene["id"] if scene else None, scene_id=scene["id"] if scene else None,
chat_clock_at=chat.get("time"), chat_clock_at=chat.get("time"),
app=request.app,
) )
# T74.2: enqueue a significance pass for the interjection # T74.2: enqueue a significance pass for the interjection
@@ -1092,6 +1096,7 @@ async def regenerate_turn(
chat_id=chat_id, chat_id=chat_id,
original_assistant_event_id=event_id, original_assistant_event_id=event_id,
edited_user_prose=edited_prose, edited_user_prose=edited_prose,
app=request.app,
) )
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
+97
View File
@@ -0,0 +1,97 @@
"""Backfill embeddings for memories that lack them (T97, Phase 4).
Walks all memories where no row exists in the ``embeddings`` table. For
each, calls :func:`chat.services.embeddings.generate_embedding` and emits
an ``embedding_indexed`` event so the projector lands the vector.
Phase 4 ships the deterministic local pseudo-embedding so this script
runs synchronously without a network round-trip the LLMClient argument
is not needed on the pseudo path. Phase 4.5+ will need a real client.
Run from the repo root:
.venv/bin/python scripts/backfill_embeddings.py [--limit N] [--dry-run]
"""
from __future__ import annotations
import argparse
import asyncio
from chat.config import load_settings
from chat.db.connection import open_db
from chat.db.migrate import apply_migrations
from chat.eventlog.log import append_and_apply
from chat.services.embeddings import (
FALLBACK_EMBEDDING_MODEL,
generate_embedding,
)
# Trigger projector handler registration so ``append_and_apply`` lands
# the embedding rows correctly.
import chat.state.embeddings # noqa: F401
import chat.state.entities # noqa: F401
import chat.state.memory # noqa: F401
import chat.state.world # noqa: F401
async def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--limit",
type=int,
default=None,
help="Cap the number of memories backfilled in this run.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Print the count of memories needing embeddings, then exit.",
)
args = parser.parse_args()
settings = load_settings()
settings.db_path.parent.mkdir(parents=True, exist_ok=True)
apply_migrations(settings.db_path)
with open_db(settings.db_path) as conn:
sql = (
"SELECT m.id, m.pov_summary FROM memories m "
"LEFT JOIN embeddings e ON e.memory_id = m.id "
"WHERE e.memory_id IS NULL "
"ORDER BY m.id"
)
if args.limit is not None:
sql += f" LIMIT {int(args.limit)}"
rows = conn.execute(sql).fetchall()
print(f"Found {len(rows)} memories needing embeddings.")
if args.dry_run:
return
indexed = 0
skipped = 0
for memory_id, text in rows:
result = await generate_embedding(
client=None, # pseudo path: no client needed
text=text or "",
)
if result.model == FALLBACK_EMBEDDING_MODEL:
print(f" Skipping memory_id={memory_id} (empty text)")
skipped += 1
continue
append_and_apply(
conn,
kind="embedding_indexed",
payload={
"memory_id": memory_id,
"model": result.model,
"dim": result.dim,
"vector": result.vector,
},
)
indexed += 1
print(f" Indexed memory_id={memory_id}")
print(f"Done. Indexed {indexed}, skipped {skipped}.")
if __name__ == "__main__":
asyncio.run(main())
+185
View File
@@ -0,0 +1,185 @@
"""Embedding worker (T97, Phase 4).
The worker drains a queue of EmbeddingJobs and emits ``embedding_indexed``
events. Mirrors test_significance.py's BackgroundWorker tests in shape:
seed a memory, enqueue jobs, call ``stop()`` to drain via sentinel, then
assert on the projected ``embeddings`` table and the underlying event_log.
"""
from __future__ import annotations
from pathlib import Path
from chat.db.connection import open_db
from chat.db.migrate import apply_migrations
from chat.eventlog.log import append_event
from chat.eventlog.projector import project
from chat.services.embedding_worker import EmbeddingJob, EmbeddingWorker
from chat.services.embeddings import (
DEFAULT_EMBEDDING_MODEL,
EmbeddingResult,
FALLBACK_EMBEDDING_MODEL,
)
# Trigger handler registration for projection.
import chat.state.embeddings # noqa: F401
import chat.state.entities # noqa: F401
import chat.state.memory # noqa: F401
import chat.state.world # noqa: F401
def _seed_memories(db_path: Path, count: int) -> list[int]:
"""Seed ``count`` memory rows for ``bot_a`` and return their ids."""
with open_db(db_path) as conn:
append_event(
conn,
kind="bot_authored",
payload={
"id": "bot_a",
"name": "BotA",
"persona": "...",
"voice_samples": [],
"traits": [],
"backstory": "",
"initial_relationship_to_you": "",
"kickoff_prose": "",
},
)
append_event(
conn,
kind="chat_created",
payload={
"id": "chat_bot_a",
"host_bot_id": "bot_a",
"initial_time": "2026-04-26T20:00:00+00:00",
"narrative_anchor": "Day 1",
"weather": "",
},
)
for i in range(count):
append_event(
conn,
kind="memory_written",
payload={
"owner_id": "bot_a",
"chat_id": "chat_bot_a",
"pov_summary": f"memory text {i}",
"witness_you": 1,
"witness_host": 1,
"witness_guest": 0,
"source": "direct",
"reliability": 1.0,
"significance": 1,
"pinned": 0,
"auto_pinned": 0,
},
)
project(conn)
return [
r[0]
for r in conn.execute(
"SELECT id FROM memories WHERE owner_id = 'bot_a' ORDER BY id"
).fetchall()
]
async def test_worker_drains_jobs_and_emits_indexed_events(tmp_path):
"""Three jobs in -> three ``embedding_indexed`` events out, all
projected into the ``embeddings`` table."""
db = tmp_path / "t.db"
apply_migrations(db)
memory_ids = _seed_memories(db, count=3)
worker = EmbeddingWorker(
conn_factory=lambda: open_db(db),
client=None, # pseudo path — no client needed
)
await worker.start()
for mid in memory_ids:
worker.enqueue(EmbeddingJob(memory_id=mid, text=f"text-{mid}"))
await worker.stop()
with open_db(db) as conn:
# Three embedding_indexed events landed.
cur = conn.execute(
"SELECT COUNT(*) FROM event_log WHERE kind = 'embedding_indexed'"
)
assert cur.fetchone()[0] == 3
# Three rows in the embeddings table — one per memory.
cur = conn.execute(
"SELECT memory_id, model, dim FROM embeddings ORDER BY memory_id"
)
rows = cur.fetchall()
assert len(rows) == 3
for (mid, model, dim), expected_mid in zip(rows, memory_ids):
assert mid == expected_mid
assert model == DEFAULT_EMBEDDING_MODEL
assert dim > 0
async def test_worker_skips_fallback_results(tmp_path, monkeypatch):
"""A fallback EmbeddingResult must NOT produce an embedding_indexed
event backfill can retry later when a real embedding is available."""
db = tmp_path / "t.db"
apply_migrations(db)
memory_ids = _seed_memories(db, count=1)
async def _fake_generate(client, *, text, model, dim, timeout_s=30.0):
return EmbeddingResult(
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim
)
# Patch the symbol the worker resolved at import time.
import chat.services.embedding_worker as worker_mod
monkeypatch.setattr(worker_mod, "generate_embedding", _fake_generate)
worker = EmbeddingWorker(
conn_factory=lambda: open_db(db),
client=None,
)
await worker.start()
worker.enqueue(EmbeddingJob(memory_id=memory_ids[0], text="anything"))
await worker.stop()
with open_db(db) as conn:
cur = conn.execute(
"SELECT COUNT(*) FROM event_log WHERE kind = 'embedding_indexed'"
)
assert cur.fetchone()[0] == 0
cur = conn.execute("SELECT COUNT(*) FROM embeddings")
assert cur.fetchone()[0] == 0
async def test_worker_handles_concurrent_jobs_serially(tmp_path):
"""Five jobs queued back-to-back must process in FIFO order — the
single-task design respects the Featherless 2-conn cap (and keeps
event_log ordering deterministic)."""
db = tmp_path / "t.db"
apply_migrations(db)
memory_ids = _seed_memories(db, count=5)
worker = EmbeddingWorker(
conn_factory=lambda: open_db(db),
client=None,
)
await worker.start()
# Enqueue all five before yielding to the loop — exercises the queue
# rather than a one-at-a-time drain.
for mid in memory_ids:
worker.enqueue(EmbeddingJob(memory_id=mid, text=f"text-{mid}"))
await worker.stop()
with open_db(db) as conn:
# Events landed in enqueue order (FIFO).
cur = conn.execute(
"SELECT json_extract(payload_json, '$.memory_id') "
"FROM event_log WHERE kind = 'embedding_indexed' "
"ORDER BY id"
)
seen = [r[0] for r in cur.fetchall()]
assert seen == memory_ids
# All five embeddings projected.
cur = conn.execute("SELECT COUNT(*) FROM embeddings")
assert cur.fetchone()[0] == 5
+46
View File
@@ -540,3 +540,49 @@ def test_record_turn_memory_you_present_false_requires_guest(tmp_path):
narrative_text="invalid", narrative_text="invalid",
you_present=False, you_present=False,
) )
# ---------------------------------------------------------------------------
# T97: embedding-worker enqueue hook.
# ---------------------------------------------------------------------------
def test_record_turn_memory_enqueues_embedding_job(tmp_path):
"""When ``app.state.embedding_worker`` is wired, every per-witness
write enqueues an :class:`EmbeddingJob` carrying the freshly-projected
memory id and the narrative text. Two-bot turn -> two jobs."""
from types import SimpleNamespace
from chat.services.embedding_worker import EmbeddingJob
db = tmp_path / "t.db"
apply_migrations(db)
_seed_two_bots(db)
captured: list[EmbeddingJob] = []
class _StubWorker:
def enqueue(self, job: EmbeddingJob) -> None:
captured.append(job)
fake_app = SimpleNamespace(
state=SimpleNamespace(embedding_worker=_StubWorker())
)
with open_db(db) as conn:
result = record_turn_memory_for_present(
conn,
chat_id="chat_ab",
host_bot_id="bot_a",
guest_bot_id="bot_b",
narrative_text="Both bots witness this beat.",
app=fake_app,
)
# One job per witness — host first, then guest (matches result dict
# insertion order in record_turn_memory_for_present).
assert len(captured) == 2
expected_ids = {result["bot_a"][1], result["bot_b"][1]}
assert {job.memory_id for job in captured} == expected_ids
for job in captured:
assert job.text == "Both bots witness this beat."
+180
View File
@@ -0,0 +1,180 @@
"""Phase 4 cross-feature integration tests (T97 follow-up).
Wave 8 / T101 will populate this file with the full Phase 4 retrieval +
embedding integration suite. For now this houses a single test pinning
the T97.5 wiring: the production turn route plumbs ``app=request.app``
all the way through ``record_turn_memory_for_present`` so the embedding
worker actually receives jobs in production. Without this fix-up the
plumbing added in T97 was dormant every per-witness write took the
no-app branch and silently dropped the embed enqueue.
The test monkeypatches ``app.state.embedding_worker.enqueue`` to record
jobs (rather than draining the worker mid-test) so the assertion is
deterministic and free of asyncio-timing flakiness inside FastAPI's
TestClient. The bug we're guarding against is "did the call site pass
``app`` at all" — the worker's drain path is exercised in
:mod:`tests.test_embedding_worker`, so duplicating that here would add
no coverage.
"""
from __future__ import annotations
import json
from pathlib import Path
import pytest
from fastapi.testclient import TestClient
from chat.app import app
from chat.db.connection import open_db
from chat.eventlog.log import append_event
from chat.eventlog.projector import project
from chat.llm.mock import MockLLMClient
def _zero_state() -> str:
return json.dumps(
{"affinity_delta": 0, "trust_delta": 0, "knowledge_facts": []}
)
def _override_llm(canned: list[str]) -> MockLLMClient:
from chat.web.kickoff import get_llm_client
mock = MockLLMClient(canned=list(canned))
app.dependency_overrides[get_llm_client] = lambda: mock
return mock
@pytest.fixture
def app_state_setup(tmp_path, monkeypatch):
cfg = tmp_path / "config.toml"
cfg.write_text('featherless_api_key = "test"\n')
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
db = tmp_path / "test.db"
monkeypatch.setenv("CHAT_DB_PATH", str(db))
with TestClient(app) as c:
# The background worker is disabled so the canned-response queue
# is consumed only by the request path. The embedding worker
# stays "started" but its loop won't observe the captured
# enqueues — we replace ``enqueue`` on the worker instance below.
app.state.background_worker.enabled = False
yield c
app.dependency_overrides.clear()
def _seed(db_path: Path) -> None:
"""Mirror of ``tests/test_turn_flow.py::_seed`` — single bot + chat
+ edge + activities so the prompt assembler has something to render.
"""
with open_db(db_path) as conn:
append_event(
conn,
kind="bot_authored",
payload={
"id": "bot_a",
"name": "BotA",
"persona": "thoughtful, observant",
"voice_samples": [],
"traits": [],
"backstory": "",
"initial_relationship_to_you": "",
"kickoff_prose": "...",
},
)
append_event(
conn,
kind="chat_created",
payload={
"id": "chat_bot_a",
"host_bot_id": "bot_a",
"initial_time": "2026-04-26T20:00:00+00:00",
"narrative_anchor": "Day 1",
"weather": "",
},
)
append_event(
conn,
kind="edge_update",
payload={
"source_id": "bot_a",
"target_id": "you",
"chat_id": "chat_bot_a",
"knowledge_facts": ["coworker"],
},
)
for entity_id, verb in [("you", "talking"), ("bot_a", "listening")]:
append_event(
conn,
kind="activity_change",
payload={
"entity_id": entity_id,
"posture": "sitting",
"action": {
"verb": verb,
"interruptible": True,
"required_attention": "low",
"expected_duration": "ongoing",
},
"attention": "",
"holding": [],
"status": {},
},
)
project(conn)
def test_post_turn_embeddings_indexed_via_worker_hook(
app_state_setup, tmp_path
):
"""POST a turn; the route must pass ``app=request.app`` into
``record_turn_memory_for_present`` so the per-witness write enqueues
an :class:`EmbeddingJob` on ``app.state.embedding_worker``.
Without the T97.5 wiring this test fails: the call site previously
omitted ``app=`` and the helper's ``app is None`` branch silently
skipped every enqueue. We monkeypatch ``enqueue`` on the live
embedding worker (rather than draining the queue mid-request) so the
assertion does not depend on asyncio scheduling inside the
TestClient the bug is in the wiring, and the wiring is what we
pin. The drain path is covered separately in
:mod:`tests.test_embedding_worker`.
"""
_seed(tmp_path / "test.db")
canned_parse = json.dumps(
{"segments": [{"kind": "dialogue", "text": "hello"}]}
)
_override_llm(
[canned_parse, "Hi there.", _zero_state(), _zero_state()]
)
captured: list = []
worker = app.state.embedding_worker
original_enqueue = worker.enqueue
worker.enqueue = captured.append # type: ignore[assignment]
try:
response = app_state_setup.post(
"/chats/chat_bot_a/turns", data={"prose": "hello"}
)
assert response.status_code == 204
finally:
worker.enqueue = original_enqueue # type: ignore[assignment]
app.dependency_overrides.clear()
# Single-bot turn -> one ``memory_written`` -> one EmbeddingJob.
# The job's ``memory_id`` should match the freshly-projected memory
# row, and its ``text`` should carry the assistant's narrative text.
assert len(captured) == 1
job = captured[0]
assert job.text == "Hi there."
with open_db(tmp_path / "test.db") as conn:
memory_ids = [
r[0]
for r in conn.execute(
"SELECT id FROM memories WHERE owner_id = ?",
("bot_a",),
).fetchall()
]
assert job.memory_id in memory_ids