mirror of https://github.com/microsoft/autogen.git
Update group chat and message types (#20)
* Update group chat and message types * fix type based router
This commit is contained in:
parent
ce58c5bc72
commit
00ffb372d1
|
@ -162,3 +162,4 @@ cython_debug/
|
|||
.ruff_cache/
|
||||
|
||||
/docs/src/reference
|
||||
.DS_Store
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
from agnext.agent_components.models_clients.openai_client import OpenAI
|
||||
|
@ -8,8 +9,27 @@ from agnext.application_components.single_threaded_agent_runtime import (
|
|||
)
|
||||
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
|
||||
from agnext.chat.messages import ChatMessage
|
||||
from agnext.chat.patterns.group_chat import GroupChat
|
||||
from agnext.chat.patterns.group_chat import GroupChat, Output
|
||||
from agnext.chat.patterns.orchestrator import Orchestrator
|
||||
from agnext.chat.types import TextMessage
|
||||
|
||||
|
||||
class ConcatOutput(Output):
|
||||
def __init__(self) -> None:
|
||||
self._output = ""
|
||||
|
||||
def on_message_received(self, message: Any) -> None:
|
||||
match message:
|
||||
case TextMessage(content=content):
|
||||
self._output += content
|
||||
case _:
|
||||
...
|
||||
|
||||
def get_output(self) -> Any:
|
||||
return self._output
|
||||
|
||||
def reset(self) -> None:
|
||||
self._output = ""
|
||||
|
||||
|
||||
async def group_chat(message: str) -> None:
|
||||
|
@ -45,13 +65,7 @@ async def group_chat(message: str) -> None:
|
|||
thread_id=cathy_oai_thread.id,
|
||||
)
|
||||
|
||||
chat = GroupChat(
|
||||
"Host",
|
||||
"A round-robin chat room.",
|
||||
runtime,
|
||||
[joe, cathy],
|
||||
num_rounds=5,
|
||||
)
|
||||
chat = GroupChat("Host", "A round-robin chat room.", runtime, [joe, cathy], num_rounds=5, output=ConcatOutput())
|
||||
|
||||
response = runtime.send_message(ChatMessage(body=message, sender="host"), chat)
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ ProducesT = TypeVar("ProducesT", covariant=True)
|
|||
|
||||
# NOTE: this works on concrete types and not inheritance
|
||||
def message_handler(
|
||||
target_type: Type[ReceivesT],
|
||||
*target_types: Type[ReceivesT],
|
||||
) -> Callable[
|
||||
[Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]],
|
||||
Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
|
||||
|
@ -22,7 +22,8 @@ def message_handler(
|
|||
def decorator(
|
||||
func: Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
|
||||
) -> Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]:
|
||||
func._target_type = target_type # type: ignore
|
||||
# Convert target_types to list and stash
|
||||
func._target_types = list(target_types) # type: ignore
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
@ -40,8 +41,9 @@ class TypeRoutedAgent(BaseAgent):
|
|||
for attr in dir(self):
|
||||
if callable(getattr(self, attr, None)):
|
||||
handler = getattr(self, attr)
|
||||
if hasattr(handler, "_target_type"):
|
||||
self._handlers[handler._target_type] = handler
|
||||
if hasattr(handler, "_target_types"):
|
||||
for target_type in handler._target_types:
|
||||
self._handlers[target_type] = handler
|
||||
|
||||
@property
|
||||
def subscriptions(self) -> Sequence[Type[Any]]:
|
||||
|
@ -60,4 +62,4 @@ class TypeRoutedAgent(BaseAgent):
|
|||
async def on_unhandled_message(
|
||||
self, message: Any, require_response: bool, cancellation_token: CancellationToken
|
||||
) -> NoReturn:
|
||||
raise CantHandleException()
|
||||
raise CantHandleException(f"Unhandled message: {message}")
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
from agnext.core.agent_runtime import AgentRuntime
|
||||
|
||||
from ...agent_components.type_routed_agent import TypeRoutedAgent
|
||||
from agnext.core.base_agent import BaseAgent
|
||||
|
||||
|
||||
class BaseChatAgent(TypeRoutedAgent):
|
||||
class BaseChatAgent(BaseAgent):
|
||||
"""The BaseAgent class for the chat API."""
|
||||
|
||||
def __init__(self, name: str, description: str, runtime: AgentRuntime) -> None:
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
import openai
|
||||
|
||||
from agnext.agent_components.type_routed_agent import message_handler
|
||||
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
|
||||
from agnext.chat.agents.base import BaseChatAgent
|
||||
from agnext.chat.types import Reset, RespondNow, TextMessage
|
||||
from agnext.core.agent_runtime import AgentRuntime
|
||||
from agnext.core.cancellation_token import CancellationToken
|
||||
|
||||
from ..messages import ChatMessage
|
||||
|
||||
|
||||
class OpenAIAssistantAgent(BaseChatAgent):
|
||||
class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
|
@ -25,29 +24,38 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||
self._current_session_window_length = 0
|
||||
|
||||
# TODO: use require_response
|
||||
@message_handler(ChatMessage)
|
||||
@message_handler(TextMessage)
|
||||
async def on_chat_message_with_cancellation(
|
||||
self, message: ChatMessage, require_response: bool, cancellation_token: CancellationToken
|
||||
) -> ChatMessage | None:
|
||||
self, message: TextMessage, require_response: bool, cancellation_token: CancellationToken
|
||||
) -> None:
|
||||
print("---------------")
|
||||
print(f"{self.name} received message from {message.sender}: {message.body}")
|
||||
print(f"{self.name} received message from {message.source}: {message.content}")
|
||||
print("---------------")
|
||||
if message.reset:
|
||||
# Reset the current session window.
|
||||
self._current_session_window_length = 0
|
||||
|
||||
# Save the message to the thread.
|
||||
_ = await self._client.beta.threads.messages.create(
|
||||
thread_id=self._thread_id,
|
||||
content=message.body,
|
||||
content=message.content,
|
||||
role="user",
|
||||
metadata={"sender": message.sender},
|
||||
metadata={"sender": message.source},
|
||||
)
|
||||
self._current_session_window_length += 1
|
||||
|
||||
# If the message is a save_message_only message, return early.
|
||||
if message.save_message_only:
|
||||
return ChatMessage(body="OK", sender=self.name)
|
||||
if require_response:
|
||||
# TODO ?
|
||||
...
|
||||
|
||||
@message_handler(Reset)
|
||||
async def on_reset(self, message: Reset, require_response: bool, cancellation_token: CancellationToken) -> None:
|
||||
# Reset the current session window.
|
||||
self._current_session_window_length = 0
|
||||
|
||||
@message_handler(RespondNow)
|
||||
async def on_respond_now(
|
||||
self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken
|
||||
) -> TextMessage | None:
|
||||
if not require_response:
|
||||
return None
|
||||
|
||||
# Create a run and wait until it finishes.
|
||||
run = await self._client.beta.threads.runs.create_and_poll(
|
||||
|
@ -73,4 +81,4 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||
raise ValueError(f"Expected text content in the last message: {last_message_content}")
|
||||
|
||||
# TODO: handle multiple text content.
|
||||
return ChatMessage(body=text_content[0].text.value, sender=self.name)
|
||||
return TextMessage(content=text_content[0].text.value, source=self.name)
|
||||
|
|
|
@ -1,21 +1,18 @@
|
|||
import random
|
||||
|
||||
from agnext.agent_components.type_routed_agent import message_handler
|
||||
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
|
||||
from agnext.chat.types import RespondNow, TextMessage
|
||||
from agnext.core.cancellation_token import CancellationToken
|
||||
|
||||
from ..agents.base import BaseChatAgent
|
||||
from ..messages import ChatMessage
|
||||
|
||||
|
||||
class RandomResponseAgent(BaseChatAgent):
|
||||
class RandomResponseAgent(BaseChatAgent, TypeRoutedAgent):
|
||||
# TODO: use require_response
|
||||
@message_handler(ChatMessage)
|
||||
@message_handler(RespondNow)
|
||||
async def on_chat_message_with_cancellation(
|
||||
self, message: ChatMessage, require_response: bool, cancellation_token: CancellationToken
|
||||
) -> ChatMessage | None:
|
||||
print(f"{self.name} received message from {message.sender}: {message.body}")
|
||||
if message.save_message_only:
|
||||
return ChatMessage(body="OK", sender=self.name)
|
||||
self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken
|
||||
) -> TextMessage:
|
||||
# Generate a random response.
|
||||
response_body = random.choice(
|
||||
[
|
||||
|
@ -36,4 +33,4 @@ class RandomResponseAgent(BaseChatAgent):
|
|||
"See you!",
|
||||
]
|
||||
)
|
||||
return ChatMessage(body=response_body, sender=self.name)
|
||||
return TextMessage(content=response_body, source=self.name)
|
||||
|
|
|
@ -1,10 +1,18 @@
|
|||
from typing import List, Sequence
|
||||
from typing import Any, List, Protocol, Sequence
|
||||
|
||||
from agnext.chat.types import Reset, RespondNow
|
||||
|
||||
from ...agent_components.type_routed_agent import message_handler
|
||||
from ...core.agent_runtime import AgentRuntime
|
||||
from ...core.cancellation_token import CancellationToken
|
||||
from ..agents.base import BaseChatAgent
|
||||
from ..messages import ChatMessage
|
||||
|
||||
|
||||
class Output(Protocol):
|
||||
def on_message_received(self, message: Any) -> None: ...
|
||||
|
||||
def get_output(self) -> Any: ...
|
||||
|
||||
def reset(self) -> None: ...
|
||||
|
||||
|
||||
class GroupChat(BaseChatAgent):
|
||||
|
@ -15,28 +23,37 @@ class GroupChat(BaseChatAgent):
|
|||
runtime: AgentRuntime,
|
||||
agents: Sequence[BaseChatAgent],
|
||||
num_rounds: int,
|
||||
output: Output,
|
||||
) -> None:
|
||||
super().__init__(name, description, runtime)
|
||||
self._agents = agents
|
||||
self._num_rounds = num_rounds
|
||||
self._history: List[ChatMessage] = []
|
||||
self._history: List[Any] = []
|
||||
self._output = output
|
||||
|
||||
@message_handler(ChatMessage)
|
||||
async def on_chat_message(
|
||||
self,
|
||||
message: ChatMessage,
|
||||
require_response: bool,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> ChatMessage | None:
|
||||
if message.reset:
|
||||
@property
|
||||
def subscriptions(self) -> Sequence[type]:
|
||||
agent_sublists = [agent.subscriptions for agent in self._agents]
|
||||
return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist]
|
||||
|
||||
async def on_message(
|
||||
self, message: Any, require_response: bool, cancellation_token: CancellationToken
|
||||
) -> Any | None:
|
||||
if isinstance(message, Reset):
|
||||
# Reset the history.
|
||||
self._history = []
|
||||
if message.save_message_only:
|
||||
# TODO: what should we do with save_message_only messages for this pattern?
|
||||
return ChatMessage(body="OK", sender=self.name)
|
||||
# TODO: reset sub-agents?
|
||||
|
||||
if isinstance(message, RespondNow):
|
||||
# TODO reset...
|
||||
return self._output.get_output()
|
||||
|
||||
# TODO: should we do nothing here?
|
||||
# Perhaps it should be saved into the message history?
|
||||
if not require_response:
|
||||
return None
|
||||
|
||||
self._history.append(message)
|
||||
previous_speaker: BaseChatAgent | None = None
|
||||
round = 0
|
||||
|
||||
while round < self._num_rounds:
|
||||
|
@ -44,41 +61,31 @@ class GroupChat(BaseChatAgent):
|
|||
# Select speaker (round-robin for now).
|
||||
speaker = self._agents[round % len(self._agents)]
|
||||
|
||||
# Send the last message to non-speaking agents.
|
||||
for agent in [agent for agent in self._agents if agent is not previous_speaker and agent is not speaker]:
|
||||
# Send the last message to all agents.
|
||||
for agent in [agent for agent in self._agents]:
|
||||
# TODO gather and await
|
||||
_ = await self._send_message(
|
||||
ChatMessage(
|
||||
body=self._history[-1].body,
|
||||
sender=self._history[-1].sender,
|
||||
save_message_only=True,
|
||||
),
|
||||
self._history[-1],
|
||||
agent,
|
||||
require_response=False,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
# Send the last message to the speaking agent and ask to speak.
|
||||
if previous_speaker is not speaker:
|
||||
response = await self._send_message(
|
||||
ChatMessage(body=self._history[-1].body, sender=self._history[-1].sender),
|
||||
speaker,
|
||||
)
|
||||
else:
|
||||
# The same speaker is speaking again.
|
||||
# TODO: should support a separate message type for request to speak only.
|
||||
response = await self._send_message(
|
||||
ChatMessage(body="", sender=self.name),
|
||||
speaker,
|
||||
)
|
||||
response = await self._send_message(
|
||||
RespondNow(),
|
||||
speaker,
|
||||
require_response=True,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
if response is not None:
|
||||
# 4. Append the response to the history.
|
||||
self._history.append(response)
|
||||
|
||||
# 5. Update the previous speaker.
|
||||
previous_speaker = speaker
|
||||
self._output.on_message_received(response)
|
||||
|
||||
# 6. Increment the round.
|
||||
round += 1
|
||||
|
||||
# Construct the final response.
|
||||
response_body = "\n".join([f"{message.sender}: {message.body}" for message in self._history])
|
||||
return ChatMessage(body=response_body, sender=self.name)
|
||||
output = self._output.get_output()
|
||||
self._output.reset()
|
||||
return output
|
||||
|
|
|
@ -2,7 +2,7 @@ import json
|
|||
from typing import Any, List, Sequence, Tuple
|
||||
|
||||
from ...agent_components.model_client import ModelClient
|
||||
from ...agent_components.type_routed_agent import message_handler
|
||||
from ...agent_components.type_routed_agent import TypeRoutedAgent, message_handler
|
||||
from ...agent_components.types import AssistantMessage, LLMMessage, UserMessage
|
||||
from ...core.agent_runtime import AgentRuntime
|
||||
from ...core.cancellation_token import CancellationToken
|
||||
|
@ -10,7 +10,7 @@ from ..agents.base import BaseChatAgent
|
|||
from ..messages import ChatMessage
|
||||
|
||||
|
||||
class Orchestrator(BaseChatAgent):
|
||||
class Orchestrator(BaseChatAgent, TypeRoutedAgent):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
|
|
|
@ -40,3 +40,9 @@ class FunctionExecutionResultMessage(BaseMessage):
|
|||
|
||||
|
||||
Message = Union[TextMessage, MultiModalMessage, FunctionCallMessage, FunctionExecutionResultMessage]
|
||||
|
||||
|
||||
class RespondNow: ...
|
||||
|
||||
|
||||
class Reset: ...
|
||||
|
|
Loading…
Reference in New Issue