mirror of https://github.com/microsoft/autogen.git
add subscription deduplication (#594)
* add subscription deduplication * format --------- Co-authored-by: Mohammad Mazraeh <mmazraeh@microsoft.com> Co-authored-by: Ryan Sweet <rysweet@microsoft.com>
This commit is contained in:
parent
58ee8b7fc1
commit
1ac52729f4
|
@ -20,7 +20,7 @@ from typing import List
|
|||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core.base import AgentId, AgentInstantiationContext
|
||||
from autogen_core.components import DefaultTopicId, RoutedAgent, message_handler
|
||||
from autogen_core.components._default_subscription import DefaultSubscription
|
||||
from autogen_core.components import DefaultSubscription
|
||||
from autogen_core.components.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
|
|
|
@ -37,7 +37,7 @@ class SubscriptionManager:
|
|||
|
||||
async def add_subscription(self, subscription: Subscription) -> None:
|
||||
# Check if the subscription already exists
|
||||
if any(sub.id == subscription.id for sub in self._subscriptions):
|
||||
if any(sub == subscription for sub in self._subscriptions):
|
||||
raise ValueError("Subscription already exists")
|
||||
|
||||
self._subscriptions.append(subscription)
|
||||
|
|
|
@ -52,5 +52,11 @@ class TypeSubscription(Subscription):
|
|||
|
||||
return AgentId(type=self._agent_type, key=topic_id.source)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, TypeSubscription):
|
||||
return False
|
||||
|
||||
return self.id == other.id or (self.agent_type == other.agent_type and self.topic_type == other.topic_type)
|
||||
|
||||
|
||||
BaseAgentType = TypeVar("BaseAgentType", bound="BaseAgent")
|
||||
|
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core.base import AgentId, AgentRuntime, MessageContext
|
||||
from autogen_core.components import ClosureAgent
|
||||
from autogen_core.components._default_subscription import DefaultSubscription
|
||||
from autogen_core.components import DefaultSubscription
|
||||
from autogen_core.components._default_topic import DefaultTopicId
|
||||
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from autogen_core.application import SingleThreadedAgentRuntime
|
|||
from autogen_core.base import AgentId, TopicId
|
||||
from autogen_core.base.exceptions import CantHandleException
|
||||
from autogen_core.components import DefaultTopicId, TypeSubscription
|
||||
from autogen_core.components import DefaultSubscription
|
||||
from test_utils import LoopbackAgent, MessageType
|
||||
|
||||
|
||||
|
@ -96,3 +97,22 @@ async def test_skipped_class_subscriptions() -> None:
|
|||
AgentId("MyAgent", key="default"), type=LoopbackAgent
|
||||
)
|
||||
assert agent_instance.num_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscription_deduplication() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
agent_type = "MyAgent"
|
||||
|
||||
# Test TypeSubscription
|
||||
type_subscription_1 = TypeSubscription("default", agent_type)
|
||||
type_subscription_2 = TypeSubscription("default", agent_type)
|
||||
|
||||
await runtime.add_subscription(type_subscription_1)
|
||||
with pytest.raises(ValueError, match="Subscription already exists"):
|
||||
await runtime.add_subscription(type_subscription_2)
|
||||
|
||||
# Test DefaultSubscription
|
||||
default_subscription = DefaultSubscription(agent_type=agent_type)
|
||||
with pytest.raises(ValueError, match="Subscription already exists"):
|
||||
await runtime.add_subscription(default_subscription)
|
||||
|
|
Loading…
Reference in New Issue