From 296e8fddddb01439520e1c44aef4ee4f9f5cb11a Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Mon, 27 Apr 2026 02:35:58 -0400 Subject: [PATCH] feat: branching service (branch_from_event + switch + metadata) (T94) --- chat/services/branching.py | 107 ++++++++++++++++++++++++++++++ tests/test_branching.py | 131 +++++++++++++++++++++++++++++++++++++ 2 files changed, 238 insertions(+) create mode 100644 chat/services/branching.py create mode 100644 tests/test_branching.py diff --git a/chat/services/branching.py b/chat/services/branching.py new file mode 100644 index 0000000..abf6ff7 --- /dev/null +++ b/chat/services/branching.py @@ -0,0 +1,107 @@ +"""Branching service (T94, Phase 4). + +Wraps branches state with validation + event emission. Phase 4 ships +the data model and creation/switching APIs; the read-side filter +(event readers consulting is_active) is a Phase 4.5+ follow-up — for +now branches are metadata-only and the existing event readers remain +branch-agnostic. The drawer UI (T98) drives create/switch via these +helpers. +""" + +from __future__ import annotations +from sqlite3 import Connection + +from chat.eventlog.log import append_and_apply +from chat.state.branches import get_branch, list_branches, active_branch # noqa: F401 + + +def branch_from_event( + conn: Connection, + *, + name: str, + origin_event_id: int, + chat_id: str | None = None, +) -> int: + """Create a new named branch forking from origin_event_id. + + Emits a branch_created event. Returns the new branch's row id. + Raises ValueError if name already exists or origin_event_id doesn't + correspond to a real event.""" + if not name or not name.strip(): + raise ValueError("branch name must be non-empty") + + if get_branch(conn, name) is not None: + raise ValueError(f"branch {name!r} already exists") + + # Validate origin_event_id is a real event id (or 0 for the bootstrap case + # which only main uses). + if origin_event_id < 0: + raise ValueError(f"origin_event_id must be >= 0, got {origin_event_id}") + if origin_event_id > 0: + row = conn.execute( + "SELECT 1 FROM event_log WHERE id = ?", (origin_event_id,) + ).fetchone() + if row is None: + raise ValueError( + f"origin_event_id {origin_event_id} does not exist in event_log" + ) + + append_and_apply( + conn, + kind="branch_created", + payload={ + "name": name, + "origin_event_id": origin_event_id, + "head_event_id": origin_event_id, # head starts at origin + "chat_id": chat_id, + }, + ) + + branch = get_branch(conn, name) + if branch is None: + # Should be unreachable if append_and_apply worked. + raise RuntimeError(f"branch {name!r} not found after creation") + return branch["id"] + + +def switch_active_branch(conn: Connection, *, name: str) -> None: + """Make the named branch active. Emits branch_switched.""" + if get_branch(conn, name) is None: + raise ValueError(f"branch {name!r} does not exist") + + append_and_apply( + conn, + kind="branch_switched", + payload={"name": name}, + ) + + +def list_branches_with_metadata( + conn: Connection, chat_id: str | None = None +) -> list[dict]: + """List branches with computed event_count metadata. + + event_count = head_event_id - origin_event_id + 1 (when both are set) + OR head_event_id (when origin is 0, e.g., main branch) + OR 0 (when head <= origin, which is the bootstrap state) + """ + branches = list_branches(conn, chat_id) + enriched = [] + for b in branches: + origin = b["origin_event_id"] + head = b["head_event_id"] + if head < origin: + event_count = 0 + elif origin == 0: + event_count = head + else: + event_count = head - origin + 1 + enriched.append({**b, "event_count": event_count}) + return enriched + + +__all__ = [ + "branch_from_event", + "switch_active_branch", + "list_branches_with_metadata", +] diff --git a/tests/test_branching.py b/tests/test_branching.py new file mode 100644 index 0000000..610bb2e --- /dev/null +++ b/tests/test_branching.py @@ -0,0 +1,131 @@ +"""Tests for the branching service (T94, Phase 4).""" + +from __future__ import annotations + +import pytest + +from chat.db.connection import open_db +from chat.db.migrate import apply_migrations +from chat.eventlog.log import append_and_apply +import chat.state.branches # noqa: F401 registers handlers +from chat.services.branching import ( + branch_from_event, + list_branches_with_metadata, + switch_active_branch, +) +from chat.state.branches import active_branch, get_branch + + +def _seed_event(conn) -> int: + """Append a benign event so we have a real event_log row to fork from. + + ``user_turn`` is a transcript-only kind with no registered projector + handler, so ``append_and_apply`` is a clean no-op on the projector + side regardless of what other handlers are imported by the suite. + """ + return append_and_apply( + conn, + kind="user_turn", + payload={"chat_id": "c1", "text": "hi"}, + ) + + +def test_branch_from_event_creates_branch_via_event(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + seed_id = _seed_event(conn) + + new_id = branch_from_event( + conn, + name="experiment", + origin_event_id=seed_id, + chat_id="c1", + ) + assert isinstance(new_id, int) and new_id > 0 + + b = get_branch(conn, "experiment") + assert b is not None + assert b["id"] == new_id + assert b["origin_event_id"] == seed_id + assert b["head_event_id"] == seed_id + assert b["chat_id"] == "c1" + assert b["is_active"] is False + + # branch_created event landed in event_log + row = conn.execute( + "SELECT COUNT(*) FROM event_log WHERE kind = 'branch_created'" + ).fetchone() + assert row[0] == 1 + + +def test_branch_from_event_duplicate_name_raises(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + seed_id = _seed_event(conn) + branch_from_event(conn, name="dup", origin_event_id=seed_id) + + with pytest.raises(ValueError, match="already exists"): + branch_from_event(conn, name="dup", origin_event_id=seed_id) + + +def test_branch_from_event_invalid_origin_raises(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + with pytest.raises(ValueError, match="does not exist"): + branch_from_event(conn, name="ghost", origin_event_id=99999) + + +def test_switch_active_branch_changes_active(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + seed_id = _seed_event(conn) + branch_from_event(conn, name="experiment", origin_event_id=seed_id) + + switch_active_branch(conn, name="experiment") + active = active_branch(conn) + assert active is not None + assert active["name"] == "experiment" + + # Switch back to main. + switch_active_branch(conn, name="main") + active2 = active_branch(conn) + assert active2 is not None + assert active2["name"] == "main" + + +def test_switch_active_branch_unknown_name_raises(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + with pytest.raises(ValueError, match="does not exist"): + switch_active_branch(conn, name="nope") + + +def test_list_branches_with_metadata_includes_event_count(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + with open_db(db) as conn: + # Seed enough events to cover origin=10 and head=15. + for _ in range(15): + _seed_event(conn) + + # Create the branch at origin=10, then bump its head to 15. + branch_from_event(conn, name="exp", origin_event_id=10) + append_and_apply( + conn, + kind="branch_head_updated", + payload={"name": "exp", "head_event_id": 15}, + ) + + rows = {b["name"]: b for b in list_branches_with_metadata(conn)} + + # main: bootstrap state — origin=0, head=0 — event_count == 0. + assert rows["main"]["event_count"] == 0 + # exp: origin=10, head=15 — event_count == 6 (inclusive). + assert rows["exp"]["origin_event_id"] == 10 + assert rows["exp"]["head_event_id"] == 15 + assert rows["exp"]["event_count"] == 6