merge: T83 regenerate.py polish bundle (cancel + DRY + scoped query + warning + ordering)
This commit is contained in:
+181
-110
@@ -68,7 +68,9 @@ Phase 2.5 changes:
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from sqlite3 import Connection
|
from sqlite3 import Connection
|
||||||
|
|
||||||
from chat.config import Settings
|
from chat.config import Settings
|
||||||
@@ -79,6 +81,10 @@ from chat.services.interjection import detect_interjection
|
|||||||
from chat.services.memory_write import record_turn_memory_for_present
|
from chat.services.memory_write import record_turn_memory_for_present
|
||||||
from chat.services.multi_state_update import compute_state_updates_for_present
|
from chat.services.multi_state_update import compute_state_updates_for_present
|
||||||
from chat.services.prompt import assemble_narrative_prompt
|
from chat.services.prompt import assemble_narrative_prompt
|
||||||
|
from chat.services.turn_common import (
|
||||||
|
gather_prior_edges,
|
||||||
|
read_recent_dialogue,
|
||||||
|
)
|
||||||
from chat.state.edges import get_edge
|
from chat.state.edges import get_edge
|
||||||
from chat.state.entities import get_bot, get_you
|
from chat.state.entities import get_bot, get_you
|
||||||
from chat.state.events import list_active_events
|
from chat.state.events import list_active_events
|
||||||
@@ -86,6 +92,8 @@ from chat.state.world import active_scene, get_chat
|
|||||||
from chat.web.pubsub import publish
|
from chat.web.pubsub import publish
|
||||||
from chat.web.render import render_turn_html
|
from chat.web.render import render_turn_html
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def regenerate_assistant_turn(
|
async def regenerate_assistant_turn(
|
||||||
conn: Connection,
|
conn: Connection,
|
||||||
@@ -104,6 +112,19 @@ async def regenerate_assistant_turn(
|
|||||||
|
|
||||||
Raises :class:`ValueError` when the chat or the assistant_turn event
|
Raises :class:`ValueError` when the chat or the assistant_turn event
|
||||||
cannot be found — the FastAPI route translates this to 404.
|
cannot be found — the FastAPI route translates this to 404.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
**Lifecycle-rollback limitation (T83.4, Phase 4 follow-up).**
|
||||||
|
When the superseded turn already produced lifecycle transitions
|
||||||
|
(``event_started`` / ``event_completed`` / ``event_cancelled``),
|
||||||
|
this function does NOT roll those rows back before re-running
|
||||||
|
``detect_event_transitions`` against the regenerated text. A
|
||||||
|
regenerate-after-completion can therefore double-emit promotion
|
||||||
|
artifacts if the new text re-completes the same event. Phase 3.5
|
||||||
|
only documents the gap and emits a WARNING log naming the
|
||||||
|
affected event_log ids; the actual undo pass is invasive
|
||||||
|
(re-projection / inverse-handler dispatch) and is deferred to
|
||||||
|
Phase 4. See the ``# T83.4`` block below for the warning emit.
|
||||||
"""
|
"""
|
||||||
chat = get_chat(conn, chat_id)
|
chat = get_chat(conn, chat_id)
|
||||||
if chat is None:
|
if chat is None:
|
||||||
@@ -136,6 +157,40 @@ async def regenerate_assistant_turn(
|
|||||||
original_assistant_payload = json.loads(row[0])
|
original_assistant_payload = json.loads(row[0])
|
||||||
original_user_turn_id = original_assistant_payload.get("user_turn_id")
|
original_user_turn_id = original_assistant_payload.get("user_turn_id")
|
||||||
|
|
||||||
|
# T83.4: scan for downstream lifecycle transitions emitted by the
|
||||||
|
# superseded turn — they're not being rolled back (see method
|
||||||
|
# docstring). Heuristic: any ``event_started`` / ``event_completed``
|
||||||
|
# / ``event_cancelled`` event_log row with id strictly greater than
|
||||||
|
# the original assistant_turn's id was emitted as part of (or after)
|
||||||
|
# that turn's processing. Lifecycle events don't carry ``chat_id``
|
||||||
|
# in their payload (their payload references an ``event_id`` FK to
|
||||||
|
# the ``events`` table, which holds chat_id), so we join through
|
||||||
|
# ``events`` to scope to this chat.
|
||||||
|
#
|
||||||
|
# A WARNING log surfaces the affected event ids so operators can
|
||||||
|
# spot double-emit cases until the Phase 4 rollback pass lands.
|
||||||
|
unrolled_lifecycle = conn.execute(
|
||||||
|
"SELECT el.id, el.kind FROM event_log AS el "
|
||||||
|
"JOIN events AS ev "
|
||||||
|
" ON ev.event_id = json_extract(el.payload_json, '$.event_id') "
|
||||||
|
"WHERE el.kind IN ("
|
||||||
|
" 'event_started', 'event_completed', 'event_cancelled'"
|
||||||
|
" ) "
|
||||||
|
" AND ev.chat_id = ? "
|
||||||
|
" AND el.id > ? "
|
||||||
|
"ORDER BY el.id ASC",
|
||||||
|
(chat_id, original_assistant_event_id),
|
||||||
|
).fetchall()
|
||||||
|
if unrolled_lifecycle:
|
||||||
|
_log.warning(
|
||||||
|
"regenerate_assistant_turn: %d lifecycle transition(s) from "
|
||||||
|
"superseded turn %s are NOT being rolled back (Phase 4 "
|
||||||
|
"follow-up). Affected event ids: %s",
|
||||||
|
len(unrolled_lifecycle),
|
||||||
|
original_assistant_event_id,
|
||||||
|
[r[0] for r in unrolled_lifecycle],
|
||||||
|
)
|
||||||
|
|
||||||
# 1a. Look up any sibling interjection beat in the same turn group
|
# 1a. Look up any sibling interjection beat in the same turn group
|
||||||
# (T73.2). The original group is (primary + optional interjection),
|
# (T73.2). The original group is (primary + optional interjection),
|
||||||
# both pinned to the same ``user_turn_id``. The interjection has a
|
# both pinned to the same ``user_turn_id``. The interjection has a
|
||||||
@@ -143,6 +198,13 @@ async def regenerate_assistant_turn(
|
|||||||
# the silent witness (the bot that wasn't the primary addressee).
|
# the silent witness (the bot that wasn't the primary addressee).
|
||||||
# Filter on ``superseded_by IS NULL`` so prior regenerates of this
|
# Filter on ``superseded_by IS NULL`` so prior regenerates of this
|
||||||
# group don't reappear as siblings.
|
# group don't reappear as siblings.
|
||||||
|
#
|
||||||
|
# T83.3: push the chat_id filter into SQL via ``json_extract`` so
|
||||||
|
# the query doesn't scan every assistant_turn row across the whole
|
||||||
|
# database. ``LIMIT 50`` bounds worst-case work even when chat_id
|
||||||
|
# isn't selective (e.g. a single chat with many turns) — we only
|
||||||
|
# need the one matching sibling. Mirrors the SQL pattern in
|
||||||
|
# ``chat.web.meanwhile._last_meanwhile_speaker``.
|
||||||
original_interjection_event_id: int | None = None
|
original_interjection_event_id: int | None = None
|
||||||
original_interjection_payload: dict | None = None
|
original_interjection_payload: dict | None = None
|
||||||
if original_user_turn_id is not None:
|
if original_user_turn_id is not None:
|
||||||
@@ -150,8 +212,11 @@ async def regenerate_assistant_turn(
|
|||||||
"SELECT id, payload_json FROM event_log "
|
"SELECT id, payload_json FROM event_log "
|
||||||
"WHERE kind = 'assistant_turn' "
|
"WHERE kind = 'assistant_turn' "
|
||||||
" AND id != ? "
|
" AND id != ? "
|
||||||
" AND superseded_by IS NULL",
|
" AND superseded_by IS NULL "
|
||||||
(original_assistant_event_id,),
|
" AND json_extract(payload_json, '$.chat_id') = ? "
|
||||||
|
"ORDER BY id DESC "
|
||||||
|
"LIMIT 50",
|
||||||
|
(original_assistant_event_id, chat_id),
|
||||||
)
|
)
|
||||||
for sib_id, sib_payload_json in sibling_cur.fetchall():
|
for sib_id, sib_payload_json in sibling_cur.fetchall():
|
||||||
sib_payload = json.loads(sib_payload_json)
|
sib_payload = json.loads(sib_payload_json)
|
||||||
@@ -208,33 +273,30 @@ async def regenerate_assistant_turn(
|
|||||||
# assistant_turn explicitly (we haven't superseded it yet — that
|
# assistant_turn explicitly (we haven't superseded it yet — that
|
||||||
# update lands at the end so the new event_id is known) and use the
|
# update lands at the end so the new event_id is known) and use the
|
||||||
# standard ``superseded_by IS NULL AND hidden = 0`` filter so any
|
# standard ``superseded_by IS NULL AND hidden = 0`` filter so any
|
||||||
# prior regenerates also drop out.
|
# prior regenerates also drop out. T83.2: shared helper handles the
|
||||||
|
# SQL + filtering; we post-process to map speaker ids to display
|
||||||
|
# names for the prompt.
|
||||||
you_entity = get_you(conn) or {"name": "you", "persona": ""}
|
you_entity = get_you(conn) or {"name": "you", "persona": ""}
|
||||||
you_name = you_entity.get("name", "you")
|
you_name = you_entity.get("name", "you")
|
||||||
cur = conn.execute(
|
raw_recent = read_recent_dialogue(
|
||||||
"SELECT id, kind, payload_json FROM event_log "
|
conn,
|
||||||
"WHERE kind IN ('user_turn', 'user_turn_edit', 'assistant_turn') "
|
chat_id,
|
||||||
" AND id != ? "
|
limit=20,
|
||||||
" AND superseded_by IS NULL AND hidden = 0 "
|
exclude_event_id=original_assistant_event_id,
|
||||||
"ORDER BY id DESC LIMIT 20",
|
|
||||||
(original_assistant_event_id,),
|
|
||||||
)
|
)
|
||||||
rows = list(reversed(cur.fetchall()))
|
|
||||||
recent: list[dict] = []
|
recent: list[dict] = []
|
||||||
for _eid, kind, payload_json in rows:
|
for entry in raw_recent:
|
||||||
p = json.loads(payload_json)
|
spk = entry.get("speaker", "bot")
|
||||||
if p.get("chat_id") != chat_id:
|
if spk == "you":
|
||||||
|
recent.append({"speaker": you_name, "text": entry.get("text", "")})
|
||||||
continue
|
continue
|
||||||
if kind in ("user_turn", "user_turn_edit"):
|
if spk == host_bot_id:
|
||||||
recent.append({"speaker": you_name, "text": p.get("prose", "")})
|
|
||||||
else:
|
|
||||||
spk = p.get("speaker_id", "bot")
|
|
||||||
spk_name = host_bot.get("name", "bot")
|
spk_name = host_bot.get("name", "bot")
|
||||||
if spk == host_bot_id:
|
elif guest_bot is not None and spk == guest_bot.get("id"):
|
||||||
spk_name = host_bot.get("name", "bot")
|
spk_name = guest_bot.get("name", "bot")
|
||||||
elif guest_bot is not None and spk == guest_bot.get("id"):
|
else:
|
||||||
spk_name = guest_bot.get("name", "bot")
|
spk_name = host_bot.get("name", "bot")
|
||||||
recent.append({"speaker": spk_name, "text": p.get("text", "")})
|
recent.append({"speaker": spk_name, "text": entry.get("text", "")})
|
||||||
|
|
||||||
# 4. Assemble the narrative prompt. ``recent`` already excludes the
|
# 4. Assemble the narrative prompt. ``recent`` already excludes the
|
||||||
# current user prose, which we pass through ``user_turn_prose``.
|
# current user prose, which we pass through ``user_turn_prose``.
|
||||||
@@ -250,19 +312,37 @@ async def regenerate_assistant_turn(
|
|||||||
guest_id=guest_bot_id,
|
guest_id=guest_bot_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. Stream the new narrative.
|
# 5. Stream the new narrative. T83.1: register the streaming Task in
|
||||||
|
# the chat-keyed in-flight registry so POST /chats/<id>/turns/cancel
|
||||||
|
# can call ``.cancel()`` on a mid-regenerate stream. We import the
|
||||||
|
# underscore name from turns.py deliberately — same single-process
|
||||||
|
# registry the cancel route reads, mirrors the meanwhile registration
|
||||||
|
# pattern in chat/web/meanwhile.py.
|
||||||
|
from chat.web.turns import _in_flight_tasks # noqa: PLC0415
|
||||||
|
|
||||||
accumulated: list[str] = []
|
accumulated: list[str] = []
|
||||||
async for chunk in client.stream(
|
|
||||||
messages,
|
async def _stream_primary() -> None:
|
||||||
model=settings.narrative_model,
|
async for chunk in client.stream(
|
||||||
max_tokens=settings.narrative_max_tokens,
|
messages,
|
||||||
temperature=settings.narrative_temperature,
|
model=settings.narrative_model,
|
||||||
):
|
max_tokens=settings.narrative_max_tokens,
|
||||||
accumulated.append(chunk)
|
temperature=settings.narrative_temperature,
|
||||||
await publish(
|
):
|
||||||
chat_id,
|
accumulated.append(chunk)
|
||||||
{"event": "token", "text": chunk, "speaker_id": speaker_bot_id},
|
await publish(
|
||||||
)
|
chat_id,
|
||||||
|
{"event": "token", "text": chunk, "speaker_id": speaker_bot_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
stream_task = asyncio.create_task(_stream_primary())
|
||||||
|
_in_flight_tasks[chat_id] = stream_task
|
||||||
|
try:
|
||||||
|
await stream_task
|
||||||
|
finally:
|
||||||
|
# Always unregister so a subsequent turn / regenerate can register
|
||||||
|
# a fresh task. Mirrors the cleanup in turns.py::post_turn.
|
||||||
|
_in_flight_tasks.pop(chat_id, None)
|
||||||
new_text = "".join(accumulated)
|
new_text = "".join(accumulated)
|
||||||
|
|
||||||
# 6. Append the new assistant_turn event. ``user_turn_id`` points at
|
# 6. Append the new assistant_turn event. ``user_turn_id`` points at
|
||||||
@@ -354,17 +434,8 @@ async def regenerate_assistant_turn(
|
|||||||
present_names[guest_bot_id] = guest_bot.get("name", "bot")
|
present_names[guest_bot_id] = guest_bot.get("name", "bot")
|
||||||
personas[guest_bot_id] = guest_bot.get("persona") or ""
|
personas[guest_bot_id] = guest_bot.get("persona") or ""
|
||||||
|
|
||||||
prior_edges: dict[tuple[str, str], dict] = {}
|
# T83.2: shared helper builds the directed-pair edge dict.
|
||||||
for src in present_ids:
|
prior_edges = gather_prior_edges(conn, present_ids)
|
||||||
for tgt in present_ids:
|
|
||||||
if src == tgt:
|
|
||||||
continue
|
|
||||||
edge = get_edge(conn, src, tgt) or {
|
|
||||||
"affinity": 50,
|
|
||||||
"trust": 50,
|
|
||||||
"summary": "",
|
|
||||||
}
|
|
||||||
prior_edges[(src, tgt)] = edge
|
|
||||||
|
|
||||||
state_updates = await compute_state_updates_for_present(
|
state_updates = await compute_state_updates_for_present(
|
||||||
client,
|
client,
|
||||||
@@ -453,34 +524,27 @@ async def regenerate_assistant_turn(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if decision.should_interject:
|
if decision.should_interject:
|
||||||
# Re-read recent so the just-appended primary is in the prompt.
|
# Re-read recent so the just-appended primary is in the
|
||||||
interject_cur = conn.execute(
|
# prompt. T83.2: shared helper + the same id->name mapping
|
||||||
"SELECT id, kind, payload_json FROM event_log "
|
# as the primary read above.
|
||||||
"WHERE kind IN ('user_turn', 'user_turn_edit', 'assistant_turn') "
|
raw_interject = read_recent_dialogue(conn, chat_id, limit=20)
|
||||||
" AND superseded_by IS NULL AND hidden = 0 "
|
|
||||||
"ORDER BY id DESC LIMIT 20",
|
|
||||||
)
|
|
||||||
interject_rows = list(reversed(interject_cur.fetchall()))
|
|
||||||
interject_recent: list[dict] = []
|
interject_recent: list[dict] = []
|
||||||
for _eid, kind, payload_json in interject_rows:
|
for entry in raw_interject:
|
||||||
p = json.loads(payload_json)
|
spk = entry.get("speaker", "bot")
|
||||||
if p.get("chat_id") != chat_id:
|
if spk == "you":
|
||||||
|
interject_recent.append(
|
||||||
|
{"speaker": you_name, "text": entry.get("text", "")}
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
if kind in ("user_turn", "user_turn_edit"):
|
if spk == host_bot_id:
|
||||||
interject_recent.append(
|
spk_name = host_bot.get("name", "bot")
|
||||||
{"speaker": you_name, "text": p.get("prose", "")}
|
elif spk == guest_bot.get("id"):
|
||||||
)
|
spk_name = guest_bot.get("name", "bot")
|
||||||
else:
|
else:
|
||||||
spk = p.get("speaker_id", "bot")
|
spk_name = "bot"
|
||||||
if spk == host_bot_id:
|
interject_recent.append(
|
||||||
spk_name = host_bot.get("name", "bot")
|
{"speaker": spk_name, "text": entry.get("text", "")}
|
||||||
elif spk == guest_bot.get("id"):
|
)
|
||||||
spk_name = guest_bot.get("name", "bot")
|
|
||||||
else:
|
|
||||||
spk_name = "bot"
|
|
||||||
interject_recent.append(
|
|
||||||
{"speaker": spk_name, "text": p.get("text", "")}
|
|
||||||
)
|
|
||||||
if interject_recent and interject_recent[-1].get("speaker") == you_name:
|
if interject_recent and interject_recent[-1].get("speaker") == you_name:
|
||||||
interject_recent = interject_recent[:-1]
|
interject_recent = interject_recent[:-1]
|
||||||
|
|
||||||
@@ -497,21 +561,32 @@ async def regenerate_assistant_turn(
|
|||||||
)
|
)
|
||||||
|
|
||||||
interject_accumulated: list[str] = []
|
interject_accumulated: list[str] = []
|
||||||
async for chunk in client.stream(
|
|
||||||
interject_messages,
|
async def _stream_interjection() -> None:
|
||||||
model=settings.narrative_model,
|
async for chunk in client.stream(
|
||||||
max_tokens=settings.narrative_max_tokens,
|
interject_messages,
|
||||||
temperature=settings.narrative_temperature,
|
model=settings.narrative_model,
|
||||||
):
|
max_tokens=settings.narrative_max_tokens,
|
||||||
interject_accumulated.append(chunk)
|
temperature=settings.narrative_temperature,
|
||||||
await publish(
|
):
|
||||||
chat_id,
|
interject_accumulated.append(chunk)
|
||||||
{
|
await publish(
|
||||||
"event": "token",
|
chat_id,
|
||||||
"text": chunk,
|
{
|
||||||
"speaker_id": silent_witness_id,
|
"event": "token",
|
||||||
},
|
"text": chunk,
|
||||||
)
|
"speaker_id": silent_witness_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# T83.1: register the interjection sub-stream in the same
|
||||||
|
# in-flight registry so /turns/cancel collapses it too.
|
||||||
|
interject_task = asyncio.create_task(_stream_interjection())
|
||||||
|
_in_flight_tasks[chat_id] = interject_task
|
||||||
|
try:
|
||||||
|
await interject_task
|
||||||
|
finally:
|
||||||
|
_in_flight_tasks.pop(chat_id, None)
|
||||||
interject_text = "".join(interject_accumulated)
|
interject_text = "".join(interject_accumulated)
|
||||||
|
|
||||||
new_interjection_event_id = append_event(
|
new_interjection_event_id = append_event(
|
||||||
@@ -573,17 +648,8 @@ async def regenerate_assistant_turn(
|
|||||||
"text": interject_text,
|
"text": interject_text,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
prior_edges_post: dict[tuple[str, str], dict] = {}
|
# T83.2: shared helper handles the directed-pair edge dict.
|
||||||
for src in present_ids:
|
prior_edges_post = gather_prior_edges(conn, present_ids)
|
||||||
for tgt in present_ids:
|
|
||||||
if src == tgt:
|
|
||||||
continue
|
|
||||||
edge = get_edge(conn, src, tgt) or {
|
|
||||||
"affinity": 50,
|
|
||||||
"trust": 50,
|
|
||||||
"summary": "",
|
|
||||||
}
|
|
||||||
prior_edges_post[(src, tgt)] = edge
|
|
||||||
|
|
||||||
state_updates_post = await compute_state_updates_for_present(
|
state_updates_post = await compute_state_updates_for_present(
|
||||||
client,
|
client,
|
||||||
@@ -620,23 +686,28 @@ async def regenerate_assistant_turn(
|
|||||||
(new_assistant_event_id, original_interjection_event_id),
|
(new_assistant_event_id, original_interjection_event_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 10. Event-lifecycle detection (Phase 3, T61). Mirrors the post_turn
|
# 9a. Event-lifecycle detection (Phase 3, T61). T83.5 cosmetic
|
||||||
# block: classify whether any active events transitioned in the
|
# ordering: mirrors ``chat.web.turns.post_turn``'s 8a block — runs
|
||||||
# regenerated narrative and append the corresponding event_started /
|
# AFTER the interjection branch (and AFTER the post-interjection
|
||||||
|
# state-update + memory passes) so the classifier sees the same
|
||||||
|
# narrative-text input post_turn does. Numbering uses ``9a`` to
|
||||||
|
# match post_turn's ``8a`` shape (the interjection branch is step 9
|
||||||
|
# in regenerate vs step 8 in post_turn; lifecycle is the immediate
|
||||||
|
# follow-on in both). Behaviour identical to the prior ``step 10``
|
||||||
|
# placement — the block was already structurally last in regenerate
|
||||||
|
# because there's no scene-close pass here.
|
||||||
|
#
|
||||||
|
# Classify whether any active events transitioned in the regenerated
|
||||||
|
# narrative and append the corresponding event_started /
|
||||||
# event_completed / event_cancelled. ``promote_completed_event``
|
# event_completed / event_cancelled. ``promote_completed_event``
|
||||||
# runs inline after a completion so promotion artifacts land in the
|
# runs inline after a completion so promotion artifacts land in the
|
||||||
# same regenerate path.
|
# same regenerate path.
|
||||||
#
|
#
|
||||||
# Phase 3.5 follow-up: when a regenerate replaces a turn that had
|
# T83.4 follow-up: when a regenerate replaces a turn that had
|
||||||
# already produced event transitions, those original transitions are
|
# already produced event transitions, those original transitions
|
||||||
# NOT undone here. The superseded ``assistant_turn`` group keeps its
|
# are NOT undone here (Phase 4 work). A WARNING log earlier in this
|
||||||
# prior ``event_started`` / ``event_completed`` events in the log
|
# function names the affected event_log ids — see the T83.4 block
|
||||||
# (they remain projected onto the events table). Phase 3.5 will add
|
# near the function entry.
|
||||||
# an "undo lifecycle" step to roll back the prior transitions before
|
|
||||||
# re-classifying the regenerated text. For v3 we accept that a
|
|
||||||
# regenerate-after-completion will double-emit promotion artifacts
|
|
||||||
# if the new text re-completes the same event — narratively rare,
|
|
||||||
# and a true fix needs the lifecycle-undo pass.
|
|
||||||
new_active_events = list_active_events(conn, chat_id)
|
new_active_events = list_active_events(conn, chat_id)
|
||||||
if new_active_events:
|
if new_active_events:
|
||||||
lifecycle_decision = await detect_event_transitions(
|
lifecycle_decision = await detect_event_transitions(
|
||||||
|
|||||||
@@ -0,0 +1,118 @@
|
|||||||
|
"""Shared helpers for turn flows (T83.2).
|
||||||
|
|
||||||
|
Both ``chat.web.turns.post_turn`` and
|
||||||
|
``chat.services.regenerate.regenerate_assistant_turn`` need to:
|
||||||
|
|
||||||
|
1. Pull a chronological tail of user-side and assistant_turn events for
|
||||||
|
prompt assembly + state-update inputs.
|
||||||
|
2. Build a directed-edge dict over a fixed set of "present" entity ids
|
||||||
|
for the multi-pair state-update pass (with the schema 50/50 default
|
||||||
|
filled in for missing rows).
|
||||||
|
|
||||||
|
Before T83.2 each call site had its own copy of these blocks. The two
|
||||||
|
copies drifted on details (T73.1 added ``user_turn_edit`` handling to
|
||||||
|
turns.py; regenerate.py had a slightly different recent-window query).
|
||||||
|
This module is the single source so a future change to either lands in
|
||||||
|
both flows by construction.
|
||||||
|
|
||||||
|
Note on overlap with ``chat.services.scene_summarize._read_recent_dialogue``:
|
||||||
|
that helper has a ``since_event_id`` clamp (T80.2 thread-detection
|
||||||
|
scope) and intentionally does NOT include ``user_turn_edit`` events —
|
||||||
|
its callers want the *original* prose, not edits. Deduplicating it
|
||||||
|
into here would either (a) require a new flag on the shared helper for
|
||||||
|
``user_turn_edit`` inclusion, or (b) silently change scene_summarize's
|
||||||
|
read shape. Both feel more invasive than the duplication is bad, so
|
||||||
|
that helper is left alone for now.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from sqlite3 import Connection
|
||||||
|
|
||||||
|
from chat.state.edges import get_edge
|
||||||
|
|
||||||
|
|
||||||
|
def read_recent_dialogue(
|
||||||
|
conn: Connection,
|
||||||
|
chat_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 50,
|
||||||
|
exclude_event_id: int | None = None,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Pull the last ``limit`` user-side / assistant_turn events for
|
||||||
|
``chat_id`` as ``[{"speaker": <id-or-"you">, "text": <prose>}]``,
|
||||||
|
chronologically ordered (oldest first).
|
||||||
|
|
||||||
|
Filters: ``superseded_by IS NULL AND hidden = 0`` — regenerated
|
||||||
|
rows drop out so the timeline reflects the current state. Includes
|
||||||
|
``user_turn``, ``user_turn_edit`` (T29 edited prose substitutes for
|
||||||
|
the original — the original is marked superseded above), and
|
||||||
|
``assistant_turn`` rows.
|
||||||
|
|
||||||
|
``exclude_event_id`` is an optional event_log id to skip — used by
|
||||||
|
regenerate to drop the original assistant_turn from its prompt
|
||||||
|
context window before that row has been marked superseded (the
|
||||||
|
supersede UPDATE lands at the end so the new event_id is known).
|
||||||
|
"""
|
||||||
|
if exclude_event_id is None:
|
||||||
|
cur = conn.execute(
|
||||||
|
"SELECT id, kind, payload_json FROM event_log "
|
||||||
|
"WHERE kind IN ('user_turn', 'user_turn_edit', 'assistant_turn') "
|
||||||
|
" AND superseded_by IS NULL AND hidden = 0 "
|
||||||
|
"ORDER BY id DESC LIMIT ?",
|
||||||
|
(limit,),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cur = conn.execute(
|
||||||
|
"SELECT id, kind, payload_json FROM event_log "
|
||||||
|
"WHERE kind IN ('user_turn', 'user_turn_edit', 'assistant_turn') "
|
||||||
|
" AND id != ? "
|
||||||
|
" AND superseded_by IS NULL AND hidden = 0 "
|
||||||
|
"ORDER BY id DESC LIMIT ?",
|
||||||
|
(exclude_event_id, limit),
|
||||||
|
)
|
||||||
|
rows = list(reversed(cur.fetchall()))
|
||||||
|
out: list[dict] = []
|
||||||
|
for _row_id, kind, payload_json in rows:
|
||||||
|
p = json.loads(payload_json)
|
||||||
|
if p.get("chat_id") != chat_id:
|
||||||
|
continue
|
||||||
|
if kind in ("user_turn", "user_turn_edit"):
|
||||||
|
out.append({"speaker": "you", "text": p.get("prose", "")})
|
||||||
|
else:
|
||||||
|
out.append(
|
||||||
|
{
|
||||||
|
"speaker": p.get("speaker_id", "bot"),
|
||||||
|
"text": p.get("text", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def gather_prior_edges(
|
||||||
|
conn: Connection, present_ids: list[str]
|
||||||
|
) -> dict[tuple[str, str], dict]:
|
||||||
|
"""Build ``{(src, tgt): {affinity, trust, summary}}`` for every
|
||||||
|
directed pair where both ``src`` and ``tgt`` are in ``present_ids``
|
||||||
|
and ``src != tgt``.
|
||||||
|
|
||||||
|
Missing rows fall back to the schema default 50/50 baseline (mirrors
|
||||||
|
the Phase 1 single-pair flow). Used by post_turn and regenerate to
|
||||||
|
seed the multi-pair state-update classifier.
|
||||||
|
"""
|
||||||
|
prior_edges: dict[tuple[str, str], dict] = {}
|
||||||
|
for src in present_ids:
|
||||||
|
for tgt in present_ids:
|
||||||
|
if src == tgt:
|
||||||
|
continue
|
||||||
|
edge = get_edge(conn, src, tgt) or {
|
||||||
|
"affinity": 50,
|
||||||
|
"trust": 50,
|
||||||
|
"summary": "",
|
||||||
|
}
|
||||||
|
prior_edges[(src, tgt)] = edge
|
||||||
|
return prior_edges
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["read_recent_dialogue", "gather_prior_edges"]
|
||||||
+12
-42
@@ -71,6 +71,10 @@ from chat.services.prompt import (
|
|||||||
from chat.services.rewind import compute_rewind_preview, execute_rewind
|
from chat.services.rewind import compute_rewind_preview, execute_rewind
|
||||||
from chat.services.scene_close import detect_scene_close
|
from chat.services.scene_close import detect_scene_close
|
||||||
from chat.services.scene_summarize import apply_scene_close_summary
|
from chat.services.scene_summarize import apply_scene_close_summary
|
||||||
|
from chat.services.turn_common import (
|
||||||
|
gather_prior_edges,
|
||||||
|
read_recent_dialogue,
|
||||||
|
)
|
||||||
from chat.services.turn_parse import ParsedTurn, parse_turn
|
from chat.services.turn_parse import ParsedTurn, parse_turn
|
||||||
from chat.state.edges import get_edge
|
from chat.state.edges import get_edge
|
||||||
from chat.state.entities import get_bot, get_you
|
from chat.state.entities import get_bot, get_you
|
||||||
@@ -113,38 +117,13 @@ def _strip_ooc_for_prompt(parsed: ParsedTurn) -> str:
|
|||||||
def _read_recent_dialogue(conn, chat_id: str, limit: int = 200) -> list[dict]:
|
def _read_recent_dialogue(conn, chat_id: str, limit: int = 200) -> list[dict]:
|
||||||
"""Return user-side and assistant_turn events for ``chat_id``.
|
"""Return user-side and assistant_turn events for ``chat_id``.
|
||||||
|
|
||||||
Includes ``user_turn``, ``user_turn_edit`` (T29 edited prose), and
|
T83.2: thin delegate over
|
||||||
``assistant_turn``. Ordered oldest-first; superseded/hidden rows are
|
:func:`chat.services.turn_common.read_recent_dialogue` so post_turn
|
||||||
skipped so regenerated turns (T29) drop out of the rendered timeline.
|
and regenerate share one implementation. The wrapper survives so
|
||||||
Each entry is shaped ``{"speaker": <id-or-"you">, "text": <prose>}``
|
the chat-detail template and other callers in this module don't all
|
||||||
for the prompt assembler and the chat-detail template.
|
have to update at once.
|
||||||
"""
|
"""
|
||||||
cur = conn.execute(
|
return read_recent_dialogue(conn, chat_id, limit=limit)
|
||||||
"SELECT id, kind, payload_json FROM event_log "
|
|
||||||
"WHERE kind IN ('user_turn', 'user_turn_edit', 'assistant_turn') "
|
|
||||||
" AND superseded_by IS NULL AND hidden = 0 "
|
|
||||||
"ORDER BY id DESC LIMIT ?",
|
|
||||||
(limit,),
|
|
||||||
)
|
|
||||||
rows = cur.fetchall()
|
|
||||||
rows.reverse() # back to chronological order
|
|
||||||
out: list[dict] = []
|
|
||||||
for _row_id, kind, payload_json in rows:
|
|
||||||
p = json.loads(payload_json)
|
|
||||||
if p.get("chat_id") != chat_id:
|
|
||||||
continue
|
|
||||||
if kind in ("user_turn", "user_turn_edit"):
|
|
||||||
# Edited prose substitutes for the original user_turn (the
|
|
||||||
# original is marked superseded_by and filtered above).
|
|
||||||
out.append({"speaker": "you", "text": p.get("prose", "")})
|
|
||||||
else:
|
|
||||||
out.append(
|
|
||||||
{
|
|
||||||
"speaker": p.get("speaker_id", "bot"),
|
|
||||||
"text": p.get("text", ""),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def _detect_addressee_id(
|
def _detect_addressee_id(
|
||||||
@@ -211,17 +190,8 @@ def _gather_state_update_inputs(
|
|||||||
present_names[guest_bot["id"]] = guest_bot["name"]
|
present_names[guest_bot["id"]] = guest_bot["name"]
|
||||||
personas[guest_bot["id"]] = guest_bot.get("persona") or ""
|
personas[guest_bot["id"]] = guest_bot.get("persona") or ""
|
||||||
|
|
||||||
prior_edges: dict[tuple[str, str], dict] = {}
|
# T83.2: directed-edge gather is shared with regenerate.py.
|
||||||
for src in present_ids:
|
prior_edges = gather_prior_edges(conn, present_ids)
|
||||||
for tgt in present_ids:
|
|
||||||
if src == tgt:
|
|
||||||
continue
|
|
||||||
edge = get_edge(conn, src, tgt) or {
|
|
||||||
"affinity": 50,
|
|
||||||
"trust": 50,
|
|
||||||
"summary": "",
|
|
||||||
}
|
|
||||||
prior_edges[(src, tgt)] = edge
|
|
||||||
return present_ids, present_names, personas, prior_edges
|
return present_ids, present_names, personas, prior_edges
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -662,3 +662,356 @@ def test_regenerate_drops_interjection_when_classifier_returns_false(
|
|||||||
new_primary_payload = json.loads(cur[0][0])
|
new_primary_payload = json.loads(cur[0][0])
|
||||||
assert new_primary_payload["text"] == "New primary text."
|
assert new_primary_payload["text"] == "New primary text."
|
||||||
assert "interjection_of" not in new_primary_payload
|
assert "interjection_of" not in new_primary_payload
|
||||||
|
|
||||||
|
|
||||||
|
def test_regenerate_with_prior_lifecycle_logs_warning(tmp_path, monkeypatch, caplog):
|
||||||
|
"""T83.4: when the superseded assistant_turn already produced
|
||||||
|
lifecycle transitions (event_started / event_completed /
|
||||||
|
event_cancelled), regenerate emits a WARNING naming the un-rolled-
|
||||||
|
back transitions. Phase 3.5 documents the gap; the actual rollback
|
||||||
|
is Phase 4 work.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from chat.config import Settings
|
||||||
|
from chat.db.migrate import apply_migrations
|
||||||
|
from chat.eventlog.log import append_and_apply
|
||||||
|
from chat.services.regenerate import regenerate_assistant_turn
|
||||||
|
|
||||||
|
db_path = tmp_path / "test.db"
|
||||||
|
cfg = tmp_path / "config.toml"
|
||||||
|
cfg.write_text('featherless_api_key = "test"\n')
|
||||||
|
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
|
||||||
|
monkeypatch.setenv("CHAT_DB_PATH", str(db_path))
|
||||||
|
apply_migrations(db_path)
|
||||||
|
|
||||||
|
_ut_id, at_id = _seed_with_one_turn(db_path)
|
||||||
|
|
||||||
|
# After the assistant_turn lands, simulate that the turn flow
|
||||||
|
# produced an event_completed transition. ``append_and_apply`` is
|
||||||
|
# the standard path so the events projection updates.
|
||||||
|
with open_db(db_path) as conn:
|
||||||
|
append_and_apply(
|
||||||
|
conn,
|
||||||
|
kind="event_planned",
|
||||||
|
payload={
|
||||||
|
"event_id": "evt_x",
|
||||||
|
"chat_id": "chat_bot_a",
|
||||||
|
"kind": "story_event",
|
||||||
|
"props": {},
|
||||||
|
"planned_for": "2026-04-30T18:00:00+00:00",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
append_and_apply(
|
||||||
|
conn,
|
||||||
|
kind="event_started",
|
||||||
|
payload={
|
||||||
|
"event_id": "evt_x",
|
||||||
|
"started_at": "2026-04-30T19:00:00+00:00",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
completed_id = append_and_apply(
|
||||||
|
conn,
|
||||||
|
kind="event_completed",
|
||||||
|
payload={
|
||||||
|
"event_id": "evt_x",
|
||||||
|
"completed_at": "2026-04-30T19:30:00+00:00",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert completed_id is not None
|
||||||
|
|
||||||
|
state_canned = json.dumps(
|
||||||
|
{"affinity_delta": 0, "trust_delta": 0, "knowledge_facts": []}
|
||||||
|
)
|
||||||
|
mock_client = MockLLMClient(
|
||||||
|
canned=["Refreshed reply.", state_canned, state_canned]
|
||||||
|
)
|
||||||
|
settings = Settings(featherless_api_key="test")
|
||||||
|
|
||||||
|
caplog.set_level(logging.WARNING, logger="chat.services.regenerate")
|
||||||
|
|
||||||
|
with open_db(db_path) as conn:
|
||||||
|
asyncio.run(
|
||||||
|
regenerate_assistant_turn(
|
||||||
|
conn,
|
||||||
|
mock_client,
|
||||||
|
settings=settings,
|
||||||
|
chat_id="chat_bot_a",
|
||||||
|
original_assistant_event_id=at_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# The warning records the count and at least one of the affected
|
||||||
|
# event_log ids (event_started + event_completed = at minimum 2).
|
||||||
|
warnings = [
|
||||||
|
r for r in caplog.records if r.levelname == "WARNING"
|
||||||
|
]
|
||||||
|
matching = [w for w in warnings if "lifecycle transition" in w.getMessage()]
|
||||||
|
assert matching, (
|
||||||
|
"expected a WARNING about un-rolled-back lifecycle transitions; "
|
||||||
|
f"got: {[w.getMessage() for w in warnings]}"
|
||||||
|
)
|
||||||
|
msg = matching[0].getMessage()
|
||||||
|
# Reference the original superseded turn's id and the event_completed
|
||||||
|
# row's id.
|
||||||
|
assert str(at_id) in msg
|
||||||
|
assert str(completed_id) in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_regenerate_sibling_lookup_scoped_to_chat(tmp_path, monkeypatch):
|
||||||
|
"""T83.3: regenerate's sibling-interjection lookup is scoped to the
|
||||||
|
chat being regenerated.
|
||||||
|
|
||||||
|
Setup: TWO chats, each with a primary + interjection turn group whose
|
||||||
|
rows happen to share the same ``user_turn_id`` value (the projector
|
||||||
|
assigns event_log ids monotonically across the whole database, so
|
||||||
|
when each chat is seeded back-to-back the chat A primary lands on a
|
||||||
|
different ``user_turn_id`` than chat B's — but in older versions the
|
||||||
|
sibling query had no chat predicate, so it could in principle latch
|
||||||
|
onto a row from a different chat if ids collided in some unusual
|
||||||
|
flow). We construct the seeding so chat B's interjection has the
|
||||||
|
SAME ``interjection_of`` value as the chat A primary's speaker_id —
|
||||||
|
pre-T83.3 the global query could have picked it up.
|
||||||
|
|
||||||
|
Assert: regenerating the chat A primary leaves chat B's rows
|
||||||
|
untouched (no supersede), and the regenerated chat A turn group's
|
||||||
|
interjection (the only one regenerate should regenerate) has its
|
||||||
|
``regenerated_from`` pointing at the chat A original interjection,
|
||||||
|
not chat B's.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from chat.config import Settings
|
||||||
|
from chat.db.migrate import apply_migrations
|
||||||
|
from chat.services import regenerate as regenerate_module
|
||||||
|
from chat.services.interjection import InterjectionDecision
|
||||||
|
from chat.services.regenerate import regenerate_assistant_turn
|
||||||
|
|
||||||
|
db_path = tmp_path / "test.db"
|
||||||
|
cfg = tmp_path / "config.toml"
|
||||||
|
cfg.write_text('featherless_api_key = "test"\n')
|
||||||
|
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
|
||||||
|
monkeypatch.setenv("CHAT_DB_PATH", str(db_path))
|
||||||
|
apply_migrations(db_path)
|
||||||
|
|
||||||
|
# Seed chat A's interjection group.
|
||||||
|
a_ut_id, a_primary_id, a_interjection_id = _seed_with_interjection_group(
|
||||||
|
db_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Seed chat B with the same shape but a different chat_id and bot
|
||||||
|
# ids, then add an interjection group whose ``interjection_of``
|
||||||
|
# points at "bot_a" so a global (unscoped) query could collide.
|
||||||
|
with open_db(db_path) as conn:
|
||||||
|
for bot_id, name in (("bot_c", "BotC"), ("bot_d", "BotD")):
|
||||||
|
append_event(
|
||||||
|
conn,
|
||||||
|
kind="bot_authored",
|
||||||
|
payload={
|
||||||
|
"id": bot_id,
|
||||||
|
"name": name,
|
||||||
|
"persona": "",
|
||||||
|
"voice_samples": [],
|
||||||
|
"traits": [],
|
||||||
|
"backstory": "",
|
||||||
|
"initial_relationship_to_you": "",
|
||||||
|
"kickoff_prose": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
append_event(
|
||||||
|
conn,
|
||||||
|
kind="chat_created",
|
||||||
|
payload={
|
||||||
|
"id": "chat_other",
|
||||||
|
"host_bot_id": "bot_c",
|
||||||
|
"guest_bot_id": "bot_d",
|
||||||
|
"initial_time": "2026-04-26T20:00:00+00:00",
|
||||||
|
"narrative_anchor": "Day 1",
|
||||||
|
"weather": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
b_ut_id = append_event(
|
||||||
|
conn,
|
||||||
|
kind="user_turn",
|
||||||
|
payload={
|
||||||
|
"chat_id": "chat_other",
|
||||||
|
"prose": "different chat",
|
||||||
|
"segments": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
b_primary_id = append_event(
|
||||||
|
conn,
|
||||||
|
kind="assistant_turn",
|
||||||
|
payload={
|
||||||
|
"chat_id": "chat_other",
|
||||||
|
"speaker_id": "bot_c",
|
||||||
|
"text": "Other primary.",
|
||||||
|
"truncated": False,
|
||||||
|
"user_turn_id": b_ut_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# The chat B interjection's ``interjection_of`` references
|
||||||
|
# "bot_a" — the chat A primary's speaker. Pre-T83.3 the global
|
||||||
|
# sibling query could mis-match this row.
|
||||||
|
b_interjection_id = append_event(
|
||||||
|
conn,
|
||||||
|
kind="assistant_turn",
|
||||||
|
payload={
|
||||||
|
"chat_id": "chat_other",
|
||||||
|
"speaker_id": "bot_d",
|
||||||
|
"text": "Cross-chat noise.",
|
||||||
|
"truncated": False,
|
||||||
|
"user_turn_id": b_ut_id,
|
||||||
|
"interjection_of": "bot_a",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stub the interjection classifier to return True so the regenerate
|
||||||
|
# actively walks the sibling-discovery path.
|
||||||
|
async def _stub_should_interject(*_args, **_kwargs):
|
||||||
|
return InterjectionDecision(should_interject=True, reason="fired")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
regenerate_module, "detect_interjection", _stub_should_interject
|
||||||
|
)
|
||||||
|
|
||||||
|
state_canned = json.dumps(
|
||||||
|
{"affinity_delta": 0, "trust_delta": 0, "knowledge_facts": []}
|
||||||
|
)
|
||||||
|
canned: list[str] = (
|
||||||
|
["New chat A primary."]
|
||||||
|
+ [state_canned] * 6
|
||||||
|
+ ["New chat A interjection."]
|
||||||
|
+ [state_canned] * 6
|
||||||
|
)
|
||||||
|
mock_client = MockLLMClient(canned=list(canned))
|
||||||
|
settings = Settings(featherless_api_key="test")
|
||||||
|
|
||||||
|
with open_db(db_path) as conn:
|
||||||
|
new_text = asyncio.run(
|
||||||
|
regenerate_assistant_turn(
|
||||||
|
conn,
|
||||||
|
mock_client,
|
||||||
|
settings=settings,
|
||||||
|
chat_id="chat_multi",
|
||||||
|
original_assistant_event_id=a_primary_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert new_text == "New chat A primary."
|
||||||
|
|
||||||
|
# Chat B rows are untouched — neither superseded nor referenced.
|
||||||
|
b_primary_super = conn.execute(
|
||||||
|
"SELECT superseded_by FROM event_log WHERE id = ?",
|
||||||
|
(b_primary_id,),
|
||||||
|
).fetchone()[0]
|
||||||
|
b_interjection_super = conn.execute(
|
||||||
|
"SELECT superseded_by FROM event_log WHERE id = ?",
|
||||||
|
(b_interjection_id,),
|
||||||
|
).fetchone()[0]
|
||||||
|
assert b_primary_super is None
|
||||||
|
assert b_interjection_super is None
|
||||||
|
|
||||||
|
# Chat A's regenerated interjection has its ``regenerated_from``
|
||||||
|
# pointing at chat A's original interjection — NOT chat B's.
|
||||||
|
cur = conn.execute(
|
||||||
|
"SELECT payload_json FROM event_log "
|
||||||
|
"WHERE kind = 'assistant_turn' "
|
||||||
|
" AND id NOT IN (?, ?, ?, ?) "
|
||||||
|
" AND superseded_by IS NULL",
|
||||||
|
(a_primary_id, a_interjection_id, b_primary_id, b_interjection_id),
|
||||||
|
).fetchall()
|
||||||
|
# Two new rows: regenerated primary + regenerated interjection.
|
||||||
|
assert len(cur) == 2
|
||||||
|
payloads = [json.loads(row[0]) for row in cur]
|
||||||
|
# Find the regenerated interjection (carries interjection_of).
|
||||||
|
new_interject_payloads = [
|
||||||
|
p for p in payloads if p.get("interjection_of")
|
||||||
|
]
|
||||||
|
assert len(new_interject_payloads) == 1
|
||||||
|
assert new_interject_payloads[0]["regenerated_from"] == a_interjection_id
|
||||||
|
# Pin chat scope on every new row.
|
||||||
|
for p in payloads:
|
||||||
|
assert p["chat_id"] == "chat_multi"
|
||||||
|
|
||||||
|
|
||||||
|
def test_regenerate_registers_task_in_in_flight_tasks(tmp_path, monkeypatch):
|
||||||
|
"""T83.1: regenerate's streaming Task is registered in the chat-keyed
|
||||||
|
``_in_flight_tasks`` dict so the /turns/cancel route can cancel a
|
||||||
|
mid-regenerate stream. Mirrors the meanwhile registration pattern
|
||||||
|
pinned by tests/test_meanwhile_turn_flow.py.
|
||||||
|
|
||||||
|
Snapshot pattern: a custom MockLLMClient subclass captures the
|
||||||
|
presence of the chat_id in ``_in_flight_tasks`` at the first stream
|
||||||
|
yield (when the regenerate coroutine is awaiting our generator and
|
||||||
|
the task is alive). Post-flight, the entry must be cleaned up so the
|
||||||
|
next regenerate / turn registers a fresh task.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
from typing import AsyncIterator, Sequence
|
||||||
|
|
||||||
|
from chat.config import Settings
|
||||||
|
from chat.db.migrate import apply_migrations
|
||||||
|
from chat.llm.client import Message
|
||||||
|
from chat.services.regenerate import regenerate_assistant_turn
|
||||||
|
from chat.web.turns import _in_flight_tasks
|
||||||
|
|
||||||
|
db_path = tmp_path / "test.db"
|
||||||
|
cfg = tmp_path / "config.toml"
|
||||||
|
cfg.write_text('featherless_api_key = "test"\n')
|
||||||
|
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
|
||||||
|
monkeypatch.setenv("CHAT_DB_PATH", str(db_path))
|
||||||
|
apply_migrations(db_path)
|
||||||
|
|
||||||
|
_ut_id, at_id = _seed_with_one_turn(db_path)
|
||||||
|
|
||||||
|
in_flight_snapshot: dict = {}
|
||||||
|
|
||||||
|
class _SnapshotMock(MockLLMClient):
|
||||||
|
async def stream(
|
||||||
|
self, messages: Sequence[Message], *, model: str, **params
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
text = self._canned.pop(0)
|
||||||
|
for i, ch in enumerate(text):
|
||||||
|
if i == 0:
|
||||||
|
in_flight_snapshot["present"] = (
|
||||||
|
"chat_bot_a" in _in_flight_tasks
|
||||||
|
)
|
||||||
|
in_flight_snapshot["task"] = _in_flight_tasks.get(
|
||||||
|
"chat_bot_a"
|
||||||
|
)
|
||||||
|
yield ch
|
||||||
|
|
||||||
|
state_canned = json.dumps(
|
||||||
|
{"affinity_delta": 0, "trust_delta": 0, "knowledge_facts": []}
|
||||||
|
)
|
||||||
|
mock_client = _SnapshotMock(
|
||||||
|
canned=["Refreshed reply.", state_canned, state_canned]
|
||||||
|
)
|
||||||
|
|
||||||
|
settings = Settings(featherless_api_key="test")
|
||||||
|
|
||||||
|
# Pre-condition: registry empty for this chat.
|
||||||
|
assert "chat_bot_a" not in _in_flight_tasks
|
||||||
|
|
||||||
|
with open_db(db_path) as conn:
|
||||||
|
new_text = asyncio.run(
|
||||||
|
regenerate_assistant_turn(
|
||||||
|
conn,
|
||||||
|
mock_client,
|
||||||
|
settings=settings,
|
||||||
|
chat_id="chat_bot_a",
|
||||||
|
original_assistant_event_id=at_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert new_text == "Refreshed reply."
|
||||||
|
|
||||||
|
# Mid-flight: the streaming task was present in the registry, and
|
||||||
|
# the captured value was an asyncio.Task.
|
||||||
|
assert in_flight_snapshot.get("present") is True, (
|
||||||
|
"_in_flight_tasks was empty at first yield — regenerate stream "
|
||||||
|
"isn't registering its task"
|
||||||
|
)
|
||||||
|
assert isinstance(in_flight_snapshot.get("task"), asyncio.Task)
|
||||||
|
# Post-flight: the entry has been cleaned up.
|
||||||
|
assert "chat_bot_a" not in _in_flight_tasks
|
||||||
|
|||||||
@@ -0,0 +1,215 @@
|
|||||||
|
"""Shared turn helpers (T83.2).
|
||||||
|
|
||||||
|
``chat.services.turn_common`` extracts two snippets that were duplicated
|
||||||
|
between ``chat.web.turns`` and ``chat.services.regenerate``: the recent
|
||||||
|
user-side / assistant_turn read, and the directed-pair edge gather for
|
||||||
|
the multi-pair state-update pass. These tests pin the helpers' behavior
|
||||||
|
independently of either call site.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from chat.db.connection import open_db
|
||||||
|
from chat.db.migrate import apply_migrations
|
||||||
|
from chat.eventlog.log import append_event
|
||||||
|
from chat.eventlog.projector import project
|
||||||
|
from chat.services.turn_common import gather_prior_edges, read_recent_dialogue
|
||||||
|
|
||||||
|
|
||||||
|
def _seed_basic_chat(db_path):
|
||||||
|
"""Seed bot + chat + a couple of edges + one round of user/assistant
|
||||||
|
turns. Returns ``(user_turn_id, assistant_turn_id)``.
|
||||||
|
"""
|
||||||
|
apply_migrations(db_path)
|
||||||
|
with open_db(db_path) as conn:
|
||||||
|
append_event(
|
||||||
|
conn,
|
||||||
|
kind="bot_authored",
|
||||||
|
payload={
|
||||||
|
"id": "bot_a",
|
||||||
|
"name": "BotA",
|
||||||
|
"persona": "thoughtful",
|
||||||
|
"voice_samples": [],
|
||||||
|
"traits": [],
|
||||||
|
"backstory": "",
|
||||||
|
"initial_relationship_to_you": "",
|
||||||
|
"kickoff_prose": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
append_event(
|
||||||
|
conn,
|
||||||
|
kind="chat_created",
|
||||||
|
payload={
|
||||||
|
"id": "chat_a",
|
||||||
|
"host_bot_id": "bot_a",
|
||||||
|
"initial_time": "2026-04-26T20:00:00+00:00",
|
||||||
|
"narrative_anchor": "Day 1",
|
||||||
|
"weather": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
append_event(
|
||||||
|
conn,
|
||||||
|
kind="edge_update",
|
||||||
|
payload={
|
||||||
|
"source_id": "bot_a",
|
||||||
|
"target_id": "you",
|
||||||
|
"chat_id": "chat_a",
|
||||||
|
"affinity_delta": 7,
|
||||||
|
"trust_delta": 3,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
append_event(
|
||||||
|
conn,
|
||||||
|
kind="edge_update",
|
||||||
|
payload={
|
||||||
|
"source_id": "you",
|
||||||
|
"target_id": "bot_a",
|
||||||
|
"chat_id": "chat_a",
|
||||||
|
"affinity_delta": 2,
|
||||||
|
"trust_delta": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ut_id = append_event(
|
||||||
|
conn,
|
||||||
|
kind="user_turn",
|
||||||
|
payload={
|
||||||
|
"chat_id": "chat_a",
|
||||||
|
"prose": "hello",
|
||||||
|
"segments": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
at_id = append_event(
|
||||||
|
conn,
|
||||||
|
kind="assistant_turn",
|
||||||
|
payload={
|
||||||
|
"chat_id": "chat_a",
|
||||||
|
"speaker_id": "bot_a",
|
||||||
|
"text": "Original.",
|
||||||
|
"truncated": False,
|
||||||
|
"user_turn_id": ut_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
project(conn)
|
||||||
|
return ut_id, at_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_recent_dialogue_returns_chronological_pairs(tmp_path):
|
||||||
|
"""``read_recent_dialogue`` returns oldest-first ``{speaker, text}``
|
||||||
|
entries scoped to the requested chat. Speaker is "you" for user-side
|
||||||
|
rows and the assistant_turn's ``speaker_id`` for bot rows.
|
||||||
|
"""
|
||||||
|
db = tmp_path / "test.db"
|
||||||
|
_seed_basic_chat(db)
|
||||||
|
|
||||||
|
with open_db(db) as conn:
|
||||||
|
out = read_recent_dialogue(conn, "chat_a", limit=10)
|
||||||
|
|
||||||
|
assert out == [
|
||||||
|
{"speaker": "you", "text": "hello"},
|
||||||
|
{"speaker": "bot_a", "text": "Original."},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_recent_dialogue_filters_superseded_and_other_chats(tmp_path):
|
||||||
|
"""Superseded rows drop out (regenerate-aware). Rows scoped to a
|
||||||
|
different chat are also filtered. ``exclude_event_id`` excludes a
|
||||||
|
specific row even when it isn't superseded yet (regenerate uses this
|
||||||
|
to drop the original assistant_turn before the supersede UPDATE
|
||||||
|
lands).
|
||||||
|
"""
|
||||||
|
db = tmp_path / "test.db"
|
||||||
|
ut_id, at_id = _seed_basic_chat(db)
|
||||||
|
|
||||||
|
with open_db(db) as conn:
|
||||||
|
# Append a second user/assistant pair.
|
||||||
|
ut_id2 = append_event(
|
||||||
|
conn,
|
||||||
|
kind="user_turn",
|
||||||
|
payload={
|
||||||
|
"chat_id": "chat_a",
|
||||||
|
"prose": "how are you",
|
||||||
|
"segments": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
at_id2 = append_event(
|
||||||
|
conn,
|
||||||
|
kind="assistant_turn",
|
||||||
|
payload={
|
||||||
|
"chat_id": "chat_a",
|
||||||
|
"speaker_id": "bot_a",
|
||||||
|
"text": "Second.",
|
||||||
|
"truncated": False,
|
||||||
|
"user_turn_id": ut_id2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# And a row scoped to a different chat — must NOT appear.
|
||||||
|
append_event(
|
||||||
|
conn,
|
||||||
|
kind="user_turn",
|
||||||
|
payload={
|
||||||
|
"chat_id": "other_chat",
|
||||||
|
"prose": "should be filtered",
|
||||||
|
"segments": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Mark the first assistant_turn as superseded — must drop out.
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE event_log SET superseded_by = ? WHERE id = ?",
|
||||||
|
(at_id2, at_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
out = read_recent_dialogue(conn, "chat_a", limit=10)
|
||||||
|
# First (superseded) assistant turn dropped; "other_chat" rows
|
||||||
|
# filtered; first user_turn still present.
|
||||||
|
speakers = [(e["speaker"], e["text"]) for e in out]
|
||||||
|
assert speakers == [
|
||||||
|
("you", "hello"),
|
||||||
|
("you", "how are you"),
|
||||||
|
("bot_a", "Second."),
|
||||||
|
]
|
||||||
|
|
||||||
|
# exclude_event_id drops at_id2 even though it's not superseded.
|
||||||
|
out2 = read_recent_dialogue(
|
||||||
|
conn, "chat_a", limit=10, exclude_event_id=at_id2
|
||||||
|
)
|
||||||
|
speakers2 = [(e["speaker"], e["text"]) for e in out2]
|
||||||
|
assert ("bot_a", "Second.") not in speakers2
|
||||||
|
assert ("you", "how are you") in speakers2
|
||||||
|
|
||||||
|
# Ensure ut_id is still part of the dataset (sanity for the seed).
|
||||||
|
assert ut_id is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_gather_prior_edges_fills_missing_with_default(tmp_path):
|
||||||
|
"""``gather_prior_edges`` returns one entry per directed pair across
|
||||||
|
``present_ids``. Missing rows fall back to the schema default
|
||||||
|
50/50 baseline; existing rows carry their stored values.
|
||||||
|
"""
|
||||||
|
db = tmp_path / "test.db"
|
||||||
|
_seed_basic_chat(db)
|
||||||
|
|
||||||
|
with open_db(db) as conn:
|
||||||
|
out = gather_prior_edges(conn, ["bot_a", "you"])
|
||||||
|
|
||||||
|
# 2 entities -> 2 directed pairs (a->b and b->a, no self-pairs).
|
||||||
|
assert set(out.keys()) == {("bot_a", "you"), ("you", "bot_a")}
|
||||||
|
bot_to_you = out[("bot_a", "you")]
|
||||||
|
you_to_bot = out[("you", "bot_a")]
|
||||||
|
# Both edges seeded with deltas — they must reflect the projected
|
||||||
|
# affinity/trust (not the default 50/50).
|
||||||
|
assert bot_to_you["affinity"] == 57 # 50 + 7
|
||||||
|
assert bot_to_you["trust"] == 53 # 50 + 3
|
||||||
|
assert you_to_bot["affinity"] == 52
|
||||||
|
assert you_to_bot["trust"] == 51
|
||||||
|
|
||||||
|
# A pair with no row yet falls back to 50/50.
|
||||||
|
with open_db(db) as conn:
|
||||||
|
out_with_missing = gather_prior_edges(
|
||||||
|
conn, ["bot_a", "you", "ghost_bot"]
|
||||||
|
)
|
||||||
|
# 3 entities -> 6 directed pairs.
|
||||||
|
assert len(out_with_missing) == 6
|
||||||
|
fallback = out_with_missing[("bot_a", "ghost_bot")]
|
||||||
|
assert fallback["affinity"] == 50
|
||||||
|
assert fallback["trust"] == 50
|
||||||
|
assert fallback["summary"] == ""
|
||||||
Reference in New Issue
Block a user