mirror of https://github.com/microsoft/autogen.git
Allow user to pass in a customized speaker selection method (#1791)
* init PR * update * update code check * update * update * update * update * Test the ability to have agents a,u,t,o,g,e,n speak in turn. * update * update * update * Evidence that groupchat not terminating because of the TERMINATE substring. * Raising NoEligibleSpeakerException allows graceful exit before max turns * update * To confirm with author that custom function is meant to override graph constraints * Confirmed the expected test behaviour with author * Update autogen/agentchat/groupchat.py * update * update --------- Co-authored-by: Joshua Kim <Joshua@spectdata.com> Co-authored-by: Qingyun Wu <qingyun0327@gmail.com>
This commit is contained in:
parent
d711bd8e5d
commit
c37227bd04
|
@ -3,7 +3,7 @@ import random
|
|||
import re
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Union, Tuple
|
||||
from typing import Dict, List, Optional, Union, Tuple, Callable
|
||||
|
||||
|
||||
from ..code_utils import content_str
|
||||
|
@ -42,7 +42,16 @@ class GroupChat:
|
|||
- "manual": the next speaker is selected manually by user input.
|
||||
- "random": the next speaker is selected randomly.
|
||||
- "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`.
|
||||
|
||||
- a customized speaker selection function (Callable): the function will be called to select the next speaker.
|
||||
The function should take the last speaker and the group chat as input and return one of the following:
|
||||
1. an `Agent` class, it must be one of the agents in the group chat.
|
||||
2. a string from ['auto', 'manual', 'random', 'round_robin'] to select a default method to use.
|
||||
3. None, which would terminate the conversation gracefully.
|
||||
```python
|
||||
def custom_speaker_selection_func(
|
||||
last_speaker: Agent, groupchat: GroupChat
|
||||
) -> Union[Agent, str, None]:
|
||||
```
|
||||
- allow_repeat_speaker: whether to allow the same speaker to speak consecutively.
|
||||
Default is True, in which case all speakers are allowed to speak consecutively.
|
||||
If `allow_repeat_speaker` is a list of Agents, then only those listed agents are allowed to repeat.
|
||||
|
@ -67,7 +76,7 @@ class GroupChat:
|
|||
max_round: Optional[int] = 10
|
||||
admin_name: Optional[str] = "Admin"
|
||||
func_call_filter: Optional[bool] = True
|
||||
speaker_selection_method: Optional[str] = "auto"
|
||||
speaker_selection_method: Optional[Union[str, Callable]] = "auto"
|
||||
allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = None
|
||||
allowed_or_disallowed_speaker_transitions: Optional[Dict] = None
|
||||
speaker_transitions_type: Optional[str] = None
|
||||
|
@ -277,11 +286,36 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
|
|||
return random.choice(agents)
|
||||
|
||||
def _prepare_and_select_agents(
|
||||
self, last_speaker: Agent
|
||||
self,
|
||||
last_speaker: Agent,
|
||||
) -> Tuple[Optional[Agent], List[Agent], Optional[List[Dict]]]:
|
||||
if self.speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS:
|
||||
# If self.speaker_selection_method is a callable, call it to get the next speaker.
|
||||
# If self.speaker_selection_method is a string, return it.
|
||||
speaker_selection_method = self.speaker_selection_method
|
||||
if isinstance(self.speaker_selection_method, Callable):
|
||||
selected_agent = self.speaker_selection_method(last_speaker, self)
|
||||
if selected_agent is None:
|
||||
raise NoEligibleSpeakerException(
|
||||
"Custom speaker selection function returned None. Terminating conversation."
|
||||
)
|
||||
elif isinstance(selected_agent, Agent):
|
||||
if selected_agent in self.agents:
|
||||
return selected_agent, self.agents, None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Custom speaker selection function returned an agent {selected_agent.name} not in the group chat."
|
||||
)
|
||||
elif isinstance(selected_agent, str):
|
||||
# If returned a string, assume it is a speaker selection method
|
||||
speaker_selection_method = selected_agent
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Custom speaker selection function returned an object of type {type(selected_agent)} instead of Agent or str."
|
||||
)
|
||||
|
||||
if speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS:
|
||||
raise ValueError(
|
||||
f"GroupChat speaker_selection_method is set to '{self.speaker_selection_method}'. "
|
||||
f"GroupChat speaker_selection_method is set to '{speaker_selection_method}'. "
|
||||
f"It should be one of {self._VALID_SPEAKER_SELECTION_METHODS} (case insensitive). "
|
||||
)
|
||||
|
||||
|
@ -300,7 +334,7 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
|
|||
f"GroupChat is underpopulated with {n_agents} agents. "
|
||||
"Please add more agents to the GroupChat or use direct communication instead."
|
||||
)
|
||||
elif n_agents == 2 and self.speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker:
|
||||
elif n_agents == 2 and speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker:
|
||||
logger.warning(
|
||||
f"GroupChat is underpopulated with {n_agents} agents. "
|
||||
"Consider setting speaker_selection_method to 'round_robin' or allow_repeat_speaker to False, "
|
||||
|
@ -366,11 +400,11 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
|
|||
|
||||
# Use the selected speaker selection method
|
||||
select_speaker_messages = None
|
||||
if self.speaker_selection_method.lower() == "manual":
|
||||
if speaker_selection_method.lower() == "manual":
|
||||
selected_agent = self.manual_select_speaker(graph_eligible_agents)
|
||||
elif self.speaker_selection_method.lower() == "round_robin":
|
||||
elif speaker_selection_method.lower() == "round_robin":
|
||||
selected_agent = self.next_agent(last_speaker, graph_eligible_agents)
|
||||
elif self.speaker_selection_method.lower() == "random":
|
||||
elif speaker_selection_method.lower() == "random":
|
||||
selected_agent = self.random_select_speaker(graph_eligible_agents)
|
||||
else:
|
||||
selected_agent = None
|
||||
|
|
|
@ -383,6 +383,7 @@
|
|||
"source": [
|
||||
"# load model here\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"config = config_list_custom[0]\n",
|
||||
"device = config.get(\"device\", \"cpu\")\n",
|
||||
"loaded_model = AutoModelForCausalLM.from_pretrained(config[\"model\"]).to(device)\n",
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -637,8 +637,6 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"creator = FigureCreator(name=\"Figure Creator~\", llm_config=gpt4_llm_config)\n",
|
||||
"\n",
|
||||
"user_proxy = autogen.UserProxyAgent(\n",
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
#!/usr/bin/env python3 -m pytest
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from autogen import AgentNameConflict
|
||||
from autogen import AgentNameConflict, Agent, GroupChat
|
||||
import pytest
|
||||
from unittest import mock
|
||||
import builtins
|
||||
import autogen
|
||||
import json
|
||||
import sys
|
||||
from autogen import Agent, GroupChat
|
||||
|
||||
|
||||
def test_func_call_groupchat():
|
||||
|
@ -663,7 +662,7 @@ def test_graceful_exit_before_max_round():
|
|||
max_consecutive_auto_reply=10,
|
||||
human_input_mode="NEVER",
|
||||
llm_config=False,
|
||||
default_auto_reply="This is sam speaking. TERMINATE",
|
||||
default_auto_reply="This is sam speaking.",
|
||||
)
|
||||
|
||||
# This speaker_transitions limits the transition to be only from agent1 to agent2, and from agent2 to agent3 and end.
|
||||
|
@ -682,7 +681,7 @@ def test_graceful_exit_before_max_round():
|
|||
|
||||
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False, is_termination_msg=None)
|
||||
|
||||
agent1.initiate_chat(group_chat_manager, message="'None' is_termination_msg function.")
|
||||
agent1.initiate_chat(group_chat_manager, message="")
|
||||
|
||||
# Note that 3 is much lower than 10 (max_round), so the conversation should end before 10 rounds.
|
||||
assert len(groupchat.messages) == 3
|
||||
|
@ -1007,6 +1006,184 @@ def test_nested_teams_chat():
|
|||
assert reply["content"] == team2_msg["content"]
|
||||
|
||||
|
||||
def test_custom_speaker_selection():
|
||||
a1 = autogen.UserProxyAgent(
|
||||
name="a1",
|
||||
default_auto_reply="This is a1 speaking.",
|
||||
human_input_mode="NEVER",
|
||||
code_execution_config={},
|
||||
)
|
||||
|
||||
a2 = autogen.UserProxyAgent(
|
||||
name="a2",
|
||||
default_auto_reply="This is a2 speaking.",
|
||||
human_input_mode="NEVER",
|
||||
code_execution_config={},
|
||||
)
|
||||
|
||||
a3 = autogen.UserProxyAgent(
|
||||
name="a3",
|
||||
default_auto_reply="TERMINATE",
|
||||
human_input_mode="NEVER",
|
||||
code_execution_config={},
|
||||
)
|
||||
|
||||
def custom_speaker_selection_func(last_speaker: Agent, groupchat: GroupChat) -> Agent:
|
||||
"""Define a customized speaker selection function.
|
||||
A recommended way is to define a transition for each speaker using the groupchat allowed_or_disallowed_speaker_transitions parameter.
|
||||
"""
|
||||
if last_speaker is a1:
|
||||
return a2
|
||||
elif last_speaker is a2:
|
||||
return a3
|
||||
|
||||
groupchat = autogen.GroupChat(
|
||||
agents=[a1, a2, a3],
|
||||
messages=[],
|
||||
max_round=20,
|
||||
speaker_selection_method=custom_speaker_selection_func,
|
||||
)
|
||||
manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
|
||||
|
||||
result = a1.initiate_chat(manager, message="Hello, this is a1 speaking.")
|
||||
assert len(result.chat_history) == 3
|
||||
|
||||
|
||||
def test_custom_speaker_selection_with_transition_graph():
|
||||
"""
|
||||
In this test, although speaker_selection_method is defined, the speaker transitions are also defined.
|
||||
There are 26 agents here, a to z.
|
||||
The speaker transitions are defined such that the agents can transition to the next alphabet.
|
||||
In addition, because we want the transition order to be a,u,t,o,g,e,n, we also define the speaker transitions for these agents.
|
||||
The speaker_selection_method is defined to return the next agent in the expected sequence.
|
||||
"""
|
||||
|
||||
# For loop that creates UserProxyAgent with names from a to z
|
||||
agents = [
|
||||
autogen.UserProxyAgent(
|
||||
name=chr(97 + i),
|
||||
default_auto_reply=f"My name is {chr(97 + i)}",
|
||||
human_input_mode="NEVER",
|
||||
code_execution_config={},
|
||||
)
|
||||
for i in range(26)
|
||||
]
|
||||
|
||||
# Initiate allowed speaker transitions
|
||||
allowed_or_disallowed_speaker_transitions = {}
|
||||
|
||||
# Each agent can transition to the next alphabet as a baseline
|
||||
# Key is Agent, value is a list of Agents that the key Agent can transition to
|
||||
for i in range(25):
|
||||
allowed_or_disallowed_speaker_transitions[agents[i]] = [agents[i + 1]]
|
||||
|
||||
# The test is to make sure that the agent sequence is a,u,t,o,g,e,n, so we need to add those transitions
|
||||
expected_sequence = ["a", "u", "t", "o", "g", "e", "n"]
|
||||
current_agent = None
|
||||
previous_agent = None
|
||||
|
||||
for char in expected_sequence:
|
||||
# convert char to i so that we can use chr(97+i)
|
||||
current_agent = agents[ord(char) - 97]
|
||||
if previous_agent is not None:
|
||||
# Add transition
|
||||
allowed_or_disallowed_speaker_transitions[previous_agent].append(current_agent)
|
||||
previous_agent = current_agent
|
||||
|
||||
def custom_speaker_selection_func(last_speaker: Agent, groupchat: GroupChat) -> Optional[Agent]:
|
||||
"""
|
||||
Define a customized speaker selection function.
|
||||
"""
|
||||
expected_sequence = ["a", "u", "t", "o", "g", "e", "n"]
|
||||
|
||||
last_speaker_char = last_speaker.name
|
||||
# Find the index of last_speaker_char in the expected_sequence
|
||||
last_speaker_index = expected_sequence.index(last_speaker_char)
|
||||
# Return the next agent in the expected sequence
|
||||
if last_speaker_index == len(expected_sequence) - 1:
|
||||
return None # terminate the conversation
|
||||
else:
|
||||
next_agent = agents[ord(expected_sequence[last_speaker_index + 1]) - 97]
|
||||
return next_agent
|
||||
|
||||
groupchat = autogen.GroupChat(
|
||||
agents=agents,
|
||||
messages=[],
|
||||
max_round=20,
|
||||
speaker_selection_method=custom_speaker_selection_func,
|
||||
allowed_or_disallowed_speaker_transitions=allowed_or_disallowed_speaker_transitions,
|
||||
speaker_transitions_type="allowed",
|
||||
)
|
||||
manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
|
||||
|
||||
results = agents[0].initiate_chat(manager, message="My name is a")
|
||||
actual_sequence = []
|
||||
|
||||
# Append to actual_sequence using results.chat_history[idx]['content'][-1]
|
||||
for idx in range(len(results.chat_history)):
|
||||
actual_sequence.append(results.chat_history[idx]["content"][-1]) # append the last character of the content
|
||||
|
||||
assert expected_sequence == actual_sequence
|
||||
|
||||
|
||||
def test_custom_speaker_selection_overrides_transition_graph():
|
||||
"""
|
||||
In this test, team A engineer can transition to team A executor and team B engineer, but team B engineer cannot transition to team A executor.
|
||||
The expected behaviour is that the custom speaker selection function will override the constraints of the graph.
|
||||
"""
|
||||
|
||||
# For loop that creates UserProxyAgent with names from a to z
|
||||
agents = [
|
||||
autogen.UserProxyAgent(
|
||||
name="teamA_engineer",
|
||||
default_auto_reply="My name is teamA_engineer",
|
||||
human_input_mode="NEVER",
|
||||
code_execution_config={},
|
||||
),
|
||||
autogen.UserProxyAgent(
|
||||
name="teamA_executor",
|
||||
default_auto_reply="My name is teamA_executor",
|
||||
human_input_mode="NEVER",
|
||||
code_execution_config={},
|
||||
),
|
||||
autogen.UserProxyAgent(
|
||||
name="teamB_engineer",
|
||||
default_auto_reply="My name is teamB_engineer",
|
||||
human_input_mode="NEVER",
|
||||
code_execution_config={},
|
||||
),
|
||||
]
|
||||
|
||||
allowed_or_disallowed_speaker_transitions = {}
|
||||
|
||||
# teamA_engineer can transition to teamA_executor and teamB_engineer
|
||||
# teamB_engineer can transition to no one
|
||||
allowed_or_disallowed_speaker_transitions[agents[0]] = [agents[1], agents[2]]
|
||||
|
||||
def custom_speaker_selection_func(last_speaker: Agent, groupchat: GroupChat) -> Optional[Agent]:
|
||||
if last_speaker.name == "teamA_engineer":
|
||||
return agents[2] # Goto teamB_engineer
|
||||
elif last_speaker.name == "teamB_engineer":
|
||||
return agents[1] # Goto teamA_executor and contradict the graph
|
||||
|
||||
groupchat = autogen.GroupChat(
|
||||
agents=agents,
|
||||
messages=[],
|
||||
max_round=20,
|
||||
speaker_selection_method=custom_speaker_selection_func,
|
||||
allowed_or_disallowed_speaker_transitions=allowed_or_disallowed_speaker_transitions,
|
||||
speaker_transitions_type="allowed",
|
||||
)
|
||||
manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
|
||||
results = agents[0].initiate_chat(manager, message="My name is teamA_engineer")
|
||||
|
||||
speakers = []
|
||||
for idx in range(len(results.chat_history)):
|
||||
speakers.append(results.chat_history[idx].get("name"))
|
||||
|
||||
assert "teamA_executor" in speakers
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_func_call_groupchat()
|
||||
# test_broadcast()
|
||||
|
@ -1017,7 +1194,9 @@ if __name__ == "__main__":
|
|||
# test_agent_mentions()
|
||||
# test_termination()
|
||||
# test_next_agent()
|
||||
test_send_intros()
|
||||
# test_send_intros()
|
||||
# test_invalid_allow_repeat_speaker()
|
||||
# test_graceful_exit_before_max_round()
|
||||
# test_clear_agents_history()
|
||||
test_custom_speaker_selection_overrides_transition_graph()
|
||||
# pass
|
||||
|
|
|
@ -22,6 +22,7 @@ Links to notebook examples:
|
|||
- Automated Task Solving with Coding & Planning Agents - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_planning.ipynb)
|
||||
- Automated Task Solving with transition paths specified in a graph - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_graph_modelling_language_using_select_speaker.ipynb)
|
||||
- Running a group chat as an inner-monolgue via the SocietyOfMindAgent - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_society_of_mind.ipynb)
|
||||
- Running a group chat with custom speaker selection function - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_groupchat_customized.ipynb)
|
||||
|
||||
1. **Sequential Multi-Agent Chats**
|
||||
- Solving Multiple Tasks in a Sequence of Chats Initiated by a Single Agent - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_multi_task_chats.ipynb)
|
||||
|
|
Loading…
Reference in New Issue