patch to graph groupchat (#1555)

* patch to #1541

* graph validity test

* update docstr
This commit is contained in:
Chi Wang 2024-02-05 23:18:13 -08:00 committed by GitHub
parent 6f8ccf80e7
commit 18a15d78fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 57 additions and 73 deletions

View File

@ -9,7 +9,7 @@ from typing import Dict, List, Optional, Union, Tuple
from ..code_utils import content_str
from .agent import Agent
from .conversable_agent import ConversableAgent
from ..graph_utils import check_graph_validity, invert_disallowed_to_allowed, has_self_loops
from ..graph_utils import check_graph_validity, invert_disallowed_to_allowed
logger = logging.getLogger(__name__)
@ -42,12 +42,22 @@ class GroupChat:
- "random": the next speaker is selected randomly.
- "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`.
- allow_repeat_speaker: whether to allow the same speaker to speak consecutively. Default is True, in which case all speakers are allowed to speak consecutively. If allow_repeat_speaker is a list of Agents, then only those listed agents are allowed to repeat. If set to False, then no speakers are allowed to repeat. allow_repeat_speaker and allowed_or_disallowed_speaker_transitions are mutually exclusive.
- allowed_or_disallowed_speaker_transitions: a dictionary of keys and list as values. The keys are the source agents, and the values are the agents that the key agent can transition to. Default is None, in which case a fully connected allowed_speaker_transitions_dict is assumed. allow_repeat_speaker and allowed_or_disallowed_speaker_transitions are mutually exclusive.
- speaker_transitions_type: whether the speaker_transitions_type is a dictionary containing lists of allowed agents or disallowed agents. allowed means the allowed_or_disallowed_speaker_transitions is a dictionary containing lists of allowed agents. If set to disallowed, then the allowed_or_disallowed_speaker_transitions is a dictionary containing lists of disallowed agents. Must be supplied if allowed_or_disallowed_speaker_transitions is not None.
- allow_repeat_speaker: whether to allow the same speaker to speak consecutively.
Default is True, in which case all speakers are allowed to speak consecutively.
If `allow_repeat_speaker` is a list of Agents, then only those listed agents are allowed to repeat.
If set to False, then no speakers are allowed to repeat.
`allow_repeat_speaker` and `allowed_or_disallowed_speaker_transitions` are mutually exclusive.
- allowed_or_disallowed_speaker_transitions: dict.
The keys are source agents, and the values are agents that the key agent can/can't transit to,
depending on speaker_transitions_type. Default is None, which means all agents can transit to all other agents.
`allow_repeat_speaker` and `allowed_or_disallowed_speaker_transitions` are mutually exclusive.
- speaker_transitions_type: whether the speaker_transitions_type is a dictionary containing lists of allowed agents or disallowed agents.
"allowed" means the `allowed_or_disallowed_speaker_transitions` is a dictionary containing lists of allowed agents.
If set to "disallowed", then the `allowed_or_disallowed_speaker_transitions` is a dictionary containing lists of disallowed agents.
Must be supplied if `allowed_or_disallowed_speaker_transitions` is not None.
- enable_clear_history: enable possibility to clear history of messages for agents manually by providing
"clear history" phrase in user prompt. This is experimental feature.
See description of GroupChatManager.clear_agents_history function for more info.
See description of `GroupChatManager.clear_agents_history` function for more info.
"""
agents: List[Agent]
@ -56,9 +66,7 @@ class GroupChat:
admin_name: Optional[str] = "Admin"
func_call_filter: Optional[bool] = True
speaker_selection_method: Optional[str] = "auto"
allow_repeat_speaker: Optional[
Union[bool, List[Agent]]
] = True # It would be set to True if allowed_or_disallowed_speaker_transitions is None
allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = None
allowed_or_disallowed_speaker_transitions: Optional[Dict] = None
speaker_transitions_type: Optional[str] = None
enable_clear_history: Optional[bool] = False
@ -70,16 +78,21 @@ class GroupChat:
def __post_init__(self):
# Post init steers clears of the automatically generated __init__ method from dataclass
# Here, we create allowed_speaker_transitions_dict from the supplied allowed_or_disallowed_speaker_transitions and is_allowed_graph, and lastly checks for validity.
if self.allow_repeat_speaker is not None and not isinstance(self.allow_repeat_speaker, (bool, list)):
raise ValueError("GroupChat allow_repeat_speaker should be a bool or a list of Agents.")
# Here, we create allowed_speaker_transitions_dict from the supplied allowed_or_disallowed_speaker_transitions and speaker_transitions_type, and lastly checks for validity.
# Check input
if self.speaker_transitions_type is not None:
self.speaker_transitions_type = self.speaker_transitions_type.lower()
assert self.speaker_transitions_type in self._VALID_SPEAKER_TRANSITIONS_TYPE, (
f"GroupChat speaker_transitions_type is set to '{self.speaker_transitions_type}'. "
f"It should be one of {self._VALID_SPEAKER_TRANSITIONS_TYPE} (case insensitive). "
)
if self.speaker_transitions_type not in self._VALID_SPEAKER_TRANSITIONS_TYPE:
raise ValueError(
f"GroupChat speaker_transitions_type is set to '{self.speaker_transitions_type}'. "
f"It should be one of {self._VALID_SPEAKER_TRANSITIONS_TYPE} (case insensitive). "
)
# If both self.allowed_or_disallowed_speaker_transitions is None and self.allow_repeat_speaker is None, set allow_repeat_speaker to True to ensure backward compatibility
# Discussed in https://github.com/microsoft/autogen/pull/857#discussion_r1451541204
@ -114,7 +127,7 @@ class GroupChat:
]
# If self.allow_repeat_speaker is True, add self loops to all agents
if self.allow_repeat_speaker:
if self.allow_repeat_speaker is True:
for agent in self.agents:
self.allowed_speaker_transitions_dict[agent].append(agent)
@ -125,7 +138,7 @@ class GroupChat:
# Create self.allowed_speaker_transitions_dict if allowed_or_disallowed_speaker_transitions is not None, using allowed_or_disallowed_speaker_transitions
else:
# Process based on is_allowed_graph
# Process based on speaker_transitions_type
if self.speaker_transitions_type == "allowed":
self.allowed_speaker_transitions_dict = self.allowed_or_disallowed_speaker_transitions
else:
@ -134,16 +147,10 @@ class GroupChat:
self.allowed_or_disallowed_speaker_transitions, self.agents
)
# Inferring self.allow_repeat_speaker from allowed_speaker_transitions_dict using has_self_loops
# Finally, self.allow_repeat_speaker shouldn't be None, so it is set from the the graph.
if self.allow_repeat_speaker is None:
self.allow_repeat_speaker = has_self_loops(self.allowed_speaker_transitions_dict)
# Check for validity
check_graph_validity(
allowed_speaker_transitions_dict=self.allowed_speaker_transitions_dict,
agents=self.agents,
allow_repeat_speaker=self.allow_repeat_speaker,
)
@property
@ -248,12 +255,10 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
f"It should be one of {self._VALID_SPEAKER_SELECTION_METHODS} (case insensitive). "
)
if not isinstance(self.allow_repeat_speaker, (bool, list)):
raise ValueError("GroupChat allow_repeat_speaker should be a bool or a list of Agents.")
# If provided a list, make sure the agent is in the list
allow_repeat_speaker = (
self.allow_repeat_speaker
if isinstance(self.allow_repeat_speaker, bool)
if isinstance(self.allow_repeat_speaker, bool) or self.allow_repeat_speaker is None
else last_speaker in self.allow_repeat_speaker
)
@ -268,8 +273,8 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
elif n_agents == 2 and self.speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker:
logger.warning(
f"GroupChat is underpopulated with {n_agents} agents. "
"It is recommended to set speaker_selection_method to 'round_robin' or allow_repeat_speaker to False."
"Or, use direct communication instead."
"Consider setting speaker_selection_method to 'round_robin' or allow_repeat_speaker to False, "
"or use direct communication, unless repeated speaker is desired."
)
if (
@ -301,7 +306,7 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
"Please check the function_map of the agents."
)
# remove the last speaker from the list to avoid selecting the same speaker if allow_repeat_speaker is False
agents = agents if allow_repeat_speaker else [agent for agent in agents if agent != last_speaker]
agents = [agent for agent in agents if agent != last_speaker] if allow_repeat_speaker is False else agents
# Filter agents with allowed_speaker_transitions_dict

View File

@ -4,22 +4,20 @@ import logging
from autogen.agentchat.groupchat import Agent
def has_self_loops(allowed_speaker_transitions: dict) -> bool:
def has_self_loops(allowed_speaker_transitions: Dict) -> bool:
"""
Returns True if there are self loops in the allowed_speaker_transitions_dict.
Returns True if there are self loops in the allowed_speaker_transitions_Dict.
"""
return any([key in value for key, value in allowed_speaker_transitions.items()])
def check_graph_validity(
allowed_speaker_transitions_dict: dict,
allowed_speaker_transitions_dict: Dict,
agents: List[Agent],
allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = True,
):
"""
allowed_speaker_transitions_dict: A dictionary of keys and list as values. The keys are the names of the agents, and the values are the names of the agents that the key agent can transition to.
agents: A list of Agents
allow_repeat_speaker: A boolean indicating whether the same agent can speak twice in a row.
Checks for the following:
Errors
@ -40,12 +38,12 @@ def check_graph_validity(
raise ValueError("allowed_speaker_transitions_dict must be a dictionary.")
# All values must be lists of Agent or empty
if not all([isinstance(value, list) or value == [] for value in allowed_speaker_transitions_dict.values()]):
raise ValueError("allowed_speaker_transitions_dict must be a dictionary of keys and list as values.")
if not all([isinstance(value, list) for value in allowed_speaker_transitions_dict.values()]):
raise ValueError("allowed_speaker_transitions_dict must be a dictionary with lists as values.")
# Check 2. Every key exists in agents
if not all([key in agents for key in allowed_speaker_transitions_dict.keys()]):
raise ValueError("allowed_speaker_transitions_dict has keys not in agents' names.")
raise ValueError("allowed_speaker_transitions_dict has keys not in agents.")
# Check 3. Every value is a list of Agents or empty list (not string).
if not all(

View File

@ -621,7 +621,6 @@
" agents=agents,\n",
" messages=[],\n",
" max_round=20,\n",
" allow_repeat_speaker=None,\n",
" allowed_or_disallowed_speaker_transitions=speaker_transitions_dict,\n",
" speaker_transitions_type=\"allowed\",\n",
")\n",

View File

@ -154,6 +154,7 @@ class GroupChatConfig:
max_round: Optional[int] = 10
admin_name: Optional[str] = "Admin"
speaker_selection_method: Optional[str] = "auto"
# TODO: match the new group chat default and support transition spec
allow_repeat_speaker: Optional[Union[bool, List[AgentConfig]]] = True
def dict(self):

View File

@ -226,16 +226,14 @@ def test_invalid_allow_repeat_speaker():
default_auto_reply="This is bob speaking.",
)
# test invalid allow_repeat_speaker
groupchat = autogen.GroupChat(
agents=[agent1, agent2],
messages=[],
max_round=6,
speaker_selection_method="round_robin",
allow_repeat_speaker={},
)
with pytest.raises(ValueError) as e:
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
agent1.initiate_chat(group_chat_manager, message="This is alice speaking.")
autogen.GroupChat(
agents=[agent1, agent2],
messages=[],
max_round=6,
speaker_selection_method="round_robin",
allow_repeat_speaker={},
)
assert str(e.value) == "GroupChat allow_repeat_speaker should be a bool or a list of Agents.", e.value
@ -506,11 +504,18 @@ def test_graph_parameters():
agents=agents,
messages=[],
max_round=3,
allow_repeat_speaker=None,
allowed_or_disallowed_speaker_transitions={agents[0]: [agents[1]], agents[1]: [agents[2]]},
)
with pytest.raises(ValueError):
GroupChat(
agents=agents,
messages=[],
max_round=3,
allow_repeat_speaker=False, # should be None
allowed_or_disallowed_speaker_transitions={agents[0]: [agents[1]], agents[1]: [agents[2]]},
)
with pytest.raises(AssertionError):
with pytest.raises(ValueError):
GroupChat(
agents=agents,
messages=[],
@ -524,25 +529,12 @@ def test_graph_parameters():
agents=agents,
messages=[],
max_round=3,
allow_repeat_speaker=None,
allowed_or_disallowed_speaker_transitions={agents[0]: [agents[1]], agents[1]: [agents[2]]},
speaker_transitions_type="allowed",
)
assert "Agent0" in group_chat.agent_names
def test_graph_validity_check():
agents = [Agent(name=f"Agent{i}") for i in range(3)]
invalid_transitions = {agents[0]: []}
with pytest.raises(ValueError):
GroupChat(
agents=agents,
messages=[],
allowed_or_disallowed_speaker_transitions=invalid_transitions,
speaker_transitions_type="allowed",
)
def test_graceful_exit_before_max_round():
agent1 = autogen.ConversableAgent(
"alice",

View File

@ -1,4 +1,3 @@
import sys
import pytest
import logging
from autogen.agentchat import Agent
@ -46,20 +45,10 @@ class TestGraphUtilCheckGraphValidity:
with pytest.raises(ValueError):
gru.check_graph_validity(invalid_speaker_transitions_dict, agents)
def test_graph_with_unauthorized_self_loops(self):
def test_graph_with_invalid_key(self):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
# Creating a subset of agents allowed to have self-loops
allowed_repeat_speakers = agents[: len(agents) // 2]
# Constructing a speaker transitions dictionary with self-loops for all agents
# Ensuring at least one agent outside the allowed_repeat_speakers has a self-loop
speaker_transitions_dict_with_self_loop = {agent: agent for agent in agents}
# Testing the function with the constructed speaker transitions dict
with pytest.raises(ValueError):
gru.check_graph_validity(
speaker_transitions_dict_with_self_loop, agents, allow_repeat_speaker=allowed_repeat_speakers
)
gru.check_graph_validity({1: 1}, agents)
# Test for Warning 1: Isolated agent nodes
def test_isolated_agent_nodes_warning(self, caplog):