mirror of https://github.com/microsoft/autogen.git
Remove require_response, rename broadcast to publish, remove publish responses (#25)
* rename broadcast to publish * remove require response, remove responses from publishing
This commit is contained in:
parent
b6dd861166
commit
cb55e00819
|
@ -19,10 +19,7 @@ class Inner(TypeRoutedAgent):
|
|||
super().__init__(name, router)
|
||||
|
||||
@message_handler(MessageType)
|
||||
async def on_new_message(
|
||||
self, message: MessageType, require_response: bool, cancellation_token: CancellationToken
|
||||
) -> MessageType:
|
||||
assert require_response
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||
return MessageType(body=f"Inner: {message.body}", sender=self.name)
|
||||
|
||||
|
||||
|
@ -32,11 +29,8 @@ class Outer(TypeRoutedAgent):
|
|||
self._inner = inner
|
||||
|
||||
@message_handler(MessageType)
|
||||
async def on_new_message(
|
||||
self, message: MessageType, require_response: bool, cancellation_token: CancellationToken
|
||||
) -> MessageType:
|
||||
assert require_response
|
||||
inner_response = self._send_message(message, self._inner, require_response=True)
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||
inner_response = self._send_message(message, self._inner)
|
||||
inner_message = await inner_response
|
||||
assert isinstance(inner_message, MessageType)
|
||||
return MessageType(body=f"Outer: {inner_message.body}", sender=self.name)
|
||||
|
|
|
@ -34,7 +34,7 @@ select = ["E", "F", "W", "B", "Q", "I"]
|
|||
ignore = ["F401", "E501"]
|
||||
|
||||
[tool.mypy]
|
||||
files = ["src", "examples"]
|
||||
files = ["src", "examples", "tests"]
|
||||
|
||||
strict = true
|
||||
python_version = "3.10"
|
||||
|
@ -53,7 +53,7 @@ disallow_untyped_decorators = true
|
|||
disallow_any_unimported = true
|
||||
|
||||
[tool.pyright]
|
||||
include = ["src", "examples"]
|
||||
include = ["src", "examples", "tests"]
|
||||
typeCheckingMode = "strict"
|
||||
reportUnnecessaryIsInstance = false
|
||||
reportMissingTypeStubs = false
|
||||
|
|
|
@ -16,12 +16,12 @@ ProducesT = TypeVar("ProducesT", covariant=True)
|
|||
def message_handler(
|
||||
*target_types: Type[ReceivesT],
|
||||
) -> Callable[
|
||||
[Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]],
|
||||
Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
|
||||
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]],
|
||||
Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
|
||||
]:
|
||||
def decorator(
|
||||
func: Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
|
||||
) -> Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]:
|
||||
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
|
||||
) -> Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]:
|
||||
# Convert target_types to list and stash
|
||||
func._target_types = list(target_types) # type: ignore
|
||||
return func
|
||||
|
@ -34,7 +34,7 @@ class TypeRoutedAgent(BaseAgent):
|
|||
super().__init__(name, router)
|
||||
|
||||
# Self is already bound to the handlers
|
||||
self._handlers: Dict[Type[Any], Callable[[Any, bool, CancellationToken], Coroutine[Any, Any, Any | None]]] = {}
|
||||
self._handlers: Dict[Type[Any], Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]]] = {}
|
||||
|
||||
router.add_agent(self)
|
||||
|
||||
|
@ -49,17 +49,13 @@ class TypeRoutedAgent(BaseAgent):
|
|||
def subscriptions(self) -> Sequence[Type[Any]]:
|
||||
return list(self._handlers.keys())
|
||||
|
||||
async def on_message(
|
||||
self, message: Any, require_response: bool, cancellation_token: CancellationToken
|
||||
) -> Any | None:
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None:
|
||||
key_type: Type[Any] = type(message) # type: ignore
|
||||
handler = self._handlers.get(key_type) # type: ignore
|
||||
if handler is not None:
|
||||
return await handler(message, require_response, cancellation_token)
|
||||
return await handler(message, cancellation_token)
|
||||
else:
|
||||
return await self.on_unhandled_message(message, require_response, cancellation_token)
|
||||
return await self.on_unhandled_message(message, cancellation_token)
|
||||
|
||||
async def on_unhandled_message(
|
||||
self, message: Any, require_response: bool, cancellation_token: CancellationToken
|
||||
) -> NoReturn:
|
||||
async def on_unhandled_message(self, message: Any, cancellation_token: CancellationToken) -> NoReturn:
|
||||
raise CantHandleException(f"Unhandled message: {message}")
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
from asyncio import Future
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Dict, List, Sequence, Set, cast
|
||||
from typing import Any, Awaitable, Dict, List, Set
|
||||
|
||||
from agnext.core.cancellation_token import CancellationToken
|
||||
from agnext.core.exceptions import MessageDroppedException
|
||||
|
@ -12,15 +12,13 @@ from ..core.agent_runtime import AgentRuntime
|
|||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class BroadcastMessageEnvelope:
|
||||
"""A message envelope for broadcasting messages to all agents that can handle
|
||||
class PublishMessageEnvelope:
|
||||
"""A message envelope for publishing messages to all agents that can handle
|
||||
the message of the type T."""
|
||||
|
||||
message: Any
|
||||
future: Future[Sequence[Any] | None]
|
||||
cancellation_token: CancellationToken
|
||||
sender: Agent | None
|
||||
require_response: bool
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
|
@ -31,9 +29,8 @@ class SendMessageEnvelope:
|
|||
message: Any
|
||||
sender: Agent | None
|
||||
recipient: Agent
|
||||
future: Future[Any | None]
|
||||
future: Future[Any]
|
||||
cancellation_token: CancellationToken
|
||||
require_response: bool
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
|
@ -46,20 +43,9 @@ class ResponseMessageEnvelope:
|
|||
recipient: Agent | None
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class BroadcastResponseMessageEnvelope:
|
||||
"""A message envelope for sending a response to a message."""
|
||||
|
||||
message: Sequence[Any]
|
||||
future: Future[Sequence[Any]]
|
||||
recipient: Agent | None
|
||||
|
||||
|
||||
class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
def __init__(self, *, before_send: InterventionHandler | None = None) -> None:
|
||||
self._message_queue: List[
|
||||
BroadcastMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope | BroadcastResponseMessageEnvelope
|
||||
] = []
|
||||
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
|
||||
self._per_type_subscribers: Dict[type, List[Agent]] = {}
|
||||
self._agents: Set[Agent] = set()
|
||||
self._before_send = before_send
|
||||
|
@ -77,7 +63,6 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||
message: Any,
|
||||
recipient: Agent,
|
||||
*,
|
||||
require_response: bool = True,
|
||||
sender: Agent | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[Any | None]:
|
||||
|
@ -95,36 +80,31 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||
future=future,
|
||||
cancellation_token=cancellation_token,
|
||||
sender=sender,
|
||||
require_response=require_response,
|
||||
)
|
||||
)
|
||||
|
||||
return future
|
||||
|
||||
# send message, require_response=False -> returns after delivery, gives None
|
||||
# send message, require_response=True -> returns after handling, gives Response
|
||||
def broadcast_message(
|
||||
def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
*,
|
||||
require_response: bool = True,
|
||||
sender: Agent | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[Sequence[Any] | None]:
|
||||
) -> Future[None]:
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self._message_queue.append(
|
||||
BroadcastMessageEnvelope(
|
||||
PublishMessageEnvelope(
|
||||
message=message,
|
||||
future=future,
|
||||
cancellation_token=cancellation_token,
|
||||
sender=sender,
|
||||
require_response=require_response,
|
||||
)
|
||||
)
|
||||
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
future.set_result(None)
|
||||
return future
|
||||
|
||||
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
|
||||
|
@ -134,64 +114,41 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||
try:
|
||||
response = await recipient.on_message(
|
||||
message_envelope.message,
|
||||
require_response=message_envelope.require_response,
|
||||
cancellation_token=message_envelope.cancellation_token,
|
||||
)
|
||||
except BaseException as e:
|
||||
message_envelope.future.set_exception(e)
|
||||
return
|
||||
|
||||
if not message_envelope.require_response and response is not None:
|
||||
raise Exception("Recipient returned a response for a message that did not request a response")
|
||||
|
||||
if message_envelope.require_response and response is None:
|
||||
raise Exception("Recipient did not return a response for a message that requested a response")
|
||||
|
||||
if message_envelope.require_response:
|
||||
self._message_queue.append(
|
||||
ResponseMessageEnvelope(
|
||||
message=response,
|
||||
future=message_envelope.future,
|
||||
sender=message_envelope.recipient,
|
||||
recipient=message_envelope.sender,
|
||||
)
|
||||
self._message_queue.append(
|
||||
ResponseMessageEnvelope(
|
||||
message=response,
|
||||
future=message_envelope.future,
|
||||
sender=message_envelope.recipient,
|
||||
recipient=message_envelope.sender,
|
||||
)
|
||||
else:
|
||||
message_envelope.future.set_result(None)
|
||||
)
|
||||
|
||||
async def _process_broadcast(self, message_envelope: BroadcastMessageEnvelope) -> None:
|
||||
async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
|
||||
responses: List[Awaitable[Any]] = []
|
||||
for agent in self._per_type_subscribers.get(type(message_envelope.message), []): # type: ignore
|
||||
future = agent.on_message(
|
||||
message_envelope.message,
|
||||
require_response=message_envelope.require_response,
|
||||
cancellation_token=message_envelope.cancellation_token,
|
||||
)
|
||||
responses.append(future)
|
||||
|
||||
try:
|
||||
all_responses = await asyncio.gather(*responses)
|
||||
except BaseException as e:
|
||||
message_envelope.future.set_exception(e)
|
||||
_all_responses = await asyncio.gather(*responses)
|
||||
except BaseException:
|
||||
# TODO log error
|
||||
return
|
||||
|
||||
if message_envelope.require_response:
|
||||
self._message_queue.append(
|
||||
BroadcastResponseMessageEnvelope(
|
||||
message=all_responses,
|
||||
future=cast(Future[Sequence[Any]], message_envelope.future),
|
||||
recipient=message_envelope.sender,
|
||||
)
|
||||
)
|
||||
else:
|
||||
message_envelope.future.set_result(None)
|
||||
# TODO if responses are given for a publish
|
||||
|
||||
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
|
||||
message_envelope.future.set_result(message_envelope.message)
|
||||
|
||||
async def _process_broadcast_response(self, message_envelope: BroadcastResponseMessageEnvelope) -> None:
|
||||
message_envelope.future.set_result(message_envelope.message)
|
||||
|
||||
async def process_next(self) -> None:
|
||||
if len(self._message_queue) == 0:
|
||||
# Yield control to the event loop to allow other tasks to run
|
||||
|
@ -211,20 +168,19 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||
message_envelope.message = temp_message
|
||||
|
||||
asyncio.create_task(self._process_send(message_envelope))
|
||||
case BroadcastMessageEnvelope(
|
||||
case PublishMessageEnvelope(
|
||||
message=message,
|
||||
sender=sender,
|
||||
future=future,
|
||||
):
|
||||
if self._before_send is not None:
|
||||
temp_message = await self._before_send.on_broadcast(message, sender=sender)
|
||||
temp_message = await self._before_send.on_publish(message, sender=sender)
|
||||
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
||||
future.set_exception(MessageDroppedException())
|
||||
# TODO log message dropped
|
||||
return
|
||||
|
||||
message_envelope.message = temp_message
|
||||
|
||||
asyncio.create_task(self._process_broadcast(message_envelope))
|
||||
asyncio.create_task(self._process_publish(message_envelope))
|
||||
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
||||
if self._before_send is not None:
|
||||
temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient)
|
||||
|
@ -236,16 +192,5 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||
|
||||
asyncio.create_task(self._process_response(message_envelope))
|
||||
|
||||
case BroadcastResponseMessageEnvelope(message=message, recipient=recipient, future=future):
|
||||
if self._before_send is not None:
|
||||
temp_message_list = await self._before_send.on_broadcast_response(message, recipient=recipient)
|
||||
if temp_message_list is DropMessage or isinstance(temp_message_list, DropMessage):
|
||||
future.set_exception(MessageDroppedException())
|
||||
return
|
||||
|
||||
message_envelope.message = list(temp_message_list) # type: ignore
|
||||
|
||||
asyncio.create_task(self._process_broadcast_response(message_envelope))
|
||||
|
||||
# Yield control to the message loop to allow other tasks to run
|
||||
await asyncio.sleep(0)
|
||||
|
|
|
@ -26,7 +26,7 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
|
|||
# TODO: use require_response
|
||||
@message_handler(TextMessage)
|
||||
async def on_chat_message_with_cancellation(
|
||||
self, message: TextMessage, require_response: bool, cancellation_token: CancellationToken
|
||||
self, message: TextMessage, cancellation_token: CancellationToken
|
||||
) -> None:
|
||||
print("---------------")
|
||||
print(f"{self.name} received message from {message.source}: {message.content}")
|
||||
|
@ -41,22 +41,13 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
|
|||
)
|
||||
self._current_session_window_length += 1
|
||||
|
||||
if require_response:
|
||||
# TODO ?
|
||||
...
|
||||
|
||||
@message_handler(Reset)
|
||||
async def on_reset(self, message: Reset, require_response: bool, cancellation_token: CancellationToken) -> None:
|
||||
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
|
||||
# Reset the current session window.
|
||||
self._current_session_window_length = 0
|
||||
|
||||
@message_handler(RespondNow)
|
||||
async def on_respond_now(
|
||||
self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken
|
||||
) -> TextMessage | None:
|
||||
if not require_response:
|
||||
return None
|
||||
|
||||
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
|
||||
# Create a run and wait until it finishes.
|
||||
run = await self._client.beta.threads.runs.create_and_poll(
|
||||
thread_id=self._thread_id,
|
||||
|
|
|
@ -11,7 +11,7 @@ class RandomResponseAgent(BaseChatAgent, TypeRoutedAgent):
|
|||
# TODO: use require_response
|
||||
@message_handler(RespondNow)
|
||||
async def on_chat_message_with_cancellation(
|
||||
self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken
|
||||
self, message: RespondNow, cancellation_token: CancellationToken
|
||||
) -> TextMessage:
|
||||
# Generate a random response.
|
||||
response_body = random.choice(
|
||||
|
|
|
@ -36,9 +36,7 @@ class GroupChat(BaseChatAgent):
|
|||
agent_sublists = [agent.subscriptions for agent in self._agents]
|
||||
return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist]
|
||||
|
||||
async def on_message(
|
||||
self, message: Any, require_response: bool, cancellation_token: CancellationToken
|
||||
) -> Any | None:
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None:
|
||||
if isinstance(message, Reset):
|
||||
# Reset the history.
|
||||
self._history = []
|
||||
|
@ -48,10 +46,8 @@ class GroupChat(BaseChatAgent):
|
|||
# TODO reset...
|
||||
return self._output.get_output()
|
||||
|
||||
# TODO: should we do nothing here?
|
||||
# Perhaps it should be saved into the message history?
|
||||
if not require_response:
|
||||
return None
|
||||
# TODO: how should we handle the group chat receiving a message while in the middle of a conversation?
|
||||
# Should this class disallow it?
|
||||
|
||||
self._history.append(message)
|
||||
round = 0
|
||||
|
@ -67,14 +63,13 @@ class GroupChat(BaseChatAgent):
|
|||
_ = await self._send_message(
|
||||
self._history[-1],
|
||||
agent,
|
||||
require_response=False,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
# TODO handle if response is not None
|
||||
|
||||
response = await self._send_message(
|
||||
RespondNow(),
|
||||
speaker,
|
||||
require_response=True,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
|
@ -88,4 +83,5 @@ class GroupChat(BaseChatAgent):
|
|||
|
||||
output = self._output.get_output()
|
||||
self._output.reset()
|
||||
self._history.clear()
|
||||
return output
|
||||
|
|
|
@ -34,7 +34,6 @@ class Orchestrator(BaseChatAgent, TypeRoutedAgent):
|
|||
async def on_chat_message(
|
||||
self,
|
||||
message: ChatMessage,
|
||||
require_response: bool,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> ChatMessage | None:
|
||||
# A task is received.
|
||||
|
|
|
@ -11,6 +11,4 @@ class Agent(Protocol):
|
|||
@property
|
||||
def subscriptions(self) -> Sequence[type]: ...
|
||||
|
||||
async def on_message(
|
||||
self, message: Any, require_response: bool, cancellation_token: CancellationToken
|
||||
) -> Any | None: ...
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: ...
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from asyncio import Future
|
||||
from typing import Any, Protocol, Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from agnext.core.agent import Agent
|
||||
from agnext.core.cancellation_token import CancellationToken
|
||||
|
@ -16,17 +16,15 @@ class AgentRuntime(Protocol):
|
|||
message: Any,
|
||||
recipient: Agent,
|
||||
*,
|
||||
require_response: bool = True,
|
||||
sender: Agent | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[Any | None]: ...
|
||||
) -> Future[Any]: ...
|
||||
|
||||
# Returns the response of all handling agents
|
||||
def broadcast_message(
|
||||
# No responses from publishing
|
||||
def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
*,
|
||||
require_response: bool = True,
|
||||
sender: Agent | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[Sequence[Any] | None]: ...
|
||||
) -> Future[None]: ...
|
||||
|
|
|
@ -29,9 +29,7 @@ class BaseAgent(ABC, Agent):
|
|||
return []
|
||||
|
||||
@abstractmethod
|
||||
async def on_message(
|
||||
self, message: Any, require_response: bool, cancellation_token: CancellationToken
|
||||
) -> Any | None: ...
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: ...
|
||||
|
||||
# Returns the response of the message
|
||||
def _send_message(
|
||||
|
@ -39,9 +37,8 @@ class BaseAgent(ABC, Agent):
|
|||
message: Any,
|
||||
recipient: Agent,
|
||||
*,
|
||||
require_response: bool = True,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[Any | None]:
|
||||
) -> Future[Any]:
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
|
@ -49,23 +46,18 @@ class BaseAgent(ABC, Agent):
|
|||
message,
|
||||
sender=self,
|
||||
recipient=recipient,
|
||||
require_response=require_response,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
cancellation_token.link_future(future)
|
||||
return future
|
||||
|
||||
# Returns the response of all handling agents
|
||||
def _broadcast_message(
|
||||
def _publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
*,
|
||||
require_response: bool = True,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[Sequence[Any] | None]:
|
||||
) -> Future[None]:
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
future = self._router.broadcast_message(
|
||||
message, sender=self, require_response=require_response, cancellation_token=cancellation_token
|
||||
)
|
||||
future = self._router.publish_message(message, sender=self, cancellation_token=cancellation_token)
|
||||
return future
|
||||
|
|
|
@ -12,9 +12,9 @@ InterventionFunction = Callable[[Any], Any | Awaitable[type[DropMessage]]]
|
|||
|
||||
class InterventionHandler(Protocol):
|
||||
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: ...
|
||||
async def on_broadcast(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: ...
|
||||
async def on_publish(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: ...
|
||||
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: ...
|
||||
async def on_broadcast_response(
|
||||
async def on_publish_response(
|
||||
self, message: Sequence[Any], *, recipient: Agent | None
|
||||
) -> Sequence[Any] | type[DropMessage]: ...
|
||||
|
||||
|
@ -23,13 +23,13 @@ class DefaultInterventionHandler(InterventionHandler):
|
|||
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]:
|
||||
return message
|
||||
|
||||
async def on_broadcast(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]:
|
||||
async def on_publish(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]:
|
||||
return message
|
||||
|
||||
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]:
|
||||
return message
|
||||
|
||||
async def on_broadcast_response(
|
||||
async def on_publish_response(
|
||||
self, message: Sequence[Any], *, recipient: Agent | None
|
||||
) -> Sequence[Any] | type[DropMessage]:
|
||||
return message
|
||||
|
|
|
@ -23,7 +23,7 @@ class LongRunningAgent(TypeRoutedAgent):
|
|||
self.cancelled = False
|
||||
|
||||
@message_handler(MessageType)
|
||||
async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType:
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||
self.called = True
|
||||
sleep = asyncio.ensure_future(asyncio.sleep(100))
|
||||
cancellation_token.link_future(sleep)
|
||||
|
@ -42,10 +42,9 @@ class NestingLongRunningAgent(TypeRoutedAgent):
|
|||
self._nested_agent = nested_agent
|
||||
|
||||
@message_handler(MessageType)
|
||||
async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType:
|
||||
assert require_response == True
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||
self.called = True
|
||||
response = self._send_message(message, self._nested_agent, require_response=require_response, cancellation_token=cancellation_token)
|
||||
response = self._send_message(message, self._nested_agent, cancellation_token=cancellation_token)
|
||||
try:
|
||||
val = await response
|
||||
assert isinstance(val, MessageType)
|
||||
|
|
|
@ -20,7 +20,7 @@ class LoopbackAgent(TypeRoutedAgent):
|
|||
|
||||
|
||||
@message_handler(MessageType)
|
||||
async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType:
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||
self.num_calls += 1
|
||||
return message
|
||||
|
||||
|
@ -28,7 +28,7 @@ class LoopbackAgent(TypeRoutedAgent):
|
|||
async def test_intervention_count_messages() -> None:
|
||||
|
||||
class DebugInterventionHandler(DefaultInterventionHandler):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.num_messages = 0
|
||||
|
||||
async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType:
|
||||
|
|
Loading…
Reference in New Issue