feat: multi-entity state-update coordinator
This commit is contained in:
@@ -0,0 +1,62 @@
|
||||
"""Multi-entity state-update coordinator (T40).
|
||||
|
||||
Wraps single-pair compute_state_update to run state updates for ALL
|
||||
directed pairs of present entities. With 3 present entities (you, host,
|
||||
guest) that's 6 directed pairs. With 2 present (you, host) it's 2 pairs.
|
||||
|
||||
Calls run sequentially to respect Featherless's 2-connection cap (the
|
||||
client-level semaphore would serialize them anyway, but doing it here
|
||||
keeps the failure surface clean — a hung pair doesn't queue behind
|
||||
itself).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from chat.llm.client import LLMClient
|
||||
from chat.services.state_update import StateUpdate, compute_state_update
|
||||
|
||||
|
||||
async def compute_state_updates_for_present(
|
||||
client: LLMClient,
|
||||
*,
|
||||
classifier_model: str,
|
||||
present_ids: list[str],
|
||||
present_names: dict[str, str],
|
||||
personas: dict[str, str],
|
||||
prior_edges: dict[tuple[str, str], dict],
|
||||
recent_dialogue: list[dict],
|
||||
timeout_s: float = 30.0,
|
||||
) -> list[tuple[str, str, StateUpdate]]:
|
||||
"""Run compute_state_update for every directed pair (src != tgt) over
|
||||
``present_ids``. Returns list of ``(source_id, target_id, update)``
|
||||
tuples in the natural iteration order over ``present_ids x present_ids``.
|
||||
|
||||
A single failing pair falls back to the schema-default StateUpdate
|
||||
(zero deltas, empty facts) inside ``compute_state_update``; the batch
|
||||
keeps going.
|
||||
"""
|
||||
out: list[tuple[str, str, StateUpdate]] = []
|
||||
for src in present_ids:
|
||||
for tgt in present_ids:
|
||||
if src == tgt:
|
||||
continue
|
||||
edge = prior_edges.get((src, tgt), {})
|
||||
update = await compute_state_update(
|
||||
client,
|
||||
model=classifier_model,
|
||||
source_id=src,
|
||||
target_id=tgt,
|
||||
source_name=present_names.get(src, src),
|
||||
source_persona=personas.get(src, "") or "",
|
||||
target_name=present_names.get(tgt, tgt),
|
||||
prior_affinity=int(edge.get("affinity", 50)),
|
||||
prior_trust=int(edge.get("trust", 50)),
|
||||
prior_summary=edge.get("summary", "") or "",
|
||||
recent_dialogue=recent_dialogue,
|
||||
timeout_s=timeout_s,
|
||||
)
|
||||
out.append((src, tgt, update))
|
||||
return out
|
||||
|
||||
|
||||
__all__ = ["compute_state_updates_for_present"]
|
||||
@@ -0,0 +1,147 @@
|
||||
"""Multi-entity state-update coordinator (T40).
|
||||
|
||||
Wraps the single-pair :func:`compute_state_update` to run state updates
|
||||
for ALL directed pairs of present entities. With 3 present entities
|
||||
(you, host, guest) that's 6 directed pairs; with 2 (you, host) it's 2.
|
||||
|
||||
Calls run sequentially to respect Featherless's 2-connection cap.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from chat.llm.mock import MockLLMClient
|
||||
from chat.services.multi_state_update import compute_state_updates_for_present
|
||||
from chat.services.state_update import StateUpdate
|
||||
|
||||
|
||||
def _canned_update(affinity: int, trust: int, facts: list[str] | None = None) -> str:
|
||||
return json.dumps(
|
||||
{
|
||||
"affinity_delta": affinity,
|
||||
"trust_delta": trust,
|
||||
"knowledge_facts": facts or [],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_entities_returns_two_updates():
|
||||
"""you + bot_a -> 2 directed pairs (you->bot_a, bot_a->you)."""
|
||||
canned = [
|
||||
_canned_update(2, 1, ["likes coffee"]), # you -> bot_a
|
||||
_canned_update(1, 0, ["greets warmly"]), # bot_a -> you
|
||||
]
|
||||
mock = MockLLMClient(canned=canned)
|
||||
|
||||
results = await compute_state_updates_for_present(
|
||||
mock,
|
||||
classifier_model="x",
|
||||
present_ids=["you", "bot_a"],
|
||||
present_names={"you": "Me", "bot_a": "BotA"},
|
||||
personas={"you": "", "bot_a": "thoughtful"},
|
||||
prior_edges={
|
||||
("you", "bot_a"): {"affinity": 50, "trust": 50, "summary": ""},
|
||||
("bot_a", "you"): {"affinity": 50, "trust": 50, "summary": ""},
|
||||
},
|
||||
recent_dialogue=[
|
||||
{"speaker": "you", "text": "hi"},
|
||||
{"speaker": "BotA", "text": "Hello!"},
|
||||
],
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0][0] == "you"
|
||||
assert results[0][1] == "bot_a"
|
||||
assert isinstance(results[0][2], StateUpdate)
|
||||
assert results[0][2].affinity_delta == 2
|
||||
assert results[0][2].trust_delta == 1
|
||||
assert results[0][2].knowledge_facts == ["likes coffee"]
|
||||
|
||||
assert results[1][0] == "bot_a"
|
||||
assert results[1][1] == "you"
|
||||
assert isinstance(results[1][2], StateUpdate)
|
||||
assert results[1][2].affinity_delta == 1
|
||||
assert results[1][2].trust_delta == 0
|
||||
assert results[1][2].knowledge_facts == ["greets warmly"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_three_entities_returns_six_updates():
|
||||
"""you + bot_a + bot_b -> 6 directed pairs (no self-pairs)."""
|
||||
canned = [_canned_update(i, 0) for i in range(6)]
|
||||
mock = MockLLMClient(canned=canned)
|
||||
|
||||
results = await compute_state_updates_for_present(
|
||||
mock,
|
||||
classifier_model="x",
|
||||
present_ids=["you", "bot_a", "bot_b"],
|
||||
present_names={"you": "Me", "bot_a": "BotA", "bot_b": "BotB"},
|
||||
personas={"you": "", "bot_a": "thoughtful", "bot_b": "cheerful"},
|
||||
prior_edges={}, # all default to 50/50/""
|
||||
recent_dialogue=[{"speaker": "you", "text": "hello all"}],
|
||||
)
|
||||
|
||||
assert len(results) == 6
|
||||
|
||||
pairs = [(src, tgt) for src, tgt, _ in results]
|
||||
# No self-pairs.
|
||||
assert all(src != tgt for src, tgt in pairs)
|
||||
# All 6 directed combinations present.
|
||||
expected = {
|
||||
("you", "bot_a"),
|
||||
("you", "bot_b"),
|
||||
("bot_a", "you"),
|
||||
("bot_a", "bot_b"),
|
||||
("bot_b", "you"),
|
||||
("bot_b", "bot_a"),
|
||||
}
|
||||
assert set(pairs) == expected
|
||||
# Every entry is a StateUpdate.
|
||||
assert all(isinstance(u, StateUpdate) for _, _, u in results)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failure_in_one_pair_does_not_kill_batch():
|
||||
"""First pair fails all 3 classify retries -> default; second parses OK."""
|
||||
canned = [
|
||||
# Pair 1 (you -> bot_a): 3 malformed responses -> default StateUpdate.
|
||||
"bad",
|
||||
"still bad",
|
||||
"nope",
|
||||
# Pair 2 (bot_a -> you): valid JSON.
|
||||
_canned_update(3, 2, ["was warm"]),
|
||||
]
|
||||
mock = MockLLMClient(canned=canned)
|
||||
|
||||
results = await compute_state_updates_for_present(
|
||||
mock,
|
||||
classifier_model="x",
|
||||
present_ids=["you", "bot_a"],
|
||||
present_names={"you": "Me", "bot_a": "BotA"},
|
||||
personas={"you": "", "bot_a": "thoughtful"},
|
||||
prior_edges={
|
||||
("you", "bot_a"): {"affinity": 60, "trust": 40, "summary": "some prior"},
|
||||
("bot_a", "you"): {"affinity": 50, "trust": 50, "summary": ""},
|
||||
},
|
||||
recent_dialogue=[{"speaker": "you", "text": "hi"}],
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
# First pair: default (zero-delta) StateUpdate.
|
||||
src1, tgt1, update1 = results[0]
|
||||
assert (src1, tgt1) == ("you", "bot_a")
|
||||
assert update1.affinity_delta == 0
|
||||
assert update1.trust_delta == 0
|
||||
assert update1.knowledge_facts == []
|
||||
|
||||
# Second pair: parsed valid JSON.
|
||||
src2, tgt2, update2 = results[1]
|
||||
assert (src2, tgt2) == ("bot_a", "you")
|
||||
assert update2.affinity_delta == 3
|
||||
assert update2.trust_delta == 2
|
||||
assert update2.knowledge_facts == ["was warm"]
|
||||
Reference in New Issue
Block a user