diff --git a/chat/app.py b/chat/app.py index a2033cf..02b936d 100644 --- a/chat/app.py +++ b/chat/app.py @@ -19,6 +19,7 @@ from chat.web.chat import router as chat_router from chat.web.kickoff import router as kickoff_router from chat.web.nav import router as nav_router from chat.web.settings import router as settings_router +from chat.web.sse import router as sse_router @asynccontextmanager @@ -40,6 +41,7 @@ app.include_router(kickoff_router) app.include_router(settings_router) app.include_router(nav_router) app.include_router(chat_router) +app.include_router(sse_router) @app.get("/health") diff --git a/chat/web/pubsub.py b/chat/web/pubsub.py new file mode 100644 index 0000000..533c975 --- /dev/null +++ b/chat/web/pubsub.py @@ -0,0 +1,60 @@ +"""In-process per-chat broadcast channel. + +Each ``chat_id`` has a list of subscriber ``asyncio.Queue`` instances. T16 +provides only the registry and fan-out mechanism; T19+ will publish events +(turn appends, streamed tokens, drawer updates, scene close, edge updates) +through this channel so all browser tabs viewing a chat stay in sync. + +The registry is process-local: appropriate for a single-user local server. +""" + +from __future__ import annotations + +import asyncio +from collections import defaultdict +from typing import Any + +# {chat_id: [queue, queue, ...]} +_subscribers: dict[str, list[asyncio.Queue]] = defaultdict(list) +_lock = asyncio.Lock() + + +async def subscribe(chat_id: str) -> asyncio.Queue: + """Subscribe to a chat's broadcast channel. + + Returns a fresh ``asyncio.Queue`` that will receive every event published + to ``chat_id`` while the subscription is active. Callers must invoke + :func:`unsubscribe` when finished (typically on client disconnect) to + avoid leaking queues into the registry. + """ + queue: asyncio.Queue = asyncio.Queue() + async with _lock: + _subscribers[chat_id].append(queue) + return queue + + +async def unsubscribe(chat_id: str, queue: asyncio.Queue) -> None: + """Remove ``queue`` from the registry; remove the chat key if empty.""" + async with _lock: + if chat_id in _subscribers: + if queue in _subscribers[chat_id]: + _subscribers[chat_id].remove(queue) + if not _subscribers[chat_id]: + del _subscribers[chat_id] + + +async def publish(chat_id: str, event: dict[str, Any]) -> None: + """Fan-out ``event`` to every subscriber of ``chat_id``. + + The same dict reference is enqueued to all subscribers. Callers should + treat published events as immutable. Queues are unbounded for v1. + """ + async with _lock: + queues = list(_subscribers.get(chat_id, [])) + for q in queues: + await q.put(event) + + +def subscriber_count(chat_id: str) -> int: + """Test helper. Returns the number of active subscribers for a chat.""" + return len(_subscribers.get(chat_id, [])) diff --git a/chat/web/sse.py b/chat/web/sse.py new file mode 100644 index 0000000..4e324c1 --- /dev/null +++ b/chat/web/sse.py @@ -0,0 +1,72 @@ +"""Server-Sent Events endpoint for per-chat live updates. + +Each browser tab on ``/chats/`` opens an SSE connection here. On connect: + +1. We verify the chat exists (404 otherwise). +2. We subscribe to the chat's pub/sub channel. +3. We emit a ``snapshot`` event with the current state. T16 only provides a + stub payload (``{"chat_id": , "ready": true}``) so the client can + confirm the channel is live; T19+ will populate it with real state. +4. We loop, awaiting events from the queue and yielding them as SSE frames. + A 15-second keepalive comment is emitted on idle to defeat intermediary + timeouts. +5. When the client disconnects we unsubscribe so the registry doesn't leak. +""" + +from __future__ import annotations + +import asyncio +import json + +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import StreamingResponse + +from chat.state.world import get_chat +from chat.web.bots import get_conn +from chat.web.pubsub import subscribe, unsubscribe + +router = APIRouter() + +# Heartbeat cadence. Long enough to avoid chattiness; short enough that most +# HTTP intermediaries won't close an idle connection. +_KEEPALIVE_SECONDS = 15.0 + + +def _format_sse(event: str, data: dict) -> bytes: + """Format a single SSE frame: ``event: \\ndata: \\n\\n``.""" + payload = json.dumps(data) + return f"event: {event}\ndata: {payload}\n\n".encode("utf-8") + + +@router.get("/chats/{chat_id}/events") +async def chat_events(chat_id: str, request: Request, conn=Depends(get_conn)): + chat = get_chat(conn, chat_id) + if chat is None: + raise HTTPException(status_code=404, detail="chat not found") + + async def stream(): + queue = await subscribe(chat_id) + try: + # Initial snapshot — T19 will fill in real state. + yield _format_sse("snapshot", {"chat_id": chat_id, "ready": True}) + while True: + if await request.is_disconnected(): + break + try: + event = await asyncio.wait_for( + queue.get(), timeout=_KEEPALIVE_SECONDS + ) + except asyncio.TimeoutError: + # SSE comment line (per spec, lines starting with ":" are + # ignored by the client) — keeps the connection warm. + yield b": keepalive\n\n" + continue + # Allow publishers to set the SSE event name via "event" key; + # default to "message" if omitted. + event = dict(event) # don't mutate the published dict + kind = event.pop("event", "message") + yield _format_sse(kind, event) + finally: + await unsubscribe(chat_id, queue) + + return StreamingResponse(stream(), media_type="text/event-stream") diff --git a/tests/test_sse.py b/tests/test_sse.py new file mode 100644 index 0000000..188af23 --- /dev/null +++ b/tests/test_sse.py @@ -0,0 +1,174 @@ +"""Tests for the per-chat SSE channel and pub/sub registry (T16). + +Covers: +- 404 from the SSE endpoint when the chat doesn't exist. +- The first event yielded on connect is the ``snapshot`` event with correct + framing (``event: snapshot\\ndata: {...}\\n\\n``) and includes the chat id. +- ``publish`` fans out to every subscriber of a chat. +- Events published to one chat don't leak to subscribers of another chat. + +Notes: +- Starlette's TestClient buffers the entire ASGI response before returning, + so it cannot read from an indefinitely-open SSE stream. The framing test + drives the route handler directly and pulls frames from the + ``StreamingResponse``'s body iterator. The 404 path is still validated via + TestClient since it returns synchronously without entering the stream. +""" + +from __future__ import annotations + +import asyncio +import json +from pathlib import Path + +import pytest +from fastapi import Request +from fastapi.testclient import TestClient + +from chat.app import app +from chat.db.connection import open_db +from chat.eventlog.log import append_event +from chat.eventlog.projector import project + + +@pytest.fixture +def client(tmp_path, monkeypatch): + cfg = tmp_path / "config.toml" + cfg.write_text('featherless_api_key = "test"\n') + monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg)) + db = tmp_path / "test.db" + monkeypatch.setenv("CHAT_DB_PATH", str(db)) + with TestClient(app) as c: + yield c + + +def _seed_chat( + db_path: Path, bot_id: str = "bot_a", chat_id: str = "chat_bot_a" +) -> None: + with open_db(db_path) as conn: + append_event( + conn, + kind="bot_authored", + payload={ + "id": bot_id, + "name": "BotA", + "persona": "...", + "voice_samples": [], + "traits": [], + "backstory": "", + "initial_relationship_to_you": "", + "kickoff_prose": "...", + }, + ) + append_event( + conn, + kind="chat_created", + payload={ + "id": chat_id, + "host_bot_id": bot_id, + "initial_time": "2026-04-26T20:00:00+00:00", + "narrative_anchor": "Day 1", + "weather": "", + }, + ) + project(conn) + + +def test_sse_endpoint_404_for_missing_chat(client): + response = client.get("/chats/nonexistent/events") + assert response.status_code == 404 + + +async def test_sse_streams_initial_snapshot(client, tmp_path): + """The first SSE frame is a correctly framed ``snapshot`` event.""" + _seed_chat(tmp_path / "test.db") + + # Drive the route directly. The TestClient can't iterate an open stream + # because Starlette's transport waits for the response to complete. + from chat.web.bots import get_conn + from chat.web.sse import chat_events + + disconnect = asyncio.Event() + + async def receive(): + await disconnect.wait() + return {"type": "http.disconnect"} + + scope = { + "type": "http", + "method": "GET", + "path": "/chats/chat_bot_a/events", + "headers": [], + "query_string": b"", + "app": client.app, + } + request = Request(scope, receive=receive) + + conn_gen = get_conn(request) + conn = next(conn_gen) + try: + response = await chat_events("chat_bot_a", request, conn=conn) + assert response.media_type == "text/event-stream" + + body_iter = response.body_iterator + first = await asyncio.wait_for(body_iter.__anext__(), timeout=2.0) + text = first.decode() + + # Framing: starts with the event line, ends with the blank-line + # terminator that delimits SSE frames. + assert text.startswith("event: snapshot\n") + assert "\ndata: " in text + assert text.endswith("\n\n") + + # Payload is JSON containing the chat id. + data_line = next( + line for line in text.splitlines() if line.startswith("data: ") + ) + payload = json.loads(data_line[len("data: "):]) + assert payload["chat_id"] == "chat_bot_a" + + # Cleanly tear down the generator's ``finally`` (unsubscribe). + disconnect.set() + try: + await asyncio.wait_for(body_iter.aclose(), timeout=2.0) + except (StopAsyncIteration, RuntimeError): + pass + finally: + try: + next(conn_gen) + except StopIteration: + pass + + +async def test_pubsub_publish_fanout(): + """Every active subscriber of a chat receives published events.""" + from chat.web.pubsub import publish, subscribe, unsubscribe + + q1 = await subscribe("chat_x") + q2 = await subscribe("chat_x") + try: + await publish("chat_x", {"event": "test", "value": 1}) + e1 = await asyncio.wait_for(q1.get(), timeout=1.0) + e2 = await asyncio.wait_for(q2.get(), timeout=1.0) + assert e1 == {"event": "test", "value": 1} + assert e2 == {"event": "test", "value": 1} + finally: + await unsubscribe("chat_x", q1) + await unsubscribe("chat_x", q2) + + +async def test_pubsub_isolated_per_chat(): + """Events published to one chat don't leak to a different chat.""" + from chat.web.pubsub import publish, subscribe, unsubscribe + + q_a = await subscribe("chat_a") + q_b = await subscribe("chat_b") + try: + await publish("chat_a", {"event": "for_a"}) + e = await asyncio.wait_for(q_a.get(), timeout=1.0) + assert e["event"] == "for_a" + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(q_b.get(), timeout=0.1) + finally: + await unsubscribe("chat_a", q_a) + await unsubscribe("chat_b", q_b)