mirror of https://github.com/microsoft/autogen.git
Support async nested chats (#3309)
* Allow async nested chats in agent chat * Fix pre-comit * Minor fix * Fix * Address feedback * Update * Fix build error --------- Co-authored-by: Qingyun Wu <qingyun0327@gmail.com>
This commit is contained in:
parent
4dab28c769
commit
aac6f05117
|
@ -377,9 +377,9 @@ class ConversableAgent(LLMAgent):
|
|||
f["reply_func"] = new_reply_func
|
||||
|
||||
@staticmethod
|
||||
def _summary_from_nested_chats(
|
||||
def _get_chats_to_run(
|
||||
chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
|
||||
) -> Tuple[bool, str]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""A simple chat reply function.
|
||||
This function initiate one or a sequence of chats between the "recipient" and the agents in the
|
||||
chat_queue.
|
||||
|
@ -406,22 +406,59 @@ class ConversableAgent(LLMAgent):
|
|||
if message:
|
||||
current_c["message"] = message
|
||||
chat_to_run.append(current_c)
|
||||
return chat_to_run
|
||||
|
||||
@staticmethod
|
||||
def _summary_from_nested_chats(
|
||||
chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
|
||||
) -> Tuple[bool, Union[str, None]]:
|
||||
"""A simple chat reply function.
|
||||
This function initiate one or a sequence of chats between the "recipient" and the agents in the
|
||||
chat_queue.
|
||||
|
||||
It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated.
|
||||
"""
|
||||
chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config)
|
||||
if not chat_to_run:
|
||||
return True, None
|
||||
res = initiate_chats(chat_to_run)
|
||||
return True, res[-1].summary
|
||||
|
||||
@staticmethod
|
||||
async def _a_summary_from_nested_chats(
|
||||
chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
|
||||
) -> Tuple[bool, Union[str, None]]:
|
||||
"""A simple chat reply function.
|
||||
This function initiate one or a sequence of chats between the "recipient" and the agents in the
|
||||
chat_queue.
|
||||
|
||||
It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated.
|
||||
"""
|
||||
chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config)
|
||||
if not chat_to_run:
|
||||
return True, None
|
||||
res = await a_initiate_chats(chat_to_run)
|
||||
index_of_last_chat = chat_to_run[-1]["chat_id"]
|
||||
return True, res[index_of_last_chat].summary
|
||||
|
||||
def register_nested_chats(
|
||||
self,
|
||||
chat_queue: List[Dict[str, Any]],
|
||||
trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List],
|
||||
reply_func_from_nested_chats: Union[str, Callable] = "summary_from_nested_chats",
|
||||
position: int = 2,
|
||||
use_async: Union[bool, None] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Register a nested chat reply function.
|
||||
Args:
|
||||
chat_queue (list): a list of chat objects to be initiated.
|
||||
chat_queue (list): a list of chat objects to be initiated. If use_async is used, then all messages in chat_queue must have a chat-id associated with them.
|
||||
trigger (Agent class, str, Agent instance, callable, or list): refer to `register_reply` for details.
|
||||
reply_func_from_nested_chats (Callable, str): the reply function for the nested chat.
|
||||
The function takes a chat_queue for nested chat, recipient agent, a list of messages, a sender agent and a config as input and returns a reply message.
|
||||
|
@ -436,15 +473,33 @@ class ConversableAgent(LLMAgent):
|
|||
) -> Tuple[bool, Union[str, Dict, None]]:
|
||||
```
|
||||
position (int): Ref to `register_reply` for details. Default to 2. It means we first check the termination and human reply, then check the registered nested chat reply.
|
||||
use_async: Uses a_initiate_chats internally to start nested chats. If the original chat is initiated with a_initiate_chats, you may set this to true so nested chats do not run in sync.
|
||||
kwargs: Ref to `register_reply` for details.
|
||||
"""
|
||||
if reply_func_from_nested_chats == "summary_from_nested_chats":
|
||||
reply_func_from_nested_chats = self._summary_from_nested_chats
|
||||
if not callable(reply_func_from_nested_chats):
|
||||
raise ValueError("reply_func_from_nested_chats must be a callable")
|
||||
if use_async:
|
||||
for chat in chat_queue:
|
||||
if chat.get("chat_id") is None:
|
||||
raise ValueError("chat_id is required for async nested chats")
|
||||
|
||||
def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
|
||||
return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config)
|
||||
if use_async:
|
||||
if reply_func_from_nested_chats == "summary_from_nested_chats":
|
||||
reply_func_from_nested_chats = self._a_summary_from_nested_chats
|
||||
if not callable(reply_func_from_nested_chats) or not inspect.iscoroutinefunction(
|
||||
reply_func_from_nested_chats
|
||||
):
|
||||
raise ValueError("reply_func_from_nested_chats must be a callable and a coroutine")
|
||||
|
||||
async def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
|
||||
return await reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config)
|
||||
|
||||
else:
|
||||
if reply_func_from_nested_chats == "summary_from_nested_chats":
|
||||
reply_func_from_nested_chats = self._summary_from_nested_chats
|
||||
if not callable(reply_func_from_nested_chats):
|
||||
raise ValueError("reply_func_from_nested_chats must be a callable")
|
||||
|
||||
def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
|
||||
return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config)
|
||||
|
||||
functools.update_wrapper(wrapped_reply_func, reply_func_from_nested_chats)
|
||||
|
||||
|
@ -454,7 +509,9 @@ class ConversableAgent(LLMAgent):
|
|||
position,
|
||||
kwargs.get("config"),
|
||||
kwargs.get("reset_config"),
|
||||
ignore_async_in_sync_chat=kwargs.get("ignore_async_in_sync_chat"),
|
||||
ignore_async_in_sync_chat=(
|
||||
not use_async if use_async is not None else kwargs.get("ignore_async_in_sync_chat")
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
|
@ -2,10 +2,12 @@
|
|||
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
import autogen
|
||||
from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
|
@ -13,6 +15,23 @@ from conftest import reason, skip_openai # noqa: E402
|
|||
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
|
||||
|
||||
|
||||
class MockAgentReplies(AgentCapability):
|
||||
def __init__(self, mock_messages: List[str]):
|
||||
self.mock_messages = mock_messages
|
||||
self.mock_message_index = 0
|
||||
|
||||
def add_to_agent(self, agent: autogen.ConversableAgent):
|
||||
def mock_reply(recipient, messages, sender, config):
|
||||
if self.mock_message_index < len(self.mock_messages):
|
||||
reply_msg = self.mock_messages[self.mock_message_index]
|
||||
self.mock_message_index += 1
|
||||
return [True, reply_msg]
|
||||
else:
|
||||
raise ValueError(f"No more mock messages available for {sender.name} to reply to {recipient.name}")
|
||||
|
||||
agent.register_reply([autogen.Agent, None], mock_reply, position=2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip_openai, reason=reason)
|
||||
def test_nested():
|
||||
config_list = autogen.config_list_from_json(env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC)
|
||||
|
@ -142,5 +161,216 @@ def test_nested():
|
|||
)
|
||||
|
||||
|
||||
def test_sync_nested_chat():
|
||||
def is_termination(msg):
|
||||
if isinstance(msg, str) and msg == "FINAL_RESULT":
|
||||
return True
|
||||
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
|
||||
return True
|
||||
return False
|
||||
|
||||
inner_assistant = autogen.AssistantAgent(
|
||||
"Inner-assistant",
|
||||
is_termination_msg=is_termination,
|
||||
)
|
||||
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)
|
||||
|
||||
inner_assistant_2 = autogen.AssistantAgent(
|
||||
"Inner-assistant-2",
|
||||
)
|
||||
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
|
||||
inner_assistant_2
|
||||
)
|
||||
|
||||
assistant = autogen.AssistantAgent(
|
||||
"Assistant",
|
||||
)
|
||||
user = autogen.UserProxyAgent(
|
||||
"User",
|
||||
human_input_mode="NEVER",
|
||||
is_termination_msg=is_termination,
|
||||
)
|
||||
assistant.register_nested_chats(
|
||||
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg"}], trigger=user
|
||||
)
|
||||
chat_result = user.initiate_chat(assistant, message="Start chat")
|
||||
assert len(chat_result.chat_history) == 2
|
||||
chat_messages = [msg["content"] for msg in chat_result.chat_history]
|
||||
assert chat_messages == ["Start chat", "FINAL_RESULT"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_nested_chat():
|
||||
def is_termination(msg):
|
||||
if isinstance(msg, str) and msg == "FINAL_RESULT":
|
||||
return True
|
||||
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
|
||||
return True
|
||||
return False
|
||||
|
||||
inner_assistant = autogen.AssistantAgent(
|
||||
"Inner-assistant",
|
||||
is_termination_msg=is_termination,
|
||||
)
|
||||
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)
|
||||
|
||||
inner_assistant_2 = autogen.AssistantAgent(
|
||||
"Inner-assistant-2",
|
||||
)
|
||||
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
|
||||
inner_assistant_2
|
||||
)
|
||||
|
||||
assistant = autogen.AssistantAgent(
|
||||
"Assistant",
|
||||
)
|
||||
user = autogen.UserProxyAgent(
|
||||
"User",
|
||||
human_input_mode="NEVER",
|
||||
is_termination_msg=is_termination,
|
||||
)
|
||||
assistant.register_nested_chats(
|
||||
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg", "chat_id": 1}],
|
||||
trigger=user,
|
||||
use_async=True,
|
||||
)
|
||||
chat_result = await user.a_initiate_chat(assistant, message="Start chat")
|
||||
assert len(chat_result.chat_history) == 2
|
||||
chat_messages = [msg["content"] for msg in chat_result.chat_history]
|
||||
assert chat_messages == ["Start chat", "FINAL_RESULT"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_nested_chat_chat_id_validation():
|
||||
def is_termination(msg):
|
||||
if isinstance(msg, str) and msg == "FINAL_RESULT":
|
||||
return True
|
||||
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
|
||||
return True
|
||||
return False
|
||||
|
||||
inner_assistant = autogen.AssistantAgent(
|
||||
"Inner-assistant",
|
||||
is_termination_msg=is_termination,
|
||||
)
|
||||
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)
|
||||
|
||||
inner_assistant_2 = autogen.AssistantAgent(
|
||||
"Inner-assistant-2",
|
||||
)
|
||||
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
|
||||
inner_assistant_2
|
||||
)
|
||||
|
||||
assistant = autogen.AssistantAgent(
|
||||
"Assistant",
|
||||
)
|
||||
user = autogen.UserProxyAgent(
|
||||
"User",
|
||||
human_input_mode="NEVER",
|
||||
is_termination_msg=is_termination,
|
||||
)
|
||||
with pytest.raises(ValueError, match="chat_id is required for async nested chats"):
|
||||
assistant.register_nested_chats(
|
||||
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg"}],
|
||||
trigger=user,
|
||||
use_async=True,
|
||||
)
|
||||
|
||||
|
||||
def test_sync_nested_chat_in_group():
|
||||
def is_termination(msg):
|
||||
if isinstance(msg, str) and msg == "FINAL_RESULT":
|
||||
return True
|
||||
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
|
||||
return True
|
||||
return False
|
||||
|
||||
inner_assistant = autogen.AssistantAgent(
|
||||
"Inner-assistant",
|
||||
is_termination_msg=is_termination,
|
||||
)
|
||||
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)
|
||||
|
||||
inner_assistant_2 = autogen.AssistantAgent(
|
||||
"Inner-assistant-2",
|
||||
)
|
||||
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
|
||||
inner_assistant_2
|
||||
)
|
||||
|
||||
assistant = autogen.AssistantAgent(
|
||||
"Assistant_In_Group_1",
|
||||
)
|
||||
MockAgentReplies(["Assistant_In_Group_1 message 1"]).add_to_agent(assistant)
|
||||
assistant2 = autogen.AssistantAgent(
|
||||
"Assistant_In_Group_2",
|
||||
)
|
||||
user = autogen.UserProxyAgent("User", human_input_mode="NEVER", is_termination_msg=is_termination)
|
||||
group = autogen.GroupChat(
|
||||
agents=[assistant, assistant2, user],
|
||||
messages=[],
|
||||
speaker_selection_method="round_robin",
|
||||
)
|
||||
group_manager = autogen.GroupChatManager(groupchat=group)
|
||||
assistant2.register_nested_chats(
|
||||
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg"}],
|
||||
trigger=group_manager,
|
||||
)
|
||||
|
||||
chat_result = user.initiate_chat(group_manager, message="Start chat", summary_method="last_msg")
|
||||
assert len(chat_result.chat_history) == 3
|
||||
chat_messages = [msg["content"] for msg in chat_result.chat_history]
|
||||
assert chat_messages == ["Start chat", "Assistant_In_Group_1 message 1", "FINAL_RESULT"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_nested_chat_in_group():
|
||||
def is_termination(msg):
|
||||
if isinstance(msg, str) and msg == "FINAL_RESULT":
|
||||
return True
|
||||
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
|
||||
return True
|
||||
return False
|
||||
|
||||
inner_assistant = autogen.AssistantAgent(
|
||||
"Inner-assistant",
|
||||
is_termination_msg=is_termination,
|
||||
)
|
||||
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)
|
||||
|
||||
inner_assistant_2 = autogen.AssistantAgent(
|
||||
"Inner-assistant-2",
|
||||
)
|
||||
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
|
||||
inner_assistant_2
|
||||
)
|
||||
|
||||
assistant = autogen.AssistantAgent(
|
||||
"Assistant_In_Group_1",
|
||||
)
|
||||
MockAgentReplies(["Assistant_In_Group_1 message 1"]).add_to_agent(assistant)
|
||||
assistant2 = autogen.AssistantAgent(
|
||||
"Assistant_In_Group_2",
|
||||
)
|
||||
user = autogen.UserProxyAgent("User", human_input_mode="NEVER", is_termination_msg=is_termination)
|
||||
group = autogen.GroupChat(
|
||||
agents=[assistant, assistant2, user],
|
||||
messages=[],
|
||||
speaker_selection_method="round_robin",
|
||||
)
|
||||
group_manager = autogen.GroupChatManager(groupchat=group)
|
||||
assistant2.register_nested_chats(
|
||||
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg", "chat_id": 1}],
|
||||
trigger=group_manager,
|
||||
use_async=True,
|
||||
)
|
||||
|
||||
chat_result = await user.a_initiate_chat(group_manager, message="Start chat", summary_method="last_msg")
|
||||
assert len(chat_result.chat_history) == 3
|
||||
chat_messages = [msg["content"] for msg in chat_result.chat_history]
|
||||
assert chat_messages == ["Start chat", "Assistant_In_Group_1 message 1", "FINAL_RESULT"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_nested()
|
||||
|
|
Loading…
Reference in New Issue