diff --git a/chat/services/multi_state_update.py b/chat/services/multi_state_update.py new file mode 100644 index 0000000..f6eeb8e --- /dev/null +++ b/chat/services/multi_state_update.py @@ -0,0 +1,62 @@ +"""Multi-entity state-update coordinator (T40). + +Wraps single-pair compute_state_update to run state updates for ALL +directed pairs of present entities. With 3 present entities (you, host, +guest) that's 6 directed pairs. With 2 present (you, host) it's 2 pairs. + +Calls run sequentially to respect Featherless's 2-connection cap (the +client-level semaphore would serialize them anyway, but doing it here +keeps the failure surface clean — a hung pair doesn't queue behind +itself). +""" + +from __future__ import annotations + +from chat.llm.client import LLMClient +from chat.services.state_update import StateUpdate, compute_state_update + + +async def compute_state_updates_for_present( + client: LLMClient, + *, + classifier_model: str, + present_ids: list[str], + present_names: dict[str, str], + personas: dict[str, str], + prior_edges: dict[tuple[str, str], dict], + recent_dialogue: list[dict], + timeout_s: float = 30.0, +) -> list[tuple[str, str, StateUpdate]]: + """Run compute_state_update for every directed pair (src != tgt) over + ``present_ids``. Returns list of ``(source_id, target_id, update)`` + tuples in the natural iteration order over ``present_ids x present_ids``. + + A single failing pair falls back to the schema-default StateUpdate + (zero deltas, empty facts) inside ``compute_state_update``; the batch + keeps going. + """ + out: list[tuple[str, str, StateUpdate]] = [] + for src in present_ids: + for tgt in present_ids: + if src == tgt: + continue + edge = prior_edges.get((src, tgt), {}) + update = await compute_state_update( + client, + model=classifier_model, + source_id=src, + target_id=tgt, + source_name=present_names.get(src, src), + source_persona=personas.get(src, "") or "", + target_name=present_names.get(tgt, tgt), + prior_affinity=int(edge.get("affinity", 50)), + prior_trust=int(edge.get("trust", 50)), + prior_summary=edge.get("summary", "") or "", + recent_dialogue=recent_dialogue, + timeout_s=timeout_s, + ) + out.append((src, tgt, update)) + return out + + +__all__ = ["compute_state_updates_for_present"] diff --git a/tests/test_multi_state_update.py b/tests/test_multi_state_update.py new file mode 100644 index 0000000..482fd5b --- /dev/null +++ b/tests/test_multi_state_update.py @@ -0,0 +1,147 @@ +"""Multi-entity state-update coordinator (T40). + +Wraps the single-pair :func:`compute_state_update` to run state updates +for ALL directed pairs of present entities. With 3 present entities +(you, host, guest) that's 6 directed pairs; with 2 (you, host) it's 2. + +Calls run sequentially to respect Featherless's 2-connection cap. +""" + +from __future__ import annotations + +import json + +import pytest + +from chat.llm.mock import MockLLMClient +from chat.services.multi_state_update import compute_state_updates_for_present +from chat.services.state_update import StateUpdate + + +def _canned_update(affinity: int, trust: int, facts: list[str] | None = None) -> str: + return json.dumps( + { + "affinity_delta": affinity, + "trust_delta": trust, + "knowledge_facts": facts or [], + } + ) + + +@pytest.mark.asyncio +async def test_two_entities_returns_two_updates(): + """you + bot_a -> 2 directed pairs (you->bot_a, bot_a->you).""" + canned = [ + _canned_update(2, 1, ["likes coffee"]), # you -> bot_a + _canned_update(1, 0, ["greets warmly"]), # bot_a -> you + ] + mock = MockLLMClient(canned=canned) + + results = await compute_state_updates_for_present( + mock, + classifier_model="x", + present_ids=["you", "bot_a"], + present_names={"you": "Me", "bot_a": "BotA"}, + personas={"you": "", "bot_a": "thoughtful"}, + prior_edges={ + ("you", "bot_a"): {"affinity": 50, "trust": 50, "summary": ""}, + ("bot_a", "you"): {"affinity": 50, "trust": 50, "summary": ""}, + }, + recent_dialogue=[ + {"speaker": "you", "text": "hi"}, + {"speaker": "BotA", "text": "Hello!"}, + ], + ) + + assert len(results) == 2 + assert results[0][0] == "you" + assert results[0][1] == "bot_a" + assert isinstance(results[0][2], StateUpdate) + assert results[0][2].affinity_delta == 2 + assert results[0][2].trust_delta == 1 + assert results[0][2].knowledge_facts == ["likes coffee"] + + assert results[1][0] == "bot_a" + assert results[1][1] == "you" + assert isinstance(results[1][2], StateUpdate) + assert results[1][2].affinity_delta == 1 + assert results[1][2].trust_delta == 0 + assert results[1][2].knowledge_facts == ["greets warmly"] + + +@pytest.mark.asyncio +async def test_three_entities_returns_six_updates(): + """you + bot_a + bot_b -> 6 directed pairs (no self-pairs).""" + canned = [_canned_update(i, 0) for i in range(6)] + mock = MockLLMClient(canned=canned) + + results = await compute_state_updates_for_present( + mock, + classifier_model="x", + present_ids=["you", "bot_a", "bot_b"], + present_names={"you": "Me", "bot_a": "BotA", "bot_b": "BotB"}, + personas={"you": "", "bot_a": "thoughtful", "bot_b": "cheerful"}, + prior_edges={}, # all default to 50/50/"" + recent_dialogue=[{"speaker": "you", "text": "hello all"}], + ) + + assert len(results) == 6 + + pairs = [(src, tgt) for src, tgt, _ in results] + # No self-pairs. + assert all(src != tgt for src, tgt in pairs) + # All 6 directed combinations present. + expected = { + ("you", "bot_a"), + ("you", "bot_b"), + ("bot_a", "you"), + ("bot_a", "bot_b"), + ("bot_b", "you"), + ("bot_b", "bot_a"), + } + assert set(pairs) == expected + # Every entry is a StateUpdate. + assert all(isinstance(u, StateUpdate) for _, _, u in results) + + +@pytest.mark.asyncio +async def test_failure_in_one_pair_does_not_kill_batch(): + """First pair fails all 3 classify retries -> default; second parses OK.""" + canned = [ + # Pair 1 (you -> bot_a): 3 malformed responses -> default StateUpdate. + "bad", + "still bad", + "nope", + # Pair 2 (bot_a -> you): valid JSON. + _canned_update(3, 2, ["was warm"]), + ] + mock = MockLLMClient(canned=canned) + + results = await compute_state_updates_for_present( + mock, + classifier_model="x", + present_ids=["you", "bot_a"], + present_names={"you": "Me", "bot_a": "BotA"}, + personas={"you": "", "bot_a": "thoughtful"}, + prior_edges={ + ("you", "bot_a"): {"affinity": 60, "trust": 40, "summary": "some prior"}, + ("bot_a", "you"): {"affinity": 50, "trust": 50, "summary": ""}, + }, + recent_dialogue=[{"speaker": "you", "text": "hi"}], + ) + + assert len(results) == 2 + + # First pair: default (zero-delta) StateUpdate. + src1, tgt1, update1 = results[0] + assert (src1, tgt1) == ("you", "bot_a") + assert update1.affinity_delta == 0 + assert update1.trust_delta == 0 + assert update1.knowledge_facts == [] + + # Second pair: parsed valid JSON. + src2, tgt2, update2 = results[1] + assert (src2, tgt2) == ("bot_a", "you") + assert update2.affinity_delta == 3 + assert update2.trust_delta == 2 + assert update2.knowledge_facts == ["was warm"]