fix: remove subscription on client disconnect in worker runtime (#3653)

* remove subscription on client disconnect in worker runtime

* address PR feedback

* remove outdated comment

* remove public properties

* fix mypy issue

* address PR feedback

---------

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
Mohammad Mazraeh 2024-10-05 15:15:01 +00:00 committed by GitHub
parent c63a034523
commit be5c0b5d3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 138 additions and 8 deletions

View File

@ -27,6 +27,7 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
self._pending_responses: Dict[int, Dict[str, Future[Any]]] = {}
self._background_tasks: Set[Task[Any]] = set()
self._subscription_manager = SubscriptionManager()
self._client_id_to_subscription_id_mapping: Dict[int, set[str]] = {}
async def OpenChannel( # type: ignore
self,
@ -68,13 +69,18 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
for future in self._pending_responses.pop(client_id, {}).values():
future.cancel()
# Remove the client id from the agent type to client id mapping.
async with self._agent_type_to_client_id_lock:
agent_types = [
agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id
]
for agent_type in agent_types:
del self._agent_type_to_client_id[agent_type]
logger.info(f"Client {client_id} disconnected.")
await self._on_client_disconnect(client_id)
async def _on_client_disconnect(self, client_id: int) -> None:
async with self._agent_type_to_client_id_lock:
agent_types = [agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id]
for agent_type in agent_types:
logger.info(f"Removing agent type {agent_type} from agent type to client id mapping")
del self._agent_type_to_client_id[agent_type]
for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, []):
logger.info(f"Client id {client_id} disconnected. Removing corresponding subscription with id {id}")
await self._subscription_manager.remove_subscription(sub_id)
logger.info(f"Client {client_id} disconnected successfully")
def _raise_on_exception(self, task: Task[Any]) -> None:
exception = task.exception()
@ -220,6 +226,8 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
)
try:
await self._subscription_manager.add_subscription(type_subscription)
subscription_ids = self._client_id_to_subscription_id_mapping.setdefault(client_id, set())
subscription_ids.add(type_subscription.id)
success = True
error = None
except ValueError as e:

View File

@ -46,3 +46,24 @@ class NoopAgent(BaseAgent):
async def on_message(self, message: Any, ctx: MessageContext) -> Any:
raise NotImplementedError
@dataclass
class MyMessage:
content: str
@default_subscription
class MyAgent(RoutedAgent):
def __init__(self, name: str) -> None:
super().__init__("My agent")
self._name = name
self._counter = 0
@message_handler
async def my_message_handler(self, message: MyMessage, ctx: MessageContext) -> None:
self._counter += 1
if self._counter > 5:
return
content = f"{self._name}: Hello x {self._counter}"
await self.publish_message(MyMessage(content=content), DefaultTopicId())

View File

@ -1,5 +1,6 @@
import asyncio
import logging
import os
from typing import List
import pytest
@ -10,13 +11,14 @@ from autogen_core.base import (
TopicId,
try_get_known_serializers_for_type,
)
from autogen_core.base._subscription import Subscription
from autogen_core.components import (
DefaultTopicId,
TypeSubscription,
default_subscription,
type_subscription,
)
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, MyAgent, MyMessage, NoopAgent
@pytest.mark.asyncio
@ -300,3 +302,102 @@ async def test_type_subscription() -> None:
await worker.stop()
await publisher.stop()
await host.stop()
@pytest.mark.asyncio
async def test_duplicate_subscription() -> None:
host_address = "localhost:50059"
host = WorkerAgentRuntimeHost(address=host_address)
worker1 = WorkerAgentRuntime(host_address=host_address)
worker1_2 = WorkerAgentRuntime(host_address=host_address)
host.start()
try:
worker1.start()
await MyAgent.register(worker1, "worker1", lambda: MyAgent("worker1"))
worker1_2.start()
# Note: This passes because worker1 is still running
with pytest.raises(RuntimeError, match="Agent type worker1 already registered"):
await MyAgent.register(worker1_2, "worker1", lambda: MyAgent("worker1_2"))
# This is somehow covered in test_disconnected_agent as well as a stop will also disconnect the agent.
# Will keep them both for now as we might replace the way we simulate a disconnect
await worker1.stop()
with pytest.raises(ValueError):
await MyAgent.register(worker1_2, "worker1", lambda: MyAgent("worker1_2"))
except Exception as ex:
raise ex
finally:
await worker1_2.stop()
await host.stop()
@pytest.mark.asyncio
async def test_disconnected_agent() -> None:
host_address = "localhost:50059"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
worker1 = WorkerAgentRuntime(host_address=host_address)
worker1_2 = WorkerAgentRuntime(host_address=host_address)
# TODO: Implementing `get_current_subscriptions` and `get_subscribed_recipients` requires access
# to some private properties. This needs to be updated once they are available publicly
def get_current_subscriptions() -> List[Subscription]:
return host._servicer._subscription_manager._subscriptions # type: ignore[reportPrivateUsage]
async def get_subscribed_recipients() -> List[AgentId]:
return await host._servicer._subscription_manager.get_subscribed_recipients(DefaultTopicId()) # type: ignore[reportPrivateUsage]
try:
worker1.start()
await MyAgent.register(worker1, "worker1", lambda: MyAgent("worker1"))
subscriptions1 = get_current_subscriptions()
assert len(subscriptions1) == 1
recipients1 = await get_subscribed_recipients()
assert AgentId(type="worker1", key="default") in recipients1
first_subscription_id = subscriptions1[0].id
await worker1.publish_message(MyMessage(content="Hello!"), DefaultTopicId())
# This is a simple simulation of worker disconnct
if worker1._host_connection is not None: # type: ignore[reportPrivateUsage]
try:
await worker1._host_connection.close() # type: ignore[reportPrivateUsage]
except asyncio.CancelledError:
pass
await asyncio.sleep(1)
subscriptions2 = get_current_subscriptions()
assert len(subscriptions2) == 0
recipients2 = await get_subscribed_recipients()
assert len(recipients2) == 0
await asyncio.sleep(1)
worker1_2.start()
await MyAgent.register(worker1_2, "worker1", lambda: MyAgent("worker1"))
subscriptions3 = get_current_subscriptions()
assert len(subscriptions3) == 1
assert first_subscription_id not in [x.id for x in subscriptions3]
recipients3 = await get_subscribed_recipients()
assert len(set(recipients2)) == len(recipients2) # Make sure there are no duplicates
assert AgentId(type="worker1", key="default") in recipients3
except Exception as ex:
raise ex
finally:
await worker1.stop()
await worker1_2.stop()
await host.stop()
if __name__ == "__main__":
os.environ["GRPC_VERBOSITY"] = "DEBUG"
os.environ["GRPC_TRACE"] = "all"
asyncio.run(test_disconnected_agent())