feat: regenerate registers stream task in _in_flight_tasks (T83.1)
Both the primary and the interjection sub-stream in ``regenerate_assistant_turn`` are now wrapped in ``asyncio.create_task`` and registered in the chat-keyed ``_in_flight_tasks`` registry that the ``/turns/cancel`` route reads. Without this, hitting Stop during a mid-regenerate stream was a no-op. Mirrors the meanwhile registration pattern in chat/web/meanwhile.py (snapshot-tested by tests/test_meanwhile_turn_flow.py). Test added: test_regenerate_registers_task_in_in_flight_tasks captures ``"chat_bot_a" in _in_flight_tasks`` at the first stream yield via a custom MockLLMClient subclass and asserts post-flight cleanup.
This commit is contained in:
+57
-27
@@ -68,6 +68,7 @@ Phase 2.5 changes:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from sqlite3 import Connection
|
||||
|
||||
@@ -250,19 +251,37 @@ async def regenerate_assistant_turn(
|
||||
guest_id=guest_bot_id,
|
||||
)
|
||||
|
||||
# 5. Stream the new narrative.
|
||||
# 5. Stream the new narrative. T83.1: register the streaming Task in
|
||||
# the chat-keyed in-flight registry so POST /chats/<id>/turns/cancel
|
||||
# can call ``.cancel()`` on a mid-regenerate stream. We import the
|
||||
# underscore name from turns.py deliberately — same single-process
|
||||
# registry the cancel route reads, mirrors the meanwhile registration
|
||||
# pattern in chat/web/meanwhile.py.
|
||||
from chat.web.turns import _in_flight_tasks # noqa: PLC0415
|
||||
|
||||
accumulated: list[str] = []
|
||||
async for chunk in client.stream(
|
||||
messages,
|
||||
model=settings.narrative_model,
|
||||
max_tokens=settings.narrative_max_tokens,
|
||||
temperature=settings.narrative_temperature,
|
||||
):
|
||||
accumulated.append(chunk)
|
||||
await publish(
|
||||
chat_id,
|
||||
{"event": "token", "text": chunk, "speaker_id": speaker_bot_id},
|
||||
)
|
||||
|
||||
async def _stream_primary() -> None:
|
||||
async for chunk in client.stream(
|
||||
messages,
|
||||
model=settings.narrative_model,
|
||||
max_tokens=settings.narrative_max_tokens,
|
||||
temperature=settings.narrative_temperature,
|
||||
):
|
||||
accumulated.append(chunk)
|
||||
await publish(
|
||||
chat_id,
|
||||
{"event": "token", "text": chunk, "speaker_id": speaker_bot_id},
|
||||
)
|
||||
|
||||
stream_task = asyncio.create_task(_stream_primary())
|
||||
_in_flight_tasks[chat_id] = stream_task
|
||||
try:
|
||||
await stream_task
|
||||
finally:
|
||||
# Always unregister so a subsequent turn / regenerate can register
|
||||
# a fresh task. Mirrors the cleanup in turns.py::post_turn.
|
||||
_in_flight_tasks.pop(chat_id, None)
|
||||
new_text = "".join(accumulated)
|
||||
|
||||
# 6. Append the new assistant_turn event. ``user_turn_id`` points at
|
||||
@@ -497,21 +516,32 @@ async def regenerate_assistant_turn(
|
||||
)
|
||||
|
||||
interject_accumulated: list[str] = []
|
||||
async for chunk in client.stream(
|
||||
interject_messages,
|
||||
model=settings.narrative_model,
|
||||
max_tokens=settings.narrative_max_tokens,
|
||||
temperature=settings.narrative_temperature,
|
||||
):
|
||||
interject_accumulated.append(chunk)
|
||||
await publish(
|
||||
chat_id,
|
||||
{
|
||||
"event": "token",
|
||||
"text": chunk,
|
||||
"speaker_id": silent_witness_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def _stream_interjection() -> None:
|
||||
async for chunk in client.stream(
|
||||
interject_messages,
|
||||
model=settings.narrative_model,
|
||||
max_tokens=settings.narrative_max_tokens,
|
||||
temperature=settings.narrative_temperature,
|
||||
):
|
||||
interject_accumulated.append(chunk)
|
||||
await publish(
|
||||
chat_id,
|
||||
{
|
||||
"event": "token",
|
||||
"text": chunk,
|
||||
"speaker_id": silent_witness_id,
|
||||
},
|
||||
)
|
||||
|
||||
# T83.1: register the interjection sub-stream in the same
|
||||
# in-flight registry so /turns/cancel collapses it too.
|
||||
interject_task = asyncio.create_task(_stream_interjection())
|
||||
_in_flight_tasks[chat_id] = interject_task
|
||||
try:
|
||||
await interject_task
|
||||
finally:
|
||||
_in_flight_tasks.pop(chat_id, None)
|
||||
interject_text = "".join(interject_accumulated)
|
||||
|
||||
new_interjection_event_id = append_event(
|
||||
|
||||
Reference in New Issue
Block a user