mirror of https://github.com/microsoft/autogen.git
Implement RPC and Subscription-based broadcast for python host and worker runtime. (#389)
* Refactor subscription in single threaded agent runtime * Update proto to support response result type * Support RPC and subscription-based broadcast for Python host and worker runtime. * format
This commit is contained in:
parent
519c981efa
commit
dc847d3985
|
@ -12,15 +12,17 @@ message RpcRequest {
|
|||
AgentId source = 2;
|
||||
AgentId target = 3;
|
||||
string method = 4;
|
||||
string data = 5;
|
||||
map<string, string> metadata = 6;
|
||||
string data_type = 5;
|
||||
string data = 6;
|
||||
map<string, string> metadata = 7;
|
||||
}
|
||||
|
||||
message RpcResponse {
|
||||
string request_id = 1;
|
||||
string result = 2;
|
||||
string error = 3;
|
||||
map<string, string> metadata = 4;
|
||||
string result_type = 2;
|
||||
string result = 3;
|
||||
string error = 4;
|
||||
map<string, string> metadata = 5;
|
||||
}
|
||||
|
||||
message Event {
|
||||
|
@ -36,6 +38,21 @@ message RegisterAgentType {
|
|||
string type = 1;
|
||||
}
|
||||
|
||||
message TypeSubscription {
|
||||
string topic_type = 1;
|
||||
string agent_type = 2;
|
||||
}
|
||||
|
||||
message Subscription {
|
||||
oneof subscription {
|
||||
TypeSubscription typeSubscription = 1;
|
||||
}
|
||||
}
|
||||
|
||||
message AddSubscription {
|
||||
Subscription subscription = 1;
|
||||
}
|
||||
|
||||
service AgentRpc {
|
||||
rpc OpenChannel (stream Message) returns (stream Message);
|
||||
}
|
||||
|
@ -46,6 +63,7 @@ message Message {
|
|||
RpcResponse response = 2;
|
||||
Event event = 3;
|
||||
RegisterAgentType registerAgentType = 4;
|
||||
AddSubscription addSubscription = 5;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from agnext.application import WorkerAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
|
@ -45,9 +46,11 @@ class ReceiveAgent(TypeRoutedAgent):
|
|||
@message_handler
|
||||
async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None:
|
||||
assert ctx.topic_id is not None
|
||||
|
||||
await self.publish_message(ReturnedFeedback(f"Returned feedback: {message.content}"), topic_id=ctx.topic_id)
|
||||
|
||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: # type: ignore
|
||||
print(f"Unhandled message: {message}")
|
||||
|
||||
|
||||
class GreeterAgent(TypeRoutedAgent):
|
||||
def __init__(self) -> None:
|
||||
|
@ -56,15 +59,16 @@ class GreeterAgent(TypeRoutedAgent):
|
|||
@message_handler
|
||||
async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None:
|
||||
assert ctx.topic_id is not None
|
||||
|
||||
await self.publish_message(Greeting(f"Hello, {message.content}!"), topic_id=ctx.topic_id)
|
||||
|
||||
@message_handler
|
||||
async def on_returned_greet(self, message: ReturnedGreeting, ctx: MessageContext) -> None:
|
||||
assert ctx.topic_id is not None
|
||||
|
||||
await self.publish_message(Feedback(f"Feedback: {message.content}"), topic_id=ctx.topic_id)
|
||||
|
||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: # type: ignore
|
||||
print(f"Unhandled message: {message}")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = WorkerAgentRuntime()
|
||||
|
@ -75,8 +79,8 @@ async def main() -> None:
|
|||
MESSAGE_TYPE_REGISTRY.add_type(ReturnedFeedback)
|
||||
await runtime.start(host_connection_string="localhost:50051")
|
||||
|
||||
await runtime.register("reciever", lambda: ReceiveAgent())
|
||||
await runtime.add_subscription(TypeSubscription("default", "reciever"))
|
||||
await runtime.register("receiver", lambda: ReceiveAgent())
|
||||
await runtime.add_subscription(TypeSubscription("default", "receiver"))
|
||||
await runtime.register("greeter", lambda: GreeterAgent())
|
||||
await runtime.add_subscription(TypeSubscription("default", "greeter"))
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from agnext.application import WorkerAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components import TypeRoutedAgent, message_handler, TypeSubscription
|
||||
from agnext.core import MESSAGE_TYPE_REGISTRY, AgentId, AgentInstantiationContext, MessageContext, TopicId
|
||||
|
||||
|
||||
|
@ -34,6 +35,9 @@ class ReceiveAgent(TypeRoutedAgent):
|
|||
async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None:
|
||||
print(f"Feedback received: {message.content}")
|
||||
|
||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: # type: ignore
|
||||
print(f"Unhandled message: {message}")
|
||||
|
||||
|
||||
class GreeterAgent(TypeRoutedAgent):
|
||||
def __init__(self, receive_agent_id: AgentId) -> None:
|
||||
|
@ -46,6 +50,9 @@ class GreeterAgent(TypeRoutedAgent):
|
|||
assert ctx.topic_id is not None
|
||||
await self.publish_message(Feedback(f"Feedback: {response.content}"), topic_id=ctx.topic_id)
|
||||
|
||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: # type: ignore
|
||||
print(f"Unhandled message: {message}")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = WorkerAgentRuntime()
|
||||
|
@ -54,11 +61,13 @@ async def main() -> None:
|
|||
MESSAGE_TYPE_REGISTRY.add_type(Feedback)
|
||||
await runtime.start(host_connection_string="localhost:50051")
|
||||
|
||||
await runtime.register("reciever", lambda: ReceiveAgent())
|
||||
await runtime.register("receiver", lambda: ReceiveAgent())
|
||||
await runtime.register(
|
||||
"greeter", lambda: GreeterAgent(AgentId("reciever", AgentInstantiationContext.current_agent_id().key))
|
||||
"greeter", lambda: GreeterAgent(AgentId("receiver", AgentInstantiationContext.current_agent_id().key))
|
||||
)
|
||||
|
||||
await runtime.add_subscription(TypeSubscription(topic_type="default", agent_type="greeter"))
|
||||
await runtime.add_subscription(TypeSubscription(topic_type="default", agent_type="receiver"))
|
||||
await runtime.publish_message(AskToGreet("Hello World!"), topic_id=TopicId("default", "default"))
|
||||
|
||||
# Just to keep the runtime running
|
||||
|
|
|
@ -6,6 +6,9 @@ from typing import Any, Dict, Set
|
|||
|
||||
import grpc
|
||||
|
||||
from ..components import TypeSubscription
|
||||
from ..core import TopicId
|
||||
from ._helpers import SubscriptionManager
|
||||
from .protos import agent_worker_pb2, agent_worker_pb2_grpc
|
||||
|
||||
logger = logging.getLogger("agnext")
|
||||
|
@ -19,9 +22,11 @@ class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
|||
self._client_id = 0
|
||||
self._client_id_lock = asyncio.Lock()
|
||||
self._send_queues: Dict[int, asyncio.Queue[agent_worker_pb2.Message]] = {}
|
||||
self._agent_type_to_client_id_lock = asyncio.Lock()
|
||||
self._agent_type_to_client_id: Dict[str, int] = {}
|
||||
self._pending_requests: Dict[int, Dict[str, Future[Any]]] = {}
|
||||
self._background_tasks: Set[Task[Any]] = set()
|
||||
self._subscription_manager = SubscriptionManager()
|
||||
|
||||
async def OpenChannel( # type: ignore
|
||||
self,
|
||||
|
@ -52,6 +57,7 @@ class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to send message to client {client_id}: {e}", exc_info=True)
|
||||
break
|
||||
logger.info(f"Sent message to client {client_id}: {message}")
|
||||
# Wait for the receiving task to finish.
|
||||
await receiving_task
|
||||
|
||||
|
@ -61,45 +67,65 @@ class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
|||
# Cancel pending requests sent to this client.
|
||||
for future in self._pending_requests.pop(client_id, {}).values():
|
||||
future.cancel()
|
||||
# Remove the client id from the agent type to client id mapping.
|
||||
async with self._agent_type_to_client_id_lock:
|
||||
agent_types = [
|
||||
agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id
|
||||
]
|
||||
for agent_type in agent_types:
|
||||
del self._agent_type_to_client_id[agent_type]
|
||||
logger.info(f"Client {client_id} disconnected.")
|
||||
|
||||
def _raise_on_exception(self, task: Task[Any]) -> None:
|
||||
exception = task.exception()
|
||||
if exception is not None:
|
||||
raise exception
|
||||
|
||||
async def _receive_messages(
|
||||
self, client_id: int, request_iterator: AsyncIterator[agent_worker_pb2.Message]
|
||||
) -> None:
|
||||
# Receive messages from the client and process them.
|
||||
async for message in request_iterator:
|
||||
logger.info(f"Received message from client {client_id}: {message}")
|
||||
oneofcase = message.WhichOneof("message")
|
||||
match oneofcase:
|
||||
case "request":
|
||||
request: agent_worker_pb2.RpcRequest = message.request
|
||||
logger.info(f"Received request message: {request}")
|
||||
task = asyncio.create_task(self._process_request(request, client_id))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "response":
|
||||
response: agent_worker_pb2.RpcResponse = message.response
|
||||
logger.info(f"Received response message: {response}")
|
||||
task = asyncio.create_task(self._process_response(response, client_id))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "event":
|
||||
event: agent_worker_pb2.Event = message.event
|
||||
logger.info(f"Received event message: {event}")
|
||||
task = asyncio.create_task(self._process_event(event))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "registerAgentType":
|
||||
register_agent_type: agent_worker_pb2.RegisterAgentType = message.registerAgentType
|
||||
logger.info(f"Received register agent type message: {register_agent_type}")
|
||||
task = asyncio.create_task(self._process_register_agent_type(register_agent_type, client_id))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "addSubscription":
|
||||
add_subscription: agent_worker_pb2.AddSubscription = message.addSubscription
|
||||
task = asyncio.create_task(self._process_add_subscription(add_subscription))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case None:
|
||||
logger.warning("Received empty message")
|
||||
|
||||
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None:
|
||||
# Deliver the message to a client given the target agent type.
|
||||
target_client_id = self._agent_type_to_client_id.get(request.target.name)
|
||||
async with self._agent_type_to_client_id_lock:
|
||||
target_client_id = self._agent_type_to_client_id.get(request.target.name)
|
||||
if target_client_id is None:
|
||||
logger.error(f"Agent {request.target.name} not found, failed to deliver message.")
|
||||
return
|
||||
|
@ -116,6 +142,7 @@ class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
|||
# Create a task to wait for the response and send it back to the client.
|
||||
send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id))
|
||||
self._background_tasks.add(send_response_task)
|
||||
send_response_task.add_done_callback(self._raise_on_exception)
|
||||
send_response_task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
async def _wait_and_send_response(self, future: Future[agent_worker_pb2.RpcResponse], client_id: int) -> None:
|
||||
|
@ -133,18 +160,47 @@ class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
|||
future.set_result(response)
|
||||
|
||||
async def _process_event(self, event: agent_worker_pb2.Event) -> None:
|
||||
# Deliver the event to all the clients.
|
||||
# TODO: deliver based on subscriptions.
|
||||
for send_queue in self._send_queues.values():
|
||||
await send_queue.put(agent_worker_pb2.Message(event=event))
|
||||
topic_id = TopicId(type=event.topic_type, source=event.topic_source)
|
||||
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
|
||||
# Get the client ids of the recipients.
|
||||
async with self._agent_type_to_client_id_lock:
|
||||
client_ids: Set[int] = set()
|
||||
for recipient in recipients:
|
||||
client_id = self._agent_type_to_client_id.get(recipient.type)
|
||||
if client_id is not None:
|
||||
client_ids.add(client_id)
|
||||
else:
|
||||
logger.error(f"Agent {recipient.type} and its client not found for topic {topic_id}.")
|
||||
# Deliver the event to clients.
|
||||
for client_id in client_ids:
|
||||
await self._send_queues[client_id].put(agent_worker_pb2.Message(event=event))
|
||||
|
||||
async def _process_register_agent_type(
|
||||
self, register_agent_type: agent_worker_pb2.RegisterAgentType, client_id: int
|
||||
) -> None:
|
||||
# Register the agent type with the host runtime.
|
||||
if register_agent_type.type in self._agent_type_to_client_id:
|
||||
existing_client_id = self._agent_type_to_client_id[register_agent_type.type]
|
||||
logger.warning(
|
||||
f"Agent type {register_agent_type.type} already registered with client {existing_client_id}, overwriting the client mapping to client {client_id}."
|
||||
)
|
||||
self._agent_type_to_client_id[register_agent_type.type] = client_id
|
||||
async with self._agent_type_to_client_id_lock:
|
||||
if register_agent_type.type in self._agent_type_to_client_id:
|
||||
existing_client_id = self._agent_type_to_client_id[register_agent_type.type]
|
||||
logger.error(
|
||||
f"Agent type {register_agent_type.type} already registered with client {existing_client_id}."
|
||||
)
|
||||
# TODO: send an error response back to the client.
|
||||
else:
|
||||
self._agent_type_to_client_id[register_agent_type.type] = client_id
|
||||
# TODO: send a success response back to the client.
|
||||
|
||||
async def _process_add_subscription(self, add_subscription: agent_worker_pb2.AddSubscription) -> None:
|
||||
oneofcase = add_subscription.subscription.WhichOneof("subscription")
|
||||
match oneofcase:
|
||||
case "typeSubscription":
|
||||
type_subscription_msg: agent_worker_pb2.TypeSubscription = (
|
||||
add_subscription.subscription.typeSubscription
|
||||
)
|
||||
type_subscription = TypeSubscription(
|
||||
topic_type=type_subscription_msg.topic_type, agent_type=type_subscription_msg.agent_type
|
||||
)
|
||||
await self._subscription_manager.add_subscription(type_subscription)
|
||||
# TODO: send a success response back to the client.
|
||||
case None:
|
||||
logger.warning("Received empty subscription message")
|
||||
|
|
|
@ -28,6 +28,7 @@ import grpc
|
|||
from grpc.aio import StreamStreamCall
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..components import TypeSubscription
|
||||
from ..core import (
|
||||
MESSAGE_TYPE_REGISTRY,
|
||||
Agent,
|
||||
|
@ -41,19 +42,11 @@ from ..core import (
|
|||
Subscription,
|
||||
TopicId,
|
||||
)
|
||||
from ._helpers import get_impl
|
||||
from .protos import AgentId as AgentIdProto
|
||||
from .protos import (
|
||||
AgentRpcStub,
|
||||
Event,
|
||||
Message,
|
||||
RegisterAgentType,
|
||||
RpcRequest,
|
||||
RpcResponse,
|
||||
)
|
||||
from ._helpers import SubscriptionManager, get_impl
|
||||
from .protos import agent_worker_pb2, agent_worker_pb2_grpc
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protos import AgentRpcAsyncStub
|
||||
from .protos.agent_worker_pb2_grpc import AgentRpcAsyncStub
|
||||
|
||||
logger = logging.getLogger("agnext")
|
||||
event_logger = logging.getLogger("agnext.events")
|
||||
|
@ -91,8 +84,8 @@ class HostConnection:
|
|||
|
||||
def __init__(self, channel: grpc.aio.Channel) -> None: # type: ignore
|
||||
self._channel = channel
|
||||
self._send_queue = asyncio.Queue[Message]()
|
||||
self._recv_queue = asyncio.Queue[Message]()
|
||||
self._send_queue = asyncio.Queue[agent_worker_pb2.Message]()
|
||||
self._recv_queue = asyncio.Queue[agent_worker_pb2.Message]()
|
||||
self._connection_task: Task[None] | None = None
|
||||
|
||||
@classmethod
|
||||
|
@ -116,24 +109,28 @@ class HostConnection:
|
|||
|
||||
@staticmethod
|
||||
async def _connect( # type: ignore
|
||||
channel: grpc.aio.Channel, send_queue: asyncio.Queue[Message], receive_queue: asyncio.Queue[Message]
|
||||
channel: grpc.aio.Channel,
|
||||
send_queue: asyncio.Queue[agent_worker_pb2.Message],
|
||||
receive_queue: asyncio.Queue[agent_worker_pb2.Message],
|
||||
) -> None:
|
||||
stub: AgentRpcAsyncStub = AgentRpcStub(channel) # type: ignore
|
||||
stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel) # type: ignore
|
||||
|
||||
# TODO: where do exceptions from reading the iterable go? How do we recover from those?
|
||||
recv_stream: StreamStreamCall[Message, Message] = stub.OpenChannel(QueueAsyncIterable(send_queue)) # type: ignore
|
||||
recv_stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel( # type: ignore
|
||||
QueueAsyncIterable(send_queue)
|
||||
) # type: ignore
|
||||
|
||||
while True:
|
||||
try:
|
||||
logger.info("Waiting for message")
|
||||
logger.info("Waiting for message from host")
|
||||
message = await recv_stream.read() # type: ignore
|
||||
if message == grpc.aio.EOF: # type: ignore
|
||||
logger.info("EOF")
|
||||
break
|
||||
message = cast(Message, message)
|
||||
logger.info("Received message: %s", message)
|
||||
message = cast(agent_worker_pb2.Message, message)
|
||||
logger.info(f"Received a message from host: {message}")
|
||||
await receive_queue.put(message)
|
||||
logger.info("Put message in queue")
|
||||
logger.info("Put message in receive queue")
|
||||
except Exception as e:
|
||||
print("=========================================================================")
|
||||
print(e)
|
||||
|
@ -141,10 +138,12 @@ class HostConnection:
|
|||
del recv_stream
|
||||
recv_stream = stub.OpenChannel(QueueAsyncIterable(send_queue)) # type: ignore
|
||||
|
||||
async def send(self, message: Message) -> None:
|
||||
async def send(self, message: agent_worker_pb2.Message) -> None:
|
||||
logger.info(f"Send message to host: {message}")
|
||||
await self._send_queue.put(message)
|
||||
logger.info("Put message in send queue")
|
||||
|
||||
async def recv(self) -> Message:
|
||||
async def recv(self) -> agent_worker_pb2.Message:
|
||||
logger.info("Getting message from queue")
|
||||
return await self._recv_queue.get()
|
||||
|
||||
|
@ -164,9 +163,7 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
self._next_request_id = 0
|
||||
self._host_connection: HostConnection | None = None
|
||||
self._background_tasks: Set[Task[Any]] = set()
|
||||
self._subscriptions: List[Subscription] = []
|
||||
self._seen_topics: Set[TopicId] = set()
|
||||
self._subscribed_recipients: DefaultDict[TopicId, List[AgentId]] = defaultdict(list)
|
||||
self._subscription_manager = SubscriptionManager()
|
||||
|
||||
async def start(self, host_connection_string: str) -> None:
|
||||
if self._running:
|
||||
|
@ -178,31 +175,38 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
self._read_task = asyncio.create_task(self._run_read_loop())
|
||||
self._running = True
|
||||
|
||||
def _raise_on_exception(self, task: Task[Any]) -> None:
|
||||
exception = task.exception()
|
||||
if exception is not None:
|
||||
raise exception
|
||||
|
||||
async def _run_read_loop(self) -> None:
|
||||
logger.info("Starting read loop")
|
||||
# TODO: catch exceptions and reconnect
|
||||
while self._running:
|
||||
try:
|
||||
message = await self._host_connection.recv() # type: ignore
|
||||
logger.info("Got message: %s", message)
|
||||
oneofcase = Message.WhichOneof(message, "message")
|
||||
oneofcase = agent_worker_pb2.Message.WhichOneof(message, "message")
|
||||
match oneofcase:
|
||||
case "registerAgentType":
|
||||
logger.warn("Cant handle registerAgentType, skipping.")
|
||||
case "registerAgentType" | "addSubscription":
|
||||
logger.warn(f"Cant handle {oneofcase}, skipping.")
|
||||
case "request":
|
||||
request: RpcRequest = message.request
|
||||
request: agent_worker_pb2.RpcRequest = message.request
|
||||
task = asyncio.create_task(self._process_request(request))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "response":
|
||||
response: RpcResponse = message.response
|
||||
response: agent_worker_pb2.RpcResponse = message.response
|
||||
task = asyncio.create_task(self._process_response(response))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "event":
|
||||
event: Event = message.event
|
||||
event: agent_worker_pb2.Event = message.event
|
||||
task = asyncio.create_task(self._process_event(event))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case None:
|
||||
logger.warn("No message")
|
||||
|
@ -230,7 +234,8 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
) -> Any:
|
||||
if not self._running:
|
||||
raise ValueError("Runtime must be running when sending message.")
|
||||
assert self._host_connection is not None
|
||||
if self._host_connection is None:
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
# create a new future for the result
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
async with self._pending_requests_lock:
|
||||
|
@ -239,20 +244,21 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
request_id_str = str(request_id)
|
||||
self._pending_requests[request_id_str] = future
|
||||
sender = cast(AgentId, sender)
|
||||
method = MESSAGE_TYPE_REGISTRY.type_name(message)
|
||||
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=method)
|
||||
runtime_message = Message(
|
||||
request=RpcRequest(
|
||||
data_type = MESSAGE_TYPE_REGISTRY.type_name(message)
|
||||
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=data_type)
|
||||
runtime_message = agent_worker_pb2.Message(
|
||||
request=agent_worker_pb2.RpcRequest(
|
||||
request_id=request_id_str,
|
||||
target=AgentIdProto(name=recipient.type, namespace=recipient.key),
|
||||
source=AgentIdProto(name=sender.type, namespace=sender.key),
|
||||
method=method,
|
||||
target=agent_worker_pb2.AgentId(name=recipient.type, namespace=recipient.key),
|
||||
source=agent_worker_pb2.AgentId(name=sender.type, namespace=sender.key),
|
||||
data_type=data_type,
|
||||
data=serialized_message,
|
||||
)
|
||||
)
|
||||
# TODO: Find a way to handle timeouts/errors
|
||||
task = asyncio.create_task(self._host_connection.send(runtime_message))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
return await future
|
||||
|
||||
|
@ -264,20 +270,21 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> None:
|
||||
assert self._host_connection is not None
|
||||
if not self._running:
|
||||
raise ValueError("Runtime must be running when publishing message.")
|
||||
if self._host_connection is None:
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
message_type = MESSAGE_TYPE_REGISTRY.type_name(message)
|
||||
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=message_type)
|
||||
message = Message(
|
||||
event=Event(
|
||||
runtime_message = agent_worker_pb2.Message(
|
||||
event=agent_worker_pb2.Event(
|
||||
topic_type=topic_id.type, topic_source=topic_id.source, data_type=message_type, data=serialized_message
|
||||
)
|
||||
)
|
||||
|
||||
async def write_message() -> None:
|
||||
assert self._host_connection is not None
|
||||
await self._host_connection.send(message)
|
||||
|
||||
await asyncio.create_task(write_message())
|
||||
task = asyncio.create_task(self._host_connection.send(runtime_message))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
raise NotImplementedError("Saving state is not yet implemented.")
|
||||
|
@ -294,70 +301,88 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
|
||||
raise NotImplementedError("Agent load_state is not yet implemented.")
|
||||
|
||||
async def _process_request(self, request: RpcRequest) -> None:
|
||||
async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None:
|
||||
assert self._host_connection is not None
|
||||
target = AgentId(request.target.name, request.target.namespace)
|
||||
source = AgentId(request.source.name, request.source.namespace)
|
||||
|
||||
logging.info(f"Processing request from {source} to {target}")
|
||||
|
||||
# Deserialize the message.
|
||||
message = MESSAGE_TYPE_REGISTRY.deserialize(request.data, type_name=request.data_type)
|
||||
|
||||
# Get the target agent and prepare the message context.
|
||||
target_agent = await self._get_agent(target)
|
||||
message_context = MessageContext(
|
||||
sender=source,
|
||||
topic_id=None,
|
||||
is_rpc=True,
|
||||
cancellation_token=CancellationToken(),
|
||||
)
|
||||
|
||||
# Call the target agent.
|
||||
try:
|
||||
logging.info(f"Processing request from {source} to {target}")
|
||||
target_agent = await self._get_agent(target)
|
||||
message_context = MessageContext(
|
||||
sender=source,
|
||||
topic_id=None,
|
||||
is_rpc=True,
|
||||
cancellation_token=CancellationToken(),
|
||||
)
|
||||
message = MESSAGE_TYPE_REGISTRY.deserialize(request.data, type_name=request.method)
|
||||
response = await target_agent.on_message(message, ctx=message_context)
|
||||
serialized_response = MESSAGE_TYPE_REGISTRY.serialize(response, type_name=request.method)
|
||||
response_message = Message(
|
||||
response=RpcResponse(
|
||||
request_id=request.request_id,
|
||||
result=serialized_response,
|
||||
)
|
||||
)
|
||||
result = await target_agent.on_message(message, ctx=message_context)
|
||||
except BaseException as e:
|
||||
response_message = Message(
|
||||
response=RpcResponse(
|
||||
response_message = agent_worker_pb2.Message(
|
||||
response=agent_worker_pb2.RpcResponse(
|
||||
request_id=request.request_id,
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
# Send the error response.
|
||||
await self._host_connection.send(response_message)
|
||||
return
|
||||
|
||||
# Serialize the result.
|
||||
result_type = MESSAGE_TYPE_REGISTRY.type_name(result)
|
||||
serialized_result = MESSAGE_TYPE_REGISTRY.serialize(result, type_name=result_type)
|
||||
|
||||
# Create the response message.
|
||||
response_message = agent_worker_pb2.Message(
|
||||
response=agent_worker_pb2.RpcResponse(
|
||||
request_id=request.request_id,
|
||||
result_type=result_type,
|
||||
result=serialized_result,
|
||||
)
|
||||
)
|
||||
|
||||
# Send the response.
|
||||
await self._host_connection.send(response_message)
|
||||
|
||||
async def _process_response(self, response: RpcResponse) -> None:
|
||||
# TODO: deserialize the response and set the future result
|
||||
async def _process_response(self, response: agent_worker_pb2.RpcResponse) -> None:
|
||||
# Deserialize the result.
|
||||
result = MESSAGE_TYPE_REGISTRY.deserialize(response.result, type_name=response.result_type)
|
||||
# Get the future and set the result.
|
||||
future = self._pending_requests.pop(response.request_id)
|
||||
if len(response.error) > 0:
|
||||
future.set_exception(Exception(response.error))
|
||||
else:
|
||||
future.set_result(response.result)
|
||||
future.set_result(result)
|
||||
|
||||
async def _process_event(self, event: Event) -> None:
|
||||
...
|
||||
# message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.data_type)
|
||||
|
||||
# for agent_id in self._per_type_subscribers[
|
||||
# (namespace, MESSAGE_TYPE_REGISTRY.type_name(message))
|
||||
# ]:
|
||||
|
||||
# agent = await self._get_agent(agent_id)
|
||||
# message_context = MessageContext(
|
||||
# # TODO: should sender be in the proto even for published events?
|
||||
# sender=None,
|
||||
# # TODO: topic_id
|
||||
# topic_id=None,
|
||||
# is_rpc=False,
|
||||
# cancellation_token=CancellationToken(),
|
||||
# )
|
||||
# try:
|
||||
# await agent.on_message(message, ctx=message_context)
|
||||
# logger.info("%s handled event %s", agent_id, message)
|
||||
# except Exception as e:
|
||||
# event_logger.error("Error handling message", exc_info=e)
|
||||
async def _process_event(self, event: agent_worker_pb2.Event) -> None:
|
||||
message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.data_type)
|
||||
topic_id = TopicId(event.topic_type, event.topic_source)
|
||||
# Get the recipients for the topic.
|
||||
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
|
||||
# Send the message to each recipient.
|
||||
responses: List[Awaitable[Any]] = []
|
||||
for agent_id in recipients:
|
||||
# TODO: avoid sending to the sender.
|
||||
message_context = MessageContext(
|
||||
sender=None,
|
||||
topic_id=topic_id,
|
||||
is_rpc=False,
|
||||
cancellation_token=CancellationToken(),
|
||||
)
|
||||
agent = await self._get_agent(agent_id)
|
||||
future = agent.on_message(message, ctx=message_context)
|
||||
responses.append(future)
|
||||
# Wait for all responses.
|
||||
try:
|
||||
await asyncio.gather(*responses)
|
||||
except BaseException as e:
|
||||
logger.error("Error handling event", exc_info=e)
|
||||
|
||||
async def register(
|
||||
self,
|
||||
|
@ -368,10 +393,10 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
raise ValueError(f"Agent with type {type} already exists.")
|
||||
self._agent_factories[type] = agent_factory
|
||||
|
||||
assert self._host_connection is not None
|
||||
message = Message(registerAgentType=RegisterAgentType(type=type))
|
||||
if self._host_connection is None:
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
message = agent_worker_pb2.Message(registerAgentType=agent_worker_pb2.RegisterAgentType(type=type))
|
||||
await self._host_connection.send(message)
|
||||
logger.info("Sent registerAgentType message for %s", type)
|
||||
return AgentType(type)
|
||||
|
||||
async def _invoke_agent_factory(
|
||||
|
@ -415,7 +440,23 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
raise NotImplementedError("try_get_underlying_agent_instance is not yet implemented.")
|
||||
|
||||
async def add_subscription(self, subscription: Subscription) -> None:
|
||||
raise NotImplementedError("Subscriptions are not yet implemented.")
|
||||
if self._host_connection is None:
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
if not isinstance(subscription, TypeSubscription):
|
||||
raise ValueError("Only TypeSubscription is supported.")
|
||||
# Add to local subscription manager.
|
||||
await self._subscription_manager.add_subscription(subscription)
|
||||
# Send the subscription to the host.
|
||||
message = agent_worker_pb2.Message(
|
||||
addSubscription=agent_worker_pb2.AddSubscription(
|
||||
subscription=agent_worker_pb2.Subscription(
|
||||
typeSubscription=agent_worker_pb2.TypeSubscription(
|
||||
topic_type=subscription.topic_type, agent_type=subscription.agent_type
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
await self._host_connection.send(message)
|
||||
|
||||
async def remove_subscription(self, id: str) -> None:
|
||||
raise NotImplementedError("Subscriptions are not yet implemented.")
|
||||
|
|
|
@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
|
|||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"*\n\x07\x41gentId\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x11\n\tnamespace\x18\x02 \x01(\t\"\xe5\x01\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x1f\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentId\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\t\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xa6\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x01(\t\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xb2\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12\x11\n\tdata_type\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\"\xbc\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"*\n\x07\x41gentId\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x11\n\tnamespace\x18\x02 \x01(\t\"\xf8\x01\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x1f\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentId\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12\x11\n\tdata_type\x18\x05 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x06 \x01(\t\x12\x32\n\x08metadata\x18\x07 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xbb\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x13\n\x0bresult_type\x18\x02 \x01(\t\x12\x0e\n\x06result\x18\x03 \x01(\t\x12\r\n\x05\x65rror\x18\x04 \x01(\t\x12\x33\n\x08metadata\x18\x05 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xc5\x01\n\x05\x45vent\x12\x11\n\tnamespace\x18\x01 \x01(\t\x12\x12\n\ntopic_type\x18\x02 \x01(\t\x12\x14\n\x0ctopic_source\x18\x03 \x01(\t\x12\x11\n\tdata_type\x18\x04 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\t\x12-\n\x08metadata\x18\x06 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"T\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x42\x0e\n\x0csubscription\"=\n\x0f\x41\x64\x64Subscription\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\xf0\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x12\x32\n\x0f\x61\x64\x64Subscription\x18\x05 \x01(\x0b\x32\x17.agents.AddSubscriptionH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
|
@ -30,21 +30,27 @@ if not _descriptor._USE_C_DESCRIPTORS:
|
|||
_globals['_AGENTID']._serialized_start=30
|
||||
_globals['_AGENTID']._serialized_end=72
|
||||
_globals['_RPCREQUEST']._serialized_start=75
|
||||
_globals['_RPCREQUEST']._serialized_end=304
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._serialized_start=257
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._serialized_end=304
|
||||
_globals['_RPCRESPONSE']._serialized_start=307
|
||||
_globals['_RPCRESPONSE']._serialized_end=473
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=257
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=304
|
||||
_globals['_EVENT']._serialized_start=476
|
||||
_globals['_EVENT']._serialized_end=654
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_start=257
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_end=304
|
||||
_globals['_REGISTERAGENTTYPE']._serialized_start=656
|
||||
_globals['_REGISTERAGENTTYPE']._serialized_end=689
|
||||
_globals['_MESSAGE']._serialized_start=692
|
||||
_globals['_MESSAGE']._serialized_end=880
|
||||
_globals['_AGENTRPC']._serialized_start=882
|
||||
_globals['_AGENTRPC']._serialized_end=945
|
||||
_globals['_RPCREQUEST']._serialized_end=323
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._serialized_start=276
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._serialized_end=323
|
||||
_globals['_RPCRESPONSE']._serialized_start=326
|
||||
_globals['_RPCRESPONSE']._serialized_end=513
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=276
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=323
|
||||
_globals['_EVENT']._serialized_start=516
|
||||
_globals['_EVENT']._serialized_end=713
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_start=276
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_end=323
|
||||
_globals['_REGISTERAGENTTYPE']._serialized_start=715
|
||||
_globals['_REGISTERAGENTTYPE']._serialized_end=748
|
||||
_globals['_TYPESUBSCRIPTION']._serialized_start=750
|
||||
_globals['_TYPESUBSCRIPTION']._serialized_end=808
|
||||
_globals['_SUBSCRIPTION']._serialized_start=810
|
||||
_globals['_SUBSCRIPTION']._serialized_end=894
|
||||
_globals['_ADDSUBSCRIPTION']._serialized_start=896
|
||||
_globals['_ADDSUBSCRIPTION']._serialized_end=957
|
||||
_globals['_MESSAGE']._serialized_start=960
|
||||
_globals['_MESSAGE']._serialized_end=1200
|
||||
_globals['_AGENTRPC']._serialized_start=1202
|
||||
_globals['_AGENTRPC']._serialized_end=1265
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
|
|
@ -54,10 +54,12 @@ class RpcRequest(google.protobuf.message.Message):
|
|||
SOURCE_FIELD_NUMBER: builtins.int
|
||||
TARGET_FIELD_NUMBER: builtins.int
|
||||
METHOD_FIELD_NUMBER: builtins.int
|
||||
DATA_TYPE_FIELD_NUMBER: builtins.int
|
||||
DATA_FIELD_NUMBER: builtins.int
|
||||
METADATA_FIELD_NUMBER: builtins.int
|
||||
request_id: builtins.str
|
||||
method: builtins.str
|
||||
data_type: builtins.str
|
||||
data: builtins.str
|
||||
@property
|
||||
def source(self) -> global___AgentId: ...
|
||||
|
@ -72,11 +74,12 @@ class RpcRequest(google.protobuf.message.Message):
|
|||
source: global___AgentId | None = ...,
|
||||
target: global___AgentId | None = ...,
|
||||
method: builtins.str = ...,
|
||||
data_type: builtins.str = ...,
|
||||
data: builtins.str = ...,
|
||||
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["source", b"source", "target", b"target"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["data", b"data", "metadata", b"metadata", "method", b"method", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["data", b"data", "data_type", b"data_type", "metadata", b"metadata", "method", b"method", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ...
|
||||
|
||||
global___RpcRequest = RpcRequest
|
||||
|
||||
|
@ -101,10 +104,12 @@ class RpcResponse(google.protobuf.message.Message):
|
|||
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
|
||||
|
||||
REQUEST_ID_FIELD_NUMBER: builtins.int
|
||||
RESULT_TYPE_FIELD_NUMBER: builtins.int
|
||||
RESULT_FIELD_NUMBER: builtins.int
|
||||
ERROR_FIELD_NUMBER: builtins.int
|
||||
METADATA_FIELD_NUMBER: builtins.int
|
||||
request_id: builtins.str
|
||||
result_type: builtins.str
|
||||
result: builtins.str
|
||||
error: builtins.str
|
||||
@property
|
||||
|
@ -113,11 +118,12 @@ class RpcResponse(google.protobuf.message.Message):
|
|||
self,
|
||||
*,
|
||||
request_id: builtins.str = ...,
|
||||
result_type: builtins.str = ...,
|
||||
result: builtins.str = ...,
|
||||
error: builtins.str = ...,
|
||||
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["error", b"error", "metadata", b"metadata", "request_id", b"request_id", "result", b"result"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["error", b"error", "metadata", b"metadata", "request_id", b"request_id", "result", b"result", "result_type", b"result_type"]) -> None: ...
|
||||
|
||||
global___RpcResponse = RpcResponse
|
||||
|
||||
|
@ -141,11 +147,13 @@ class Event(google.protobuf.message.Message):
|
|||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
|
||||
|
||||
NAMESPACE_FIELD_NUMBER: builtins.int
|
||||
TOPIC_TYPE_FIELD_NUMBER: builtins.int
|
||||
TOPIC_SOURCE_FIELD_NUMBER: builtins.int
|
||||
DATA_TYPE_FIELD_NUMBER: builtins.int
|
||||
DATA_FIELD_NUMBER: builtins.int
|
||||
METADATA_FIELD_NUMBER: builtins.int
|
||||
namespace: builtins.str
|
||||
topic_type: builtins.str
|
||||
topic_source: builtins.str
|
||||
data_type: builtins.str
|
||||
|
@ -155,13 +163,14 @@ class Event(google.protobuf.message.Message):
|
|||
def __init__(
|
||||
self,
|
||||
*,
|
||||
namespace: builtins.str = ...,
|
||||
topic_type: builtins.str = ...,
|
||||
topic_source: builtins.str = ...,
|
||||
data_type: builtins.str = ...,
|
||||
data: builtins.str = ...,
|
||||
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["data", b"data", "data_type", b"data_type", "metadata", b"metadata", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["data", b"data", "data_type", b"data_type", "metadata", b"metadata", "namespace", b"namespace", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ...
|
||||
|
||||
global___Event = Event
|
||||
|
||||
|
@ -180,6 +189,59 @@ class RegisterAgentType(google.protobuf.message.Message):
|
|||
|
||||
global___RegisterAgentType = RegisterAgentType
|
||||
|
||||
@typing.final
|
||||
class TypeSubscription(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
TOPIC_TYPE_FIELD_NUMBER: builtins.int
|
||||
AGENT_TYPE_FIELD_NUMBER: builtins.int
|
||||
topic_type: builtins.str
|
||||
agent_type: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
topic_type: builtins.str = ...,
|
||||
agent_type: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["agent_type", b"agent_type", "topic_type", b"topic_type"]) -> None: ...
|
||||
|
||||
global___TypeSubscription = TypeSubscription
|
||||
|
||||
@typing.final
|
||||
class Subscription(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
TYPESUBSCRIPTION_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def typeSubscription(self) -> global___TypeSubscription: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
typeSubscription: global___TypeSubscription | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["subscription", b"subscription", "typeSubscription", b"typeSubscription"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["subscription", b"subscription", "typeSubscription", b"typeSubscription"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["subscription", b"subscription"]) -> typing.Literal["typeSubscription"] | None: ...
|
||||
|
||||
global___Subscription = Subscription
|
||||
|
||||
@typing.final
|
||||
class AddSubscription(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
SUBSCRIPTION_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def subscription(self) -> global___Subscription: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
subscription: global___Subscription | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["subscription", b"subscription"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["subscription", b"subscription"]) -> None: ...
|
||||
|
||||
global___AddSubscription = AddSubscription
|
||||
|
||||
@typing.final
|
||||
class Message(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
@ -188,6 +250,7 @@ class Message(google.protobuf.message.Message):
|
|||
RESPONSE_FIELD_NUMBER: builtins.int
|
||||
EVENT_FIELD_NUMBER: builtins.int
|
||||
REGISTERAGENTTYPE_FIELD_NUMBER: builtins.int
|
||||
ADDSUBSCRIPTION_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def request(self) -> global___RpcRequest: ...
|
||||
@property
|
||||
|
@ -196,6 +259,8 @@ class Message(google.protobuf.message.Message):
|
|||
def event(self) -> global___Event: ...
|
||||
@property
|
||||
def registerAgentType(self) -> global___RegisterAgentType: ...
|
||||
@property
|
||||
def addSubscription(self) -> global___AddSubscription: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
@ -203,9 +268,10 @@ class Message(google.protobuf.message.Message):
|
|||
response: global___RpcResponse | None = ...,
|
||||
event: global___Event | None = ...,
|
||||
registerAgentType: global___RegisterAgentType | None = ...,
|
||||
addSubscription: global___AddSubscription | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["event", b"event", "message", b"message", "registerAgentType", b"registerAgentType", "request", b"request", "response", b"response"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["event", b"event", "message", b"message", "registerAgentType", b"registerAgentType", "request", b"request", "response", b"response"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "event", "registerAgentType"] | None: ...
|
||||
def HasField(self, field_name: typing.Literal["addSubscription", b"addSubscription", "event", b"event", "message", b"message", "registerAgentType", b"registerAgentType", "request", b"request", "response", b"response"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["addSubscription", b"addSubscription", "event", b"event", "message", b"message", "registerAgentType", b"registerAgentType", "request", b"request", "response", b"response"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "event", "registerAgentType", "addSubscription"] | None: ...
|
||||
|
||||
global___Message = Message
|
||||
|
|
|
@ -35,6 +35,14 @@ class TypeSubscription(Subscription):
|
|||
def id(self) -> str:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def topic_type(self) -> str:
|
||||
return self._topic_type
|
||||
|
||||
@property
|
||||
def agent_type(self) -> str:
|
||||
return self._agent_type
|
||||
|
||||
def is_match(self, topic_id: TopicId) -> bool:
|
||||
return topic_id.type == self._topic_type
|
||||
|
||||
|
|
Loading…
Reference in New Issue