From 64a07aa87fa19cc388bf65824995bb746cac1a4d Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Mon, 27 Apr 2026 02:51:40 -0400 Subject: [PATCH] feat: memory_write enqueues embedding job after each memory_written (T97.2) --- chat/services/memory_write.py | 43 +++++++++++++++++++++++++++++++- tests/test_memory_write.py | 46 +++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/chat/services/memory_write.py b/chat/services/memory_write.py index d60c3d9..5e89eb3 100644 --- a/chat/services/memory_write.py +++ b/chat/services/memory_write.py @@ -13,6 +13,14 @@ Phase 1 simplifications (per plan §11.1, T27 will refine): pass overwrites via a follow-up event. - Witness flags are hard-coded ``[you=1, host=1, guest=0]``. Phase 2 will derive them from ``chat.guest_bot_id`` once a guest can be present. + +T97 (Phase 4): each successful memory write also enqueues an +:class:`~chat.services.embedding_worker.EmbeddingJob` on the +lifespan-managed embedding worker, so the just-written memory gets a +vector indexed out-of-band. The hook is opt-in via the ``app`` kwarg — +callers without a FastAPI app handle (e.g. one-off scripts, isolated +unit tests) simply don't enqueue, and the backfill script can pick up +those rows later. """ from __future__ import annotations @@ -20,6 +28,7 @@ from __future__ import annotations from sqlite3 import Connection from chat.eventlog.log import append_and_apply +from chat.services.embedding_worker import EmbeddingJob def _write_one_memory( @@ -35,9 +44,16 @@ def _write_one_memory( chat_clock_at: str | None, source: str, significance: int, + app=None, ) -> tuple[int, int | None]: """Append a single ``memory_written`` event for ``owner_id`` and return - ``(event_id, memory_id)`` for the projected row.""" + ``(event_id, memory_id)`` for the projected row. + + When ``app`` is provided and ``app.state.embedding_worker`` exists, + enqueue an :class:`EmbeddingJob` for the freshly-projected memory id + (T97). Skipped silently if the worker is absent or the projected row + can't be located — the backfill script handles missing-vector rows. + """ payload: dict = { "owner_id": owner_id, "chat_id": chat_id, @@ -64,6 +80,23 @@ def _write_one_memory( (owner_id, chat_id), ).fetchone() memory_id = row[0] if row else None + + # T97: enqueue an embedding job for the just-written memory. The + # worker drains the queue out-of-band and emits an + # ``embedding_indexed`` event when the vector is ready. ``getattr`` + # keeps this a no-op for callers without a wired-up app (scripts, + # tests) — the backfill script handles those rows. + if memory_id is not None and narrative_text and narrative_text.strip(): + worker = ( + getattr(app.state, "embedding_worker", None) + if app is not None + else None + ) + if worker is not None: + worker.enqueue( + EmbeddingJob(memory_id=memory_id, text=narrative_text) + ) + return event_id, memory_id @@ -79,6 +112,7 @@ def record_turn_memory_for_present( source: str = "direct", significance: int = 1, you_present: bool = True, + app=None, ) -> dict[str, tuple[int, int | None]]: """Single entry-point for per-turn memory writes (T84). @@ -97,6 +131,9 @@ def record_turn_memory_for_present( with ``you_present=False`` is a programming error and raises :class:`ValueError`. + When ``app`` is provided, each per-witness write also enqueues an + :class:`EmbeddingJob` on ``app.state.embedding_worker`` (T97). + Returns a mapping ``{bot_id: (event_id, memory_id)}`` so callers can look up the freshly-projected memory id per owner without re-querying the database. @@ -121,6 +158,7 @@ def record_turn_memory_for_present( chat_clock_at=chat_clock_at, source=source, significance=significance, + app=app, ) if guest_bot_id is not None: result[guest_bot_id] = _write_one_memory( @@ -135,6 +173,7 @@ def record_turn_memory_for_present( chat_clock_at=chat_clock_at, source=source, significance=significance, + app=app, ) return result @@ -150,6 +189,7 @@ def record_meanwhile_memory( chat_clock_at: str | None = None, source: str = "direct", significance: int = 1, + app=None, ) -> dict[str, tuple[int, int | None]]: """Backward-compat thin wrapper for meanwhile memory writes (T64, T84). @@ -169,4 +209,5 @@ def record_meanwhile_memory( source=source, significance=significance, you_present=False, + app=app, ) diff --git a/tests/test_memory_write.py b/tests/test_memory_write.py index 8c5253a..3c135a5 100644 --- a/tests/test_memory_write.py +++ b/tests/test_memory_write.py @@ -540,3 +540,49 @@ def test_record_turn_memory_you_present_false_requires_guest(tmp_path): narrative_text="invalid", you_present=False, ) + + +# --------------------------------------------------------------------------- +# T97: embedding-worker enqueue hook. +# --------------------------------------------------------------------------- + + +def test_record_turn_memory_enqueues_embedding_job(tmp_path): + """When ``app.state.embedding_worker`` is wired, every per-witness + write enqueues an :class:`EmbeddingJob` carrying the freshly-projected + memory id and the narrative text. Two-bot turn -> two jobs.""" + from types import SimpleNamespace + + from chat.services.embedding_worker import EmbeddingJob + + db = tmp_path / "t.db" + apply_migrations(db) + _seed_two_bots(db) + + captured: list[EmbeddingJob] = [] + + class _StubWorker: + def enqueue(self, job: EmbeddingJob) -> None: + captured.append(job) + + fake_app = SimpleNamespace( + state=SimpleNamespace(embedding_worker=_StubWorker()) + ) + + with open_db(db) as conn: + result = record_turn_memory_for_present( + conn, + chat_id="chat_ab", + host_bot_id="bot_a", + guest_bot_id="bot_b", + narrative_text="Both bots witness this beat.", + app=fake_app, + ) + + # One job per witness — host first, then guest (matches result dict + # insertion order in record_turn_memory_for_present). + assert len(captured) == 2 + expected_ids = {result["bot_a"][1], result["bot_b"][1]} + assert {job.memory_id for job in captured} == expected_ids + for job in captured: + assert job.text == "Both bots witness this beat."