From 1ac52729f47388e6ed35327c445be4cde9150604 Mon Sep 17 00:00:00 2001 From: Mohammad Mazraeh Date: Mon, 23 Sep 2024 15:57:48 +0000 Subject: [PATCH] add subscription deduplication (#594) * add subscription deduplication * format --------- Co-authored-by: Mohammad Mazraeh Co-authored-by: Ryan Sweet --- .../samples/patterns/group_chat.py | 2 +- .../src/autogen_core/application/_helpers.py | 2 +- .../components/_type_subscription.py | 6 ++++++ .../autogen-core/tests/test_closure_agent.py | 2 +- .../autogen-core/tests/test_subscription.py | 20 +++++++++++++++++++ 5 files changed, 29 insertions(+), 3 deletions(-) diff --git a/python/packages/autogen-core/samples/patterns/group_chat.py b/python/packages/autogen-core/samples/patterns/group_chat.py index 728728566..c5fc6dddd 100644 --- a/python/packages/autogen-core/samples/patterns/group_chat.py +++ b/python/packages/autogen-core/samples/patterns/group_chat.py @@ -20,7 +20,7 @@ from typing import List from autogen_core.application import SingleThreadedAgentRuntime from autogen_core.base import AgentId, AgentInstantiationContext from autogen_core.components import DefaultTopicId, RoutedAgent, message_handler -from autogen_core.components._default_subscription import DefaultSubscription +from autogen_core.components import DefaultSubscription from autogen_core.components.models import ( AssistantMessage, ChatCompletionClient, diff --git a/python/packages/autogen-core/src/autogen_core/application/_helpers.py b/python/packages/autogen-core/src/autogen_core/application/_helpers.py index abb38b98d..18286c617 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_helpers.py +++ b/python/packages/autogen-core/src/autogen_core/application/_helpers.py @@ -37,7 +37,7 @@ class SubscriptionManager: async def add_subscription(self, subscription: Subscription) -> None: # Check if the subscription already exists - if any(sub.id == subscription.id for sub in self._subscriptions): + if any(sub == subscription for sub in self._subscriptions): raise ValueError("Subscription already exists") self._subscriptions.append(subscription) diff --git a/python/packages/autogen-core/src/autogen_core/components/_type_subscription.py b/python/packages/autogen-core/src/autogen_core/components/_type_subscription.py index 18e478de5..c745d2d0e 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_type_subscription.py +++ b/python/packages/autogen-core/src/autogen_core/components/_type_subscription.py @@ -52,5 +52,11 @@ class TypeSubscription(Subscription): return AgentId(type=self._agent_type, key=topic_id.source) + def __eq__(self, other: object) -> bool: + if not isinstance(other, TypeSubscription): + return False + + return self.id == other.id or (self.agent_type == other.agent_type and self.topic_type == other.topic_type) + BaseAgentType = TypeVar("BaseAgentType", bound="BaseAgent") diff --git a/python/packages/autogen-core/tests/test_closure_agent.py b/python/packages/autogen-core/tests/test_closure_agent.py index 0b3c07ae7..4b4482e22 100644 --- a/python/packages/autogen-core/tests/test_closure_agent.py +++ b/python/packages/autogen-core/tests/test_closure_agent.py @@ -5,7 +5,7 @@ import pytest from autogen_core.application import SingleThreadedAgentRuntime from autogen_core.base import AgentId, AgentRuntime, MessageContext from autogen_core.components import ClosureAgent -from autogen_core.components._default_subscription import DefaultSubscription +from autogen_core.components import DefaultSubscription from autogen_core.components._default_topic import DefaultTopicId diff --git a/python/packages/autogen-core/tests/test_subscription.py b/python/packages/autogen-core/tests/test_subscription.py index cb531738c..7a5403f99 100644 --- a/python/packages/autogen-core/tests/test_subscription.py +++ b/python/packages/autogen-core/tests/test_subscription.py @@ -3,6 +3,7 @@ from autogen_core.application import SingleThreadedAgentRuntime from autogen_core.base import AgentId, TopicId from autogen_core.base.exceptions import CantHandleException from autogen_core.components import DefaultTopicId, TypeSubscription +from autogen_core.components import DefaultSubscription from test_utils import LoopbackAgent, MessageType @@ -96,3 +97,22 @@ async def test_skipped_class_subscriptions() -> None: AgentId("MyAgent", key="default"), type=LoopbackAgent ) assert agent_instance.num_calls == 0 + + +@pytest.mark.asyncio +async def test_subscription_deduplication() -> None: + runtime = SingleThreadedAgentRuntime() + agent_type = "MyAgent" + + # Test TypeSubscription + type_subscription_1 = TypeSubscription("default", agent_type) + type_subscription_2 = TypeSubscription("default", agent_type) + + await runtime.add_subscription(type_subscription_1) + with pytest.raises(ValueError, match="Subscription already exists"): + await runtime.add_subscription(type_subscription_2) + + # Test DefaultSubscription + default_subscription = DefaultSubscription(agent_type=agent_type) + with pytest.raises(ValueError, match="Subscription already exists"): + await runtime.add_subscription(default_subscription)