mirror of https://github.com/microsoft/autogen.git
Implement default sub and topic (#398)
* Implement default sub and topic * format * update test
This commit is contained in:
parent
8f082cecda
commit
4c964fa772
|
@ -20,6 +20,7 @@ from ..core import (
|
|||
AgentType,
|
||||
CancellationToken,
|
||||
MessageContext,
|
||||
MessageHandlerContext,
|
||||
Subscription,
|
||||
SubscriptionInstantiationContext,
|
||||
TopicId,
|
||||
|
@ -264,10 +265,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||
is_rpc=True,
|
||||
cancellation_token=message_envelope.cancellation_token,
|
||||
)
|
||||
response = await recipient_agent.on_message(
|
||||
message_envelope.message,
|
||||
ctx=message_context,
|
||||
)
|
||||
with MessageHandlerContext.populate_context(recipient_agent.id):
|
||||
response = await recipient_agent.on_message(
|
||||
message_envelope.message,
|
||||
ctx=message_context,
|
||||
)
|
||||
except BaseException as e:
|
||||
message_envelope.future.set_exception(e)
|
||||
return
|
||||
|
@ -313,10 +315,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||
cancellation_token=message_envelope.cancellation_token,
|
||||
)
|
||||
agent = await self._get_agent(agent_id)
|
||||
future = agent.on_message(
|
||||
message_envelope.message,
|
||||
ctx=message_context,
|
||||
)
|
||||
with MessageHandlerContext.populate_context(agent.id):
|
||||
future = agent.on_message(
|
||||
message_envelope.message,
|
||||
ctx=message_context,
|
||||
)
|
||||
responses.append(future)
|
||||
|
||||
try:
|
||||
|
|
|
@ -39,6 +39,7 @@ from ..core import (
|
|||
AgentType,
|
||||
CancellationToken,
|
||||
MessageContext,
|
||||
MessageHandlerContext,
|
||||
Subscription,
|
||||
SubscriptionInstantiationContext,
|
||||
TopicId,
|
||||
|
@ -323,7 +324,8 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
|
||||
# Call the target agent.
|
||||
try:
|
||||
result = await target_agent.on_message(message, ctx=message_context)
|
||||
with MessageHandlerContext.populate_context(target_agent.id):
|
||||
result = await target_agent.on_message(message, ctx=message_context)
|
||||
except BaseException as e:
|
||||
response_message = agent_worker_pb2.Message(
|
||||
response=agent_worker_pb2.RpcResponse(
|
||||
|
@ -377,7 +379,8 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
cancellation_token=CancellationToken(),
|
||||
)
|
||||
agent = await self._get_agent(agent_id)
|
||||
future = agent.on_message(message, ctx=message_context)
|
||||
with MessageHandlerContext.populate_context(agent.id):
|
||||
future = agent.on_message(message, ctx=message_context)
|
||||
responses.append(future)
|
||||
# Wait for all responses.
|
||||
try:
|
||||
|
|
|
@ -3,9 +3,20 @@ The :mod:`agnext.components` module provides building blocks for creating single
|
|||
"""
|
||||
|
||||
from ._closure_agent import ClosureAgent
|
||||
from ._default_subscription import DefaultSubscription
|
||||
from ._default_topic import DefaultTopicId
|
||||
from ._image import Image
|
||||
from ._type_routed_agent import TypeRoutedAgent, message_handler
|
||||
from ._type_subscription import TypeSubscription
|
||||
from ._types import FunctionCall
|
||||
|
||||
__all__ = ["Image", "TypeRoutedAgent", "ClosureAgent", "message_handler", "FunctionCall", "TypeSubscription"]
|
||||
__all__ = [
|
||||
"Image",
|
||||
"TypeRoutedAgent",
|
||||
"ClosureAgent",
|
||||
"message_handler",
|
||||
"FunctionCall",
|
||||
"TypeSubscription",
|
||||
"DefaultSubscription",
|
||||
"DefaultTopicId",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
from agnext.core.exceptions import CantHandleException
|
||||
|
||||
from ..core import SubscriptionInstantiationContext
|
||||
from ._type_subscription import TypeSubscription
|
||||
|
||||
|
||||
class DefaultSubscription(TypeSubscription):
|
||||
def __init__(self, topic_type: str = "default", agent_type: str | None = None):
|
||||
"""The default subscription is designed to be a sensible default for applications that only need global scope for agents.
|
||||
|
||||
This topic by default uses the "default" topic type and attempts to detect the agent type to use based on the instantiation context.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
await runtime.register("MyAgent", agent_factory, lambda: [DefaultSubscription()])
|
||||
|
||||
Args:
|
||||
topic_type (str, optional): The topic type to subscribe to. Defaults to "default".
|
||||
agent_type (str, optional): The agent type to use for the subscription. Defaults to None, in which case it will attempt to detect the agent type based on the instantiation context.
|
||||
"""
|
||||
|
||||
if agent_type is None:
|
||||
try:
|
||||
agent_type = SubscriptionInstantiationContext.agent_type().type
|
||||
except RuntimeError as e:
|
||||
raise CantHandleException(
|
||||
"If agent_type is not specified DefaultSubscription must be created within the subscription callback in AgentRuntime.register"
|
||||
) from e
|
||||
|
||||
super().__init__(topic_type, agent_type)
|
|
@ -0,0 +1,21 @@
|
|||
from ..core import MessageHandlerContext, TopicId
|
||||
|
||||
|
||||
class DefaultTopicId(TopicId):
|
||||
def __init__(self, type: str = "default", source: str | None = None) -> None:
|
||||
"""DefaultTopicId provides a sensible default for the topic_id and source fields of a TopicId.
|
||||
|
||||
If created in the context of a message handler, the source will be set to the agent_id of the message handler, otherwise it will be set to "default".
|
||||
|
||||
Args:
|
||||
type (str, optional): Topic type to publish message to. Defaults to "default".
|
||||
source (str | None, optional): Topic source to publish message to. If None, the source will be set to the agent_id of the message handler if in the context of a message handler, otherwise it will be set to "default". Defaults to None.
|
||||
"""
|
||||
if source is None:
|
||||
try:
|
||||
source = MessageHandlerContext.agent_id().key
|
||||
# If we aren't in the context of a message handler, we use the default source
|
||||
except RuntimeError:
|
||||
source = "default"
|
||||
|
||||
super().__init__(type, source)
|
|
@ -13,6 +13,7 @@ from ._agent_type import AgentType
|
|||
from ._base_agent import BaseAgent
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._message_context import MessageContext
|
||||
from ._message_handler_context import MessageHandlerContext
|
||||
from ._serialization import MESSAGE_TYPE_REGISTRY, Serialization, TypeDeserializer, TypeSerializer
|
||||
from ._subscription import Subscription
|
||||
from ._subscription_context import SubscriptionInstantiationContext
|
||||
|
@ -37,4 +38,5 @@ __all__ = [
|
|||
"Serialization",
|
||||
"AgentType",
|
||||
"SubscriptionInstantiationContext",
|
||||
"MessageHandlerContext",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, ClassVar, Generator
|
||||
|
||||
from ._agent_id import AgentId
|
||||
|
||||
|
||||
class MessageHandlerContext:
|
||||
def __init__(self) -> None:
|
||||
raise RuntimeError(
|
||||
"MessageHandlerContext cannot be instantiated. It is a static class that provides context management for agent instantiation."
|
||||
)
|
||||
|
||||
MESSAGE_HANDLER_CONTEXT: ClassVar[ContextVar[AgentId]] = ContextVar("MESSAGE_HANDLER_CONTEXT")
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def populate_context(cls, ctx: AgentId) -> Generator[None, Any, None]:
|
||||
token = MessageHandlerContext.MESSAGE_HANDLER_CONTEXT.set(ctx)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
MessageHandlerContext.MESSAGE_HANDLER_CONTEXT.reset(token)
|
||||
|
||||
@classmethod
|
||||
def agent_id(cls) -> AgentId:
|
||||
try:
|
||||
return cls.MESSAGE_HANDLER_CONTEXT.get()
|
||||
except LookupError as e:
|
||||
raise RuntimeError("MessageHandlerContext.agent_id() must be called within a message handler.") from e
|
|
@ -1,11 +1,11 @@
|
|||
import asyncio
|
||||
import pytest
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.components import TypeSubscription, DefaultTopicId, DefaultSubscription
|
||||
from agnext.core import AgentId, AgentInstantiationContext
|
||||
from agnext.core import TopicId
|
||||
from agnext.core._subscription import Subscription
|
||||
from agnext.core._subscription_context import SubscriptionInstantiationContext
|
||||
from agnext.core import Subscription
|
||||
from agnext.core import SubscriptionInstantiationContext
|
||||
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent
|
||||
|
||||
|
||||
|
@ -163,3 +163,62 @@ async def test_register_factory_direct_list() -> None:
|
|||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_subscription() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
|
||||
runtime.start()
|
||||
agent_id = AgentId("name", key="default")
|
||||
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
|
||||
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_default_default_subscription() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription(topic_type="Other")])
|
||||
runtime.start()
|
||||
agent_id = AgentId("name", key="default")
|
||||
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(type="Other"))
|
||||
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_publish_to_other_source() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
|
||||
runtime.start()
|
||||
agent_id = AgentId("name", key="default")
|
||||
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(source="other"))
|
||||
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 0
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
|
||||
assert other_long_running_agent.num_calls == 1
|
||||
|
|
Loading…
Reference in New Issue