Remove require_response, rename broadcast to publish, remove publish responses (#25)

* rename broadcast to publish

* remove require response, remove responses from publishing
This commit is contained in:
Jack Gerrits 2024-05-26 08:45:02 -04:00 committed by GitHub
parent b6dd861166
commit cb55e00819
14 changed files with 69 additions and 161 deletions

View File

@ -19,10 +19,7 @@ class Inner(TypeRoutedAgent):
super().__init__(name, router)
@message_handler(MessageType)
async def on_new_message(
self, message: MessageType, require_response: bool, cancellation_token: CancellationToken
) -> MessageType:
assert require_response
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
return MessageType(body=f"Inner: {message.body}", sender=self.name)
@ -32,11 +29,8 @@ class Outer(TypeRoutedAgent):
self._inner = inner
@message_handler(MessageType)
async def on_new_message(
self, message: MessageType, require_response: bool, cancellation_token: CancellationToken
) -> MessageType:
assert require_response
inner_response = self._send_message(message, self._inner, require_response=True)
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
inner_response = self._send_message(message, self._inner)
inner_message = await inner_response
assert isinstance(inner_message, MessageType)
return MessageType(body=f"Outer: {inner_message.body}", sender=self.name)

View File

@ -34,7 +34,7 @@ select = ["E", "F", "W", "B", "Q", "I"]
ignore = ["F401", "E501"]
[tool.mypy]
files = ["src", "examples"]
files = ["src", "examples", "tests"]
strict = true
python_version = "3.10"
@ -53,7 +53,7 @@ disallow_untyped_decorators = true
disallow_any_unimported = true
[tool.pyright]
include = ["src", "examples"]
include = ["src", "examples", "tests"]
typeCheckingMode = "strict"
reportUnnecessaryIsInstance = false
reportMissingTypeStubs = false

View File

@ -16,12 +16,12 @@ ProducesT = TypeVar("ProducesT", covariant=True)
def message_handler(
*target_types: Type[ReceivesT],
) -> Callable[
[Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]],
Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]],
Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
]:
def decorator(
func: Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
) -> Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]:
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
) -> Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]:
# Convert target_types to list and stash
func._target_types = list(target_types) # type: ignore
return func
@ -34,7 +34,7 @@ class TypeRoutedAgent(BaseAgent):
super().__init__(name, router)
# Self is already bound to the handlers
self._handlers: Dict[Type[Any], Callable[[Any, bool, CancellationToken], Coroutine[Any, Any, Any | None]]] = {}
self._handlers: Dict[Type[Any], Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]]] = {}
router.add_agent(self)
@ -49,17 +49,13 @@ class TypeRoutedAgent(BaseAgent):
def subscriptions(self) -> Sequence[Type[Any]]:
return list(self._handlers.keys())
async def on_message(
self, message: Any, require_response: bool, cancellation_token: CancellationToken
) -> Any | None:
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None:
key_type: Type[Any] = type(message) # type: ignore
handler = self._handlers.get(key_type) # type: ignore
if handler is not None:
return await handler(message, require_response, cancellation_token)
return await handler(message, cancellation_token)
else:
return await self.on_unhandled_message(message, require_response, cancellation_token)
return await self.on_unhandled_message(message, cancellation_token)
async def on_unhandled_message(
self, message: Any, require_response: bool, cancellation_token: CancellationToken
) -> NoReturn:
async def on_unhandled_message(self, message: Any, cancellation_token: CancellationToken) -> NoReturn:
raise CantHandleException(f"Unhandled message: {message}")

View File

@ -1,7 +1,7 @@
import asyncio
from asyncio import Future
from dataclasses import dataclass
from typing import Any, Awaitable, Dict, List, Sequence, Set, cast
from typing import Any, Awaitable, Dict, List, Set
from agnext.core.cancellation_token import CancellationToken
from agnext.core.exceptions import MessageDroppedException
@ -12,15 +12,13 @@ from ..core.agent_runtime import AgentRuntime
@dataclass(kw_only=True)
class BroadcastMessageEnvelope:
"""A message envelope for broadcasting messages to all agents that can handle
class PublishMessageEnvelope:
"""A message envelope for publishing messages to all agents that can handle
the message of the type T."""
message: Any
future: Future[Sequence[Any] | None]
cancellation_token: CancellationToken
sender: Agent | None
require_response: bool
@dataclass(kw_only=True)
@ -31,9 +29,8 @@ class SendMessageEnvelope:
message: Any
sender: Agent | None
recipient: Agent
future: Future[Any | None]
future: Future[Any]
cancellation_token: CancellationToken
require_response: bool
@dataclass(kw_only=True)
@ -46,20 +43,9 @@ class ResponseMessageEnvelope:
recipient: Agent | None
@dataclass(kw_only=True)
class BroadcastResponseMessageEnvelope:
"""A message envelope for sending a response to a message."""
message: Sequence[Any]
future: Future[Sequence[Any]]
recipient: Agent | None
class SingleThreadedAgentRuntime(AgentRuntime):
def __init__(self, *, before_send: InterventionHandler | None = None) -> None:
self._message_queue: List[
BroadcastMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope | BroadcastResponseMessageEnvelope
] = []
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
self._per_type_subscribers: Dict[type, List[Agent]] = {}
self._agents: Set[Agent] = set()
self._before_send = before_send
@ -77,7 +63,6 @@ class SingleThreadedAgentRuntime(AgentRuntime):
message: Any,
recipient: Agent,
*,
require_response: bool = True,
sender: Agent | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[Any | None]:
@ -95,36 +80,31 @@ class SingleThreadedAgentRuntime(AgentRuntime):
future=future,
cancellation_token=cancellation_token,
sender=sender,
require_response=require_response,
)
)
return future
# send message, require_response=False -> returns after delivery, gives None
# send message, require_response=True -> returns after handling, gives Response
def broadcast_message(
def publish_message(
self,
message: Any,
*,
require_response: bool = True,
sender: Agent | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[Sequence[Any] | None]:
) -> Future[None]:
if cancellation_token is None:
cancellation_token = CancellationToken()
future = asyncio.get_event_loop().create_future()
self._message_queue.append(
BroadcastMessageEnvelope(
PublishMessageEnvelope(
message=message,
future=future,
cancellation_token=cancellation_token,
sender=sender,
require_response=require_response,
)
)
future = asyncio.get_event_loop().create_future()
future.set_result(None)
return future
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
@ -134,64 +114,41 @@ class SingleThreadedAgentRuntime(AgentRuntime):
try:
response = await recipient.on_message(
message_envelope.message,
require_response=message_envelope.require_response,
cancellation_token=message_envelope.cancellation_token,
)
except BaseException as e:
message_envelope.future.set_exception(e)
return
if not message_envelope.require_response and response is not None:
raise Exception("Recipient returned a response for a message that did not request a response")
if message_envelope.require_response and response is None:
raise Exception("Recipient did not return a response for a message that requested a response")
if message_envelope.require_response:
self._message_queue.append(
ResponseMessageEnvelope(
message=response,
future=message_envelope.future,
sender=message_envelope.recipient,
recipient=message_envelope.sender,
)
self._message_queue.append(
ResponseMessageEnvelope(
message=response,
future=message_envelope.future,
sender=message_envelope.recipient,
recipient=message_envelope.sender,
)
else:
message_envelope.future.set_result(None)
)
async def _process_broadcast(self, message_envelope: BroadcastMessageEnvelope) -> None:
async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
responses: List[Awaitable[Any]] = []
for agent in self._per_type_subscribers.get(type(message_envelope.message), []): # type: ignore
future = agent.on_message(
message_envelope.message,
require_response=message_envelope.require_response,
cancellation_token=message_envelope.cancellation_token,
)
responses.append(future)
try:
all_responses = await asyncio.gather(*responses)
except BaseException as e:
message_envelope.future.set_exception(e)
_all_responses = await asyncio.gather(*responses)
except BaseException:
# TODO log error
return
if message_envelope.require_response:
self._message_queue.append(
BroadcastResponseMessageEnvelope(
message=all_responses,
future=cast(Future[Sequence[Any]], message_envelope.future),
recipient=message_envelope.sender,
)
)
else:
message_envelope.future.set_result(None)
# TODO if responses are given for a publish
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
message_envelope.future.set_result(message_envelope.message)
async def _process_broadcast_response(self, message_envelope: BroadcastResponseMessageEnvelope) -> None:
message_envelope.future.set_result(message_envelope.message)
async def process_next(self) -> None:
if len(self._message_queue) == 0:
# Yield control to the event loop to allow other tasks to run
@ -211,20 +168,19 @@ class SingleThreadedAgentRuntime(AgentRuntime):
message_envelope.message = temp_message
asyncio.create_task(self._process_send(message_envelope))
case BroadcastMessageEnvelope(
case PublishMessageEnvelope(
message=message,
sender=sender,
future=future,
):
if self._before_send is not None:
temp_message = await self._before_send.on_broadcast(message, sender=sender)
temp_message = await self._before_send.on_publish(message, sender=sender)
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
future.set_exception(MessageDroppedException())
# TODO log message dropped
return
message_envelope.message = temp_message
asyncio.create_task(self._process_broadcast(message_envelope))
asyncio.create_task(self._process_publish(message_envelope))
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._before_send is not None:
temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient)
@ -236,16 +192,5 @@ class SingleThreadedAgentRuntime(AgentRuntime):
asyncio.create_task(self._process_response(message_envelope))
case BroadcastResponseMessageEnvelope(message=message, recipient=recipient, future=future):
if self._before_send is not None:
temp_message_list = await self._before_send.on_broadcast_response(message, recipient=recipient)
if temp_message_list is DropMessage or isinstance(temp_message_list, DropMessage):
future.set_exception(MessageDroppedException())
return
message_envelope.message = list(temp_message_list) # type: ignore
asyncio.create_task(self._process_broadcast_response(message_envelope))
# Yield control to the message loop to allow other tasks to run
await asyncio.sleep(0)

View File

@ -26,7 +26,7 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
# TODO: use require_response
@message_handler(TextMessage)
async def on_chat_message_with_cancellation(
self, message: TextMessage, require_response: bool, cancellation_token: CancellationToken
self, message: TextMessage, cancellation_token: CancellationToken
) -> None:
print("---------------")
print(f"{self.name} received message from {message.source}: {message.content}")
@ -41,22 +41,13 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
)
self._current_session_window_length += 1
if require_response:
# TODO ?
...
@message_handler(Reset)
async def on_reset(self, message: Reset, require_response: bool, cancellation_token: CancellationToken) -> None:
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
# Reset the current session window.
self._current_session_window_length = 0
@message_handler(RespondNow)
async def on_respond_now(
self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken
) -> TextMessage | None:
if not require_response:
return None
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
# Create a run and wait until it finishes.
run = await self._client.beta.threads.runs.create_and_poll(
thread_id=self._thread_id,

View File

@ -11,7 +11,7 @@ class RandomResponseAgent(BaseChatAgent, TypeRoutedAgent):
# TODO: use require_response
@message_handler(RespondNow)
async def on_chat_message_with_cancellation(
self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken
self, message: RespondNow, cancellation_token: CancellationToken
) -> TextMessage:
# Generate a random response.
response_body = random.choice(

View File

@ -36,9 +36,7 @@ class GroupChat(BaseChatAgent):
agent_sublists = [agent.subscriptions for agent in self._agents]
return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist]
async def on_message(
self, message: Any, require_response: bool, cancellation_token: CancellationToken
) -> Any | None:
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None:
if isinstance(message, Reset):
# Reset the history.
self._history = []
@ -48,10 +46,8 @@ class GroupChat(BaseChatAgent):
# TODO reset...
return self._output.get_output()
# TODO: should we do nothing here?
# Perhaps it should be saved into the message history?
if not require_response:
return None
# TODO: how should we handle the group chat receiving a message while in the middle of a conversation?
# Should this class disallow it?
self._history.append(message)
round = 0
@ -67,14 +63,13 @@ class GroupChat(BaseChatAgent):
_ = await self._send_message(
self._history[-1],
agent,
require_response=False,
cancellation_token=cancellation_token,
)
# TODO handle if response is not None
response = await self._send_message(
RespondNow(),
speaker,
require_response=True,
cancellation_token=cancellation_token,
)
@ -88,4 +83,5 @@ class GroupChat(BaseChatAgent):
output = self._output.get_output()
self._output.reset()
self._history.clear()
return output

View File

@ -34,7 +34,6 @@ class Orchestrator(BaseChatAgent, TypeRoutedAgent):
async def on_chat_message(
self,
message: ChatMessage,
require_response: bool,
cancellation_token: CancellationToken,
) -> ChatMessage | None:
# A task is received.

View File

@ -11,6 +11,4 @@ class Agent(Protocol):
@property
def subscriptions(self) -> Sequence[type]: ...
async def on_message(
self, message: Any, require_response: bool, cancellation_token: CancellationToken
) -> Any | None: ...
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: ...

View File

@ -1,5 +1,5 @@
from asyncio import Future
from typing import Any, Protocol, Sequence
from typing import Any, Protocol
from agnext.core.agent import Agent
from agnext.core.cancellation_token import CancellationToken
@ -16,17 +16,15 @@ class AgentRuntime(Protocol):
message: Any,
recipient: Agent,
*,
require_response: bool = True,
sender: Agent | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[Any | None]: ...
) -> Future[Any]: ...
# Returns the response of all handling agents
def broadcast_message(
# No responses from publishing
def publish_message(
self,
message: Any,
*,
require_response: bool = True,
sender: Agent | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[Sequence[Any] | None]: ...
) -> Future[None]: ...

View File

@ -29,9 +29,7 @@ class BaseAgent(ABC, Agent):
return []
@abstractmethod
async def on_message(
self, message: Any, require_response: bool, cancellation_token: CancellationToken
) -> Any | None: ...
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: ...
# Returns the response of the message
def _send_message(
@ -39,9 +37,8 @@ class BaseAgent(ABC, Agent):
message: Any,
recipient: Agent,
*,
require_response: bool = True,
cancellation_token: CancellationToken | None = None,
) -> Future[Any | None]:
) -> Future[Any]:
if cancellation_token is None:
cancellation_token = CancellationToken()
@ -49,23 +46,18 @@ class BaseAgent(ABC, Agent):
message,
sender=self,
recipient=recipient,
require_response=require_response,
cancellation_token=cancellation_token,
)
cancellation_token.link_future(future)
return future
# Returns the response of all handling agents
def _broadcast_message(
def _publish_message(
self,
message: Any,
*,
require_response: bool = True,
cancellation_token: CancellationToken | None = None,
) -> Future[Sequence[Any] | None]:
) -> Future[None]:
if cancellation_token is None:
cancellation_token = CancellationToken()
future = self._router.broadcast_message(
message, sender=self, require_response=require_response, cancellation_token=cancellation_token
)
future = self._router.publish_message(message, sender=self, cancellation_token=cancellation_token)
return future

View File

@ -12,9 +12,9 @@ InterventionFunction = Callable[[Any], Any | Awaitable[type[DropMessage]]]
class InterventionHandler(Protocol):
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: ...
async def on_broadcast(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: ...
async def on_publish(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: ...
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: ...
async def on_broadcast_response(
async def on_publish_response(
self, message: Sequence[Any], *, recipient: Agent | None
) -> Sequence[Any] | type[DropMessage]: ...
@ -23,13 +23,13 @@ class DefaultInterventionHandler(InterventionHandler):
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]:
return message
async def on_broadcast(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]:
async def on_publish(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]:
return message
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]:
return message
async def on_broadcast_response(
async def on_publish_response(
self, message: Sequence[Any], *, recipient: Agent | None
) -> Sequence[Any] | type[DropMessage]:
return message

View File

@ -23,7 +23,7 @@ class LongRunningAgent(TypeRoutedAgent):
self.cancelled = False
@message_handler(MessageType)
async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType:
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
self.called = True
sleep = asyncio.ensure_future(asyncio.sleep(100))
cancellation_token.link_future(sleep)
@ -42,10 +42,9 @@ class NestingLongRunningAgent(TypeRoutedAgent):
self._nested_agent = nested_agent
@message_handler(MessageType)
async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType:
assert require_response == True
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
self.called = True
response = self._send_message(message, self._nested_agent, require_response=require_response, cancellation_token=cancellation_token)
response = self._send_message(message, self._nested_agent, cancellation_token=cancellation_token)
try:
val = await response
assert isinstance(val, MessageType)

View File

@ -20,7 +20,7 @@ class LoopbackAgent(TypeRoutedAgent):
@message_handler(MessageType)
async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType:
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
self.num_calls += 1
return message
@ -28,7 +28,7 @@ class LoopbackAgent(TypeRoutedAgent):
async def test_intervention_count_messages() -> None:
class DebugInterventionHandler(DefaultInterventionHandler):
def __init__(self):
def __init__(self) -> None:
self.num_messages = 0
async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType: