mirror of https://github.com/microsoft/autogen.git
Support async nested chats (#3309)
* Allow async nested chats in agent chat * Fix pre-comit * Minor fix * Fix * Address feedback * Update * Fix build error --------- Co-authored-by: Qingyun Wu <qingyun0327@gmail.com>
This commit is contained in:
parent
4dab28c769
commit
aac6f05117
|
@ -377,9 +377,9 @@ class ConversableAgent(LLMAgent):
|
||||||
f["reply_func"] = new_reply_func
|
f["reply_func"] = new_reply_func
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _summary_from_nested_chats(
|
def _get_chats_to_run(
|
||||||
chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
|
chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
|
||||||
) -> Tuple[bool, str]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""A simple chat reply function.
|
"""A simple chat reply function.
|
||||||
This function initiate one or a sequence of chats between the "recipient" and the agents in the
|
This function initiate one or a sequence of chats between the "recipient" and the agents in the
|
||||||
chat_queue.
|
chat_queue.
|
||||||
|
@ -406,22 +406,59 @@ class ConversableAgent(LLMAgent):
|
||||||
if message:
|
if message:
|
||||||
current_c["message"] = message
|
current_c["message"] = message
|
||||||
chat_to_run.append(current_c)
|
chat_to_run.append(current_c)
|
||||||
|
return chat_to_run
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _summary_from_nested_chats(
|
||||||
|
chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
|
||||||
|
) -> Tuple[bool, Union[str, None]]:
|
||||||
|
"""A simple chat reply function.
|
||||||
|
This function initiate one or a sequence of chats between the "recipient" and the agents in the
|
||||||
|
chat_queue.
|
||||||
|
|
||||||
|
It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated.
|
||||||
|
"""
|
||||||
|
chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config)
|
||||||
if not chat_to_run:
|
if not chat_to_run:
|
||||||
return True, None
|
return True, None
|
||||||
res = initiate_chats(chat_to_run)
|
res = initiate_chats(chat_to_run)
|
||||||
return True, res[-1].summary
|
return True, res[-1].summary
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _a_summary_from_nested_chats(
|
||||||
|
chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
|
||||||
|
) -> Tuple[bool, Union[str, None]]:
|
||||||
|
"""A simple chat reply function.
|
||||||
|
This function initiate one or a sequence of chats between the "recipient" and the agents in the
|
||||||
|
chat_queue.
|
||||||
|
|
||||||
|
It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated.
|
||||||
|
"""
|
||||||
|
chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config)
|
||||||
|
if not chat_to_run:
|
||||||
|
return True, None
|
||||||
|
res = await a_initiate_chats(chat_to_run)
|
||||||
|
index_of_last_chat = chat_to_run[-1]["chat_id"]
|
||||||
|
return True, res[index_of_last_chat].summary
|
||||||
|
|
||||||
def register_nested_chats(
|
def register_nested_chats(
|
||||||
self,
|
self,
|
||||||
chat_queue: List[Dict[str, Any]],
|
chat_queue: List[Dict[str, Any]],
|
||||||
trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List],
|
trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List],
|
||||||
reply_func_from_nested_chats: Union[str, Callable] = "summary_from_nested_chats",
|
reply_func_from_nested_chats: Union[str, Callable] = "summary_from_nested_chats",
|
||||||
position: int = 2,
|
position: int = 2,
|
||||||
|
use_async: Union[bool, None] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register a nested chat reply function.
|
"""Register a nested chat reply function.
|
||||||
Args:
|
Args:
|
||||||
chat_queue (list): a list of chat objects to be initiated.
|
chat_queue (list): a list of chat objects to be initiated. If use_async is used, then all messages in chat_queue must have a chat-id associated with them.
|
||||||
trigger (Agent class, str, Agent instance, callable, or list): refer to `register_reply` for details.
|
trigger (Agent class, str, Agent instance, callable, or list): refer to `register_reply` for details.
|
||||||
reply_func_from_nested_chats (Callable, str): the reply function for the nested chat.
|
reply_func_from_nested_chats (Callable, str): the reply function for the nested chat.
|
||||||
The function takes a chat_queue for nested chat, recipient agent, a list of messages, a sender agent and a config as input and returns a reply message.
|
The function takes a chat_queue for nested chat, recipient agent, a list of messages, a sender agent and a config as input and returns a reply message.
|
||||||
|
@ -436,15 +473,33 @@ class ConversableAgent(LLMAgent):
|
||||||
) -> Tuple[bool, Union[str, Dict, None]]:
|
) -> Tuple[bool, Union[str, Dict, None]]:
|
||||||
```
|
```
|
||||||
position (int): Ref to `register_reply` for details. Default to 2. It means we first check the termination and human reply, then check the registered nested chat reply.
|
position (int): Ref to `register_reply` for details. Default to 2. It means we first check the termination and human reply, then check the registered nested chat reply.
|
||||||
|
use_async: Uses a_initiate_chats internally to start nested chats. If the original chat is initiated with a_initiate_chats, you may set this to true so nested chats do not run in sync.
|
||||||
kwargs: Ref to `register_reply` for details.
|
kwargs: Ref to `register_reply` for details.
|
||||||
"""
|
"""
|
||||||
if reply_func_from_nested_chats == "summary_from_nested_chats":
|
if use_async:
|
||||||
reply_func_from_nested_chats = self._summary_from_nested_chats
|
for chat in chat_queue:
|
||||||
if not callable(reply_func_from_nested_chats):
|
if chat.get("chat_id") is None:
|
||||||
raise ValueError("reply_func_from_nested_chats must be a callable")
|
raise ValueError("chat_id is required for async nested chats")
|
||||||
|
|
||||||
def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
|
if use_async:
|
||||||
return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config)
|
if reply_func_from_nested_chats == "summary_from_nested_chats":
|
||||||
|
reply_func_from_nested_chats = self._a_summary_from_nested_chats
|
||||||
|
if not callable(reply_func_from_nested_chats) or not inspect.iscoroutinefunction(
|
||||||
|
reply_func_from_nested_chats
|
||||||
|
):
|
||||||
|
raise ValueError("reply_func_from_nested_chats must be a callable and a coroutine")
|
||||||
|
|
||||||
|
async def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
|
||||||
|
return await reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if reply_func_from_nested_chats == "summary_from_nested_chats":
|
||||||
|
reply_func_from_nested_chats = self._summary_from_nested_chats
|
||||||
|
if not callable(reply_func_from_nested_chats):
|
||||||
|
raise ValueError("reply_func_from_nested_chats must be a callable")
|
||||||
|
|
||||||
|
def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
|
||||||
|
return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config)
|
||||||
|
|
||||||
functools.update_wrapper(wrapped_reply_func, reply_func_from_nested_chats)
|
functools.update_wrapper(wrapped_reply_func, reply_func_from_nested_chats)
|
||||||
|
|
||||||
|
@ -454,7 +509,9 @@ class ConversableAgent(LLMAgent):
|
||||||
position,
|
position,
|
||||||
kwargs.get("config"),
|
kwargs.get("config"),
|
||||||
kwargs.get("reset_config"),
|
kwargs.get("reset_config"),
|
||||||
ignore_async_in_sync_chat=kwargs.get("ignore_async_in_sync_chat"),
|
ignore_async_in_sync_chat=(
|
||||||
|
not use_async if use_async is not None else kwargs.get("ignore_async_in_sync_chat")
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -2,10 +2,12 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import autogen
|
import autogen
|
||||||
|
from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
|
||||||
|
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
|
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
|
||||||
|
@ -13,6 +15,23 @@ from conftest import reason, skip_openai # noqa: E402
|
||||||
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
|
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
class MockAgentReplies(AgentCapability):
|
||||||
|
def __init__(self, mock_messages: List[str]):
|
||||||
|
self.mock_messages = mock_messages
|
||||||
|
self.mock_message_index = 0
|
||||||
|
|
||||||
|
def add_to_agent(self, agent: autogen.ConversableAgent):
|
||||||
|
def mock_reply(recipient, messages, sender, config):
|
||||||
|
if self.mock_message_index < len(self.mock_messages):
|
||||||
|
reply_msg = self.mock_messages[self.mock_message_index]
|
||||||
|
self.mock_message_index += 1
|
||||||
|
return [True, reply_msg]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No more mock messages available for {sender.name} to reply to {recipient.name}")
|
||||||
|
|
||||||
|
agent.register_reply([autogen.Agent, None], mock_reply, position=2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(skip_openai, reason=reason)
|
@pytest.mark.skipif(skip_openai, reason=reason)
|
||||||
def test_nested():
|
def test_nested():
|
||||||
config_list = autogen.config_list_from_json(env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC)
|
config_list = autogen.config_list_from_json(env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC)
|
||||||
|
@ -142,5 +161,216 @@ def test_nested():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_nested_chat():
|
||||||
|
def is_termination(msg):
|
||||||
|
if isinstance(msg, str) and msg == "FINAL_RESULT":
|
||||||
|
return True
|
||||||
|
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
inner_assistant = autogen.AssistantAgent(
|
||||||
|
"Inner-assistant",
|
||||||
|
is_termination_msg=is_termination,
|
||||||
|
)
|
||||||
|
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)
|
||||||
|
|
||||||
|
inner_assistant_2 = autogen.AssistantAgent(
|
||||||
|
"Inner-assistant-2",
|
||||||
|
)
|
||||||
|
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
|
||||||
|
inner_assistant_2
|
||||||
|
)
|
||||||
|
|
||||||
|
assistant = autogen.AssistantAgent(
|
||||||
|
"Assistant",
|
||||||
|
)
|
||||||
|
user = autogen.UserProxyAgent(
|
||||||
|
"User",
|
||||||
|
human_input_mode="NEVER",
|
||||||
|
is_termination_msg=is_termination,
|
||||||
|
)
|
||||||
|
assistant.register_nested_chats(
|
||||||
|
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg"}], trigger=user
|
||||||
|
)
|
||||||
|
chat_result = user.initiate_chat(assistant, message="Start chat")
|
||||||
|
assert len(chat_result.chat_history) == 2
|
||||||
|
chat_messages = [msg["content"] for msg in chat_result.chat_history]
|
||||||
|
assert chat_messages == ["Start chat", "FINAL_RESULT"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_nested_chat():
|
||||||
|
def is_termination(msg):
|
||||||
|
if isinstance(msg, str) and msg == "FINAL_RESULT":
|
||||||
|
return True
|
||||||
|
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
inner_assistant = autogen.AssistantAgent(
|
||||||
|
"Inner-assistant",
|
||||||
|
is_termination_msg=is_termination,
|
||||||
|
)
|
||||||
|
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)
|
||||||
|
|
||||||
|
inner_assistant_2 = autogen.AssistantAgent(
|
||||||
|
"Inner-assistant-2",
|
||||||
|
)
|
||||||
|
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
|
||||||
|
inner_assistant_2
|
||||||
|
)
|
||||||
|
|
||||||
|
assistant = autogen.AssistantAgent(
|
||||||
|
"Assistant",
|
||||||
|
)
|
||||||
|
user = autogen.UserProxyAgent(
|
||||||
|
"User",
|
||||||
|
human_input_mode="NEVER",
|
||||||
|
is_termination_msg=is_termination,
|
||||||
|
)
|
||||||
|
assistant.register_nested_chats(
|
||||||
|
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg", "chat_id": 1}],
|
||||||
|
trigger=user,
|
||||||
|
use_async=True,
|
||||||
|
)
|
||||||
|
chat_result = await user.a_initiate_chat(assistant, message="Start chat")
|
||||||
|
assert len(chat_result.chat_history) == 2
|
||||||
|
chat_messages = [msg["content"] for msg in chat_result.chat_history]
|
||||||
|
assert chat_messages == ["Start chat", "FINAL_RESULT"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_nested_chat_chat_id_validation():
|
||||||
|
def is_termination(msg):
|
||||||
|
if isinstance(msg, str) and msg == "FINAL_RESULT":
|
||||||
|
return True
|
||||||
|
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
inner_assistant = autogen.AssistantAgent(
|
||||||
|
"Inner-assistant",
|
||||||
|
is_termination_msg=is_termination,
|
||||||
|
)
|
||||||
|
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)
|
||||||
|
|
||||||
|
inner_assistant_2 = autogen.AssistantAgent(
|
||||||
|
"Inner-assistant-2",
|
||||||
|
)
|
||||||
|
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
|
||||||
|
inner_assistant_2
|
||||||
|
)
|
||||||
|
|
||||||
|
assistant = autogen.AssistantAgent(
|
||||||
|
"Assistant",
|
||||||
|
)
|
||||||
|
user = autogen.UserProxyAgent(
|
||||||
|
"User",
|
||||||
|
human_input_mode="NEVER",
|
||||||
|
is_termination_msg=is_termination,
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="chat_id is required for async nested chats"):
|
||||||
|
assistant.register_nested_chats(
|
||||||
|
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg"}],
|
||||||
|
trigger=user,
|
||||||
|
use_async=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_nested_chat_in_group():
|
||||||
|
def is_termination(msg):
|
||||||
|
if isinstance(msg, str) and msg == "FINAL_RESULT":
|
||||||
|
return True
|
||||||
|
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
inner_assistant = autogen.AssistantAgent(
|
||||||
|
"Inner-assistant",
|
||||||
|
is_termination_msg=is_termination,
|
||||||
|
)
|
||||||
|
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)
|
||||||
|
|
||||||
|
inner_assistant_2 = autogen.AssistantAgent(
|
||||||
|
"Inner-assistant-2",
|
||||||
|
)
|
||||||
|
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
|
||||||
|
inner_assistant_2
|
||||||
|
)
|
||||||
|
|
||||||
|
assistant = autogen.AssistantAgent(
|
||||||
|
"Assistant_In_Group_1",
|
||||||
|
)
|
||||||
|
MockAgentReplies(["Assistant_In_Group_1 message 1"]).add_to_agent(assistant)
|
||||||
|
assistant2 = autogen.AssistantAgent(
|
||||||
|
"Assistant_In_Group_2",
|
||||||
|
)
|
||||||
|
user = autogen.UserProxyAgent("User", human_input_mode="NEVER", is_termination_msg=is_termination)
|
||||||
|
group = autogen.GroupChat(
|
||||||
|
agents=[assistant, assistant2, user],
|
||||||
|
messages=[],
|
||||||
|
speaker_selection_method="round_robin",
|
||||||
|
)
|
||||||
|
group_manager = autogen.GroupChatManager(groupchat=group)
|
||||||
|
assistant2.register_nested_chats(
|
||||||
|
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg"}],
|
||||||
|
trigger=group_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_result = user.initiate_chat(group_manager, message="Start chat", summary_method="last_msg")
|
||||||
|
assert len(chat_result.chat_history) == 3
|
||||||
|
chat_messages = [msg["content"] for msg in chat_result.chat_history]
|
||||||
|
assert chat_messages == ["Start chat", "Assistant_In_Group_1 message 1", "FINAL_RESULT"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_nested_chat_in_group():
|
||||||
|
def is_termination(msg):
|
||||||
|
if isinstance(msg, str) and msg == "FINAL_RESULT":
|
||||||
|
return True
|
||||||
|
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
inner_assistant = autogen.AssistantAgent(
|
||||||
|
"Inner-assistant",
|
||||||
|
is_termination_msg=is_termination,
|
||||||
|
)
|
||||||
|
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)
|
||||||
|
|
||||||
|
inner_assistant_2 = autogen.AssistantAgent(
|
||||||
|
"Inner-assistant-2",
|
||||||
|
)
|
||||||
|
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
|
||||||
|
inner_assistant_2
|
||||||
|
)
|
||||||
|
|
||||||
|
assistant = autogen.AssistantAgent(
|
||||||
|
"Assistant_In_Group_1",
|
||||||
|
)
|
||||||
|
MockAgentReplies(["Assistant_In_Group_1 message 1"]).add_to_agent(assistant)
|
||||||
|
assistant2 = autogen.AssistantAgent(
|
||||||
|
"Assistant_In_Group_2",
|
||||||
|
)
|
||||||
|
user = autogen.UserProxyAgent("User", human_input_mode="NEVER", is_termination_msg=is_termination)
|
||||||
|
group = autogen.GroupChat(
|
||||||
|
agents=[assistant, assistant2, user],
|
||||||
|
messages=[],
|
||||||
|
speaker_selection_method="round_robin",
|
||||||
|
)
|
||||||
|
group_manager = autogen.GroupChatManager(groupchat=group)
|
||||||
|
assistant2.register_nested_chats(
|
||||||
|
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg", "chat_id": 1}],
|
||||||
|
trigger=group_manager,
|
||||||
|
use_async=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_result = await user.a_initiate_chat(group_manager, message="Start chat", summary_method="last_msg")
|
||||||
|
assert len(chat_result.chat_history) == 3
|
||||||
|
chat_messages = [msg["content"] for msg in chat_result.chat_history]
|
||||||
|
assert chat_messages == ["Start chat", "Assistant_In_Group_1 message 1", "FINAL_RESULT"]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_nested()
|
test_nested()
|
||||||
|
|
Loading…
Reference in New Issue