mirror of https://github.com/microsoft/autogen.git
AgentChat streaming API (#4015)
This commit is contained in:
parent
4023454c58
commit
cff7d842a6
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Sequence
|
||||
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Sequence
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components import FunctionCall
|
||||
|
@ -27,7 +27,7 @@ from ..messages import (
|
|||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallMessage,
|
||||
ToolCallResultMessages,
|
||||
ToolCallResultMessage,
|
||||
)
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
|
@ -98,7 +98,11 @@ class Handoff(BaseModel):
|
|||
@property
|
||||
def handoff_tool(self) -> Tool:
|
||||
"""Create a handoff tool from this handoff configuration."""
|
||||
return FunctionTool(lambda: self.message, name=self.name, description=self.description)
|
||||
|
||||
def _handoff_tool() -> str:
|
||||
return self.message
|
||||
|
||||
return FunctionTool(_handoff_tool, name=self.name, description=self.description)
|
||||
|
||||
|
||||
class AssistantAgent(BaseChatAgent):
|
||||
|
@ -138,7 +142,7 @@ class AssistantAgent(BaseChatAgent):
|
|||
|
||||
|
||||
The following example demonstrates how to create an assistant agent with
|
||||
a model client and a tool, and generate a response to a simple task using the tool.
|
||||
a model client and a tool, and generate a stream of messages for a task.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -154,7 +158,11 @@ class AssistantAgent(BaseChatAgent):
|
|||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time])
|
||||
|
||||
await agent.run("What is the current time?", termination_condition=MaxMessageTermination(3))
|
||||
stream = agent.run_stream("What is the current time?", termination_condition=MaxMessageTermination(3))
|
||||
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
|
@ -219,6 +227,14 @@ class AssistantAgent(BaseChatAgent):
|
|||
return [TextMessage, StopMessage]
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async for message in self.on_messages_stream(messages, cancellation_token):
|
||||
if isinstance(message, Response):
|
||||
return message
|
||||
raise AssertionError("The stream should have returned the final result.")
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[InnerMessage | Response, None]:
|
||||
# Add messages to the model context.
|
||||
for msg in messages:
|
||||
if isinstance(msg, ResetMessage):
|
||||
|
@ -243,6 +259,7 @@ class AssistantAgent(BaseChatAgent):
|
|||
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
|
||||
# Add the tool call message to the output.
|
||||
inner_messages.append(ToolCallMessage(content=result.content, source=self.name))
|
||||
yield ToolCallMessage(content=result.content, source=self.name)
|
||||
|
||||
# Execute the tool calls.
|
||||
results = await asyncio.gather(
|
||||
|
@ -250,7 +267,8 @@ class AssistantAgent(BaseChatAgent):
|
|||
)
|
||||
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
|
||||
self._model_context.append(FunctionExecutionResultMessage(content=results))
|
||||
inner_messages.append(ToolCallResultMessages(content=results, source=self.name))
|
||||
inner_messages.append(ToolCallResultMessage(content=results, source=self.name))
|
||||
yield ToolCallResultMessage(content=results, source=self.name)
|
||||
|
||||
# Detect handoff requests.
|
||||
handoffs: List[Handoff] = []
|
||||
|
@ -261,12 +279,13 @@ class AssistantAgent(BaseChatAgent):
|
|||
if len(handoffs) > 1:
|
||||
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
|
||||
# Return the output messages to signal the handoff.
|
||||
return Response(
|
||||
yield Response(
|
||||
chat_message=HandoffMessage(
|
||||
content=handoffs[0].message, target=handoffs[0].target, source=self.name
|
||||
),
|
||||
inner_messages=inner_messages,
|
||||
)
|
||||
return
|
||||
|
||||
# Generate an inference result based on the current model context.
|
||||
result = await self._model_client.create(
|
||||
|
@ -278,13 +297,13 @@ class AssistantAgent(BaseChatAgent):
|
|||
# Detect stop request.
|
||||
request_stop = "terminate" in result.content.strip().lower()
|
||||
if request_stop:
|
||||
return Response(
|
||||
yield Response(
|
||||
chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
|
||||
)
|
||||
|
||||
return Response(
|
||||
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
|
||||
)
|
||||
else:
|
||||
yield Response(
|
||||
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
|
||||
)
|
||||
|
||||
async def _execute_tool_call(
|
||||
self, tool_call: FunctionCall, cancellation_token: CancellationToken
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List, Sequence
|
||||
from typing import AsyncGenerator, List, Sequence
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
|
||||
from ..base import ChatAgent, Response, TaskResult, TerminationCondition
|
||||
from ..base import ChatAgent, Response, TaskResult
|
||||
from ..messages import ChatMessage, InnerMessage, TextMessage
|
||||
|
||||
|
||||
|
@ -40,12 +40,22 @@ class BaseChatAgent(ChatAgent, ABC):
|
|||
"""Handles incoming messages and returns a response."""
|
||||
...
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[InnerMessage | Response, None]:
|
||||
"""Handles incoming messages and returns a stream of messages and
|
||||
and the final item is the response. The base implementation in :class:`BaseChatAgent`
|
||||
simply calls :meth:`on_messages` and yields the messages in the response."""
|
||||
response = await self.on_messages(messages, cancellation_token)
|
||||
for inner_message in response.inner_messages or []:
|
||||
yield inner_message
|
||||
yield response
|
||||
|
||||
async def run(
|
||||
self,
|
||||
task: str,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
) -> TaskResult:
|
||||
"""Run the agent with the given task and return the result."""
|
||||
if cancellation_token is None:
|
||||
|
@ -57,3 +67,25 @@ class BaseChatAgent(ChatAgent, ABC):
|
|||
messages += response.inner_messages
|
||||
messages.append(response.chat_message)
|
||||
return TaskResult(messages=messages)
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
task: str,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
|
||||
"""Run the agent with the given task and return a stream of messages
|
||||
and the final task result as the last item in the stream."""
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
first_message = TextMessage(content=task, source="user")
|
||||
yield first_message
|
||||
messages: List[InnerMessage | ChatMessage] = [first_message]
|
||||
async for message in self.on_messages_stream([first_message], cancellation_token):
|
||||
if isinstance(message, Response):
|
||||
yield message.chat_message
|
||||
messages.append(message.chat_message)
|
||||
yield TaskResult(messages=messages)
|
||||
else:
|
||||
messages.append(message)
|
||||
yield message
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Protocol, Sequence, runtime_checkable
|
||||
from typing import AsyncGenerator, List, Protocol, Sequence, runtime_checkable
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
|
||||
from ..messages import ChatMessage, InnerMessage
|
||||
from ._task import TaskResult, TaskRunner
|
||||
from ._termination import TerminationCondition
|
||||
from ._task import TaskRunner
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
|
@ -45,12 +44,9 @@ class ChatAgent(TaskRunner, Protocol):
|
|||
"""Handles incoming messages and returns a response."""
|
||||
...
|
||||
|
||||
async def run(
|
||||
self,
|
||||
task: str,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
) -> TaskResult:
|
||||
"""Run the agent with the given task and return the result."""
|
||||
def on_messages_stream(
|
||||
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[InnerMessage | Response, None]:
|
||||
"""Handles incoming messages and returns a stream of inner messages and
|
||||
and the final item is the response."""
|
||||
...
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Protocol, Sequence
|
||||
from typing import AsyncGenerator, Protocol, Sequence
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
|
||||
from ..messages import ChatMessage, InnerMessage
|
||||
from ._termination import TerminationCondition
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -23,7 +22,16 @@ class TaskRunner(Protocol):
|
|||
task: str,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
) -> TaskResult:
|
||||
"""Run the task."""
|
||||
...
|
||||
|
||||
def run_stream(
|
||||
self,
|
||||
task: str,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
|
||||
"""Run the task and produces a stream of messages and the final result
|
||||
as the last item in the stream."""
|
||||
...
|
||||
|
|
|
@ -1,18 +1,7 @@
|
|||
from typing import Protocol
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
|
||||
from ._task import TaskResult, TaskRunner
|
||||
from ._termination import TerminationCondition
|
||||
from ._task import TaskRunner
|
||||
|
||||
|
||||
class Team(TaskRunner, Protocol):
|
||||
async def run(
|
||||
self,
|
||||
task: str,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
) -> TaskResult:
|
||||
"""Run the team on a given task until the termination condition is met."""
|
||||
...
|
||||
pass
|
||||
|
|
|
@ -57,14 +57,14 @@ class ToolCallMessage(BaseMessage):
|
|||
"""The tool calls."""
|
||||
|
||||
|
||||
class ToolCallResultMessages(BaseMessage):
|
||||
class ToolCallResultMessage(BaseMessage):
|
||||
"""A message signaling the results of tool calls."""
|
||||
|
||||
content: List[FunctionExecutionResult]
|
||||
"""The tool call results."""
|
||||
|
||||
|
||||
InnerMessage = ToolCallMessage | ToolCallResultMessages
|
||||
InnerMessage = ToolCallMessage | ToolCallResultMessage
|
||||
"""Messages for intra-agent monologues."""
|
||||
|
||||
|
||||
|
@ -80,6 +80,6 @@ __all__ = [
|
|||
"HandoffMessage",
|
||||
"ResetMessage",
|
||||
"ToolCallMessage",
|
||||
"ToolCallResultMessages",
|
||||
"ToolCallResultMessage",
|
||||
"ChatMessage",
|
||||
]
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List
|
||||
from typing import AsyncGenerator, Callable, List
|
||||
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core.base import (
|
||||
|
@ -75,9 +76,24 @@ class BaseGroupChat(Team, ABC):
|
|||
cancellation_token: CancellationToken | None = None,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
) -> TaskResult:
|
||||
"""Run the team and return the result."""
|
||||
# Create intervention handler for termination.
|
||||
"""Run the team and return the result. The base implementation uses
|
||||
:meth:`run_stream` to run the team and then returns the final result."""
|
||||
async for message in self.run_stream(
|
||||
task, cancellation_token=cancellation_token, termination_condition=termination_condition
|
||||
):
|
||||
if isinstance(message, TaskResult):
|
||||
return message
|
||||
raise AssertionError("The stream should have returned the final result.")
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
task: str,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
|
||||
"""Run the team and produces a stream of messages and the final result
|
||||
as the last item in the stream."""
|
||||
# Create the runtime.
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
|
@ -132,6 +148,7 @@ class BaseGroupChat(Team, ABC):
|
|||
)
|
||||
|
||||
output_messages: List[InnerMessage | ChatMessage] = []
|
||||
output_message_queue: asyncio.Queue[InnerMessage | ChatMessage | None] = asyncio.Queue()
|
||||
|
||||
async def collect_output_messages(
|
||||
_runtime: AgentRuntime,
|
||||
|
@ -140,6 +157,7 @@ class BaseGroupChat(Team, ABC):
|
|||
ctx: MessageContext,
|
||||
) -> None:
|
||||
output_messages.append(message)
|
||||
await output_message_queue.put(message)
|
||||
|
||||
await ClosureAgent.register(
|
||||
runtime,
|
||||
|
@ -158,14 +176,29 @@ class BaseGroupChat(Team, ABC):
|
|||
group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id)
|
||||
first_chat_message = TextMessage(content=task, source="user")
|
||||
output_messages.append(first_chat_message)
|
||||
await output_message_queue.put(first_chat_message)
|
||||
await runtime.publish_message(
|
||||
GroupChatPublishEvent(agent_message=first_chat_message),
|
||||
topic_id=team_topic_id,
|
||||
)
|
||||
await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id)
|
||||
|
||||
# Wait for the runtime to stop.
|
||||
await runtime.stop_when_idle()
|
||||
# Start a coroutine to stop the runtime and signal the output message queue is complete.
|
||||
async def stop_runtime() -> None:
|
||||
await runtime.stop_when_idle()
|
||||
await output_message_queue.put(None)
|
||||
|
||||
# Return the result.
|
||||
return TaskResult(messages=output_messages)
|
||||
shutdown_task = asyncio.create_task(stop_runtime())
|
||||
|
||||
# Yield the messsages until the queue is empty.
|
||||
while True:
|
||||
message = await output_message_queue.get()
|
||||
if message is None:
|
||||
break
|
||||
yield message
|
||||
|
||||
# Wait for the shutdown task to finish.
|
||||
await shutdown_task
|
||||
|
||||
# Yield the final result.
|
||||
yield TaskResult(messages=output_messages)
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Any, List
|
|||
from autogen_core.base import MessageContext
|
||||
from autogen_core.components import DefaultTopicId, event
|
||||
|
||||
from ...base import ChatAgent
|
||||
from ...base import ChatAgent, Response
|
||||
from ...messages import ChatMessage
|
||||
from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
|
@ -37,28 +37,26 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
|||
"""Handle a content request event by passing the messages in the buffer
|
||||
to the delegate agent and publish the response."""
|
||||
# Pass the messages in the buffer to the delegate agent.
|
||||
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
|
||||
if not any(isinstance(response.chat_message, msg_type) for msg_type in self._agent.produced_message_types):
|
||||
raise ValueError(
|
||||
f"The agent {self._agent.name} produced an unexpected message type: {type(response)}. "
|
||||
f"Expected one of: {self._agent.produced_message_types}. "
|
||||
f"Check the agent's produced_message_types property."
|
||||
)
|
||||
response: Response | None = None
|
||||
async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token):
|
||||
if isinstance(msg, Response):
|
||||
await self.publish_message(
|
||||
msg.chat_message,
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
response = msg
|
||||
else:
|
||||
# Publish the message to the output topic.
|
||||
await self.publish_message(msg, topic_id=DefaultTopicId(type=self._output_topic_type))
|
||||
if response is None:
|
||||
raise ValueError("The agent did not produce a final response. Check the agent's on_messages_stream method.")
|
||||
|
||||
# Publish inner messages to the output topic.
|
||||
if response.inner_messages is not None:
|
||||
for inner_message in response.inner_messages:
|
||||
await self.publish_message(inner_message, topic_id=DefaultTopicId(type=self._output_topic_type))
|
||||
|
||||
# Publish the response.
|
||||
# Publish the response to the group chat.
|
||||
self._message_buffer.clear()
|
||||
await self.publish_message(
|
||||
GroupChatPublishEvent(agent_message=response.chat_message, source=self.id),
|
||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||
)
|
||||
|
||||
# Publish the response to the output topic.
|
||||
await self.publish_message(response.chat_message, topic_id=DefaultTopicId(type=self._output_topic_type))
|
||||
|
||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
||||
raise ValueError(f"Unhandled message in agent container: {type(message)}")
|
||||
|
|
|
@ -61,24 +61,45 @@ class RoundRobinGroupChat(BaseGroupChat):
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from autogen_agentchat.agents import ToolUseAssistantAgent
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat, StopMessageTermination
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat
|
||||
from autogen_agentchat.task import StopMessageTermination
|
||||
|
||||
assistant = ToolUseAssistantAgent("Assistant", model_client=..., registered_tools=...)
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
|
||||
async def get_weather(location: str) -> str:
|
||||
return f"The weather in {location} is sunny."
|
||||
|
||||
|
||||
assistant = AssistantAgent(
|
||||
"Assistant",
|
||||
model_client=model_client,
|
||||
tools=[get_weather],
|
||||
)
|
||||
team = RoundRobinGroupChat([assistant])
|
||||
await team.run("What's the weather in New York?", termination_condition=StopMessageTermination())
|
||||
stream = team.run_stream("What's the weather in New York?", termination_condition=StopMessageTermination())
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
A team with multiple participants:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from autogen_agentchat.agents import CodingAssistantAgent, CodeExecutorAgent
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat, StopMessageTermination
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat
|
||||
from autogen_agentchat.task import StopMessageTermination
|
||||
|
||||
coding_assistant = CodingAssistantAgent("Coding_Assistant", model_client=...)
|
||||
executor_agent = CodeExecutorAgent("Code_Executor", code_executor=...)
|
||||
team = RoundRobinGroupChat([coding_assistant, executor_agent])
|
||||
await team.run("Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination())
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent("Assistant1", model_client=model_client)
|
||||
agent2 = AssistantAgent("Assistant2", model_client=model_client)
|
||||
team = RoundRobinGroupChat([agent1, agent2])
|
||||
stream = team.run_stream("Tell me some jokes.", termination_condition=StopMessageTermination())
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
"""
|
||||
|
||||
|
|
|
@ -170,14 +170,48 @@ class SelectorGroupChat(BaseGroupChat):
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from autogen_agentchat.agents import ToolUseAssistantAgent
|
||||
from autogen_agentchat.teams import SelectorGroupChat, StopMessageTermination
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.teams import SelectorGroupChat
|
||||
from autogen_agentchat.task import StopMessageTermination
|
||||
|
||||
travel_advisor = ToolUseAssistantAgent("Travel_Advisor", model_client=..., registered_tools=...)
|
||||
hotel_agent = ToolUseAssistantAgent("Hotel_Agent", model_client=..., registered_tools=...)
|
||||
flight_agent = ToolUseAssistantAgent("Flight_Agent", model_client=..., registered_tools=...)
|
||||
team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=...)
|
||||
await team.run("Book a 3-day trip to new york.", termination_condition=StopMessageTermination())
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
|
||||
async def lookup_hotel(location: str) -> str:
|
||||
return f"Here are some hotels in {location}: hotel1, hotel2, hotel3."
|
||||
|
||||
|
||||
async def lookup_flight(origin: str, destination: str) -> str:
|
||||
return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3."
|
||||
|
||||
|
||||
async def book_trip() -> str:
|
||||
return "Your trip is booked!"
|
||||
|
||||
|
||||
travel_advisor = AssistantAgent(
|
||||
"Travel_Advisor",
|
||||
model_client,
|
||||
tools=[book_trip],
|
||||
description="Helps with travel planning.",
|
||||
)
|
||||
hotel_agent = AssistantAgent(
|
||||
"Hotel_Agent",
|
||||
model_client,
|
||||
tools=[lookup_hotel],
|
||||
description="Helps with hotel booking.",
|
||||
)
|
||||
flight_agent = AssistantAgent(
|
||||
"Flight_Agent",
|
||||
model_client,
|
||||
tools=[lookup_flight],
|
||||
description="Helps with flight booking.",
|
||||
)
|
||||
team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=model_client)
|
||||
stream = team.run_stream("Book a 3-day trip to new york.", termination_condition=StopMessageTermination())
|
||||
async for message in stream:
|
||||
print(message)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -79,7 +79,10 @@ class Swarm(BaseGroupChat):
|
|||
)
|
||||
|
||||
team = Swarm([agent1, agent2])
|
||||
await team.run("What is bob's birthday?", termination_condition=MaxMessageTermination(3))
|
||||
|
||||
stream = team.run_stream("What is bob's birthday?", termination_condition=MaxMessageTermination(3))
|
||||
async for message in stream:
|
||||
print(message)
|
||||
"""
|
||||
|
||||
def __init__(self, participants: List[ChatAgent], termination_condition: TerminationCondition | None = None):
|
||||
|
|
|
@ -6,9 +6,9 @@ from typing import Any, AsyncGenerator, List
|
|||
import pytest
|
||||
from autogen_agentchat import EVENT_LOGGER_NAME
|
||||
from autogen_agentchat.agents import AssistantAgent, Handoff
|
||||
from autogen_agentchat.base import TaskResult
|
||||
from autogen_agentchat.logging import FileLogHandler
|
||||
from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessages
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
|
||||
from autogen_core.components.tools import FunctionTool
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from openai.resources.chat.completions import AsyncCompletions
|
||||
|
@ -114,9 +114,19 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
assert len(result.messages) == 4
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], ToolCallMessage)
|
||||
assert isinstance(result.messages[2], ToolCallResultMessages)
|
||||
assert isinstance(result.messages[2], ToolCallResultMessage)
|
||||
assert isinstance(result.messages[3], TextMessage)
|
||||
|
||||
# Test streaming.
|
||||
mock._curr_index = 0 # pyright: ignore
|
||||
index = 0
|
||||
async for message in tool_use_agent.run_stream("task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
@ -160,8 +170,19 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
handoffs=[handoff],
|
||||
)
|
||||
assert HandoffMessage in tool_use_agent.produced_message_types
|
||||
response = await tool_use_agent.on_messages(
|
||||
[TextMessage(content="task", source="user")], cancellation_token=CancellationToken()
|
||||
)
|
||||
assert isinstance(response.chat_message, HandoffMessage)
|
||||
assert response.chat_message.target == "agent2"
|
||||
result = await tool_use_agent.run("task")
|
||||
assert len(result.messages) == 4
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], ToolCallMessage)
|
||||
assert isinstance(result.messages[2], ToolCallResultMessage)
|
||||
assert isinstance(result.messages[3], HandoffMessage)
|
||||
|
||||
# Test streaming.
|
||||
mock._curr_index = 0 # pyright: ignore
|
||||
index = 0
|
||||
async for message in tool_use_agent.run_stream("task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
|
|
@ -12,7 +12,7 @@ from autogen_agentchat.agents import (
|
|||
CodeExecutorAgent,
|
||||
Handoff,
|
||||
)
|
||||
from autogen_agentchat.base import Response
|
||||
from autogen_agentchat.base import Response, TaskResult
|
||||
from autogen_agentchat.logging import FileLogHandler
|
||||
from autogen_agentchat.messages import (
|
||||
ChatMessage,
|
||||
|
@ -20,7 +20,7 @@ from autogen_agentchat.messages import (
|
|||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallMessage,
|
||||
ToolCallResultMessages,
|
||||
ToolCallResultMessage,
|
||||
)
|
||||
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination
|
||||
from autogen_agentchat.teams import (
|
||||
|
@ -59,6 +59,9 @@ class _MockChatCompletion:
|
|||
self._curr_index += 1
|
||||
return completion
|
||||
|
||||
def reset(self) -> None:
|
||||
self._curr_index = 0
|
||||
|
||||
|
||||
class _EchoAgent(BaseChatAgent):
|
||||
def __init__(self, name: str, description: str) -> None:
|
||||
|
@ -147,7 +150,8 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
)
|
||||
team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent])
|
||||
result = await team.run(
|
||||
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
||||
"Write a program that prints 'Hello, world!'",
|
||||
termination_condition=StopMessageTermination(),
|
||||
)
|
||||
expected_messages = [
|
||||
"Write a program that prints 'Hello, world!'",
|
||||
|
@ -164,6 +168,18 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
# Assert that all expected messages are in the collected messages
|
||||
assert normalized_messages == expected_messages
|
||||
|
||||
# Test streaming.
|
||||
mock.reset()
|
||||
index = 0
|
||||
async for message in team.run_stream(
|
||||
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
||||
):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
@ -230,13 +246,14 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
|
|||
echo_agent = _EchoAgent("echo_agent", description="echo agent")
|
||||
team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent])
|
||||
result = await team.run(
|
||||
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
||||
"Write a program that prints 'Hello, world!'",
|
||||
termination_condition=StopMessageTermination(),
|
||||
)
|
||||
|
||||
assert len(result.messages) == 6
|
||||
assert isinstance(result.messages[0], TextMessage) # task
|
||||
assert isinstance(result.messages[1], ToolCallMessage) # tool call
|
||||
assert isinstance(result.messages[2], ToolCallResultMessages) # tool call result
|
||||
assert isinstance(result.messages[2], ToolCallResultMessage) # tool call result
|
||||
assert isinstance(result.messages[3], TextMessage) # tool use agent response
|
||||
assert isinstance(result.messages[4], TextMessage) # echo agent response
|
||||
assert isinstance(result.messages[5], StopMessage) # tool use agent response
|
||||
|
@ -253,6 +270,19 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
|
|||
assert context[2].content[0].call_id == "1"
|
||||
assert context[3].content == "Hello"
|
||||
|
||||
# Test streaming.
|
||||
tool_use_agent._model_context.clear() # pyright: ignore
|
||||
mock.reset()
|
||||
index = 0
|
||||
async for message in team.run_stream(
|
||||
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
||||
):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
@ -320,7 +350,8 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
)
|
||||
result = await team.run(
|
||||
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
||||
"Write a program that prints 'Hello, world!'",
|
||||
termination_condition=StopMessageTermination(),
|
||||
)
|
||||
assert len(result.messages) == 6
|
||||
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
||||
|
@ -330,6 +361,19 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
assert result.messages[4].source == "agent2"
|
||||
assert result.messages[5].source == "agent1"
|
||||
|
||||
# Test streaming.
|
||||
mock.reset()
|
||||
agent1._count = 0 # pyright: ignore
|
||||
index = 0
|
||||
async for message in team.run_stream(
|
||||
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
||||
):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
@ -356,7 +400,8 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch)
|
|||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
)
|
||||
result = await team.run(
|
||||
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
||||
"Write a program that prints 'Hello, world!'",
|
||||
termination_condition=StopMessageTermination(),
|
||||
)
|
||||
assert len(result.messages) == 5
|
||||
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
||||
|
@ -367,6 +412,19 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch)
|
|||
# only one chat completion was called
|
||||
assert mock._curr_index == 1 # pyright: ignore
|
||||
|
||||
# Test streaming.
|
||||
mock.reset()
|
||||
agent1._count = 0 # pyright: ignore
|
||||
index = 0
|
||||
async for message in team.run_stream(
|
||||
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
||||
):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
@ -422,6 +480,18 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte
|
|||
assert result.messages[2].source == "agent2"
|
||||
assert result.messages[3].source == "agent1"
|
||||
|
||||
# Test streaming.
|
||||
mock.reset()
|
||||
index = 0
|
||||
async for message in team.run_stream(
|
||||
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
||||
):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
||||
|
||||
class _HandOffAgent(BaseChatAgent):
|
||||
def __init__(self, name: str, description: str, next_agent: str) -> None:
|
||||
|
@ -446,8 +516,8 @@ async def test_swarm_handoff() -> None:
|
|||
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
||||
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
||||
|
||||
team = Swarm([second_agent, first_agent, third_agent])
|
||||
result = await team.run("task", termination_condition=MaxMessageTermination(6))
|
||||
team = Swarm([second_agent, first_agent, third_agent], termination_condition=MaxMessageTermination(6))
|
||||
result = await team.run("task")
|
||||
assert len(result.messages) == 6
|
||||
assert result.messages[0].content == "task"
|
||||
assert result.messages[1].content == "Transferred to third_agent."
|
||||
|
@ -456,6 +526,16 @@ async def test_swarm_handoff() -> None:
|
|||
assert result.messages[4].content == "Transferred to third_agent."
|
||||
assert result.messages[5].content == "Transferred to first_agent."
|
||||
|
||||
# Test streaming.
|
||||
index = 0
|
||||
stream = team.run_stream("task", termination_condition=MaxMessageTermination(6))
|
||||
async for message in stream:
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
@ -514,19 +594,31 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
|
|||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
|
||||
agnet1 = AssistantAgent(
|
||||
agent1 = AssistantAgent(
|
||||
"agent1",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")],
|
||||
)
|
||||
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
|
||||
team = Swarm([agnet1, agent2])
|
||||
team = Swarm([agent1, agent2])
|
||||
result = await team.run("task", termination_condition=StopMessageTermination())
|
||||
assert len(result.messages) == 7
|
||||
assert result.messages[0].content == "task"
|
||||
assert isinstance(result.messages[1], ToolCallMessage)
|
||||
assert isinstance(result.messages[2], ToolCallResultMessages)
|
||||
assert isinstance(result.messages[2], ToolCallResultMessage)
|
||||
assert result.messages[3].content == "handoff to agent2"
|
||||
assert result.messages[4].content == "Transferred to agent1."
|
||||
assert result.messages[5].content == "Hello"
|
||||
assert result.messages[6].content == "TERMINATE"
|
||||
|
||||
# Test streaming.
|
||||
agent1._model_context.clear() # pyright: ignore
|
||||
mock.reset()
|
||||
index = 0
|
||||
stream = team.run_stream("task", termination_condition=StopMessageTermination())
|
||||
async for message in stream:
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
|
Loading…
Reference in New Issue