From aac6f0511747f015261298e4dabbbe257a61979c Mon Sep 17 00:00:00 2001 From: Aamir <48929123+heyitsaamir@users.noreply.github.com> Date: Thu, 8 Aug 2024 18:14:33 -0700 Subject: [PATCH] Support async nested chats (#3309) * Allow async nested chats in agent chat * Fix pre-comit * Minor fix * Fix * Address feedback * Update * Fix build error --------- Co-authored-by: Qingyun Wu --- autogen/agentchat/conversable_agent.py | 77 +++++++-- test/agentchat/test_nested.py | 230 +++++++++++++++++++++++++ 2 files changed, 297 insertions(+), 10 deletions(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index a088c49108..9254ef57de 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -377,9 +377,9 @@ class ConversableAgent(LLMAgent): f["reply_func"] = new_reply_func @staticmethod - def _summary_from_nested_chats( + def _get_chats_to_run( chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any - ) -> Tuple[bool, str]: + ) -> List[Dict[str, Any]]: """A simple chat reply function. This function initiate one or a sequence of chats between the "recipient" and the agents in the chat_queue. @@ -406,22 +406,59 @@ class ConversableAgent(LLMAgent): if message: current_c["message"] = message chat_to_run.append(current_c) + return chat_to_run + + @staticmethod + def _summary_from_nested_chats( + chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any + ) -> Tuple[bool, Union[str, None]]: + """A simple chat reply function. + This function initiate one or a sequence of chats between the "recipient" and the agents in the + chat_queue. + + It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue. + + Returns: + Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. + """ + chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config) if not chat_to_run: return True, None res = initiate_chats(chat_to_run) return True, res[-1].summary + @staticmethod + async def _a_summary_from_nested_chats( + chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any + ) -> Tuple[bool, Union[str, None]]: + """A simple chat reply function. + This function initiate one or a sequence of chats between the "recipient" and the agents in the + chat_queue. + + It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue. + + Returns: + Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. + """ + chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config) + if not chat_to_run: + return True, None + res = await a_initiate_chats(chat_to_run) + index_of_last_chat = chat_to_run[-1]["chat_id"] + return True, res[index_of_last_chat].summary + def register_nested_chats( self, chat_queue: List[Dict[str, Any]], trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List], reply_func_from_nested_chats: Union[str, Callable] = "summary_from_nested_chats", position: int = 2, + use_async: Union[bool, None] = None, **kwargs, ) -> None: """Register a nested chat reply function. Args: - chat_queue (list): a list of chat objects to be initiated. + chat_queue (list): a list of chat objects to be initiated. If use_async is used, then all messages in chat_queue must have a chat-id associated with them. trigger (Agent class, str, Agent instance, callable, or list): refer to `register_reply` for details. reply_func_from_nested_chats (Callable, str): the reply function for the nested chat. The function takes a chat_queue for nested chat, recipient agent, a list of messages, a sender agent and a config as input and returns a reply message. @@ -436,15 +473,33 @@ class ConversableAgent(LLMAgent): ) -> Tuple[bool, Union[str, Dict, None]]: ``` position (int): Ref to `register_reply` for details. Default to 2. It means we first check the termination and human reply, then check the registered nested chat reply. + use_async: Uses a_initiate_chats internally to start nested chats. If the original chat is initiated with a_initiate_chats, you may set this to true so nested chats do not run in sync. kwargs: Ref to `register_reply` for details. """ - if reply_func_from_nested_chats == "summary_from_nested_chats": - reply_func_from_nested_chats = self._summary_from_nested_chats - if not callable(reply_func_from_nested_chats): - raise ValueError("reply_func_from_nested_chats must be a callable") + if use_async: + for chat in chat_queue: + if chat.get("chat_id") is None: + raise ValueError("chat_id is required for async nested chats") - def wrapped_reply_func(recipient, messages=None, sender=None, config=None): - return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config) + if use_async: + if reply_func_from_nested_chats == "summary_from_nested_chats": + reply_func_from_nested_chats = self._a_summary_from_nested_chats + if not callable(reply_func_from_nested_chats) or not inspect.iscoroutinefunction( + reply_func_from_nested_chats + ): + raise ValueError("reply_func_from_nested_chats must be a callable and a coroutine") + + async def wrapped_reply_func(recipient, messages=None, sender=None, config=None): + return await reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config) + + else: + if reply_func_from_nested_chats == "summary_from_nested_chats": + reply_func_from_nested_chats = self._summary_from_nested_chats + if not callable(reply_func_from_nested_chats): + raise ValueError("reply_func_from_nested_chats must be a callable") + + def wrapped_reply_func(recipient, messages=None, sender=None, config=None): + return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config) functools.update_wrapper(wrapped_reply_func, reply_func_from_nested_chats) @@ -454,7 +509,9 @@ class ConversableAgent(LLMAgent): position, kwargs.get("config"), kwargs.get("reset_config"), - ignore_async_in_sync_chat=kwargs.get("ignore_async_in_sync_chat"), + ignore_async_in_sync_chat=( + not use_async if use_async is not None else kwargs.get("ignore_async_in_sync_chat") + ), ) @property diff --git a/test/agentchat/test_nested.py b/test/agentchat/test_nested.py index ee8da793fd..04fc84b5b3 100755 --- a/test/agentchat/test_nested.py +++ b/test/agentchat/test_nested.py @@ -2,10 +2,12 @@ import os import sys +from typing import List import pytest import autogen +from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability sys.path.append(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) @@ -13,6 +15,23 @@ from conftest import reason, skip_openai # noqa: E402 from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402 +class MockAgentReplies(AgentCapability): + def __init__(self, mock_messages: List[str]): + self.mock_messages = mock_messages + self.mock_message_index = 0 + + def add_to_agent(self, agent: autogen.ConversableAgent): + def mock_reply(recipient, messages, sender, config): + if self.mock_message_index < len(self.mock_messages): + reply_msg = self.mock_messages[self.mock_message_index] + self.mock_message_index += 1 + return [True, reply_msg] + else: + raise ValueError(f"No more mock messages available for {sender.name} to reply to {recipient.name}") + + agent.register_reply([autogen.Agent, None], mock_reply, position=2) + + @pytest.mark.skipif(skip_openai, reason=reason) def test_nested(): config_list = autogen.config_list_from_json(env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC) @@ -142,5 +161,216 @@ def test_nested(): ) +def test_sync_nested_chat(): + def is_termination(msg): + if isinstance(msg, str) and msg == "FINAL_RESULT": + return True + elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT": + return True + return False + + inner_assistant = autogen.AssistantAgent( + "Inner-assistant", + is_termination_msg=is_termination, + ) + MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant) + + inner_assistant_2 = autogen.AssistantAgent( + "Inner-assistant-2", + ) + MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent( + inner_assistant_2 + ) + + assistant = autogen.AssistantAgent( + "Assistant", + ) + user = autogen.UserProxyAgent( + "User", + human_input_mode="NEVER", + is_termination_msg=is_termination, + ) + assistant.register_nested_chats( + [{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg"}], trigger=user + ) + chat_result = user.initiate_chat(assistant, message="Start chat") + assert len(chat_result.chat_history) == 2 + chat_messages = [msg["content"] for msg in chat_result.chat_history] + assert chat_messages == ["Start chat", "FINAL_RESULT"] + + +@pytest.mark.asyncio +async def test_async_nested_chat(): + def is_termination(msg): + if isinstance(msg, str) and msg == "FINAL_RESULT": + return True + elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT": + return True + return False + + inner_assistant = autogen.AssistantAgent( + "Inner-assistant", + is_termination_msg=is_termination, + ) + MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant) + + inner_assistant_2 = autogen.AssistantAgent( + "Inner-assistant-2", + ) + MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent( + inner_assistant_2 + ) + + assistant = autogen.AssistantAgent( + "Assistant", + ) + user = autogen.UserProxyAgent( + "User", + human_input_mode="NEVER", + is_termination_msg=is_termination, + ) + assistant.register_nested_chats( + [{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg", "chat_id": 1}], + trigger=user, + use_async=True, + ) + chat_result = await user.a_initiate_chat(assistant, message="Start chat") + assert len(chat_result.chat_history) == 2 + chat_messages = [msg["content"] for msg in chat_result.chat_history] + assert chat_messages == ["Start chat", "FINAL_RESULT"] + + +@pytest.mark.asyncio +async def test_async_nested_chat_chat_id_validation(): + def is_termination(msg): + if isinstance(msg, str) and msg == "FINAL_RESULT": + return True + elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT": + return True + return False + + inner_assistant = autogen.AssistantAgent( + "Inner-assistant", + is_termination_msg=is_termination, + ) + MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant) + + inner_assistant_2 = autogen.AssistantAgent( + "Inner-assistant-2", + ) + MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent( + inner_assistant_2 + ) + + assistant = autogen.AssistantAgent( + "Assistant", + ) + user = autogen.UserProxyAgent( + "User", + human_input_mode="NEVER", + is_termination_msg=is_termination, + ) + with pytest.raises(ValueError, match="chat_id is required for async nested chats"): + assistant.register_nested_chats( + [{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg"}], + trigger=user, + use_async=True, + ) + + +def test_sync_nested_chat_in_group(): + def is_termination(msg): + if isinstance(msg, str) and msg == "FINAL_RESULT": + return True + elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT": + return True + return False + + inner_assistant = autogen.AssistantAgent( + "Inner-assistant", + is_termination_msg=is_termination, + ) + MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant) + + inner_assistant_2 = autogen.AssistantAgent( + "Inner-assistant-2", + ) + MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent( + inner_assistant_2 + ) + + assistant = autogen.AssistantAgent( + "Assistant_In_Group_1", + ) + MockAgentReplies(["Assistant_In_Group_1 message 1"]).add_to_agent(assistant) + assistant2 = autogen.AssistantAgent( + "Assistant_In_Group_2", + ) + user = autogen.UserProxyAgent("User", human_input_mode="NEVER", is_termination_msg=is_termination) + group = autogen.GroupChat( + agents=[assistant, assistant2, user], + messages=[], + speaker_selection_method="round_robin", + ) + group_manager = autogen.GroupChatManager(groupchat=group) + assistant2.register_nested_chats( + [{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg"}], + trigger=group_manager, + ) + + chat_result = user.initiate_chat(group_manager, message="Start chat", summary_method="last_msg") + assert len(chat_result.chat_history) == 3 + chat_messages = [msg["content"] for msg in chat_result.chat_history] + assert chat_messages == ["Start chat", "Assistant_In_Group_1 message 1", "FINAL_RESULT"] + + +@pytest.mark.asyncio +async def test_async_nested_chat_in_group(): + def is_termination(msg): + if isinstance(msg, str) and msg == "FINAL_RESULT": + return True + elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT": + return True + return False + + inner_assistant = autogen.AssistantAgent( + "Inner-assistant", + is_termination_msg=is_termination, + ) + MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant) + + inner_assistant_2 = autogen.AssistantAgent( + "Inner-assistant-2", + ) + MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent( + inner_assistant_2 + ) + + assistant = autogen.AssistantAgent( + "Assistant_In_Group_1", + ) + MockAgentReplies(["Assistant_In_Group_1 message 1"]).add_to_agent(assistant) + assistant2 = autogen.AssistantAgent( + "Assistant_In_Group_2", + ) + user = autogen.UserProxyAgent("User", human_input_mode="NEVER", is_termination_msg=is_termination) + group = autogen.GroupChat( + agents=[assistant, assistant2, user], + messages=[], + speaker_selection_method="round_robin", + ) + group_manager = autogen.GroupChatManager(groupchat=group) + assistant2.register_nested_chats( + [{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg", "chat_id": 1}], + trigger=group_manager, + use_async=True, + ) + + chat_result = await user.a_initiate_chat(group_manager, message="Start chat", summary_method="last_msg") + assert len(chat_result.chat_history) == 3 + chat_messages = [msg["content"] for msg in chat_result.chat_history] + assert chat_messages == ["Start chat", "Assistant_In_Group_1 message 1", "FINAL_RESULT"] + + if __name__ == "__main__": test_nested()