merge: T97 memory write hook + embedding worker + backfill + call-site wiring
This commit is contained in:
+15
@@ -16,6 +16,7 @@ from chat.db.migrate import apply_migrations
|
||||
from chat.eventlog.log import read_events
|
||||
from chat.eventlog.projector import apply_event
|
||||
from chat.services.background import BackgroundWorker
|
||||
from chat.services.embedding_worker import EmbeddingWorker
|
||||
from chat.services.snapshot import latest_snapshot_path, restore_from_snapshot
|
||||
|
||||
# Trigger handler registration:
|
||||
@@ -85,9 +86,23 @@ async def lifespan(app: FastAPI):
|
||||
await worker.start()
|
||||
app.state.background_worker = worker
|
||||
|
||||
# T97: separate worker for the async embedding pass. Each
|
||||
# ``memory_written`` enqueues an EmbeddingJob; the worker drains the
|
||||
# queue, calls ``generate_embedding``, and emits ``embedding_indexed``.
|
||||
# Phase 4's pseudo-embedding path is local so the worker doesn't need
|
||||
# an LLM client; we still pass one so the Phase 4.5 swap to a real
|
||||
# model is a one-line change.
|
||||
embedding_worker = EmbeddingWorker(
|
||||
conn_factory=lambda: open_db(settings.db_path),
|
||||
client=_factory(),
|
||||
)
|
||||
await embedding_worker.start()
|
||||
app.state.embedding_worker = embedding_worker
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await embedding_worker.stop()
|
||||
await worker.stop()
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
"""Embedding worker (T97, Phase 4).
|
||||
|
||||
Drains a queue of embedding jobs. Each job carries a memory id and the
|
||||
narrative text to embed; the worker calls
|
||||
:func:`chat.services.embeddings.generate_embedding` and emits an
|
||||
``embedding_indexed`` event so the projector lands the vector in the
|
||||
``embeddings`` table.
|
||||
|
||||
Mirrors the :class:`chat.services.background.BackgroundWorker` pattern:
|
||||
single asyncio task, sentinel-based shutdown, exceptions are caught and
|
||||
logged so a flaky embedding call doesn't take down the worker. Each job
|
||||
opens its own SQLite connection via ``conn_factory`` — the request path
|
||||
and the worker do not share connections.
|
||||
|
||||
Featherless concurrency (the 2-conn cap) is respected by virtue of the
|
||||
single-task design: jobs run strictly serially. Phase 4's pseudo-embedding
|
||||
path is local and synchronous so this is largely moot, but the pattern
|
||||
is in place for the Phase 4.5+ real-embedding swap.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from sqlite3 import Connection
|
||||
from typing import Callable
|
||||
|
||||
from chat.eventlog.log import append_and_apply
|
||||
from chat.services.embeddings import (
|
||||
DEFAULT_EMBEDDING_DIM,
|
||||
DEFAULT_EMBEDDING_MODEL,
|
||||
FALLBACK_EMBEDDING_MODEL,
|
||||
generate_embedding,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingJob:
|
||||
"""One unit of work for the embedding worker.
|
||||
|
||||
``memory_id`` is the row to attach the vector to; ``text`` is the
|
||||
narrative text to embed (typically ``memories.pov_summary``).
|
||||
"""
|
||||
|
||||
memory_id: int
|
||||
text: str
|
||||
|
||||
|
||||
class EmbeddingWorker:
|
||||
"""asyncio.Queue-backed single-worker task for embedding generation.
|
||||
|
||||
Started on app startup; ``stop()`` enqueues a sentinel and awaits
|
||||
the task so any in-flight job has a chance to finish. Pending jobs
|
||||
after the sentinel are dropped on shutdown.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
conn_factory: Callable[[], Connection],
|
||||
client, # LLMClient | None — unused on the pseudo path.
|
||||
model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
dim: int = DEFAULT_EMBEDDING_DIM,
|
||||
enabled: bool = True,
|
||||
) -> None:
|
||||
self._queue: asyncio.Queue[EmbeddingJob | None] = asyncio.Queue()
|
||||
self._conn_factory = conn_factory
|
||||
self._client = client
|
||||
self._model = model
|
||||
self._dim = dim
|
||||
self._task: asyncio.Task | None = None
|
||||
self.enabled = enabled
|
||||
|
||||
def enqueue(self, job: EmbeddingJob) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
self._queue.put_nowait(job)
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._task is None:
|
||||
self._task = asyncio.create_task(self._run())
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._task is None:
|
||||
return
|
||||
self._queue.put_nowait(None) # sentinel
|
||||
await self._task
|
||||
self._task = None
|
||||
|
||||
async def _run(self) -> None:
|
||||
while True:
|
||||
job = await self._queue.get()
|
||||
if job is None:
|
||||
return
|
||||
try:
|
||||
await self._process(job)
|
||||
except Exception as exc: # noqa: BLE001 — worker must not die
|
||||
log.warning(
|
||||
"embedding worker failed for memory_id=%s: %s",
|
||||
job.memory_id,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _process(self, job: EmbeddingJob) -> None:
|
||||
result = await generate_embedding(
|
||||
self._client,
|
||||
text=job.text,
|
||||
model=self._model,
|
||||
dim=self._dim,
|
||||
)
|
||||
if result.model == FALLBACK_EMBEDDING_MODEL:
|
||||
# Don't index a fallback (zero) vector — the backfill script
|
||||
# can retry later once a real embedding is available.
|
||||
log.debug(
|
||||
"embedding worker skipping fallback result for memory_id=%s",
|
||||
job.memory_id,
|
||||
)
|
||||
return
|
||||
with self._conn_factory() as conn:
|
||||
append_and_apply(
|
||||
conn,
|
||||
kind="embedding_indexed",
|
||||
payload={
|
||||
"memory_id": job.memory_id,
|
||||
"model": result.model,
|
||||
"dim": result.dim,
|
||||
"vector": result.vector,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["EmbeddingJob", "EmbeddingWorker"]
|
||||
@@ -13,6 +13,14 @@ Phase 1 simplifications (per plan §11.1, T27 will refine):
|
||||
pass overwrites via a follow-up event.
|
||||
- Witness flags are hard-coded ``[you=1, host=1, guest=0]``. Phase 2 will
|
||||
derive them from ``chat.guest_bot_id`` once a guest can be present.
|
||||
|
||||
T97 (Phase 4): each successful memory write also enqueues an
|
||||
:class:`~chat.services.embedding_worker.EmbeddingJob` on the
|
||||
lifespan-managed embedding worker, so the just-written memory gets a
|
||||
vector indexed out-of-band. The hook is opt-in via the ``app`` kwarg —
|
||||
callers without a FastAPI app handle (e.g. one-off scripts, isolated
|
||||
unit tests) simply don't enqueue, and the backfill script can pick up
|
||||
those rows later.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -20,6 +28,7 @@ from __future__ import annotations
|
||||
from sqlite3 import Connection
|
||||
|
||||
from chat.eventlog.log import append_and_apply
|
||||
from chat.services.embedding_worker import EmbeddingJob
|
||||
|
||||
|
||||
def _write_one_memory(
|
||||
@@ -35,9 +44,16 @@ def _write_one_memory(
|
||||
chat_clock_at: str | None,
|
||||
source: str,
|
||||
significance: int,
|
||||
app=None,
|
||||
) -> tuple[int, int | None]:
|
||||
"""Append a single ``memory_written`` event for ``owner_id`` and return
|
||||
``(event_id, memory_id)`` for the projected row."""
|
||||
``(event_id, memory_id)`` for the projected row.
|
||||
|
||||
When ``app`` is provided and ``app.state.embedding_worker`` exists,
|
||||
enqueue an :class:`EmbeddingJob` for the freshly-projected memory id
|
||||
(T97). Skipped silently if the worker is absent or the projected row
|
||||
can't be located — the backfill script handles missing-vector rows.
|
||||
"""
|
||||
payload: dict = {
|
||||
"owner_id": owner_id,
|
||||
"chat_id": chat_id,
|
||||
@@ -64,6 +80,23 @@ def _write_one_memory(
|
||||
(owner_id, chat_id),
|
||||
).fetchone()
|
||||
memory_id = row[0] if row else None
|
||||
|
||||
# T97: enqueue an embedding job for the just-written memory. The
|
||||
# worker drains the queue out-of-band and emits an
|
||||
# ``embedding_indexed`` event when the vector is ready. ``getattr``
|
||||
# keeps this a no-op for callers without a wired-up app (scripts,
|
||||
# tests) — the backfill script handles those rows.
|
||||
if memory_id is not None and narrative_text and narrative_text.strip():
|
||||
worker = (
|
||||
getattr(app.state, "embedding_worker", None)
|
||||
if app is not None
|
||||
else None
|
||||
)
|
||||
if worker is not None:
|
||||
worker.enqueue(
|
||||
EmbeddingJob(memory_id=memory_id, text=narrative_text)
|
||||
)
|
||||
|
||||
return event_id, memory_id
|
||||
|
||||
|
||||
@@ -79,6 +112,7 @@ def record_turn_memory_for_present(
|
||||
source: str = "direct",
|
||||
significance: int = 1,
|
||||
you_present: bool = True,
|
||||
app=None,
|
||||
) -> dict[str, tuple[int, int | None]]:
|
||||
"""Single entry-point for per-turn memory writes (T84).
|
||||
|
||||
@@ -97,6 +131,9 @@ def record_turn_memory_for_present(
|
||||
with ``you_present=False`` is a programming error and raises
|
||||
:class:`ValueError`.
|
||||
|
||||
When ``app`` is provided, each per-witness write also enqueues an
|
||||
:class:`EmbeddingJob` on ``app.state.embedding_worker`` (T97).
|
||||
|
||||
Returns a mapping ``{bot_id: (event_id, memory_id)}`` so callers can
|
||||
look up the freshly-projected memory id per owner without re-querying
|
||||
the database.
|
||||
@@ -121,6 +158,7 @@ def record_turn_memory_for_present(
|
||||
chat_clock_at=chat_clock_at,
|
||||
source=source,
|
||||
significance=significance,
|
||||
app=app,
|
||||
)
|
||||
if guest_bot_id is not None:
|
||||
result[guest_bot_id] = _write_one_memory(
|
||||
@@ -135,6 +173,7 @@ def record_turn_memory_for_present(
|
||||
chat_clock_at=chat_clock_at,
|
||||
source=source,
|
||||
significance=significance,
|
||||
app=app,
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -150,6 +189,7 @@ def record_meanwhile_memory(
|
||||
chat_clock_at: str | None = None,
|
||||
source: str = "direct",
|
||||
significance: int = 1,
|
||||
app=None,
|
||||
) -> dict[str, tuple[int, int | None]]:
|
||||
"""Backward-compat thin wrapper for meanwhile memory writes (T64, T84).
|
||||
|
||||
@@ -169,4 +209,5 @@ def record_meanwhile_memory(
|
||||
source=source,
|
||||
significance=significance,
|
||||
you_present=False,
|
||||
app=app,
|
||||
)
|
||||
|
||||
@@ -103,6 +103,7 @@ async def regenerate_assistant_turn(
|
||||
chat_id: str,
|
||||
original_assistant_event_id: int,
|
||||
edited_user_prose: str | None = None,
|
||||
app=None,
|
||||
) -> str:
|
||||
"""Regenerate the assistant turn linked to ``original_assistant_event_id``.
|
||||
|
||||
@@ -414,6 +415,7 @@ async def regenerate_assistant_turn(
|
||||
narrative_text=new_text,
|
||||
scene_id=scene["id"] if scene else None,
|
||||
chat_clock_at=chat.get("time"),
|
||||
app=app,
|
||||
)
|
||||
|
||||
last_at = chat.get("time")
|
||||
@@ -648,6 +650,7 @@ async def regenerate_assistant_turn(
|
||||
narrative_text=interject_text,
|
||||
scene_id=scene["id"] if scene else None,
|
||||
chat_clock_at=chat.get("time"),
|
||||
app=app,
|
||||
)
|
||||
|
||||
# Re-run the multi-pair state-update with the post-interjection
|
||||
|
||||
@@ -993,6 +993,7 @@ async def skip_elision(
|
||||
chat_id=chat_id,
|
||||
new_time=new_time,
|
||||
landing_state_hint=landing_state_hint,
|
||||
app=request.app,
|
||||
)
|
||||
except ChatNotFoundError as exc:
|
||||
# Missing chat row: typed exception (T81) replaces the prior
|
||||
@@ -1036,6 +1037,7 @@ async def skip_jump(
|
||||
new_time=new_time,
|
||||
notable_prose=notable_prose,
|
||||
reset_activity=reset_flag,
|
||||
app=request.app,
|
||||
)
|
||||
except ChatNotFoundError as exc:
|
||||
# Missing chat row: typed exception (T81) replaces the prior
|
||||
|
||||
@@ -131,6 +131,7 @@ async def process_meanwhile_turn(
|
||||
*,
|
||||
chat_id: str,
|
||||
prose: str,
|
||||
app=None,
|
||||
) -> dict:
|
||||
"""Run one meanwhile turn end-to-end.
|
||||
|
||||
@@ -314,6 +315,7 @@ async def process_meanwhile_turn(
|
||||
narrative_text=text,
|
||||
scene_id=scene_id,
|
||||
chat_clock_at=chat.get("time"),
|
||||
app=app,
|
||||
)
|
||||
|
||||
# 9. Post-turn state-update — exactly 2 directed pairs over the
|
||||
|
||||
@@ -91,6 +91,7 @@ async def process_elision_skip(
|
||||
chat_id: str,
|
||||
new_time: str,
|
||||
landing_state_hint: str = "",
|
||||
app=None,
|
||||
) -> dict:
|
||||
"""Run an elision skip end-to-end.
|
||||
|
||||
@@ -175,6 +176,7 @@ async def process_jump_skip(
|
||||
new_time: str,
|
||||
notable_prose: str = "",
|
||||
reset_activity: bool = False,
|
||||
app=None,
|
||||
) -> dict:
|
||||
"""Run a jump skip end-to-end.
|
||||
|
||||
@@ -254,6 +256,7 @@ async def process_jump_skip(
|
||||
chat_clock_at=new_time,
|
||||
source="synthesized",
|
||||
significance=mem.significance,
|
||||
app=app,
|
||||
)
|
||||
|
||||
narration = await narrate_skip(
|
||||
|
||||
@@ -248,6 +248,7 @@ async def post_turn(
|
||||
settings,
|
||||
chat_id=chat_id,
|
||||
prose=prose,
|
||||
app=request.app,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
@@ -352,6 +353,7 @@ async def post_turn(
|
||||
new_time=new_time,
|
||||
landing_state_hint=getattr(parsed, "landing_state_hint", "")
|
||||
or "",
|
||||
app=request.app,
|
||||
)
|
||||
except ChatNotFoundError as exc:
|
||||
# Defensive: chat existence is checked above, so this only
|
||||
@@ -512,6 +514,7 @@ async def post_turn(
|
||||
narrative_text=primary_text,
|
||||
scene_id=scene["id"] if scene else None,
|
||||
chat_clock_at=chat.get("time"),
|
||||
app=request.app,
|
||||
)
|
||||
|
||||
# 7b. Post-turn state-update pass (Requirements §3.4 / T40). All
|
||||
@@ -746,6 +749,7 @@ async def post_turn(
|
||||
narrative_text=interjection_text,
|
||||
scene_id=scene["id"] if scene else None,
|
||||
chat_clock_at=chat.get("time"),
|
||||
app=request.app,
|
||||
)
|
||||
|
||||
# T74.2: enqueue a significance pass for the interjection
|
||||
@@ -1092,6 +1096,7 @@ async def regenerate_turn(
|
||||
chat_id=chat_id,
|
||||
original_assistant_event_id=event_id,
|
||||
edited_user_prose=edited_prose,
|
||||
app=request.app,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
"""Backfill embeddings for memories that lack them (T97, Phase 4).
|
||||
|
||||
Walks all memories where no row exists in the ``embeddings`` table. For
|
||||
each, calls :func:`chat.services.embeddings.generate_embedding` and emits
|
||||
an ``embedding_indexed`` event so the projector lands the vector.
|
||||
|
||||
Phase 4 ships the deterministic local pseudo-embedding so this script
|
||||
runs synchronously without a network round-trip — the LLMClient argument
|
||||
is not needed on the pseudo path. Phase 4.5+ will need a real client.
|
||||
|
||||
Run from the repo root:
|
||||
.venv/bin/python scripts/backfill_embeddings.py [--limit N] [--dry-run]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from chat.config import load_settings
|
||||
from chat.db.connection import open_db
|
||||
from chat.db.migrate import apply_migrations
|
||||
from chat.eventlog.log import append_and_apply
|
||||
from chat.services.embeddings import (
|
||||
FALLBACK_EMBEDDING_MODEL,
|
||||
generate_embedding,
|
||||
)
|
||||
|
||||
# Trigger projector handler registration so ``append_and_apply`` lands
|
||||
# the embedding rows correctly.
|
||||
import chat.state.embeddings # noqa: F401
|
||||
import chat.state.entities # noqa: F401
|
||||
import chat.state.memory # noqa: F401
|
||||
import chat.state.world # noqa: F401
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Cap the number of memories backfilled in this run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Print the count of memories needing embeddings, then exit.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
settings = load_settings()
|
||||
settings.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
apply_migrations(settings.db_path)
|
||||
|
||||
with open_db(settings.db_path) as conn:
|
||||
sql = (
|
||||
"SELECT m.id, m.pov_summary FROM memories m "
|
||||
"LEFT JOIN embeddings e ON e.memory_id = m.id "
|
||||
"WHERE e.memory_id IS NULL "
|
||||
"ORDER BY m.id"
|
||||
)
|
||||
if args.limit is not None:
|
||||
sql += f" LIMIT {int(args.limit)}"
|
||||
rows = conn.execute(sql).fetchall()
|
||||
print(f"Found {len(rows)} memories needing embeddings.")
|
||||
if args.dry_run:
|
||||
return
|
||||
|
||||
indexed = 0
|
||||
skipped = 0
|
||||
for memory_id, text in rows:
|
||||
result = await generate_embedding(
|
||||
client=None, # pseudo path: no client needed
|
||||
text=text or "",
|
||||
)
|
||||
if result.model == FALLBACK_EMBEDDING_MODEL:
|
||||
print(f" Skipping memory_id={memory_id} (empty text)")
|
||||
skipped += 1
|
||||
continue
|
||||
append_and_apply(
|
||||
conn,
|
||||
kind="embedding_indexed",
|
||||
payload={
|
||||
"memory_id": memory_id,
|
||||
"model": result.model,
|
||||
"dim": result.dim,
|
||||
"vector": result.vector,
|
||||
},
|
||||
)
|
||||
indexed += 1
|
||||
print(f" Indexed memory_id={memory_id}")
|
||||
print(f"Done. Indexed {indexed}, skipped {skipped}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,185 @@
|
||||
"""Embedding worker (T97, Phase 4).
|
||||
|
||||
The worker drains a queue of EmbeddingJobs and emits ``embedding_indexed``
|
||||
events. Mirrors test_significance.py's BackgroundWorker tests in shape:
|
||||
seed a memory, enqueue jobs, call ``stop()`` to drain via sentinel, then
|
||||
assert on the projected ``embeddings`` table and the underlying event_log.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from chat.db.connection import open_db
|
||||
from chat.db.migrate import apply_migrations
|
||||
from chat.eventlog.log import append_event
|
||||
from chat.eventlog.projector import project
|
||||
from chat.services.embedding_worker import EmbeddingJob, EmbeddingWorker
|
||||
from chat.services.embeddings import (
|
||||
DEFAULT_EMBEDDING_MODEL,
|
||||
EmbeddingResult,
|
||||
FALLBACK_EMBEDDING_MODEL,
|
||||
)
|
||||
|
||||
# Trigger handler registration for projection.
|
||||
import chat.state.embeddings # noqa: F401
|
||||
import chat.state.entities # noqa: F401
|
||||
import chat.state.memory # noqa: F401
|
||||
import chat.state.world # noqa: F401
|
||||
|
||||
|
||||
def _seed_memories(db_path: Path, count: int) -> list[int]:
|
||||
"""Seed ``count`` memory rows for ``bot_a`` and return their ids."""
|
||||
with open_db(db_path) as conn:
|
||||
append_event(
|
||||
conn,
|
||||
kind="bot_authored",
|
||||
payload={
|
||||
"id": "bot_a",
|
||||
"name": "BotA",
|
||||
"persona": "...",
|
||||
"voice_samples": [],
|
||||
"traits": [],
|
||||
"backstory": "",
|
||||
"initial_relationship_to_you": "",
|
||||
"kickoff_prose": "",
|
||||
},
|
||||
)
|
||||
append_event(
|
||||
conn,
|
||||
kind="chat_created",
|
||||
payload={
|
||||
"id": "chat_bot_a",
|
||||
"host_bot_id": "bot_a",
|
||||
"initial_time": "2026-04-26T20:00:00+00:00",
|
||||
"narrative_anchor": "Day 1",
|
||||
"weather": "",
|
||||
},
|
||||
)
|
||||
for i in range(count):
|
||||
append_event(
|
||||
conn,
|
||||
kind="memory_written",
|
||||
payload={
|
||||
"owner_id": "bot_a",
|
||||
"chat_id": "chat_bot_a",
|
||||
"pov_summary": f"memory text {i}",
|
||||
"witness_you": 1,
|
||||
"witness_host": 1,
|
||||
"witness_guest": 0,
|
||||
"source": "direct",
|
||||
"reliability": 1.0,
|
||||
"significance": 1,
|
||||
"pinned": 0,
|
||||
"auto_pinned": 0,
|
||||
},
|
||||
)
|
||||
project(conn)
|
||||
return [
|
||||
r[0]
|
||||
for r in conn.execute(
|
||||
"SELECT id FROM memories WHERE owner_id = 'bot_a' ORDER BY id"
|
||||
).fetchall()
|
||||
]
|
||||
|
||||
|
||||
async def test_worker_drains_jobs_and_emits_indexed_events(tmp_path):
|
||||
"""Three jobs in -> three ``embedding_indexed`` events out, all
|
||||
projected into the ``embeddings`` table."""
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
memory_ids = _seed_memories(db, count=3)
|
||||
|
||||
worker = EmbeddingWorker(
|
||||
conn_factory=lambda: open_db(db),
|
||||
client=None, # pseudo path — no client needed
|
||||
)
|
||||
await worker.start()
|
||||
for mid in memory_ids:
|
||||
worker.enqueue(EmbeddingJob(memory_id=mid, text=f"text-{mid}"))
|
||||
await worker.stop()
|
||||
|
||||
with open_db(db) as conn:
|
||||
# Three embedding_indexed events landed.
|
||||
cur = conn.execute(
|
||||
"SELECT COUNT(*) FROM event_log WHERE kind = 'embedding_indexed'"
|
||||
)
|
||||
assert cur.fetchone()[0] == 3
|
||||
# Three rows in the embeddings table — one per memory.
|
||||
cur = conn.execute(
|
||||
"SELECT memory_id, model, dim FROM embeddings ORDER BY memory_id"
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
assert len(rows) == 3
|
||||
for (mid, model, dim), expected_mid in zip(rows, memory_ids):
|
||||
assert mid == expected_mid
|
||||
assert model == DEFAULT_EMBEDDING_MODEL
|
||||
assert dim > 0
|
||||
|
||||
|
||||
async def test_worker_skips_fallback_results(tmp_path, monkeypatch):
|
||||
"""A fallback EmbeddingResult must NOT produce an embedding_indexed
|
||||
event — backfill can retry later when a real embedding is available."""
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
memory_ids = _seed_memories(db, count=1)
|
||||
|
||||
async def _fake_generate(client, *, text, model, dim, timeout_s=30.0):
|
||||
return EmbeddingResult(
|
||||
vector=[0.0] * dim, model=FALLBACK_EMBEDDING_MODEL, dim=dim
|
||||
)
|
||||
|
||||
# Patch the symbol the worker resolved at import time.
|
||||
import chat.services.embedding_worker as worker_mod
|
||||
|
||||
monkeypatch.setattr(worker_mod, "generate_embedding", _fake_generate)
|
||||
|
||||
worker = EmbeddingWorker(
|
||||
conn_factory=lambda: open_db(db),
|
||||
client=None,
|
||||
)
|
||||
await worker.start()
|
||||
worker.enqueue(EmbeddingJob(memory_id=memory_ids[0], text="anything"))
|
||||
await worker.stop()
|
||||
|
||||
with open_db(db) as conn:
|
||||
cur = conn.execute(
|
||||
"SELECT COUNT(*) FROM event_log WHERE kind = 'embedding_indexed'"
|
||||
)
|
||||
assert cur.fetchone()[0] == 0
|
||||
cur = conn.execute("SELECT COUNT(*) FROM embeddings")
|
||||
assert cur.fetchone()[0] == 0
|
||||
|
||||
|
||||
async def test_worker_handles_concurrent_jobs_serially(tmp_path):
|
||||
"""Five jobs queued back-to-back must process in FIFO order — the
|
||||
single-task design respects the Featherless 2-conn cap (and keeps
|
||||
event_log ordering deterministic)."""
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
memory_ids = _seed_memories(db, count=5)
|
||||
|
||||
worker = EmbeddingWorker(
|
||||
conn_factory=lambda: open_db(db),
|
||||
client=None,
|
||||
)
|
||||
await worker.start()
|
||||
# Enqueue all five before yielding to the loop — exercises the queue
|
||||
# rather than a one-at-a-time drain.
|
||||
for mid in memory_ids:
|
||||
worker.enqueue(EmbeddingJob(memory_id=mid, text=f"text-{mid}"))
|
||||
await worker.stop()
|
||||
|
||||
with open_db(db) as conn:
|
||||
# Events landed in enqueue order (FIFO).
|
||||
cur = conn.execute(
|
||||
"SELECT json_extract(payload_json, '$.memory_id') "
|
||||
"FROM event_log WHERE kind = 'embedding_indexed' "
|
||||
"ORDER BY id"
|
||||
)
|
||||
seen = [r[0] for r in cur.fetchall()]
|
||||
assert seen == memory_ids
|
||||
|
||||
# All five embeddings projected.
|
||||
cur = conn.execute("SELECT COUNT(*) FROM embeddings")
|
||||
assert cur.fetchone()[0] == 5
|
||||
@@ -540,3 +540,49 @@ def test_record_turn_memory_you_present_false_requires_guest(tmp_path):
|
||||
narrative_text="invalid",
|
||||
you_present=False,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# T97: embedding-worker enqueue hook.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_record_turn_memory_enqueues_embedding_job(tmp_path):
|
||||
"""When ``app.state.embedding_worker`` is wired, every per-witness
|
||||
write enqueues an :class:`EmbeddingJob` carrying the freshly-projected
|
||||
memory id and the narrative text. Two-bot turn -> two jobs."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from chat.services.embedding_worker import EmbeddingJob
|
||||
|
||||
db = tmp_path / "t.db"
|
||||
apply_migrations(db)
|
||||
_seed_two_bots(db)
|
||||
|
||||
captured: list[EmbeddingJob] = []
|
||||
|
||||
class _StubWorker:
|
||||
def enqueue(self, job: EmbeddingJob) -> None:
|
||||
captured.append(job)
|
||||
|
||||
fake_app = SimpleNamespace(
|
||||
state=SimpleNamespace(embedding_worker=_StubWorker())
|
||||
)
|
||||
|
||||
with open_db(db) as conn:
|
||||
result = record_turn_memory_for_present(
|
||||
conn,
|
||||
chat_id="chat_ab",
|
||||
host_bot_id="bot_a",
|
||||
guest_bot_id="bot_b",
|
||||
narrative_text="Both bots witness this beat.",
|
||||
app=fake_app,
|
||||
)
|
||||
|
||||
# One job per witness — host first, then guest (matches result dict
|
||||
# insertion order in record_turn_memory_for_present).
|
||||
assert len(captured) == 2
|
||||
expected_ids = {result["bot_a"][1], result["bot_b"][1]}
|
||||
assert {job.memory_id for job in captured} == expected_ids
|
||||
for job in captured:
|
||||
assert job.text == "Both bots witness this beat."
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
"""Phase 4 cross-feature integration tests (T97 follow-up).
|
||||
|
||||
Wave 8 / T101 will populate this file with the full Phase 4 retrieval +
|
||||
embedding integration suite. For now this houses a single test pinning
|
||||
the T97.5 wiring: the production turn route plumbs ``app=request.app``
|
||||
all the way through ``record_turn_memory_for_present`` so the embedding
|
||||
worker actually receives jobs in production. Without this fix-up the
|
||||
plumbing added in T97 was dormant — every per-witness write took the
|
||||
no-app branch and silently dropped the embed enqueue.
|
||||
|
||||
The test monkeypatches ``app.state.embedding_worker.enqueue`` to record
|
||||
jobs (rather than draining the worker mid-test) so the assertion is
|
||||
deterministic and free of asyncio-timing flakiness inside FastAPI's
|
||||
TestClient. The bug we're guarding against is "did the call site pass
|
||||
``app`` at all" — the worker's drain path is exercised in
|
||||
:mod:`tests.test_embedding_worker`, so duplicating that here would add
|
||||
no coverage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from chat.app import app
|
||||
from chat.db.connection import open_db
|
||||
from chat.eventlog.log import append_event
|
||||
from chat.eventlog.projector import project
|
||||
from chat.llm.mock import MockLLMClient
|
||||
|
||||
|
||||
def _zero_state() -> str:
|
||||
return json.dumps(
|
||||
{"affinity_delta": 0, "trust_delta": 0, "knowledge_facts": []}
|
||||
)
|
||||
|
||||
|
||||
def _override_llm(canned: list[str]) -> MockLLMClient:
|
||||
from chat.web.kickoff import get_llm_client
|
||||
|
||||
mock = MockLLMClient(canned=list(canned))
|
||||
app.dependency_overrides[get_llm_client] = lambda: mock
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_state_setup(tmp_path, monkeypatch):
|
||||
cfg = tmp_path / "config.toml"
|
||||
cfg.write_text('featherless_api_key = "test"\n')
|
||||
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
|
||||
db = tmp_path / "test.db"
|
||||
monkeypatch.setenv("CHAT_DB_PATH", str(db))
|
||||
with TestClient(app) as c:
|
||||
# The background worker is disabled so the canned-response queue
|
||||
# is consumed only by the request path. The embedding worker
|
||||
# stays "started" but its loop won't observe the captured
|
||||
# enqueues — we replace ``enqueue`` on the worker instance below.
|
||||
app.state.background_worker.enabled = False
|
||||
yield c
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _seed(db_path: Path) -> None:
|
||||
"""Mirror of ``tests/test_turn_flow.py::_seed`` — single bot + chat
|
||||
+ edge + activities so the prompt assembler has something to render.
|
||||
"""
|
||||
with open_db(db_path) as conn:
|
||||
append_event(
|
||||
conn,
|
||||
kind="bot_authored",
|
||||
payload={
|
||||
"id": "bot_a",
|
||||
"name": "BotA",
|
||||
"persona": "thoughtful, observant",
|
||||
"voice_samples": [],
|
||||
"traits": [],
|
||||
"backstory": "",
|
||||
"initial_relationship_to_you": "",
|
||||
"kickoff_prose": "...",
|
||||
},
|
||||
)
|
||||
append_event(
|
||||
conn,
|
||||
kind="chat_created",
|
||||
payload={
|
||||
"id": "chat_bot_a",
|
||||
"host_bot_id": "bot_a",
|
||||
"initial_time": "2026-04-26T20:00:00+00:00",
|
||||
"narrative_anchor": "Day 1",
|
||||
"weather": "",
|
||||
},
|
||||
)
|
||||
append_event(
|
||||
conn,
|
||||
kind="edge_update",
|
||||
payload={
|
||||
"source_id": "bot_a",
|
||||
"target_id": "you",
|
||||
"chat_id": "chat_bot_a",
|
||||
"knowledge_facts": ["coworker"],
|
||||
},
|
||||
)
|
||||
for entity_id, verb in [("you", "talking"), ("bot_a", "listening")]:
|
||||
append_event(
|
||||
conn,
|
||||
kind="activity_change",
|
||||
payload={
|
||||
"entity_id": entity_id,
|
||||
"posture": "sitting",
|
||||
"action": {
|
||||
"verb": verb,
|
||||
"interruptible": True,
|
||||
"required_attention": "low",
|
||||
"expected_duration": "ongoing",
|
||||
},
|
||||
"attention": "",
|
||||
"holding": [],
|
||||
"status": {},
|
||||
},
|
||||
)
|
||||
project(conn)
|
||||
|
||||
|
||||
def test_post_turn_embeddings_indexed_via_worker_hook(
|
||||
app_state_setup, tmp_path
|
||||
):
|
||||
"""POST a turn; the route must pass ``app=request.app`` into
|
||||
``record_turn_memory_for_present`` so the per-witness write enqueues
|
||||
an :class:`EmbeddingJob` on ``app.state.embedding_worker``.
|
||||
|
||||
Without the T97.5 wiring this test fails: the call site previously
|
||||
omitted ``app=`` and the helper's ``app is None`` branch silently
|
||||
skipped every enqueue. We monkeypatch ``enqueue`` on the live
|
||||
embedding worker (rather than draining the queue mid-request) so the
|
||||
assertion does not depend on asyncio scheduling inside the
|
||||
TestClient — the bug is in the wiring, and the wiring is what we
|
||||
pin. The drain path is covered separately in
|
||||
:mod:`tests.test_embedding_worker`.
|
||||
"""
|
||||
_seed(tmp_path / "test.db")
|
||||
|
||||
canned_parse = json.dumps(
|
||||
{"segments": [{"kind": "dialogue", "text": "hello"}]}
|
||||
)
|
||||
_override_llm(
|
||||
[canned_parse, "Hi there.", _zero_state(), _zero_state()]
|
||||
)
|
||||
|
||||
captured: list = []
|
||||
worker = app.state.embedding_worker
|
||||
original_enqueue = worker.enqueue
|
||||
worker.enqueue = captured.append # type: ignore[assignment]
|
||||
try:
|
||||
response = app_state_setup.post(
|
||||
"/chats/chat_bot_a/turns", data={"prose": "hello"}
|
||||
)
|
||||
assert response.status_code == 204
|
||||
finally:
|
||||
worker.enqueue = original_enqueue # type: ignore[assignment]
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
# Single-bot turn -> one ``memory_written`` -> one EmbeddingJob.
|
||||
# The job's ``memory_id`` should match the freshly-projected memory
|
||||
# row, and its ``text`` should carry the assistant's narrative text.
|
||||
assert len(captured) == 1
|
||||
job = captured[0]
|
||||
assert job.text == "Hi there."
|
||||
|
||||
with open_db(tmp_path / "test.db") as conn:
|
||||
memory_ids = [
|
||||
r[0]
|
||||
for r in conn.execute(
|
||||
"SELECT id FROM memories WHERE owner_id = ?",
|
||||
("bot_a",),
|
||||
).fetchall()
|
||||
]
|
||||
assert job.memory_id in memory_ids
|
||||
Reference in New Issue
Block a user