mirror of https://github.com/microsoft/autogen.git
Formalize `ChatAgent` response as a dataclass with inner messages (#3990)
This commit is contained in:
parent
e63fd17ed5
commit
3d51ab76ae
|
@ -18,12 +18,16 @@ from autogen_core.components.tools import FunctionTool, Tool
|
|||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from .. import EVENT_LOGGER_NAME
|
||||
from ..base import Response
|
||||
from ..messages import (
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
InnerMessage,
|
||||
ResetMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallMessage,
|
||||
ToolCallResultMessages,
|
||||
)
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
|
@ -214,7 +218,7 @@ class AssistantAgent(BaseChatAgent):
|
|||
return [TextMessage, HandoffMessage, StopMessage]
|
||||
return [TextMessage, StopMessage]
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
# Add messages to the model context.
|
||||
for msg in messages:
|
||||
if isinstance(msg, ResetMessage):
|
||||
|
@ -222,6 +226,9 @@ class AssistantAgent(BaseChatAgent):
|
|||
else:
|
||||
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
|
||||
|
||||
# Inner messages.
|
||||
inner_messages: List[InnerMessage] = []
|
||||
|
||||
# Generate an inference result based on the current model context.
|
||||
llm_messages = self._system_messages + self._model_context
|
||||
result = await self._model_client.create(
|
||||
|
@ -234,12 +241,16 @@ class AssistantAgent(BaseChatAgent):
|
|||
# Run tool calls until the model produces a string response.
|
||||
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
|
||||
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
|
||||
# Add the tool call message to the output.
|
||||
inner_messages.append(ToolCallMessage(content=result.content, source=self.name))
|
||||
|
||||
# Execute the tool calls.
|
||||
results = await asyncio.gather(
|
||||
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
|
||||
)
|
||||
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
|
||||
self._model_context.append(FunctionExecutionResultMessage(content=results))
|
||||
inner_messages.append(ToolCallResultMessages(content=results, source=self.name))
|
||||
|
||||
# Detect handoff requests.
|
||||
handoffs: List[Handoff] = []
|
||||
|
@ -249,8 +260,13 @@ class AssistantAgent(BaseChatAgent):
|
|||
if len(handoffs) > 0:
|
||||
if len(handoffs) > 1:
|
||||
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
|
||||
# Respond with a handoff message.
|
||||
return HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name)
|
||||
# Return the output messages to signal the handoff.
|
||||
return Response(
|
||||
chat_message=HandoffMessage(
|
||||
content=handoffs[0].message, target=handoffs[0].target, source=self.name
|
||||
),
|
||||
inner_messages=inner_messages,
|
||||
)
|
||||
|
||||
# Generate an inference result based on the current model context.
|
||||
result = await self._model_client.create(
|
||||
|
@ -262,9 +278,13 @@ class AssistantAgent(BaseChatAgent):
|
|||
# Detect stop request.
|
||||
request_stop = "terminate" in result.content.strip().lower()
|
||||
if request_stop:
|
||||
return StopMessage(content=result.content, source=self.name)
|
||||
return Response(
|
||||
chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
|
||||
)
|
||||
|
||||
return TextMessage(content=result.content, source=self.name)
|
||||
return Response(
|
||||
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
|
||||
)
|
||||
|
||||
async def _execute_tool_call(
|
||||
self, tool_call: FunctionCall, cancellation_token: CancellationToken
|
||||
|
|
|
@ -3,9 +3,8 @@ from typing import List, Sequence
|
|||
|
||||
from autogen_core.base import CancellationToken
|
||||
|
||||
from ..base import ChatAgent, TaskResult, TerminationCondition
|
||||
from ..messages import ChatMessage
|
||||
from ..teams import RoundRobinGroupChat
|
||||
from ..base import ChatAgent, Response, TaskResult, TerminationCondition
|
||||
from ..messages import ChatMessage, InnerMessage, TextMessage
|
||||
|
||||
|
||||
class BaseChatAgent(ChatAgent, ABC):
|
||||
|
@ -37,8 +36,8 @@ class BaseChatAgent(ChatAgent, ABC):
|
|||
...
|
||||
|
||||
@abstractmethod
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
"""Handle incoming messages and return a response message."""
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
"""Handles incoming messages and returns a response."""
|
||||
...
|
||||
|
||||
async def run(
|
||||
|
@ -49,10 +48,12 @@ class BaseChatAgent(ChatAgent, ABC):
|
|||
termination_condition: TerminationCondition | None = None,
|
||||
) -> TaskResult:
|
||||
"""Run the agent with the given task and return the result."""
|
||||
group_chat = RoundRobinGroupChat(participants=[self])
|
||||
result = await group_chat.run(
|
||||
task=task,
|
||||
cancellation_token=cancellation_token,
|
||||
termination_condition=termination_condition,
|
||||
)
|
||||
return result
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
first_message = TextMessage(content=task, source="user")
|
||||
response = await self.on_messages([first_message], cancellation_token)
|
||||
messages: List[InnerMessage | ChatMessage] = [first_message]
|
||||
if response.inner_messages is not None:
|
||||
messages += response.inner_messages
|
||||
messages.append(response.chat_message)
|
||||
return TaskResult(messages=messages)
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import List, Sequence
|
|||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components.code_executor import CodeBlock, CodeExecutor, extract_markdown_code_blocks
|
||||
|
||||
from ..base import Response
|
||||
from ..messages import ChatMessage, TextMessage
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
|
@ -25,7 +26,7 @@ class CodeExecutorAgent(BaseChatAgent):
|
|||
"""The types of messages that the code executor agent produces."""
|
||||
return [TextMessage]
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
# Extract code blocks from the messages.
|
||||
code_blocks: List[CodeBlock] = []
|
||||
for msg in messages:
|
||||
|
@ -34,6 +35,6 @@ class CodeExecutorAgent(BaseChatAgent):
|
|||
if code_blocks:
|
||||
# Execute the code blocks.
|
||||
result = await self._code_executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token)
|
||||
return TextMessage(content=result.output, source=self.name)
|
||||
return Response(chat_message=TextMessage(content=result.output, source=self.name))
|
||||
else:
|
||||
return TextMessage(content="No code blocks found in the thread.", source=self.name)
|
||||
return Response(chat_message=TextMessage(content="No code blocks found in the thread.", source=self.name))
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
from ._chat_agent import ChatAgent
|
||||
from ._chat_agent import ChatAgent, Response
|
||||
from ._task import TaskResult, TaskRunner
|
||||
from ._team import Team
|
||||
from ._termination import TerminatedException, TerminationCondition
|
||||
|
||||
__all__ = [
|
||||
"ChatAgent",
|
||||
"Response",
|
||||
"Team",
|
||||
"TerminatedException",
|
||||
"TerminationCondition",
|
||||
|
|
|
@ -1,12 +1,24 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Protocol, Sequence, runtime_checkable
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
|
||||
from ..messages import ChatMessage
|
||||
from ..messages import ChatMessage, InnerMessage
|
||||
from ._task import TaskResult, TaskRunner
|
||||
from ._termination import TerminationCondition
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Response:
|
||||
"""A response from calling :meth:`ChatAgent.on_messages`."""
|
||||
|
||||
chat_message: ChatMessage
|
||||
"""A chat message produced by the agent as the response."""
|
||||
|
||||
inner_messages: List[InnerMessage] | None = None
|
||||
"""Inner messages produced by the agent."""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ChatAgent(TaskRunner, Protocol):
|
||||
"""Protocol for a chat agent."""
|
||||
|
@ -29,8 +41,8 @@ class ChatAgent(TaskRunner, Protocol):
|
|||
"""The types of messages that the agent produces."""
|
||||
...
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
"""Handle incoming messages and return a response message."""
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
"""Handles incoming messages and returns a response."""
|
||||
...
|
||||
|
||||
async def run(
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Protocol, Sequence
|
|||
|
||||
from autogen_core.base import CancellationToken
|
||||
|
||||
from ..messages import ChatMessage
|
||||
from ..messages import ChatMessage, InnerMessage
|
||||
from ._termination import TerminationCondition
|
||||
|
||||
|
||||
|
@ -11,7 +11,7 @@ from ._termination import TerminationCondition
|
|||
class TaskResult:
|
||||
"""Result of running a task."""
|
||||
|
||||
messages: Sequence[ChatMessage]
|
||||
messages: Sequence[InnerMessage | ChatMessage]
|
||||
"""Messages produced by the task."""
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from typing import List
|
||||
|
||||
from autogen_core.components import Image
|
||||
from autogen_core.components import FunctionCall, Image
|
||||
from autogen_core.components.models import FunctionExecutionResult
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
@ -49,8 +50,26 @@ class ResetMessage(BaseMessage):
|
|||
"""The content for the reset message."""
|
||||
|
||||
|
||||
class ToolCallMessage(BaseMessage):
|
||||
"""A message signaling the use of tools."""
|
||||
|
||||
content: List[FunctionCall]
|
||||
"""The tool calls."""
|
||||
|
||||
|
||||
class ToolCallResultMessages(BaseMessage):
|
||||
"""A message signaling the results of tool calls."""
|
||||
|
||||
content: List[FunctionExecutionResult]
|
||||
"""The tool call results."""
|
||||
|
||||
|
||||
InnerMessage = ToolCallMessage | ToolCallResultMessages
|
||||
"""Messages for intra-agent monologues."""
|
||||
|
||||
|
||||
ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ResetMessage
|
||||
"""A message used by agents in a team."""
|
||||
"""Messages for agent-to-agent communication."""
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -60,5 +79,7 @@ __all__ = [
|
|||
"StopMessage",
|
||||
"HandoffMessage",
|
||||
"ResetMessage",
|
||||
"ToolCallMessage",
|
||||
"ToolCallResultMessages",
|
||||
"ChatMessage",
|
||||
]
|
||||
|
|
|
@ -15,7 +15,7 @@ from autogen_core.base import (
|
|||
from autogen_core.components import ClosureAgent, TypeSubscription
|
||||
|
||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
||||
from ...messages import ChatMessage, TextMessage
|
||||
from ...messages import ChatMessage, InnerMessage, TextMessage
|
||||
from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
from ._chat_agent_container import ChatAgentContainer
|
||||
|
@ -56,12 +56,13 @@ class BaseGroupChat(Team, ABC):
|
|||
def _create_participant_factory(
|
||||
self,
|
||||
parent_topic_type: str,
|
||||
output_topic_type: str,
|
||||
agent: ChatAgent,
|
||||
) -> Callable[[], ChatAgentContainer]:
|
||||
def _factory() -> ChatAgentContainer:
|
||||
id = AgentInstantiationContext.current_agent_id()
|
||||
assert id == AgentId(type=agent.name, key=self._team_id)
|
||||
container = ChatAgentContainer(parent_topic_type, agent)
|
||||
container = ChatAgentContainer(parent_topic_type, output_topic_type, agent)
|
||||
assert container.id == id
|
||||
return container
|
||||
|
||||
|
@ -85,6 +86,7 @@ class BaseGroupChat(Team, ABC):
|
|||
group_chat_manager_topic_type = group_chat_manager_agent_type.type
|
||||
group_topic_type = "round_robin_group_topic"
|
||||
team_topic_type = "team_topic"
|
||||
output_topic_type = "output_topic"
|
||||
|
||||
# Register participants.
|
||||
participant_topic_types: List[str] = []
|
||||
|
@ -97,7 +99,7 @@ class BaseGroupChat(Team, ABC):
|
|||
await ChatAgentContainer.register(
|
||||
runtime,
|
||||
type=agent_type,
|
||||
factory=self._create_participant_factory(group_topic_type, participant),
|
||||
factory=self._create_participant_factory(group_topic_type, output_topic_type, participant),
|
||||
)
|
||||
# Add subscriptions for the participant.
|
||||
await runtime.add_subscription(TypeSubscription(topic_type=topic_type, agent_type=agent_type))
|
||||
|
@ -129,22 +131,22 @@ class BaseGroupChat(Team, ABC):
|
|||
TypeSubscription(topic_type=team_topic_type, agent_type=group_chat_manager_agent_type.type)
|
||||
)
|
||||
|
||||
group_chat_messages: List[ChatMessage] = []
|
||||
output_messages: List[InnerMessage | ChatMessage] = []
|
||||
|
||||
async def collect_group_chat_messages(
|
||||
async def collect_output_messages(
|
||||
_runtime: AgentRuntime,
|
||||
id: AgentId,
|
||||
message: GroupChatPublishEvent,
|
||||
message: InnerMessage | ChatMessage,
|
||||
ctx: MessageContext,
|
||||
) -> None:
|
||||
group_chat_messages.append(message.agent_message)
|
||||
output_messages.append(message)
|
||||
|
||||
await ClosureAgent.register(
|
||||
runtime,
|
||||
type="collect_group_chat_messages",
|
||||
closure=collect_group_chat_messages,
|
||||
type="collect_output_messages",
|
||||
closure=collect_output_messages,
|
||||
subscriptions=lambda: [
|
||||
TypeSubscription(topic_type=group_topic_type, agent_type="collect_group_chat_messages"),
|
||||
TypeSubscription(topic_type=output_topic_type, agent_type="collect_output_messages"),
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -154,8 +156,10 @@ class BaseGroupChat(Team, ABC):
|
|||
# Run the team by publishing the task to the team topic and then requesting the result.
|
||||
team_topic_id = TopicId(type=team_topic_type, source=self._team_id)
|
||||
group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id)
|
||||
first_chat_message = TextMessage(content=task, source="user")
|
||||
output_messages.append(first_chat_message)
|
||||
await runtime.publish_message(
|
||||
GroupChatPublishEvent(agent_message=TextMessage(content=task, source="user")),
|
||||
GroupChatPublishEvent(agent_message=first_chat_message),
|
||||
topic_id=team_topic_id,
|
||||
)
|
||||
await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id)
|
||||
|
@ -164,4 +168,4 @@ class BaseGroupChat(Team, ABC):
|
|||
await runtime.stop_when_idle()
|
||||
|
||||
# Return the result.
|
||||
return TaskResult(messages=group_chat_messages)
|
||||
return TaskResult(messages=output_messages)
|
||||
|
|
|
@ -16,12 +16,14 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
|||
|
||||
Args:
|
||||
parent_topic_type (str): The topic type of the parent orchestrator.
|
||||
output_topic_type (str): The topic type for the output.
|
||||
agent (ChatAgent): The agent to delegate message handling to.
|
||||
"""
|
||||
|
||||
def __init__(self, parent_topic_type: str, agent: ChatAgent) -> None:
|
||||
def __init__(self, parent_topic_type: str, output_topic_type: str, agent: ChatAgent) -> None:
|
||||
super().__init__(description=agent.description)
|
||||
self._parent_topic_type = parent_topic_type
|
||||
self._output_topic_type = output_topic_type
|
||||
self._agent = agent
|
||||
self._message_buffer: List[ChatMessage] = []
|
||||
|
||||
|
@ -36,18 +38,27 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
|||
to the delegate agent and publish the response."""
|
||||
# Pass the messages in the buffer to the delegate agent.
|
||||
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
|
||||
if not any(isinstance(response, msg_type) for msg_type in self._agent.produced_message_types):
|
||||
if not any(isinstance(response.chat_message, msg_type) for msg_type in self._agent.produced_message_types):
|
||||
raise ValueError(
|
||||
f"The agent {self._agent.name} produced an unexpected message type: {type(response)}. "
|
||||
f"Expected one of: {self._agent.produced_message_types}"
|
||||
f"Expected one of: {self._agent.produced_message_types}. "
|
||||
f"Check the agent's produced_message_types property."
|
||||
)
|
||||
|
||||
# Publish inner messages to the output topic.
|
||||
if response.inner_messages is not None:
|
||||
for inner_message in response.inner_messages:
|
||||
await self.publish_message(inner_message, topic_id=DefaultTopicId(type=self._output_topic_type))
|
||||
|
||||
# Publish the response.
|
||||
self._message_buffer.clear()
|
||||
await self.publish_message(
|
||||
GroupChatPublishEvent(agent_message=response, source=self.id),
|
||||
GroupChatPublishEvent(agent_message=response.chat_message, source=self.id),
|
||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||
)
|
||||
|
||||
# Publish the response to the output topic.
|
||||
await self.publish_message(response.chat_message, topic_id=DefaultTopicId(type=self._output_topic_type))
|
||||
|
||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
||||
raise ValueError(f"Unhandled message in agent container: {type(message)}")
|
||||
|
|
|
@ -7,7 +7,7 @@ import pytest
|
|||
from autogen_agentchat import EVENT_LOGGER_NAME
|
||||
from autogen_agentchat.agents import AssistantAgent, Handoff
|
||||
from autogen_agentchat.logging import FileLogHandler
|
||||
from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage
|
||||
from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessages
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components.tools import FunctionTool
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
|
@ -111,10 +111,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||
)
|
||||
result = await tool_use_agent.run("task")
|
||||
assert len(result.messages) == 3
|
||||
assert len(result.messages) == 4
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], TextMessage)
|
||||
assert isinstance(result.messages[2], StopMessage)
|
||||
assert isinstance(result.messages[1], ToolCallMessage)
|
||||
assert isinstance(result.messages[2], ToolCallResultMessages)
|
||||
assert isinstance(result.messages[3], TextMessage)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -162,5 +163,5 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
response = await tool_use_agent.on_messages(
|
||||
[TextMessage(content="task", source="user")], cancellation_token=CancellationToken()
|
||||
)
|
||||
assert isinstance(response, HandoffMessage)
|
||||
assert response.target == "agent2"
|
||||
assert isinstance(response.chat_message, HandoffMessage)
|
||||
assert response.chat_message.target == "agent2"
|
||||
|
|
|
@ -12,12 +12,15 @@ from autogen_agentchat.agents import (
|
|||
CodeExecutorAgent,
|
||||
Handoff,
|
||||
)
|
||||
from autogen_agentchat.base import Response
|
||||
from autogen_agentchat.logging import FileLogHandler
|
||||
from autogen_agentchat.messages import (
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallMessage,
|
||||
ToolCallResultMessages,
|
||||
)
|
||||
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination
|
||||
from autogen_agentchat.teams import (
|
||||
|
@ -66,14 +69,14 @@ class _EchoAgent(BaseChatAgent):
|
|||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
return [TextMessage]
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
if len(messages) > 0:
|
||||
assert isinstance(messages[0], TextMessage)
|
||||
self._last_message = messages[0].content
|
||||
return TextMessage(content=messages[0].content, source=self.name)
|
||||
return Response(chat_message=TextMessage(content=messages[0].content, source=self.name))
|
||||
else:
|
||||
assert self._last_message is not None
|
||||
return TextMessage(content=self._last_message, source=self.name)
|
||||
return Response(chat_message=TextMessage(content=self._last_message, source=self.name))
|
||||
|
||||
|
||||
class _StopAgent(_EchoAgent):
|
||||
|
@ -86,11 +89,11 @@ class _StopAgent(_EchoAgent):
|
|||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
return [TextMessage, StopMessage]
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
self._count += 1
|
||||
if self._count < self._stop_at:
|
||||
return await super().on_messages(messages, cancellation_token)
|
||||
return StopMessage(content="TERMINATE", source=self.name)
|
||||
return Response(chat_message=StopMessage(content="TERMINATE", source=self.name))
|
||||
|
||||
|
||||
def _pass_function(input: str) -> str:
|
||||
|
@ -230,11 +233,13 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
|
|||
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
||||
)
|
||||
|
||||
assert len(result.messages) == 4
|
||||
assert len(result.messages) == 6
|
||||
assert isinstance(result.messages[0], TextMessage) # task
|
||||
assert isinstance(result.messages[1], TextMessage) # tool use agent response
|
||||
assert isinstance(result.messages[2], TextMessage) # echo agent response
|
||||
assert isinstance(result.messages[3], StopMessage) # tool use agent response
|
||||
assert isinstance(result.messages[1], ToolCallMessage) # tool call
|
||||
assert isinstance(result.messages[2], ToolCallResultMessages) # tool call result
|
||||
assert isinstance(result.messages[3], TextMessage) # tool use agent response
|
||||
assert isinstance(result.messages[4], TextMessage) # echo agent response
|
||||
assert isinstance(result.messages[5], StopMessage) # tool use agent response
|
||||
|
||||
context = tool_use_agent._model_context # pyright: ignore
|
||||
assert context[0].content == "Write a program that prints 'Hello, world!'"
|
||||
|
@ -427,8 +432,12 @@ class _HandOffAgent(BaseChatAgent):
|
|||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
return [HandoffMessage]
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
return HandoffMessage(content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name)
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
return Response(
|
||||
chat_message=HandoffMessage(
|
||||
content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -513,9 +522,11 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
|
|||
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
|
||||
team = Swarm([agnet1, agent2])
|
||||
result = await team.run("task", termination_condition=StopMessageTermination())
|
||||
assert len(result.messages) == 5
|
||||
assert len(result.messages) == 7
|
||||
assert result.messages[0].content == "task"
|
||||
assert result.messages[1].content == "handoff to agent2"
|
||||
assert result.messages[2].content == "Transferred to agent1."
|
||||
assert result.messages[3].content == "Hello"
|
||||
assert result.messages[4].content == "TERMINATE"
|
||||
assert isinstance(result.messages[1], ToolCallMessage)
|
||||
assert isinstance(result.messages[2], ToolCallResultMessages)
|
||||
assert result.messages[3].content == "handoff to agent2"
|
||||
assert result.messages[4].content == "Transferred to agent1."
|
||||
assert result.messages[5].content == "Hello"
|
||||
assert result.messages[6].content == "TERMINATE"
|
||||
|
|
|
@ -251,6 +251,7 @@
|
|||
"from typing import List, Sequence\n",
|
||||
"\n",
|
||||
"from autogen_agentchat.agents import BaseChatAgent\n",
|
||||
"from autogen_agentchat.base import Response\n",
|
||||
"from autogen_agentchat.messages import (\n",
|
||||
" ChatMessage,\n",
|
||||
" StopMessage,\n",
|
||||
|
@ -266,11 +267,11 @@
|
|||
" def produced_message_types(self) -> List[type[ChatMessage]]:\n",
|
||||
" return [TextMessage, StopMessage]\n",
|
||||
"\n",
|
||||
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:\n",
|
||||
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||||
" user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n",
|
||||
" if \"TERMINATE\" in user_input:\n",
|
||||
" return StopMessage(content=\"User has terminated the conversation.\", source=self.name)\n",
|
||||
" return TextMessage(content=user_input, source=self.name)\n",
|
||||
" return Response(chat_message=StopMessage(content=\"User has terminated the conversation.\", source=self.name))\n",
|
||||
" return Response(chat_message=TextMessage(content=user_input, source=self.name))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"user_proxy_agent = UserProxyAgent(name=\"user_proxy_agent\")\n",
|
||||
|
|
|
@ -45,6 +45,7 @@
|
|||
" CodingAssistantAgent,\n",
|
||||
" ToolUseAssistantAgent,\n",
|
||||
")\n",
|
||||
"from autogen_agentchat.base import Response\n",
|
||||
"from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage\n",
|
||||
"from autogen_agentchat.task import StopMessageTermination\n",
|
||||
"from autogen_agentchat.teams import SelectorGroupChat\n",
|
||||
|
@ -75,11 +76,11 @@
|
|||
" def produced_message_types(self) -> List[type[ChatMessage]]:\n",
|
||||
" return [TextMessage, StopMessage]\n",
|
||||
"\n",
|
||||
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:\n",
|
||||
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||||
" user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n",
|
||||
" if \"TERMINATE\" in user_input:\n",
|
||||
" return StopMessage(content=\"User has terminated the conversation.\", source=self.name)\n",
|
||||
" return TextMessage(content=user_input, source=self.name)"
|
||||
" return Response(chat_message=StopMessage(content=\"User has terminated the conversation.\", source=self.name))\n",
|
||||
" return Response(chat_message=TextMessage(content=user_input, source=self.name))"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue