61 lines
2.1 KiB
Python
61 lines
2.1 KiB
Python
"""In-process per-chat broadcast channel.
|
|
|
|
Each ``chat_id`` has a list of subscriber ``asyncio.Queue`` instances. T16
|
|
provides only the registry and fan-out mechanism; T19+ will publish events
|
|
(turn appends, streamed tokens, drawer updates, scene close, edge updates)
|
|
through this channel so all browser tabs viewing a chat stay in sync.
|
|
|
|
The registry is process-local: appropriate for a single-user local server.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections import defaultdict
|
|
from typing import Any
|
|
|
|
# {chat_id: [queue, queue, ...]}
|
|
_subscribers: dict[str, list[asyncio.Queue]] = defaultdict(list)
|
|
_lock = asyncio.Lock()
|
|
|
|
|
|
async def subscribe(chat_id: str) -> asyncio.Queue:
|
|
"""Subscribe to a chat's broadcast channel.
|
|
|
|
Returns a fresh ``asyncio.Queue`` that will receive every event published
|
|
to ``chat_id`` while the subscription is active. Callers must invoke
|
|
:func:`unsubscribe` when finished (typically on client disconnect) to
|
|
avoid leaking queues into the registry.
|
|
"""
|
|
queue: asyncio.Queue = asyncio.Queue()
|
|
async with _lock:
|
|
_subscribers[chat_id].append(queue)
|
|
return queue
|
|
|
|
|
|
async def unsubscribe(chat_id: str, queue: asyncio.Queue) -> None:
|
|
"""Remove ``queue`` from the registry; remove the chat key if empty."""
|
|
async with _lock:
|
|
if chat_id in _subscribers:
|
|
if queue in _subscribers[chat_id]:
|
|
_subscribers[chat_id].remove(queue)
|
|
if not _subscribers[chat_id]:
|
|
del _subscribers[chat_id]
|
|
|
|
|
|
async def publish(chat_id: str, event: dict[str, Any]) -> None:
|
|
"""Fan-out ``event`` to every subscriber of ``chat_id``.
|
|
|
|
The same dict reference is enqueued to all subscribers. Callers should
|
|
treat published events as immutable. Queues are unbounded for v1.
|
|
"""
|
|
async with _lock:
|
|
queues = list(_subscribers.get(chat_id, []))
|
|
for q in queues:
|
|
await q.put(event)
|
|
|
|
|
|
def subscriber_count(chat_id: str) -> int:
|
|
"""Test helper. Returns the number of active subscribers for a chat."""
|
|
return len(_subscribers.get(chat_id, []))
|