mirror of https://github.com/microsoft/autogen.git
Proxy PR for Long Context Capability 1513 (#1591)
* Add new capability to handle long context * Make print conditional * Remove superfluous comment * Fix msg order * Allow user to specify max_tokens * Add ability to specify max_tokens per message; improve name * Improve doc and readability * Add tests * Improve documentation and add tests per Erik and Chi's feedback * Update notebook * Update doc string of add to agents * Improve doc string * improve notebook * Update github workflows for context handling * Update docstring * update notebook to use raw config list. * Update contrib-openai.yml remove _target * Fix code formatting * Fix workflow file * Update .github/workflows/contrib-openai.yml --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
parent
a3c3317faa
commit
47d6c7567e
|
@ -260,3 +260,42 @@ jobs:
|
|||
with:
|
||||
file: ./coverage.xml
|
||||
flags: unittests
|
||||
ContextHandling:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: ["3.11"]
|
||||
runs-on: ${{ matrix.os }}
|
||||
environment: openai1
|
||||
steps:
|
||||
# checkout to pr branch
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install packages and dependencies
|
||||
run: |
|
||||
docker --version
|
||||
python -m pip install --upgrade pip wheel
|
||||
pip install -e .
|
||||
python -c "import autogen"
|
||||
pip install coverage pytest
|
||||
- name: Coverage
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
|
||||
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
|
||||
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
|
||||
BING_API_KEY: ${{ secrets.BING_API_KEY }}
|
||||
run: |
|
||||
coverage run -a -m pytest test/agentchat/contrib/capabilities/test_context_handling.py
|
||||
coverage xml
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
with:
|
||||
file: ./coverage.xml
|
||||
flags: unittests
|
||||
|
|
|
@ -253,3 +253,40 @@ jobs:
|
|||
with:
|
||||
file: ./coverage.xml
|
||||
flags: unittests
|
||||
|
||||
ContextHandling:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-2019]
|
||||
python-version: ["3.11"]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install packages and dependencies for all tests
|
||||
run: |
|
||||
python -m pip install --upgrade pip wheel
|
||||
pip install pytest
|
||||
- name: Install packages and dependencies for Context Handling
|
||||
run: |
|
||||
pip install -e .
|
||||
- name: Set AUTOGEN_USE_DOCKER based on OS
|
||||
shell: bash
|
||||
run: |
|
||||
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
|
||||
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
|
||||
fi
|
||||
- name: Coverage
|
||||
run: |
|
||||
pip install coverage>=5.3
|
||||
coverage run -a -m pytest test/agentchat/contrib/capabilities/test_context_handling.py --skip-openai
|
||||
coverage xml
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
with:
|
||||
file: ./coverage.xml
|
||||
flags: unittests
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
from .teachability import Teachability
|
||||
from .agent_capability import AgentCapability
|
||||
|
||||
|
||||
__all__ = ["Teachability", "AgentCapability"]
|
|
@ -0,0 +1,108 @@
|
|||
import sys
|
||||
from termcolor import colored
|
||||
from typing import Dict, Optional, List
|
||||
from autogen import ConversableAgent
|
||||
from autogen import token_count_utils
|
||||
|
||||
|
||||
class TransformChatHistory:
|
||||
"""
|
||||
An agent's chat history with other agents is a common context that it uses to generate a reply.
|
||||
This capability allows the agent to transform its chat history prior to using it to generate a reply.
|
||||
It does not permanently modify the chat history, but rather processes it on every invocation.
|
||||
|
||||
This capability class enables various strategies to transform chat history, such as:
|
||||
- Truncate messages: Truncate each message to first maximum number of tokens.
|
||||
- Limit number of messages: Truncate the chat history to a maximum number of (recent) messages.
|
||||
- Limit number of tokens: Truncate the chat history to number of recent N messages that fit in
|
||||
maximum number of tokens.
|
||||
Note that the system message, because of its special significance, is always kept as is.
|
||||
|
||||
The three strategies can be combined. For example, when each of these parameters are specified
|
||||
they are used in the following order:
|
||||
1. First truncate messages to a maximum number of tokens
|
||||
2. Second, it limits the number of message to keep
|
||||
3. Third, it limits the total number of tokens in the chat history
|
||||
|
||||
Args:
|
||||
max_tokens_per_message (Optional[int]): Maximum number of tokens to keep in each message.
|
||||
max_messages (Optional[int]): Maximum number of messages to keep in the context.
|
||||
max_tokens (Optional[int]): Maximum number of tokens to keep in the context.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_tokens_per_message: Optional[int] = None,
|
||||
max_messages: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
self.max_tokens_per_message = max_tokens_per_message if max_tokens_per_message else sys.maxsize
|
||||
self.max_messages = max_messages if max_messages else sys.maxsize
|
||||
self.max_tokens = max_tokens if max_tokens else sys.maxsize
|
||||
|
||||
def add_to_agent(self, agent: ConversableAgent):
|
||||
"""
|
||||
Adds TransformChatHistory capability to the given agent.
|
||||
"""
|
||||
agent.register_hook(hookable_method=agent.process_all_messages, hook=self._transform_messages)
|
||||
|
||||
def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
Args:
|
||||
messages: List of messages to process.
|
||||
|
||||
Returns:
|
||||
List of messages with the first system message and the last max_messages messages.
|
||||
"""
|
||||
processed_messages = []
|
||||
messages = messages.copy()
|
||||
rest_messages = messages
|
||||
|
||||
# check if the first message is a system message and append it to the processed messages
|
||||
if len(messages) > 0:
|
||||
if messages[0]["role"] == "system":
|
||||
msg = messages[0]
|
||||
processed_messages.append(msg)
|
||||
rest_messages = messages[1:]
|
||||
|
||||
processed_messages_tokens = 0
|
||||
for msg in messages:
|
||||
msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message)
|
||||
|
||||
# iterate through rest of the messages and append them to the processed messages
|
||||
for msg in rest_messages[-self.max_messages :]:
|
||||
msg_tokens = token_count_utils.count_token(msg["content"])
|
||||
if processed_messages_tokens + msg_tokens > self.max_tokens:
|
||||
break
|
||||
processed_messages.append(msg)
|
||||
processed_messages_tokens += msg_tokens
|
||||
|
||||
total_tokens = 0
|
||||
for msg in messages:
|
||||
total_tokens += token_count_utils.count_token(msg["content"])
|
||||
|
||||
num_truncated = len(messages) - len(processed_messages)
|
||||
if num_truncated > 0 or total_tokens > processed_messages_tokens:
|
||||
print(colored(f"Truncated {len(messages) - len(processed_messages)} messages.", "yellow"))
|
||||
print(colored(f"Truncated {total_tokens - processed_messages_tokens} tokens.", "yellow"))
|
||||
return processed_messages
|
||||
|
||||
|
||||
def truncate_str_to_tokens(text: str, max_tokens: int) -> str:
|
||||
"""
|
||||
Truncate a string so that number of tokens in less than max_tokens.
|
||||
|
||||
Args:
|
||||
content: String to process.
|
||||
max_tokens: Maximum number of tokens to keep.
|
||||
|
||||
Returns:
|
||||
Truncated string.
|
||||
"""
|
||||
truncated_string = ""
|
||||
for char in text:
|
||||
truncated_string += char
|
||||
if token_count_utils.count_token(truncated_string) == max_tokens:
|
||||
break
|
||||
return truncated_string
|
|
@ -194,7 +194,7 @@ class ConversableAgent(Agent):
|
|||
|
||||
# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
|
||||
# New hookable methods should be added to this list as required to support new agent capabilities.
|
||||
self.hook_lists = {self.process_last_message: []} # This is currently the only hookable method.
|
||||
self.hook_lists = {self.process_last_message: [], self.process_all_messages: []}
|
||||
|
||||
def register_reply(
|
||||
self,
|
||||
|
@ -1528,6 +1528,10 @@ class ConversableAgent(Agent):
|
|||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
|
||||
# Call the hookable method that gives registered hooks a chance to process all messages.
|
||||
# Message modifications do not affect the incoming messages or self._oai_messages.
|
||||
messages = self.process_all_messages(messages)
|
||||
|
||||
# Call the hookable method that gives registered hooks a chance to process the last message.
|
||||
# Message modifications do not affect the incoming messages or self._oai_messages.
|
||||
messages = self.process_last_message(messages)
|
||||
|
@ -1584,6 +1588,10 @@ class ConversableAgent(Agent):
|
|||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
|
||||
# Call the hookable method that gives registered hooks a chance to process all messages.
|
||||
# Message modifications do not affect the incoming messages or self._oai_messages.
|
||||
messages = self.process_all_messages(messages)
|
||||
|
||||
# Call the hookable method that gives registered hooks a chance to process the last message.
|
||||
# Message modifications do not affect the incoming messages or self._oai_messages.
|
||||
messages = self.process_last_message(messages)
|
||||
|
@ -2211,6 +2219,21 @@ class ConversableAgent(Agent):
|
|||
assert hook not in hook_list, f"{hook} is already registered as a hook."
|
||||
hook_list.append(hook)
|
||||
|
||||
def process_all_messages(self, messages: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
Calls any registered capability hooks to process all messages, potentially modifying the messages.
|
||||
"""
|
||||
hook_list = self.hook_lists[self.process_all_messages]
|
||||
# If no hooks are registered, or if there are no messages to process, return the original message list.
|
||||
if len(hook_list) == 0 or messages is None:
|
||||
return messages
|
||||
|
||||
# Call each hook (in order of registration) to process the messages.
|
||||
processed_messages = messages
|
||||
for hook in hook_list:
|
||||
processed_messages = hook(processed_messages)
|
||||
return processed_messages
|
||||
|
||||
def process_last_message(self, messages):
|
||||
"""
|
||||
Calls any registered capability hooks to use and potentially modify the text of the last message,
|
||||
|
|
|
@ -0,0 +1,605 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Handling A Long Context via `TransformChatHistory`\n",
|
||||
"\n",
|
||||
"This notebook illustrates how you can use the `TransformChatHistory` capability to give any `Conversable` agent an ability to handle a long context. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"## Uncomment to install pyautogen if you don't have it already\n",
|
||||
"#! pip install pyautogen"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import autogen\n",
|
||||
"from autogen.agentchat.contrib.capabilities import context_handling"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To add this ability to any agent, define the capability and then use `add_to_agent`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
|
||||
"\n",
|
||||
"plot and save a graph of x^2 from -10 to 10\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"Here's a Python code snippet to plot and save a graph of x^2 from -10 to 10:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"# Generate x values from -10 to 10\n",
|
||||
"x = np.linspace(-10, 10, 100)\n",
|
||||
"\n",
|
||||
"# Evaluate y values using x^2\n",
|
||||
"y = x**2\n",
|
||||
"\n",
|
||||
"# Plot the graph\n",
|
||||
"plt.plot(x, y)\n",
|
||||
"\n",
|
||||
"# Set labels and title\n",
|
||||
"plt.xlabel('x')\n",
|
||||
"plt.ylabel('y')\n",
|
||||
"plt.title('Graph of y = x^2')\n",
|
||||
"\n",
|
||||
"# Save the graph as an image file (e.g., PNG)\n",
|
||||
"plt.savefig('x_squared_graph.png')\n",
|
||||
"\n",
|
||||
"# Show the graph\n",
|
||||
"plt.show()\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Please make sure to have the `matplotlib` library installed in your Python environment. After executing the code, the graph will be saved as \"x_squared_graph.png\" in the current directory. You can change the filename if desired.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
|
||||
"\n",
|
||||
"exitcode: 0 (execution succeeded)\n",
|
||||
"Code output: \n",
|
||||
"Figure(640x480)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"Great! The code executed successfully and generated a graph of x^2 from -10 to 10. You can save the graph by adding the following code snippet:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"# Save the graph as a file\n",
|
||||
"plt.savefig('graph.png')\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"This will save the graph as a PNG file named \"graph.png\" in your current working directory.\n",
|
||||
"\n",
|
||||
"Now, you can check the saved graph in your current directory. Let me know if you need any further assistance.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
|
||||
"\n",
|
||||
"exitcode: 1 (execution failed)\n",
|
||||
"Code output: \n",
|
||||
"Traceback (most recent call last):\n",
|
||||
" File \"\", line 2, in <module>\n",
|
||||
" plt.savefig('graph.png')\n",
|
||||
"NameError: name 'plt' is not defined\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"I apologize for the mistake. It seems that the `plt` module was not imported correctly. Please make sure that you have the `matplotlib` library installed. You can install it using the following command:\n",
|
||||
"\n",
|
||||
"```sh\n",
|
||||
"pip install matplotlib\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Once you have the library installed, please try running the code again. If you are still facing any issues, please let me know.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is sh)...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
|
||||
"\n",
|
||||
"exitcode: 0 (execution succeeded)\n",
|
||||
"Code output: \n",
|
||||
"Defaulting to user installation because normal site-packages is not writeable\n",
|
||||
"Requirement already satisfied: matplotlib in /home/vscode/.local/lib/python3.10/site-packages (3.8.2)\n",
|
||||
"Requirement already satisfied: pyparsing>=2.3.1 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (3.1.1)\n",
|
||||
"Requirement already satisfied: fonttools>=4.22.0 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (4.47.2)\n",
|
||||
"Requirement already satisfied: python-dateutil>=2.7 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (2.8.2)\n",
|
||||
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/site-packages (from matplotlib) (23.2)\n",
|
||||
"Requirement already satisfied: pillow>=8 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (10.2.0)\n",
|
||||
"Requirement already satisfied: cycler>=0.10 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (0.12.1)\n",
|
||||
"Requirement already satisfied: kiwisolver>=1.3.1 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (1.4.5)\n",
|
||||
"Requirement already satisfied: numpy<2,>=1.21 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (1.26.3)\n",
|
||||
"Requirement already satisfied: contourpy>=1.0.1 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (1.2.0)\n",
|
||||
"Requirement already satisfied: six>=1.5 in /home/vscode/.local/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"Thank you for installing the `matplotlib` library. Let's try running the code again:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"# Generate x values from -10 to 10\n",
|
||||
"x = np.linspace(-10, 10, 100)\n",
|
||||
"\n",
|
||||
"# Compute the y values (x^2)\n",
|
||||
"y = x**2\n",
|
||||
"\n",
|
||||
"# Plot the graph\n",
|
||||
"plt.plot(x, y)\n",
|
||||
"\n",
|
||||
"# Save the graph as a file\n",
|
||||
"plt.savefig('graph.png')\n",
|
||||
"\n",
|
||||
"# Show the graph\n",
|
||||
"plt.show()\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Please give it a try and let me know if it works for you.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
|
||||
"\n",
|
||||
"exitcode: 0 (execution succeeded)\n",
|
||||
"Code output: \n",
|
||||
"Figure(640x480)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"It seems that the code executed successfully and generated a figure of x^2 from -10 to 10. To save the graph as an image file, you can add the following code snippet:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"# Save the graph as a file\n",
|
||||
"plt.savefig('graph.png')\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"This will save the graph as a PNG file named \"graph.png\" in the current directory. You can change the file name or modify the file format as needed.\n",
|
||||
"Remember to replace `plt.savefig('graph.png')` with your desired file name if necessary.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
|
||||
"\n",
|
||||
"exitcode: 1 (execution failed)\n",
|
||||
"Code output: \n",
|
||||
"Traceback (most recent call last):\n",
|
||||
" File \"\", line 2, in <module>\n",
|
||||
" plt.savefig('graph.png')\n",
|
||||
"NameError: name 'plt' is not defined\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mTruncated 1 messages.\u001b[0m\n",
|
||||
"\u001b[33mTruncated 15 tokens.\u001b[0m\n",
|
||||
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"My apologies for the mistake again. It seems that the `plt` module was not imported correctly. Please make sure that you have imported the `matplotlib.pyplot` module.\n",
|
||||
"\n",
|
||||
"Here is the correct code snippet:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"# Generate x values from -10 to 10\n",
|
||||
"x = np.linspace(-10, 10, 100)\n",
|
||||
"y = x ** 2\n",
|
||||
"\n",
|
||||
"# Create a plot of x^2\n",
|
||||
"plt.plot(x, y)\n",
|
||||
"\n",
|
||||
"# Save the graph as a file\n",
|
||||
"plt.savefig('graph.png')\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Please try running this code. It should generate the graph and save it as `graph.png` in the current working directory.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
|
||||
"\n",
|
||||
"exitcode: 0 (execution succeeded)\n",
|
||||
"Code output: \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mTruncated 3 messages.\u001b[0m\n",
|
||||
"\u001b[33mTruncated 84 tokens.\u001b[0m\n",
|
||||
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"Great! The code executed successfully and generated a graph of x^2 from -10 to 10. You can save the graph by adding the following code snippet:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"# Save the graph as a file\n",
|
||||
"plt.savefig('graph.png')\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"This will save the graph as a PNG image file named \"graph.png\" in the same directory as your Python script. You can change the filename and format (e.g., 'graph.jpg', 'graph.pdf') as needed.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
|
||||
"\n",
|
||||
"exitcode: 1 (execution failed)\n",
|
||||
"Code output: \n",
|
||||
"Traceback (most recent call last):\n",
|
||||
" File \"\", line 2, in <module>\n",
|
||||
" plt.savefig('graph.png')\n",
|
||||
"NameError: name 'plt' is not defined\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mTruncated 5 messages.\u001b[0m\n",
|
||||
"\u001b[33mTruncated 184 tokens.\u001b[0m\n",
|
||||
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"I apologize for the confusion. It looks like the code didn't execute as expected. \n",
|
||||
"\n",
|
||||
"To save the graph in a file, try the following code:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"# Generate x values from -10 to 10\n",
|
||||
"x = np.linspace(-10, 10, 100)\n",
|
||||
"\n",
|
||||
"# Calculate the y values (square of x)\n",
|
||||
"y = x ** 2\n",
|
||||
"\n",
|
||||
"# Plot the graph\n",
|
||||
"plt.plot(x, y)\n",
|
||||
"\n",
|
||||
"# Save the graph as a file\n",
|
||||
"plt.savefig('graph.png')\n",
|
||||
"\n",
|
||||
"# Show the graph\n",
|
||||
"plt.show()\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Make sure you have the `matplotlib` library installed. Let me know if you encounter any further issues.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
|
||||
"\n",
|
||||
"exitcode: 0 (execution succeeded)\n",
|
||||
"Code output: \n",
|
||||
"Figure(640x480)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mTruncated 7 messages.\u001b[0m\n",
|
||||
"\u001b[33mTruncated 284 tokens.\u001b[0m\n",
|
||||
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"Great! The code executed successfully and generated a graph of x^2 from -10 to 10. You can save the graph by adding the following code snippet:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"# Save the graph as a file\n",
|
||||
"plt.savefig('graph.png')\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mTruncated 9 messages.\u001b[0m\n",
|
||||
"\u001b[33mTruncated 353 tokens.\u001b[0m\n",
|
||||
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"Sorry for the mistake again. The code snippet provided was incomplete. To correctly save the graph as an image file, you need to add a line of code to close the plot after saving. Here's the updated code:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"# Save the graph as a file\n",
|
||||
"plt.savefig('graph.png')\n",
|
||||
"plt.close()\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Please try running the complete code again and it should save the graph as \"graph.png\" in the current directory. Let me know if you encounter any further issues.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
|
||||
"\n",
|
||||
"exitcode: 1 (execution failed)\n",
|
||||
"Code output: \n",
|
||||
"Traceback (most recent call last):\n",
|
||||
" File \"\", line 2, in <module>\n",
|
||||
" plt.savefig('graph.png')\n",
|
||||
"NameError: name 'plt' is not defined\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mTruncated 11 messages.\u001b[0m\n",
|
||||
"\u001b[33mTruncated 453 tokens.\u001b[0m\n",
|
||||
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"I apologize for the mistake once again. It seems that the matplotlib module is not installed on your system. To install it, you can use pip, the package manager for Python.\n",
|
||||
"\n",
|
||||
"Here are the steps to install matplotlib:\n",
|
||||
"\n",
|
||||
"1. Open a command prompt or terminal.\n",
|
||||
"2. Type the following command and press Enter:\n",
|
||||
"\n",
|
||||
"```sh\n",
|
||||
"pip install matplotlib\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"3. Wait for the installation to complete.\n",
|
||||
"\n",
|
||||
"Once you have installed matplotlib, you can try running the code again to save the graph as an image file.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"assistant = autogen.AssistantAgent(\n",
|
||||
" \"assistant\",\n",
|
||||
" llm_config={\n",
|
||||
" \"config_list\": [{\"model\": \"gpt-3.5-turbo\", \"api_key\": \"YOUR_API_KEY\"}],\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Instantiate the capability to manage chat history\n",
|
||||
"manage_chat_history = context_handling.TransformChatHistory(max_tokens_per_message=50, max_messages=10, max_tokens=1000)\n",
|
||||
"# Add the capability to the assistant\n",
|
||||
"manage_chat_history.add_to_agent(assistant)\n",
|
||||
"\n",
|
||||
"user_proxy = autogen.UserProxyAgent(\n",
|
||||
" \"user_proxy\",\n",
|
||||
" human_input_mode=\"NEVER\",\n",
|
||||
" is_termination_msg=lambda x: \"TERMINATE\" in x.get(\"content\", \"\"),\n",
|
||||
" code_execution_config={\n",
|
||||
" \"work_dir\": \"coding\",\n",
|
||||
" \"use_docker\": False,\n",
|
||||
" },\n",
|
||||
" max_consecutive_auto_reply=10,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"user_proxy.initiate_chat(assistant, message=\"plot and save a graph of x^2 from -10 to 10\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Why is this important?\n",
|
||||
"This capability is especially useful if you expect the agent histories to become exceptionally large and exceed the context length offered by your LLM.\n",
|
||||
"For example, in the example below, we will define two agents -- one without this ability and one with this ability.\n",
|
||||
"\n",
|
||||
"The agent with this ability will be able to handle longer chat history without crashing."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
|
||||
"\n",
|
||||
"plot and save a graph of x^2 from -10 to 10\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"Encountered an error with the base assistant\n",
|
||||
"Error code: 400 - {'error': {'message': \"This model's maximum context length is 4097 tokens. However, your messages resulted in 1009487 tokens. Please reduce the length of the messages.\", 'type': 'invalid_request_error', 'param': 'messages', 'code': 'context_length_exceeded'}}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
|
||||
"\n",
|
||||
"plot and save a graph of x^2 from -10 to 10\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mTruncated 1991 messages.\u001b[0m\n",
|
||||
"\u001b[33mTruncated 49800 tokens.\u001b[0m\n",
|
||||
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"Here is the Python code to plot and save a graph of x^2 from -10 to 10:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"# Generate values for x from -10 to 10\n",
|
||||
"x = np.linspace(-10, 10, 100)\n",
|
||||
"\n",
|
||||
"# Calculate y values for x^2\n",
|
||||
"y = x ** 2\n",
|
||||
"\n",
|
||||
"# Plot the graph\n",
|
||||
"plt.plot(x, y)\n",
|
||||
"plt.xlabel('x')\n",
|
||||
"plt.ylabel('y = x^2')\n",
|
||||
"plt.title('Graph of y = x^2')\n",
|
||||
"\n",
|
||||
"# Save the graph as a PNG file\n",
|
||||
"plt.savefig('graph.png')\n",
|
||||
"\n",
|
||||
"# Close the plot\n",
|
||||
"plt.close()\n",
|
||||
"\n",
|
||||
"print('Graph saved as graph.png')\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Please make sure you have the `matplotlib` library installed before running this code. You can install it by running `pip install matplotlib` in your terminal.\n",
|
||||
"\n",
|
||||
"After executing the code, a graph of y = x^2 will be saved as `graph.png` in your current working directory.\n",
|
||||
"\n",
|
||||
"Let me know if you need any further assistance!\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
|
||||
"\n",
|
||||
"exitcode: 0 (execution succeeded)\n",
|
||||
"Code output: \n",
|
||||
"Graph saved as graph.png\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mTruncated 1993 messages.\u001b[0m\n",
|
||||
"\u001b[33mTruncated 49850 tokens.\u001b[0m\n",
|
||||
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"Great! The code executed successfully and the graph has been saved as \"graph.png\". You can now view the graph to see the plot of x^2 from -10 to 10. If you have any more questions or need further assistance, feel free to ask.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mTruncated 1995 messages.\u001b[0m\n",
|
||||
"\u001b[33mTruncated 49900 tokens.\u001b[0m\n",
|
||||
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"TERMINATE\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"assistant_base = autogen.AssistantAgent(\n",
|
||||
" \"assistant\",\n",
|
||||
" llm_config={\n",
|
||||
" \"config_list\": [{\"model\": \"gpt-3.5-turbo\", \"api_key\": \"YOUR_API_KEY\"}],\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assistant_with_context_handling = autogen.AssistantAgent(\n",
|
||||
" \"assistant\",\n",
|
||||
" llm_config={\n",
|
||||
" \"config_list\": [{\"model\": \"gpt-3.5-turbo\", \"api_key\": \"YOUR_API_KEY\"}],\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"# suppose this capability is not available\n",
|
||||
"manage_chat_history = context_handling.TransformChatHistory(max_tokens_per_message=50, max_messages=10, max_tokens=1000)\n",
|
||||
"manage_chat_history.add_to_agent(assistant_with_context_handling)\n",
|
||||
"\n",
|
||||
"user_proxy = autogen.UserProxyAgent(\n",
|
||||
" \"user_proxy\",\n",
|
||||
" human_input_mode=\"NEVER\",\n",
|
||||
" is_termination_msg=lambda x: \"TERMINATE\" in x.get(\"content\", \"\"),\n",
|
||||
" code_execution_config={\n",
|
||||
" \"work_dir\": \"coding\",\n",
|
||||
" \"use_docker\": False,\n",
|
||||
" },\n",
|
||||
" max_consecutive_auto_reply=2,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# suppose the chat history is large\n",
|
||||
"# Create a very long chat history that is bound to cause a crash\n",
|
||||
"# for gpt 3.5\n",
|
||||
"long_history = []\n",
|
||||
"for i in range(1000):\n",
|
||||
" # define a fake, very long message\n",
|
||||
" assitant_msg = {\"role\": \"assistant\", \"content\": \"test \" * 1000}\n",
|
||||
" user_msg = {\"role\": \"user\", \"content\": \"\"}\n",
|
||||
"\n",
|
||||
" assistant_base.send(assitant_msg, user_proxy, request_reply=False, silent=True)\n",
|
||||
" assistant_with_context_handling.send(assitant_msg, user_proxy, request_reply=False, silent=True)\n",
|
||||
" user_proxy.send(user_msg, assistant_base, request_reply=False, silent=True)\n",
|
||||
" user_proxy.send(user_msg, assistant_with_context_handling, request_reply=False, silent=True)\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" user_proxy.initiate_chat(assistant_base, message=\"plot and save a graph of x^2 from -10 to 10\", clear_history=False)\n",
|
||||
"except Exception as e:\n",
|
||||
" print(\"Encountered an error with the base assistant\")\n",
|
||||
" print(e)\n",
|
||||
" print(\"\\n\\n\")\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" user_proxy.initiate_chat(\n",
|
||||
" assistant_with_context_handling, message=\"plot and save a graph of x^2 from -10 to 10\", clear_history=False\n",
|
||||
" )\n",
|
||||
"except Exception as e:\n",
|
||||
" print(e)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,107 @@
|
|||
import pytest
|
||||
import os
|
||||
import sys
|
||||
import autogen
|
||||
from autogen import token_count_utils
|
||||
from autogen.agentchat.contrib.capabilities.context_handling import TransformChatHistory
|
||||
from autogen import AssistantAgent, UserProxyAgent
|
||||
|
||||
# from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
from conftest import skip_openai # noqa: E402
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
from test_assistant_agent import OAI_CONFIG_LIST, KEY_LOC # noqa: E402
|
||||
|
||||
try:
|
||||
from openai import OpenAI
|
||||
except ImportError:
|
||||
skip = True
|
||||
else:
|
||||
skip = False or skip_openai
|
||||
|
||||
|
||||
def test_transform_chat_history():
|
||||
"""
|
||||
Test the TransformChatHistory capability.
|
||||
|
||||
In particular, test the following methods:
|
||||
- _transform_messages
|
||||
- truncate_string_to_tokens
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": "System message"},
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "This is another test message"},
|
||||
]
|
||||
|
||||
# check whether num of messages is less than max_messages
|
||||
transform_chat_history = TransformChatHistory(max_messages=1)
|
||||
transformed_messages = transform_chat_history._transform_messages(messages)
|
||||
assert len(transformed_messages) == 2 # System message and the last message
|
||||
|
||||
# check whether num of tokens per message are is less than max_tokens
|
||||
transform_chat_history = TransformChatHistory(max_tokens_per_message=5)
|
||||
transformed_messages = transform_chat_history._transform_messages(messages)
|
||||
for message in transformed_messages:
|
||||
if message["role"] == "system":
|
||||
continue
|
||||
else:
|
||||
assert token_count_utils.count_token(message["content"]) <= 5
|
||||
|
||||
transform_chat_history = TransformChatHistory(max_tokens=5)
|
||||
transformed_messages = transform_chat_history._transform_messages(messages)
|
||||
|
||||
token_count = 0
|
||||
for message in transformed_messages:
|
||||
if message["role"] == "system":
|
||||
continue
|
||||
token_count += token_count_utils.count_token(message["content"])
|
||||
assert token_count <= 5
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai not installed OR requested to skip")
|
||||
def test_transform_chat_history_with_agents():
|
||||
"""
|
||||
This test create a GPT 3.5 agent with this capability and test the add_to_agent method.
|
||||
Including whether it prevents a crash when chat histories become excessively long.
|
||||
"""
|
||||
config_list = autogen.config_list_from_json(
|
||||
OAI_CONFIG_LIST,
|
||||
KEY_LOC,
|
||||
filter_dict={
|
||||
"model": "gpt-3.5-turbo",
|
||||
},
|
||||
)
|
||||
assistant = AssistantAgent("assistant", llm_config={"config_list": config_list}, max_consecutive_auto_reply=1)
|
||||
context_handling = TransformChatHistory(max_messages=10, max_tokens_per_message=5, max_tokens=1000)
|
||||
context_handling.add_to_agent(assistant)
|
||||
user = UserProxyAgent(
|
||||
"user",
|
||||
code_execution_config={"work_dir": "coding"},
|
||||
human_input_mode="NEVER",
|
||||
is_termination_msg=lambda x: "TERMINATE" in x.get("content", ""),
|
||||
max_consecutive_auto_reply=1,
|
||||
)
|
||||
|
||||
# Create a very long chat history that is bound to cause a crash
|
||||
# for gpt 3.5
|
||||
for i in range(1000):
|
||||
assitant_msg = {"role": "assistant", "content": "test " * 1000}
|
||||
user_msg = {"role": "user", "content": ""}
|
||||
|
||||
assistant.send(assitant_msg, user, request_reply=False)
|
||||
user.send(user_msg, assistant, request_reply=False)
|
||||
|
||||
try:
|
||||
user.initiate_chat(
|
||||
assistant, message="Plot a chart of nvidia and tesla stock prices for the last 5 years", clear_history=False
|
||||
)
|
||||
except Exception as e:
|
||||
assert False, f"Chat initiation failed with error {str(e)}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_transform_chat_history()
|
||||
test_transform_chat_history_with_agents()
|
Loading…
Reference in New Issue