Files
chat/tests/test_multi_state_update.py
T
2026-04-26 15:51:58 -04:00

148 lines
4.8 KiB
Python

"""Multi-entity state-update coordinator (T40).
Wraps the single-pair :func:`compute_state_update` to run state updates
for ALL directed pairs of present entities. With 3 present entities
(you, host, guest) that's 6 directed pairs; with 2 (you, host) it's 2.
Calls run sequentially to respect Featherless's 2-connection cap.
"""
from __future__ import annotations
import json
import pytest
from chat.llm.mock import MockLLMClient
from chat.services.multi_state_update import compute_state_updates_for_present
from chat.services.state_update import StateUpdate
def _canned_update(affinity: int, trust: int, facts: list[str] | None = None) -> str:
return json.dumps(
{
"affinity_delta": affinity,
"trust_delta": trust,
"knowledge_facts": facts or [],
}
)
@pytest.mark.asyncio
async def test_two_entities_returns_two_updates():
"""you + bot_a -> 2 directed pairs (you->bot_a, bot_a->you)."""
canned = [
_canned_update(2, 1, ["likes coffee"]), # you -> bot_a
_canned_update(1, 0, ["greets warmly"]), # bot_a -> you
]
mock = MockLLMClient(canned=canned)
results = await compute_state_updates_for_present(
mock,
classifier_model="x",
present_ids=["you", "bot_a"],
present_names={"you": "Me", "bot_a": "BotA"},
personas={"you": "", "bot_a": "thoughtful"},
prior_edges={
("you", "bot_a"): {"affinity": 50, "trust": 50, "summary": ""},
("bot_a", "you"): {"affinity": 50, "trust": 50, "summary": ""},
},
recent_dialogue=[
{"speaker": "you", "text": "hi"},
{"speaker": "BotA", "text": "Hello!"},
],
)
assert len(results) == 2
assert results[0][0] == "you"
assert results[0][1] == "bot_a"
assert isinstance(results[0][2], StateUpdate)
assert results[0][2].affinity_delta == 2
assert results[0][2].trust_delta == 1
assert results[0][2].knowledge_facts == ["likes coffee"]
assert results[1][0] == "bot_a"
assert results[1][1] == "you"
assert isinstance(results[1][2], StateUpdate)
assert results[1][2].affinity_delta == 1
assert results[1][2].trust_delta == 0
assert results[1][2].knowledge_facts == ["greets warmly"]
@pytest.mark.asyncio
async def test_three_entities_returns_six_updates():
"""you + bot_a + bot_b -> 6 directed pairs (no self-pairs)."""
canned = [_canned_update(i, 0) for i in range(6)]
mock = MockLLMClient(canned=canned)
results = await compute_state_updates_for_present(
mock,
classifier_model="x",
present_ids=["you", "bot_a", "bot_b"],
present_names={"you": "Me", "bot_a": "BotA", "bot_b": "BotB"},
personas={"you": "", "bot_a": "thoughtful", "bot_b": "cheerful"},
prior_edges={}, # all default to 50/50/""
recent_dialogue=[{"speaker": "you", "text": "hello all"}],
)
assert len(results) == 6
pairs = [(src, tgt) for src, tgt, _ in results]
# No self-pairs.
assert all(src != tgt for src, tgt in pairs)
# All 6 directed combinations present.
expected = {
("you", "bot_a"),
("you", "bot_b"),
("bot_a", "you"),
("bot_a", "bot_b"),
("bot_b", "you"),
("bot_b", "bot_a"),
}
assert set(pairs) == expected
# Every entry is a StateUpdate.
assert all(isinstance(u, StateUpdate) for _, _, u in results)
@pytest.mark.asyncio
async def test_failure_in_one_pair_does_not_kill_batch():
"""First pair fails all 3 classify retries -> default; second parses OK."""
canned = [
# Pair 1 (you -> bot_a): 3 malformed responses -> default StateUpdate.
"bad",
"still bad",
"nope",
# Pair 2 (bot_a -> you): valid JSON.
_canned_update(3, 2, ["was warm"]),
]
mock = MockLLMClient(canned=canned)
results = await compute_state_updates_for_present(
mock,
classifier_model="x",
present_ids=["you", "bot_a"],
present_names={"you": "Me", "bot_a": "BotA"},
personas={"you": "", "bot_a": "thoughtful"},
prior_edges={
("you", "bot_a"): {"affinity": 60, "trust": 40, "summary": "some prior"},
("bot_a", "you"): {"affinity": 50, "trust": 50, "summary": ""},
},
recent_dialogue=[{"speaker": "you", "text": "hi"}],
)
assert len(results) == 2
# First pair: default (zero-delta) StateUpdate.
src1, tgt1, update1 = results[0]
assert (src1, tgt1) == ("you", "bot_a")
assert update1.affinity_delta == 0
assert update1.trust_delta == 0
assert update1.knowledge_facts == []
# Second pair: parsed valid JSON.
src2, tgt2, update2 = results[1]
assert (src2, tgt2) == ("bot_a", "you")
assert update2.affinity_delta == 3
assert update2.trust_delta == 2
assert update2.knowledge_facts == ["was warm"]