133 lines
4.1 KiB
Python
133 lines
4.1 KiB
Python
from __future__ import annotations
|
|
from pathlib import Path
|
|
from fastapi import APIRouter, Depends, Form, HTTPException, Request
|
|
from fastapi.responses import RedirectResponse, HTMLResponse
|
|
from fastapi.templating import Jinja2Templates
|
|
|
|
from chat.db.connection import open_db
|
|
from chat.eventlog.log import append_event
|
|
from chat.eventlog.projector import project
|
|
from chat.state.entities import list_bots
|
|
|
|
TEMPLATES = Jinja2Templates(directory=str(Path(__file__).resolve().parent.parent / "templates"))
|
|
|
|
router = APIRouter()
|
|
|
|
REQUIRED_FIELDS = ("id", "name", "persona", "initial_relationship_to_you", "kickoff_prose")
|
|
|
|
|
|
def get_conn(request: Request):
|
|
settings = request.app.state.settings
|
|
db_path: Path = settings.db_path
|
|
with open_db(db_path, check_same_thread=False) as conn:
|
|
yield conn
|
|
|
|
|
|
def _split_voice_samples(text: str) -> list[str]:
|
|
if not text or not text.strip():
|
|
return []
|
|
# Split on a line containing only "---" (with optional surrounding whitespace).
|
|
parts: list[str] = []
|
|
buf: list[str] = []
|
|
for line in text.splitlines():
|
|
if line.strip() == "---":
|
|
if buf:
|
|
parts.append("\n".join(buf).strip())
|
|
buf = []
|
|
continue
|
|
buf.append(line)
|
|
if buf:
|
|
parts.append("\n".join(buf).strip())
|
|
return [p for p in parts if p]
|
|
|
|
|
|
def _split_traits(text: str) -> list[str]:
|
|
if not text or not text.strip():
|
|
return []
|
|
items: list[str] = []
|
|
for line in text.splitlines():
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
if "," in line:
|
|
items.extend(p.strip() for p in line.split(","))
|
|
else:
|
|
items.append(line)
|
|
return [t for t in items if t]
|
|
|
|
|
|
@router.get("/bots", response_class=HTMLResponse)
|
|
async def bots_list(request: Request, conn=Depends(get_conn)):
|
|
bots = list_bots(conn)
|
|
return TEMPLATES.TemplateResponse(
|
|
request, "bot_list.html", {"bots": bots, "active_nav": "bots"}
|
|
)
|
|
|
|
|
|
@router.get("/bots/new", response_class=HTMLResponse)
|
|
async def bot_form(request: Request):
|
|
return TEMPLATES.TemplateResponse(
|
|
request, "bot_form.html", {"values": {}, "error": None, "active_nav": "bots"}
|
|
)
|
|
|
|
|
|
@router.post("/bots/new")
|
|
async def bot_create(
|
|
request: Request,
|
|
id: str = Form(""),
|
|
name: str = Form(""),
|
|
persona: str = Form(""),
|
|
voice_samples: str = Form(""),
|
|
traits: str = Form(""),
|
|
backstory: str = Form(""),
|
|
initial_relationship_to_you: str = Form(""),
|
|
kickoff_prose: str = Form(""),
|
|
conn=Depends(get_conn),
|
|
):
|
|
values = {
|
|
"id": id,
|
|
"name": name,
|
|
"persona": persona,
|
|
"voice_samples": voice_samples,
|
|
"traits": traits,
|
|
"backstory": backstory,
|
|
"initial_relationship_to_you": initial_relationship_to_you,
|
|
"kickoff_prose": kickoff_prose,
|
|
}
|
|
missing = [f for f in REQUIRED_FIELDS if not values[f].strip()]
|
|
if missing:
|
|
raise HTTPException(status_code=400, detail=f"missing required: {', '.join(missing)}")
|
|
|
|
payload = {
|
|
"id": id.strip(),
|
|
"name": name.strip(),
|
|
"persona": persona.strip(),
|
|
"voice_samples": _split_voice_samples(voice_samples),
|
|
"traits": _split_traits(traits),
|
|
"backstory": backstory.strip(),
|
|
"initial_relationship_to_you": initial_relationship_to_you.strip(),
|
|
"kickoff_prose": kickoff_prose.strip(),
|
|
}
|
|
append_event(conn, kind="bot_authored", payload=payload)
|
|
project(conn)
|
|
return RedirectResponse(url=f"/bots/{payload['id']}/kickoff", status_code=303)
|
|
|
|
|
|
@router.post("/bots/{bot_id}/reset")
|
|
async def reset_bot_route(
|
|
bot_id: str,
|
|
request: Request,
|
|
confirm_name: str = Form(""),
|
|
conn=Depends(get_conn),
|
|
):
|
|
from chat.services.reset import reset_bot
|
|
|
|
try:
|
|
reset_bot(conn, bot_id, confirm_name=confirm_name)
|
|
except ValueError as e:
|
|
msg = str(e).lower()
|
|
if "not found" in msg:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
return RedirectResponse(url="/bots", status_code=303)
|