Files
2026-04-26 12:49:41 -04:00

175 lines
5.6 KiB
Python

"""Tests for the per-chat SSE channel and pub/sub registry (T16).
Covers:
- 404 from the SSE endpoint when the chat doesn't exist.
- The first event yielded on connect is the ``snapshot`` event with correct
framing (``event: snapshot\\ndata: {...}\\n\\n``) and includes the chat id.
- ``publish`` fans out to every subscriber of a chat.
- Events published to one chat don't leak to subscribers of another chat.
Notes:
- Starlette's TestClient buffers the entire ASGI response before returning,
so it cannot read from an indefinitely-open SSE stream. The framing test
drives the route handler directly and pulls frames from the
``StreamingResponse``'s body iterator. The 404 path is still validated via
TestClient since it returns synchronously without entering the stream.
"""
from __future__ import annotations
import asyncio
import json
from pathlib import Path
import pytest
from fastapi import Request
from fastapi.testclient import TestClient
from chat.app import app
from chat.db.connection import open_db
from chat.eventlog.log import append_event
from chat.eventlog.projector import project
@pytest.fixture
def client(tmp_path, monkeypatch):
cfg = tmp_path / "config.toml"
cfg.write_text('featherless_api_key = "test"\n')
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
db = tmp_path / "test.db"
monkeypatch.setenv("CHAT_DB_PATH", str(db))
with TestClient(app) as c:
yield c
def _seed_chat(
db_path: Path, bot_id: str = "bot_a", chat_id: str = "chat_bot_a"
) -> None:
with open_db(db_path) as conn:
append_event(
conn,
kind="bot_authored",
payload={
"id": bot_id,
"name": "BotA",
"persona": "...",
"voice_samples": [],
"traits": [],
"backstory": "",
"initial_relationship_to_you": "",
"kickoff_prose": "...",
},
)
append_event(
conn,
kind="chat_created",
payload={
"id": chat_id,
"host_bot_id": bot_id,
"initial_time": "2026-04-26T20:00:00+00:00",
"narrative_anchor": "Day 1",
"weather": "",
},
)
project(conn)
def test_sse_endpoint_404_for_missing_chat(client):
response = client.get("/chats/nonexistent/events")
assert response.status_code == 404
async def test_sse_streams_initial_snapshot(client, tmp_path):
"""The first SSE frame is a correctly framed ``snapshot`` event."""
_seed_chat(tmp_path / "test.db")
# Drive the route directly. The TestClient can't iterate an open stream
# because Starlette's transport waits for the response to complete.
from chat.web.bots import get_conn
from chat.web.sse import chat_events
disconnect = asyncio.Event()
async def receive():
await disconnect.wait()
return {"type": "http.disconnect"}
scope = {
"type": "http",
"method": "GET",
"path": "/chats/chat_bot_a/events",
"headers": [],
"query_string": b"",
"app": client.app,
}
request = Request(scope, receive=receive)
conn_gen = get_conn(request)
conn = next(conn_gen)
try:
response = await chat_events("chat_bot_a", request, conn=conn)
assert response.media_type == "text/event-stream"
body_iter = response.body_iterator
first = await asyncio.wait_for(body_iter.__anext__(), timeout=2.0)
text = first.decode()
# Framing: starts with the event line, ends with the blank-line
# terminator that delimits SSE frames.
assert text.startswith("event: snapshot\n")
assert "\ndata: " in text
assert text.endswith("\n\n")
# Payload is JSON containing the chat id.
data_line = next(
line for line in text.splitlines() if line.startswith("data: ")
)
payload = json.loads(data_line[len("data: "):])
assert payload["chat_id"] == "chat_bot_a"
# Cleanly tear down the generator's ``finally`` (unsubscribe).
disconnect.set()
try:
await asyncio.wait_for(body_iter.aclose(), timeout=2.0)
except (StopAsyncIteration, RuntimeError):
pass
finally:
try:
next(conn_gen)
except StopIteration:
pass
async def test_pubsub_publish_fanout():
"""Every active subscriber of a chat receives published events."""
from chat.web.pubsub import publish, subscribe, unsubscribe
q1 = await subscribe("chat_x")
q2 = await subscribe("chat_x")
try:
await publish("chat_x", {"event": "test", "value": 1})
e1 = await asyncio.wait_for(q1.get(), timeout=1.0)
e2 = await asyncio.wait_for(q2.get(), timeout=1.0)
assert e1 == {"event": "test", "value": 1}
assert e2 == {"event": "test", "value": 1}
finally:
await unsubscribe("chat_x", q1)
await unsubscribe("chat_x", q2)
async def test_pubsub_isolated_per_chat():
"""Events published to one chat don't leak to a different chat."""
from chat.web.pubsub import publish, subscribe, unsubscribe
q_a = await subscribe("chat_a")
q_b = await subscribe("chat_b")
try:
await publish("chat_a", {"event": "for_a"})
e = await asyncio.wait_for(q_a.get(), timeout=1.0)
assert e["event"] == "for_a"
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(q_b.get(), timeout=0.1)
finally:
await unsubscribe("chat_a", q_a)
await unsubscribe("chat_b", q_b)