Add token counting to chat completion client #220 (#239)

* Add token counting to chat completion client

* fix mypy

* ignore pyright for object type

* format
This commit is contained in:
Eric Zhu 2024-07-19 18:44:22 -07:00 committed by GitHub
parent c425a447a7
commit 2041905acb
5 changed files with 168 additions and 2 deletions

View File

@ -21,7 +21,8 @@ dependencies = [
"pydantic>=1.10,<3",
"types-aiofiles",
"grpcio",
"protobuf"
"protobuf",
"tiktoken"
]
[tool.hatch.envs.default]
@ -30,6 +31,7 @@ dependencies = [
"pyright==1.1.368",
"mypy==1.10.0",
"ruff==0.4.8",
"tiktoken",
"types-Pillow",
"polars",
"chess",

View File

@ -48,5 +48,9 @@ class ChatCompletionClient(Protocol):
def total_usage(self) -> RequestUsage: ...
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ...
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ...
@property
def capabilities(self) -> ModelCapabilities: ...

View File

@ -77,6 +77,21 @@ _MODEL_CAPABILITIES: Dict[str, ModelCapabilities] = {
},
}
_MODEL_TOKEN_LIMITS: Dict[str, int] = {
"gpt-4o-2024-05-13": 128000,
"gpt-4-turbo-2024-04-09": 128000,
"gpt-4-0125-preview": 128000,
"gpt-4-1106-preview": 128000,
"gpt-4-1106-vision-preview": 128000,
"gpt-4-0613": 8192,
"gpt-4-32k-0613": 32768,
"gpt-3.5-turbo-0125": 16385,
"gpt-3.5-turbo-1106": 16385,
"gpt-3.5-turbo-instruct": 4096,
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-16k-0613": 16385,
}
def resolve_model(model: str) -> str:
if model in _MODEL_POINTERS:
@ -87,3 +102,8 @@ def resolve_model(model: str) -> str:
def get_capabilties(model: str) -> ModelCapabilities:
resolved_model = resolve_model(model)
return _MODEL_CAPABILITIES[resolved_model]
def get_token_limit(model: str) -> int:
resolved_model = resolve_model(model)
return _MODEL_TOKEN_LIMITS[resolved_model]

View File

@ -1,4 +1,5 @@
import inspect
import json
import logging
import re
import warnings
@ -15,6 +16,7 @@ from typing import (
cast,
)
import tiktoken
from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
@ -32,7 +34,7 @@ from openai.types.chat import (
from openai.types.shared_params import FunctionDefinition, FunctionParameters
from typing_extensions import Unpack
from ...application.logging import EVENT_LOGGER_NAME
from ...application.logging import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
from ...application.logging.events import LLMCallEvent
from .. import (
FunctionCall,
@ -53,6 +55,7 @@ from ._types import (
from .config import AzureOpenAIClientConfiguration, OpenAIClientConfiguration
logger = logging.getLogger(EVENT_LOGGER_NAME)
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
openai_init_kwargs = set(inspect.getfullargspec(AsyncOpenAI.__init__).kwonlyargs)
aopenai_init_kwargs = set(inspect.getfullargspec(AsyncAzureOpenAI.__init__).kwonlyargs)
@ -518,6 +521,77 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
def total_usage(self) -> RequestUsage:
return self._total_usage
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
model = self._create_args["model"]
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
trace_logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
# Message tokens.
for message in messages:
num_tokens += tokens_per_message
oai_message = to_oai_type(message)
for oai_message_part in oai_message:
for key, value in oai_message_part.items():
if value is None:
continue
if not isinstance(value, str):
try:
value = json.dumps(value)
except TypeError:
trace_logger.warning(f"Could not convert {value} to string, skipping.")
continue
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
# Tool tokens.
oai_tools = convert_tools(tools)
for tool in oai_tools:
function = tool["function"]
tool_tokens = len(encoding.encode(function["name"]))
if "description" in function:
tool_tokens += len(encoding.encode(function["description"]))
tool_tokens -= 2
if "parameters" in function:
parameters = function["parameters"]
if "properties" in parameters:
assert isinstance(parameters["properties"], dict)
for propertiesKey in parameters["properties"]: # pyright: ignore
assert isinstance(propertiesKey, str)
tool_tokens += len(encoding.encode(propertiesKey))
v = parameters["properties"][propertiesKey] # pyright: ignore
for field in v: # pyright: ignore
if field == "type":
tool_tokens += 2
tool_tokens += len(encoding.encode(v["type"])) # pyright: ignore
elif field == "description":
tool_tokens += 2
tool_tokens += len(encoding.encode(v["description"])) # pyright: ignore
elif field == "enum":
tool_tokens -= 3
for o in v["enum"]: # pyright: ignore
tool_tokens += 3
tool_tokens += len(encoding.encode(o)) # pyright: ignore
else:
trace_logger.warning(f"Not supported field {field}")
tool_tokens += 11
if len(parameters["properties"]) == 0: # pyright: ignore
tool_tokens -= 2
num_tokens += tool_tokens
num_tokens += 12
return num_tokens
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
token_limit = _model_info.get_token_limit(self._create_args["model"])
return token_limit - self.count_tokens(messages, tools)
@property
def capabilities(self) -> ModelCapabilities:
return self._model_capabilities

View File

@ -0,0 +1,66 @@
from typing import List
import pytest
from agnext.components import Image
from agnext.components.models import (
AssistantMessage,
AzureOpenAIChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
OpenAIChatCompletionClient,
SystemMessage,
UserMessage,
)
from agnext.components.tools import FunctionTool
@pytest.mark.asyncio
async def test_openai_chat_completion_client() -> None:
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
assert client
@pytest.mark.asyncio
async def test_azure_openai_chat_completion_client() -> None:
client = AzureOpenAIChatCompletionClient(
model="gpt-4o",
api_key="api_key",
api_version="2020-08-04",
azure_endpoint="https://dummy.com",
model_capabilities={"vision": True, "function_calling": True, "json_output": True},
)
assert client
@pytest.mark.asyncio
async def test_openai_chat_completion_client_count_tokens() -> None:
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
messages : List[LLMMessage] = [
SystemMessage(content="Hello"),
UserMessage(content="Hello", source="user"),
AssistantMessage(content="Hello", source="assistant"),
UserMessage(
content=[
"str1",
Image.from_base64(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4z8AAAAMBAQDJ/pLvAAAAAElFTkSuQmCC"
),
],
source="user",
),
FunctionExecutionResultMessage(content=[FunctionExecutionResult(content="Hello", call_id="1")]),
]
def tool1(test: str, test2: str) -> str:
return test + test2
def tool2(test1: int, test2: List[int]) -> str:
return str(test1) + str(test2)
tools = [FunctionTool(tool1, description="example tool 1"), FunctionTool(tool2, description="example tool 2")]
num_tokens = client.count_tokens(messages, tools=tools)
assert num_tokens
remaining_tokens = client.remaining_tokens(messages, tools=tools)
assert remaining_tokens