This commit is contained in:
Wael Karkoub 2024-04-12 06:03:32 +01:00 committed by GitHub
parent 689950e58e
commit 78cb908f95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 1 deletions

View File

@ -126,9 +126,16 @@ class MessageTokenLimiter:
processed_messages_tokens = 0
# calculate tokens for all messages
total_tokens = sum(_count_tokens(msg["content"]) for msg in temp_messages)
total_tokens = sum(
_count_tokens(msg["content"]) for msg in temp_messages if isinstance(msg.get("content"), (str, list))
)
for msg in reversed(temp_messages):
# Some messages may not have content.
if not isinstance(msg.get("content"), (str, list)):
processed_messages.insert(0, msg)
continue
expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message
# If adding this message would exceed the token limit, truncate the last message to meet the total token

View File

@ -76,6 +76,29 @@ def test_limit_token_transform():
assert len(transformed_messages) <= len(messages)
def test_limit_token_transform_without_content():
"""Test the TokenLimitTransform with messages that don't have content."""
messages = [{"role": "user", "function_call": "example"}, {"role": "assistant", "content": None}]
# check if token limit per message works nicely with total token limit.
token_limit_transform = MessageTokenLimiter(max_tokens=10, max_tokens_per_message=5)
transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages))
assert len(transformed_messages) == len(messages)
def test_limit_token_transform_total_token_count():
"""Tests if the TokenLimitTransform truncates without dropping messages."""
messages = [{"role": "very very very very very"}]
token_limit_transform = MessageTokenLimiter(max_tokens=1)
transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages))
assert len(transformed_messages) == 1
def test_max_message_history_length_transform():
"""
Test the MessageHistoryLimiter capability to limit the number of messages.