175 lines
5.6 KiB
Python
175 lines
5.6 KiB
Python
"""Tests for the per-chat SSE channel and pub/sub registry (T16).
|
|
|
|
Covers:
|
|
- 404 from the SSE endpoint when the chat doesn't exist.
|
|
- The first event yielded on connect is the ``snapshot`` event with correct
|
|
framing (``event: snapshot\\ndata: {...}\\n\\n``) and includes the chat id.
|
|
- ``publish`` fans out to every subscriber of a chat.
|
|
- Events published to one chat don't leak to subscribers of another chat.
|
|
|
|
Notes:
|
|
- Starlette's TestClient buffers the entire ASGI response before returning,
|
|
so it cannot read from an indefinitely-open SSE stream. The framing test
|
|
drives the route handler directly and pulls frames from the
|
|
``StreamingResponse``'s body iterator. The 404 path is still validated via
|
|
TestClient since it returns synchronously without entering the stream.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
from fastapi import Request
|
|
from fastapi.testclient import TestClient
|
|
|
|
from chat.app import app
|
|
from chat.db.connection import open_db
|
|
from chat.eventlog.log import append_event
|
|
from chat.eventlog.projector import project
|
|
|
|
|
|
@pytest.fixture
|
|
def client(tmp_path, monkeypatch):
|
|
cfg = tmp_path / "config.toml"
|
|
cfg.write_text('featherless_api_key = "test"\n')
|
|
monkeypatch.setenv("CHAT_CONFIG_PATH", str(cfg))
|
|
db = tmp_path / "test.db"
|
|
monkeypatch.setenv("CHAT_DB_PATH", str(db))
|
|
with TestClient(app) as c:
|
|
yield c
|
|
|
|
|
|
def _seed_chat(
|
|
db_path: Path, bot_id: str = "bot_a", chat_id: str = "chat_bot_a"
|
|
) -> None:
|
|
with open_db(db_path) as conn:
|
|
append_event(
|
|
conn,
|
|
kind="bot_authored",
|
|
payload={
|
|
"id": bot_id,
|
|
"name": "BotA",
|
|
"persona": "...",
|
|
"voice_samples": [],
|
|
"traits": [],
|
|
"backstory": "",
|
|
"initial_relationship_to_you": "",
|
|
"kickoff_prose": "...",
|
|
},
|
|
)
|
|
append_event(
|
|
conn,
|
|
kind="chat_created",
|
|
payload={
|
|
"id": chat_id,
|
|
"host_bot_id": bot_id,
|
|
"initial_time": "2026-04-26T20:00:00+00:00",
|
|
"narrative_anchor": "Day 1",
|
|
"weather": "",
|
|
},
|
|
)
|
|
project(conn)
|
|
|
|
|
|
def test_sse_endpoint_404_for_missing_chat(client):
|
|
response = client.get("/chats/nonexistent/events")
|
|
assert response.status_code == 404
|
|
|
|
|
|
async def test_sse_streams_initial_snapshot(client, tmp_path):
|
|
"""The first SSE frame is a correctly framed ``snapshot`` event."""
|
|
_seed_chat(tmp_path / "test.db")
|
|
|
|
# Drive the route directly. The TestClient can't iterate an open stream
|
|
# because Starlette's transport waits for the response to complete.
|
|
from chat.web.bots import get_conn
|
|
from chat.web.sse import chat_events
|
|
|
|
disconnect = asyncio.Event()
|
|
|
|
async def receive():
|
|
await disconnect.wait()
|
|
return {"type": "http.disconnect"}
|
|
|
|
scope = {
|
|
"type": "http",
|
|
"method": "GET",
|
|
"path": "/chats/chat_bot_a/events",
|
|
"headers": [],
|
|
"query_string": b"",
|
|
"app": client.app,
|
|
}
|
|
request = Request(scope, receive=receive)
|
|
|
|
conn_gen = get_conn(request)
|
|
conn = next(conn_gen)
|
|
try:
|
|
response = await chat_events("chat_bot_a", request, conn=conn)
|
|
assert response.media_type == "text/event-stream"
|
|
|
|
body_iter = response.body_iterator
|
|
first = await asyncio.wait_for(body_iter.__anext__(), timeout=2.0)
|
|
text = first.decode()
|
|
|
|
# Framing: starts with the event line, ends with the blank-line
|
|
# terminator that delimits SSE frames.
|
|
assert text.startswith("event: snapshot\n")
|
|
assert "\ndata: " in text
|
|
assert text.endswith("\n\n")
|
|
|
|
# Payload is JSON containing the chat id.
|
|
data_line = next(
|
|
line for line in text.splitlines() if line.startswith("data: ")
|
|
)
|
|
payload = json.loads(data_line[len("data: "):])
|
|
assert payload["chat_id"] == "chat_bot_a"
|
|
|
|
# Cleanly tear down the generator's ``finally`` (unsubscribe).
|
|
disconnect.set()
|
|
try:
|
|
await asyncio.wait_for(body_iter.aclose(), timeout=2.0)
|
|
except (StopAsyncIteration, RuntimeError):
|
|
pass
|
|
finally:
|
|
try:
|
|
next(conn_gen)
|
|
except StopIteration:
|
|
pass
|
|
|
|
|
|
async def test_pubsub_publish_fanout():
|
|
"""Every active subscriber of a chat receives published events."""
|
|
from chat.web.pubsub import publish, subscribe, unsubscribe
|
|
|
|
q1 = await subscribe("chat_x")
|
|
q2 = await subscribe("chat_x")
|
|
try:
|
|
await publish("chat_x", {"event": "test", "value": 1})
|
|
e1 = await asyncio.wait_for(q1.get(), timeout=1.0)
|
|
e2 = await asyncio.wait_for(q2.get(), timeout=1.0)
|
|
assert e1 == {"event": "test", "value": 1}
|
|
assert e2 == {"event": "test", "value": 1}
|
|
finally:
|
|
await unsubscribe("chat_x", q1)
|
|
await unsubscribe("chat_x", q2)
|
|
|
|
|
|
async def test_pubsub_isolated_per_chat():
|
|
"""Events published to one chat don't leak to a different chat."""
|
|
from chat.web.pubsub import publish, subscribe, unsubscribe
|
|
|
|
q_a = await subscribe("chat_a")
|
|
q_b = await subscribe("chat_b")
|
|
try:
|
|
await publish("chat_a", {"event": "for_a"})
|
|
e = await asyncio.wait_for(q_a.get(), timeout=1.0)
|
|
assert e["event"] == "for_a"
|
|
with pytest.raises(asyncio.TimeoutError):
|
|
await asyncio.wait_for(q_b.get(), timeout=0.1)
|
|
finally:
|
|
await unsubscribe("chat_a", q_a)
|
|
await unsubscribe("chat_b", q_b)
|