Update group chat and message types (#20)

* Update group chat and message types

* fix type based router
This commit is contained in:
Jack Gerrits 2024-05-24 17:25:17 -04:00 committed by GitHub
parent ce58c5bc72
commit 00ffb372d1
9 changed files with 121 additions and 87 deletions

1
.gitignore vendored
View File

@ -162,3 +162,4 @@ cython_debug/
.ruff_cache/
/docs/src/reference
.DS_Store

View File

@ -1,5 +1,6 @@
import argparse
import asyncio
from typing import Any
import openai
from agnext.agent_components.models_clients.openai_client import OpenAI
@ -8,8 +9,27 @@ from agnext.application_components.single_threaded_agent_runtime import (
)
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
from agnext.chat.messages import ChatMessage
from agnext.chat.patterns.group_chat import GroupChat
from agnext.chat.patterns.group_chat import GroupChat, Output
from agnext.chat.patterns.orchestrator import Orchestrator
from agnext.chat.types import TextMessage
class ConcatOutput(Output):
def __init__(self) -> None:
self._output = ""
def on_message_received(self, message: Any) -> None:
match message:
case TextMessage(content=content):
self._output += content
case _:
...
def get_output(self) -> Any:
return self._output
def reset(self) -> None:
self._output = ""
async def group_chat(message: str) -> None:
@ -45,13 +65,7 @@ async def group_chat(message: str) -> None:
thread_id=cathy_oai_thread.id,
)
chat = GroupChat(
"Host",
"A round-robin chat room.",
runtime,
[joe, cathy],
num_rounds=5,
)
chat = GroupChat("Host", "A round-robin chat room.", runtime, [joe, cathy], num_rounds=5, output=ConcatOutput())
response = runtime.send_message(ChatMessage(body=message, sender="host"), chat)

View File

@ -14,7 +14,7 @@ ProducesT = TypeVar("ProducesT", covariant=True)
# NOTE: this works on concrete types and not inheritance
def message_handler(
target_type: Type[ReceivesT],
*target_types: Type[ReceivesT],
) -> Callable[
[Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]],
Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
@ -22,7 +22,8 @@ def message_handler(
def decorator(
func: Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
) -> Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]:
func._target_type = target_type # type: ignore
# Convert target_types to list and stash
func._target_types = list(target_types) # type: ignore
return func
return decorator
@ -40,8 +41,9 @@ class TypeRoutedAgent(BaseAgent):
for attr in dir(self):
if callable(getattr(self, attr, None)):
handler = getattr(self, attr)
if hasattr(handler, "_target_type"):
self._handlers[handler._target_type] = handler
if hasattr(handler, "_target_types"):
for target_type in handler._target_types:
self._handlers[target_type] = handler
@property
def subscriptions(self) -> Sequence[Type[Any]]:
@ -60,4 +62,4 @@ class TypeRoutedAgent(BaseAgent):
async def on_unhandled_message(
self, message: Any, require_response: bool, cancellation_token: CancellationToken
) -> NoReturn:
raise CantHandleException()
raise CantHandleException(f"Unhandled message: {message}")

View File

@ -1,9 +1,8 @@
from agnext.core.agent_runtime import AgentRuntime
from ...agent_components.type_routed_agent import TypeRoutedAgent
from agnext.core.base_agent import BaseAgent
class BaseChatAgent(TypeRoutedAgent):
class BaseChatAgent(BaseAgent):
"""The BaseAgent class for the chat API."""
def __init__(self, name: str, description: str, runtime: AgentRuntime) -> None:

View File

@ -1,14 +1,13 @@
import openai
from agnext.agent_components.type_routed_agent import message_handler
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
from agnext.chat.agents.base import BaseChatAgent
from agnext.chat.types import Reset, RespondNow, TextMessage
from agnext.core.agent_runtime import AgentRuntime
from agnext.core.cancellation_token import CancellationToken
from ..messages import ChatMessage
class OpenAIAssistantAgent(BaseChatAgent):
class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
def __init__(
self,
name: str,
@ -25,29 +24,38 @@ class OpenAIAssistantAgent(BaseChatAgent):
self._current_session_window_length = 0
# TODO: use require_response
@message_handler(ChatMessage)
@message_handler(TextMessage)
async def on_chat_message_with_cancellation(
self, message: ChatMessage, require_response: bool, cancellation_token: CancellationToken
) -> ChatMessage | None:
self, message: TextMessage, require_response: bool, cancellation_token: CancellationToken
) -> None:
print("---------------")
print(f"{self.name} received message from {message.sender}: {message.body}")
print(f"{self.name} received message from {message.source}: {message.content}")
print("---------------")
if message.reset:
# Reset the current session window.
self._current_session_window_length = 0
# Save the message to the thread.
_ = await self._client.beta.threads.messages.create(
thread_id=self._thread_id,
content=message.body,
content=message.content,
role="user",
metadata={"sender": message.sender},
metadata={"sender": message.source},
)
self._current_session_window_length += 1
# If the message is a save_message_only message, return early.
if message.save_message_only:
return ChatMessage(body="OK", sender=self.name)
if require_response:
# TODO ?
...
@message_handler(Reset)
async def on_reset(self, message: Reset, require_response: bool, cancellation_token: CancellationToken) -> None:
# Reset the current session window.
self._current_session_window_length = 0
@message_handler(RespondNow)
async def on_respond_now(
self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken
) -> TextMessage | None:
if not require_response:
return None
# Create a run and wait until it finishes.
run = await self._client.beta.threads.runs.create_and_poll(
@ -73,4 +81,4 @@ class OpenAIAssistantAgent(BaseChatAgent):
raise ValueError(f"Expected text content in the last message: {last_message_content}")
# TODO: handle multiple text content.
return ChatMessage(body=text_content[0].text.value, sender=self.name)
return TextMessage(content=text_content[0].text.value, source=self.name)

View File

@ -1,21 +1,18 @@
import random
from agnext.agent_components.type_routed_agent import message_handler
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
from agnext.chat.types import RespondNow, TextMessage
from agnext.core.cancellation_token import CancellationToken
from ..agents.base import BaseChatAgent
from ..messages import ChatMessage
class RandomResponseAgent(BaseChatAgent):
class RandomResponseAgent(BaseChatAgent, TypeRoutedAgent):
# TODO: use require_response
@message_handler(ChatMessage)
@message_handler(RespondNow)
async def on_chat_message_with_cancellation(
self, message: ChatMessage, require_response: bool, cancellation_token: CancellationToken
) -> ChatMessage | None:
print(f"{self.name} received message from {message.sender}: {message.body}")
if message.save_message_only:
return ChatMessage(body="OK", sender=self.name)
self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken
) -> TextMessage:
# Generate a random response.
response_body = random.choice(
[
@ -36,4 +33,4 @@ class RandomResponseAgent(BaseChatAgent):
"See you!",
]
)
return ChatMessage(body=response_body, sender=self.name)
return TextMessage(content=response_body, source=self.name)

View File

@ -1,10 +1,18 @@
from typing import List, Sequence
from typing import Any, List, Protocol, Sequence
from agnext.chat.types import Reset, RespondNow
from ...agent_components.type_routed_agent import message_handler
from ...core.agent_runtime import AgentRuntime
from ...core.cancellation_token import CancellationToken
from ..agents.base import BaseChatAgent
from ..messages import ChatMessage
class Output(Protocol):
def on_message_received(self, message: Any) -> None: ...
def get_output(self) -> Any: ...
def reset(self) -> None: ...
class GroupChat(BaseChatAgent):
@ -15,28 +23,37 @@ class GroupChat(BaseChatAgent):
runtime: AgentRuntime,
agents: Sequence[BaseChatAgent],
num_rounds: int,
output: Output,
) -> None:
super().__init__(name, description, runtime)
self._agents = agents
self._num_rounds = num_rounds
self._history: List[ChatMessage] = []
self._history: List[Any] = []
self._output = output
@message_handler(ChatMessage)
async def on_chat_message(
self,
message: ChatMessage,
require_response: bool,
cancellation_token: CancellationToken,
) -> ChatMessage | None:
if message.reset:
@property
def subscriptions(self) -> Sequence[type]:
agent_sublists = [agent.subscriptions for agent in self._agents]
return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist]
async def on_message(
self, message: Any, require_response: bool, cancellation_token: CancellationToken
) -> Any | None:
if isinstance(message, Reset):
# Reset the history.
self._history = []
if message.save_message_only:
# TODO: what should we do with save_message_only messages for this pattern?
return ChatMessage(body="OK", sender=self.name)
# TODO: reset sub-agents?
if isinstance(message, RespondNow):
# TODO reset...
return self._output.get_output()
# TODO: should we do nothing here?
# Perhaps it should be saved into the message history?
if not require_response:
return None
self._history.append(message)
previous_speaker: BaseChatAgent | None = None
round = 0
while round < self._num_rounds:
@ -44,41 +61,31 @@ class GroupChat(BaseChatAgent):
# Select speaker (round-robin for now).
speaker = self._agents[round % len(self._agents)]
# Send the last message to non-speaking agents.
for agent in [agent for agent in self._agents if agent is not previous_speaker and agent is not speaker]:
# Send the last message to all agents.
for agent in [agent for agent in self._agents]:
# TODO gather and await
_ = await self._send_message(
ChatMessage(
body=self._history[-1].body,
sender=self._history[-1].sender,
save_message_only=True,
),
self._history[-1],
agent,
require_response=False,
cancellation_token=cancellation_token,
)
# Send the last message to the speaking agent and ask to speak.
if previous_speaker is not speaker:
response = await self._send_message(
ChatMessage(body=self._history[-1].body, sender=self._history[-1].sender),
speaker,
)
else:
# The same speaker is speaking again.
# TODO: should support a separate message type for request to speak only.
response = await self._send_message(
ChatMessage(body="", sender=self.name),
speaker,
)
response = await self._send_message(
RespondNow(),
speaker,
require_response=True,
cancellation_token=cancellation_token,
)
if response is not None:
# 4. Append the response to the history.
self._history.append(response)
# 5. Update the previous speaker.
previous_speaker = speaker
self._output.on_message_received(response)
# 6. Increment the round.
round += 1
# Construct the final response.
response_body = "\n".join([f"{message.sender}: {message.body}" for message in self._history])
return ChatMessage(body=response_body, sender=self.name)
output = self._output.get_output()
self._output.reset()
return output

View File

@ -2,7 +2,7 @@ import json
from typing import Any, List, Sequence, Tuple
from ...agent_components.model_client import ModelClient
from ...agent_components.type_routed_agent import message_handler
from ...agent_components.type_routed_agent import TypeRoutedAgent, message_handler
from ...agent_components.types import AssistantMessage, LLMMessage, UserMessage
from ...core.agent_runtime import AgentRuntime
from ...core.cancellation_token import CancellationToken
@ -10,7 +10,7 @@ from ..agents.base import BaseChatAgent
from ..messages import ChatMessage
class Orchestrator(BaseChatAgent):
class Orchestrator(BaseChatAgent, TypeRoutedAgent):
def __init__(
self,
name: str,

View File

@ -40,3 +40,9 @@ class FunctionExecutionResultMessage(BaseMessage):
Message = Union[TextMessage, MultiModalMessage, FunctionCallMessage, FunctionExecutionResultMessage]
class RespondNow: ...
class Reset: ...