From cff7d842a6d4acd2733c1f6207061cd60fa10821 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Fri, 1 Nov 2024 04:12:43 -0700 Subject: [PATCH] AgentChat streaming API (#4015) --- .../agents/_assistant_agent.py | 43 +++++-- .../agents/_base_chat_agent.py | 38 +++++- .../src/autogen_agentchat/base/_chat_agent.py | 18 ++- .../src/autogen_agentchat/base/_task.py | 14 ++- .../src/autogen_agentchat/base/_team.py | 15 +-- .../src/autogen_agentchat/messages.py | 6 +- .../teams/_group_chat/_base_group_chat.py | 47 +++++-- .../_group_chat/_chat_agent_container.py | 32 +++-- .../_group_chat/_round_robin_group_chat.py | 41 +++++-- .../teams/_group_chat/_selector_group_chat.py | 48 ++++++-- .../teams/_group_chat/_swarm_group_chat.py | 5 +- .../tests/test_assistant_agent.py | 37 ++++-- .../tests/test_group_chat.py | 116 ++++++++++++++++-- 13 files changed, 353 insertions(+), 107 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 5414f782f0..86a4f39952 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -1,7 +1,7 @@ import asyncio import json import logging -from typing import Any, Awaitable, Callable, Dict, List, Sequence +from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Sequence from autogen_core.base import CancellationToken from autogen_core.components import FunctionCall @@ -27,7 +27,7 @@ from ..messages import ( StopMessage, TextMessage, ToolCallMessage, - ToolCallResultMessages, + ToolCallResultMessage, ) from ._base_chat_agent import BaseChatAgent @@ -98,7 +98,11 @@ class Handoff(BaseModel): @property def handoff_tool(self) -> Tool: """Create a handoff tool from this handoff configuration.""" - return FunctionTool(lambda: self.message, name=self.name, description=self.description) + + def _handoff_tool() -> str: + return self.message + + return FunctionTool(_handoff_tool, name=self.name, description=self.description) class AssistantAgent(BaseChatAgent): @@ -138,7 +142,7 @@ class AssistantAgent(BaseChatAgent): The following example demonstrates how to create an assistant agent with - a model client and a tool, and generate a response to a simple task using the tool. + a model client and a tool, and generate a stream of messages for a task. .. code-block:: python @@ -154,7 +158,11 @@ class AssistantAgent(BaseChatAgent): model_client = OpenAIChatCompletionClient(model="gpt-4o") agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time]) - await agent.run("What is the current time?", termination_condition=MaxMessageTermination(3)) + stream = agent.run_stream("What is the current time?", termination_condition=MaxMessageTermination(3)) + + async for message in stream: + print(message) + """ @@ -219,6 +227,14 @@ class AssistantAgent(BaseChatAgent): return [TextMessage, StopMessage] async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: + async for message in self.on_messages_stream(messages, cancellation_token): + if isinstance(message, Response): + return message + raise AssertionError("The stream should have returned the final result.") + + async def on_messages_stream( + self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken + ) -> AsyncGenerator[InnerMessage | Response, None]: # Add messages to the model context. for msg in messages: if isinstance(msg, ResetMessage): @@ -243,6 +259,7 @@ class AssistantAgent(BaseChatAgent): event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name)) # Add the tool call message to the output. inner_messages.append(ToolCallMessage(content=result.content, source=self.name)) + yield ToolCallMessage(content=result.content, source=self.name) # Execute the tool calls. results = await asyncio.gather( @@ -250,7 +267,8 @@ class AssistantAgent(BaseChatAgent): ) event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name)) self._model_context.append(FunctionExecutionResultMessage(content=results)) - inner_messages.append(ToolCallResultMessages(content=results, source=self.name)) + inner_messages.append(ToolCallResultMessage(content=results, source=self.name)) + yield ToolCallResultMessage(content=results, source=self.name) # Detect handoff requests. handoffs: List[Handoff] = [] @@ -261,12 +279,13 @@ class AssistantAgent(BaseChatAgent): if len(handoffs) > 1: raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}") # Return the output messages to signal the handoff. - return Response( + yield Response( chat_message=HandoffMessage( content=handoffs[0].message, target=handoffs[0].target, source=self.name ), inner_messages=inner_messages, ) + return # Generate an inference result based on the current model context. result = await self._model_client.create( @@ -278,13 +297,13 @@ class AssistantAgent(BaseChatAgent): # Detect stop request. request_stop = "terminate" in result.content.strip().lower() if request_stop: - return Response( + yield Response( chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages ) - - return Response( - chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages - ) + else: + yield Response( + chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages + ) async def _execute_tool_call( self, tool_call: FunctionCall, cancellation_token: CancellationToken diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py index ac74077e27..cf146b0c10 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod -from typing import List, Sequence +from typing import AsyncGenerator, List, Sequence from autogen_core.base import CancellationToken -from ..base import ChatAgent, Response, TaskResult, TerminationCondition +from ..base import ChatAgent, Response, TaskResult from ..messages import ChatMessage, InnerMessage, TextMessage @@ -40,12 +40,22 @@ class BaseChatAgent(ChatAgent, ABC): """Handles incoming messages and returns a response.""" ... + async def on_messages_stream( + self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken + ) -> AsyncGenerator[InnerMessage | Response, None]: + """Handles incoming messages and returns a stream of messages and + and the final item is the response. The base implementation in :class:`BaseChatAgent` + simply calls :meth:`on_messages` and yields the messages in the response.""" + response = await self.on_messages(messages, cancellation_token) + for inner_message in response.inner_messages or []: + yield inner_message + yield response + async def run( self, task: str, *, cancellation_token: CancellationToken | None = None, - termination_condition: TerminationCondition | None = None, ) -> TaskResult: """Run the agent with the given task and return the result.""" if cancellation_token is None: @@ -57,3 +67,25 @@ class BaseChatAgent(ChatAgent, ABC): messages += response.inner_messages messages.append(response.chat_message) return TaskResult(messages=messages) + + async def run_stream( + self, + task: str, + *, + cancellation_token: CancellationToken | None = None, + ) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]: + """Run the agent with the given task and return a stream of messages + and the final task result as the last item in the stream.""" + if cancellation_token is None: + cancellation_token = CancellationToken() + first_message = TextMessage(content=task, source="user") + yield first_message + messages: List[InnerMessage | ChatMessage] = [first_message] + async for message in self.on_messages_stream([first_message], cancellation_token): + if isinstance(message, Response): + yield message.chat_message + messages.append(message.chat_message) + yield TaskResult(messages=messages) + else: + messages.append(message) + yield message diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py index d60dba349c..ce73352dae 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py @@ -1,11 +1,10 @@ from dataclasses import dataclass -from typing import List, Protocol, Sequence, runtime_checkable +from typing import AsyncGenerator, List, Protocol, Sequence, runtime_checkable from autogen_core.base import CancellationToken from ..messages import ChatMessage, InnerMessage -from ._task import TaskResult, TaskRunner -from ._termination import TerminationCondition +from ._task import TaskRunner @dataclass(kw_only=True) @@ -45,12 +44,9 @@ class ChatAgent(TaskRunner, Protocol): """Handles incoming messages and returns a response.""" ... - async def run( - self, - task: str, - *, - cancellation_token: CancellationToken | None = None, - termination_condition: TerminationCondition | None = None, - ) -> TaskResult: - """Run the agent with the given task and return the result.""" + def on_messages_stream( + self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken + ) -> AsyncGenerator[InnerMessage | Response, None]: + """Handles incoming messages and returns a stream of inner messages and + and the final item is the response.""" ... diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py index 326cceecb1..2e68c2b811 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py @@ -1,10 +1,9 @@ from dataclasses import dataclass -from typing import Protocol, Sequence +from typing import AsyncGenerator, Protocol, Sequence from autogen_core.base import CancellationToken from ..messages import ChatMessage, InnerMessage -from ._termination import TerminationCondition @dataclass @@ -23,7 +22,16 @@ class TaskRunner(Protocol): task: str, *, cancellation_token: CancellationToken | None = None, - termination_condition: TerminationCondition | None = None, ) -> TaskResult: """Run the task.""" ... + + def run_stream( + self, + task: str, + *, + cancellation_token: CancellationToken | None = None, + ) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]: + """Run the task and produces a stream of messages and the final result + as the last item in the stream.""" + ... diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_team.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_team.py index b0a1dc3d2a..e112a3b512 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_team.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_team.py @@ -1,18 +1,7 @@ from typing import Protocol -from autogen_core.base import CancellationToken - -from ._task import TaskResult, TaskRunner -from ._termination import TerminationCondition +from ._task import TaskRunner class Team(TaskRunner, Protocol): - async def run( - self, - task: str, - *, - cancellation_token: CancellationToken | None = None, - termination_condition: TerminationCondition | None = None, - ) -> TaskResult: - """Run the team on a given task until the termination condition is met.""" - ... + pass diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index f206250e10..51dbcca333 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -57,14 +57,14 @@ class ToolCallMessage(BaseMessage): """The tool calls.""" -class ToolCallResultMessages(BaseMessage): +class ToolCallResultMessage(BaseMessage): """A message signaling the results of tool calls.""" content: List[FunctionExecutionResult] """The tool call results.""" -InnerMessage = ToolCallMessage | ToolCallResultMessages +InnerMessage = ToolCallMessage | ToolCallResultMessage """Messages for intra-agent monologues.""" @@ -80,6 +80,6 @@ __all__ = [ "HandoffMessage", "ResetMessage", "ToolCallMessage", - "ToolCallResultMessages", + "ToolCallResultMessage", "ChatMessage", ] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index 9f3132a749..78ec5159e3 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -1,6 +1,7 @@ +import asyncio import uuid from abc import ABC, abstractmethod -from typing import Callable, List +from typing import AsyncGenerator, Callable, List from autogen_core.application import SingleThreadedAgentRuntime from autogen_core.base import ( @@ -75,9 +76,24 @@ class BaseGroupChat(Team, ABC): cancellation_token: CancellationToken | None = None, termination_condition: TerminationCondition | None = None, ) -> TaskResult: - """Run the team and return the result.""" - # Create intervention handler for termination. + """Run the team and return the result. The base implementation uses + :meth:`run_stream` to run the team and then returns the final result.""" + async for message in self.run_stream( + task, cancellation_token=cancellation_token, termination_condition=termination_condition + ): + if isinstance(message, TaskResult): + return message + raise AssertionError("The stream should have returned the final result.") + async def run_stream( + self, + task: str, + *, + cancellation_token: CancellationToken | None = None, + termination_condition: TerminationCondition | None = None, + ) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]: + """Run the team and produces a stream of messages and the final result + as the last item in the stream.""" # Create the runtime. runtime = SingleThreadedAgentRuntime() @@ -132,6 +148,7 @@ class BaseGroupChat(Team, ABC): ) output_messages: List[InnerMessage | ChatMessage] = [] + output_message_queue: asyncio.Queue[InnerMessage | ChatMessage | None] = asyncio.Queue() async def collect_output_messages( _runtime: AgentRuntime, @@ -140,6 +157,7 @@ class BaseGroupChat(Team, ABC): ctx: MessageContext, ) -> None: output_messages.append(message) + await output_message_queue.put(message) await ClosureAgent.register( runtime, @@ -158,14 +176,29 @@ class BaseGroupChat(Team, ABC): group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id) first_chat_message = TextMessage(content=task, source="user") output_messages.append(first_chat_message) + await output_message_queue.put(first_chat_message) await runtime.publish_message( GroupChatPublishEvent(agent_message=first_chat_message), topic_id=team_topic_id, ) await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id) - # Wait for the runtime to stop. - await runtime.stop_when_idle() + # Start a coroutine to stop the runtime and signal the output message queue is complete. + async def stop_runtime() -> None: + await runtime.stop_when_idle() + await output_message_queue.put(None) - # Return the result. - return TaskResult(messages=output_messages) + shutdown_task = asyncio.create_task(stop_runtime()) + + # Yield the messsages until the queue is empty. + while True: + message = await output_message_queue.get() + if message is None: + break + yield message + + # Wait for the shutdown task to finish. + await shutdown_task + + # Yield the final result. + yield TaskResult(messages=output_messages) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py index 1423735c2f..3fde3f6864 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py @@ -3,7 +3,7 @@ from typing import Any, List from autogen_core.base import MessageContext from autogen_core.components import DefaultTopicId, event -from ...base import ChatAgent +from ...base import ChatAgent, Response from ...messages import ChatMessage from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent from ._sequential_routed_agent import SequentialRoutedAgent @@ -37,28 +37,26 @@ class ChatAgentContainer(SequentialRoutedAgent): """Handle a content request event by passing the messages in the buffer to the delegate agent and publish the response.""" # Pass the messages in the buffer to the delegate agent. - response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token) - if not any(isinstance(response.chat_message, msg_type) for msg_type in self._agent.produced_message_types): - raise ValueError( - f"The agent {self._agent.name} produced an unexpected message type: {type(response)}. " - f"Expected one of: {self._agent.produced_message_types}. " - f"Check the agent's produced_message_types property." - ) + response: Response | None = None + async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token): + if isinstance(msg, Response): + await self.publish_message( + msg.chat_message, + topic_id=DefaultTopicId(type=self._output_topic_type), + ) + response = msg + else: + # Publish the message to the output topic. + await self.publish_message(msg, topic_id=DefaultTopicId(type=self._output_topic_type)) + if response is None: + raise ValueError("The agent did not produce a final response. Check the agent's on_messages_stream method.") - # Publish inner messages to the output topic. - if response.inner_messages is not None: - for inner_message in response.inner_messages: - await self.publish_message(inner_message, topic_id=DefaultTopicId(type=self._output_topic_type)) - - # Publish the response. + # Publish the response to the group chat. self._message_buffer.clear() await self.publish_message( GroupChatPublishEvent(agent_message=response.chat_message, source=self.id), topic_id=DefaultTopicId(type=self._parent_topic_type), ) - # Publish the response to the output topic. - await self.publish_message(response.chat_message, topic_id=DefaultTopicId(type=self._output_topic_type)) - async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: raise ValueError(f"Unhandled message in agent container: {type(message)}") diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py index e8f5f66533..cec47f6e1b 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py @@ -61,24 +61,45 @@ class RoundRobinGroupChat(BaseGroupChat): .. code-block:: python - from autogen_agentchat.agents import ToolUseAssistantAgent - from autogen_agentchat.teams import RoundRobinGroupChat, StopMessageTermination + from autogen_ext.models import OpenAIChatCompletionClient + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.teams import RoundRobinGroupChat + from autogen_agentchat.task import StopMessageTermination - assistant = ToolUseAssistantAgent("Assistant", model_client=..., registered_tools=...) + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + + async def get_weather(location: str) -> str: + return f"The weather in {location} is sunny." + + + assistant = AssistantAgent( + "Assistant", + model_client=model_client, + tools=[get_weather], + ) team = RoundRobinGroupChat([assistant]) - await team.run("What's the weather in New York?", termination_condition=StopMessageTermination()) + stream = team.run_stream("What's the weather in New York?", termination_condition=StopMessageTermination()) + async for message in stream: + print(message) A team with multiple participants: .. code-block:: python - from autogen_agentchat.agents import CodingAssistantAgent, CodeExecutorAgent - from autogen_agentchat.teams import RoundRobinGroupChat, StopMessageTermination + from autogen_ext.models import OpenAIChatCompletionClient + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.teams import RoundRobinGroupChat + from autogen_agentchat.task import StopMessageTermination - coding_assistant = CodingAssistantAgent("Coding_Assistant", model_client=...) - executor_agent = CodeExecutorAgent("Code_Executor", code_executor=...) - team = RoundRobinGroupChat([coding_assistant, executor_agent]) - await team.run("Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()) + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + agent1 = AssistantAgent("Assistant1", model_client=model_client) + agent2 = AssistantAgent("Assistant2", model_client=model_client) + team = RoundRobinGroupChat([agent1, agent2]) + stream = team.run_stream("Tell me some jokes.", termination_condition=StopMessageTermination()) + async for message in stream: + print(message) """ diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 3cc489daa6..ed7694d385 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -170,14 +170,48 @@ class SelectorGroupChat(BaseGroupChat): .. code-block:: python - from autogen_agentchat.agents import ToolUseAssistantAgent - from autogen_agentchat.teams import SelectorGroupChat, StopMessageTermination + from autogen_ext.models import OpenAIChatCompletionClient + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.teams import SelectorGroupChat + from autogen_agentchat.task import StopMessageTermination - travel_advisor = ToolUseAssistantAgent("Travel_Advisor", model_client=..., registered_tools=...) - hotel_agent = ToolUseAssistantAgent("Hotel_Agent", model_client=..., registered_tools=...) - flight_agent = ToolUseAssistantAgent("Flight_Agent", model_client=..., registered_tools=...) - team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=...) - await team.run("Book a 3-day trip to new york.", termination_condition=StopMessageTermination()) + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + + async def lookup_hotel(location: str) -> str: + return f"Here are some hotels in {location}: hotel1, hotel2, hotel3." + + + async def lookup_flight(origin: str, destination: str) -> str: + return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3." + + + async def book_trip() -> str: + return "Your trip is booked!" + + + travel_advisor = AssistantAgent( + "Travel_Advisor", + model_client, + tools=[book_trip], + description="Helps with travel planning.", + ) + hotel_agent = AssistantAgent( + "Hotel_Agent", + model_client, + tools=[lookup_hotel], + description="Helps with hotel booking.", + ) + flight_agent = AssistantAgent( + "Flight_Agent", + model_client, + tools=[lookup_flight], + description="Helps with flight booking.", + ) + team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=model_client) + stream = team.run_stream("Book a 3-day trip to new york.", termination_condition=StopMessageTermination()) + async for message in stream: + print(message) """ def __init__( diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py index 872f12e2ba..0f4ec0e63a 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py @@ -79,7 +79,10 @@ class Swarm(BaseGroupChat): ) team = Swarm([agent1, agent2]) - await team.run("What is bob's birthday?", termination_condition=MaxMessageTermination(3)) + + stream = team.run_stream("What is bob's birthday?", termination_condition=MaxMessageTermination(3)) + async for message in stream: + print(message) """ def __init__(self, participants: List[ChatAgent], termination_condition: TerminationCondition | None = None): diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 9dee76539b..4589f86860 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -6,9 +6,9 @@ from typing import Any, AsyncGenerator, List import pytest from autogen_agentchat import EVENT_LOGGER_NAME from autogen_agentchat.agents import AssistantAgent, Handoff +from autogen_agentchat.base import TaskResult from autogen_agentchat.logging import FileLogHandler -from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessages -from autogen_core.base import CancellationToken +from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessage from autogen_core.components.tools import FunctionTool from autogen_ext.models import OpenAIChatCompletionClient from openai.resources.chat.completions import AsyncCompletions @@ -114,9 +114,19 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: assert len(result.messages) == 4 assert isinstance(result.messages[0], TextMessage) assert isinstance(result.messages[1], ToolCallMessage) - assert isinstance(result.messages[2], ToolCallResultMessages) + assert isinstance(result.messages[2], ToolCallResultMessage) assert isinstance(result.messages[3], TextMessage) + # Test streaming. + mock._curr_index = 0 # pyright: ignore + index = 0 + async for message in tool_use_agent.run_stream("task"): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + @pytest.mark.asyncio async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: @@ -160,8 +170,19 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: handoffs=[handoff], ) assert HandoffMessage in tool_use_agent.produced_message_types - response = await tool_use_agent.on_messages( - [TextMessage(content="task", source="user")], cancellation_token=CancellationToken() - ) - assert isinstance(response.chat_message, HandoffMessage) - assert response.chat_message.target == "agent2" + result = await tool_use_agent.run("task") + assert len(result.messages) == 4 + assert isinstance(result.messages[0], TextMessage) + assert isinstance(result.messages[1], ToolCallMessage) + assert isinstance(result.messages[2], ToolCallResultMessage) + assert isinstance(result.messages[3], HandoffMessage) + + # Test streaming. + mock._curr_index = 0 # pyright: ignore + index = 0 + async for message in tool_use_agent.run_stream("task"): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index e6510c2fa1..4e1485ce30 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -12,7 +12,7 @@ from autogen_agentchat.agents import ( CodeExecutorAgent, Handoff, ) -from autogen_agentchat.base import Response +from autogen_agentchat.base import Response, TaskResult from autogen_agentchat.logging import FileLogHandler from autogen_agentchat.messages import ( ChatMessage, @@ -20,7 +20,7 @@ from autogen_agentchat.messages import ( StopMessage, TextMessage, ToolCallMessage, - ToolCallResultMessages, + ToolCallResultMessage, ) from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination from autogen_agentchat.teams import ( @@ -59,6 +59,9 @@ class _MockChatCompletion: self._curr_index += 1 return completion + def reset(self) -> None: + self._curr_index = 0 + class _EchoAgent(BaseChatAgent): def __init__(self, name: str, description: str) -> None: @@ -147,7 +150,8 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None: ) team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent]) result = await team.run( - "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + "Write a program that prints 'Hello, world!'", + termination_condition=StopMessageTermination(), ) expected_messages = [ "Write a program that prints 'Hello, world!'", @@ -164,6 +168,18 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None: # Assert that all expected messages are in the collected messages assert normalized_messages == expected_messages + # Test streaming. + mock.reset() + index = 0 + async for message in team.run_stream( + "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + ): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + @pytest.mark.asyncio async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: @@ -230,13 +246,14 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch echo_agent = _EchoAgent("echo_agent", description="echo agent") team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent]) result = await team.run( - "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + "Write a program that prints 'Hello, world!'", + termination_condition=StopMessageTermination(), ) assert len(result.messages) == 6 assert isinstance(result.messages[0], TextMessage) # task assert isinstance(result.messages[1], ToolCallMessage) # tool call - assert isinstance(result.messages[2], ToolCallResultMessages) # tool call result + assert isinstance(result.messages[2], ToolCallResultMessage) # tool call result assert isinstance(result.messages[3], TextMessage) # tool use agent response assert isinstance(result.messages[4], TextMessage) # echo agent response assert isinstance(result.messages[5], StopMessage) # tool use agent response @@ -253,6 +270,19 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch assert context[2].content[0].call_id == "1" assert context[3].content == "Hello" + # Test streaming. + tool_use_agent._model_context.clear() # pyright: ignore + mock.reset() + index = 0 + async for message in team.run_stream( + "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + ): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + @pytest.mark.asyncio async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None: @@ -320,7 +350,8 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None: model_client=OpenAIChatCompletionClient(model=model, api_key=""), ) result = await team.run( - "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + "Write a program that prints 'Hello, world!'", + termination_condition=StopMessageTermination(), ) assert len(result.messages) == 6 assert result.messages[0].content == "Write a program that prints 'Hello, world!'" @@ -330,6 +361,19 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None: assert result.messages[4].source == "agent2" assert result.messages[5].source == "agent1" + # Test streaming. + mock.reset() + agent1._count = 0 # pyright: ignore + index = 0 + async for message in team.run_stream( + "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + ): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + @pytest.mark.asyncio async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) -> None: @@ -356,7 +400,8 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) model_client=OpenAIChatCompletionClient(model=model, api_key=""), ) result = await team.run( - "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + "Write a program that prints 'Hello, world!'", + termination_condition=StopMessageTermination(), ) assert len(result.messages) == 5 assert result.messages[0].content == "Write a program that prints 'Hello, world!'" @@ -367,6 +412,19 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) # only one chat completion was called assert mock._curr_index == 1 # pyright: ignore + # Test streaming. + mock.reset() + agent1._count = 0 # pyright: ignore + index = 0 + async for message in team.run_stream( + "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + ): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + @pytest.mark.asyncio async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pytest.MonkeyPatch) -> None: @@ -422,6 +480,18 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte assert result.messages[2].source == "agent2" assert result.messages[3].source == "agent1" + # Test streaming. + mock.reset() + index = 0 + async for message in team.run_stream( + "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + ): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + class _HandOffAgent(BaseChatAgent): def __init__(self, name: str, description: str, next_agent: str) -> None: @@ -446,8 +516,8 @@ async def test_swarm_handoff() -> None: second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent") third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent") - team = Swarm([second_agent, first_agent, third_agent]) - result = await team.run("task", termination_condition=MaxMessageTermination(6)) + team = Swarm([second_agent, first_agent, third_agent], termination_condition=MaxMessageTermination(6)) + result = await team.run("task") assert len(result.messages) == 6 assert result.messages[0].content == "task" assert result.messages[1].content == "Transferred to third_agent." @@ -456,6 +526,16 @@ async def test_swarm_handoff() -> None: assert result.messages[4].content == "Transferred to third_agent." assert result.messages[5].content == "Transferred to first_agent." + # Test streaming. + index = 0 + stream = team.run_stream("task", termination_condition=MaxMessageTermination(6)) + async for message in stream: + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + @pytest.mark.asyncio async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None: @@ -514,19 +594,31 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) - mock = _MockChatCompletion(chat_completions) monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) - agnet1 = AssistantAgent( + agent1 = AssistantAgent( "agent1", model_client=OpenAIChatCompletionClient(model=model, api_key=""), handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")], ) agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1") - team = Swarm([agnet1, agent2]) + team = Swarm([agent1, agent2]) result = await team.run("task", termination_condition=StopMessageTermination()) assert len(result.messages) == 7 assert result.messages[0].content == "task" assert isinstance(result.messages[1], ToolCallMessage) - assert isinstance(result.messages[2], ToolCallResultMessages) + assert isinstance(result.messages[2], ToolCallResultMessage) assert result.messages[3].content == "handoff to agent2" assert result.messages[4].content == "Transferred to agent1." assert result.messages[5].content == "Hello" assert result.messages[6].content == "TERMINATE" + + # Test streaming. + agent1._model_context.clear() # pyright: ignore + mock.reset() + index = 0 + stream = team.run_stream("task", termination_condition=StopMessageTermination()) + async for message in stream: + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1