58 lines
1.7 KiB
Python
58 lines
1.7 KiB
Python
from __future__ import annotations
|
|
import sqlite3
|
|
import threading
|
|
|
|
from chat.db.connection import open_db
|
|
|
|
|
|
def test_open_db_default_uses_check_same_thread_true(tmp_path):
|
|
"""Default open_db must reject cross-thread use (safe default)."""
|
|
db = tmp_path / "t.db"
|
|
captured: list[BaseException | None] = []
|
|
|
|
with open_db(db) as conn:
|
|
conn.execute("CREATE TABLE t (x INTEGER)")
|
|
|
|
def worker():
|
|
try:
|
|
conn.execute("SELECT 1").fetchall()
|
|
captured.append(None)
|
|
except BaseException as e: # noqa: BLE001
|
|
captured.append(e)
|
|
|
|
t = threading.Thread(target=worker)
|
|
t.start()
|
|
t.join()
|
|
|
|
assert len(captured) == 1
|
|
err = captured[0]
|
|
assert isinstance(err, sqlite3.ProgrammingError), (
|
|
f"expected sqlite3.ProgrammingError on cross-thread use, got {err!r}"
|
|
)
|
|
|
|
|
|
def test_open_db_can_disable_check_same_thread(tmp_path):
|
|
"""open_db(check_same_thread=False) must allow cross-thread use."""
|
|
db = tmp_path / "t.db"
|
|
captured: list[BaseException | None] = []
|
|
rows: list[object] = []
|
|
|
|
with open_db(db, check_same_thread=False) as conn:
|
|
conn.execute("CREATE TABLE t (x INTEGER)")
|
|
conn.execute("INSERT INTO t (x) VALUES (42)")
|
|
|
|
def worker():
|
|
try:
|
|
result = conn.execute("SELECT x FROM t").fetchall()
|
|
rows.extend(result)
|
|
captured.append(None)
|
|
except BaseException as e: # noqa: BLE001
|
|
captured.append(e)
|
|
|
|
t = threading.Thread(target=worker)
|
|
t.start()
|
|
t.join()
|
|
|
|
assert captured == [None], f"worker thread raised: {captured}"
|
|
assert rows == [(42,)]
|