Files
chat/chat/app.py
T
2026-04-27 02:51:44 -04:00

150 lines
5.3 KiB
Python

from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from starlette.exceptions import HTTPException as StarletteHTTPException
from chat.config import load_settings
from chat.db.connection import open_db
from chat.db.migrate import apply_migrations
from chat.eventlog.log import read_events
from chat.eventlog.projector import apply_event
from chat.services.background import BackgroundWorker
from chat.services.embedding_worker import EmbeddingWorker
from chat.services.snapshot import latest_snapshot_path, restore_from_snapshot
# Trigger handler registration:
import chat.state.entities # noqa: F401
import chat.state.edges # noqa: F401
import chat.state.manual_edit # noqa: F401
import chat.state.memory # noqa: F401
import chat.state.world # noqa: F401
from chat.web.bots import router as bots_router
from chat.web.chat import router as chat_router
from chat.web.drawer import router as drawer_router
from chat.web.kickoff import router as kickoff_router
from chat.web.middleware import FirstRunRedirectMiddleware
from chat.web.nav import router as nav_router
from chat.web.settings import router as settings_router
from chat.web.sse import router as sse_router
from chat.web.turns import router as turns_router
log = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
settings = load_settings()
settings.db_path.parent.mkdir(parents=True, exist_ok=True)
apply_migrations(settings.db_path)
# T31 cold-load fast-path: if a periodic snapshot exists, restore
# projected tables from it and replay only events past its
# ``last_event_id``. Migrations already ran above, so any new tables
# introduced after the snapshot was taken are present and empty —
# the replay-forward step refills them from the event log.
snapshot_path = latest_snapshot_path(settings.data_dir, kind="periodic")
if snapshot_path is not None:
with open_db(settings.db_path) as conn:
last_event_id = restore_from_snapshot(conn, snapshot_path)
for event in read_events(
conn, branch_id=1, after_id=last_event_id
):
apply_event(conn, event)
log.info(
"cold-load restored from %s, replayed events past id %d",
snapshot_path,
last_event_id,
)
app.state.settings = settings
# Cap concurrent Featherless connections to the account's limit
# (free / lower paid tiers cap at 2). Shared across all
# FeatherlessClient instances in the process.
from chat.llm.featherless import FeatherlessClient
FeatherlessClient.configure_concurrency(settings.featherless_max_concurrent)
# Background worker for the async significance pass (T22). Each job
# constructs a fresh FeatherlessClient via the factory; tests can
# disable enqueue by toggling ``app.state.background_worker.enabled``.
def _factory():
return FeatherlessClient(
api_key=settings.featherless_api_key,
base_url=settings.featherless_base_url,
)
worker = BackgroundWorker(settings, llm_client_factory=_factory)
await worker.start()
app.state.background_worker = worker
# T97: separate worker for the async embedding pass. Each
# ``memory_written`` enqueues an EmbeddingJob; the worker drains the
# queue, calls ``generate_embedding``, and emits ``embedding_indexed``.
# Phase 4's pseudo-embedding path is local so the worker doesn't need
# an LLM client; we still pass one so the Phase 4.5 swap to a real
# model is a one-line change.
embedding_worker = EmbeddingWorker(
conn_factory=lambda: open_db(settings.db_path),
client=_factory(),
)
await embedding_worker.start()
app.state.embedding_worker = embedding_worker
try:
yield
finally:
await embedding_worker.stop()
await worker.stop()
app = FastAPI(title="chat", lifespan=lifespan)
app.add_middleware(FirstRunRedirectMiddleware)
STATIC_DIR = Path(__file__).resolve().parent / "static"
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
ERROR_TEMPLATES = Jinja2Templates(
directory=str(Path(__file__).resolve().parent / "templates")
)
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
"""Render a friendly HTML page for 404/500; JSON for everything else."""
if exc.status_code in (404, 500):
return ERROR_TEMPLATES.TemplateResponse(
request,
"errors.html",
{
"status_code": exc.status_code,
"detail": exc.detail or "Something went wrong.",
"active_nav": "chats",
},
status_code=exc.status_code,
)
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
app.include_router(bots_router)
app.include_router(kickoff_router)
app.include_router(settings_router)
app.include_router(nav_router)
app.include_router(chat_router)
app.include_router(drawer_router)
app.include_router(sse_router)
app.include_router(turns_router)
@app.get("/health")
def health():
return {"status": "ok"}