Implement RPC and Subscription-based broadcast for python host and worker runtime. (#389)

* Refactor subscription in single threaded agent runtime

* Update proto to support response result type

* Support RPC and subscription-based broadcast for Python host and worker runtime.

* format
This commit is contained in:
Eric Zhu 2024-08-22 09:07:28 -07:00 committed by GitHub
parent 519c981efa
commit dc847d3985
8 changed files with 358 additions and 150 deletions

View File

@ -12,15 +12,17 @@ message RpcRequest {
AgentId source = 2;
AgentId target = 3;
string method = 4;
string data = 5;
map<string, string> metadata = 6;
string data_type = 5;
string data = 6;
map<string, string> metadata = 7;
}
message RpcResponse {
string request_id = 1;
string result = 2;
string error = 3;
map<string, string> metadata = 4;
string result_type = 2;
string result = 3;
string error = 4;
map<string, string> metadata = 5;
}
message Event {
@ -36,6 +38,21 @@ message RegisterAgentType {
string type = 1;
}
message TypeSubscription {
string topic_type = 1;
string agent_type = 2;
}
message Subscription {
oneof subscription {
TypeSubscription typeSubscription = 1;
}
}
message AddSubscription {
Subscription subscription = 1;
}
service AgentRpc {
rpc OpenChannel (stream Message) returns (stream Message);
}
@ -46,6 +63,7 @@ message Message {
RpcResponse response = 2;
Event event = 3;
RegisterAgentType registerAgentType = 4;
AddSubscription addSubscription = 5;
}
}

View File

@ -1,6 +1,7 @@
import asyncio
import logging
from dataclasses import dataclass
from typing import Any, NoReturn
from agnext.application import WorkerAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
@ -45,9 +46,11 @@ class ReceiveAgent(TypeRoutedAgent):
@message_handler
async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None:
assert ctx.topic_id is not None
await self.publish_message(ReturnedFeedback(f"Returned feedback: {message.content}"), topic_id=ctx.topic_id)
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: # type: ignore
print(f"Unhandled message: {message}")
class GreeterAgent(TypeRoutedAgent):
def __init__(self) -> None:
@ -56,15 +59,16 @@ class GreeterAgent(TypeRoutedAgent):
@message_handler
async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None:
assert ctx.topic_id is not None
await self.publish_message(Greeting(f"Hello, {message.content}!"), topic_id=ctx.topic_id)
@message_handler
async def on_returned_greet(self, message: ReturnedGreeting, ctx: MessageContext) -> None:
assert ctx.topic_id is not None
await self.publish_message(Feedback(f"Feedback: {message.content}"), topic_id=ctx.topic_id)
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: # type: ignore
print(f"Unhandled message: {message}")
async def main() -> None:
runtime = WorkerAgentRuntime()
@ -75,8 +79,8 @@ async def main() -> None:
MESSAGE_TYPE_REGISTRY.add_type(ReturnedFeedback)
await runtime.start(host_connection_string="localhost:50051")
await runtime.register("reciever", lambda: ReceiveAgent())
await runtime.add_subscription(TypeSubscription("default", "reciever"))
await runtime.register("receiver", lambda: ReceiveAgent())
await runtime.add_subscription(TypeSubscription("default", "receiver"))
await runtime.register("greeter", lambda: GreeterAgent())
await runtime.add_subscription(TypeSubscription("default", "greeter"))

View File

@ -1,9 +1,10 @@
import asyncio
import logging
from dataclasses import dataclass
from typing import Any, NoReturn
from agnext.application import WorkerAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components import TypeRoutedAgent, message_handler, TypeSubscription
from agnext.core import MESSAGE_TYPE_REGISTRY, AgentId, AgentInstantiationContext, MessageContext, TopicId
@ -34,6 +35,9 @@ class ReceiveAgent(TypeRoutedAgent):
async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None:
print(f"Feedback received: {message.content}")
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: # type: ignore
print(f"Unhandled message: {message}")
class GreeterAgent(TypeRoutedAgent):
def __init__(self, receive_agent_id: AgentId) -> None:
@ -46,6 +50,9 @@ class GreeterAgent(TypeRoutedAgent):
assert ctx.topic_id is not None
await self.publish_message(Feedback(f"Feedback: {response.content}"), topic_id=ctx.topic_id)
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: # type: ignore
print(f"Unhandled message: {message}")
async def main() -> None:
runtime = WorkerAgentRuntime()
@ -54,11 +61,13 @@ async def main() -> None:
MESSAGE_TYPE_REGISTRY.add_type(Feedback)
await runtime.start(host_connection_string="localhost:50051")
await runtime.register("reciever", lambda: ReceiveAgent())
await runtime.register("receiver", lambda: ReceiveAgent())
await runtime.register(
"greeter", lambda: GreeterAgent(AgentId("reciever", AgentInstantiationContext.current_agent_id().key))
"greeter", lambda: GreeterAgent(AgentId("receiver", AgentInstantiationContext.current_agent_id().key))
)
await runtime.add_subscription(TypeSubscription(topic_type="default", agent_type="greeter"))
await runtime.add_subscription(TypeSubscription(topic_type="default", agent_type="receiver"))
await runtime.publish_message(AskToGreet("Hello World!"), topic_id=TopicId("default", "default"))
# Just to keep the runtime running

View File

@ -6,6 +6,9 @@ from typing import Any, Dict, Set
import grpc
from ..components import TypeSubscription
from ..core import TopicId
from ._helpers import SubscriptionManager
from .protos import agent_worker_pb2, agent_worker_pb2_grpc
logger = logging.getLogger("agnext")
@ -19,9 +22,11 @@ class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer):
self._client_id = 0
self._client_id_lock = asyncio.Lock()
self._send_queues: Dict[int, asyncio.Queue[agent_worker_pb2.Message]] = {}
self._agent_type_to_client_id_lock = asyncio.Lock()
self._agent_type_to_client_id: Dict[str, int] = {}
self._pending_requests: Dict[int, Dict[str, Future[Any]]] = {}
self._background_tasks: Set[Task[Any]] = set()
self._subscription_manager = SubscriptionManager()
async def OpenChannel( # type: ignore
self,
@ -52,6 +57,7 @@ class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer):
except Exception as e:
logger.error(f"Failed to send message to client {client_id}: {e}", exc_info=True)
break
logger.info(f"Sent message to client {client_id}: {message}")
# Wait for the receiving task to finish.
await receiving_task
@ -61,45 +67,65 @@ class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer):
# Cancel pending requests sent to this client.
for future in self._pending_requests.pop(client_id, {}).values():
future.cancel()
# Remove the client id from the agent type to client id mapping.
async with self._agent_type_to_client_id_lock:
agent_types = [
agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id
]
for agent_type in agent_types:
del self._agent_type_to_client_id[agent_type]
logger.info(f"Client {client_id} disconnected.")
def _raise_on_exception(self, task: Task[Any]) -> None:
exception = task.exception()
if exception is not None:
raise exception
async def _receive_messages(
self, client_id: int, request_iterator: AsyncIterator[agent_worker_pb2.Message]
) -> None:
# Receive messages from the client and process them.
async for message in request_iterator:
logger.info(f"Received message from client {client_id}: {message}")
oneofcase = message.WhichOneof("message")
match oneofcase:
case "request":
request: agent_worker_pb2.RpcRequest = message.request
logger.info(f"Received request message: {request}")
task = asyncio.create_task(self._process_request(request, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "response":
response: agent_worker_pb2.RpcResponse = message.response
logger.info(f"Received response message: {response}")
task = asyncio.create_task(self._process_response(response, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "event":
event: agent_worker_pb2.Event = message.event
logger.info(f"Received event message: {event}")
task = asyncio.create_task(self._process_event(event))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "registerAgentType":
register_agent_type: agent_worker_pb2.RegisterAgentType = message.registerAgentType
logger.info(f"Received register agent type message: {register_agent_type}")
task = asyncio.create_task(self._process_register_agent_type(register_agent_type, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "addSubscription":
add_subscription: agent_worker_pb2.AddSubscription = message.addSubscription
task = asyncio.create_task(self._process_add_subscription(add_subscription))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case None:
logger.warning("Received empty message")
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None:
# Deliver the message to a client given the target agent type.
target_client_id = self._agent_type_to_client_id.get(request.target.name)
async with self._agent_type_to_client_id_lock:
target_client_id = self._agent_type_to_client_id.get(request.target.name)
if target_client_id is None:
logger.error(f"Agent {request.target.name} not found, failed to deliver message.")
return
@ -116,6 +142,7 @@ class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer):
# Create a task to wait for the response and send it back to the client.
send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id))
self._background_tasks.add(send_response_task)
send_response_task.add_done_callback(self._raise_on_exception)
send_response_task.add_done_callback(self._background_tasks.discard)
async def _wait_and_send_response(self, future: Future[agent_worker_pb2.RpcResponse], client_id: int) -> None:
@ -133,18 +160,47 @@ class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer):
future.set_result(response)
async def _process_event(self, event: agent_worker_pb2.Event) -> None:
# Deliver the event to all the clients.
# TODO: deliver based on subscriptions.
for send_queue in self._send_queues.values():
await send_queue.put(agent_worker_pb2.Message(event=event))
topic_id = TopicId(type=event.topic_type, source=event.topic_source)
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
# Get the client ids of the recipients.
async with self._agent_type_to_client_id_lock:
client_ids: Set[int] = set()
for recipient in recipients:
client_id = self._agent_type_to_client_id.get(recipient.type)
if client_id is not None:
client_ids.add(client_id)
else:
logger.error(f"Agent {recipient.type} and its client not found for topic {topic_id}.")
# Deliver the event to clients.
for client_id in client_ids:
await self._send_queues[client_id].put(agent_worker_pb2.Message(event=event))
async def _process_register_agent_type(
self, register_agent_type: agent_worker_pb2.RegisterAgentType, client_id: int
) -> None:
# Register the agent type with the host runtime.
if register_agent_type.type in self._agent_type_to_client_id:
existing_client_id = self._agent_type_to_client_id[register_agent_type.type]
logger.warning(
f"Agent type {register_agent_type.type} already registered with client {existing_client_id}, overwriting the client mapping to client {client_id}."
)
self._agent_type_to_client_id[register_agent_type.type] = client_id
async with self._agent_type_to_client_id_lock:
if register_agent_type.type in self._agent_type_to_client_id:
existing_client_id = self._agent_type_to_client_id[register_agent_type.type]
logger.error(
f"Agent type {register_agent_type.type} already registered with client {existing_client_id}."
)
# TODO: send an error response back to the client.
else:
self._agent_type_to_client_id[register_agent_type.type] = client_id
# TODO: send a success response back to the client.
async def _process_add_subscription(self, add_subscription: agent_worker_pb2.AddSubscription) -> None:
oneofcase = add_subscription.subscription.WhichOneof("subscription")
match oneofcase:
case "typeSubscription":
type_subscription_msg: agent_worker_pb2.TypeSubscription = (
add_subscription.subscription.typeSubscription
)
type_subscription = TypeSubscription(
topic_type=type_subscription_msg.topic_type, agent_type=type_subscription_msg.agent_type
)
await self._subscription_manager.add_subscription(type_subscription)
# TODO: send a success response back to the client.
case None:
logger.warning("Received empty subscription message")

View File

@ -28,6 +28,7 @@ import grpc
from grpc.aio import StreamStreamCall
from typing_extensions import Self
from ..components import TypeSubscription
from ..core import (
MESSAGE_TYPE_REGISTRY,
Agent,
@ -41,19 +42,11 @@ from ..core import (
Subscription,
TopicId,
)
from ._helpers import get_impl
from .protos import AgentId as AgentIdProto
from .protos import (
AgentRpcStub,
Event,
Message,
RegisterAgentType,
RpcRequest,
RpcResponse,
)
from ._helpers import SubscriptionManager, get_impl
from .protos import agent_worker_pb2, agent_worker_pb2_grpc
if TYPE_CHECKING:
from .protos import AgentRpcAsyncStub
from .protos.agent_worker_pb2_grpc import AgentRpcAsyncStub
logger = logging.getLogger("agnext")
event_logger = logging.getLogger("agnext.events")
@ -91,8 +84,8 @@ class HostConnection:
def __init__(self, channel: grpc.aio.Channel) -> None: # type: ignore
self._channel = channel
self._send_queue = asyncio.Queue[Message]()
self._recv_queue = asyncio.Queue[Message]()
self._send_queue = asyncio.Queue[agent_worker_pb2.Message]()
self._recv_queue = asyncio.Queue[agent_worker_pb2.Message]()
self._connection_task: Task[None] | None = None
@classmethod
@ -116,24 +109,28 @@ class HostConnection:
@staticmethod
async def _connect( # type: ignore
channel: grpc.aio.Channel, send_queue: asyncio.Queue[Message], receive_queue: asyncio.Queue[Message]
channel: grpc.aio.Channel,
send_queue: asyncio.Queue[agent_worker_pb2.Message],
receive_queue: asyncio.Queue[agent_worker_pb2.Message],
) -> None:
stub: AgentRpcAsyncStub = AgentRpcStub(channel) # type: ignore
stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel) # type: ignore
# TODO: where do exceptions from reading the iterable go? How do we recover from those?
recv_stream: StreamStreamCall[Message, Message] = stub.OpenChannel(QueueAsyncIterable(send_queue)) # type: ignore
recv_stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel( # type: ignore
QueueAsyncIterable(send_queue)
) # type: ignore
while True:
try:
logger.info("Waiting for message")
logger.info("Waiting for message from host")
message = await recv_stream.read() # type: ignore
if message == grpc.aio.EOF: # type: ignore
logger.info("EOF")
break
message = cast(Message, message)
logger.info("Received message: %s", message)
message = cast(agent_worker_pb2.Message, message)
logger.info(f"Received a message from host: {message}")
await receive_queue.put(message)
logger.info("Put message in queue")
logger.info("Put message in receive queue")
except Exception as e:
print("=========================================================================")
print(e)
@ -141,10 +138,12 @@ class HostConnection:
del recv_stream
recv_stream = stub.OpenChannel(QueueAsyncIterable(send_queue)) # type: ignore
async def send(self, message: Message) -> None:
async def send(self, message: agent_worker_pb2.Message) -> None:
logger.info(f"Send message to host: {message}")
await self._send_queue.put(message)
logger.info("Put message in send queue")
async def recv(self) -> Message:
async def recv(self) -> agent_worker_pb2.Message:
logger.info("Getting message from queue")
return await self._recv_queue.get()
@ -164,9 +163,7 @@ class WorkerAgentRuntime(AgentRuntime):
self._next_request_id = 0
self._host_connection: HostConnection | None = None
self._background_tasks: Set[Task[Any]] = set()
self._subscriptions: List[Subscription] = []
self._seen_topics: Set[TopicId] = set()
self._subscribed_recipients: DefaultDict[TopicId, List[AgentId]] = defaultdict(list)
self._subscription_manager = SubscriptionManager()
async def start(self, host_connection_string: str) -> None:
if self._running:
@ -178,31 +175,38 @@ class WorkerAgentRuntime(AgentRuntime):
self._read_task = asyncio.create_task(self._run_read_loop())
self._running = True
def _raise_on_exception(self, task: Task[Any]) -> None:
exception = task.exception()
if exception is not None:
raise exception
async def _run_read_loop(self) -> None:
logger.info("Starting read loop")
# TODO: catch exceptions and reconnect
while self._running:
try:
message = await self._host_connection.recv() # type: ignore
logger.info("Got message: %s", message)
oneofcase = Message.WhichOneof(message, "message")
oneofcase = agent_worker_pb2.Message.WhichOneof(message, "message")
match oneofcase:
case "registerAgentType":
logger.warn("Cant handle registerAgentType, skipping.")
case "registerAgentType" | "addSubscription":
logger.warn(f"Cant handle {oneofcase}, skipping.")
case "request":
request: RpcRequest = message.request
request: agent_worker_pb2.RpcRequest = message.request
task = asyncio.create_task(self._process_request(request))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "response":
response: RpcResponse = message.response
response: agent_worker_pb2.RpcResponse = message.response
task = asyncio.create_task(self._process_response(response))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "event":
event: Event = message.event
event: agent_worker_pb2.Event = message.event
task = asyncio.create_task(self._process_event(event))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case None:
logger.warn("No message")
@ -230,7 +234,8 @@ class WorkerAgentRuntime(AgentRuntime):
) -> Any:
if not self._running:
raise ValueError("Runtime must be running when sending message.")
assert self._host_connection is not None
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
# create a new future for the result
future = asyncio.get_event_loop().create_future()
async with self._pending_requests_lock:
@ -239,20 +244,21 @@ class WorkerAgentRuntime(AgentRuntime):
request_id_str = str(request_id)
self._pending_requests[request_id_str] = future
sender = cast(AgentId, sender)
method = MESSAGE_TYPE_REGISTRY.type_name(message)
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=method)
runtime_message = Message(
request=RpcRequest(
data_type = MESSAGE_TYPE_REGISTRY.type_name(message)
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=data_type)
runtime_message = agent_worker_pb2.Message(
request=agent_worker_pb2.RpcRequest(
request_id=request_id_str,
target=AgentIdProto(name=recipient.type, namespace=recipient.key),
source=AgentIdProto(name=sender.type, namespace=sender.key),
method=method,
target=agent_worker_pb2.AgentId(name=recipient.type, namespace=recipient.key),
source=agent_worker_pb2.AgentId(name=sender.type, namespace=sender.key),
data_type=data_type,
data=serialized_message,
)
)
# TODO: Find a way to handle timeouts/errors
task = asyncio.create_task(self._host_connection.send(runtime_message))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
return await future
@ -264,20 +270,21 @@ class WorkerAgentRuntime(AgentRuntime):
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> None:
assert self._host_connection is not None
if not self._running:
raise ValueError("Runtime must be running when publishing message.")
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
message_type = MESSAGE_TYPE_REGISTRY.type_name(message)
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=message_type)
message = Message(
event=Event(
runtime_message = agent_worker_pb2.Message(
event=agent_worker_pb2.Event(
topic_type=topic_id.type, topic_source=topic_id.source, data_type=message_type, data=serialized_message
)
)
async def write_message() -> None:
assert self._host_connection is not None
await self._host_connection.send(message)
await asyncio.create_task(write_message())
task = asyncio.create_task(self._host_connection.send(runtime_message))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
async def save_state(self) -> Mapping[str, Any]:
raise NotImplementedError("Saving state is not yet implemented.")
@ -294,70 +301,88 @@ class WorkerAgentRuntime(AgentRuntime):
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
raise NotImplementedError("Agent load_state is not yet implemented.")
async def _process_request(self, request: RpcRequest) -> None:
async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None:
assert self._host_connection is not None
target = AgentId(request.target.name, request.target.namespace)
source = AgentId(request.source.name, request.source.namespace)
logging.info(f"Processing request from {source} to {target}")
# Deserialize the message.
message = MESSAGE_TYPE_REGISTRY.deserialize(request.data, type_name=request.data_type)
# Get the target agent and prepare the message context.
target_agent = await self._get_agent(target)
message_context = MessageContext(
sender=source,
topic_id=None,
is_rpc=True,
cancellation_token=CancellationToken(),
)
# Call the target agent.
try:
logging.info(f"Processing request from {source} to {target}")
target_agent = await self._get_agent(target)
message_context = MessageContext(
sender=source,
topic_id=None,
is_rpc=True,
cancellation_token=CancellationToken(),
)
message = MESSAGE_TYPE_REGISTRY.deserialize(request.data, type_name=request.method)
response = await target_agent.on_message(message, ctx=message_context)
serialized_response = MESSAGE_TYPE_REGISTRY.serialize(response, type_name=request.method)
response_message = Message(
response=RpcResponse(
request_id=request.request_id,
result=serialized_response,
)
)
result = await target_agent.on_message(message, ctx=message_context)
except BaseException as e:
response_message = Message(
response=RpcResponse(
response_message = agent_worker_pb2.Message(
response=agent_worker_pb2.RpcResponse(
request_id=request.request_id,
error=str(e),
)
)
# Send the error response.
await self._host_connection.send(response_message)
return
# Serialize the result.
result_type = MESSAGE_TYPE_REGISTRY.type_name(result)
serialized_result = MESSAGE_TYPE_REGISTRY.serialize(result, type_name=result_type)
# Create the response message.
response_message = agent_worker_pb2.Message(
response=agent_worker_pb2.RpcResponse(
request_id=request.request_id,
result_type=result_type,
result=serialized_result,
)
)
# Send the response.
await self._host_connection.send(response_message)
async def _process_response(self, response: RpcResponse) -> None:
# TODO: deserialize the response and set the future result
async def _process_response(self, response: agent_worker_pb2.RpcResponse) -> None:
# Deserialize the result.
result = MESSAGE_TYPE_REGISTRY.deserialize(response.result, type_name=response.result_type)
# Get the future and set the result.
future = self._pending_requests.pop(response.request_id)
if len(response.error) > 0:
future.set_exception(Exception(response.error))
else:
future.set_result(response.result)
future.set_result(result)
async def _process_event(self, event: Event) -> None:
...
# message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.data_type)
# for agent_id in self._per_type_subscribers[
# (namespace, MESSAGE_TYPE_REGISTRY.type_name(message))
# ]:
# agent = await self._get_agent(agent_id)
# message_context = MessageContext(
# # TODO: should sender be in the proto even for published events?
# sender=None,
# # TODO: topic_id
# topic_id=None,
# is_rpc=False,
# cancellation_token=CancellationToken(),
# )
# try:
# await agent.on_message(message, ctx=message_context)
# logger.info("%s handled event %s", agent_id, message)
# except Exception as e:
# event_logger.error("Error handling message", exc_info=e)
async def _process_event(self, event: agent_worker_pb2.Event) -> None:
message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.data_type)
topic_id = TopicId(event.topic_type, event.topic_source)
# Get the recipients for the topic.
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
# Send the message to each recipient.
responses: List[Awaitable[Any]] = []
for agent_id in recipients:
# TODO: avoid sending to the sender.
message_context = MessageContext(
sender=None,
topic_id=topic_id,
is_rpc=False,
cancellation_token=CancellationToken(),
)
agent = await self._get_agent(agent_id)
future = agent.on_message(message, ctx=message_context)
responses.append(future)
# Wait for all responses.
try:
await asyncio.gather(*responses)
except BaseException as e:
logger.error("Error handling event", exc_info=e)
async def register(
self,
@ -368,10 +393,10 @@ class WorkerAgentRuntime(AgentRuntime):
raise ValueError(f"Agent with type {type} already exists.")
self._agent_factories[type] = agent_factory
assert self._host_connection is not None
message = Message(registerAgentType=RegisterAgentType(type=type))
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
message = agent_worker_pb2.Message(registerAgentType=agent_worker_pb2.RegisterAgentType(type=type))
await self._host_connection.send(message)
logger.info("Sent registerAgentType message for %s", type)
return AgentType(type)
async def _invoke_agent_factory(
@ -415,7 +440,23 @@ class WorkerAgentRuntime(AgentRuntime):
raise NotImplementedError("try_get_underlying_agent_instance is not yet implemented.")
async def add_subscription(self, subscription: Subscription) -> None:
raise NotImplementedError("Subscriptions are not yet implemented.")
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
if not isinstance(subscription, TypeSubscription):
raise ValueError("Only TypeSubscription is supported.")
# Add to local subscription manager.
await self._subscription_manager.add_subscription(subscription)
# Send the subscription to the host.
message = agent_worker_pb2.Message(
addSubscription=agent_worker_pb2.AddSubscription(
subscription=agent_worker_pb2.Subscription(
typeSubscription=agent_worker_pb2.TypeSubscription(
topic_type=subscription.topic_type, agent_type=subscription.agent_type
)
)
)
)
await self._host_connection.send(message)
async def remove_subscription(self, id: str) -> None:
raise NotImplementedError("Subscriptions are not yet implemented.")

View File

@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"*\n\x07\x41gentId\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x11\n\tnamespace\x18\x02 \x01(\t\"\xe5\x01\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x1f\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentId\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\t\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xa6\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x01(\t\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xb2\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12\x11\n\tdata_type\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\"\xbc\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"*\n\x07\x41gentId\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x11\n\tnamespace\x18\x02 \x01(\t\"\xf8\x01\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x1f\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentId\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12\x11\n\tdata_type\x18\x05 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x06 \x01(\t\x12\x32\n\x08metadata\x18\x07 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xbb\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x13\n\x0bresult_type\x18\x02 \x01(\t\x12\x0e\n\x06result\x18\x03 \x01(\t\x12\r\n\x05\x65rror\x18\x04 \x01(\t\x12\x33\n\x08metadata\x18\x05 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xc5\x01\n\x05\x45vent\x12\x11\n\tnamespace\x18\x01 \x01(\t\x12\x12\n\ntopic_type\x18\x02 \x01(\t\x12\x14\n\x0ctopic_source\x18\x03 \x01(\t\x12\x11\n\tdata_type\x18\x04 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\t\x12-\n\x08metadata\x18\x06 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"T\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x42\x0e\n\x0csubscription\"=\n\x0f\x41\x64\x64Subscription\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\xf0\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x12\x32\n\x0f\x61\x64\x64Subscription\x18\x05 \x01(\x0b\x32\x17.agents.AddSubscriptionH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@ -30,21 +30,27 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals['_AGENTID']._serialized_start=30
_globals['_AGENTID']._serialized_end=72
_globals['_RPCREQUEST']._serialized_start=75
_globals['_RPCREQUEST']._serialized_end=304
_globals['_RPCREQUEST_METADATAENTRY']._serialized_start=257
_globals['_RPCREQUEST_METADATAENTRY']._serialized_end=304
_globals['_RPCRESPONSE']._serialized_start=307
_globals['_RPCRESPONSE']._serialized_end=473
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=257
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=304
_globals['_EVENT']._serialized_start=476
_globals['_EVENT']._serialized_end=654
_globals['_EVENT_METADATAENTRY']._serialized_start=257
_globals['_EVENT_METADATAENTRY']._serialized_end=304
_globals['_REGISTERAGENTTYPE']._serialized_start=656
_globals['_REGISTERAGENTTYPE']._serialized_end=689
_globals['_MESSAGE']._serialized_start=692
_globals['_MESSAGE']._serialized_end=880
_globals['_AGENTRPC']._serialized_start=882
_globals['_AGENTRPC']._serialized_end=945
_globals['_RPCREQUEST']._serialized_end=323
_globals['_RPCREQUEST_METADATAENTRY']._serialized_start=276
_globals['_RPCREQUEST_METADATAENTRY']._serialized_end=323
_globals['_RPCRESPONSE']._serialized_start=326
_globals['_RPCRESPONSE']._serialized_end=513
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=276
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=323
_globals['_EVENT']._serialized_start=516
_globals['_EVENT']._serialized_end=713
_globals['_EVENT_METADATAENTRY']._serialized_start=276
_globals['_EVENT_METADATAENTRY']._serialized_end=323
_globals['_REGISTERAGENTTYPE']._serialized_start=715
_globals['_REGISTERAGENTTYPE']._serialized_end=748
_globals['_TYPESUBSCRIPTION']._serialized_start=750
_globals['_TYPESUBSCRIPTION']._serialized_end=808
_globals['_SUBSCRIPTION']._serialized_start=810
_globals['_SUBSCRIPTION']._serialized_end=894
_globals['_ADDSUBSCRIPTION']._serialized_start=896
_globals['_ADDSUBSCRIPTION']._serialized_end=957
_globals['_MESSAGE']._serialized_start=960
_globals['_MESSAGE']._serialized_end=1200
_globals['_AGENTRPC']._serialized_start=1202
_globals['_AGENTRPC']._serialized_end=1265
# @@protoc_insertion_point(module_scope)

View File

@ -54,10 +54,12 @@ class RpcRequest(google.protobuf.message.Message):
SOURCE_FIELD_NUMBER: builtins.int
TARGET_FIELD_NUMBER: builtins.int
METHOD_FIELD_NUMBER: builtins.int
DATA_TYPE_FIELD_NUMBER: builtins.int
DATA_FIELD_NUMBER: builtins.int
METADATA_FIELD_NUMBER: builtins.int
request_id: builtins.str
method: builtins.str
data_type: builtins.str
data: builtins.str
@property
def source(self) -> global___AgentId: ...
@ -72,11 +74,12 @@ class RpcRequest(google.protobuf.message.Message):
source: global___AgentId | None = ...,
target: global___AgentId | None = ...,
method: builtins.str = ...,
data_type: builtins.str = ...,
data: builtins.str = ...,
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["source", b"source", "target", b"target"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["data", b"data", "metadata", b"metadata", "method", b"method", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ...
def ClearField(self, field_name: typing.Literal["data", b"data", "data_type", b"data_type", "metadata", b"metadata", "method", b"method", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ...
global___RpcRequest = RpcRequest
@ -101,10 +104,12 @@ class RpcResponse(google.protobuf.message.Message):
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
REQUEST_ID_FIELD_NUMBER: builtins.int
RESULT_TYPE_FIELD_NUMBER: builtins.int
RESULT_FIELD_NUMBER: builtins.int
ERROR_FIELD_NUMBER: builtins.int
METADATA_FIELD_NUMBER: builtins.int
request_id: builtins.str
result_type: builtins.str
result: builtins.str
error: builtins.str
@property
@ -113,11 +118,12 @@ class RpcResponse(google.protobuf.message.Message):
self,
*,
request_id: builtins.str = ...,
result_type: builtins.str = ...,
result: builtins.str = ...,
error: builtins.str = ...,
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["error", b"error", "metadata", b"metadata", "request_id", b"request_id", "result", b"result"]) -> None: ...
def ClearField(self, field_name: typing.Literal["error", b"error", "metadata", b"metadata", "request_id", b"request_id", "result", b"result", "result_type", b"result_type"]) -> None: ...
global___RpcResponse = RpcResponse
@ -141,11 +147,13 @@ class Event(google.protobuf.message.Message):
) -> None: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
NAMESPACE_FIELD_NUMBER: builtins.int
TOPIC_TYPE_FIELD_NUMBER: builtins.int
TOPIC_SOURCE_FIELD_NUMBER: builtins.int
DATA_TYPE_FIELD_NUMBER: builtins.int
DATA_FIELD_NUMBER: builtins.int
METADATA_FIELD_NUMBER: builtins.int
namespace: builtins.str
topic_type: builtins.str
topic_source: builtins.str
data_type: builtins.str
@ -155,13 +163,14 @@ class Event(google.protobuf.message.Message):
def __init__(
self,
*,
namespace: builtins.str = ...,
topic_type: builtins.str = ...,
topic_source: builtins.str = ...,
data_type: builtins.str = ...,
data: builtins.str = ...,
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["data", b"data", "data_type", b"data_type", "metadata", b"metadata", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ...
def ClearField(self, field_name: typing.Literal["data", b"data", "data_type", b"data_type", "metadata", b"metadata", "namespace", b"namespace", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ...
global___Event = Event
@ -180,6 +189,59 @@ class RegisterAgentType(google.protobuf.message.Message):
global___RegisterAgentType = RegisterAgentType
@typing.final
class TypeSubscription(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TOPIC_TYPE_FIELD_NUMBER: builtins.int
AGENT_TYPE_FIELD_NUMBER: builtins.int
topic_type: builtins.str
agent_type: builtins.str
def __init__(
self,
*,
topic_type: builtins.str = ...,
agent_type: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["agent_type", b"agent_type", "topic_type", b"topic_type"]) -> None: ...
global___TypeSubscription = TypeSubscription
@typing.final
class Subscription(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TYPESUBSCRIPTION_FIELD_NUMBER: builtins.int
@property
def typeSubscription(self) -> global___TypeSubscription: ...
def __init__(
self,
*,
typeSubscription: global___TypeSubscription | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["subscription", b"subscription", "typeSubscription", b"typeSubscription"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["subscription", b"subscription", "typeSubscription", b"typeSubscription"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["subscription", b"subscription"]) -> typing.Literal["typeSubscription"] | None: ...
global___Subscription = Subscription
@typing.final
class AddSubscription(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
SUBSCRIPTION_FIELD_NUMBER: builtins.int
@property
def subscription(self) -> global___Subscription: ...
def __init__(
self,
*,
subscription: global___Subscription | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["subscription", b"subscription"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["subscription", b"subscription"]) -> None: ...
global___AddSubscription = AddSubscription
@typing.final
class Message(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@ -188,6 +250,7 @@ class Message(google.protobuf.message.Message):
RESPONSE_FIELD_NUMBER: builtins.int
EVENT_FIELD_NUMBER: builtins.int
REGISTERAGENTTYPE_FIELD_NUMBER: builtins.int
ADDSUBSCRIPTION_FIELD_NUMBER: builtins.int
@property
def request(self) -> global___RpcRequest: ...
@property
@ -196,6 +259,8 @@ class Message(google.protobuf.message.Message):
def event(self) -> global___Event: ...
@property
def registerAgentType(self) -> global___RegisterAgentType: ...
@property
def addSubscription(self) -> global___AddSubscription: ...
def __init__(
self,
*,
@ -203,9 +268,10 @@ class Message(google.protobuf.message.Message):
response: global___RpcResponse | None = ...,
event: global___Event | None = ...,
registerAgentType: global___RegisterAgentType | None = ...,
addSubscription: global___AddSubscription | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["event", b"event", "message", b"message", "registerAgentType", b"registerAgentType", "request", b"request", "response", b"response"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["event", b"event", "message", b"message", "registerAgentType", b"registerAgentType", "request", b"request", "response", b"response"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "event", "registerAgentType"] | None: ...
def HasField(self, field_name: typing.Literal["addSubscription", b"addSubscription", "event", b"event", "message", b"message", "registerAgentType", b"registerAgentType", "request", b"request", "response", b"response"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["addSubscription", b"addSubscription", "event", b"event", "message", b"message", "registerAgentType", b"registerAgentType", "request", b"request", "response", b"response"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "event", "registerAgentType", "addSubscription"] | None: ...
global___Message = Message

View File

@ -35,6 +35,14 @@ class TypeSubscription(Subscription):
def id(self) -> str:
return self._id
@property
def topic_type(self) -> str:
return self._topic_type
@property
def agent_type(self) -> str:
return self._agent_type
def is_match(self, topic_id: TopicId) -> bool:
return topic_id.type == self._topic_type