add subscription deduplication (#594)

* add subscription deduplication

* format

---------

Co-authored-by: Mohammad Mazraeh <mmazraeh@microsoft.com>
Co-authored-by: Ryan Sweet <rysweet@microsoft.com>
This commit is contained in:
Mohammad Mazraeh 2024-09-23 15:57:48 +00:00 committed by GitHub
parent 58ee8b7fc1
commit 1ac52729f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 29 additions and 3 deletions

View File

@ -20,7 +20,7 @@ from typing import List
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, AgentInstantiationContext
from autogen_core.components import DefaultTopicId, RoutedAgent, message_handler
from autogen_core.components._default_subscription import DefaultSubscription
from autogen_core.components import DefaultSubscription
from autogen_core.components.models import (
AssistantMessage,
ChatCompletionClient,

View File

@ -37,7 +37,7 @@ class SubscriptionManager:
async def add_subscription(self, subscription: Subscription) -> None:
# Check if the subscription already exists
if any(sub.id == subscription.id for sub in self._subscriptions):
if any(sub == subscription for sub in self._subscriptions):
raise ValueError("Subscription already exists")
self._subscriptions.append(subscription)

View File

@ -52,5 +52,11 @@ class TypeSubscription(Subscription):
return AgentId(type=self._agent_type, key=topic_id.source)
def __eq__(self, other: object) -> bool:
if not isinstance(other, TypeSubscription):
return False
return self.id == other.id or (self.agent_type == other.agent_type and self.topic_type == other.topic_type)
BaseAgentType = TypeVar("BaseAgentType", bound="BaseAgent")

View File

@ -5,7 +5,7 @@ import pytest
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, AgentRuntime, MessageContext
from autogen_core.components import ClosureAgent
from autogen_core.components._default_subscription import DefaultSubscription
from autogen_core.components import DefaultSubscription
from autogen_core.components._default_topic import DefaultTopicId

View File

@ -3,6 +3,7 @@ from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, TopicId
from autogen_core.base.exceptions import CantHandleException
from autogen_core.components import DefaultTopicId, TypeSubscription
from autogen_core.components import DefaultSubscription
from test_utils import LoopbackAgent, MessageType
@ -96,3 +97,22 @@ async def test_skipped_class_subscriptions() -> None:
AgentId("MyAgent", key="default"), type=LoopbackAgent
)
assert agent_instance.num_calls == 0
@pytest.mark.asyncio
async def test_subscription_deduplication() -> None:
runtime = SingleThreadedAgentRuntime()
agent_type = "MyAgent"
# Test TypeSubscription
type_subscription_1 = TypeSubscription("default", agent_type)
type_subscription_2 = TypeSubscription("default", agent_type)
await runtime.add_subscription(type_subscription_1)
with pytest.raises(ValueError, match="Subscription already exists"):
await runtime.add_subscription(type_subscription_2)
# Test DefaultSubscription
default_subscription = DefaultSubscription(agent_type=agent_type)
with pytest.raises(ValueError, match="Subscription already exists"):
await runtime.add_subscription(default_subscription)