mirror of https://github.com/microsoft/autogen.git
Make auto reply method pluggable (#1177)
* Make auto reply method pluggable * allow richer trigger types * test list
This commit is contained in:
parent
2208dfb79e
commit
a603e6dddc
|
@ -2,12 +2,13 @@ from .agent import Agent
|
|||
from .responsive_agent import ResponsiveAgent
|
||||
from .assistant_agent import AssistantAgent
|
||||
from .user_proxy_agent import UserProxyAgent
|
||||
from .groupchat import GroupChatManager
|
||||
from .groupchat import GroupChat, GroupChatManager
|
||||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"ResponsiveAgent",
|
||||
"AssistantAgent",
|
||||
"UserProxyAgent",
|
||||
"GroupChat",
|
||||
"GroupChatManager",
|
||||
]
|
||||
|
|
|
@ -165,7 +165,7 @@ class MathUserProxyAgent(UserProxyAgent):
|
|||
default_auto_reply=default_auto_reply,
|
||||
**kwargs,
|
||||
)
|
||||
self.register_auto_reply(Agent, self._generate_math_reply, 1)
|
||||
self.register_auto_reply(Agent, MathUserProxyAgent._generate_math_reply, 1)
|
||||
# fixed var
|
||||
self._max_invalid_q_per_step = max_invalid_q_per_step
|
||||
|
||||
|
|
|
@ -1,26 +1,63 @@
|
|||
from dataclasses import dataclass
|
||||
import sys
|
||||
from typing import Dict, List, Optional, Union
|
||||
from .agent import Agent
|
||||
from .responsive_agent import ResponsiveAgent
|
||||
|
||||
|
||||
class GroupChatManager(ResponsiveAgent):
|
||||
"""(WIP) A chat manager agent that can manage a group chat of multiple agents."""
|
||||
@dataclass
|
||||
class GroupChat:
|
||||
"""A group chat class that contains a list of agents and the maximum number of rounds."""
|
||||
|
||||
agents: List[Agent]
|
||||
max_round: int
|
||||
messages: List[Dict]
|
||||
max_round: int = 10
|
||||
|
||||
@property
|
||||
def agent_names(self) -> List[str]:
|
||||
"""Return the names of the agents in the group chat."""
|
||||
return [agent.name for agent in self.agents]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the group chat."""
|
||||
self.messages.clear()
|
||||
|
||||
def agent_by_name(self, name: str) -> Agent:
|
||||
"""Find the next speaker based on the message."""
|
||||
return self.agents[self.agent_names.index(name)]
|
||||
|
||||
def next_agent(self, agent: Agent) -> Agent:
|
||||
"""Return the next agent in the list."""
|
||||
return self.agents[(self.agent_names.index(agent.name) + 1) % len(self.agents)]
|
||||
|
||||
def select_speaker_msg(self):
|
||||
"""Return the message for selecting the next speaker."""
|
||||
return f"""You are in a role play game. The following roles are available:
|
||||
{self._participant_roles()}. Read the following conversation.
|
||||
Then select the next role from {self.agent_names} to play. Only return the role."""
|
||||
|
||||
def select_speaker(self, last_speaker: Agent, selctor: ResponsiveAgent):
|
||||
"""Select the next speaker."""
|
||||
selctor.update_system_message(self.select_speaker_msg())
|
||||
final, name = selctor.generate_oai_reply(self.messages)
|
||||
if not final:
|
||||
# i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id
|
||||
return self.next_agent(last_speaker)
|
||||
try:
|
||||
return self.agent_by_name(name)
|
||||
except ValueError:
|
||||
return self.next_agent(last_speaker)
|
||||
|
||||
def _participant_roles(self):
|
||||
return "\n".join([f"{agent.name}: {agent.system_message}" for agent in self.agents])
|
||||
|
||||
def _select_speaker_msg(self):
|
||||
return f"""You are in a role play game. The following roles are available:
|
||||
{self._participant_roles()}. Read the following conversation.
|
||||
Then select the next role from {self._agent_names} to play. Only return the role."""
|
||||
|
||||
class GroupChatManager(ResponsiveAgent):
|
||||
"""(WIP) A chat manager agent that can manage a group chat of multiple agents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_round: Optional[int] = 10,
|
||||
groupchat: GroupChat,
|
||||
name: Optional[str] = "chat_manager",
|
||||
# unlimited consecutive auto reply by default
|
||||
max_consecutive_auto_reply: Optional[int] = sys.maxsize,
|
||||
|
@ -33,56 +70,35 @@ Then select the next role from {self._agent_names} to play. Only return the role
|
|||
name=name,
|
||||
max_consecutive_auto_reply=max_consecutive_auto_reply,
|
||||
human_input_mode=human_input_mode,
|
||||
system_message=system_message,
|
||||
**kwargs,
|
||||
)
|
||||
self.register_auto_reply(Agent, self._generate_reply_for_participant)
|
||||
self.max_round = max_round
|
||||
self._agent_names = []
|
||||
self._messages = []
|
||||
self.register_auto_reply(Agent, GroupChatManager.run_chat, context=groupchat, reset_context=GroupChat.reset)
|
||||
# self._random = random.Random(seed)
|
||||
|
||||
def _generate_reply_for_participant(
|
||||
def run_chat(
|
||||
self,
|
||||
messages: Optional[List[Dict]] = None,
|
||||
sender: Optional[Agent] = None,
|
||||
context: Optional[GroupChat] = None,
|
||||
) -> Union[str, Dict, None]:
|
||||
self._agent_names = [agent.name for agent in self.agents]
|
||||
"""Run a group chat."""
|
||||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
message = messages[-1]
|
||||
speaker = sender
|
||||
for i in range(self.max_round):
|
||||
for i in range(context.max_round):
|
||||
# set the name to speaker's name if the role is not function
|
||||
if message["role"] != "function":
|
||||
message["name"] = speaker.name
|
||||
self._messages.append(message)
|
||||
context.messages.append(message)
|
||||
# broadcast the message to all agents except the speaker
|
||||
for agent in self.agents:
|
||||
for agent in context.agents:
|
||||
if agent != speaker:
|
||||
self.send(message, agent, request_reply=False)
|
||||
if i != self.max_round - 1:
|
||||
if i != context.max_round - 1:
|
||||
# speaker selection msg from an agent
|
||||
speaker = self._select_speaker(speaker)
|
||||
speaker = context.select_speaker(speaker, self)
|
||||
speaker.send(speaker.generate_reply(sender=self), self, request_reply=False)
|
||||
message = self.last_message(speaker)
|
||||
return True, None
|
||||
|
||||
def _select_speaker(self, last_speaker: Agent):
|
||||
"""Select the next speaker."""
|
||||
self.update_system_message(self._select_speaker_msg())
|
||||
final, name = self._generate_oai_reply(self._messages)
|
||||
if not final:
|
||||
# i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id
|
||||
return self.agents[(self._agent_names.index(last_speaker.name) + 1) % len(self._agent_names)]
|
||||
try:
|
||||
return self.agent_by_name(name)
|
||||
except ValueError:
|
||||
return self.agents[(self._agent_names.index(last_speaker.name) + 1) % len(self._agent_names)]
|
||||
|
||||
def agent_by_name(self, name: str) -> Agent:
|
||||
"""Find the next speaker based on the message."""
|
||||
return self.agents[self._agent_names.index(name)]
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
self._messages.clear()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from collections import defaultdict
|
||||
import copy
|
||||
import json
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
from flaml.autogen import oai
|
||||
from .agent import Agent
|
||||
from flaml.autogen.code_utils import DEFAULT_MODEL, UNKNOWN, execute_code, extract_code, infer_lang
|
||||
|
@ -108,26 +109,64 @@ class ResponsiveAgent(Agent):
|
|||
self._max_consecutive_auto_reply_dict = defaultdict(self.max_consecutive_auto_reply)
|
||||
self._function_map = {} if function_map is None else function_map
|
||||
self._default_auto_reply = default_auto_reply
|
||||
self._class_specific_reply = []
|
||||
self._reply_func_list = []
|
||||
self.reply_at_receive = defaultdict(bool)
|
||||
self.register_auto_reply(Agent, self._generate_oai_reply)
|
||||
self.register_auto_reply(Agent, self._generate_code_execution_reply)
|
||||
self.register_auto_reply(Agent, self._generate_function_call_reply)
|
||||
self.register_auto_reply(Agent, self._check_termination_and_human_reply)
|
||||
self.register_auto_reply(Agent, ResponsiveAgent.generate_oai_reply)
|
||||
self.register_auto_reply(Agent, ResponsiveAgent.generate_code_execution_reply)
|
||||
self.register_auto_reply(Agent, ResponsiveAgent.generate_function_call_reply)
|
||||
self.register_auto_reply(Agent, ResponsiveAgent.check_termination_and_human_reply)
|
||||
|
||||
def register_auto_reply(self, class_type, reply_func: Callable, position: int = 0):
|
||||
"""Register a class-specific reply function.
|
||||
def register_auto_reply(
|
||||
self,
|
||||
trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List],
|
||||
reply_func: Callable,
|
||||
position: Optional[int] = 0,
|
||||
context: Optional[Any] = None,
|
||||
reset_context: Optional[Callable] = None,
|
||||
):
|
||||
"""Register a reply function.
|
||||
|
||||
The class-specific reply function will be called when the sender is an instance of the class_type.
|
||||
The reply function will be called when the trigger matches the sender.
|
||||
The function registered later will be checked earlier by default.
|
||||
To change the order, set the position to a positive integer.
|
||||
|
||||
Args:
|
||||
class_type (Class): the class type.
|
||||
trigger (Agent class, str, Agent instance, callable, or list): the trigger.
|
||||
- If a class is provided, the reply function will be called when the sender is an instance of the class.
|
||||
- If a string is provided, the reply function will be called when the sender's name matches the string.
|
||||
- If an agent instance is provided, the reply function will be called when the sender is the agent instance.
|
||||
- If a callable is provided, the reply function will be called when the callable returns True.
|
||||
- If a list is provided, the reply function will be called when any of the triggers in the list is activated.
|
||||
reply_func (Callable): the reply function.
|
||||
The function takes a recipient agent, a list of messages, a sender agent and a context as input and returns a reply message.
|
||||
```python
|
||||
def reply_func(
|
||||
recipient: ResponsiveAgent,
|
||||
messages: Optional[List[Dict]] = None,
|
||||
sender: Optional[Agent] = None,
|
||||
context: Optional[Any] = None,
|
||||
) -> Union[str, Dict, None]:
|
||||
```
|
||||
position (int): the position of the reply function in the reply function list.
|
||||
The function registered later will be checked earlier by default.
|
||||
To change the order, set the position to a positive integer.
|
||||
context (Any): the context to be passed to the reply function.
|
||||
When an agent is reset, the context will be reset to the original value.
|
||||
reset_context (Callable): the function to reset the context.
|
||||
The function returns None. Signature: ```def reset_context(context: Any)```
|
||||
"""
|
||||
self._class_specific_reply.insert(position, (class_type, reply_func))
|
||||
if not isinstance(trigger, (type, str, Agent, Callable, list)):
|
||||
raise ValueError("trigger must be a class, a string, an agent, a callable or a list.")
|
||||
self._reply_func_list.insert(
|
||||
position,
|
||||
{
|
||||
"trigger": trigger,
|
||||
"reply_func": reply_func,
|
||||
"context": copy.copy(context),
|
||||
"init_context": context,
|
||||
"reset_context": reset_context,
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def system_message(self):
|
||||
|
@ -362,6 +401,11 @@ class ResponsiveAgent(Agent):
|
|||
self.clear_history()
|
||||
self.reset_consecutive_auto_reply_counter()
|
||||
self.stop_reply_at_receive()
|
||||
for reply_func_tuple in self._reply_func_list:
|
||||
if reply_func_tuple["reset_context"] is not None:
|
||||
reply_func_tuple["reset_context"](reply_func_tuple["context"])
|
||||
else:
|
||||
reply_func_tuple["context"] = copy.copy(reply_func_tuple["init_context"])
|
||||
|
||||
def stop_reply_at_receive(self, sender: Optional[Agent] = None):
|
||||
"""Reset the reply_at_receive of the sender."""
|
||||
|
@ -388,28 +432,34 @@ class ResponsiveAgent(Agent):
|
|||
else:
|
||||
self._oai_messages[agent].clear()
|
||||
|
||||
def _generate_oai_reply(
|
||||
def generate_oai_reply(
|
||||
self,
|
||||
messages: Optional[List[Dict]] = None,
|
||||
sender: Optional[Agent] = None,
|
||||
context: Optional[Any] = None,
|
||||
) -> Tuple[bool, Union[str, Dict, None]]:
|
||||
if self.llm_config is False:
|
||||
"""Generate a reply using autogen.oai."""
|
||||
llm_config = self.llm_config if context is None else context
|
||||
if llm_config is False:
|
||||
return False, None
|
||||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
|
||||
# TODO: #1143 handle token limit exceeded error
|
||||
response = oai.ChatCompletion.create(
|
||||
context=messages[-1].pop("context", None), messages=self._oai_system_message + messages, **self.llm_config
|
||||
context=messages[-1].pop("context", None), messages=self._oai_system_message + messages, **llm_config
|
||||
)
|
||||
return True, oai.ChatCompletion.extract_text_or_function_call(response)[0]
|
||||
|
||||
def _generate_code_execution_reply(
|
||||
def generate_code_execution_reply(
|
||||
self,
|
||||
messages: Optional[List[Dict]] = None,
|
||||
sender: Optional[Agent] = None,
|
||||
context: Optional[Any] = None,
|
||||
):
|
||||
if self._code_execution_config is False:
|
||||
"""Generate a reply using code execution."""
|
||||
code_execution_config = context if context is not None else self._code_execution_config
|
||||
if code_execution_config is False:
|
||||
return False, None
|
||||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
|
@ -426,11 +476,15 @@ class ResponsiveAgent(Agent):
|
|||
exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed"
|
||||
return True, f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs}"
|
||||
|
||||
def _generate_function_call_reply(
|
||||
def generate_function_call_reply(
|
||||
self,
|
||||
messages: Optional[List[Dict]] = None,
|
||||
sender: Optional[Agent] = None,
|
||||
context: Optional[Any] = None,
|
||||
):
|
||||
"""Generate a reply using function call."""
|
||||
if context is None:
|
||||
context = self
|
||||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
message = messages[-1]
|
||||
|
@ -439,11 +493,15 @@ class ResponsiveAgent(Agent):
|
|||
return True, func_return
|
||||
return False, None
|
||||
|
||||
def _check_termination_and_human_reply(
|
||||
def check_termination_and_human_reply(
|
||||
self,
|
||||
messages: Optional[List[Dict]] = None,
|
||||
sender: Optional[Agent] = None,
|
||||
context: Optional[Any] = None,
|
||||
) -> Tuple[bool, Union[str, Dict, None]]:
|
||||
"""Check if the conversation should be terminated, and if human reply is provided."""
|
||||
if context is None:
|
||||
context = self
|
||||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
message = messages[-1]
|
||||
|
@ -538,15 +596,32 @@ class ResponsiveAgent(Agent):
|
|||
"""
|
||||
assert messages is not None or sender is not None, "Either messages or sender must be provided."
|
||||
if sender is not None:
|
||||
for class_specifc_reply in self._class_specific_reply:
|
||||
if isinstance(sender, class_specifc_reply[0]) and (
|
||||
not exclude or class_specifc_reply[1] not in exclude
|
||||
):
|
||||
final, reply = class_specifc_reply[1](messages, sender)
|
||||
for reply_func_tuple in self._reply_func_list:
|
||||
if exclude and reply_func_tuple["reply_func"] in exclude:
|
||||
continue
|
||||
if self._match_trigger(reply_func_tuple["trigger"], sender):
|
||||
final, reply = reply_func_tuple["reply_func"](
|
||||
self, messages=messages, sender=sender, context=reply_func_tuple["context"]
|
||||
)
|
||||
if final:
|
||||
return reply
|
||||
return self._default_auto_reply
|
||||
|
||||
def _match_trigger(self, trigger, sender):
|
||||
"""Check if the sender matches the trigger."""
|
||||
if isinstance(trigger, str):
|
||||
return trigger == sender.name
|
||||
elif isinstance(trigger, type):
|
||||
return isinstance(sender, trigger)
|
||||
elif isinstance(trigger, Agent):
|
||||
return trigger == sender
|
||||
elif isinstance(trigger, Callable):
|
||||
return trigger(sender)
|
||||
elif isinstance(trigger, list):
|
||||
return any(self._match_trigger(t, sender) for t in trigger)
|
||||
else:
|
||||
raise ValueError(f"Unsupported trigger type: {type(trigger)}")
|
||||
|
||||
def get_human_input(self, prompt: str) -> str:
|
||||
"""Get human input.
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ class UserProxyAgent(ResponsiveAgent):
|
|||
UserProxyAgent is a subclass of ResponsiveAgent configured with `human_input_mode` to ALWAYS
|
||||
and `llm_config` to False. By default, the agent will prompt for human input every time a message is received.
|
||||
Code execution is enabled by default. LLM-based auto reply is disabled by default.
|
||||
To modify auto reply, register a method with `register_class_specific_reply`.
|
||||
To modify auto reply, register a method with (`register_auto_reply`)[responsive_agent#register_auto_reply].
|
||||
The method should have a similar signature with `_generate_oai_reply` method.
|
||||
To modify the way to get human input, override `get_human_input` method.
|
||||
To modify the way to execute code blocks, single code block, or function call, override `execute_code_blocks`,
|
||||
|
|
|
@ -137,7 +137,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"from collections import defaultdict\n",
|
||||
"from typing import Dict, List, Optional, Union\n",
|
||||
"from typing import Any, Dict, List, Optional, Union\n",
|
||||
"\n",
|
||||
"sys_msg = \"\"\"You are an AI-powered chess board agent.\n",
|
||||
"You translate user's natural language input into legal UCI moves.\n",
|
||||
|
@ -164,7 +164,7 @@
|
|||
" llm_config={\"temperature\": 0.0, \"config_list\": config_list_gpt4},\n",
|
||||
" max_consecutive_auto_reply=10,\n",
|
||||
" )\n",
|
||||
" self.register_auto_reply(autogen.ResponsiveAgent, self._generate_board_reply)\n",
|
||||
" self.register_auto_reply(autogen.ResponsiveAgent, BoardAgent._generate_board_reply)\n",
|
||||
" self._board = board\n",
|
||||
" self._correct_move_messages = defaultdict(list)\n",
|
||||
"\n",
|
||||
|
@ -172,6 +172,7 @@
|
|||
" self,\n",
|
||||
" messages: Optional[List[Dict]] = None,\n",
|
||||
" sender: Optional[autogen.Agent] = None,\n",
|
||||
" context: Optional[Any] = None,\n",
|
||||
" ) -> Union[str, Dict, None]:\n",
|
||||
" # Filter for messages that do not contain error.\n",
|
||||
" if messages is None:\n",
|
||||
|
@ -179,7 +180,7 @@
|
|||
" message = messages[-1]\n",
|
||||
" assert message.get(\"role\") == \"user\"\n",
|
||||
" # extract a UCI move from player's message\n",
|
||||
" reply = self.generate_reply(self._correct_move_messages[sender] + [message], sender, exclude=[self._generate_board_reply])\n",
|
||||
" reply = self.generate_reply(self._correct_move_messages[sender] + [message], sender, exclude=[BoardAgent._generate_board_reply])\n",
|
||||
" if isinstance(reply, str):\n",
|
||||
" uci_move = reply\n",
|
||||
" else:\n",
|
||||
|
@ -242,8 +243,8 @@
|
|||
" max_consecutive_auto_reply=max_turns,\n",
|
||||
" **kwargs,\n",
|
||||
" )\n",
|
||||
" self.register_auto_reply(BoardAgent, self._generate_reply_for_board)\n",
|
||||
" self.register_auto_reply(ChessPlayerAgent, self._generate_reply_for_player)\n",
|
||||
" self.register_auto_reply(BoardAgent, ChessPlayerAgent._generate_reply_for_board)\n",
|
||||
" self.register_auto_reply(ChessPlayerAgent, ChessPlayerAgent._generate_reply_for_player)\n",
|
||||
" self._board_agent = board_agent\n",
|
||||
" self.update_max_consecutive_auto_reply(self._board_agent.max_consecutive_auto_reply(), self._board_agent)\n",
|
||||
"\n",
|
||||
|
@ -251,6 +252,7 @@
|
|||
" self,\n",
|
||||
" messages: Optional[List[Dict]] = None,\n",
|
||||
" sender: Optional[autogen.Agent] = None,\n",
|
||||
" context: Optional[Any] = None,\n",
|
||||
" ) -> Union[str, Dict, None]:\n",
|
||||
" if messages is None:\n",
|
||||
" messages = self._oai_messages[sender]\n",
|
||||
|
@ -260,7 +262,7 @@
|
|||
" if last_message[\"content\"].startswith(\"Error\"):\n",
|
||||
" # try again\n",
|
||||
" last_message[\"role\"] = \"system\"\n",
|
||||
" return True, self.generate_reply(messages + board_state_msg, sender, exclude=[self._generate_reply_for_board])\n",
|
||||
" return True, self.generate_reply(messages + board_state_msg, sender, exclude=[ChessPlayerAgent._generate_reply_for_board])\n",
|
||||
" else:\n",
|
||||
" return True, None\n",
|
||||
"\n",
|
||||
|
@ -268,13 +270,14 @@
|
|||
" self,\n",
|
||||
" messages: Optional[List[Dict]] = None,\n",
|
||||
" sender: Optional[autogen.Agent] = None,\n",
|
||||
" context: Optional[Any] = None,\n",
|
||||
" ) -> Union[str, Dict, None]:\n",
|
||||
" if messages is None:\n",
|
||||
" messages = self._oai_messages[sender]\n",
|
||||
" # add a system message about the current state of the board.\n",
|
||||
" board_state_msg = [{\"role\": \"system\", \"content\": f\"Current board:\\n{self._board_agent._board}\"}]\n",
|
||||
" # propose a reply which will be sent to the board agent for verification.\n",
|
||||
" message = self.generate_reply(messages + board_state_msg, sender, exclude=[self._generate_reply_for_player])\n",
|
||||
" message = self.generate_reply(messages + board_state_msg, sender, exclude=[ChessPlayerAgent._generate_reply_for_player])\n",
|
||||
" if message is None:\n",
|
||||
" return True, None\n",
|
||||
" # converse with the board until a legal move is made or max allowed retries.\n",
|
||||
|
@ -467,7 +470,13 @@
|
|||
"g1f3. \n",
|
||||
"Aiming to control the center of the board. Your move.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[33mPlayer black\u001b[0m (to BoardAgent):\n",
|
||||
"\n",
|
||||
"g8f6. \n",
|
||||
|
|
|
@ -124,7 +124,6 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"llm_config = {\"config_list\": config_list_gpt4}\n",
|
||||
"group_chat_manager = autogen.GroupChatManager(max_round=4, llm_config=llm_config)\n",
|
||||
"human = autogen.UserProxyAgent(\n",
|
||||
" name=\"Human\",\n",
|
||||
" system_message=\"A human admin.\",\n",
|
||||
|
@ -138,8 +137,8 @@
|
|||
" system_message=\"Code reviewer. Prevent code execution if unsafe or not well documented. Suggest changes. Otherwise, approve and return the final code to execute.\",\n",
|
||||
" llm_config=llm_config,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"group_chat_manager.agents = [human, alice, bob]"
|
||||
"groupchat = autogen.GroupChat(agents=[human, alice, bob], messages=[], max_round=4)\n",
|
||||
"manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -416,7 +415,13 @@
|
|||
"\n",
|
||||
"Always use this script carefully because web-scraping isn't always reliable or legal on all web pages. Always ensure you have express permission or that the website's terms and conditions don't forbid this kind of usage.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n",
|
||||
"\u001b[31m\n",
|
||||
|
@ -454,7 +459,7 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"human.initiate_chat(group_chat_manager, message=\"find a latest paper about generative agents\")"
|
||||
"human.initiate_chat(manager, message=\"find a latest paper about generative agents\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
|
@ -2,7 +2,6 @@ from flaml import autogen
|
|||
|
||||
|
||||
def test_chat_manager():
|
||||
group_chat_manager = autogen.GroupChatManager(max_round=2, llm_config=False)
|
||||
agent1 = autogen.ResponsiveAgent(
|
||||
"alice",
|
||||
max_consecutive_auto_reply=2,
|
||||
|
@ -17,17 +16,52 @@ def test_chat_manager():
|
|||
llm_config=False,
|
||||
default_auto_reply="This is bob speaking.",
|
||||
)
|
||||
group_chat_manager.agents = [agent1, agent2]
|
||||
groupchat = autogen.GroupChat(agents=[agent1, agent2], messages=[], max_round=2)
|
||||
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
|
||||
agent1.initiate_chat(group_chat_manager, message="hello")
|
||||
|
||||
assert len(agent1.chat_messages[group_chat_manager]) == 2
|
||||
assert len(groupchat.messages) == 2
|
||||
|
||||
group_chat_manager.reset()
|
||||
assert len(groupchat.messages) == 0
|
||||
agent1.reset()
|
||||
agent2.reset()
|
||||
agent2.initiate_chat(group_chat_manager, message="hello")
|
||||
assert len(groupchat.messages) == 2
|
||||
|
||||
|
||||
def test_plugin():
|
||||
# Give another Agent class ability to manage group chat
|
||||
agent1 = autogen.ResponsiveAgent(
|
||||
"alice",
|
||||
max_consecutive_auto_reply=2,
|
||||
human_input_mode="NEVER",
|
||||
llm_config=False,
|
||||
default_auto_reply="This is alice sepaking.",
|
||||
)
|
||||
agent2 = autogen.ResponsiveAgent(
|
||||
"bob",
|
||||
max_consecutive_auto_reply=2,
|
||||
human_input_mode="NEVER",
|
||||
llm_config=False,
|
||||
default_auto_reply="This is bob speaking.",
|
||||
)
|
||||
groupchat = autogen.GroupChat(agents=[agent1, agent2], messages=[], max_round=2)
|
||||
group_chat_manager = autogen.ResponsiveAgent(name="deputy_manager", llm_config=False)
|
||||
group_chat_manager.register_auto_reply(
|
||||
autogen.Agent,
|
||||
reply_func=autogen.GroupChatManager.run_chat,
|
||||
context=groupchat,
|
||||
reset_context=autogen.GroupChat.reset,
|
||||
)
|
||||
agent1.initiate_chat(group_chat_manager, message="hello")
|
||||
|
||||
assert len(agent1.chat_messages[group_chat_manager]) == 2
|
||||
assert len(groupchat.messages) == 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_broadcast()
|
||||
test_chat_manager()
|
||||
# test_chat_manager()
|
||||
test_plugin()
|
||||
|
|
|
@ -2,6 +2,44 @@ import pytest
|
|||
from flaml.autogen.agentchat import ResponsiveAgent
|
||||
|
||||
|
||||
def test_trigger():
|
||||
agent = ResponsiveAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
|
||||
agent1 = ResponsiveAgent("a1", max_consecutive_auto_reply=0, human_input_mode="NEVER")
|
||||
agent.register_auto_reply(agent1, lambda recipient, messages, sender, context: (True, "hello"))
|
||||
agent1.initiate_chat(agent, message="hi")
|
||||
assert agent1.last_message(agent)["content"] == "hello"
|
||||
agent.register_auto_reply("a1", lambda recipient, messages, sender, context: (True, "hello a1"))
|
||||
agent1.initiate_chat(agent, message="hi")
|
||||
assert agent1.last_message(agent)["content"] == "hello a1"
|
||||
agent.register_auto_reply(
|
||||
ResponsiveAgent, lambda recipient, messages, sender, context: (True, "hello responsive agent")
|
||||
)
|
||||
agent1.initiate_chat(agent, message="hi")
|
||||
assert agent1.last_message(agent)["content"] == "hello responsive agent"
|
||||
agent.register_auto_reply(
|
||||
lambda sender: sender.name.startswith("a"), lambda recipient, messages, sender, context: (True, "hello a")
|
||||
)
|
||||
agent1.initiate_chat(agent, message="hi")
|
||||
assert agent1.last_message(agent)["content"] == "hello a"
|
||||
agent.register_auto_reply(
|
||||
lambda sender: sender.name.startswith("b"), lambda recipient, messages, sender, context: (True, "hello b")
|
||||
)
|
||||
agent1.initiate_chat(agent, message="hi")
|
||||
assert agent1.last_message(agent)["content"] == "hello a"
|
||||
agent.register_auto_reply(
|
||||
["agent2", agent1], lambda recipient, messages, sender, context: (True, "hello agent2 or agent1")
|
||||
)
|
||||
agent1.initiate_chat(agent, message="hi")
|
||||
assert agent1.last_message(agent)["content"] == "hello agent2 or agent1"
|
||||
agent.register_auto_reply(
|
||||
["agent2", "agent3"], lambda recipient, messages, sender, context: (True, "hello agent2 or agent3")
|
||||
)
|
||||
agent1.initiate_chat(agent, message="hi")
|
||||
assert agent1.last_message(agent)["content"] == "hello agent2 or agent1"
|
||||
pytest.raises(ValueError, agent.register_auto_reply, 1, lambda recipient, messages, sender, context: (True, "hi"))
|
||||
pytest.raises(ValueError, agent._match_trigger, 1, agent1)
|
||||
|
||||
|
||||
def test_context():
|
||||
agent = ResponsiveAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
|
||||
agent1 = ResponsiveAgent("a1", max_consecutive_auto_reply=0, human_input_mode="NEVER")
|
||||
|
@ -117,6 +155,7 @@ def test_responsive_agent():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_context()
|
||||
test_trigger()
|
||||
# test_context()
|
||||
# test_max_consecutive_auto_reply()
|
||||
# test_responsive_agent(pytest.monkeypatch)
|
||||
|
|
Loading…
Reference in New Issue