From 00ffb372d13ce52cc6f63e2e1d491f5485605705 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Fri, 24 May 2024 17:25:17 -0400 Subject: [PATCH] Update group chat and message types (#20) * Update group chat and message types * fix type based router --- .gitignore | 1 + examples/patterns.py | 30 ++++-- .../agent_components/type_routed_agent.py | 12 ++- src/agnext/chat/agents/base.py | 5 +- src/agnext/chat/agents/oai_assistant.py | 42 +++++---- src/agnext/chat/agents/random_agent.py | 17 ++-- src/agnext/chat/patterns/group_chat.py | 91 ++++++++++--------- src/agnext/chat/patterns/orchestrator.py | 4 +- src/agnext/chat/types.py | 6 ++ 9 files changed, 121 insertions(+), 87 deletions(-) diff --git a/.gitignore b/.gitignore index bf95311d43..82455027b2 100644 --- a/.gitignore +++ b/.gitignore @@ -162,3 +162,4 @@ cython_debug/ .ruff_cache/ /docs/src/reference +.DS_Store diff --git a/examples/patterns.py b/examples/patterns.py index 32f889ffdd..defa872988 100644 --- a/examples/patterns.py +++ b/examples/patterns.py @@ -1,5 +1,6 @@ import argparse import asyncio +from typing import Any import openai from agnext.agent_components.models_clients.openai_client import OpenAI @@ -8,8 +9,27 @@ from agnext.application_components.single_threaded_agent_runtime import ( ) from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent from agnext.chat.messages import ChatMessage -from agnext.chat.patterns.group_chat import GroupChat +from agnext.chat.patterns.group_chat import GroupChat, Output from agnext.chat.patterns.orchestrator import Orchestrator +from agnext.chat.types import TextMessage + + +class ConcatOutput(Output): + def __init__(self) -> None: + self._output = "" + + def on_message_received(self, message: Any) -> None: + match message: + case TextMessage(content=content): + self._output += content + case _: + ... + + def get_output(self) -> Any: + return self._output + + def reset(self) -> None: + self._output = "" async def group_chat(message: str) -> None: @@ -45,13 +65,7 @@ async def group_chat(message: str) -> None: thread_id=cathy_oai_thread.id, ) - chat = GroupChat( - "Host", - "A round-robin chat room.", - runtime, - [joe, cathy], - num_rounds=5, - ) + chat = GroupChat("Host", "A round-robin chat room.", runtime, [joe, cathy], num_rounds=5, output=ConcatOutput()) response = runtime.send_message(ChatMessage(body=message, sender="host"), chat) diff --git a/src/agnext/agent_components/type_routed_agent.py b/src/agnext/agent_components/type_routed_agent.py index 490f3c03ca..af6042d87b 100644 --- a/src/agnext/agent_components/type_routed_agent.py +++ b/src/agnext/agent_components/type_routed_agent.py @@ -14,7 +14,7 @@ ProducesT = TypeVar("ProducesT", covariant=True) # NOTE: this works on concrete types and not inheritance def message_handler( - target_type: Type[ReceivesT], + *target_types: Type[ReceivesT], ) -> Callable[ [Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]], Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]], @@ -22,7 +22,8 @@ def message_handler( def decorator( func: Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]], ) -> Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]: - func._target_type = target_type # type: ignore + # Convert target_types to list and stash + func._target_types = list(target_types) # type: ignore return func return decorator @@ -40,8 +41,9 @@ class TypeRoutedAgent(BaseAgent): for attr in dir(self): if callable(getattr(self, attr, None)): handler = getattr(self, attr) - if hasattr(handler, "_target_type"): - self._handlers[handler._target_type] = handler + if hasattr(handler, "_target_types"): + for target_type in handler._target_types: + self._handlers[target_type] = handler @property def subscriptions(self) -> Sequence[Type[Any]]: @@ -60,4 +62,4 @@ class TypeRoutedAgent(BaseAgent): async def on_unhandled_message( self, message: Any, require_response: bool, cancellation_token: CancellationToken ) -> NoReturn: - raise CantHandleException() + raise CantHandleException(f"Unhandled message: {message}") diff --git a/src/agnext/chat/agents/base.py b/src/agnext/chat/agents/base.py index 3cece5005c..14934f17d7 100644 --- a/src/agnext/chat/agents/base.py +++ b/src/agnext/chat/agents/base.py @@ -1,9 +1,8 @@ from agnext.core.agent_runtime import AgentRuntime - -from ...agent_components.type_routed_agent import TypeRoutedAgent +from agnext.core.base_agent import BaseAgent -class BaseChatAgent(TypeRoutedAgent): +class BaseChatAgent(BaseAgent): """The BaseAgent class for the chat API.""" def __init__(self, name: str, description: str, runtime: AgentRuntime) -> None: diff --git a/src/agnext/chat/agents/oai_assistant.py b/src/agnext/chat/agents/oai_assistant.py index 11ef741f93..8ee8980313 100644 --- a/src/agnext/chat/agents/oai_assistant.py +++ b/src/agnext/chat/agents/oai_assistant.py @@ -1,14 +1,13 @@ import openai -from agnext.agent_components.type_routed_agent import message_handler +from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler from agnext.chat.agents.base import BaseChatAgent +from agnext.chat.types import Reset, RespondNow, TextMessage from agnext.core.agent_runtime import AgentRuntime from agnext.core.cancellation_token import CancellationToken -from ..messages import ChatMessage - -class OpenAIAssistantAgent(BaseChatAgent): +class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent): def __init__( self, name: str, @@ -25,29 +24,38 @@ class OpenAIAssistantAgent(BaseChatAgent): self._current_session_window_length = 0 # TODO: use require_response - @message_handler(ChatMessage) + @message_handler(TextMessage) async def on_chat_message_with_cancellation( - self, message: ChatMessage, require_response: bool, cancellation_token: CancellationToken - ) -> ChatMessage | None: + self, message: TextMessage, require_response: bool, cancellation_token: CancellationToken + ) -> None: print("---------------") - print(f"{self.name} received message from {message.sender}: {message.body}") + print(f"{self.name} received message from {message.source}: {message.content}") print("---------------") - if message.reset: - # Reset the current session window. - self._current_session_window_length = 0 # Save the message to the thread. _ = await self._client.beta.threads.messages.create( thread_id=self._thread_id, - content=message.body, + content=message.content, role="user", - metadata={"sender": message.sender}, + metadata={"sender": message.source}, ) self._current_session_window_length += 1 - # If the message is a save_message_only message, return early. - if message.save_message_only: - return ChatMessage(body="OK", sender=self.name) + if require_response: + # TODO ? + ... + + @message_handler(Reset) + async def on_reset(self, message: Reset, require_response: bool, cancellation_token: CancellationToken) -> None: + # Reset the current session window. + self._current_session_window_length = 0 + + @message_handler(RespondNow) + async def on_respond_now( + self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken + ) -> TextMessage | None: + if not require_response: + return None # Create a run and wait until it finishes. run = await self._client.beta.threads.runs.create_and_poll( @@ -73,4 +81,4 @@ class OpenAIAssistantAgent(BaseChatAgent): raise ValueError(f"Expected text content in the last message: {last_message_content}") # TODO: handle multiple text content. - return ChatMessage(body=text_content[0].text.value, sender=self.name) + return TextMessage(content=text_content[0].text.value, source=self.name) diff --git a/src/agnext/chat/agents/random_agent.py b/src/agnext/chat/agents/random_agent.py index 0581ac437b..f96a2d949d 100644 --- a/src/agnext/chat/agents/random_agent.py +++ b/src/agnext/chat/agents/random_agent.py @@ -1,21 +1,18 @@ import random -from agnext.agent_components.type_routed_agent import message_handler +from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler +from agnext.chat.types import RespondNow, TextMessage from agnext.core.cancellation_token import CancellationToken from ..agents.base import BaseChatAgent -from ..messages import ChatMessage -class RandomResponseAgent(BaseChatAgent): +class RandomResponseAgent(BaseChatAgent, TypeRoutedAgent): # TODO: use require_response - @message_handler(ChatMessage) + @message_handler(RespondNow) async def on_chat_message_with_cancellation( - self, message: ChatMessage, require_response: bool, cancellation_token: CancellationToken - ) -> ChatMessage | None: - print(f"{self.name} received message from {message.sender}: {message.body}") - if message.save_message_only: - return ChatMessage(body="OK", sender=self.name) + self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken + ) -> TextMessage: # Generate a random response. response_body = random.choice( [ @@ -36,4 +33,4 @@ class RandomResponseAgent(BaseChatAgent): "See you!", ] ) - return ChatMessage(body=response_body, sender=self.name) + return TextMessage(content=response_body, source=self.name) diff --git a/src/agnext/chat/patterns/group_chat.py b/src/agnext/chat/patterns/group_chat.py index 4efd490f21..0120522d9b 100644 --- a/src/agnext/chat/patterns/group_chat.py +++ b/src/agnext/chat/patterns/group_chat.py @@ -1,10 +1,18 @@ -from typing import List, Sequence +from typing import Any, List, Protocol, Sequence + +from agnext.chat.types import Reset, RespondNow -from ...agent_components.type_routed_agent import message_handler from ...core.agent_runtime import AgentRuntime from ...core.cancellation_token import CancellationToken from ..agents.base import BaseChatAgent -from ..messages import ChatMessage + + +class Output(Protocol): + def on_message_received(self, message: Any) -> None: ... + + def get_output(self) -> Any: ... + + def reset(self) -> None: ... class GroupChat(BaseChatAgent): @@ -15,28 +23,37 @@ class GroupChat(BaseChatAgent): runtime: AgentRuntime, agents: Sequence[BaseChatAgent], num_rounds: int, + output: Output, ) -> None: super().__init__(name, description, runtime) self._agents = agents self._num_rounds = num_rounds - self._history: List[ChatMessage] = [] + self._history: List[Any] = [] + self._output = output - @message_handler(ChatMessage) - async def on_chat_message( - self, - message: ChatMessage, - require_response: bool, - cancellation_token: CancellationToken, - ) -> ChatMessage | None: - if message.reset: + @property + def subscriptions(self) -> Sequence[type]: + agent_sublists = [agent.subscriptions for agent in self._agents] + return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist] + + async def on_message( + self, message: Any, require_response: bool, cancellation_token: CancellationToken + ) -> Any | None: + if isinstance(message, Reset): # Reset the history. self._history = [] - if message.save_message_only: - # TODO: what should we do with save_message_only messages for this pattern? - return ChatMessage(body="OK", sender=self.name) + # TODO: reset sub-agents? + + if isinstance(message, RespondNow): + # TODO reset... + return self._output.get_output() + + # TODO: should we do nothing here? + # Perhaps it should be saved into the message history? + if not require_response: + return None self._history.append(message) - previous_speaker: BaseChatAgent | None = None round = 0 while round < self._num_rounds: @@ -44,41 +61,31 @@ class GroupChat(BaseChatAgent): # Select speaker (round-robin for now). speaker = self._agents[round % len(self._agents)] - # Send the last message to non-speaking agents. - for agent in [agent for agent in self._agents if agent is not previous_speaker and agent is not speaker]: + # Send the last message to all agents. + for agent in [agent for agent in self._agents]: + # TODO gather and await _ = await self._send_message( - ChatMessage( - body=self._history[-1].body, - sender=self._history[-1].sender, - save_message_only=True, - ), + self._history[-1], agent, + require_response=False, + cancellation_token=cancellation_token, ) - # Send the last message to the speaking agent and ask to speak. - if previous_speaker is not speaker: - response = await self._send_message( - ChatMessage(body=self._history[-1].body, sender=self._history[-1].sender), - speaker, - ) - else: - # The same speaker is speaking again. - # TODO: should support a separate message type for request to speak only. - response = await self._send_message( - ChatMessage(body="", sender=self.name), - speaker, - ) + response = await self._send_message( + RespondNow(), + speaker, + require_response=True, + cancellation_token=cancellation_token, + ) if response is not None: # 4. Append the response to the history. self._history.append(response) - - # 5. Update the previous speaker. - previous_speaker = speaker + self._output.on_message_received(response) # 6. Increment the round. round += 1 - # Construct the final response. - response_body = "\n".join([f"{message.sender}: {message.body}" for message in self._history]) - return ChatMessage(body=response_body, sender=self.name) + output = self._output.get_output() + self._output.reset() + return output diff --git a/src/agnext/chat/patterns/orchestrator.py b/src/agnext/chat/patterns/orchestrator.py index 0798da8260..fbf597a97f 100644 --- a/src/agnext/chat/patterns/orchestrator.py +++ b/src/agnext/chat/patterns/orchestrator.py @@ -2,7 +2,7 @@ import json from typing import Any, List, Sequence, Tuple from ...agent_components.model_client import ModelClient -from ...agent_components.type_routed_agent import message_handler +from ...agent_components.type_routed_agent import TypeRoutedAgent, message_handler from ...agent_components.types import AssistantMessage, LLMMessage, UserMessage from ...core.agent_runtime import AgentRuntime from ...core.cancellation_token import CancellationToken @@ -10,7 +10,7 @@ from ..agents.base import BaseChatAgent from ..messages import ChatMessage -class Orchestrator(BaseChatAgent): +class Orchestrator(BaseChatAgent, TypeRoutedAgent): def __init__( self, name: str, diff --git a/src/agnext/chat/types.py b/src/agnext/chat/types.py index 1a2bbb5c54..c49b746882 100644 --- a/src/agnext/chat/types.py +++ b/src/agnext/chat/types.py @@ -40,3 +40,9 @@ class FunctionExecutionResultMessage(BaseMessage): Message = Union[TextMessage, MultiModalMessage, FunctionCallMessage, FunctionExecutionResultMessage] + + +class RespondNow: ... + + +class Reset: ...