mirror of https://github.com/microsoft/autogen.git
ensure agent name is unique, add some docs (#26)
This commit is contained in:
parent
cb55e00819
commit
f8f7418ebf
|
@ -25,6 +25,7 @@ apidoc_output_dir = 'reference'
|
|||
apidoc_template_dir = '_apidoc_templates'
|
||||
apidoc_separate_modules = True
|
||||
apidoc_extra_args = ["--no-toc"]
|
||||
napoleon_custom_sections = [('Returns', 'params_style')]
|
||||
|
||||
templates_path = []
|
||||
exclude_patterns = ["reference/agnext.rst"]
|
||||
|
|
|
@ -31,13 +31,9 @@ def message_handler(
|
|||
|
||||
class TypeRoutedAgent(BaseAgent):
|
||||
def __init__(self, name: str, router: AgentRuntime) -> None:
|
||||
super().__init__(name, router)
|
||||
|
||||
# Self is already bound to the handlers
|
||||
self._handlers: Dict[Type[Any], Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]]] = {}
|
||||
|
||||
router.add_agent(self)
|
||||
|
||||
for attr in dir(self):
|
||||
if callable(getattr(self, attr, None)):
|
||||
handler = getattr(self, attr)
|
||||
|
@ -45,6 +41,8 @@ class TypeRoutedAgent(BaseAgent):
|
|||
for target_type in handler._target_types:
|
||||
self._handlers[target_type] = handler
|
||||
|
||||
super().__init__(name, router)
|
||||
|
||||
@property
|
||||
def subscriptions(self) -> Sequence[Type[Any]]:
|
||||
return list(self._handlers.keys())
|
||||
|
|
|
@ -51,6 +51,10 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||
self._before_send = before_send
|
||||
|
||||
def add_agent(self, agent: Agent) -> None:
|
||||
agent_names = {agent.name for agent in self._agents}
|
||||
if agent.name in agent_names:
|
||||
raise ValueError(f"Agent with name {agent.name} already exists. Agent names must be unique.")
|
||||
|
||||
for message_type in agent.subscriptions:
|
||||
if message_type not in self._per_type_subscribers:
|
||||
self._per_type_subscribers[message_type] = []
|
||||
|
|
|
@ -6,9 +6,30 @@ from agnext.core.cancellation_token import CancellationToken
|
|||
@runtime_checkable
|
||||
class Agent(Protocol):
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
def name(self) -> str:
|
||||
"""Name of the agent.
|
||||
|
||||
Note:
|
||||
This name should be unique within the runtime.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def subscriptions(self) -> Sequence[type]: ...
|
||||
def subscriptions(self) -> Sequence[type]:
|
||||
"""Types of messages that this agent can receive."""
|
||||
...
|
||||
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: ...
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any:
|
||||
"""Message handler for the agent. This should only be called by the runtime, not by other agents.
|
||||
|
||||
Args:
|
||||
message (Any): Received message. Type is one of the types in `subscriptions`.
|
||||
cancellation_token (CancellationToken): Cancellation token for the message.
|
||||
|
||||
Returns:
|
||||
Any: Response to the message. Can be None.
|
||||
|
||||
Notes:
|
||||
If there was a cancellation, this function should raise a `CancelledError`.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -8,7 +8,16 @@ from agnext.core.cancellation_token import CancellationToken
|
|||
|
||||
|
||||
class AgentRuntime(Protocol):
|
||||
def add_agent(self, agent: Agent) -> None: ...
|
||||
def add_agent(self, agent: Agent) -> None:
|
||||
"""Add an agent to the runtime.
|
||||
|
||||
Args:
|
||||
agent (Agent): Agent to add to the runtime.
|
||||
|
||||
Note:
|
||||
The name of the agent should be unique within the runtime.
|
||||
"""
|
||||
...
|
||||
|
||||
# Returns the response of the message
|
||||
def send_message(
|
||||
|
|
|
@ -18,6 +18,7 @@ class BaseAgent(ABC, Agent):
|
|||
def __init__(self, name: str, router: AgentRuntime) -> None:
|
||||
self._name = name
|
||||
self._router = router
|
||||
router.add_agent(self)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
@ -29,7 +30,7 @@ class BaseAgent(ABC, Agent):
|
|||
return []
|
||||
|
||||
@abstractmethod
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: ...
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: ...
|
||||
|
||||
# Returns the response of the message
|
||||
def _send_message(
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
from typing import Any, Sequence
|
||||
import pytest
|
||||
|
||||
from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime
|
||||
from agnext.core.agent_runtime import AgentRuntime
|
||||
from agnext.core.base_agent import BaseAgent
|
||||
from agnext.core.cancellation_token import CancellationToken
|
||||
|
||||
class NoopAgent(BaseAgent):
|
||||
def __init__(self, name: str, router: AgentRuntime) -> None:
|
||||
super().__init__(name, router)
|
||||
|
||||
@property
|
||||
def subscriptions(self) -> Sequence[type]:
|
||||
return []
|
||||
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_names_must_be_unique() -> None:
|
||||
router = SingleThreadedAgentRuntime()
|
||||
|
||||
_agent1 = NoopAgent("name1", router)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_agent1_again = NoopAgent("name1", router)
|
||||
|
||||
_agent3 = NoopAgent("name3", router)
|
||||
|
||||
|
Loading…
Reference in New Issue