feat: async significance pass with auto-pin on score 3

This commit is contained in:
Joseph Doherty
2026-04-26 13:27:25 -04:00
parent a45dabb6ae
commit eb4cdf9cbb
9 changed files with 588 additions and 7 deletions
+173
View File
@@ -0,0 +1,173 @@
"""Async background worker for post-turn jobs (T22).
The turn flow records a ``memory_written`` event synchronously on the
request path so the timeline updates immediately. Significance scoring is
a separate classifier round-trip that we don't want to block on, so the
turn handler enqueues a :class:`SignificanceJob` here and the worker
drains the queue out-of-band.
A single :class:`BackgroundWorker` is started/stopped via FastAPI lifespan
in :mod:`chat.app`. The worker owns its own ``asyncio.Queue`` and runs
exactly one task that pulls jobs off the queue, calls
:func:`chat.services.significance.compute_significance`, and writes
``memory_significance_set`` (and on score 3, ``memory_pin_changed``)
events. Each job opens its own DB connection — workers and request
handlers don't share connections.
Failures inside ``_process`` are logged and swallowed: a flaky classifier
shouldn't take down the worker. Tests can disable enqueue() by setting
``BackgroundWorker.enabled = False`` (e.g. in the existing turn-flow
fixture, which doesn't have a usable LLM key for the lifespan-managed
factory).
"""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from typing import Callable
from chat.config import Settings
from chat.db.connection import open_db
from chat.eventlog.log import append_and_apply
from chat.llm.client import LLMClient
from chat.services.significance import compute_significance
log = logging.getLogger(__name__)
@dataclass
class SignificanceJob:
"""One unit of work for the background worker.
``host_bot_id`` is the memory's owner — used both for the auto-pin
soft cap query and as the eventual scope for the soft-cap eviction.
"""
memory_id: int
narrative_text: str
prior_dialogue: list[dict]
host_bot_id: str
class BackgroundWorker:
"""asyncio.Queue-backed single-worker task.
Started on app startup; ``stop()`` enqueues a sentinel and awaits the
task so any in-flight job has a chance to finish. Pending jobs after
the sentinel are dropped on shutdown — Phase 1 simplification.
"""
def __init__(
self,
settings: Settings,
llm_client_factory: Callable[[], LLMClient],
*,
enabled: bool = True,
) -> None:
self._settings = settings
self._llm_client_factory = llm_client_factory
self._queue: asyncio.Queue[SignificanceJob | None] = asyncio.Queue()
self._task: asyncio.Task | None = None
self.enabled = enabled
async def start(self) -> None:
if self._task is not None:
return
self._task = asyncio.create_task(self._run())
async def stop(self) -> None:
if self._task is None:
return
await self._queue.put(None) # sentinel
await self._task
self._task = None
def enqueue(self, job: SignificanceJob) -> None:
if not self.enabled:
return
self._queue.put_nowait(job)
async def _run(self) -> None:
while True:
job = await self._queue.get()
if job is None:
return
try:
await self._process(job)
except Exception as exc: # noqa: BLE001 — worker must not die
log.exception("significance job failed: %s", exc)
async def _process(self, job: SignificanceJob) -> None:
client = self._llm_client_factory()
score = await compute_significance(
client,
model=self._settings.classifier_model,
narrative_text=job.narrative_text,
prior_dialogue=job.prior_dialogue,
)
with open_db(self._settings.db_path) as conn:
append_and_apply(
conn,
kind="memory_significance_set",
payload={
"memory_id": job.memory_id,
"significance": score,
},
)
if score >= 3:
_auto_pin_with_cap(
conn,
owner_id=job.host_bot_id,
memory_id=job.memory_id,
)
def _auto_pin_with_cap(
conn,
*,
owner_id: str,
memory_id: int,
cap: int = 8,
) -> None:
"""Auto-pin ``memory_id`` and evict the oldest auto-pin if over ``cap``.
Per §8.5: pivotal turns are auto-pinned, with a soft cap of 8 pins per
bot. When the cap is exceeded the oldest auto-pin is unpinned (manual
pins are never auto-evicted — we filter on ``auto_pinned = 1``).
"""
append_and_apply(
conn,
kind="memory_pin_changed",
payload={
"memory_id": memory_id,
"pinned": 1,
"auto_pinned": 1,
},
)
cur = conn.execute(
"SELECT COUNT(*) FROM memories WHERE owner_id = ? AND pinned = 1",
(owner_id,),
)
count = cur.fetchone()[0]
if count <= cap:
return
cur = conn.execute(
"SELECT id FROM memories "
"WHERE owner_id = ? AND pinned = 1 AND auto_pinned = 1 AND id != ? "
"ORDER BY created_at ASC, id ASC LIMIT 1",
(owner_id, memory_id),
)
row = cur.fetchone()
if row is None:
return
append_and_apply(
conn,
kind="memory_pin_changed",
payload={
"memory_id": row[0],
"pinned": 0,
"auto_pinned": 0,
},
)
+20 -3
View File
@@ -32,13 +32,22 @@ def record_turn_memory(
chat_clock_at: str | None = None,
source: str = "direct",
significance: int = 1,
) -> int:
) -> tuple[int, int | None]:
"""Append a ``memory_written`` event for the host bot's POV of this turn.
Uses :func:`chat.eventlog.log.append_and_apply` (not raw
:func:`append_event`) so the new memory row is projected immediately
without re-running prior non-idempotent handlers (e.g. ``edge_update``
deltas). Returns the new event id.
deltas).
Returns ``(event_id, memory_id)``. ``event_id`` is the row id of the
just-appended ``memory_written`` event in ``event_log``. ``memory_id``
is the autoincrement PK of the corresponding ``memories`` row — these
are *different* numbers (event_log and memories use independent
rowid sequences) so callers needing to update significance or pin
state must use ``memory_id``. Falls back to ``None`` if the projected
row can't be located, which shouldn't happen but keeps the return
shape stable.
"""
payload: dict = {
"owner_id": host_bot_id,
@@ -58,4 +67,12 @@ def record_turn_memory(
if chat_clock_at is not None:
payload["chat_clock_at"] = chat_clock_at
return append_and_apply(conn, kind="memory_written", payload=payload)
event_id = append_and_apply(conn, kind="memory_written", payload=payload)
row = conn.execute(
"SELECT id FROM memories "
"WHERE owner_id = ? AND chat_id = ? "
"ORDER BY id DESC LIMIT 1",
(host_bot_id, chat_id),
).fetchone()
memory_id = row[0] if row else None
return event_id, memory_id
+75
View File
@@ -0,0 +1,75 @@
"""Turn-level significance scorer (T22).
Per Requirements §11.1, each turn is scored on a 0-3 scale:
- 0 = Routine: small talk, ordinary action.
- 1 = Notable: a specific detail or beat worth remembering.
- 2 = Significant: a scene-level moment, real disagreement, confided secret.
- 3 = Pivotal: a relationship-altering event (first kiss, betrayal, "I love
you").
The scorer is conservative: pivotal (3) requires a clear signal because the
auto-pin rule (§8.5) gives those memories permanent shelf space. The
classifier returns a strict-JSON ``SignificanceVerdict``; a malformed or
refusal-shaped response falls back to ``score=1`` (Notable) — a safe
middle-of-the-road default that won't trigger auto-pin.
"""
from __future__ import annotations
from pydantic import BaseModel, Field
from chat.llm.classify import classify
from chat.llm.client import LLMClient
class SignificanceVerdict(BaseModel):
score: int = Field(ge=0, le=3)
reason: str = ""
_SYSTEM = """You score the significance of a roleplay turn 0-3:
0 = Routine: small talk, ordinary action.
1 = Notable: a specific detail or beat worth remembering.
2 = Significant: a scene-level moment, real disagreement, confided secret.
3 = Pivotal: a relationship-altering event (first kiss, betrayal, "I love you").
Be conservative — pivotal (3) requires a clear signal. Reply with JSON: {"score": int 0-3, "reason": str}."""
async def compute_significance(
client: LLMClient,
*,
model: str,
narrative_text: str,
prior_dialogue: list[dict],
timeout_s: float = 10.0,
) -> int:
"""Score the significance of ``narrative_text`` (the just-written turn).
``prior_dialogue`` is a list of ``{"speaker", "text"}`` dicts ordered
oldest-first; the last 6 entries are stitched into the user prompt as
context so the classifier can recognize escalation. Returns an int in
``[0, 3]`` — clamped defensively in case the classifier slips a value
past the schema validator.
"""
user_prompt = "PRIOR DIALOGUE:\n"
for turn in prior_dialogue[-6:]:
speaker = turn.get("speaker", "?")
text = turn.get("text", "")
user_prompt += f"{speaker}: {text}\n"
user_prompt += (
f"\nNEW TURN:\n{narrative_text}\n\n"
"Score the significance of the NEW TURN."
)
result = await classify(
client,
model=model,
system=_SYSTEM,
user=user_prompt,
schema=SignificanceVerdict,
default=SignificanceVerdict(score=1, reason="fallback"),
timeout_s=timeout_s,
)
return max(0, min(3, result.score))