Distributed agent runtime API for host and worker; unit tests; documentation (#465)

* host agent runtime API and docs

* graceful shutdown of worker

* HostAgentRuntime --> WorkerAgentRuntimeHost

* Add unit tests for worker runtime

* Fix bug in worker runtime adding sender filed to proto. Documentation.

* wip

* Fix unit tests; refactor API

* fix formatting

* Fix

* Update

* Make source field optional in Event proto
This commit is contained in:
Eric Zhu 2024-09-13 08:17:53 -07:00 committed by GitHub
parent 0376a0b399
commit a6c1b503ad
14 changed files with 637 additions and 142 deletions

View File

@ -15,7 +15,7 @@ message Payload {
message RpcRequest {
string request_id = 1;
AgentId source = 2;
optional AgentId source = 2;
AgentId target = 3;
string method = 4;
Payload payload = 5;
@ -32,8 +32,9 @@ message RpcResponse {
message Event {
string topic_type = 1;
string topic_source = 2;
Payload payload = 3;
map<string, string> metadata = 4;
optional AgentId source = 3;
Payload payload = 4;
map<string, string> metadata = 5;
}
message RegisterAgentType {

View File

@ -0,0 +1,207 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Distributed Agent Runtime (Experimental)\n",
"\n",
"```{note}\n",
"The distributed agent runtime is an experimental feature. Expect breaking changes\n",
"to the API.\n",
"```\n",
"\n",
"A distributed agent runtime facilitates communication and agent lifecycle management\n",
"across process boundaries.\n",
"It consists of a host service and at least one worker runtime.\n",
"\n",
"The host service maintains connections to all active worker runtimes,\n",
"facilitates message delivery, and keeps sessions for all direct messages (i.e., RPCs).\n",
"A worker runtime processes application code (agents) and connects to the host service.\n",
"It also advertises the agents which they support to the host service,\n",
"so the host service can deliver messages to the correct worker.\n",
"\n",
"We can start a host service using {py:class}`autogen_core.application.WorkerAgentRuntimeHost`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from autogen_core.application import WorkerAgentRuntimeHost\n",
"\n",
"host = WorkerAgentRuntimeHost(address=\"localhost:50051\")\n",
"host.start() # Start a host service in the background."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above code starts the host service in the background and accepting\n",
"worker connections at port 50051.\n",
"\n",
"Before running worker runtimes, let's define our agent.\n",
"The agent publishes a new message on every message it receives.\n",
"It also keeps track of how many messages it has published, and \n",
"stops publishing new messages once it has published 5 messages."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"\n",
"from autogen_core.base import MessageContext\n",
"from autogen_core.components import DefaultTopicId, RoutedAgent, message_handler\n",
"\n",
"\n",
"@dataclass\n",
"class MyMessage:\n",
" content: str\n",
"\n",
"\n",
"class MyAgent(RoutedAgent):\n",
" def __init__(self, name: str) -> None:\n",
" super().__init__(\"My agent\")\n",
" self._name = name\n",
" self._counter = 0\n",
"\n",
" @message_handler\n",
" async def my_message_handler(self, message: MyMessage, ctx: MessageContext) -> None:\n",
" self._counter += 1\n",
" if self._counter > 5:\n",
" return\n",
" content = f\"{self._name}: Hello x {self._counter}\"\n",
" print(content)\n",
" await self.publish_message(MyMessage(content=content), DefaultTopicId())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can set up the worker agent runtimes.\n",
"We use {py:class}`autogen_core.application.WorkerAgentRuntime`.\n",
"We set up two worker runtimes. Each runtime hosts one agent.\n",
"All agents publish and subscribe to the default topic, so they can see all\n",
"messages being published.\n",
"\n",
"To run the agents, we publishes a message from a worker."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"worker1: Hello x 1\n",
"worker2: Hello x 1\n",
"worker2: Hello x 2\n",
"worker1: Hello x 2\n",
"worker1: Hello x 3\n",
"worker2: Hello x 3\n",
"worker2: Hello x 4\n",
"worker1: Hello x 4\n",
"worker1: Hello x 5\n",
"worker2: Hello x 5\n"
]
}
],
"source": [
"import asyncio\n",
"\n",
"from autogen_core.application import WorkerAgentRuntime\n",
"from autogen_core.base import MESSAGE_TYPE_REGISTRY, AgentId, try_get_known_serializers_for_type\n",
"from autogen_core.components import DefaultSubscription\n",
"\n",
"MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MyMessage))\n",
"\n",
"worker1 = WorkerAgentRuntime(host_address=\"localhost:50051\")\n",
"worker1.start()\n",
"await worker1.register(\"worker1\", lambda: MyAgent(\"worker1\"), lambda: [DefaultSubscription()])\n",
"\n",
"worker2 = WorkerAgentRuntime(host_address=\"localhost:50051\")\n",
"worker2.start()\n",
"await worker2.register(\"worker2\", lambda: MyAgent(\"worker2\"), lambda: [DefaultSubscription()])\n",
"\n",
"await worker2.publish_message(MyMessage(content=\"Hello!\"), DefaultTopicId())\n",
"\n",
"# Let the agents run for a while.\n",
"await asyncio.sleep(5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see each agent published exactly 5 messages.\n",
"\n",
"To stop the worker runtimes, we can call {py:meth}`autogen_core.application.WorkerAgentRuntime.stop`."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"await worker1.stop()\n",
"await worker2.stop()\n",
"\n",
"# To keep the worker running until a termination signal is received (e.g., SIGTERM).\n",
"# await worker1.stop_when_signal()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can call {py:meth}`autogen_core.application.WorkerAgentRuntimeHost.stop`\n",
"to stop the host service."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"await host.stop()\n",
"\n",
"# To keep the host service running until a termination signal (e.g., SIGTERM)\n",
"# await host.stop_when_signal()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "agnext",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -44,8 +44,8 @@ from `Agent and Multi-Agent Application <core-concepts/agent-and-multi-agent-app
:hidden:
guides/logging
guides/distributed-agent-runtime
guides/telemetry
guides/worker-protocol
.. toctree::
:caption: Cookbook

View File

@ -1,44 +1,12 @@
import asyncio
import signal
import grpc
from autogen_core.application import HostRuntimeServicer
from autogen_core.application.protos import agent_worker_pb2_grpc
async def serve(server: grpc.aio.Server) -> None: # type: ignore
await server.start()
print("Server started")
await server.wait_for_termination()
from autogen_core.application import WorkerAgentRuntimeHost
async def main() -> None:
server = grpc.aio.server()
agent_worker_pb2_grpc.add_AgentRpcServicer_to_server(HostRuntimeServicer(), server)
server.add_insecure_port("[::]:50051")
# Set up signal handling for graceful shutdown
loop = asyncio.get_running_loop()
shutdown_event = asyncio.Event()
def signal_handler() -> None:
print("Received exit signal, shutting down gracefully...")
shutdown_event.set()
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
# Start server in background task
serve_task = asyncio.create_task(serve(server))
# Wait for the signal to trigger the shutdown event
await shutdown_event.wait()
# Graceful shutdown
await server.stop(5) # 5 second grace period
await serve_task
print("Server stopped")
service = WorkerAgentRuntimeHost(address="localhost:50051")
service.start()
await service.stop_when_signal()
if __name__ == "__main__":
@ -46,8 +14,4 @@ if __name__ == "__main__":
logging.basicConfig(level=logging.WARNING)
logging.getLogger("autogen_core").setLevel(logging.DEBUG)
try:
asyncio.run(main())
except KeyboardInterrupt:
print("Server shutdown interrupted.")
asyncio.run(main())

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import Any, NoReturn
from autogen_core.application import WorkerAgentRuntime
from autogen_core.base import MESSAGE_TYPE_REGISTRY, MessageContext, try_get_known_serializers_for_type
from autogen_core.base import MESSAGE_TYPE_REGISTRY, AgentId, MessageContext, try_get_known_serializers_for_type
from autogen_core.components import DefaultSubscription, DefaultTopicId, RoutedAgent, message_handler
@ -66,25 +66,20 @@ class GreeterAgent(RoutedAgent):
async def main() -> None:
runtime = WorkerAgentRuntime()
runtime = WorkerAgentRuntime(host_address="localhost:50051")
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(Greeting))
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(AskToGreet))
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(Feedback))
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(ReturnedGreeting))
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(ReturnedFeedback))
await runtime.start(host_connection_string="localhost:50051")
runtime.start()
await runtime.register("receiver", ReceiveAgent, lambda: [DefaultSubscription()])
await runtime.register("greeter", GreeterAgent, lambda: [DefaultSubscription()])
await runtime.publish_message(AskToGreet("Hello World!"), topic_id=DefaultTopicId())
# Just to keep the runtime running
try:
await asyncio.sleep(1000000)
except KeyboardInterrupt:
pass
await runtime.stop()
await runtime.stop_when_signal()
if __name__ == "__main__":

View File

@ -60,11 +60,11 @@ class GreeterAgent(RoutedAgent):
async def main() -> None:
runtime = WorkerAgentRuntime()
runtime = WorkerAgentRuntime(host_address="localhost:50051")
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(Greeting))
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(AskToGreet))
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(Feedback))
await runtime.start(host_connection_string="localhost:50051")
runtime.start()
await runtime.register("receiver", lambda: ReceiveAgent(), lambda: [DefaultSubscription()])
await runtime.register(
@ -74,12 +74,7 @@ async def main() -> None:
)
await runtime.publish_message(AskToGreet("Hello World!"), topic_id=DefaultTopicId())
# Just to keep the runtime running
try:
await asyncio.sleep(1000000)
except KeyboardInterrupt:
pass
await runtime.stop()
await runtime.stop_when_signal()
if __name__ == "__main__":

View File

@ -2,8 +2,12 @@
The :mod:`autogen_core.application` module provides implementations of core components that are used to compose an application
"""
from ._host_runtime_servicer import HostRuntimeServicer
from ._single_threaded_agent_runtime import SingleThreadedAgentRuntime
from ._worker_runtime import WorkerAgentRuntime
from ._worker_runtime_host import WorkerAgentRuntimeHost
__all__ = ["SingleThreadedAgentRuntime", "WorkerAgentRuntime", "HostRuntimeServicer"]
__all__ = [
"SingleThreadedAgentRuntime",
"WorkerAgentRuntime",
"WorkerAgentRuntimeHost",
]

View File

@ -2,6 +2,7 @@ import asyncio
import inspect
import json
import logging
import signal
import warnings
from asyncio import Future, Task
from collections import defaultdict
@ -19,6 +20,7 @@ from typing import (
Literal,
Mapping,
ParamSpec,
Sequence,
Set,
Type,
TypeVar,
@ -96,13 +98,9 @@ class HostConnection:
self._connection_task: Task[None] | None = None
@classmethod
async def from_connection_string(
cls, connection_string: str, grpc_config: Mapping[str, Any] = DEFAULT_GRPC_CONFIG
) -> Self:
logger.info("Connecting to %s", connection_string)
channel = grpc.aio.insecure_channel(
connection_string, options=[("grpc.service_config", json.dumps(grpc_config))]
)
def from_host_address(cls, host_address: str, grpc_config: Mapping[str, Any] = DEFAULT_GRPC_CONFIG) -> Self:
logger.info("Connecting to %s", host_address)
channel = grpc.aio.insecure_channel(host_address, options=[("grpc.service_config", json.dumps(grpc_config))])
instance = cls(channel)
instance._connection_task = asyncio.create_task(
instance._connect(channel, instance._send_queue, instance._recv_queue)
@ -110,6 +108,8 @@ class HostConnection:
return instance
async def close(self) -> None:
if self._connection_task is None:
raise RuntimeError("Connection is not open.")
await self._channel.close()
if self._connection_task is not None:
await self._connection_task
@ -128,22 +128,15 @@ class HostConnection:
) # type: ignore
while True:
try:
logger.info("Waiting for message from host")
message = await recv_stream.read() # type: ignore
if message == grpc.aio.EOF: # type: ignore
logger.info("EOF")
break
message = cast(agent_worker_pb2.Message, message)
logger.info(f"Received a message from host: {message}")
await receive_queue.put(message)
logger.info("Put message in receive queue")
except Exception as e:
print("=========================================================================")
print(e)
print("=========================================================================")
del recv_stream
recv_stream = stub.OpenChannel(QueueAsyncIterable(send_queue)) # type: ignore
logger.info("Waiting for message from host")
message = await recv_stream.read() # type: ignore
if message == grpc.aio.EOF: # type: ignore
logger.info("EOF")
break
message = cast(agent_worker_pb2.Message, message)
logger.info(f"Received a message from host: {message}")
await receive_queue.put(message)
logger.info("Put message in receive queue")
async def send(self, message: agent_worker_pb2.Message) -> None:
logger.info(f"Send message to host: {message}")
@ -156,7 +149,8 @@ class HostConnection:
class WorkerAgentRuntime(AgentRuntime):
def __init__(self, tracer_provider: TracerProvider | None = None) -> None:
def __init__(self, host_address: str, tracer_provider: TracerProvider | None = None) -> None:
self._host_address = host_address
self._trace_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("Worker Runtime"))
self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set)
self._agent_factories: Dict[
@ -173,12 +167,13 @@ class WorkerAgentRuntime(AgentRuntime):
self._background_tasks: Set[Task[Any]] = set()
self._subscription_manager = SubscriptionManager()
async def start(self, host_connection_string: str) -> None:
def start(self) -> None:
"""Start the runtime in a background task."""
if self._running:
raise ValueError("Runtime is already running.")
logger.info(f"Connecting to host: {host_connection_string}")
self._host_connection = await HostConnection.from_connection_string(host_connection_string)
logger.info("connection")
logger.info(f"Connecting to host: {self._host_address}")
self._host_connection = HostConnection.from_host_address(self._host_address)
logger.info("Connection established")
if self._read_task is None:
self._read_task = asyncio.create_task(self._run_read_loop())
self._running = True
@ -197,7 +192,7 @@ class WorkerAgentRuntime(AgentRuntime):
oneofcase = agent_worker_pb2.Message.WhichOneof(message, "message")
match oneofcase:
case "registerAgentType" | "addSubscription":
logger.warn(f"Cant handle {oneofcase}, skipping.")
logger.warning(f"Cant handle {oneofcase}, skipping.")
case "request":
request: agent_worker_pb2.RpcRequest = message.request
task = asyncio.create_task(self._process_request(request))
@ -217,16 +212,51 @@ class WorkerAgentRuntime(AgentRuntime):
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case None:
logger.warn("No message")
logger.warning("No message")
except Exception as e:
logger.error("Error in read loop", exc_info=e)
async def stop(self) -> None:
"""Stop the runtime immediately."""
if not self._running:
raise RuntimeError("Runtime is not running.")
self._running = False
# Wait for all background tasks to finish.
final_tasks_results = await asyncio.gather(*self._background_tasks, return_exceptions=True)
for task_result in final_tasks_results:
if isinstance(task_result, Exception):
logger.error("Error in background task", exc_info=task_result)
# Close the host connection.
if self._host_connection is not None:
await self._host_connection.close()
try:
await self._host_connection.close()
except asyncio.CancelledError:
pass
# Cancel the read task.
if self._read_task is not None:
await self._read_task
self._read_task.cancel()
try:
await self._read_task
except asyncio.CancelledError:
pass
async def stop_when_signal(self, signals: Sequence[signal.Signals] = (signal.SIGTERM, signal.SIGINT)) -> None:
"""Stop the runtime when a signal is received."""
loop = asyncio.get_running_loop()
shutdown_event = asyncio.Event()
def signal_handler() -> None:
logger.info("Received exit signal, shutting down gracefully...")
shutdown_event.set()
for sig in signals:
loop.add_signal_handler(sig, signal_handler)
# Wait for the signal to trigger the shutdown event.
await shutdown_event.wait()
# Stop the runtime.
await self.stop()
@property
def _known_agent_names(self) -> Set[str]:
@ -267,7 +297,6 @@ class WorkerAgentRuntime(AgentRuntime):
request_id = self._next_request_id
request_id_str = str(request_id)
self._pending_requests[request_id_str] = future
sender = cast(AgentId, sender)
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(
message, type_name=data_type, data_content_type=JSON_DATA_CONTENT_TYPE
)
@ -276,7 +305,7 @@ class WorkerAgentRuntime(AgentRuntime):
request=agent_worker_pb2.RpcRequest(
request_id=request_id_str,
target=agent_worker_pb2.AgentId(type=recipient.type, key=recipient.key),
source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key),
source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key) if sender is not None else None,
metadata=telemetry_metadata,
payload=agent_worker_pb2.Payload(
data_type=data_type,
@ -317,6 +346,7 @@ class WorkerAgentRuntime(AgentRuntime):
event=agent_worker_pb2.Event(
topic_type=topic_id.type,
topic_source=topic_id.source,
source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key) if sender is not None else None,
metadata=telemetry_metadata,
payload=agent_worker_pb2.Payload(
data_type=message_type,
@ -348,10 +378,13 @@ class WorkerAgentRuntime(AgentRuntime):
async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None:
assert self._host_connection is not None
target = AgentId(request.target.type, request.target.key)
source = AgentId(request.source.type, request.source.key)
logging.info(f"Processing request from {source} to {target}")
recipient = AgentId(request.target.type, request.target.key)
sender: AgentId | None = None
if request.HasField("source"):
sender = AgentId(request.source.type, request.source.key)
logging.info(f"Processing request from {sender} to {recipient}")
else:
logging.info(f"Processing request from unknown source to {recipient}")
# Deserialize the message.
message = MESSAGE_TYPE_REGISTRY.deserialize(
@ -360,26 +393,26 @@ class WorkerAgentRuntime(AgentRuntime):
data_content_type=request.payload.data_content_type,
)
# Get the target agent and prepare the message context.
target_agent = await self._get_agent(target)
# Get the receiving agent and prepare the message context.
rec_agent = await self._get_agent(recipient)
message_context = MessageContext(
sender=source,
sender=sender,
topic_id=None,
is_rpc=True,
cancellation_token=CancellationToken(),
)
# Call the target agent.
# Call the receiving agent.
try:
with MessageHandlerContext.populate_context(target_agent.id):
with MessageHandlerContext.populate_context(rec_agent.id):
with self._trace_helper.trace_block(
"process",
target_agent.id,
rec_agent.id,
parent=request.metadata,
attributes={"request_id": request.request_id},
extraAttributes={"message_type": request.payload.data_type},
):
result = await target_agent.on_message(message, ctx=message_context)
result = await rec_agent.on_message(message, ctx=message_context)
except BaseException as e:
response_message = agent_worker_pb2.Message(
response=agent_worker_pb2.RpcResponse(
@ -436,18 +469,22 @@ class WorkerAgentRuntime(AgentRuntime):
future.set_result(result)
async def _process_event(self, event: agent_worker_pb2.Event) -> None:
topic_id = TopicId(event.topic_type, event.topic_source)
message = MESSAGE_TYPE_REGISTRY.deserialize(
event.payload.data, type_name=event.payload.data_type, data_content_type=event.payload.data_content_type
)
sender: AgentId | None = None
if event.HasField("source"):
sender = AgentId(event.source.type, event.source.key)
topic_id = TopicId(event.topic_type, event.topic_source)
# Get the recipients for the topic.
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
# Send the message to each recipient.
responses: List[Awaitable[Any]] = []
for agent_id in recipients:
# TODO: avoid sending to the sender.
if agent_id == sender:
continue
message_context = MessageContext(
sender=None,
sender=sender,
topic_id=topic_id,
is_rpc=False,
cancellation_token=CancellationToken(),
@ -543,7 +580,16 @@ class WorkerAgentRuntime(AgentRuntime):
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
raise NotImplementedError("try_get_underlying_agent_instance is not yet implemented.")
if id.type not in self._agent_factories:
raise LookupError(f"Agent with name {id.type} not found.")
# TODO: check if remote
agent_instance = await self._get_agent(id)
if not isinstance(agent_instance, type):
raise TypeError(f"Agent with name {id.type} is not of type {type.__name__}")
return agent_instance
async def add_subscription(self, subscription: Subscription) -> None:
if self._host_connection is None:

View File

@ -0,0 +1,68 @@
import asyncio
import logging
import signal
from typing import Sequence
import grpc
from ._worker_runtime_host_servicer import WorkerAgentRuntimeHostServicer
from .protos import agent_worker_pb2_grpc
logger = logging.getLogger("autogen_core")
class WorkerAgentRuntimeHost:
def __init__(self, address: str) -> None:
self._server = grpc.aio.server()
self._servicer = WorkerAgentRuntimeHostServicer()
agent_worker_pb2_grpc.add_AgentRpcServicer_to_server(self._servicer, self._server)
self._server.add_insecure_port(address)
self._address = address
self._serve_task: asyncio.Task[None] | None = None
async def _serve(self) -> None:
await self._server.start()
logger.info(f"Server started at {self._address}.")
await self._server.wait_for_termination()
def start(self) -> None:
"""Start the server in a background task."""
if self._serve_task is not None:
raise RuntimeError("Host runtime is already started.")
self._serve_task = asyncio.create_task(self._serve())
async def stop(self, grace: int = 5) -> None:
"""Stop the server."""
if self._serve_task is None:
raise RuntimeError("Host runtime is not started.")
await self._server.stop(grace=grace)
self._serve_task.cancel()
try:
await self._serve_task
except asyncio.CancelledError:
pass
logger.info("Server stopped.")
self._serve_task = None
async def stop_when_signal(
self, grace: int = 5, signals: Sequence[signal.Signals] = (signal.SIGTERM, signal.SIGINT)
) -> None:
"""Stop the server when a signal is received."""
if self._serve_task is None:
raise RuntimeError("Host runtime is not started.")
# Set up signal handling for graceful shutdown.
loop = asyncio.get_running_loop()
shutdown_event = asyncio.Event()
def signal_handler() -> None:
logger.info("Received exit signal, shutting down gracefully...")
shutdown_event.set()
for sig in signals:
loop.add_signal_handler(sig, signal_handler)
# Wait for the signal to trigger the shutdown event.
await shutdown_event.wait()
# Shutdown the server.
await self.stop(grace=grace)

View File

@ -15,7 +15,7 @@ logger = logging.getLogger("autogen_core")
event_logger = logging.getLogger("autogen_core.events")
class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer):
class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
"""A gRPC servicer that hosts message delivery service for agents."""
def __init__(self) -> None:

View File

@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xf9\x01\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x1f\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentId\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xb3\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12 \n\x07payload\x18\x03 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x04 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"T\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x42\x0e\n\x0csubscription\"=\n\x0f\x41\x64\x64Subscription\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\xf0\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x12\x32\n\x0f\x61\x64\x64Subscription\x18\x05 \x01(\x0b\x32\x17.agents.AddSubscriptionH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xe4\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12$\n\x06source\x18\x03 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12 \n\x07payload\x18\x04 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"T\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x42\x0e\n\x0csubscription\"=\n\x0f\x41\x64\x64Subscription\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\xf0\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x12\x32\n\x0f\x61\x64\x64Subscription\x18\x05 \x01(\x0b\x32\x17.agents.AddSubscriptionH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@ -32,27 +32,27 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_globals['_PAYLOAD']._serialized_start=68
_globals['_PAYLOAD']._serialized_end=137
_globals['_RPCREQUEST']._serialized_start=140
_globals['_RPCREQUEST']._serialized_end=389
_globals['_RPCREQUEST_METADATAENTRY']._serialized_start=342
_globals['_RPCREQUEST_METADATAENTRY']._serialized_end=389
_globals['_RPCRESPONSE']._serialized_start=392
_globals['_RPCRESPONSE']._serialized_end=576
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=342
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=389
_globals['_EVENT']._serialized_start=579
_globals['_EVENT']._serialized_end=758
_globals['_EVENT_METADATAENTRY']._serialized_start=342
_globals['_EVENT_METADATAENTRY']._serialized_end=389
_globals['_REGISTERAGENTTYPE']._serialized_start=760
_globals['_REGISTERAGENTTYPE']._serialized_end=793
_globals['_TYPESUBSCRIPTION']._serialized_start=795
_globals['_TYPESUBSCRIPTION']._serialized_end=853
_globals['_SUBSCRIPTION']._serialized_start=855
_globals['_SUBSCRIPTION']._serialized_end=939
_globals['_ADDSUBSCRIPTION']._serialized_start=941
_globals['_ADDSUBSCRIPTION']._serialized_end=1002
_globals['_MESSAGE']._serialized_start=1005
_globals['_MESSAGE']._serialized_end=1245
_globals['_AGENTRPC']._serialized_start=1247
_globals['_AGENTRPC']._serialized_end=1310
_globals['_RPCREQUEST']._serialized_end=405
_globals['_RPCREQUEST_METADATAENTRY']._serialized_start=347
_globals['_RPCREQUEST_METADATAENTRY']._serialized_end=394
_globals['_RPCRESPONSE']._serialized_start=408
_globals['_RPCRESPONSE']._serialized_end=592
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=347
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=394
_globals['_EVENT']._serialized_start=595
_globals['_EVENT']._serialized_end=823
_globals['_EVENT_METADATAENTRY']._serialized_start=347
_globals['_EVENT_METADATAENTRY']._serialized_end=394
_globals['_REGISTERAGENTTYPE']._serialized_start=825
_globals['_REGISTERAGENTTYPE']._serialized_end=858
_globals['_TYPESUBSCRIPTION']._serialized_start=860
_globals['_TYPESUBSCRIPTION']._serialized_end=918
_globals['_SUBSCRIPTION']._serialized_start=920
_globals['_SUBSCRIPTION']._serialized_end=1004
_globals['_ADDSUBSCRIPTION']._serialized_start=1006
_globals['_ADDSUBSCRIPTION']._serialized_end=1067
_globals['_MESSAGE']._serialized_start=1070
_globals['_MESSAGE']._serialized_end=1310
_globals['_AGENTRPC']._serialized_start=1312
_globals['_AGENTRPC']._serialized_end=1375
# @@protoc_insertion_point(module_scope)

View File

@ -97,8 +97,9 @@ class RpcRequest(google.protobuf.message.Message):
payload: global___Payload | None = ...,
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["payload", b"payload", "source", b"source", "target", b"target"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["metadata", b"metadata", "method", b"method", "payload", b"payload", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ...
def HasField(self, field_name: typing.Literal["_source", b"_source", "payload", b"payload", "source", b"source", "target", b"target"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_source", b"_source", "metadata", b"metadata", "method", b"method", "payload", b"payload", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["_source", b"_source"]) -> typing.Literal["source"] | None: ...
global___RpcRequest = RpcRequest
@ -167,11 +168,14 @@ class Event(google.protobuf.message.Message):
TOPIC_TYPE_FIELD_NUMBER: builtins.int
TOPIC_SOURCE_FIELD_NUMBER: builtins.int
SOURCE_FIELD_NUMBER: builtins.int
PAYLOAD_FIELD_NUMBER: builtins.int
METADATA_FIELD_NUMBER: builtins.int
topic_type: builtins.str
topic_source: builtins.str
@property
def source(self) -> global___AgentId: ...
@property
def payload(self) -> global___Payload: ...
@property
def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ...
@ -180,11 +184,13 @@ class Event(google.protobuf.message.Message):
*,
topic_type: builtins.str = ...,
topic_source: builtins.str = ...,
source: global___AgentId | None = ...,
payload: global___Payload | None = ...,
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["payload", b"payload"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["metadata", b"metadata", "payload", b"payload", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ...
def HasField(self, field_name: typing.Literal["_source", b"_source", "payload", b"payload", "source", b"source"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_source", b"_source", "metadata", b"metadata", "payload", b"payload", "source", b"source", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["_source", b"_source"]) -> typing.Literal["source"] | None: ...
global___Event = Event

View File

@ -0,0 +1,209 @@
import asyncio
import pytest
from autogen_core.application import WorkerAgentRuntime, WorkerAgentRuntimeHost
from autogen_core.base import (
MESSAGE_TYPE_REGISTRY,
AgentId,
AgentInstantiationContext,
TopicId,
try_get_known_serializers_for_type,
)
from autogen_core.components import DefaultSubscription, DefaultTopicId, TypeSubscription
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent
@pytest.mark.asyncio
async def test_agent_names_must_be_unique() -> None:
# Keep it unique to this test only.
host_address = "localhost:50051"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
worker = WorkerAgentRuntime(host_address=host_address)
worker.start()
def agent_factory() -> NoopAgent:
id = AgentInstantiationContext.current_agent_id()
assert id == AgentId("name1", "default")
agent = NoopAgent()
assert agent.id == id
return agent
await worker.register("name1", agent_factory)
with pytest.raises(ValueError):
await worker.register("name1", NoopAgent)
await worker.register("name3", NoopAgent)
# Let the agent run for a bit.
await asyncio.sleep(2)
await worker.stop()
await host.stop()
@pytest.mark.asyncio
async def test_register_receives_publish() -> None:
# Keep it unique to this test only.
host_address = "localhost:50052"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
worker = WorkerAgentRuntime(host_address=host_address)
worker.start()
await worker.register("name", LoopbackAgent)
await worker.add_subscription(TypeSubscription("default", "name"))
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await worker.publish_message(MessageType(), topic_id=topic_id)
# Let the agent run for a bit.
await asyncio.sleep(2)
# Agent in default namespace should have received the message
long_running_agent = await worker.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await worker.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgent
)
assert other_long_running_agent.num_calls == 0
await worker.stop()
await host.stop()
@pytest.mark.asyncio
async def test_register_receives_publish_cascade() -> None:
# Keep it unique to this test only.
host_address = "localhost:50053"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(CascadingMessageType))
runtime = WorkerAgentRuntime(host_address=host_address)
runtime.start()
num_agents = 5
num_initial_messages = 5
max_rounds = 5
total_num_calls_expected = 0
for i in range(0, max_rounds):
total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i)
# Register agents
for i in range(num_agents):
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds))
await runtime.add_subscription(TypeSubscription("default", f"name{i}"))
# Publish messages
for _ in range(num_initial_messages):
await runtime.publish_message(CascadingMessageType(round=1), topic_id=DefaultTopicId())
# Let the agents run for a bit.
await asyncio.sleep(5)
# Check that each agent received the correct number of messages.
for i in range(num_agents):
agent = await runtime.try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent)
assert agent.num_calls == total_num_calls_expected
await runtime.stop()
await host.stop()
@pytest.mark.asyncio
async def test_default_subscription() -> None:
# Keep it unique to this test only.
host_address = "localhost:50054"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
runtime = WorkerAgentRuntime(host_address=host_address)
runtime.start()
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
await asyncio.sleep(2)
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgent
)
assert other_long_running_agent.num_calls == 0
await runtime.stop()
await host.stop()
@pytest.mark.asyncio
async def test_non_default_default_subscription() -> None:
# Keep it unique to this test only.
host_address = "localhost:50055"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
runtime = WorkerAgentRuntime(host_address=host_address)
runtime.start()
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription(topic_type="Other")])
agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(type="Other"))
await asyncio.sleep(2)
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgent
)
assert other_long_running_agent.num_calls == 0
await runtime.stop()
await host.stop()
@pytest.mark.asyncio
async def test_non_publish_to_other_source() -> None:
# Keep it unique to this test only.
host_address = "localhost:50056"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
runtime = WorkerAgentRuntime(host_address=host_address)
runtime.start()
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(source="other"))
await asyncio.sleep(2)
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 0
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgent
)
assert other_long_running_agent.num_calls == 1
await runtime.stop()
await host.stop()