feat: regenerate registers stream task in _in_flight_tasks (T83.1)

Both the primary and the interjection sub-stream in
``regenerate_assistant_turn`` are now wrapped in ``asyncio.create_task``
and registered in the chat-keyed ``_in_flight_tasks`` registry that the
``/turns/cancel`` route reads. Without this, hitting Stop during a
mid-regenerate stream was a no-op.

Mirrors the meanwhile registration pattern in chat/web/meanwhile.py
(snapshot-tested by tests/test_meanwhile_turn_flow.py).

Test added: test_regenerate_registers_task_in_in_flight_tasks captures
``"chat_bot_a" in _in_flight_tasks`` at the first stream yield via a
custom MockLLMClient subclass and asserts post-flight cleanup.
This commit is contained in:
Joseph Doherty
2026-04-26 22:11:23 -04:00
parent 9e7c16de40
commit f2fd30c5a9
2 changed files with 139 additions and 27 deletions
+57 -27
View File
@@ -68,6 +68,7 @@ Phase 2.5 changes:
from __future__ import annotations
import asyncio
import json
from sqlite3 import Connection
@@ -250,19 +251,37 @@ async def regenerate_assistant_turn(
guest_id=guest_bot_id,
)
# 5. Stream the new narrative.
# 5. Stream the new narrative. T83.1: register the streaming Task in
# the chat-keyed in-flight registry so POST /chats/<id>/turns/cancel
# can call ``.cancel()`` on a mid-regenerate stream. We import the
# underscore name from turns.py deliberately — same single-process
# registry the cancel route reads, mirrors the meanwhile registration
# pattern in chat/web/meanwhile.py.
from chat.web.turns import _in_flight_tasks # noqa: PLC0415
accumulated: list[str] = []
async for chunk in client.stream(
messages,
model=settings.narrative_model,
max_tokens=settings.narrative_max_tokens,
temperature=settings.narrative_temperature,
):
accumulated.append(chunk)
await publish(
chat_id,
{"event": "token", "text": chunk, "speaker_id": speaker_bot_id},
)
async def _stream_primary() -> None:
async for chunk in client.stream(
messages,
model=settings.narrative_model,
max_tokens=settings.narrative_max_tokens,
temperature=settings.narrative_temperature,
):
accumulated.append(chunk)
await publish(
chat_id,
{"event": "token", "text": chunk, "speaker_id": speaker_bot_id},
)
stream_task = asyncio.create_task(_stream_primary())
_in_flight_tasks[chat_id] = stream_task
try:
await stream_task
finally:
# Always unregister so a subsequent turn / regenerate can register
# a fresh task. Mirrors the cleanup in turns.py::post_turn.
_in_flight_tasks.pop(chat_id, None)
new_text = "".join(accumulated)
# 6. Append the new assistant_turn event. ``user_turn_id`` points at
@@ -497,21 +516,32 @@ async def regenerate_assistant_turn(
)
interject_accumulated: list[str] = []
async for chunk in client.stream(
interject_messages,
model=settings.narrative_model,
max_tokens=settings.narrative_max_tokens,
temperature=settings.narrative_temperature,
):
interject_accumulated.append(chunk)
await publish(
chat_id,
{
"event": "token",
"text": chunk,
"speaker_id": silent_witness_id,
},
)
async def _stream_interjection() -> None:
async for chunk in client.stream(
interject_messages,
model=settings.narrative_model,
max_tokens=settings.narrative_max_tokens,
temperature=settings.narrative_temperature,
):
interject_accumulated.append(chunk)
await publish(
chat_id,
{
"event": "token",
"text": chunk,
"speaker_id": silent_witness_id,
},
)
# T83.1: register the interjection sub-stream in the same
# in-flight registry so /turns/cancel collapses it too.
interject_task = asyncio.create_task(_stream_interjection())
_in_flight_tasks[chat_id] = interject_task
try:
await interject_task
finally:
_in_flight_tasks.pop(chat_id, None)
interject_text = "".join(interject_accumulated)
new_interjection_event_id = append_event(
+82
View File
@@ -662,3 +662,85 @@ def test_regenerate_drops_interjection_when_classifier_returns_false(
new_primary_payload = json.loads(cur[0][0])
assert new_primary_payload["text"] == "New primary text."
assert "interjection_of" not in new_primary_payload
def test_regenerate_registers_task_in_in_flight_tasks(tmp_path, monkeypatch):
"""T83.1: regenerate's streaming Task is registered in the chat-keyed
``_in_flight_tasks`` dict so the /turns/cancel route can cancel a
mid-regenerate stream. Mirrors the meanwhile registration pattern
pinned by tests/test_meanwhile_turn_flow.py.
Snapshot pattern: a custom MockLLMClient subclass captures the
presence of the chat_id in ``_in_flight_tasks`` at the first stream
yield (when the regenerate coroutine is awaiting our generator and
the task is alive). Post-flight, the entry must be cleaned up so the
next regenerate / turn registers a fresh task.
"""
import asyncio
from typing import AsyncIterator, Sequence
from chat.config import Settings
from chat.db.migrate import apply_migrations
from chat.llm.client import Message
from chat.services.regenerate import regenerate_assistant_turn
from chat.web.turns import _in_flight_tasks
db_path = tmp_path / "test.db"
cfg = tmp_path / "config.toml"
cfg.write_text('featherless_api_key = "test"\n')
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
monkeypatch.setenv("CHAT_DB_PATH", str(db_path))
apply_migrations(db_path)
_ut_id, at_id = _seed_with_one_turn(db_path)
in_flight_snapshot: dict = {}
class _SnapshotMock(MockLLMClient):
async def stream(
self, messages: Sequence[Message], *, model: str, **params
) -> AsyncIterator[str]:
text = self._canned.pop(0)
for i, ch in enumerate(text):
if i == 0:
in_flight_snapshot["present"] = (
"chat_bot_a" in _in_flight_tasks
)
in_flight_snapshot["task"] = _in_flight_tasks.get(
"chat_bot_a"
)
yield ch
state_canned = json.dumps(
{"affinity_delta": 0, "trust_delta": 0, "knowledge_facts": []}
)
mock_client = _SnapshotMock(
canned=["Refreshed reply.", state_canned, state_canned]
)
settings = Settings(featherless_api_key="test")
# Pre-condition: registry empty for this chat.
assert "chat_bot_a" not in _in_flight_tasks
with open_db(db_path) as conn:
new_text = asyncio.run(
regenerate_assistant_turn(
conn,
mock_client,
settings=settings,
chat_id="chat_bot_a",
original_assistant_event_id=at_id,
)
)
assert new_text == "Refreshed reply."
# Mid-flight: the streaming task was present in the registry, and
# the captured value was an asyncio.Task.
assert in_flight_snapshot.get("present") is True, (
"_in_flight_tasks was empty at first yield — regenerate stream "
"isn't registering its task"
)
assert isinstance(in_flight_snapshot.get("task"), asyncio.Task)
# Post-flight: the entry has been cleaned up.
assert "chat_bot_a" not in _in_flight_tasks