Graph group chat (#857)

* Move contrib-openai.yml

* Moved groupgroupchat

* From #753

* Removed local test references

* Added ignore=test/agentchat/contrib

* Trying to pass contrib-openai tests

* More specific in unit testing.

* Update .github/workflows/contrib-tests.yml

Co-authored-by: Li Jiang <lijiang1@microsoft.com>

* Remove coverage as it is included in test dependencies

* Improved docstring with overview of GraphGroupChat

* Iterate on feedback

* Precommit pass

* user just use pip install pyautogen[graphs]

* Pass precommit

* Pas precommit

* Graph utils an test completed

* Added inversion tests

* Added inversion util

* allow_repeat_speaker can be a list of Agents

* Remove unnessary imports

* Expect ValueError with 1 and 0 agents

* Check that main passes all tests

* Check main

* Pytest all in main

* All done

* pre-commit changes

* noqa E402

* precommit pass

* Removed bin

* Removed old unit test

* Test test_graph_utils

* minor cleanup

* restore tests

* Correct documentation

* Special case of only one agent remaining.

* Improved pytest

* precommit pass

* Delete OAI_CONFIG_LIST_sample copy

* Returns a filtered list for auto to work

* Rename var speaker_order_dict

* To write test cases

* Added check for a list of Agents to repeat

* precommit pass

* Update documentation

* Extract names in allow_repeat_speaker

* Post review changes

* hange "pull_request_target" into "pull_request" temporarily.

* 3 return values from main

* pre-commit changes

* PC edits

* docstr changes

* PC edits

* Rest of changes from main

* Update autogen/agentchat/groupchat.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* Remove unnecessary script files from tracking

* Non empty scripts files from main

* Revert changes in script files to match main branch

* Removed link from website as notebook is removed.

* test/test_graph_utils.py is tested as part of L52 of build.yml

* GroupChat ValueError check

* docstr update

* More clarification in docstr

* Update autogen/agentchat/groupchat.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* Update autogen/agentchat/groupchat.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* Update autogen/agentchat/groupchat.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* Update autogen/agentchat/groupchat.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* 1.add commit to line138 in groupchat.py;2.fix bug if random choice [];3.return selected_agent if len(graph_eligible_agents) is 1;4.replace all speaker_order to speaker_transitions;5.format

* fix graph_modelling notebook in the last cell

* fix failure in test_groupchat.py

* fix agent out of group to initiate a chat like SocietyOfMind

* add a warning rule in graph_utils to check duplicates in any lists

* refactor allowed_or_disallowed_speaker_transitions to Dict[Agent, List[Agent]] and modify the tests and notebook

* delete Rule 4 in graph_utils and related test case. Add a test to resolve 993fd006e9 (r1460726831)

* fix as the final comments

* modify setup option from graphs to graph and add texts in optional-dependencies.md

* Update autogen/graph_utils.py

---------

Co-authored-by: Li Jiang <lijiang1@microsoft.com>
Co-authored-by: Beibin Li <BeibinLi@users.noreply.github.com>
Co-authored-by: Chi Wang <wang.chi@microsoft.com>
Co-authored-by: Qingyun Wu <qingyun0327@gmail.com>
Co-authored-by: Yishen Sun <freedeaths@FREEDEATHS-XPS>
Co-authored-by: freedeaths <register917@gmail.com>
This commit is contained in:
Joshua Kim 2024-02-06 14:13:18 +11:00 committed by GitHub
parent feed806489
commit c603ca434e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 805 additions and 808 deletions

View File

@ -2,16 +2,27 @@ import logging
import random
import re
import sys
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union, Tuple
from ..code_utils import content_str
from .agent import Agent
from .conversable_agent import ConversableAgent
from ..graph_utils import check_graph_validity, invert_disallowed_to_allowed, has_self_loops
logger = logging.getLogger(__name__)
class NoEligibleSpeakerException(Exception):
"""Exception raised for early termination of a GroupChat."""
def __init__(self, message="No eligible speakers."):
self.message = message
super().__init__(self.message)
@dataclass
class GroupChat:
"""(In preview) A group chat class that contains the following data fields:
@ -30,7 +41,10 @@ class GroupChat:
- "manual": the next speaker is selected manually by user input.
- "random": the next speaker is selected randomly.
- "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`.
- allow_repeat_speaker: whether to allow the same speaker to speak consecutively. Default is True, in which case all speakers are allowed to speak consecutively. If allow_repeat_speaker is a list of Agents, then only those listed agents are allowed to repeat. If set to False, then no speakers are allowed to repeat.
- allow_repeat_speaker: whether to allow the same speaker to speak consecutively. Default is True, in which case all speakers are allowed to speak consecutively. If allow_repeat_speaker is a list of Agents, then only those listed agents are allowed to repeat. If set to False, then no speakers are allowed to repeat. allow_repeat_speaker and allowed_or_disallowed_speaker_transitions are mutually exclusive.
- allowed_or_disallowed_speaker_transitions: a dictionary of keys and list as values. The keys are the source agents, and the values are the agents that the key agent can transition to. Default is None, in which case a fully connected allowed_speaker_transitions_dict is assumed. allow_repeat_speaker and allowed_or_disallowed_speaker_transitions are mutually exclusive.
- speaker_transitions_type: whether the speaker_transitions_type is a dictionary containing lists of allowed agents or disallowed agents. allowed means the allowed_or_disallowed_speaker_transitions is a dictionary containing lists of allowed agents. If set to disallowed, then the allowed_or_disallowed_speaker_transitions is a dictionary containing lists of disallowed agents. Must be supplied if allowed_or_disallowed_speaker_transitions is not None.
- enable_clear_history: enable possibility to clear history of messages for agents manually by providing
"clear history" phrase in user prompt. This is experimental feature.
See description of GroupChatManager.clear_agents_history function for more info.
@ -42,10 +56,95 @@ class GroupChat:
admin_name: Optional[str] = "Admin"
func_call_filter: Optional[bool] = True
speaker_selection_method: Optional[str] = "auto"
allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = True
allow_repeat_speaker: Optional[
Union[bool, List[Agent]]
] = True # It would be set to True if allowed_or_disallowed_speaker_transitions is None
allowed_or_disallowed_speaker_transitions: Optional[Dict] = None
speaker_transitions_type: Optional[str] = None
enable_clear_history: Optional[bool] = False
_VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"]
_VALID_SPEAKER_TRANSITIONS_TYPE = ["allowed", "disallowed", None]
allowed_speaker_transitions_dict: Dict = field(init=False)
def __post_init__(self):
# Post init steers clears of the automatically generated __init__ method from dataclass
# Here, we create allowed_speaker_transitions_dict from the supplied allowed_or_disallowed_speaker_transitions and is_allowed_graph, and lastly checks for validity.
# Check input
if self.speaker_transitions_type is not None:
self.speaker_transitions_type = self.speaker_transitions_type.lower()
assert self.speaker_transitions_type in self._VALID_SPEAKER_TRANSITIONS_TYPE, (
f"GroupChat speaker_transitions_type is set to '{self.speaker_transitions_type}'. "
f"It should be one of {self._VALID_SPEAKER_TRANSITIONS_TYPE} (case insensitive). "
)
# If both self.allowed_or_disallowed_speaker_transitions is None and self.allow_repeat_speaker is None, set allow_repeat_speaker to True to ensure backward compatibility
# Discussed in https://github.com/microsoft/autogen/pull/857#discussion_r1451541204
if self.allowed_or_disallowed_speaker_transitions is None and self.allow_repeat_speaker is None:
self.allow_repeat_speaker = True
# self.allowed_or_disallowed_speaker_transitions and self.allow_repeat_speaker are mutually exclusive parameters.
# Discussed in https://github.com/microsoft/autogen/pull/857#discussion_r1451266661
if self.allowed_or_disallowed_speaker_transitions is not None and self.allow_repeat_speaker is not None:
raise ValueError(
"Don't provide both allowed_or_disallowed_speaker_transitions and allow_repeat_speaker in group chat. "
"Please set one of them to None."
)
# Asks the user to specify whether the speaker_transitions_type is allowed or disallowed if speaker_transitions_type is supplied
# Discussed in https://github.com/microsoft/autogen/pull/857#discussion_r1451259524
if self.allowed_or_disallowed_speaker_transitions is not None and self.speaker_transitions_type is None:
raise ValueError(
"GroupChat allowed_or_disallowed_speaker_transitions is not None, but speaker_transitions_type is None. "
"Please set speaker_transitions_type to either 'allowed' or 'disallowed'."
)
# Inferring self.allowed_speaker_transitions_dict
# Create self.allowed_speaker_transitions_dict if allowed_or_disallowed_speaker_transitions is None, using allow_repeat_speaker
if self.allowed_or_disallowed_speaker_transitions is None:
self.allowed_speaker_transitions_dict = {}
# Create a fully connected allowed_speaker_transitions_dict not including self loops
for agent in self.agents:
self.allowed_speaker_transitions_dict[agent] = [
other_agent for other_agent in self.agents if other_agent != agent
]
# If self.allow_repeat_speaker is True, add self loops to all agents
if self.allow_repeat_speaker:
for agent in self.agents:
self.allowed_speaker_transitions_dict[agent].append(agent)
# Else if self.allow_repeat_speaker is a list of Agents, add self loops to the agents in the list
elif isinstance(self.allow_repeat_speaker, list):
for agent in self.allow_repeat_speaker:
self.allowed_speaker_transitions_dict[agent].append(agent)
# Create self.allowed_speaker_transitions_dict if allowed_or_disallowed_speaker_transitions is not None, using allowed_or_disallowed_speaker_transitions
else:
# Process based on is_allowed_graph
if self.speaker_transitions_type == "allowed":
self.allowed_speaker_transitions_dict = self.allowed_or_disallowed_speaker_transitions
else:
# Logic for processing disallowed allowed_or_disallowed_speaker_transitions to allowed_speaker_transitions_dict
self.allowed_speaker_transitions_dict = invert_disallowed_to_allowed(
self.allowed_or_disallowed_speaker_transitions, self.agents
)
# Inferring self.allow_repeat_speaker from allowed_speaker_transitions_dict using has_self_loops
# Finally, self.allow_repeat_speaker shouldn't be None, so it is set from the the graph.
if self.allow_repeat_speaker is None:
self.allow_repeat_speaker = has_self_loops(self.allowed_speaker_transitions_dict)
# Check for validity
check_graph_validity(
allowed_speaker_transitions_dict=self.allowed_speaker_transitions_dict,
agents=self.agents,
allow_repeat_speaker=self.allow_repeat_speaker,
)
@property
def agent_names(self) -> List[str]:
@ -134,6 +233,12 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
print(f"Invalid input. Please enter a number between 1 and {_n_agents}.")
return None
def random_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[Agent, None]:
"""Randomly select the next speaker."""
if agents is None:
agents = self.agents
return random.choice(agents)
def _prepare_and_select_agents(
self, last_speaker: Agent
) -> Tuple[Optional[Agent], List[Agent], Optional[List[Dict]]]:
@ -198,13 +303,40 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
# remove the last speaker from the list to avoid selecting the same speaker if allow_repeat_speaker is False
agents = agents if allow_repeat_speaker else [agent for agent in agents if agent != last_speaker]
# Filter agents with allowed_speaker_transitions_dict
is_last_speaker_in_group = last_speaker in self.agents
# this condition means last_speaker is a sink in the graph, then no agents are eligible
if last_speaker not in self.allowed_speaker_transitions_dict and is_last_speaker_in_group:
raise NoEligibleSpeakerException(
f"Last speaker {last_speaker.name} is not in the allowed_speaker_transitions_dict."
)
# last_speaker is not in the group, so all agents are eligible
elif last_speaker not in self.allowed_speaker_transitions_dict and not is_last_speaker_in_group:
graph_eligible_agents = []
else:
# Extract agent names from the list of agents
graph_eligible_agents = [
agent for agent in agents if agent in self.allowed_speaker_transitions_dict[last_speaker]
]
# If there is only one eligible agent, just return it to avoid the speaker selection prompt
if len(graph_eligible_agents) == 1:
return graph_eligible_agents[0], graph_eligible_agents, None
# If there are no eligible agents, return None, which means all agents will be taken into consideration in the next step
if len(graph_eligible_agents) == 0:
graph_eligible_agents = None
# Use the selected speaker selection method
select_speaker_messages = None
if self.speaker_selection_method.lower() == "manual":
selected_agent = self.manual_select_speaker(agents)
selected_agent = self.manual_select_speaker(graph_eligible_agents)
elif self.speaker_selection_method.lower() == "round_robin":
selected_agent = self.next_agent(last_speaker, agents)
selected_agent = self.next_agent(last_speaker, graph_eligible_agents)
elif self.speaker_selection_method.lower() == "random":
selected_agent = random.choice(agents)
selected_agent = self.random_select_speaker(graph_eligible_agents)
else:
selected_agent = None
select_speaker_messages = self.messages.copy()
@ -214,11 +346,11 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
if select_speaker_messages[-1].get("tool_calls", False):
select_speaker_messages[-1] = dict(select_speaker_messages[-1], tool_calls=None)
select_speaker_messages = select_speaker_messages + [
{"role": "system", "content": self.select_speaker_prompt(agents)}
{"role": "system", "content": self.select_speaker_prompt(graph_eligible_agents)}
]
return selected_agent, agents, select_speaker_messages
return selected_agent, graph_eligible_agents, select_speaker_messages
def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
def select_speaker(self, last_speaker: Agent, selector: ConversableAgent) -> Agent:
"""Select the next speaker."""
selected_agent, agents, messages = self._prepare_and_select_agents(last_speaker)
if selected_agent:
@ -228,7 +360,7 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
final, name = selector.generate_oai_reply(messages)
return self._finalize_speaker(last_speaker, final, name, agents)
async def a_select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
async def a_select_speaker(self, last_speaker: Agent, selector: ConversableAgent) -> Agent:
"""Select the next speaker."""
selected_agent, agents, messages = self._prepare_and_select_agents(last_speaker)
if selected_agent:
@ -238,7 +370,7 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
final, name = await selector.a_generate_oai_reply(messages)
return self._finalize_speaker(last_speaker, final, name, agents)
def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents: List[Agent]) -> Agent:
def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents: Optional[List[Agent]]) -> Agent:
if not final:
# the LLM client is None, thus no reply is generated. Use round robin instead.
return self.next_agent(last_speaker, agents)
@ -272,7 +404,7 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
roles.append(f"{agent.name}: {agent.description}".strip())
return "\n".join(roles)
def _mentioned_agents(self, message_content: Union[str, List], agents: List[Agent]) -> Dict:
def _mentioned_agents(self, message_content: Union[str, List], agents: Optional[List[Agent]]) -> Dict:
"""Counts the number of times each agent is mentioned in the provided message content.
Args:
@ -282,6 +414,9 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
Returns:
Dict: a counter for mentioned agents.
"""
if agents is None:
agents = self.agents
# Cast message content to str
if isinstance(message_content, dict):
message_content = message_content["content"]
@ -387,6 +522,10 @@ class GroupChatManager(ConversableAgent):
else:
# admin agent is not found in the participants
raise
except NoEligibleSpeakerException:
# No eligible speaker, terminate the conversation
break
if reply is None:
# no reply is generated, exit the chat
break

138
autogen/graph_utils.py Normal file
View File

@ -0,0 +1,138 @@
from typing import Dict, List, Optional, Union
import logging
from autogen.agentchat.groupchat import Agent
def has_self_loops(allowed_speaker_transitions: dict) -> bool:
"""
Returns True if there are self loops in the allowed_speaker_transitions_dict.
"""
return any([key in value for key, value in allowed_speaker_transitions.items()])
def check_graph_validity(
allowed_speaker_transitions_dict: dict,
agents: List[Agent],
allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = True,
):
"""
allowed_speaker_transitions_dict: A dictionary of keys and list as values. The keys are the names of the agents, and the values are the names of the agents that the key agent can transition to.
agents: A list of Agents
allow_repeat_speaker: A boolean indicating whether the same agent can speak twice in a row.
Checks for the following:
Errors
1. The dictionary must have a structure of keys and list as values
2. Every key exists in agents.
3. Every value is a list of Agents (not string).
Warnings
1. Warning if there are isolated agent nodes
2. Warning if the set of agents in allowed_speaker_transitions do not match agents
3. Warning if there are duplicated agents in any values of `allowed_speaker_transitions_dict`
"""
### Errors
# Check 1. The dictionary must have a structure of keys and list as values
if not isinstance(allowed_speaker_transitions_dict, dict):
raise ValueError("allowed_speaker_transitions_dict must be a dictionary.")
# All values must be lists of Agent or empty
if not all([isinstance(value, list) or value == [] for value in allowed_speaker_transitions_dict.values()]):
raise ValueError("allowed_speaker_transitions_dict must be a dictionary of keys and list as values.")
# Check 2. Every key exists in agents
if not all([key in agents for key in allowed_speaker_transitions_dict.keys()]):
raise ValueError("allowed_speaker_transitions_dict has keys not in agents' names.")
# Check 3. Every value is a list of Agents or empty list (not string).
if not all(
[all([isinstance(agent, Agent) for agent in value]) for value in allowed_speaker_transitions_dict.values()]
):
raise ValueError("allowed_speaker_transitions_dict has values that are not lists of Agents.")
# Warnings
# Warning 1. Warning if there are isolated agent nodes, there are not incoming nor outgoing edges
# Concat keys if len(value) is positive
has_outgoing_edge = []
for key, agent_list in allowed_speaker_transitions_dict.items():
if len(agent_list) > 0:
has_outgoing_edge.append(key)
no_outgoing_edges = [agent for agent in agents if agent not in has_outgoing_edge]
# allowed_speaker_transitions_dict.values() is a list of list of Agents
# values_all_agents is a list of all agents in allowed_speaker_transitions_dict.values()
has_incoming_edge = []
for agent_list in allowed_speaker_transitions_dict.values():
if len(agent_list) > 0:
has_incoming_edge.extend(agent_list)
no_incoming_edges = [agent for agent in agents if agent not in has_incoming_edge]
isolated_agents = set(no_incoming_edges).intersection(set(no_outgoing_edges))
if len(isolated_agents) > 0:
logging.warning(
f"""Warning: There are isolated agent nodes, there are not incoming nor outgoing edges. Isolated agents: {[agent.name for agent in isolated_agents]}"""
)
# Warning 2. Warning if the set of agents in allowed_speaker_transitions do not match agents
# Get set of agents
agents_in_allowed_speaker_transitions = set(has_incoming_edge).union(set(has_outgoing_edge))
full_anti_join = set(agents_in_allowed_speaker_transitions).symmetric_difference(set(agents))
if len(full_anti_join) > 0:
logging.warning(
f"""Warning: The set of agents in allowed_speaker_transitions do not match agents. Offending agents: {[agent.name for agent in full_anti_join]}"""
)
# Warning 3. Warning if there are duplicated agents in any values of `allowed_speaker_transitions_dict`
for key, values in allowed_speaker_transitions_dict.items():
duplicates = [item for item in values if values.count(item) > 1]
unique_duplicates = list(set(duplicates))
if unique_duplicates:
logging.warning(
f"Agent '{key.name}' has duplicate elements: {[agent.name for agent in unique_duplicates]}. Please remove duplicates manually."
)
def invert_disallowed_to_allowed(disallowed_speaker_transitions_dict: dict, agents: List[Agent]) -> dict:
"""
Start with a fully connected allowed_speaker_transitions_dict of all agents. Remove edges from the fully connected allowed_speaker_transitions_dict according to the disallowed_speaker_transitions_dict to form the allowed_speaker_transitions_dict.
"""
# Create a fully connected allowed_speaker_transitions_dict of all agents
allowed_speaker_transitions_dict = {agent: [other_agent for other_agent in agents] for agent in agents}
# Remove edges from allowed_speaker_transitions_dict according to the disallowed_speaker_transitions_dict
for key, value in disallowed_speaker_transitions_dict.items():
allowed_speaker_transitions_dict[key] = [
agent for agent in allowed_speaker_transitions_dict[key] if agent not in value
]
return allowed_speaker_transitions_dict
def visualize_speaker_transitions_dict(speaker_transitions_dict: dict, agents: List[Agent]):
"""
Visualize the speaker_transitions_dict using networkx.
"""
try:
import networkx as nx
import matplotlib.pyplot as plt
except ImportError as e:
logging.fatal("Failed to import networkx or matplotlib. Try running 'pip install autogen[graphs]'")
raise e
G = nx.DiGraph()
# Add nodes
G.add_nodes_from([agent.name for agent in agents])
# Add edges
for key, value in speaker_transitions_dict.items():
for agent in value:
G.add_edge(key.name, agent.name)
# Visualize
nx.draw(G, with_labels=True, font_weight="bold")
plt.show()

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -51,7 +51,7 @@ setuptools.setup(
"autobuild": ["chromadb", "sentence-transformers", "huggingface-hub"],
"teachable": ["chromadb"],
"lmm": ["replicate", "pillow"],
"graphs": ["networkx~=3.2.1", "matplotlib~=3.8.1"],
"graph": ["networkx", "matplotlib"],
"websurfer": ["beautifulsoup4", "markdownify", "pdfminer.six", "pathvalidate"],
"redis": ["redis"],
},

View File

@ -3,6 +3,8 @@ from unittest import mock
import builtins
import autogen
import json
import sys
from autogen import Agent, GroupChat
def test_func_call_groupchat():
@ -199,27 +201,11 @@ def _test_n_agents_less_than_3(method):
"This is bob speaking.",
] * 3
# test one agent
groupchat = autogen.GroupChat(
agents=[agent1],
messages=[],
max_round=6,
speaker_selection_method="round_robin",
allow_repeat_speaker=False,
)
with pytest.raises(ValueError):
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
agent1.initiate_chat(group_chat_manager, message="This is alice speaking.")
# test zero agent
groupchat = autogen.GroupChat(
agents=[],
messages=[],
max_round=6,
speaker_selection_method="round_robin",
allow_repeat_speaker=False,
)
with pytest.raises(ValueError):
groupchat = autogen.GroupChat(
agents=[], messages=[], max_round=6, speaker_selection_method="round_robin", allow_repeat_speaker=False
)
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
agent1.initiate_chat(group_chat_manager, message="This is alice speaking.")
@ -504,6 +490,104 @@ def test_selection_helpers():
groupchat.manual_select_speaker()
def test_init_default_parameters():
agents = [Agent(name=f"Agent{i}") for i in range(3)]
group_chat = GroupChat(agents=agents, messages=[], max_round=3)
for agent in agents:
assert set([a.name for a in group_chat.allowed_speaker_transitions_dict[agent]]) == set(
[a.name for a in agents]
)
def test_graph_parameters():
agents = [Agent(name=f"Agent{i}") for i in range(3)]
with pytest.raises(ValueError):
GroupChat(
agents=agents,
messages=[],
max_round=3,
allow_repeat_speaker=None,
allowed_or_disallowed_speaker_transitions={agents[0]: [agents[1]], agents[1]: [agents[2]]},
)
with pytest.raises(AssertionError):
GroupChat(
agents=agents,
messages=[],
max_round=3,
allow_repeat_speaker=None,
allowed_or_disallowed_speaker_transitions={agents[0]: [agents[1]], agents[1]: [agents[2]]},
speaker_transitions_type="a",
)
group_chat = GroupChat(
agents=agents,
messages=[],
max_round=3,
allow_repeat_speaker=None,
allowed_or_disallowed_speaker_transitions={agents[0]: [agents[1]], agents[1]: [agents[2]]},
speaker_transitions_type="allowed",
)
assert "Agent0" in group_chat.agent_names
def test_graph_validity_check():
agents = [Agent(name=f"Agent{i}") for i in range(3)]
invalid_transitions = {agents[0]: []}
with pytest.raises(ValueError):
GroupChat(
agents=agents,
messages=[],
allowed_or_disallowed_speaker_transitions=invalid_transitions,
speaker_transitions_type="allowed",
)
def test_graceful_exit_before_max_round():
agent1 = autogen.ConversableAgent(
"alice",
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is alice speaking.",
)
agent2 = autogen.ConversableAgent(
"bob",
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is bob speaking.",
)
agent3 = autogen.ConversableAgent(
"sam",
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is sam speaking. TERMINATE",
)
# This speaker_transitions limits the transition to be only from agent1 to agent2, and from agent2 to agent3 and end.
allowed_or_disallowed_speaker_transitions = {agent1: [agent2], agent2: [agent3]}
# Test empty is_termination_msg function
groupchat = autogen.GroupChat(
agents=[agent1, agent2, agent3],
messages=[],
speaker_selection_method="round_robin",
max_round=10,
allow_repeat_speaker=None,
allowed_or_disallowed_speaker_transitions=allowed_or_disallowed_speaker_transitions,
speaker_transitions_type="allowed",
)
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False, is_termination_msg=None)
agent1.initiate_chat(group_chat_manager, message="'None' is_termination_msg function.")
# Note that 3 is much lower than 10 (max_round), so the conversation should end before 10 rounds.
assert len(groupchat.messages) == 3
def test_clear_agents_history():
agent1 = autogen.ConversableAgent(
"alice",
@ -604,4 +688,5 @@ if __name__ == "__main__":
# test_termination()
# test_next_agent()
# test_invalid_allow_repeat_speaker()
# test_graceful_exit_before_max_round()
test_clear_agents_history()

165
test/test_graph_utils.py Normal file
View File

@ -0,0 +1,165 @@
import sys
import pytest
import logging
from autogen.agentchat import Agent
import autogen.graph_utils as gru
class TestHelpers:
def test_has_self_loops(self):
# Setup test data
agents = [Agent(name=f"Agent{i}") for i in range(3)]
allowed_speaker_transitions = {
agents[0]: [agents[1], agents[2]],
agents[1]: [agents[2]],
agents[2]: [agents[0]],
}
allowed_speaker_transitions_with_self_loops = {
agents[0]: [agents[0], agents[1], agents[2]],
agents[1]: [agents[1], agents[2]],
agents[2]: [agents[0]],
}
# Testing
assert not gru.has_self_loops(allowed_speaker_transitions)
assert gru.has_self_loops(allowed_speaker_transitions_with_self_loops)
class TestGraphUtilCheckGraphValidity:
def test_valid_structure(self):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
valid_speaker_transitions_dict = {agent: [other_agent for other_agent in agents] for agent in agents}
gru.check_graph_validity(allowed_speaker_transitions_dict=valid_speaker_transitions_dict, agents=agents)
def test_graph_with_invalid_structure(self):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
unseen_agent = Agent("unseen_agent")
invalid_speaker_transitions_dict = {unseen_agent: ["stranger"]}
with pytest.raises(ValueError):
gru.check_graph_validity(invalid_speaker_transitions_dict, agents)
def test_graph_with_invalid_string(self):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
invalid_speaker_transitions_dict = {
agent: ["agent1"] for agent in agents
} # 'agent1' is a string, not an Agent. Therefore raises an error.
with pytest.raises(ValueError):
gru.check_graph_validity(invalid_speaker_transitions_dict, agents)
def test_graph_with_unauthorized_self_loops(self):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
# Creating a subset of agents allowed to have self-loops
allowed_repeat_speakers = agents[: len(agents) // 2]
# Constructing a speaker transitions dictionary with self-loops for all agents
# Ensuring at least one agent outside the allowed_repeat_speakers has a self-loop
speaker_transitions_dict_with_self_loop = {agent: agent for agent in agents}
# Testing the function with the constructed speaker transitions dict
with pytest.raises(ValueError):
gru.check_graph_validity(
speaker_transitions_dict_with_self_loop, agents, allow_repeat_speaker=allowed_repeat_speakers
)
# Test for Warning 1: Isolated agent nodes
def test_isolated_agent_nodes_warning(self, caplog):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
# Create a speaker_transitions_dict where at least one agent is isolated
speaker_transitions_dict_with_isolation = {agents[0]: [agents[0], agents[1]], agents[1]: [agents[0]]}
# Add an isolated agent
speaker_transitions_dict_with_isolation[agents[2]] = []
with caplog.at_level(logging.WARNING):
gru.check_graph_validity(
allowed_speaker_transitions_dict=speaker_transitions_dict_with_isolation, agents=agents
)
assert "isolated" in caplog.text
# Test for Warning 2: Warning if the set of agents in allowed_speaker_transitions do not match agents
def test_warning_for_mismatch_in_agents(self, caplog):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
# Test with missing agents in allowed_speaker_transitions_dict
unknown_agent_dict = {
agents[0]: [agents[0], agents[1], agents[2]],
agents[1]: [agents[0], agents[1], agents[2]],
agents[2]: [agents[0], agents[1], agents[2], Agent("unknown_agent")],
}
with caplog.at_level(logging.WARNING):
gru.check_graph_validity(allowed_speaker_transitions_dict=unknown_agent_dict, agents=agents)
assert "allowed_speaker_transitions do not match agents" in caplog.text
# Test for Warning 3: Warning if there is duplicated agents in allowed_speaker_transitions_dict
def test_warning_for_duplicate_agents(self, caplog):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
# Construct an `allowed_speaker_transitions_dict` with duplicated agents
duplicate_agents_dict = {
agents[0]: [agents[0], agents[1], agents[2]],
agents[1]: [agents[0], agents[1], agents[2], agents[1]],
agents[2]: [agents[0], agents[1], agents[2], agents[0], agents[2]],
}
with caplog.at_level(logging.WARNING):
gru.check_graph_validity(allowed_speaker_transitions_dict=duplicate_agents_dict, agents=agents)
assert "duplicate" in caplog.text
class TestGraphUtilInvertDisallowedToAllowed:
def test_basic_functionality(self):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
disallowed_graph = {agents[0]: [agents[1]], agents[1]: [agents[0], agents[2]], agents[2]: []}
expected_allowed_graph = {
agents[0]: [agents[0], agents[2]],
agents[1]: [agents[1]],
agents[2]: [agents[0], agents[1], agents[2]],
}
# Compare names of agents
inverted = gru.invert_disallowed_to_allowed(disallowed_graph, agents)
assert inverted == expected_allowed_graph
def test_empty_disallowed_graph(self):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
disallowed_graph = {}
expected_allowed_graph = {
agents[0]: [agents[0], agents[1], agents[2]],
agents[1]: [agents[0], agents[1], agents[2]],
agents[2]: [agents[0], agents[1], agents[2]],
}
# Compare names of agents
inverted = gru.invert_disallowed_to_allowed(disallowed_graph, agents)
assert inverted == expected_allowed_graph
def test_fully_disallowed_graph(self):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
disallowed_graph = {
agents[0]: [agents[0], agents[1], agents[2]],
agents[1]: [agents[0], agents[1], agents[2]],
agents[2]: [agents[0], agents[1], agents[2]],
}
expected_allowed_graph = {agents[0]: [], agents[1]: [], agents[2]: []}
# Compare names of agents
inverted = gru.invert_disallowed_to_allowed(disallowed_graph, agents)
assert inverted == expected_allowed_graph
def test_disallowed_graph_with_nonexistent_agent(self):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
disallowed_graph = {agents[0]: [Agent("nonexistent_agent")]}
# In this case, the function should ignore the nonexistent agent and proceed with the inversion
expected_allowed_graph = {
agents[0]: [agents[0], agents[1], agents[2]],
agents[1]: [agents[0], agents[1], agents[2]],
agents[2]: [agents[0], agents[1], agents[2]],
}
# Compare names of agents
inverted = gru.invert_disallowed_to_allowed(disallowed_graph, agents)
assert inverted == expected_allowed_graph

View File

@ -20,7 +20,6 @@ Links to notebook examples:
- Automated Data Visualization by Group Chat (with 3 group member agents and 1 manager agent) - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_groupchat_vis.ipynb)
- Automated Complex Task Solving by Group Chat (with 6 group member agents and 1 manager agent) - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_groupchat_research.ipynb)
- Automated Task Solving with Coding & Planning Agents - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_planning.ipynb)
- Automated Task Solving with agents divided into 2 groups - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_hierarchy_flow_using_select_speaker.ipynb)
- Automated Task Solving with transition paths specified in a graph - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_graph_modelling_language_using_select_speaker.ipynb)
- Running a group chat as an inner-monolgue via the SocietyOfMindAgent - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_society_of_mind.ipynb)

View File

@ -313,12 +313,17 @@ On the one hand, one can achieve fully autonomous conversations after an initial
#### Static and dynamic conversations
By adopting the conversation-driven control with both programming language and natural language, AutoGen inherently allows dynamic conversation. Dynamic conversation allows the agent topology to change depending on the actual flow of conversation under different input problem instances, while the flow of a static conversation always follows a pre-defined topology. The dynamic conversation pattern is useful in complex applications where the patterns of interaction cannot be predetermined in advance. AutoGen provides two general approaches to achieving dynamic conversation:
AutoGen, by integrating conversation-driven control utilizing both programming and natural language, inherently supports dynamic conversations. This dynamic nature allows the agent topology to adapt based on the actual conversation flow under varying input problem scenarios. Conversely, static conversations adhere to a predefined topology. Dynamic conversations are particularly beneficial in complex settings where interaction patterns cannot be predetermined.
- Registered auto-reply. With the pluggable auto-reply function, one can choose to invoke conversations with other agents depending on the content of the current message and context. A working system demonstrating this type of dynamic conversation can be found in this code example, demonstrating a [dynamic group chat](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_groupchat.ipynb). In the system, we register an auto-reply function in the group chat manager, which lets LLM decide who the next speaker will be in a group chat setting.
1. Registered auto-reply
With the pluggable auto-reply function, one can choose to invoke conversations with other agents depending on the content of the current message and context. For example:
- Hierarchical chat like in [OptiGuide](https://github.com/microsoft/optiguide).
- [Dynamic Group Chat](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_groupchat.ipynb) which is a special form of hierarchical chat. In the system, we register a reply function in the group chat manager, which broadcasts messages and decides who the next speaker will be in a group chat setting.
- [Finite state machine (FSM) based group chat](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_graph_modelling_language_using_select_speaker.ipynb) which is a special form of dynamic group chat. In this approach, a directed transition matrix is fed into group chat. Users can specify legal transitions or specify disallowed transitions.
- Nested chat like in [conversational chess](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_chess.ipynb).
- LLM-based function call. In this approach, LLM decides whether or not to call a particular function depending on the conversation status in each inference call.
By messaging additional agents in the called functions, the LLM can drive dynamic multi-agent conversation. A working system showcasing this type of dynamic conversation can be found in the [multi-user math problem solving scenario](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_two_users.ipynb), where a student assistant would automatically resort to an expert using function calls.
2. LLM-Based Function Call
Another approach involves LLM-based function calls, where LLM decides if a specific function should be invoked based on the conversation's status during each inference. This approach enables dynamic multi-agent conversations, as seen in scenarios like [multi-user math problem solving scenario](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_two_users.ipynb), where a student assistant automatically seeks expertise via function calls.
### LLM Caching

View File

@ -107,3 +107,14 @@ pip install "pyautogen[mathchat]<0.2"
Example notebooks:
[Using MathChat to Solve Math Problems](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_MathChat.ipynb)
## Graph
To use a graph in `GroupChat`, particularly for graph visualization, please install AutoGen with the [graph] option.
```bash
pip install "pyautogen[graph]"
```
Example notebook: [Graph Modeling Language with using select_speaker](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_graph_modelling_language_using_select_speaker.ipynb)