Formalize `ChatAgent` response as a dataclass with inner messages (#3990)

This commit is contained in:
Eric Zhu 2024-10-30 10:27:57 -07:00 committed by GitHub
parent e63fd17ed5
commit 3d51ab76ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 157 additions and 72 deletions

View File

@ -18,12 +18,16 @@ from autogen_core.components.tools import FunctionTool, Tool
from pydantic import BaseModel, ConfigDict, Field, model_validator
from .. import EVENT_LOGGER_NAME
from ..base import Response
from ..messages import (
ChatMessage,
HandoffMessage,
InnerMessage,
ResetMessage,
StopMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessages,
)
from ._base_chat_agent import BaseChatAgent
@ -214,7 +218,7 @@ class AssistantAgent(BaseChatAgent):
return [TextMessage, HandoffMessage, StopMessage]
return [TextMessage, StopMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
# Add messages to the model context.
for msg in messages:
if isinstance(msg, ResetMessage):
@ -222,6 +226,9 @@ class AssistantAgent(BaseChatAgent):
else:
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
# Inner messages.
inner_messages: List[InnerMessage] = []
# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
result = await self._model_client.create(
@ -234,12 +241,16 @@ class AssistantAgent(BaseChatAgent):
# Run tool calls until the model produces a string response.
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
# Add the tool call message to the output.
inner_messages.append(ToolCallMessage(content=result.content, source=self.name))
# Execute the tool calls.
results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
)
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
self._model_context.append(FunctionExecutionResultMessage(content=results))
inner_messages.append(ToolCallResultMessages(content=results, source=self.name))
# Detect handoff requests.
handoffs: List[Handoff] = []
@ -249,8 +260,13 @@ class AssistantAgent(BaseChatAgent):
if len(handoffs) > 0:
if len(handoffs) > 1:
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
# Respond with a handoff message.
return HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name)
# Return the output messages to signal the handoff.
return Response(
chat_message=HandoffMessage(
content=handoffs[0].message, target=handoffs[0].target, source=self.name
),
inner_messages=inner_messages,
)
# Generate an inference result based on the current model context.
result = await self._model_client.create(
@ -262,9 +278,13 @@ class AssistantAgent(BaseChatAgent):
# Detect stop request.
request_stop = "terminate" in result.content.strip().lower()
if request_stop:
return StopMessage(content=result.content, source=self.name)
return Response(
chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
return TextMessage(content=result.content, source=self.name)
return Response(
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken

View File

@ -3,9 +3,8 @@ from typing import List, Sequence
from autogen_core.base import CancellationToken
from ..base import ChatAgent, TaskResult, TerminationCondition
from ..messages import ChatMessage
from ..teams import RoundRobinGroupChat
from ..base import ChatAgent, Response, TaskResult, TerminationCondition
from ..messages import ChatMessage, InnerMessage, TextMessage
class BaseChatAgent(ChatAgent, ABC):
@ -37,8 +36,8 @@ class BaseChatAgent(ChatAgent, ABC):
...
@abstractmethod
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
"""Handle incoming messages and return a response message."""
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
"""Handles incoming messages and returns a response."""
...
async def run(
@ -49,10 +48,12 @@ class BaseChatAgent(ChatAgent, ABC):
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the agent with the given task and return the result."""
group_chat = RoundRobinGroupChat(participants=[self])
result = await group_chat.run(
task=task,
cancellation_token=cancellation_token,
termination_condition=termination_condition,
)
return result
if cancellation_token is None:
cancellation_token = CancellationToken()
first_message = TextMessage(content=task, source="user")
response = await self.on_messages([first_message], cancellation_token)
messages: List[InnerMessage | ChatMessage] = [first_message]
if response.inner_messages is not None:
messages += response.inner_messages
messages.append(response.chat_message)
return TaskResult(messages=messages)

View File

@ -3,6 +3,7 @@ from typing import List, Sequence
from autogen_core.base import CancellationToken
from autogen_core.components.code_executor import CodeBlock, CodeExecutor, extract_markdown_code_blocks
from ..base import Response
from ..messages import ChatMessage, TextMessage
from ._base_chat_agent import BaseChatAgent
@ -25,7 +26,7 @@ class CodeExecutorAgent(BaseChatAgent):
"""The types of messages that the code executor agent produces."""
return [TextMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
# Extract code blocks from the messages.
code_blocks: List[CodeBlock] = []
for msg in messages:
@ -34,6 +35,6 @@ class CodeExecutorAgent(BaseChatAgent):
if code_blocks:
# Execute the code blocks.
result = await self._code_executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token)
return TextMessage(content=result.output, source=self.name)
return Response(chat_message=TextMessage(content=result.output, source=self.name))
else:
return TextMessage(content="No code blocks found in the thread.", source=self.name)
return Response(chat_message=TextMessage(content="No code blocks found in the thread.", source=self.name))

View File

@ -1,10 +1,11 @@
from ._chat_agent import ChatAgent
from ._chat_agent import ChatAgent, Response
from ._task import TaskResult, TaskRunner
from ._team import Team
from ._termination import TerminatedException, TerminationCondition
__all__ = [
"ChatAgent",
"Response",
"Team",
"TerminatedException",
"TerminationCondition",

View File

@ -1,12 +1,24 @@
from dataclasses import dataclass
from typing import List, Protocol, Sequence, runtime_checkable
from autogen_core.base import CancellationToken
from ..messages import ChatMessage
from ..messages import ChatMessage, InnerMessage
from ._task import TaskResult, TaskRunner
from ._termination import TerminationCondition
@dataclass(kw_only=True)
class Response:
"""A response from calling :meth:`ChatAgent.on_messages`."""
chat_message: ChatMessage
"""A chat message produced by the agent as the response."""
inner_messages: List[InnerMessage] | None = None
"""Inner messages produced by the agent."""
@runtime_checkable
class ChatAgent(TaskRunner, Protocol):
"""Protocol for a chat agent."""
@ -29,8 +41,8 @@ class ChatAgent(TaskRunner, Protocol):
"""The types of messages that the agent produces."""
...
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
"""Handle incoming messages and return a response message."""
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
"""Handles incoming messages and returns a response."""
...
async def run(

View File

@ -3,7 +3,7 @@ from typing import Protocol, Sequence
from autogen_core.base import CancellationToken
from ..messages import ChatMessage
from ..messages import ChatMessage, InnerMessage
from ._termination import TerminationCondition
@ -11,7 +11,7 @@ from ._termination import TerminationCondition
class TaskResult:
"""Result of running a task."""
messages: Sequence[ChatMessage]
messages: Sequence[InnerMessage | ChatMessage]
"""Messages produced by the task."""

View File

@ -1,6 +1,7 @@
from typing import List
from autogen_core.components import Image
from autogen_core.components import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult
from pydantic import BaseModel
@ -49,8 +50,26 @@ class ResetMessage(BaseMessage):
"""The content for the reset message."""
class ToolCallMessage(BaseMessage):
"""A message signaling the use of tools."""
content: List[FunctionCall]
"""The tool calls."""
class ToolCallResultMessages(BaseMessage):
"""A message signaling the results of tool calls."""
content: List[FunctionExecutionResult]
"""The tool call results."""
InnerMessage = ToolCallMessage | ToolCallResultMessages
"""Messages for intra-agent monologues."""
ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ResetMessage
"""A message used by agents in a team."""
"""Messages for agent-to-agent communication."""
__all__ = [
@ -60,5 +79,7 @@ __all__ = [
"StopMessage",
"HandoffMessage",
"ResetMessage",
"ToolCallMessage",
"ToolCallResultMessages",
"ChatMessage",
]

View File

@ -15,7 +15,7 @@ from autogen_core.base import (
from autogen_core.components import ClosureAgent, TypeSubscription
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
from ...messages import ChatMessage, TextMessage
from ...messages import ChatMessage, InnerMessage, TextMessage
from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent
from ._base_group_chat_manager import BaseGroupChatManager
from ._chat_agent_container import ChatAgentContainer
@ -56,12 +56,13 @@ class BaseGroupChat(Team, ABC):
def _create_participant_factory(
self,
parent_topic_type: str,
output_topic_type: str,
agent: ChatAgent,
) -> Callable[[], ChatAgentContainer]:
def _factory() -> ChatAgentContainer:
id = AgentInstantiationContext.current_agent_id()
assert id == AgentId(type=agent.name, key=self._team_id)
container = ChatAgentContainer(parent_topic_type, agent)
container = ChatAgentContainer(parent_topic_type, output_topic_type, agent)
assert container.id == id
return container
@ -85,6 +86,7 @@ class BaseGroupChat(Team, ABC):
group_chat_manager_topic_type = group_chat_manager_agent_type.type
group_topic_type = "round_robin_group_topic"
team_topic_type = "team_topic"
output_topic_type = "output_topic"
# Register participants.
participant_topic_types: List[str] = []
@ -97,7 +99,7 @@ class BaseGroupChat(Team, ABC):
await ChatAgentContainer.register(
runtime,
type=agent_type,
factory=self._create_participant_factory(group_topic_type, participant),
factory=self._create_participant_factory(group_topic_type, output_topic_type, participant),
)
# Add subscriptions for the participant.
await runtime.add_subscription(TypeSubscription(topic_type=topic_type, agent_type=agent_type))
@ -129,22 +131,22 @@ class BaseGroupChat(Team, ABC):
TypeSubscription(topic_type=team_topic_type, agent_type=group_chat_manager_agent_type.type)
)
group_chat_messages: List[ChatMessage] = []
output_messages: List[InnerMessage | ChatMessage] = []
async def collect_group_chat_messages(
async def collect_output_messages(
_runtime: AgentRuntime,
id: AgentId,
message: GroupChatPublishEvent,
message: InnerMessage | ChatMessage,
ctx: MessageContext,
) -> None:
group_chat_messages.append(message.agent_message)
output_messages.append(message)
await ClosureAgent.register(
runtime,
type="collect_group_chat_messages",
closure=collect_group_chat_messages,
type="collect_output_messages",
closure=collect_output_messages,
subscriptions=lambda: [
TypeSubscription(topic_type=group_topic_type, agent_type="collect_group_chat_messages"),
TypeSubscription(topic_type=output_topic_type, agent_type="collect_output_messages"),
],
)
@ -154,8 +156,10 @@ class BaseGroupChat(Team, ABC):
# Run the team by publishing the task to the team topic and then requesting the result.
team_topic_id = TopicId(type=team_topic_type, source=self._team_id)
group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id)
first_chat_message = TextMessage(content=task, source="user")
output_messages.append(first_chat_message)
await runtime.publish_message(
GroupChatPublishEvent(agent_message=TextMessage(content=task, source="user")),
GroupChatPublishEvent(agent_message=first_chat_message),
topic_id=team_topic_id,
)
await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id)
@ -164,4 +168,4 @@ class BaseGroupChat(Team, ABC):
await runtime.stop_when_idle()
# Return the result.
return TaskResult(messages=group_chat_messages)
return TaskResult(messages=output_messages)

View File

@ -16,12 +16,14 @@ class ChatAgentContainer(SequentialRoutedAgent):
Args:
parent_topic_type (str): The topic type of the parent orchestrator.
output_topic_type (str): The topic type for the output.
agent (ChatAgent): The agent to delegate message handling to.
"""
def __init__(self, parent_topic_type: str, agent: ChatAgent) -> None:
def __init__(self, parent_topic_type: str, output_topic_type: str, agent: ChatAgent) -> None:
super().__init__(description=agent.description)
self._parent_topic_type = parent_topic_type
self._output_topic_type = output_topic_type
self._agent = agent
self._message_buffer: List[ChatMessage] = []
@ -36,18 +38,27 @@ class ChatAgentContainer(SequentialRoutedAgent):
to the delegate agent and publish the response."""
# Pass the messages in the buffer to the delegate agent.
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
if not any(isinstance(response, msg_type) for msg_type in self._agent.produced_message_types):
if not any(isinstance(response.chat_message, msg_type) for msg_type in self._agent.produced_message_types):
raise ValueError(
f"The agent {self._agent.name} produced an unexpected message type: {type(response)}. "
f"Expected one of: {self._agent.produced_message_types}"
f"Expected one of: {self._agent.produced_message_types}. "
f"Check the agent's produced_message_types property."
)
# Publish inner messages to the output topic.
if response.inner_messages is not None:
for inner_message in response.inner_messages:
await self.publish_message(inner_message, topic_id=DefaultTopicId(type=self._output_topic_type))
# Publish the response.
self._message_buffer.clear()
await self.publish_message(
GroupChatPublishEvent(agent_message=response, source=self.id),
GroupChatPublishEvent(agent_message=response.chat_message, source=self.id),
topic_id=DefaultTopicId(type=self._parent_topic_type),
)
# Publish the response to the output topic.
await self.publish_message(response.chat_message, topic_id=DefaultTopicId(type=self._output_topic_type))
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
raise ValueError(f"Unhandled message in agent container: {type(message)}")

View File

@ -7,7 +7,7 @@ import pytest
from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import AssistantAgent, Handoff
from autogen_agentchat.logging import FileLogHandler
from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage
from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessages
from autogen_core.base import CancellationToken
from autogen_core.components.tools import FunctionTool
from autogen_ext.models import OpenAIChatCompletionClient
@ -111,10 +111,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)
result = await tool_use_agent.run("task")
assert len(result.messages) == 3
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert isinstance(result.messages[1], TextMessage)
assert isinstance(result.messages[2], StopMessage)
assert isinstance(result.messages[1], ToolCallMessage)
assert isinstance(result.messages[2], ToolCallResultMessages)
assert isinstance(result.messages[3], TextMessage)
@pytest.mark.asyncio
@ -162,5 +163,5 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
response = await tool_use_agent.on_messages(
[TextMessage(content="task", source="user")], cancellation_token=CancellationToken()
)
assert isinstance(response, HandoffMessage)
assert response.target == "agent2"
assert isinstance(response.chat_message, HandoffMessage)
assert response.chat_message.target == "agent2"

View File

@ -12,12 +12,15 @@ from autogen_agentchat.agents import (
CodeExecutorAgent,
Handoff,
)
from autogen_agentchat.base import Response
from autogen_agentchat.logging import FileLogHandler
from autogen_agentchat.messages import (
ChatMessage,
HandoffMessage,
StopMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessages,
)
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination
from autogen_agentchat.teams import (
@ -66,14 +69,14 @@ class _EchoAgent(BaseChatAgent):
def produced_message_types(self) -> List[type[ChatMessage]]:
return [TextMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
if len(messages) > 0:
assert isinstance(messages[0], TextMessage)
self._last_message = messages[0].content
return TextMessage(content=messages[0].content, source=self.name)
return Response(chat_message=TextMessage(content=messages[0].content, source=self.name))
else:
assert self._last_message is not None
return TextMessage(content=self._last_message, source=self.name)
return Response(chat_message=TextMessage(content=self._last_message, source=self.name))
class _StopAgent(_EchoAgent):
@ -86,11 +89,11 @@ class _StopAgent(_EchoAgent):
def produced_message_types(self) -> List[type[ChatMessage]]:
return [TextMessage, StopMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
self._count += 1
if self._count < self._stop_at:
return await super().on_messages(messages, cancellation_token)
return StopMessage(content="TERMINATE", source=self.name)
return Response(chat_message=StopMessage(content="TERMINATE", source=self.name))
def _pass_function(input: str) -> str:
@ -230,11 +233,13 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
)
assert len(result.messages) == 4
assert len(result.messages) == 6
assert isinstance(result.messages[0], TextMessage) # task
assert isinstance(result.messages[1], TextMessage) # tool use agent response
assert isinstance(result.messages[2], TextMessage) # echo agent response
assert isinstance(result.messages[3], StopMessage) # tool use agent response
assert isinstance(result.messages[1], ToolCallMessage) # tool call
assert isinstance(result.messages[2], ToolCallResultMessages) # tool call result
assert isinstance(result.messages[3], TextMessage) # tool use agent response
assert isinstance(result.messages[4], TextMessage) # echo agent response
assert isinstance(result.messages[5], StopMessage) # tool use agent response
context = tool_use_agent._model_context # pyright: ignore
assert context[0].content == "Write a program that prints 'Hello, world!'"
@ -427,8 +432,12 @@ class _HandOffAgent(BaseChatAgent):
def produced_message_types(self) -> List[type[ChatMessage]]:
return [HandoffMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
return HandoffMessage(content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name)
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
return Response(
chat_message=HandoffMessage(
content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name
)
)
@pytest.mark.asyncio
@ -513,9 +522,11 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
team = Swarm([agnet1, agent2])
result = await team.run("task", termination_condition=StopMessageTermination())
assert len(result.messages) == 5
assert len(result.messages) == 7
assert result.messages[0].content == "task"
assert result.messages[1].content == "handoff to agent2"
assert result.messages[2].content == "Transferred to agent1."
assert result.messages[3].content == "Hello"
assert result.messages[4].content == "TERMINATE"
assert isinstance(result.messages[1], ToolCallMessage)
assert isinstance(result.messages[2], ToolCallResultMessages)
assert result.messages[3].content == "handoff to agent2"
assert result.messages[4].content == "Transferred to agent1."
assert result.messages[5].content == "Hello"
assert result.messages[6].content == "TERMINATE"

View File

@ -251,6 +251,7 @@
"from typing import List, Sequence\n",
"\n",
"from autogen_agentchat.agents import BaseChatAgent\n",
"from autogen_agentchat.base import Response\n",
"from autogen_agentchat.messages import (\n",
" ChatMessage,\n",
" StopMessage,\n",
@ -266,11 +267,11 @@
" def produced_message_types(self) -> List[type[ChatMessage]]:\n",
" return [TextMessage, StopMessage]\n",
"\n",
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:\n",
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
" user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n",
" if \"TERMINATE\" in user_input:\n",
" return StopMessage(content=\"User has terminated the conversation.\", source=self.name)\n",
" return TextMessage(content=user_input, source=self.name)\n",
" return Response(chat_message=StopMessage(content=\"User has terminated the conversation.\", source=self.name))\n",
" return Response(chat_message=TextMessage(content=user_input, source=self.name))\n",
"\n",
"\n",
"user_proxy_agent = UserProxyAgent(name=\"user_proxy_agent\")\n",

View File

@ -45,6 +45,7 @@
" CodingAssistantAgent,\n",
" ToolUseAssistantAgent,\n",
")\n",
"from autogen_agentchat.base import Response\n",
"from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage\n",
"from autogen_agentchat.task import StopMessageTermination\n",
"from autogen_agentchat.teams import SelectorGroupChat\n",
@ -75,11 +76,11 @@
" def produced_message_types(self) -> List[type[ChatMessage]]:\n",
" return [TextMessage, StopMessage]\n",
"\n",
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:\n",
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
" user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n",
" if \"TERMINATE\" in user_input:\n",
" return StopMessage(content=\"User has terminated the conversation.\", source=self.name)\n",
" return TextMessage(content=user_input, source=self.name)"
" return Response(chat_message=StopMessage(content=\"User has terminated the conversation.\", source=self.name))\n",
" return Response(chat_message=TextMessage(content=user_input, source=self.name))"
]
},
{