feat: regenerate registers stream task in _in_flight_tasks (T83.1)
Both the primary and the interjection sub-stream in ``regenerate_assistant_turn`` are now wrapped in ``asyncio.create_task`` and registered in the chat-keyed ``_in_flight_tasks`` registry that the ``/turns/cancel`` route reads. Without this, hitting Stop during a mid-regenerate stream was a no-op. Mirrors the meanwhile registration pattern in chat/web/meanwhile.py (snapshot-tested by tests/test_meanwhile_turn_flow.py). Test added: test_regenerate_registers_task_in_in_flight_tasks captures ``"chat_bot_a" in _in_flight_tasks`` at the first stream yield via a custom MockLLMClient subclass and asserts post-flight cleanup.
This commit is contained in:
+57
-27
@@ -68,6 +68,7 @@ Phase 2.5 changes:
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from sqlite3 import Connection
|
from sqlite3 import Connection
|
||||||
|
|
||||||
@@ -250,19 +251,37 @@ async def regenerate_assistant_turn(
|
|||||||
guest_id=guest_bot_id,
|
guest_id=guest_bot_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. Stream the new narrative.
|
# 5. Stream the new narrative. T83.1: register the streaming Task in
|
||||||
|
# the chat-keyed in-flight registry so POST /chats/<id>/turns/cancel
|
||||||
|
# can call ``.cancel()`` on a mid-regenerate stream. We import the
|
||||||
|
# underscore name from turns.py deliberately — same single-process
|
||||||
|
# registry the cancel route reads, mirrors the meanwhile registration
|
||||||
|
# pattern in chat/web/meanwhile.py.
|
||||||
|
from chat.web.turns import _in_flight_tasks # noqa: PLC0415
|
||||||
|
|
||||||
accumulated: list[str] = []
|
accumulated: list[str] = []
|
||||||
async for chunk in client.stream(
|
|
||||||
messages,
|
async def _stream_primary() -> None:
|
||||||
model=settings.narrative_model,
|
async for chunk in client.stream(
|
||||||
max_tokens=settings.narrative_max_tokens,
|
messages,
|
||||||
temperature=settings.narrative_temperature,
|
model=settings.narrative_model,
|
||||||
):
|
max_tokens=settings.narrative_max_tokens,
|
||||||
accumulated.append(chunk)
|
temperature=settings.narrative_temperature,
|
||||||
await publish(
|
):
|
||||||
chat_id,
|
accumulated.append(chunk)
|
||||||
{"event": "token", "text": chunk, "speaker_id": speaker_bot_id},
|
await publish(
|
||||||
)
|
chat_id,
|
||||||
|
{"event": "token", "text": chunk, "speaker_id": speaker_bot_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
stream_task = asyncio.create_task(_stream_primary())
|
||||||
|
_in_flight_tasks[chat_id] = stream_task
|
||||||
|
try:
|
||||||
|
await stream_task
|
||||||
|
finally:
|
||||||
|
# Always unregister so a subsequent turn / regenerate can register
|
||||||
|
# a fresh task. Mirrors the cleanup in turns.py::post_turn.
|
||||||
|
_in_flight_tasks.pop(chat_id, None)
|
||||||
new_text = "".join(accumulated)
|
new_text = "".join(accumulated)
|
||||||
|
|
||||||
# 6. Append the new assistant_turn event. ``user_turn_id`` points at
|
# 6. Append the new assistant_turn event. ``user_turn_id`` points at
|
||||||
@@ -497,21 +516,32 @@ async def regenerate_assistant_turn(
|
|||||||
)
|
)
|
||||||
|
|
||||||
interject_accumulated: list[str] = []
|
interject_accumulated: list[str] = []
|
||||||
async for chunk in client.stream(
|
|
||||||
interject_messages,
|
async def _stream_interjection() -> None:
|
||||||
model=settings.narrative_model,
|
async for chunk in client.stream(
|
||||||
max_tokens=settings.narrative_max_tokens,
|
interject_messages,
|
||||||
temperature=settings.narrative_temperature,
|
model=settings.narrative_model,
|
||||||
):
|
max_tokens=settings.narrative_max_tokens,
|
||||||
interject_accumulated.append(chunk)
|
temperature=settings.narrative_temperature,
|
||||||
await publish(
|
):
|
||||||
chat_id,
|
interject_accumulated.append(chunk)
|
||||||
{
|
await publish(
|
||||||
"event": "token",
|
chat_id,
|
||||||
"text": chunk,
|
{
|
||||||
"speaker_id": silent_witness_id,
|
"event": "token",
|
||||||
},
|
"text": chunk,
|
||||||
)
|
"speaker_id": silent_witness_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# T83.1: register the interjection sub-stream in the same
|
||||||
|
# in-flight registry so /turns/cancel collapses it too.
|
||||||
|
interject_task = asyncio.create_task(_stream_interjection())
|
||||||
|
_in_flight_tasks[chat_id] = interject_task
|
||||||
|
try:
|
||||||
|
await interject_task
|
||||||
|
finally:
|
||||||
|
_in_flight_tasks.pop(chat_id, None)
|
||||||
interject_text = "".join(interject_accumulated)
|
interject_text = "".join(interject_accumulated)
|
||||||
|
|
||||||
new_interjection_event_id = append_event(
|
new_interjection_event_id = append_event(
|
||||||
|
|||||||
@@ -662,3 +662,85 @@ def test_regenerate_drops_interjection_when_classifier_returns_false(
|
|||||||
new_primary_payload = json.loads(cur[0][0])
|
new_primary_payload = json.loads(cur[0][0])
|
||||||
assert new_primary_payload["text"] == "New primary text."
|
assert new_primary_payload["text"] == "New primary text."
|
||||||
assert "interjection_of" not in new_primary_payload
|
assert "interjection_of" not in new_primary_payload
|
||||||
|
|
||||||
|
|
||||||
|
def test_regenerate_registers_task_in_in_flight_tasks(tmp_path, monkeypatch):
|
||||||
|
"""T83.1: regenerate's streaming Task is registered in the chat-keyed
|
||||||
|
``_in_flight_tasks`` dict so the /turns/cancel route can cancel a
|
||||||
|
mid-regenerate stream. Mirrors the meanwhile registration pattern
|
||||||
|
pinned by tests/test_meanwhile_turn_flow.py.
|
||||||
|
|
||||||
|
Snapshot pattern: a custom MockLLMClient subclass captures the
|
||||||
|
presence of the chat_id in ``_in_flight_tasks`` at the first stream
|
||||||
|
yield (when the regenerate coroutine is awaiting our generator and
|
||||||
|
the task is alive). Post-flight, the entry must be cleaned up so the
|
||||||
|
next regenerate / turn registers a fresh task.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
from typing import AsyncIterator, Sequence
|
||||||
|
|
||||||
|
from chat.config import Settings
|
||||||
|
from chat.db.migrate import apply_migrations
|
||||||
|
from chat.llm.client import Message
|
||||||
|
from chat.services.regenerate import regenerate_assistant_turn
|
||||||
|
from chat.web.turns import _in_flight_tasks
|
||||||
|
|
||||||
|
db_path = tmp_path / "test.db"
|
||||||
|
cfg = tmp_path / "config.toml"
|
||||||
|
cfg.write_text('featherless_api_key = "test"\n')
|
||||||
|
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
|
||||||
|
monkeypatch.setenv("CHAT_DB_PATH", str(db_path))
|
||||||
|
apply_migrations(db_path)
|
||||||
|
|
||||||
|
_ut_id, at_id = _seed_with_one_turn(db_path)
|
||||||
|
|
||||||
|
in_flight_snapshot: dict = {}
|
||||||
|
|
||||||
|
class _SnapshotMock(MockLLMClient):
|
||||||
|
async def stream(
|
||||||
|
self, messages: Sequence[Message], *, model: str, **params
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
text = self._canned.pop(0)
|
||||||
|
for i, ch in enumerate(text):
|
||||||
|
if i == 0:
|
||||||
|
in_flight_snapshot["present"] = (
|
||||||
|
"chat_bot_a" in _in_flight_tasks
|
||||||
|
)
|
||||||
|
in_flight_snapshot["task"] = _in_flight_tasks.get(
|
||||||
|
"chat_bot_a"
|
||||||
|
)
|
||||||
|
yield ch
|
||||||
|
|
||||||
|
state_canned = json.dumps(
|
||||||
|
{"affinity_delta": 0, "trust_delta": 0, "knowledge_facts": []}
|
||||||
|
)
|
||||||
|
mock_client = _SnapshotMock(
|
||||||
|
canned=["Refreshed reply.", state_canned, state_canned]
|
||||||
|
)
|
||||||
|
|
||||||
|
settings = Settings(featherless_api_key="test")
|
||||||
|
|
||||||
|
# Pre-condition: registry empty for this chat.
|
||||||
|
assert "chat_bot_a" not in _in_flight_tasks
|
||||||
|
|
||||||
|
with open_db(db_path) as conn:
|
||||||
|
new_text = asyncio.run(
|
||||||
|
regenerate_assistant_turn(
|
||||||
|
conn,
|
||||||
|
mock_client,
|
||||||
|
settings=settings,
|
||||||
|
chat_id="chat_bot_a",
|
||||||
|
original_assistant_event_id=at_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert new_text == "Refreshed reply."
|
||||||
|
|
||||||
|
# Mid-flight: the streaming task was present in the registry, and
|
||||||
|
# the captured value was an asyncio.Task.
|
||||||
|
assert in_flight_snapshot.get("present") is True, (
|
||||||
|
"_in_flight_tasks was empty at first yield — regenerate stream "
|
||||||
|
"isn't registering its task"
|
||||||
|
)
|
||||||
|
assert isinstance(in_flight_snapshot.get("task"), asyncio.Task)
|
||||||
|
# Post-flight: the entry has been cleaned up.
|
||||||
|
assert "chat_bot_a" not in _in_flight_tasks
|
||||||
|
|||||||
Reference in New Issue
Block a user