mirror of https://github.com/microsoft/autogen.git
fix: remove subscription on client disconnect in worker runtime (#3653)
* remove subscription on client disconnect in worker runtime * address PR feedback * remove outdated comment * remove public properties * fix mypy issue * address PR feedback --------- Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
parent
c63a034523
commit
be5c0b5d3e
|
@ -27,6 +27,7 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
|||
self._pending_responses: Dict[int, Dict[str, Future[Any]]] = {}
|
||||
self._background_tasks: Set[Task[Any]] = set()
|
||||
self._subscription_manager = SubscriptionManager()
|
||||
self._client_id_to_subscription_id_mapping: Dict[int, set[str]] = {}
|
||||
|
||||
async def OpenChannel( # type: ignore
|
||||
self,
|
||||
|
@ -68,13 +69,18 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
|||
for future in self._pending_responses.pop(client_id, {}).values():
|
||||
future.cancel()
|
||||
# Remove the client id from the agent type to client id mapping.
|
||||
async with self._agent_type_to_client_id_lock:
|
||||
agent_types = [
|
||||
agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id
|
||||
]
|
||||
for agent_type in agent_types:
|
||||
del self._agent_type_to_client_id[agent_type]
|
||||
logger.info(f"Client {client_id} disconnected.")
|
||||
await self._on_client_disconnect(client_id)
|
||||
|
||||
async def _on_client_disconnect(self, client_id: int) -> None:
|
||||
async with self._agent_type_to_client_id_lock:
|
||||
agent_types = [agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id]
|
||||
for agent_type in agent_types:
|
||||
logger.info(f"Removing agent type {agent_type} from agent type to client id mapping")
|
||||
del self._agent_type_to_client_id[agent_type]
|
||||
for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, []):
|
||||
logger.info(f"Client id {client_id} disconnected. Removing corresponding subscription with id {id}")
|
||||
await self._subscription_manager.remove_subscription(sub_id)
|
||||
logger.info(f"Client {client_id} disconnected successfully")
|
||||
|
||||
def _raise_on_exception(self, task: Task[Any]) -> None:
|
||||
exception = task.exception()
|
||||
|
@ -220,6 +226,8 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
|||
)
|
||||
try:
|
||||
await self._subscription_manager.add_subscription(type_subscription)
|
||||
subscription_ids = self._client_id_to_subscription_id_mapping.setdefault(client_id, set())
|
||||
subscription_ids.add(type_subscription.id)
|
||||
success = True
|
||||
error = None
|
||||
except ValueError as e:
|
||||
|
|
|
@ -46,3 +46,24 @@ class NoopAgent(BaseAgent):
|
|||
|
||||
async def on_message(self, message: Any, ctx: MessageContext) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyMessage:
|
||||
content: str
|
||||
|
||||
|
||||
@default_subscription
|
||||
class MyAgent(RoutedAgent):
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__("My agent")
|
||||
self._name = name
|
||||
self._counter = 0
|
||||
|
||||
@message_handler
|
||||
async def my_message_handler(self, message: MyMessage, ctx: MessageContext) -> None:
|
||||
self._counter += 1
|
||||
if self._counter > 5:
|
||||
return
|
||||
content = f"{self._name}: Hello x {self._counter}"
|
||||
await self.publish_message(MyMessage(content=content), DefaultTopicId())
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
@ -10,13 +11,14 @@ from autogen_core.base import (
|
|||
TopicId,
|
||||
try_get_known_serializers_for_type,
|
||||
)
|
||||
from autogen_core.base._subscription import Subscription
|
||||
from autogen_core.components import (
|
||||
DefaultTopicId,
|
||||
TypeSubscription,
|
||||
default_subscription,
|
||||
type_subscription,
|
||||
)
|
||||
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent
|
||||
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, MyAgent, MyMessage, NoopAgent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -300,3 +302,102 @@ async def test_type_subscription() -> None:
|
|||
await worker.stop()
|
||||
await publisher.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_subscription() -> None:
|
||||
host_address = "localhost:50059"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
worker1 = WorkerAgentRuntime(host_address=host_address)
|
||||
worker1_2 = WorkerAgentRuntime(host_address=host_address)
|
||||
host.start()
|
||||
try:
|
||||
worker1.start()
|
||||
await MyAgent.register(worker1, "worker1", lambda: MyAgent("worker1"))
|
||||
|
||||
worker1_2.start()
|
||||
|
||||
# Note: This passes because worker1 is still running
|
||||
with pytest.raises(RuntimeError, match="Agent type worker1 already registered"):
|
||||
await MyAgent.register(worker1_2, "worker1", lambda: MyAgent("worker1_2"))
|
||||
|
||||
# This is somehow covered in test_disconnected_agent as well as a stop will also disconnect the agent.
|
||||
# Will keep them both for now as we might replace the way we simulate a disconnect
|
||||
await worker1.stop()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await MyAgent.register(worker1_2, "worker1", lambda: MyAgent("worker1_2"))
|
||||
|
||||
except Exception as ex:
|
||||
raise ex
|
||||
finally:
|
||||
await worker1_2.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnected_agent() -> None:
|
||||
host_address = "localhost:50059"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
worker1 = WorkerAgentRuntime(host_address=host_address)
|
||||
worker1_2 = WorkerAgentRuntime(host_address=host_address)
|
||||
|
||||
# TODO: Implementing `get_current_subscriptions` and `get_subscribed_recipients` requires access
|
||||
# to some private properties. This needs to be updated once they are available publicly
|
||||
|
||||
def get_current_subscriptions() -> List[Subscription]:
|
||||
return host._servicer._subscription_manager._subscriptions # type: ignore[reportPrivateUsage]
|
||||
|
||||
async def get_subscribed_recipients() -> List[AgentId]:
|
||||
return await host._servicer._subscription_manager.get_subscribed_recipients(DefaultTopicId()) # type: ignore[reportPrivateUsage]
|
||||
|
||||
try:
|
||||
worker1.start()
|
||||
await MyAgent.register(worker1, "worker1", lambda: MyAgent("worker1"))
|
||||
|
||||
subscriptions1 = get_current_subscriptions()
|
||||
assert len(subscriptions1) == 1
|
||||
recipients1 = await get_subscribed_recipients()
|
||||
assert AgentId(type="worker1", key="default") in recipients1
|
||||
|
||||
first_subscription_id = subscriptions1[0].id
|
||||
|
||||
await worker1.publish_message(MyMessage(content="Hello!"), DefaultTopicId())
|
||||
# This is a simple simulation of worker disconnct
|
||||
if worker1._host_connection is not None: # type: ignore[reportPrivateUsage]
|
||||
try:
|
||||
await worker1._host_connection.close() # type: ignore[reportPrivateUsage]
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
subscriptions2 = get_current_subscriptions()
|
||||
assert len(subscriptions2) == 0
|
||||
recipients2 = await get_subscribed_recipients()
|
||||
assert len(recipients2) == 0
|
||||
await asyncio.sleep(1)
|
||||
|
||||
worker1_2.start()
|
||||
await MyAgent.register(worker1_2, "worker1", lambda: MyAgent("worker1"))
|
||||
|
||||
subscriptions3 = get_current_subscriptions()
|
||||
assert len(subscriptions3) == 1
|
||||
assert first_subscription_id not in [x.id for x in subscriptions3]
|
||||
|
||||
recipients3 = await get_subscribed_recipients()
|
||||
assert len(set(recipients2)) == len(recipients2) # Make sure there are no duplicates
|
||||
assert AgentId(type="worker1", key="default") in recipients3
|
||||
except Exception as ex:
|
||||
raise ex
|
||||
finally:
|
||||
await worker1.stop()
|
||||
await worker1_2.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["GRPC_VERBOSITY"] = "DEBUG"
|
||||
os.environ["GRPC_TRACE"] = "all"
|
||||
asyncio.run(test_disconnected_agent())
|
||||
|
|
Loading…
Reference in New Issue