132 lines
4.2 KiB
Python
132 lines
4.2 KiB
Python
"""Tests for the branching service (T94, Phase 4)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
|
|
from chat.db.connection import open_db
|
|
from chat.db.migrate import apply_migrations
|
|
from chat.eventlog.log import append_and_apply
|
|
import chat.state.branches # noqa: F401 registers handlers
|
|
from chat.services.branching import (
|
|
branch_from_event,
|
|
list_branches_with_metadata,
|
|
switch_active_branch,
|
|
)
|
|
from chat.state.branches import active_branch, get_branch
|
|
|
|
|
|
def _seed_event(conn) -> int:
|
|
"""Append a benign event so we have a real event_log row to fork from.
|
|
|
|
``user_turn`` is a transcript-only kind with no registered projector
|
|
handler, so ``append_and_apply`` is a clean no-op on the projector
|
|
side regardless of what other handlers are imported by the suite.
|
|
"""
|
|
return append_and_apply(
|
|
conn,
|
|
kind="user_turn",
|
|
payload={"chat_id": "c1", "text": "hi"},
|
|
)
|
|
|
|
|
|
def test_branch_from_event_creates_branch_via_event(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
with open_db(db) as conn:
|
|
seed_id = _seed_event(conn)
|
|
|
|
new_id = branch_from_event(
|
|
conn,
|
|
name="experiment",
|
|
origin_event_id=seed_id,
|
|
chat_id="c1",
|
|
)
|
|
assert isinstance(new_id, int) and new_id > 0
|
|
|
|
b = get_branch(conn, "experiment")
|
|
assert b is not None
|
|
assert b["id"] == new_id
|
|
assert b["origin_event_id"] == seed_id
|
|
assert b["head_event_id"] == seed_id
|
|
assert b["chat_id"] == "c1"
|
|
assert b["is_active"] is False
|
|
|
|
# branch_created event landed in event_log
|
|
row = conn.execute(
|
|
"SELECT COUNT(*) FROM event_log WHERE kind = 'branch_created'"
|
|
).fetchone()
|
|
assert row[0] == 1
|
|
|
|
|
|
def test_branch_from_event_duplicate_name_raises(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
with open_db(db) as conn:
|
|
seed_id = _seed_event(conn)
|
|
branch_from_event(conn, name="dup", origin_event_id=seed_id)
|
|
|
|
with pytest.raises(ValueError, match="already exists"):
|
|
branch_from_event(conn, name="dup", origin_event_id=seed_id)
|
|
|
|
|
|
def test_branch_from_event_invalid_origin_raises(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
with open_db(db) as conn:
|
|
with pytest.raises(ValueError, match="does not exist"):
|
|
branch_from_event(conn, name="ghost", origin_event_id=99999)
|
|
|
|
|
|
def test_switch_active_branch_changes_active(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
with open_db(db) as conn:
|
|
seed_id = _seed_event(conn)
|
|
branch_from_event(conn, name="experiment", origin_event_id=seed_id)
|
|
|
|
switch_active_branch(conn, name="experiment")
|
|
active = active_branch(conn)
|
|
assert active is not None
|
|
assert active["name"] == "experiment"
|
|
|
|
# Switch back to main.
|
|
switch_active_branch(conn, name="main")
|
|
active2 = active_branch(conn)
|
|
assert active2 is not None
|
|
assert active2["name"] == "main"
|
|
|
|
|
|
def test_switch_active_branch_unknown_name_raises(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
with open_db(db) as conn:
|
|
with pytest.raises(ValueError, match="does not exist"):
|
|
switch_active_branch(conn, name="nope")
|
|
|
|
|
|
def test_list_branches_with_metadata_includes_event_count(tmp_path):
|
|
db = tmp_path / "t.db"
|
|
apply_migrations(db)
|
|
with open_db(db) as conn:
|
|
# Seed enough events to cover origin=10 and head=15.
|
|
for _ in range(15):
|
|
_seed_event(conn)
|
|
|
|
# Create the branch at origin=10, then bump its head to 15.
|
|
branch_from_event(conn, name="exp", origin_event_id=10)
|
|
append_and_apply(
|
|
conn,
|
|
kind="branch_head_updated",
|
|
payload={"name": "exp", "head_event_id": 15},
|
|
)
|
|
|
|
rows = {b["name"]: b for b in list_branches_with_metadata(conn)}
|
|
|
|
# main: bootstrap state — origin=0, head=0 — event_count == 0.
|
|
assert rows["main"]["event_count"] == 0
|
|
# exp: origin=10, head=15 — event_count == 6 (inclusive).
|
|
assert rows["exp"]["origin_event_id"] == 10
|
|
assert rows["exp"]["head_event_id"] == 15
|
|
assert rows["exp"]["event_count"] == 6
|