Text Compression Transform (#2225)

* adds implementation

* handles optional import

* cleanup

* updates github workflows

* skip test if dependencies not installed

* skip test if dependencies not installed

* use cpu

* skip openai

* unskip openai

* adds protocol

* better docstr

* minor fixes

* updates optional dependencies docs

* wip

* update docstrings

* wip

* adds back llmlingua requirement

* finalized protocol

* improve docstr

* guide complete

* improve docstr

* fix FAQ

* added cache support

* improve cache key

* cache key fix + faq fix

* improve docs

* improve guide

* args -> params

* spelling
This commit is contained in:
Wael Karkoub 2024-05-06 15:16:49 +01:00 committed by GitHub
parent 5a3a8a5541
commit 372ac1e794
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 503 additions and 33 deletions

View File

@ -400,7 +400,7 @@ jobs:
pip install pytest-cov>=5
- name: Install packages and dependencies for Transform Messages
run: |
pip install -e .
pip install -e '.[long-context]'
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |

View File

@ -0,0 +1,68 @@
from typing import Any, Dict, Optional, Protocol
IMPORT_ERROR: Optional[Exception] = None
try:
import llmlingua
except ImportError:
IMPORT_ERROR = ImportError(
"LLMLingua is not installed. Please install it with `pip install pyautogen[long-context]`"
)
PromptCompressor = object
else:
from llmlingua import PromptCompressor
class TextCompressor(Protocol):
"""Defines a protocol for text compression to optimize agent interactions."""
def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
"""This method takes a string as input and returns a dictionary containing the compressed text and other
relevant information. The compressed text should be stored under the 'compressed_text' key in the dictionary.
To calculate the number of saved tokens, the dictionary should include 'origin_tokens' and 'compressed_tokens' keys.
"""
...
class LLMLingua:
"""Compresses text messages using LLMLingua for improved efficiency in processing and response generation.
NOTE: The effectiveness of compression and the resultant token savings can vary based on the content of the messages
and the specific configurations used for the PromptCompressor.
"""
def __init__(
self,
prompt_compressor_kwargs: Dict = dict(
model_name="microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
use_llmlingua2=True,
device_map="cpu",
),
structured_compression: bool = False,
) -> None:
"""
Args:
prompt_compressor_kwargs (dict): A dictionary of keyword arguments for the PromptCompressor. Defaults to a
dictionary with model_name set to "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
use_llmlingua2 set to True, and device_map set to "cpu".
structured_compression (bool): A flag indicating whether to use structured compression. If True, the
structured_compress_prompt method of the PromptCompressor is used. Otherwise, the compress_prompt method
is used. Defaults to False.
dictionary.
Raises:
ImportError: If the llmlingua library is not installed.
"""
if IMPORT_ERROR:
raise IMPORT_ERROR
self._prompt_compressor = PromptCompressor(**prompt_compressor_kwargs)
assert isinstance(self._prompt_compressor, llmlingua.PromptCompressor)
self._compression_method = (
self._prompt_compressor.structured_compress_prompt
if structured_compression
else self._prompt_compressor.compress_prompt
)
def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
return self._compression_method([text], **compression_params)

View File

@ -1,4 +1,5 @@
import copy
import json
import sys
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
@ -6,6 +7,9 @@ import tiktoken
from termcolor import colored
from autogen import token_count_utils
from autogen.cache import AbstractCache, Cache
from .text_compressors import LLMLingua, TextCompressor
class MessageTransform(Protocol):
@ -156,7 +160,7 @@ class MessageTokenLimiter:
assert self._min_tokens is not None
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not self._are_min_tokens_reached(messages):
if not _min_tokens_reached(messages, self._min_tokens):
return messages
temp_messages = copy.deepcopy(messages)
@ -205,19 +209,6 @@ class MessageTokenLimiter:
return logs_str, True
return "No tokens were truncated.", False
def _are_min_tokens_reached(self, messages: List[Dict]) -> bool:
"""
Returns True if no minimum tokens restrictions are applied.
Either if the total number of tokens in the messages is greater than or equal to the `min_theshold_tokens`,
or no minimum tokens threshold is set.
"""
if not self._min_tokens:
return True
messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= self._min_tokens
def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]:
if isinstance(contents, str):
return self._truncate_tokens(contents, n_tokens)
@ -268,7 +259,7 @@ class MessageTokenLimiter:
return max_tokens if max_tokens is not None else sys.maxsize
def _validate_min_tokens(self, min_tokens: int, max_tokens: int) -> int:
def _validate_min_tokens(self, min_tokens: Optional[int], max_tokens: Optional[int]) -> int:
if min_tokens is None:
return 0
if min_tokens < 0:
@ -278,6 +269,154 @@ class MessageTokenLimiter:
return min_tokens
class TextMessageCompressor:
"""A transform for compressing text messages in a conversation history.
It uses a specified text compression method to reduce the token count of messages, which can lead to more efficient
processing and response generation by downstream models.
"""
def __init__(
self,
text_compressor: Optional[TextCompressor] = None,
min_tokens: Optional[int] = None,
compression_params: Dict = dict(),
cache: Optional[AbstractCache] = Cache.disk(),
):
"""
Args:
text_compressor (TextCompressor or None): An instance of a class that implements the TextCompressor
protocol. If None, it defaults to LLMLingua.
min_tokens (int or None): Minimum number of tokens in messages to apply the transformation. Must be greater
than or equal to 0 if not None. If None, no threshold-based compression is applied.
compression_args (dict): A dictionary of arguments for the compression method. Defaults to an empty
dictionary.
cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages.
If None, no caching will be used.
"""
if text_compressor is None:
text_compressor = LLMLingua()
self._validate_min_tokens(min_tokens)
self._text_compressor = text_compressor
self._min_tokens = min_tokens
self._compression_args = compression_params
self._cache = cache
# Optimizing savings calculations to optimize log generation
self._recent_tokens_savings = 0
def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Applies compression to messages in a conversation history based on the specified configuration.
The function processes each message according to the `compression_args` and `min_tokens` settings, applying
the specified compression configuration and returning a new list of messages with reduced token counts
where possible.
Args:
messages (List[Dict]): A list of message dictionaries to be compressed.
Returns:
List[Dict]: A list of dictionaries with the message content compressed according to the configured
method and scope.
"""
# Make sure there is at least one message
if not messages:
return messages
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not _min_tokens_reached(messages, self._min_tokens):
return messages
total_savings = 0
processed_messages = messages.copy()
for message in processed_messages:
# Some messages may not have content.
if not isinstance(message.get("content"), (str, list)):
continue
if _is_content_text_empty(message["content"]):
continue
cached_content = self._cache_get(message["content"])
if cached_content is not None:
savings, compressed_content = cached_content
else:
savings, compressed_content = self._compress(message["content"])
self._cache_set(message["content"], compressed_content, savings)
message["content"] = compressed_content
total_savings += savings
self._recent_tokens_savings = total_savings
return processed_messages
def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
if self._recent_tokens_savings > 0:
return f"{self._recent_tokens_savings} tokens saved with text compression.", True
else:
return "No tokens saved with text compression.", False
def _compress(self, content: Union[str, List[Dict]]) -> Tuple[int, Union[str, List[Dict]]]:
"""Compresses the given text or multimodal content using the specified compression method."""
if isinstance(content, str):
return self._compress_text(content)
elif isinstance(content, list):
return self._compress_multimodal(content)
else:
return 0, content
def _compress_multimodal(self, content: List[Dict]) -> Tuple[int, List[Dict]]:
tokens_saved = 0
for msg in content:
if "text" in msg:
savings, msg["text"] = self._compress_text(msg["text"])
tokens_saved += savings
return tokens_saved, content
def _compress_text(self, text: str) -> Tuple[int, str]:
"""Compresses the given text using the specified compression method."""
compressed_text = self._text_compressor.compress_text(text, **self._compression_args)
savings = 0
if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text:
savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"]
return savings, compressed_text["compressed_prompt"]
def _cache_get(self, content: Union[str, List[Dict]]) -> Optional[Tuple[int, Union[str, List[Dict]]]]:
if self._cache:
cached_value = self._cache.get(self._cache_key(content))
if cached_value:
return cached_value
def _cache_set(
self, content: Union[str, List[Dict]], compressed_content: Union[str, List[Dict]], tokens_saved: int
):
if self._cache:
value = (tokens_saved, json.dumps(compressed_content))
self._cache.set(self._cache_key(content), value)
def _cache_key(self, content: Union[str, List[Dict]]) -> str:
return f"{json.dumps(content)}_{self._min_tokens}"
def _validate_min_tokens(self, min_tokens: Optional[int]):
if min_tokens is not None and min_tokens <= 0:
raise ValueError("min_tokens must be greater than 0 or None")
def _min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool:
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value."""
if not min_tokens:
return True
messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= min_tokens
def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
token_count = 0
if isinstance(content, str):
@ -286,3 +425,12 @@ def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
for item in content:
token_count += _count_tokens(item.get("text", ""))
return token_count
def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
if isinstance(content, str):
return content == ""
elif isinstance(content, list):
return all(_is_content_text_empty(item.get("text", "")) for item in content)
else:
return False

View File

@ -79,6 +79,7 @@ extra_require = {
"websockets": ["websockets>=12.0,<13"],
"jupyter-executor": jupyter_executor,
"types": ["mypy==1.9.0", "pytest>=6.1.1,<8"] + jupyter_executor,
"long-context": ["llmlingua<0.3"],
}
setuptools.setup(

View File

@ -1,5 +1,6 @@
import copy
from typing import Dict, List
from unittest.mock import MagicMock, patch
import pytest
@ -118,13 +119,82 @@ def test_message_token_limiter_get_logs(message_token_limiter, messages, expecte
assert logs_str == expected_logs
def test_text_compression():
"""Test the TextMessageCompressor transform."""
try:
from autogen.agentchat.contrib.capabilities.transforms import TextMessageCompressor
text_compressor = TextMessageCompressor()
except ImportError:
pytest.skip("LLM Lingua is not installed.")
text = "Run this test with a long string. "
messages = [
{
"role": "assistant",
"content": [{"type": "text", "text": "".join([text] * 3)}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "".join([text] * 3)}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "".join([text] * 3)}],
},
]
transformed_messages = text_compressor.apply_transform([{"content": text}])
assert len(transformed_messages[0]["content"]) < len(text)
# Test compressing all messages
text_compressor = TextMessageCompressor()
transformed_messages = text_compressor.apply_transform(copy.deepcopy(messages))
for message in transformed_messages:
assert len(message["content"][0]["text"]) < len(messages[0]["content"][0]["text"])
def test_text_compression_cache():
try:
from autogen.agentchat.contrib.capabilities.transforms import TextMessageCompressor
except ImportError:
pytest.skip("LLM Lingua is not installed.")
messages = get_long_messages()
mock_compressed_content = (1, {"content": "mock"})
with patch(
"autogen.agentchat.contrib.capabilities.transforms.TextMessageCompressor._cache_get",
MagicMock(return_value=(1, {"content": "mock"})),
) as mocked_get, patch(
"autogen.agentchat.contrib.capabilities.transforms.TextMessageCompressor._cache_set", MagicMock()
) as mocked_set:
text_compressor = TextMessageCompressor()
text_compressor.apply_transform(messages)
text_compressor.apply_transform(messages)
assert mocked_get.call_count == len(messages)
assert mocked_set.call_count == len(messages)
# We already populated the cache with the mock content
# We need to test if we retrieve the correct content
text_compressor = TextMessageCompressor()
compressed_messages = text_compressor.apply_transform(messages)
for message in compressed_messages:
assert message["content"] == mock_compressed_content[1]
if __name__ == "__main__":
long_messages = get_long_messages()
short_messages = get_short_messages()
no_content_messages = get_no_content_messages()
message_history_limiter = MessageHistoryLimiter(max_messages=3)
message_token_limiter = MessageTokenLimiter(max_tokens_per_message=3)
message_token_limiter_with_threshold = MessageTokenLimiter(max_tokens_per_message=1, min_tokens=10)
msg_history_limiter = MessageHistoryLimiter(max_messages=3)
msg_token_limiter = MessageTokenLimiter(max_tokens_per_message=3)
msg_token_limiter_with_threshold = MessageTokenLimiter(max_tokens_per_message=1, min_tokens=10)
# Test Parameters
message_history_limiter_apply_transform_parameters = {
@ -170,14 +240,14 @@ if __name__ == "__main__":
message_history_limiter_apply_transform_parameters["messages"],
message_history_limiter_apply_transform_parameters["expected_messages_len"],
):
test_message_history_limiter_apply_transform(message_history_limiter, messages, expected_messages_len)
test_message_history_limiter_apply_transform(msg_history_limiter, messages, expected_messages_len)
for messages, expected_logs, expected_effect in zip(
message_history_limiter_get_logs_parameters["messages"],
message_history_limiter_get_logs_parameters["expected_logs"],
message_history_limiter_get_logs_parameters["expected_effect"],
):
test_message_history_limiter_get_logs(message_history_limiter, messages, expected_logs, expected_effect)
test_message_history_limiter_get_logs(msg_history_limiter, messages, expected_logs, expected_effect)
# Call the MessageTokenLimiter tests
@ -187,7 +257,7 @@ if __name__ == "__main__":
message_token_limiter_apply_transform_parameters["expected_messages_len"],
):
test_message_token_limiter_apply_transform(
message_token_limiter, messages, expected_token_count, expected_messages_len
msg_token_limiter, messages, expected_token_count, expected_messages_len
)
for messages, expected_token_count, expected_messages_len in zip(
@ -196,7 +266,7 @@ if __name__ == "__main__":
message_token_limiter_with_threshold_apply_transform_parameters["expected_messages_len"],
):
test_message_token_limiter_with_threshold_apply_transform(
message_token_limiter_with_threshold, messages, expected_token_count, expected_messages_len
msg_token_limiter_with_threshold, messages, expected_token_count, expected_messages_len
)
for messages, expected_logs, expected_effect in zip(
@ -204,4 +274,4 @@ if __name__ == "__main__":
message_token_limiter_get_logs_parameters["expected_logs"],
message_token_limiter_get_logs_parameters["expected_effect"],
):
test_message_token_limiter_get_logs(message_token_limiter, messages, expected_logs, expected_effect)
test_message_token_limiter_get_logs(msg_token_limiter, messages, expected_logs, expected_effect)

View File

@ -267,7 +267,7 @@ Migrating enhances flexibility, modularity, and customization in handling chat m
### How to migrate?
To ensure a smooth migration process, simply follow the detailed guide provided in [Handling Long Context Conversations with Transform Messages](/docs/topics/long_contexts.md).
To ensure a smooth migration process, simply follow the detailed guide provided in [Introduction to TransformMessages](/docs/topics/handling_long_contexts/intro_to_transform_messages.md).
### What should I do if I get the error "TypeError: Assistants.create() got an unexpected keyword argument 'file_ids'"?

View File

@ -85,7 +85,7 @@ To use Teachability, please install AutoGen with the [teachable] option.
pip install "pyautogen[teachable]"
```
Example notebook: [Chatting with a teachable agent](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_teachability.ipynb)
Example notebook: [Chatting with a teachable agent](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_teachability.ipynb)
## Large Multimodal Model (LMM) Agents
@ -115,9 +115,16 @@ Example notebooks:
To use a graph in `GroupChat`, particularly for graph visualization, please install AutoGen with the [graph] option.
```bash
pip install "pyautogen[graph]"
```
Example notebook: [Graph Modeling Language with using select_speaker](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_graph_modelling_language_using_select_speaker.ipynb)
Example notebook: [Graph Modeling Language with using select_speaker](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_graph_modelling_language_using_select_speaker.ipynb)
## Long Context Handling
AutoGen includes support for handling long textual contexts by leveraging the LLMLingua library for text compression. To enable this functionality, please install AutoGen with the `[long-context]` option:
```bash
pip install "pyautogen[long-context]"
```

View File

@ -0,0 +1,4 @@
{
"label": "Handling Long Contexts",
"collapsible": true
}

View File

@ -0,0 +1,171 @@
# Compressing Text with LLMLingua
Text compression is crucial for optimizing interactions with LLMs, especially when dealing with long prompts that can lead to higher costs and slower response times. LLMLingua is a tool designed to compress prompts effectively, enhancing the efficiency and cost-effectiveness of LLM operations.
This guide introduces LLMLingua's integration with AutoGen, demonstrating how to use this tool to compress text, thereby optimizing the usage of LLMs for various applications.
:::info Requirements
Install `pyautogen[long-context]` and `PyMuPDF`:
```bash
pip install "pyautogen[long-context]" PyMuPDF
```
For more information, please refer to the [installation guide](/docs/installation/).
:::
## Example 1: Compressing AutoGen Research Paper using LLMLingua
We will look at how we can use `TextMessageCompressor` to compress an AutoGen research paper using `LLMLingua`. Here's how you can initialize `TextMessageCompressor` with LLMLingua, a text compressor that adheres to the `TextCompressor` protocol.
```python
import tempfile
import fitz # PyMuPDF
import requests
from autogen.agentchat.contrib.capabilities.text_compressors import LLMLingua
from autogen.agentchat.contrib.capabilities.transforms import TextMessageCompressor
AUTOGEN_PAPER = "https://arxiv.org/pdf/2308.08155"
def extract_text_from_pdf():
# Download the PDF
response = requests.get(AUTOGEN_PAPER)
response.raise_for_status() # Ensure the download was successful
text = ""
# Save the PDF to a temporary file
with tempfile.TemporaryDirectory() as temp_dir:
with open(temp_dir + "temp.pdf", "wb") as f:
f.write(response.content)
# Open the PDF
with fitz.open(temp_dir + "temp.pdf") as doc:
# Read and extract text from each page
for page in doc:
text += page.get_text()
return text
# Example usage
pdf_text = extract_text_from_pdf()
llm_lingua = LLMLingua()
text_compressor = TextMessageCompressor(text_compressor=llm_lingua)
compressed_text = text_compressor.apply_transform([{"content": pdf_text}])
print(text_compressor.get_logs([], []))
```
```console
('19765 tokens saved with text compression.', True)
```
## Example 2: Integrating LLMLingua with `ConversableAgent`
Now, let's integrate `LLMLingua` into a conversational agent within AutoGen. This allows dynamic compression of prompts before they are sent to the LLM.
```python
import os
import autogen
from autogen.agentchat.contrib.capabilities import transform_messages
system_message = "You are a world class researcher."
config_list = [{"model": "gpt-4-turbo", "api_key": os.getenv("OPENAI_API_KEY")}]
# Define your agent; the user proxy and an assistant
researcher = autogen.ConversableAgent(
"assistant",
llm_config={"config_list": config_list},
max_consecutive_auto_reply=1,
system_message=system_message,
human_input_mode="NEVER",
)
user_proxy = autogen.UserProxyAgent(
"user_proxy",
human_input_mode="NEVER",
is_termination_msg=lambda x: "TERMINATE" in x.get("content", ""),
max_consecutive_auto_reply=1,
)
```
:::tip
Learn more about configuring LLMs for agents [here](/docs/topics/llm_configuration).
:::
```python
context_handling = transform_messages.TransformMessages(transforms=[text_compressor])
context_handling.add_to_agent(researcher)
message = "Summarize this research paper for me, include the important information" + pdf_text
result = user_proxy.initiate_chat(recipient=researcher, clear_history=True, message=message, silent=True)
print(result.chat_history[1]["content"])
```
```console
19953 tokens saved with text compression.
The paper describes AutoGen, a framework designed to facilitate the development of diverse large language model (LLM) applications through conversational multi-agent systems. The framework emphasizes customization and flexibility, enabling developers to define agent interaction behaviors in natural language or computer code.
Key components of AutoGen include:
1. **Conversable Agents**: These are customizable agents designed to operate autonomously or through human interaction. They are capable of initiating, maintaining, and responding within conversations, contributing effectively to multi-agent dialogues.
2. **Conversation Programming**: AutoGen introduces a programming paradigm centered around conversational interactions among agents. This approach simplifies the development of complex applications by streamlining how agents communicate and interact, focusing on conversational logic rather than traditional coding for
mats.
3. **Agent Customization and Flexibility**: Developers have the freedom to define the capabilities and behaviors of agents within the system, allowing for a wide range of applications across different domains.
4. **Application Versatility**: The paper outlines various use cases from mathematics and coding to decision-making and entertainment, demonstrating AutoGen's ability to cope with a broad spectrum of complexities and requirements.
5. **Hierarchical and Joint Chat Capabilities**: The system supports complex conversation patterns including hierarchical and multi-agent interactions, facilitating robust dialogues that can dynamically adjust based on the conversation context and the agents' roles.
6. **Open-source and Community Engagement**: AutoGen is presented as an open-source framework, inviting contributions and adaptations from the global development community to expand its capabilities and applications.
The framework's architecture is designed so that it can be seamlessly integrated into existing systems, providing a robust foundation for developing sophisticated multi-agent applications that leverage the capabilities of modern LLMs. The paper also discusses potential ethical considerations and future improvements, highlighting the importance of continual development in response to evolving tech landscapes and user needs.
```
## Example 3: Modifying LLMLingua's Compression Parameters
LLMLingua's flexibility allows for various configurations, such as customizing instructions for the LLM or setting specific token counts for compression. This example demonstrates how to set a target token count, enabling the use of models with smaller context sizes like gpt-3.5.
```python
config_list = [{"model": "gpt-3.5-turbo", "api_key": os.getenv("OPENAI_API_KEY")}]
researcher = autogen.ConversableAgent(
"assistant",
llm_config={"config_list": config_list},
max_consecutive_auto_reply=1,
system_message=system_message,
human_input_mode="NEVER",
)
text_compressor = TextMessageCompressor(
text_compressor=llm_lingua,
compression_params={"target_token": 13000},
cache=None,
)
context_handling = transform_messages.TransformMessages(transforms=[text_compressor])
context_handling.add_to_agent(researcher)
compressed_text = text_compressor.apply_transform([{"content": message}])
result = user_proxy.initiate_chat(recipient=researcher, clear_history=True, message=message, silent=True)
print(result.chat_history[1]["content"])
```
```console
25308 tokens saved with text compression.
Based on the extensive research paper information provided, it seems that the focus is on developing a framework called AutoGen for creating multi-agent conversations based on Large Language Models (LLMs) for a variety of applications such as math problem solving, coding, decision-making, and more.
The paper discusses the importance of incorporating diverse roles of LLMs, human inputs, and tools to enhance the capabilities of the conversable agents within the AutoGen framework. It also delves into the effectiveness of different systems in various scenarios, showcases the implementation of AutoGen in pilot studies, and compares its performance with other systems in tasks like math problem-solving, coding, and decision-making.
The paper also highlights the different features and components of AutoGen such as the AssistantAgent, UserProxyAgent, ExecutorAgent, and GroupChatManager, emphasizing its flexibility, ease of use, and modularity in managing multi-agent interactions. It presents case analyses to demonstrate the effectiveness of AutoGen in various applications and scenarios.
Furthermore, the paper includes manual evaluations, scenario testing, code examples, and detailed comparisons with other systems like ChatGPT, OptiGuide, MetaGPT, and more, to showcase the performance and capabilities of the AutoGen framework.
Overall, the research paper showcases the potential of AutoGen in facilitating dynamic multi-agent conversations, enhancing decision-making processes, and improving problem-solving tasks with the integration of LLMs, human inputs, and tools in a collaborative framework.
```

View File

@ -1,4 +1,4 @@
# Handling Long Context Conversations with Transform Messages
# Introduction to Transform Messages
Why do we need to handle long contexts? The problem arises from several constraints and requirements:
@ -14,6 +14,7 @@ The `TransformMessages` capability is designed to modify incoming messages befor
:::info Requirements
Install `pyautogen`:
```bash
pip install pyautogen
```
@ -99,9 +100,9 @@ pprint.pprint(processed_short_messages)
```console
[{'content': 'hello there, how are you?', 'role': 'user'},
{'content': [{'text': 'hello', 'type': 'text'}], 'role': 'assistant'}]
```
```
We can see that no transformation was applied, because the threshold of 10 total tokens was not reached.
We can see that no transformation was applied, because the threshold of 10 total tokens was not reached.
### Apply Transformations Using Agents
@ -318,7 +319,7 @@ result = user_proxy.initiate_chat(
```
````console
```console
user_proxy (to assistant):
What are the two API keys that I just provided
@ -340,4 +341,4 @@ user_proxy (to assistant):
--------------------------------------------------------------------------------
Redacted 2 OpenAI API keys.
````
```