diff --git a/chat/services/regenerate.py b/chat/services/regenerate.py index 1317903..cb6b23c 100644 --- a/chat/services/regenerate.py +++ b/chat/services/regenerate.py @@ -68,6 +68,7 @@ Phase 2.5 changes: from __future__ import annotations +import asyncio import json from sqlite3 import Connection @@ -250,19 +251,37 @@ async def regenerate_assistant_turn( guest_id=guest_bot_id, ) - # 5. Stream the new narrative. + # 5. Stream the new narrative. T83.1: register the streaming Task in + # the chat-keyed in-flight registry so POST /chats//turns/cancel + # can call ``.cancel()`` on a mid-regenerate stream. We import the + # underscore name from turns.py deliberately — same single-process + # registry the cancel route reads, mirrors the meanwhile registration + # pattern in chat/web/meanwhile.py. + from chat.web.turns import _in_flight_tasks # noqa: PLC0415 + accumulated: list[str] = [] - async for chunk in client.stream( - messages, - model=settings.narrative_model, - max_tokens=settings.narrative_max_tokens, - temperature=settings.narrative_temperature, - ): - accumulated.append(chunk) - await publish( - chat_id, - {"event": "token", "text": chunk, "speaker_id": speaker_bot_id}, - ) + + async def _stream_primary() -> None: + async for chunk in client.stream( + messages, + model=settings.narrative_model, + max_tokens=settings.narrative_max_tokens, + temperature=settings.narrative_temperature, + ): + accumulated.append(chunk) + await publish( + chat_id, + {"event": "token", "text": chunk, "speaker_id": speaker_bot_id}, + ) + + stream_task = asyncio.create_task(_stream_primary()) + _in_flight_tasks[chat_id] = stream_task + try: + await stream_task + finally: + # Always unregister so a subsequent turn / regenerate can register + # a fresh task. Mirrors the cleanup in turns.py::post_turn. + _in_flight_tasks.pop(chat_id, None) new_text = "".join(accumulated) # 6. Append the new assistant_turn event. ``user_turn_id`` points at @@ -497,21 +516,32 @@ async def regenerate_assistant_turn( ) interject_accumulated: list[str] = [] - async for chunk in client.stream( - interject_messages, - model=settings.narrative_model, - max_tokens=settings.narrative_max_tokens, - temperature=settings.narrative_temperature, - ): - interject_accumulated.append(chunk) - await publish( - chat_id, - { - "event": "token", - "text": chunk, - "speaker_id": silent_witness_id, - }, - ) + + async def _stream_interjection() -> None: + async for chunk in client.stream( + interject_messages, + model=settings.narrative_model, + max_tokens=settings.narrative_max_tokens, + temperature=settings.narrative_temperature, + ): + interject_accumulated.append(chunk) + await publish( + chat_id, + { + "event": "token", + "text": chunk, + "speaker_id": silent_witness_id, + }, + ) + + # T83.1: register the interjection sub-stream in the same + # in-flight registry so /turns/cancel collapses it too. + interject_task = asyncio.create_task(_stream_interjection()) + _in_flight_tasks[chat_id] = interject_task + try: + await interject_task + finally: + _in_flight_tasks.pop(chat_id, None) interject_text = "".join(interject_accumulated) new_interjection_event_id = append_event( diff --git a/tests/test_regenerate.py b/tests/test_regenerate.py index 7fa22bc..f065980 100644 --- a/tests/test_regenerate.py +++ b/tests/test_regenerate.py @@ -662,3 +662,85 @@ def test_regenerate_drops_interjection_when_classifier_returns_false( new_primary_payload = json.loads(cur[0][0]) assert new_primary_payload["text"] == "New primary text." assert "interjection_of" not in new_primary_payload + + +def test_regenerate_registers_task_in_in_flight_tasks(tmp_path, monkeypatch): + """T83.1: regenerate's streaming Task is registered in the chat-keyed + ``_in_flight_tasks`` dict so the /turns/cancel route can cancel a + mid-regenerate stream. Mirrors the meanwhile registration pattern + pinned by tests/test_meanwhile_turn_flow.py. + + Snapshot pattern: a custom MockLLMClient subclass captures the + presence of the chat_id in ``_in_flight_tasks`` at the first stream + yield (when the regenerate coroutine is awaiting our generator and + the task is alive). Post-flight, the entry must be cleaned up so the + next regenerate / turn registers a fresh task. + """ + import asyncio + from typing import AsyncIterator, Sequence + + from chat.config import Settings + from chat.db.migrate import apply_migrations + from chat.llm.client import Message + from chat.services.regenerate import regenerate_assistant_turn + from chat.web.turns import _in_flight_tasks + + db_path = tmp_path / "test.db" + cfg = tmp_path / "config.toml" + cfg.write_text('featherless_api_key = "test"\n') + monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg)) + monkeypatch.setenv("CHAT_DB_PATH", str(db_path)) + apply_migrations(db_path) + + _ut_id, at_id = _seed_with_one_turn(db_path) + + in_flight_snapshot: dict = {} + + class _SnapshotMock(MockLLMClient): + async def stream( + self, messages: Sequence[Message], *, model: str, **params + ) -> AsyncIterator[str]: + text = self._canned.pop(0) + for i, ch in enumerate(text): + if i == 0: + in_flight_snapshot["present"] = ( + "chat_bot_a" in _in_flight_tasks + ) + in_flight_snapshot["task"] = _in_flight_tasks.get( + "chat_bot_a" + ) + yield ch + + state_canned = json.dumps( + {"affinity_delta": 0, "trust_delta": 0, "knowledge_facts": []} + ) + mock_client = _SnapshotMock( + canned=["Refreshed reply.", state_canned, state_canned] + ) + + settings = Settings(featherless_api_key="test") + + # Pre-condition: registry empty for this chat. + assert "chat_bot_a" not in _in_flight_tasks + + with open_db(db_path) as conn: + new_text = asyncio.run( + regenerate_assistant_turn( + conn, + mock_client, + settings=settings, + chat_id="chat_bot_a", + original_assistant_event_id=at_id, + ) + ) + assert new_text == "Refreshed reply." + + # Mid-flight: the streaming task was present in the registry, and + # the captured value was an asyncio.Task. + assert in_flight_snapshot.get("present") is True, ( + "_in_flight_tasks was empty at first yield — regenerate stream " + "isn't registering its task" + ) + assert isinstance(in_flight_snapshot.get("task"), asyncio.Task) + # Post-flight: the entry has been cleaned up. + assert "chat_bot_a" not in _in_flight_tasks