diff --git a/chat/services/memory_write.py b/chat/services/memory_write.py index 3ca40c5..0a6f9b1 100644 --- a/chat/services/memory_write.py +++ b/chat/services/memory_write.py @@ -76,3 +76,103 @@ def record_turn_memory( ).fetchone() memory_id = row[0] if row else None return event_id, memory_id + + +def _write_one_memory( + conn: Connection, + *, + owner_id: str, + chat_id: str, + narrative_text: str, + witness_you: int, + witness_host: int, + witness_guest: int, + scene_id: int | None, + chat_clock_at: str | None, + source: str, + significance: int, +) -> tuple[int, int | None]: + """Append a single ``memory_written`` event for ``owner_id`` and return + ``(event_id, memory_id)`` for the projected row.""" + payload: dict = { + "owner_id": owner_id, + "chat_id": chat_id, + "pov_summary": narrative_text, + "witness_you": witness_you, + "witness_host": witness_host, + "witness_guest": witness_guest, + "source": source, + "reliability": 1.0, + "significance": significance, + "pinned": 0, + "auto_pinned": 0, + } + if scene_id is not None: + payload["scene_id"] = scene_id + if chat_clock_at is not None: + payload["chat_clock_at"] = chat_clock_at + + event_id = append_and_apply(conn, kind="memory_written", payload=payload) + row = conn.execute( + "SELECT id FROM memories " + "WHERE owner_id = ? AND chat_id = ? " + "ORDER BY id DESC LIMIT 1", + (owner_id, chat_id), + ).fetchone() + memory_id = row[0] if row else None + return event_id, memory_id + + +def record_turn_memory_for_present( + conn: Connection, + *, + chat_id: str, + host_bot_id: str, + guest_bot_id: str | None, + narrative_text: str, + scene_id: int | None = None, + chat_clock_at: str | None = None, + source: str = "direct", + significance: int = 1, +) -> dict[str, tuple[int, int | None]]: + """Write a ``memory_written`` event for each present bot witness. + + Host is always written. Guest is written iff ``guest_bot_id is not + None``. Witness flags are ``[you=1, host=1, guest=1]`` when a guest + is present, ``[you=1, host=1, guest=0]`` otherwise. + + Returns a mapping ``{bot_id: (event_id, memory_id)}`` so callers can + look up the freshly-projected memory id per owner without re-querying + the database. + """ + witness_guest = 1 if guest_bot_id is not None else 0 + + result: dict[str, tuple[int, int | None]] = {} + result[host_bot_id] = _write_one_memory( + conn, + owner_id=host_bot_id, + chat_id=chat_id, + narrative_text=narrative_text, + witness_you=1, + witness_host=1, + witness_guest=witness_guest, + scene_id=scene_id, + chat_clock_at=chat_clock_at, + source=source, + significance=significance, + ) + if guest_bot_id is not None: + result[guest_bot_id] = _write_one_memory( + conn, + owner_id=guest_bot_id, + chat_id=chat_id, + narrative_text=narrative_text, + witness_you=1, + witness_host=1, + witness_guest=1, + scene_id=scene_id, + chat_clock_at=chat_clock_at, + source=source, + significance=significance, + ) + return result diff --git a/tests/test_memory_write.py b/tests/test_memory_write.py index aa87610..00243b0 100644 --- a/tests/test_memory_write.py +++ b/tests/test_memory_write.py @@ -22,7 +22,7 @@ from chat.db.migrate import apply_migrations from chat.eventlog.log import append_event from chat.eventlog.projector import project from chat.llm.mock import MockLLMClient -from chat.services.memory_write import record_turn_memory +from chat.services.memory_write import record_turn_memory, record_turn_memory_for_present import chat.state.entities # noqa: F401 - register handlers import chat.state.memory # noqa: F401 import chat.state.world # noqa: F401 @@ -295,3 +295,152 @@ def test_post_turn_writes_memory_for_host_bot(client, tmp_path): assert w_guest == 0 assert source == "direct" assert sig == 1 + + +# --------------------------------------------------------------------------- +# T41: record_turn_memory_for_present — multi-witness helper. +# --------------------------------------------------------------------------- + + +def _seed_two_bots(db_path: Path) -> None: + """Author host + guest bots and create a two-bot chat.""" + with open_db(db_path) as conn: + for bot_id, name in (("bot_a", "BotA"), ("bot_b", "BotB")): + append_event( + conn, + kind="bot_authored", + payload={ + "id": bot_id, + "name": name, + "persona": "...", + "voice_samples": [], + "traits": [], + "backstory": "", + "initial_relationship_to_you": "", + "kickoff_prose": "", + }, + ) + append_event( + conn, + kind="chat_created", + payload={ + "id": "chat_ab", + "host_bot_id": "bot_a", + "guest_bot_id": "bot_b", + "initial_time": "2026-04-26T20:00:00+00:00", + "narrative_anchor": "Day 1", + "weather": "", + }, + ) + project(conn) + + +def test_record_for_present_no_guest_writes_single_memory_with_witness_1_1_0(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + _seed_minimal(db) + with open_db(db) as conn: + result = record_turn_memory_for_present( + conn, + chat_id="chat_bot_a", + host_bot_id="bot_a", + guest_bot_id=None, + narrative_text="BotA glances out the window.", + scene_id=None, + chat_clock_at="2026-04-26T20:00:00+00:00", + ) + + # Returned dict has only the host key. + assert set(result.keys()) == {"bot_a"} + eid_h, mid_h = result["bot_a"] + assert eid_h > 0 + assert mid_h is not None and mid_h > 0 + + rows = conn.execute( + "SELECT owner_id, witness_you, witness_host, witness_guest " + "FROM memories" + ).fetchall() + assert len(rows) == 1 + owner_id, w_you, w_host, w_guest = rows[0] + assert owner_id == "bot_a" + assert w_you == 1 + assert w_host == 1 + assert w_guest == 0 + + # Exactly one memory_written event was appended. + cur = conn.execute( + "SELECT COUNT(*) FROM event_log WHERE kind = 'memory_written'" + ) + assert cur.fetchone()[0] == 1 + + +def test_record_for_present_with_guest_writes_two_memories_with_witness_1_1_1(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + _seed_two_bots(db) + with open_db(db) as conn: + result = record_turn_memory_for_present( + conn, + chat_id="chat_ab", + host_bot_id="bot_a", + guest_bot_id="bot_b", + narrative_text="BotA and BotB share a glance.", + scene_id=None, + chat_clock_at="2026-04-26T20:00:00+00:00", + ) + + # Returned dict has both keys. + assert set(result.keys()) == {"bot_a", "bot_b"} + eid_h, mid_h = result["bot_a"] + eid_g, mid_g = result["bot_b"] + assert eid_h > 0 and eid_g > 0 + assert mid_h is not None and mid_h > 0 + assert mid_g is not None and mid_g > 0 + # Distinct event ids and memory ids. + assert eid_h != eid_g + assert mid_h != mid_g + + rows = conn.execute( + "SELECT owner_id, witness_you, witness_host, witness_guest " + "FROM memories ORDER BY owner_id" + ).fetchall() + assert len(rows) == 2 + owners = {r[0] for r in rows} + assert owners == {"bot_a", "bot_b"} + # All rows should have witness mask [1, 1, 1]. + for _owner, w_you, w_host, w_guest in rows: + assert w_you == 1 + assert w_host == 1 + assert w_guest == 1 + + # Two memory_written events were appended. + cur = conn.execute( + "SELECT COUNT(*) FROM event_log WHERE kind = 'memory_written'" + ) + assert cur.fetchone()[0] == 2 + + +def test_record_for_present_dict_keys_match(tmp_path): + db = tmp_path / "t.db" + apply_migrations(db) + _seed_two_bots(db) + with open_db(db) as conn: + # No guest: keys == {host_bot_id}. + result_no_guest = record_turn_memory_for_present( + conn, + chat_id="chat_ab", + host_bot_id="bot_a", + guest_bot_id=None, + narrative_text="Just BotA's POV.", + ) + assert set(result_no_guest.keys()) == {"bot_a"} + + # With guest: keys == {host_bot_id, guest_bot_id}. + result_with_guest = record_turn_memory_for_present( + conn, + chat_id="chat_ab", + host_bot_id="bot_a", + guest_bot_id="bot_b", + narrative_text="Both bots witness this.", + ) + assert set(result_with_guest.keys()) == {"bot_a", "bot_b"}