mirror of https://github.com/microsoft/autogen.git
* create model context component, remove chat memory component, refactor samples #454 * Fix bugs in samples. * Fix * Update docs * add unit tests
This commit is contained in:
parent
82bb342fb3
commit
976a7d4d77
|
@ -28,7 +28,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -51,7 +51,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -85,7 +85,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -93,20 +93,24 @@
|
|||
"output_type": "stream",
|
||||
"text": [
|
||||
"Streamed responses:\n",
|
||||
"In the heart of the Whispering Woods lived Ember, a small dragon with scales of shimmering gold. Unlike other dragons, Ember breathed not fire but music, each note a whisper of ancient songs. The villagers, initially fearful, soon realized her gift brought harmony and joy.\n",
|
||||
"In a secluded valley where the sun painted the sky with hues of gold, a solitary dragon named Bremora stood guard. Her emerald scales shimmered with an ancient light as she watched over the village below. Unlike her fiery kin, Bremora had no desire for destruction; her soul was bound by a promise to protect.\n",
|
||||
"\n",
|
||||
"One night, as darkness threatened the land, Ember's melodies summoned the stars, casting a protective glow. The villagers danced beneath the celestial orchestra, their worries dissolving like morning mist.\n",
|
||||
"Generations ago, a wise elder had befriended Bremora, offering her companionship instead of fear. In gratitude, she vowed to shield the village from calamity. Years passed, and children grew up believing in the legends of a watchful dragon who brought them prosperity and peace.\n",
|
||||
"\n",
|
||||
"From that day on, Ember's song became a nightly ritual, a promise that light and harmony would always prevail. The dragon of the Whispering Woods was a symbol of peace, her golden scales a testament to the magic of gentleness.\n",
|
||||
"One summer, an ominous storm threatened the valley, with ravenous winds and torrents of rain. Bremora rose into the tempest, her mighty wings defying the chaos. She channeled her breath—not of fire, but of warmth and tranquility—calming the storm and saving her cherished valley.\n",
|
||||
"\n",
|
||||
"When dawn broke and the village emerged unscathed, the people looked to the sky. There, Bremora soared gracefully, a guardian spirit woven into their lives, silently promising her eternal vigilance.\n",
|
||||
"\n",
|
||||
"------------\n",
|
||||
"\n",
|
||||
"The complete response:\n",
|
||||
"In the heart of the Whispering Woods lived Ember, a small dragon with scales of shimmering gold. Unlike other dragons, Ember breathed not fire but music, each note a whisper of ancient songs. The villagers, initially fearful, soon realized her gift brought harmony and joy.\n",
|
||||
"In a secluded valley where the sun painted the sky with hues of gold, a solitary dragon named Bremora stood guard. Her emerald scales shimmered with an ancient light as she watched over the village below. Unlike her fiery kin, Bremora had no desire for destruction; her soul was bound by a promise to protect.\n",
|
||||
"\n",
|
||||
"One night, as darkness threatened the land, Ember's melodies summoned the stars, casting a protective glow. The villagers danced beneath the celestial orchestra, their worries dissolving like morning mist.\n",
|
||||
"Generations ago, a wise elder had befriended Bremora, offering her companionship instead of fear. In gratitude, she vowed to shield the village from calamity. Years passed, and children grew up believing in the legends of a watchful dragon who brought them prosperity and peace.\n",
|
||||
"\n",
|
||||
"From that day on, Ember's song became a nightly ritual, a promise that light and harmony would always prevail. The dragon of the Whispering Woods was a symbol of peace, her golden scales a testament to the magic of gentleness.\n"
|
||||
"One summer, an ominous storm threatened the valley, with ravenous winds and torrents of rain. Bremora rose into the tempest, her mighty wings defying the chaos. She channeled her breath—not of fire, but of warmth and tranquility—calming the storm and saving her cherished valley.\n",
|
||||
"\n",
|
||||
"When dawn broke and the village emerged unscathed, the people looked to the sky. There, Bremora soared gracefully, a guardian spirit woven into their lives, silently promising her eternal vigilance.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -260,7 +264,7 @@
|
|||
"metadata": {},
|
||||
"source": [
|
||||
"The `SimpleAgent` class is a subclass of the\n",
|
||||
"{py:class}`autogen_core.components.TypeRoutedAgent` class for the convenience of automatically routing messages to the appropriate handlers.\n",
|
||||
"{py:class}`autogen_core.components.RoutedAgent` class for the convenience of automatically routing messages to the appropriate handlers.\n",
|
||||
"It has a single handler, `handle_user_message`, which handles message from the user. It uses the `ChatCompletionClient` to generate a response to the message.\n",
|
||||
"It then returns the response to the user, following the direct communication model.\n",
|
||||
"\n",
|
||||
|
@ -273,46 +277,46 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Seattle offers a wide range of activities and attractions for visitors. Here are some fun things to do in the city:\n",
|
||||
"Seattle is a vibrant city with a wide range of activities and attractions. Here are some fun things to do in Seattle:\n",
|
||||
"\n",
|
||||
"1. **Space Needle**: Visit this iconic landmark for stunning panoramic views of the city and surrounding mountains.\n",
|
||||
"1. **Space Needle**: Visit this iconic observation tower for stunning views of the city and surrounding mountains.\n",
|
||||
"\n",
|
||||
"2. **Pike Place Market**: Explore this historic market where you can shop for fresh produce, local crafts, and enjoy delicious street food. Don't miss the famous fish-throwing!\n",
|
||||
"2. **Pike Place Market**: Explore this historic market where you can see the famous fish toss, buy local produce, and find unique crafts and eateries.\n",
|
||||
"\n",
|
||||
"3. **Chihuly Garden and Glass**: Admire the breathtaking glass art installations by artist Dale Chihuly, both indoors and in the beautiful outdoor garden.\n",
|
||||
"3. **Museum of Pop Culture (MoPOP)**: Dive into the world of contemporary culture, music, and science fiction at this interactive museum.\n",
|
||||
"\n",
|
||||
"4. **Museum of Pop Culture (MoPOP)**: Discover exhibits focused on music, science fiction, and pop culture, including artifacts from famous films and music legends.\n",
|
||||
"4. **Chihuly Garden and Glass**: Marvel at the beautiful glass art installations by artist Dale Chihuly, located right next to the Space Needle.\n",
|
||||
"\n",
|
||||
"5. **Seattle Aquarium**: Learn about marine life native to the Pacific Northwest and see fascinating exhibits, including sea otters and jellyfish.\n",
|
||||
"5. **Seattle Aquarium**: Discover the diverse marine life of the Pacific Northwest at this engaging aquarium.\n",
|
||||
"\n",
|
||||
"6. **Fremont Troll**: Take a photo with this quirky public art installation, a large troll sculpture located under the Aurora Bridge.\n",
|
||||
"6. **Seattle Art Museum**: Explore a vast collection of art from around the world, including contemporary and indigenous art.\n",
|
||||
"\n",
|
||||
"7. **Kerry Park**: Enjoy one of the best viewpoints of Seattle's skyline, especially at sunset or during the evening when the city lights up.\n",
|
||||
"7. **Kerry Park**: For one of the best views of the Seattle skyline, head to this small park on Queen Anne Hill.\n",
|
||||
"\n",
|
||||
"8. **Discovery Park**: Explore this large urban park with trails, beaches, and beautiful views of Puget Sound and the Olympic Mountains.\n",
|
||||
"8. **Ballard Locks**: Watch boats pass through the locks and observe the salmon ladder to see salmon migrating.\n",
|
||||
"\n",
|
||||
"9. **Seattle Art Museum**: Browse a diverse collection of art from around the world, including contemporary and Native American art.\n",
|
||||
"9. **Ferry to Bainbridge Island**: Take a scenic ferry ride across Puget Sound to enjoy charming shops, restaurants, and beautiful natural scenery.\n",
|
||||
"\n",
|
||||
"10. **Take a Ferry Ride**: Enjoy a scenic boat ride to nearby islands like Bainbridge Island or Vashon Island. The views of the Seattle skyline from the water are stunning.\n",
|
||||
"10. **Olympic Sculpture Park**: Stroll through this outdoor park with large-scale sculptures and stunning views of the waterfront and mountains.\n",
|
||||
"\n",
|
||||
"11. **Underground Tour**: Learn about Seattle's history on a guided tour of the underground passageways that played a significant role in the city’s development.\n",
|
||||
"11. **Underground Tour**: Discover Seattle's history on this quirky tour of the city's underground passageways in Pioneer Square.\n",
|
||||
"\n",
|
||||
"12. **Ballard Locks**: Visit the Hiram M. Chittenden Locks to see boats pass between Lake Washington and Puget Sound and watch salmon swim upstream in the fish ladder (seasonal).\n",
|
||||
"12. **Seattle Waterfront**: Enjoy the shops, restaurants, and attractions along the waterfront, including the Seattle Great Wheel and the aquarium.\n",
|
||||
"\n",
|
||||
"13. **Local Breweries**: Seattle is known for its craft beer scene; take a brewery tour or visit a taproom to sample local brews.\n",
|
||||
"13. **Discovery Park**: Explore the largest green space in Seattle, featuring trails, beaches, and views of Puget Sound.\n",
|
||||
"\n",
|
||||
"14. **Attend a Sports Game**: Catch a Seattle Seahawks (NFL), Seattle Mariners (MLB), or Seattle Sounders (MLS) game, depending on the season.\n",
|
||||
"14. **Food Tours**: Try out Seattle’s diverse culinary scene, including fresh seafood, international cuisines, and coffee culture (don’t miss the original Starbucks!).\n",
|
||||
"\n",
|
||||
"15. **Seattle Great Wheel**: Ride this Ferris wheel on the waterfront for beautiful views, especially at night when it’s illuminated.\n",
|
||||
"15. **Attend a Sports Game**: Catch a Seahawks (NFL), Mariners (MLB), or Sounders (MLS) game for a lively local experience.\n",
|
||||
"\n",
|
||||
"These activities showcase Seattle’s vibrant culture, unique attractions, and stunning natural beauty. Enjoy your visit!\n"
|
||||
"Whether you're interested in culture, nature, food, or history, Seattle has something for everyone to enjoy!\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -322,7 +326,7 @@
|
|||
"\n",
|
||||
"runtime = SingleThreadedAgentRuntime()\n",
|
||||
"await runtime.register(\n",
|
||||
" \"simple-agent\",\n",
|
||||
" \"simple_agent\",\n",
|
||||
" lambda: SimpleAgent(\n",
|
||||
" OpenAIChatCompletionClient(\n",
|
||||
" model=\"gpt-4o-mini\",\n",
|
||||
|
@ -334,11 +338,162 @@
|
|||
"runtime.start()\n",
|
||||
"# Send a message to the agent and get the response.\n",
|
||||
"message = Message(\"Hello, what are some fun things to do in Seattle?\")\n",
|
||||
"response = await runtime.send_message(message, AgentId(\"simple-agent\", \"default\"))\n",
|
||||
"response = await runtime.send_message(message, AgentId(\"simple_agent\", \"default\"))\n",
|
||||
"print(response.content)\n",
|
||||
"# Stop the runtime processing messages.\n",
|
||||
"await runtime.stop()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Manage Model Context\n",
|
||||
"\n",
|
||||
"The above `SimpleAgent` always responds with a fresh context that contains only\n",
|
||||
"the system message and the latest user's message.\n",
|
||||
"We can use model context classes from {py:mod}`autogen_core.components.model_context`\n",
|
||||
"to make the agent \"remember\" previous conversations.\n",
|
||||
"A model context supports storage and retrieval of Chat Completion messages.\n",
|
||||
"It is always used together with a model client to generate LLM-based responses.\n",
|
||||
"\n",
|
||||
"For example, {py:mod}`~autogen_core.components.model_context.BufferedChatCompletionContext`\n",
|
||||
"is a most-recent-used (MRU) context that stores the most recent `buffer_size`\n",
|
||||
"number of messages. This is useful to avoid context overflow in many LLMs.\n",
|
||||
"\n",
|
||||
"Let's update the previous example to use\n",
|
||||
"{py:mod}`~autogen_core.components.model_context.BufferedChatCompletionContext`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from autogen_core.components.model_context import BufferedChatCompletionContext\n",
|
||||
"from autogen_core.components.models import AssistantMessage\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class SimpleAgentWithContext(RoutedAgent):\n",
|
||||
" def __init__(self, model_client: ChatCompletionClient) -> None:\n",
|
||||
" super().__init__(\"A simple agent\")\n",
|
||||
" self._system_messages = [SystemMessage(\"You are a helpful AI assistant.\")]\n",
|
||||
" self._model_client = model_client\n",
|
||||
" self._model_context = BufferedChatCompletionContext(buffer_size=5)\n",
|
||||
"\n",
|
||||
" @message_handler\n",
|
||||
" async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n",
|
||||
" # Prepare input to the chat completion model.\n",
|
||||
" user_message = UserMessage(content=message.content, source=\"user\")\n",
|
||||
" # Add message to model context.\n",
|
||||
" await self._model_context.add_message(user_message)\n",
|
||||
" # Generate a response.\n",
|
||||
" response = await self._model_client.create(\n",
|
||||
" self._system_messages + (await self._model_context.get_messages()),\n",
|
||||
" cancellation_token=ctx.cancellation_token,\n",
|
||||
" )\n",
|
||||
" # Return with the model's response.\n",
|
||||
" assert isinstance(response.content, str)\n",
|
||||
" # Add message to model context.\n",
|
||||
" await self._model_context.add_message(AssistantMessage(content=response.content, source=self.metadata[\"type\"]))\n",
|
||||
" return Message(content=response.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now let's try to ask follow up questions after the first one."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Question: Hello, what are some fun things to do in Seattle?\n",
|
||||
"Response: Seattle offers a wide variety of fun activities and attractions for visitors. Here are some highlights:\n",
|
||||
"\n",
|
||||
"1. **Pike Place Market**: Explore this iconic market, where you can find fresh produce, unique crafts, and the famous fish-throwing vendors. Don’t forget to visit the original Starbucks!\n",
|
||||
"\n",
|
||||
"2. **Space Needle**: Enjoy breathtaking views of the city and Mount Rainier from the observation deck of this iconic structure. You can also dine at the SkyCity restaurant.\n",
|
||||
"\n",
|
||||
"3. **Chihuly Garden and Glass**: Admire the stunning glass art installations created by artist Dale Chihuly. The garden and exhibit are particularly beautiful, especially in good weather.\n",
|
||||
"\n",
|
||||
"4. **Museum of Pop Culture (MoPOP)**: Dive into the world of music, science fiction, and pop culture through interactive exhibits and memorabilia.\n",
|
||||
"\n",
|
||||
"5. **Seattle Aquarium**: Located on the waterfront, the aquarium features a variety of marine life native to the Pacific Northwest, including otters and diving birds.\n",
|
||||
"\n",
|
||||
"6. **Seattle Art Museum (SAM)**: Explore a diverse collection of art from around the world, including Native American art and contemporary pieces.\n",
|
||||
"\n",
|
||||
"7. **Ballard Locks**: Watch boats travel between the Puget Sound and Lake Union, and see salmon navigating the fish ladder during spawning season.\n",
|
||||
"\n",
|
||||
"8. **Fremont Troll**: Visit this quirky public art installation located under the Aurora Bridge, where you can take fun photos with the giant troll.\n",
|
||||
"\n",
|
||||
"9. **Kerry Park**: For a picturesque view of the Seattle skyline, head to Kerry Park on Queen Anne Hill, especially at sunset.\n",
|
||||
"\n",
|
||||
"10. **Take a Ferry Ride**: Enjoy the scenic views while taking a ferry to nearby Bainbridge Island or Vashon Island for a relaxing day trip.\n",
|
||||
"\n",
|
||||
"11. **Underground Tour**: Explore Seattle’s history on an entertaining underground tour in Pioneer Square, where you’ll learn about the city’s early days.\n",
|
||||
"\n",
|
||||
"12. **Attend a Sporting Event**: Depending on the season, catch a Seattle Seahawks (NFL) game, a Seattle Mariners (MLB) game, or a Seattle Sounders (MLS) match.\n",
|
||||
"\n",
|
||||
"13. **Explore Discovery Park**: Enjoy nature with hiking trails, beach access, and stunning views of the Puget Sound and Olympic Mountains.\n",
|
||||
"\n",
|
||||
"14. **West Seattle’s Alki Beach**: Relax at this beach with beautiful views of the Seattle skyline and enjoy beachside activities like biking or kayaking.\n",
|
||||
"\n",
|
||||
"15. **Dining and Craft Beer**: Seattle has a vibrant food scene and is known for its seafood, coffee culture, and craft breweries. Make sure to explore local restaurants and breweries.\n",
|
||||
"\n",
|
||||
"There’s something for everyone in Seattle, whether you’re interested in nature, art, history, or food!\n",
|
||||
"-----\n",
|
||||
"Question: What was the first thing you mentioned?\n",
|
||||
"Response: The first thing I mentioned was **Pike Place Market**, an iconic market in Seattle where you can find fresh produce, unique crafts, and experience the famous fish-throwing vendors. It's also home to the original Starbucks and various charming shops and eateries.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"runtime = SingleThreadedAgentRuntime()\n",
|
||||
"await runtime.register(\n",
|
||||
" \"simple_agent_context\",\n",
|
||||
" lambda: SimpleAgentWithContext(\n",
|
||||
" OpenAIChatCompletionClient(\n",
|
||||
" model=\"gpt-4o-mini\",\n",
|
||||
" # api_key=\"sk-...\", # Optional if you have an OPENAI_API_KEY set in the environment.\n",
|
||||
" )\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"# Start the runtime processing messages.\n",
|
||||
"runtime.start()\n",
|
||||
"agent_id = AgentId(\"simple_agent_context\", \"default\")\n",
|
||||
"\n",
|
||||
"# First question.\n",
|
||||
"message = Message(\"Hello, what are some fun things to do in Seattle?\")\n",
|
||||
"print(f\"Question: {message.content}\")\n",
|
||||
"response = await runtime.send_message(message, agent_id)\n",
|
||||
"print(f\"Response: {response.content}\")\n",
|
||||
"print(\"-----\")\n",
|
||||
"\n",
|
||||
"# Second question.\n",
|
||||
"message = Message(\"What was the first thing you mentioned?\")\n",
|
||||
"print(f\"Question: {message.content}\")\n",
|
||||
"response = await runtime.send_message(message, agent_id)\n",
|
||||
"print(f\"Response: {response.content}\")\n",
|
||||
"\n",
|
||||
"# Stop the runtime processing messages.\n",
|
||||
"await runtime.stop()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"From the second response, you can see the agent now can recall its own previous responses."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
@ -53,7 +53,6 @@ from `Agent and Multi-Agent Application <core-concepts/agent-and-multi-agent-app
|
|||
cookbook/type-routed-agent
|
||||
cookbook/azure-openai-with-aad-auth
|
||||
cookbook/termination-with-intervention
|
||||
cookbook/buffered-memory
|
||||
cookbook/extracting-results-with-an-agent
|
||||
cookbook/openai-assistant-agent
|
||||
cookbook/langgraph-agent
|
||||
|
|
|
@ -9,12 +9,14 @@ from autogen_core.components import (
|
|||
RoutedAgent,
|
||||
message_handler,
|
||||
)
|
||||
from autogen_core.components.memory import ChatMemory
|
||||
from autogen_core.components.model_context import ChatCompletionContext
|
||||
from autogen_core.components.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_core.components.tools import Tool
|
||||
|
||||
|
@ -30,7 +32,6 @@ from ..types import (
|
|||
ToolApprovalRequest,
|
||||
ToolApprovalResponse,
|
||||
)
|
||||
from ..utils import convert_messages_to_llm_messages
|
||||
|
||||
|
||||
class ChatCompletionAgent(RoutedAgent):
|
||||
|
@ -41,7 +42,8 @@ class ChatCompletionAgent(RoutedAgent):
|
|||
description (str): The description of the agent.
|
||||
system_messages (List[SystemMessage]): The system messages to use for
|
||||
the ChatCompletion API.
|
||||
memory (ChatMemory[Message]): The memory to store and retrieve messages.
|
||||
model_context (ChatCompletionContext): The context manager for storing
|
||||
and retrieving ChatCompletion messages.
|
||||
model_client (ChatCompletionClient): The client to use for the
|
||||
ChatCompletion API.
|
||||
tools (Sequence[Tool], optional): The tools used by the agent. Defaults
|
||||
|
@ -61,7 +63,7 @@ class ChatCompletionAgent(RoutedAgent):
|
|||
self,
|
||||
description: str,
|
||||
system_messages: List[SystemMessage],
|
||||
memory: ChatMemory[Message],
|
||||
model_context: ChatCompletionContext,
|
||||
model_client: ChatCompletionClient,
|
||||
tools: Sequence[Tool] = [],
|
||||
tool_approver: AgentId | None = None,
|
||||
|
@ -70,7 +72,7 @@ class ChatCompletionAgent(RoutedAgent):
|
|||
self._description = description
|
||||
self._system_messages = system_messages
|
||||
self._client = model_client
|
||||
self._memory = memory
|
||||
self._model_context = model_context
|
||||
self._tools = tools
|
||||
self._tool_approver = tool_approver
|
||||
|
||||
|
@ -79,20 +81,20 @@ class ChatCompletionAgent(RoutedAgent):
|
|||
"""Handle a text message. This method adds the message to the memory and
|
||||
does not generate any message."""
|
||||
# Add a user message.
|
||||
await self._memory.add_message(message)
|
||||
await self._model_context.add_message(UserMessage(content=message.content, source=message.source))
|
||||
|
||||
@message_handler()
|
||||
async def on_multi_modal_message(self, message: MultiModalMessage, ctx: MessageContext) -> None:
|
||||
"""Handle a multimodal message. This method adds the message to the memory
|
||||
and does not generate any message."""
|
||||
# Add a user message.
|
||||
await self._memory.add_message(message)
|
||||
await self._model_context.add_message(UserMessage(content=message.content, source=message.source))
|
||||
|
||||
@message_handler()
|
||||
async def on_reset(self, message: Reset, ctx: MessageContext) -> None:
|
||||
"""Handle a reset message. This method clears the memory."""
|
||||
# Reset the chat messages.
|
||||
await self._memory.clear()
|
||||
await self._model_context.clear()
|
||||
|
||||
@message_handler()
|
||||
async def on_respond_now(self, message: RespondNow, ctx: MessageContext) -> TextMessage | FunctionCallMessage:
|
||||
|
@ -123,9 +125,6 @@ class ChatCompletionAgent(RoutedAgent):
|
|||
if len(self._tools) == 0:
|
||||
raise ValueError("No tools available")
|
||||
|
||||
# Add a tool call message.
|
||||
await self._memory.add_message(message)
|
||||
|
||||
# Execute the tool calls.
|
||||
results: List[FunctionExecutionResult] = []
|
||||
execution_futures: List[Coroutine[Any, Any, Tuple[str, str]]] = []
|
||||
|
@ -160,9 +159,6 @@ class ChatCompletionAgent(RoutedAgent):
|
|||
# Create a tool call result message.
|
||||
tool_call_result_msg = FunctionExecutionResultMessage(content=results)
|
||||
|
||||
# Add tool call result message.
|
||||
await self._memory.add_message(tool_call_result_msg)
|
||||
|
||||
# Return the results.
|
||||
return tool_call_result_msg
|
||||
|
||||
|
@ -172,12 +168,13 @@ class ChatCompletionAgent(RoutedAgent):
|
|||
ctx: MessageContext,
|
||||
) -> TextMessage | FunctionCallMessage:
|
||||
# Get a response from the model.
|
||||
hisorical_messages = await self._memory.get_messages()
|
||||
response = await self._client.create(
|
||||
self._system_messages + convert_messages_to_llm_messages(hisorical_messages, self.metadata["type"]),
|
||||
self._system_messages + (await self._model_context.get_messages()),
|
||||
tools=self._tools,
|
||||
json_output=response_format == ResponseFormat.json_object,
|
||||
)
|
||||
# Add the response to the chat messages context.
|
||||
await self._model_context.add_message(AssistantMessage(content=response.content, source=self.metadata["type"]))
|
||||
|
||||
# If the agent has function executor, and the response is a list of
|
||||
# tool calls, iterate with itself until we get a response that is not a
|
||||
|
@ -193,13 +190,18 @@ class ChatCompletionAgent(RoutedAgent):
|
|||
recipient=self.id,
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
if not isinstance(response, FunctionExecutionResultMessage):
|
||||
raise RuntimeError(f"Expect FunctionExecutionResultMessage but got {response}.")
|
||||
await self._model_context.add_message(response)
|
||||
# Make an assistant message from the response.
|
||||
hisorical_messages = await self._memory.get_messages()
|
||||
response = await self._client.create(
|
||||
self._system_messages + convert_messages_to_llm_messages(hisorical_messages, self.metadata["type"]),
|
||||
self._system_messages + (await self._model_context.get_messages()),
|
||||
tools=self._tools,
|
||||
json_output=response_format == ResponseFormat.json_object,
|
||||
)
|
||||
await self._model_context.add_message(
|
||||
AssistantMessage(content=response.content, source=self.metadata["type"])
|
||||
)
|
||||
|
||||
final_response: Message
|
||||
if isinstance(response.content, str):
|
||||
|
@ -211,9 +213,6 @@ class ChatCompletionAgent(RoutedAgent):
|
|||
else:
|
||||
raise ValueError(f"Unexpected response: {response.content}")
|
||||
|
||||
# Add the response to the chat messages.
|
||||
await self._memory.add_message(final_response)
|
||||
|
||||
return final_response
|
||||
|
||||
async def _execute_function(
|
||||
|
@ -253,10 +252,10 @@ class ChatCompletionAgent(RoutedAgent):
|
|||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"memory": self._memory.save_state(),
|
||||
"memory": self._model_context.save_state(),
|
||||
"system_messages": self._system_messages,
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._memory.load_state(state["memory"])
|
||||
self._model_context.load_state(state["memory"])
|
||||
self._system_messages = state["system_messages"]
|
||||
|
|
|
@ -8,7 +8,8 @@ from autogen_core.components import (
|
|||
RoutedAgent,
|
||||
message_handler,
|
||||
)
|
||||
from autogen_core.components.memory import ChatMemory
|
||||
from autogen_core.components.model_context import ChatCompletionContext
|
||||
from autogen_core.components.models import AssistantMessage, UserMessage
|
||||
|
||||
from ..types import (
|
||||
Message,
|
||||
|
@ -25,7 +26,8 @@ class ImageGenerationAgent(RoutedAgent):
|
|||
|
||||
Args:
|
||||
description (str): The description of the agent.
|
||||
memory (ChatMemory[Message]): The memory to store and retrieve messages.
|
||||
model_context (ChatCompletionContext): The context manager for storing
|
||||
and retrieving ChatCompletion messages.
|
||||
client (openai.AsyncClient): The client to use for the OpenAI API.
|
||||
model (Literal["dall-e-2", "dall-e-3"], optional): The DALL-E model to use. Defaults to "dall-e-2".
|
||||
"""
|
||||
|
@ -33,23 +35,23 @@ class ImageGenerationAgent(RoutedAgent):
|
|||
def __init__(
|
||||
self,
|
||||
description: str,
|
||||
memory: ChatMemory[Message],
|
||||
model_context: ChatCompletionContext,
|
||||
client: openai.AsyncClient,
|
||||
model: Literal["dall-e-2", "dall-e-3"] = "dall-e-2",
|
||||
):
|
||||
super().__init__(description)
|
||||
self._client = client
|
||||
self._model = model
|
||||
self._memory = memory
|
||||
self._model_context = model_context
|
||||
|
||||
@message_handler
|
||||
async def on_text_message(self, message: TextMessage, ctx: MessageContext) -> None:
|
||||
"""Handle a text message. This method adds the message to the memory."""
|
||||
await self._memory.add_message(message)
|
||||
await self._model_context.add_message(UserMessage(content=message.content, source=message.source))
|
||||
|
||||
@message_handler
|
||||
async def on_reset(self, message: Reset, ctx: MessageContext) -> None:
|
||||
await self._memory.clear()
|
||||
await self._model_context.clear()
|
||||
|
||||
@message_handler
|
||||
async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None:
|
||||
|
@ -61,14 +63,14 @@ class ImageGenerationAgent(RoutedAgent):
|
|||
await self.publish_message(response, topic_id=DefaultTopicId())
|
||||
|
||||
async def _generate_response(self, cancellation_token: CancellationToken) -> MultiModalMessage:
|
||||
messages = await self._memory.get_messages()
|
||||
messages = await self._model_context.get_messages()
|
||||
if len(messages) == 0:
|
||||
return MultiModalMessage(
|
||||
content=["I need more information to generate an image."], source=self.metadata["type"]
|
||||
)
|
||||
prompt = ""
|
||||
for m in messages:
|
||||
assert isinstance(m, TextMessage)
|
||||
assert isinstance(m.content, str)
|
||||
prompt += m.content + "\n"
|
||||
prompt.strip()
|
||||
response = await self._client.images.generate(model=self._model, prompt=prompt, response_format="b64_json")
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
from ._buffered import BufferedChatMemory
|
||||
from ._head_and_tail import HeadAndTailChatMemory
|
||||
|
||||
__all__ = ["BufferedChatMemory", "HeadAndTailChatMemory"]
|
|
@ -1,47 +0,0 @@
|
|||
from typing import Any, List, Mapping
|
||||
|
||||
from autogen_core.components.memory import ChatMemory
|
||||
from autogen_core.components.models import FunctionExecutionResultMessage
|
||||
|
||||
from ..types import Message
|
||||
|
||||
|
||||
class BufferedChatMemory(ChatMemory[Message]):
|
||||
"""A buffered chat memory that keeps a view of the last n messages,
|
||||
where n is the buffer size. The buffer size is set at initialization.
|
||||
|
||||
Args:
|
||||
buffer_size (int): The size of the buffer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size: int, initial_messages: List[Message] | None = None) -> None:
|
||||
self._messages: List[Message] = initial_messages or []
|
||||
self._buffer_size = buffer_size
|
||||
|
||||
async def add_message(self, message: Message) -> None:
|
||||
"""Add a message to the memory."""
|
||||
self._messages.append(message)
|
||||
|
||||
async def get_messages(self) -> List[Message]:
|
||||
"""Get at most `buffer_size` recent messages."""
|
||||
messages = self._messages[-self._buffer_size :]
|
||||
# Handle the first message is a function call result message.
|
||||
if messages and isinstance(messages[0], FunctionExecutionResultMessage):
|
||||
# Remove the first message from the list.
|
||||
messages = messages[1:]
|
||||
return messages
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the message memory."""
|
||||
self._messages = []
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"messages": [message for message in self._messages],
|
||||
"buffer_size": self._buffer_size,
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._messages = state["messages"]
|
||||
self._buffer_size = state["buffer_size"]
|
|
@ -1,4 +1,3 @@
|
|||
from ._group_chat_manager import GroupChatManager
|
||||
from ._orchestrator_chat import OrchestratorChat
|
||||
|
||||
__all__ = ["GroupChatManager", "OrchestratorChat"]
|
||||
__all__ = ["GroupChatManager"]
|
||||
|
|
|
@ -3,8 +3,8 @@ from typing import Any, Callable, List, Mapping
|
|||
|
||||
from autogen_core.base import AgentId, AgentProxy, MessageContext
|
||||
from autogen_core.components import RoutedAgent, message_handler
|
||||
from autogen_core.components.memory import ChatMemory
|
||||
from autogen_core.components.models import ChatCompletionClient
|
||||
from autogen_core.components.model_context import ChatCompletionContext
|
||||
from autogen_core.components.models import ChatCompletionClient, UserMessage
|
||||
|
||||
from ..types import (
|
||||
Message,
|
||||
|
@ -26,7 +26,8 @@ class GroupChatManager(RoutedAgent):
|
|||
description (str): The description of the agent.
|
||||
runtime (AgentRuntime): The runtime to register the agent.
|
||||
participants (List[AgentId]): The list of participants in the group chat.
|
||||
memory (ChatMemory[Message]): The memory to store and retrieve messages.
|
||||
model_context (ChatCompletionContext): The context manager for storing
|
||||
and retrieving ChatCompletion messages.
|
||||
model_client (ChatCompletionClient, optional): The client to use for the model.
|
||||
If provided, the agent will use the model to select the next speaker.
|
||||
If not provided, the agent will select the next speaker from the list of participants
|
||||
|
@ -45,14 +46,14 @@ class GroupChatManager(RoutedAgent):
|
|||
self,
|
||||
description: str,
|
||||
participants: List[AgentId],
|
||||
memory: ChatMemory[Message],
|
||||
model_context: ChatCompletionContext,
|
||||
model_client: ChatCompletionClient | None = None,
|
||||
termination_word: str = "TERMINATE",
|
||||
transitions: Mapping[AgentId, List[AgentId]] = {},
|
||||
on_message_received: Callable[[TextMessage | MultiModalMessage], None] | None = None,
|
||||
):
|
||||
super().__init__(description)
|
||||
self._memory = memory
|
||||
self._model_context = model_context
|
||||
self._client = model_client
|
||||
self._participants = participants
|
||||
self._participant_proxies = dict((p, AgentProxy(p, self.runtime)) for p in participants)
|
||||
|
@ -78,7 +79,7 @@ class GroupChatManager(RoutedAgent):
|
|||
@message_handler()
|
||||
async def on_reset(self, message: Reset, ctx: MessageContext) -> None:
|
||||
"""Handle a reset message. This method clears the memory."""
|
||||
await self._memory.clear()
|
||||
await self._model_context.clear()
|
||||
|
||||
@message_handler()
|
||||
async def on_new_message(self, message: TextMessage | MultiModalMessage, ctx: MessageContext) -> None:
|
||||
|
@ -94,7 +95,7 @@ class GroupChatManager(RoutedAgent):
|
|||
return
|
||||
|
||||
# Save the message to chat memory.
|
||||
await self._memory.add_message(message)
|
||||
await self._model_context.add_message(UserMessage(content=message.content, source=message.source))
|
||||
|
||||
# Get the last speaker.
|
||||
last_speaker_name = message.source
|
||||
|
@ -132,7 +133,7 @@ class GroupChatManager(RoutedAgent):
|
|||
else:
|
||||
# If a model client is provided, select the speaker based on the transitions and the model.
|
||||
speaker_index = await select_speaker(
|
||||
self._memory, self._client, [self._participant_proxies[c] for c in candidates]
|
||||
self._model_context, self._client, [self._participant_proxies[c] for c in candidates]
|
||||
)
|
||||
speaker = candidates[speaker_index]
|
||||
|
||||
|
@ -144,10 +145,10 @@ class GroupChatManager(RoutedAgent):
|
|||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"memory": self._memory.save_state(),
|
||||
"chat_history": self._model_context.save_state(),
|
||||
"termination_word": self._termination_word,
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._memory.load_state(state["memory"])
|
||||
self._model_context.load_state(state["chat_history"])
|
||||
self._termination_word = state["termination_word"]
|
||||
|
|
|
@ -4,20 +4,18 @@ import re
|
|||
from typing import Dict, List
|
||||
|
||||
from autogen_core.base import AgentProxy
|
||||
from autogen_core.components.memory import ChatMemory
|
||||
from autogen_core.components.models import ChatCompletionClient, SystemMessage
|
||||
|
||||
from ..types import Message, TextMessage
|
||||
from autogen_core.components.model_context import ChatCompletionContext
|
||||
from autogen_core.components.models import ChatCompletionClient, SystemMessage, UserMessage
|
||||
|
||||
|
||||
async def select_speaker(memory: ChatMemory[Message], client: ChatCompletionClient, agents: List[AgentProxy]) -> int:
|
||||
async def select_speaker(context: ChatCompletionContext, client: ChatCompletionClient, agents: List[AgentProxy]) -> int:
|
||||
"""Selects the next speaker in a group chat using a ChatCompletion client."""
|
||||
# TODO: Handle multi-modal messages.
|
||||
|
||||
# Construct formated current message history.
|
||||
history_messages: List[str] = []
|
||||
for msg in await memory.get_messages():
|
||||
assert isinstance(msg, TextMessage)
|
||||
for msg in await context.get_messages():
|
||||
assert isinstance(msg, UserMessage) and isinstance(msg.content, str)
|
||||
history_messages.append(f"{msg.source}: {msg.content}")
|
||||
history = "\n".join(history_messages)
|
||||
|
||||
|
|
|
@ -1,406 +0,0 @@
|
|||
import json
|
||||
from typing import Any, Sequence, Tuple
|
||||
|
||||
from autogen_core.base import AgentId, AgentRuntime, MessageContext
|
||||
from autogen_core.components import RoutedAgent, message_handler
|
||||
|
||||
from ..types import Reset, RespondNow, ResponseFormat, TextMessage
|
||||
|
||||
__all__ = ["OrchestratorChat"]
|
||||
|
||||
|
||||
class OrchestratorChat(RoutedAgent):
|
||||
def __init__(
|
||||
self,
|
||||
description: str,
|
||||
runtime: AgentRuntime,
|
||||
orchestrator: AgentId,
|
||||
planner: AgentId,
|
||||
specialists: Sequence[AgentId],
|
||||
max_turns: int = 30,
|
||||
max_stalled_turns_before_retry: int = 2,
|
||||
max_retry_attempts: int = 1,
|
||||
) -> None:
|
||||
super().__init__(description)
|
||||
self._orchestrator = orchestrator
|
||||
self._planner = planner
|
||||
self._specialists = specialists
|
||||
self._max_turns = max_turns
|
||||
self._max_stalled_turns_before_retry = max_stalled_turns_before_retry
|
||||
self._max_retry_attempts_before_educated_guess = max_retry_attempts
|
||||
|
||||
@property
|
||||
def children(self) -> Sequence[AgentId]:
|
||||
return list(self._specialists) + [self._orchestrator, self._planner]
|
||||
|
||||
@message_handler()
|
||||
async def on_text_message(
|
||||
self,
|
||||
message: TextMessage,
|
||||
ctx: MessageContext,
|
||||
) -> TextMessage:
|
||||
# A task is received.
|
||||
task = message.content
|
||||
|
||||
# Prepare the task.
|
||||
team, names, facts, plan = await self._prepare_task(task, message.source)
|
||||
|
||||
# Main loop.
|
||||
total_turns = 0
|
||||
retry_attempts = 0
|
||||
while total_turns < self._max_turns:
|
||||
# Reset all agents.
|
||||
for agent in [*self._specialists, self._orchestrator]:
|
||||
await (await self.send_message(Reset(), agent))
|
||||
|
||||
# Create the task specs.
|
||||
task_specs = f"""
|
||||
We are working to address the following user request:
|
||||
|
||||
{task}
|
||||
|
||||
|
||||
To answer this request we have assembled the following team:
|
||||
|
||||
{team}
|
||||
|
||||
Some additional points to consider:
|
||||
|
||||
{facts}
|
||||
|
||||
{plan}
|
||||
""".strip()
|
||||
|
||||
# Send the task specs to the orchestrator and specialists.
|
||||
for agent in [*self._specialists, self._orchestrator]:
|
||||
await (await self.send_message(TextMessage(content=task_specs, source=self.metadata["type"]), agent))
|
||||
|
||||
# Inner loop.
|
||||
stalled_turns = 0
|
||||
while total_turns < self._max_turns:
|
||||
# Reflect on the task.
|
||||
data = await self._reflect_on_task(task, team, names, message.source)
|
||||
|
||||
# Check if the request is satisfied.
|
||||
if data["is_request_satisfied"]["answer"]:
|
||||
return TextMessage(
|
||||
content=f"The task has been successfully addressed. {data['is_request_satisfied']['reason']}",
|
||||
source=self.metadata["type"],
|
||||
)
|
||||
|
||||
# Update stalled turns.
|
||||
if data["is_progress_being_made"]["answer"]:
|
||||
stalled_turns = max(0, stalled_turns - 1)
|
||||
else:
|
||||
stalled_turns += 1
|
||||
|
||||
# Handle retry.
|
||||
if stalled_turns > self._max_stalled_turns_before_retry:
|
||||
# In a retry, we need to rewrite the facts and the plan.
|
||||
|
||||
# Rewrite the facts.
|
||||
facts = await self._rewrite_facts(facts, message.source)
|
||||
|
||||
# Increment the retry attempts.
|
||||
retry_attempts += 1
|
||||
|
||||
# Check if we should just guess.
|
||||
if retry_attempts > self._max_retry_attempts_before_educated_guess:
|
||||
# Make an educated guess.
|
||||
educated_guess = await self._educated_guess(facts, message.source)
|
||||
if educated_guess["has_educated_guesses"]["answer"]:
|
||||
return TextMessage(
|
||||
content=f"The task is addressed with an educated guess. {educated_guess['has_educated_guesses']['reason']}",
|
||||
source=self.metadata["type"],
|
||||
)
|
||||
|
||||
# Come up with a new plan.
|
||||
plan = await self._rewrite_plan(team, message.source)
|
||||
|
||||
# Exit the inner loop.
|
||||
break
|
||||
|
||||
# Get the subtask.
|
||||
subtask = data["instruction_or_question"]["answer"]
|
||||
if subtask is None:
|
||||
subtask = ""
|
||||
|
||||
# Update agents.
|
||||
for agent in [*self._specialists, self._orchestrator]:
|
||||
_ = await (
|
||||
await self.send_message(
|
||||
TextMessage(content=subtask, source=self.metadata["type"]),
|
||||
agent,
|
||||
)
|
||||
)
|
||||
|
||||
# Find the speaker.
|
||||
try:
|
||||
speaker = next(agent for agent in self._specialists if agent.type == data["next_speaker"]["answer"])
|
||||
except StopIteration as e:
|
||||
raise ValueError(f"Invalid next speaker: {data['next_speaker']['answer']}") from e
|
||||
|
||||
# Ask speaker to speak.
|
||||
speaker_response = await (await self.send_message(RespondNow(), speaker))
|
||||
assert speaker_response is not None
|
||||
|
||||
# Update all other agents with the speaker's response.
|
||||
for agent in [agent for agent in self._specialists if agent != speaker] + [self._orchestrator]:
|
||||
await (
|
||||
await self.send_message(
|
||||
TextMessage(
|
||||
content=speaker_response.content,
|
||||
source=speaker_response.source,
|
||||
),
|
||||
agent,
|
||||
)
|
||||
)
|
||||
|
||||
# Increment the total turns.
|
||||
total_turns += 1
|
||||
|
||||
return TextMessage(
|
||||
content="The task was not addressed. The maximum number of turns was reached.",
|
||||
source=self.metadata["type"],
|
||||
)
|
||||
|
||||
async def _prepare_task(self, task: str, sender: str) -> Tuple[str, str, str, str]:
|
||||
# Reset planner.
|
||||
await (await self.send_message(Reset(), self._planner))
|
||||
|
||||
# A reusable description of the team.
|
||||
team = "\n".join(
|
||||
[
|
||||
agent.type + ": " + (await self.runtime.agent_metadata(agent))["description"]
|
||||
for agent in self._specialists
|
||||
]
|
||||
)
|
||||
names = ", ".join([agent.type for agent in self._specialists])
|
||||
|
||||
# A place to store relevant facts.
|
||||
facts = ""
|
||||
|
||||
# A plance to store the plan.
|
||||
plan = ""
|
||||
|
||||
# Start by writing what we know
|
||||
closed_book_prompt = f"""Below I will present you a request. Before we begin addressing the request, please answer the following pre-survey to the best of your ability. Keep in mind that you are Ken Jennings-level with trivia, and Mensa-level with puzzles, so there should be a deep well to draw from.
|
||||
|
||||
Here is the request:
|
||||
|
||||
{task}
|
||||
|
||||
Here is the pre-survey:
|
||||
|
||||
1. Please list any specific facts or figures that are GIVEN in the request itself. It is possible that there are none.
|
||||
2. Please list any facts that may need to be looked up, and WHERE SPECIFICALLY they might be found. In some cases, authoritative sources are mentioned in the request itself.
|
||||
3. Please list any facts that may need to be derived (e.g., via logical deduction, simulation, or computation)
|
||||
4. Please list any facts that are recalled from memory, hunches, well-reasoned guesses, etc.
|
||||
|
||||
When answering this survey, keep in mind that "facts" will typically be specific names, dates, statistics, etc. Your answer should use headings:
|
||||
|
||||
1. GIVEN OR VERIFIED FACTS
|
||||
2. FACTS TO LOOK UP
|
||||
3. FACTS TO DERIVE
|
||||
4. EDUCATED GUESSES
|
||||
""".strip()
|
||||
|
||||
# Ask the planner to obtain prior knowledge about facts.
|
||||
await (await self.send_message(TextMessage(content=closed_book_prompt, source=sender), self._planner))
|
||||
facts_response = await (await self.send_message(RespondNow(), self._planner))
|
||||
|
||||
facts = str(facts_response.content)
|
||||
|
||||
# Make an initial plan
|
||||
plan_prompt = f"""Fantastic. To address this request we have assembled the following team:
|
||||
|
||||
{team}
|
||||
|
||||
Based on the team composition, and known and unknown facts, please devise a short bullet-point plan for addressing the original request. Remember, there is no requirement to involve all team members -- a team member's particular expertise may not be needed for this task.""".strip()
|
||||
|
||||
# Send second messag eto the planner.
|
||||
await self.send_message(TextMessage(content=plan_prompt, source=sender), self._planner)
|
||||
plan_response = await (await self.send_message(RespondNow(), self._planner))
|
||||
plan = str(plan_response.content)
|
||||
|
||||
return team, names, facts, plan
|
||||
|
||||
async def _reflect_on_task(
|
||||
self,
|
||||
task: str,
|
||||
team: str,
|
||||
names: str,
|
||||
sender: str,
|
||||
) -> Any:
|
||||
step_prompt = f"""
|
||||
Recall we are working on the following request:
|
||||
|
||||
{task}
|
||||
|
||||
And we have assembled the following team:
|
||||
|
||||
{team}
|
||||
|
||||
To make progress on the request, please answer the following questions, including necessary reasoning:
|
||||
|
||||
- Is the request fully satisfied? (True if complete, or False if the original request has yet to be SUCCESSFULLY addressed)
|
||||
- Are we making forward progress? (True if just starting, or recent messages are adding value. False if recent messages show evidence of being stuck in a reasoning or action loop, or there is evidence of significant barriers to success such as the inability to read from a required file)
|
||||
- Who should speak next? (select from: {names})
|
||||
- What instruction or question would you give this team member? (Phrase as if speaking directly to them, and include any specific information they may need)
|
||||
|
||||
Please output an answer in pure JSON format according to the following schema. The JSON object must be parsable as-is. DO NOT OUTPUT ANYTHING OTHER THAN JSON, AND DO NOT DEVIATE FROM THIS SCHEMA:
|
||||
|
||||
{{
|
||||
"is_request_satisfied": {{
|
||||
"reason": string,
|
||||
"answer": boolean
|
||||
}},
|
||||
"is_progress_being_made": {{
|
||||
"reason": string,
|
||||
"answer": boolean
|
||||
}},
|
||||
"next_speaker": {{
|
||||
"reason": string,
|
||||
"answer": string (select from: {names})
|
||||
}},
|
||||
"instruction_or_question": {{
|
||||
"reason": string,
|
||||
"answer": string
|
||||
}}
|
||||
}}
|
||||
""".strip()
|
||||
request = step_prompt
|
||||
while True:
|
||||
# Send a message to the orchestrator.
|
||||
await (await self.send_message(TextMessage(content=request, source=sender), self._orchestrator))
|
||||
# Request a response.
|
||||
step_response = await (
|
||||
await self.send_message(
|
||||
RespondNow(response_format=ResponseFormat.json_object),
|
||||
self._orchestrator,
|
||||
)
|
||||
)
|
||||
# TODO: use typed dictionary.
|
||||
try:
|
||||
result = json.loads(str(step_response.content))
|
||||
except json.JSONDecodeError as e:
|
||||
request = f"Invalid JSON: {str(e)}"
|
||||
continue
|
||||
if "is_request_satisfied" not in result:
|
||||
request = "Missing key: is_request_satisfied"
|
||||
continue
|
||||
elif (
|
||||
not isinstance(result["is_request_satisfied"], dict)
|
||||
or "answer" not in result["is_request_satisfied"]
|
||||
or "reason" not in result["is_request_satisfied"]
|
||||
):
|
||||
request = "Invalid value for key: is_request_satisfied, expected 'answer' and 'reason'"
|
||||
continue
|
||||
if "is_progress_being_made" not in result:
|
||||
request = "Missing key: is_progress_being_made"
|
||||
continue
|
||||
elif (
|
||||
not isinstance(result["is_progress_being_made"], dict)
|
||||
or "answer" not in result["is_progress_being_made"]
|
||||
or "reason" not in result["is_progress_being_made"]
|
||||
):
|
||||
request = "Invalid value for key: is_progress_being_made, expected 'answer' and 'reason'"
|
||||
continue
|
||||
if "next_speaker" not in result:
|
||||
request = "Missing key: next_speaker"
|
||||
continue
|
||||
elif (
|
||||
not isinstance(result["next_speaker"], dict)
|
||||
or "answer" not in result["next_speaker"]
|
||||
or "reason" not in result["next_speaker"]
|
||||
):
|
||||
request = "Invalid value for key: next_speaker, expected 'answer' and 'reason'"
|
||||
continue
|
||||
elif result["next_speaker"]["answer"] not in names:
|
||||
request = f"Invalid value for key: next_speaker, expected 'answer' in {names}"
|
||||
continue
|
||||
if "instruction_or_question" not in result:
|
||||
request = "Missing key: instruction_or_question"
|
||||
continue
|
||||
elif (
|
||||
not isinstance(result["instruction_or_question"], dict)
|
||||
or "answer" not in result["instruction_or_question"]
|
||||
or "reason" not in result["instruction_or_question"]
|
||||
):
|
||||
request = "Invalid value for key: instruction_or_question, expected 'answer' and 'reason'"
|
||||
continue
|
||||
return result
|
||||
|
||||
async def _rewrite_facts(self, facts: str, sender: str) -> str:
|
||||
new_facts_prompt = f"""It's clear we aren't making as much progress as we would like, but we may have learned something new. Please rewrite the following fact sheet, updating it to include anything new we have learned. This is also a good time to update educated guesses (please add or update at least one educated guess or hunch, and explain your reasoning).
|
||||
|
||||
{facts}
|
||||
""".strip()
|
||||
# Send a message to the orchestrator.
|
||||
await (await self.send_message(TextMessage(content=new_facts_prompt, source=sender), self._orchestrator))
|
||||
# Request a response.
|
||||
new_facts_response = await (await self.send_message(RespondNow(), self._orchestrator))
|
||||
return str(new_facts_response.content)
|
||||
|
||||
async def _educated_guess(self, facts: str, sender: str) -> Any:
|
||||
# Make an educated guess.
|
||||
educated_guess_promt = f"""Given the following information
|
||||
|
||||
{facts}
|
||||
|
||||
Please answer the following question, including necessary reasoning:
|
||||
- Do you have two or more congruent pieces of information that will allow you to make an educated guess for the original request? The educated guess MUST answer the question.
|
||||
Please output an answer in pure JSON format according to the following schema. The JSON object must be parsable as-is. DO NOT OUTPUT ANYTHING OTHER THAN JSON, AND DO NOT DEVIATE FROM THIS SCHEMA:
|
||||
|
||||
{{
|
||||
"has_educated_guesses": {{
|
||||
"reason": string,
|
||||
"answer": boolean
|
||||
}}
|
||||
}}
|
||||
""".strip()
|
||||
request = educated_guess_promt
|
||||
while True:
|
||||
# Send a message to the orchestrator.
|
||||
await (
|
||||
await self.send_message(
|
||||
TextMessage(content=request, source=sender),
|
||||
self._orchestrator,
|
||||
)
|
||||
)
|
||||
# Request a response.
|
||||
response = await (
|
||||
await self.send_message(
|
||||
RespondNow(response_format=ResponseFormat.json_object),
|
||||
self._orchestrator,
|
||||
)
|
||||
)
|
||||
try:
|
||||
result = json.loads(str(response.content))
|
||||
except json.JSONDecodeError as e:
|
||||
request = f"Invalid JSON: {str(e)}"
|
||||
continue
|
||||
# TODO: use typed dictionary.
|
||||
if "has_educated_guesses" not in result:
|
||||
request = "Missing key: has_educated_guesses"
|
||||
continue
|
||||
if (
|
||||
not isinstance(result["has_educated_guesses"], dict)
|
||||
or "answer" not in result["has_educated_guesses"]
|
||||
or "reason" not in result["has_educated_guesses"]
|
||||
):
|
||||
request = "Invalid value for key: has_educated_guesses, expected 'answer' and 'reason'"
|
||||
continue
|
||||
return result
|
||||
|
||||
async def _rewrite_plan(self, team: str, sender: str) -> str:
|
||||
new_plan_prompt = f"""Please come up with a new plan expressed in bullet points. Keep in mind the following team composition, and do not involve any other outside people in the plan -- we cannot contact anyone else.
|
||||
|
||||
Team membership:
|
||||
{team}
|
||||
""".strip()
|
||||
# Send a message to the orchestrator.
|
||||
await (await self.send_message(TextMessage(content=new_plan_prompt, source=sender), self._orchestrator))
|
||||
# Request a response.
|
||||
new_plan_response = await (await self.send_message(RespondNow(), self._orchestrator))
|
||||
return str(new_plan_response.content)
|
|
@ -13,7 +13,8 @@ import aiofiles
|
|||
import openai
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core.base import AgentId, AgentRuntime, MessageContext
|
||||
from autogen_core.components import DefaultTopicId, RoutedAgent, message_handler
|
||||
from autogen_core.components import DefaultSubscription, DefaultTopicId, RoutedAgent, message_handler
|
||||
from autogen_core.components.model_context import BufferedChatCompletionContext
|
||||
from openai import AsyncAssistantEventHandler
|
||||
from openai.types.beta.thread import ToolResources
|
||||
from openai.types.beta.threads import Message, Text, TextDelta
|
||||
|
@ -24,7 +25,6 @@ sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
|||
|
||||
from autogen_core.base import AgentInstantiationContext
|
||||
from common.agents import OpenAIAssistantAgent
|
||||
from common.memory import BufferedChatMemory
|
||||
from common.patterns._group_chat_manager import GroupChatManager
|
||||
from common.types import PublishNow, TextMessage
|
||||
|
||||
|
@ -189,6 +189,7 @@ async def assistant_chat(runtime: AgentRuntime) -> str:
|
|||
thread_id=thread.id,
|
||||
assistant_event_handler_factory=lambda: EventHandler(),
|
||||
),
|
||||
lambda: [DefaultSubscription()],
|
||||
)
|
||||
|
||||
await runtime.register(
|
||||
|
@ -199,18 +200,20 @@ async def assistant_chat(runtime: AgentRuntime) -> str:
|
|||
thread_id=thread.id,
|
||||
vector_store_id=vector_store.id,
|
||||
),
|
||||
lambda: [DefaultSubscription()],
|
||||
)
|
||||
# Create a group chat manager to facilitate a turn-based conversation.
|
||||
await runtime.register(
|
||||
"GroupChatManager",
|
||||
lambda: GroupChatManager(
|
||||
description="A group chat manager.",
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
model_context=BufferedChatCompletionContext(buffer_size=10),
|
||||
participants=[
|
||||
AgentId("Assistant", AgentInstantiationContext.current_agent_id().key),
|
||||
AgentId("User", AgentInstantiationContext.current_agent_id().key),
|
||||
],
|
||||
),
|
||||
lambda: [DefaultSubscription()],
|
||||
)
|
||||
return "User"
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ from typing import Annotated, Literal
|
|||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core.base import AgentInstantiationContext, AgentRuntime
|
||||
from autogen_core.components import DefaultSubscription, DefaultTopicId
|
||||
from autogen_core.components.model_context import BufferedChatCompletionContext
|
||||
from autogen_core.components.models import SystemMessage
|
||||
from autogen_core.components.tools import FunctionTool
|
||||
from chess import BLACK, SQUARE_NAMES, WHITE, Board, Move
|
||||
|
@ -21,7 +22,6 @@ sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
|||
|
||||
from autogen_core.base import AgentId
|
||||
from common.agents._chat_completion_agent import ChatCompletionAgent
|
||||
from common.memory import BufferedChatMemory
|
||||
from common.patterns._group_chat_manager import GroupChatManager
|
||||
from common.types import TextMessage
|
||||
from common.utils import get_chat_completion_client_from_envs
|
||||
|
@ -170,7 +170,7 @@ async def chess_game(runtime: AgentRuntime) -> None: # type: ignore
|
|||
"Think about your strategy and call make_move(thinking, move) to make a move."
|
||||
),
|
||||
],
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
model_context=BufferedChatCompletionContext(buffer_size=10),
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o"),
|
||||
tools=black_tools,
|
||||
),
|
||||
|
@ -188,7 +188,7 @@ async def chess_game(runtime: AgentRuntime) -> None: # type: ignore
|
|||
"Think about your strategy and call make_move(thinking, move) to make a move."
|
||||
),
|
||||
],
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
model_context=BufferedChatCompletionContext(buffer_size=10),
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o"),
|
||||
tools=white_tools,
|
||||
),
|
||||
|
@ -200,12 +200,13 @@ async def chess_game(runtime: AgentRuntime) -> None: # type: ignore
|
|||
"ChessGame",
|
||||
lambda: GroupChatManager(
|
||||
description="A chess game between two agents.",
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
model_context=BufferedChatCompletionContext(buffer_size=10),
|
||||
participants=[
|
||||
AgentId("PlayerWhite", AgentInstantiationContext.current_agent_id().key),
|
||||
AgentId("PlayerBlack", AgentInstantiationContext.current_agent_id().key),
|
||||
], # white goes first
|
||||
),
|
||||
lambda: [DefaultSubscription()],
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -24,23 +24,22 @@ slow external system that the agent needs to interact with.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import datetime
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Mapping, Optional
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core.base import AgentId, CancellationToken, MessageContext
|
||||
from autogen_core.base.intervention import DefaultInterventionHandler
|
||||
from autogen_core.components import DefaultSubscription, DefaultTopicId, RoutedAgent, message_handler
|
||||
from autogen_core.components import FunctionCall
|
||||
from autogen_core.components import DefaultSubscription, DefaultTopicId, FunctionCall, RoutedAgent, message_handler
|
||||
from autogen_core.components.model_context import BufferedChatCompletionContext
|
||||
from autogen_core.components.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
LLMMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
@ -49,7 +48,6 @@ from pydantic import BaseModel, Field
|
|||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from common.memory import BufferedChatMemory
|
||||
from common.types import TextMessage
|
||||
from common.utils import get_chat_completion_client_from_envs
|
||||
|
||||
|
@ -95,24 +93,24 @@ class SlowUserProxyAgent(RoutedAgent):
|
|||
description: str,
|
||||
) -> None:
|
||||
super().__init__(description)
|
||||
self._memory = BufferedChatMemory(buffer_size=5)
|
||||
self._model_context = BufferedChatCompletionContext(buffer_size=5)
|
||||
self._name = name
|
||||
|
||||
@message_handler
|
||||
async def handle_message(self, message: AssistantTextMessage, ctx: MessageContext) -> None:
|
||||
await self._memory.add_message(message)
|
||||
await self._model_context.add_message(AssistantMessage(content=message.content, source=message.source))
|
||||
await self.publish_message(
|
||||
GetSlowUserMessage(content=message.content), topic_id=DefaultTopicId("scheduling_assistant_conversation")
|
||||
)
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
state_to_save = {
|
||||
"memory": self._memory.save_state(),
|
||||
"memory": self._model_context.save_state(),
|
||||
}
|
||||
return state_to_save
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._memory.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
|
||||
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
|
||||
|
||||
|
||||
class ScheduleMeetingInput(BaseModel):
|
||||
|
@ -148,8 +146,11 @@ class SchedulingAssistantAgent(RoutedAgent):
|
|||
initial_message: AssistantTextMessage | None = None,
|
||||
) -> None:
|
||||
super().__init__(description)
|
||||
self._memory = BufferedChatMemory(
|
||||
buffer_size=5, initial_messages=[initial_message] if initial_message else None
|
||||
self._model_context = BufferedChatCompletionContext(
|
||||
buffer_size=5,
|
||||
initial_messages=[UserMessage(content=initial_message.content, source=initial_message.source)]
|
||||
if initial_message
|
||||
else None,
|
||||
)
|
||||
self._name = name
|
||||
self._model_client = model_client
|
||||
|
@ -164,20 +165,12 @@ Today's date is {datetime.datetime.now().strftime("%Y-%m-%d")}
|
|||
|
||||
@message_handler
|
||||
async def handle_message(self, message: UserTextMessage, ctx: MessageContext) -> None:
|
||||
await self._memory.add_message(message)
|
||||
llm_messages: List[LLMMessage] = []
|
||||
memory_messages = await self._memory.get_messages()
|
||||
for m in memory_messages:
|
||||
assert isinstance(m, TextMessage), f"Expected TextMessage, but got {
|
||||
type(m)}"
|
||||
if m.source == self.metadata["type"]:
|
||||
llm_messages.append(AssistantMessage(content=m.content, source=self.metadata["type"]))
|
||||
else:
|
||||
llm_messages.append(UserMessage(content=m.content, source=m.source))
|
||||
llm_messages.append(UserMessage(content=message.content, source=message.source))
|
||||
await self._model_context.add_message(UserMessage(content=message.content, source=message.source))
|
||||
|
||||
tools = [ScheduleMeetingTool()]
|
||||
response = await self._model_client.create(self._system_messages + llm_messages, tools=tools)
|
||||
response = await self._model_client.create(
|
||||
self._system_messages + (await self._model_context.get_messages()), tools=tools
|
||||
)
|
||||
|
||||
if isinstance(response.content, list) and all(isinstance(item, FunctionCall) for item in response.content):
|
||||
for call in response.content:
|
||||
|
@ -194,17 +187,17 @@ Today's date is {datetime.datetime.now().strftime("%Y-%m-%d")}
|
|||
|
||||
assert isinstance(response.content, str)
|
||||
speech = AssistantTextMessage(content=response.content, source=self.metadata["type"])
|
||||
await self._memory.add_message(speech)
|
||||
await self._model_context.add_message(AssistantMessage(content=response.content, source=self.metadata["type"]))
|
||||
|
||||
await self.publish_message(speech, topic_id=DefaultTopicId("scheduling_assistant_conversation"))
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"memory": self._memory.save_state(),
|
||||
"memory": self._model_context.save_state(),
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._memory.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
|
||||
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
|
||||
|
||||
|
||||
class NeedsUserInputHandler(DefaultInterventionHandler):
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
from ._base import ChatMemory
|
||||
|
||||
__all__ = ["ChatMemory"]
|
|
@ -1,19 +0,0 @@
|
|||
from typing import List, Mapping, Protocol, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ChatMemory(Protocol[T]):
|
||||
"""A protocol for defining the interface of a chat memory. A chat memory
|
||||
lets agents store and retrieve messages. It can be implemented with
|
||||
different memory recall strategies."""
|
||||
|
||||
async def add_message(self, message: T) -> None: ...
|
||||
|
||||
async def get_messages(self) -> List[T]: ...
|
||||
|
||||
async def clear(self) -> None: ...
|
||||
|
||||
def save_state(self) -> Mapping[str, T]: ...
|
||||
|
||||
def load_state(self, state: Mapping[str, T]) -> None: ...
|
|
@ -0,0 +1,5 @@
|
|||
from ._buffered_chat_completion_context import BufferedChatCompletionContext
|
||||
from ._chat_completion_context import ChatCompletionContext
|
||||
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
|
||||
|
||||
__all__ = ["ChatCompletionContext", "BufferedChatCompletionContext", "HeadAndTailChatCompletionContext"]
|
|
@ -1,17 +1,11 @@
|
|||
# Buffered Memory
|
||||
|
||||
Here is an example of a custom memory implementation that keeps a view of the
|
||||
last N messages:
|
||||
|
||||
```python
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
from autogen_core.components.memory import ChatMemory
|
||||
from autogen_core.components.models import FunctionExecutionResultMessage, LLMMessage
|
||||
from ..models import FunctionExecutionResultMessage, LLMMessage
|
||||
from ._chat_completion_context import ChatCompletionContext
|
||||
|
||||
|
||||
class BufferedChatMemory(ChatMemory[LLMMessage]):
|
||||
"""A buffered chat memory that keeps a view of the last n messages,
|
||||
class BufferedChatCompletionContext(ChatCompletionContext):
|
||||
"""A buffered chat completion context that keeps a view of the last n messages,
|
||||
where n is the buffer size. The buffer size is set at initialization.
|
||||
|
||||
Args:
|
||||
|
@ -19,8 +13,8 @@ class BufferedChatMemory(ChatMemory[LLMMessage]):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size: int) -> None:
|
||||
self._messages: List[LLMMessage] = []
|
||||
def __init__(self, buffer_size: int, initial_messages: List[LLMMessage] | None = None) -> None:
|
||||
self._messages: List[LLMMessage] = initial_messages or []
|
||||
self._buffer_size = buffer_size
|
||||
|
||||
async def add_message(self, message: LLMMessage) -> None:
|
|
@ -0,0 +1,19 @@
|
|||
from typing import List, Mapping, Protocol
|
||||
|
||||
from ..models import LLMMessage
|
||||
|
||||
|
||||
class ChatCompletionContext(Protocol):
|
||||
"""A protocol for defining the interface of a chat completion context.
|
||||
A chat completion context lets agents store and retrieve LLM messages.
|
||||
It can be implemented with different recall strategies."""
|
||||
|
||||
async def add_message(self, message: LLMMessage) -> None: ...
|
||||
|
||||
async def get_messages(self) -> List[LLMMessage]: ...
|
||||
|
||||
async def clear(self) -> None: ...
|
||||
|
||||
def save_state(self) -> Mapping[str, LLMMessage]: ...
|
||||
|
||||
def load_state(self, state: Mapping[str, LLMMessage]) -> None: ...
|
|
@ -1,13 +1,12 @@
|
|||
from typing import Any, List, Mapping
|
||||
|
||||
from autogen_core.components.memory import ChatMemory
|
||||
from autogen_core.components.models import FunctionExecutionResultMessage
|
||||
|
||||
from ..types import FunctionCallMessage, Message, TextMessage
|
||||
from .._types import FunctionCall
|
||||
from ..models import AssistantMessage, FunctionExecutionResultMessage, LLMMessage, UserMessage
|
||||
from ._chat_completion_context import ChatCompletionContext
|
||||
|
||||
|
||||
class HeadAndTailChatMemory(ChatMemory[Message]):
|
||||
"""A chat memory that keeps a view of the first n and last m messages,
|
||||
class HeadAndTailChatCompletionContext(ChatCompletionContext):
|
||||
"""A chat completion context that keeps a view of the first n and last m messages,
|
||||
where n is the head size and m is the tail size. The head and tail sizes
|
||||
are set at initialization.
|
||||
|
||||
|
@ -17,19 +16,24 @@ class HeadAndTailChatMemory(ChatMemory[Message]):
|
|||
"""
|
||||
|
||||
def __init__(self, head_size: int, tail_size: int) -> None:
|
||||
self._messages: List[Message] = []
|
||||
self._messages: List[LLMMessage] = []
|
||||
self._head_size = head_size
|
||||
self._tail_size = tail_size
|
||||
|
||||
async def add_message(self, message: Message) -> None:
|
||||
async def add_message(self, message: LLMMessage) -> None:
|
||||
"""Add a message to the memory."""
|
||||
self._messages.append(message)
|
||||
|
||||
async def get_messages(self) -> List[Message]:
|
||||
async def get_messages(self) -> List[LLMMessage]:
|
||||
"""Get at most `head_size` recent messages and `tail_size` oldest messages."""
|
||||
head_messages = self._messages[: self._head_size]
|
||||
# Handle the last message is a function call message.
|
||||
if head_messages and isinstance(head_messages[-1], FunctionCallMessage):
|
||||
if (
|
||||
head_messages
|
||||
and isinstance(head_messages[-1], AssistantMessage)
|
||||
and isinstance(head_messages[-1].content, list)
|
||||
and all(isinstance(item, FunctionCall) for item in head_messages[-1].content)
|
||||
):
|
||||
# Remove the last message from the head.
|
||||
head_messages = head_messages[:-1]
|
||||
|
||||
|
@ -45,7 +49,7 @@ class HeadAndTailChatMemory(ChatMemory[Message]):
|
|||
# return all messages.
|
||||
return self._messages
|
||||
|
||||
placeholder_messages = [TextMessage(content=f"Skipped {num_skipped} messages.", source="System")]
|
||||
placeholder_messages = [UserMessage(content=f"Skipped {num_skipped} messages.", source="System")]
|
||||
return head_messages + placeholder_messages + tail_messages
|
||||
|
||||
async def clear(self) -> None:
|
|
@ -0,0 +1,50 @@
|
|||
from typing import List
|
||||
|
||||
import pytest
|
||||
from autogen_core.components.model_context import BufferedChatCompletionContext, HeadAndTailChatCompletionContext
|
||||
from autogen_core.components.models import AssistantMessage, LLMMessage, UserMessage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffered_model_context() -> None:
|
||||
model_context = BufferedChatCompletionContext(buffer_size=2)
|
||||
messages: List[LLMMessage] = [
|
||||
UserMessage(content="Hello!", source="user"),
|
||||
AssistantMessage(content="What can I do for you?", source="assistant"),
|
||||
UserMessage(content="Tell what are some fun things to do in seattle.", source="user"),
|
||||
]
|
||||
await model_context.add_message(messages[0])
|
||||
await model_context.add_message(messages[1])
|
||||
await model_context.add_message(messages[2])
|
||||
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 2
|
||||
assert retrieved[0] == messages[1]
|
||||
assert retrieved[1] == messages[2]
|
||||
|
||||
await model_context.clear()
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_head_and_tail_model_context() -> None:
|
||||
model_context = HeadAndTailChatCompletionContext(head_size=1, tail_size=1)
|
||||
messages: List[LLMMessage] = [
|
||||
UserMessage(content="Hello!", source="user"),
|
||||
AssistantMessage(content="What can I do for you?", source="assistant"),
|
||||
UserMessage(content="Tell what are some fun things to do in seattle.", source="user"),
|
||||
AssistantMessage(content="Pike place, space needle, mt rainer", source="assistant"),
|
||||
UserMessage(content="More places?", source="user"),
|
||||
]
|
||||
for msg in messages:
|
||||
await model_context.add_message(msg)
|
||||
|
||||
retrived = await model_context.get_messages()
|
||||
assert len(retrived) == 3 # 1 head, 1 tail + 1 placeholder.
|
||||
assert retrived[0] == messages[0]
|
||||
assert retrived[2] == messages[-1]
|
||||
|
||||
await model_context.clear()
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 0
|
Loading…
Reference in New Issue