refactor: open_db with check_same_thread parameter (T68)

This commit is contained in:
Joseph Doherty
2026-04-26 17:05:29 -04:00
parent e05f28e9d5
commit 994728b5ed
3 changed files with 61 additions and 11 deletions
+2 -2
View File
@@ -5,9 +5,9 @@ from pathlib import Path
@contextmanager @contextmanager
def open_db(path: Path): def open_db(path: Path, *, check_same_thread: bool = True):
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(path) conn = sqlite3.connect(path, check_same_thread=check_same_thread)
conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON") conn.execute("PRAGMA foreign_keys=ON")
try: try:
+2 -9
View File
@@ -1,10 +1,10 @@
from __future__ import annotations from __future__ import annotations
import sqlite3
from pathlib import Path from pathlib import Path
from fastapi import APIRouter, Depends, Form, HTTPException, Request from fastapi import APIRouter, Depends, Form, HTTPException, Request
from fastapi.responses import RedirectResponse, HTMLResponse from fastapi.responses import RedirectResponse, HTMLResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from chat.db.connection import open_db
from chat.eventlog.log import append_event from chat.eventlog.log import append_event
from chat.eventlog.projector import project from chat.eventlog.projector import project
from chat.state.entities import list_bots from chat.state.entities import list_bots
@@ -19,15 +19,8 @@ REQUIRED_FIELDS = ("id", "name", "persona", "initial_relationship_to_you", "kick
def get_conn(request: Request): def get_conn(request: Request):
settings = request.app.state.settings settings = request.app.state.settings
db_path: Path = settings.db_path db_path: Path = settings.db_path
db_path.parent.mkdir(parents=True, exist_ok=True) with open_db(db_path, check_same_thread=False) as conn:
conn = sqlite3.connect(db_path, check_same_thread=False)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
try:
yield conn yield conn
conn.commit()
finally:
conn.close()
def _split_voice_samples(text: str) -> list[str]: def _split_voice_samples(text: str) -> list[str]:
+57
View File
@@ -0,0 +1,57 @@
from __future__ import annotations
import sqlite3
import threading
from chat.db.connection import open_db
def test_open_db_default_uses_check_same_thread_true(tmp_path):
"""Default open_db must reject cross-thread use (safe default)."""
db = tmp_path / "t.db"
captured: list[BaseException | None] = []
with open_db(db) as conn:
conn.execute("CREATE TABLE t (x INTEGER)")
def worker():
try:
conn.execute("SELECT 1").fetchall()
captured.append(None)
except BaseException as e: # noqa: BLE001
captured.append(e)
t = threading.Thread(target=worker)
t.start()
t.join()
assert len(captured) == 1
err = captured[0]
assert isinstance(err, sqlite3.ProgrammingError), (
f"expected sqlite3.ProgrammingError on cross-thread use, got {err!r}"
)
def test_open_db_can_disable_check_same_thread(tmp_path):
"""open_db(check_same_thread=False) must allow cross-thread use."""
db = tmp_path / "t.db"
captured: list[BaseException | None] = []
rows: list[object] = []
with open_db(db, check_same_thread=False) as conn:
conn.execute("CREATE TABLE t (x INTEGER)")
conn.execute("INSERT INTO t (x) VALUES (42)")
def worker():
try:
result = conn.execute("SELECT x FROM t").fetchall()
rows.extend(result)
captured.append(None)
except BaseException as e: # noqa: BLE001
captured.append(e)
t = threading.Thread(target=worker)
t.start()
t.join()
assert captured == [None], f"worker thread raised: {captured}"
assert rows == [(42,)]