mirror of https://github.com/microsoft/autogen.git
Distributed agent runtime API for host and worker; unit tests; documentation (#465)
* host agent runtime API and docs * graceful shutdown of worker * HostAgentRuntime --> WorkerAgentRuntimeHost * Add unit tests for worker runtime * Fix bug in worker runtime adding sender filed to proto. Documentation. * wip * Fix unit tests; refactor API * fix formatting * Fix * Update * Make source field optional in Event proto
This commit is contained in:
parent
0376a0b399
commit
a6c1b503ad
|
@ -15,7 +15,7 @@ message Payload {
|
|||
|
||||
message RpcRequest {
|
||||
string request_id = 1;
|
||||
AgentId source = 2;
|
||||
optional AgentId source = 2;
|
||||
AgentId target = 3;
|
||||
string method = 4;
|
||||
Payload payload = 5;
|
||||
|
@ -32,8 +32,9 @@ message RpcResponse {
|
|||
message Event {
|
||||
string topic_type = 1;
|
||||
string topic_source = 2;
|
||||
Payload payload = 3;
|
||||
map<string, string> metadata = 4;
|
||||
optional AgentId source = 3;
|
||||
Payload payload = 4;
|
||||
map<string, string> metadata = 5;
|
||||
}
|
||||
|
||||
message RegisterAgentType {
|
||||
|
|
|
@ -0,0 +1,207 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Distributed Agent Runtime (Experimental)\n",
|
||||
"\n",
|
||||
"```{note}\n",
|
||||
"The distributed agent runtime is an experimental feature. Expect breaking changes\n",
|
||||
"to the API.\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"A distributed agent runtime facilitates communication and agent lifecycle management\n",
|
||||
"across process boundaries.\n",
|
||||
"It consists of a host service and at least one worker runtime.\n",
|
||||
"\n",
|
||||
"The host service maintains connections to all active worker runtimes,\n",
|
||||
"facilitates message delivery, and keeps sessions for all direct messages (i.e., RPCs).\n",
|
||||
"A worker runtime processes application code (agents) and connects to the host service.\n",
|
||||
"It also advertises the agents which they support to the host service,\n",
|
||||
"so the host service can deliver messages to the correct worker.\n",
|
||||
"\n",
|
||||
"We can start a host service using {py:class}`autogen_core.application.WorkerAgentRuntimeHost`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from autogen_core.application import WorkerAgentRuntimeHost\n",
|
||||
"\n",
|
||||
"host = WorkerAgentRuntimeHost(address=\"localhost:50051\")\n",
|
||||
"host.start() # Start a host service in the background."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The above code starts the host service in the background and accepting\n",
|
||||
"worker connections at port 50051.\n",
|
||||
"\n",
|
||||
"Before running worker runtimes, let's define our agent.\n",
|
||||
"The agent publishes a new message on every message it receives.\n",
|
||||
"It also keeps track of how many messages it has published, and \n",
|
||||
"stops publishing new messages once it has published 5 messages."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dataclasses import dataclass\n",
|
||||
"\n",
|
||||
"from autogen_core.base import MessageContext\n",
|
||||
"from autogen_core.components import DefaultTopicId, RoutedAgent, message_handler\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@dataclass\n",
|
||||
"class MyMessage:\n",
|
||||
" content: str\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class MyAgent(RoutedAgent):\n",
|
||||
" def __init__(self, name: str) -> None:\n",
|
||||
" super().__init__(\"My agent\")\n",
|
||||
" self._name = name\n",
|
||||
" self._counter = 0\n",
|
||||
"\n",
|
||||
" @message_handler\n",
|
||||
" async def my_message_handler(self, message: MyMessage, ctx: MessageContext) -> None:\n",
|
||||
" self._counter += 1\n",
|
||||
" if self._counter > 5:\n",
|
||||
" return\n",
|
||||
" content = f\"{self._name}: Hello x {self._counter}\"\n",
|
||||
" print(content)\n",
|
||||
" await self.publish_message(MyMessage(content=content), DefaultTopicId())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we can set up the worker agent runtimes.\n",
|
||||
"We use {py:class}`autogen_core.application.WorkerAgentRuntime`.\n",
|
||||
"We set up two worker runtimes. Each runtime hosts one agent.\n",
|
||||
"All agents publish and subscribe to the default topic, so they can see all\n",
|
||||
"messages being published.\n",
|
||||
"\n",
|
||||
"To run the agents, we publishes a message from a worker."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"worker1: Hello x 1\n",
|
||||
"worker2: Hello x 1\n",
|
||||
"worker2: Hello x 2\n",
|
||||
"worker1: Hello x 2\n",
|
||||
"worker1: Hello x 3\n",
|
||||
"worker2: Hello x 3\n",
|
||||
"worker2: Hello x 4\n",
|
||||
"worker1: Hello x 4\n",
|
||||
"worker1: Hello x 5\n",
|
||||
"worker2: Hello x 5\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import asyncio\n",
|
||||
"\n",
|
||||
"from autogen_core.application import WorkerAgentRuntime\n",
|
||||
"from autogen_core.base import MESSAGE_TYPE_REGISTRY, AgentId, try_get_known_serializers_for_type\n",
|
||||
"from autogen_core.components import DefaultSubscription\n",
|
||||
"\n",
|
||||
"MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MyMessage))\n",
|
||||
"\n",
|
||||
"worker1 = WorkerAgentRuntime(host_address=\"localhost:50051\")\n",
|
||||
"worker1.start()\n",
|
||||
"await worker1.register(\"worker1\", lambda: MyAgent(\"worker1\"), lambda: [DefaultSubscription()])\n",
|
||||
"\n",
|
||||
"worker2 = WorkerAgentRuntime(host_address=\"localhost:50051\")\n",
|
||||
"worker2.start()\n",
|
||||
"await worker2.register(\"worker2\", lambda: MyAgent(\"worker2\"), lambda: [DefaultSubscription()])\n",
|
||||
"\n",
|
||||
"await worker2.publish_message(MyMessage(content=\"Hello!\"), DefaultTopicId())\n",
|
||||
"\n",
|
||||
"# Let the agents run for a while.\n",
|
||||
"await asyncio.sleep(5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can see each agent published exactly 5 messages.\n",
|
||||
"\n",
|
||||
"To stop the worker runtimes, we can call {py:meth}`autogen_core.application.WorkerAgentRuntime.stop`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await worker1.stop()\n",
|
||||
"await worker2.stop()\n",
|
||||
"\n",
|
||||
"# To keep the worker running until a termination signal is received (e.g., SIGTERM).\n",
|
||||
"# await worker1.stop_when_signal()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can call {py:meth}`autogen_core.application.WorkerAgentRuntimeHost.stop`\n",
|
||||
"to stop the host service."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await host.stop()\n",
|
||||
"\n",
|
||||
"# To keep the host service running until a termination signal (e.g., SIGTERM)\n",
|
||||
"# await host.stop_when_signal()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "agnext",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -44,8 +44,8 @@ from `Agent and Multi-Agent Application <core-concepts/agent-and-multi-agent-app
|
|||
:hidden:
|
||||
|
||||
guides/logging
|
||||
guides/distributed-agent-runtime
|
||||
guides/telemetry
|
||||
guides/worker-protocol
|
||||
|
||||
.. toctree::
|
||||
:caption: Cookbook
|
||||
|
|
|
@ -1,44 +1,12 @@
|
|||
import asyncio
|
||||
import signal
|
||||
|
||||
import grpc
|
||||
from autogen_core.application import HostRuntimeServicer
|
||||
from autogen_core.application.protos import agent_worker_pb2_grpc
|
||||
|
||||
|
||||
async def serve(server: grpc.aio.Server) -> None: # type: ignore
|
||||
await server.start()
|
||||
print("Server started")
|
||||
await server.wait_for_termination()
|
||||
from autogen_core.application import WorkerAgentRuntimeHost
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
server = grpc.aio.server()
|
||||
agent_worker_pb2_grpc.add_AgentRpcServicer_to_server(HostRuntimeServicer(), server)
|
||||
server.add_insecure_port("[::]:50051")
|
||||
|
||||
# Set up signal handling for graceful shutdown
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
shutdown_event = asyncio.Event()
|
||||
|
||||
def signal_handler() -> None:
|
||||
print("Received exit signal, shutting down gracefully...")
|
||||
shutdown_event.set()
|
||||
|
||||
loop.add_signal_handler(signal.SIGINT, signal_handler)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Start server in background task
|
||||
serve_task = asyncio.create_task(serve(server))
|
||||
|
||||
# Wait for the signal to trigger the shutdown event
|
||||
await shutdown_event.wait()
|
||||
|
||||
# Graceful shutdown
|
||||
await server.stop(5) # 5 second grace period
|
||||
await serve_task
|
||||
print("Server stopped")
|
||||
service = WorkerAgentRuntimeHost(address="localhost:50051")
|
||||
service.start()
|
||||
await service.stop_when_signal()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -46,8 +14,4 @@ if __name__ == "__main__":
|
|||
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("autogen_core").setLevel(logging.DEBUG)
|
||||
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("Server shutdown interrupted.")
|
||||
asyncio.run(main())
|
||||
|
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|||
from typing import Any, NoReturn
|
||||
|
||||
from autogen_core.application import WorkerAgentRuntime
|
||||
from autogen_core.base import MESSAGE_TYPE_REGISTRY, MessageContext, try_get_known_serializers_for_type
|
||||
from autogen_core.base import MESSAGE_TYPE_REGISTRY, AgentId, MessageContext, try_get_known_serializers_for_type
|
||||
from autogen_core.components import DefaultSubscription, DefaultTopicId, RoutedAgent, message_handler
|
||||
|
||||
|
||||
|
@ -66,25 +66,20 @@ class GreeterAgent(RoutedAgent):
|
|||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = WorkerAgentRuntime()
|
||||
runtime = WorkerAgentRuntime(host_address="localhost:50051")
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(Greeting))
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(AskToGreet))
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(Feedback))
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(ReturnedGreeting))
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(ReturnedFeedback))
|
||||
await runtime.start(host_connection_string="localhost:50051")
|
||||
runtime.start()
|
||||
|
||||
await runtime.register("receiver", ReceiveAgent, lambda: [DefaultSubscription()])
|
||||
await runtime.register("greeter", GreeterAgent, lambda: [DefaultSubscription()])
|
||||
|
||||
await runtime.publish_message(AskToGreet("Hello World!"), topic_id=DefaultTopicId())
|
||||
|
||||
# Just to keep the runtime running
|
||||
try:
|
||||
await asyncio.sleep(1000000)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
await runtime.stop()
|
||||
await runtime.stop_when_signal()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -60,11 +60,11 @@ class GreeterAgent(RoutedAgent):
|
|||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = WorkerAgentRuntime()
|
||||
runtime = WorkerAgentRuntime(host_address="localhost:50051")
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(Greeting))
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(AskToGreet))
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(Feedback))
|
||||
await runtime.start(host_connection_string="localhost:50051")
|
||||
runtime.start()
|
||||
|
||||
await runtime.register("receiver", lambda: ReceiveAgent(), lambda: [DefaultSubscription()])
|
||||
await runtime.register(
|
||||
|
@ -74,12 +74,7 @@ async def main() -> None:
|
|||
)
|
||||
await runtime.publish_message(AskToGreet("Hello World!"), topic_id=DefaultTopicId())
|
||||
|
||||
# Just to keep the runtime running
|
||||
try:
|
||||
await asyncio.sleep(1000000)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
await runtime.stop()
|
||||
await runtime.stop_when_signal()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -2,8 +2,12 @@
|
|||
The :mod:`autogen_core.application` module provides implementations of core components that are used to compose an application
|
||||
"""
|
||||
|
||||
from ._host_runtime_servicer import HostRuntimeServicer
|
||||
from ._single_threaded_agent_runtime import SingleThreadedAgentRuntime
|
||||
from ._worker_runtime import WorkerAgentRuntime
|
||||
from ._worker_runtime_host import WorkerAgentRuntimeHost
|
||||
|
||||
__all__ = ["SingleThreadedAgentRuntime", "WorkerAgentRuntime", "HostRuntimeServicer"]
|
||||
__all__ = [
|
||||
"SingleThreadedAgentRuntime",
|
||||
"WorkerAgentRuntime",
|
||||
"WorkerAgentRuntimeHost",
|
||||
]
|
||||
|
|
|
@ -2,6 +2,7 @@ import asyncio
|
|||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import signal
|
||||
import warnings
|
||||
from asyncio import Future, Task
|
||||
from collections import defaultdict
|
||||
|
@ -19,6 +20,7 @@ from typing import (
|
|||
Literal,
|
||||
Mapping,
|
||||
ParamSpec,
|
||||
Sequence,
|
||||
Set,
|
||||
Type,
|
||||
TypeVar,
|
||||
|
@ -96,13 +98,9 @@ class HostConnection:
|
|||
self._connection_task: Task[None] | None = None
|
||||
|
||||
@classmethod
|
||||
async def from_connection_string(
|
||||
cls, connection_string: str, grpc_config: Mapping[str, Any] = DEFAULT_GRPC_CONFIG
|
||||
) -> Self:
|
||||
logger.info("Connecting to %s", connection_string)
|
||||
channel = grpc.aio.insecure_channel(
|
||||
connection_string, options=[("grpc.service_config", json.dumps(grpc_config))]
|
||||
)
|
||||
def from_host_address(cls, host_address: str, grpc_config: Mapping[str, Any] = DEFAULT_GRPC_CONFIG) -> Self:
|
||||
logger.info("Connecting to %s", host_address)
|
||||
channel = grpc.aio.insecure_channel(host_address, options=[("grpc.service_config", json.dumps(grpc_config))])
|
||||
instance = cls(channel)
|
||||
instance._connection_task = asyncio.create_task(
|
||||
instance._connect(channel, instance._send_queue, instance._recv_queue)
|
||||
|
@ -110,6 +108,8 @@ class HostConnection:
|
|||
return instance
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._connection_task is None:
|
||||
raise RuntimeError("Connection is not open.")
|
||||
await self._channel.close()
|
||||
if self._connection_task is not None:
|
||||
await self._connection_task
|
||||
|
@ -128,22 +128,15 @@ class HostConnection:
|
|||
) # type: ignore
|
||||
|
||||
while True:
|
||||
try:
|
||||
logger.info("Waiting for message from host")
|
||||
message = await recv_stream.read() # type: ignore
|
||||
if message == grpc.aio.EOF: # type: ignore
|
||||
logger.info("EOF")
|
||||
break
|
||||
message = cast(agent_worker_pb2.Message, message)
|
||||
logger.info(f"Received a message from host: {message}")
|
||||
await receive_queue.put(message)
|
||||
logger.info("Put message in receive queue")
|
||||
except Exception as e:
|
||||
print("=========================================================================")
|
||||
print(e)
|
||||
print("=========================================================================")
|
||||
del recv_stream
|
||||
recv_stream = stub.OpenChannel(QueueAsyncIterable(send_queue)) # type: ignore
|
||||
logger.info("Waiting for message from host")
|
||||
message = await recv_stream.read() # type: ignore
|
||||
if message == grpc.aio.EOF: # type: ignore
|
||||
logger.info("EOF")
|
||||
break
|
||||
message = cast(agent_worker_pb2.Message, message)
|
||||
logger.info(f"Received a message from host: {message}")
|
||||
await receive_queue.put(message)
|
||||
logger.info("Put message in receive queue")
|
||||
|
||||
async def send(self, message: agent_worker_pb2.Message) -> None:
|
||||
logger.info(f"Send message to host: {message}")
|
||||
|
@ -156,7 +149,8 @@ class HostConnection:
|
|||
|
||||
|
||||
class WorkerAgentRuntime(AgentRuntime):
|
||||
def __init__(self, tracer_provider: TracerProvider | None = None) -> None:
|
||||
def __init__(self, host_address: str, tracer_provider: TracerProvider | None = None) -> None:
|
||||
self._host_address = host_address
|
||||
self._trace_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("Worker Runtime"))
|
||||
self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set)
|
||||
self._agent_factories: Dict[
|
||||
|
@ -173,12 +167,13 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
self._background_tasks: Set[Task[Any]] = set()
|
||||
self._subscription_manager = SubscriptionManager()
|
||||
|
||||
async def start(self, host_connection_string: str) -> None:
|
||||
def start(self) -> None:
|
||||
"""Start the runtime in a background task."""
|
||||
if self._running:
|
||||
raise ValueError("Runtime is already running.")
|
||||
logger.info(f"Connecting to host: {host_connection_string}")
|
||||
self._host_connection = await HostConnection.from_connection_string(host_connection_string)
|
||||
logger.info("connection")
|
||||
logger.info(f"Connecting to host: {self._host_address}")
|
||||
self._host_connection = HostConnection.from_host_address(self._host_address)
|
||||
logger.info("Connection established")
|
||||
if self._read_task is None:
|
||||
self._read_task = asyncio.create_task(self._run_read_loop())
|
||||
self._running = True
|
||||
|
@ -197,7 +192,7 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
oneofcase = agent_worker_pb2.Message.WhichOneof(message, "message")
|
||||
match oneofcase:
|
||||
case "registerAgentType" | "addSubscription":
|
||||
logger.warn(f"Cant handle {oneofcase}, skipping.")
|
||||
logger.warning(f"Cant handle {oneofcase}, skipping.")
|
||||
case "request":
|
||||
request: agent_worker_pb2.RpcRequest = message.request
|
||||
task = asyncio.create_task(self._process_request(request))
|
||||
|
@ -217,16 +212,51 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case None:
|
||||
logger.warn("No message")
|
||||
logger.warning("No message")
|
||||
except Exception as e:
|
||||
logger.error("Error in read loop", exc_info=e)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the runtime immediately."""
|
||||
if not self._running:
|
||||
raise RuntimeError("Runtime is not running.")
|
||||
self._running = False
|
||||
# Wait for all background tasks to finish.
|
||||
final_tasks_results = await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
for task_result in final_tasks_results:
|
||||
if isinstance(task_result, Exception):
|
||||
logger.error("Error in background task", exc_info=task_result)
|
||||
# Close the host connection.
|
||||
if self._host_connection is not None:
|
||||
await self._host_connection.close()
|
||||
try:
|
||||
await self._host_connection.close()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Cancel the read task.
|
||||
if self._read_task is not None:
|
||||
await self._read_task
|
||||
self._read_task.cancel()
|
||||
try:
|
||||
await self._read_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def stop_when_signal(self, signals: Sequence[signal.Signals] = (signal.SIGTERM, signal.SIGINT)) -> None:
|
||||
"""Stop the runtime when a signal is received."""
|
||||
loop = asyncio.get_running_loop()
|
||||
shutdown_event = asyncio.Event()
|
||||
|
||||
def signal_handler() -> None:
|
||||
logger.info("Received exit signal, shutting down gracefully...")
|
||||
shutdown_event.set()
|
||||
|
||||
for sig in signals:
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
|
||||
# Wait for the signal to trigger the shutdown event.
|
||||
await shutdown_event.wait()
|
||||
|
||||
# Stop the runtime.
|
||||
await self.stop()
|
||||
|
||||
@property
|
||||
def _known_agent_names(self) -> Set[str]:
|
||||
|
@ -267,7 +297,6 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
request_id = self._next_request_id
|
||||
request_id_str = str(request_id)
|
||||
self._pending_requests[request_id_str] = future
|
||||
sender = cast(AgentId, sender)
|
||||
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(
|
||||
message, type_name=data_type, data_content_type=JSON_DATA_CONTENT_TYPE
|
||||
)
|
||||
|
@ -276,7 +305,7 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
request=agent_worker_pb2.RpcRequest(
|
||||
request_id=request_id_str,
|
||||
target=agent_worker_pb2.AgentId(type=recipient.type, key=recipient.key),
|
||||
source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key),
|
||||
source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key) if sender is not None else None,
|
||||
metadata=telemetry_metadata,
|
||||
payload=agent_worker_pb2.Payload(
|
||||
data_type=data_type,
|
||||
|
@ -317,6 +346,7 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
event=agent_worker_pb2.Event(
|
||||
topic_type=topic_id.type,
|
||||
topic_source=topic_id.source,
|
||||
source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key) if sender is not None else None,
|
||||
metadata=telemetry_metadata,
|
||||
payload=agent_worker_pb2.Payload(
|
||||
data_type=message_type,
|
||||
|
@ -348,10 +378,13 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
|
||||
async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None:
|
||||
assert self._host_connection is not None
|
||||
target = AgentId(request.target.type, request.target.key)
|
||||
source = AgentId(request.source.type, request.source.key)
|
||||
|
||||
logging.info(f"Processing request from {source} to {target}")
|
||||
recipient = AgentId(request.target.type, request.target.key)
|
||||
sender: AgentId | None = None
|
||||
if request.HasField("source"):
|
||||
sender = AgentId(request.source.type, request.source.key)
|
||||
logging.info(f"Processing request from {sender} to {recipient}")
|
||||
else:
|
||||
logging.info(f"Processing request from unknown source to {recipient}")
|
||||
|
||||
# Deserialize the message.
|
||||
message = MESSAGE_TYPE_REGISTRY.deserialize(
|
||||
|
@ -360,26 +393,26 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
data_content_type=request.payload.data_content_type,
|
||||
)
|
||||
|
||||
# Get the target agent and prepare the message context.
|
||||
target_agent = await self._get_agent(target)
|
||||
# Get the receiving agent and prepare the message context.
|
||||
rec_agent = await self._get_agent(recipient)
|
||||
message_context = MessageContext(
|
||||
sender=source,
|
||||
sender=sender,
|
||||
topic_id=None,
|
||||
is_rpc=True,
|
||||
cancellation_token=CancellationToken(),
|
||||
)
|
||||
|
||||
# Call the target agent.
|
||||
# Call the receiving agent.
|
||||
try:
|
||||
with MessageHandlerContext.populate_context(target_agent.id):
|
||||
with MessageHandlerContext.populate_context(rec_agent.id):
|
||||
with self._trace_helper.trace_block(
|
||||
"process",
|
||||
target_agent.id,
|
||||
rec_agent.id,
|
||||
parent=request.metadata,
|
||||
attributes={"request_id": request.request_id},
|
||||
extraAttributes={"message_type": request.payload.data_type},
|
||||
):
|
||||
result = await target_agent.on_message(message, ctx=message_context)
|
||||
result = await rec_agent.on_message(message, ctx=message_context)
|
||||
except BaseException as e:
|
||||
response_message = agent_worker_pb2.Message(
|
||||
response=agent_worker_pb2.RpcResponse(
|
||||
|
@ -436,18 +469,22 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
future.set_result(result)
|
||||
|
||||
async def _process_event(self, event: agent_worker_pb2.Event) -> None:
|
||||
topic_id = TopicId(event.topic_type, event.topic_source)
|
||||
message = MESSAGE_TYPE_REGISTRY.deserialize(
|
||||
event.payload.data, type_name=event.payload.data_type, data_content_type=event.payload.data_content_type
|
||||
)
|
||||
sender: AgentId | None = None
|
||||
if event.HasField("source"):
|
||||
sender = AgentId(event.source.type, event.source.key)
|
||||
topic_id = TopicId(event.topic_type, event.topic_source)
|
||||
# Get the recipients for the topic.
|
||||
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
|
||||
# Send the message to each recipient.
|
||||
responses: List[Awaitable[Any]] = []
|
||||
for agent_id in recipients:
|
||||
# TODO: avoid sending to the sender.
|
||||
if agent_id == sender:
|
||||
continue
|
||||
message_context = MessageContext(
|
||||
sender=None,
|
||||
sender=sender,
|
||||
topic_id=topic_id,
|
||||
is_rpc=False,
|
||||
cancellation_token=CancellationToken(),
|
||||
|
@ -543,7 +580,16 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
|
||||
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
|
||||
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
|
||||
raise NotImplementedError("try_get_underlying_agent_instance is not yet implemented.")
|
||||
if id.type not in self._agent_factories:
|
||||
raise LookupError(f"Agent with name {id.type} not found.")
|
||||
|
||||
# TODO: check if remote
|
||||
agent_instance = await self._get_agent(id)
|
||||
|
||||
if not isinstance(agent_instance, type):
|
||||
raise TypeError(f"Agent with name {id.type} is not of type {type.__name__}")
|
||||
|
||||
return agent_instance
|
||||
|
||||
async def add_subscription(self, subscription: Subscription) -> None:
|
||||
if self._host_connection is None:
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
from typing import Sequence
|
||||
|
||||
import grpc
|
||||
|
||||
from ._worker_runtime_host_servicer import WorkerAgentRuntimeHostServicer
|
||||
from .protos import agent_worker_pb2_grpc
|
||||
|
||||
logger = logging.getLogger("autogen_core")
|
||||
|
||||
|
||||
class WorkerAgentRuntimeHost:
|
||||
def __init__(self, address: str) -> None:
|
||||
self._server = grpc.aio.server()
|
||||
self._servicer = WorkerAgentRuntimeHostServicer()
|
||||
agent_worker_pb2_grpc.add_AgentRpcServicer_to_server(self._servicer, self._server)
|
||||
self._server.add_insecure_port(address)
|
||||
self._address = address
|
||||
self._serve_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def _serve(self) -> None:
|
||||
await self._server.start()
|
||||
logger.info(f"Server started at {self._address}.")
|
||||
await self._server.wait_for_termination()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the server in a background task."""
|
||||
if self._serve_task is not None:
|
||||
raise RuntimeError("Host runtime is already started.")
|
||||
self._serve_task = asyncio.create_task(self._serve())
|
||||
|
||||
async def stop(self, grace: int = 5) -> None:
|
||||
"""Stop the server."""
|
||||
if self._serve_task is None:
|
||||
raise RuntimeError("Host runtime is not started.")
|
||||
await self._server.stop(grace=grace)
|
||||
self._serve_task.cancel()
|
||||
try:
|
||||
await self._serve_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Server stopped.")
|
||||
self._serve_task = None
|
||||
|
||||
async def stop_when_signal(
|
||||
self, grace: int = 5, signals: Sequence[signal.Signals] = (signal.SIGTERM, signal.SIGINT)
|
||||
) -> None:
|
||||
"""Stop the server when a signal is received."""
|
||||
if self._serve_task is None:
|
||||
raise RuntimeError("Host runtime is not started.")
|
||||
# Set up signal handling for graceful shutdown.
|
||||
loop = asyncio.get_running_loop()
|
||||
shutdown_event = asyncio.Event()
|
||||
|
||||
def signal_handler() -> None:
|
||||
logger.info("Received exit signal, shutting down gracefully...")
|
||||
shutdown_event.set()
|
||||
|
||||
for sig in signals:
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
|
||||
# Wait for the signal to trigger the shutdown event.
|
||||
await shutdown_event.wait()
|
||||
|
||||
# Shutdown the server.
|
||||
await self.stop(grace=grace)
|
|
@ -15,7 +15,7 @@ logger = logging.getLogger("autogen_core")
|
|||
event_logger = logging.getLogger("autogen_core.events")
|
||||
|
||||
|
||||
class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
||||
class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
||||
"""A gRPC servicer that hosts message delivery service for agents."""
|
||||
|
||||
def __init__(self) -> None:
|
|
@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
|
|||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xf9\x01\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x1f\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentId\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xb3\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12 \n\x07payload\x18\x03 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x04 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"T\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x42\x0e\n\x0csubscription\"=\n\x0f\x41\x64\x64Subscription\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\xf0\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x12\x32\n\x0f\x61\x64\x64Subscription\x18\x05 \x01(\x0b\x32\x17.agents.AddSubscriptionH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xe4\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12$\n\x06source\x18\x03 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12 \n\x07payload\x18\x04 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"T\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x42\x0e\n\x0csubscription\"=\n\x0f\x41\x64\x64Subscription\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\xf0\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x12\x32\n\x0f\x61\x64\x64Subscription\x18\x05 \x01(\x0b\x32\x17.agents.AddSubscriptionH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
|
@ -32,27 +32,27 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|||
_globals['_PAYLOAD']._serialized_start=68
|
||||
_globals['_PAYLOAD']._serialized_end=137
|
||||
_globals['_RPCREQUEST']._serialized_start=140
|
||||
_globals['_RPCREQUEST']._serialized_end=389
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._serialized_start=342
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._serialized_end=389
|
||||
_globals['_RPCRESPONSE']._serialized_start=392
|
||||
_globals['_RPCRESPONSE']._serialized_end=576
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=342
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=389
|
||||
_globals['_EVENT']._serialized_start=579
|
||||
_globals['_EVENT']._serialized_end=758
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_start=342
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_end=389
|
||||
_globals['_REGISTERAGENTTYPE']._serialized_start=760
|
||||
_globals['_REGISTERAGENTTYPE']._serialized_end=793
|
||||
_globals['_TYPESUBSCRIPTION']._serialized_start=795
|
||||
_globals['_TYPESUBSCRIPTION']._serialized_end=853
|
||||
_globals['_SUBSCRIPTION']._serialized_start=855
|
||||
_globals['_SUBSCRIPTION']._serialized_end=939
|
||||
_globals['_ADDSUBSCRIPTION']._serialized_start=941
|
||||
_globals['_ADDSUBSCRIPTION']._serialized_end=1002
|
||||
_globals['_MESSAGE']._serialized_start=1005
|
||||
_globals['_MESSAGE']._serialized_end=1245
|
||||
_globals['_AGENTRPC']._serialized_start=1247
|
||||
_globals['_AGENTRPC']._serialized_end=1310
|
||||
_globals['_RPCREQUEST']._serialized_end=405
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._serialized_start=347
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._serialized_end=394
|
||||
_globals['_RPCRESPONSE']._serialized_start=408
|
||||
_globals['_RPCRESPONSE']._serialized_end=592
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=347
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=394
|
||||
_globals['_EVENT']._serialized_start=595
|
||||
_globals['_EVENT']._serialized_end=823
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_start=347
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_end=394
|
||||
_globals['_REGISTERAGENTTYPE']._serialized_start=825
|
||||
_globals['_REGISTERAGENTTYPE']._serialized_end=858
|
||||
_globals['_TYPESUBSCRIPTION']._serialized_start=860
|
||||
_globals['_TYPESUBSCRIPTION']._serialized_end=918
|
||||
_globals['_SUBSCRIPTION']._serialized_start=920
|
||||
_globals['_SUBSCRIPTION']._serialized_end=1004
|
||||
_globals['_ADDSUBSCRIPTION']._serialized_start=1006
|
||||
_globals['_ADDSUBSCRIPTION']._serialized_end=1067
|
||||
_globals['_MESSAGE']._serialized_start=1070
|
||||
_globals['_MESSAGE']._serialized_end=1310
|
||||
_globals['_AGENTRPC']._serialized_start=1312
|
||||
_globals['_AGENTRPC']._serialized_end=1375
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
|
|
@ -97,8 +97,9 @@ class RpcRequest(google.protobuf.message.Message):
|
|||
payload: global___Payload | None = ...,
|
||||
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["payload", b"payload", "source", b"source", "target", b"target"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["metadata", b"metadata", "method", b"method", "payload", b"payload", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["_source", b"_source", "payload", b"payload", "source", b"source", "target", b"target"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["_source", b"_source", "metadata", b"metadata", "method", b"method", "payload", b"payload", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_source", b"_source"]) -> typing.Literal["source"] | None: ...
|
||||
|
||||
global___RpcRequest = RpcRequest
|
||||
|
||||
|
@ -167,11 +168,14 @@ class Event(google.protobuf.message.Message):
|
|||
|
||||
TOPIC_TYPE_FIELD_NUMBER: builtins.int
|
||||
TOPIC_SOURCE_FIELD_NUMBER: builtins.int
|
||||
SOURCE_FIELD_NUMBER: builtins.int
|
||||
PAYLOAD_FIELD_NUMBER: builtins.int
|
||||
METADATA_FIELD_NUMBER: builtins.int
|
||||
topic_type: builtins.str
|
||||
topic_source: builtins.str
|
||||
@property
|
||||
def source(self) -> global___AgentId: ...
|
||||
@property
|
||||
def payload(self) -> global___Payload: ...
|
||||
@property
|
||||
def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ...
|
||||
|
@ -180,11 +184,13 @@ class Event(google.protobuf.message.Message):
|
|||
*,
|
||||
topic_type: builtins.str = ...,
|
||||
topic_source: builtins.str = ...,
|
||||
source: global___AgentId | None = ...,
|
||||
payload: global___Payload | None = ...,
|
||||
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["payload", b"payload"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["metadata", b"metadata", "payload", b"payload", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["_source", b"_source", "payload", b"payload", "source", b"source"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["_source", b"_source", "metadata", b"metadata", "payload", b"payload", "source", b"source", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_source", b"_source"]) -> typing.Literal["source"] | None: ...
|
||||
|
||||
global___Event = Event
|
||||
|
||||
|
|
|
@ -0,0 +1,209 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
from autogen_core.application import WorkerAgentRuntime, WorkerAgentRuntimeHost
|
||||
from autogen_core.base import (
|
||||
MESSAGE_TYPE_REGISTRY,
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
TopicId,
|
||||
try_get_known_serializers_for_type,
|
||||
)
|
||||
from autogen_core.components import DefaultSubscription, DefaultTopicId, TypeSubscription
|
||||
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_names_must_be_unique() -> None:
|
||||
# Keep it unique to this test only.
|
||||
host_address = "localhost:50051"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
|
||||
worker = WorkerAgentRuntime(host_address=host_address)
|
||||
worker.start()
|
||||
|
||||
def agent_factory() -> NoopAgent:
|
||||
id = AgentInstantiationContext.current_agent_id()
|
||||
assert id == AgentId("name1", "default")
|
||||
agent = NoopAgent()
|
||||
assert agent.id == id
|
||||
return agent
|
||||
|
||||
await worker.register("name1", agent_factory)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await worker.register("name1", NoopAgent)
|
||||
|
||||
await worker.register("name3", NoopAgent)
|
||||
|
||||
# Let the agent run for a bit.
|
||||
await asyncio.sleep(2)
|
||||
|
||||
await worker.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_receives_publish() -> None:
|
||||
# Keep it unique to this test only.
|
||||
host_address = "localhost:50052"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
|
||||
worker = WorkerAgentRuntime(host_address=host_address)
|
||||
worker.start()
|
||||
|
||||
await worker.register("name", LoopbackAgent)
|
||||
await worker.add_subscription(TypeSubscription("default", "name"))
|
||||
agent_id = AgentId("name", key="default")
|
||||
topic_id = TopicId("default", "default")
|
||||
await worker.publish_message(MessageType(), topic_id=topic_id)
|
||||
|
||||
# Let the agent run for a bit.
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await worker.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await worker.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgent
|
||||
)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
await worker.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_receives_publish_cascade() -> None:
|
||||
# Keep it unique to this test only.
|
||||
host_address = "localhost:50053"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(CascadingMessageType))
|
||||
runtime = WorkerAgentRuntime(host_address=host_address)
|
||||
runtime.start()
|
||||
|
||||
num_agents = 5
|
||||
num_initial_messages = 5
|
||||
max_rounds = 5
|
||||
total_num_calls_expected = 0
|
||||
for i in range(0, max_rounds):
|
||||
total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i)
|
||||
|
||||
# Register agents
|
||||
for i in range(num_agents):
|
||||
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds))
|
||||
await runtime.add_subscription(TypeSubscription("default", f"name{i}"))
|
||||
|
||||
# Publish messages
|
||||
for _ in range(num_initial_messages):
|
||||
await runtime.publish_message(CascadingMessageType(round=1), topic_id=DefaultTopicId())
|
||||
|
||||
# Let the agents run for a bit.
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Check that each agent received the correct number of messages.
|
||||
for i in range(num_agents):
|
||||
agent = await runtime.try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent)
|
||||
assert agent.num_calls == total_num_calls_expected
|
||||
|
||||
await runtime.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_subscription() -> None:
|
||||
# Keep it unique to this test only.
|
||||
host_address = "localhost:50054"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
runtime = WorkerAgentRuntime(host_address=host_address)
|
||||
runtime.start()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
|
||||
agent_id = AgentId("name", key="default")
|
||||
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgent
|
||||
)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
await runtime.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_default_default_subscription() -> None:
|
||||
# Keep it unique to this test only.
|
||||
host_address = "localhost:50055"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
runtime = WorkerAgentRuntime(host_address=host_address)
|
||||
runtime.start()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription(topic_type="Other")])
|
||||
agent_id = AgentId("name", key="default")
|
||||
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(type="Other"))
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgent
|
||||
)
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
await runtime.stop()
|
||||
await host.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_publish_to_other_source() -> None:
|
||||
# Keep it unique to this test only.
|
||||
host_address = "localhost:50056"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
runtime = WorkerAgentRuntime(host_address=host_address)
|
||||
runtime.start()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
|
||||
agent_id = AgentId("name", key="default")
|
||||
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(source="other"))
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
|
||||
assert long_running_agent.num_calls == 0
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
|
||||
AgentId("name", key="other"), type=LoopbackAgent
|
||||
)
|
||||
assert other_long_running_agent.num_calls == 1
|
||||
|
||||
await runtime.stop()
|
||||
await host.stop()
|
Loading…
Reference in New Issue