mirror of https://github.com/microsoft/autogen.git
* Add token counting to chat completion client * fix mypy * ignore pyright for object type * format
This commit is contained in:
parent
c425a447a7
commit
2041905acb
|
@ -21,7 +21,8 @@ dependencies = [
|
|||
"pydantic>=1.10,<3",
|
||||
"types-aiofiles",
|
||||
"grpcio",
|
||||
"protobuf"
|
||||
"protobuf",
|
||||
"tiktoken"
|
||||
]
|
||||
|
||||
[tool.hatch.envs.default]
|
||||
|
@ -30,6 +31,7 @@ dependencies = [
|
|||
"pyright==1.1.368",
|
||||
"mypy==1.10.0",
|
||||
"ruff==0.4.8",
|
||||
"tiktoken",
|
||||
"types-Pillow",
|
||||
"polars",
|
||||
"chess",
|
||||
|
|
|
@ -48,5 +48,9 @@ class ChatCompletionClient(Protocol):
|
|||
|
||||
def total_usage(self) -> RequestUsage: ...
|
||||
|
||||
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ...
|
||||
|
||||
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ...
|
||||
|
||||
@property
|
||||
def capabilities(self) -> ModelCapabilities: ...
|
||||
|
|
|
@ -77,6 +77,21 @@ _MODEL_CAPABILITIES: Dict[str, ModelCapabilities] = {
|
|||
},
|
||||
}
|
||||
|
||||
_MODEL_TOKEN_LIMITS: Dict[str, int] = {
|
||||
"gpt-4o-2024-05-13": 128000,
|
||||
"gpt-4-turbo-2024-04-09": 128000,
|
||||
"gpt-4-0125-preview": 128000,
|
||||
"gpt-4-1106-preview": 128000,
|
||||
"gpt-4-1106-vision-preview": 128000,
|
||||
"gpt-4-0613": 8192,
|
||||
"gpt-4-32k-0613": 32768,
|
||||
"gpt-3.5-turbo-0125": 16385,
|
||||
"gpt-3.5-turbo-1106": 16385,
|
||||
"gpt-3.5-turbo-instruct": 4096,
|
||||
"gpt-3.5-turbo-0613": 4096,
|
||||
"gpt-3.5-turbo-16k-0613": 16385,
|
||||
}
|
||||
|
||||
|
||||
def resolve_model(model: str) -> str:
|
||||
if model in _MODEL_POINTERS:
|
||||
|
@ -87,3 +102,8 @@ def resolve_model(model: str) -> str:
|
|||
def get_capabilties(model: str) -> ModelCapabilities:
|
||||
resolved_model = resolve_model(model)
|
||||
return _MODEL_CAPABILITIES[resolved_model]
|
||||
|
||||
|
||||
def get_token_limit(model: str) -> int:
|
||||
resolved_model = resolve_model(model)
|
||||
return _MODEL_TOKEN_LIMITS[resolved_model]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import warnings
|
||||
|
@ -15,6 +16,7 @@ from typing import (
|
|||
cast,
|
||||
)
|
||||
|
||||
import tiktoken
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
|
@ -32,7 +34,7 @@ from openai.types.chat import (
|
|||
from openai.types.shared_params import FunctionDefinition, FunctionParameters
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from ...application.logging import EVENT_LOGGER_NAME
|
||||
from ...application.logging import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
|
||||
from ...application.logging.events import LLMCallEvent
|
||||
from .. import (
|
||||
FunctionCall,
|
||||
|
@ -53,6 +55,7 @@ from ._types import (
|
|||
from .config import AzureOpenAIClientConfiguration, OpenAIClientConfiguration
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
|
||||
|
||||
openai_init_kwargs = set(inspect.getfullargspec(AsyncOpenAI.__init__).kwonlyargs)
|
||||
aopenai_init_kwargs = set(inspect.getfullargspec(AsyncAzureOpenAI.__init__).kwonlyargs)
|
||||
|
@ -518,6 +521,77 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||
def total_usage(self) -> RequestUsage:
|
||||
return self._total_usage
|
||||
|
||||
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
model = self._create_args["model"]
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
trace_logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
num_tokens = 0
|
||||
|
||||
# Message tokens.
|
||||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
oai_message = to_oai_type(message)
|
||||
for oai_message_part in oai_message:
|
||||
for key, value in oai_message_part.items():
|
||||
if value is None:
|
||||
continue
|
||||
if not isinstance(value, str):
|
||||
try:
|
||||
value = json.dumps(value)
|
||||
except TypeError:
|
||||
trace_logger.warning(f"Could not convert {value} to string, skipping.")
|
||||
continue
|
||||
num_tokens += len(encoding.encode(value))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
|
||||
# Tool tokens.
|
||||
oai_tools = convert_tools(tools)
|
||||
for tool in oai_tools:
|
||||
function = tool["function"]
|
||||
tool_tokens = len(encoding.encode(function["name"]))
|
||||
if "description" in function:
|
||||
tool_tokens += len(encoding.encode(function["description"]))
|
||||
tool_tokens -= 2
|
||||
if "parameters" in function:
|
||||
parameters = function["parameters"]
|
||||
if "properties" in parameters:
|
||||
assert isinstance(parameters["properties"], dict)
|
||||
for propertiesKey in parameters["properties"]: # pyright: ignore
|
||||
assert isinstance(propertiesKey, str)
|
||||
tool_tokens += len(encoding.encode(propertiesKey))
|
||||
v = parameters["properties"][propertiesKey] # pyright: ignore
|
||||
for field in v: # pyright: ignore
|
||||
if field == "type":
|
||||
tool_tokens += 2
|
||||
tool_tokens += len(encoding.encode(v["type"])) # pyright: ignore
|
||||
elif field == "description":
|
||||
tool_tokens += 2
|
||||
tool_tokens += len(encoding.encode(v["description"])) # pyright: ignore
|
||||
elif field == "enum":
|
||||
tool_tokens -= 3
|
||||
for o in v["enum"]: # pyright: ignore
|
||||
tool_tokens += 3
|
||||
tool_tokens += len(encoding.encode(o)) # pyright: ignore
|
||||
else:
|
||||
trace_logger.warning(f"Not supported field {field}")
|
||||
tool_tokens += 11
|
||||
if len(parameters["properties"]) == 0: # pyright: ignore
|
||||
tool_tokens -= 2
|
||||
num_tokens += tool_tokens
|
||||
num_tokens += 12
|
||||
return num_tokens
|
||||
|
||||
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
token_limit = _model_info.get_token_limit(self._create_args["model"])
|
||||
return token_limit - self.count_tokens(messages, tools)
|
||||
|
||||
@property
|
||||
def capabilities(self) -> ModelCapabilities:
|
||||
return self._model_capabilities
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
from typing import List
|
||||
|
||||
import pytest
|
||||
from agnext.components import Image
|
||||
from agnext.components.models import (
|
||||
AssistantMessage,
|
||||
AzureOpenAIChatCompletionClient,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
OpenAIChatCompletionClient,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.components.tools import FunctionTool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_chat_completion_client() -> None:
|
||||
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
|
||||
assert client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_openai_chat_completion_client() -> None:
|
||||
client = AzureOpenAIChatCompletionClient(
|
||||
model="gpt-4o",
|
||||
api_key="api_key",
|
||||
api_version="2020-08-04",
|
||||
azure_endpoint="https://dummy.com",
|
||||
model_capabilities={"vision": True, "function_calling": True, "json_output": True},
|
||||
)
|
||||
assert client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_chat_completion_client_count_tokens() -> None:
|
||||
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
|
||||
messages : List[LLMMessage] = [
|
||||
SystemMessage(content="Hello"),
|
||||
UserMessage(content="Hello", source="user"),
|
||||
AssistantMessage(content="Hello", source="assistant"),
|
||||
UserMessage(
|
||||
content=[
|
||||
"str1",
|
||||
Image.from_base64(
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4z8AAAAMBAQDJ/pLvAAAAAElFTkSuQmCC"
|
||||
),
|
||||
],
|
||||
source="user",
|
||||
),
|
||||
FunctionExecutionResultMessage(content=[FunctionExecutionResult(content="Hello", call_id="1")]),
|
||||
]
|
||||
|
||||
def tool1(test: str, test2: str) -> str:
|
||||
return test + test2
|
||||
|
||||
def tool2(test1: int, test2: List[int]) -> str:
|
||||
return str(test1) + str(test2)
|
||||
|
||||
tools = [FunctionTool(tool1, description="example tool 1"), FunctionTool(tool2, description="example tool 2")]
|
||||
num_tokens = client.count_tokens(messages, tools=tools)
|
||||
assert num_tokens
|
||||
|
||||
remaining_tokens = client.remaining_tokens(messages, tools=tools)
|
||||
assert remaining_tokens
|
Loading…
Reference in New Issue