diff --git a/chat/app.py b/chat/app.py index 8417d1c..6233d58 100644 --- a/chat/app.py +++ b/chat/app.py @@ -7,6 +7,7 @@ from fastapi.staticfiles import StaticFiles from chat.config import load_settings from chat.db.migrate import apply_migrations +from chat.services.background import BackgroundWorker # Trigger handler registration: import chat.state.entities # noqa: F401 @@ -29,7 +30,26 @@ async def lifespan(app: FastAPI): settings.db_path.parent.mkdir(parents=True, exist_ok=True) apply_migrations(settings.db_path) app.state.settings = settings - yield + + # Background worker for the async significance pass (T22). Each job + # constructs a fresh FeatherlessClient via the factory; tests can + # disable enqueue by toggling ``app.state.background_worker.enabled``. + def _factory(): + from chat.llm.featherless import FeatherlessClient + + return FeatherlessClient( + api_key=settings.featherless_api_key, + base_url=settings.featherless_base_url, + ) + + worker = BackgroundWorker(settings, llm_client_factory=_factory) + await worker.start() + app.state.background_worker = worker + + try: + yield + finally: + await worker.stop() app = FastAPI(title="chat", lifespan=lifespan) diff --git a/chat/services/background.py b/chat/services/background.py new file mode 100644 index 0000000..0ad7f68 --- /dev/null +++ b/chat/services/background.py @@ -0,0 +1,173 @@ +"""Async background worker for post-turn jobs (T22). + +The turn flow records a ``memory_written`` event synchronously on the +request path so the timeline updates immediately. Significance scoring is +a separate classifier round-trip that we don't want to block on, so the +turn handler enqueues a :class:`SignificanceJob` here and the worker +drains the queue out-of-band. + +A single :class:`BackgroundWorker` is started/stopped via FastAPI lifespan +in :mod:`chat.app`. The worker owns its own ``asyncio.Queue`` and runs +exactly one task that pulls jobs off the queue, calls +:func:`chat.services.significance.compute_significance`, and writes +``memory_significance_set`` (and on score 3, ``memory_pin_changed``) +events. Each job opens its own DB connection — workers and request +handlers don't share connections. + +Failures inside ``_process`` are logged and swallowed: a flaky classifier +shouldn't take down the worker. Tests can disable enqueue() by setting +``BackgroundWorker.enabled = False`` (e.g. in the existing turn-flow +fixture, which doesn't have a usable LLM key for the lifespan-managed +factory). +""" + +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass +from typing import Callable + +from chat.config import Settings +from chat.db.connection import open_db +from chat.eventlog.log import append_and_apply +from chat.llm.client import LLMClient +from chat.services.significance import compute_significance + +log = logging.getLogger(__name__) + + +@dataclass +class SignificanceJob: + """One unit of work for the background worker. + + ``host_bot_id`` is the memory's owner — used both for the auto-pin + soft cap query and as the eventual scope for the soft-cap eviction. + """ + + memory_id: int + narrative_text: str + prior_dialogue: list[dict] + host_bot_id: str + + +class BackgroundWorker: + """asyncio.Queue-backed single-worker task. + + Started on app startup; ``stop()`` enqueues a sentinel and awaits the + task so any in-flight job has a chance to finish. Pending jobs after + the sentinel are dropped on shutdown — Phase 1 simplification. + """ + + def __init__( + self, + settings: Settings, + llm_client_factory: Callable[[], LLMClient], + *, + enabled: bool = True, + ) -> None: + self._settings = settings + self._llm_client_factory = llm_client_factory + self._queue: asyncio.Queue[SignificanceJob | None] = asyncio.Queue() + self._task: asyncio.Task | None = None + self.enabled = enabled + + async def start(self) -> None: + if self._task is not None: + return + self._task = asyncio.create_task(self._run()) + + async def stop(self) -> None: + if self._task is None: + return + await self._queue.put(None) # sentinel + await self._task + self._task = None + + def enqueue(self, job: SignificanceJob) -> None: + if not self.enabled: + return + self._queue.put_nowait(job) + + async def _run(self) -> None: + while True: + job = await self._queue.get() + if job is None: + return + try: + await self._process(job) + except Exception as exc: # noqa: BLE001 — worker must not die + log.exception("significance job failed: %s", exc) + + async def _process(self, job: SignificanceJob) -> None: + client = self._llm_client_factory() + score = await compute_significance( + client, + model=self._settings.classifier_model, + narrative_text=job.narrative_text, + prior_dialogue=job.prior_dialogue, + ) + with open_db(self._settings.db_path) as conn: + append_and_apply( + conn, + kind="memory_significance_set", + payload={ + "memory_id": job.memory_id, + "significance": score, + }, + ) + if score >= 3: + _auto_pin_with_cap( + conn, + owner_id=job.host_bot_id, + memory_id=job.memory_id, + ) + + +def _auto_pin_with_cap( + conn, + *, + owner_id: str, + memory_id: int, + cap: int = 8, +) -> None: + """Auto-pin ``memory_id`` and evict the oldest auto-pin if over ``cap``. + + Per §8.5: pivotal turns are auto-pinned, with a soft cap of 8 pins per + bot. When the cap is exceeded the oldest auto-pin is unpinned (manual + pins are never auto-evicted — we filter on ``auto_pinned = 1``). + """ + append_and_apply( + conn, + kind="memory_pin_changed", + payload={ + "memory_id": memory_id, + "pinned": 1, + "auto_pinned": 1, + }, + ) + cur = conn.execute( + "SELECT COUNT(*) FROM memories WHERE owner_id = ? AND pinned = 1", + (owner_id,), + ) + count = cur.fetchone()[0] + if count <= cap: + return + cur = conn.execute( + "SELECT id FROM memories " + "WHERE owner_id = ? AND pinned = 1 AND auto_pinned = 1 AND id != ? " + "ORDER BY created_at ASC, id ASC LIMIT 1", + (owner_id, memory_id), + ) + row = cur.fetchone() + if row is None: + return + append_and_apply( + conn, + kind="memory_pin_changed", + payload={ + "memory_id": row[0], + "pinned": 0, + "auto_pinned": 0, + }, + ) diff --git a/chat/services/memory_write.py b/chat/services/memory_write.py index ee27d3d..3ca40c5 100644 --- a/chat/services/memory_write.py +++ b/chat/services/memory_write.py @@ -32,13 +32,22 @@ def record_turn_memory( chat_clock_at: str | None = None, source: str = "direct", significance: int = 1, -) -> int: +) -> tuple[int, int | None]: """Append a ``memory_written`` event for the host bot's POV of this turn. Uses :func:`chat.eventlog.log.append_and_apply` (not raw :func:`append_event`) so the new memory row is projected immediately without re-running prior non-idempotent handlers (e.g. ``edge_update`` - deltas). Returns the new event id. + deltas). + + Returns ``(event_id, memory_id)``. ``event_id`` is the row id of the + just-appended ``memory_written`` event in ``event_log``. ``memory_id`` + is the autoincrement PK of the corresponding ``memories`` row — these + are *different* numbers (event_log and memories use independent + rowid sequences) so callers needing to update significance or pin + state must use ``memory_id``. Falls back to ``None`` if the projected + row can't be located, which shouldn't happen but keeps the return + shape stable. """ payload: dict = { "owner_id": host_bot_id, @@ -58,4 +67,12 @@ def record_turn_memory( if chat_clock_at is not None: payload["chat_clock_at"] = chat_clock_at - return append_and_apply(conn, kind="memory_written", payload=payload) + event_id = append_and_apply(conn, kind="memory_written", payload=payload) + row = conn.execute( + "SELECT id FROM memories " + "WHERE owner_id = ? AND chat_id = ? " + "ORDER BY id DESC LIMIT 1", + (host_bot_id, chat_id), + ).fetchone() + memory_id = row[0] if row else None + return event_id, memory_id diff --git a/chat/services/significance.py b/chat/services/significance.py new file mode 100644 index 0000000..eb3791c --- /dev/null +++ b/chat/services/significance.py @@ -0,0 +1,75 @@ +"""Turn-level significance scorer (T22). + +Per Requirements §11.1, each turn is scored on a 0-3 scale: + +- 0 = Routine: small talk, ordinary action. +- 1 = Notable: a specific detail or beat worth remembering. +- 2 = Significant: a scene-level moment, real disagreement, confided secret. +- 3 = Pivotal: a relationship-altering event (first kiss, betrayal, "I love + you"). + +The scorer is conservative: pivotal (3) requires a clear signal because the +auto-pin rule (§8.5) gives those memories permanent shelf space. The +classifier returns a strict-JSON ``SignificanceVerdict``; a malformed or +refusal-shaped response falls back to ``score=1`` (Notable) — a safe +middle-of-the-road default that won't trigger auto-pin. +""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + +from chat.llm.classify import classify +from chat.llm.client import LLMClient + + +class SignificanceVerdict(BaseModel): + score: int = Field(ge=0, le=3) + reason: str = "" + + +_SYSTEM = """You score the significance of a roleplay turn 0-3: +0 = Routine: small talk, ordinary action. +1 = Notable: a specific detail or beat worth remembering. +2 = Significant: a scene-level moment, real disagreement, confided secret. +3 = Pivotal: a relationship-altering event (first kiss, betrayal, "I love you"). + +Be conservative — pivotal (3) requires a clear signal. Reply with JSON: {"score": int 0-3, "reason": str}.""" + + +async def compute_significance( + client: LLMClient, + *, + model: str, + narrative_text: str, + prior_dialogue: list[dict], + timeout_s: float = 10.0, +) -> int: + """Score the significance of ``narrative_text`` (the just-written turn). + + ``prior_dialogue`` is a list of ``{"speaker", "text"}`` dicts ordered + oldest-first; the last 6 entries are stitched into the user prompt as + context so the classifier can recognize escalation. Returns an int in + ``[0, 3]`` — clamped defensively in case the classifier slips a value + past the schema validator. + """ + user_prompt = "PRIOR DIALOGUE:\n" + for turn in prior_dialogue[-6:]: + speaker = turn.get("speaker", "?") + text = turn.get("text", "") + user_prompt += f"{speaker}: {text}\n" + user_prompt += ( + f"\nNEW TURN:\n{narrative_text}\n\n" + "Score the significance of the NEW TURN." + ) + + result = await classify( + client, + model=model, + system=_SYSTEM, + user=user_prompt, + schema=SignificanceVerdict, + default=SignificanceVerdict(score=1, reason="fallback"), + timeout_s=timeout_s, + ) + return max(0, min(3, result.score)) diff --git a/chat/state/memory.py b/chat/state/memory.py index b2655b0..f0420ca 100644 --- a/chat/state/memory.py +++ b/chat/state/memory.py @@ -38,6 +38,35 @@ def _apply_memory_written(conn: Connection, e: Event) -> None: ) +@on("memory_significance_set") +def _apply_memory_significance_set(conn: Connection, e: Event) -> None: + """Update an existing memory's significance score (T22). + + Emitted by the async significance worker after it scores the turn. + """ + p = e.payload + conn.execute( + "UPDATE memories SET significance = ? WHERE id = ?", + (int(p["significance"]), int(p["memory_id"])), + ) + + +@on("memory_pin_changed") +def _apply_memory_pin_changed(conn: Connection, e: Event) -> None: + """Toggle a memory's pin state (T22, §8.5). + + Used both for auto-pinning a pivotal turn and for evicting the oldest + auto-pin when the per-owner soft cap is exceeded. Manual pins use the + same handler; the ``auto_pinned`` flag distinguishes them so the + eviction query can leave manual pins alone. + """ + p = e.payload + conn.execute( + "UPDATE memories SET pinned = ?, auto_pinned = ? WHERE id = ?", + (int(p["pinned"]), int(p["auto_pinned"]), int(p["memory_id"])), + ) + + def get_memory(conn: Connection, memory_id: int) -> dict | None: row = conn.execute( "SELECT * FROM memories WHERE id = ?", (memory_id,) diff --git a/chat/web/turns.py b/chat/web/turns.py index ba678a6..b5cca0b 100644 --- a/chat/web/turns.py +++ b/chat/web/turns.py @@ -39,6 +39,7 @@ from fastapi import APIRouter, Depends, Form, HTTPException, Request from fastapi.responses import Response from chat.eventlog.log import append_and_apply, append_event +from chat.services.background import SignificanceJob from chat.services.memory_write import record_turn_memory from chat.services.prompt import assemble_narrative_prompt from chat.services.state_update import compute_state_update @@ -226,7 +227,7 @@ async def post_turn( # narrative text (T27 will rewrite at scene close). Significance # defaults to 1; T22's async classifier pass will overwrite it. scene = active_scene(conn, chat_id) - record_turn_memory( + _event_id, memory_id = record_turn_memory( conn, chat_id=chat_id, host_bot_id=host_bot["id"], @@ -313,6 +314,23 @@ async def post_turn( }, ) + # 6c. Enqueue the async significance pass (Plan §11.1, T22). The + # worker scores the just-written memory 0-3, updates significance, + # and auto-pins on score 3 with the §8.5 soft-cap eviction rule. + # Enqueued before the broadcast so it's outstanding by the time the + # client sees ``turn_html`` — but the worker is async, so the user + # never blocks on it. + worker = getattr(request.app.state, "background_worker", None) + if worker is not None and memory_id is not None: + worker.enqueue( + SignificanceJob( + memory_id=memory_id, + narrative_text=full_text, + prior_dialogue=recent_for_update, + host_bot_id=host_bot["id"], + ) + ) + # 7. Broadcast a JSON completion event (for JS consumers) and an HTML # fragment event (for HTMX SSE swap-into-timeline). await publish( diff --git a/tests/test_memory_write.py b/tests/test_memory_write.py index f25fb4f..feaf6f5 100644 --- a/tests/test_memory_write.py +++ b/tests/test_memory_write.py @@ -64,7 +64,7 @@ def test_record_turn_memory_writes_event_and_projects(tmp_path): apply_migrations(db) _seed_minimal(db) with open_db(db) as conn: - eid = record_turn_memory( + eid, mid = record_turn_memory( conn, chat_id="chat_bot_a", host_bot_id="bot_a", @@ -73,6 +73,7 @@ def test_record_turn_memory_writes_event_and_projects(tmp_path): chat_clock_at="2026-04-26T20:00:00+00:00", ) assert eid > 0 + assert mid is not None and mid > 0 rows = conn.execute( "SELECT id, owner_id, chat_id, pov_summary, " @@ -110,13 +111,14 @@ def test_record_turn_memory_omits_optional_fields(tmp_path): _seed_minimal(db) with open_db(db) as conn: # Call without scene_id/chat_clock_at — should default to None. - eid = record_turn_memory( + eid, mid = record_turn_memory( conn, chat_id="chat_bot_a", host_bot_id="bot_a", narrative_text="A simple memory.", ) assert eid > 0 + assert mid is not None and mid > 0 row = conn.execute( "SELECT scene_id, chat_clock_at, source, reliability, " @@ -168,6 +170,11 @@ def client(tmp_path, monkeypatch): app.dependency_overrides[get_llm_client] = lambda: mock with TestClient(app) as c: + # Disable the lifespan-managed background worker — it would try + # to call Featherless with the test API key. The unit tests in + # test_significance.py exercise the worker directly with a mock + # factory; here we only care about the synchronous turn flow. + app.state.background_worker.enabled = False c.mock_llm = mock # type: ignore[attr-defined] yield c diff --git a/tests/test_significance.py b/tests/test_significance.py new file mode 100644 index 0000000..e5a538c --- /dev/null +++ b/tests/test_significance.py @@ -0,0 +1,237 @@ +"""Async significance pass with auto-pin on score 3 (T22). + +After ``assistant_turn`` lands the turn flow enqueues a SignificanceJob on +a background asyncio worker. The worker calls a classifier (per §11.1, +score 0-3) and writes a ``memory_significance_set`` event. On score 3 the +memory is auto-pinned and a soft cap of 8 pins per owner is enforced — +when the cap is exceeded the oldest auto-pin (excluding the just-pinned +row) is unpinned via another ``memory_pin_changed`` event. +""" + +from __future__ import annotations + +import asyncio +import json + +import pytest + +from chat.config import load_settings +from chat.db.connection import open_db +from chat.db.migrate import apply_migrations +from chat.eventlog.log import append_event +from chat.eventlog.projector import project +from chat.llm.mock import MockLLMClient +from chat.services.background import BackgroundWorker, SignificanceJob +from chat.services.significance import compute_significance + +# Trigger handler registration for projection. +import chat.state.entities # noqa: F401 +import chat.state.memory # noqa: F401 +import chat.state.world # noqa: F401 + + +async def test_compute_significance_parses_score(): + canned = json.dumps({"score": 2, "reason": "notable"}) + mock = MockLLMClient(canned=[canned]) + score = await compute_significance( + mock, + model="x", + narrative_text="...", + prior_dialogue=[], + ) + assert score == 2 + + +async def test_compute_significance_default_on_failure(): + # Both attempts return non-JSON text; the classify wrapper falls back + # to the SignificanceVerdict default (score=1, "fallback"). + mock = MockLLMClient(canned=["nope", "still nope"]) + score = await compute_significance( + mock, + model="x", + narrative_text="...", + prior_dialogue=[], + ) + assert score == 1 + + +async def test_background_worker_processes_job_and_updates_significance( + tmp_path, monkeypatch +): + cfg = tmp_path / "config.toml" + cfg.write_text('featherless_api_key = "test"\n') + monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg)) + db = tmp_path / "test.db" + monkeypatch.setenv("CHAT_DB_PATH", str(db)) + apply_migrations(db) + settings = load_settings() + + # Seed bot, chat, memory. + with open_db(db) as conn: + append_event( + conn, + kind="bot_authored", + payload={ + "id": "bot_a", + "name": "BotA", + "persona": "...", + "voice_samples": [], + "traits": [], + "backstory": "", + "initial_relationship_to_you": "", + "kickoff_prose": "", + }, + ) + append_event( + conn, + kind="chat_created", + payload={ + "id": "chat_bot_a", + "host_bot_id": "bot_a", + "initial_time": "2026-04-26T20:00:00+00:00", + "narrative_anchor": "Day 1", + "weather": "", + }, + ) + append_event( + conn, + kind="memory_written", + payload={ + "owner_id": "bot_a", + "chat_id": "chat_bot_a", + "pov_summary": "Some scene", + "witness_you": 1, + "witness_host": 1, + "witness_guest": 0, + "source": "direct", + "reliability": 1.0, + "significance": 1, + "pinned": 0, + "auto_pinned": 0, + }, + ) + project(conn) + memory_id = conn.execute( + "SELECT id FROM memories WHERE owner_id = 'bot_a'" + ).fetchone()[0] + + # Worker with mock LLM that returns score=3 (pivotal). + canned = [json.dumps({"score": 3, "reason": "pivotal"})] + factory = lambda: MockLLMClient(canned=list(canned)) + worker = BackgroundWorker(settings, llm_client_factory=factory) + await worker.start() + worker.enqueue( + SignificanceJob( + memory_id=memory_id, + narrative_text="...", + prior_dialogue=[], + host_bot_id="bot_a", + ) + ) + # Drain via stop sentinel — guarantees the prior job completed. + await worker.stop() + + # Verify significance updated AND memory auto-pinned. + with open_db(db) as conn: + row = conn.execute( + "SELECT significance, pinned, auto_pinned FROM memories " + "WHERE id = ?", + (memory_id,), + ).fetchone() + assert row[0] == 3 + assert row[1] == 1 # pinned + assert row[2] == 1 # auto_pinned + + +async def test_auto_pin_evicts_oldest_when_over_cap(tmp_path, monkeypatch): + """Pin 9 memories with score 3; verify only 8 are pinned at the end.""" + cfg = tmp_path / "config.toml" + cfg.write_text('featherless_api_key = "test"\n') + monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg)) + db = tmp_path / "test.db" + monkeypatch.setenv("CHAT_DB_PATH", str(db)) + apply_migrations(db) + settings = load_settings() + + with open_db(db) as conn: + append_event( + conn, + kind="bot_authored", + payload={ + "id": "bot_a", + "name": "BotA", + "persona": "...", + "voice_samples": [], + "traits": [], + "backstory": "", + "initial_relationship_to_you": "", + "kickoff_prose": "", + }, + ) + append_event( + conn, + kind="chat_created", + payload={ + "id": "chat_bot_a", + "host_bot_id": "bot_a", + "initial_time": "2026-04-26T20:00:00+00:00", + "narrative_anchor": "Day 1", + "weather": "", + }, + ) + for i in range(9): + append_event( + conn, + kind="memory_written", + payload={ + "owner_id": "bot_a", + "chat_id": "chat_bot_a", + "pov_summary": f"memory {i}", + "witness_you": 1, + "witness_host": 1, + "witness_guest": 0, + "source": "direct", + "reliability": 1.0, + "significance": 1, + "pinned": 0, + "auto_pinned": 0, + }, + ) + project(conn) + memory_ids = [ + r[0] + for r in conn.execute( + "SELECT id FROM memories WHERE owner_id = 'bot_a' ORDER BY id" + ).fetchall() + ] + + # Each job runs through its own MockLLMClient with one canned response. + factory = lambda: MockLLMClient( + canned=[json.dumps({"score": 3, "reason": "pivotal"})] + ) + worker = BackgroundWorker(settings, llm_client_factory=factory) + await worker.start() + for mid in memory_ids: + worker.enqueue( + SignificanceJob( + memory_id=mid, + narrative_text="...", + prior_dialogue=[], + host_bot_id="bot_a", + ) + ) + await worker.stop() + + with open_db(db) as conn: + pinned_count = conn.execute( + "SELECT COUNT(*) FROM memories " + "WHERE owner_id = 'bot_a' AND pinned = 1" + ).fetchone()[0] + assert pinned_count == 8 + + # The oldest should have been evicted. + first_id = memory_ids[0] + first_pinned = conn.execute( + "SELECT pinned FROM memories WHERE id = ?", (first_id,) + ).fetchone()[0] + assert first_pinned == 0 diff --git a/tests/test_turn_flow.py b/tests/test_turn_flow.py index 1f7228b..44c3bfb 100644 --- a/tests/test_turn_flow.py +++ b/tests/test_turn_flow.py @@ -58,6 +58,11 @@ def client(tmp_path, monkeypatch): app.dependency_overrides[get_llm_client] = lambda: mock with TestClient(app) as c: + # Disable the lifespan-managed background worker — it would + # otherwise try to score significance through Featherless with + # a fake test API key. Worker behavior is exercised directly in + # tests/test_significance.py with a mock LLM factory. + app.state.background_worker.enabled = False c.mock_llm = mock # type: ignore[attr-defined] yield c