feat: per-chat SSE channel and pub/sub
This commit is contained in:
@@ -19,6 +19,7 @@ from chat.web.chat import router as chat_router
|
||||
from chat.web.kickoff import router as kickoff_router
|
||||
from chat.web.nav import router as nav_router
|
||||
from chat.web.settings import router as settings_router
|
||||
from chat.web.sse import router as sse_router
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -40,6 +41,7 @@ app.include_router(kickoff_router)
|
||||
app.include_router(settings_router)
|
||||
app.include_router(nav_router)
|
||||
app.include_router(chat_router)
|
||||
app.include_router(sse_router)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
"""In-process per-chat broadcast channel.
|
||||
|
||||
Each ``chat_id`` has a list of subscriber ``asyncio.Queue`` instances. T16
|
||||
provides only the registry and fan-out mechanism; T19+ will publish events
|
||||
(turn appends, streamed tokens, drawer updates, scene close, edge updates)
|
||||
through this channel so all browser tabs viewing a chat stay in sync.
|
||||
|
||||
The registry is process-local: appropriate for a single-user local server.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
# {chat_id: [queue, queue, ...]}
|
||||
_subscribers: dict[str, list[asyncio.Queue]] = defaultdict(list)
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def subscribe(chat_id: str) -> asyncio.Queue:
|
||||
"""Subscribe to a chat's broadcast channel.
|
||||
|
||||
Returns a fresh ``asyncio.Queue`` that will receive every event published
|
||||
to ``chat_id`` while the subscription is active. Callers must invoke
|
||||
:func:`unsubscribe` when finished (typically on client disconnect) to
|
||||
avoid leaking queues into the registry.
|
||||
"""
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
async with _lock:
|
||||
_subscribers[chat_id].append(queue)
|
||||
return queue
|
||||
|
||||
|
||||
async def unsubscribe(chat_id: str, queue: asyncio.Queue) -> None:
|
||||
"""Remove ``queue`` from the registry; remove the chat key if empty."""
|
||||
async with _lock:
|
||||
if chat_id in _subscribers:
|
||||
if queue in _subscribers[chat_id]:
|
||||
_subscribers[chat_id].remove(queue)
|
||||
if not _subscribers[chat_id]:
|
||||
del _subscribers[chat_id]
|
||||
|
||||
|
||||
async def publish(chat_id: str, event: dict[str, Any]) -> None:
|
||||
"""Fan-out ``event`` to every subscriber of ``chat_id``.
|
||||
|
||||
The same dict reference is enqueued to all subscribers. Callers should
|
||||
treat published events as immutable. Queues are unbounded for v1.
|
||||
"""
|
||||
async with _lock:
|
||||
queues = list(_subscribers.get(chat_id, []))
|
||||
for q in queues:
|
||||
await q.put(event)
|
||||
|
||||
|
||||
def subscriber_count(chat_id: str) -> int:
|
||||
"""Test helper. Returns the number of active subscribers for a chat."""
|
||||
return len(_subscribers.get(chat_id, []))
|
||||
@@ -0,0 +1,72 @@
|
||||
"""Server-Sent Events endpoint for per-chat live updates.
|
||||
|
||||
Each browser tab on ``/chats/<id>`` opens an SSE connection here. On connect:
|
||||
|
||||
1. We verify the chat exists (404 otherwise).
|
||||
2. We subscribe to the chat's pub/sub channel.
|
||||
3. We emit a ``snapshot`` event with the current state. T16 only provides a
|
||||
stub payload (``{"chat_id": <id>, "ready": true}``) so the client can
|
||||
confirm the channel is live; T19+ will populate it with real state.
|
||||
4. We loop, awaiting events from the queue and yielding them as SSE frames.
|
||||
A 15-second keepalive comment is emitted on idle to defeat intermediary
|
||||
timeouts.
|
||||
5. When the client disconnects we unsubscribe so the registry doesn't leak.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from chat.state.world import get_chat
|
||||
from chat.web.bots import get_conn
|
||||
from chat.web.pubsub import subscribe, unsubscribe
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Heartbeat cadence. Long enough to avoid chattiness; short enough that most
|
||||
# HTTP intermediaries won't close an idle connection.
|
||||
_KEEPALIVE_SECONDS = 15.0
|
||||
|
||||
|
||||
def _format_sse(event: str, data: dict) -> bytes:
|
||||
"""Format a single SSE frame: ``event: <name>\\ndata: <json>\\n\\n``."""
|
||||
payload = json.dumps(data)
|
||||
return f"event: {event}\ndata: {payload}\n\n".encode("utf-8")
|
||||
|
||||
|
||||
@router.get("/chats/{chat_id}/events")
|
||||
async def chat_events(chat_id: str, request: Request, conn=Depends(get_conn)):
|
||||
chat = get_chat(conn, chat_id)
|
||||
if chat is None:
|
||||
raise HTTPException(status_code=404, detail="chat not found")
|
||||
|
||||
async def stream():
|
||||
queue = await subscribe(chat_id)
|
||||
try:
|
||||
# Initial snapshot — T19 will fill in real state.
|
||||
yield _format_sse("snapshot", {"chat_id": chat_id, "ready": True})
|
||||
while True:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
try:
|
||||
event = await asyncio.wait_for(
|
||||
queue.get(), timeout=_KEEPALIVE_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# SSE comment line (per spec, lines starting with ":" are
|
||||
# ignored by the client) — keeps the connection warm.
|
||||
yield b": keepalive\n\n"
|
||||
continue
|
||||
# Allow publishers to set the SSE event name via "event" key;
|
||||
# default to "message" if omitted.
|
||||
event = dict(event) # don't mutate the published dict
|
||||
kind = event.pop("event", "message")
|
||||
yield _format_sse(kind, event)
|
||||
finally:
|
||||
await unsubscribe(chat_id, queue)
|
||||
|
||||
return StreamingResponse(stream(), media_type="text/event-stream")
|
||||
@@ -0,0 +1,174 @@
|
||||
"""Tests for the per-chat SSE channel and pub/sub registry (T16).
|
||||
|
||||
Covers:
|
||||
- 404 from the SSE endpoint when the chat doesn't exist.
|
||||
- The first event yielded on connect is the ``snapshot`` event with correct
|
||||
framing (``event: snapshot\\ndata: {...}\\n\\n``) and includes the chat id.
|
||||
- ``publish`` fans out to every subscriber of a chat.
|
||||
- Events published to one chat don't leak to subscribers of another chat.
|
||||
|
||||
Notes:
|
||||
- Starlette's TestClient buffers the entire ASGI response before returning,
|
||||
so it cannot read from an indefinitely-open SSE stream. The framing test
|
||||
drives the route handler directly and pulls frames from the
|
||||
``StreamingResponse``'s body iterator. The 404 path is still validated via
|
||||
TestClient since it returns synchronously without entering the stream.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from chat.app import app
|
||||
from chat.db.connection import open_db
|
||||
from chat.eventlog.log import append_event
|
||||
from chat.eventlog.projector import project
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(tmp_path, monkeypatch):
|
||||
cfg = tmp_path / "config.toml"
|
||||
cfg.write_text('featherless_api_key = "test"\n')
|
||||
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
|
||||
db = tmp_path / "test.db"
|
||||
monkeypatch.setenv("CHAT_DB_PATH", str(db))
|
||||
with TestClient(app) as c:
|
||||
yield c
|
||||
|
||||
|
||||
def _seed_chat(
|
||||
db_path: Path, bot_id: str = "bot_a", chat_id: str = "chat_bot_a"
|
||||
) -> None:
|
||||
with open_db(db_path) as conn:
|
||||
append_event(
|
||||
conn,
|
||||
kind="bot_authored",
|
||||
payload={
|
||||
"id": bot_id,
|
||||
"name": "BotA",
|
||||
"persona": "...",
|
||||
"voice_samples": [],
|
||||
"traits": [],
|
||||
"backstory": "",
|
||||
"initial_relationship_to_you": "",
|
||||
"kickoff_prose": "...",
|
||||
},
|
||||
)
|
||||
append_event(
|
||||
conn,
|
||||
kind="chat_created",
|
||||
payload={
|
||||
"id": chat_id,
|
||||
"host_bot_id": bot_id,
|
||||
"initial_time": "2026-04-26T20:00:00+00:00",
|
||||
"narrative_anchor": "Day 1",
|
||||
"weather": "",
|
||||
},
|
||||
)
|
||||
project(conn)
|
||||
|
||||
|
||||
def test_sse_endpoint_404_for_missing_chat(client):
|
||||
response = client.get("/chats/nonexistent/events")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
async def test_sse_streams_initial_snapshot(client, tmp_path):
|
||||
"""The first SSE frame is a correctly framed ``snapshot`` event."""
|
||||
_seed_chat(tmp_path / "test.db")
|
||||
|
||||
# Drive the route directly. The TestClient can't iterate an open stream
|
||||
# because Starlette's transport waits for the response to complete.
|
||||
from chat.web.bots import get_conn
|
||||
from chat.web.sse import chat_events
|
||||
|
||||
disconnect = asyncio.Event()
|
||||
|
||||
async def receive():
|
||||
await disconnect.wait()
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "GET",
|
||||
"path": "/chats/chat_bot_a/events",
|
||||
"headers": [],
|
||||
"query_string": b"",
|
||||
"app": client.app,
|
||||
}
|
||||
request = Request(scope, receive=receive)
|
||||
|
||||
conn_gen = get_conn(request)
|
||||
conn = next(conn_gen)
|
||||
try:
|
||||
response = await chat_events("chat_bot_a", request, conn=conn)
|
||||
assert response.media_type == "text/event-stream"
|
||||
|
||||
body_iter = response.body_iterator
|
||||
first = await asyncio.wait_for(body_iter.__anext__(), timeout=2.0)
|
||||
text = first.decode()
|
||||
|
||||
# Framing: starts with the event line, ends with the blank-line
|
||||
# terminator that delimits SSE frames.
|
||||
assert text.startswith("event: snapshot\n")
|
||||
assert "\ndata: " in text
|
||||
assert text.endswith("\n\n")
|
||||
|
||||
# Payload is JSON containing the chat id.
|
||||
data_line = next(
|
||||
line for line in text.splitlines() if line.startswith("data: ")
|
||||
)
|
||||
payload = json.loads(data_line[len("data: "):])
|
||||
assert payload["chat_id"] == "chat_bot_a"
|
||||
|
||||
# Cleanly tear down the generator's ``finally`` (unsubscribe).
|
||||
disconnect.set()
|
||||
try:
|
||||
await asyncio.wait_for(body_iter.aclose(), timeout=2.0)
|
||||
except (StopAsyncIteration, RuntimeError):
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
next(conn_gen)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
|
||||
async def test_pubsub_publish_fanout():
|
||||
"""Every active subscriber of a chat receives published events."""
|
||||
from chat.web.pubsub import publish, subscribe, unsubscribe
|
||||
|
||||
q1 = await subscribe("chat_x")
|
||||
q2 = await subscribe("chat_x")
|
||||
try:
|
||||
await publish("chat_x", {"event": "test", "value": 1})
|
||||
e1 = await asyncio.wait_for(q1.get(), timeout=1.0)
|
||||
e2 = await asyncio.wait_for(q2.get(), timeout=1.0)
|
||||
assert e1 == {"event": "test", "value": 1}
|
||||
assert e2 == {"event": "test", "value": 1}
|
||||
finally:
|
||||
await unsubscribe("chat_x", q1)
|
||||
await unsubscribe("chat_x", q2)
|
||||
|
||||
|
||||
async def test_pubsub_isolated_per_chat():
|
||||
"""Events published to one chat don't leak to a different chat."""
|
||||
from chat.web.pubsub import publish, subscribe, unsubscribe
|
||||
|
||||
q_a = await subscribe("chat_a")
|
||||
q_b = await subscribe("chat_b")
|
||||
try:
|
||||
await publish("chat_a", {"event": "for_a"})
|
||||
e = await asyncio.wait_for(q_a.get(), timeout=1.0)
|
||||
assert e["event"] == "for_a"
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(q_b.get(), timeout=0.1)
|
||||
finally:
|
||||
await unsubscribe("chat_a", q_a)
|
||||
await unsubscribe("chat_b", q_b)
|
||||
Reference in New Issue
Block a user