mirror of https://github.com/microsoft/autogen.git
Rename fields in agent id (#334)
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
9f0bdb154c
commit
437dbefc32
|
@ -2,6 +2,7 @@ syntax = "proto3";
|
||||||
|
|
||||||
package agents;
|
package agents;
|
||||||
|
|
||||||
|
// TODO: update
|
||||||
message AgentId {
|
message AgentId {
|
||||||
string name = 1;
|
string name = 1;
|
||||||
string namespace = 2;
|
string namespace = 2;
|
||||||
|
|
|
@ -60,18 +60,18 @@ class GroupChatManager(TypeRoutedAgent):
|
||||||
for key, value in transitions.items():
|
for key, value in transitions.items():
|
||||||
if not value:
|
if not value:
|
||||||
# Make sure no empty transitions are provided.
|
# Make sure no empty transitions are provided.
|
||||||
raise ValueError(f"Empty transition list provided for {key.name}.")
|
raise ValueError(f"Empty transition list provided for {key.type}.")
|
||||||
if key not in participants:
|
if key not in participants:
|
||||||
# Make sure all keys are in the list of participants.
|
# Make sure all keys are in the list of participants.
|
||||||
raise ValueError(f"Transition key {key.name} not found in participants.")
|
raise ValueError(f"Transition key {key.type} not found in participants.")
|
||||||
for v in value:
|
for v in value:
|
||||||
if v not in participants:
|
if v not in participants:
|
||||||
# Make sure all values are in the list of participants.
|
# Make sure all values are in the list of participants.
|
||||||
raise ValueError(f"Transition value {v.name} not found in participants.")
|
raise ValueError(f"Transition value {v.type} not found in participants.")
|
||||||
if self._client is None:
|
if self._client is None:
|
||||||
# Make sure there is only one transition for each key if no model client is provided.
|
# Make sure there is only one transition for each key if no model client is provided.
|
||||||
if len(value) > 1:
|
if len(value) > 1:
|
||||||
raise ValueError(f"Multiple transitions provided for {key.name} but no model client is provided.")
|
raise ValueError(f"Multiple transitions provided for {key.type} but no model client is provided.")
|
||||||
self._tranistions = transitions
|
self._tranistions = transitions
|
||||||
self._on_message_received = on_message_received
|
self._on_message_received = on_message_received
|
||||||
|
|
||||||
|
@ -100,7 +100,7 @@ class GroupChatManager(TypeRoutedAgent):
|
||||||
|
|
||||||
# Get the last speaker.
|
# Get the last speaker.
|
||||||
last_speaker_name = message.source
|
last_speaker_name = message.source
|
||||||
last_speaker_index = next((i for i, p in enumerate(self._participants) if p.name == last_speaker_name), None)
|
last_speaker_index = next((i for i, p in enumerate(self._participants) if p.type == last_speaker_name), None)
|
||||||
|
|
||||||
# Get the candidates for the next speaker.
|
# Get the candidates for the next speaker.
|
||||||
if last_speaker_index is not None:
|
if last_speaker_index is not None:
|
||||||
|
@ -112,7 +112,7 @@ class GroupChatManager(TypeRoutedAgent):
|
||||||
candidates = self._participants
|
candidates = self._participants
|
||||||
else:
|
else:
|
||||||
candidates = self._participants
|
candidates = self._participants
|
||||||
logger.debug(f"Group chat manager next speaker candidates: {[c.name for c in candidates]}")
|
logger.debug(f"Group chat manager next speaker candidates: {[c.type for c in candidates]}")
|
||||||
|
|
||||||
# Select speaker.
|
# Select speaker.
|
||||||
if len(candidates) == 0:
|
if len(candidates) == 0:
|
||||||
|
@ -138,7 +138,7 @@ class GroupChatManager(TypeRoutedAgent):
|
||||||
)
|
)
|
||||||
speaker = candidates[speaker_index]
|
speaker = candidates[speaker_index]
|
||||||
|
|
||||||
logger.debug(f"Group chat manager selected speaker: {speaker.name if speaker is not None else None}")
|
logger.debug(f"Group chat manager selected speaker: {speaker.type if speaker is not None else None}")
|
||||||
|
|
||||||
if speaker is not None:
|
if speaker is not None:
|
||||||
# Send the message to the selected speaker to ask it to publish a response.
|
# Send the message to the selected speaker to ask it to publish a response.
|
||||||
|
|
|
@ -136,7 +136,7 @@ Some additional points to consider:
|
||||||
|
|
||||||
# Find the speaker.
|
# Find the speaker.
|
||||||
try:
|
try:
|
||||||
speaker = next(agent for agent in self._specialists if agent.name == data["next_speaker"]["answer"])
|
speaker = next(agent for agent in self._specialists if agent.type == data["next_speaker"]["answer"])
|
||||||
except StopIteration as e:
|
except StopIteration as e:
|
||||||
raise ValueError(f"Invalid next speaker: {data['next_speaker']['answer']}") from e
|
raise ValueError(f"Invalid next speaker: {data['next_speaker']['answer']}") from e
|
||||||
|
|
||||||
|
@ -171,11 +171,11 @@ Some additional points to consider:
|
||||||
# A reusable description of the team.
|
# A reusable description of the team.
|
||||||
team = "\n".join(
|
team = "\n".join(
|
||||||
[
|
[
|
||||||
agent.name + ": " + (await self.runtime.agent_metadata(agent))["description"]
|
agent.type + ": " + (await self.runtime.agent_metadata(agent))["description"]
|
||||||
for agent in self._specialists
|
for agent in self._specialists
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
names = ", ".join([agent.name for agent in self._specialists])
|
names = ", ".join([agent.type for agent in self._specialists])
|
||||||
|
|
||||||
# A place to store relevant facts.
|
# A place to store relevant facts.
|
||||||
facts = ""
|
facts = ""
|
||||||
|
|
|
@ -98,7 +98,7 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
||||||
alice = await runtime.register_and_get_proxy(
|
alice = await runtime.register_and_get_proxy(
|
||||||
"Alice",
|
"Alice",
|
||||||
lambda: ChatRoomAgent(
|
lambda: ChatRoomAgent(
|
||||||
name=AgentInstantiationContext.current_agent_id().name,
|
name=AgentInstantiationContext.current_agent_id().type,
|
||||||
description="Alice in the chat room.",
|
description="Alice in the chat room.",
|
||||||
background_story="Alice is a software engineer who loves to code.",
|
background_story="Alice is a software engineer who loves to code.",
|
||||||
memory=BufferedChatMemory(buffer_size=10),
|
memory=BufferedChatMemory(buffer_size=10),
|
||||||
|
@ -108,7 +108,7 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
||||||
bob = await runtime.register_and_get_proxy(
|
bob = await runtime.register_and_get_proxy(
|
||||||
"Bob",
|
"Bob",
|
||||||
lambda: ChatRoomAgent(
|
lambda: ChatRoomAgent(
|
||||||
name=AgentInstantiationContext.current_agent_id().name,
|
name=AgentInstantiationContext.current_agent_id().type,
|
||||||
description="Bob in the chat room.",
|
description="Bob in the chat room.",
|
||||||
background_story="Bob is a data scientist who loves to analyze data.",
|
background_story="Bob is a data scientist who loves to analyze data.",
|
||||||
memory=BufferedChatMemory(buffer_size=10),
|
memory=BufferedChatMemory(buffer_size=10),
|
||||||
|
@ -118,7 +118,7 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
||||||
charlie = await runtime.register_and_get_proxy(
|
charlie = await runtime.register_and_get_proxy(
|
||||||
"Charlie",
|
"Charlie",
|
||||||
lambda: ChatRoomAgent(
|
lambda: ChatRoomAgent(
|
||||||
name=AgentInstantiationContext.current_agent_id().name,
|
name=AgentInstantiationContext.current_agent_id().type,
|
||||||
description="Charlie in the chat room.",
|
description="Charlie in the chat room.",
|
||||||
background_story="Charlie is a designer who loves to create art.",
|
background_story="Charlie is a designer who loves to create art.",
|
||||||
memory=BufferedChatMemory(buffer_size=10),
|
memory=BufferedChatMemory(buffer_size=10),
|
||||||
|
@ -126,9 +126,9 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
app.welcoming_notice = f"""Welcome to the chat room demo with the following participants:
|
app.welcoming_notice = f"""Welcome to the chat room demo with the following participants:
|
||||||
1. 👧 {alice.id.name}: {(await alice.metadata)['description']}
|
1. 👧 {alice.id.type}: {(await alice.metadata)['description']}
|
||||||
2. 👱🏼♂️ {bob.id.name}: {(await bob.metadata)['description']}
|
2. 👱🏼♂️ {bob.id.type}: {(await bob.metadata)['description']}
|
||||||
3. 👨🏾🦳 {charlie.id.name}: {(await charlie.metadata)['description']}
|
3. 👨🏾🦳 {charlie.id.type}: {(await charlie.metadata)['description']}
|
||||||
|
|
||||||
Each participant decides on its own whether to respond to the latest message.
|
Each participant decides on its own whether to respond to the latest message.
|
||||||
|
|
||||||
|
|
|
@ -170,16 +170,16 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||||
# )
|
# )
|
||||||
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
future = asyncio.get_event_loop().create_future()
|
||||||
if recipient.name not in self._known_agent_names:
|
if recipient.type not in self._known_agent_names:
|
||||||
future.set_exception(Exception("Recipient not found"))
|
future.set_exception(Exception("Recipient not found"))
|
||||||
|
|
||||||
if sender is not None and sender.namespace != recipient.namespace:
|
if sender is not None and sender.key != recipient.key:
|
||||||
raise ValueError("Sender and recipient must be in the same namespace to communicate.")
|
raise ValueError("Sender and recipient must be in the same namespace to communicate.")
|
||||||
|
|
||||||
await self._process_seen_namespace(recipient.namespace)
|
await self._process_seen_namespace(recipient.key)
|
||||||
|
|
||||||
content = message.__dict__ if hasattr(message, "__dict__") else message
|
content = message.__dict__ if hasattr(message, "__dict__") else message
|
||||||
logger.info(f"Sending message of type {type(message).__name__} to {recipient.name}: {content}")
|
logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}")
|
||||||
|
|
||||||
self._message_queue.append(
|
self._message_queue.append(
|
||||||
SendMessageEnvelope(
|
SendMessageEnvelope(
|
||||||
|
@ -221,7 +221,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||||
if sender is None and namespace is None:
|
if sender is None and namespace is None:
|
||||||
raise ValueError("Namespace must be provided if sender is not provided.")
|
raise ValueError("Namespace must be provided if sender is not provided.")
|
||||||
|
|
||||||
sender_namespace = sender.namespace if sender is not None else None
|
sender_namespace = sender.key if sender is not None else None
|
||||||
explicit_namespace = namespace
|
explicit_namespace = namespace
|
||||||
if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace:
|
if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -250,7 +250,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||||
for agent_id_str in state:
|
for agent_id_str in state:
|
||||||
agent_id = AgentId.from_str(agent_id_str)
|
agent_id = AgentId.from_str(agent_id_str)
|
||||||
if agent_id.name in self._known_agent_names:
|
if agent_id.type in self._known_agent_names:
|
||||||
(await self._get_agent(agent_id)).load_state(state[str(agent_id)])
|
(await self._get_agent(agent_id)).load_state(state[str(agent_id)])
|
||||||
|
|
||||||
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
|
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
|
||||||
|
@ -259,7 +259,8 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||||
# assert recipient in self._agents
|
# assert recipient in self._agents
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown"
|
# TODO use id
|
||||||
|
sender_name = message_envelope.sender.type if message_envelope.sender is not None else "Unknown"
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_name}"
|
f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_name}"
|
||||||
)
|
)
|
||||||
|
@ -297,15 +298,16 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||||
for agent_id in self._per_type_subscribers[
|
for agent_id in self._per_type_subscribers[
|
||||||
(target_namespace, MESSAGE_TYPE_REGISTRY.type_name(message_envelope.message))
|
(target_namespace, MESSAGE_TYPE_REGISTRY.type_name(message_envelope.message))
|
||||||
]:
|
]:
|
||||||
if message_envelope.sender is not None and agent_id.name == message_envelope.sender.name:
|
if message_envelope.sender is not None and agent_id.type == message_envelope.sender.type:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
sender_agent = (
|
sender_agent = (
|
||||||
await self._get_agent(message_envelope.sender) if message_envelope.sender is not None else None
|
await self._get_agent(message_envelope.sender) if message_envelope.sender is not None else None
|
||||||
)
|
)
|
||||||
|
# TODO use id
|
||||||
sender_name = sender_agent.metadata["name"] if sender_agent is not None else "Unknown"
|
sender_name = sender_agent.metadata["name"] if sender_agent is not None else "Unknown"
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Calling message handler for {agent_id.name} with message type {type(message_envelope.message).__name__} published by {sender_name}"
|
f"Calling message handler for {agent_id.type} with message type {type(message_envelope.message).__name__} published by {sender_name}"
|
||||||
)
|
)
|
||||||
# event_logger.info(
|
# event_logger.info(
|
||||||
# MessageEvent(
|
# MessageEvent(
|
||||||
|
@ -342,7 +344,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||||
else message_envelope.message
|
else message_envelope.message
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {message_envelope.recipient} from {message_envelope.sender.name}: {content}"
|
f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {message_envelope.recipient} from {message_envelope.sender.type}: {content}"
|
||||||
)
|
)
|
||||||
# event_logger.info(
|
# event_logger.info(
|
||||||
# MessageEvent(
|
# MessageEvent(
|
||||||
|
@ -455,7 +457,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||||
|
|
||||||
# For all already prepared namespaces we need to prepare this agent
|
# For all already prepared namespaces we need to prepare this agent
|
||||||
for namespace in self._known_namespaces:
|
for namespace in self._known_namespaces:
|
||||||
await self._get_agent(AgentId(name=name, namespace=namespace))
|
await self._get_agent(AgentId(type=name, key=namespace))
|
||||||
|
|
||||||
async def _invoke_agent_factory(
|
async def _invoke_agent_factory(
|
||||||
self,
|
self,
|
||||||
|
@ -482,23 +484,23 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
async def _get_agent(self, agent_id: AgentId) -> Agent:
|
async def _get_agent(self, agent_id: AgentId) -> Agent:
|
||||||
await self._process_seen_namespace(agent_id.namespace)
|
await self._process_seen_namespace(agent_id.key)
|
||||||
if agent_id in self._instantiated_agents:
|
if agent_id in self._instantiated_agents:
|
||||||
return self._instantiated_agents[agent_id]
|
return self._instantiated_agents[agent_id]
|
||||||
|
|
||||||
if agent_id.name not in self._agent_factories:
|
if agent_id.type not in self._agent_factories:
|
||||||
raise LookupError(f"Agent with name {agent_id.name} not found.")
|
raise LookupError(f"Agent with name {agent_id.type} not found.")
|
||||||
|
|
||||||
agent_factory = self._agent_factories[agent_id.name]
|
agent_factory = self._agent_factories[agent_id.type]
|
||||||
|
|
||||||
agent = await self._invoke_agent_factory(agent_factory, agent_id)
|
agent = await self._invoke_agent_factory(agent_factory, agent_id)
|
||||||
for message_type in agent.metadata["subscriptions"]:
|
for message_type in agent.metadata["subscriptions"]:
|
||||||
self._per_type_subscribers[(agent_id.namespace, message_type)].add(agent_id)
|
self._per_type_subscribers[(agent_id.key, message_type)].add(agent_id)
|
||||||
self._instantiated_agents[agent_id] = agent
|
self._instantiated_agents[agent_id] = agent
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
async def get(self, name: str, *, namespace: str = "default") -> AgentId:
|
async def get(self, name: str, *, namespace: str = "default") -> AgentId:
|
||||||
return (await self._get_agent(AgentId(name=name, namespace=namespace))).id
|
return (await self._get_agent(AgentId(type=name, key=namespace))).id
|
||||||
|
|
||||||
async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
|
async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
|
||||||
id = await self.get(name, namespace=namespace)
|
id = await self.get(name, namespace=namespace)
|
||||||
|
@ -506,14 +508,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||||
|
|
||||||
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
|
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
|
||||||
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
|
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
|
||||||
if id.name not in self._agent_factories:
|
if id.type not in self._agent_factories:
|
||||||
raise LookupError(f"Agent with name {id.name} not found.")
|
raise LookupError(f"Agent with name {id.type} not found.")
|
||||||
|
|
||||||
# TODO: check if remote
|
# TODO: check if remote
|
||||||
agent_instance = await self._get_agent(id)
|
agent_instance = await self._get_agent(id)
|
||||||
|
|
||||||
if not isinstance(agent_instance, type):
|
if not isinstance(agent_instance, type):
|
||||||
raise TypeError(f"Agent with name {id.name} is not of type {type.__name__}")
|
raise TypeError(f"Agent with name {id.type} is not of type {type.__name__}")
|
||||||
|
|
||||||
return agent_instance
|
return agent_instance
|
||||||
|
|
||||||
|
@ -525,4 +527,4 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||||
|
|
||||||
self._known_namespaces.add(namespace)
|
self._known_namespaces.add(namespace)
|
||||||
for name in self._known_agent_names:
|
for name in self._known_agent_names:
|
||||||
await self._get_agent(AgentId(name=name, namespace=namespace))
|
await self._get_agent(AgentId(type=name, key=namespace))
|
||||||
|
|
|
@ -2,7 +2,7 @@ import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from ...core import Agent
|
from agnext.core import AgentId
|
||||||
|
|
||||||
|
|
||||||
class LLMCallEvent:
|
class LLMCallEvent:
|
||||||
|
@ -57,16 +57,16 @@ class MessageEvent:
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
payload: Any,
|
payload: Any,
|
||||||
sender: Agent | None,
|
sender: AgentId | None,
|
||||||
receiver: Agent | None,
|
receiver: AgentId | None,
|
||||||
kind: MessageKind,
|
kind: MessageKind,
|
||||||
delivery_stage: DeliveryStage,
|
delivery_stage: DeliveryStage,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
self.kwargs["payload"] = payload
|
self.kwargs["payload"] = payload
|
||||||
self.kwargs["sender"] = None if sender is None else sender.metadata["name"]
|
self.kwargs["sender"] = None if sender is None else str(sender)
|
||||||
self.kwargs["receiver"] = None if receiver is None else receiver.metadata["name"]
|
self.kwargs["receiver"] = None if receiver is None else str(receiver)
|
||||||
self.kwargs["kind"] = kind
|
self.kwargs["kind"] = kind
|
||||||
self.kwargs["delivery_stage"] = delivery_stage
|
self.kwargs["delivery_stage"] = delivery_stage
|
||||||
self.kwargs["type"] = "Message"
|
self.kwargs["type"] = "Message"
|
||||||
|
|
|
@ -68,16 +68,12 @@ class ClosureAgent(Agent):
|
||||||
def metadata(self) -> AgentMetadata:
|
def metadata(self) -> AgentMetadata:
|
||||||
assert self._id is not None
|
assert self._id is not None
|
||||||
return AgentMetadata(
|
return AgentMetadata(
|
||||||
namespace=self._id.namespace,
|
namespace=self._id.key,
|
||||||
name=self._id.name,
|
name=self._id.type,
|
||||||
description=self._description,
|
description=self._description,
|
||||||
subscriptions=self._subscriptions,
|
subscriptions=self._subscriptions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return self.id.name
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def id(self) -> AgentId:
|
def id(self) -> AgentId:
|
||||||
return self._id
|
return self._id
|
||||||
|
|
|
@ -43,4 +43,4 @@ class TypeSubscription(Subscription):
|
||||||
raise CantHandleException("TopicId does not match the subscription")
|
raise CantHandleException("TopicId does not match the subscription")
|
||||||
|
|
||||||
# TODO: Update agentid to reflect agent type and key
|
# TODO: Update agentid to reflect agent type and key
|
||||||
return AgentId(name=self._agent_type, namespace=topic_id.source)
|
return AgentId(type=self._agent_type, key=topic_id.source)
|
||||||
|
|
|
@ -2,33 +2,39 @@ from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
class AgentId:
|
class AgentId:
|
||||||
def __init__(self, name: str, namespace: str) -> None:
|
def __init__(self, type: str, key: str) -> None:
|
||||||
self._name = name
|
if type.isidentifier() is False:
|
||||||
self._namespace = namespace
|
raise ValueError(f"Invalid type: {type}")
|
||||||
|
|
||||||
def __str__(self) -> str:
|
self._type = type
|
||||||
return f"{self._namespace}/{self._name}"
|
self._key = key
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return hash((self._namespace, self._name))
|
return hash((self._type, self._key))
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self._type}:{self._key}"
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"AgentId({self._name}, {self._namespace})"
|
return f'AgentId(type="{self._type}", key="{self._key}")'
|
||||||
|
|
||||||
def __eq__(self, value: object) -> bool:
|
def __eq__(self, value: object) -> bool:
|
||||||
if not isinstance(value, AgentId):
|
if not isinstance(value, AgentId):
|
||||||
return False
|
return False
|
||||||
return self._name == value.name and self._namespace == value.namespace
|
return self._type == value.type and self._key == value.key
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_str(cls, agent_id: str) -> Self:
|
def from_str(cls, agent_id: str) -> Self:
|
||||||
namespace, name = agent_id.split("/")
|
items = agent_id.split(":", maxsplit=1)
|
||||||
return cls(name, namespace)
|
if len(items) != 2:
|
||||||
|
raise ValueError(f"Invalid agent id: {agent_id}")
|
||||||
|
type, key = items[0], items[1]
|
||||||
|
return cls(type, key)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def namespace(self) -> str:
|
def type(self) -> str:
|
||||||
return self._namespace
|
return self._type
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def key(self) -> str:
|
||||||
return self._name
|
return self._key
|
||||||
|
|
|
@ -15,8 +15,8 @@ class BaseAgent(ABC, Agent):
|
||||||
def metadata(self) -> AgentMetadata:
|
def metadata(self) -> AgentMetadata:
|
||||||
assert self._id is not None
|
assert self._id is not None
|
||||||
return AgentMetadata(
|
return AgentMetadata(
|
||||||
namespace=self._id.namespace,
|
namespace=self._id.key,
|
||||||
name=self._id.name,
|
name=self._id.type,
|
||||||
description=self._description,
|
description=self._description,
|
||||||
subscriptions=self._subscriptions,
|
subscriptions=self._subscriptions,
|
||||||
)
|
)
|
||||||
|
@ -38,8 +38,8 @@ class BaseAgent(ABC, Agent):
|
||||||
self._subscriptions = subscriptions
|
self._subscriptions = subscriptions
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def type(self) -> str:
|
||||||
return self.id.name
|
return self.id.type
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def id(self) -> AgentId:
|
def id(self) -> AgentId:
|
||||||
|
|
|
@ -292,8 +292,8 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||||
runtime_message = Message(
|
runtime_message = Message(
|
||||||
request=RpcRequest(
|
request=RpcRequest(
|
||||||
request_id=request_id_str,
|
request_id=request_id_str,
|
||||||
target=AgentIdProto(name=recipient.name, namespace=recipient.namespace),
|
target=AgentIdProto(name=recipient.type, namespace=recipient.key),
|
||||||
source=AgentIdProto(name=sender.name, namespace=sender.namespace),
|
source=AgentIdProto(name=sender.type, namespace=sender.key),
|
||||||
data=message,
|
data=message,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -310,7 +310,7 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||||
cancellation_token: CancellationToken | None = None,
|
cancellation_token: CancellationToken | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert self._runtime_connection is not None
|
assert self._runtime_connection is not None
|
||||||
sender_namespace = sender.namespace if sender is not None else None
|
sender_namespace = sender.key if sender is not None else None
|
||||||
explicit_namespace = namespace
|
explicit_namespace = namespace
|
||||||
if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace:
|
if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -356,7 +356,7 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||||
|
|
||||||
# For all already prepared namespaces we need to prepare this agent
|
# For all already prepared namespaces we need to prepare this agent
|
||||||
for namespace in self._known_namespaces:
|
for namespace in self._known_namespaces:
|
||||||
await self._get_agent(AgentId(name=name, namespace=namespace))
|
await self._get_agent(AgentId(type=name, key=namespace))
|
||||||
|
|
||||||
await self.send_register_agent_type(name)
|
await self.send_register_agent_type(name)
|
||||||
|
|
||||||
|
@ -385,25 +385,25 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
async def _get_agent(self, agent_id: AgentId) -> Agent:
|
async def _get_agent(self, agent_id: AgentId) -> Agent:
|
||||||
await self._process_seen_namespace(agent_id.namespace)
|
await self._process_seen_namespace(agent_id.key)
|
||||||
if agent_id in self._instantiated_agents:
|
if agent_id in self._instantiated_agents:
|
||||||
return self._instantiated_agents[agent_id]
|
return self._instantiated_agents[agent_id]
|
||||||
|
|
||||||
if agent_id.name not in self._agent_factories:
|
if agent_id.type not in self._agent_factories:
|
||||||
raise ValueError(f"Agent with name {agent_id.name} not found.")
|
raise ValueError(f"Agent with name {agent_id.type} not found.")
|
||||||
|
|
||||||
agent_factory = self._agent_factories[agent_id.name]
|
agent_factory = self._agent_factories[agent_id.type]
|
||||||
|
|
||||||
agent = await self._invoke_agent_factory(agent_factory, agent_id)
|
agent = await self._invoke_agent_factory(agent_factory, agent_id)
|
||||||
|
|
||||||
for message_type in agent.metadata["subscriptions"]:
|
for message_type in agent.metadata["subscriptions"]:
|
||||||
self._per_type_subscribers[(agent_id.namespace, message_type)].add(agent_id)
|
self._per_type_subscribers[(agent_id.key, message_type)].add(agent_id)
|
||||||
|
|
||||||
self._instantiated_agents[agent_id] = agent
|
self._instantiated_agents[agent_id] = agent
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
async def get(self, name: str, *, namespace: str = "default") -> AgentId:
|
async def get(self, name: str, *, namespace: str = "default") -> AgentId:
|
||||||
return (await self._get_agent(AgentId(name=name, namespace=namespace))).id
|
return (await self._get_agent(AgentId(type=name, key=namespace))).id
|
||||||
|
|
||||||
async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
|
async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
|
||||||
id = await self.get(name, namespace=namespace)
|
id = await self.get(name, namespace=namespace)
|
||||||
|
@ -421,4 +421,4 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||||
|
|
||||||
self._known_namespaces.add(namespace)
|
self._known_namespaces.add(namespace)
|
||||||
for name in self._known_agent_names:
|
for name in self._known_agent_names:
|
||||||
await self._get_agent(AgentId(name=name, namespace=namespace))
|
await self._get_agent(AgentId(type=name, key=namespace))
|
||||||
|
|
|
@ -25,8 +25,8 @@ async def test_register_receives_publish() -> None:
|
||||||
queue = asyncio.Queue[tuple[str, str]]()
|
queue = asyncio.Queue[tuple[str, str]]()
|
||||||
|
|
||||||
async def log_message(_runtime: AgentRuntime, id: AgentId, message: Message, cancellation_token: CancellationToken) -> None:
|
async def log_message(_runtime: AgentRuntime, id: AgentId, message: Message, cancellation_token: CancellationToken) -> None:
|
||||||
namespace = id.namespace
|
key = id.key
|
||||||
await queue.put((namespace, message.content))
|
await queue.put((key, message.content))
|
||||||
|
|
||||||
await runtime.register("name", lambda: ClosureAgent("My agent", log_message))
|
await runtime.register("name", lambda: ClosureAgent("My agent", log_message))
|
||||||
run_context = runtime.start()
|
run_context = runtime.start()
|
||||||
|
|
|
@ -16,7 +16,7 @@ def test_type_subscription_match() -> None:
|
||||||
def test_type_subscription_map() -> None:
|
def test_type_subscription_map() -> None:
|
||||||
sub = TypeSubscription(topic_type="t1", agent_type="a1")
|
sub = TypeSubscription(topic_type="t1", agent_type="a1")
|
||||||
|
|
||||||
assert sub.map_to_agent(TopicId(type="t1", source="s1")) == AgentId(name="a1", namespace="s1")
|
assert sub.map_to_agent(TopicId(type="t1", source="s1")) == AgentId(type="a1", key="s1")
|
||||||
|
|
||||||
with pytest.raises(CantHandleException):
|
with pytest.raises(CantHandleException):
|
||||||
_agent_id = sub.map_to_agent(TopicId(type="t0", source="s1"))
|
_agent_id = sub.map_to_agent(TopicId(type="t0", source="s1"))
|
||||||
|
|
Loading…
Reference in New Issue