Agent factory can be async (#247)

This commit is contained in:
Jack Gerrits 2024-07-23 11:49:38 -07:00 committed by GitHub
parent 718fad6e0d
commit a52d3bab53
47 changed files with 352 additions and 299 deletions

View File

@ -1,9 +1,8 @@
import asyncio
import logging
import json
import os
from typing import Any, Dict, List, Tuple, Union
from typing import List
from agnext.application import SingleThreadedAgentRuntime
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.components.models import (
@ -18,13 +17,17 @@ from agnext.application.logging import EVENT_LOGGER_NAME
from team_one.markdown_browser import MarkdownConverter, UnsupportedFormatException
from team_one.agents.coder import Coder, Executor
from team_one.agents.orchestrator import LedgerOrchestrator
from team_one.messages import BroadcastMessage, OrchestrationEvent, RequestReplyMessage
from team_one.messages import BroadcastMessage
from team_one.agents.multimodal_web_surfer import MultimodalWebSurfer
from team_one.agents.file_surfer import FileSurfer
from team_one.utils import LogHandler, message_content_to_str, create_completion_client_from_env
from team_one.utils import LogHandler, message_content_to_str
import re
from agnext.components.models import AssistantMessage
async def response_preparer(task: str, source: str, client: ChatCompletionClient, transcript: List[LLMMessage]):
async def response_preparer(task: str, source: str, client: ChatCompletionClient, transcript: List[LLMMessage]) -> str:
messages: List[LLMMessage] = [
UserMessage(
content=f"Earlier you were asked the following:\n\n{task}\n\nYour team then worked diligently to address that request. Here is a transcript of that conversation:",
@ -37,7 +40,8 @@ async def response_preparer(task: str, source: str, client: ChatCompletionClient
messages.append(
UserMessage(
content = message_content_to_str(message.content),
source=message.source,
# TODO fix this -> remove type ignore
source=message.source, # type: ignore
)
)
@ -68,7 +72,7 @@ If you are asked for a comma separated list, apply the above rules depending on
# No answer
if "unable to determine" in response.content.lower():
messages.append( AssistantMessage(content=response.content, source="self" ) )
messages.append(
messages.append(
UserMessage(
content= f"""
I understand that a definitive answer could not be determined. Please make a well-informed EDUCATED GUESS based on the conversation.
@ -115,29 +119,29 @@ async def main() -> None:
)
# Register agents.
coder = runtime.register_and_get_proxy(
coder = await runtime.register_and_get_proxy(
"Coder",
lambda: Coder(model_client=client),
)
executor = runtime.register_and_get_proxy(
executor = await runtime.register_and_get_proxy(
"Executor",
lambda: Executor(
"A agent for executing code", executor=LocalCommandLineCodeExecutor()
),
)
file_surfer = runtime.register_and_get_proxy(
file_surfer = await runtime.register_and_get_proxy(
"file_surfer",
lambda: FileSurfer(model_client=client),
)
web_surfer = runtime.register_and_get_proxy(
web_surfer = await runtime.register_and_get_proxy(
"WebSurfer",
lambda: MultimodalWebSurfer(), # Configuration is set later by init()
)
orchestrator = runtime.register_and_get_proxy("orchestrator", lambda: LedgerOrchestrator(
orchestrator = await runtime.register_and_get_proxy("orchestrator", lambda: LedgerOrchestrator(
agents=[coder, executor, file_surfer, web_surfer],
model_client=client,
))
@ -185,7 +189,7 @@ async def main() -> None:
actual_orchestrator = runtime._get_agent(orchestrator.id) # type: ignore
assert isinstance(actual_orchestrator, LedgerOrchestrator)
transcript: List[LLMMessage] = actual_orchestrator._chat_history # type: ignore
print(await response_preparer(task=task, source=orchestrator.metadata["name"], client=client, transcript=transcript))
print(await response_preparer(task=task, source=(await orchestrator.metadata)["name"], client=client, transcript=transcript))

View File

@ -34,18 +34,18 @@ async def main() -> None:
)
# Register agents.
coder = runtime.register_and_get_proxy(
coder = await runtime.register_and_get_proxy(
"Coder",
lambda: Coder(model_client=client),
)
executor = runtime.register_and_get_proxy(
executor = await runtime.register_and_get_proxy(
"Executor",
lambda: Executor(
"A agent for executing code", executor=LocalCommandLineCodeExecutor()
),
)
runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor]))
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor]))
prompt = ""
with open("prompt.txt", "rt") as fh:

View File

@ -1,12 +1,11 @@
import asyncio
import json
import re
import uuid
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union
from typing import Dict, List, Union
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import FunctionCall, TypeRoutedAgent, message_handler
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components.code_executor import (
CodeBlock,
CodeExecutor,
@ -16,16 +15,12 @@ from agnext.components.models import (
AssistantMessage,
AzureOpenAIChatCompletionClient,
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
ModelCapabilities,
OpenAIChatCompletionClient,
SystemMessage,
UserMessage,
)
from agnext.components.tools import CodeExecutionResult, PythonCodeExecutionTool
from agnext.core import AgentId, CancellationToken
from agnext.core import CancellationToken
# from azure.identity import DefaultAzureCredential, get_bearer_token_provider
@ -66,7 +61,7 @@ if __name__ == "__main__":
main()
```
The user cannot provide any feedback or perform any other action beyond executing the code you suggest. In particular, the user can't modify your code, and can't copy and paste anything, and can't fill in missing values. Thus, do not suggest incomplete code which requires users to perform any of these actions.
The user cannot provide any feedback or perform any other action beyond executing the code you suggest. In particular, the user can't modify your code, and can't copy and paste anything, and can't fill in missing values. Thus, do not suggest incomplete code which requires users to perform any of these actions.
Check the execution result returned by the user. If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes -- code blocks must stand alone and be ready to execute without modification. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, and think of a different approach to try.
@ -222,11 +217,11 @@ async def main() -> None:
)
# Register agents.
coder = runtime.register_and_get(
coder = await runtime.register_and_get(
"Coder",
lambda: Coder(model_client=client),
)
runtime.register(
await runtime.register(
"Executor",
lambda: Executor(
"A agent for executing code", executor=LocalCommandLineCodeExecutor()

View File

@ -32,7 +32,7 @@ dependencies = [
"mypy==1.10.0",
"ruff==0.4.8",
"tiktoken",
"types-Pillow",
"types-pillow",
"polars",
"chess",
"tavily-python",

View File

@ -22,10 +22,12 @@ async def select_speaker(memory: ChatMemory[Message], client: ChatCompletionClie
history = "\n".join(history_messages)
# Construct agent roles.
roles = "\n".join([f"{agent.metadata['name']}: {agent.metadata['description']}".strip() for agent in agents])
roles = "\n".join(
[f"{(await agent.metadata)['name']}: {(await agent.metadata)['description']}".strip() for agent in agents]
)
# Construct agent list.
participants = str([agent.metadata["name"] for agent in agents])
participants = str([(await agent.metadata)["name"] for agent in agents])
# Select the next speaker.
select_speaker_prompt = f"""You are in a role play game. The following roles are available:
@ -39,16 +41,22 @@ Read the above conversation. Then select the next role from {participants} to pl
select_speaker_messages = [SystemMessage(select_speaker_prompt)]
response = await client.create(messages=select_speaker_messages)
assert isinstance(response.content, str)
mentions = mentioned_agents(response.content, agents)
mentions = await mentioned_agents(response.content, agents)
if len(mentions) != 1:
raise ValueError(f"Expected exactly one agent to be mentioned, but got {mentions}")
agent_name = list(mentions.keys())[0]
agent_index = next((i for i, agent in enumerate(agents) if agent.metadata["name"] == agent_name), None)
# Get the index of the selected agent by name
agent_index = 0
for i, agent in enumerate(agents):
if (await agent.metadata)["name"] == agent_name:
agent_index = i
break
assert agent_index is not None
return agent_index
def mentioned_agents(message_content: str, agents: List[AgentProxy]) -> Dict[str, int]:
async def mentioned_agents(message_content: str, agents: List[AgentProxy]) -> Dict[str, int]:
"""Counts the number of times each agent is mentioned in the provided message content.
Agent names will match under any of the following conditions (all case-sensitive):
- Exact name match
@ -66,7 +74,7 @@ def mentioned_agents(message_content: str, agents: List[AgentProxy]) -> Dict[str
for agent in agents:
# Finds agent mentions, taking word boundaries into account,
# accommodates escaping underscores and underscores as spaces
name = agent.metadata["name"]
name = (await agent.metadata)["name"]
regex = (
r"(?<=\W)("
+ re.escape(name)

View File

@ -170,7 +170,10 @@ Some additional points to consider:
# A reusable description of the team.
team = "\n".join(
[agent.name + ": " + self.runtime.agent_metadata(agent)["description"] for agent in self._specialists]
[
agent.name + ": " + (await self.runtime.agent_metadata(agent))["description"]
for agent in self._specialists
]
)
names = ", ".join([agent.name for agent in self._specialists])

View File

@ -45,8 +45,8 @@ class Outer(TypeRoutedAgent):
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
inner = runtime.register_and_get("inner", Inner)
outer = runtime.register_and_get("outer", lambda: Outer(inner))
inner = await runtime.register_and_get("inner", Inner)
outer = await runtime.register_and_get("outer", lambda: Outer(inner))
run_context = runtime.start()

View File

@ -45,7 +45,7 @@ class ChatCompletionAgent(TypeRoutedAgent):
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
agent = runtime.register_and_get(
agent = await runtime.register_and_get(
"chat_agent",
lambda: ChatCompletionAgent("Chat agent", get_chat_completion_client_from_envs(model="gpt-3.5-turbo")),
)

View File

@ -77,7 +77,7 @@ async def main() -> None:
runtime = SingleThreadedAgentRuntime()
# Register the agents.
jack = runtime.register_and_get(
jack = await runtime.register_and_get(
"Jack",
lambda: ChatCompletionAgent(
description="Jack a comedian",
@ -88,7 +88,7 @@ async def main() -> None:
termination_word="TERMINATE",
),
)
runtime.register_and_get(
await runtime.register_and_get(
"Cathy",
lambda: ChatCompletionAgent(
description="Cathy a poet",

View File

@ -166,7 +166,7 @@ class EventHandler(AsyncAssistantEventHandler):
print("\n".join(citations))
def assistant_chat(runtime: AgentRuntime) -> AgentId:
async def assistant_chat(runtime: AgentRuntime) -> AgentId:
oai_assistant = openai.beta.assistants.create(
model="gpt-4-turbo",
description="An AI assistant that helps with everyday tasks.",
@ -177,7 +177,7 @@ def assistant_chat(runtime: AgentRuntime) -> AgentId:
thread = openai.beta.threads.create(
tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}},
)
assistant = runtime.register_and_get(
assistant = await runtime.register_and_get(
"Assistant",
lambda: OpenAIAssistantAgent(
description="An AI assistant that helps with everyday tasks.",
@ -188,7 +188,7 @@ def assistant_chat(runtime: AgentRuntime) -> AgentId:
),
)
user = runtime.register_and_get(
user = await runtime.register_and_get(
"User",
lambda: UserProxyAgent(
client=openai.AsyncClient(),
@ -198,7 +198,7 @@ def assistant_chat(runtime: AgentRuntime) -> AgentId:
),
)
# Create a group chat manager to facilitate a turn-based conversation.
runtime.register(
await runtime.register(
"GroupChatManager",
lambda: GroupChatManager(
description="A group chat manager.",
@ -225,7 +225,7 @@ This will upload data.csv to the assistant for use with the code interpreter too
Type "exit" to exit the chat.
"""
runtime = SingleThreadedAgentRuntime()
user = assistant_chat(runtime)
user = await assistant_chat(runtime)
_run_context = runtime.start()
print(usage)
# Request the user to start the conversation.

View File

@ -87,15 +87,15 @@ class ChatRoomUserAgent(TextualUserAgent):
# Define a chat room with participants -- the runtime is the chat room.
def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
runtime.register(
async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
await runtime.register(
"User",
lambda: ChatRoomUserAgent(
description="The user in the chat room.",
app=app,
),
)
alice = runtime.register_and_get_proxy(
alice = await runtime.register_and_get_proxy(
"Alice",
lambda rt, id: ChatRoomAgent(
name=id.name,
@ -105,7 +105,7 @@ def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
),
)
bob = runtime.register_and_get_proxy(
bob = await runtime.register_and_get_proxy(
"Bob",
lambda rt, id: ChatRoomAgent(
name=id.name,
@ -115,7 +115,7 @@ def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
),
)
charlie = runtime.register_and_get_proxy(
charlie = await runtime.register_and_get_proxy(
"Charlie",
lambda rt, id: ChatRoomAgent(
name=id.name,
@ -126,9 +126,9 @@ def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
),
)
app.welcoming_notice = f"""Welcome to the chat room demo with the following participants:
1. 👧 {alice.id.name}: {alice.metadata['description']}
2. 👱🏼 {bob.id.name}: {bob.metadata['description']}
3. 👨🏾🦳 {charlie.id.name}: {charlie.metadata['description']}
1. 👧 {alice.id.name}: {(await alice.metadata)['description']}
2. 👱🏼 {bob.id.name}: {(await bob.metadata)['description']}
3. 👨🏾🦳 {charlie.id.name}: {(await charlie.metadata)['description']}
Each participant decides on its own whether to respond to the latest message.
@ -139,7 +139,7 @@ You can greet the chat room by typing your first message below.
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
app = TextualChatApp(runtime, user_name="You")
chat_room(runtime, app)
await chat_room(runtime, app)
_run_context = runtime.start()
await app.run_async()

View File

@ -88,7 +88,7 @@ def make_move(
return f"Moved {piece_name} ({piece_symbol}) from {SQUARE_NAMES[newMove.from_square]} to {SQUARE_NAMES[newMove.to_square]}."
def chess_game(runtime: AgentRuntime) -> None: # type: ignore
async def chess_game(runtime: AgentRuntime) -> None: # type: ignore
"""Create agents for a chess game and return the group chat."""
# Create the board.
@ -156,7 +156,7 @@ def chess_game(runtime: AgentRuntime) -> None: # type: ignore
),
]
black = runtime.register_and_get(
black = await runtime.register_and_get(
"PlayerBlack",
lambda: ChatCompletionAgent(
description="Player playing black.",
@ -173,7 +173,7 @@ def chess_game(runtime: AgentRuntime) -> None: # type: ignore
tools=black_tools,
),
)
white = runtime.register_and_get(
white = await runtime.register_and_get(
"PlayerWhite",
lambda: ChatCompletionAgent(
description="Player playing white.",
@ -192,7 +192,7 @@ def chess_game(runtime: AgentRuntime) -> None: # type: ignore
)
# Create a group chat manager for the chess game to orchestrate a turn-based
# conversation between the two agents.
runtime.register(
await runtime.register(
"ChessGame",
lambda: GroupChatManager(
description="A chess game between two agents.",
@ -204,7 +204,7 @@ def chess_game(runtime: AgentRuntime) -> None: # type: ignore
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
chess_game(runtime)
await chess_game(runtime)
# Publish an initial message to trigger the group chat manager to start orchestration.
await runtime.publish_message(TextMessage(content="Game started.", source="System"), namespace="default")
while True:

View File

@ -19,15 +19,15 @@ from common.utils import get_chat_completion_client_from_envs
from utils import TextualChatApp, TextualUserAgent
def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None:
runtime.register(
async def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None:
await runtime.register(
"User",
lambda: TextualUserAgent(
description="A user looking for illustration.",
app=app,
),
)
descriptor = runtime.register_and_get_proxy(
descriptor = await runtime.register_and_get_proxy(
"Descriptor",
lambda: ChatCompletionAgent(
description="An AI agent that provides a description of the image.",
@ -46,7 +46,7 @@ def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo", max_tokens=500),
),
)
illustrator = runtime.register_and_get_proxy(
illustrator = await runtime.register_and_get_proxy(
"Illustrator",
lambda: ImageGenerationAgent(
description="An AI agent that generates images.",
@ -55,7 +55,7 @@ def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None:
memory=BufferedChatMemory(buffer_size=1),
),
)
critic = runtime.register_and_get_proxy(
critic = await runtime.register_and_get_proxy(
"Critic",
lambda: ChatCompletionAgent(
description="An AI agent that provides feedback on images given user's requirements.",
@ -74,7 +74,7 @@ def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
),
)
runtime.register(
await runtime.register(
"GroupChatManager",
lambda: GroupChatManager(
description="A chat manager that handles group chat.",
@ -86,9 +86,9 @@ def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None:
app.welcoming_notice = f"""You are now in a group chat with the following agents:
1. 🤖 {descriptor.metadata['name']}: {descriptor.metadata.get('description')}
2. 🤖 {illustrator.metadata['name']}: {illustrator.metadata.get('description')}
3. 🤖 {critic.metadata['name']}: {critic.metadata.get('description')}
1. 🤖 {(await descriptor.metadata)['name']}: {(await descriptor.metadata).get('description')}
2. 🤖 {(await illustrator.metadata)['name']}: {(await illustrator.metadata).get('description')}
3. 🤖 {(await critic.metadata)['name']}: {(await critic.metadata).get('description')}
Provide a prompt for the illustrator to generate an image.
"""
@ -97,7 +97,7 @@ Provide a prompt for the illustrator to generate an image.
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
app = TextualChatApp(runtime, user_name="You")
illustrator_critics(runtime, app)
await illustrator_critics(runtime, app)
_run_context = runtime.start()
await app.run_async()

View File

@ -105,15 +105,15 @@ async def create_image(
return f"Image created and saved to {filename}."
def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: # type: ignore
user_agent = runtime.register_and_get(
async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: # type: ignore
user_agent = await runtime.register_and_get(
"Customer",
lambda: TextualUserAgent(
description="A customer looking for help.",
app=app,
),
)
developer = runtime.register_and_get(
developer = await runtime.register_and_get(
"Developer",
lambda: ChatCompletionAgent(
description="A Python software developer.",
@ -153,7 +153,7 @@ def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: #
),
)
product_manager = runtime.register_and_get(
product_manager = await runtime.register_and_get(
"ProductManager",
lambda: ChatCompletionAgent(
description="A product manager. "
@ -182,7 +182,7 @@ def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: #
tool_approver=user_agent,
),
)
ux_designer = runtime.register_and_get(
ux_designer = await runtime.register_and_get(
"UserExperienceDesigner",
lambda: ChatCompletionAgent(
description="A user experience designer for creating user interfaces.",
@ -215,7 +215,7 @@ def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: #
),
)
illustrator = runtime.register_and_get(
illustrator = await runtime.register_and_get(
"Illustrator",
lambda: ChatCompletionAgent(
description="An illustrator for creating images.",
@ -240,7 +240,7 @@ def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: #
tool_approver=user_agent,
),
)
runtime.register(
await runtime.register(
"GroupChatManager",
lambda: GroupChatManager(
description="A group chat manager.",
@ -279,7 +279,7 @@ def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: #
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
app = TextualChatApp(runtime, user_name="You")
software_consultancy(runtime, app)
await software_consultancy(runtime, app)
# Start the runtime.
_run_context = runtime.start()
# Start the app.

View File

@ -27,8 +27,8 @@ async def build_app(runtime: AgentRuntime) -> None:
api_version="2024-02-01",
)
runtime.register("GraphicDesigner", lambda: GraphicDesignerAgent(client=image_client))
runtime.register("Auditor", lambda: AuditAgent(model_client=chat_client))
await runtime.register("GraphicDesigner", lambda: GraphicDesignerAgent(client=image_client))
await runtime.register("Auditor", lambda: AuditAgent(model_client=chat_client))
runtime.get("GraphicDesigner")
runtime.get("Auditor")
await runtime.get("GraphicDesigner")
await runtime.get("Auditor")

View File

@ -30,7 +30,7 @@ class Printer(TypeRoutedAgent):
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
await build_app(runtime)
runtime.register("Printer", lambda: Printer())
await runtime.register("Printer", lambda: Printer())
ctx = runtime.start()

View File

@ -180,8 +180,10 @@ async def main(task: str, temp_dir: str) -> None:
runtime = SingleThreadedAgentRuntime()
# Register the agents.
runtime.register("coder", lambda: Coder(model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo")))
runtime.register("executor", lambda: Executor(executor=LocalCommandLineCodeExecutor(work_dir=temp_dir)))
await runtime.register(
"coder", lambda: Coder(model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"))
)
await runtime.register("executor", lambda: Executor(executor=LocalCommandLineCodeExecutor(work_dir=temp_dir)))
run_context = runtime.start()
# Publish the task message.

View File

@ -251,14 +251,14 @@ Code: <Your code>
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
runtime.register(
await runtime.register(
"ReviewerAgent",
lambda: ReviewerAgent(
description="Code Reviewer",
model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo"),
),
)
runtime.register(
await runtime.register(
"CoderAgent",
lambda: CoderAgent(
description="Coder",

View File

@ -113,7 +113,7 @@ async def main() -> None:
runtime = SingleThreadedAgentRuntime()
# Register the participants.
agent1 = runtime.register_and_get(
agent1 = await runtime.register_and_get(
"DataScientist",
lambda: GroupChatParticipant(
description="A data scientist",
@ -121,7 +121,7 @@ async def main() -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo"),
),
)
agent2 = runtime.register_and_get(
agent2 = await runtime.register_and_get(
"Engineer",
lambda: GroupChatParticipant(
description="An engineer",
@ -129,7 +129,7 @@ async def main() -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo"),
),
)
agent3 = runtime.register_and_get(
agent3 = await runtime.register_and_get(
"Artist",
lambda: GroupChatParticipant(
description="An artist",
@ -139,7 +139,7 @@ async def main() -> None:
)
# Register the group chat manager.
runtime.register(
await runtime.register(
"GroupChatManager",
lambda: RoundRobinGroupChatManager(
description="A group chat manager",

View File

@ -112,7 +112,7 @@ class AggregatorAgent(TypeRoutedAgent):
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
# TODO: use different models for each agent.
runtime.register(
await runtime.register(
"ReferenceAgent1",
lambda: ReferenceAgent(
description="Reference Agent 1",
@ -120,7 +120,7 @@ async def main() -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo", temperature=0.1),
),
)
runtime.register(
await runtime.register(
"ReferenceAgent2",
lambda: ReferenceAgent(
description="Reference Agent 2",
@ -128,7 +128,7 @@ async def main() -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo", temperature=0.5),
),
)
runtime.register(
await runtime.register(
"ReferenceAgent3",
lambda: ReferenceAgent(
description="Reference Agent 3",
@ -136,7 +136,7 @@ async def main() -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo", temperature=1.0),
),
)
runtime.register(
await runtime.register(
"AggregatorAgent",
lambda: AggregatorAgent(
description="Aggregator Agent",

View File

@ -211,7 +211,7 @@ async def main(question: str) -> None:
# Register the solver agents.
# Create a sparse connection: each solver agent has two neighbors.
# NOTE: to create a dense connection, each solver agent should be connected to all other solver agents.
runtime.register(
await runtime.register(
"MathSolver1",
lambda: MathSolver(
get_chat_completion_client_from_envs(model="gpt-3.5-turbo"),
@ -219,7 +219,7 @@ async def main(question: str) -> None:
max_round=3,
),
)
runtime.register(
await runtime.register(
"MathSolver2",
lambda: MathSolver(
get_chat_completion_client_from_envs(model="gpt-3.5-turbo"),
@ -227,7 +227,7 @@ async def main(question: str) -> None:
max_round=3,
),
)
runtime.register(
await runtime.register(
"MathSolver3",
lambda: MathSolver(
get_chat_completion_client_from_envs(model="gpt-3.5-turbo"),
@ -235,7 +235,7 @@ async def main(question: str) -> None:
max_round=3,
),
)
runtime.register(
await runtime.register(
"MathSolver4",
lambda: MathSolver(
get_chat_completion_client_from_envs(model="gpt-3.5-turbo"),
@ -244,7 +244,7 @@ async def main(question: str) -> None:
),
)
# Register the aggregator agent.
runtime.register("MathAggregator", lambda: MathAggregator(num_solvers=4))
await runtime.register("MathAggregator", lambda: MathAggregator(num_solvers=4))
run_context = runtime.start()

View File

@ -130,7 +130,7 @@ async def main() -> None:
)
]
# Register agents.
tool_agent = runtime.register_and_get(
tool_agent = await runtime.register_and_get(
"tool_enabled_agent",
lambda: ToolEnabledAgent(
description="Tool Use Agent",

View File

@ -191,8 +191,8 @@ async def main() -> None:
)
]
# Register agents.
runtime.register("tool_executor", lambda: ToolExecutorAgent("Tool Executor", tools))
runtime.register(
await runtime.register("tool_executor", lambda: ToolExecutorAgent("Tool Executor", tools))
await runtime.register(
"tool_use_agent",
lambda: ToolUseAgent(
description="Tool Use Agent",

View File

@ -32,7 +32,7 @@ async def main() -> None:
# Create the runtime.
runtime = SingleThreadedAgentRuntime()
# Register agents.
tool_agent = runtime.register_and_get(
tool_agent = await runtime.register_and_get(
"tool_enabled_agent",
lambda: ToolEnabledAgent(
description="Tool Use Agent",

View File

@ -123,7 +123,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
# (namespace, type) -> List[AgentId]
self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set)
self._agent_factories: Dict[str, Callable[[], Agent] | Callable[[AgentRuntime, AgentId], Agent]] = {}
self._agent_factories: Dict[
str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
] = {}
self._instantiated_agents: Dict[AgentId, Agent] = {}
self._intervention_handler = intervention_handler
self._known_namespaces: set[str] = set()
@ -173,7 +175,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
if sender is not None and sender.namespace != recipient.namespace:
raise ValueError("Sender and recipient must be in the same namespace to communicate.")
self._process_seen_namespace(recipient.namespace)
await self._process_seen_namespace(recipient.namespace)
content = message.__dict__ if hasattr(message, "__dict__") else message
logger.info(f"Sending message of type {type(message).__name__} to {recipient.name}: {content}")
@ -227,7 +229,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
assert explicit_namespace is not None or sender_namespace is not None
namespace = cast(str, explicit_namespace or sender_namespace)
self._process_seen_namespace(namespace)
await self._process_seen_namespace(namespace)
self._message_queue.append(
PublishMessageEnvelope(
@ -238,17 +240,17 @@ class SingleThreadedAgentRuntime(AgentRuntime):
)
)
def save_state(self) -> Mapping[str, Any]:
async def save_state(self) -> Mapping[str, Any]:
state: Dict[str, Dict[str, Any]] = {}
for agent_id in self._instantiated_agents:
state[str(agent_id)] = dict(self._get_agent(agent_id).save_state())
state[str(agent_id)] = dict((await self._get_agent(agent_id)).save_state())
return state
def load_state(self, state: Mapping[str, Any]) -> None:
async def load_state(self, state: Mapping[str, Any]) -> None:
for agent_id_str in state:
agent_id = AgentId.from_str(agent_id_str)
if agent_id.name in self._known_agent_names:
self._get_agent(agent_id).load_state(state[str(agent_id)])
(await self._get_agent(agent_id)).load_state(state[str(agent_id)])
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
recipient = message_envelope.recipient
@ -269,7 +271,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
recipient_agent = self._get_agent(recipient)
recipient_agent = await self._get_agent(recipient)
response = await recipient_agent.on_message(
message_envelope.message,
cancellation_token=message_envelope.cancellation_token,
@ -297,7 +299,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
if message_envelope.sender is not None and agent_id.name == message_envelope.sender.name:
continue
sender_agent = self._get_agent(message_envelope.sender) if message_envelope.sender is not None else None
sender_agent = (
await self._get_agent(message_envelope.sender) if message_envelope.sender is not None else None
)
sender_name = sender_agent.metadata["name"] if sender_agent is not None else "Unknown"
logger.info(
f"Calling message handler for {agent_id.name} with message type {type(message_envelope.message).__name__} published by {sender_name}"
@ -312,7 +316,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
# )
# )
agent = self._get_agent(agent_id)
agent = await self._get_agent(agent_id)
future = agent.on_message(
message_envelope.message,
cancellation_token=message_envelope.cancellation_token,
@ -430,19 +434,19 @@ class SingleThreadedAgentRuntime(AgentRuntime):
def start(self) -> RunContext:
return RunContext(self)
def agent_metadata(self, agent: AgentId) -> AgentMetadata:
return self._get_agent(agent).metadata
async def agent_metadata(self, agent: AgentId) -> AgentMetadata:
return (await self._get_agent(agent)).metadata
def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
return self._get_agent(agent).save_state()
async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
return (await self._get_agent(agent)).save_state()
def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
self._get_agent(agent).load_state(state)
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
(await self._get_agent(agent)).load_state(state)
def register(
async def register(
self,
name: str,
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
) -> None:
if name in self._agent_factories:
raise ValueError(f"Agent with name {name} already exists.")
@ -450,28 +454,30 @@ class SingleThreadedAgentRuntime(AgentRuntime):
# For all already prepared namespaces we need to prepare this agent
for namespace in self._known_namespaces:
self._get_agent(AgentId(name=name, namespace=namespace))
await self._get_agent(AgentId(name=name, namespace=namespace))
def _invoke_agent_factory(
self, agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], agent_id: AgentId
async def _invoke_agent_factory(
self,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
agent_id: AgentId,
) -> T:
token = agent_instantiation_context.set((self, agent_id))
with agent_instantiation_context((self, agent_id)):
if len(inspect.signature(agent_factory).parameters) == 0:
factory_one = cast(Callable[[], T], agent_factory)
agent = factory_one()
elif len(inspect.signature(agent_factory).parameters) == 2:
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
agent = factory_two(self, agent_id)
else:
raise ValueError("Agent factory must take 0 or 2 arguments.")
if len(inspect.signature(agent_factory).parameters) == 0:
factory_one = cast(Callable[[], T], agent_factory)
agent = factory_one()
elif len(inspect.signature(agent_factory).parameters) == 2:
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
agent = factory_two(self, agent_id)
else:
raise ValueError("Agent factory must take 0 or 2 arguments.")
if inspect.isawaitable(agent):
return cast(T, await agent)
agent_instantiation_context.reset(token)
return agent
return agent
def _get_agent(self, agent_id: AgentId) -> Agent:
self._process_seen_namespace(agent_id.namespace)
async def _get_agent(self, agent_id: AgentId) -> Agent:
await self._process_seen_namespace(agent_id.namespace)
if agent_id in self._instantiated_agents:
return self._instantiated_agents[agent_id]
@ -480,25 +486,25 @@ class SingleThreadedAgentRuntime(AgentRuntime):
agent_factory = self._agent_factories[agent_id.name]
agent = self._invoke_agent_factory(agent_factory, agent_id)
agent = await self._invoke_agent_factory(agent_factory, agent_id)
for message_type in agent.metadata["subscriptions"]:
self._per_type_subscribers[(agent_id.namespace, message_type)].add(agent_id)
self._instantiated_agents[agent_id] = agent
return agent
def get(self, name: str, *, namespace: str = "default") -> AgentId:
return self._get_agent(AgentId(name=name, namespace=namespace)).id
async def get(self, name: str, *, namespace: str = "default") -> AgentId:
return (await self._get_agent(AgentId(name=name, namespace=namespace))).id
def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
id = self.get(name, namespace=namespace)
async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
id = await self.get(name, namespace=namespace)
return AgentProxy(id, self)
# Hydrate the agent instances in a namespace. The primary reason for this is
# to ensure message type subscriptions are set up.
def _process_seen_namespace(self, namespace: str) -> None:
async def _process_seen_namespace(self, namespace: str) -> None:
if namespace in self._known_namespaces:
return
self._known_namespaces.add(namespace)
for name in self._known_agent_names:
self._get_agent(AgentId(name=name, namespace=namespace))
await self._get_agent(AgentId(name=name, namespace=namespace))

View File

@ -4,7 +4,7 @@ from typing import Any, Awaitable, Callable, Mapping, Sequence, TypeVar, get_typ
from ..core._agent import Agent
from ..core._agent_id import AgentId
from ..core._agent_metadata import AgentMetadata
from ..core._agent_runtime import AgentRuntime, agent_instantiation_context
from ..core._agent_runtime import AGENT_INSTANTIATION_CONTEXT_VAR, AgentRuntime
from ..core._cancellation_token import CancellationToken
from ..core._serialization import MESSAGE_TYPE_REGISTRY
from ..core.exceptions import CantHandleException
@ -46,7 +46,7 @@ class ClosureAgent(Agent):
self, description: str, closure: Callable[[AgentRuntime, AgentId, T, CancellationToken], Awaitable[Any]]
) -> None:
try:
runtime, id = agent_instantiation_context.get()
runtime, id = AGENT_INSTANTIATION_CONTEXT_VAR.get()
except LookupError as e:
raise RuntimeError(
"ClosureAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated."

View File

@ -7,7 +7,7 @@ from ._agent_id import AgentId
from ._agent_metadata import AgentMetadata
from ._agent_props import AgentChildren
from ._agent_proxy import AgentProxy
from ._agent_runtime import AgentRuntime, agent_instantiation_context
from ._agent_runtime import AGENT_INSTANTIATION_CONTEXT_VAR, AgentRuntime, agent_instantiation_context
from ._base_agent import BaseAgent
from ._cancellation_token import CancellationToken
from ._serialization import MESSAGE_TYPE_REGISTRY, TypeDeserializer, TypeSerializer
@ -22,6 +22,7 @@ __all__ = [
"CancellationToken",
"AgentChildren",
"agent_instantiation_context",
"AGENT_INSTANTIATION_CONTEXT_VAR",
"MESSAGE_TYPE_REGISTRY",
"TypeSerializer",
"TypeDeserializer",

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping
from typing import TYPE_CHECKING, Any, Awaitable, Mapping
from ._agent_id import AgentId
from ._agent_metadata import AgentMetadata
@ -21,7 +21,7 @@ class AgentProxy:
return self._agent
@property
def metadata(self) -> AgentMetadata:
def metadata(self) -> Awaitable[AgentMetadata]:
"""Metadata of the agent."""
return self._runtime.agent_metadata(self._agent)
@ -39,14 +39,14 @@ class AgentProxy:
cancellation_token=cancellation_token,
)
def save_state(self) -> Mapping[str, Any]:
async def save_state(self) -> Mapping[str, Any]:
"""Save the state of the agent. The result must be JSON serializable."""
return self._runtime.agent_save_state(self._agent)
return await self._runtime.agent_save_state(self._agent)
def load_state(self, state: Mapping[str, Any]) -> None:
async def load_state(self, state: Mapping[str, Any]) -> None:
"""Load in the state of the agent obtained from `save_state`.
Args:
state (Mapping[str, Any]): State of the agent. Must be JSON serializable.
"""
self._runtime.agent_load_state(self._agent, state)
await self._runtime.agent_load_state(self._agent, state)

View File

@ -1,7 +1,8 @@
from __future__ import annotations
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Callable, Mapping, Protocol, TypeVar, overload, runtime_checkable
from typing import Any, Awaitable, Callable, Generator, Mapping, Protocol, TypeVar, overload, runtime_checkable
from ._agent import Agent
from ._agent_id import AgentId
@ -13,7 +14,18 @@ from ._cancellation_token import CancellationToken
T = TypeVar("T", bound=Agent)
agent_instantiation_context: ContextVar[tuple[AgentRuntime, AgentId]] = ContextVar("agent_instantiation_context")
AGENT_INSTANTIATION_CONTEXT_VAR: ContextVar[tuple[AgentRuntime, AgentId]] = ContextVar(
"AGENT_INSTANTIATION_CONTEXT_VAR"
)
@contextmanager
def agent_instantiation_context(ctx: tuple[AgentRuntime, AgentId]) -> Generator[None, Any, None]:
token = AGENT_INSTANTIATION_CONTEXT_VAR.set(ctx)
try:
yield
finally:
AGENT_INSTANTIATION_CONTEXT_VAR.reset(token)
@runtime_checkable
@ -68,23 +80,23 @@ class AgentRuntime(Protocol):
"""
@overload
def register(
async def register(
self,
name: str,
agent_factory: Callable[[], T],
agent_factory: Callable[[], T | Awaitable[T]],
) -> None: ...
@overload
def register(
async def register(
self,
name: str,
agent_factory: Callable[[AgentRuntime, AgentId], T],
agent_factory: Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
) -> None: ...
def register(
async def register(
self,
name: str,
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
) -> None:
"""Register an agent factory with the runtime associated with a specific name. The name must be unique.
@ -110,7 +122,7 @@ class AgentRuntime(Protocol):
...
def get(self, name: str, *, namespace: str = "default") -> AgentId:
async def get(self, name: str, *, namespace: str = "default") -> AgentId:
"""Get an agent by name and namespace.
Args:
@ -122,7 +134,7 @@ class AgentRuntime(Protocol):
"""
...
def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
"""Get a proxy for an agent by name and namespace.
Args:
@ -135,27 +147,27 @@ class AgentRuntime(Protocol):
...
@overload
def register_and_get(
async def register_and_get(
self,
name: str,
agent_factory: Callable[[], T],
agent_factory: Callable[[], T | Awaitable[T]],
*,
namespace: str = "default",
) -> AgentId: ...
@overload
def register_and_get(
async def register_and_get(
self,
name: str,
agent_factory: Callable[[AgentRuntime, AgentId], T],
agent_factory: Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
*,
namespace: str = "default",
) -> AgentId: ...
def register_and_get(
async def register_and_get(
self,
name: str,
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
*,
namespace: str = "default",
) -> AgentId:
@ -169,31 +181,31 @@ class AgentRuntime(Protocol):
Returns:
AgentId: The agent id.
"""
self.register(name, agent_factory)
return self.get(name, namespace=namespace)
await self.register(name, agent_factory)
return await self.get(name, namespace=namespace)
@overload
def register_and_get_proxy(
async def register_and_get_proxy(
self,
name: str,
agent_factory: Callable[[], T],
agent_factory: Callable[[], T | Awaitable[T]],
*,
namespace: str = "default",
) -> AgentProxy: ...
@overload
def register_and_get_proxy(
async def register_and_get_proxy(
self,
name: str,
agent_factory: Callable[[AgentRuntime, AgentId], T],
agent_factory: Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
*,
namespace: str = "default",
) -> AgentProxy: ...
def register_and_get_proxy(
async def register_and_get_proxy(
self,
name: str,
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
*,
namespace: str = "default",
) -> AgentProxy:
@ -207,10 +219,10 @@ class AgentRuntime(Protocol):
Returns:
AgentProxy: The agent proxy.
"""
self.register(name, agent_factory)
return self.get_proxy(name, namespace=namespace)
await self.register(name, agent_factory)
return await self.get_proxy(name, namespace=namespace)
def save_state(self) -> Mapping[str, Any]:
async def save_state(self) -> Mapping[str, Any]:
"""Save the state of the entire runtime, including all hosted agents. The only way to restore the state is to pass it to :meth:`load_state`.
The structure of the state is implementation defined and can be any JSON serializable object.
@ -220,7 +232,7 @@ class AgentRuntime(Protocol):
"""
...
def load_state(self, state: Mapping[str, Any]) -> None:
async def load_state(self, state: Mapping[str, Any]) -> None:
"""Load the state of the entire runtime, including all hosted agents. The state should be the same as the one returned by :meth:`save_state`.
Args:
@ -228,7 +240,7 @@ class AgentRuntime(Protocol):
"""
...
def agent_metadata(self, agent: AgentId) -> AgentMetadata:
async def agent_metadata(self, agent: AgentId) -> AgentMetadata:
"""Get the metadata for an agent.
Args:
@ -239,7 +251,7 @@ class AgentRuntime(Protocol):
"""
...
def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
"""Save the state of a single agent.
The structure of the state is implementation defined and can be any JSON serializable object.
@ -252,7 +264,7 @@ class AgentRuntime(Protocol):
"""
...
def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
"""Load the state of a single agent.
Args:

View File

@ -5,7 +5,7 @@ from typing import Any, Mapping, Sequence
from ._agent import Agent
from ._agent_id import AgentId
from ._agent_metadata import AgentMetadata
from ._agent_runtime import AgentRuntime, agent_instantiation_context
from ._agent_runtime import AGENT_INSTANTIATION_CONTEXT_VAR, AgentRuntime
from ._cancellation_token import CancellationToken
@ -22,7 +22,7 @@ class BaseAgent(ABC, Agent):
def __init__(self, description: str, subscriptions: Sequence[str]) -> None:
try:
runtime, id = agent_instantiation_context.get()
runtime, id = AGENT_INSTANTIATION_CONTEXT_VAR.get()
except LookupError as e:
raise RuntimeError(
"BaseAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated."

View File

@ -12,6 +12,7 @@ from typing import (
Any,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
ClassVar,
DefaultDict,
@ -188,7 +189,9 @@ class WorkerAgentRuntime(AgentRuntime):
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
# (namespace, type) -> List[AgentId]
self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set)
self._agent_factories: Dict[str, Callable[[], Agent] | Callable[[AgentRuntime, AgentId], Agent]] = {}
self._agent_factories: Dict[
str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
] = {}
# If empty, then all namespaces are valid for that agent type
self._valid_namespaces: Dict[str, Sequence[str]] = {}
self._instantiated_agents: Dict[AgentId, Agent] = {}
@ -249,7 +252,7 @@ class WorkerAgentRuntime(AgentRuntime):
(namespace, MESSAGE_TYPE_REGISTRY.type_name(message))
]:
logger.info("Sending message to %s", agent_id)
agent = self._get_agent(agent_id)
agent = await self._get_agent(agent_id)
try:
await agent.on_message(message, CancellationToken())
logger.info("%s handled event %s", agent_id, message)
@ -321,7 +324,7 @@ class WorkerAgentRuntime(AgentRuntime):
assert explicit_namespace is not None or sender_namespace is not None
actual_namespace = cast(str, explicit_namespace or sender_namespace)
self._process_seen_namespace(actual_namespace)
await self._process_seen_namespace(actual_namespace)
message_type = MESSAGE_TYPE_REGISTRY.type_name(message)
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=message_type)
message = Message(event=Event(namespace=actual_namespace, type=message_type, data=serialized_message))
@ -332,25 +335,25 @@ class WorkerAgentRuntime(AgentRuntime):
await asyncio.create_task(write_message())
def save_state(self) -> Mapping[str, Any]:
async def save_state(self) -> Mapping[str, Any]:
raise NotImplementedError("Saving state is not yet implemented.")
def load_state(self, state: Mapping[str, Any]) -> None:
async def load_state(self, state: Mapping[str, Any]) -> None:
raise NotImplementedError("Loading state is not yet implemented.")
def agent_metadata(self, agent: AgentId) -> AgentMetadata:
async def agent_metadata(self, agent: AgentId) -> AgentMetadata:
raise NotImplementedError("Agent metadata is not yet implemented.")
def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
raise NotImplementedError("Agent save_state is not yet implemented.")
def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
raise NotImplementedError("Agent load_state is not yet implemented.")
def register(
async def register(
self,
name: str,
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
) -> None:
if name in self._agent_factories:
raise ValueError(f"Agent with name {name} already exists.")
@ -358,29 +361,32 @@ class WorkerAgentRuntime(AgentRuntime):
# For all already prepared namespaces we need to prepare this agent
for namespace in self._known_namespaces:
self._get_agent(AgentId(name=name, namespace=namespace))
await self._get_agent(AgentId(name=name, namespace=namespace))
# TODO do we need to convert register to async?
asyncio.create_task(self.send_register_agent_type(name))
await self.send_register_agent_type(name)
def _invoke_agent_factory(
async def _invoke_agent_factory(
self,
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
agent_id: AgentId,
) -> T:
if len(inspect.signature(agent_factory).parameters) == 0:
factory_one = cast(Callable[[], T], agent_factory)
agent = factory_one()
elif len(inspect.signature(agent_factory).parameters) == 2:
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
agent = factory_two(self, agent_id)
else:
raise ValueError("Agent factory must take 0 or 2 arguments.")
with agent_instantiation_context((self, agent_id)):
if len(inspect.signature(agent_factory).parameters) == 0:
factory_one = cast(Callable[[], T], agent_factory)
agent = factory_one()
elif len(inspect.signature(agent_factory).parameters) == 2:
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
agent = factory_two(self, agent_id)
else:
raise ValueError("Agent factory must take 0 or 2 arguments.")
if inspect.isawaitable(agent):
return cast(T, await agent)
return agent
def _get_agent(self, agent_id: AgentId) -> Agent:
self._process_seen_namespace(agent_id.namespace)
async def _get_agent(self, agent_id: AgentId) -> Agent:
await self._process_seen_namespace(agent_id.namespace)
if agent_id in self._instantiated_agents:
return self._instantiated_agents[agent_id]
@ -389,9 +395,7 @@ class WorkerAgentRuntime(AgentRuntime):
agent_factory = self._agent_factories[agent_id.name]
token = agent_instantiation_context.set((self, agent_id))
agent = self._invoke_agent_factory(agent_factory, agent_id)
agent_instantiation_context.reset(token)
agent = await self._invoke_agent_factory(agent_factory, agent_id)
for message_type in agent.metadata["subscriptions"]:
self._per_type_subscribers[(agent_id.namespace, message_type)].add(agent_id)
@ -399,19 +403,19 @@ class WorkerAgentRuntime(AgentRuntime):
self._instantiated_agents[agent_id] = agent
return agent
def get(self, name: str, *, namespace: str = "default") -> AgentId:
return self._get_agent(AgentId(name=name, namespace=namespace)).id
async def get(self, name: str, *, namespace: str = "default") -> AgentId:
return (await self._get_agent(AgentId(name=name, namespace=namespace))).id
def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
id = self.get(name, namespace=namespace)
async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
id = await self.get(name, namespace=namespace)
return AgentProxy(id, self)
# Hydrate the agent instances in a namespace. The primary reason for this is
# to ensure message type subscriptions are set up.
def _process_seen_namespace(self, namespace: str) -> None:
async def _process_seen_namespace(self, namespace: str) -> None:
if namespace in self._known_namespaces:
return
self._known_namespaces.add(namespace)
for name in self._known_agent_names:
self._get_agent(AgentId(name=name, namespace=namespace))
await self._get_agent(AgentId(name=name, namespace=namespace))

View File

@ -15,19 +15,19 @@ async def main() -> None:
runtime = SingleThreadedAgentRuntime()
# Register agents.
coder = runtime.register_and_get_proxy(
coder = await runtime.register_and_get_proxy(
"Coder",
lambda: Coder(model_client=create_completion_client_from_env()),
)
executor = runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code"))
executor = await runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code"))
user_proxy = runtime.register_and_get_proxy(
user_proxy = await runtime.register_and_get_proxy(
"UserProxy",
lambda: UserProxy(description="The current user interacting with you."),
)
runtime.register(
await runtime.register(
"orchestrator",
lambda: LedgerOrchestrator(
model_client=create_completion_client_from_env(), agents=[coder, executor, user_proxy]

View File

@ -15,19 +15,19 @@ async def main() -> None:
runtime = SingleThreadedAgentRuntime()
# Register agents.
coder = runtime.register_and_get_proxy(
coder = await runtime.register_and_get_proxy(
"Coder",
lambda: Coder(model_client=create_completion_client_from_env()),
)
executor = runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code"))
executor = await runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code"))
user_proxy = runtime.register_and_get_proxy(
user_proxy = await runtime.register_and_get_proxy(
"UserProxy",
lambda: UserProxy(),
)
runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor, user_proxy]))
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor, user_proxy]))
run_context = runtime.start()
await runtime.send_message(RequestReplyMessage(), user_proxy.id)

View File

@ -18,16 +18,16 @@ async def main() -> None:
client = create_completion_client_from_env()
# Register agents.
file_surfer = runtime.register_and_get_proxy(
file_surfer = await runtime.register_and_get_proxy(
"file_surfer",
lambda: FileSurfer(model_client=client),
)
user_proxy = runtime.register_and_get_proxy(
user_proxy = await runtime.register_and_get_proxy(
"UserProxy",
lambda: UserProxy(),
)
runtime.register("orchestrator", lambda: RoundRobinOrchestrator([file_surfer, user_proxy]))
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([file_surfer, user_proxy]))
run_context = runtime.start()
await runtime.send_message(RequestReplyMessage(), user_proxy.id)

View File

@ -13,10 +13,10 @@ from team_one.utils import LogHandler
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
fake1 = runtime.register_and_get_proxy("fake_agent_1", lambda: ReflexAgent("First reflect agent"))
fake2 = runtime.register_and_get_proxy("fake_agent_2", lambda: ReflexAgent("Second reflect agent"))
fake3 = runtime.register_and_get_proxy("fake_agent_3", lambda: ReflexAgent("Third reflect agent"))
runtime.register_and_get("orchestrator", lambda: RoundRobinOrchestrator([fake1, fake2, fake3]))
fake1 = await runtime.register_and_get_proxy("fake_agent_1", lambda: ReflexAgent("First reflect agent"))
fake2 = await runtime.register_and_get_proxy("fake_agent_2", lambda: ReflexAgent("Second reflect agent"))
fake3 = await runtime.register_and_get_proxy("fake_agent_3", lambda: ReflexAgent("Third reflect agent"))
await runtime.register_and_get("orchestrator", lambda: RoundRobinOrchestrator([fake1, fake2, fake3]))
task_message = UserMessage(content="Test Message", source="User")
run_context = runtime.start()

View File

@ -19,16 +19,16 @@ async def main() -> None:
client = create_completion_client_from_env()
# Register agents.
coder = runtime.register_and_get_proxy(
coder = await runtime.register_and_get_proxy(
"Coder",
lambda: Coder(model_client=client),
)
user_proxy = runtime.register_and_get_proxy(
user_proxy = await runtime.register_and_get_proxy(
"UserProxy",
lambda: UserProxy(),
)
runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, user_proxy]))
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, user_proxy]))
run_context = runtime.start()
await runtime.send_message(RequestReplyMessage(), user_proxy.id)

View File

@ -21,17 +21,17 @@ async def main() -> None:
client = create_completion_client_from_env()
# Register agents.
web_surfer = runtime.register_and_get_proxy(
web_surfer = await runtime.register_and_get_proxy(
"WebSurfer",
lambda: MultimodalWebSurfer(),
)
user_proxy = runtime.register_and_get_proxy(
user_proxy = await runtime.register_and_get_proxy(
"UserProxy",
lambda: UserProxy(),
)
runtime.register("orchestrator", lambda: RoundRobinOrchestrator([web_surfer, user_proxy]))
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([web_surfer, user_proxy]))
run_context = runtime.start()

View File

@ -32,7 +32,7 @@ dependencies = [
"youtube-transcript-api",
"SpeechRecognition",
"pathvalidate",
"playwright"
"playwright",
]
[tool.hatch.envs.default]
@ -45,7 +45,8 @@ dependencies = [
"aiofiles",
"types-aiofiles",
"types-requests",
"azure-identity"
"types-pillow",
"azure-identity",
]
[tool.hatch.envs.default.extra-scripts]
@ -71,7 +72,13 @@ line-length = 120
fix = true
exclude = ["build", "dist", "page_script.js"]
target-version = "py310"
include = ["src/**", "examples/*.py"]
include = [
"src/**",
"examples/*.py",
"../../benchmarks/HumanEval/Templates/TeamOne/scenario.py",
"../../benchmarks/HumanEval/Templates/TwoAgents/scenario.py",
"../../benchmarks/GAIA/TeamOne/TwoAgents/scenario.py",
]
[tool.ruff.format]
docstring-code-format = true
@ -81,7 +88,11 @@ select = ["E", "F", "W", "B", "Q", "I", "ASYNC"]
ignore = ["F401", "E501"]
[tool.mypy]
files = ["src", "examples", "tests"]
files = [
"src",
"tests",
"examples",
]
strict = true
python_version = "3.10"
@ -100,7 +111,14 @@ disallow_untyped_decorators = true
disallow_any_unimported = true
[tool.pyright]
include = ["src", "tests", "examples"]
include = [
"src",
"tests",
"examples",
"../../benchmarks/HumanEval/Templates/TeamOne/scenario.py",
"../../benchmarks/HumanEval/Templates/TwoAgents/scenario.py",
"../../benchmarks/GAIA/Templates/TeamOne/scenario.py",
]
typeCheckingMode = "strict"
reportUnnecessaryIsInstance = false
reportMissingTypeStubs = false

View File

@ -69,7 +69,7 @@ class BaseOrchestrator(TypeRoutedAgent):
logger.info(
OrchestrationEvent(
source=f"{self.metadata['name']} (thought)",
message=f"Next speaker {next_agent.metadata['name']}" "",
message=f"Next speaker {(await next_agent.metadata)['name']}" "",
)
)

View File

@ -68,7 +68,7 @@ def _draw_roi(
luminance = color[0] * 0.3 + color[1] * 0.59 + color[2] * 0.11
text_color = (0, 0, 0, 255) if luminance > 90 else (255, 255, 255, 255)
roi = [(rect["left"], rect["top"]), (rect["right"], rect["bottom"])]
roi = ((rect["left"], rect["top"]), (rect["right"], rect["bottom"]))
label_location = (rect["right"], rect["top"])
label_anchor = "rb"

View File

@ -79,16 +79,16 @@ class LedgerOrchestrator(BaseOrchestrator):
def _get_ledger_prompt(self, task: str, team: str, names: List[str]) -> str:
return self._ledger_prompt.format(task=task, team=team, names=names)
def _get_team_description(self) -> str:
async def _get_team_description(self) -> str:
team_description = ""
for agent in self._agents:
name = agent.metadata["name"]
description = agent.metadata["description"]
name = (await agent.metadata)["name"]
description = (await agent.metadata)["description"]
team_description += f"{name}: {description}\n"
return team_description
def _get_team_names(self) -> List[str]:
return [agent.metadata["name"] for agent in self._agents]
async def _get_team_names(self) -> List[str]:
return [(await agent.metadata)["name"] for agent in self._agents]
def _set_task_str(self, message: LLMMessage) -> None:
if len(self._chat_history) == 1:
@ -112,7 +112,7 @@ class LedgerOrchestrator(BaseOrchestrator):
return False
async def _plan(self) -> str:
team_description = self._get_team_description()
team_description = await self._get_team_description()
# 1. GATHER FACTS
# create a closed book task and generate a response and update the chat history
@ -144,8 +144,8 @@ class LedgerOrchestrator(BaseOrchestrator):
async def update_ledger(self) -> Dict[str, Any]:
max_json_retries = 10
team_description = self._get_team_description()
names = self._get_team_names()
team_description = await self._get_team_description()
names = await self._get_team_names()
ledger_prompt = self._get_ledger_prompt(self.task_str, team_description, names)
ledger_user_message = UserMessage(content=ledger_prompt, source=self.metadata["name"])
@ -234,7 +234,7 @@ class LedgerOrchestrator(BaseOrchestrator):
next_agent_name = ledger_dict["next_speaker"]["answer"]
for agent in self._agents:
if agent.metadata["name"] == next_agent_name:
if (await agent.metadata)["name"] == next_agent_name:
# broadcast a new message
instruction = ledger_dict["instruction_or_question"]["answer"]
user_message = UserMessage(content=instruction, source=self.metadata["name"])

View File

@ -1,6 +1,6 @@
import pytest
from pytest_mock import MockerFixture
from agnext.core import AgentRuntime, agent_instantiation_context, AgentId
from agnext.core import AgentRuntime, AGENT_INSTANTIATION_CONTEXT_VAR, AgentId
from test_utils import NoopAgent
@ -11,7 +11,7 @@ async def test_base_agent_create(mocker: MockerFixture) -> None:
runtime = mocker.Mock(spec=AgentRuntime)
# Shows how to set the context for the agent instantiation in a test context
agent_instantiation_context.set((runtime, AgentId("name", "namespace")))
AGENT_INSTANTIATION_CONTEXT_VAR.set((runtime, AgentId("name", "namespace")))
agent = NoopAgent()
assert agent.runtime == runtime

View File

@ -57,7 +57,7 @@ class NestingLongRunningAgent(TypeRoutedAgent):
async def test_cancellation_with_token() -> None:
runtime = SingleThreadedAgentRuntime()
long_running = runtime.register_and_get("long_running", LongRunningAgent)
long_running = await runtime.register_and_get("long_running", LongRunningAgent)
token = CancellationToken()
response = asyncio.create_task(runtime.send_message(MessageType(), recipient=long_running, cancellation_token=token))
assert not response.done()
@ -73,7 +73,7 @@ async def test_cancellation_with_token() -> None:
await response
assert response.done()
long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
assert long_running_agent.called
assert long_running_agent.cancelled
@ -83,8 +83,8 @@ async def test_cancellation_with_token() -> None:
async def test_nested_cancellation_only_outer_called() -> None:
runtime = SingleThreadedAgentRuntime()
long_running = runtime.register_and_get("long_running", LongRunningAgent)
nested = runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
long_running = await runtime.register_and_get("long_running", LongRunningAgent)
nested = await runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
token = CancellationToken()
response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token))
@ -100,10 +100,10 @@ async def test_nested_cancellation_only_outer_called() -> None:
await response
assert response.done()
nested_agent: NestingLongRunningAgent = runtime._get_agent(nested) # type: ignore
nested_agent: NestingLongRunningAgent = await runtime._get_agent(nested) # type: ignore
assert nested_agent.called
assert nested_agent.cancelled
long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
assert long_running_agent.called is False
assert long_running_agent.cancelled is False
@ -111,8 +111,8 @@ async def test_nested_cancellation_only_outer_called() -> None:
async def test_nested_cancellation_inner_called() -> None:
runtime = SingleThreadedAgentRuntime()
long_running = runtime.register_and_get("long_running", LongRunningAgent )
nested = runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
long_running = await runtime.register_and_get("long_running", LongRunningAgent )
nested = await runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
token = CancellationToken()
response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token))
@ -130,9 +130,9 @@ async def test_nested_cancellation_inner_called() -> None:
await response
assert response.done()
nested_agent: NestingLongRunningAgent = runtime._get_agent(nested) # type: ignore
nested_agent: NestingLongRunningAgent = await runtime._get_agent(nested) # type: ignore
assert nested_agent.called
assert nested_agent.cancelled
long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
assert long_running_agent.called
assert long_running_agent.cancelled

View File

@ -28,7 +28,7 @@ async def test_register_receives_publish() -> None:
namespace = id.namespace
await queue.put((namespace, message.content))
runtime.register("name", lambda: ClosureAgent("My agent", log_message))
await runtime.register("name", lambda: ClosureAgent("My agent", log_message))
run_context = runtime.start()
await runtime.publish_message(Message("first message"), namespace="default")
await runtime.publish_message(Message("second message"), namespace="default")

View File

@ -19,7 +19,7 @@ async def test_intervention_count_messages() -> None:
handler = DebugInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
loopback = runtime.register_and_get("name", LoopbackAgent)
loopback = await runtime.register_and_get("name", LoopbackAgent)
run_context = runtime.start()
_response = await runtime.send_message(MessageType(), recipient=loopback)
@ -27,7 +27,7 @@ async def test_intervention_count_messages() -> None:
await run_context.stop()
assert handler.num_messages == 1
loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore
loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore
assert loopback_agent.num_calls == 1
@pytest.mark.asyncio
@ -40,7 +40,7 @@ async def test_intervention_drop_send() -> None:
handler = DropSendInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
loopback = runtime.register_and_get("name", LoopbackAgent)
loopback = await runtime.register_and_get("name", LoopbackAgent)
run_context = runtime.start()
with pytest.raises(MessageDroppedException):
@ -48,7 +48,7 @@ async def test_intervention_drop_send() -> None:
await run_context.stop()
loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore
loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore
assert loopback_agent.num_calls == 0
@ -62,7 +62,7 @@ async def test_intervention_drop_response() -> None:
handler = DropResponseInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
loopback = runtime.register_and_get("name", LoopbackAgent)
loopback = await runtime.register_and_get("name", LoopbackAgent)
run_context = runtime.start()
with pytest.raises(MessageDroppedException):
@ -84,7 +84,7 @@ async def test_intervention_raise_exception_on_send() -> None:
handler = ExceptionInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
long_running = runtime.register_and_get("name", LoopbackAgent)
long_running = await runtime.register_and_get("name", LoopbackAgent)
run_context = runtime.start()
with pytest.raises(InterventionException):
@ -92,7 +92,7 @@ async def test_intervention_raise_exception_on_send() -> None:
await run_context.stop()
long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore
long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore
assert long_running_agent.num_calls == 0
@pytest.mark.asyncio
@ -108,12 +108,12 @@ async def test_intervention_raise_exception_on_respond() -> None:
handler = ExceptionInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
long_running = runtime.register_and_get("name", LoopbackAgent)
long_running = await runtime.register_and_get("name", LoopbackAgent)
run_context = runtime.start()
with pytest.raises(InterventionException):
_response = await runtime.send_message(MessageType(), recipient=long_running)
await run_context.stop()
long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore
long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore
assert long_running_agent.num_calls == 1

View File

@ -14,31 +14,31 @@ async def test_agent_names_must_be_unique() -> None:
assert agent.id == id
return agent
agent1 = runtime.register_and_get("name1", agent_factory)
agent1 = await runtime.register_and_get("name1", agent_factory)
assert agent1 == AgentId("name1", "default")
with pytest.raises(ValueError):
_agent1 = runtime.register_and_get("name1", NoopAgent)
_agent1 = await runtime.register_and_get("name1", NoopAgent)
_agent1 = runtime.register_and_get("name3", NoopAgent)
_agent1 = await runtime.register_and_get("name3", NoopAgent)
@pytest.mark.asyncio
async def test_register_receives_publish() -> None:
runtime = SingleThreadedAgentRuntime()
runtime.register("name", LoopbackAgent)
await runtime.register("name", LoopbackAgent)
run_context = runtime.start()
await runtime.publish_message(MessageType(), namespace="default")
await run_context.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name")) # type: ignore
long_running_agent: LoopbackAgent = await runtime._get_agent(await runtime.get("name")) # type: ignore
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name", namespace="other")) # type: ignore
other_long_running_agent: LoopbackAgent = await runtime._get_agent(await runtime.get("name", namespace="other")) # type: ignore
assert other_long_running_agent.num_calls == 0
@ -54,7 +54,7 @@ async def test_register_receives_publish_cascade() -> None:
# Register agents
for i in range(num_agents):
runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds))
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds))
run_context = runtime.start()
@ -67,5 +67,5 @@ async def test_register_receives_publish_cascade() -> None:
# Check that each agent received the correct number of messages.
for i in range(num_agents):
agent: CascadingAgent = runtime._get_agent(runtime.get(f"name{i}")) # type: ignore
agent: CascadingAgent = await runtime._get_agent(await runtime.get(f"name{i}")) # type: ignore
assert agent.num_calls == total_num_calls_expected

View File

@ -5,8 +5,8 @@ from agnext.application import SingleThreadedAgentRuntime
from agnext.core import BaseAgent, CancellationToken
class StatefulAgent(BaseAgent): # type: ignore
def __init__(self) -> None: # type: ignore
class StatefulAgent(BaseAgent):
def __init__(self) -> None:
super().__init__("A stateful agent", [])
self.state = 0
@ -14,7 +14,7 @@ class StatefulAgent(BaseAgent): # type: ignore
def subscriptions(self) -> Sequence[type]:
return []
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: # type: ignore
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> None:
raise NotImplementedError
def save_state(self) -> Mapping[str, Any]:
@ -28,8 +28,8 @@ class StatefulAgent(BaseAgent): # type: ignore
async def test_agent_can_save_state() -> None:
runtime = SingleThreadedAgentRuntime()
agent1_id = runtime.register_and_get("name1", StatefulAgent)
agent1: StatefulAgent = runtime._get_agent(agent1_id) # type: ignore
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
agent1: StatefulAgent = await runtime._get_agent(agent1_id) # type: ignore
assert agent1.state == 0
agent1.state = 1
assert agent1.state == 1
@ -46,19 +46,19 @@ async def test_agent_can_save_state() -> None:
async def test_runtime_can_save_state() -> None:
runtime = SingleThreadedAgentRuntime()
agent1_id = runtime.register_and_get("name1", StatefulAgent)
agent1: StatefulAgent = runtime._get_agent(agent1_id) # type: ignore
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
agent1: StatefulAgent = await runtime._get_agent(agent1_id) # type: ignore
assert agent1.state == 0
agent1.state = 1
assert agent1.state == 1
runtime_state = runtime.save_state()
runtime_state = await runtime.save_state()
runtime2 = SingleThreadedAgentRuntime()
agent2_id = runtime2.register_and_get("name1", StatefulAgent)
agent2: StatefulAgent = runtime2._get_agent(agent2_id) # type: ignore
agent2_id = await runtime2.register_and_get("name1", StatefulAgent)
agent2: StatefulAgent = await runtime2._get_agent(agent2_id) # type: ignore
runtime2.load_state(runtime_state)
await runtime2.load_state(runtime_state)
assert agent2.state == 1