feat: per-chat SSE channel and pub/sub

This commit is contained in:
Joseph Doherty
2026-04-26 12:49:41 -04:00
parent e79f4d8d22
commit 656c2558cb
4 changed files with 308 additions and 0 deletions
+2
View File
@@ -19,6 +19,7 @@ from chat.web.chat import router as chat_router
from chat.web.kickoff import router as kickoff_router
from chat.web.nav import router as nav_router
from chat.web.settings import router as settings_router
from chat.web.sse import router as sse_router
@asynccontextmanager
@@ -40,6 +41,7 @@ app.include_router(kickoff_router)
app.include_router(settings_router)
app.include_router(nav_router)
app.include_router(chat_router)
app.include_router(sse_router)
@app.get("/health")
+60
View File
@@ -0,0 +1,60 @@
"""In-process per-chat broadcast channel.
Each ``chat_id`` has a list of subscriber ``asyncio.Queue`` instances. T16
provides only the registry and fan-out mechanism; T19+ will publish events
(turn appends, streamed tokens, drawer updates, scene close, edge updates)
through this channel so all browser tabs viewing a chat stay in sync.
The registry is process-local: appropriate for a single-user local server.
"""
from __future__ import annotations
import asyncio
from collections import defaultdict
from typing import Any
# {chat_id: [queue, queue, ...]}
_subscribers: dict[str, list[asyncio.Queue]] = defaultdict(list)
_lock = asyncio.Lock()
async def subscribe(chat_id: str) -> asyncio.Queue:
"""Subscribe to a chat's broadcast channel.
Returns a fresh ``asyncio.Queue`` that will receive every event published
to ``chat_id`` while the subscription is active. Callers must invoke
:func:`unsubscribe` when finished (typically on client disconnect) to
avoid leaking queues into the registry.
"""
queue: asyncio.Queue = asyncio.Queue()
async with _lock:
_subscribers[chat_id].append(queue)
return queue
async def unsubscribe(chat_id: str, queue: asyncio.Queue) -> None:
"""Remove ``queue`` from the registry; remove the chat key if empty."""
async with _lock:
if chat_id in _subscribers:
if queue in _subscribers[chat_id]:
_subscribers[chat_id].remove(queue)
if not _subscribers[chat_id]:
del _subscribers[chat_id]
async def publish(chat_id: str, event: dict[str, Any]) -> None:
"""Fan-out ``event`` to every subscriber of ``chat_id``.
The same dict reference is enqueued to all subscribers. Callers should
treat published events as immutable. Queues are unbounded for v1.
"""
async with _lock:
queues = list(_subscribers.get(chat_id, []))
for q in queues:
await q.put(event)
def subscriber_count(chat_id: str) -> int:
"""Test helper. Returns the number of active subscribers for a chat."""
return len(_subscribers.get(chat_id, []))
+72
View File
@@ -0,0 +1,72 @@
"""Server-Sent Events endpoint for per-chat live updates.
Each browser tab on ``/chats/<id>`` opens an SSE connection here. On connect:
1. We verify the chat exists (404 otherwise).
2. We subscribe to the chat's pub/sub channel.
3. We emit a ``snapshot`` event with the current state. T16 only provides a
stub payload (``{"chat_id": <id>, "ready": true}``) so the client can
confirm the channel is live; T19+ will populate it with real state.
4. We loop, awaiting events from the queue and yielding them as SSE frames.
A 15-second keepalive comment is emitted on idle to defeat intermediary
timeouts.
5. When the client disconnects we unsubscribe so the registry doesn't leak.
"""
from __future__ import annotations
import asyncio
import json
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse
from chat.state.world import get_chat
from chat.web.bots import get_conn
from chat.web.pubsub import subscribe, unsubscribe
router = APIRouter()
# Heartbeat cadence. Long enough to avoid chattiness; short enough that most
# HTTP intermediaries won't close an idle connection.
_KEEPALIVE_SECONDS = 15.0
def _format_sse(event: str, data: dict) -> bytes:
"""Format a single SSE frame: ``event: <name>\\ndata: <json>\\n\\n``."""
payload = json.dumps(data)
return f"event: {event}\ndata: {payload}\n\n".encode("utf-8")
@router.get("/chats/{chat_id}/events")
async def chat_events(chat_id: str, request: Request, conn=Depends(get_conn)):
chat = get_chat(conn, chat_id)
if chat is None:
raise HTTPException(status_code=404, detail="chat not found")
async def stream():
queue = await subscribe(chat_id)
try:
# Initial snapshot — T19 will fill in real state.
yield _format_sse("snapshot", {"chat_id": chat_id, "ready": True})
while True:
if await request.is_disconnected():
break
try:
event = await asyncio.wait_for(
queue.get(), timeout=_KEEPALIVE_SECONDS
)
except asyncio.TimeoutError:
# SSE comment line (per spec, lines starting with ":" are
# ignored by the client) — keeps the connection warm.
yield b": keepalive\n\n"
continue
# Allow publishers to set the SSE event name via "event" key;
# default to "message" if omitted.
event = dict(event) # don't mutate the published dict
kind = event.pop("event", "message")
yield _format_sse(kind, event)
finally:
await unsubscribe(chat_id, queue)
return StreamingResponse(stream(), media_type="text/event-stream")
+174
View File
@@ -0,0 +1,174 @@
"""Tests for the per-chat SSE channel and pub/sub registry (T16).
Covers:
- 404 from the SSE endpoint when the chat doesn't exist.
- The first event yielded on connect is the ``snapshot`` event with correct
framing (``event: snapshot\\ndata: {...}\\n\\n``) and includes the chat id.
- ``publish`` fans out to every subscriber of a chat.
- Events published to one chat don't leak to subscribers of another chat.
Notes:
- Starlette's TestClient buffers the entire ASGI response before returning,
so it cannot read from an indefinitely-open SSE stream. The framing test
drives the route handler directly and pulls frames from the
``StreamingResponse``'s body iterator. The 404 path is still validated via
TestClient since it returns synchronously without entering the stream.
"""
from __future__ import annotations
import asyncio
import json
from pathlib import Path
import pytest
from fastapi import Request
from fastapi.testclient import TestClient
from chat.app import app
from chat.db.connection import open_db
from chat.eventlog.log import append_event
from chat.eventlog.projector import project
@pytest.fixture
def client(tmp_path, monkeypatch):
cfg = tmp_path / "config.toml"
cfg.write_text('featherless_api_key = "test"\n')
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
db = tmp_path / "test.db"
monkeypatch.setenv("CHAT_DB_PATH", str(db))
with TestClient(app) as c:
yield c
def _seed_chat(
db_path: Path, bot_id: str = "bot_a", chat_id: str = "chat_bot_a"
) -> None:
with open_db(db_path) as conn:
append_event(
conn,
kind="bot_authored",
payload={
"id": bot_id,
"name": "BotA",
"persona": "...",
"voice_samples": [],
"traits": [],
"backstory": "",
"initial_relationship_to_you": "",
"kickoff_prose": "...",
},
)
append_event(
conn,
kind="chat_created",
payload={
"id": chat_id,
"host_bot_id": bot_id,
"initial_time": "2026-04-26T20:00:00+00:00",
"narrative_anchor": "Day 1",
"weather": "",
},
)
project(conn)
def test_sse_endpoint_404_for_missing_chat(client):
response = client.get("/chats/nonexistent/events")
assert response.status_code == 404
async def test_sse_streams_initial_snapshot(client, tmp_path):
"""The first SSE frame is a correctly framed ``snapshot`` event."""
_seed_chat(tmp_path / "test.db")
# Drive the route directly. The TestClient can't iterate an open stream
# because Starlette's transport waits for the response to complete.
from chat.web.bots import get_conn
from chat.web.sse import chat_events
disconnect = asyncio.Event()
async def receive():
await disconnect.wait()
return {"type": "http.disconnect"}
scope = {
"type": "http",
"method": "GET",
"path": "/chats/chat_bot_a/events",
"headers": [],
"query_string": b"",
"app": client.app,
}
request = Request(scope, receive=receive)
conn_gen = get_conn(request)
conn = next(conn_gen)
try:
response = await chat_events("chat_bot_a", request, conn=conn)
assert response.media_type == "text/event-stream"
body_iter = response.body_iterator
first = await asyncio.wait_for(body_iter.__anext__(), timeout=2.0)
text = first.decode()
# Framing: starts with the event line, ends with the blank-line
# terminator that delimits SSE frames.
assert text.startswith("event: snapshot\n")
assert "\ndata: " in text
assert text.endswith("\n\n")
# Payload is JSON containing the chat id.
data_line = next(
line for line in text.splitlines() if line.startswith("data: ")
)
payload = json.loads(data_line[len("data: "):])
assert payload["chat_id"] == "chat_bot_a"
# Cleanly tear down the generator's ``finally`` (unsubscribe).
disconnect.set()
try:
await asyncio.wait_for(body_iter.aclose(), timeout=2.0)
except (StopAsyncIteration, RuntimeError):
pass
finally:
try:
next(conn_gen)
except StopIteration:
pass
async def test_pubsub_publish_fanout():
"""Every active subscriber of a chat receives published events."""
from chat.web.pubsub import publish, subscribe, unsubscribe
q1 = await subscribe("chat_x")
q2 = await subscribe("chat_x")
try:
await publish("chat_x", {"event": "test", "value": 1})
e1 = await asyncio.wait_for(q1.get(), timeout=1.0)
e2 = await asyncio.wait_for(q2.get(), timeout=1.0)
assert e1 == {"event": "test", "value": 1}
assert e2 == {"event": "test", "value": 1}
finally:
await unsubscribe("chat_x", q1)
await unsubscribe("chat_x", q2)
async def test_pubsub_isolated_per_chat():
"""Events published to one chat don't leak to a different chat."""
from chat.web.pubsub import publish, subscribe, unsubscribe
q_a = await subscribe("chat_a")
q_b = await subscribe("chat_b")
try:
await publish("chat_a", {"event": "for_a"})
e = await asyncio.wait_for(q_a.get(), timeout=1.0)
assert e["event"] == "for_a"
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(q_b.get(), timeout=0.1)
finally:
await unsubscribe("chat_a", q_a)
await unsubscribe("chat_b", q_b)