AgentChat streaming API (#4015)

This commit is contained in:
Eric Zhu 2024-11-01 04:12:43 -07:00 committed by GitHub
parent 4023454c58
commit cff7d842a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 353 additions and 107 deletions

View File

@ -1,7 +1,7 @@
import asyncio
import json
import logging
from typing import Any, Awaitable, Callable, Dict, List, Sequence
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Sequence
from autogen_core.base import CancellationToken
from autogen_core.components import FunctionCall
@ -27,7 +27,7 @@ from ..messages import (
StopMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessages,
ToolCallResultMessage,
)
from ._base_chat_agent import BaseChatAgent
@ -98,7 +98,11 @@ class Handoff(BaseModel):
@property
def handoff_tool(self) -> Tool:
"""Create a handoff tool from this handoff configuration."""
return FunctionTool(lambda: self.message, name=self.name, description=self.description)
def _handoff_tool() -> str:
return self.message
return FunctionTool(_handoff_tool, name=self.name, description=self.description)
class AssistantAgent(BaseChatAgent):
@ -138,7 +142,7 @@ class AssistantAgent(BaseChatAgent):
The following example demonstrates how to create an assistant agent with
a model client and a tool, and generate a response to a simple task using the tool.
a model client and a tool, and generate a stream of messages for a task.
.. code-block:: python
@ -154,7 +158,11 @@ class AssistantAgent(BaseChatAgent):
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time])
await agent.run("What is the current time?", termination_condition=MaxMessageTermination(3))
stream = agent.run_stream("What is the current time?", termination_condition=MaxMessageTermination(3))
async for message in stream:
print(message)
"""
@ -219,6 +227,14 @@ class AssistantAgent(BaseChatAgent):
return [TextMessage, StopMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
return message
raise AssertionError("The stream should have returned the final result.")
async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[InnerMessage | Response, None]:
# Add messages to the model context.
for msg in messages:
if isinstance(msg, ResetMessage):
@ -243,6 +259,7 @@ class AssistantAgent(BaseChatAgent):
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
# Add the tool call message to the output.
inner_messages.append(ToolCallMessage(content=result.content, source=self.name))
yield ToolCallMessage(content=result.content, source=self.name)
# Execute the tool calls.
results = await asyncio.gather(
@ -250,7 +267,8 @@ class AssistantAgent(BaseChatAgent):
)
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
self._model_context.append(FunctionExecutionResultMessage(content=results))
inner_messages.append(ToolCallResultMessages(content=results, source=self.name))
inner_messages.append(ToolCallResultMessage(content=results, source=self.name))
yield ToolCallResultMessage(content=results, source=self.name)
# Detect handoff requests.
handoffs: List[Handoff] = []
@ -261,12 +279,13 @@ class AssistantAgent(BaseChatAgent):
if len(handoffs) > 1:
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
# Return the output messages to signal the handoff.
return Response(
yield Response(
chat_message=HandoffMessage(
content=handoffs[0].message, target=handoffs[0].target, source=self.name
),
inner_messages=inner_messages,
)
return
# Generate an inference result based on the current model context.
result = await self._model_client.create(
@ -278,11 +297,11 @@ class AssistantAgent(BaseChatAgent):
# Detect stop request.
request_stop = "terminate" in result.content.strip().lower()
if request_stop:
return Response(
yield Response(
chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
return Response(
else:
yield Response(
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
)

View File

@ -1,9 +1,9 @@
from abc import ABC, abstractmethod
from typing import List, Sequence
from typing import AsyncGenerator, List, Sequence
from autogen_core.base import CancellationToken
from ..base import ChatAgent, Response, TaskResult, TerminationCondition
from ..base import ChatAgent, Response, TaskResult
from ..messages import ChatMessage, InnerMessage, TextMessage
@ -40,12 +40,22 @@ class BaseChatAgent(ChatAgent, ABC):
"""Handles incoming messages and returns a response."""
...
async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[InnerMessage | Response, None]:
"""Handles incoming messages and returns a stream of messages and
and the final item is the response. The base implementation in :class:`BaseChatAgent`
simply calls :meth:`on_messages` and yields the messages in the response."""
response = await self.on_messages(messages, cancellation_token)
for inner_message in response.inner_messages or []:
yield inner_message
yield response
async def run(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the agent with the given task and return the result."""
if cancellation_token is None:
@ -57,3 +67,25 @@ class BaseChatAgent(ChatAgent, ABC):
messages += response.inner_messages
messages.append(response.chat_message)
return TaskResult(messages=messages)
async def run_stream(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
"""Run the agent with the given task and return a stream of messages
and the final task result as the last item in the stream."""
if cancellation_token is None:
cancellation_token = CancellationToken()
first_message = TextMessage(content=task, source="user")
yield first_message
messages: List[InnerMessage | ChatMessage] = [first_message]
async for message in self.on_messages_stream([first_message], cancellation_token):
if isinstance(message, Response):
yield message.chat_message
messages.append(message.chat_message)
yield TaskResult(messages=messages)
else:
messages.append(message)
yield message

View File

@ -1,11 +1,10 @@
from dataclasses import dataclass
from typing import List, Protocol, Sequence, runtime_checkable
from typing import AsyncGenerator, List, Protocol, Sequence, runtime_checkable
from autogen_core.base import CancellationToken
from ..messages import ChatMessage, InnerMessage
from ._task import TaskResult, TaskRunner
from ._termination import TerminationCondition
from ._task import TaskRunner
@dataclass(kw_only=True)
@ -45,12 +44,9 @@ class ChatAgent(TaskRunner, Protocol):
"""Handles incoming messages and returns a response."""
...
async def run(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the agent with the given task and return the result."""
def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[InnerMessage | Response, None]:
"""Handles incoming messages and returns a stream of inner messages and
and the final item is the response."""
...

View File

@ -1,10 +1,9 @@
from dataclasses import dataclass
from typing import Protocol, Sequence
from typing import AsyncGenerator, Protocol, Sequence
from autogen_core.base import CancellationToken
from ..messages import ChatMessage, InnerMessage
from ._termination import TerminationCondition
@dataclass
@ -23,7 +22,16 @@ class TaskRunner(Protocol):
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the task."""
...
def run_stream(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
"""Run the task and produces a stream of messages and the final result
as the last item in the stream."""
...

View File

@ -1,18 +1,7 @@
from typing import Protocol
from autogen_core.base import CancellationToken
from ._task import TaskResult, TaskRunner
from ._termination import TerminationCondition
from ._task import TaskRunner
class Team(TaskRunner, Protocol):
async def run(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the team on a given task until the termination condition is met."""
...
pass

View File

@ -57,14 +57,14 @@ class ToolCallMessage(BaseMessage):
"""The tool calls."""
class ToolCallResultMessages(BaseMessage):
class ToolCallResultMessage(BaseMessage):
"""A message signaling the results of tool calls."""
content: List[FunctionExecutionResult]
"""The tool call results."""
InnerMessage = ToolCallMessage | ToolCallResultMessages
InnerMessage = ToolCallMessage | ToolCallResultMessage
"""Messages for intra-agent monologues."""
@ -80,6 +80,6 @@ __all__ = [
"HandoffMessage",
"ResetMessage",
"ToolCallMessage",
"ToolCallResultMessages",
"ToolCallResultMessage",
"ChatMessage",
]

View File

@ -1,6 +1,7 @@
import asyncio
import uuid
from abc import ABC, abstractmethod
from typing import Callable, List
from typing import AsyncGenerator, Callable, List
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import (
@ -75,9 +76,24 @@ class BaseGroupChat(Team, ABC):
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the team and return the result."""
# Create intervention handler for termination.
"""Run the team and return the result. The base implementation uses
:meth:`run_stream` to run the team and then returns the final result."""
async for message in self.run_stream(
task, cancellation_token=cancellation_token, termination_condition=termination_condition
):
if isinstance(message, TaskResult):
return message
raise AssertionError("The stream should have returned the final result.")
async def run_stream(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
"""Run the team and produces a stream of messages and the final result
as the last item in the stream."""
# Create the runtime.
runtime = SingleThreadedAgentRuntime()
@ -132,6 +148,7 @@ class BaseGroupChat(Team, ABC):
)
output_messages: List[InnerMessage | ChatMessage] = []
output_message_queue: asyncio.Queue[InnerMessage | ChatMessage | None] = asyncio.Queue()
async def collect_output_messages(
_runtime: AgentRuntime,
@ -140,6 +157,7 @@ class BaseGroupChat(Team, ABC):
ctx: MessageContext,
) -> None:
output_messages.append(message)
await output_message_queue.put(message)
await ClosureAgent.register(
runtime,
@ -158,14 +176,29 @@ class BaseGroupChat(Team, ABC):
group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id)
first_chat_message = TextMessage(content=task, source="user")
output_messages.append(first_chat_message)
await output_message_queue.put(first_chat_message)
await runtime.publish_message(
GroupChatPublishEvent(agent_message=first_chat_message),
topic_id=team_topic_id,
)
await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id)
# Wait for the runtime to stop.
# Start a coroutine to stop the runtime and signal the output message queue is complete.
async def stop_runtime() -> None:
await runtime.stop_when_idle()
await output_message_queue.put(None)
# Return the result.
return TaskResult(messages=output_messages)
shutdown_task = asyncio.create_task(stop_runtime())
# Yield the messsages until the queue is empty.
while True:
message = await output_message_queue.get()
if message is None:
break
yield message
# Wait for the shutdown task to finish.
await shutdown_task
# Yield the final result.
yield TaskResult(messages=output_messages)

View File

@ -3,7 +3,7 @@ from typing import Any, List
from autogen_core.base import MessageContext
from autogen_core.components import DefaultTopicId, event
from ...base import ChatAgent
from ...base import ChatAgent, Response
from ...messages import ChatMessage
from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent
from ._sequential_routed_agent import SequentialRoutedAgent
@ -37,28 +37,26 @@ class ChatAgentContainer(SequentialRoutedAgent):
"""Handle a content request event by passing the messages in the buffer
to the delegate agent and publish the response."""
# Pass the messages in the buffer to the delegate agent.
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
if not any(isinstance(response.chat_message, msg_type) for msg_type in self._agent.produced_message_types):
raise ValueError(
f"The agent {self._agent.name} produced an unexpected message type: {type(response)}. "
f"Expected one of: {self._agent.produced_message_types}. "
f"Check the agent's produced_message_types property."
response: Response | None = None
async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token):
if isinstance(msg, Response):
await self.publish_message(
msg.chat_message,
topic_id=DefaultTopicId(type=self._output_topic_type),
)
response = msg
else:
# Publish the message to the output topic.
await self.publish_message(msg, topic_id=DefaultTopicId(type=self._output_topic_type))
if response is None:
raise ValueError("The agent did not produce a final response. Check the agent's on_messages_stream method.")
# Publish inner messages to the output topic.
if response.inner_messages is not None:
for inner_message in response.inner_messages:
await self.publish_message(inner_message, topic_id=DefaultTopicId(type=self._output_topic_type))
# Publish the response.
# Publish the response to the group chat.
self._message_buffer.clear()
await self.publish_message(
GroupChatPublishEvent(agent_message=response.chat_message, source=self.id),
topic_id=DefaultTopicId(type=self._parent_topic_type),
)
# Publish the response to the output topic.
await self.publish_message(response.chat_message, topic_id=DefaultTopicId(type=self._output_topic_type))
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
raise ValueError(f"Unhandled message in agent container: {type(message)}")

View File

@ -61,24 +61,45 @@ class RoundRobinGroupChat(BaseGroupChat):
.. code-block:: python
from autogen_agentchat.agents import ToolUseAssistantAgent
from autogen_agentchat.teams import RoundRobinGroupChat, StopMessageTermination
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.task import StopMessageTermination
assistant = ToolUseAssistantAgent("Assistant", model_client=..., registered_tools=...)
model_client = OpenAIChatCompletionClient(model="gpt-4o")
async def get_weather(location: str) -> str:
return f"The weather in {location} is sunny."
assistant = AssistantAgent(
"Assistant",
model_client=model_client,
tools=[get_weather],
)
team = RoundRobinGroupChat([assistant])
await team.run("What's the weather in New York?", termination_condition=StopMessageTermination())
stream = team.run_stream("What's the weather in New York?", termination_condition=StopMessageTermination())
async for message in stream:
print(message)
A team with multiple participants:
.. code-block:: python
from autogen_agentchat.agents import CodingAssistantAgent, CodeExecutorAgent
from autogen_agentchat.teams import RoundRobinGroupChat, StopMessageTermination
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.task import StopMessageTermination
coding_assistant = CodingAssistantAgent("Coding_Assistant", model_client=...)
executor_agent = CodeExecutorAgent("Code_Executor", code_executor=...)
team = RoundRobinGroupChat([coding_assistant, executor_agent])
await team.run("Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination())
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent1 = AssistantAgent("Assistant1", model_client=model_client)
agent2 = AssistantAgent("Assistant2", model_client=model_client)
team = RoundRobinGroupChat([agent1, agent2])
stream = team.run_stream("Tell me some jokes.", termination_condition=StopMessageTermination())
async for message in stream:
print(message)
"""

View File

@ -170,14 +170,48 @@ class SelectorGroupChat(BaseGroupChat):
.. code-block:: python
from autogen_agentchat.agents import ToolUseAssistantAgent
from autogen_agentchat.teams import SelectorGroupChat, StopMessageTermination
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.task import StopMessageTermination
travel_advisor = ToolUseAssistantAgent("Travel_Advisor", model_client=..., registered_tools=...)
hotel_agent = ToolUseAssistantAgent("Hotel_Agent", model_client=..., registered_tools=...)
flight_agent = ToolUseAssistantAgent("Flight_Agent", model_client=..., registered_tools=...)
team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=...)
await team.run("Book a 3-day trip to new york.", termination_condition=StopMessageTermination())
model_client = OpenAIChatCompletionClient(model="gpt-4o")
async def lookup_hotel(location: str) -> str:
return f"Here are some hotels in {location}: hotel1, hotel2, hotel3."
async def lookup_flight(origin: str, destination: str) -> str:
return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3."
async def book_trip() -> str:
return "Your trip is booked!"
travel_advisor = AssistantAgent(
"Travel_Advisor",
model_client,
tools=[book_trip],
description="Helps with travel planning.",
)
hotel_agent = AssistantAgent(
"Hotel_Agent",
model_client,
tools=[lookup_hotel],
description="Helps with hotel booking.",
)
flight_agent = AssistantAgent(
"Flight_Agent",
model_client,
tools=[lookup_flight],
description="Helps with flight booking.",
)
team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=model_client)
stream = team.run_stream("Book a 3-day trip to new york.", termination_condition=StopMessageTermination())
async for message in stream:
print(message)
"""
def __init__(

View File

@ -79,7 +79,10 @@ class Swarm(BaseGroupChat):
)
team = Swarm([agent1, agent2])
await team.run("What is bob's birthday?", termination_condition=MaxMessageTermination(3))
stream = team.run_stream("What is bob's birthday?", termination_condition=MaxMessageTermination(3))
async for message in stream:
print(message)
"""
def __init__(self, participants: List[ChatAgent], termination_condition: TerminationCondition | None = None):

View File

@ -6,9 +6,9 @@ from typing import Any, AsyncGenerator, List
import pytest
from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import AssistantAgent, Handoff
from autogen_agentchat.base import TaskResult
from autogen_agentchat.logging import FileLogHandler
from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessages
from autogen_core.base import CancellationToken
from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
from autogen_core.components.tools import FunctionTool
from autogen_ext.models import OpenAIChatCompletionClient
from openai.resources.chat.completions import AsyncCompletions
@ -114,9 +114,19 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert isinstance(result.messages[1], ToolCallMessage)
assert isinstance(result.messages[2], ToolCallResultMessages)
assert isinstance(result.messages[2], ToolCallResultMessage)
assert isinstance(result.messages[3], TextMessage)
# Test streaming.
mock._curr_index = 0 # pyright: ignore
index = 0
async for message in tool_use_agent.run_stream("task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
@ -160,8 +170,19 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
handoffs=[handoff],
)
assert HandoffMessage in tool_use_agent.produced_message_types
response = await tool_use_agent.on_messages(
[TextMessage(content="task", source="user")], cancellation_token=CancellationToken()
)
assert isinstance(response.chat_message, HandoffMessage)
assert response.chat_message.target == "agent2"
result = await tool_use_agent.run("task")
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert isinstance(result.messages[1], ToolCallMessage)
assert isinstance(result.messages[2], ToolCallResultMessage)
assert isinstance(result.messages[3], HandoffMessage)
# Test streaming.
mock._curr_index = 0 # pyright: ignore
index = 0
async for message in tool_use_agent.run_stream("task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1

View File

@ -12,7 +12,7 @@ from autogen_agentchat.agents import (
CodeExecutorAgent,
Handoff,
)
from autogen_agentchat.base import Response
from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.logging import FileLogHandler
from autogen_agentchat.messages import (
ChatMessage,
@ -20,7 +20,7 @@ from autogen_agentchat.messages import (
StopMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessages,
ToolCallResultMessage,
)
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination
from autogen_agentchat.teams import (
@ -59,6 +59,9 @@ class _MockChatCompletion:
self._curr_index += 1
return completion
def reset(self) -> None:
self._curr_index = 0
class _EchoAgent(BaseChatAgent):
def __init__(self, name: str, description: str) -> None:
@ -147,7 +150,8 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
)
team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent])
result = await team.run(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
)
expected_messages = [
"Write a program that prints 'Hello, world!'",
@ -164,6 +168,18 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
# Assert that all expected messages are in the collected messages
assert normalized_messages == expected_messages
# Test streaming.
mock.reset()
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio
async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
@ -230,13 +246,14 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
echo_agent = _EchoAgent("echo_agent", description="echo agent")
team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent])
result = await team.run(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
)
assert len(result.messages) == 6
assert isinstance(result.messages[0], TextMessage) # task
assert isinstance(result.messages[1], ToolCallMessage) # tool call
assert isinstance(result.messages[2], ToolCallResultMessages) # tool call result
assert isinstance(result.messages[2], ToolCallResultMessage) # tool call result
assert isinstance(result.messages[3], TextMessage) # tool use agent response
assert isinstance(result.messages[4], TextMessage) # echo agent response
assert isinstance(result.messages[5], StopMessage) # tool use agent response
@ -253,6 +270,19 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
assert context[2].content[0].call_id == "1"
assert context[3].content == "Hello"
# Test streaming.
tool_use_agent._model_context.clear() # pyright: ignore
mock.reset()
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio
async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
@ -320,7 +350,8 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
)
result = await team.run(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
)
assert len(result.messages) == 6
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
@ -330,6 +361,19 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
assert result.messages[4].source == "agent2"
assert result.messages[5].source == "agent1"
# Test streaming.
mock.reset()
agent1._count = 0 # pyright: ignore
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio
async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) -> None:
@ -356,7 +400,8 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch)
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
)
result = await team.run(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
)
assert len(result.messages) == 5
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
@ -367,6 +412,19 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch)
# only one chat completion was called
assert mock._curr_index == 1 # pyright: ignore
# Test streaming.
mock.reset()
agent1._count = 0 # pyright: ignore
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio
async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pytest.MonkeyPatch) -> None:
@ -422,6 +480,18 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte
assert result.messages[2].source == "agent2"
assert result.messages[3].source == "agent1"
# Test streaming.
mock.reset()
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
class _HandOffAgent(BaseChatAgent):
def __init__(self, name: str, description: str, next_agent: str) -> None:
@ -446,8 +516,8 @@ async def test_swarm_handoff() -> None:
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
team = Swarm([second_agent, first_agent, third_agent])
result = await team.run("task", termination_condition=MaxMessageTermination(6))
team = Swarm([second_agent, first_agent, third_agent], termination_condition=MaxMessageTermination(6))
result = await team.run("task")
assert len(result.messages) == 6
assert result.messages[0].content == "task"
assert result.messages[1].content == "Transferred to third_agent."
@ -456,6 +526,16 @@ async def test_swarm_handoff() -> None:
assert result.messages[4].content == "Transferred to third_agent."
assert result.messages[5].content == "Transferred to first_agent."
# Test streaming.
index = 0
stream = team.run_stream("task", termination_condition=MaxMessageTermination(6))
async for message in stream:
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio
async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
@ -514,19 +594,31 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
agnet1 = AssistantAgent(
agent1 = AssistantAgent(
"agent1",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")],
)
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
team = Swarm([agnet1, agent2])
team = Swarm([agent1, agent2])
result = await team.run("task", termination_condition=StopMessageTermination())
assert len(result.messages) == 7
assert result.messages[0].content == "task"
assert isinstance(result.messages[1], ToolCallMessage)
assert isinstance(result.messages[2], ToolCallResultMessages)
assert isinstance(result.messages[2], ToolCallResultMessage)
assert result.messages[3].content == "handoff to agent2"
assert result.messages[4].content == "Transferred to agent1."
assert result.messages[5].content == "Hello"
assert result.messages[6].content == "TERMINATE"
# Test streaming.
agent1._model_context.clear() # pyright: ignore
mock.reset()
index = 0
stream = team.run_stream("task", termination_condition=StopMessageTermination())
async for message in stream:
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1