ensure agent name is unique, add some docs (#26)

This commit is contained in:
Jack Gerrits 2024-05-27 16:33:28 -04:00 committed by GitHub
parent cb55e00819
commit f8f7418ebf
7 changed files with 75 additions and 9 deletions

View File

@ -25,6 +25,7 @@ apidoc_output_dir = 'reference'
apidoc_template_dir = '_apidoc_templates'
apidoc_separate_modules = True
apidoc_extra_args = ["--no-toc"]
napoleon_custom_sections = [('Returns', 'params_style')]
templates_path = []
exclude_patterns = ["reference/agnext.rst"]

View File

@ -31,13 +31,9 @@ def message_handler(
class TypeRoutedAgent(BaseAgent):
def __init__(self, name: str, router: AgentRuntime) -> None:
super().__init__(name, router)
# Self is already bound to the handlers
self._handlers: Dict[Type[Any], Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]]] = {}
router.add_agent(self)
for attr in dir(self):
if callable(getattr(self, attr, None)):
handler = getattr(self, attr)
@ -45,6 +41,8 @@ class TypeRoutedAgent(BaseAgent):
for target_type in handler._target_types:
self._handlers[target_type] = handler
super().__init__(name, router)
@property
def subscriptions(self) -> Sequence[Type[Any]]:
return list(self._handlers.keys())

View File

@ -51,6 +51,10 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self._before_send = before_send
def add_agent(self, agent: Agent) -> None:
agent_names = {agent.name for agent in self._agents}
if agent.name in agent_names:
raise ValueError(f"Agent with name {agent.name} already exists. Agent names must be unique.")
for message_type in agent.subscriptions:
if message_type not in self._per_type_subscribers:
self._per_type_subscribers[message_type] = []

View File

@ -6,9 +6,30 @@ from agnext.core.cancellation_token import CancellationToken
@runtime_checkable
class Agent(Protocol):
@property
def name(self) -> str: ...
def name(self) -> str:
"""Name of the agent.
Note:
This name should be unique within the runtime.
"""
...
@property
def subscriptions(self) -> Sequence[type]: ...
def subscriptions(self) -> Sequence[type]:
"""Types of messages that this agent can receive."""
...
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: ...
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any:
"""Message handler for the agent. This should only be called by the runtime, not by other agents.
Args:
message (Any): Received message. Type is one of the types in `subscriptions`.
cancellation_token (CancellationToken): Cancellation token for the message.
Returns:
Any: Response to the message. Can be None.
Notes:
If there was a cancellation, this function should raise a `CancelledError`.
"""
...

View File

@ -8,7 +8,16 @@ from agnext.core.cancellation_token import CancellationToken
class AgentRuntime(Protocol):
def add_agent(self, agent: Agent) -> None: ...
def add_agent(self, agent: Agent) -> None:
"""Add an agent to the runtime.
Args:
agent (Agent): Agent to add to the runtime.
Note:
The name of the agent should be unique within the runtime.
"""
...
# Returns the response of the message
def send_message(

View File

@ -18,6 +18,7 @@ class BaseAgent(ABC, Agent):
def __init__(self, name: str, router: AgentRuntime) -> None:
self._name = name
self._router = router
router.add_agent(self)
@property
def name(self) -> str:
@ -29,7 +30,7 @@ class BaseAgent(ABC, Agent):
return []
@abstractmethod
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: ...
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: ...
# Returns the response of the message
def _send_message(

32
tests/test_runtime.py Normal file
View File

@ -0,0 +1,32 @@
from typing import Any, Sequence
import pytest
from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime
from agnext.core.agent_runtime import AgentRuntime
from agnext.core.base_agent import BaseAgent
from agnext.core.cancellation_token import CancellationToken
class NoopAgent(BaseAgent):
def __init__(self, name: str, router: AgentRuntime) -> None:
super().__init__(name, router)
@property
def subscriptions(self) -> Sequence[type]:
return []
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any:
raise NotImplementedError
@pytest.mark.asyncio
async def test_agent_names_must_be_unique() -> None:
router = SingleThreadedAgentRuntime()
_agent1 = NoopAgent("name1", router)
with pytest.raises(ValueError):
_agent1_again = NoopAgent("name1", router)
_agent3 = NoopAgent("name3", router)