mirror of https://github.com/microsoft/autogen.git
Agent factory can be async (#247)
This commit is contained in:
parent
718fad6e0d
commit
a52d3bab53
|
@ -1,9 +1,8 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import List
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.application.logging import EVENT_LOGGER_NAME
|
||||
from agnext.components.models import (
|
||||
|
@ -18,13 +17,17 @@ from agnext.application.logging import EVENT_LOGGER_NAME
|
|||
from team_one.markdown_browser import MarkdownConverter, UnsupportedFormatException
|
||||
from team_one.agents.coder import Coder, Executor
|
||||
from team_one.agents.orchestrator import LedgerOrchestrator
|
||||
from team_one.messages import BroadcastMessage, OrchestrationEvent, RequestReplyMessage
|
||||
from team_one.messages import BroadcastMessage
|
||||
from team_one.agents.multimodal_web_surfer import MultimodalWebSurfer
|
||||
from team_one.agents.file_surfer import FileSurfer
|
||||
from team_one.utils import LogHandler, message_content_to_str, create_completion_client_from_env
|
||||
from team_one.utils import LogHandler, message_content_to_str
|
||||
|
||||
import re
|
||||
|
||||
from agnext.components.models import AssistantMessage
|
||||
|
||||
|
||||
async def response_preparer(task: str, source: str, client: ChatCompletionClient, transcript: List[LLMMessage]):
|
||||
async def response_preparer(task: str, source: str, client: ChatCompletionClient, transcript: List[LLMMessage]) -> str:
|
||||
messages: List[LLMMessage] = [
|
||||
UserMessage(
|
||||
content=f"Earlier you were asked the following:\n\n{task}\n\nYour team then worked diligently to address that request. Here is a transcript of that conversation:",
|
||||
|
@ -37,7 +40,8 @@ async def response_preparer(task: str, source: str, client: ChatCompletionClient
|
|||
messages.append(
|
||||
UserMessage(
|
||||
content = message_content_to_str(message.content),
|
||||
source=message.source,
|
||||
# TODO fix this -> remove type ignore
|
||||
source=message.source, # type: ignore
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -68,7 +72,7 @@ If you are asked for a comma separated list, apply the above rules depending on
|
|||
# No answer
|
||||
if "unable to determine" in response.content.lower():
|
||||
messages.append( AssistantMessage(content=response.content, source="self" ) )
|
||||
messages.append(
|
||||
messages.append(
|
||||
UserMessage(
|
||||
content= f"""
|
||||
I understand that a definitive answer could not be determined. Please make a well-informed EDUCATED GUESS based on the conversation.
|
||||
|
@ -115,29 +119,29 @@ async def main() -> None:
|
|||
)
|
||||
|
||||
# Register agents.
|
||||
coder = runtime.register_and_get_proxy(
|
||||
coder = await runtime.register_and_get_proxy(
|
||||
"Coder",
|
||||
lambda: Coder(model_client=client),
|
||||
)
|
||||
|
||||
executor = runtime.register_and_get_proxy(
|
||||
executor = await runtime.register_and_get_proxy(
|
||||
"Executor",
|
||||
lambda: Executor(
|
||||
"A agent for executing code", executor=LocalCommandLineCodeExecutor()
|
||||
),
|
||||
)
|
||||
|
||||
file_surfer = runtime.register_and_get_proxy(
|
||||
file_surfer = await runtime.register_and_get_proxy(
|
||||
"file_surfer",
|
||||
lambda: FileSurfer(model_client=client),
|
||||
)
|
||||
|
||||
web_surfer = runtime.register_and_get_proxy(
|
||||
web_surfer = await runtime.register_and_get_proxy(
|
||||
"WebSurfer",
|
||||
lambda: MultimodalWebSurfer(), # Configuration is set later by init()
|
||||
)
|
||||
|
||||
orchestrator = runtime.register_and_get_proxy("orchestrator", lambda: LedgerOrchestrator(
|
||||
orchestrator = await runtime.register_and_get_proxy("orchestrator", lambda: LedgerOrchestrator(
|
||||
agents=[coder, executor, file_surfer, web_surfer],
|
||||
model_client=client,
|
||||
))
|
||||
|
@ -185,7 +189,7 @@ async def main() -> None:
|
|||
actual_orchestrator = runtime._get_agent(orchestrator.id) # type: ignore
|
||||
assert isinstance(actual_orchestrator, LedgerOrchestrator)
|
||||
transcript: List[LLMMessage] = actual_orchestrator._chat_history # type: ignore
|
||||
print(await response_preparer(task=task, source=orchestrator.metadata["name"], client=client, transcript=transcript))
|
||||
print(await response_preparer(task=task, source=(await orchestrator.metadata)["name"], client=client, transcript=transcript))
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -34,18 +34,18 @@ async def main() -> None:
|
|||
)
|
||||
|
||||
# Register agents.
|
||||
coder = runtime.register_and_get_proxy(
|
||||
coder = await runtime.register_and_get_proxy(
|
||||
"Coder",
|
||||
lambda: Coder(model_client=client),
|
||||
)
|
||||
executor = runtime.register_and_get_proxy(
|
||||
executor = await runtime.register_and_get_proxy(
|
||||
"Executor",
|
||||
lambda: Executor(
|
||||
"A agent for executing code", executor=LocalCommandLineCodeExecutor()
|
||||
),
|
||||
)
|
||||
|
||||
runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor]))
|
||||
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor]))
|
||||
|
||||
prompt = ""
|
||||
with open("prompt.txt", "rt") as fh:
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import FunctionCall, TypeRoutedAgent, message_handler
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components.code_executor import (
|
||||
CodeBlock,
|
||||
CodeExecutor,
|
||||
|
@ -16,16 +15,12 @@ from agnext.components.models import (
|
|||
AssistantMessage,
|
||||
AzureOpenAIChatCompletionClient,
|
||||
ChatCompletionClient,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
ModelCapabilities,
|
||||
OpenAIChatCompletionClient,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.components.tools import CodeExecutionResult, PythonCodeExecutionTool
|
||||
from agnext.core import AgentId, CancellationToken
|
||||
from agnext.core import CancellationToken
|
||||
|
||||
# from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
|
||||
|
@ -66,7 +61,7 @@ if __name__ == "__main__":
|
|||
main()
|
||||
```
|
||||
|
||||
The user cannot provide any feedback or perform any other action beyond executing the code you suggest. In particular, the user can't modify your code, and can't copy and paste anything, and can't fill in missing values. Thus, do not suggest incomplete code which requires users to perform any of these actions.
|
||||
The user cannot provide any feedback or perform any other action beyond executing the code you suggest. In particular, the user can't modify your code, and can't copy and paste anything, and can't fill in missing values. Thus, do not suggest incomplete code which requires users to perform any of these actions.
|
||||
|
||||
Check the execution result returned by the user. If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes -- code blocks must stand alone and be ready to execute without modification. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, and think of a different approach to try.
|
||||
|
||||
|
@ -222,11 +217,11 @@ async def main() -> None:
|
|||
)
|
||||
|
||||
# Register agents.
|
||||
coder = runtime.register_and_get(
|
||||
coder = await runtime.register_and_get(
|
||||
"Coder",
|
||||
lambda: Coder(model_client=client),
|
||||
)
|
||||
runtime.register(
|
||||
await runtime.register(
|
||||
"Executor",
|
||||
lambda: Executor(
|
||||
"A agent for executing code", executor=LocalCommandLineCodeExecutor()
|
||||
|
|
|
@ -32,7 +32,7 @@ dependencies = [
|
|||
"mypy==1.10.0",
|
||||
"ruff==0.4.8",
|
||||
"tiktoken",
|
||||
"types-Pillow",
|
||||
"types-pillow",
|
||||
"polars",
|
||||
"chess",
|
||||
"tavily-python",
|
||||
|
|
|
@ -22,10 +22,12 @@ async def select_speaker(memory: ChatMemory[Message], client: ChatCompletionClie
|
|||
history = "\n".join(history_messages)
|
||||
|
||||
# Construct agent roles.
|
||||
roles = "\n".join([f"{agent.metadata['name']}: {agent.metadata['description']}".strip() for agent in agents])
|
||||
roles = "\n".join(
|
||||
[f"{(await agent.metadata)['name']}: {(await agent.metadata)['description']}".strip() for agent in agents]
|
||||
)
|
||||
|
||||
# Construct agent list.
|
||||
participants = str([agent.metadata["name"] for agent in agents])
|
||||
participants = str([(await agent.metadata)["name"] for agent in agents])
|
||||
|
||||
# Select the next speaker.
|
||||
select_speaker_prompt = f"""You are in a role play game. The following roles are available:
|
||||
|
@ -39,16 +41,22 @@ Read the above conversation. Then select the next role from {participants} to pl
|
|||
select_speaker_messages = [SystemMessage(select_speaker_prompt)]
|
||||
response = await client.create(messages=select_speaker_messages)
|
||||
assert isinstance(response.content, str)
|
||||
mentions = mentioned_agents(response.content, agents)
|
||||
mentions = await mentioned_agents(response.content, agents)
|
||||
if len(mentions) != 1:
|
||||
raise ValueError(f"Expected exactly one agent to be mentioned, but got {mentions}")
|
||||
agent_name = list(mentions.keys())[0]
|
||||
agent_index = next((i for i, agent in enumerate(agents) if agent.metadata["name"] == agent_name), None)
|
||||
# Get the index of the selected agent by name
|
||||
agent_index = 0
|
||||
for i, agent in enumerate(agents):
|
||||
if (await agent.metadata)["name"] == agent_name:
|
||||
agent_index = i
|
||||
break
|
||||
|
||||
assert agent_index is not None
|
||||
return agent_index
|
||||
|
||||
|
||||
def mentioned_agents(message_content: str, agents: List[AgentProxy]) -> Dict[str, int]:
|
||||
async def mentioned_agents(message_content: str, agents: List[AgentProxy]) -> Dict[str, int]:
|
||||
"""Counts the number of times each agent is mentioned in the provided message content.
|
||||
Agent names will match under any of the following conditions (all case-sensitive):
|
||||
- Exact name match
|
||||
|
@ -66,7 +74,7 @@ def mentioned_agents(message_content: str, agents: List[AgentProxy]) -> Dict[str
|
|||
for agent in agents:
|
||||
# Finds agent mentions, taking word boundaries into account,
|
||||
# accommodates escaping underscores and underscores as spaces
|
||||
name = agent.metadata["name"]
|
||||
name = (await agent.metadata)["name"]
|
||||
regex = (
|
||||
r"(?<=\W)("
|
||||
+ re.escape(name)
|
||||
|
|
|
@ -170,7 +170,10 @@ Some additional points to consider:
|
|||
|
||||
# A reusable description of the team.
|
||||
team = "\n".join(
|
||||
[agent.name + ": " + self.runtime.agent_metadata(agent)["description"] for agent in self._specialists]
|
||||
[
|
||||
agent.name + ": " + (await self.runtime.agent_metadata(agent))["description"]
|
||||
for agent in self._specialists
|
||||
]
|
||||
)
|
||||
names = ", ".join([agent.name for agent in self._specialists])
|
||||
|
||||
|
|
|
@ -45,8 +45,8 @@ class Outer(TypeRoutedAgent):
|
|||
|
||||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
inner = runtime.register_and_get("inner", Inner)
|
||||
outer = runtime.register_and_get("outer", lambda: Outer(inner))
|
||||
inner = await runtime.register_and_get("inner", Inner)
|
||||
outer = await runtime.register_and_get("outer", lambda: Outer(inner))
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ class ChatCompletionAgent(TypeRoutedAgent):
|
|||
|
||||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
agent = runtime.register_and_get(
|
||||
agent = await runtime.register_and_get(
|
||||
"chat_agent",
|
||||
lambda: ChatCompletionAgent("Chat agent", get_chat_completion_client_from_envs(model="gpt-3.5-turbo")),
|
||||
)
|
||||
|
|
|
@ -77,7 +77,7 @@ async def main() -> None:
|
|||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
# Register the agents.
|
||||
jack = runtime.register_and_get(
|
||||
jack = await runtime.register_and_get(
|
||||
"Jack",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="Jack a comedian",
|
||||
|
@ -88,7 +88,7 @@ async def main() -> None:
|
|||
termination_word="TERMINATE",
|
||||
),
|
||||
)
|
||||
runtime.register_and_get(
|
||||
await runtime.register_and_get(
|
||||
"Cathy",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="Cathy a poet",
|
||||
|
|
|
@ -166,7 +166,7 @@ class EventHandler(AsyncAssistantEventHandler):
|
|||
print("\n".join(citations))
|
||||
|
||||
|
||||
def assistant_chat(runtime: AgentRuntime) -> AgentId:
|
||||
async def assistant_chat(runtime: AgentRuntime) -> AgentId:
|
||||
oai_assistant = openai.beta.assistants.create(
|
||||
model="gpt-4-turbo",
|
||||
description="An AI assistant that helps with everyday tasks.",
|
||||
|
@ -177,7 +177,7 @@ def assistant_chat(runtime: AgentRuntime) -> AgentId:
|
|||
thread = openai.beta.threads.create(
|
||||
tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}},
|
||||
)
|
||||
assistant = runtime.register_and_get(
|
||||
assistant = await runtime.register_and_get(
|
||||
"Assistant",
|
||||
lambda: OpenAIAssistantAgent(
|
||||
description="An AI assistant that helps with everyday tasks.",
|
||||
|
@ -188,7 +188,7 @@ def assistant_chat(runtime: AgentRuntime) -> AgentId:
|
|||
),
|
||||
)
|
||||
|
||||
user = runtime.register_and_get(
|
||||
user = await runtime.register_and_get(
|
||||
"User",
|
||||
lambda: UserProxyAgent(
|
||||
client=openai.AsyncClient(),
|
||||
|
@ -198,7 +198,7 @@ def assistant_chat(runtime: AgentRuntime) -> AgentId:
|
|||
),
|
||||
)
|
||||
# Create a group chat manager to facilitate a turn-based conversation.
|
||||
runtime.register(
|
||||
await runtime.register(
|
||||
"GroupChatManager",
|
||||
lambda: GroupChatManager(
|
||||
description="A group chat manager.",
|
||||
|
@ -225,7 +225,7 @@ This will upload data.csv to the assistant for use with the code interpreter too
|
|||
Type "exit" to exit the chat.
|
||||
"""
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
user = assistant_chat(runtime)
|
||||
user = await assistant_chat(runtime)
|
||||
_run_context = runtime.start()
|
||||
print(usage)
|
||||
# Request the user to start the conversation.
|
||||
|
|
|
@ -87,15 +87,15 @@ class ChatRoomUserAgent(TextualUserAgent):
|
|||
|
||||
|
||||
# Define a chat room with participants -- the runtime is the chat room.
|
||||
def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
||||
runtime.register(
|
||||
async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
||||
await runtime.register(
|
||||
"User",
|
||||
lambda: ChatRoomUserAgent(
|
||||
description="The user in the chat room.",
|
||||
app=app,
|
||||
),
|
||||
)
|
||||
alice = runtime.register_and_get_proxy(
|
||||
alice = await runtime.register_and_get_proxy(
|
||||
"Alice",
|
||||
lambda rt, id: ChatRoomAgent(
|
||||
name=id.name,
|
||||
|
@ -105,7 +105,7 @@ def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
|||
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
|
||||
),
|
||||
)
|
||||
bob = runtime.register_and_get_proxy(
|
||||
bob = await runtime.register_and_get_proxy(
|
||||
"Bob",
|
||||
lambda rt, id: ChatRoomAgent(
|
||||
name=id.name,
|
||||
|
@ -115,7 +115,7 @@ def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
|||
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
|
||||
),
|
||||
)
|
||||
charlie = runtime.register_and_get_proxy(
|
||||
charlie = await runtime.register_and_get_proxy(
|
||||
"Charlie",
|
||||
lambda rt, id: ChatRoomAgent(
|
||||
name=id.name,
|
||||
|
@ -126,9 +126,9 @@ def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
|||
),
|
||||
)
|
||||
app.welcoming_notice = f"""Welcome to the chat room demo with the following participants:
|
||||
1. 👧 {alice.id.name}: {alice.metadata['description']}
|
||||
2. 👱🏼♂️ {bob.id.name}: {bob.metadata['description']}
|
||||
3. 👨🏾🦳 {charlie.id.name}: {charlie.metadata['description']}
|
||||
1. 👧 {alice.id.name}: {(await alice.metadata)['description']}
|
||||
2. 👱🏼♂️ {bob.id.name}: {(await bob.metadata)['description']}
|
||||
3. 👨🏾🦳 {charlie.id.name}: {(await charlie.metadata)['description']}
|
||||
|
||||
Each participant decides on its own whether to respond to the latest message.
|
||||
|
||||
|
@ -139,7 +139,7 @@ You can greet the chat room by typing your first message below.
|
|||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
app = TextualChatApp(runtime, user_name="You")
|
||||
chat_room(runtime, app)
|
||||
await chat_room(runtime, app)
|
||||
_run_context = runtime.start()
|
||||
await app.run_async()
|
||||
|
||||
|
|
|
@ -88,7 +88,7 @@ def make_move(
|
|||
return f"Moved {piece_name} ({piece_symbol}) from {SQUARE_NAMES[newMove.from_square]} to {SQUARE_NAMES[newMove.to_square]}."
|
||||
|
||||
|
||||
def chess_game(runtime: AgentRuntime) -> None: # type: ignore
|
||||
async def chess_game(runtime: AgentRuntime) -> None: # type: ignore
|
||||
"""Create agents for a chess game and return the group chat."""
|
||||
|
||||
# Create the board.
|
||||
|
@ -156,7 +156,7 @@ def chess_game(runtime: AgentRuntime) -> None: # type: ignore
|
|||
),
|
||||
]
|
||||
|
||||
black = runtime.register_and_get(
|
||||
black = await runtime.register_and_get(
|
||||
"PlayerBlack",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="Player playing black.",
|
||||
|
@ -173,7 +173,7 @@ def chess_game(runtime: AgentRuntime) -> None: # type: ignore
|
|||
tools=black_tools,
|
||||
),
|
||||
)
|
||||
white = runtime.register_and_get(
|
||||
white = await runtime.register_and_get(
|
||||
"PlayerWhite",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="Player playing white.",
|
||||
|
@ -192,7 +192,7 @@ def chess_game(runtime: AgentRuntime) -> None: # type: ignore
|
|||
)
|
||||
# Create a group chat manager for the chess game to orchestrate a turn-based
|
||||
# conversation between the two agents.
|
||||
runtime.register(
|
||||
await runtime.register(
|
||||
"ChessGame",
|
||||
lambda: GroupChatManager(
|
||||
description="A chess game between two agents.",
|
||||
|
@ -204,7 +204,7 @@ def chess_game(runtime: AgentRuntime) -> None: # type: ignore
|
|||
|
||||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
chess_game(runtime)
|
||||
await chess_game(runtime)
|
||||
# Publish an initial message to trigger the group chat manager to start orchestration.
|
||||
await runtime.publish_message(TextMessage(content="Game started.", source="System"), namespace="default")
|
||||
while True:
|
||||
|
|
|
@ -19,15 +19,15 @@ from common.utils import get_chat_completion_client_from_envs
|
|||
from utils import TextualChatApp, TextualUserAgent
|
||||
|
||||
|
||||
def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
||||
runtime.register(
|
||||
async def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
||||
await runtime.register(
|
||||
"User",
|
||||
lambda: TextualUserAgent(
|
||||
description="A user looking for illustration.",
|
||||
app=app,
|
||||
),
|
||||
)
|
||||
descriptor = runtime.register_and_get_proxy(
|
||||
descriptor = await runtime.register_and_get_proxy(
|
||||
"Descriptor",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="An AI agent that provides a description of the image.",
|
||||
|
@ -46,7 +46,7 @@ def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
|||
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo", max_tokens=500),
|
||||
),
|
||||
)
|
||||
illustrator = runtime.register_and_get_proxy(
|
||||
illustrator = await runtime.register_and_get_proxy(
|
||||
"Illustrator",
|
||||
lambda: ImageGenerationAgent(
|
||||
description="An AI agent that generates images.",
|
||||
|
@ -55,7 +55,7 @@ def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
|||
memory=BufferedChatMemory(buffer_size=1),
|
||||
),
|
||||
)
|
||||
critic = runtime.register_and_get_proxy(
|
||||
critic = await runtime.register_and_get_proxy(
|
||||
"Critic",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="An AI agent that provides feedback on images given user's requirements.",
|
||||
|
@ -74,7 +74,7 @@ def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
|||
model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"),
|
||||
),
|
||||
)
|
||||
runtime.register(
|
||||
await runtime.register(
|
||||
"GroupChatManager",
|
||||
lambda: GroupChatManager(
|
||||
description="A chat manager that handles group chat.",
|
||||
|
@ -86,9 +86,9 @@ def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
|||
|
||||
app.welcoming_notice = f"""You are now in a group chat with the following agents:
|
||||
|
||||
1. 🤖 {descriptor.metadata['name']}: {descriptor.metadata.get('description')}
|
||||
2. 🤖 {illustrator.metadata['name']}: {illustrator.metadata.get('description')}
|
||||
3. 🤖 {critic.metadata['name']}: {critic.metadata.get('description')}
|
||||
1. 🤖 {(await descriptor.metadata)['name']}: {(await descriptor.metadata).get('description')}
|
||||
2. 🤖 {(await illustrator.metadata)['name']}: {(await illustrator.metadata).get('description')}
|
||||
3. 🤖 {(await critic.metadata)['name']}: {(await critic.metadata).get('description')}
|
||||
|
||||
Provide a prompt for the illustrator to generate an image.
|
||||
"""
|
||||
|
@ -97,7 +97,7 @@ Provide a prompt for the illustrator to generate an image.
|
|||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
app = TextualChatApp(runtime, user_name="You")
|
||||
illustrator_critics(runtime, app)
|
||||
await illustrator_critics(runtime, app)
|
||||
_run_context = runtime.start()
|
||||
await app.run_async()
|
||||
|
||||
|
|
|
@ -105,15 +105,15 @@ async def create_image(
|
|||
return f"Image created and saved to {filename}."
|
||||
|
||||
|
||||
def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: # type: ignore
|
||||
user_agent = runtime.register_and_get(
|
||||
async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: # type: ignore
|
||||
user_agent = await runtime.register_and_get(
|
||||
"Customer",
|
||||
lambda: TextualUserAgent(
|
||||
description="A customer looking for help.",
|
||||
app=app,
|
||||
),
|
||||
)
|
||||
developer = runtime.register_and_get(
|
||||
developer = await runtime.register_and_get(
|
||||
"Developer",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="A Python software developer.",
|
||||
|
@ -153,7 +153,7 @@ def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: #
|
|||
),
|
||||
)
|
||||
|
||||
product_manager = runtime.register_and_get(
|
||||
product_manager = await runtime.register_and_get(
|
||||
"ProductManager",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="A product manager. "
|
||||
|
@ -182,7 +182,7 @@ def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: #
|
|||
tool_approver=user_agent,
|
||||
),
|
||||
)
|
||||
ux_designer = runtime.register_and_get(
|
||||
ux_designer = await runtime.register_and_get(
|
||||
"UserExperienceDesigner",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="A user experience designer for creating user interfaces.",
|
||||
|
@ -215,7 +215,7 @@ def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: #
|
|||
),
|
||||
)
|
||||
|
||||
illustrator = runtime.register_and_get(
|
||||
illustrator = await runtime.register_and_get(
|
||||
"Illustrator",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="An illustrator for creating images.",
|
||||
|
@ -240,7 +240,7 @@ def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: #
|
|||
tool_approver=user_agent,
|
||||
),
|
||||
)
|
||||
runtime.register(
|
||||
await runtime.register(
|
||||
"GroupChatManager",
|
||||
lambda: GroupChatManager(
|
||||
description="A group chat manager.",
|
||||
|
@ -279,7 +279,7 @@ def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: #
|
|||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
app = TextualChatApp(runtime, user_name="You")
|
||||
software_consultancy(runtime, app)
|
||||
await software_consultancy(runtime, app)
|
||||
# Start the runtime.
|
||||
_run_context = runtime.start()
|
||||
# Start the app.
|
||||
|
|
|
@ -27,8 +27,8 @@ async def build_app(runtime: AgentRuntime) -> None:
|
|||
api_version="2024-02-01",
|
||||
)
|
||||
|
||||
runtime.register("GraphicDesigner", lambda: GraphicDesignerAgent(client=image_client))
|
||||
runtime.register("Auditor", lambda: AuditAgent(model_client=chat_client))
|
||||
await runtime.register("GraphicDesigner", lambda: GraphicDesignerAgent(client=image_client))
|
||||
await runtime.register("Auditor", lambda: AuditAgent(model_client=chat_client))
|
||||
|
||||
runtime.get("GraphicDesigner")
|
||||
runtime.get("Auditor")
|
||||
await runtime.get("GraphicDesigner")
|
||||
await runtime.get("Auditor")
|
||||
|
|
|
@ -30,7 +30,7 @@ class Printer(TypeRoutedAgent):
|
|||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
await build_app(runtime)
|
||||
runtime.register("Printer", lambda: Printer())
|
||||
await runtime.register("Printer", lambda: Printer())
|
||||
|
||||
ctx = runtime.start()
|
||||
|
||||
|
|
|
@ -180,8 +180,10 @@ async def main(task: str, temp_dir: str) -> None:
|
|||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
# Register the agents.
|
||||
runtime.register("coder", lambda: Coder(model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo")))
|
||||
runtime.register("executor", lambda: Executor(executor=LocalCommandLineCodeExecutor(work_dir=temp_dir)))
|
||||
await runtime.register(
|
||||
"coder", lambda: Coder(model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"))
|
||||
)
|
||||
await runtime.register("executor", lambda: Executor(executor=LocalCommandLineCodeExecutor(work_dir=temp_dir)))
|
||||
run_context = runtime.start()
|
||||
|
||||
# Publish the task message.
|
||||
|
|
|
@ -251,14 +251,14 @@ Code: <Your code>
|
|||
|
||||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
runtime.register(
|
||||
await runtime.register(
|
||||
"ReviewerAgent",
|
||||
lambda: ReviewerAgent(
|
||||
description="Code Reviewer",
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo"),
|
||||
),
|
||||
)
|
||||
runtime.register(
|
||||
await runtime.register(
|
||||
"CoderAgent",
|
||||
lambda: CoderAgent(
|
||||
description="Coder",
|
||||
|
|
|
@ -113,7 +113,7 @@ async def main() -> None:
|
|||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
# Register the participants.
|
||||
agent1 = runtime.register_and_get(
|
||||
agent1 = await runtime.register_and_get(
|
||||
"DataScientist",
|
||||
lambda: GroupChatParticipant(
|
||||
description="A data scientist",
|
||||
|
@ -121,7 +121,7 @@ async def main() -> None:
|
|||
model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo"),
|
||||
),
|
||||
)
|
||||
agent2 = runtime.register_and_get(
|
||||
agent2 = await runtime.register_and_get(
|
||||
"Engineer",
|
||||
lambda: GroupChatParticipant(
|
||||
description="An engineer",
|
||||
|
@ -129,7 +129,7 @@ async def main() -> None:
|
|||
model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo"),
|
||||
),
|
||||
)
|
||||
agent3 = runtime.register_and_get(
|
||||
agent3 = await runtime.register_and_get(
|
||||
"Artist",
|
||||
lambda: GroupChatParticipant(
|
||||
description="An artist",
|
||||
|
@ -139,7 +139,7 @@ async def main() -> None:
|
|||
)
|
||||
|
||||
# Register the group chat manager.
|
||||
runtime.register(
|
||||
await runtime.register(
|
||||
"GroupChatManager",
|
||||
lambda: RoundRobinGroupChatManager(
|
||||
description="A group chat manager",
|
||||
|
|
|
@ -112,7 +112,7 @@ class AggregatorAgent(TypeRoutedAgent):
|
|||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
# TODO: use different models for each agent.
|
||||
runtime.register(
|
||||
await runtime.register(
|
||||
"ReferenceAgent1",
|
||||
lambda: ReferenceAgent(
|
||||
description="Reference Agent 1",
|
||||
|
@ -120,7 +120,7 @@ async def main() -> None:
|
|||
model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo", temperature=0.1),
|
||||
),
|
||||
)
|
||||
runtime.register(
|
||||
await runtime.register(
|
||||
"ReferenceAgent2",
|
||||
lambda: ReferenceAgent(
|
||||
description="Reference Agent 2",
|
||||
|
@ -128,7 +128,7 @@ async def main() -> None:
|
|||
model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo", temperature=0.5),
|
||||
),
|
||||
)
|
||||
runtime.register(
|
||||
await runtime.register(
|
||||
"ReferenceAgent3",
|
||||
lambda: ReferenceAgent(
|
||||
description="Reference Agent 3",
|
||||
|
@ -136,7 +136,7 @@ async def main() -> None:
|
|||
model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo", temperature=1.0),
|
||||
),
|
||||
)
|
||||
runtime.register(
|
||||
await runtime.register(
|
||||
"AggregatorAgent",
|
||||
lambda: AggregatorAgent(
|
||||
description="Aggregator Agent",
|
||||
|
|
|
@ -211,7 +211,7 @@ async def main(question: str) -> None:
|
|||
# Register the solver agents.
|
||||
# Create a sparse connection: each solver agent has two neighbors.
|
||||
# NOTE: to create a dense connection, each solver agent should be connected to all other solver agents.
|
||||
runtime.register(
|
||||
await runtime.register(
|
||||
"MathSolver1",
|
||||
lambda: MathSolver(
|
||||
get_chat_completion_client_from_envs(model="gpt-3.5-turbo"),
|
||||
|
@ -219,7 +219,7 @@ async def main(question: str) -> None:
|
|||
max_round=3,
|
||||
),
|
||||
)
|
||||
runtime.register(
|
||||
await runtime.register(
|
||||
"MathSolver2",
|
||||
lambda: MathSolver(
|
||||
get_chat_completion_client_from_envs(model="gpt-3.5-turbo"),
|
||||
|
@ -227,7 +227,7 @@ async def main(question: str) -> None:
|
|||
max_round=3,
|
||||
),
|
||||
)
|
||||
runtime.register(
|
||||
await runtime.register(
|
||||
"MathSolver3",
|
||||
lambda: MathSolver(
|
||||
get_chat_completion_client_from_envs(model="gpt-3.5-turbo"),
|
||||
|
@ -235,7 +235,7 @@ async def main(question: str) -> None:
|
|||
max_round=3,
|
||||
),
|
||||
)
|
||||
runtime.register(
|
||||
await runtime.register(
|
||||
"MathSolver4",
|
||||
lambda: MathSolver(
|
||||
get_chat_completion_client_from_envs(model="gpt-3.5-turbo"),
|
||||
|
@ -244,7 +244,7 @@ async def main(question: str) -> None:
|
|||
),
|
||||
)
|
||||
# Register the aggregator agent.
|
||||
runtime.register("MathAggregator", lambda: MathAggregator(num_solvers=4))
|
||||
await runtime.register("MathAggregator", lambda: MathAggregator(num_solvers=4))
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
|
|
|
@ -130,7 +130,7 @@ async def main() -> None:
|
|||
)
|
||||
]
|
||||
# Register agents.
|
||||
tool_agent = runtime.register_and_get(
|
||||
tool_agent = await runtime.register_and_get(
|
||||
"tool_enabled_agent",
|
||||
lambda: ToolEnabledAgent(
|
||||
description="Tool Use Agent",
|
||||
|
|
|
@ -191,8 +191,8 @@ async def main() -> None:
|
|||
)
|
||||
]
|
||||
# Register agents.
|
||||
runtime.register("tool_executor", lambda: ToolExecutorAgent("Tool Executor", tools))
|
||||
runtime.register(
|
||||
await runtime.register("tool_executor", lambda: ToolExecutorAgent("Tool Executor", tools))
|
||||
await runtime.register(
|
||||
"tool_use_agent",
|
||||
lambda: ToolUseAgent(
|
||||
description="Tool Use Agent",
|
||||
|
|
|
@ -32,7 +32,7 @@ async def main() -> None:
|
|||
# Create the runtime.
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
# Register agents.
|
||||
tool_agent = runtime.register_and_get(
|
||||
tool_agent = await runtime.register_and_get(
|
||||
"tool_enabled_agent",
|
||||
lambda: ToolEnabledAgent(
|
||||
description="Tool Use Agent",
|
||||
|
|
|
@ -123,7 +123,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
|
||||
# (namespace, type) -> List[AgentId]
|
||||
self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set)
|
||||
self._agent_factories: Dict[str, Callable[[], Agent] | Callable[[AgentRuntime, AgentId], Agent]] = {}
|
||||
self._agent_factories: Dict[
|
||||
str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
|
||||
] = {}
|
||||
self._instantiated_agents: Dict[AgentId, Agent] = {}
|
||||
self._intervention_handler = intervention_handler
|
||||
self._known_namespaces: set[str] = set()
|
||||
|
@ -173,7 +175,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||
if sender is not None and sender.namespace != recipient.namespace:
|
||||
raise ValueError("Sender and recipient must be in the same namespace to communicate.")
|
||||
|
||||
self._process_seen_namespace(recipient.namespace)
|
||||
await self._process_seen_namespace(recipient.namespace)
|
||||
|
||||
content = message.__dict__ if hasattr(message, "__dict__") else message
|
||||
logger.info(f"Sending message of type {type(message).__name__} to {recipient.name}: {content}")
|
||||
|
@ -227,7 +229,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||
|
||||
assert explicit_namespace is not None or sender_namespace is not None
|
||||
namespace = cast(str, explicit_namespace or sender_namespace)
|
||||
self._process_seen_namespace(namespace)
|
||||
await self._process_seen_namespace(namespace)
|
||||
|
||||
self._message_queue.append(
|
||||
PublishMessageEnvelope(
|
||||
|
@ -238,17 +240,17 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||
)
|
||||
)
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state: Dict[str, Dict[str, Any]] = {}
|
||||
for agent_id in self._instantiated_agents:
|
||||
state[str(agent_id)] = dict(self._get_agent(agent_id).save_state())
|
||||
state[str(agent_id)] = dict((await self._get_agent(agent_id)).save_state())
|
||||
return state
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
for agent_id_str in state:
|
||||
agent_id = AgentId.from_str(agent_id_str)
|
||||
if agent_id.name in self._known_agent_names:
|
||||
self._get_agent(agent_id).load_state(state[str(agent_id)])
|
||||
(await self._get_agent(agent_id)).load_state(state[str(agent_id)])
|
||||
|
||||
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
|
||||
recipient = message_envelope.recipient
|
||||
|
@ -269,7 +271,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||
# delivery_stage=DeliveryStage.DELIVER,
|
||||
# )
|
||||
# )
|
||||
recipient_agent = self._get_agent(recipient)
|
||||
recipient_agent = await self._get_agent(recipient)
|
||||
response = await recipient_agent.on_message(
|
||||
message_envelope.message,
|
||||
cancellation_token=message_envelope.cancellation_token,
|
||||
|
@ -297,7 +299,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||
if message_envelope.sender is not None and agent_id.name == message_envelope.sender.name:
|
||||
continue
|
||||
|
||||
sender_agent = self._get_agent(message_envelope.sender) if message_envelope.sender is not None else None
|
||||
sender_agent = (
|
||||
await self._get_agent(message_envelope.sender) if message_envelope.sender is not None else None
|
||||
)
|
||||
sender_name = sender_agent.metadata["name"] if sender_agent is not None else "Unknown"
|
||||
logger.info(
|
||||
f"Calling message handler for {agent_id.name} with message type {type(message_envelope.message).__name__} published by {sender_name}"
|
||||
|
@ -312,7 +316,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||
# )
|
||||
# )
|
||||
|
||||
agent = self._get_agent(agent_id)
|
||||
agent = await self._get_agent(agent_id)
|
||||
future = agent.on_message(
|
||||
message_envelope.message,
|
||||
cancellation_token=message_envelope.cancellation_token,
|
||||
|
@ -430,19 +434,19 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||
def start(self) -> RunContext:
|
||||
return RunContext(self)
|
||||
|
||||
def agent_metadata(self, agent: AgentId) -> AgentMetadata:
|
||||
return self._get_agent(agent).metadata
|
||||
async def agent_metadata(self, agent: AgentId) -> AgentMetadata:
|
||||
return (await self._get_agent(agent)).metadata
|
||||
|
||||
def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
|
||||
return self._get_agent(agent).save_state()
|
||||
async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
|
||||
return (await self._get_agent(agent)).save_state()
|
||||
|
||||
def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
|
||||
self._get_agent(agent).load_state(state)
|
||||
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
|
||||
(await self._get_agent(agent)).load_state(state)
|
||||
|
||||
def register(
|
||||
async def register(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
) -> None:
|
||||
if name in self._agent_factories:
|
||||
raise ValueError(f"Agent with name {name} already exists.")
|
||||
|
@ -450,28 +454,30 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||
|
||||
# For all already prepared namespaces we need to prepare this agent
|
||||
for namespace in self._known_namespaces:
|
||||
self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
await self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
|
||||
def _invoke_agent_factory(
|
||||
self, agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], agent_id: AgentId
|
||||
async def _invoke_agent_factory(
|
||||
self,
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
agent_id: AgentId,
|
||||
) -> T:
|
||||
token = agent_instantiation_context.set((self, agent_id))
|
||||
with agent_instantiation_context((self, agent_id)):
|
||||
if len(inspect.signature(agent_factory).parameters) == 0:
|
||||
factory_one = cast(Callable[[], T], agent_factory)
|
||||
agent = factory_one()
|
||||
elif len(inspect.signature(agent_factory).parameters) == 2:
|
||||
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
|
||||
agent = factory_two(self, agent_id)
|
||||
else:
|
||||
raise ValueError("Agent factory must take 0 or 2 arguments.")
|
||||
|
||||
if len(inspect.signature(agent_factory).parameters) == 0:
|
||||
factory_one = cast(Callable[[], T], agent_factory)
|
||||
agent = factory_one()
|
||||
elif len(inspect.signature(agent_factory).parameters) == 2:
|
||||
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
|
||||
agent = factory_two(self, agent_id)
|
||||
else:
|
||||
raise ValueError("Agent factory must take 0 or 2 arguments.")
|
||||
if inspect.isawaitable(agent):
|
||||
return cast(T, await agent)
|
||||
|
||||
agent_instantiation_context.reset(token)
|
||||
return agent
|
||||
|
||||
return agent
|
||||
|
||||
def _get_agent(self, agent_id: AgentId) -> Agent:
|
||||
self._process_seen_namespace(agent_id.namespace)
|
||||
async def _get_agent(self, agent_id: AgentId) -> Agent:
|
||||
await self._process_seen_namespace(agent_id.namespace)
|
||||
if agent_id in self._instantiated_agents:
|
||||
return self._instantiated_agents[agent_id]
|
||||
|
||||
|
@ -480,25 +486,25 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||
|
||||
agent_factory = self._agent_factories[agent_id.name]
|
||||
|
||||
agent = self._invoke_agent_factory(agent_factory, agent_id)
|
||||
agent = await self._invoke_agent_factory(agent_factory, agent_id)
|
||||
for message_type in agent.metadata["subscriptions"]:
|
||||
self._per_type_subscribers[(agent_id.namespace, message_type)].add(agent_id)
|
||||
self._instantiated_agents[agent_id] = agent
|
||||
return agent
|
||||
|
||||
def get(self, name: str, *, namespace: str = "default") -> AgentId:
|
||||
return self._get_agent(AgentId(name=name, namespace=namespace)).id
|
||||
async def get(self, name: str, *, namespace: str = "default") -> AgentId:
|
||||
return (await self._get_agent(AgentId(name=name, namespace=namespace))).id
|
||||
|
||||
def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
|
||||
id = self.get(name, namespace=namespace)
|
||||
async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
|
||||
id = await self.get(name, namespace=namespace)
|
||||
return AgentProxy(id, self)
|
||||
|
||||
# Hydrate the agent instances in a namespace. The primary reason for this is
|
||||
# to ensure message type subscriptions are set up.
|
||||
def _process_seen_namespace(self, namespace: str) -> None:
|
||||
async def _process_seen_namespace(self, namespace: str) -> None:
|
||||
if namespace in self._known_namespaces:
|
||||
return
|
||||
|
||||
self._known_namespaces.add(namespace)
|
||||
for name in self._known_agent_names:
|
||||
self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
await self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Any, Awaitable, Callable, Mapping, Sequence, TypeVar, get_typ
|
|||
from ..core._agent import Agent
|
||||
from ..core._agent_id import AgentId
|
||||
from ..core._agent_metadata import AgentMetadata
|
||||
from ..core._agent_runtime import AgentRuntime, agent_instantiation_context
|
||||
from ..core._agent_runtime import AGENT_INSTANTIATION_CONTEXT_VAR, AgentRuntime
|
||||
from ..core._cancellation_token import CancellationToken
|
||||
from ..core._serialization import MESSAGE_TYPE_REGISTRY
|
||||
from ..core.exceptions import CantHandleException
|
||||
|
@ -46,7 +46,7 @@ class ClosureAgent(Agent):
|
|||
self, description: str, closure: Callable[[AgentRuntime, AgentId, T, CancellationToken], Awaitable[Any]]
|
||||
) -> None:
|
||||
try:
|
||||
runtime, id = agent_instantiation_context.get()
|
||||
runtime, id = AGENT_INSTANTIATION_CONTEXT_VAR.get()
|
||||
except LookupError as e:
|
||||
raise RuntimeError(
|
||||
"ClosureAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated."
|
||||
|
|
|
@ -7,7 +7,7 @@ from ._agent_id import AgentId
|
|||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_props import AgentChildren
|
||||
from ._agent_proxy import AgentProxy
|
||||
from ._agent_runtime import AgentRuntime, agent_instantiation_context
|
||||
from ._agent_runtime import AGENT_INSTANTIATION_CONTEXT_VAR, AgentRuntime, agent_instantiation_context
|
||||
from ._base_agent import BaseAgent
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._serialization import MESSAGE_TYPE_REGISTRY, TypeDeserializer, TypeSerializer
|
||||
|
@ -22,6 +22,7 @@ __all__ = [
|
|||
"CancellationToken",
|
||||
"AgentChildren",
|
||||
"agent_instantiation_context",
|
||||
"AGENT_INSTANTIATION_CONTEXT_VAR",
|
||||
"MESSAGE_TYPE_REGISTRY",
|
||||
"TypeSerializer",
|
||||
"TypeDeserializer",
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Mapping
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
|
@ -21,7 +21,7 @@ class AgentProxy:
|
|||
return self._agent
|
||||
|
||||
@property
|
||||
def metadata(self) -> AgentMetadata:
|
||||
def metadata(self) -> Awaitable[AgentMetadata]:
|
||||
"""Metadata of the agent."""
|
||||
return self._runtime.agent_metadata(self._agent)
|
||||
|
||||
|
@ -39,14 +39,14 @@ class AgentProxy:
|
|||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the state of the agent. The result must be JSON serializable."""
|
||||
return self._runtime.agent_save_state(self._agent)
|
||||
return await self._runtime.agent_save_state(self._agent)
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load in the state of the agent obtained from `save_state`.
|
||||
|
||||
Args:
|
||||
state (Mapping[str, Any]): State of the agent. Must be JSON serializable.
|
||||
"""
|
||||
self._runtime.agent_load_state(self._agent, state)
|
||||
await self._runtime.agent_load_state(self._agent, state)
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Callable, Mapping, Protocol, TypeVar, overload, runtime_checkable
|
||||
from typing import Any, Awaitable, Callable, Generator, Mapping, Protocol, TypeVar, overload, runtime_checkable
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
|
@ -13,7 +14,18 @@ from ._cancellation_token import CancellationToken
|
|||
|
||||
T = TypeVar("T", bound=Agent)
|
||||
|
||||
agent_instantiation_context: ContextVar[tuple[AgentRuntime, AgentId]] = ContextVar("agent_instantiation_context")
|
||||
AGENT_INSTANTIATION_CONTEXT_VAR: ContextVar[tuple[AgentRuntime, AgentId]] = ContextVar(
|
||||
"AGENT_INSTANTIATION_CONTEXT_VAR"
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def agent_instantiation_context(ctx: tuple[AgentRuntime, AgentId]) -> Generator[None, Any, None]:
|
||||
token = AGENT_INSTANTIATION_CONTEXT_VAR.set(ctx)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
AGENT_INSTANTIATION_CONTEXT_VAR.reset(token)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
@ -68,23 +80,23 @@ class AgentRuntime(Protocol):
|
|||
"""
|
||||
|
||||
@overload
|
||||
def register(
|
||||
async def register(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T],
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def register(
|
||||
async def register(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[AgentRuntime, AgentId], T],
|
||||
agent_factory: Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
) -> None: ...
|
||||
|
||||
def register(
|
||||
async def register(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
) -> None:
|
||||
"""Register an agent factory with the runtime associated with a specific name. The name must be unique.
|
||||
|
||||
|
@ -110,7 +122,7 @@ class AgentRuntime(Protocol):
|
|||
|
||||
...
|
||||
|
||||
def get(self, name: str, *, namespace: str = "default") -> AgentId:
|
||||
async def get(self, name: str, *, namespace: str = "default") -> AgentId:
|
||||
"""Get an agent by name and namespace.
|
||||
|
||||
Args:
|
||||
|
@ -122,7 +134,7 @@ class AgentRuntime(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
|
||||
async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
|
||||
"""Get a proxy for an agent by name and namespace.
|
||||
|
||||
Args:
|
||||
|
@ -135,27 +147,27 @@ class AgentRuntime(Protocol):
|
|||
...
|
||||
|
||||
@overload
|
||||
def register_and_get(
|
||||
async def register_and_get(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T],
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
) -> AgentId: ...
|
||||
|
||||
@overload
|
||||
def register_and_get(
|
||||
async def register_and_get(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[AgentRuntime, AgentId], T],
|
||||
agent_factory: Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
) -> AgentId: ...
|
||||
|
||||
def register_and_get(
|
||||
async def register_and_get(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
) -> AgentId:
|
||||
|
@ -169,31 +181,31 @@ class AgentRuntime(Protocol):
|
|||
Returns:
|
||||
AgentId: The agent id.
|
||||
"""
|
||||
self.register(name, agent_factory)
|
||||
return self.get(name, namespace=namespace)
|
||||
await self.register(name, agent_factory)
|
||||
return await self.get(name, namespace=namespace)
|
||||
|
||||
@overload
|
||||
def register_and_get_proxy(
|
||||
async def register_and_get_proxy(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T],
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
) -> AgentProxy: ...
|
||||
|
||||
@overload
|
||||
def register_and_get_proxy(
|
||||
async def register_and_get_proxy(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[AgentRuntime, AgentId], T],
|
||||
agent_factory: Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
) -> AgentProxy: ...
|
||||
|
||||
def register_and_get_proxy(
|
||||
async def register_and_get_proxy(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
) -> AgentProxy:
|
||||
|
@ -207,10 +219,10 @@ class AgentRuntime(Protocol):
|
|||
Returns:
|
||||
AgentProxy: The agent proxy.
|
||||
"""
|
||||
self.register(name, agent_factory)
|
||||
return self.get_proxy(name, namespace=namespace)
|
||||
await self.register(name, agent_factory)
|
||||
return await self.get_proxy(name, namespace=namespace)
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the state of the entire runtime, including all hosted agents. The only way to restore the state is to pass it to :meth:`load_state`.
|
||||
|
||||
The structure of the state is implementation defined and can be any JSON serializable object.
|
||||
|
@ -220,7 +232,7 @@ class AgentRuntime(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load the state of the entire runtime, including all hosted agents. The state should be the same as the one returned by :meth:`save_state`.
|
||||
|
||||
Args:
|
||||
|
@ -228,7 +240,7 @@ class AgentRuntime(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def agent_metadata(self, agent: AgentId) -> AgentMetadata:
|
||||
async def agent_metadata(self, agent: AgentId) -> AgentMetadata:
|
||||
"""Get the metadata for an agent.
|
||||
|
||||
Args:
|
||||
|
@ -239,7 +251,7 @@ class AgentRuntime(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
|
||||
async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
|
||||
"""Save the state of a single agent.
|
||||
|
||||
The structure of the state is implementation defined and can be any JSON serializable object.
|
||||
|
@ -252,7 +264,7 @@ class AgentRuntime(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
|
||||
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
|
||||
"""Load the state of a single agent.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Any, Mapping, Sequence
|
|||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_runtime import AgentRuntime, agent_instantiation_context
|
||||
from ._agent_runtime import AGENT_INSTANTIATION_CONTEXT_VAR, AgentRuntime
|
||||
from ._cancellation_token import CancellationToken
|
||||
|
||||
|
||||
|
@ -22,7 +22,7 @@ class BaseAgent(ABC, Agent):
|
|||
|
||||
def __init__(self, description: str, subscriptions: Sequence[str]) -> None:
|
||||
try:
|
||||
runtime, id = agent_instantiation_context.get()
|
||||
runtime, id = AGENT_INSTANTIATION_CONTEXT_VAR.get()
|
||||
except LookupError as e:
|
||||
raise RuntimeError(
|
||||
"BaseAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated."
|
||||
|
|
|
@ -12,6 +12,7 @@ from typing import (
|
|||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
ClassVar,
|
||||
DefaultDict,
|
||||
|
@ -188,7 +189,9 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
|
||||
# (namespace, type) -> List[AgentId]
|
||||
self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set)
|
||||
self._agent_factories: Dict[str, Callable[[], Agent] | Callable[[AgentRuntime, AgentId], Agent]] = {}
|
||||
self._agent_factories: Dict[
|
||||
str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
|
||||
] = {}
|
||||
# If empty, then all namespaces are valid for that agent type
|
||||
self._valid_namespaces: Dict[str, Sequence[str]] = {}
|
||||
self._instantiated_agents: Dict[AgentId, Agent] = {}
|
||||
|
@ -249,7 +252,7 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
(namespace, MESSAGE_TYPE_REGISTRY.type_name(message))
|
||||
]:
|
||||
logger.info("Sending message to %s", agent_id)
|
||||
agent = self._get_agent(agent_id)
|
||||
agent = await self._get_agent(agent_id)
|
||||
try:
|
||||
await agent.on_message(message, CancellationToken())
|
||||
logger.info("%s handled event %s", agent_id, message)
|
||||
|
@ -321,7 +324,7 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
|
||||
assert explicit_namespace is not None or sender_namespace is not None
|
||||
actual_namespace = cast(str, explicit_namespace or sender_namespace)
|
||||
self._process_seen_namespace(actual_namespace)
|
||||
await self._process_seen_namespace(actual_namespace)
|
||||
message_type = MESSAGE_TYPE_REGISTRY.type_name(message)
|
||||
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=message_type)
|
||||
message = Message(event=Event(namespace=actual_namespace, type=message_type, data=serialized_message))
|
||||
|
@ -332,25 +335,25 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
|
||||
await asyncio.create_task(write_message())
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
raise NotImplementedError("Saving state is not yet implemented.")
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
raise NotImplementedError("Loading state is not yet implemented.")
|
||||
|
||||
def agent_metadata(self, agent: AgentId) -> AgentMetadata:
|
||||
async def agent_metadata(self, agent: AgentId) -> AgentMetadata:
|
||||
raise NotImplementedError("Agent metadata is not yet implemented.")
|
||||
|
||||
def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
|
||||
async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
|
||||
raise NotImplementedError("Agent save_state is not yet implemented.")
|
||||
|
||||
def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
|
||||
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
|
||||
raise NotImplementedError("Agent load_state is not yet implemented.")
|
||||
|
||||
def register(
|
||||
async def register(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
) -> None:
|
||||
if name in self._agent_factories:
|
||||
raise ValueError(f"Agent with name {name} already exists.")
|
||||
|
@ -358,29 +361,32 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
|
||||
# For all already prepared namespaces we need to prepare this agent
|
||||
for namespace in self._known_namespaces:
|
||||
self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
await self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
|
||||
# TODO do we need to convert register to async?
|
||||
asyncio.create_task(self.send_register_agent_type(name))
|
||||
await self.send_register_agent_type(name)
|
||||
|
||||
def _invoke_agent_factory(
|
||||
async def _invoke_agent_factory(
|
||||
self,
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
agent_id: AgentId,
|
||||
) -> T:
|
||||
if len(inspect.signature(agent_factory).parameters) == 0:
|
||||
factory_one = cast(Callable[[], T], agent_factory)
|
||||
agent = factory_one()
|
||||
elif len(inspect.signature(agent_factory).parameters) == 2:
|
||||
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
|
||||
agent = factory_two(self, agent_id)
|
||||
else:
|
||||
raise ValueError("Agent factory must take 0 or 2 arguments.")
|
||||
with agent_instantiation_context((self, agent_id)):
|
||||
if len(inspect.signature(agent_factory).parameters) == 0:
|
||||
factory_one = cast(Callable[[], T], agent_factory)
|
||||
agent = factory_one()
|
||||
elif len(inspect.signature(agent_factory).parameters) == 2:
|
||||
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
|
||||
agent = factory_two(self, agent_id)
|
||||
else:
|
||||
raise ValueError("Agent factory must take 0 or 2 arguments.")
|
||||
|
||||
if inspect.isawaitable(agent):
|
||||
return cast(T, await agent)
|
||||
|
||||
return agent
|
||||
|
||||
def _get_agent(self, agent_id: AgentId) -> Agent:
|
||||
self._process_seen_namespace(agent_id.namespace)
|
||||
async def _get_agent(self, agent_id: AgentId) -> Agent:
|
||||
await self._process_seen_namespace(agent_id.namespace)
|
||||
if agent_id in self._instantiated_agents:
|
||||
return self._instantiated_agents[agent_id]
|
||||
|
||||
|
@ -389,9 +395,7 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
|
||||
agent_factory = self._agent_factories[agent_id.name]
|
||||
|
||||
token = agent_instantiation_context.set((self, agent_id))
|
||||
agent = self._invoke_agent_factory(agent_factory, agent_id)
|
||||
agent_instantiation_context.reset(token)
|
||||
agent = await self._invoke_agent_factory(agent_factory, agent_id)
|
||||
|
||||
for message_type in agent.metadata["subscriptions"]:
|
||||
self._per_type_subscribers[(agent_id.namespace, message_type)].add(agent_id)
|
||||
|
@ -399,19 +403,19 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||
self._instantiated_agents[agent_id] = agent
|
||||
return agent
|
||||
|
||||
def get(self, name: str, *, namespace: str = "default") -> AgentId:
|
||||
return self._get_agent(AgentId(name=name, namespace=namespace)).id
|
||||
async def get(self, name: str, *, namespace: str = "default") -> AgentId:
|
||||
return (await self._get_agent(AgentId(name=name, namespace=namespace))).id
|
||||
|
||||
def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
|
||||
id = self.get(name, namespace=namespace)
|
||||
async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
|
||||
id = await self.get(name, namespace=namespace)
|
||||
return AgentProxy(id, self)
|
||||
|
||||
# Hydrate the agent instances in a namespace. The primary reason for this is
|
||||
# to ensure message type subscriptions are set up.
|
||||
def _process_seen_namespace(self, namespace: str) -> None:
|
||||
async def _process_seen_namespace(self, namespace: str) -> None:
|
||||
if namespace in self._known_namespaces:
|
||||
return
|
||||
|
||||
self._known_namespaces.add(namespace)
|
||||
for name in self._known_agent_names:
|
||||
self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
await self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
|
|
|
@ -15,19 +15,19 @@ async def main() -> None:
|
|||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
# Register agents.
|
||||
coder = runtime.register_and_get_proxy(
|
||||
coder = await runtime.register_and_get_proxy(
|
||||
"Coder",
|
||||
lambda: Coder(model_client=create_completion_client_from_env()),
|
||||
)
|
||||
|
||||
executor = runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code"))
|
||||
executor = await runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code"))
|
||||
|
||||
user_proxy = runtime.register_and_get_proxy(
|
||||
user_proxy = await runtime.register_and_get_proxy(
|
||||
"UserProxy",
|
||||
lambda: UserProxy(description="The current user interacting with you."),
|
||||
)
|
||||
|
||||
runtime.register(
|
||||
await runtime.register(
|
||||
"orchestrator",
|
||||
lambda: LedgerOrchestrator(
|
||||
model_client=create_completion_client_from_env(), agents=[coder, executor, user_proxy]
|
||||
|
|
|
@ -15,19 +15,19 @@ async def main() -> None:
|
|||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
# Register agents.
|
||||
coder = runtime.register_and_get_proxy(
|
||||
coder = await runtime.register_and_get_proxy(
|
||||
"Coder",
|
||||
lambda: Coder(model_client=create_completion_client_from_env()),
|
||||
)
|
||||
|
||||
executor = runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code"))
|
||||
executor = await runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code"))
|
||||
|
||||
user_proxy = runtime.register_and_get_proxy(
|
||||
user_proxy = await runtime.register_and_get_proxy(
|
||||
"UserProxy",
|
||||
lambda: UserProxy(),
|
||||
)
|
||||
|
||||
runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor, user_proxy]))
|
||||
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor, user_proxy]))
|
||||
|
||||
run_context = runtime.start()
|
||||
await runtime.send_message(RequestReplyMessage(), user_proxy.id)
|
||||
|
|
|
@ -18,16 +18,16 @@ async def main() -> None:
|
|||
client = create_completion_client_from_env()
|
||||
|
||||
# Register agents.
|
||||
file_surfer = runtime.register_and_get_proxy(
|
||||
file_surfer = await runtime.register_and_get_proxy(
|
||||
"file_surfer",
|
||||
lambda: FileSurfer(model_client=client),
|
||||
)
|
||||
user_proxy = runtime.register_and_get_proxy(
|
||||
user_proxy = await runtime.register_and_get_proxy(
|
||||
"UserProxy",
|
||||
lambda: UserProxy(),
|
||||
)
|
||||
|
||||
runtime.register("orchestrator", lambda: RoundRobinOrchestrator([file_surfer, user_proxy]))
|
||||
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([file_surfer, user_proxy]))
|
||||
|
||||
run_context = runtime.start()
|
||||
await runtime.send_message(RequestReplyMessage(), user_proxy.id)
|
||||
|
|
|
@ -13,10 +13,10 @@ from team_one.utils import LogHandler
|
|||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
fake1 = runtime.register_and_get_proxy("fake_agent_1", lambda: ReflexAgent("First reflect agent"))
|
||||
fake2 = runtime.register_and_get_proxy("fake_agent_2", lambda: ReflexAgent("Second reflect agent"))
|
||||
fake3 = runtime.register_and_get_proxy("fake_agent_3", lambda: ReflexAgent("Third reflect agent"))
|
||||
runtime.register_and_get("orchestrator", lambda: RoundRobinOrchestrator([fake1, fake2, fake3]))
|
||||
fake1 = await runtime.register_and_get_proxy("fake_agent_1", lambda: ReflexAgent("First reflect agent"))
|
||||
fake2 = await runtime.register_and_get_proxy("fake_agent_2", lambda: ReflexAgent("Second reflect agent"))
|
||||
fake3 = await runtime.register_and_get_proxy("fake_agent_3", lambda: ReflexAgent("Third reflect agent"))
|
||||
await runtime.register_and_get("orchestrator", lambda: RoundRobinOrchestrator([fake1, fake2, fake3]))
|
||||
|
||||
task_message = UserMessage(content="Test Message", source="User")
|
||||
run_context = runtime.start()
|
||||
|
|
|
@ -19,16 +19,16 @@ async def main() -> None:
|
|||
client = create_completion_client_from_env()
|
||||
|
||||
# Register agents.
|
||||
coder = runtime.register_and_get_proxy(
|
||||
coder = await runtime.register_and_get_proxy(
|
||||
"Coder",
|
||||
lambda: Coder(model_client=client),
|
||||
)
|
||||
user_proxy = runtime.register_and_get_proxy(
|
||||
user_proxy = await runtime.register_and_get_proxy(
|
||||
"UserProxy",
|
||||
lambda: UserProxy(),
|
||||
)
|
||||
|
||||
runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, user_proxy]))
|
||||
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, user_proxy]))
|
||||
|
||||
run_context = runtime.start()
|
||||
await runtime.send_message(RequestReplyMessage(), user_proxy.id)
|
||||
|
|
|
@ -21,17 +21,17 @@ async def main() -> None:
|
|||
client = create_completion_client_from_env()
|
||||
|
||||
# Register agents.
|
||||
web_surfer = runtime.register_and_get_proxy(
|
||||
web_surfer = await runtime.register_and_get_proxy(
|
||||
"WebSurfer",
|
||||
lambda: MultimodalWebSurfer(),
|
||||
)
|
||||
|
||||
user_proxy = runtime.register_and_get_proxy(
|
||||
user_proxy = await runtime.register_and_get_proxy(
|
||||
"UserProxy",
|
||||
lambda: UserProxy(),
|
||||
)
|
||||
|
||||
runtime.register("orchestrator", lambda: RoundRobinOrchestrator([web_surfer, user_proxy]))
|
||||
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([web_surfer, user_proxy]))
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ dependencies = [
|
|||
"youtube-transcript-api",
|
||||
"SpeechRecognition",
|
||||
"pathvalidate",
|
||||
"playwright"
|
||||
"playwright",
|
||||
]
|
||||
|
||||
[tool.hatch.envs.default]
|
||||
|
@ -45,7 +45,8 @@ dependencies = [
|
|||
"aiofiles",
|
||||
"types-aiofiles",
|
||||
"types-requests",
|
||||
"azure-identity"
|
||||
"types-pillow",
|
||||
"azure-identity",
|
||||
]
|
||||
|
||||
[tool.hatch.envs.default.extra-scripts]
|
||||
|
@ -71,7 +72,13 @@ line-length = 120
|
|||
fix = true
|
||||
exclude = ["build", "dist", "page_script.js"]
|
||||
target-version = "py310"
|
||||
include = ["src/**", "examples/*.py"]
|
||||
include = [
|
||||
"src/**",
|
||||
"examples/*.py",
|
||||
"../../benchmarks/HumanEval/Templates/TeamOne/scenario.py",
|
||||
"../../benchmarks/HumanEval/Templates/TwoAgents/scenario.py",
|
||||
"../../benchmarks/GAIA/TeamOne/TwoAgents/scenario.py",
|
||||
]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
|
@ -81,7 +88,11 @@ select = ["E", "F", "W", "B", "Q", "I", "ASYNC"]
|
|||
ignore = ["F401", "E501"]
|
||||
|
||||
[tool.mypy]
|
||||
files = ["src", "examples", "tests"]
|
||||
files = [
|
||||
"src",
|
||||
"tests",
|
||||
"examples",
|
||||
]
|
||||
|
||||
strict = true
|
||||
python_version = "3.10"
|
||||
|
@ -100,7 +111,14 @@ disallow_untyped_decorators = true
|
|||
disallow_any_unimported = true
|
||||
|
||||
[tool.pyright]
|
||||
include = ["src", "tests", "examples"]
|
||||
include = [
|
||||
"src",
|
||||
"tests",
|
||||
"examples",
|
||||
"../../benchmarks/HumanEval/Templates/TeamOne/scenario.py",
|
||||
"../../benchmarks/HumanEval/Templates/TwoAgents/scenario.py",
|
||||
"../../benchmarks/GAIA/Templates/TeamOne/scenario.py",
|
||||
]
|
||||
typeCheckingMode = "strict"
|
||||
reportUnnecessaryIsInstance = false
|
||||
reportMissingTypeStubs = false
|
||||
|
|
|
@ -69,7 +69,7 @@ class BaseOrchestrator(TypeRoutedAgent):
|
|||
logger.info(
|
||||
OrchestrationEvent(
|
||||
source=f"{self.metadata['name']} (thought)",
|
||||
message=f"Next speaker {next_agent.metadata['name']}" "",
|
||||
message=f"Next speaker {(await next_agent.metadata)['name']}" "",
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -68,7 +68,7 @@ def _draw_roi(
|
|||
luminance = color[0] * 0.3 + color[1] * 0.59 + color[2] * 0.11
|
||||
text_color = (0, 0, 0, 255) if luminance > 90 else (255, 255, 255, 255)
|
||||
|
||||
roi = [(rect["left"], rect["top"]), (rect["right"], rect["bottom"])]
|
||||
roi = ((rect["left"], rect["top"]), (rect["right"], rect["bottom"]))
|
||||
|
||||
label_location = (rect["right"], rect["top"])
|
||||
label_anchor = "rb"
|
||||
|
|
|
@ -79,16 +79,16 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||
def _get_ledger_prompt(self, task: str, team: str, names: List[str]) -> str:
|
||||
return self._ledger_prompt.format(task=task, team=team, names=names)
|
||||
|
||||
def _get_team_description(self) -> str:
|
||||
async def _get_team_description(self) -> str:
|
||||
team_description = ""
|
||||
for agent in self._agents:
|
||||
name = agent.metadata["name"]
|
||||
description = agent.metadata["description"]
|
||||
name = (await agent.metadata)["name"]
|
||||
description = (await agent.metadata)["description"]
|
||||
team_description += f"{name}: {description}\n"
|
||||
return team_description
|
||||
|
||||
def _get_team_names(self) -> List[str]:
|
||||
return [agent.metadata["name"] for agent in self._agents]
|
||||
async def _get_team_names(self) -> List[str]:
|
||||
return [(await agent.metadata)["name"] for agent in self._agents]
|
||||
|
||||
def _set_task_str(self, message: LLMMessage) -> None:
|
||||
if len(self._chat_history) == 1:
|
||||
|
@ -112,7 +112,7 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||
return False
|
||||
|
||||
async def _plan(self) -> str:
|
||||
team_description = self._get_team_description()
|
||||
team_description = await self._get_team_description()
|
||||
|
||||
# 1. GATHER FACTS
|
||||
# create a closed book task and generate a response and update the chat history
|
||||
|
@ -144,8 +144,8 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||
async def update_ledger(self) -> Dict[str, Any]:
|
||||
max_json_retries = 10
|
||||
|
||||
team_description = self._get_team_description()
|
||||
names = self._get_team_names()
|
||||
team_description = await self._get_team_description()
|
||||
names = await self._get_team_names()
|
||||
ledger_prompt = self._get_ledger_prompt(self.task_str, team_description, names)
|
||||
ledger_user_message = UserMessage(content=ledger_prompt, source=self.metadata["name"])
|
||||
|
||||
|
@ -234,7 +234,7 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||
|
||||
next_agent_name = ledger_dict["next_speaker"]["answer"]
|
||||
for agent in self._agents:
|
||||
if agent.metadata["name"] == next_agent_name:
|
||||
if (await agent.metadata)["name"] == next_agent_name:
|
||||
# broadcast a new message
|
||||
instruction = ledger_dict["instruction_or_question"]["answer"]
|
||||
user_message = UserMessage(content=instruction, source=self.metadata["name"])
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from agnext.core import AgentRuntime, agent_instantiation_context, AgentId
|
||||
from agnext.core import AgentRuntime, AGENT_INSTANTIATION_CONTEXT_VAR, AgentId
|
||||
|
||||
from test_utils import NoopAgent
|
||||
|
||||
|
@ -11,7 +11,7 @@ async def test_base_agent_create(mocker: MockerFixture) -> None:
|
|||
runtime = mocker.Mock(spec=AgentRuntime)
|
||||
|
||||
# Shows how to set the context for the agent instantiation in a test context
|
||||
agent_instantiation_context.set((runtime, AgentId("name", "namespace")))
|
||||
AGENT_INSTANTIATION_CONTEXT_VAR.set((runtime, AgentId("name", "namespace")))
|
||||
|
||||
agent = NoopAgent()
|
||||
assert agent.runtime == runtime
|
||||
|
|
|
@ -57,7 +57,7 @@ class NestingLongRunningAgent(TypeRoutedAgent):
|
|||
async def test_cancellation_with_token() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
long_running = runtime.register_and_get("long_running", LongRunningAgent)
|
||||
long_running = await runtime.register_and_get("long_running", LongRunningAgent)
|
||||
token = CancellationToken()
|
||||
response = asyncio.create_task(runtime.send_message(MessageType(), recipient=long_running, cancellation_token=token))
|
||||
assert not response.done()
|
||||
|
@ -73,7 +73,7 @@ async def test_cancellation_with_token() -> None:
|
|||
await response
|
||||
|
||||
assert response.done()
|
||||
long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore
|
||||
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
|
||||
assert long_running_agent.called
|
||||
assert long_running_agent.cancelled
|
||||
|
||||
|
@ -83,8 +83,8 @@ async def test_cancellation_with_token() -> None:
|
|||
async def test_nested_cancellation_only_outer_called() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
long_running = runtime.register_and_get("long_running", LongRunningAgent)
|
||||
nested = runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
|
||||
long_running = await runtime.register_and_get("long_running", LongRunningAgent)
|
||||
nested = await runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
|
||||
|
||||
token = CancellationToken()
|
||||
response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token))
|
||||
|
@ -100,10 +100,10 @@ async def test_nested_cancellation_only_outer_called() -> None:
|
|||
await response
|
||||
|
||||
assert response.done()
|
||||
nested_agent: NestingLongRunningAgent = runtime._get_agent(nested) # type: ignore
|
||||
nested_agent: NestingLongRunningAgent = await runtime._get_agent(nested) # type: ignore
|
||||
assert nested_agent.called
|
||||
assert nested_agent.cancelled
|
||||
long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore
|
||||
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
|
||||
assert long_running_agent.called is False
|
||||
assert long_running_agent.cancelled is False
|
||||
|
||||
|
@ -111,8 +111,8 @@ async def test_nested_cancellation_only_outer_called() -> None:
|
|||
async def test_nested_cancellation_inner_called() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
long_running = runtime.register_and_get("long_running", LongRunningAgent )
|
||||
nested = runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
|
||||
long_running = await runtime.register_and_get("long_running", LongRunningAgent )
|
||||
nested = await runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
|
||||
|
||||
token = CancellationToken()
|
||||
response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token))
|
||||
|
@ -130,9 +130,9 @@ async def test_nested_cancellation_inner_called() -> None:
|
|||
await response
|
||||
|
||||
assert response.done()
|
||||
nested_agent: NestingLongRunningAgent = runtime._get_agent(nested) # type: ignore
|
||||
nested_agent: NestingLongRunningAgent = await runtime._get_agent(nested) # type: ignore
|
||||
assert nested_agent.called
|
||||
assert nested_agent.cancelled
|
||||
long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore
|
||||
long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore
|
||||
assert long_running_agent.called
|
||||
assert long_running_agent.cancelled
|
||||
|
|
|
@ -28,7 +28,7 @@ async def test_register_receives_publish() -> None:
|
|||
namespace = id.namespace
|
||||
await queue.put((namespace, message.content))
|
||||
|
||||
runtime.register("name", lambda: ClosureAgent("My agent", log_message))
|
||||
await runtime.register("name", lambda: ClosureAgent("My agent", log_message))
|
||||
run_context = runtime.start()
|
||||
await runtime.publish_message(Message("first message"), namespace="default")
|
||||
await runtime.publish_message(Message("second message"), namespace="default")
|
||||
|
|
|
@ -19,7 +19,7 @@ async def test_intervention_count_messages() -> None:
|
|||
|
||||
handler = DebugInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
loopback = runtime.register_and_get("name", LoopbackAgent)
|
||||
loopback = await runtime.register_and_get("name", LoopbackAgent)
|
||||
run_context = runtime.start()
|
||||
|
||||
_response = await runtime.send_message(MessageType(), recipient=loopback)
|
||||
|
@ -27,7 +27,7 @@ async def test_intervention_count_messages() -> None:
|
|||
await run_context.stop()
|
||||
|
||||
assert handler.num_messages == 1
|
||||
loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore
|
||||
loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore
|
||||
assert loopback_agent.num_calls == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -40,7 +40,7 @@ async def test_intervention_drop_send() -> None:
|
|||
handler = DropSendInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
loopback = runtime.register_and_get("name", LoopbackAgent)
|
||||
loopback = await runtime.register_and_get("name", LoopbackAgent)
|
||||
run_context = runtime.start()
|
||||
|
||||
with pytest.raises(MessageDroppedException):
|
||||
|
@ -48,7 +48,7 @@ async def test_intervention_drop_send() -> None:
|
|||
|
||||
await run_context.stop()
|
||||
|
||||
loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore
|
||||
loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore
|
||||
assert loopback_agent.num_calls == 0
|
||||
|
||||
|
||||
|
@ -62,7 +62,7 @@ async def test_intervention_drop_response() -> None:
|
|||
handler = DropResponseInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
loopback = runtime.register_and_get("name", LoopbackAgent)
|
||||
loopback = await runtime.register_and_get("name", LoopbackAgent)
|
||||
run_context = runtime.start()
|
||||
|
||||
with pytest.raises(MessageDroppedException):
|
||||
|
@ -84,7 +84,7 @@ async def test_intervention_raise_exception_on_send() -> None:
|
|||
handler = ExceptionInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
long_running = runtime.register_and_get("name", LoopbackAgent)
|
||||
long_running = await runtime.register_and_get("name", LoopbackAgent)
|
||||
run_context = runtime.start()
|
||||
|
||||
with pytest.raises(InterventionException):
|
||||
|
@ -92,7 +92,7 @@ async def test_intervention_raise_exception_on_send() -> None:
|
|||
|
||||
await run_context.stop()
|
||||
|
||||
long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore
|
||||
long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore
|
||||
assert long_running_agent.num_calls == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -108,12 +108,12 @@ async def test_intervention_raise_exception_on_respond() -> None:
|
|||
handler = ExceptionInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(intervention_handler=handler)
|
||||
|
||||
long_running = runtime.register_and_get("name", LoopbackAgent)
|
||||
long_running = await runtime.register_and_get("name", LoopbackAgent)
|
||||
run_context = runtime.start()
|
||||
with pytest.raises(InterventionException):
|
||||
_response = await runtime.send_message(MessageType(), recipient=long_running)
|
||||
|
||||
await run_context.stop()
|
||||
|
||||
long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore
|
||||
long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
|
|
@ -14,31 +14,31 @@ async def test_agent_names_must_be_unique() -> None:
|
|||
assert agent.id == id
|
||||
return agent
|
||||
|
||||
agent1 = runtime.register_and_get("name1", agent_factory)
|
||||
agent1 = await runtime.register_and_get("name1", agent_factory)
|
||||
assert agent1 == AgentId("name1", "default")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_agent1 = runtime.register_and_get("name1", NoopAgent)
|
||||
_agent1 = await runtime.register_and_get("name1", NoopAgent)
|
||||
|
||||
_agent1 = runtime.register_and_get("name3", NoopAgent)
|
||||
_agent1 = await runtime.register_and_get("name3", NoopAgent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_receives_publish() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
runtime.register("name", LoopbackAgent)
|
||||
await runtime.register("name", LoopbackAgent)
|
||||
run_context = runtime.start()
|
||||
await runtime.publish_message(MessageType(), namespace="default")
|
||||
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name")) # type: ignore
|
||||
long_running_agent: LoopbackAgent = await runtime._get_agent(await runtime.get("name")) # type: ignore
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name", namespace="other")) # type: ignore
|
||||
other_long_running_agent: LoopbackAgent = await runtime._get_agent(await runtime.get("name", namespace="other")) # type: ignore
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
|
||||
|
@ -54,7 +54,7 @@ async def test_register_receives_publish_cascade() -> None:
|
|||
|
||||
# Register agents
|
||||
for i in range(num_agents):
|
||||
runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds))
|
||||
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds))
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
|
@ -67,5 +67,5 @@ async def test_register_receives_publish_cascade() -> None:
|
|||
|
||||
# Check that each agent received the correct number of messages.
|
||||
for i in range(num_agents):
|
||||
agent: CascadingAgent = runtime._get_agent(runtime.get(f"name{i}")) # type: ignore
|
||||
agent: CascadingAgent = await runtime._get_agent(await runtime.get(f"name{i}")) # type: ignore
|
||||
assert agent.num_calls == total_num_calls_expected
|
||||
|
|
|
@ -5,8 +5,8 @@ from agnext.application import SingleThreadedAgentRuntime
|
|||
from agnext.core import BaseAgent, CancellationToken
|
||||
|
||||
|
||||
class StatefulAgent(BaseAgent): # type: ignore
|
||||
def __init__(self) -> None: # type: ignore
|
||||
class StatefulAgent(BaseAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("A stateful agent", [])
|
||||
self.state = 0
|
||||
|
||||
|
@ -14,7 +14,7 @@ class StatefulAgent(BaseAgent): # type: ignore
|
|||
def subscriptions(self) -> Sequence[type]:
|
||||
return []
|
||||
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: # type: ignore
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
|
@ -28,8 +28,8 @@ class StatefulAgent(BaseAgent): # type: ignore
|
|||
async def test_agent_can_save_state() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
agent1_id = runtime.register_and_get("name1", StatefulAgent)
|
||||
agent1: StatefulAgent = runtime._get_agent(agent1_id) # type: ignore
|
||||
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
|
||||
agent1: StatefulAgent = await runtime._get_agent(agent1_id) # type: ignore
|
||||
assert agent1.state == 0
|
||||
agent1.state = 1
|
||||
assert agent1.state == 1
|
||||
|
@ -46,19 +46,19 @@ async def test_agent_can_save_state() -> None:
|
|||
async def test_runtime_can_save_state() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
agent1_id = runtime.register_and_get("name1", StatefulAgent)
|
||||
agent1: StatefulAgent = runtime._get_agent(agent1_id) # type: ignore
|
||||
agent1_id = await runtime.register_and_get("name1", StatefulAgent)
|
||||
agent1: StatefulAgent = await runtime._get_agent(agent1_id) # type: ignore
|
||||
assert agent1.state == 0
|
||||
agent1.state = 1
|
||||
assert agent1.state == 1
|
||||
|
||||
runtime_state = runtime.save_state()
|
||||
runtime_state = await runtime.save_state()
|
||||
|
||||
runtime2 = SingleThreadedAgentRuntime()
|
||||
agent2_id = runtime2.register_and_get("name1", StatefulAgent)
|
||||
agent2: StatefulAgent = runtime2._get_agent(agent2_id) # type: ignore
|
||||
agent2_id = await runtime2.register_and_get("name1", StatefulAgent)
|
||||
agent2: StatefulAgent = await runtime2._get_agent(agent2_id) # type: ignore
|
||||
|
||||
runtime2.load_state(runtime_state)
|
||||
await runtime2.load_state(runtime_state)
|
||||
assert agent2.state == 1
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue