from __future__ import annotations import sqlite3 import threading from chat.db.connection import open_db def test_open_db_default_uses_check_same_thread_true(tmp_path): """Default open_db must reject cross-thread use (safe default).""" db = tmp_path / "t.db" captured: list[BaseException | None] = [] with open_db(db) as conn: conn.execute("CREATE TABLE t (x INTEGER)") def worker(): try: conn.execute("SELECT 1").fetchall() captured.append(None) except BaseException as e: # noqa: BLE001 captured.append(e) t = threading.Thread(target=worker) t.start() t.join() assert len(captured) == 1 err = captured[0] assert isinstance(err, sqlite3.ProgrammingError), ( f"expected sqlite3.ProgrammingError on cross-thread use, got {err!r}" ) def test_open_db_can_disable_check_same_thread(tmp_path): """open_db(check_same_thread=False) must allow cross-thread use.""" db = tmp_path / "t.db" captured: list[BaseException | None] = [] rows: list[object] = [] with open_db(db, check_same_thread=False) as conn: conn.execute("CREATE TABLE t (x INTEGER)") conn.execute("INSERT INTO t (x) VALUES (42)") def worker(): try: result = conn.execute("SELECT x FROM t").fetchall() rows.extend(result) captured.append(None) except BaseException as e: # noqa: BLE001 captured.append(e) t = threading.Thread(target=worker) t.start() t.join() assert captured == [None], f"worker thread raised: {captured}" assert rows == [(42,)]