AgentChat streaming API (#4015)

This commit is contained in:
Eric Zhu 2024-11-01 04:12:43 -07:00 committed by GitHub
parent 4023454c58
commit cff7d842a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 353 additions and 107 deletions

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
import json import json
import logging import logging
from typing import Any, Awaitable, Callable, Dict, List, Sequence from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Sequence
from autogen_core.base import CancellationToken from autogen_core.base import CancellationToken
from autogen_core.components import FunctionCall from autogen_core.components import FunctionCall
@ -27,7 +27,7 @@ from ..messages import (
StopMessage, StopMessage,
TextMessage, TextMessage,
ToolCallMessage, ToolCallMessage,
ToolCallResultMessages, ToolCallResultMessage,
) )
from ._base_chat_agent import BaseChatAgent from ._base_chat_agent import BaseChatAgent
@ -98,7 +98,11 @@ class Handoff(BaseModel):
@property @property
def handoff_tool(self) -> Tool: def handoff_tool(self) -> Tool:
"""Create a handoff tool from this handoff configuration.""" """Create a handoff tool from this handoff configuration."""
return FunctionTool(lambda: self.message, name=self.name, description=self.description)
def _handoff_tool() -> str:
return self.message
return FunctionTool(_handoff_tool, name=self.name, description=self.description)
class AssistantAgent(BaseChatAgent): class AssistantAgent(BaseChatAgent):
@ -138,7 +142,7 @@ class AssistantAgent(BaseChatAgent):
The following example demonstrates how to create an assistant agent with The following example demonstrates how to create an assistant agent with
a model client and a tool, and generate a response to a simple task using the tool. a model client and a tool, and generate a stream of messages for a task.
.. code-block:: python .. code-block:: python
@ -154,7 +158,11 @@ class AssistantAgent(BaseChatAgent):
model_client = OpenAIChatCompletionClient(model="gpt-4o") model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time]) agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time])
await agent.run("What is the current time?", termination_condition=MaxMessageTermination(3)) stream = agent.run_stream("What is the current time?", termination_condition=MaxMessageTermination(3))
async for message in stream:
print(message)
""" """
@ -219,6 +227,14 @@ class AssistantAgent(BaseChatAgent):
return [TextMessage, StopMessage] return [TextMessage, StopMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
return message
raise AssertionError("The stream should have returned the final result.")
async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[InnerMessage | Response, None]:
# Add messages to the model context. # Add messages to the model context.
for msg in messages: for msg in messages:
if isinstance(msg, ResetMessage): if isinstance(msg, ResetMessage):
@ -243,6 +259,7 @@ class AssistantAgent(BaseChatAgent):
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name)) event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
# Add the tool call message to the output. # Add the tool call message to the output.
inner_messages.append(ToolCallMessage(content=result.content, source=self.name)) inner_messages.append(ToolCallMessage(content=result.content, source=self.name))
yield ToolCallMessage(content=result.content, source=self.name)
# Execute the tool calls. # Execute the tool calls.
results = await asyncio.gather( results = await asyncio.gather(
@ -250,7 +267,8 @@ class AssistantAgent(BaseChatAgent):
) )
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name)) event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
self._model_context.append(FunctionExecutionResultMessage(content=results)) self._model_context.append(FunctionExecutionResultMessage(content=results))
inner_messages.append(ToolCallResultMessages(content=results, source=self.name)) inner_messages.append(ToolCallResultMessage(content=results, source=self.name))
yield ToolCallResultMessage(content=results, source=self.name)
# Detect handoff requests. # Detect handoff requests.
handoffs: List[Handoff] = [] handoffs: List[Handoff] = []
@ -261,12 +279,13 @@ class AssistantAgent(BaseChatAgent):
if len(handoffs) > 1: if len(handoffs) > 1:
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}") raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
# Return the output messages to signal the handoff. # Return the output messages to signal the handoff.
return Response( yield Response(
chat_message=HandoffMessage( chat_message=HandoffMessage(
content=handoffs[0].message, target=handoffs[0].target, source=self.name content=handoffs[0].message, target=handoffs[0].target, source=self.name
), ),
inner_messages=inner_messages, inner_messages=inner_messages,
) )
return
# Generate an inference result based on the current model context. # Generate an inference result based on the current model context.
result = await self._model_client.create( result = await self._model_client.create(
@ -278,13 +297,13 @@ class AssistantAgent(BaseChatAgent):
# Detect stop request. # Detect stop request.
request_stop = "terminate" in result.content.strip().lower() request_stop = "terminate" in result.content.strip().lower()
if request_stop: if request_stop:
return Response( yield Response(
chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
) )
else:
return Response( yield Response(
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
) )
async def _execute_tool_call( async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken self, tool_call: FunctionCall, cancellation_token: CancellationToken

View File

@ -1,9 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Sequence from typing import AsyncGenerator, List, Sequence
from autogen_core.base import CancellationToken from autogen_core.base import CancellationToken
from ..base import ChatAgent, Response, TaskResult, TerminationCondition from ..base import ChatAgent, Response, TaskResult
from ..messages import ChatMessage, InnerMessage, TextMessage from ..messages import ChatMessage, InnerMessage, TextMessage
@ -40,12 +40,22 @@ class BaseChatAgent(ChatAgent, ABC):
"""Handles incoming messages and returns a response.""" """Handles incoming messages and returns a response."""
... ...
async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[InnerMessage | Response, None]:
"""Handles incoming messages and returns a stream of messages and
and the final item is the response. The base implementation in :class:`BaseChatAgent`
simply calls :meth:`on_messages` and yields the messages in the response."""
response = await self.on_messages(messages, cancellation_token)
for inner_message in response.inner_messages or []:
yield inner_message
yield response
async def run( async def run(
self, self,
task: str, task: str,
*, *,
cancellation_token: CancellationToken | None = None, cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult: ) -> TaskResult:
"""Run the agent with the given task and return the result.""" """Run the agent with the given task and return the result."""
if cancellation_token is None: if cancellation_token is None:
@ -57,3 +67,25 @@ class BaseChatAgent(ChatAgent, ABC):
messages += response.inner_messages messages += response.inner_messages
messages.append(response.chat_message) messages.append(response.chat_message)
return TaskResult(messages=messages) return TaskResult(messages=messages)
async def run_stream(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
"""Run the agent with the given task and return a stream of messages
and the final task result as the last item in the stream."""
if cancellation_token is None:
cancellation_token = CancellationToken()
first_message = TextMessage(content=task, source="user")
yield first_message
messages: List[InnerMessage | ChatMessage] = [first_message]
async for message in self.on_messages_stream([first_message], cancellation_token):
if isinstance(message, Response):
yield message.chat_message
messages.append(message.chat_message)
yield TaskResult(messages=messages)
else:
messages.append(message)
yield message

View File

@ -1,11 +1,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Protocol, Sequence, runtime_checkable from typing import AsyncGenerator, List, Protocol, Sequence, runtime_checkable
from autogen_core.base import CancellationToken from autogen_core.base import CancellationToken
from ..messages import ChatMessage, InnerMessage from ..messages import ChatMessage, InnerMessage
from ._task import TaskResult, TaskRunner from ._task import TaskRunner
from ._termination import TerminationCondition
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -45,12 +44,9 @@ class ChatAgent(TaskRunner, Protocol):
"""Handles incoming messages and returns a response.""" """Handles incoming messages and returns a response."""
... ...
async def run( def on_messages_stream(
self, self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
task: str, ) -> AsyncGenerator[InnerMessage | Response, None]:
*, """Handles incoming messages and returns a stream of inner messages and
cancellation_token: CancellationToken | None = None, and the final item is the response."""
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the agent with the given task and return the result."""
... ...

View File

@ -1,10 +1,9 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Protocol, Sequence from typing import AsyncGenerator, Protocol, Sequence
from autogen_core.base import CancellationToken from autogen_core.base import CancellationToken
from ..messages import ChatMessage, InnerMessage from ..messages import ChatMessage, InnerMessage
from ._termination import TerminationCondition
@dataclass @dataclass
@ -23,7 +22,16 @@ class TaskRunner(Protocol):
task: str, task: str,
*, *,
cancellation_token: CancellationToken | None = None, cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult: ) -> TaskResult:
"""Run the task.""" """Run the task."""
... ...
def run_stream(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
"""Run the task and produces a stream of messages and the final result
as the last item in the stream."""
...

View File

@ -1,18 +1,7 @@
from typing import Protocol from typing import Protocol
from autogen_core.base import CancellationToken from ._task import TaskRunner
from ._task import TaskResult, TaskRunner
from ._termination import TerminationCondition
class Team(TaskRunner, Protocol): class Team(TaskRunner, Protocol):
async def run( pass
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the team on a given task until the termination condition is met."""
...

View File

@ -57,14 +57,14 @@ class ToolCallMessage(BaseMessage):
"""The tool calls.""" """The tool calls."""
class ToolCallResultMessages(BaseMessage): class ToolCallResultMessage(BaseMessage):
"""A message signaling the results of tool calls.""" """A message signaling the results of tool calls."""
content: List[FunctionExecutionResult] content: List[FunctionExecutionResult]
"""The tool call results.""" """The tool call results."""
InnerMessage = ToolCallMessage | ToolCallResultMessages InnerMessage = ToolCallMessage | ToolCallResultMessage
"""Messages for intra-agent monologues.""" """Messages for intra-agent monologues."""
@ -80,6 +80,6 @@ __all__ = [
"HandoffMessage", "HandoffMessage",
"ResetMessage", "ResetMessage",
"ToolCallMessage", "ToolCallMessage",
"ToolCallResultMessages", "ToolCallResultMessage",
"ChatMessage", "ChatMessage",
] ]

View File

@ -1,6 +1,7 @@
import asyncio
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, List from typing import AsyncGenerator, Callable, List
from autogen_core.application import SingleThreadedAgentRuntime from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import ( from autogen_core.base import (
@ -75,9 +76,24 @@ class BaseGroupChat(Team, ABC):
cancellation_token: CancellationToken | None = None, cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None, termination_condition: TerminationCondition | None = None,
) -> TaskResult: ) -> TaskResult:
"""Run the team and return the result.""" """Run the team and return the result. The base implementation uses
# Create intervention handler for termination. :meth:`run_stream` to run the team and then returns the final result."""
async for message in self.run_stream(
task, cancellation_token=cancellation_token, termination_condition=termination_condition
):
if isinstance(message, TaskResult):
return message
raise AssertionError("The stream should have returned the final result.")
async def run_stream(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
"""Run the team and produces a stream of messages and the final result
as the last item in the stream."""
# Create the runtime. # Create the runtime.
runtime = SingleThreadedAgentRuntime() runtime = SingleThreadedAgentRuntime()
@ -132,6 +148,7 @@ class BaseGroupChat(Team, ABC):
) )
output_messages: List[InnerMessage | ChatMessage] = [] output_messages: List[InnerMessage | ChatMessage] = []
output_message_queue: asyncio.Queue[InnerMessage | ChatMessage | None] = asyncio.Queue()
async def collect_output_messages( async def collect_output_messages(
_runtime: AgentRuntime, _runtime: AgentRuntime,
@ -140,6 +157,7 @@ class BaseGroupChat(Team, ABC):
ctx: MessageContext, ctx: MessageContext,
) -> None: ) -> None:
output_messages.append(message) output_messages.append(message)
await output_message_queue.put(message)
await ClosureAgent.register( await ClosureAgent.register(
runtime, runtime,
@ -158,14 +176,29 @@ class BaseGroupChat(Team, ABC):
group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id) group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id)
first_chat_message = TextMessage(content=task, source="user") first_chat_message = TextMessage(content=task, source="user")
output_messages.append(first_chat_message) output_messages.append(first_chat_message)
await output_message_queue.put(first_chat_message)
await runtime.publish_message( await runtime.publish_message(
GroupChatPublishEvent(agent_message=first_chat_message), GroupChatPublishEvent(agent_message=first_chat_message),
topic_id=team_topic_id, topic_id=team_topic_id,
) )
await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id) await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id)
# Wait for the runtime to stop. # Start a coroutine to stop the runtime and signal the output message queue is complete.
await runtime.stop_when_idle() async def stop_runtime() -> None:
await runtime.stop_when_idle()
await output_message_queue.put(None)
# Return the result. shutdown_task = asyncio.create_task(stop_runtime())
return TaskResult(messages=output_messages)
# Yield the messsages until the queue is empty.
while True:
message = await output_message_queue.get()
if message is None:
break
yield message
# Wait for the shutdown task to finish.
await shutdown_task
# Yield the final result.
yield TaskResult(messages=output_messages)

View File

@ -3,7 +3,7 @@ from typing import Any, List
from autogen_core.base import MessageContext from autogen_core.base import MessageContext
from autogen_core.components import DefaultTopicId, event from autogen_core.components import DefaultTopicId, event
from ...base import ChatAgent from ...base import ChatAgent, Response
from ...messages import ChatMessage from ...messages import ChatMessage
from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent
from ._sequential_routed_agent import SequentialRoutedAgent from ._sequential_routed_agent import SequentialRoutedAgent
@ -37,28 +37,26 @@ class ChatAgentContainer(SequentialRoutedAgent):
"""Handle a content request event by passing the messages in the buffer """Handle a content request event by passing the messages in the buffer
to the delegate agent and publish the response.""" to the delegate agent and publish the response."""
# Pass the messages in the buffer to the delegate agent. # Pass the messages in the buffer to the delegate agent.
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token) response: Response | None = None
if not any(isinstance(response.chat_message, msg_type) for msg_type in self._agent.produced_message_types): async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token):
raise ValueError( if isinstance(msg, Response):
f"The agent {self._agent.name} produced an unexpected message type: {type(response)}. " await self.publish_message(
f"Expected one of: {self._agent.produced_message_types}. " msg.chat_message,
f"Check the agent's produced_message_types property." topic_id=DefaultTopicId(type=self._output_topic_type),
) )
response = msg
else:
# Publish the message to the output topic.
await self.publish_message(msg, topic_id=DefaultTopicId(type=self._output_topic_type))
if response is None:
raise ValueError("The agent did not produce a final response. Check the agent's on_messages_stream method.")
# Publish inner messages to the output topic. # Publish the response to the group chat.
if response.inner_messages is not None:
for inner_message in response.inner_messages:
await self.publish_message(inner_message, topic_id=DefaultTopicId(type=self._output_topic_type))
# Publish the response.
self._message_buffer.clear() self._message_buffer.clear()
await self.publish_message( await self.publish_message(
GroupChatPublishEvent(agent_message=response.chat_message, source=self.id), GroupChatPublishEvent(agent_message=response.chat_message, source=self.id),
topic_id=DefaultTopicId(type=self._parent_topic_type), topic_id=DefaultTopicId(type=self._parent_topic_type),
) )
# Publish the response to the output topic.
await self.publish_message(response.chat_message, topic_id=DefaultTopicId(type=self._output_topic_type))
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
raise ValueError(f"Unhandled message in agent container: {type(message)}") raise ValueError(f"Unhandled message in agent container: {type(message)}")

View File

@ -61,24 +61,45 @@ class RoundRobinGroupChat(BaseGroupChat):
.. code-block:: python .. code-block:: python
from autogen_agentchat.agents import ToolUseAssistantAgent from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.teams import RoundRobinGroupChat, StopMessageTermination from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.task import StopMessageTermination
assistant = ToolUseAssistantAgent("Assistant", model_client=..., registered_tools=...) model_client = OpenAIChatCompletionClient(model="gpt-4o")
async def get_weather(location: str) -> str:
return f"The weather in {location} is sunny."
assistant = AssistantAgent(
"Assistant",
model_client=model_client,
tools=[get_weather],
)
team = RoundRobinGroupChat([assistant]) team = RoundRobinGroupChat([assistant])
await team.run("What's the weather in New York?", termination_condition=StopMessageTermination()) stream = team.run_stream("What's the weather in New York?", termination_condition=StopMessageTermination())
async for message in stream:
print(message)
A team with multiple participants: A team with multiple participants:
.. code-block:: python .. code-block:: python
from autogen_agentchat.agents import CodingAssistantAgent, CodeExecutorAgent from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.teams import RoundRobinGroupChat, StopMessageTermination from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.task import StopMessageTermination
coding_assistant = CodingAssistantAgent("Coding_Assistant", model_client=...) model_client = OpenAIChatCompletionClient(model="gpt-4o")
executor_agent = CodeExecutorAgent("Code_Executor", code_executor=...)
team = RoundRobinGroupChat([coding_assistant, executor_agent]) agent1 = AssistantAgent("Assistant1", model_client=model_client)
await team.run("Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()) agent2 = AssistantAgent("Assistant2", model_client=model_client)
team = RoundRobinGroupChat([agent1, agent2])
stream = team.run_stream("Tell me some jokes.", termination_condition=StopMessageTermination())
async for message in stream:
print(message)
""" """

View File

@ -170,14 +170,48 @@ class SelectorGroupChat(BaseGroupChat):
.. code-block:: python .. code-block:: python
from autogen_agentchat.agents import ToolUseAssistantAgent from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.teams import SelectorGroupChat, StopMessageTermination from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.task import StopMessageTermination
travel_advisor = ToolUseAssistantAgent("Travel_Advisor", model_client=..., registered_tools=...) model_client = OpenAIChatCompletionClient(model="gpt-4o")
hotel_agent = ToolUseAssistantAgent("Hotel_Agent", model_client=..., registered_tools=...)
flight_agent = ToolUseAssistantAgent("Flight_Agent", model_client=..., registered_tools=...)
team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=...) async def lookup_hotel(location: str) -> str:
await team.run("Book a 3-day trip to new york.", termination_condition=StopMessageTermination()) return f"Here are some hotels in {location}: hotel1, hotel2, hotel3."
async def lookup_flight(origin: str, destination: str) -> str:
return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3."
async def book_trip() -> str:
return "Your trip is booked!"
travel_advisor = AssistantAgent(
"Travel_Advisor",
model_client,
tools=[book_trip],
description="Helps with travel planning.",
)
hotel_agent = AssistantAgent(
"Hotel_Agent",
model_client,
tools=[lookup_hotel],
description="Helps with hotel booking.",
)
flight_agent = AssistantAgent(
"Flight_Agent",
model_client,
tools=[lookup_flight],
description="Helps with flight booking.",
)
team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=model_client)
stream = team.run_stream("Book a 3-day trip to new york.", termination_condition=StopMessageTermination())
async for message in stream:
print(message)
""" """
def __init__( def __init__(

View File

@ -79,7 +79,10 @@ class Swarm(BaseGroupChat):
) )
team = Swarm([agent1, agent2]) team = Swarm([agent1, agent2])
await team.run("What is bob's birthday?", termination_condition=MaxMessageTermination(3))
stream = team.run_stream("What is bob's birthday?", termination_condition=MaxMessageTermination(3))
async for message in stream:
print(message)
""" """
def __init__(self, participants: List[ChatAgent], termination_condition: TerminationCondition | None = None): def __init__(self, participants: List[ChatAgent], termination_condition: TerminationCondition | None = None):

View File

@ -6,9 +6,9 @@ from typing import Any, AsyncGenerator, List
import pytest import pytest
from autogen_agentchat import EVENT_LOGGER_NAME from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import AssistantAgent, Handoff from autogen_agentchat.agents import AssistantAgent, Handoff
from autogen_agentchat.base import TaskResult
from autogen_agentchat.logging import FileLogHandler from autogen_agentchat.logging import FileLogHandler
from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessages from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
from autogen_core.base import CancellationToken
from autogen_core.components.tools import FunctionTool from autogen_core.components.tools import FunctionTool
from autogen_ext.models import OpenAIChatCompletionClient from autogen_ext.models import OpenAIChatCompletionClient
from openai.resources.chat.completions import AsyncCompletions from openai.resources.chat.completions import AsyncCompletions
@ -114,9 +114,19 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(result.messages) == 4 assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage) assert isinstance(result.messages[0], TextMessage)
assert isinstance(result.messages[1], ToolCallMessage) assert isinstance(result.messages[1], ToolCallMessage)
assert isinstance(result.messages[2], ToolCallResultMessages) assert isinstance(result.messages[2], ToolCallResultMessage)
assert isinstance(result.messages[3], TextMessage) assert isinstance(result.messages[3], TextMessage)
# Test streaming.
mock._curr_index = 0 # pyright: ignore
index = 0
async for message in tool_use_agent.run_stream("task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
@ -160,8 +170,19 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
handoffs=[handoff], handoffs=[handoff],
) )
assert HandoffMessage in tool_use_agent.produced_message_types assert HandoffMessage in tool_use_agent.produced_message_types
response = await tool_use_agent.on_messages( result = await tool_use_agent.run("task")
[TextMessage(content="task", source="user")], cancellation_token=CancellationToken() assert len(result.messages) == 4
) assert isinstance(result.messages[0], TextMessage)
assert isinstance(response.chat_message, HandoffMessage) assert isinstance(result.messages[1], ToolCallMessage)
assert response.chat_message.target == "agent2" assert isinstance(result.messages[2], ToolCallResultMessage)
assert isinstance(result.messages[3], HandoffMessage)
# Test streaming.
mock._curr_index = 0 # pyright: ignore
index = 0
async for message in tool_use_agent.run_stream("task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1

View File

@ -12,7 +12,7 @@ from autogen_agentchat.agents import (
CodeExecutorAgent, CodeExecutorAgent,
Handoff, Handoff,
) )
from autogen_agentchat.base import Response from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.logging import FileLogHandler from autogen_agentchat.logging import FileLogHandler
from autogen_agentchat.messages import ( from autogen_agentchat.messages import (
ChatMessage, ChatMessage,
@ -20,7 +20,7 @@ from autogen_agentchat.messages import (
StopMessage, StopMessage,
TextMessage, TextMessage,
ToolCallMessage, ToolCallMessage,
ToolCallResultMessages, ToolCallResultMessage,
) )
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination
from autogen_agentchat.teams import ( from autogen_agentchat.teams import (
@ -59,6 +59,9 @@ class _MockChatCompletion:
self._curr_index += 1 self._curr_index += 1
return completion return completion
def reset(self) -> None:
self._curr_index = 0
class _EchoAgent(BaseChatAgent): class _EchoAgent(BaseChatAgent):
def __init__(self, name: str, description: str) -> None: def __init__(self, name: str, description: str) -> None:
@ -147,7 +150,8 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
) )
team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent]) team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent])
result = await team.run( result = await team.run(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() "Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
) )
expected_messages = [ expected_messages = [
"Write a program that prints 'Hello, world!'", "Write a program that prints 'Hello, world!'",
@ -164,6 +168,18 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
# Assert that all expected messages are in the collected messages # Assert that all expected messages are in the collected messages
assert normalized_messages == expected_messages assert normalized_messages == expected_messages
# Test streaming.
mock.reset()
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
@ -230,13 +246,14 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
echo_agent = _EchoAgent("echo_agent", description="echo agent") echo_agent = _EchoAgent("echo_agent", description="echo agent")
team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent]) team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent])
result = await team.run( result = await team.run(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() "Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
) )
assert len(result.messages) == 6 assert len(result.messages) == 6
assert isinstance(result.messages[0], TextMessage) # task assert isinstance(result.messages[0], TextMessage) # task
assert isinstance(result.messages[1], ToolCallMessage) # tool call assert isinstance(result.messages[1], ToolCallMessage) # tool call
assert isinstance(result.messages[2], ToolCallResultMessages) # tool call result assert isinstance(result.messages[2], ToolCallResultMessage) # tool call result
assert isinstance(result.messages[3], TextMessage) # tool use agent response assert isinstance(result.messages[3], TextMessage) # tool use agent response
assert isinstance(result.messages[4], TextMessage) # echo agent response assert isinstance(result.messages[4], TextMessage) # echo agent response
assert isinstance(result.messages[5], StopMessage) # tool use agent response assert isinstance(result.messages[5], StopMessage) # tool use agent response
@ -253,6 +270,19 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
assert context[2].content[0].call_id == "1" assert context[2].content[0].call_id == "1"
assert context[3].content == "Hello" assert context[3].content == "Hello"
# Test streaming.
tool_use_agent._model_context.clear() # pyright: ignore
mock.reset()
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None: async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
@ -320,7 +350,8 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
model_client=OpenAIChatCompletionClient(model=model, api_key=""), model_client=OpenAIChatCompletionClient(model=model, api_key=""),
) )
result = await team.run( result = await team.run(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() "Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
) )
assert len(result.messages) == 6 assert len(result.messages) == 6
assert result.messages[0].content == "Write a program that prints 'Hello, world!'" assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
@ -330,6 +361,19 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
assert result.messages[4].source == "agent2" assert result.messages[4].source == "agent2"
assert result.messages[5].source == "agent1" assert result.messages[5].source == "agent1"
# Test streaming.
mock.reset()
agent1._count = 0 # pyright: ignore
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) -> None: async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) -> None:
@ -356,7 +400,8 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch)
model_client=OpenAIChatCompletionClient(model=model, api_key=""), model_client=OpenAIChatCompletionClient(model=model, api_key=""),
) )
result = await team.run( result = await team.run(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() "Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
) )
assert len(result.messages) == 5 assert len(result.messages) == 5
assert result.messages[0].content == "Write a program that prints 'Hello, world!'" assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
@ -367,6 +412,19 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch)
# only one chat completion was called # only one chat completion was called
assert mock._curr_index == 1 # pyright: ignore assert mock._curr_index == 1 # pyright: ignore
# Test streaming.
mock.reset()
agent1._count = 0 # pyright: ignore
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pytest.MonkeyPatch) -> None: async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pytest.MonkeyPatch) -> None:
@ -422,6 +480,18 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte
assert result.messages[2].source == "agent2" assert result.messages[2].source == "agent2"
assert result.messages[3].source == "agent1" assert result.messages[3].source == "agent1"
# Test streaming.
mock.reset()
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
class _HandOffAgent(BaseChatAgent): class _HandOffAgent(BaseChatAgent):
def __init__(self, name: str, description: str, next_agent: str) -> None: def __init__(self, name: str, description: str, next_agent: str) -> None:
@ -446,8 +516,8 @@ async def test_swarm_handoff() -> None:
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent") second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent") third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
team = Swarm([second_agent, first_agent, third_agent]) team = Swarm([second_agent, first_agent, third_agent], termination_condition=MaxMessageTermination(6))
result = await team.run("task", termination_condition=MaxMessageTermination(6)) result = await team.run("task")
assert len(result.messages) == 6 assert len(result.messages) == 6
assert result.messages[0].content == "task" assert result.messages[0].content == "task"
assert result.messages[1].content == "Transferred to third_agent." assert result.messages[1].content == "Transferred to third_agent."
@ -456,6 +526,16 @@ async def test_swarm_handoff() -> None:
assert result.messages[4].content == "Transferred to third_agent." assert result.messages[4].content == "Transferred to third_agent."
assert result.messages[5].content == "Transferred to first_agent." assert result.messages[5].content == "Transferred to first_agent."
# Test streaming.
index = 0
stream = team.run_stream("task", termination_condition=MaxMessageTermination(6))
async for message in stream:
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None: async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
@ -514,19 +594,31 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
mock = _MockChatCompletion(chat_completions) mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
agnet1 = AssistantAgent( agent1 = AssistantAgent(
"agent1", "agent1",
model_client=OpenAIChatCompletionClient(model=model, api_key=""), model_client=OpenAIChatCompletionClient(model=model, api_key=""),
handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")], handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")],
) )
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1") agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
team = Swarm([agnet1, agent2]) team = Swarm([agent1, agent2])
result = await team.run("task", termination_condition=StopMessageTermination()) result = await team.run("task", termination_condition=StopMessageTermination())
assert len(result.messages) == 7 assert len(result.messages) == 7
assert result.messages[0].content == "task" assert result.messages[0].content == "task"
assert isinstance(result.messages[1], ToolCallMessage) assert isinstance(result.messages[1], ToolCallMessage)
assert isinstance(result.messages[2], ToolCallResultMessages) assert isinstance(result.messages[2], ToolCallResultMessage)
assert result.messages[3].content == "handoff to agent2" assert result.messages[3].content == "handoff to agent2"
assert result.messages[4].content == "Transferred to agent1." assert result.messages[4].content == "Transferred to agent1."
assert result.messages[5].content == "Hello" assert result.messages[5].content == "Hello"
assert result.messages[6].content == "TERMINATE" assert result.messages[6].content == "TERMINATE"
# Test streaming.
agent1._model_context.clear() # pyright: ignore
mock.reset()
index = 0
stream = team.run_stream("task", termination_condition=StopMessageTermination())
async for message in stream:
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1