148 lines
4.8 KiB
Python
148 lines
4.8 KiB
Python
"""Multi-entity state-update coordinator (T40).
|
|
|
|
Wraps the single-pair :func:`compute_state_update` to run state updates
|
|
for ALL directed pairs of present entities. With 3 present entities
|
|
(you, host, guest) that's 6 directed pairs; with 2 (you, host) it's 2.
|
|
|
|
Calls run sequentially to respect Featherless's 2-connection cap.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
|
|
import pytest
|
|
|
|
from chat.llm.mock import MockLLMClient
|
|
from chat.services.multi_state_update import compute_state_updates_for_present
|
|
from chat.services.state_update import StateUpdate
|
|
|
|
|
|
def _canned_update(affinity: int, trust: int, facts: list[str] | None = None) -> str:
|
|
return json.dumps(
|
|
{
|
|
"affinity_delta": affinity,
|
|
"trust_delta": trust,
|
|
"knowledge_facts": facts or [],
|
|
}
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_two_entities_returns_two_updates():
|
|
"""you + bot_a -> 2 directed pairs (you->bot_a, bot_a->you)."""
|
|
canned = [
|
|
_canned_update(2, 1, ["likes coffee"]), # you -> bot_a
|
|
_canned_update(1, 0, ["greets warmly"]), # bot_a -> you
|
|
]
|
|
mock = MockLLMClient(canned=canned)
|
|
|
|
results = await compute_state_updates_for_present(
|
|
mock,
|
|
classifier_model="x",
|
|
present_ids=["you", "bot_a"],
|
|
present_names={"you": "Me", "bot_a": "BotA"},
|
|
personas={"you": "", "bot_a": "thoughtful"},
|
|
prior_edges={
|
|
("you", "bot_a"): {"affinity": 50, "trust": 50, "summary": ""},
|
|
("bot_a", "you"): {"affinity": 50, "trust": 50, "summary": ""},
|
|
},
|
|
recent_dialogue=[
|
|
{"speaker": "you", "text": "hi"},
|
|
{"speaker": "BotA", "text": "Hello!"},
|
|
],
|
|
)
|
|
|
|
assert len(results) == 2
|
|
assert results[0][0] == "you"
|
|
assert results[0][1] == "bot_a"
|
|
assert isinstance(results[0][2], StateUpdate)
|
|
assert results[0][2].affinity_delta == 2
|
|
assert results[0][2].trust_delta == 1
|
|
assert results[0][2].knowledge_facts == ["likes coffee"]
|
|
|
|
assert results[1][0] == "bot_a"
|
|
assert results[1][1] == "you"
|
|
assert isinstance(results[1][2], StateUpdate)
|
|
assert results[1][2].affinity_delta == 1
|
|
assert results[1][2].trust_delta == 0
|
|
assert results[1][2].knowledge_facts == ["greets warmly"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_three_entities_returns_six_updates():
|
|
"""you + bot_a + bot_b -> 6 directed pairs (no self-pairs)."""
|
|
canned = [_canned_update(i, 0) for i in range(6)]
|
|
mock = MockLLMClient(canned=canned)
|
|
|
|
results = await compute_state_updates_for_present(
|
|
mock,
|
|
classifier_model="x",
|
|
present_ids=["you", "bot_a", "bot_b"],
|
|
present_names={"you": "Me", "bot_a": "BotA", "bot_b": "BotB"},
|
|
personas={"you": "", "bot_a": "thoughtful", "bot_b": "cheerful"},
|
|
prior_edges={}, # all default to 50/50/""
|
|
recent_dialogue=[{"speaker": "you", "text": "hello all"}],
|
|
)
|
|
|
|
assert len(results) == 6
|
|
|
|
pairs = [(src, tgt) for src, tgt, _ in results]
|
|
# No self-pairs.
|
|
assert all(src != tgt for src, tgt in pairs)
|
|
# All 6 directed combinations present.
|
|
expected = {
|
|
("you", "bot_a"),
|
|
("you", "bot_b"),
|
|
("bot_a", "you"),
|
|
("bot_a", "bot_b"),
|
|
("bot_b", "you"),
|
|
("bot_b", "bot_a"),
|
|
}
|
|
assert set(pairs) == expected
|
|
# Every entry is a StateUpdate.
|
|
assert all(isinstance(u, StateUpdate) for _, _, u in results)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_failure_in_one_pair_does_not_kill_batch():
|
|
"""First pair fails all 3 classify retries -> default; second parses OK."""
|
|
canned = [
|
|
# Pair 1 (you -> bot_a): 3 malformed responses -> default StateUpdate.
|
|
"bad",
|
|
"still bad",
|
|
"nope",
|
|
# Pair 2 (bot_a -> you): valid JSON.
|
|
_canned_update(3, 2, ["was warm"]),
|
|
]
|
|
mock = MockLLMClient(canned=canned)
|
|
|
|
results = await compute_state_updates_for_present(
|
|
mock,
|
|
classifier_model="x",
|
|
present_ids=["you", "bot_a"],
|
|
present_names={"you": "Me", "bot_a": "BotA"},
|
|
personas={"you": "", "bot_a": "thoughtful"},
|
|
prior_edges={
|
|
("you", "bot_a"): {"affinity": 60, "trust": 40, "summary": "some prior"},
|
|
("bot_a", "you"): {"affinity": 50, "trust": 50, "summary": ""},
|
|
},
|
|
recent_dialogue=[{"speaker": "you", "text": "hi"}],
|
|
)
|
|
|
|
assert len(results) == 2
|
|
|
|
# First pair: default (zero-delta) StateUpdate.
|
|
src1, tgt1, update1 = results[0]
|
|
assert (src1, tgt1) == ("you", "bot_a")
|
|
assert update1.affinity_delta == 0
|
|
assert update1.trust_delta == 0
|
|
assert update1.knowledge_facts == []
|
|
|
|
# Second pair: parsed valid JSON.
|
|
src2, tgt2, update2 = results[1]
|
|
assert (src2, tgt2) == ("bot_a", "you")
|
|
assert update2.affinity_delta == 3
|
|
assert update2.trust_delta == 2
|
|
assert update2.knowledge_facts == ["was warm"]
|