Add response for registering agent type and adding subscriptions in distributed runtime, and throws an exception when error (#582)

* Add response for registering agent type and adding subscriptions in distributed runtime

* Update unit tests

* lint
This commit is contained in:
Eric Zhu 2024-09-19 10:50:17 -07:00 committed by GitHub
parent d624367894
commit 8018677234
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 398 additions and 132 deletions

View File

@ -45,8 +45,15 @@ message Event {
map<string, string> metadata = 5;
}
message RegisterAgentType {
string type = 1;
message RegisterAgentTypeRequest {
string request_id = 1;
string type = 2;
}
message RegisterAgentTypeResponse {
string request_id = 1;
bool success = 2;
optional string error = 3;
}
message TypeSubscription {
@ -60,8 +67,15 @@ message Subscription {
}
}
message AddSubscription {
Subscription subscription = 1;
message AddSubscriptionRequest {
string request_id = 1;
Subscription subscription = 2;
}
message AddSubscriptionResponse {
string request_id = 1;
bool success = 2;
optional string error = 3;
}
service AgentRpc {
@ -73,9 +87,11 @@ message Message {
RpcRequest request = 1;
RpcResponse response = 2;
Event event = 3;
RegisterAgentType registerAgentType = 4;
AddSubscription addSubscription = 5;
cloudevent.CloudEvent cloudEvent = 6;
RegisterAgentTypeRequest registerAgentTypeRequest = 4;
RegisterAgentTypeResponse registerAgentTypeResponse = 5;
AddSubscriptionRequest addSubscriptionRequest = 6;
AddSubscriptionResponse addSubscriptionResponse = 7;
cloudevent.CloudEvent cloudEvent = 8;
}
}

View File

@ -0,0 +1,43 @@
from dataclasses import dataclass
from autogen_core.base import MessageContext
from autogen_core.components import DefaultTopicId, RoutedAgent, default_subscription, message_handler
@dataclass
class CascadingMessage:
round: int
@dataclass
class ReceiveMessageEvent:
round: int
sender: str
recipient: str
@default_subscription
class CascadingAgent(RoutedAgent):
def __init__(self, max_rounds: int) -> None:
super().__init__("A cascading agent.")
self.max_rounds = max_rounds
@message_handler
async def on_new_message(self, message: CascadingMessage, ctx: MessageContext) -> None:
await self.publish_message(
ReceiveMessageEvent(round=message.round, sender=str(ctx.sender), recipient=str(self.id)),
topic_id=DefaultTopicId(),
)
if message.round == self.max_rounds:
return
await self.publish_message(CascadingMessage(round=message.round + 1), topic_id=DefaultTopicId())
@default_subscription
class ObserverAgent(RoutedAgent):
def __init__(self) -> None:
super().__init__("An observer agent.")
@message_handler
async def on_receive_message(self, message: ReceiveMessageEvent, ctx: MessageContext) -> None:
print(f"[Round {message.round}]: Message from {message.sender} to {message.recipient}.")

View File

@ -0,0 +1,22 @@
from agents import CascadingMessage, ObserverAgent
from autogen_core.application import WorkerAgentRuntime
from autogen_core.base import try_get_known_serializers_for_type
from autogen_core.components import DefaultTopicId
async def main() -> None:
runtime = WorkerAgentRuntime(host_address="localhost:50051")
runtime.add_message_serializer(try_get_known_serializers_for_type(CascadingMessage))
runtime.start()
await ObserverAgent.register(runtime, "observer_agent", lambda: ObserverAgent())
await runtime.publish_message(CascadingMessage(round=1), topic_id=DefaultTopicId())
await runtime.stop_when_signal()
if __name__ == "__main__":
# import logging
# logging.basicConfig(level=logging.DEBUG)
# logger = logging.getLogger("autogen_core")
import asyncio
asyncio.run(main())

View File

@ -0,0 +1,24 @@
import uuid
from agents import CascadingAgent, ReceiveMessageEvent
from autogen_core.application import WorkerAgentRuntime
from autogen_core.base import try_get_known_serializers_for_type
async def main() -> None:
runtime = WorkerAgentRuntime(host_address="localhost:50051")
runtime.add_message_serializer(try_get_known_serializers_for_type(ReceiveMessageEvent))
runtime.start()
agent_type = f"cascading_agent_{uuid.uuid4()}".replace("-", "_")
await CascadingAgent.register(runtime, agent_type, lambda: CascadingAgent(max_rounds=3))
await runtime.stop_when_signal()
if __name__ == "__main__":
import logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("autogen_core")
import asyncio
asyncio.run(main())

View File

@ -29,7 +29,7 @@ from typing import (
import grpc
from grpc.aio import StreamStreamCall
from opentelemetry.trace import NoOpTracerProvider, TracerProvider
from opentelemetry.trace import TracerProvider
from typing_extensions import Self, deprecated
from autogen_core.base import JSON_DATA_CONTENT_TYPE
@ -194,23 +194,34 @@ class WorkerAgentRuntime(AgentRuntime):
message = await self._host_connection.recv() # type: ignore
oneofcase = agent_worker_pb2.Message.WhichOneof(message, "message")
match oneofcase:
case "registerAgentType" | "addSubscription":
case "registerAgentTypeRequest" | "addSubscriptionRequest":
logger.warning(f"Cant handle {oneofcase}, skipping.")
case "request":
request: agent_worker_pb2.RpcRequest = message.request
task = asyncio.create_task(self._process_request(request))
task = asyncio.create_task(self._process_request(message.request))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "response":
response: agent_worker_pb2.RpcResponse = message.response
task = asyncio.create_task(self._process_response(response))
task = asyncio.create_task(self._process_response(message.response))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "event":
event: agent_worker_pb2.Event = message.event
task = asyncio.create_task(self._process_event(event))
task = asyncio.create_task(self._process_event(message.event))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "registerAgentTypeResponse":
task = asyncio.create_task(
self._process_register_agent_type_response(message.registerAgentTypeResponse)
)
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "addSubscriptionResponse":
task = asyncio.create_task(
self._process_add_subscription_response(message.addSubscriptionResponse)
)
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
@ -295,18 +306,15 @@ class WorkerAgentRuntime(AgentRuntime):
):
# create a new future for the result
future = asyncio.get_event_loop().create_future()
async with self._pending_requests_lock:
self._next_request_id += 1
request_id = self._next_request_id
request_id_str = str(request_id)
self._pending_requests[request_id_str] = future
request_id = await self._get_new_request_id()
self._pending_requests[request_id] = future
serialized_message = self._serialization_registry.serialize(
message, type_name=data_type, data_content_type=JSON_DATA_CONTENT_TYPE
)
telemetry_metadata = get_telemetry_grpc_metadata()
runtime_message = agent_worker_pb2.Message(
request=agent_worker_pb2.RpcRequest(
request_id=request_id_str,
request_id=request_id,
target=agent_worker_pb2.AgentId(type=recipient.type, key=recipient.key),
source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key) if sender is not None else None,
metadata=telemetry_metadata,
@ -379,6 +387,11 @@ class WorkerAgentRuntime(AgentRuntime):
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
raise NotImplementedError("Agent load_state is not yet implemented.")
async def _get_new_request_id(self) -> str:
async with self._pending_requests_lock:
self._next_request_id += 1
return str(self._next_request_id)
async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None:
assert self._host_connection is not None
recipient = AgentId(request.target.type, request.target.key)
@ -529,9 +542,21 @@ class WorkerAgentRuntime(AgentRuntime):
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
message = agent_worker_pb2.Message(registerAgentType=agent_worker_pb2.RegisterAgentType(type=type))
# Create a future for the registration response.
future = asyncio.get_event_loop().create_future()
request_id = await self._get_new_request_id()
self._pending_requests[request_id] = future
# Send the registration request message to the host.
message = agent_worker_pb2.Message(
registerAgentTypeRequest=agent_worker_pb2.RegisterAgentTypeRequest(request_id=request_id, type=type)
)
await self._host_connection.send(message)
# Wait for the registration response.
await future
if subscriptions is not None:
if callable(subscriptions):
with SubscriptionInstantiationContext.populate_context(AgentType(type)):
@ -574,11 +599,29 @@ class WorkerAgentRuntime(AgentRuntime):
self._agent_factories[type.type] = factory_wrapper
message = agent_worker_pb2.Message(registerAgentType=agent_worker_pb2.RegisterAgentType(type=type.type))
# Create a future for the registration response.
future = asyncio.get_event_loop().create_future()
request_id = await self._get_new_request_id()
self._pending_requests[request_id] = future
# Send the registration request message to the host.
message = agent_worker_pb2.Message(
registerAgentTypeRequest=agent_worker_pb2.RegisterAgentTypeRequest(request_id=request_id, type=type.type)
)
await self._host_connection.send(message)
# Wait for the registration response.
await future
return type
async def _process_register_agent_type_response(self, response: agent_worker_pb2.RegisterAgentTypeResponse) -> None:
future = self._pending_requests.pop(response.request_id)
if response.HasField("error"):
future.set_exception(RuntimeError(response.error))
else:
future.set_result(None)
async def _invoke_agent_factory(
self,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
@ -635,18 +678,35 @@ class WorkerAgentRuntime(AgentRuntime):
raise ValueError("Only TypeSubscription is supported.")
# Add to local subscription manager.
await self._subscription_manager.add_subscription(subscription)
# Create a future for the subscription response.
future = asyncio.get_event_loop().create_future()
request_id = await self._get_new_request_id()
self._pending_requests[request_id] = future
# Send the subscription to the host.
message = agent_worker_pb2.Message(
addSubscription=agent_worker_pb2.AddSubscription(
addSubscriptionRequest=agent_worker_pb2.AddSubscriptionRequest(
request_id=request_id,
subscription=agent_worker_pb2.Subscription(
typeSubscription=agent_worker_pb2.TypeSubscription(
topic_type=subscription.topic_type, agent_type=subscription.agent_type
)
)
),
)
)
await self._host_connection.send(message)
# Wait for the subscription response.
await future
async def _process_add_subscription_response(self, response: agent_worker_pb2.AddSubscriptionResponse) -> None:
future = self._pending_requests.pop(response.request_id)
if response.HasField("error"):
future.set_exception(RuntimeError(response.error))
else:
future.set_result(None)
async def remove_subscription(self, id: str) -> None:
raise NotImplementedError("Subscriptions cannot be removed while using distributed runtime currently.")

View File

@ -107,18 +107,22 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "registerAgentType":
register_agent_type: agent_worker_pb2.RegisterAgentType = message.registerAgentType
task = asyncio.create_task(self._process_register_agent_type(register_agent_type, client_id))
case "registerAgentTypeRequest":
register_agent_type: agent_worker_pb2.RegisterAgentTypeRequest = message.registerAgentTypeRequest
task = asyncio.create_task(
self._process_register_agent_type_request(register_agent_type, client_id)
)
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "addSubscription":
add_subscription: agent_worker_pb2.AddSubscription = message.addSubscription
task = asyncio.create_task(self._process_add_subscription(add_subscription))
case "addSubscriptionRequest":
add_subscription: agent_worker_pb2.AddSubscriptionRequest = message.addSubscriptionRequest
task = asyncio.create_task(self._process_add_subscription_request(add_subscription, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "registerAgentTypeResponse" | "addSubscriptionResponse":
logger.warning(f"Received unexpected message type: {oneofcase}")
case None:
logger.warning("Received empty message")
@ -175,32 +179,57 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
for client_id in client_ids:
await self._send_queues[client_id].put(agent_worker_pb2.Message(event=event))
async def _process_register_agent_type(
self, register_agent_type: agent_worker_pb2.RegisterAgentType, client_id: int
async def _process_register_agent_type_request(
self, register_agent_type_req: agent_worker_pb2.RegisterAgentTypeRequest, client_id: int
) -> None:
# Register the agent type with the host runtime.
async with self._agent_type_to_client_id_lock:
if register_agent_type.type in self._agent_type_to_client_id:
existing_client_id = self._agent_type_to_client_id[register_agent_type.type]
if register_agent_type_req.type in self._agent_type_to_client_id:
existing_client_id = self._agent_type_to_client_id[register_agent_type_req.type]
logger.error(
f"Agent type {register_agent_type.type} already registered with client {existing_client_id}."
f"Agent type {register_agent_type_req.type} already registered with client {existing_client_id}."
)
# TODO: send an error response back to the client.
success = False
error = f"Agent type {register_agent_type_req.type} already registered."
else:
self._agent_type_to_client_id[register_agent_type.type] = client_id
# TODO: send a success response back to the client.
self._agent_type_to_client_id[register_agent_type_req.type] = client_id
success = True
error = None
# Send a response back to the client.
await self._send_queues[client_id].put(
agent_worker_pb2.Message(
registerAgentTypeResponse=agent_worker_pb2.RegisterAgentTypeResponse(
request_id=register_agent_type_req.request_id, success=success, error=error
)
)
)
async def _process_add_subscription(self, add_subscription: agent_worker_pb2.AddSubscription) -> None:
oneofcase = add_subscription.subscription.WhichOneof("subscription")
async def _process_add_subscription_request(
self, add_subscription_req: agent_worker_pb2.AddSubscriptionRequest, client_id: int
) -> None:
oneofcase = add_subscription_req.subscription.WhichOneof("subscription")
match oneofcase:
case "typeSubscription":
type_subscription_msg: agent_worker_pb2.TypeSubscription = (
add_subscription.subscription.typeSubscription
add_subscription_req.subscription.typeSubscription
)
type_subscription = TypeSubscription(
topic_type=type_subscription_msg.topic_type, agent_type=type_subscription_msg.agent_type
)
await self._subscription_manager.add_subscription(type_subscription)
# TODO: send a success response back to the client.
try:
await self._subscription_manager.add_subscription(type_subscription)
success = True
error = None
except ValueError as e:
success = False
error = str(e)
# Send a response back to the client.
await self._send_queues[client_id].put(
agent_worker_pb2.Message(
addSubscriptionResponse=agent_worker_pb2.AddSubscriptionResponse(
request_id=add_subscription_req.request_id, success=success, error=error
)
)
)
case None:
logger.warning("Received empty subscription message")

View File

@ -6,24 +6,3 @@ import os
import sys
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
from typing import TYPE_CHECKING
from .agent_worker_pb2 import AgentId, Event, Message, RegisterAgentType, RpcRequest, RpcResponse
from .agent_worker_pb2_grpc import AgentRpcServicer, AgentRpcStub, add_AgentRpcServicer_to_server
if TYPE_CHECKING:
from .agent_worker_pb2_grpc import AgentRpcAsyncStub
__all__ = [
"RpcRequest",
"RpcResponse",
"Event",
"RegisterAgentType",
"AgentRpcAsyncStub",
"AgentRpcStub",
"Message",
"AgentId"
]
else:
__all__ = ["RpcRequest", "RpcResponse", "Event", "RegisterAgentType", "AgentRpcStub", "Message", "AgentId"]

View File

@ -16,7 +16,7 @@ import cloudevent_pb2 as cloudevent__pb2
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"\'\n\x07TopicId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xe4\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12$\n\x06source\x18\x03 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12 \n\x07payload\x18\x04 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"T\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x42\x0e\n\x0csubscription\"=\n\x0f\x41\x64\x64Subscription\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\x9e\x02\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x12\x32\n\x0f\x61\x64\x64Subscription\x18\x05 \x01(\x0b\x32\x17.agents.AddSubscriptionH\x00\x12,\n\ncloudEvent\x18\x06 \x01(\x0b\x32\x16.cloudevent.CloudEventH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"\'\n\x07TopicId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xe4\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12$\n\x06source\x18\x03 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12 \n\x07payload\x18\x04 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"<\n\x18RegisterAgentTypeRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\"^\n\x19RegisterAgentTypeResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"T\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x42\x0e\n\x0csubscription\"X\n\x16\x41\x64\x64SubscriptionRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12*\n\x0csubscription\x18\x02 \x01(\x0b\x32\x14.agents.Subscription\"\\\n\x17\x41\x64\x64SubscriptionResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\xc6\x03\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x44\n\x18registerAgentTypeRequest\x18\x04 \x01(\x0b\x32 .agents.RegisterAgentTypeRequestH\x00\x12\x46\n\x19registerAgentTypeResponse\x18\x05 \x01(\x0b\x32!.agents.RegisterAgentTypeResponseH\x00\x12@\n\x16\x61\x64\x64SubscriptionRequest\x18\x06 \x01(\x0b\x32\x1e.agents.AddSubscriptionRequestH\x00\x12\x42\n\x17\x61\x64\x64SubscriptionResponse\x18\x07 \x01(\x0b\x32\x1f.agents.AddSubscriptionResponseH\x00\x12,\n\ncloudEvent\x18\x08 \x01(\x0b\x32\x16.cloudevent.CloudEventH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@ -47,16 +47,20 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_globals['_EVENT']._serialized_end=909
_globals['_EVENT_METADATAENTRY']._serialized_start=433
_globals['_EVENT_METADATAENTRY']._serialized_end=480
_globals['_REGISTERAGENTTYPE']._serialized_start=911
_globals['_REGISTERAGENTTYPE']._serialized_end=944
_globals['_TYPESUBSCRIPTION']._serialized_start=946
_globals['_TYPESUBSCRIPTION']._serialized_end=1004
_globals['_SUBSCRIPTION']._serialized_start=1006
_globals['_SUBSCRIPTION']._serialized_end=1090
_globals['_ADDSUBSCRIPTION']._serialized_start=1092
_globals['_ADDSUBSCRIPTION']._serialized_end=1153
_globals['_MESSAGE']._serialized_start=1156
_globals['_MESSAGE']._serialized_end=1442
_globals['_AGENTRPC']._serialized_start=1444
_globals['_AGENTRPC']._serialized_end=1507
_globals['_REGISTERAGENTTYPEREQUEST']._serialized_start=911
_globals['_REGISTERAGENTTYPEREQUEST']._serialized_end=971
_globals['_REGISTERAGENTTYPERESPONSE']._serialized_start=973
_globals['_REGISTERAGENTTYPERESPONSE']._serialized_end=1067
_globals['_TYPESUBSCRIPTION']._serialized_start=1069
_globals['_TYPESUBSCRIPTION']._serialized_end=1127
_globals['_SUBSCRIPTION']._serialized_start=1129
_globals['_SUBSCRIPTION']._serialized_end=1213
_globals['_ADDSUBSCRIPTIONREQUEST']._serialized_start=1215
_globals['_ADDSUBSCRIPTIONREQUEST']._serialized_end=1303
_globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_start=1305
_globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_end=1397
_globals['_MESSAGE']._serialized_start=1400
_globals['_MESSAGE']._serialized_end=1854
_globals['_AGENTRPC']._serialized_start=1856
_globals['_AGENTRPC']._serialized_end=1919
# @@protoc_insertion_point(module_scope)

View File

@ -214,19 +214,45 @@ class Event(google.protobuf.message.Message):
global___Event = Event
@typing.final
class RegisterAgentType(google.protobuf.message.Message):
class RegisterAgentTypeRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
REQUEST_ID_FIELD_NUMBER: builtins.int
TYPE_FIELD_NUMBER: builtins.int
request_id: builtins.str
type: builtins.str
def __init__(
self,
*,
request_id: builtins.str = ...,
type: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["type", b"type"]) -> None: ...
def ClearField(self, field_name: typing.Literal["request_id", b"request_id", "type", b"type"]) -> None: ...
global___RegisterAgentType = RegisterAgentType
global___RegisterAgentTypeRequest = RegisterAgentTypeRequest
@typing.final
class RegisterAgentTypeResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
REQUEST_ID_FIELD_NUMBER: builtins.int
SUCCESS_FIELD_NUMBER: builtins.int
ERROR_FIELD_NUMBER: builtins.int
request_id: builtins.str
success: builtins.bool
error: builtins.str
def __init__(
self,
*,
request_id: builtins.str = ...,
success: builtins.bool = ...,
error: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error", "request_id", b"request_id", "success", b"success"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ...
global___RegisterAgentTypeResponse = RegisterAgentTypeResponse
@typing.final
class TypeSubscription(google.protobuf.message.Message):
@ -265,21 +291,47 @@ class Subscription(google.protobuf.message.Message):
global___Subscription = Subscription
@typing.final
class AddSubscription(google.protobuf.message.Message):
class AddSubscriptionRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
REQUEST_ID_FIELD_NUMBER: builtins.int
SUBSCRIPTION_FIELD_NUMBER: builtins.int
request_id: builtins.str
@property
def subscription(self) -> global___Subscription: ...
def __init__(
self,
*,
request_id: builtins.str = ...,
subscription: global___Subscription | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["subscription", b"subscription"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["subscription", b"subscription"]) -> None: ...
def ClearField(self, field_name: typing.Literal["request_id", b"request_id", "subscription", b"subscription"]) -> None: ...
global___AddSubscription = AddSubscription
global___AddSubscriptionRequest = AddSubscriptionRequest
@typing.final
class AddSubscriptionResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
REQUEST_ID_FIELD_NUMBER: builtins.int
SUCCESS_FIELD_NUMBER: builtins.int
ERROR_FIELD_NUMBER: builtins.int
request_id: builtins.str
success: builtins.bool
error: builtins.str
def __init__(
self,
*,
request_id: builtins.str = ...,
success: builtins.bool = ...,
error: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error", "request_id", b"request_id", "success", b"success"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ...
global___AddSubscriptionResponse = AddSubscriptionResponse
@typing.final
class Message(google.protobuf.message.Message):
@ -288,8 +340,10 @@ class Message(google.protobuf.message.Message):
REQUEST_FIELD_NUMBER: builtins.int
RESPONSE_FIELD_NUMBER: builtins.int
EVENT_FIELD_NUMBER: builtins.int
REGISTERAGENTTYPE_FIELD_NUMBER: builtins.int
ADDSUBSCRIPTION_FIELD_NUMBER: builtins.int
REGISTERAGENTTYPEREQUEST_FIELD_NUMBER: builtins.int
REGISTERAGENTTYPERESPONSE_FIELD_NUMBER: builtins.int
ADDSUBSCRIPTIONREQUEST_FIELD_NUMBER: builtins.int
ADDSUBSCRIPTIONRESPONSE_FIELD_NUMBER: builtins.int
CLOUDEVENT_FIELD_NUMBER: builtins.int
@property
def request(self) -> global___RpcRequest: ...
@ -298,9 +352,13 @@ class Message(google.protobuf.message.Message):
@property
def event(self) -> global___Event: ...
@property
def registerAgentType(self) -> global___RegisterAgentType: ...
def registerAgentTypeRequest(self) -> global___RegisterAgentTypeRequest: ...
@property
def addSubscription(self) -> global___AddSubscription: ...
def registerAgentTypeResponse(self) -> global___RegisterAgentTypeResponse: ...
@property
def addSubscriptionRequest(self) -> global___AddSubscriptionRequest: ...
@property
def addSubscriptionResponse(self) -> global___AddSubscriptionResponse: ...
@property
def cloudEvent(self) -> cloudevent_pb2.CloudEvent: ...
def __init__(
@ -309,12 +367,14 @@ class Message(google.protobuf.message.Message):
request: global___RpcRequest | None = ...,
response: global___RpcResponse | None = ...,
event: global___Event | None = ...,
registerAgentType: global___RegisterAgentType | None = ...,
addSubscription: global___AddSubscription | None = ...,
registerAgentTypeRequest: global___RegisterAgentTypeRequest | None = ...,
registerAgentTypeResponse: global___RegisterAgentTypeResponse | None = ...,
addSubscriptionRequest: global___AddSubscriptionRequest | None = ...,
addSubscriptionResponse: global___AddSubscriptionResponse | None = ...,
cloudEvent: cloudevent_pb2.CloudEvent | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["addSubscription", b"addSubscription", "cloudEvent", b"cloudEvent", "event", b"event", "message", b"message", "registerAgentType", b"registerAgentType", "request", b"request", "response", b"response"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["addSubscription", b"addSubscription", "cloudEvent", b"cloudEvent", "event", b"event", "message", b"message", "registerAgentType", b"registerAgentType", "request", b"request", "response", b"response"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "event", "registerAgentType", "addSubscription", "cloudEvent"] | None: ...
def HasField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "event", b"event", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse", "request", b"request", "response", b"response"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "event", b"event", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse", "request", b"request", "response", b"response"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "event", "registerAgentTypeRequest", "registerAgentTypeResponse", "addSubscriptionRequest", "addSubscriptionResponse", "cloudEvent"] | None: ...
global___Message = Message

View File

@ -22,4 +22,5 @@ __all__ = [
"TypeSubscription",
"DefaultSubscription",
"DefaultTopicId",
"default_subscription",
]

View File

@ -2,7 +2,7 @@ from dataclasses import dataclass
from typing import Any
from autogen_core.base import BaseAgent, MessageContext
from autogen_core.components import DefaultTopicId, RoutedAgent, message_handler
from autogen_core.components import DefaultTopicId, RoutedAgent, default_subscription, message_handler
@dataclass
@ -25,6 +25,7 @@ class LoopbackAgent(RoutedAgent):
return message
@default_subscription
class CascadingAgent(RoutedAgent):
def __init__(self, max_rounds: int) -> None:
super().__init__("A cascading agent.")

View File

@ -5,7 +5,7 @@ import pytest
from autogen_core.application import WorkerAgentRuntime, WorkerAgentRuntimeHost
from autogen_core.base import (
AgentId,
AgentInstantiationContext,
AgentType,
TopicId,
try_get_known_serializers_for_type,
)
@ -14,7 +14,7 @@ from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, Mess
@pytest.mark.asyncio
async def test_agent_names_must_be_unique() -> None:
async def test_agent_types_must_be_unique_single_worker() -> None:
host_address = "localhost:50051"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
@ -22,67 +22,95 @@ async def test_agent_names_must_be_unique() -> None:
worker = WorkerAgentRuntime(host_address=host_address)
worker.start()
def agent_factory() -> NoopAgent:
id = AgentInstantiationContext.current_agent_id()
assert id == AgentId("name1", "default")
agent = NoopAgent()
assert agent.id == id
return agent
await worker.register("name1", agent_factory)
await worker.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
with pytest.raises(ValueError):
await worker.register("name1", NoopAgent)
await worker.register_factory(
type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent
)
await worker.register("name3", NoopAgent)
# Let the agent run for a bit.
await asyncio.sleep(2)
await worker.register_factory(type=AgentType("name4"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
await worker.stop()
await host.stop()
@pytest.mark.asyncio
async def test_agent_types_must_be_unique_multiple_workers() -> None:
host_address = "localhost:50059"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
worker1 = WorkerAgentRuntime(host_address=host_address)
worker1.start()
worker2 = WorkerAgentRuntime(host_address=host_address)
worker2.start()
await worker1.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
with pytest.raises(RuntimeError):
await worker2.register_factory(
type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent
)
await worker2.register_factory(type=AgentType("name4"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
await worker1.stop()
await worker2.stop()
await host.stop()
@pytest.mark.asyncio
async def test_register_receives_publish() -> None:
host_address = "localhost:50052"
host_address = "localhost:50060"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
worker = WorkerAgentRuntime(host_address=host_address)
worker.add_message_serializer(try_get_known_serializers_for_type(MessageType))
worker.start()
worker1 = WorkerAgentRuntime(host_address=host_address)
worker1.start()
worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await worker1.register_factory(
type=AgentType("name1"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
)
await worker1.add_subscription(TypeSubscription("default", "name1"))
await worker.register("name", LoopbackAgent)
await worker.add_subscription(TypeSubscription("default", "name"))
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await worker.publish_message(MessageType(), topic_id=topic_id)
worker2 = WorkerAgentRuntime(host_address=host_address)
worker2.start()
worker2.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await worker2.register_factory(
type=AgentType("name2"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
)
await worker2.add_subscription(TypeSubscription("default", "name2"))
# Publish message from worker1
await worker1.publish_message(MessageType(), topic_id=TopicId("default", "default"))
# Let the agent run for a bit.
await asyncio.sleep(2)
# Agent in default namespace should have received the message
long_running_agent = await worker.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agents in default topic source should have received the message.
worker1_agent = await worker1.try_get_underlying_agent_instance(AgentId("name1", "default"), LoopbackAgent)
assert worker1_agent.num_calls == 1
worker2_agent = await worker2.try_get_underlying_agent_instance(AgentId("name2", "default"), LoopbackAgent)
assert worker2_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await worker.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgent
)
assert other_long_running_agent.num_calls == 0
# Agents in other topic source should not have received the message.
worker1_agent = await worker1.try_get_underlying_agent_instance(AgentId("name1", "other"), LoopbackAgent)
assert worker1_agent.num_calls == 0
worker2_agent = await worker2.try_get_underlying_agent_instance(AgentId("name2", "other"), LoopbackAgent)
assert worker2_agent.num_calls == 0
await worker.stop()
await worker1.stop()
await worker2.stop()
await host.stop()
@pytest.mark.asyncio
async def test_register_receives_publish_cascade() -> None:
async def test_register_receives_publish_cascade_single_worker() -> None:
host_address = "localhost:50053"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
runtime = WorkerAgentRuntime(host_address=host_address)
runtime.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType))
runtime.start()
num_agents = 5
@ -94,7 +122,7 @@ async def test_register_receives_publish_cascade() -> None:
# Register agents
for i in range(num_agents):
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds), lambda: [DefaultSubscription()])
await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds))
# Publish messages
for _ in range(num_initial_messages):
@ -133,9 +161,8 @@ async def test_register_receives_publish_cascade_multiple_workers() -> None:
# Register agents
for i in range(num_agents):
runtime = WorkerAgentRuntime(host_address=host_address)
runtime.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType))
runtime.start()
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds), lambda: [DefaultSubscription()])
await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds))
workers.append(runtime)
# Publish messages