refactor: open_db with check_same_thread parameter (T68)
This commit is contained in:
@@ -5,9 +5,9 @@ from pathlib import Path
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def open_db(path: Path):
|
def open_db(path: Path, *, check_same_thread: bool = True):
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
conn = sqlite3.connect(path)
|
conn = sqlite3.connect(path, check_same_thread=check_same_thread)
|
||||||
conn.execute("PRAGMA journal_mode=WAL")
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
conn.execute("PRAGMA foreign_keys=ON")
|
conn.execute("PRAGMA foreign_keys=ON")
|
||||||
try:
|
try:
|
||||||
|
|||||||
+2
-9
@@ -1,10 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import sqlite3
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from fastapi import APIRouter, Depends, Form, HTTPException, Request
|
from fastapi import APIRouter, Depends, Form, HTTPException, Request
|
||||||
from fastapi.responses import RedirectResponse, HTMLResponse
|
from fastapi.responses import RedirectResponse, HTMLResponse
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
|
|
||||||
|
from chat.db.connection import open_db
|
||||||
from chat.eventlog.log import append_event
|
from chat.eventlog.log import append_event
|
||||||
from chat.eventlog.projector import project
|
from chat.eventlog.projector import project
|
||||||
from chat.state.entities import list_bots
|
from chat.state.entities import list_bots
|
||||||
@@ -19,15 +19,8 @@ REQUIRED_FIELDS = ("id", "name", "persona", "initial_relationship_to_you", "kick
|
|||||||
def get_conn(request: Request):
|
def get_conn(request: Request):
|
||||||
settings = request.app.state.settings
|
settings = request.app.state.settings
|
||||||
db_path: Path = settings.db_path
|
db_path: Path = settings.db_path
|
||||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
with open_db(db_path, check_same_thread=False) as conn:
|
||||||
conn = sqlite3.connect(db_path, check_same_thread=False)
|
|
||||||
conn.execute("PRAGMA journal_mode=WAL")
|
|
||||||
conn.execute("PRAGMA foreign_keys=ON")
|
|
||||||
try:
|
|
||||||
yield conn
|
yield conn
|
||||||
conn.commit()
|
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
def _split_voice_samples(text: str) -> list[str]:
|
def _split_voice_samples(text: str) -> list[str]:
|
||||||
|
|||||||
@@ -0,0 +1,57 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from chat.db.connection import open_db
|
||||||
|
|
||||||
|
|
||||||
|
def test_open_db_default_uses_check_same_thread_true(tmp_path):
|
||||||
|
"""Default open_db must reject cross-thread use (safe default)."""
|
||||||
|
db = tmp_path / "t.db"
|
||||||
|
captured: list[BaseException | None] = []
|
||||||
|
|
||||||
|
with open_db(db) as conn:
|
||||||
|
conn.execute("CREATE TABLE t (x INTEGER)")
|
||||||
|
|
||||||
|
def worker():
|
||||||
|
try:
|
||||||
|
conn.execute("SELECT 1").fetchall()
|
||||||
|
captured.append(None)
|
||||||
|
except BaseException as e: # noqa: BLE001
|
||||||
|
captured.append(e)
|
||||||
|
|
||||||
|
t = threading.Thread(target=worker)
|
||||||
|
t.start()
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert len(captured) == 1
|
||||||
|
err = captured[0]
|
||||||
|
assert isinstance(err, sqlite3.ProgrammingError), (
|
||||||
|
f"expected sqlite3.ProgrammingError on cross-thread use, got {err!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_open_db_can_disable_check_same_thread(tmp_path):
|
||||||
|
"""open_db(check_same_thread=False) must allow cross-thread use."""
|
||||||
|
db = tmp_path / "t.db"
|
||||||
|
captured: list[BaseException | None] = []
|
||||||
|
rows: list[object] = []
|
||||||
|
|
||||||
|
with open_db(db, check_same_thread=False) as conn:
|
||||||
|
conn.execute("CREATE TABLE t (x INTEGER)")
|
||||||
|
conn.execute("INSERT INTO t (x) VALUES (42)")
|
||||||
|
|
||||||
|
def worker():
|
||||||
|
try:
|
||||||
|
result = conn.execute("SELECT x FROM t").fetchall()
|
||||||
|
rows.extend(result)
|
||||||
|
captured.append(None)
|
||||||
|
except BaseException as e: # noqa: BLE001
|
||||||
|
captured.append(e)
|
||||||
|
|
||||||
|
t = threading.Thread(target=worker)
|
||||||
|
t.start()
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert captured == [None], f"worker thread raised: {captured}"
|
||||||
|
assert rows == [(42,)]
|
||||||
Reference in New Issue
Block a user