From 994728b5ed76571f2f7c4c73197cc2e0f1b3ff74 Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 26 Apr 2026 17:05:29 -0400 Subject: [PATCH] refactor: open_db with check_same_thread parameter (T68) --- chat/db/connection.py | 4 +-- chat/web/bots.py | 11 ++----- tests/test_open_db_threading.py | 57 +++++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 11 deletions(-) create mode 100644 tests/test_open_db_threading.py diff --git a/chat/db/connection.py b/chat/db/connection.py index ad21a01..f293aca 100644 --- a/chat/db/connection.py +++ b/chat/db/connection.py @@ -5,9 +5,9 @@ from pathlib import Path @contextmanager -def open_db(path: Path): +def open_db(path: Path, *, check_same_thread: bool = True): path.parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(path) + conn = sqlite3.connect(path, check_same_thread=check_same_thread) conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA foreign_keys=ON") try: diff --git a/chat/web/bots.py b/chat/web/bots.py index d9666f3..027130b 100644 --- a/chat/web/bots.py +++ b/chat/web/bots.py @@ -1,10 +1,10 @@ from __future__ import annotations -import sqlite3 from pathlib import Path from fastapi import APIRouter, Depends, Form, HTTPException, Request from fastapi.responses import RedirectResponse, HTMLResponse from fastapi.templating import Jinja2Templates +from chat.db.connection import open_db from chat.eventlog.log import append_event from chat.eventlog.projector import project from chat.state.entities import list_bots @@ -19,15 +19,8 @@ REQUIRED_FIELDS = ("id", "name", "persona", "initial_relationship_to_you", "kick def get_conn(request: Request): settings = request.app.state.settings db_path: Path = settings.db_path - db_path.parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(db_path, check_same_thread=False) - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA foreign_keys=ON") - try: + with open_db(db_path, check_same_thread=False) as conn: yield conn - conn.commit() - finally: - conn.close() def _split_voice_samples(text: str) -> list[str]: diff --git a/tests/test_open_db_threading.py b/tests/test_open_db_threading.py new file mode 100644 index 0000000..c7756f9 --- /dev/null +++ b/tests/test_open_db_threading.py @@ -0,0 +1,57 @@ +from __future__ import annotations +import sqlite3 +import threading + +from chat.db.connection import open_db + + +def test_open_db_default_uses_check_same_thread_true(tmp_path): + """Default open_db must reject cross-thread use (safe default).""" + db = tmp_path / "t.db" + captured: list[BaseException | None] = [] + + with open_db(db) as conn: + conn.execute("CREATE TABLE t (x INTEGER)") + + def worker(): + try: + conn.execute("SELECT 1").fetchall() + captured.append(None) + except BaseException as e: # noqa: BLE001 + captured.append(e) + + t = threading.Thread(target=worker) + t.start() + t.join() + + assert len(captured) == 1 + err = captured[0] + assert isinstance(err, sqlite3.ProgrammingError), ( + f"expected sqlite3.ProgrammingError on cross-thread use, got {err!r}" + ) + + +def test_open_db_can_disable_check_same_thread(tmp_path): + """open_db(check_same_thread=False) must allow cross-thread use.""" + db = tmp_path / "t.db" + captured: list[BaseException | None] = [] + rows: list[object] = [] + + with open_db(db, check_same_thread=False) as conn: + conn.execute("CREATE TABLE t (x INTEGER)") + conn.execute("INSERT INTO t (x) VALUES (42)") + + def worker(): + try: + result = conn.execute("SELECT x FROM t").fetchall() + rows.extend(result) + captured.append(None) + except BaseException as e: # noqa: BLE001 + captured.append(e) + + t = threading.Thread(target=worker) + t.start() + t.join() + + assert captured == [None], f"worker thread raised: {captured}" + assert rows == [(42,)]