From e8d24a08753a3b2316b66a21209debd2767285b4 Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 26 Apr 2026 13:17:07 -0400 Subject: [PATCH] feat: post-turn state-update pass per present entity --- chat/eventlog/log.py | 39 ++++++ chat/services/state_update.py | 144 +++++++++++++++++++++ chat/web/turns.py | 93 ++++++++++++- tests/test_state_update.py | 237 ++++++++++++++++++++++++++++++++++ tests/test_turn_flow.py | 16 ++- 5 files changed, 523 insertions(+), 6 deletions(-) create mode 100644 chat/services/state_update.py create mode 100644 tests/test_state_update.py diff --git a/chat/eventlog/log.py b/chat/eventlog/log.py index 74919fb..ad228da 100644 --- a/chat/eventlog/log.py +++ b/chat/eventlog/log.py @@ -24,6 +24,45 @@ def append_event(conn: Connection, *, kind: str, payload: dict[str, Any], branch return cur.lastrowid +def append_and_apply( + conn: Connection, + *, + kind: str, + payload: dict[str, Any], + branch_id: int = 1, +) -> int: + """Append an event AND immediately apply just that event's handler. + + Calling :func:`chat.eventlog.projector.project` after an append + re-runs every prior event, which is fine for idempotent inserts but + catastrophic for delta-shaped events like ``edge_update`` whose + handler is *not* replay-safe (each pass would re-add the same + ``affinity_delta``). This helper runs only the brand-new event + through the registered handler, leaving prior state untouched. + + No-ops cleanly when ``kind`` has no registered handler — useful for + transcript-only events like ``user_turn`` / ``assistant_turn`` where + callers may swap ``append_event`` for ``append_and_apply`` without + side effects. + """ + # Local import to avoid a circular dependency at module import: the + # projector imports from .log to define ``Event``. + from chat.eventlog.projector import apply_event + + eid = append_event(conn, kind=kind, payload=payload, branch_id=branch_id) + event = Event( + id=eid, + branch_id=branch_id, + ts="", + kind=kind, + payload=payload, + superseded_by=None, + hidden=False, + ) + apply_event(conn, event) + return eid + + def read_events(conn: Connection, branch_id: int = 1, after_id: int = 0) -> Iterator[Event]: cur = conn.execute( "SELECT id, branch_id, ts, kind, payload_json, superseded_by, hidden " diff --git a/chat/services/state_update.py b/chat/services/state_update.py new file mode 100644 index 0000000..c9408fc --- /dev/null +++ b/chat/services/state_update.py @@ -0,0 +1,144 @@ +"""Post-turn state-update pass. + +Per Requirements §3.4, after every utterance we run a classifier on each +present entity (silent witnesses included) to extract directed-edge +deltas — what changed in *source*'s view of *target*. The classifier +returns three signals: + +- ``affinity_delta`` — signed change in how warmly source feels (typical + range -3..+3; the edge handler clamps the running total to 0..100). +- ``trust_delta`` — signed change in source's trust of target (same + shape). +- ``knowledge_facts`` — concrete things source learned about target + during this exchange. Stored verbatim and appended to ``edge.knowledge``. + +The wrapper deliberately uses :func:`chat.llm.classify.classify` with a +``default=StateUpdate()`` so a flapping classifier never blocks the turn +flow — at worst the edge sits unchanged and the next turn tries again +(§3.3 "graceful degradation" rule). +""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + +from chat.llm.classify import classify +from chat.llm.client import LLMClient + + +class StateUpdate(BaseModel): + """One directed-edge update from a single classifier call. + + Defaults are deliberately a no-op (zero deltas, empty facts) so a + failing classifier produces a benign event rather than a disruption. + """ + + affinity_delta: int = 0 + trust_delta: int = 0 + knowledge_facts: list[str] = Field(default_factory=list) + + +_SYSTEM_PROMPT = ( + "You are reading a recent slice of dialogue from a roleplay scene. " + "You assess how SOURCE's view of TARGET shifted based on what was " + "said — including silent witnessing (SOURCE may not have spoken).\n\n" + "Output a JSON object with exactly three fields:\n" + "- affinity_delta: signed integer in [-3, 3]. How much warmer " + "(positive) or cooler (negative) SOURCE now feels toward TARGET.\n" + "- trust_delta: signed integer in [-3, 3]. How much more (positive) " + "or less (negative) SOURCE now trusts TARGET.\n" + "- knowledge_facts: list of short strings. New, concrete facts " + "SOURCE learned about TARGET in this exchange. Use TARGET's actual " + "stated content; do not infer or interpret. Empty list is fine.\n\n" + "Be conservative. Most turns produce small deltas (-1, 0, +1). " + "Reserve +/-2 or +/-3 for moments that materially shift the " + "relationship. Knowledge_facts should be specific things stated in " + "dialogue (e.g. \"works at the bakery\"), not interpretations " + "(\"seems lonely\")." +) + + +def _format_dialogue(recent_dialogue: list[dict]) -> str: + """Render the recent-dialogue slice as plain ``Speaker: text`` lines.""" + if not recent_dialogue: + return "(no dialogue yet)" + lines = [] + for turn in recent_dialogue: + speaker = turn.get("speaker", "?") + text = turn.get("text", "") + lines.append(f"{speaker}: {text}") + return "\n".join(lines) + + +def _build_user_prompt( + *, + source_name: str, + source_persona: str, + target_name: str, + prior_affinity: int, + prior_trust: int, + prior_summary: str, + recent_dialogue: list[dict], +) -> str: + return ( + f"SOURCE: {source_name}\n" + f"SOURCE_PERSONA: {source_persona or '(none)'}\n" + f"TARGET: {target_name}\n" + f"PRIOR_AFFINITY (0-100): {prior_affinity}\n" + f"PRIOR_TRUST (0-100): {prior_trust}\n" + f"PRIOR_SUMMARY: {prior_summary or '(none)'}\n\n" + f"RECENT_DIALOGUE:\n{_format_dialogue(recent_dialogue)}\n\n" + "How did SOURCE's view of TARGET shift? Respond with JSON only." + ) + + +async def compute_state_update( + client: LLMClient, + *, + model: str, + source_id: str, + target_id: str, + source_name: str, + source_persona: str, + target_name: str, + prior_affinity: int, + prior_trust: int, + prior_summary: str, + recent_dialogue: list[dict], + timeout_s: float = 10.0, +) -> StateUpdate: + """Run a classifier pass and return the directed-edge update. + + On classifier failure (after retry) returns the schema default — a + no-op ``StateUpdate`` — so the turn flow can keep moving. The + ``source_id`` / ``target_id`` arguments are accepted for symmetry + with the caller (T20's POST flow uses them when emitting the + ``edge_update`` event); they're not currently embedded in the + prompt because the classifier reasons about names, not opaque ids. + """ + # ``source_id``/``target_id`` are kept on the signature even though + # the prompt only quotes the names: callers in turns.py thread the + # ids straight from this function's args into the appended event. + del source_id, target_id # silence unused-arg lint cleanly + + user_prompt = _build_user_prompt( + source_name=source_name, + source_persona=source_persona, + target_name=target_name, + prior_affinity=prior_affinity, + prior_trust=prior_trust, + prior_summary=prior_summary, + recent_dialogue=recent_dialogue, + ) + + return await classify( + client, + model=model, + system=_SYSTEM_PROMPT, + user=user_prompt, + schema=StateUpdate, + default=StateUpdate(), + timeout_s=timeout_s, + ) + + diff --git a/chat/web/turns.py b/chat/web/turns.py index e2c0082..5128ce3 100644 --- a/chat/web/turns.py +++ b/chat/web/turns.py @@ -15,9 +15,12 @@ The turn flow strings together the pieces built in T17 (turn parser), T18 channel as a ``token`` event so any subscribed browser tab sees them arrive in real time. 6. On stream complete, append an ``assistant_turn`` event with the full - text and ``truncated=False``. Also publish a ``turn_html`` event with a - ready-to-swap HTML fragment so HTMX's SSE extension can append it to - the timeline without a page reload. + text and ``truncated=False``. Then run a post-turn state-update pass + (Requirements §3.4): one classifier call per directed edge between + present entities, each producing an ``edge_update`` event with + affinity/trust/knowledge deltas. Finally publish a ``turn_html`` + event with a ready-to-swap HTML fragment so HTMX's SSE extension can + append it to the timeline without a page reload. 7. Return ``204 No Content`` — the SSE channel is the real conveyor of state, not the POST response body. @@ -35,11 +38,13 @@ import json from fastapi import APIRouter, Depends, Form, HTTPException, Request from fastapi.responses import Response -from chat.eventlog.log import append_event +from chat.eventlog.log import append_and_apply, append_event from chat.services.prompt import assemble_narrative_prompt +from chat.services.state_update import compute_state_update from chat.services.turn_parse import ParsedTurn, parse_turn +from chat.state.edges import get_edge +from chat.state.entities import get_bot, get_you from chat.state.world import get_chat -from chat.state.entities import get_bot from chat.web.bots import get_conn from chat.web.kickoff import get_llm_client from chat.web.pubsub import publish @@ -214,6 +219,84 @@ async def post_turn( }, ) + # 6b. Post-turn state-update pass (Requirements §3.4). For Phase 1 + # the only present entities are ``you`` and ``host_bot`` so we run + # two classifier calls — one per directed edge — and append the + # resulting ``edge_update`` events. The recent-dialogue slice is + # re-read here so the pass sees the just-appended assistant turn. + # We use ``append_and_apply`` (vs append + project) because the + # edge_update handler is *not* replay-safe: re-projecting prior + # events would re-apply their deltas on top of the live row. + recent_for_update = _read_recent_dialogue(conn, chat_id, limit=10) + you_entity = get_you(conn) or {"name": "you", "persona": ""} + last_at = chat.get("time") + + edge_b2y = get_edge(conn, host_bot["id"], "you") or { + "affinity": 50, + "trust": 50, + "summary": "", + } + update_b2y = await compute_state_update( + client, + model=settings.classifier_model, + source_id=host_bot["id"], + target_id="you", + source_name=host_bot["name"], + source_persona=host_bot.get("persona", ""), + target_name=you_entity.get("name", "you"), + prior_affinity=edge_b2y["affinity"], + prior_trust=edge_b2y["trust"], + prior_summary=edge_b2y.get("summary", "") or "", + recent_dialogue=recent_for_update, + ) + append_and_apply( + conn, + kind="edge_update", + payload={ + "source_id": host_bot["id"], + "target_id": "you", + "chat_id": chat_id, + "affinity_delta": update_b2y.affinity_delta, + "trust_delta": update_b2y.trust_delta, + "knowledge_facts": update_b2y.knowledge_facts, + "last_interaction_at": last_at, + "last_interaction_chat_id": chat_id, + }, + ) + + edge_y2b = get_edge(conn, "you", host_bot["id"]) or { + "affinity": 50, + "trust": 50, + "summary": "", + } + update_y2b = await compute_state_update( + client, + model=settings.classifier_model, + source_id="you", + target_id=host_bot["id"], + source_name=you_entity.get("name", "you"), + source_persona=you_entity.get("persona", "") or "", + target_name=host_bot["name"], + prior_affinity=edge_y2b["affinity"], + prior_trust=edge_y2b["trust"], + prior_summary=edge_y2b.get("summary", "") or "", + recent_dialogue=recent_for_update, + ) + append_and_apply( + conn, + kind="edge_update", + payload={ + "source_id": "you", + "target_id": host_bot["id"], + "chat_id": chat_id, + "affinity_delta": update_y2b.affinity_delta, + "trust_delta": update_y2b.trust_delta, + "knowledge_facts": update_y2b.knowledge_facts, + "last_interaction_at": last_at, + "last_interaction_chat_id": chat_id, + }, + ) + # 7. Broadcast a JSON completion event (for JS consumers) and an HTML # fragment event (for HTMX SSE swap-into-timeline). await publish( diff --git a/tests/test_state_update.py b/tests/test_state_update.py new file mode 100644 index 0000000..e7c19b3 --- /dev/null +++ b/tests/test_state_update.py @@ -0,0 +1,237 @@ +"""Post-turn state-update pass (T20). + +Per Requirements §3.4, after each utterance we run a classifier on every +present entity (silent witnesses too) to extract directed-edge deltas +(``affinity_delta``, ``trust_delta``, ``knowledge_facts``). The deltas +land as ``edge_update`` events and project into the ``edges`` table. + +These tests cover: +- The unit-level :func:`compute_state_update` happy path: classifier + returns valid JSON, the wrapper returns a populated ``StateUpdate``. +- The unit-level fallback path: classifier fails twice, the wrapper + returns a no-op ``StateUpdate`` (zeros + empty facts) per §3.3. +- The integration path: a successful POST appends two ``edge_update`` + events (one per direction) after the ``assistant_turn`` and the edge + projections reflect the deltas. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest +from fastapi.testclient import TestClient + +from chat.app import app +from chat.db.connection import open_db +from chat.eventlog.log import append_event +from chat.eventlog.projector import project +from chat.llm.mock import MockLLMClient +from chat.services.state_update import StateUpdate, compute_state_update + + +@pytest.mark.asyncio +async def test_compute_state_update_parses_classifier_output(): + canned = json.dumps( + {"affinity_delta": 2, "trust_delta": 1, "knowledge_facts": ["likes coffee"]} + ) + mock = MockLLMClient(canned=[canned]) + result = await compute_state_update( + mock, + model="x", + source_id="bot_a", + target_id="you", + source_name="BotA", + source_persona="thoughtful", + target_name="Me", + prior_affinity=50, + prior_trust=50, + prior_summary="", + recent_dialogue=[ + {"speaker": "you", "text": "hi"}, + {"speaker": "BotA", "text": "Hello!"}, + ], + ) + assert isinstance(result, StateUpdate) + assert result.affinity_delta == 2 + assert result.trust_delta == 1 + assert result.knowledge_facts == ["likes coffee"] + + +@pytest.mark.asyncio +async def test_compute_state_update_returns_default_on_failure(): + """Two malformed classifier responses -> default StateUpdate (zeros).""" + mock = MockLLMClient(canned=["nope", "still nope"]) + result = await compute_state_update( + mock, + model="x", + source_id="bot_a", + target_id="you", + source_name="BotA", + source_persona="", + target_name="Me", + prior_affinity=50, + prior_trust=50, + prior_summary="", + recent_dialogue=[], + ) + assert result.affinity_delta == 0 + assert result.trust_delta == 0 + assert result.knowledge_facts == [] + + +# --- integration test -------------------------------------------------------- + + +@pytest.fixture +def client(tmp_path, monkeypatch): + cfg = tmp_path / "config.toml" + cfg.write_text('featherless_api_key = "test"\n') + monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg)) + db = tmp_path / "test.db" + monkeypatch.setenv("CHAT_DB_PATH", str(db)) + + canned_parse = json.dumps( + {"segments": [{"kind": "dialogue", "text": "hello"}]} + ) + canned_response = "Hi there." + canned_state_b2y = json.dumps( + {"affinity_delta": 2, "trust_delta": 1, "knowledge_facts": ["greets warmly"]} + ) + canned_state_y2b = json.dumps( + {"affinity_delta": 3, "trust_delta": 0, "knowledge_facts": []} + ) + + from chat.web.kickoff import get_llm_client + + mock = MockLLMClient( + canned=[canned_parse, canned_response, canned_state_b2y, canned_state_y2b] + ) + app.dependency_overrides[get_llm_client] = lambda: mock + + with TestClient(app) as c: + c.mock_llm = mock # type: ignore[attr-defined] + yield c + + app.dependency_overrides.clear() + + +def _seed(db_path: Path) -> None: + """Author a bot, create a chat, and seed enough state for prompt assembly.""" + with open_db(db_path) as conn: + append_event( + conn, + kind="bot_authored", + payload={ + "id": "bot_a", + "name": "BotA", + "persona": "thoughtful, observant", + "voice_samples": [], + "traits": [], + "backstory": "", + "initial_relationship_to_you": "", + "kickoff_prose": "...", + }, + ) + append_event( + conn, + kind="chat_created", + payload={ + "id": "chat_bot_a", + "host_bot_id": "bot_a", + "initial_time": "2026-04-26T20:00:00+00:00", + "narrative_anchor": "Day 1", + "weather": "", + }, + ) + append_event( + conn, + kind="edge_update", + payload={ + "source_id": "bot_a", + "target_id": "you", + "chat_id": "chat_bot_a", + "knowledge_facts": ["coworker"], + }, + ) + append_event( + conn, + kind="activity_change", + payload={ + "entity_id": "you", + "posture": "sitting", + "action": { + "verb": "talking", + "interruptible": True, + "required_attention": "low", + "expected_duration": "ongoing", + }, + "attention": "", + "holding": [], + "status": {}, + }, + ) + append_event( + conn, + kind="activity_change", + payload={ + "entity_id": "bot_a", + "posture": "sitting", + "action": { + "verb": "listening", + "interruptible": True, + "required_attention": "low", + "expected_duration": "ongoing", + }, + "attention": "", + "holding": [], + "status": {}, + }, + ) + project(conn) + + +def test_post_turn_appends_edge_updates_and_applies_deltas(client, tmp_path): + """After a turn, edge_update events fire for both directions and project.""" + db_path = tmp_path / "test.db" + _seed(db_path) + + response = client.post("/chats/chat_bot_a/turns", data={"prose": "hello"}) + assert response.status_code == 204 + + with open_db(db_path) as conn: + # Two new edge_update events should land *after* the assistant_turn. + cur = conn.execute( + "SELECT kind, payload_json FROM event_log " + "WHERE kind = 'edge_update' " + "AND id > (SELECT MAX(id) FROM event_log WHERE kind = 'assistant_turn') " + "ORDER BY id" + ) + rows = cur.fetchall() + assert len(rows) == 2 + kinds = [r[0] for r in rows] + assert kinds == ["edge_update", "edge_update"] + + # Inspect the two payloads — one per direction. + payloads = [json.loads(r[1]) for r in rows] + directions = {(p["source_id"], p["target_id"]) for p in payloads} + assert ("bot_a", "you") in directions + assert ("you", "bot_a") in directions + + # Edge bot_a -> you: seeded affinity=50, plus delta 2 -> 52. + from chat.state.edges import get_edge + + edge_b2y = get_edge(conn, "bot_a", "you") + assert edge_b2y is not None + assert edge_b2y["affinity"] == 52 + assert edge_b2y["trust"] == 51 + # Existing fact preserved, new fact appended. + assert "coworker" in edge_b2y["knowledge"] + assert "greets warmly" in edge_b2y["knowledge"] + + # Edge you -> bot_a: defaults (50/50) plus delta +3 affinity -> 53. + edge_y2b = get_edge(conn, "you", "bot_a") + assert edge_y2b is not None + assert edge_y2b["affinity"] == 53 + assert edge_y2b["trust"] == 50 diff --git a/tests/test_turn_flow.py b/tests/test_turn_flow.py index b404fbb..1f7228b 100644 --- a/tests/test_turn_flow.py +++ b/tests/test_turn_flow.py @@ -36,11 +36,25 @@ def client(tmp_path, monkeypatch): {"segments": [{"kind": "dialogue", "text": "hello"}]} ) canned_response = "Hi there." + # Two state-update classifier calls fire after the assistant_turn + # (one per directed edge: bot->you, you->bot). We feed them benign + # zero-delta JSON so the existing assertions about ``user_turn`` / + # ``assistant_turn`` are unaffected. + canned_state_update = json.dumps( + {"affinity_delta": 0, "trust_delta": 0, "knowledge_facts": []} + ) # Import here so env vars are visible to the dependency lookup. from chat.web.kickoff import get_llm_client - mock = MockLLMClient(canned=[canned_parse, canned_response]) + mock = MockLLMClient( + canned=[ + canned_parse, + canned_response, + canned_state_update, + canned_state_update, + ] + ) app.dependency_overrides[get_llm_client] = lambda: mock with TestClient(app) as c: