mirror of https://github.com/microsoft/autogen.git
Add response for registering agent type and adding subscriptions in distributed runtime, and throws an exception when error (#582)
* Add response for registering agent type and adding subscriptions in distributed runtime * Update unit tests * lint
This commit is contained in:
parent
d624367894
commit
8018677234
|
@ -45,8 +45,15 @@ message Event {
|
|||
map<string, string> metadata = 5;
|
||||
}
|
||||
|
||||
message RegisterAgentType {
|
||||
string type = 1;
|
||||
message RegisterAgentTypeRequest {
|
||||
string request_id = 1;
|
||||
string type = 2;
|
||||
}
|
||||
|
||||
message RegisterAgentTypeResponse {
|
||||
string request_id = 1;
|
||||
bool success = 2;
|
||||
optional string error = 3;
|
||||
}
|
||||
|
||||
message TypeSubscription {
|
||||
|
@ -60,8 +67,15 @@ message Subscription {
|
|||
}
|
||||
}
|
||||
|
||||
message AddSubscription {
|
||||
Subscription subscription = 1;
|
||||
message AddSubscriptionRequest {
|
||||
string request_id = 1;
|
||||
Subscription subscription = 2;
|
||||
}
|
||||
|
||||
message AddSubscriptionResponse {
|
||||
string request_id = 1;
|
||||
bool success = 2;
|
||||
optional string error = 3;
|
||||
}
|
||||
|
||||
service AgentRpc {
|
||||
|
@ -73,9 +87,11 @@ message Message {
|
|||
RpcRequest request = 1;
|
||||
RpcResponse response = 2;
|
||||
Event event = 3;
|
||||
RegisterAgentType registerAgentType = 4;
|
||||
AddSubscription addSubscription = 5;
|
||||
cloudevent.CloudEvent cloudEvent = 6;
|
||||
RegisterAgentTypeRequest registerAgentTypeRequest = 4;
|
||||
RegisterAgentTypeResponse registerAgentTypeResponse = 5;
|
||||
AddSubscriptionRequest addSubscriptionRequest = 6;
|
||||
AddSubscriptionResponse addSubscriptionResponse = 7;
|
||||
cloudevent.CloudEvent cloudEvent = 8;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
from autogen_core.base import MessageContext
|
||||
from autogen_core.components import DefaultTopicId, RoutedAgent, default_subscription, message_handler
|
||||
|
||||
|
||||
@dataclass
|
||||
class CascadingMessage:
|
||||
round: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReceiveMessageEvent:
|
||||
round: int
|
||||
sender: str
|
||||
recipient: str
|
||||
|
||||
|
||||
@default_subscription
|
||||
class CascadingAgent(RoutedAgent):
|
||||
def __init__(self, max_rounds: int) -> None:
|
||||
super().__init__("A cascading agent.")
|
||||
self.max_rounds = max_rounds
|
||||
|
||||
@message_handler
|
||||
async def on_new_message(self, message: CascadingMessage, ctx: MessageContext) -> None:
|
||||
await self.publish_message(
|
||||
ReceiveMessageEvent(round=message.round, sender=str(ctx.sender), recipient=str(self.id)),
|
||||
topic_id=DefaultTopicId(),
|
||||
)
|
||||
if message.round == self.max_rounds:
|
||||
return
|
||||
await self.publish_message(CascadingMessage(round=message.round + 1), topic_id=DefaultTopicId())
|
||||
|
||||
|
||||
@default_subscription
|
||||
class ObserverAgent(RoutedAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("An observer agent.")
|
||||
|
||||
@message_handler
|
||||
async def on_receive_message(self, message: ReceiveMessageEvent, ctx: MessageContext) -> None:
|
||||
print(f"[Round {message.round}]: Message from {message.sender} to {message.recipient}.")
|
|
@ -0,0 +1,22 @@
|
|||
from agents import CascadingMessage, ObserverAgent
|
||||
from autogen_core.application import WorkerAgentRuntime
|
||||
from autogen_core.base import try_get_known_serializers_for_type
|
||||
from autogen_core.components import DefaultTopicId
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = WorkerAgentRuntime(host_address="localhost:50051")
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(CascadingMessage))
|
||||
runtime.start()
|
||||
await ObserverAgent.register(runtime, "observer_agent", lambda: ObserverAgent())
|
||||
await runtime.publish_message(CascadingMessage(round=1), topic_id=DefaultTopicId())
|
||||
await runtime.stop_when_signal()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# import logging
|
||||
# logging.basicConfig(level=logging.DEBUG)
|
||||
# logger = logging.getLogger("autogen_core")
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
|
@ -0,0 +1,24 @@
|
|||
import uuid
|
||||
|
||||
from agents import CascadingAgent, ReceiveMessageEvent
|
||||
from autogen_core.application import WorkerAgentRuntime
|
||||
from autogen_core.base import try_get_known_serializers_for_type
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = WorkerAgentRuntime(host_address="localhost:50051")
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(ReceiveMessageEvent))
|
||||
runtime.start()
|
||||
agent_type = f"cascading_agent_{uuid.uuid4()}".replace("-", "_")
|
||||
await CascadingAgent.register(runtime, agent_type, lambda: CascadingAgent(max_rounds=3))
|
||||
await runtime.stop_when_signal()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger("autogen_core")
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
|
@ -29,7 +29,7 @@ from typing import (
|
|||
|
||||
import grpc
|
||||
from grpc.aio import StreamStreamCall
|
||||
from opentelemetry.trace import NoOpTracerProvider, TracerProvider
|
||||
from opentelemetry.trace import TracerProvider
|
||||
from typing_extensions import Self, deprecated
|
||||
|
||||
from autogen_core.base import JSON_DATA_CONTENT_TYPE
|
||||
|
@ -194,23 +194,34 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
message = await self._host_connection.recv() # type: ignore
|
||||
oneofcase = agent_worker_pb2.Message.WhichOneof(message, "message")
|
||||
match oneofcase:
|
||||
case "registerAgentType" | "addSubscription":
|
||||
case "registerAgentTypeRequest" | "addSubscriptionRequest":
|
||||
logger.warning(f"Cant handle {oneofcase}, skipping.")
|
||||
case "request":
|
||||
request: agent_worker_pb2.RpcRequest = message.request
|
||||
task = asyncio.create_task(self._process_request(request))
|
||||
task = asyncio.create_task(self._process_request(message.request))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "response":
|
||||
response: agent_worker_pb2.RpcResponse = message.response
|
||||
task = asyncio.create_task(self._process_response(response))
|
||||
task = asyncio.create_task(self._process_response(message.response))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "event":
|
||||
event: agent_worker_pb2.Event = message.event
|
||||
task = asyncio.create_task(self._process_event(event))
|
||||
task = asyncio.create_task(self._process_event(message.event))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "registerAgentTypeResponse":
|
||||
task = asyncio.create_task(
|
||||
self._process_register_agent_type_response(message.registerAgentTypeResponse)
|
||||
)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "addSubscriptionResponse":
|
||||
task = asyncio.create_task(
|
||||
self._process_add_subscription_response(message.addSubscriptionResponse)
|
||||
)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
@ -295,18 +306,15 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
):
|
||||
# create a new future for the result
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
async with self._pending_requests_lock:
|
||||
self._next_request_id += 1
|
||||
request_id = self._next_request_id
|
||||
request_id_str = str(request_id)
|
||||
self._pending_requests[request_id_str] = future
|
||||
request_id = await self._get_new_request_id()
|
||||
self._pending_requests[request_id] = future
|
||||
serialized_message = self._serialization_registry.serialize(
|
||||
message, type_name=data_type, data_content_type=JSON_DATA_CONTENT_TYPE
|
||||
)
|
||||
telemetry_metadata = get_telemetry_grpc_metadata()
|
||||
runtime_message = agent_worker_pb2.Message(
|
||||
request=agent_worker_pb2.RpcRequest(
|
||||
request_id=request_id_str,
|
||||
request_id=request_id,
|
||||
target=agent_worker_pb2.AgentId(type=recipient.type, key=recipient.key),
|
||||
source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key) if sender is not None else None,
|
||||
metadata=telemetry_metadata,
|
||||
|
@ -379,6 +387,11 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
|
||||
raise NotImplementedError("Agent load_state is not yet implemented.")
|
||||
|
||||
async def _get_new_request_id(self) -> str:
|
||||
async with self._pending_requests_lock:
|
||||
self._next_request_id += 1
|
||||
return str(self._next_request_id)
|
||||
|
||||
async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None:
|
||||
assert self._host_connection is not None
|
||||
recipient = AgentId(request.target.type, request.target.key)
|
||||
|
@ -529,9 +542,21 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
|
||||
if self._host_connection is None:
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
message = agent_worker_pb2.Message(registerAgentType=agent_worker_pb2.RegisterAgentType(type=type))
|
||||
|
||||
# Create a future for the registration response.
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
request_id = await self._get_new_request_id()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
# Send the registration request message to the host.
|
||||
message = agent_worker_pb2.Message(
|
||||
registerAgentTypeRequest=agent_worker_pb2.RegisterAgentTypeRequest(request_id=request_id, type=type)
|
||||
)
|
||||
await self._host_connection.send(message)
|
||||
|
||||
# Wait for the registration response.
|
||||
await future
|
||||
|
||||
if subscriptions is not None:
|
||||
if callable(subscriptions):
|
||||
with SubscriptionInstantiationContext.populate_context(AgentType(type)):
|
||||
|
@ -574,11 +599,29 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
|
||||
self._agent_factories[type.type] = factory_wrapper
|
||||
|
||||
message = agent_worker_pb2.Message(registerAgentType=agent_worker_pb2.RegisterAgentType(type=type.type))
|
||||
# Create a future for the registration response.
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
request_id = await self._get_new_request_id()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
# Send the registration request message to the host.
|
||||
message = agent_worker_pb2.Message(
|
||||
registerAgentTypeRequest=agent_worker_pb2.RegisterAgentTypeRequest(request_id=request_id, type=type.type)
|
||||
)
|
||||
await self._host_connection.send(message)
|
||||
|
||||
# Wait for the registration response.
|
||||
await future
|
||||
|
||||
return type
|
||||
|
||||
async def _process_register_agent_type_response(self, response: agent_worker_pb2.RegisterAgentTypeResponse) -> None:
|
||||
future = self._pending_requests.pop(response.request_id)
|
||||
if response.HasField("error"):
|
||||
future.set_exception(RuntimeError(response.error))
|
||||
else:
|
||||
future.set_result(None)
|
||||
|
||||
async def _invoke_agent_factory(
|
||||
self,
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
|
@ -635,18 +678,35 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
raise ValueError("Only TypeSubscription is supported.")
|
||||
# Add to local subscription manager.
|
||||
await self._subscription_manager.add_subscription(subscription)
|
||||
|
||||
# Create a future for the subscription response.
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
request_id = await self._get_new_request_id()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
# Send the subscription to the host.
|
||||
message = agent_worker_pb2.Message(
|
||||
addSubscription=agent_worker_pb2.AddSubscription(
|
||||
addSubscriptionRequest=agent_worker_pb2.AddSubscriptionRequest(
|
||||
request_id=request_id,
|
||||
subscription=agent_worker_pb2.Subscription(
|
||||
typeSubscription=agent_worker_pb2.TypeSubscription(
|
||||
topic_type=subscription.topic_type, agent_type=subscription.agent_type
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
await self._host_connection.send(message)
|
||||
|
||||
# Wait for the subscription response.
|
||||
await future
|
||||
|
||||
async def _process_add_subscription_response(self, response: agent_worker_pb2.AddSubscriptionResponse) -> None:
|
||||
future = self._pending_requests.pop(response.request_id)
|
||||
if response.HasField("error"):
|
||||
future.set_exception(RuntimeError(response.error))
|
||||
else:
|
||||
future.set_result(None)
|
||||
|
||||
async def remove_subscription(self, id: str) -> None:
|
||||
raise NotImplementedError("Subscriptions cannot be removed while using distributed runtime currently.")
|
||||
|
||||
|
|
|
@ -107,18 +107,22 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
|||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "registerAgentType":
|
||||
register_agent_type: agent_worker_pb2.RegisterAgentType = message.registerAgentType
|
||||
task = asyncio.create_task(self._process_register_agent_type(register_agent_type, client_id))
|
||||
case "registerAgentTypeRequest":
|
||||
register_agent_type: agent_worker_pb2.RegisterAgentTypeRequest = message.registerAgentTypeRequest
|
||||
task = asyncio.create_task(
|
||||
self._process_register_agent_type_request(register_agent_type, client_id)
|
||||
)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "addSubscription":
|
||||
add_subscription: agent_worker_pb2.AddSubscription = message.addSubscription
|
||||
task = asyncio.create_task(self._process_add_subscription(add_subscription))
|
||||
case "addSubscriptionRequest":
|
||||
add_subscription: agent_worker_pb2.AddSubscriptionRequest = message.addSubscriptionRequest
|
||||
task = asyncio.create_task(self._process_add_subscription_request(add_subscription, client_id))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "registerAgentTypeResponse" | "addSubscriptionResponse":
|
||||
logger.warning(f"Received unexpected message type: {oneofcase}")
|
||||
case None:
|
||||
logger.warning("Received empty message")
|
||||
|
||||
|
@ -175,32 +179,57 @@ class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
|||
for client_id in client_ids:
|
||||
await self._send_queues[client_id].put(agent_worker_pb2.Message(event=event))
|
||||
|
||||
async def _process_register_agent_type(
|
||||
self, register_agent_type: agent_worker_pb2.RegisterAgentType, client_id: int
|
||||
async def _process_register_agent_type_request(
|
||||
self, register_agent_type_req: agent_worker_pb2.RegisterAgentTypeRequest, client_id: int
|
||||
) -> None:
|
||||
# Register the agent type with the host runtime.
|
||||
async with self._agent_type_to_client_id_lock:
|
||||
if register_agent_type.type in self._agent_type_to_client_id:
|
||||
existing_client_id = self._agent_type_to_client_id[register_agent_type.type]
|
||||
if register_agent_type_req.type in self._agent_type_to_client_id:
|
||||
existing_client_id = self._agent_type_to_client_id[register_agent_type_req.type]
|
||||
logger.error(
|
||||
f"Agent type {register_agent_type.type} already registered with client {existing_client_id}."
|
||||
f"Agent type {register_agent_type_req.type} already registered with client {existing_client_id}."
|
||||
)
|
||||
# TODO: send an error response back to the client.
|
||||
success = False
|
||||
error = f"Agent type {register_agent_type_req.type} already registered."
|
||||
else:
|
||||
self._agent_type_to_client_id[register_agent_type.type] = client_id
|
||||
# TODO: send a success response back to the client.
|
||||
self._agent_type_to_client_id[register_agent_type_req.type] = client_id
|
||||
success = True
|
||||
error = None
|
||||
# Send a response back to the client.
|
||||
await self._send_queues[client_id].put(
|
||||
agent_worker_pb2.Message(
|
||||
registerAgentTypeResponse=agent_worker_pb2.RegisterAgentTypeResponse(
|
||||
request_id=register_agent_type_req.request_id, success=success, error=error
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def _process_add_subscription(self, add_subscription: agent_worker_pb2.AddSubscription) -> None:
|
||||
oneofcase = add_subscription.subscription.WhichOneof("subscription")
|
||||
async def _process_add_subscription_request(
|
||||
self, add_subscription_req: agent_worker_pb2.AddSubscriptionRequest, client_id: int
|
||||
) -> None:
|
||||
oneofcase = add_subscription_req.subscription.WhichOneof("subscription")
|
||||
match oneofcase:
|
||||
case "typeSubscription":
|
||||
type_subscription_msg: agent_worker_pb2.TypeSubscription = (
|
||||
add_subscription.subscription.typeSubscription
|
||||
add_subscription_req.subscription.typeSubscription
|
||||
)
|
||||
type_subscription = TypeSubscription(
|
||||
topic_type=type_subscription_msg.topic_type, agent_type=type_subscription_msg.agent_type
|
||||
)
|
||||
await self._subscription_manager.add_subscription(type_subscription)
|
||||
# TODO: send a success response back to the client.
|
||||
try:
|
||||
await self._subscription_manager.add_subscription(type_subscription)
|
||||
success = True
|
||||
error = None
|
||||
except ValueError as e:
|
||||
success = False
|
||||
error = str(e)
|
||||
# Send a response back to the client.
|
||||
await self._send_queues[client_id].put(
|
||||
agent_worker_pb2.Message(
|
||||
addSubscriptionResponse=agent_worker_pb2.AddSubscriptionResponse(
|
||||
request_id=add_subscription_req.request_id, success=success, error=error
|
||||
)
|
||||
)
|
||||
)
|
||||
case None:
|
||||
logger.warning("Received empty subscription message")
|
||||
|
|
|
@ -6,24 +6,3 @@ import os
|
|||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .agent_worker_pb2 import AgentId, Event, Message, RegisterAgentType, RpcRequest, RpcResponse
|
||||
from .agent_worker_pb2_grpc import AgentRpcServicer, AgentRpcStub, add_AgentRpcServicer_to_server
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .agent_worker_pb2_grpc import AgentRpcAsyncStub
|
||||
|
||||
__all__ = [
|
||||
"RpcRequest",
|
||||
"RpcResponse",
|
||||
"Event",
|
||||
"RegisterAgentType",
|
||||
"AgentRpcAsyncStub",
|
||||
"AgentRpcStub",
|
||||
"Message",
|
||||
"AgentId"
|
||||
]
|
||||
else:
|
||||
__all__ = ["RpcRequest", "RpcResponse", "Event", "RegisterAgentType", "AgentRpcStub", "Message", "AgentId"]
|
||||
|
|
|
@ -16,7 +16,7 @@ import cloudevent_pb2 as cloudevent__pb2
|
|||
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"\'\n\x07TopicId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xe4\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12$\n\x06source\x18\x03 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12 \n\x07payload\x18\x04 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"T\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x42\x0e\n\x0csubscription\"=\n\x0f\x41\x64\x64Subscription\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\x9e\x02\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x12\x32\n\x0f\x61\x64\x64Subscription\x18\x05 \x01(\x0b\x32\x17.agents.AddSubscriptionH\x00\x12,\n\ncloudEvent\x18\x06 \x01(\x0b\x32\x16.cloudevent.CloudEventH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"\'\n\x07TopicId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xe4\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12$\n\x06source\x18\x03 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12 \n\x07payload\x18\x04 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"<\n\x18RegisterAgentTypeRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\"^\n\x19RegisterAgentTypeResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"T\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x42\x0e\n\x0csubscription\"X\n\x16\x41\x64\x64SubscriptionRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12*\n\x0csubscription\x18\x02 \x01(\x0b\x32\x14.agents.Subscription\"\\\n\x17\x41\x64\x64SubscriptionResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\xc6\x03\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x44\n\x18registerAgentTypeRequest\x18\x04 \x01(\x0b\x32 .agents.RegisterAgentTypeRequestH\x00\x12\x46\n\x19registerAgentTypeResponse\x18\x05 \x01(\x0b\x32!.agents.RegisterAgentTypeResponseH\x00\x12@\n\x16\x61\x64\x64SubscriptionRequest\x18\x06 \x01(\x0b\x32\x1e.agents.AddSubscriptionRequestH\x00\x12\x42\n\x17\x61\x64\x64SubscriptionResponse\x18\x07 \x01(\x0b\x32\x1f.agents.AddSubscriptionResponseH\x00\x12,\n\ncloudEvent\x18\x08 \x01(\x0b\x32\x16.cloudevent.CloudEventH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
|
@ -47,16 +47,20 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|||
_globals['_EVENT']._serialized_end=909
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_start=433
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_end=480
|
||||
_globals['_REGISTERAGENTTYPE']._serialized_start=911
|
||||
_globals['_REGISTERAGENTTYPE']._serialized_end=944
|
||||
_globals['_TYPESUBSCRIPTION']._serialized_start=946
|
||||
_globals['_TYPESUBSCRIPTION']._serialized_end=1004
|
||||
_globals['_SUBSCRIPTION']._serialized_start=1006
|
||||
_globals['_SUBSCRIPTION']._serialized_end=1090
|
||||
_globals['_ADDSUBSCRIPTION']._serialized_start=1092
|
||||
_globals['_ADDSUBSCRIPTION']._serialized_end=1153
|
||||
_globals['_MESSAGE']._serialized_start=1156
|
||||
_globals['_MESSAGE']._serialized_end=1442
|
||||
_globals['_AGENTRPC']._serialized_start=1444
|
||||
_globals['_AGENTRPC']._serialized_end=1507
|
||||
_globals['_REGISTERAGENTTYPEREQUEST']._serialized_start=911
|
||||
_globals['_REGISTERAGENTTYPEREQUEST']._serialized_end=971
|
||||
_globals['_REGISTERAGENTTYPERESPONSE']._serialized_start=973
|
||||
_globals['_REGISTERAGENTTYPERESPONSE']._serialized_end=1067
|
||||
_globals['_TYPESUBSCRIPTION']._serialized_start=1069
|
||||
_globals['_TYPESUBSCRIPTION']._serialized_end=1127
|
||||
_globals['_SUBSCRIPTION']._serialized_start=1129
|
||||
_globals['_SUBSCRIPTION']._serialized_end=1213
|
||||
_globals['_ADDSUBSCRIPTIONREQUEST']._serialized_start=1215
|
||||
_globals['_ADDSUBSCRIPTIONREQUEST']._serialized_end=1303
|
||||
_globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_start=1305
|
||||
_globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_end=1397
|
||||
_globals['_MESSAGE']._serialized_start=1400
|
||||
_globals['_MESSAGE']._serialized_end=1854
|
||||
_globals['_AGENTRPC']._serialized_start=1856
|
||||
_globals['_AGENTRPC']._serialized_end=1919
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
|
|
@ -214,19 +214,45 @@ class Event(google.protobuf.message.Message):
|
|||
global___Event = Event
|
||||
|
||||
@typing.final
|
||||
class RegisterAgentType(google.protobuf.message.Message):
|
||||
class RegisterAgentTypeRequest(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
REQUEST_ID_FIELD_NUMBER: builtins.int
|
||||
TYPE_FIELD_NUMBER: builtins.int
|
||||
request_id: builtins.str
|
||||
type: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request_id: builtins.str = ...,
|
||||
type: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["type", b"type"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["request_id", b"request_id", "type", b"type"]) -> None: ...
|
||||
|
||||
global___RegisterAgentType = RegisterAgentType
|
||||
global___RegisterAgentTypeRequest = RegisterAgentTypeRequest
|
||||
|
||||
@typing.final
|
||||
class RegisterAgentTypeResponse(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
REQUEST_ID_FIELD_NUMBER: builtins.int
|
||||
SUCCESS_FIELD_NUMBER: builtins.int
|
||||
ERROR_FIELD_NUMBER: builtins.int
|
||||
request_id: builtins.str
|
||||
success: builtins.bool
|
||||
error: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request_id: builtins.str = ...,
|
||||
success: builtins.bool = ...,
|
||||
error: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error", "request_id", b"request_id", "success", b"success"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ...
|
||||
|
||||
global___RegisterAgentTypeResponse = RegisterAgentTypeResponse
|
||||
|
||||
@typing.final
|
||||
class TypeSubscription(google.protobuf.message.Message):
|
||||
|
@ -265,21 +291,47 @@ class Subscription(google.protobuf.message.Message):
|
|||
global___Subscription = Subscription
|
||||
|
||||
@typing.final
|
||||
class AddSubscription(google.protobuf.message.Message):
|
||||
class AddSubscriptionRequest(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
REQUEST_ID_FIELD_NUMBER: builtins.int
|
||||
SUBSCRIPTION_FIELD_NUMBER: builtins.int
|
||||
request_id: builtins.str
|
||||
@property
|
||||
def subscription(self) -> global___Subscription: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request_id: builtins.str = ...,
|
||||
subscription: global___Subscription | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["subscription", b"subscription"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["subscription", b"subscription"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["request_id", b"request_id", "subscription", b"subscription"]) -> None: ...
|
||||
|
||||
global___AddSubscription = AddSubscription
|
||||
global___AddSubscriptionRequest = AddSubscriptionRequest
|
||||
|
||||
@typing.final
|
||||
class AddSubscriptionResponse(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
REQUEST_ID_FIELD_NUMBER: builtins.int
|
||||
SUCCESS_FIELD_NUMBER: builtins.int
|
||||
ERROR_FIELD_NUMBER: builtins.int
|
||||
request_id: builtins.str
|
||||
success: builtins.bool
|
||||
error: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request_id: builtins.str = ...,
|
||||
success: builtins.bool = ...,
|
||||
error: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error", "request_id", b"request_id", "success", b"success"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ...
|
||||
|
||||
global___AddSubscriptionResponse = AddSubscriptionResponse
|
||||
|
||||
@typing.final
|
||||
class Message(google.protobuf.message.Message):
|
||||
|
@ -288,8 +340,10 @@ class Message(google.protobuf.message.Message):
|
|||
REQUEST_FIELD_NUMBER: builtins.int
|
||||
RESPONSE_FIELD_NUMBER: builtins.int
|
||||
EVENT_FIELD_NUMBER: builtins.int
|
||||
REGISTERAGENTTYPE_FIELD_NUMBER: builtins.int
|
||||
ADDSUBSCRIPTION_FIELD_NUMBER: builtins.int
|
||||
REGISTERAGENTTYPEREQUEST_FIELD_NUMBER: builtins.int
|
||||
REGISTERAGENTTYPERESPONSE_FIELD_NUMBER: builtins.int
|
||||
ADDSUBSCRIPTIONREQUEST_FIELD_NUMBER: builtins.int
|
||||
ADDSUBSCRIPTIONRESPONSE_FIELD_NUMBER: builtins.int
|
||||
CLOUDEVENT_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def request(self) -> global___RpcRequest: ...
|
||||
|
@ -298,9 +352,13 @@ class Message(google.protobuf.message.Message):
|
|||
@property
|
||||
def event(self) -> global___Event: ...
|
||||
@property
|
||||
def registerAgentType(self) -> global___RegisterAgentType: ...
|
||||
def registerAgentTypeRequest(self) -> global___RegisterAgentTypeRequest: ...
|
||||
@property
|
||||
def addSubscription(self) -> global___AddSubscription: ...
|
||||
def registerAgentTypeResponse(self) -> global___RegisterAgentTypeResponse: ...
|
||||
@property
|
||||
def addSubscriptionRequest(self) -> global___AddSubscriptionRequest: ...
|
||||
@property
|
||||
def addSubscriptionResponse(self) -> global___AddSubscriptionResponse: ...
|
||||
@property
|
||||
def cloudEvent(self) -> cloudevent_pb2.CloudEvent: ...
|
||||
def __init__(
|
||||
|
@ -309,12 +367,14 @@ class Message(google.protobuf.message.Message):
|
|||
request: global___RpcRequest | None = ...,
|
||||
response: global___RpcResponse | None = ...,
|
||||
event: global___Event | None = ...,
|
||||
registerAgentType: global___RegisterAgentType | None = ...,
|
||||
addSubscription: global___AddSubscription | None = ...,
|
||||
registerAgentTypeRequest: global___RegisterAgentTypeRequest | None = ...,
|
||||
registerAgentTypeResponse: global___RegisterAgentTypeResponse | None = ...,
|
||||
addSubscriptionRequest: global___AddSubscriptionRequest | None = ...,
|
||||
addSubscriptionResponse: global___AddSubscriptionResponse | None = ...,
|
||||
cloudEvent: cloudevent_pb2.CloudEvent | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["addSubscription", b"addSubscription", "cloudEvent", b"cloudEvent", "event", b"event", "message", b"message", "registerAgentType", b"registerAgentType", "request", b"request", "response", b"response"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["addSubscription", b"addSubscription", "cloudEvent", b"cloudEvent", "event", b"event", "message", b"message", "registerAgentType", b"registerAgentType", "request", b"request", "response", b"response"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "event", "registerAgentType", "addSubscription", "cloudEvent"] | None: ...
|
||||
def HasField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "event", b"event", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse", "request", b"request", "response", b"response"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "event", b"event", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse", "request", b"request", "response", b"response"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "event", "registerAgentTypeRequest", "registerAgentTypeResponse", "addSubscriptionRequest", "addSubscriptionResponse", "cloudEvent"] | None: ...
|
||||
|
||||
global___Message = Message
|
||||
|
|
|
@ -22,4 +22,5 @@ __all__ = [
|
|||
"TypeSubscription",
|
||||
"DefaultSubscription",
|
||||
"DefaultTopicId",
|
||||
"default_subscription",
|
||||
]
|
||||
|
|
|
@ -2,7 +2,7 @@ from dataclasses import dataclass
|
|||
from typing import Any
|
||||
|
||||
from autogen_core.base import BaseAgent, MessageContext
|
||||
from autogen_core.components import DefaultTopicId, RoutedAgent, message_handler
|
||||
from autogen_core.components import DefaultTopicId, RoutedAgent, default_subscription, message_handler
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -25,6 +25,7 @@ class LoopbackAgent(RoutedAgent):
|
|||
return message
|
||||
|
||||
|
||||
@default_subscription
|
||||
class CascadingAgent(RoutedAgent):
|
||||
def __init__(self, max_rounds: int) -> None:
|
||||
super().__init__("A cascading agent.")
|
||||
|
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
from autogen_core.application import WorkerAgentRuntime, WorkerAgentRuntimeHost
|
||||
from autogen_core.base import (
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
AgentType,
|
||||
TopicId,
|
||||
try_get_known_serializers_for_type,
|
||||
)
|
||||
|
@ -14,7 +14,7 @@ from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, Mess
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_names_must_be_unique() -> None:
|
||||
async def test_agent_types_must_be_unique_single_worker() -> None:
|
||||
host_address = "localhost:50051"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
|
@ -22,67 +22,95 @@ async def test_agent_names_must_be_unique() -> None:
|
|||
worker = WorkerAgentRuntime(host_address=host_address)
|
||||
worker.start()
|
||||
|
||||
def agent_factory() -> NoopAgent:
|
||||
id = AgentInstantiationContext.current_agent_id()
|
||||
assert id == AgentId("name1", "default")
|
||||
agent = NoopAgent()
|
||||
assert agent.id == id
|
||||
return agent
|
||||
|
||||
await worker.register("name1", agent_factory)
|
||||
await worker.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await worker.register("name1", NoopAgent)
|
||||
await worker.register_factory(
|
||||
type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent
|
||||
)
|
||||
|
||||
await worker.register("name3", NoopAgent)
|
||||
|
||||
# Let the agent run for a bit.
|
||||
await asyncio.sleep(2)
|
||||
await worker.register_factory(type=AgentType("name4"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
|
||||
|
||||
await worker.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_types_must_be_unique_multiple_workers() -> None:
|
||||
host_address = "localhost:50059"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
|
||||
worker1 = WorkerAgentRuntime(host_address=host_address)
|
||||
worker1.start()
|
||||
worker2 = WorkerAgentRuntime(host_address=host_address)
|
||||
worker2.start()
|
||||
|
||||
await worker1.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await worker2.register_factory(
|
||||
type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent
|
||||
)
|
||||
|
||||
await worker2.register_factory(type=AgentType("name4"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
|
||||
|
||||
await worker1.stop()
|
||||
await worker2.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_receives_publish() -> None:
|
||||
host_address = "localhost:50052"
|
||||
host_address = "localhost:50060"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
|
||||
worker = WorkerAgentRuntime(host_address=host_address)
|
||||
worker.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
worker.start()
|
||||
worker1 = WorkerAgentRuntime(host_address=host_address)
|
||||
worker1.start()
|
||||
worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
await worker1.register_factory(
|
||||
type=AgentType("name1"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
|
||||
)
|
||||
await worker1.add_subscription(TypeSubscription("default", "name1"))
|
||||
|
||||
await worker.register("name", LoopbackAgent)
|
||||
await worker.add_subscription(TypeSubscription("default", "name"))
|
||||
agent_id = AgentId("name", key="default")
|
||||
topic_id = TopicId("default", "default")
|
||||
await worker.publish_message(MessageType(), topic_id=topic_id)
|
||||
worker2 = WorkerAgentRuntime(host_address=host_address)
|
||||
worker2.start()
|
||||
worker2.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
await worker2.register_factory(
|
||||
type=AgentType("name2"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
|
||||
)
|
||||
await worker2.add_subscription(TypeSubscription("default", "name2"))
|
||||
|
||||
# Publish message from worker1
|
||||
await worker1.publish_message(MessageType(), topic_id=TopicId("default", "default"))
|
||||
|
||||
# Let the agent run for a bit.
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await worker.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 1
|
||||
# Agents in default topic source should have received the message.
|
||||
worker1_agent = await worker1.try_get_underlying_agent_instance(AgentId("name1", "default"), LoopbackAgent)
|
||||
assert worker1_agent.num_calls == 1
|
||||
worker2_agent = await worker2.try_get_underlying_agent_instance(AgentId("name2", "default"), LoopbackAgent)
|
||||
assert worker2_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await worker.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgent
|
||||
)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
# Agents in other topic source should not have received the message.
|
||||
worker1_agent = await worker1.try_get_underlying_agent_instance(AgentId("name1", "other"), LoopbackAgent)
|
||||
assert worker1_agent.num_calls == 0
|
||||
worker2_agent = await worker2.try_get_underlying_agent_instance(AgentId("name2", "other"), LoopbackAgent)
|
||||
assert worker2_agent.num_calls == 0
|
||||
|
||||
await worker.stop()
|
||||
await worker1.stop()
|
||||
await worker2.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_receives_publish_cascade() -> None:
|
||||
async def test_register_receives_publish_cascade_single_worker() -> None:
|
||||
host_address = "localhost:50053"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
runtime = WorkerAgentRuntime(host_address=host_address)
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType))
|
||||
runtime.start()
|
||||
|
||||
num_agents = 5
|
||||
|
@ -94,7 +122,7 @@ async def test_register_receives_publish_cascade() -> None:
|
|||
|
||||
# Register agents
|
||||
for i in range(num_agents):
|
||||
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds), lambda: [DefaultSubscription()])
|
||||
await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds))
|
||||
|
||||
# Publish messages
|
||||
for _ in range(num_initial_messages):
|
||||
|
@ -133,9 +161,8 @@ async def test_register_receives_publish_cascade_multiple_workers() -> None:
|
|||
# Register agents
|
||||
for i in range(num_agents):
|
||||
runtime = WorkerAgentRuntime(host_address=host_address)
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType))
|
||||
runtime.start()
|
||||
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds), lambda: [DefaultSubscription()])
|
||||
await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds))
|
||||
workers.append(runtime)
|
||||
|
||||
# Publish messages
|
||||
|
|
Loading…
Reference in New Issue