mirror of https://github.com/microsoft/autogen.git
AgentChat streaming API (#4015)
This commit is contained in:
parent
4023454c58
commit
cff7d842a6
|
@ -1,7 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Awaitable, Callable, Dict, List, Sequence
|
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Sequence
|
||||||
|
|
||||||
from autogen_core.base import CancellationToken
|
from autogen_core.base import CancellationToken
|
||||||
from autogen_core.components import FunctionCall
|
from autogen_core.components import FunctionCall
|
||||||
|
@ -27,7 +27,7 @@ from ..messages import (
|
||||||
StopMessage,
|
StopMessage,
|
||||||
TextMessage,
|
TextMessage,
|
||||||
ToolCallMessage,
|
ToolCallMessage,
|
||||||
ToolCallResultMessages,
|
ToolCallResultMessage,
|
||||||
)
|
)
|
||||||
from ._base_chat_agent import BaseChatAgent
|
from ._base_chat_agent import BaseChatAgent
|
||||||
|
|
||||||
|
@ -98,7 +98,11 @@ class Handoff(BaseModel):
|
||||||
@property
|
@property
|
||||||
def handoff_tool(self) -> Tool:
|
def handoff_tool(self) -> Tool:
|
||||||
"""Create a handoff tool from this handoff configuration."""
|
"""Create a handoff tool from this handoff configuration."""
|
||||||
return FunctionTool(lambda: self.message, name=self.name, description=self.description)
|
|
||||||
|
def _handoff_tool() -> str:
|
||||||
|
return self.message
|
||||||
|
|
||||||
|
return FunctionTool(_handoff_tool, name=self.name, description=self.description)
|
||||||
|
|
||||||
|
|
||||||
class AssistantAgent(BaseChatAgent):
|
class AssistantAgent(BaseChatAgent):
|
||||||
|
@ -138,7 +142,7 @@ class AssistantAgent(BaseChatAgent):
|
||||||
|
|
||||||
|
|
||||||
The following example demonstrates how to create an assistant agent with
|
The following example demonstrates how to create an assistant agent with
|
||||||
a model client and a tool, and generate a response to a simple task using the tool.
|
a model client and a tool, and generate a stream of messages for a task.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -154,7 +158,11 @@ class AssistantAgent(BaseChatAgent):
|
||||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||||
agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time])
|
agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time])
|
||||||
|
|
||||||
await agent.run("What is the current time?", termination_condition=MaxMessageTermination(3))
|
stream = agent.run_stream("What is the current time?", termination_condition=MaxMessageTermination(3))
|
||||||
|
|
||||||
|
async for message in stream:
|
||||||
|
print(message)
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -219,6 +227,14 @@ class AssistantAgent(BaseChatAgent):
|
||||||
return [TextMessage, StopMessage]
|
return [TextMessage, StopMessage]
|
||||||
|
|
||||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||||
|
async for message in self.on_messages_stream(messages, cancellation_token):
|
||||||
|
if isinstance(message, Response):
|
||||||
|
return message
|
||||||
|
raise AssertionError("The stream should have returned the final result.")
|
||||||
|
|
||||||
|
async def on_messages_stream(
|
||||||
|
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||||
|
) -> AsyncGenerator[InnerMessage | Response, None]:
|
||||||
# Add messages to the model context.
|
# Add messages to the model context.
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if isinstance(msg, ResetMessage):
|
if isinstance(msg, ResetMessage):
|
||||||
|
@ -243,6 +259,7 @@ class AssistantAgent(BaseChatAgent):
|
||||||
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
|
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
|
||||||
# Add the tool call message to the output.
|
# Add the tool call message to the output.
|
||||||
inner_messages.append(ToolCallMessage(content=result.content, source=self.name))
|
inner_messages.append(ToolCallMessage(content=result.content, source=self.name))
|
||||||
|
yield ToolCallMessage(content=result.content, source=self.name)
|
||||||
|
|
||||||
# Execute the tool calls.
|
# Execute the tool calls.
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
|
@ -250,7 +267,8 @@ class AssistantAgent(BaseChatAgent):
|
||||||
)
|
)
|
||||||
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
|
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
|
||||||
self._model_context.append(FunctionExecutionResultMessage(content=results))
|
self._model_context.append(FunctionExecutionResultMessage(content=results))
|
||||||
inner_messages.append(ToolCallResultMessages(content=results, source=self.name))
|
inner_messages.append(ToolCallResultMessage(content=results, source=self.name))
|
||||||
|
yield ToolCallResultMessage(content=results, source=self.name)
|
||||||
|
|
||||||
# Detect handoff requests.
|
# Detect handoff requests.
|
||||||
handoffs: List[Handoff] = []
|
handoffs: List[Handoff] = []
|
||||||
|
@ -261,12 +279,13 @@ class AssistantAgent(BaseChatAgent):
|
||||||
if len(handoffs) > 1:
|
if len(handoffs) > 1:
|
||||||
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
|
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
|
||||||
# Return the output messages to signal the handoff.
|
# Return the output messages to signal the handoff.
|
||||||
return Response(
|
yield Response(
|
||||||
chat_message=HandoffMessage(
|
chat_message=HandoffMessage(
|
||||||
content=handoffs[0].message, target=handoffs[0].target, source=self.name
|
content=handoffs[0].message, target=handoffs[0].target, source=self.name
|
||||||
),
|
),
|
||||||
inner_messages=inner_messages,
|
inner_messages=inner_messages,
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# Generate an inference result based on the current model context.
|
# Generate an inference result based on the current model context.
|
||||||
result = await self._model_client.create(
|
result = await self._model_client.create(
|
||||||
|
@ -278,13 +297,13 @@ class AssistantAgent(BaseChatAgent):
|
||||||
# Detect stop request.
|
# Detect stop request.
|
||||||
request_stop = "terminate" in result.content.strip().lower()
|
request_stop = "terminate" in result.content.strip().lower()
|
||||||
if request_stop:
|
if request_stop:
|
||||||
return Response(
|
yield Response(
|
||||||
chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
|
chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
return Response(
|
yield Response(
|
||||||
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
|
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _execute_tool_call(
|
async def _execute_tool_call(
|
||||||
self, tool_call: FunctionCall, cancellation_token: CancellationToken
|
self, tool_call: FunctionCall, cancellation_token: CancellationToken
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Sequence
|
from typing import AsyncGenerator, List, Sequence
|
||||||
|
|
||||||
from autogen_core.base import CancellationToken
|
from autogen_core.base import CancellationToken
|
||||||
|
|
||||||
from ..base import ChatAgent, Response, TaskResult, TerminationCondition
|
from ..base import ChatAgent, Response, TaskResult
|
||||||
from ..messages import ChatMessage, InnerMessage, TextMessage
|
from ..messages import ChatMessage, InnerMessage, TextMessage
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,12 +40,22 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||||
"""Handles incoming messages and returns a response."""
|
"""Handles incoming messages and returns a response."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
async def on_messages_stream(
|
||||||
|
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||||
|
) -> AsyncGenerator[InnerMessage | Response, None]:
|
||||||
|
"""Handles incoming messages and returns a stream of messages and
|
||||||
|
and the final item is the response. The base implementation in :class:`BaseChatAgent`
|
||||||
|
simply calls :meth:`on_messages` and yields the messages in the response."""
|
||||||
|
response = await self.on_messages(messages, cancellation_token)
|
||||||
|
for inner_message in response.inner_messages or []:
|
||||||
|
yield inner_message
|
||||||
|
yield response
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
task: str,
|
task: str,
|
||||||
*,
|
*,
|
||||||
cancellation_token: CancellationToken | None = None,
|
cancellation_token: CancellationToken | None = None,
|
||||||
termination_condition: TerminationCondition | None = None,
|
|
||||||
) -> TaskResult:
|
) -> TaskResult:
|
||||||
"""Run the agent with the given task and return the result."""
|
"""Run the agent with the given task and return the result."""
|
||||||
if cancellation_token is None:
|
if cancellation_token is None:
|
||||||
|
@ -57,3 +67,25 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||||
messages += response.inner_messages
|
messages += response.inner_messages
|
||||||
messages.append(response.chat_message)
|
messages.append(response.chat_message)
|
||||||
return TaskResult(messages=messages)
|
return TaskResult(messages=messages)
|
||||||
|
|
||||||
|
async def run_stream(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
*,
|
||||||
|
cancellation_token: CancellationToken | None = None,
|
||||||
|
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
|
||||||
|
"""Run the agent with the given task and return a stream of messages
|
||||||
|
and the final task result as the last item in the stream."""
|
||||||
|
if cancellation_token is None:
|
||||||
|
cancellation_token = CancellationToken()
|
||||||
|
first_message = TextMessage(content=task, source="user")
|
||||||
|
yield first_message
|
||||||
|
messages: List[InnerMessage | ChatMessage] = [first_message]
|
||||||
|
async for message in self.on_messages_stream([first_message], cancellation_token):
|
||||||
|
if isinstance(message, Response):
|
||||||
|
yield message.chat_message
|
||||||
|
messages.append(message.chat_message)
|
||||||
|
yield TaskResult(messages=messages)
|
||||||
|
else:
|
||||||
|
messages.append(message)
|
||||||
|
yield message
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Protocol, Sequence, runtime_checkable
|
from typing import AsyncGenerator, List, Protocol, Sequence, runtime_checkable
|
||||||
|
|
||||||
from autogen_core.base import CancellationToken
|
from autogen_core.base import CancellationToken
|
||||||
|
|
||||||
from ..messages import ChatMessage, InnerMessage
|
from ..messages import ChatMessage, InnerMessage
|
||||||
from ._task import TaskResult, TaskRunner
|
from ._task import TaskRunner
|
||||||
from ._termination import TerminationCondition
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
@dataclass(kw_only=True)
|
||||||
|
@ -45,12 +44,9 @@ class ChatAgent(TaskRunner, Protocol):
|
||||||
"""Handles incoming messages and returns a response."""
|
"""Handles incoming messages and returns a response."""
|
||||||
...
|
...
|
||||||
|
|
||||||
async def run(
|
def on_messages_stream(
|
||||||
self,
|
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||||
task: str,
|
) -> AsyncGenerator[InnerMessage | Response, None]:
|
||||||
*,
|
"""Handles incoming messages and returns a stream of inner messages and
|
||||||
cancellation_token: CancellationToken | None = None,
|
and the final item is the response."""
|
||||||
termination_condition: TerminationCondition | None = None,
|
|
||||||
) -> TaskResult:
|
|
||||||
"""Run the agent with the given task and return the result."""
|
|
||||||
...
|
...
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Protocol, Sequence
|
from typing import AsyncGenerator, Protocol, Sequence
|
||||||
|
|
||||||
from autogen_core.base import CancellationToken
|
from autogen_core.base import CancellationToken
|
||||||
|
|
||||||
from ..messages import ChatMessage, InnerMessage
|
from ..messages import ChatMessage, InnerMessage
|
||||||
from ._termination import TerminationCondition
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -23,7 +22,16 @@ class TaskRunner(Protocol):
|
||||||
task: str,
|
task: str,
|
||||||
*,
|
*,
|
||||||
cancellation_token: CancellationToken | None = None,
|
cancellation_token: CancellationToken | None = None,
|
||||||
termination_condition: TerminationCondition | None = None,
|
|
||||||
) -> TaskResult:
|
) -> TaskResult:
|
||||||
"""Run the task."""
|
"""Run the task."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def run_stream(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
*,
|
||||||
|
cancellation_token: CancellationToken | None = None,
|
||||||
|
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
|
||||||
|
"""Run the task and produces a stream of messages and the final result
|
||||||
|
as the last item in the stream."""
|
||||||
|
...
|
||||||
|
|
|
@ -1,18 +1,7 @@
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
from autogen_core.base import CancellationToken
|
from ._task import TaskRunner
|
||||||
|
|
||||||
from ._task import TaskResult, TaskRunner
|
|
||||||
from ._termination import TerminationCondition
|
|
||||||
|
|
||||||
|
|
||||||
class Team(TaskRunner, Protocol):
|
class Team(TaskRunner, Protocol):
|
||||||
async def run(
|
pass
|
||||||
self,
|
|
||||||
task: str,
|
|
||||||
*,
|
|
||||||
cancellation_token: CancellationToken | None = None,
|
|
||||||
termination_condition: TerminationCondition | None = None,
|
|
||||||
) -> TaskResult:
|
|
||||||
"""Run the team on a given task until the termination condition is met."""
|
|
||||||
...
|
|
||||||
|
|
|
@ -57,14 +57,14 @@ class ToolCallMessage(BaseMessage):
|
||||||
"""The tool calls."""
|
"""The tool calls."""
|
||||||
|
|
||||||
|
|
||||||
class ToolCallResultMessages(BaseMessage):
|
class ToolCallResultMessage(BaseMessage):
|
||||||
"""A message signaling the results of tool calls."""
|
"""A message signaling the results of tool calls."""
|
||||||
|
|
||||||
content: List[FunctionExecutionResult]
|
content: List[FunctionExecutionResult]
|
||||||
"""The tool call results."""
|
"""The tool call results."""
|
||||||
|
|
||||||
|
|
||||||
InnerMessage = ToolCallMessage | ToolCallResultMessages
|
InnerMessage = ToolCallMessage | ToolCallResultMessage
|
||||||
"""Messages for intra-agent monologues."""
|
"""Messages for intra-agent monologues."""
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,6 +80,6 @@ __all__ = [
|
||||||
"HandoffMessage",
|
"HandoffMessage",
|
||||||
"ResetMessage",
|
"ResetMessage",
|
||||||
"ToolCallMessage",
|
"ToolCallMessage",
|
||||||
"ToolCallResultMessages",
|
"ToolCallResultMessage",
|
||||||
"ChatMessage",
|
"ChatMessage",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, List
|
from typing import AsyncGenerator, Callable, List
|
||||||
|
|
||||||
from autogen_core.application import SingleThreadedAgentRuntime
|
from autogen_core.application import SingleThreadedAgentRuntime
|
||||||
from autogen_core.base import (
|
from autogen_core.base import (
|
||||||
|
@ -75,9 +76,24 @@ class BaseGroupChat(Team, ABC):
|
||||||
cancellation_token: CancellationToken | None = None,
|
cancellation_token: CancellationToken | None = None,
|
||||||
termination_condition: TerminationCondition | None = None,
|
termination_condition: TerminationCondition | None = None,
|
||||||
) -> TaskResult:
|
) -> TaskResult:
|
||||||
"""Run the team and return the result."""
|
"""Run the team and return the result. The base implementation uses
|
||||||
# Create intervention handler for termination.
|
:meth:`run_stream` to run the team and then returns the final result."""
|
||||||
|
async for message in self.run_stream(
|
||||||
|
task, cancellation_token=cancellation_token, termination_condition=termination_condition
|
||||||
|
):
|
||||||
|
if isinstance(message, TaskResult):
|
||||||
|
return message
|
||||||
|
raise AssertionError("The stream should have returned the final result.")
|
||||||
|
|
||||||
|
async def run_stream(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
*,
|
||||||
|
cancellation_token: CancellationToken | None = None,
|
||||||
|
termination_condition: TerminationCondition | None = None,
|
||||||
|
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
|
||||||
|
"""Run the team and produces a stream of messages and the final result
|
||||||
|
as the last item in the stream."""
|
||||||
# Create the runtime.
|
# Create the runtime.
|
||||||
runtime = SingleThreadedAgentRuntime()
|
runtime = SingleThreadedAgentRuntime()
|
||||||
|
|
||||||
|
@ -132,6 +148,7 @@ class BaseGroupChat(Team, ABC):
|
||||||
)
|
)
|
||||||
|
|
||||||
output_messages: List[InnerMessage | ChatMessage] = []
|
output_messages: List[InnerMessage | ChatMessage] = []
|
||||||
|
output_message_queue: asyncio.Queue[InnerMessage | ChatMessage | None] = asyncio.Queue()
|
||||||
|
|
||||||
async def collect_output_messages(
|
async def collect_output_messages(
|
||||||
_runtime: AgentRuntime,
|
_runtime: AgentRuntime,
|
||||||
|
@ -140,6 +157,7 @@ class BaseGroupChat(Team, ABC):
|
||||||
ctx: MessageContext,
|
ctx: MessageContext,
|
||||||
) -> None:
|
) -> None:
|
||||||
output_messages.append(message)
|
output_messages.append(message)
|
||||||
|
await output_message_queue.put(message)
|
||||||
|
|
||||||
await ClosureAgent.register(
|
await ClosureAgent.register(
|
||||||
runtime,
|
runtime,
|
||||||
|
@ -158,14 +176,29 @@ class BaseGroupChat(Team, ABC):
|
||||||
group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id)
|
group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id)
|
||||||
first_chat_message = TextMessage(content=task, source="user")
|
first_chat_message = TextMessage(content=task, source="user")
|
||||||
output_messages.append(first_chat_message)
|
output_messages.append(first_chat_message)
|
||||||
|
await output_message_queue.put(first_chat_message)
|
||||||
await runtime.publish_message(
|
await runtime.publish_message(
|
||||||
GroupChatPublishEvent(agent_message=first_chat_message),
|
GroupChatPublishEvent(agent_message=first_chat_message),
|
||||||
topic_id=team_topic_id,
|
topic_id=team_topic_id,
|
||||||
)
|
)
|
||||||
await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id)
|
await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id)
|
||||||
|
|
||||||
# Wait for the runtime to stop.
|
# Start a coroutine to stop the runtime and signal the output message queue is complete.
|
||||||
await runtime.stop_when_idle()
|
async def stop_runtime() -> None:
|
||||||
|
await runtime.stop_when_idle()
|
||||||
|
await output_message_queue.put(None)
|
||||||
|
|
||||||
# Return the result.
|
shutdown_task = asyncio.create_task(stop_runtime())
|
||||||
return TaskResult(messages=output_messages)
|
|
||||||
|
# Yield the messsages until the queue is empty.
|
||||||
|
while True:
|
||||||
|
message = await output_message_queue.get()
|
||||||
|
if message is None:
|
||||||
|
break
|
||||||
|
yield message
|
||||||
|
|
||||||
|
# Wait for the shutdown task to finish.
|
||||||
|
await shutdown_task
|
||||||
|
|
||||||
|
# Yield the final result.
|
||||||
|
yield TaskResult(messages=output_messages)
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import Any, List
|
||||||
from autogen_core.base import MessageContext
|
from autogen_core.base import MessageContext
|
||||||
from autogen_core.components import DefaultTopicId, event
|
from autogen_core.components import DefaultTopicId, event
|
||||||
|
|
||||||
from ...base import ChatAgent
|
from ...base import ChatAgent, Response
|
||||||
from ...messages import ChatMessage
|
from ...messages import ChatMessage
|
||||||
from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent
|
from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent
|
||||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||||
|
@ -37,28 +37,26 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
||||||
"""Handle a content request event by passing the messages in the buffer
|
"""Handle a content request event by passing the messages in the buffer
|
||||||
to the delegate agent and publish the response."""
|
to the delegate agent and publish the response."""
|
||||||
# Pass the messages in the buffer to the delegate agent.
|
# Pass the messages in the buffer to the delegate agent.
|
||||||
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
|
response: Response | None = None
|
||||||
if not any(isinstance(response.chat_message, msg_type) for msg_type in self._agent.produced_message_types):
|
async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token):
|
||||||
raise ValueError(
|
if isinstance(msg, Response):
|
||||||
f"The agent {self._agent.name} produced an unexpected message type: {type(response)}. "
|
await self.publish_message(
|
||||||
f"Expected one of: {self._agent.produced_message_types}. "
|
msg.chat_message,
|
||||||
f"Check the agent's produced_message_types property."
|
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||||
)
|
)
|
||||||
|
response = msg
|
||||||
|
else:
|
||||||
|
# Publish the message to the output topic.
|
||||||
|
await self.publish_message(msg, topic_id=DefaultTopicId(type=self._output_topic_type))
|
||||||
|
if response is None:
|
||||||
|
raise ValueError("The agent did not produce a final response. Check the agent's on_messages_stream method.")
|
||||||
|
|
||||||
# Publish inner messages to the output topic.
|
# Publish the response to the group chat.
|
||||||
if response.inner_messages is not None:
|
|
||||||
for inner_message in response.inner_messages:
|
|
||||||
await self.publish_message(inner_message, topic_id=DefaultTopicId(type=self._output_topic_type))
|
|
||||||
|
|
||||||
# Publish the response.
|
|
||||||
self._message_buffer.clear()
|
self._message_buffer.clear()
|
||||||
await self.publish_message(
|
await self.publish_message(
|
||||||
GroupChatPublishEvent(agent_message=response.chat_message, source=self.id),
|
GroupChatPublishEvent(agent_message=response.chat_message, source=self.id),
|
||||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Publish the response to the output topic.
|
|
||||||
await self.publish_message(response.chat_message, topic_id=DefaultTopicId(type=self._output_topic_type))
|
|
||||||
|
|
||||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
||||||
raise ValueError(f"Unhandled message in agent container: {type(message)}")
|
raise ValueError(f"Unhandled message in agent container: {type(message)}")
|
||||||
|
|
|
@ -61,24 +61,45 @@ class RoundRobinGroupChat(BaseGroupChat):
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from autogen_agentchat.agents import ToolUseAssistantAgent
|
from autogen_ext.models import OpenAIChatCompletionClient
|
||||||
from autogen_agentchat.teams import RoundRobinGroupChat, StopMessageTermination
|
from autogen_agentchat.agents import AssistantAgent
|
||||||
|
from autogen_agentchat.teams import RoundRobinGroupChat
|
||||||
|
from autogen_agentchat.task import StopMessageTermination
|
||||||
|
|
||||||
assistant = ToolUseAssistantAgent("Assistant", model_client=..., registered_tools=...)
|
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_weather(location: str) -> str:
|
||||||
|
return f"The weather in {location} is sunny."
|
||||||
|
|
||||||
|
|
||||||
|
assistant = AssistantAgent(
|
||||||
|
"Assistant",
|
||||||
|
model_client=model_client,
|
||||||
|
tools=[get_weather],
|
||||||
|
)
|
||||||
team = RoundRobinGroupChat([assistant])
|
team = RoundRobinGroupChat([assistant])
|
||||||
await team.run("What's the weather in New York?", termination_condition=StopMessageTermination())
|
stream = team.run_stream("What's the weather in New York?", termination_condition=StopMessageTermination())
|
||||||
|
async for message in stream:
|
||||||
|
print(message)
|
||||||
|
|
||||||
A team with multiple participants:
|
A team with multiple participants:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from autogen_agentchat.agents import CodingAssistantAgent, CodeExecutorAgent
|
from autogen_ext.models import OpenAIChatCompletionClient
|
||||||
from autogen_agentchat.teams import RoundRobinGroupChat, StopMessageTermination
|
from autogen_agentchat.agents import AssistantAgent
|
||||||
|
from autogen_agentchat.teams import RoundRobinGroupChat
|
||||||
|
from autogen_agentchat.task import StopMessageTermination
|
||||||
|
|
||||||
coding_assistant = CodingAssistantAgent("Coding_Assistant", model_client=...)
|
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||||
executor_agent = CodeExecutorAgent("Code_Executor", code_executor=...)
|
|
||||||
team = RoundRobinGroupChat([coding_assistant, executor_agent])
|
agent1 = AssistantAgent("Assistant1", model_client=model_client)
|
||||||
await team.run("Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination())
|
agent2 = AssistantAgent("Assistant2", model_client=model_client)
|
||||||
|
team = RoundRobinGroupChat([agent1, agent2])
|
||||||
|
stream = team.run_stream("Tell me some jokes.", termination_condition=StopMessageTermination())
|
||||||
|
async for message in stream:
|
||||||
|
print(message)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -170,14 +170,48 @@ class SelectorGroupChat(BaseGroupChat):
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from autogen_agentchat.agents import ToolUseAssistantAgent
|
from autogen_ext.models import OpenAIChatCompletionClient
|
||||||
from autogen_agentchat.teams import SelectorGroupChat, StopMessageTermination
|
from autogen_agentchat.agents import AssistantAgent
|
||||||
|
from autogen_agentchat.teams import SelectorGroupChat
|
||||||
|
from autogen_agentchat.task import StopMessageTermination
|
||||||
|
|
||||||
travel_advisor = ToolUseAssistantAgent("Travel_Advisor", model_client=..., registered_tools=...)
|
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||||
hotel_agent = ToolUseAssistantAgent("Hotel_Agent", model_client=..., registered_tools=...)
|
|
||||||
flight_agent = ToolUseAssistantAgent("Flight_Agent", model_client=..., registered_tools=...)
|
|
||||||
team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=...)
|
async def lookup_hotel(location: str) -> str:
|
||||||
await team.run("Book a 3-day trip to new york.", termination_condition=StopMessageTermination())
|
return f"Here are some hotels in {location}: hotel1, hotel2, hotel3."
|
||||||
|
|
||||||
|
|
||||||
|
async def lookup_flight(origin: str, destination: str) -> str:
|
||||||
|
return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3."
|
||||||
|
|
||||||
|
|
||||||
|
async def book_trip() -> str:
|
||||||
|
return "Your trip is booked!"
|
||||||
|
|
||||||
|
|
||||||
|
travel_advisor = AssistantAgent(
|
||||||
|
"Travel_Advisor",
|
||||||
|
model_client,
|
||||||
|
tools=[book_trip],
|
||||||
|
description="Helps with travel planning.",
|
||||||
|
)
|
||||||
|
hotel_agent = AssistantAgent(
|
||||||
|
"Hotel_Agent",
|
||||||
|
model_client,
|
||||||
|
tools=[lookup_hotel],
|
||||||
|
description="Helps with hotel booking.",
|
||||||
|
)
|
||||||
|
flight_agent = AssistantAgent(
|
||||||
|
"Flight_Agent",
|
||||||
|
model_client,
|
||||||
|
tools=[lookup_flight],
|
||||||
|
description="Helps with flight booking.",
|
||||||
|
)
|
||||||
|
team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=model_client)
|
||||||
|
stream = team.run_stream("Book a 3-day trip to new york.", termination_condition=StopMessageTermination())
|
||||||
|
async for message in stream:
|
||||||
|
print(message)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -79,7 +79,10 @@ class Swarm(BaseGroupChat):
|
||||||
)
|
)
|
||||||
|
|
||||||
team = Swarm([agent1, agent2])
|
team = Swarm([agent1, agent2])
|
||||||
await team.run("What is bob's birthday?", termination_condition=MaxMessageTermination(3))
|
|
||||||
|
stream = team.run_stream("What is bob's birthday?", termination_condition=MaxMessageTermination(3))
|
||||||
|
async for message in stream:
|
||||||
|
print(message)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, participants: List[ChatAgent], termination_condition: TerminationCondition | None = None):
|
def __init__(self, participants: List[ChatAgent], termination_condition: TerminationCondition | None = None):
|
||||||
|
|
|
@ -6,9 +6,9 @@ from typing import Any, AsyncGenerator, List
|
||||||
import pytest
|
import pytest
|
||||||
from autogen_agentchat import EVENT_LOGGER_NAME
|
from autogen_agentchat import EVENT_LOGGER_NAME
|
||||||
from autogen_agentchat.agents import AssistantAgent, Handoff
|
from autogen_agentchat.agents import AssistantAgent, Handoff
|
||||||
|
from autogen_agentchat.base import TaskResult
|
||||||
from autogen_agentchat.logging import FileLogHandler
|
from autogen_agentchat.logging import FileLogHandler
|
||||||
from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessages
|
from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
|
||||||
from autogen_core.base import CancellationToken
|
|
||||||
from autogen_core.components.tools import FunctionTool
|
from autogen_core.components.tools import FunctionTool
|
||||||
from autogen_ext.models import OpenAIChatCompletionClient
|
from autogen_ext.models import OpenAIChatCompletionClient
|
||||||
from openai.resources.chat.completions import AsyncCompletions
|
from openai.resources.chat.completions import AsyncCompletions
|
||||||
|
@ -114,9 +114,19 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
assert len(result.messages) == 4
|
assert len(result.messages) == 4
|
||||||
assert isinstance(result.messages[0], TextMessage)
|
assert isinstance(result.messages[0], TextMessage)
|
||||||
assert isinstance(result.messages[1], ToolCallMessage)
|
assert isinstance(result.messages[1], ToolCallMessage)
|
||||||
assert isinstance(result.messages[2], ToolCallResultMessages)
|
assert isinstance(result.messages[2], ToolCallResultMessage)
|
||||||
assert isinstance(result.messages[3], TextMessage)
|
assert isinstance(result.messages[3], TextMessage)
|
||||||
|
|
||||||
|
# Test streaming.
|
||||||
|
mock._curr_index = 0 # pyright: ignore
|
||||||
|
index = 0
|
||||||
|
async for message in tool_use_agent.run_stream("task"):
|
||||||
|
if isinstance(message, TaskResult):
|
||||||
|
assert message == result
|
||||||
|
else:
|
||||||
|
assert message == result.messages[index]
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
@ -160,8 +170,19 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
handoffs=[handoff],
|
handoffs=[handoff],
|
||||||
)
|
)
|
||||||
assert HandoffMessage in tool_use_agent.produced_message_types
|
assert HandoffMessage in tool_use_agent.produced_message_types
|
||||||
response = await tool_use_agent.on_messages(
|
result = await tool_use_agent.run("task")
|
||||||
[TextMessage(content="task", source="user")], cancellation_token=CancellationToken()
|
assert len(result.messages) == 4
|
||||||
)
|
assert isinstance(result.messages[0], TextMessage)
|
||||||
assert isinstance(response.chat_message, HandoffMessage)
|
assert isinstance(result.messages[1], ToolCallMessage)
|
||||||
assert response.chat_message.target == "agent2"
|
assert isinstance(result.messages[2], ToolCallResultMessage)
|
||||||
|
assert isinstance(result.messages[3], HandoffMessage)
|
||||||
|
|
||||||
|
# Test streaming.
|
||||||
|
mock._curr_index = 0 # pyright: ignore
|
||||||
|
index = 0
|
||||||
|
async for message in tool_use_agent.run_stream("task"):
|
||||||
|
if isinstance(message, TaskResult):
|
||||||
|
assert message == result
|
||||||
|
else:
|
||||||
|
assert message == result.messages[index]
|
||||||
|
index += 1
|
||||||
|
|
|
@ -12,7 +12,7 @@ from autogen_agentchat.agents import (
|
||||||
CodeExecutorAgent,
|
CodeExecutorAgent,
|
||||||
Handoff,
|
Handoff,
|
||||||
)
|
)
|
||||||
from autogen_agentchat.base import Response
|
from autogen_agentchat.base import Response, TaskResult
|
||||||
from autogen_agentchat.logging import FileLogHandler
|
from autogen_agentchat.logging import FileLogHandler
|
||||||
from autogen_agentchat.messages import (
|
from autogen_agentchat.messages import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
|
@ -20,7 +20,7 @@ from autogen_agentchat.messages import (
|
||||||
StopMessage,
|
StopMessage,
|
||||||
TextMessage,
|
TextMessage,
|
||||||
ToolCallMessage,
|
ToolCallMessage,
|
||||||
ToolCallResultMessages,
|
ToolCallResultMessage,
|
||||||
)
|
)
|
||||||
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination
|
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination
|
||||||
from autogen_agentchat.teams import (
|
from autogen_agentchat.teams import (
|
||||||
|
@ -59,6 +59,9 @@ class _MockChatCompletion:
|
||||||
self._curr_index += 1
|
self._curr_index += 1
|
||||||
return completion
|
return completion
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self._curr_index = 0
|
||||||
|
|
||||||
|
|
||||||
class _EchoAgent(BaseChatAgent):
|
class _EchoAgent(BaseChatAgent):
|
||||||
def __init__(self, name: str, description: str) -> None:
|
def __init__(self, name: str, description: str) -> None:
|
||||||
|
@ -147,7 +150,8 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
)
|
)
|
||||||
team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent])
|
team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent])
|
||||||
result = await team.run(
|
result = await team.run(
|
||||||
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
"Write a program that prints 'Hello, world!'",
|
||||||
|
termination_condition=StopMessageTermination(),
|
||||||
)
|
)
|
||||||
expected_messages = [
|
expected_messages = [
|
||||||
"Write a program that prints 'Hello, world!'",
|
"Write a program that prints 'Hello, world!'",
|
||||||
|
@ -164,6 +168,18 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
# Assert that all expected messages are in the collected messages
|
# Assert that all expected messages are in the collected messages
|
||||||
assert normalized_messages == expected_messages
|
assert normalized_messages == expected_messages
|
||||||
|
|
||||||
|
# Test streaming.
|
||||||
|
mock.reset()
|
||||||
|
index = 0
|
||||||
|
async for message in team.run_stream(
|
||||||
|
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
||||||
|
):
|
||||||
|
if isinstance(message, TaskResult):
|
||||||
|
assert message == result
|
||||||
|
else:
|
||||||
|
assert message == result.messages[index]
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
@ -230,13 +246,14 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
|
||||||
echo_agent = _EchoAgent("echo_agent", description="echo agent")
|
echo_agent = _EchoAgent("echo_agent", description="echo agent")
|
||||||
team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent])
|
team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent])
|
||||||
result = await team.run(
|
result = await team.run(
|
||||||
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
"Write a program that prints 'Hello, world!'",
|
||||||
|
termination_condition=StopMessageTermination(),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(result.messages) == 6
|
assert len(result.messages) == 6
|
||||||
assert isinstance(result.messages[0], TextMessage) # task
|
assert isinstance(result.messages[0], TextMessage) # task
|
||||||
assert isinstance(result.messages[1], ToolCallMessage) # tool call
|
assert isinstance(result.messages[1], ToolCallMessage) # tool call
|
||||||
assert isinstance(result.messages[2], ToolCallResultMessages) # tool call result
|
assert isinstance(result.messages[2], ToolCallResultMessage) # tool call result
|
||||||
assert isinstance(result.messages[3], TextMessage) # tool use agent response
|
assert isinstance(result.messages[3], TextMessage) # tool use agent response
|
||||||
assert isinstance(result.messages[4], TextMessage) # echo agent response
|
assert isinstance(result.messages[4], TextMessage) # echo agent response
|
||||||
assert isinstance(result.messages[5], StopMessage) # tool use agent response
|
assert isinstance(result.messages[5], StopMessage) # tool use agent response
|
||||||
|
@ -253,6 +270,19 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
|
||||||
assert context[2].content[0].call_id == "1"
|
assert context[2].content[0].call_id == "1"
|
||||||
assert context[3].content == "Hello"
|
assert context[3].content == "Hello"
|
||||||
|
|
||||||
|
# Test streaming.
|
||||||
|
tool_use_agent._model_context.clear() # pyright: ignore
|
||||||
|
mock.reset()
|
||||||
|
index = 0
|
||||||
|
async for message in team.run_stream(
|
||||||
|
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
||||||
|
):
|
||||||
|
if isinstance(message, TaskResult):
|
||||||
|
assert message == result
|
||||||
|
else:
|
||||||
|
assert message == result.messages[index]
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
@ -320,7 +350,8 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||||
)
|
)
|
||||||
result = await team.run(
|
result = await team.run(
|
||||||
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
"Write a program that prints 'Hello, world!'",
|
||||||
|
termination_condition=StopMessageTermination(),
|
||||||
)
|
)
|
||||||
assert len(result.messages) == 6
|
assert len(result.messages) == 6
|
||||||
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
||||||
|
@ -330,6 +361,19 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
assert result.messages[4].source == "agent2"
|
assert result.messages[4].source == "agent2"
|
||||||
assert result.messages[5].source == "agent1"
|
assert result.messages[5].source == "agent1"
|
||||||
|
|
||||||
|
# Test streaming.
|
||||||
|
mock.reset()
|
||||||
|
agent1._count = 0 # pyright: ignore
|
||||||
|
index = 0
|
||||||
|
async for message in team.run_stream(
|
||||||
|
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
||||||
|
):
|
||||||
|
if isinstance(message, TaskResult):
|
||||||
|
assert message == result
|
||||||
|
else:
|
||||||
|
assert message == result.messages[index]
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
@ -356,7 +400,8 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch)
|
||||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||||
)
|
)
|
||||||
result = await team.run(
|
result = await team.run(
|
||||||
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
"Write a program that prints 'Hello, world!'",
|
||||||
|
termination_condition=StopMessageTermination(),
|
||||||
)
|
)
|
||||||
assert len(result.messages) == 5
|
assert len(result.messages) == 5
|
||||||
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
||||||
|
@ -367,6 +412,19 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch)
|
||||||
# only one chat completion was called
|
# only one chat completion was called
|
||||||
assert mock._curr_index == 1 # pyright: ignore
|
assert mock._curr_index == 1 # pyright: ignore
|
||||||
|
|
||||||
|
# Test streaming.
|
||||||
|
mock.reset()
|
||||||
|
agent1._count = 0 # pyright: ignore
|
||||||
|
index = 0
|
||||||
|
async for message in team.run_stream(
|
||||||
|
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
||||||
|
):
|
||||||
|
if isinstance(message, TaskResult):
|
||||||
|
assert message == result
|
||||||
|
else:
|
||||||
|
assert message == result.messages[index]
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
@ -422,6 +480,18 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte
|
||||||
assert result.messages[2].source == "agent2"
|
assert result.messages[2].source == "agent2"
|
||||||
assert result.messages[3].source == "agent1"
|
assert result.messages[3].source == "agent1"
|
||||||
|
|
||||||
|
# Test streaming.
|
||||||
|
mock.reset()
|
||||||
|
index = 0
|
||||||
|
async for message in team.run_stream(
|
||||||
|
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
||||||
|
):
|
||||||
|
if isinstance(message, TaskResult):
|
||||||
|
assert message == result
|
||||||
|
else:
|
||||||
|
assert message == result.messages[index]
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
|
||||||
class _HandOffAgent(BaseChatAgent):
|
class _HandOffAgent(BaseChatAgent):
|
||||||
def __init__(self, name: str, description: str, next_agent: str) -> None:
|
def __init__(self, name: str, description: str, next_agent: str) -> None:
|
||||||
|
@ -446,8 +516,8 @@ async def test_swarm_handoff() -> None:
|
||||||
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
||||||
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
||||||
|
|
||||||
team = Swarm([second_agent, first_agent, third_agent])
|
team = Swarm([second_agent, first_agent, third_agent], termination_condition=MaxMessageTermination(6))
|
||||||
result = await team.run("task", termination_condition=MaxMessageTermination(6))
|
result = await team.run("task")
|
||||||
assert len(result.messages) == 6
|
assert len(result.messages) == 6
|
||||||
assert result.messages[0].content == "task"
|
assert result.messages[0].content == "task"
|
||||||
assert result.messages[1].content == "Transferred to third_agent."
|
assert result.messages[1].content == "Transferred to third_agent."
|
||||||
|
@ -456,6 +526,16 @@ async def test_swarm_handoff() -> None:
|
||||||
assert result.messages[4].content == "Transferred to third_agent."
|
assert result.messages[4].content == "Transferred to third_agent."
|
||||||
assert result.messages[5].content == "Transferred to first_agent."
|
assert result.messages[5].content == "Transferred to first_agent."
|
||||||
|
|
||||||
|
# Test streaming.
|
||||||
|
index = 0
|
||||||
|
stream = team.run_stream("task", termination_condition=MaxMessageTermination(6))
|
||||||
|
async for message in stream:
|
||||||
|
if isinstance(message, TaskResult):
|
||||||
|
assert message == result
|
||||||
|
else:
|
||||||
|
assert message == result.messages[index]
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
@ -514,19 +594,31 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
|
||||||
mock = _MockChatCompletion(chat_completions)
|
mock = _MockChatCompletion(chat_completions)
|
||||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||||
|
|
||||||
agnet1 = AssistantAgent(
|
agent1 = AssistantAgent(
|
||||||
"agent1",
|
"agent1",
|
||||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||||
handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")],
|
handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")],
|
||||||
)
|
)
|
||||||
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
|
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
|
||||||
team = Swarm([agnet1, agent2])
|
team = Swarm([agent1, agent2])
|
||||||
result = await team.run("task", termination_condition=StopMessageTermination())
|
result = await team.run("task", termination_condition=StopMessageTermination())
|
||||||
assert len(result.messages) == 7
|
assert len(result.messages) == 7
|
||||||
assert result.messages[0].content == "task"
|
assert result.messages[0].content == "task"
|
||||||
assert isinstance(result.messages[1], ToolCallMessage)
|
assert isinstance(result.messages[1], ToolCallMessage)
|
||||||
assert isinstance(result.messages[2], ToolCallResultMessages)
|
assert isinstance(result.messages[2], ToolCallResultMessage)
|
||||||
assert result.messages[3].content == "handoff to agent2"
|
assert result.messages[3].content == "handoff to agent2"
|
||||||
assert result.messages[4].content == "Transferred to agent1."
|
assert result.messages[4].content == "Transferred to agent1."
|
||||||
assert result.messages[5].content == "Hello"
|
assert result.messages[5].content == "Hello"
|
||||||
assert result.messages[6].content == "TERMINATE"
|
assert result.messages[6].content == "TERMINATE"
|
||||||
|
|
||||||
|
# Test streaming.
|
||||||
|
agent1._model_context.clear() # pyright: ignore
|
||||||
|
mock.reset()
|
||||||
|
index = 0
|
||||||
|
stream = team.run_stream("task", termination_condition=StopMessageTermination())
|
||||||
|
async for message in stream:
|
||||||
|
if isinstance(message, TaskResult):
|
||||||
|
assert message == result
|
||||||
|
else:
|
||||||
|
assert message == result.messages[index]
|
||||||
|
index += 1
|
||||||
|
|
Loading…
Reference in New Issue