150 lines
5.3 KiB
Python
150 lines
5.3 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.templating import Jinja2Templates
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
|
|
from chat.config import load_settings
|
|
from chat.db.connection import open_db
|
|
from chat.db.migrate import apply_migrations
|
|
from chat.eventlog.log import read_events
|
|
from chat.eventlog.projector import apply_event
|
|
from chat.services.background import BackgroundWorker
|
|
from chat.services.embedding_worker import EmbeddingWorker
|
|
from chat.services.snapshot import latest_snapshot_path, restore_from_snapshot
|
|
|
|
# Trigger handler registration:
|
|
import chat.state.entities # noqa: F401
|
|
import chat.state.edges # noqa: F401
|
|
import chat.state.manual_edit # noqa: F401
|
|
import chat.state.memory # noqa: F401
|
|
import chat.state.world # noqa: F401
|
|
|
|
from chat.web.bots import router as bots_router
|
|
from chat.web.chat import router as chat_router
|
|
from chat.web.drawer import router as drawer_router
|
|
from chat.web.kickoff import router as kickoff_router
|
|
from chat.web.middleware import FirstRunRedirectMiddleware
|
|
from chat.web.nav import router as nav_router
|
|
from chat.web.settings import router as settings_router
|
|
from chat.web.sse import router as sse_router
|
|
from chat.web.turns import router as turns_router
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
settings = load_settings()
|
|
settings.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
apply_migrations(settings.db_path)
|
|
|
|
# T31 cold-load fast-path: if a periodic snapshot exists, restore
|
|
# projected tables from it and replay only events past its
|
|
# ``last_event_id``. Migrations already ran above, so any new tables
|
|
# introduced after the snapshot was taken are present and empty —
|
|
# the replay-forward step refills them from the event log.
|
|
snapshot_path = latest_snapshot_path(settings.data_dir, kind="periodic")
|
|
if snapshot_path is not None:
|
|
with open_db(settings.db_path) as conn:
|
|
last_event_id = restore_from_snapshot(conn, snapshot_path)
|
|
for event in read_events(
|
|
conn, branch_id=1, after_id=last_event_id
|
|
):
|
|
apply_event(conn, event)
|
|
log.info(
|
|
"cold-load restored from %s, replayed events past id %d",
|
|
snapshot_path,
|
|
last_event_id,
|
|
)
|
|
|
|
app.state.settings = settings
|
|
|
|
# Cap concurrent Featherless connections to the account's limit
|
|
# (free / lower paid tiers cap at 2). Shared across all
|
|
# FeatherlessClient instances in the process.
|
|
from chat.llm.featherless import FeatherlessClient
|
|
|
|
FeatherlessClient.configure_concurrency(settings.featherless_max_concurrent)
|
|
|
|
# Background worker for the async significance pass (T22). Each job
|
|
# constructs a fresh FeatherlessClient via the factory; tests can
|
|
# disable enqueue by toggling ``app.state.background_worker.enabled``.
|
|
def _factory():
|
|
return FeatherlessClient(
|
|
api_key=settings.featherless_api_key,
|
|
base_url=settings.featherless_base_url,
|
|
)
|
|
|
|
worker = BackgroundWorker(settings, llm_client_factory=_factory)
|
|
await worker.start()
|
|
app.state.background_worker = worker
|
|
|
|
# T97: separate worker for the async embedding pass. Each
|
|
# ``memory_written`` enqueues an EmbeddingJob; the worker drains the
|
|
# queue, calls ``generate_embedding``, and emits ``embedding_indexed``.
|
|
# Phase 4's pseudo-embedding path is local so the worker doesn't need
|
|
# an LLM client; we still pass one so the Phase 4.5 swap to a real
|
|
# model is a one-line change.
|
|
embedding_worker = EmbeddingWorker(
|
|
conn_factory=lambda: open_db(settings.db_path),
|
|
client=_factory(),
|
|
)
|
|
await embedding_worker.start()
|
|
app.state.embedding_worker = embedding_worker
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
await embedding_worker.stop()
|
|
await worker.stop()
|
|
|
|
|
|
app = FastAPI(title="chat", lifespan=lifespan)
|
|
app.add_middleware(FirstRunRedirectMiddleware)
|
|
|
|
STATIC_DIR = Path(__file__).resolve().parent / "static"
|
|
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
|
|
|
ERROR_TEMPLATES = Jinja2Templates(
|
|
directory=str(Path(__file__).resolve().parent / "templates")
|
|
)
|
|
|
|
|
|
@app.exception_handler(StarletteHTTPException)
|
|
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
|
"""Render a friendly HTML page for 404/500; JSON for everything else."""
|
|
if exc.status_code in (404, 500):
|
|
return ERROR_TEMPLATES.TemplateResponse(
|
|
request,
|
|
"errors.html",
|
|
{
|
|
"status_code": exc.status_code,
|
|
"detail": exc.detail or "Something went wrong.",
|
|
"active_nav": "chats",
|
|
},
|
|
status_code=exc.status_code,
|
|
)
|
|
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
|
|
|
|
|
app.include_router(bots_router)
|
|
app.include_router(kickoff_router)
|
|
app.include_router(settings_router)
|
|
app.include_router(nav_router)
|
|
app.include_router(chat_router)
|
|
app.include_router(drawer_router)
|
|
app.include_router(sse_router)
|
|
app.include_router(turns_router)
|
|
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
return {"status": "ok"}
|