Support async nested chats (#3309)

* Allow async nested chats in agent chat

* Fix pre-comit

* Minor fix

* Fix

* Address feedback

* Update

* Fix build error

---------

Co-authored-by: Qingyun Wu <qingyun0327@gmail.com>
This commit is contained in:
Aamir 2024-08-08 18:14:33 -07:00 committed by GitHub
parent 4dab28c769
commit aac6f05117
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 297 additions and 10 deletions

View File

@ -377,9 +377,9 @@ class ConversableAgent(LLMAgent):
f["reply_func"] = new_reply_func f["reply_func"] = new_reply_func
@staticmethod @staticmethod
def _summary_from_nested_chats( def _get_chats_to_run(
chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
) -> Tuple[bool, str]: ) -> List[Dict[str, Any]]:
"""A simple chat reply function. """A simple chat reply function.
This function initiate one or a sequence of chats between the "recipient" and the agents in the This function initiate one or a sequence of chats between the "recipient" and the agents in the
chat_queue. chat_queue.
@ -406,22 +406,59 @@ class ConversableAgent(LLMAgent):
if message: if message:
current_c["message"] = message current_c["message"] = message
chat_to_run.append(current_c) chat_to_run.append(current_c)
return chat_to_run
@staticmethod
def _summary_from_nested_chats(
chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
) -> Tuple[bool, Union[str, None]]:
"""A simple chat reply function.
This function initiate one or a sequence of chats between the "recipient" and the agents in the
chat_queue.
It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue.
Returns:
Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated.
"""
chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config)
if not chat_to_run: if not chat_to_run:
return True, None return True, None
res = initiate_chats(chat_to_run) res = initiate_chats(chat_to_run)
return True, res[-1].summary return True, res[-1].summary
@staticmethod
async def _a_summary_from_nested_chats(
chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
) -> Tuple[bool, Union[str, None]]:
"""A simple chat reply function.
This function initiate one or a sequence of chats between the "recipient" and the agents in the
chat_queue.
It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue.
Returns:
Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated.
"""
chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config)
if not chat_to_run:
return True, None
res = await a_initiate_chats(chat_to_run)
index_of_last_chat = chat_to_run[-1]["chat_id"]
return True, res[index_of_last_chat].summary
def register_nested_chats( def register_nested_chats(
self, self,
chat_queue: List[Dict[str, Any]], chat_queue: List[Dict[str, Any]],
trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List], trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List],
reply_func_from_nested_chats: Union[str, Callable] = "summary_from_nested_chats", reply_func_from_nested_chats: Union[str, Callable] = "summary_from_nested_chats",
position: int = 2, position: int = 2,
use_async: Union[bool, None] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
"""Register a nested chat reply function. """Register a nested chat reply function.
Args: Args:
chat_queue (list): a list of chat objects to be initiated. chat_queue (list): a list of chat objects to be initiated. If use_async is used, then all messages in chat_queue must have a chat-id associated with them.
trigger (Agent class, str, Agent instance, callable, or list): refer to `register_reply` for details. trigger (Agent class, str, Agent instance, callable, or list): refer to `register_reply` for details.
reply_func_from_nested_chats (Callable, str): the reply function for the nested chat. reply_func_from_nested_chats (Callable, str): the reply function for the nested chat.
The function takes a chat_queue for nested chat, recipient agent, a list of messages, a sender agent and a config as input and returns a reply message. The function takes a chat_queue for nested chat, recipient agent, a list of messages, a sender agent and a config as input and returns a reply message.
@ -436,15 +473,33 @@ class ConversableAgent(LLMAgent):
) -> Tuple[bool, Union[str, Dict, None]]: ) -> Tuple[bool, Union[str, Dict, None]]:
``` ```
position (int): Ref to `register_reply` for details. Default to 2. It means we first check the termination and human reply, then check the registered nested chat reply. position (int): Ref to `register_reply` for details. Default to 2. It means we first check the termination and human reply, then check the registered nested chat reply.
use_async: Uses a_initiate_chats internally to start nested chats. If the original chat is initiated with a_initiate_chats, you may set this to true so nested chats do not run in sync.
kwargs: Ref to `register_reply` for details. kwargs: Ref to `register_reply` for details.
""" """
if reply_func_from_nested_chats == "summary_from_nested_chats": if use_async:
reply_func_from_nested_chats = self._summary_from_nested_chats for chat in chat_queue:
if not callable(reply_func_from_nested_chats): if chat.get("chat_id") is None:
raise ValueError("reply_func_from_nested_chats must be a callable") raise ValueError("chat_id is required for async nested chats")
def wrapped_reply_func(recipient, messages=None, sender=None, config=None): if use_async:
return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config) if reply_func_from_nested_chats == "summary_from_nested_chats":
reply_func_from_nested_chats = self._a_summary_from_nested_chats
if not callable(reply_func_from_nested_chats) or not inspect.iscoroutinefunction(
reply_func_from_nested_chats
):
raise ValueError("reply_func_from_nested_chats must be a callable and a coroutine")
async def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
return await reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config)
else:
if reply_func_from_nested_chats == "summary_from_nested_chats":
reply_func_from_nested_chats = self._summary_from_nested_chats
if not callable(reply_func_from_nested_chats):
raise ValueError("reply_func_from_nested_chats must be a callable")
def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config)
functools.update_wrapper(wrapped_reply_func, reply_func_from_nested_chats) functools.update_wrapper(wrapped_reply_func, reply_func_from_nested_chats)
@ -454,7 +509,9 @@ class ConversableAgent(LLMAgent):
position, position,
kwargs.get("config"), kwargs.get("config"),
kwargs.get("reset_config"), kwargs.get("reset_config"),
ignore_async_in_sync_chat=kwargs.get("ignore_async_in_sync_chat"), ignore_async_in_sync_chat=(
not use_async if use_async is not None else kwargs.get("ignore_async_in_sync_chat")
),
) )
@property @property

View File

@ -2,10 +2,12 @@
import os import os
import sys import sys
from typing import List
import pytest import pytest
import autogen import autogen
from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
sys.path.append(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
@ -13,6 +15,23 @@ from conftest import reason, skip_openai # noqa: E402
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402 from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
class MockAgentReplies(AgentCapability):
def __init__(self, mock_messages: List[str]):
self.mock_messages = mock_messages
self.mock_message_index = 0
def add_to_agent(self, agent: autogen.ConversableAgent):
def mock_reply(recipient, messages, sender, config):
if self.mock_message_index < len(self.mock_messages):
reply_msg = self.mock_messages[self.mock_message_index]
self.mock_message_index += 1
return [True, reply_msg]
else:
raise ValueError(f"No more mock messages available for {sender.name} to reply to {recipient.name}")
agent.register_reply([autogen.Agent, None], mock_reply, position=2)
@pytest.mark.skipif(skip_openai, reason=reason) @pytest.mark.skipif(skip_openai, reason=reason)
def test_nested(): def test_nested():
config_list = autogen.config_list_from_json(env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC) config_list = autogen.config_list_from_json(env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC)
@ -142,5 +161,216 @@ def test_nested():
) )
def test_sync_nested_chat():
def is_termination(msg):
if isinstance(msg, str) and msg == "FINAL_RESULT":
return True
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
return True
return False
inner_assistant = autogen.AssistantAgent(
"Inner-assistant",
is_termination_msg=is_termination,
)
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)
inner_assistant_2 = autogen.AssistantAgent(
"Inner-assistant-2",
)
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
inner_assistant_2
)
assistant = autogen.AssistantAgent(
"Assistant",
)
user = autogen.UserProxyAgent(
"User",
human_input_mode="NEVER",
is_termination_msg=is_termination,
)
assistant.register_nested_chats(
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg"}], trigger=user
)
chat_result = user.initiate_chat(assistant, message="Start chat")
assert len(chat_result.chat_history) == 2
chat_messages = [msg["content"] for msg in chat_result.chat_history]
assert chat_messages == ["Start chat", "FINAL_RESULT"]
@pytest.mark.asyncio
async def test_async_nested_chat():
def is_termination(msg):
if isinstance(msg, str) and msg == "FINAL_RESULT":
return True
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
return True
return False
inner_assistant = autogen.AssistantAgent(
"Inner-assistant",
is_termination_msg=is_termination,
)
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)
inner_assistant_2 = autogen.AssistantAgent(
"Inner-assistant-2",
)
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
inner_assistant_2
)
assistant = autogen.AssistantAgent(
"Assistant",
)
user = autogen.UserProxyAgent(
"User",
human_input_mode="NEVER",
is_termination_msg=is_termination,
)
assistant.register_nested_chats(
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg", "chat_id": 1}],
trigger=user,
use_async=True,
)
chat_result = await user.a_initiate_chat(assistant, message="Start chat")
assert len(chat_result.chat_history) == 2
chat_messages = [msg["content"] for msg in chat_result.chat_history]
assert chat_messages == ["Start chat", "FINAL_RESULT"]
@pytest.mark.asyncio
async def test_async_nested_chat_chat_id_validation():
def is_termination(msg):
if isinstance(msg, str) and msg == "FINAL_RESULT":
return True
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
return True
return False
inner_assistant = autogen.AssistantAgent(
"Inner-assistant",
is_termination_msg=is_termination,
)
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)
inner_assistant_2 = autogen.AssistantAgent(
"Inner-assistant-2",
)
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
inner_assistant_2
)
assistant = autogen.AssistantAgent(
"Assistant",
)
user = autogen.UserProxyAgent(
"User",
human_input_mode="NEVER",
is_termination_msg=is_termination,
)
with pytest.raises(ValueError, match="chat_id is required for async nested chats"):
assistant.register_nested_chats(
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg"}],
trigger=user,
use_async=True,
)
def test_sync_nested_chat_in_group():
def is_termination(msg):
if isinstance(msg, str) and msg == "FINAL_RESULT":
return True
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
return True
return False
inner_assistant = autogen.AssistantAgent(
"Inner-assistant",
is_termination_msg=is_termination,
)
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)
inner_assistant_2 = autogen.AssistantAgent(
"Inner-assistant-2",
)
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
inner_assistant_2
)
assistant = autogen.AssistantAgent(
"Assistant_In_Group_1",
)
MockAgentReplies(["Assistant_In_Group_1 message 1"]).add_to_agent(assistant)
assistant2 = autogen.AssistantAgent(
"Assistant_In_Group_2",
)
user = autogen.UserProxyAgent("User", human_input_mode="NEVER", is_termination_msg=is_termination)
group = autogen.GroupChat(
agents=[assistant, assistant2, user],
messages=[],
speaker_selection_method="round_robin",
)
group_manager = autogen.GroupChatManager(groupchat=group)
assistant2.register_nested_chats(
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg"}],
trigger=group_manager,
)
chat_result = user.initiate_chat(group_manager, message="Start chat", summary_method="last_msg")
assert len(chat_result.chat_history) == 3
chat_messages = [msg["content"] for msg in chat_result.chat_history]
assert chat_messages == ["Start chat", "Assistant_In_Group_1 message 1", "FINAL_RESULT"]
@pytest.mark.asyncio
async def test_async_nested_chat_in_group():
def is_termination(msg):
if isinstance(msg, str) and msg == "FINAL_RESULT":
return True
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
return True
return False
inner_assistant = autogen.AssistantAgent(
"Inner-assistant",
is_termination_msg=is_termination,
)
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)
inner_assistant_2 = autogen.AssistantAgent(
"Inner-assistant-2",
)
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
inner_assistant_2
)
assistant = autogen.AssistantAgent(
"Assistant_In_Group_1",
)
MockAgentReplies(["Assistant_In_Group_1 message 1"]).add_to_agent(assistant)
assistant2 = autogen.AssistantAgent(
"Assistant_In_Group_2",
)
user = autogen.UserProxyAgent("User", human_input_mode="NEVER", is_termination_msg=is_termination)
group = autogen.GroupChat(
agents=[assistant, assistant2, user],
messages=[],
speaker_selection_method="round_robin",
)
group_manager = autogen.GroupChatManager(groupchat=group)
assistant2.register_nested_chats(
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg", "chat_id": 1}],
trigger=group_manager,
use_async=True,
)
chat_result = await user.a_initiate_chat(group_manager, message="Start chat", summary_method="last_msg")
assert len(chat_result.chat_history) == 3
chat_messages = [msg["content"] for msg in chat_result.chat_history]
assert chat_messages == ["Start chat", "Assistant_In_Group_1 message 1", "FINAL_RESULT"]
if __name__ == "__main__": if __name__ == "__main__":
test_nested() test_nested()