Proxy PR for Long Context Capability 1513 (#1591)

* Add new capability to handle long context

* Make print conditional

* Remove superfluous comment

* Fix msg order

* Allow user to specify max_tokens

* Add ability to specify max_tokens per message; improve name

* Improve doc and readability

* Add tests

* Improve documentation and add tests per Erik and Chi's feedback

* Update notebook

* Update doc string of add to agents

* Improve doc string

* improve notebook

* Update github workflows for context handling

* Update docstring

* update notebook to use raw config list.

* Update contrib-openai.yml remove _target

* Fix code formatting

* Fix workflow file

* Update .github/workflows/contrib-openai.yml

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
gagb 2024-02-08 10:26:00 -08:00 committed by GitHub
parent a3c3317faa
commit 47d6c7567e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 920 additions and 6 deletions

View File

@ -260,3 +260,42 @@ jobs:
with:
file: ./coverage.xml
flags: unittests
ContextHandling:
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.11"]
runs-on: ${{ matrix.os }}
environment: openai1
steps:
# checkout to pr branch
- name: Checkout
uses: actions/checkout@v3
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies
run: |
docker --version
python -m pip install --upgrade pip wheel
pip install -e .
python -c "import autogen"
pip install coverage pytest
- name: Coverage
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
BING_API_KEY: ${{ secrets.BING_API_KEY }}
run: |
coverage run -a -m pytest test/agentchat/contrib/capabilities/test_context_handling.py
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests

View File

@ -253,3 +253,40 @@ jobs:
with:
file: ./coverage.xml
flags: unittests
ContextHandling:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-2019]
python-version: ["3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
pip install pytest
- name: Install packages and dependencies for Context Handling
run: |
pip install -e .
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
fi
- name: Coverage
run: |
pip install coverage>=5.3
coverage run -a -m pytest test/agentchat/contrib/capabilities/test_context_handling.py --skip-openai
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests

View File

@ -1,5 +0,0 @@
from .teachability import Teachability
from .agent_capability import AgentCapability
__all__ = ["Teachability", "AgentCapability"]

View File

@ -0,0 +1,108 @@
import sys
from termcolor import colored
from typing import Dict, Optional, List
from autogen import ConversableAgent
from autogen import token_count_utils
class TransformChatHistory:
"""
An agent's chat history with other agents is a common context that it uses to generate a reply.
This capability allows the agent to transform its chat history prior to using it to generate a reply.
It does not permanently modify the chat history, but rather processes it on every invocation.
This capability class enables various strategies to transform chat history, such as:
- Truncate messages: Truncate each message to first maximum number of tokens.
- Limit number of messages: Truncate the chat history to a maximum number of (recent) messages.
- Limit number of tokens: Truncate the chat history to number of recent N messages that fit in
maximum number of tokens.
Note that the system message, because of its special significance, is always kept as is.
The three strategies can be combined. For example, when each of these parameters are specified
they are used in the following order:
1. First truncate messages to a maximum number of tokens
2. Second, it limits the number of message to keep
3. Third, it limits the total number of tokens in the chat history
Args:
max_tokens_per_message (Optional[int]): Maximum number of tokens to keep in each message.
max_messages (Optional[int]): Maximum number of messages to keep in the context.
max_tokens (Optional[int]): Maximum number of tokens to keep in the context.
"""
def __init__(
self,
*,
max_tokens_per_message: Optional[int] = None,
max_messages: Optional[int] = None,
max_tokens: Optional[int] = None,
):
self.max_tokens_per_message = max_tokens_per_message if max_tokens_per_message else sys.maxsize
self.max_messages = max_messages if max_messages else sys.maxsize
self.max_tokens = max_tokens if max_tokens else sys.maxsize
def add_to_agent(self, agent: ConversableAgent):
"""
Adds TransformChatHistory capability to the given agent.
"""
agent.register_hook(hookable_method=agent.process_all_messages, hook=self._transform_messages)
def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
"""
Args:
messages: List of messages to process.
Returns:
List of messages with the first system message and the last max_messages messages.
"""
processed_messages = []
messages = messages.copy()
rest_messages = messages
# check if the first message is a system message and append it to the processed messages
if len(messages) > 0:
if messages[0]["role"] == "system":
msg = messages[0]
processed_messages.append(msg)
rest_messages = messages[1:]
processed_messages_tokens = 0
for msg in messages:
msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message)
# iterate through rest of the messages and append them to the processed messages
for msg in rest_messages[-self.max_messages :]:
msg_tokens = token_count_utils.count_token(msg["content"])
if processed_messages_tokens + msg_tokens > self.max_tokens:
break
processed_messages.append(msg)
processed_messages_tokens += msg_tokens
total_tokens = 0
for msg in messages:
total_tokens += token_count_utils.count_token(msg["content"])
num_truncated = len(messages) - len(processed_messages)
if num_truncated > 0 or total_tokens > processed_messages_tokens:
print(colored(f"Truncated {len(messages) - len(processed_messages)} messages.", "yellow"))
print(colored(f"Truncated {total_tokens - processed_messages_tokens} tokens.", "yellow"))
return processed_messages
def truncate_str_to_tokens(text: str, max_tokens: int) -> str:
"""
Truncate a string so that number of tokens in less than max_tokens.
Args:
content: String to process.
max_tokens: Maximum number of tokens to keep.
Returns:
Truncated string.
"""
truncated_string = ""
for char in text:
truncated_string += char
if token_count_utils.count_token(truncated_string) == max_tokens:
break
return truncated_string

View File

@ -194,7 +194,7 @@ class ConversableAgent(Agent):
# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
# New hookable methods should be added to this list as required to support new agent capabilities.
self.hook_lists = {self.process_last_message: []} # This is currently the only hookable method.
self.hook_lists = {self.process_last_message: [], self.process_all_messages: []}
def register_reply(
self,
@ -1528,6 +1528,10 @@ class ConversableAgent(Agent):
if messages is None:
messages = self._oai_messages[sender]
# Call the hookable method that gives registered hooks a chance to process all messages.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_all_messages(messages)
# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_message(messages)
@ -1584,6 +1588,10 @@ class ConversableAgent(Agent):
if messages is None:
messages = self._oai_messages[sender]
# Call the hookable method that gives registered hooks a chance to process all messages.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_all_messages(messages)
# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_message(messages)
@ -2211,6 +2219,21 @@ class ConversableAgent(Agent):
assert hook not in hook_list, f"{hook} is already registered as a hook."
hook_list.append(hook)
def process_all_messages(self, messages: List[Dict]) -> List[Dict]:
"""
Calls any registered capability hooks to process all messages, potentially modifying the messages.
"""
hook_list = self.hook_lists[self.process_all_messages]
# If no hooks are registered, or if there are no messages to process, return the original message list.
if len(hook_list) == 0 or messages is None:
return messages
# Call each hook (in order of registration) to process the messages.
processed_messages = messages
for hook in hook_list:
processed_messages = hook(processed_messages)
return processed_messages
def process_last_message(self, messages):
"""
Calls any registered capability hooks to use and potentially modify the text of the last message,

View File

@ -0,0 +1,605 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Handling A Long Context via `TransformChatHistory`\n",
"\n",
"This notebook illustrates how you can use the `TransformChatHistory` capability to give any `Conversable` agent an ability to handle a long context. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"## Uncomment to install pyautogen if you don't have it already\n",
"#! pip install pyautogen"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import autogen\n",
"from autogen.agentchat.contrib.capabilities import context_handling"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To add this ability to any agent, define the capability and then use `add_to_agent`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"plot and save a graph of x^2 from -10 to 10\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"Here's a Python code snippet to plot and save a graph of x^2 from -10 to 10:\n",
"\n",
"```python\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# Generate x values from -10 to 10\n",
"x = np.linspace(-10, 10, 100)\n",
"\n",
"# Evaluate y values using x^2\n",
"y = x**2\n",
"\n",
"# Plot the graph\n",
"plt.plot(x, y)\n",
"\n",
"# Set labels and title\n",
"plt.xlabel('x')\n",
"plt.ylabel('y')\n",
"plt.title('Graph of y = x^2')\n",
"\n",
"# Save the graph as an image file (e.g., PNG)\n",
"plt.savefig('x_squared_graph.png')\n",
"\n",
"# Show the graph\n",
"plt.show()\n",
"```\n",
"\n",
"Please make sure to have the `matplotlib` library installed in your Python environment. After executing the code, the graph will be saved as \"x_squared_graph.png\" in the current directory. You can change the filename if desired.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"Figure(640x480)\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"Great! The code executed successfully and generated a graph of x^2 from -10 to 10. You can save the graph by adding the following code snippet:\n",
"\n",
"```python\n",
"# Save the graph as a file\n",
"plt.savefig('graph.png')\n",
"```\n",
"\n",
"This will save the graph as a PNG file named \"graph.png\" in your current working directory.\n",
"\n",
"Now, you can check the saved graph in your current directory. Let me know if you need any further assistance.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 1 (execution failed)\n",
"Code output: \n",
"Traceback (most recent call last):\n",
" File \"\", line 2, in <module>\n",
" plt.savefig('graph.png')\n",
"NameError: name 'plt' is not defined\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"I apologize for the mistake. It seems that the `plt` module was not imported correctly. Please make sure that you have the `matplotlib` library installed. You can install it using the following command:\n",
"\n",
"```sh\n",
"pip install matplotlib\n",
"```\n",
"\n",
"Once you have the library installed, please try running the code again. If you are still facing any issues, please let me know.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is sh)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"Defaulting to user installation because normal site-packages is not writeable\n",
"Requirement already satisfied: matplotlib in /home/vscode/.local/lib/python3.10/site-packages (3.8.2)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (3.1.1)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (4.47.2)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (2.8.2)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/site-packages (from matplotlib) (23.2)\n",
"Requirement already satisfied: pillow>=8 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (10.2.0)\n",
"Requirement already satisfied: cycler>=0.10 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (0.12.1)\n",
"Requirement already satisfied: kiwisolver>=1.3.1 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (1.4.5)\n",
"Requirement already satisfied: numpy<2,>=1.21 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (1.26.3)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (1.2.0)\n",
"Requirement already satisfied: six>=1.5 in /home/vscode/.local/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"Thank you for installing the `matplotlib` library. Let's try running the code again:\n",
"\n",
"```python\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# Generate x values from -10 to 10\n",
"x = np.linspace(-10, 10, 100)\n",
"\n",
"# Compute the y values (x^2)\n",
"y = x**2\n",
"\n",
"# Plot the graph\n",
"plt.plot(x, y)\n",
"\n",
"# Save the graph as a file\n",
"plt.savefig('graph.png')\n",
"\n",
"# Show the graph\n",
"plt.show()\n",
"```\n",
"\n",
"Please give it a try and let me know if it works for you.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"Figure(640x480)\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"It seems that the code executed successfully and generated a figure of x^2 from -10 to 10. To save the graph as an image file, you can add the following code snippet:\n",
"\n",
"```python\n",
"# Save the graph as a file\n",
"plt.savefig('graph.png')\n",
"```\n",
"\n",
"This will save the graph as a PNG file named \"graph.png\" in the current directory. You can change the file name or modify the file format as needed.\n",
"Remember to replace `plt.savefig('graph.png')` with your desired file name if necessary.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 1 (execution failed)\n",
"Code output: \n",
"Traceback (most recent call last):\n",
" File \"\", line 2, in <module>\n",
" plt.savefig('graph.png')\n",
"NameError: name 'plt' is not defined\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mTruncated 1 messages.\u001b[0m\n",
"\u001b[33mTruncated 15 tokens.\u001b[0m\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"My apologies for the mistake again. It seems that the `plt` module was not imported correctly. Please make sure that you have imported the `matplotlib.pyplot` module.\n",
"\n",
"Here is the correct code snippet:\n",
"\n",
"```python\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# Generate x values from -10 to 10\n",
"x = np.linspace(-10, 10, 100)\n",
"y = x ** 2\n",
"\n",
"# Create a plot of x^2\n",
"plt.plot(x, y)\n",
"\n",
"# Save the graph as a file\n",
"plt.savefig('graph.png')\n",
"```\n",
"\n",
"Please try running this code. It should generate the graph and save it as `graph.png` in the current working directory.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mTruncated 3 messages.\u001b[0m\n",
"\u001b[33mTruncated 84 tokens.\u001b[0m\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"Great! The code executed successfully and generated a graph of x^2 from -10 to 10. You can save the graph by adding the following code snippet:\n",
"\n",
"```python\n",
"# Save the graph as a file\n",
"plt.savefig('graph.png')\n",
"```\n",
"\n",
"This will save the graph as a PNG image file named \"graph.png\" in the same directory as your Python script. You can change the filename and format (e.g., 'graph.jpg', 'graph.pdf') as needed.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 1 (execution failed)\n",
"Code output: \n",
"Traceback (most recent call last):\n",
" File \"\", line 2, in <module>\n",
" plt.savefig('graph.png')\n",
"NameError: name 'plt' is not defined\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mTruncated 5 messages.\u001b[0m\n",
"\u001b[33mTruncated 184 tokens.\u001b[0m\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"I apologize for the confusion. It looks like the code didn't execute as expected. \n",
"\n",
"To save the graph in a file, try the following code:\n",
"\n",
"```python\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# Generate x values from -10 to 10\n",
"x = np.linspace(-10, 10, 100)\n",
"\n",
"# Calculate the y values (square of x)\n",
"y = x ** 2\n",
"\n",
"# Plot the graph\n",
"plt.plot(x, y)\n",
"\n",
"# Save the graph as a file\n",
"plt.savefig('graph.png')\n",
"\n",
"# Show the graph\n",
"plt.show()\n",
"```\n",
"\n",
"Make sure you have the `matplotlib` library installed. Let me know if you encounter any further issues.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"Figure(640x480)\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mTruncated 7 messages.\u001b[0m\n",
"\u001b[33mTruncated 284 tokens.\u001b[0m\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"Great! The code executed successfully and generated a graph of x^2 from -10 to 10. You can save the graph by adding the following code snippet:\n",
"\n",
"```python\n",
"# Save the graph as a file\n",
"plt.savefig('graph.png')\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mTruncated 9 messages.\u001b[0m\n",
"\u001b[33mTruncated 353 tokens.\u001b[0m\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"Sorry for the mistake again. The code snippet provided was incomplete. To correctly save the graph as an image file, you need to add a line of code to close the plot after saving. Here's the updated code:\n",
"\n",
"```python\n",
"# Save the graph as a file\n",
"plt.savefig('graph.png')\n",
"plt.close()\n",
"```\n",
"\n",
"Please try running the complete code again and it should save the graph as \"graph.png\" in the current directory. Let me know if you encounter any further issues.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 1 (execution failed)\n",
"Code output: \n",
"Traceback (most recent call last):\n",
" File \"\", line 2, in <module>\n",
" plt.savefig('graph.png')\n",
"NameError: name 'plt' is not defined\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mTruncated 11 messages.\u001b[0m\n",
"\u001b[33mTruncated 453 tokens.\u001b[0m\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"I apologize for the mistake once again. It seems that the matplotlib module is not installed on your system. To install it, you can use pip, the package manager for Python.\n",
"\n",
"Here are the steps to install matplotlib:\n",
"\n",
"1. Open a command prompt or terminal.\n",
"2. Type the following command and press Enter:\n",
"\n",
"```sh\n",
"pip install matplotlib\n",
"```\n",
"\n",
"3. Wait for the installation to complete.\n",
"\n",
"Once you have installed matplotlib, you can try running the code again to save the graph as an image file.\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"assistant = autogen.AssistantAgent(\n",
" \"assistant\",\n",
" llm_config={\n",
" \"config_list\": [{\"model\": \"gpt-3.5-turbo\", \"api_key\": \"YOUR_API_KEY\"}],\n",
" },\n",
")\n",
"\n",
"# Instantiate the capability to manage chat history\n",
"manage_chat_history = context_handling.TransformChatHistory(max_tokens_per_message=50, max_messages=10, max_tokens=1000)\n",
"# Add the capability to the assistant\n",
"manage_chat_history.add_to_agent(assistant)\n",
"\n",
"user_proxy = autogen.UserProxyAgent(\n",
" \"user_proxy\",\n",
" human_input_mode=\"NEVER\",\n",
" is_termination_msg=lambda x: \"TERMINATE\" in x.get(\"content\", \"\"),\n",
" code_execution_config={\n",
" \"work_dir\": \"coding\",\n",
" \"use_docker\": False,\n",
" },\n",
" max_consecutive_auto_reply=10,\n",
")\n",
"\n",
"user_proxy.initiate_chat(assistant, message=\"plot and save a graph of x^2 from -10 to 10\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Why is this important?\n",
"This capability is especially useful if you expect the agent histories to become exceptionally large and exceed the context length offered by your LLM.\n",
"For example, in the example below, we will define two agents -- one without this ability and one with this ability.\n",
"\n",
"The agent with this ability will be able to handle longer chat history without crashing."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"plot and save a graph of x^2 from -10 to 10\n",
"\n",
"--------------------------------------------------------------------------------\n",
"Encountered an error with the base assistant\n",
"Error code: 400 - {'error': {'message': \"This model's maximum context length is 4097 tokens. However, your messages resulted in 1009487 tokens. Please reduce the length of the messages.\", 'type': 'invalid_request_error', 'param': 'messages', 'code': 'context_length_exceeded'}}\n",
"\n",
"\n",
"\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"plot and save a graph of x^2 from -10 to 10\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mTruncated 1991 messages.\u001b[0m\n",
"\u001b[33mTruncated 49800 tokens.\u001b[0m\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"Here is the Python code to plot and save a graph of x^2 from -10 to 10:\n",
"\n",
"```python\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# Generate values for x from -10 to 10\n",
"x = np.linspace(-10, 10, 100)\n",
"\n",
"# Calculate y values for x^2\n",
"y = x ** 2\n",
"\n",
"# Plot the graph\n",
"plt.plot(x, y)\n",
"plt.xlabel('x')\n",
"plt.ylabel('y = x^2')\n",
"plt.title('Graph of y = x^2')\n",
"\n",
"# Save the graph as a PNG file\n",
"plt.savefig('graph.png')\n",
"\n",
"# Close the plot\n",
"plt.close()\n",
"\n",
"print('Graph saved as graph.png')\n",
"```\n",
"\n",
"Please make sure you have the `matplotlib` library installed before running this code. You can install it by running `pip install matplotlib` in your terminal.\n",
"\n",
"After executing the code, a graph of y = x^2 will be saved as `graph.png` in your current working directory.\n",
"\n",
"Let me know if you need any further assistance!\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"Graph saved as graph.png\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mTruncated 1993 messages.\u001b[0m\n",
"\u001b[33mTruncated 49850 tokens.\u001b[0m\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"Great! The code executed successfully and the graph has been saved as \"graph.png\". You can now view the graph to see the plot of x^2 from -10 to 10. If you have any more questions or need further assistance, feel free to ask.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mTruncated 1995 messages.\u001b[0m\n",
"\u001b[33mTruncated 49900 tokens.\u001b[0m\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"assistant_base = autogen.AssistantAgent(\n",
" \"assistant\",\n",
" llm_config={\n",
" \"config_list\": [{\"model\": \"gpt-3.5-turbo\", \"api_key\": \"YOUR_API_KEY\"}],\n",
" },\n",
")\n",
"\n",
"assistant_with_context_handling = autogen.AssistantAgent(\n",
" \"assistant\",\n",
" llm_config={\n",
" \"config_list\": [{\"model\": \"gpt-3.5-turbo\", \"api_key\": \"YOUR_API_KEY\"}],\n",
" },\n",
")\n",
"# suppose this capability is not available\n",
"manage_chat_history = context_handling.TransformChatHistory(max_tokens_per_message=50, max_messages=10, max_tokens=1000)\n",
"manage_chat_history.add_to_agent(assistant_with_context_handling)\n",
"\n",
"user_proxy = autogen.UserProxyAgent(\n",
" \"user_proxy\",\n",
" human_input_mode=\"NEVER\",\n",
" is_termination_msg=lambda x: \"TERMINATE\" in x.get(\"content\", \"\"),\n",
" code_execution_config={\n",
" \"work_dir\": \"coding\",\n",
" \"use_docker\": False,\n",
" },\n",
" max_consecutive_auto_reply=2,\n",
")\n",
"\n",
"# suppose the chat history is large\n",
"# Create a very long chat history that is bound to cause a crash\n",
"# for gpt 3.5\n",
"long_history = []\n",
"for i in range(1000):\n",
" # define a fake, very long message\n",
" assitant_msg = {\"role\": \"assistant\", \"content\": \"test \" * 1000}\n",
" user_msg = {\"role\": \"user\", \"content\": \"\"}\n",
"\n",
" assistant_base.send(assitant_msg, user_proxy, request_reply=False, silent=True)\n",
" assistant_with_context_handling.send(assitant_msg, user_proxy, request_reply=False, silent=True)\n",
" user_proxy.send(user_msg, assistant_base, request_reply=False, silent=True)\n",
" user_proxy.send(user_msg, assistant_with_context_handling, request_reply=False, silent=True)\n",
"\n",
"try:\n",
" user_proxy.initiate_chat(assistant_base, message=\"plot and save a graph of x^2 from -10 to 10\", clear_history=False)\n",
"except Exception as e:\n",
" print(\"Encountered an error with the base assistant\")\n",
" print(e)\n",
" print(\"\\n\\n\")\n",
"\n",
"try:\n",
" user_proxy.initiate_chat(\n",
" assistant_with_context_handling, message=\"plot and save a graph of x^2 from -10 to 10\", clear_history=False\n",
" )\n",
"except Exception as e:\n",
" print(e)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,107 @@
import pytest
import os
import sys
import autogen
from autogen import token_count_utils
from autogen.agentchat.contrib.capabilities.context_handling import TransformChatHistory
from autogen import AssistantAgent, UserProxyAgent
# from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from conftest import skip_openai # noqa: E402
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from test_assistant_agent import OAI_CONFIG_LIST, KEY_LOC # noqa: E402
try:
from openai import OpenAI
except ImportError:
skip = True
else:
skip = False or skip_openai
def test_transform_chat_history():
"""
Test the TransformChatHistory capability.
In particular, test the following methods:
- _transform_messages
- truncate_string_to_tokens
"""
messages = [
{"role": "system", "content": "System message"},
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "This is another test message"},
]
# check whether num of messages is less than max_messages
transform_chat_history = TransformChatHistory(max_messages=1)
transformed_messages = transform_chat_history._transform_messages(messages)
assert len(transformed_messages) == 2 # System message and the last message
# check whether num of tokens per message are is less than max_tokens
transform_chat_history = TransformChatHistory(max_tokens_per_message=5)
transformed_messages = transform_chat_history._transform_messages(messages)
for message in transformed_messages:
if message["role"] == "system":
continue
else:
assert token_count_utils.count_token(message["content"]) <= 5
transform_chat_history = TransformChatHistory(max_tokens=5)
transformed_messages = transform_chat_history._transform_messages(messages)
token_count = 0
for message in transformed_messages:
if message["role"] == "system":
continue
token_count += token_count_utils.count_token(message["content"])
assert token_count <= 5
@pytest.mark.skipif(skip, reason="openai not installed OR requested to skip")
def test_transform_chat_history_with_agents():
"""
This test create a GPT 3.5 agent with this capability and test the add_to_agent method.
Including whether it prevents a crash when chat histories become excessively long.
"""
config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST,
KEY_LOC,
filter_dict={
"model": "gpt-3.5-turbo",
},
)
assistant = AssistantAgent("assistant", llm_config={"config_list": config_list}, max_consecutive_auto_reply=1)
context_handling = TransformChatHistory(max_messages=10, max_tokens_per_message=5, max_tokens=1000)
context_handling.add_to_agent(assistant)
user = UserProxyAgent(
"user",
code_execution_config={"work_dir": "coding"},
human_input_mode="NEVER",
is_termination_msg=lambda x: "TERMINATE" in x.get("content", ""),
max_consecutive_auto_reply=1,
)
# Create a very long chat history that is bound to cause a crash
# for gpt 3.5
for i in range(1000):
assitant_msg = {"role": "assistant", "content": "test " * 1000}
user_msg = {"role": "user", "content": ""}
assistant.send(assitant_msg, user, request_reply=False)
user.send(user_msg, assistant, request_reply=False)
try:
user.initiate_chat(
assistant, message="Plot a chart of nvidia and tesla stock prices for the last 5 years", clear_history=False
)
except Exception as e:
assert False, f"Chat initiation failed with error {str(e)}"
if __name__ == "__main__":
test_transform_chat_history()
test_transform_chat_history_with_agents()