From 3d51ab76ae94cfae25ce7c1a3e5d4fe3eb8134bd Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Wed, 30 Oct 2024 10:27:57 -0700 Subject: [PATCH] Formalize `ChatAgent` response as a dataclass with inner messages (#3990) --- .../agents/_assistant_agent.py | 30 ++++++++++--- .../agents/_base_chat_agent.py | 25 +++++------ .../agents/_code_executor_agent.py | 7 +-- .../src/autogen_agentchat/base/__init__.py | 3 +- .../src/autogen_agentchat/base/_chat_agent.py | 18 ++++++-- .../src/autogen_agentchat/base/_task.py | 4 +- .../src/autogen_agentchat/messages.py | 25 ++++++++++- .../teams/_group_chat/_base_group_chat.py | 28 ++++++------ .../_group_chat/_chat_agent_container.py | 19 ++++++-- .../tests/test_assistant_agent.py | 13 +++--- .../tests/test_group_chat.py | 43 ++++++++++++------- .../tutorial/agents.ipynb | 7 +-- .../tutorial/selector-group-chat.ipynb | 7 +-- 13 files changed, 157 insertions(+), 72 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index a5dece3f62..5414f782f0 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -18,12 +18,16 @@ from autogen_core.components.tools import FunctionTool, Tool from pydantic import BaseModel, ConfigDict, Field, model_validator from .. import EVENT_LOGGER_NAME +from ..base import Response from ..messages import ( ChatMessage, HandoffMessage, + InnerMessage, ResetMessage, StopMessage, TextMessage, + ToolCallMessage, + ToolCallResultMessages, ) from ._base_chat_agent import BaseChatAgent @@ -214,7 +218,7 @@ class AssistantAgent(BaseChatAgent): return [TextMessage, HandoffMessage, StopMessage] return [TextMessage, StopMessage] - async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: + async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: # Add messages to the model context. for msg in messages: if isinstance(msg, ResetMessage): @@ -222,6 +226,9 @@ class AssistantAgent(BaseChatAgent): else: self._model_context.append(UserMessage(content=msg.content, source=msg.source)) + # Inner messages. + inner_messages: List[InnerMessage] = [] + # Generate an inference result based on the current model context. llm_messages = self._system_messages + self._model_context result = await self._model_client.create( @@ -234,12 +241,16 @@ class AssistantAgent(BaseChatAgent): # Run tool calls until the model produces a string response. while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content): event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name)) + # Add the tool call message to the output. + inner_messages.append(ToolCallMessage(content=result.content, source=self.name)) + # Execute the tool calls. results = await asyncio.gather( *[self._execute_tool_call(call, cancellation_token) for call in result.content] ) event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name)) self._model_context.append(FunctionExecutionResultMessage(content=results)) + inner_messages.append(ToolCallResultMessages(content=results, source=self.name)) # Detect handoff requests. handoffs: List[Handoff] = [] @@ -249,8 +260,13 @@ class AssistantAgent(BaseChatAgent): if len(handoffs) > 0: if len(handoffs) > 1: raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}") - # Respond with a handoff message. - return HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name) + # Return the output messages to signal the handoff. + return Response( + chat_message=HandoffMessage( + content=handoffs[0].message, target=handoffs[0].target, source=self.name + ), + inner_messages=inner_messages, + ) # Generate an inference result based on the current model context. result = await self._model_client.create( @@ -262,9 +278,13 @@ class AssistantAgent(BaseChatAgent): # Detect stop request. request_stop = "terminate" in result.content.strip().lower() if request_stop: - return StopMessage(content=result.content, source=self.name) + return Response( + chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages + ) - return TextMessage(content=result.content, source=self.name) + return Response( + chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages + ) async def _execute_tool_call( self, tool_call: FunctionCall, cancellation_token: CancellationToken diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py index bc53528678..ac74077e27 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py @@ -3,9 +3,8 @@ from typing import List, Sequence from autogen_core.base import CancellationToken -from ..base import ChatAgent, TaskResult, TerminationCondition -from ..messages import ChatMessage -from ..teams import RoundRobinGroupChat +from ..base import ChatAgent, Response, TaskResult, TerminationCondition +from ..messages import ChatMessage, InnerMessage, TextMessage class BaseChatAgent(ChatAgent, ABC): @@ -37,8 +36,8 @@ class BaseChatAgent(ChatAgent, ABC): ... @abstractmethod - async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: - """Handle incoming messages and return a response message.""" + async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: + """Handles incoming messages and returns a response.""" ... async def run( @@ -49,10 +48,12 @@ class BaseChatAgent(ChatAgent, ABC): termination_condition: TerminationCondition | None = None, ) -> TaskResult: """Run the agent with the given task and return the result.""" - group_chat = RoundRobinGroupChat(participants=[self]) - result = await group_chat.run( - task=task, - cancellation_token=cancellation_token, - termination_condition=termination_condition, - ) - return result + if cancellation_token is None: + cancellation_token = CancellationToken() + first_message = TextMessage(content=task, source="user") + response = await self.on_messages([first_message], cancellation_token) + messages: List[InnerMessage | ChatMessage] = [first_message] + if response.inner_messages is not None: + messages += response.inner_messages + messages.append(response.chat_message) + return TaskResult(messages=messages) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_code_executor_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_code_executor_agent.py index c5c216e52e..8c21d53fb8 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_code_executor_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_code_executor_agent.py @@ -3,6 +3,7 @@ from typing import List, Sequence from autogen_core.base import CancellationToken from autogen_core.components.code_executor import CodeBlock, CodeExecutor, extract_markdown_code_blocks +from ..base import Response from ..messages import ChatMessage, TextMessage from ._base_chat_agent import BaseChatAgent @@ -25,7 +26,7 @@ class CodeExecutorAgent(BaseChatAgent): """The types of messages that the code executor agent produces.""" return [TextMessage] - async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: + async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: # Extract code blocks from the messages. code_blocks: List[CodeBlock] = [] for msg in messages: @@ -34,6 +35,6 @@ class CodeExecutorAgent(BaseChatAgent): if code_blocks: # Execute the code blocks. result = await self._code_executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token) - return TextMessage(content=result.output, source=self.name) + return Response(chat_message=TextMessage(content=result.output, source=self.name)) else: - return TextMessage(content="No code blocks found in the thread.", source=self.name) + return Response(chat_message=TextMessage(content="No code blocks found in the thread.", source=self.name)) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/__init__.py index 436d69fb04..1b95d6e180 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/__init__.py @@ -1,10 +1,11 @@ -from ._chat_agent import ChatAgent +from ._chat_agent import ChatAgent, Response from ._task import TaskResult, TaskRunner from ._team import Team from ._termination import TerminatedException, TerminationCondition __all__ = [ "ChatAgent", + "Response", "Team", "TerminatedException", "TerminationCondition", diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py index 689f6e6d5e..d60dba349c 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py @@ -1,12 +1,24 @@ +from dataclasses import dataclass from typing import List, Protocol, Sequence, runtime_checkable from autogen_core.base import CancellationToken -from ..messages import ChatMessage +from ..messages import ChatMessage, InnerMessage from ._task import TaskResult, TaskRunner from ._termination import TerminationCondition +@dataclass(kw_only=True) +class Response: + """A response from calling :meth:`ChatAgent.on_messages`.""" + + chat_message: ChatMessage + """A chat message produced by the agent as the response.""" + + inner_messages: List[InnerMessage] | None = None + """Inner messages produced by the agent.""" + + @runtime_checkable class ChatAgent(TaskRunner, Protocol): """Protocol for a chat agent.""" @@ -29,8 +41,8 @@ class ChatAgent(TaskRunner, Protocol): """The types of messages that the agent produces.""" ... - async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: - """Handle incoming messages and return a response message.""" + async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: + """Handles incoming messages and returns a response.""" ... async def run( diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py index 1d9a768b90..326cceecb1 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py @@ -3,7 +3,7 @@ from typing import Protocol, Sequence from autogen_core.base import CancellationToken -from ..messages import ChatMessage +from ..messages import ChatMessage, InnerMessage from ._termination import TerminationCondition @@ -11,7 +11,7 @@ from ._termination import TerminationCondition class TaskResult: """Result of running a task.""" - messages: Sequence[ChatMessage] + messages: Sequence[InnerMessage | ChatMessage] """Messages produced by the task.""" diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index feb8b867c7..f206250e10 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -1,6 +1,7 @@ from typing import List -from autogen_core.components import Image +from autogen_core.components import FunctionCall, Image +from autogen_core.components.models import FunctionExecutionResult from pydantic import BaseModel @@ -49,8 +50,26 @@ class ResetMessage(BaseMessage): """The content for the reset message.""" +class ToolCallMessage(BaseMessage): + """A message signaling the use of tools.""" + + content: List[FunctionCall] + """The tool calls.""" + + +class ToolCallResultMessages(BaseMessage): + """A message signaling the results of tool calls.""" + + content: List[FunctionExecutionResult] + """The tool call results.""" + + +InnerMessage = ToolCallMessage | ToolCallResultMessages +"""Messages for intra-agent monologues.""" + + ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ResetMessage -"""A message used by agents in a team.""" +"""Messages for agent-to-agent communication.""" __all__ = [ @@ -60,5 +79,7 @@ __all__ = [ "StopMessage", "HandoffMessage", "ResetMessage", + "ToolCallMessage", + "ToolCallResultMessages", "ChatMessage", ] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index f5268a3a9a..9f3132a749 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -15,7 +15,7 @@ from autogen_core.base import ( from autogen_core.components import ClosureAgent, TypeSubscription from ...base import ChatAgent, TaskResult, Team, TerminationCondition -from ...messages import ChatMessage, TextMessage +from ...messages import ChatMessage, InnerMessage, TextMessage from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent from ._base_group_chat_manager import BaseGroupChatManager from ._chat_agent_container import ChatAgentContainer @@ -56,12 +56,13 @@ class BaseGroupChat(Team, ABC): def _create_participant_factory( self, parent_topic_type: str, + output_topic_type: str, agent: ChatAgent, ) -> Callable[[], ChatAgentContainer]: def _factory() -> ChatAgentContainer: id = AgentInstantiationContext.current_agent_id() assert id == AgentId(type=agent.name, key=self._team_id) - container = ChatAgentContainer(parent_topic_type, agent) + container = ChatAgentContainer(parent_topic_type, output_topic_type, agent) assert container.id == id return container @@ -85,6 +86,7 @@ class BaseGroupChat(Team, ABC): group_chat_manager_topic_type = group_chat_manager_agent_type.type group_topic_type = "round_robin_group_topic" team_topic_type = "team_topic" + output_topic_type = "output_topic" # Register participants. participant_topic_types: List[str] = [] @@ -97,7 +99,7 @@ class BaseGroupChat(Team, ABC): await ChatAgentContainer.register( runtime, type=agent_type, - factory=self._create_participant_factory(group_topic_type, participant), + factory=self._create_participant_factory(group_topic_type, output_topic_type, participant), ) # Add subscriptions for the participant. await runtime.add_subscription(TypeSubscription(topic_type=topic_type, agent_type=agent_type)) @@ -129,22 +131,22 @@ class BaseGroupChat(Team, ABC): TypeSubscription(topic_type=team_topic_type, agent_type=group_chat_manager_agent_type.type) ) - group_chat_messages: List[ChatMessage] = [] + output_messages: List[InnerMessage | ChatMessage] = [] - async def collect_group_chat_messages( + async def collect_output_messages( _runtime: AgentRuntime, id: AgentId, - message: GroupChatPublishEvent, + message: InnerMessage | ChatMessage, ctx: MessageContext, ) -> None: - group_chat_messages.append(message.agent_message) + output_messages.append(message) await ClosureAgent.register( runtime, - type="collect_group_chat_messages", - closure=collect_group_chat_messages, + type="collect_output_messages", + closure=collect_output_messages, subscriptions=lambda: [ - TypeSubscription(topic_type=group_topic_type, agent_type="collect_group_chat_messages"), + TypeSubscription(topic_type=output_topic_type, agent_type="collect_output_messages"), ], ) @@ -154,8 +156,10 @@ class BaseGroupChat(Team, ABC): # Run the team by publishing the task to the team topic and then requesting the result. team_topic_id = TopicId(type=team_topic_type, source=self._team_id) group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id) + first_chat_message = TextMessage(content=task, source="user") + output_messages.append(first_chat_message) await runtime.publish_message( - GroupChatPublishEvent(agent_message=TextMessage(content=task, source="user")), + GroupChatPublishEvent(agent_message=first_chat_message), topic_id=team_topic_id, ) await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id) @@ -164,4 +168,4 @@ class BaseGroupChat(Team, ABC): await runtime.stop_when_idle() # Return the result. - return TaskResult(messages=group_chat_messages) + return TaskResult(messages=output_messages) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py index e2970ffe63..1423735c2f 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py @@ -16,12 +16,14 @@ class ChatAgentContainer(SequentialRoutedAgent): Args: parent_topic_type (str): The topic type of the parent orchestrator. + output_topic_type (str): The topic type for the output. agent (ChatAgent): The agent to delegate message handling to. """ - def __init__(self, parent_topic_type: str, agent: ChatAgent) -> None: + def __init__(self, parent_topic_type: str, output_topic_type: str, agent: ChatAgent) -> None: super().__init__(description=agent.description) self._parent_topic_type = parent_topic_type + self._output_topic_type = output_topic_type self._agent = agent self._message_buffer: List[ChatMessage] = [] @@ -36,18 +38,27 @@ class ChatAgentContainer(SequentialRoutedAgent): to the delegate agent and publish the response.""" # Pass the messages in the buffer to the delegate agent. response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token) - if not any(isinstance(response, msg_type) for msg_type in self._agent.produced_message_types): + if not any(isinstance(response.chat_message, msg_type) for msg_type in self._agent.produced_message_types): raise ValueError( f"The agent {self._agent.name} produced an unexpected message type: {type(response)}. " - f"Expected one of: {self._agent.produced_message_types}" + f"Expected one of: {self._agent.produced_message_types}. " + f"Check the agent's produced_message_types property." ) + # Publish inner messages to the output topic. + if response.inner_messages is not None: + for inner_message in response.inner_messages: + await self.publish_message(inner_message, topic_id=DefaultTopicId(type=self._output_topic_type)) + # Publish the response. self._message_buffer.clear() await self.publish_message( - GroupChatPublishEvent(agent_message=response, source=self.id), + GroupChatPublishEvent(agent_message=response.chat_message, source=self.id), topic_id=DefaultTopicId(type=self._parent_topic_type), ) + # Publish the response to the output topic. + await self.publish_message(response.chat_message, topic_id=DefaultTopicId(type=self._output_topic_type)) + async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: raise ValueError(f"Unhandled message in agent container: {type(message)}") diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 332a7bab15..9dee76539b 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -7,7 +7,7 @@ import pytest from autogen_agentchat import EVENT_LOGGER_NAME from autogen_agentchat.agents import AssistantAgent, Handoff from autogen_agentchat.logging import FileLogHandler -from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage +from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessages from autogen_core.base import CancellationToken from autogen_core.components.tools import FunctionTool from autogen_ext.models import OpenAIChatCompletionClient @@ -111,10 +111,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], ) result = await tool_use_agent.run("task") - assert len(result.messages) == 3 + assert len(result.messages) == 4 assert isinstance(result.messages[0], TextMessage) - assert isinstance(result.messages[1], TextMessage) - assert isinstance(result.messages[2], StopMessage) + assert isinstance(result.messages[1], ToolCallMessage) + assert isinstance(result.messages[2], ToolCallResultMessages) + assert isinstance(result.messages[3], TextMessage) @pytest.mark.asyncio @@ -162,5 +163,5 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: response = await tool_use_agent.on_messages( [TextMessage(content="task", source="user")], cancellation_token=CancellationToken() ) - assert isinstance(response, HandoffMessage) - assert response.target == "agent2" + assert isinstance(response.chat_message, HandoffMessage) + assert response.chat_message.target == "agent2" diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 1d04c78b60..e6510c2fa1 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -12,12 +12,15 @@ from autogen_agentchat.agents import ( CodeExecutorAgent, Handoff, ) +from autogen_agentchat.base import Response from autogen_agentchat.logging import FileLogHandler from autogen_agentchat.messages import ( ChatMessage, HandoffMessage, StopMessage, TextMessage, + ToolCallMessage, + ToolCallResultMessages, ) from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination from autogen_agentchat.teams import ( @@ -66,14 +69,14 @@ class _EchoAgent(BaseChatAgent): def produced_message_types(self) -> List[type[ChatMessage]]: return [TextMessage] - async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: + async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: if len(messages) > 0: assert isinstance(messages[0], TextMessage) self._last_message = messages[0].content - return TextMessage(content=messages[0].content, source=self.name) + return Response(chat_message=TextMessage(content=messages[0].content, source=self.name)) else: assert self._last_message is not None - return TextMessage(content=self._last_message, source=self.name) + return Response(chat_message=TextMessage(content=self._last_message, source=self.name)) class _StopAgent(_EchoAgent): @@ -86,11 +89,11 @@ class _StopAgent(_EchoAgent): def produced_message_types(self) -> List[type[ChatMessage]]: return [TextMessage, StopMessage] - async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: + async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: self._count += 1 if self._count < self._stop_at: return await super().on_messages(messages, cancellation_token) - return StopMessage(content="TERMINATE", source=self.name) + return Response(chat_message=StopMessage(content="TERMINATE", source=self.name)) def _pass_function(input: str) -> str: @@ -230,11 +233,13 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() ) - assert len(result.messages) == 4 + assert len(result.messages) == 6 assert isinstance(result.messages[0], TextMessage) # task - assert isinstance(result.messages[1], TextMessage) # tool use agent response - assert isinstance(result.messages[2], TextMessage) # echo agent response - assert isinstance(result.messages[3], StopMessage) # tool use agent response + assert isinstance(result.messages[1], ToolCallMessage) # tool call + assert isinstance(result.messages[2], ToolCallResultMessages) # tool call result + assert isinstance(result.messages[3], TextMessage) # tool use agent response + assert isinstance(result.messages[4], TextMessage) # echo agent response + assert isinstance(result.messages[5], StopMessage) # tool use agent response context = tool_use_agent._model_context # pyright: ignore assert context[0].content == "Write a program that prints 'Hello, world!'" @@ -427,8 +432,12 @@ class _HandOffAgent(BaseChatAgent): def produced_message_types(self) -> List[type[ChatMessage]]: return [HandoffMessage] - async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: - return HandoffMessage(content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name) + async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: + return Response( + chat_message=HandoffMessage( + content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name + ) + ) @pytest.mark.asyncio @@ -513,9 +522,11 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) - agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1") team = Swarm([agnet1, agent2]) result = await team.run("task", termination_condition=StopMessageTermination()) - assert len(result.messages) == 5 + assert len(result.messages) == 7 assert result.messages[0].content == "task" - assert result.messages[1].content == "handoff to agent2" - assert result.messages[2].content == "Transferred to agent1." - assert result.messages[3].content == "Hello" - assert result.messages[4].content == "TERMINATE" + assert isinstance(result.messages[1], ToolCallMessage) + assert isinstance(result.messages[2], ToolCallResultMessages) + assert result.messages[3].content == "handoff to agent2" + assert result.messages[4].content == "Transferred to agent1." + assert result.messages[5].content == "Hello" + assert result.messages[6].content == "TERMINATE" diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/agents.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/agents.ipynb index 0e58c2522a..d367ba29b7 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/agents.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/agents.ipynb @@ -251,6 +251,7 @@ "from typing import List, Sequence\n", "\n", "from autogen_agentchat.agents import BaseChatAgent\n", + "from autogen_agentchat.base import Response\n", "from autogen_agentchat.messages import (\n", " ChatMessage,\n", " StopMessage,\n", @@ -266,11 +267,11 @@ " def produced_message_types(self) -> List[type[ChatMessage]]:\n", " return [TextMessage, StopMessage]\n", "\n", - " async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:\n", + " async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n", " user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n", " if \"TERMINATE\" in user_input:\n", - " return StopMessage(content=\"User has terminated the conversation.\", source=self.name)\n", - " return TextMessage(content=user_input, source=self.name)\n", + " return Response(chat_message=StopMessage(content=\"User has terminated the conversation.\", source=self.name))\n", + " return Response(chat_message=TextMessage(content=user_input, source=self.name))\n", "\n", "\n", "user_proxy_agent = UserProxyAgent(name=\"user_proxy_agent\")\n", diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/selector-group-chat.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/selector-group-chat.ipynb index f5a5358aae..633c81867b 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/selector-group-chat.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/selector-group-chat.ipynb @@ -45,6 +45,7 @@ " CodingAssistantAgent,\n", " ToolUseAssistantAgent,\n", ")\n", + "from autogen_agentchat.base import Response\n", "from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage\n", "from autogen_agentchat.task import StopMessageTermination\n", "from autogen_agentchat.teams import SelectorGroupChat\n", @@ -75,11 +76,11 @@ " def produced_message_types(self) -> List[type[ChatMessage]]:\n", " return [TextMessage, StopMessage]\n", "\n", - " async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:\n", + " async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n", " user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n", " if \"TERMINATE\" in user_input:\n", - " return StopMessage(content=\"User has terminated the conversation.\", source=self.name)\n", - " return TextMessage(content=user_input, source=self.name)" + " return Response(chat_message=StopMessage(content=\"User has terminated the conversation.\", source=self.name))\n", + " return Response(chat_message=TextMessage(content=user_input, source=self.name))" ] }, {