Implement default sub and topic (#398)

* Implement default sub and topic

* format

* update test
This commit is contained in:
Jack Gerrits 2024-08-23 16:01:57 -04:00 committed by GitHub
parent 8f082cecda
commit 4c964fa772
8 changed files with 175 additions and 14 deletions

View File

@ -20,6 +20,7 @@ from ..core import (
AgentType,
CancellationToken,
MessageContext,
MessageHandlerContext,
Subscription,
SubscriptionInstantiationContext,
TopicId,
@ -264,10 +265,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
is_rpc=True,
cancellation_token=message_envelope.cancellation_token,
)
response = await recipient_agent.on_message(
message_envelope.message,
ctx=message_context,
)
with MessageHandlerContext.populate_context(recipient_agent.id):
response = await recipient_agent.on_message(
message_envelope.message,
ctx=message_context,
)
except BaseException as e:
message_envelope.future.set_exception(e)
return
@ -313,10 +315,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
cancellation_token=message_envelope.cancellation_token,
)
agent = await self._get_agent(agent_id)
future = agent.on_message(
message_envelope.message,
ctx=message_context,
)
with MessageHandlerContext.populate_context(agent.id):
future = agent.on_message(
message_envelope.message,
ctx=message_context,
)
responses.append(future)
try:

View File

@ -39,6 +39,7 @@ from ..core import (
AgentType,
CancellationToken,
MessageContext,
MessageHandlerContext,
Subscription,
SubscriptionInstantiationContext,
TopicId,
@ -323,7 +324,8 @@ class WorkerAgentRuntime(AgentRuntime):
# Call the target agent.
try:
result = await target_agent.on_message(message, ctx=message_context)
with MessageHandlerContext.populate_context(target_agent.id):
result = await target_agent.on_message(message, ctx=message_context)
except BaseException as e:
response_message = agent_worker_pb2.Message(
response=agent_worker_pb2.RpcResponse(
@ -377,7 +379,8 @@ class WorkerAgentRuntime(AgentRuntime):
cancellation_token=CancellationToken(),
)
agent = await self._get_agent(agent_id)
future = agent.on_message(message, ctx=message_context)
with MessageHandlerContext.populate_context(agent.id):
future = agent.on_message(message, ctx=message_context)
responses.append(future)
# Wait for all responses.
try:

View File

@ -3,9 +3,20 @@ The :mod:`agnext.components` module provides building blocks for creating single
"""
from ._closure_agent import ClosureAgent
from ._default_subscription import DefaultSubscription
from ._default_topic import DefaultTopicId
from ._image import Image
from ._type_routed_agent import TypeRoutedAgent, message_handler
from ._type_subscription import TypeSubscription
from ._types import FunctionCall
__all__ = ["Image", "TypeRoutedAgent", "ClosureAgent", "message_handler", "FunctionCall", "TypeSubscription"]
__all__ = [
"Image",
"TypeRoutedAgent",
"ClosureAgent",
"message_handler",
"FunctionCall",
"TypeSubscription",
"DefaultSubscription",
"DefaultTopicId",
]

View File

@ -0,0 +1,32 @@
from agnext.core.exceptions import CantHandleException
from ..core import SubscriptionInstantiationContext
from ._type_subscription import TypeSubscription
class DefaultSubscription(TypeSubscription):
def __init__(self, topic_type: str = "default", agent_type: str | None = None):
"""The default subscription is designed to be a sensible default for applications that only need global scope for agents.
This topic by default uses the "default" topic type and attempts to detect the agent type to use based on the instantiation context.
Example:
.. code-block:: python
await runtime.register("MyAgent", agent_factory, lambda: [DefaultSubscription()])
Args:
topic_type (str, optional): The topic type to subscribe to. Defaults to "default".
agent_type (str, optional): The agent type to use for the subscription. Defaults to None, in which case it will attempt to detect the agent type based on the instantiation context.
"""
if agent_type is None:
try:
agent_type = SubscriptionInstantiationContext.agent_type().type
except RuntimeError as e:
raise CantHandleException(
"If agent_type is not specified DefaultSubscription must be created within the subscription callback in AgentRuntime.register"
) from e
super().__init__(topic_type, agent_type)

View File

@ -0,0 +1,21 @@
from ..core import MessageHandlerContext, TopicId
class DefaultTopicId(TopicId):
def __init__(self, type: str = "default", source: str | None = None) -> None:
"""DefaultTopicId provides a sensible default for the topic_id and source fields of a TopicId.
If created in the context of a message handler, the source will be set to the agent_id of the message handler, otherwise it will be set to "default".
Args:
type (str, optional): Topic type to publish message to. Defaults to "default".
source (str | None, optional): Topic source to publish message to. If None, the source will be set to the agent_id of the message handler if in the context of a message handler, otherwise it will be set to "default". Defaults to None.
"""
if source is None:
try:
source = MessageHandlerContext.agent_id().key
# If we aren't in the context of a message handler, we use the default source
except RuntimeError:
source = "default"
super().__init__(type, source)

View File

@ -13,6 +13,7 @@ from ._agent_type import AgentType
from ._base_agent import BaseAgent
from ._cancellation_token import CancellationToken
from ._message_context import MessageContext
from ._message_handler_context import MessageHandlerContext
from ._serialization import MESSAGE_TYPE_REGISTRY, Serialization, TypeDeserializer, TypeSerializer
from ._subscription import Subscription
from ._subscription_context import SubscriptionInstantiationContext
@ -37,4 +38,5 @@ __all__ = [
"Serialization",
"AgentType",
"SubscriptionInstantiationContext",
"MessageHandlerContext",
]

View File

@ -0,0 +1,30 @@
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, ClassVar, Generator
from ._agent_id import AgentId
class MessageHandlerContext:
def __init__(self) -> None:
raise RuntimeError(
"MessageHandlerContext cannot be instantiated. It is a static class that provides context management for agent instantiation."
)
MESSAGE_HANDLER_CONTEXT: ClassVar[ContextVar[AgentId]] = ContextVar("MESSAGE_HANDLER_CONTEXT")
@classmethod
@contextmanager
def populate_context(cls, ctx: AgentId) -> Generator[None, Any, None]:
token = MessageHandlerContext.MESSAGE_HANDLER_CONTEXT.set(ctx)
try:
yield
finally:
MessageHandlerContext.MESSAGE_HANDLER_CONTEXT.reset(token)
@classmethod
def agent_id(cls) -> AgentId:
try:
return cls.MESSAGE_HANDLER_CONTEXT.get()
except LookupError as e:
raise RuntimeError("MessageHandlerContext.agent_id() must be called within a message handler.") from e

View File

@ -1,11 +1,11 @@
import asyncio
import pytest
from agnext.application import SingleThreadedAgentRuntime
from agnext.components._type_subscription import TypeSubscription
from agnext.components import TypeSubscription, DefaultTopicId, DefaultSubscription
from agnext.core import AgentId, AgentInstantiationContext
from agnext.core import TopicId
from agnext.core._subscription import Subscription
from agnext.core._subscription_context import SubscriptionInstantiationContext
from agnext.core import Subscription
from agnext.core import SubscriptionInstantiationContext
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent
@ -163,3 +163,62 @@ async def test_register_factory_direct_list() -> None:
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_default_subscription() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
runtime.start()
agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_non_default_default_subscription() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription(topic_type="Other")])
runtime.start()
agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(type="Other"))
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_non_publish_to_other_source() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
runtime.start()
agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(source="other"))
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 0
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 1