feat: multi-entity state-update coordinator
This commit is contained in:
@@ -0,0 +1,62 @@
|
|||||||
|
"""Multi-entity state-update coordinator (T40).
|
||||||
|
|
||||||
|
Wraps single-pair compute_state_update to run state updates for ALL
|
||||||
|
directed pairs of present entities. With 3 present entities (you, host,
|
||||||
|
guest) that's 6 directed pairs. With 2 present (you, host) it's 2 pairs.
|
||||||
|
|
||||||
|
Calls run sequentially to respect Featherless's 2-connection cap (the
|
||||||
|
client-level semaphore would serialize them anyway, but doing it here
|
||||||
|
keeps the failure surface clean — a hung pair doesn't queue behind
|
||||||
|
itself).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from chat.llm.client import LLMClient
|
||||||
|
from chat.services.state_update import StateUpdate, compute_state_update
|
||||||
|
|
||||||
|
|
||||||
|
async def compute_state_updates_for_present(
|
||||||
|
client: LLMClient,
|
||||||
|
*,
|
||||||
|
classifier_model: str,
|
||||||
|
present_ids: list[str],
|
||||||
|
present_names: dict[str, str],
|
||||||
|
personas: dict[str, str],
|
||||||
|
prior_edges: dict[tuple[str, str], dict],
|
||||||
|
recent_dialogue: list[dict],
|
||||||
|
timeout_s: float = 30.0,
|
||||||
|
) -> list[tuple[str, str, StateUpdate]]:
|
||||||
|
"""Run compute_state_update for every directed pair (src != tgt) over
|
||||||
|
``present_ids``. Returns list of ``(source_id, target_id, update)``
|
||||||
|
tuples in the natural iteration order over ``present_ids x present_ids``.
|
||||||
|
|
||||||
|
A single failing pair falls back to the schema-default StateUpdate
|
||||||
|
(zero deltas, empty facts) inside ``compute_state_update``; the batch
|
||||||
|
keeps going.
|
||||||
|
"""
|
||||||
|
out: list[tuple[str, str, StateUpdate]] = []
|
||||||
|
for src in present_ids:
|
||||||
|
for tgt in present_ids:
|
||||||
|
if src == tgt:
|
||||||
|
continue
|
||||||
|
edge = prior_edges.get((src, tgt), {})
|
||||||
|
update = await compute_state_update(
|
||||||
|
client,
|
||||||
|
model=classifier_model,
|
||||||
|
source_id=src,
|
||||||
|
target_id=tgt,
|
||||||
|
source_name=present_names.get(src, src),
|
||||||
|
source_persona=personas.get(src, "") or "",
|
||||||
|
target_name=present_names.get(tgt, tgt),
|
||||||
|
prior_affinity=int(edge.get("affinity", 50)),
|
||||||
|
prior_trust=int(edge.get("trust", 50)),
|
||||||
|
prior_summary=edge.get("summary", "") or "",
|
||||||
|
recent_dialogue=recent_dialogue,
|
||||||
|
timeout_s=timeout_s,
|
||||||
|
)
|
||||||
|
out.append((src, tgt, update))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["compute_state_updates_for_present"]
|
||||||
@@ -0,0 +1,147 @@
|
|||||||
|
"""Multi-entity state-update coordinator (T40).
|
||||||
|
|
||||||
|
Wraps the single-pair :func:`compute_state_update` to run state updates
|
||||||
|
for ALL directed pairs of present entities. With 3 present entities
|
||||||
|
(you, host, guest) that's 6 directed pairs; with 2 (you, host) it's 2.
|
||||||
|
|
||||||
|
Calls run sequentially to respect Featherless's 2-connection cap.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from chat.llm.mock import MockLLMClient
|
||||||
|
from chat.services.multi_state_update import compute_state_updates_for_present
|
||||||
|
from chat.services.state_update import StateUpdate
|
||||||
|
|
||||||
|
|
||||||
|
def _canned_update(affinity: int, trust: int, facts: list[str] | None = None) -> str:
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"affinity_delta": affinity,
|
||||||
|
"trust_delta": trust,
|
||||||
|
"knowledge_facts": facts or [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_two_entities_returns_two_updates():
|
||||||
|
"""you + bot_a -> 2 directed pairs (you->bot_a, bot_a->you)."""
|
||||||
|
canned = [
|
||||||
|
_canned_update(2, 1, ["likes coffee"]), # you -> bot_a
|
||||||
|
_canned_update(1, 0, ["greets warmly"]), # bot_a -> you
|
||||||
|
]
|
||||||
|
mock = MockLLMClient(canned=canned)
|
||||||
|
|
||||||
|
results = await compute_state_updates_for_present(
|
||||||
|
mock,
|
||||||
|
classifier_model="x",
|
||||||
|
present_ids=["you", "bot_a"],
|
||||||
|
present_names={"you": "Me", "bot_a": "BotA"},
|
||||||
|
personas={"you": "", "bot_a": "thoughtful"},
|
||||||
|
prior_edges={
|
||||||
|
("you", "bot_a"): {"affinity": 50, "trust": 50, "summary": ""},
|
||||||
|
("bot_a", "you"): {"affinity": 50, "trust": 50, "summary": ""},
|
||||||
|
},
|
||||||
|
recent_dialogue=[
|
||||||
|
{"speaker": "you", "text": "hi"},
|
||||||
|
{"speaker": "BotA", "text": "Hello!"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0][0] == "you"
|
||||||
|
assert results[0][1] == "bot_a"
|
||||||
|
assert isinstance(results[0][2], StateUpdate)
|
||||||
|
assert results[0][2].affinity_delta == 2
|
||||||
|
assert results[0][2].trust_delta == 1
|
||||||
|
assert results[0][2].knowledge_facts == ["likes coffee"]
|
||||||
|
|
||||||
|
assert results[1][0] == "bot_a"
|
||||||
|
assert results[1][1] == "you"
|
||||||
|
assert isinstance(results[1][2], StateUpdate)
|
||||||
|
assert results[1][2].affinity_delta == 1
|
||||||
|
assert results[1][2].trust_delta == 0
|
||||||
|
assert results[1][2].knowledge_facts == ["greets warmly"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_three_entities_returns_six_updates():
|
||||||
|
"""you + bot_a + bot_b -> 6 directed pairs (no self-pairs)."""
|
||||||
|
canned = [_canned_update(i, 0) for i in range(6)]
|
||||||
|
mock = MockLLMClient(canned=canned)
|
||||||
|
|
||||||
|
results = await compute_state_updates_for_present(
|
||||||
|
mock,
|
||||||
|
classifier_model="x",
|
||||||
|
present_ids=["you", "bot_a", "bot_b"],
|
||||||
|
present_names={"you": "Me", "bot_a": "BotA", "bot_b": "BotB"},
|
||||||
|
personas={"you": "", "bot_a": "thoughtful", "bot_b": "cheerful"},
|
||||||
|
prior_edges={}, # all default to 50/50/""
|
||||||
|
recent_dialogue=[{"speaker": "you", "text": "hello all"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 6
|
||||||
|
|
||||||
|
pairs = [(src, tgt) for src, tgt, _ in results]
|
||||||
|
# No self-pairs.
|
||||||
|
assert all(src != tgt for src, tgt in pairs)
|
||||||
|
# All 6 directed combinations present.
|
||||||
|
expected = {
|
||||||
|
("you", "bot_a"),
|
||||||
|
("you", "bot_b"),
|
||||||
|
("bot_a", "you"),
|
||||||
|
("bot_a", "bot_b"),
|
||||||
|
("bot_b", "you"),
|
||||||
|
("bot_b", "bot_a"),
|
||||||
|
}
|
||||||
|
assert set(pairs) == expected
|
||||||
|
# Every entry is a StateUpdate.
|
||||||
|
assert all(isinstance(u, StateUpdate) for _, _, u in results)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_failure_in_one_pair_does_not_kill_batch():
|
||||||
|
"""First pair fails all 3 classify retries -> default; second parses OK."""
|
||||||
|
canned = [
|
||||||
|
# Pair 1 (you -> bot_a): 3 malformed responses -> default StateUpdate.
|
||||||
|
"bad",
|
||||||
|
"still bad",
|
||||||
|
"nope",
|
||||||
|
# Pair 2 (bot_a -> you): valid JSON.
|
||||||
|
_canned_update(3, 2, ["was warm"]),
|
||||||
|
]
|
||||||
|
mock = MockLLMClient(canned=canned)
|
||||||
|
|
||||||
|
results = await compute_state_updates_for_present(
|
||||||
|
mock,
|
||||||
|
classifier_model="x",
|
||||||
|
present_ids=["you", "bot_a"],
|
||||||
|
present_names={"you": "Me", "bot_a": "BotA"},
|
||||||
|
personas={"you": "", "bot_a": "thoughtful"},
|
||||||
|
prior_edges={
|
||||||
|
("you", "bot_a"): {"affinity": 60, "trust": 40, "summary": "some prior"},
|
||||||
|
("bot_a", "you"): {"affinity": 50, "trust": 50, "summary": ""},
|
||||||
|
},
|
||||||
|
recent_dialogue=[{"speaker": "you", "text": "hi"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
# First pair: default (zero-delta) StateUpdate.
|
||||||
|
src1, tgt1, update1 = results[0]
|
||||||
|
assert (src1, tgt1) == ("you", "bot_a")
|
||||||
|
assert update1.affinity_delta == 0
|
||||||
|
assert update1.trust_delta == 0
|
||||||
|
assert update1.knowledge_facts == []
|
||||||
|
|
||||||
|
# Second pair: parsed valid JSON.
|
||||||
|
src2, tgt2, update2 = results[1]
|
||||||
|
assert (src2, tgt2) == ("bot_a", "you")
|
||||||
|
assert update2.affinity_delta == 3
|
||||||
|
assert update2.trust_delta == 2
|
||||||
|
assert update2.knowledge_facts == ["was warm"]
|
||||||
Reference in New Issue
Block a user