mirror of https://github.com/microsoft/autogen.git
patch to graph groupchat (#1555)
* patch to #1541 * graph validity test * update docstr
This commit is contained in:
parent
6f8ccf80e7
commit
18a15d78fd
|
@ -9,7 +9,7 @@ from typing import Dict, List, Optional, Union, Tuple
|
|||
from ..code_utils import content_str
|
||||
from .agent import Agent
|
||||
from .conversable_agent import ConversableAgent
|
||||
from ..graph_utils import check_graph_validity, invert_disallowed_to_allowed, has_self_loops
|
||||
from ..graph_utils import check_graph_validity, invert_disallowed_to_allowed
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -42,12 +42,22 @@ class GroupChat:
|
|||
- "random": the next speaker is selected randomly.
|
||||
- "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`.
|
||||
|
||||
- allow_repeat_speaker: whether to allow the same speaker to speak consecutively. Default is True, in which case all speakers are allowed to speak consecutively. If allow_repeat_speaker is a list of Agents, then only those listed agents are allowed to repeat. If set to False, then no speakers are allowed to repeat. allow_repeat_speaker and allowed_or_disallowed_speaker_transitions are mutually exclusive.
|
||||
- allowed_or_disallowed_speaker_transitions: a dictionary of keys and list as values. The keys are the source agents, and the values are the agents that the key agent can transition to. Default is None, in which case a fully connected allowed_speaker_transitions_dict is assumed. allow_repeat_speaker and allowed_or_disallowed_speaker_transitions are mutually exclusive.
|
||||
- speaker_transitions_type: whether the speaker_transitions_type is a dictionary containing lists of allowed agents or disallowed agents. allowed means the allowed_or_disallowed_speaker_transitions is a dictionary containing lists of allowed agents. If set to disallowed, then the allowed_or_disallowed_speaker_transitions is a dictionary containing lists of disallowed agents. Must be supplied if allowed_or_disallowed_speaker_transitions is not None.
|
||||
- allow_repeat_speaker: whether to allow the same speaker to speak consecutively.
|
||||
Default is True, in which case all speakers are allowed to speak consecutively.
|
||||
If `allow_repeat_speaker` is a list of Agents, then only those listed agents are allowed to repeat.
|
||||
If set to False, then no speakers are allowed to repeat.
|
||||
`allow_repeat_speaker` and `allowed_or_disallowed_speaker_transitions` are mutually exclusive.
|
||||
- allowed_or_disallowed_speaker_transitions: dict.
|
||||
The keys are source agents, and the values are agents that the key agent can/can't transit to,
|
||||
depending on speaker_transitions_type. Default is None, which means all agents can transit to all other agents.
|
||||
`allow_repeat_speaker` and `allowed_or_disallowed_speaker_transitions` are mutually exclusive.
|
||||
- speaker_transitions_type: whether the speaker_transitions_type is a dictionary containing lists of allowed agents or disallowed agents.
|
||||
"allowed" means the `allowed_or_disallowed_speaker_transitions` is a dictionary containing lists of allowed agents.
|
||||
If set to "disallowed", then the `allowed_or_disallowed_speaker_transitions` is a dictionary containing lists of disallowed agents.
|
||||
Must be supplied if `allowed_or_disallowed_speaker_transitions` is not None.
|
||||
- enable_clear_history: enable possibility to clear history of messages for agents manually by providing
|
||||
"clear history" phrase in user prompt. This is experimental feature.
|
||||
See description of GroupChatManager.clear_agents_history function for more info.
|
||||
See description of `GroupChatManager.clear_agents_history` function for more info.
|
||||
"""
|
||||
|
||||
agents: List[Agent]
|
||||
|
@ -56,9 +66,7 @@ class GroupChat:
|
|||
admin_name: Optional[str] = "Admin"
|
||||
func_call_filter: Optional[bool] = True
|
||||
speaker_selection_method: Optional[str] = "auto"
|
||||
allow_repeat_speaker: Optional[
|
||||
Union[bool, List[Agent]]
|
||||
] = True # It would be set to True if allowed_or_disallowed_speaker_transitions is None
|
||||
allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = None
|
||||
allowed_or_disallowed_speaker_transitions: Optional[Dict] = None
|
||||
speaker_transitions_type: Optional[str] = None
|
||||
enable_clear_history: Optional[bool] = False
|
||||
|
@ -70,16 +78,21 @@ class GroupChat:
|
|||
|
||||
def __post_init__(self):
|
||||
# Post init steers clears of the automatically generated __init__ method from dataclass
|
||||
# Here, we create allowed_speaker_transitions_dict from the supplied allowed_or_disallowed_speaker_transitions and is_allowed_graph, and lastly checks for validity.
|
||||
|
||||
if self.allow_repeat_speaker is not None and not isinstance(self.allow_repeat_speaker, (bool, list)):
|
||||
raise ValueError("GroupChat allow_repeat_speaker should be a bool or a list of Agents.")
|
||||
|
||||
# Here, we create allowed_speaker_transitions_dict from the supplied allowed_or_disallowed_speaker_transitions and speaker_transitions_type, and lastly checks for validity.
|
||||
|
||||
# Check input
|
||||
if self.speaker_transitions_type is not None:
|
||||
self.speaker_transitions_type = self.speaker_transitions_type.lower()
|
||||
|
||||
assert self.speaker_transitions_type in self._VALID_SPEAKER_TRANSITIONS_TYPE, (
|
||||
f"GroupChat speaker_transitions_type is set to '{self.speaker_transitions_type}'. "
|
||||
f"It should be one of {self._VALID_SPEAKER_TRANSITIONS_TYPE} (case insensitive). "
|
||||
)
|
||||
if self.speaker_transitions_type not in self._VALID_SPEAKER_TRANSITIONS_TYPE:
|
||||
raise ValueError(
|
||||
f"GroupChat speaker_transitions_type is set to '{self.speaker_transitions_type}'. "
|
||||
f"It should be one of {self._VALID_SPEAKER_TRANSITIONS_TYPE} (case insensitive). "
|
||||
)
|
||||
|
||||
# If both self.allowed_or_disallowed_speaker_transitions is None and self.allow_repeat_speaker is None, set allow_repeat_speaker to True to ensure backward compatibility
|
||||
# Discussed in https://github.com/microsoft/autogen/pull/857#discussion_r1451541204
|
||||
|
@ -114,7 +127,7 @@ class GroupChat:
|
|||
]
|
||||
|
||||
# If self.allow_repeat_speaker is True, add self loops to all agents
|
||||
if self.allow_repeat_speaker:
|
||||
if self.allow_repeat_speaker is True:
|
||||
for agent in self.agents:
|
||||
self.allowed_speaker_transitions_dict[agent].append(agent)
|
||||
|
||||
|
@ -125,7 +138,7 @@ class GroupChat:
|
|||
|
||||
# Create self.allowed_speaker_transitions_dict if allowed_or_disallowed_speaker_transitions is not None, using allowed_or_disallowed_speaker_transitions
|
||||
else:
|
||||
# Process based on is_allowed_graph
|
||||
# Process based on speaker_transitions_type
|
||||
if self.speaker_transitions_type == "allowed":
|
||||
self.allowed_speaker_transitions_dict = self.allowed_or_disallowed_speaker_transitions
|
||||
else:
|
||||
|
@ -134,16 +147,10 @@ class GroupChat:
|
|||
self.allowed_or_disallowed_speaker_transitions, self.agents
|
||||
)
|
||||
|
||||
# Inferring self.allow_repeat_speaker from allowed_speaker_transitions_dict using has_self_loops
|
||||
# Finally, self.allow_repeat_speaker shouldn't be None, so it is set from the the graph.
|
||||
if self.allow_repeat_speaker is None:
|
||||
self.allow_repeat_speaker = has_self_loops(self.allowed_speaker_transitions_dict)
|
||||
|
||||
# Check for validity
|
||||
check_graph_validity(
|
||||
allowed_speaker_transitions_dict=self.allowed_speaker_transitions_dict,
|
||||
agents=self.agents,
|
||||
allow_repeat_speaker=self.allow_repeat_speaker,
|
||||
)
|
||||
|
||||
@property
|
||||
|
@ -248,12 +255,10 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
|
|||
f"It should be one of {self._VALID_SPEAKER_SELECTION_METHODS} (case insensitive). "
|
||||
)
|
||||
|
||||
if not isinstance(self.allow_repeat_speaker, (bool, list)):
|
||||
raise ValueError("GroupChat allow_repeat_speaker should be a bool or a list of Agents.")
|
||||
# If provided a list, make sure the agent is in the list
|
||||
allow_repeat_speaker = (
|
||||
self.allow_repeat_speaker
|
||||
if isinstance(self.allow_repeat_speaker, bool)
|
||||
if isinstance(self.allow_repeat_speaker, bool) or self.allow_repeat_speaker is None
|
||||
else last_speaker in self.allow_repeat_speaker
|
||||
)
|
||||
|
||||
|
@ -268,8 +273,8 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
|
|||
elif n_agents == 2 and self.speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker:
|
||||
logger.warning(
|
||||
f"GroupChat is underpopulated with {n_agents} agents. "
|
||||
"It is recommended to set speaker_selection_method to 'round_robin' or allow_repeat_speaker to False."
|
||||
"Or, use direct communication instead."
|
||||
"Consider setting speaker_selection_method to 'round_robin' or allow_repeat_speaker to False, "
|
||||
"or use direct communication, unless repeated speaker is desired."
|
||||
)
|
||||
|
||||
if (
|
||||
|
@ -301,7 +306,7 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
|
|||
"Please check the function_map of the agents."
|
||||
)
|
||||
# remove the last speaker from the list to avoid selecting the same speaker if allow_repeat_speaker is False
|
||||
agents = agents if allow_repeat_speaker else [agent for agent in agents if agent != last_speaker]
|
||||
agents = [agent for agent in agents if agent != last_speaker] if allow_repeat_speaker is False else agents
|
||||
|
||||
# Filter agents with allowed_speaker_transitions_dict
|
||||
|
||||
|
|
|
@ -4,22 +4,20 @@ import logging
|
|||
from autogen.agentchat.groupchat import Agent
|
||||
|
||||
|
||||
def has_self_loops(allowed_speaker_transitions: dict) -> bool:
|
||||
def has_self_loops(allowed_speaker_transitions: Dict) -> bool:
|
||||
"""
|
||||
Returns True if there are self loops in the allowed_speaker_transitions_dict.
|
||||
Returns True if there are self loops in the allowed_speaker_transitions_Dict.
|
||||
"""
|
||||
return any([key in value for key, value in allowed_speaker_transitions.items()])
|
||||
|
||||
|
||||
def check_graph_validity(
|
||||
allowed_speaker_transitions_dict: dict,
|
||||
allowed_speaker_transitions_dict: Dict,
|
||||
agents: List[Agent],
|
||||
allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = True,
|
||||
):
|
||||
"""
|
||||
allowed_speaker_transitions_dict: A dictionary of keys and list as values. The keys are the names of the agents, and the values are the names of the agents that the key agent can transition to.
|
||||
agents: A list of Agents
|
||||
allow_repeat_speaker: A boolean indicating whether the same agent can speak twice in a row.
|
||||
|
||||
Checks for the following:
|
||||
Errors
|
||||
|
@ -40,12 +38,12 @@ def check_graph_validity(
|
|||
raise ValueError("allowed_speaker_transitions_dict must be a dictionary.")
|
||||
|
||||
# All values must be lists of Agent or empty
|
||||
if not all([isinstance(value, list) or value == [] for value in allowed_speaker_transitions_dict.values()]):
|
||||
raise ValueError("allowed_speaker_transitions_dict must be a dictionary of keys and list as values.")
|
||||
if not all([isinstance(value, list) for value in allowed_speaker_transitions_dict.values()]):
|
||||
raise ValueError("allowed_speaker_transitions_dict must be a dictionary with lists as values.")
|
||||
|
||||
# Check 2. Every key exists in agents
|
||||
if not all([key in agents for key in allowed_speaker_transitions_dict.keys()]):
|
||||
raise ValueError("allowed_speaker_transitions_dict has keys not in agents' names.")
|
||||
raise ValueError("allowed_speaker_transitions_dict has keys not in agents.")
|
||||
|
||||
# Check 3. Every value is a list of Agents or empty list (not string).
|
||||
if not all(
|
||||
|
|
|
@ -621,7 +621,6 @@
|
|||
" agents=agents,\n",
|
||||
" messages=[],\n",
|
||||
" max_round=20,\n",
|
||||
" allow_repeat_speaker=None,\n",
|
||||
" allowed_or_disallowed_speaker_transitions=speaker_transitions_dict,\n",
|
||||
" speaker_transitions_type=\"allowed\",\n",
|
||||
")\n",
|
||||
|
|
|
@ -154,6 +154,7 @@ class GroupChatConfig:
|
|||
max_round: Optional[int] = 10
|
||||
admin_name: Optional[str] = "Admin"
|
||||
speaker_selection_method: Optional[str] = "auto"
|
||||
# TODO: match the new group chat default and support transition spec
|
||||
allow_repeat_speaker: Optional[Union[bool, List[AgentConfig]]] = True
|
||||
|
||||
def dict(self):
|
||||
|
|
|
@ -226,16 +226,14 @@ def test_invalid_allow_repeat_speaker():
|
|||
default_auto_reply="This is bob speaking.",
|
||||
)
|
||||
# test invalid allow_repeat_speaker
|
||||
groupchat = autogen.GroupChat(
|
||||
agents=[agent1, agent2],
|
||||
messages=[],
|
||||
max_round=6,
|
||||
speaker_selection_method="round_robin",
|
||||
allow_repeat_speaker={},
|
||||
)
|
||||
with pytest.raises(ValueError) as e:
|
||||
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
|
||||
agent1.initiate_chat(group_chat_manager, message="This is alice speaking.")
|
||||
autogen.GroupChat(
|
||||
agents=[agent1, agent2],
|
||||
messages=[],
|
||||
max_round=6,
|
||||
speaker_selection_method="round_robin",
|
||||
allow_repeat_speaker={},
|
||||
)
|
||||
assert str(e.value) == "GroupChat allow_repeat_speaker should be a bool or a list of Agents.", e.value
|
||||
|
||||
|
||||
|
@ -506,11 +504,18 @@ def test_graph_parameters():
|
|||
agents=agents,
|
||||
messages=[],
|
||||
max_round=3,
|
||||
allow_repeat_speaker=None,
|
||||
allowed_or_disallowed_speaker_transitions={agents[0]: [agents[1]], agents[1]: [agents[2]]},
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
GroupChat(
|
||||
agents=agents,
|
||||
messages=[],
|
||||
max_round=3,
|
||||
allow_repeat_speaker=False, # should be None
|
||||
allowed_or_disallowed_speaker_transitions={agents[0]: [agents[1]], agents[1]: [agents[2]]},
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
with pytest.raises(ValueError):
|
||||
GroupChat(
|
||||
agents=agents,
|
||||
messages=[],
|
||||
|
@ -524,25 +529,12 @@ def test_graph_parameters():
|
|||
agents=agents,
|
||||
messages=[],
|
||||
max_round=3,
|
||||
allow_repeat_speaker=None,
|
||||
allowed_or_disallowed_speaker_transitions={agents[0]: [agents[1]], agents[1]: [agents[2]]},
|
||||
speaker_transitions_type="allowed",
|
||||
)
|
||||
assert "Agent0" in group_chat.agent_names
|
||||
|
||||
|
||||
def test_graph_validity_check():
|
||||
agents = [Agent(name=f"Agent{i}") for i in range(3)]
|
||||
invalid_transitions = {agents[0]: []}
|
||||
with pytest.raises(ValueError):
|
||||
GroupChat(
|
||||
agents=agents,
|
||||
messages=[],
|
||||
allowed_or_disallowed_speaker_transitions=invalid_transitions,
|
||||
speaker_transitions_type="allowed",
|
||||
)
|
||||
|
||||
|
||||
def test_graceful_exit_before_max_round():
|
||||
agent1 = autogen.ConversableAgent(
|
||||
"alice",
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import sys
|
||||
import pytest
|
||||
import logging
|
||||
from autogen.agentchat import Agent
|
||||
|
@ -46,20 +45,10 @@ class TestGraphUtilCheckGraphValidity:
|
|||
with pytest.raises(ValueError):
|
||||
gru.check_graph_validity(invalid_speaker_transitions_dict, agents)
|
||||
|
||||
def test_graph_with_unauthorized_self_loops(self):
|
||||
def test_graph_with_invalid_key(self):
|
||||
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
|
||||
# Creating a subset of agents allowed to have self-loops
|
||||
allowed_repeat_speakers = agents[: len(agents) // 2]
|
||||
|
||||
# Constructing a speaker transitions dictionary with self-loops for all agents
|
||||
# Ensuring at least one agent outside the allowed_repeat_speakers has a self-loop
|
||||
speaker_transitions_dict_with_self_loop = {agent: agent for agent in agents}
|
||||
|
||||
# Testing the function with the constructed speaker transitions dict
|
||||
with pytest.raises(ValueError):
|
||||
gru.check_graph_validity(
|
||||
speaker_transitions_dict_with_self_loop, agents, allow_repeat_speaker=allowed_repeat_speakers
|
||||
)
|
||||
gru.check_graph_validity({1: 1}, agents)
|
||||
|
||||
# Test for Warning 1: Isolated agent nodes
|
||||
def test_isolated_agent_nodes_warning(self, caplog):
|
||||
|
|
Loading…
Reference in New Issue