feat: per-chat SSE channel and pub/sub
This commit is contained in:
@@ -0,0 +1,174 @@
|
||||
"""Tests for the per-chat SSE channel and pub/sub registry (T16).
|
||||
|
||||
Covers:
|
||||
- 404 from the SSE endpoint when the chat doesn't exist.
|
||||
- The first event yielded on connect is the ``snapshot`` event with correct
|
||||
framing (``event: snapshot\\ndata: {...}\\n\\n``) and includes the chat id.
|
||||
- ``publish`` fans out to every subscriber of a chat.
|
||||
- Events published to one chat don't leak to subscribers of another chat.
|
||||
|
||||
Notes:
|
||||
- Starlette's TestClient buffers the entire ASGI response before returning,
|
||||
so it cannot read from an indefinitely-open SSE stream. The framing test
|
||||
drives the route handler directly and pulls frames from the
|
||||
``StreamingResponse``'s body iterator. The 404 path is still validated via
|
||||
TestClient since it returns synchronously without entering the stream.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from chat.app import app
|
||||
from chat.db.connection import open_db
|
||||
from chat.eventlog.log import append_event
|
||||
from chat.eventlog.projector import project
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(tmp_path, monkeypatch):
|
||||
cfg = tmp_path / "config.toml"
|
||||
cfg.write_text('featherless_api_key = "test"\n')
|
||||
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
|
||||
db = tmp_path / "test.db"
|
||||
monkeypatch.setenv("CHAT_DB_PATH", str(db))
|
||||
with TestClient(app) as c:
|
||||
yield c
|
||||
|
||||
|
||||
def _seed_chat(
|
||||
db_path: Path, bot_id: str = "bot_a", chat_id: str = "chat_bot_a"
|
||||
) -> None:
|
||||
with open_db(db_path) as conn:
|
||||
append_event(
|
||||
conn,
|
||||
kind="bot_authored",
|
||||
payload={
|
||||
"id": bot_id,
|
||||
"name": "BotA",
|
||||
"persona": "...",
|
||||
"voice_samples": [],
|
||||
"traits": [],
|
||||
"backstory": "",
|
||||
"initial_relationship_to_you": "",
|
||||
"kickoff_prose": "...",
|
||||
},
|
||||
)
|
||||
append_event(
|
||||
conn,
|
||||
kind="chat_created",
|
||||
payload={
|
||||
"id": chat_id,
|
||||
"host_bot_id": bot_id,
|
||||
"initial_time": "2026-04-26T20:00:00+00:00",
|
||||
"narrative_anchor": "Day 1",
|
||||
"weather": "",
|
||||
},
|
||||
)
|
||||
project(conn)
|
||||
|
||||
|
||||
def test_sse_endpoint_404_for_missing_chat(client):
|
||||
response = client.get("/chats/nonexistent/events")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
async def test_sse_streams_initial_snapshot(client, tmp_path):
|
||||
"""The first SSE frame is a correctly framed ``snapshot`` event."""
|
||||
_seed_chat(tmp_path / "test.db")
|
||||
|
||||
# Drive the route directly. The TestClient can't iterate an open stream
|
||||
# because Starlette's transport waits for the response to complete.
|
||||
from chat.web.bots import get_conn
|
||||
from chat.web.sse import chat_events
|
||||
|
||||
disconnect = asyncio.Event()
|
||||
|
||||
async def receive():
|
||||
await disconnect.wait()
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "GET",
|
||||
"path": "/chats/chat_bot_a/events",
|
||||
"headers": [],
|
||||
"query_string": b"",
|
||||
"app": client.app,
|
||||
}
|
||||
request = Request(scope, receive=receive)
|
||||
|
||||
conn_gen = get_conn(request)
|
||||
conn = next(conn_gen)
|
||||
try:
|
||||
response = await chat_events("chat_bot_a", request, conn=conn)
|
||||
assert response.media_type == "text/event-stream"
|
||||
|
||||
body_iter = response.body_iterator
|
||||
first = await asyncio.wait_for(body_iter.__anext__(), timeout=2.0)
|
||||
text = first.decode()
|
||||
|
||||
# Framing: starts with the event line, ends with the blank-line
|
||||
# terminator that delimits SSE frames.
|
||||
assert text.startswith("event: snapshot\n")
|
||||
assert "\ndata: " in text
|
||||
assert text.endswith("\n\n")
|
||||
|
||||
# Payload is JSON containing the chat id.
|
||||
data_line = next(
|
||||
line for line in text.splitlines() if line.startswith("data: ")
|
||||
)
|
||||
payload = json.loads(data_line[len("data: "):])
|
||||
assert payload["chat_id"] == "chat_bot_a"
|
||||
|
||||
# Cleanly tear down the generator's ``finally`` (unsubscribe).
|
||||
disconnect.set()
|
||||
try:
|
||||
await asyncio.wait_for(body_iter.aclose(), timeout=2.0)
|
||||
except (StopAsyncIteration, RuntimeError):
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
next(conn_gen)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
|
||||
async def test_pubsub_publish_fanout():
|
||||
"""Every active subscriber of a chat receives published events."""
|
||||
from chat.web.pubsub import publish, subscribe, unsubscribe
|
||||
|
||||
q1 = await subscribe("chat_x")
|
||||
q2 = await subscribe("chat_x")
|
||||
try:
|
||||
await publish("chat_x", {"event": "test", "value": 1})
|
||||
e1 = await asyncio.wait_for(q1.get(), timeout=1.0)
|
||||
e2 = await asyncio.wait_for(q2.get(), timeout=1.0)
|
||||
assert e1 == {"event": "test", "value": 1}
|
||||
assert e2 == {"event": "test", "value": 1}
|
||||
finally:
|
||||
await unsubscribe("chat_x", q1)
|
||||
await unsubscribe("chat_x", q2)
|
||||
|
||||
|
||||
async def test_pubsub_isolated_per_chat():
|
||||
"""Events published to one chat don't leak to a different chat."""
|
||||
from chat.web.pubsub import publish, subscribe, unsubscribe
|
||||
|
||||
q_a = await subscribe("chat_a")
|
||||
q_b = await subscribe("chat_b")
|
||||
try:
|
||||
await publish("chat_a", {"event": "for_a"})
|
||||
e = await asyncio.wait_for(q_a.get(), timeout=1.0)
|
||||
assert e["event"] == "for_a"
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(q_b.get(), timeout=0.1)
|
||||
finally:
|
||||
await unsubscribe("chat_a", q_a)
|
||||
await unsubscribe("chat_b", q_b)
|
||||
Reference in New Issue
Block a user