From 8e4254fc20606fd78f9ebbfc5c08bbdd06e3d8b6 Mon Sep 17 00:00:00 2001 From: LangChain4j Date: Tue, 12 Dec 2023 16:45:16 +0100 Subject: [PATCH] make OpenAI tokenizer more precise (#346) This PR is a rework of `OpenAiTokenizer`. Added `OpenAiTokenizerIT` with lots of tests to ensure that `OpenAiTokenizer` calculates token usage very close to OpenAI. In most cases calculation is 1:1, in some corner cases the difference is within 5%. --- .../azure/AzureOpenAiStreamingChatModel.java | 17 +- .../AzureOpenAiStreamingChatModelIT.java | 2 +- .../langchain4j/internal/GsonJsonCodec.java | 28 +- .../java/dev/langchain4j/internal/Json.java | 49 +- .../java/dev/langchain4j/model/Tokenizer.java | 6 +- .../openai/OpenAiStreamingChatModel.java | 13 +- .../model/openai/OpenAiTokenizer.java | 276 ++- .../model/openai/OpenAiChatModelIT.java | 25 +- .../openai/OpenAiStreamingChatModelIT.java | 33 +- .../model/openai/OpenAiTokenizerIT.java | 1739 +++++++++++++++++ .../model/openai/OpenAiTokenizerTest.java | 118 +- .../AzureOpenAiStreamingChatModel.java | 17 +- .../service/StreamingAiServicesIT.java | 15 +- 13 files changed, 2093 insertions(+), 245 deletions(-) create mode 100644 langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerIT.java diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java index defd5c556..af8b46590 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java @@ -20,6 +20,7 @@ import java.time.Duration; import java.util.List; import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupOpenAIClient; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.toFunctions; @@ -117,7 +118,7 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel @Override public void generate(List messages, ToolSpecification toolSpecification, StreamingResponseHandler handler) { - generate(messages, singletonList(toolSpecification), toolSpecification, handler); + generate(messages, null, toolSpecification, handler); } private void generate(List messages, @@ -136,18 +137,18 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel Integer inputTokenCount = tokenizer == null ? null : tokenizer.estimateTokenCountInMessages(messages); - if (toolSpecifications != null && !toolSpecifications.isEmpty()) { + if (toolThatMustBeExecuted != null) { + options.setFunctions(toFunctions(singletonList(toolThatMustBeExecuted))); + options.setFunctionCall(new FunctionCallConfig(toolThatMustBeExecuted.name())); + if (tokenizer != null) { + inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted); + } + } else if (!isNullOrEmpty(toolSpecifications)) { options.setFunctions(toFunctions(toolSpecifications)); if (tokenizer != null) { inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications); } } - if (toolThatMustBeExecuted != null) { - options.setFunctionCall(new FunctionCallConfig(toolThatMustBeExecuted.name())); - if (tokenizer != null) { - inputTokenCount += tokenizer.estimateTokenCountInToolSpecification(toolThatMustBeExecuted); - } - } AzureOpenAiStreamingResponseBuilder responseBuilder = new AzureOpenAiStreamingResponseBuilder(inputTokenCount); diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModelIT.java index 6f25121af..ec8023f1d 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModelIT.java @@ -124,7 +124,7 @@ class AzureOpenAiStreamingChatModelIT { assertThat(toolExecutionRequest.name()).isEqualTo("calculator"); assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}"); - assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(50); + assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(53); assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0); assertThat(response.tokenUsage().totalTokenCount()) .isEqualTo(response.tokenUsage().inputTokenCount() + response.tokenUsage().outputTokenCount()); diff --git a/langchain4j-core/src/main/java/dev/langchain4j/internal/GsonJsonCodec.java b/langchain4j-core/src/main/java/dev/langchain4j/internal/GsonJsonCodec.java index 8b1db52c5..e092c4ec4 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/internal/GsonJsonCodec.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/internal/GsonJsonCodec.java @@ -1,23 +1,21 @@ package dev.langchain4j.internal; -import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE; -import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE_TIME; -import com.google.gson.Gson; -import com.google.gson.GsonBuilder; -import com.google.gson.JsonDeserializer; -import com.google.gson.JsonPrimitive; -import com.google.gson.JsonSerializer; +import com.google.gson.*; +import com.google.gson.reflect.TypeToken; import com.google.gson.stream.JsonWriter; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStreamWriter; + +import java.io.*; +import java.lang.reflect.Type; import java.nio.charset.StandardCharsets; import java.time.LocalDate; import java.time.LocalDateTime; +import java.util.Map; + +import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE; +import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE_TIME; class GsonJsonCodec implements Json.JsonCodec { + private static final Gson GSON = new GsonBuilder() .setPrettyPrinting() .registerTypeAdapter( @@ -40,6 +38,9 @@ class GsonJsonCodec implements Json.JsonCodec { ) .create(); + public static final Type MAP_TYPE = new TypeToken>() { + }.getType(); + @Override public String toJson(Object o) { return GSON.toJson(o); @@ -47,6 +48,9 @@ class GsonJsonCodec implements Json.JsonCodec { @Override public T fromJson(String json, Class type) { + if (type == Map.class) { + return GSON.fromJson(json, MAP_TYPE); + } return GSON.fromJson(json, type); } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/internal/Json.java b/langchain4j-core/src/main/java/dev/langchain4j/internal/Json.java index 9d2f2bcb4..7678eb366 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/internal/Json.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/internal/Json.java @@ -1,44 +1,43 @@ package dev.langchain4j.internal; -import dev.langchain4j.spi.json.JsonCodecFactory; import dev.langchain4j.spi.ServiceHelper; +import dev.langchain4j.spi.json.JsonCodecFactory; + import java.io.IOException; import java.io.InputStream; import java.util.Collection; public class Json { - private static final JsonCodec CODEC = loadCodec(); + private static final JsonCodec CODEC = loadCodec(); - private static JsonCodec loadCodec() { - Collection factories = ServiceHelper.loadFactories(JsonCodecFactory.class); - for (JsonCodecFactory factory : factories) { - return factory.create(); + private static JsonCodec loadCodec() { + Collection factories = ServiceHelper.loadFactories(JsonCodecFactory.class); + for (JsonCodecFactory factory : factories) { + return factory.create(); + } + // fallback to default + return new GsonJsonCodec(); } - // fallback to default - return new GsonJsonCodec(); - } + public static String toJson(Object o) { + return CODEC.toJson(o); + } + public static T fromJson(String json, Class type) { + return CODEC.fromJson(json, type); + } - public static String toJson(Object o) { - return CODEC.toJson(o); - } + public static InputStream toInputStream(Object o, Class type) throws IOException { + return CODEC.toInputStream(o, type); + } - public static T fromJson(String json, Class type) { - return CODEC.fromJson(json, type); - } + public interface JsonCodec { - public static InputStream toInputStream(Object o, Class type) throws IOException { - return CODEC.toInputStream(o, type); - } + String toJson(Object o); - public interface JsonCodec { - String toJson(Object o); - - T fromJson(String json, Class type); - InputStream toInputStream(Object o, Class type) throws IOException; - - } + T fromJson(String json, Class type); + InputStream toInputStream(Object o, Class type) throws IOException; + } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/Tokenizer.java b/langchain4j-core/src/main/java/dev/langchain4j/model/Tokenizer.java index a66c8528d..ad91de730 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/model/Tokenizer.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/model/Tokenizer.java @@ -29,12 +29,12 @@ public interface Tokenizer { return estimateTokenCountInToolSpecifications(toolSpecifications); } - default int estimateTokenCountInToolSpecification(ToolSpecification toolSpecification) { + int estimateTokenCountInToolSpecifications(Iterable toolSpecifications); + + default int estimateTokenCountInForcefulToolSpecification(ToolSpecification toolSpecification) { return estimateTokenCountInToolSpecifications(singletonList(toolSpecification)); } - int estimateTokenCountInToolSpecifications(Iterable toolSpecifications); - int estimateTokenCountInToolExecutionRequests(Iterable toolExecutionRequests); default int estimateTokenCountInForcefulToolExecutionRequest(ToolExecutionRequest toolExecutionRequest) { diff --git a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingChatModel.java b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingChatModel.java index fb2a44604..d33d3d9c0 100644 --- a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingChatModel.java +++ b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiStreamingChatModel.java @@ -20,6 +20,7 @@ import java.time.Duration; import java.util.List; import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.model.openai.InternalOpenAiHelper.*; import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; import static java.time.Duration.ofSeconds; @@ -93,7 +94,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok @Override public void generate(List messages, ToolSpecification toolSpecification, StreamingResponseHandler handler) { - generate(messages, singletonList(toolSpecification), toolSpecification, handler); + generate(messages, null, toolSpecification, handler); } private void generate(List messages, @@ -114,14 +115,14 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok int inputTokenCount = tokenizer.estimateTokenCountInMessages(messages); - if (toolSpecifications != null && !toolSpecifications.isEmpty()) { + if (toolThatMustBeExecuted != null) { + requestBuilder.tools(toTools(singletonList(toolThatMustBeExecuted))); + requestBuilder.toolChoice(toolThatMustBeExecuted.name()); + inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted); + } else if (!isNullOrEmpty(toolSpecifications)) { requestBuilder.tools(toTools(toolSpecifications)); inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications); } - if (toolThatMustBeExecuted != null) { - requestBuilder.toolChoice(toolThatMustBeExecuted.name()); - inputTokenCount += tokenizer.estimateTokenCountInToolSpecification(toolThatMustBeExecuted); - } ChatCompletionRequest request = requestBuilder.build(); diff --git a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiTokenizer.java b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiTokenizer.java index 2d9863327..153652d43 100644 --- a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiTokenizer.java +++ b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiTokenizer.java @@ -7,7 +7,6 @@ import dev.langchain4j.agent.tool.ToolParameters; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; -import dev.langchain4j.data.message.ToolExecutionResultMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.model.Tokenizer; @@ -17,10 +16,17 @@ import java.util.Optional; import java.util.function.Supplier; import static dev.langchain4j.internal.Exceptions.illegalArgument; +import static dev.langchain4j.internal.Json.fromJson; +import static dev.langchain4j.internal.Utils.isNullOrBlank; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; -import static dev.langchain4j.model.openai.InternalOpenAiHelper.roleFrom; -import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO_0301; +import static dev.langchain4j.model.openai.OpenAiModelName.*; +import static java.util.Collections.singletonList; +/** + * This class can be used to estimate the cost (in tokens) before calling OpenAI or when using streaming. + * Magic numbers present in this class were found empirically while testing. + * There are integration tests in place that are making sure that the calculations here are very close to that of OpenAI. + */ public class OpenAiTokenizer implements Tokenizer { private final String modelName; @@ -40,43 +46,15 @@ public class OpenAiTokenizer implements Tokenizer { .countTokensOrdinary(text); } - //Estimate the number of tokens in the parameters of a tool - private int estimateTokenCountInToolParameters(ToolParameters parameters) { - //Return early if there are no parameters - if (parameters == null) return 0; - - int tokenCount = 0; - Map> properties = parameters.properties(); - for (String property : properties.keySet()) { - for (Map.Entry entry : properties.get(property).entrySet()) { - if ("type".equals(entry.getKey())) { - tokenCount += 3; // found experimentally while playing with OpenAI API - tokenCount += estimateTokenCountInText(entry.getValue().toString()); - } else if ("description".equals(entry.getKey())) { - tokenCount += 3; // found experimentally while playing with OpenAI API - tokenCount += estimateTokenCountInText(entry.getValue().toString()); - } else if ("enum".equals(entry.getKey())) { - tokenCount -= 3; // found experimentally while playing with OpenAI API - for (Object enumValue : (Object[]) entry.getValue()) { - tokenCount += 3; // found experimentally while playing with OpenAI API - tokenCount += estimateTokenCountInText(enumValue.toString()); - } - } - } - } - return tokenCount; - } - @Override public int estimateTokenCountInMessage(ChatMessage message) { - int tokenCount = 0; + int tokenCount = 1; // 1 token for role tokenCount += extraTokensPerMessage(); tokenCount += estimateTokenCountInText(message.text()); - tokenCount += estimateTokenCountInText(roleFrom(message).toString().toLowerCase()); if (message instanceof UserMessage) { UserMessage userMessage = (UserMessage) message; - if (userMessage.name() != null) { + if (userMessage.name() != null && !modelName.equals(GPT_4_VISION_PREVIEW)) { tokenCount += extraTokensPerName(); tokenCount += estimateTokenCountInText(userMessage.name()); } @@ -84,53 +62,37 @@ public class OpenAiTokenizer implements Tokenizer { if (message instanceof AiMessage) { AiMessage aiMessage = (AiMessage) message; - if (aiMessage.hasToolExecutionRequests()) { - for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) { - tokenCount += 4; // found experimentally while playing with OpenAI API - tokenCount += estimateTokenCountInText(toolExecutionRequest.name()); + if (aiMessage.toolExecutionRequests() != null) { + if (modelName.contains("1106")) { + tokenCount += 6; + } else { + tokenCount += 3; + } + if (aiMessage.toolExecutionRequests().size() == 1) { + tokenCount -= 1; + ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0); + tokenCount += estimateTokenCountInText(toolExecutionRequest.name()) * 2; tokenCount += estimateTokenCountInText(toolExecutionRequest.arguments()); + } else { + tokenCount += 15; + for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) { + tokenCount += 7; + tokenCount += estimateTokenCountInText(toolExecutionRequest.name()); + + Map arguments = fromJson(toolExecutionRequest.arguments(), Map.class); + for (Map.Entry argument : arguments.entrySet()) { + tokenCount += 2; + tokenCount += estimateTokenCountInText(argument.getKey().toString()); + tokenCount += estimateTokenCountInText(argument.getValue().toString()); + } + } } } } - if (message instanceof ToolExecutionResultMessage) { - ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage) message; - tokenCount += -1; // found experimentally while playing with OpenAI API - tokenCount += estimateTokenCountInText(toolExecutionResultMessage.toolName()); - } - return tokenCount; } - @Override - public int estimateTokenCountInMessages(Iterable messages) { - // see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - - int tokenCount = 3; // every reply is primed with <|start|>assistant<|message|> - for (ChatMessage message : messages) { - tokenCount += estimateTokenCountInMessage(message); - } - return tokenCount; - } - - @Override - public int estimateTokenCountInToolSpecifications(Iterable toolSpecifications) { - int tokenCount = 0; - for (ToolSpecification toolSpecification : toolSpecifications) { - tokenCount += estimateTokenCountInText(toolSpecification.name()); - tokenCount += estimateTokenCountInText(toolSpecification.description()); - tokenCount += estimateTokenCountInToolParameters(toolSpecification.parameters()); - tokenCount += 12; // found experimentally while playing with OpenAI API - } - tokenCount += 12; // found experimentally while playing with OpenAI API - return tokenCount; - } - - @Override - public int estimateTokenCountInToolExecutionRequests(Iterable toolExecutionRequests) { - return 0; // TODO - } - private int extraTokensPerMessage() { if (modelName.equals(GPT_3_5_TURBO_0301)) { return 4; @@ -147,6 +109,88 @@ public class OpenAiTokenizer implements Tokenizer { } } + @Override + public int estimateTokenCountInMessages(Iterable messages) { + // see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + + int tokenCount = 3; // every reply is primed with <|start|>assistant<|message|> + for (ChatMessage message : messages) { + tokenCount += estimateTokenCountInMessage(message); + } + return tokenCount; + } + + @Override + public int estimateTokenCountInToolSpecifications(Iterable toolSpecifications) { + int tokenCount = 16; + for (ToolSpecification toolSpecification : toolSpecifications) { + tokenCount += 6; + tokenCount += estimateTokenCountInText(toolSpecification.name()); + if (toolSpecification.description() != null) { + tokenCount += 2; + tokenCount += estimateTokenCountInText(toolSpecification.description()); + } + tokenCount += estimateTokenCountInToolParameters(toolSpecification.parameters()); + } + return tokenCount; + } + + private int estimateTokenCountInToolParameters(ToolParameters parameters) { + if (parameters == null) { + return 0; + } + + int tokenCount = 3; + Map> properties = parameters.properties(); + if (modelName.contains("1106")) { + tokenCount += properties.size() - 1; + } + for (String property : properties.keySet()) { + if (modelName.contains("1106")) { + tokenCount += 2; + } else { + tokenCount += 3; + } + tokenCount += estimateTokenCountInText(property); + for (Map.Entry entry : properties.get(property).entrySet()) { + if ("type".equals(entry.getKey())) { + if ("array".equals(entry.getValue()) && modelName.contains("1106")) { + tokenCount += 1; + } + // TODO object + } else if ("description".equals(entry.getKey())) { + tokenCount += 2; + tokenCount += estimateTokenCountInText(entry.getValue().toString()); + if (modelName.contains("1106") && parameters.required().contains(property)) { + tokenCount += 1; + } + } else if ("enum".equals(entry.getKey())) { + if (modelName.contains("1106")) { + tokenCount -= 2; + } else { + tokenCount -= 3; + } + for (Object enumValue : (Object[]) entry.getValue()) { + tokenCount += 3; + tokenCount += estimateTokenCountInText(enumValue.toString()); + } + } + } + } + return tokenCount; + } + + @Override + public int estimateTokenCountInForcefulToolSpecification(ToolSpecification toolSpecification) { + int tokenCount = estimateTokenCountInToolSpecifications(singletonList(toolSpecification)); + tokenCount += 4; + tokenCount += estimateTokenCountInText(toolSpecification.name()); + if (modelName.contains("1106")) { + tokenCount += 3; + } + return tokenCount; + } + public List encode(String text) { return encoding.orElseThrow(unknownModelException()) .encodeOrdinary(text); @@ -165,4 +209,92 @@ public class OpenAiTokenizer implements Tokenizer { private Supplier unknownModelException() { return () -> illegalArgument("Model '%s' is unknown to jtokkit", modelName); } + + @Override + public int estimateTokenCountInToolExecutionRequests(Iterable toolExecutionRequests) { + + int tokenCount = 0; + + int toolsCount = 0; + int toolsWithArgumentsCount = 0; + int toolsWithoutArgumentsCount = 0; + + int totalArgumentsCount = 0; + + for (ToolExecutionRequest toolExecutionRequest : toolExecutionRequests) { + tokenCount += 4; + tokenCount += estimateTokenCountInText(toolExecutionRequest.name()); + tokenCount += estimateTokenCountInText(toolExecutionRequest.arguments()); + + int argumentCount = countArguments(toolExecutionRequest.arguments()); + if (argumentCount == 0) { + toolsWithoutArgumentsCount++; + } else { + toolsWithArgumentsCount++; + } + totalArgumentsCount += argumentCount; + + toolsCount++; + } + + if (modelName.equals(GPT_3_5_TURBO_1106)) { + tokenCount += 16; + tokenCount += 3 * toolsWithoutArgumentsCount; + tokenCount += toolsCount; + if (totalArgumentsCount > 0) { + tokenCount -= 1; + tokenCount -= 2 * totalArgumentsCount; + tokenCount += 2 * toolsWithArgumentsCount; + tokenCount += toolsCount; + } + } + + if (modelName.equals(GPT_4_1106_PREVIEW)) { + tokenCount += 3; + if (toolsCount > 1) { + tokenCount += 18; + tokenCount += 15 * toolsCount; + tokenCount += totalArgumentsCount; + tokenCount -= 3 * toolsWithoutArgumentsCount; + } + } + + return tokenCount; + } + + @Override + public int estimateTokenCountInForcefulToolExecutionRequest(ToolExecutionRequest toolExecutionRequest) { + + if (modelName.equals(GPT_4_1106_PREVIEW)) { + int argumentsCount = countArguments(toolExecutionRequest.arguments()); + if (argumentsCount == 0) { + return 1; + } else { + return estimateTokenCountInText(toolExecutionRequest.arguments()); + } + } + + int tokenCount = estimateTokenCountInToolExecutionRequests(singletonList(toolExecutionRequest)); + tokenCount -= 4; + tokenCount -= estimateTokenCountInText(toolExecutionRequest.name()); + + if (modelName.equals(GPT_3_5_TURBO_1106)) { + int argumentsCount = countArguments(toolExecutionRequest.arguments()); + if (argumentsCount == 0) { + return 1; + } + tokenCount -= 19; + tokenCount += 2 * argumentsCount; + } + + return tokenCount; + } + + static int countArguments(String arguments) { + if (isNullOrBlank(arguments)) { + return 0; + } + Map argumentsMap = fromJson(arguments, Map.class); + return argumentsMap.size(); + } } diff --git a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelIT.java b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelIT.java index ba044c36e..c475cc157 100644 --- a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelIT.java +++ b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelIT.java @@ -15,6 +15,7 @@ import java.util.List; import static dev.ai4j.openai4j.chat.ChatCompletionModel.GPT_3_5_TURBO_1106; import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER; +import static dev.langchain4j.data.message.ToolExecutionResultMessage.from; import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.model.output.FinishReason.*; import static java.util.Arrays.asList; @@ -41,17 +42,17 @@ class OpenAiChatModelIT { void should_generate_answer_and_return_token_usage_and_finish_reason_stop() { // given - UserMessage userMessage = userMessage("hello, how are you?"); + UserMessage userMessage = userMessage("What is the capital of Germany?"); // when Response response = model.generate(userMessage); // then - assertThat(response.content().text()).isNotBlank(); + assertThat(response.content().text()).contains("Berlin"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(13); - assertThat(tokenUsage.outputTokenCount()).isGreaterThan(1); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(14); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -67,7 +68,7 @@ class OpenAiChatModelIT { .maxTokens(3) .build(); - UserMessage userMessage = userMessage("hello, how are you?"); + UserMessage userMessage = userMessage("What is the capital of Germany?"); // when Response response = model.generate(userMessage); @@ -76,7 +77,7 @@ class OpenAiChatModelIT { assertThat(response.content().text()).isNotBlank(); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(13); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(14); assertThat(tokenUsage.outputTokenCount()).isEqualTo(3); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -113,8 +114,7 @@ class OpenAiChatModelIT { assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION); // given - ToolExecutionResultMessage toolExecutionResultMessage - = ToolExecutionResultMessage.from(toolExecutionRequest, "4"); + ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "4"); List messages = asList(userMessage, aiMessage, toolExecutionResultMessage); // when @@ -135,7 +135,7 @@ class OpenAiChatModelIT { } @Test - void should_execute_concrete_tool_then_answer() { + void should_execute_tool_forcefully_then_answer() { // given UserMessage userMessage = userMessage("2+2=?"); @@ -162,8 +162,7 @@ class OpenAiChatModelIT { assertThat(response.finishReason()).isEqualTo(STOP); // not sure if a bug in OpenAI or stop is expected here // given - ToolExecutionResultMessage toolExecutionResultMessage - = ToolExecutionResultMessage.from(toolExecutionRequest, "4"); + ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "4"); List messages = asList(userMessage, aiMessage, toolExecutionResultMessage); // when @@ -221,8 +220,8 @@ class OpenAiChatModelIT { assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION); // given - ToolExecutionResultMessage toolExecutionResultMessage1 = ToolExecutionResultMessage.from(toolExecutionRequest1, "4"); - ToolExecutionResultMessage toolExecutionResultMessage2 = ToolExecutionResultMessage.from(toolExecutionRequest2, "6"); + ToolExecutionResultMessage toolExecutionResultMessage1 = from(toolExecutionRequest1, "4"); + ToolExecutionResultMessage toolExecutionResultMessage2 = from(toolExecutionRequest2, "6"); List messages = asList(userMessage, aiMessage, toolExecutionResultMessage1, toolExecutionResultMessage2); diff --git a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiStreamingChatModelIT.java b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiStreamingChatModelIT.java index c850cf682..bc8502789 100644 --- a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiStreamingChatModelIT.java +++ b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiStreamingChatModelIT.java @@ -10,6 +10,7 @@ import dev.langchain4j.model.StreamingResponseHandler; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; +import org.assertj.core.data.Percentage; import org.junit.jupiter.api.Test; import java.util.List; @@ -19,6 +20,7 @@ import java.util.concurrent.TimeoutException; import static dev.ai4j.openai4j.chat.ChatCompletionModel.GPT_3_5_TURBO_1106; import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER; +import static dev.langchain4j.data.message.ToolExecutionResultMessage.from; import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.model.output.FinishReason.STOP; import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION; @@ -26,6 +28,7 @@ import static java.util.Arrays.asList; import static java.util.Collections.singletonList; import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.data.Percentage.withPercentage; class OpenAiStreamingChatModelIT { @@ -43,6 +46,8 @@ class OpenAiStreamingChatModelIT { .addParameter("second", INTEGER) .build(); + Percentage tokenizerPrecision = withPercentage(5); + @Test void should_stream_answer() throws ExecutionException, InterruptedException, TimeoutException { @@ -133,15 +138,15 @@ class OpenAiStreamingChatModelIT { assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(50); // TODO should be 53? - assertThat(tokenUsage.outputTokenCount()).isEqualTo(0); // TODO should be 22? + assertThat(tokenUsage.inputTokenCount()).isCloseTo(53, tokenizerPrecision); + assertThat(tokenUsage.outputTokenCount()).isCloseTo(22, tokenizerPrecision); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION); // given - ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(toolExecutionRequest, "4"); + ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "4"); List messages = asList(userMessage, aiMessage, toolExecutionResultMessage); @@ -175,7 +180,7 @@ class OpenAiStreamingChatModelIT { assertThat(secondAiMessage.toolExecutionRequests()).isNull(); TokenUsage secondTokenUsage = secondResponse.tokenUsage(); - assertThat(secondTokenUsage.inputTokenCount()).isEqualTo(43); + assertThat(secondTokenUsage.inputTokenCount()).isCloseTo(41, tokenizerPrecision); assertThat(secondTokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(secondTokenUsage.totalTokenCount()) .isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount()); @@ -184,7 +189,7 @@ class OpenAiStreamingChatModelIT { } @Test - void should_execute_concrete_tool_then_stream_answer() throws Exception { + void should_execute_tool_forcefully_then_stream_answer() throws Exception { // given UserMessage userMessage = userMessage("2+2=?"); @@ -227,15 +232,15 @@ class OpenAiStreamingChatModelIT { assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(89); // TODO should be 53? - assertThat(tokenUsage.outputTokenCount()).isEqualTo(0); // TODO should be 22? + assertThat(tokenUsage.inputTokenCount()).isCloseTo(59, tokenizerPrecision); + assertThat(tokenUsage.outputTokenCount()).isCloseTo(16, tokenizerPrecision); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); assertThat(response.finishReason()).isEqualTo(STOP); // not sure if a bug in OpenAI or stop is expected here // given - ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(toolExecutionRequest, "4"); + ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "4"); List messages = asList(userMessage, aiMessage, toolExecutionResultMessage); @@ -269,7 +274,7 @@ class OpenAiStreamingChatModelIT { assertThat(secondAiMessage.toolExecutionRequests()).isNull(); TokenUsage secondTokenUsage = secondResponse.tokenUsage(); - assertThat(secondTokenUsage.inputTokenCount()).isEqualTo(43); + assertThat(secondTokenUsage.inputTokenCount()).isCloseTo(41, tokenizerPrecision); assertThat(secondTokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(secondTokenUsage.totalTokenCount()) .isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount()); @@ -332,16 +337,16 @@ class OpenAiStreamingChatModelIT { assertThat(toolExecutionRequest2.arguments()).isEqualToIgnoringWhitespace("{\"first\": 3, \"second\": 3}"); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(55); // TODO should be 57? - assertThat(tokenUsage.outputTokenCount()).isEqualTo(0); // TODO should be 51? + assertThat(tokenUsage.inputTokenCount()).isCloseTo(57, tokenizerPrecision); + assertThat(tokenUsage.outputTokenCount()).isCloseTo(51, tokenizerPrecision); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION); // given - ToolExecutionResultMessage toolExecutionResultMessage1 = ToolExecutionResultMessage.from(toolExecutionRequest1, "4"); - ToolExecutionResultMessage toolExecutionResultMessage2 = ToolExecutionResultMessage.from(toolExecutionRequest2, "6"); + ToolExecutionResultMessage toolExecutionResultMessage1 = from(toolExecutionRequest1, "4"); + ToolExecutionResultMessage toolExecutionResultMessage2 = from(toolExecutionRequest2, "6"); List messages = asList(userMessage, aiMessage, toolExecutionResultMessage1, toolExecutionResultMessage2); @@ -375,7 +380,7 @@ class OpenAiStreamingChatModelIT { assertThat(secondAiMessage.toolExecutionRequests()).isNull(); TokenUsage secondTokenUsage = secondResponse.tokenUsage(); - assertThat(secondTokenUsage.inputTokenCount()).isEqualTo(66); // TODO should be 83? + assertThat(secondTokenUsage.inputTokenCount()).isCloseTo(83, tokenizerPrecision); assertThat(secondTokenUsage.outputTokenCount()).isGreaterThan(0); assertThat(secondTokenUsage.totalTokenCount()) .isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount()); diff --git a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerIT.java b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerIT.java new file mode 100644 index 000000000..0ffd5b217 --- /dev/null +++ b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerIT.java @@ -0,0 +1,1739 @@ +package dev.langchain4j.model.openai; + +import dev.ai4j.openai4j.chat.ChatCompletionModel; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.Tokenizer; +import dev.langchain4j.model.output.Response; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.opentest4j.AssertionFailedError; + +import java.util.List; +import java.util.stream.Stream; + +import static dev.ai4j.openai4j.chat.ChatCompletionModel.*; +import static dev.langchain4j.agent.tool.JsonSchemaProperty.*; +import static dev.langchain4j.data.message.AiMessage.aiMessage; +import static dev.langchain4j.data.message.SystemMessage.systemMessage; +import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage; +import static dev.langchain4j.data.message.UserMessage.userMessage; +import static java.util.Arrays.asList; +import static java.util.Arrays.stream; +import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.data.Percentage.withPercentage; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +// TODO use exact model for Tokenizer (the one returned by LLM) +class OpenAiTokenizerIT { + + @ParameterizedTest + @MethodSource + void should_count_tokens_in_messages(List messages, ChatCompletionModel modelName) { + + // given + OpenAiChatModel model = OpenAiChatModel.builder() + .apiKey(System.getenv("OPENAI_API_KEY")) + .modelName(modelName.toString()) + .maxTokens(1) // we don't need outputs, let's not waste tokens + .logRequests(true) + .logResponses(true) + .build(); + + int expectedTokenCount = model.generate(messages).tokenUsage().inputTokenCount(); + + Tokenizer tokenizer = new OpenAiTokenizer(modelName.toString()); + + // when + int tokenCount = tokenizer.estimateTokenCountInMessages(messages); + + // then + assertThat(tokenCount).isEqualTo(expectedTokenCount); + } + + static Stream should_count_tokens_in_messages() { + return stream(ChatCompletionModel.values()) + // I don't have access to these models + .filter(model -> model != GPT_4_32K && model != GPT_4_32K_0314 && model != GPT_4_32K_0613) + .flatMap(model -> Stream.of( + arguments(singletonList(systemMessage("Be friendly.")), model), + arguments(singletonList(systemMessage("You are a helpful assistant, help the user!")), model), + + arguments(singletonList(userMessage("Hi")), model), + arguments(singletonList(userMessage("Hello, how are you?")), model), + + arguments(singletonList(userMessage("Stan", "Hi")), model), + arguments(singletonList(userMessage("Klaus", "Hi")), model), + arguments(singletonList(userMessage("Giovanni", "Hi")), model), + + arguments(singletonList(aiMessage("Hi")), model), + arguments(singletonList(aiMessage("Hello, how can I help you?")), model), + + arguments(asList( + systemMessage("Be helpful"), + userMessage("hi") + ), model), + + arguments(asList( + systemMessage("Be helpful"), + userMessage("hi"), + aiMessage("Hello, how can I help you?"), + userMessage("tell me a joke") + ), model), + + arguments(asList( + systemMessage("Be helpful"), + userMessage("hi"), + aiMessage("Hello, how can I help you?"), + userMessage("tell me a joke"), + aiMessage("Why don't scientists trust atoms?\n\nBecause they make up everything!"), + userMessage("tell me another one, this one is not funny") + ), model) + )); + } + + @ParameterizedTest + @MethodSource + void should_count_tokens_in_messages_with_single_tool(List messages, ChatCompletionModel modelName) { + + // given + OpenAiChatModel model = OpenAiChatModel.builder() + .apiKey(System.getenv("OPENAI_API_KEY")) + .modelName(modelName.toString()) + .maxTokens(1) // we don't need outputs, let's not waste tokens + .logRequests(true) + .logResponses(true) + .build(); + + int expectedTokenCount = model.generate(messages).tokenUsage().inputTokenCount(); + + Tokenizer tokenizer = new OpenAiTokenizer(modelName.toString()); + + // when + int tokenCount = tokenizer.estimateTokenCountInMessages(messages); + + // then + assertThat(tokenCount).isCloseTo(expectedTokenCount, withPercentage(4)); + } + + static Stream should_count_tokens_in_messages_with_single_tool() { + return stream(ChatCompletionModel.values()) + // I don't have access to these models + .filter(model -> model != GPT_4_32K && model != GPT_4_32K_0314 && model != GPT_4_32K_0613 + && model != GPT_4_0314 // does not support tools + && model != GPT_4_VISION_PREVIEW // does not support tools (yet) + ) + .flatMap(model -> Stream.of( + + // various tool "name" lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("time") // 1 token + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "23:59") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("b") + .name("current_time") // 2 tokens + .arguments("{}") + .build() + ), + toolExecutionResultMessage("b", null, "23:59") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("c") + .name("get_current_time") // 3 tokens + .arguments("{}") + .build() + ), + toolExecutionResultMessage("c", null, "23:59") + ), model), + + // 1 argument, various argument "name" lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Berlin\"}") // 1 token + .build() + ), + toolExecutionResultMessage("a", null, "23:59") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("b") + .name("current_time") + .arguments("{\"target_city\":\"Berlin\"}") // 2 tokens + .build() + ), + toolExecutionResultMessage("b", null, "23:59") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("c") + .name("current_time") + .arguments("{\"target_city_name\":\"Berlin\"}") // 3 tokens + .build() + ), + toolExecutionResultMessage("c", null, "23:59") + ), model), + + // 1 argument, various argument "value" lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Berlin\"}") // 1 token + .build() + ), + toolExecutionResultMessage("a", null, "23:59") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("b") + .name("current_time") + .arguments("{\"city\":\"Munich\"}") // 3 tokens + .build() + ), + toolExecutionResultMessage("b", null, "23:59") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("c") + .name("current_time") + .arguments("{\"city\":\"Pietramontecorvino\"}") // 8 tokens + .build() + ), + toolExecutionResultMessage("c", null, "23:59") + ), model), + + // 1 argument, various numeric argument "value" lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city_id\": 189}") // 1 token + .build() + ), + toolExecutionResultMessage("a", null, "23:59") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("b") + .name("current_time") + .arguments("{\"city_id\": 189647}") // 2 tokens + .build() + ), + toolExecutionResultMessage("b", null, "23:59") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("c") + .name("current_time") + .arguments("{\"city_id\": 189647852}") // 3 tokens + .build() + ), + toolExecutionResultMessage("c", null, "23:59") + ), model), + + // 2 arguments + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Berlin\",\"country\":\"Germany\"}") + .build() + ), + toolExecutionResultMessage("a", null, "23:59") + ), model), + + // 3 arguments + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Berlin\",\"country\":\"Germany\",\"format\":\"24\"}") + .build() + ), + toolExecutionResultMessage("a", null, "23:59") + ), model), + + // various result lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "23") // 1 token + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("b") + .name("current_time") + .arguments("{}") + .build() + ), + toolExecutionResultMessage("b", null, "23:59") // 3 tokens + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("c") + .name("current_time") + .arguments("{}") + .build() + ), + toolExecutionResultMessage("c", null, "23:59:59") // 5 tokens + ), model) + )); + } + + @ParameterizedTest + @MethodSource + void should_count_tokens_in_messages_with_multiple_tools(List messages, + ChatCompletionModel modelName) { + // given + OpenAiChatModel model = OpenAiChatModel.builder() + .apiKey(System.getenv("OPENAI_API_KEY")) + .modelName(modelName.toString()) + .maxTokens(1) // we don't need outputs, let's not waste tokens + .logRequests(true) + .logResponses(true) + .build(); + + int expectedTokenCount = model.generate(messages).tokenUsage().inputTokenCount(); + + Tokenizer tokenizer = new OpenAiTokenizer(modelName.toString()); + + // when + int tokenCount = tokenizer.estimateTokenCountInMessages(messages); + + // then + assertThat(tokenCount).isCloseTo(expectedTokenCount, withPercentage(4)); + } + + static Stream should_count_tokens_in_messages_with_multiple_tools() { + return stream(ChatCompletionModel.values()) + // only these models support parallel tool calling + .filter(model -> model == GPT_3_5_TURBO_1106 || model == GPT_4_1106_PREVIEW) + .flatMap(model -> Stream.of( + + // various tool "name" lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("time") // 1 token + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("temperature") // 1 token + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") // 2 tokens + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") // 2 tokens + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("get_current_time") // 3 tokens + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("get_current_temperature") // 3 tokens + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + + // 1 argument, various argument "name" lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Berlin\"}") // 1 token + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"city\":\"Berlin\"}") // 1 token + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"target_city\":\"Berlin\"}") // 2 tokens + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"target_city\":\"Berlin\"}") // 2 tokens + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"target_city_name\":\"Berlin\"}") // 3 tokens + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"target_city_name\":\"Berlin\"}") // 3 tokens + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + + // 1 argument, various argument "value" lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Berlin\"}") // 1 token + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"city\":\"Berlin\"}") // 1 token + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Munich\"}") // 3 tokens + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"city\":\"Munich\"}") // 3 tokens + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Pietramontecorvino\"}") // 8 tokens + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"city\":\"Pietramontecorvino\"}") // 8 tokens + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + + // 1 argument, various numeric argument "value" lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city_id\": 189}") // 1 token + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"city_id\": 189}") // 1 token + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city_id\": 189647}") // 2 tokens + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"city_id\": 189647}") // 2 tokens + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city_id\": 189647852}") // 3 tokens + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"city_id\": 189647852}") // 3 tokens + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + + // 2 arguments + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Berlin\",\"country\":\"Germany\"}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"city\":\"Berlin\",\"country\":\"Germany\"}") + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + + // 3 arguments + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{\"city\":\"Berlin\",\"country\":\"Germany\",\"format\":\"24\"}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{\"city\":\"Berlin\",\"country\":\"Germany\",\"unit\":\"C\"}") + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2") + ), model), + + // various result lengths + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "23"), // 1 token + toolExecutionResultMessage("b", null, "17") // 1 token + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "23:59"), // 3 tokens + toolExecutionResultMessage("b", null, "17.5") // 3 tokens + ), model), + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("current_time") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("current_temperature") + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "23:59:59"), // 5 tokens + toolExecutionResultMessage("b", null, "17.5 grad C") // 5 tokens + ), model), + + // 3 tools without arguments + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("time") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("temperature") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("c") + .name("weather") + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2"), + toolExecutionResultMessage("c", null, "3") + ), model), + + // 3 tools with arguments + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("time") + .arguments("{\"city\":\"Berlin\"}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("temperature") + .arguments("{\"city\":\"Berlin\"}") + .build(), + ToolExecutionRequest.builder() + .id("c") + .name("weather") + .arguments("{\"city\":\"Berlin\"}") + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2"), + toolExecutionResultMessage("c", null, "3") + ), model), + + // 4 tools without arguments + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("time") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("temperature") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("c") + .name("weather") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .id("d") + .name("UV") + .arguments("{}") + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2"), + toolExecutionResultMessage("c", null, "3"), + toolExecutionResultMessage("d", null, "4") + ), model), + + // 4 tools with arguments + arguments(asList( + aiMessage( + ToolExecutionRequest.builder() + .id("a") + .name("time") + .arguments("{\"city\":\"Berlin\"}") + .build(), + ToolExecutionRequest.builder() + .id("b") + .name("temperature") + .arguments("{\"city\":\"Berlin\"}") + .build(), + ToolExecutionRequest.builder() + .id("c") + .name("weather") + .arguments("{\"city\":\"Berlin\"}") + .build(), + ToolExecutionRequest.builder() + .id("d") + .name("UV") + .arguments("{\"city\":\"Berlin\"}") + .build() + ), + toolExecutionResultMessage("a", null, "1"), + toolExecutionResultMessage("b", null, "2"), + toolExecutionResultMessage("c", null, "3"), + toolExecutionResultMessage("d", null, "4") + ), model) + )); + } + + @ParameterizedTest + @MethodSource + void should_count_tokens_in_tool_specifications(List toolSpecifications, + ChatCompletionModel modelName) { + // given + OpenAiChatModel model = OpenAiChatModel.builder() + .apiKey(System.getenv("OPENAI_API_KEY")) + .modelName(modelName.toString()) + .maxTokens(1) // we don't need outputs, let's not waste tokens + .logRequests(true) + .logResponses(true) + .build(); + + List dummyMessages = singletonList(userMessage("hi")); + + Tokenizer tokenizer = new OpenAiTokenizer(modelName.toString()); + + int expectedTokenCount = model.generate(dummyMessages, toolSpecifications).tokenUsage().inputTokenCount() + - tokenizer.estimateTokenCountInMessages(dummyMessages); + + // when + int tokenCount = tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications); + + // then + assertThat(tokenCount).isCloseTo(expectedTokenCount, withPercentage(2)); + } + + static Stream should_count_tokens_in_tool_specifications() { + return stream(ChatCompletionModel.values()) + // I don't have access to these models + .filter(model -> model != GPT_4_32K && model != GPT_4_32K_0314 && model != GPT_4_32K_0613 + && model != GPT_4_0314 // does not support tools + && model != GPT_4_VISION_PREVIEW) // does not support tools (yet) + .flatMap(model -> Stream.of( + + // "name" of various lengths + arguments(singletonList(ToolSpecification.builder() + .name("time") // 1 token + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") // 2 tokens + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("get_current_time") // 3 tokens + .build()), model), + + // "description" of various lengths + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("time") // 1 token + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") // 2 tokens + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("returns current time in 24-hour format") // 8 tokens + .build()), model), + + // 1 parameter with "name" of various lengths + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city") // 1 token + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("target_city") // 2 tokens + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("target_city_name") // 3 tokens + .build()), model), + + // 1 parameter with "description" of various lengths + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city", description("city")) // 1 token + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city", description("target city name")) // 3 tokens + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city", description("city for which time should be returned")) // 7 tokens + .build()), model), + + // 1 parameter with varying "type" + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city", STRING) + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city", INTEGER) + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("cities", ARRAY, items(INTEGER)) + .build()), model), + + // 1 parameter with "enum" of various range of values + arguments(singletonList(ToolSpecification.builder() + .name("current_temperature") + .description("current temperature") + .addParameter("unit", enums("C")) + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_temperature") + .description("current temperature") + .addParameter("unit", enums("C", "K")) + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_temperature") + .description("current temperature") + .addParameter("unit", enums("C", "K", "F")) + .build()), model), + + // 1 parameter with "enum" of various name lengths + arguments(singletonList(ToolSpecification.builder() + .name("current_temperature") + .description("current temperature") + .addParameter("unit", enums("celsius", "kelvin", "fahrenheit")) // 2 tokens each + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_temperature") + .description("current temperature") + .addParameter("unit", enums("CELSIUS", "KELVIN", "FAHRENHEIT")) // 3-5 tokens + .build()), model), + + // 2 parameters with "name" of various length + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city") // 1 token + .addParameter("country") // 1 token + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("target_city") // 2 tokens + .addParameter("target_country") // 2 tokens + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("target_city_name") // 3 tokens + .addParameter("target_country_name") // 3 tokens + .build()), model), + + // 3 parameters with "name" of various length + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city") // 1 token + .addParameter("country") // 1 token + .addParameter("format", enums("12H", "24H")) // 1 token + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("target_city") // 2 tokens + .addParameter("target_country") // 2 tokens + .addParameter("time_format", enums("12H", "24H")) // 2 tokens + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("target_city_name") // 3 tokens + .addParameter("target_country_name") // 3 tokens + .addParameter("current_time_format", enums("12H", "24H")) // 3 tokens + .build()), model), + + // 1 optional parameter with "name" of various lengths + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addOptionalParameter("city") // 1 token + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addOptionalParameter("target_city") // 2 tokens + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addOptionalParameter("target_city_name") // 3 tokens + .build()), model), + + // 2 optional parameters with "name" of various lengths + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addOptionalParameter("city") // 1 token + .addOptionalParameter("country") // 1 token + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addOptionalParameter("target_city") // 2 tokens + .addOptionalParameter("target_country") // 2 tokens + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addOptionalParameter("target_city_name") // 3 tokens + .addOptionalParameter("target_country_name") // 3 tokens + .build()), model), + + // 1 mandatory, 1 optional parameters with "name" of various lengths + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city") // 1 token + .addOptionalParameter("country") // 1 token + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("target_city") // 2 tokens + .addOptionalParameter("target_country") // 2 tokens + .build()), model), + arguments(singletonList(ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("target_city_name") // 3 tokens + .addOptionalParameter("target_country_name") // 3 tokens + .build()), model), + + // 2 tools + arguments(asList( + ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city_name", description("city name")) + .addOptionalParameter("country_name", description("optional country name")) + .build(), + ToolSpecification.builder() + .name("current_temperature") + .description("current temperature") + .addParameter("city_name", description("city name")) + .addOptionalParameter("country_name", description("optional country name")) + .build() + ), model), + + // 3 tools + arguments(asList( + ToolSpecification.builder() + .name("current_time") + .description("current time") + .addParameter("city_name", description("city name")) + .addOptionalParameter("country_name", description("optional country name")) + .build(), + ToolSpecification.builder() + .name("current_temperature") + .description("current temperature") + .addParameter("city_name", description("city name")) + .addOptionalParameter("country_name", description("optional country name")) + .build(), + ToolSpecification.builder() + .name("current_weather") + .description("current weather") + .addParameter("city_name", description("city name")) + .addOptionalParameter("country_name", description("optional country name")) + .build() + ), model) + )); + } + + @ParameterizedTest + @MethodSource + void should_count_tokens_in_tool_execution_request(UserMessage userMessage, + ToolSpecification toolSpecification, + ToolExecutionRequest expectedToolExecutionRequest, + ChatCompletionModel modelName) { + // given + OpenAiChatModel model = OpenAiChatModel.builder() + .apiKey(System.getenv("OPENAI_API_KEY")) + .modelName(modelName.toString()) + .logRequests(true) + .logResponses(true) + .build(); + + Response response = model.generate(singletonList(userMessage), singletonList(toolSpecification)); + + List toolExecutionRequests = response.content().toolExecutionRequests(); + // we need to ensure that model generated expected tool execution request, + // then we can use output token count as a reference + assertThat(toolExecutionRequests).hasSize(1); + ToolExecutionRequest toolExecutionRequest = toolExecutionRequests.get(0); + assertThat(toolExecutionRequest.name()).isEqualTo(expectedToolExecutionRequest.name()); + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace(expectedToolExecutionRequest.arguments()); + + int expectedTokenCount = response.tokenUsage().outputTokenCount(); + + Tokenizer tokenizer = new OpenAiTokenizer(modelName.toString()); + + // when + int tokenCount = tokenizer.estimateTokenCountInToolExecutionRequests(toolExecutionRequests); + + // then + try { + assertThat(tokenCount).isEqualTo(expectedTokenCount); + } catch (AssertionFailedError e) { + if (modelName == GPT_3_5_TURBO_1106) { + // sometimes GPT_3_5_TURBO_1106 calculates tokens wrongly + // see https://community.openai.com/t/inconsistent-token-billing-for-tool-calls-in-gpt-3-5-turbo-1106 + // TODO remove once they fix it + e.printStackTrace(); + // there is some pattern to it, so we are going to check if this is really the case or our calculation is wrong + Tokenizer tokenizer2 = new OpenAiTokenizer(GPT_3_5_TURBO.toString()); + int tokenCount2 = tokenizer2.estimateTokenCountInToolExecutionRequests(toolExecutionRequests); + assertThat(tokenCount2).isEqualTo(expectedTokenCount - 3); + } else { + throw e; + } + } + } + + @ParameterizedTest + @MethodSource("should_count_tokens_in_tool_execution_request") + void should_count_tokens_in_forceful_tool_specification_and_execution_request(UserMessage userMessage, + ToolSpecification toolSpecification, + ToolExecutionRequest expectedToolExecutionRequest, + ChatCompletionModel modelName) { + // given + OpenAiChatModel model = OpenAiChatModel.builder() + .apiKey(System.getenv("OPENAI_API_KEY")) + .modelName(modelName.toString()) + .logRequests(true) + .logResponses(true) + .build(); + + Response response = model.generate(singletonList(userMessage), toolSpecification); + + Tokenizer tokenizer = new OpenAiTokenizer(modelName.toString()); + + int expectedTokenCountInSpecification = response.tokenUsage().inputTokenCount() + - tokenizer.estimateTokenCountInMessages(singletonList(userMessage)); + + // when + int tokenCountInSpecification = tokenizer.estimateTokenCountInForcefulToolSpecification(toolSpecification); + + // then + assertThat(tokenCountInSpecification).isEqualTo(expectedTokenCountInSpecification); + + // given + List toolExecutionRequests = response.content().toolExecutionRequests(); + // we need to ensure that model generated expected tool execution request, + // then we can use output token count as a reference + assertThat(toolExecutionRequests).hasSize(1); + ToolExecutionRequest toolExecutionRequest = toolExecutionRequests.get(0); + assertThat(toolExecutionRequest.name()).isEqualTo(expectedToolExecutionRequest.name()); + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace(expectedToolExecutionRequest.arguments()); + + int expectedTokenCountInToolRequest = response.tokenUsage().outputTokenCount(); + + // when + int tokenCountInToolRequest = tokenizer.estimateTokenCountInForcefulToolExecutionRequest(toolExecutionRequest); + + // then + assertThat(tokenCountInToolRequest).isEqualTo(expectedTokenCountInToolRequest); + } + + static Stream should_count_tokens_in_tool_execution_request() { + return stream(ChatCompletionModel.values()) + // I don't have access to these models + .filter(model -> model != GPT_4_32K && model != GPT_4_32K_0314 && model != GPT_4_32K_0613 + && model != GPT_4_0314 // does not support tools + && model != GPT_4_VISION_PREVIEW) // does not support tools (yet) + .flatMap(model -> Stream.of( + + // no arguments, different lengths of "name" + arguments( + userMessage("What is the time now?"), + ToolSpecification.builder() + .name("time") // 1 token + .description("returns current time") + .build(), + ToolExecutionRequest.builder() + .name("time") // 1 token + .arguments("{}") + .build(), + model + ), + arguments( + userMessage("What is the time now?"), + ToolSpecification.builder() + .name("current_time") // 2 tokens + .description("returns current time") + .build(), + ToolExecutionRequest.builder() + .name("current_time") // 2 tokens + .arguments("{}") + .build(), + model + ), + arguments( + userMessage("What is the time now?"), + ToolSpecification.builder() + .name("get_current_time") // 3 tokens + .description("returns current time") + .build(), + ToolExecutionRequest.builder() + .name("get_current_time") // 3 tokens + .arguments("{}") + .build(), + model + ), + + // one argument, different lengths of "arguments" + arguments( + userMessage("What is the time in Berlin now?"), + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"city\":\"Berlin\"}") // 5 tokens + .build(), + model + ), + arguments( + userMessage("What is the time in Munich now?"), + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"city\":\"Munich\"}") // 7 tokens + .build(), + model + ), + arguments( + userMessage("What is the time in Pietramontecorvino now?"), + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"city\":\"Pietramontecorvino\"}") // 12 tokens + .build(), + model + ), + + // two arguments, different lengths of "arguments" + arguments( + userMessage("What is the time in Berlin now?"), + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Germany\",\"city\":\"Berlin\"}") // 9 tokens + .build(), + model + ), + arguments( + userMessage("What is the time in Munich now?"), + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Germany\",\"city\":\"Munich\"}") // 11 tokens + .build(), + model + ), + arguments( + userMessage("What is the time in Pietramontecorvino now?"), + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Italy\",\"city\":\"Pietramontecorvino\"}") // 16 tokens + .build(), + model + ), + + // three arguments, different lengths of "arguments" + arguments( + userMessage("What is the time in Berlin now in 24-hour format?"), + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .addParameter("format", enums("12", "24")) + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Germany\",\"city\":\"Berlin\",\"format\":\"24\"}") // 13 tokens + .build(), + model + ), + arguments( + userMessage("What is the time in Munich now in 24-hour format?"), + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .addParameter("format", enums("12", "24")) + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Germany\",\"city\":\"Munich\",\"format\":\"24\"}") // 15 tokens + .build(), + model + ), + arguments( + userMessage("What is the time in Pietramontecorvino now in 24-hour format?"), + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .addParameter("format", enums("12", "24")) + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Italy\",\"city\":\"Pietramontecorvino\",\"format\":\"24\"}") // 20 tokens + .build(), + model + ) + )); + } + + @ParameterizedTest + @MethodSource + void should_count_tokens_in_multiple_tool_execution_requests(UserMessage userMessage, + List toolSpecifications, + List expectedToolExecutionRequests, + ChatCompletionModel modelName) { + // given + OpenAiChatModel model = OpenAiChatModel.builder() + .apiKey(System.getenv("OPENAI_API_KEY")) + .modelName(modelName.toString()) + .logRequests(true) + .logResponses(true) + .build(); + + Response response = model.generate(singletonList(userMessage), toolSpecifications); + + List toolExecutionRequests = response.content().toolExecutionRequests(); + // we need to ensure that model generated expected tool execution requests, + // then we can use output token count as a reference + assertThat(toolExecutionRequests).hasSize(expectedToolExecutionRequests.size()); + for (int i = 0; i < toolExecutionRequests.size(); i++) { + ToolExecutionRequest toolExecutionRequest = toolExecutionRequests.get(i); + ToolExecutionRequest expectedToolExecutionRequest = expectedToolExecutionRequests.get(i); + assertThat(toolExecutionRequest.name()).isEqualTo(expectedToolExecutionRequest.name()); + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace(expectedToolExecutionRequest.arguments()); + } + + int expectedTokenCount = response.tokenUsage().outputTokenCount(); + + Tokenizer tokenizer = new OpenAiTokenizer(modelName.toString()); + + // when + int tokenCount = tokenizer.estimateTokenCountInToolExecutionRequests(toolExecutionRequests); + + // then + assertThat(tokenCount).isEqualTo(expectedTokenCount); + } + + static Stream should_count_tokens_in_multiple_tool_execution_requests() { + return stream(ChatCompletionModel.values()) + // only these models support parallel tool calling + .filter(model -> model == GPT_3_5_TURBO_1106 || model == GPT_4_1106_PREVIEW) + .flatMap(model -> Stream.of( + + // no arguments, different lengths of "name" + arguments( + userMessage("What is the time and date now?"), + asList( + ToolSpecification.builder() + .name("time") + .description("returns current time") + .build(), + ToolSpecification.builder() + .name("date") + .description("returns current date") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("time") // 1 token + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .name("date") // 1 token + .arguments("{}") + .build() + ), + model + ), + arguments( + userMessage("What is the time and date now?"), + asList( + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .build(), + ToolSpecification.builder() + .name("current_date") + .description("returns current date") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("current_time") // 2 tokens + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .name("current_date") // 2 tokens + .arguments("{}") + .build() + ), + model + ), + arguments( + userMessage("What is the time and date now?"), + asList( + ToolSpecification.builder() + .name("get_current_time") + .description("returns current time") + .build(), + ToolSpecification.builder() + .name("get_current_date") + .description("returns current date") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("get_current_time") // 3 tokens + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .name("get_current_date") // 3 tokens + .arguments("{}") + .build() + ), + model + ), + + // no arguments, 3 tools + arguments( + userMessage("What is the time and date and location?"), + asList( + ToolSpecification.builder() + .name("time") + .description("returns current time") + .build(), + ToolSpecification.builder() + .name("date") + .description("returns current date") + .build(), + ToolSpecification.builder() + .name("location") + .description("returns current location") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("time") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .name("date") + .arguments("{}") + .build(), + ToolExecutionRequest.builder() + .name("location") + .arguments("{}") + .build() + ), + model + ), + + // no arguments, 1 argument + arguments( + userMessage("What is the time in Munich and date now?"), + asList( + ToolSpecification.builder() + .name("time") + .description("returns current time") + .addParameter("city") + .build(), + ToolSpecification.builder() + .name("date") + .description("returns current date") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("time") + .arguments("{\"city\":\"Munich\"}") + .build(), + ToolExecutionRequest.builder() + .name("date") + .arguments("{}") + .build() + ), + model + ), + + // one argument, 2 different tools, different lengths of "arguments" + arguments( + userMessage("What is the time and date in Berlin now?"), + asList( + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .build(), + ToolSpecification.builder() + .name("current_date") + .description("returns current date") + .addParameter("city") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"city\":\"Berlin\"}") // 5 tokens + .build(), + ToolExecutionRequest.builder() + .name("current_date") + .arguments("{\"city\":\"Berlin\"}") // 5 tokens + .build() + ), + model + ), + arguments( + userMessage("What is the time and date in Pietramontecorvino now?"), + asList( + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .build(), + ToolSpecification.builder() + .name("current_date") + .description("returns current date") + .addParameter("city") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"city\":\"Pietramontecorvino\"}") // 12 tokens + .build(), + ToolExecutionRequest.builder() + .name("current_date") + .arguments("{\"city\":\"Pietramontecorvino\"}") // 12 tokens + .build() + ), + model + ), + + // different tools, different lengths of argument values + arguments( + userMessage("What is the time in Berlin and date in Munich now?"), + asList( + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .build(), + ToolSpecification.builder() + .name("current_date") + .description("returns current date") + .addParameter("city") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"city\":\"Berlin\"}") // 5 tokens + .build(), + ToolExecutionRequest.builder() + .name("current_date") + .arguments("{\"city\":\"Munich\"}") // 7 tokens + .build() + ), + model + ), + + // different tools, different lengths of "name", different lengths of argument values + arguments( + userMessage("What is the time in Berlin and date in Munich now?"), + asList( + ToolSpecification.builder() + .name("time") + .description("returns current time") + .addParameter("city") + .build(), + ToolSpecification.builder() + .name("current_date") + .description("returns current date") + .addParameter("city") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("time") // 1 tokens + .arguments("{\"city\":\"Berlin\"}") // 5 tokens + .build(), + ToolExecutionRequest.builder() + .name("current_date") // 2 tokens + .arguments("{\"city\":\"Munich\"}") // 7 tokens + .build() + ), + model + ), + + // one argument, 4 tool requests + arguments( + userMessage("What is the time in Berlin, Munich, London and Paris now?"), + singletonList( + ToolSpecification.builder() + .name("time") + .description("returns current time") + .addParameter("city") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("time") + .arguments("{\"city\":\"Berlin\"}") + .build(), + ToolExecutionRequest.builder() + .name("time") + .arguments("{\"city\":\"Munich\"}") + .build(), + ToolExecutionRequest.builder() + .name("time") + .arguments("{\"city\":\"London\"}") + .build(), + ToolExecutionRequest.builder() + .name("time") + .arguments("{\"city\":\"Paris\"}") + .build() + ), + model + ), + + // two arguments, different lengths of "arguments" + arguments( + userMessage("What is the time and date in Berlin now?"), + asList( + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .build(), + ToolSpecification.builder() + .name("current_date") + .description("returns current date") + .addParameter("city") + .addParameter("country") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Germany\",\"city\":\"Berlin\"}") // 9 tokens + .build(), + ToolExecutionRequest.builder() + .name("current_date") // 2 tokens + .arguments("{\"country\":\"Germany\",\"city\":\"Berlin\"}") // 9 tokens + .build() + ), + model + ), + arguments( + userMessage("What is the time and date in Pietramontecorvino now?"), + asList( + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .build(), + ToolSpecification.builder() + .name("current_date") + .description("returns current date") + .addParameter("city") + .addParameter("country") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Italy\",\"city\":\"Pietramontecorvino\"}") // 16 tokens + .build(), + ToolExecutionRequest.builder() + .name("current_date") + .arguments("{\"country\":\"Italy\",\"city\":\"Pietramontecorvino\"}") // 16 tokens + .build() + ), + model + ), + + // three arguments, different lengths of "arguments" + arguments( + userMessage("What is the time in Berlin and Pietramontecorvino in 24-hour format?"), + singletonList( + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .addParameter("format", enums("12", "24")) + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("current_time") // 2 tokens + .arguments("{\"country\":\"Germany\",\"city\":\"Berlin\",\"format\":\"24\"}") // 13 tokens + .build(), + ToolExecutionRequest.builder() + .name("current_time") // 2 tokens + .arguments("{\"country\":\"Italy\",\"city\":\"Pietramontecorvino\",\"format\":\"24\"}") // 20 tokens + .build() + ), + model + ), + + // three tool execution requests, different tools and lengths of "arguments" + arguments( + userMessage("What is the time in Berlin and Pietramontecorvino in 24-hour format now and date in Munich?"), + asList( + ToolSpecification.builder() + .name("current_time") + .description("returns current time") + .addParameter("city") + .addParameter("country") + .addParameter("format", enums("12", "24")) + .build(), + ToolSpecification.builder() + .name("current_date") + .description("returns current date") + .addParameter("city") + .addParameter("country") + .build() + ), + asList( + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Germany\",\"city\":\"Berlin\",\"format\": \"24\"}") // 14 tokens + .build(), + ToolExecutionRequest.builder() + .name("current_time") + .arguments("{\"country\":\"Italy\",\"city\":\"Pietramontecorvino\",\"format\": \"24\"}") // 21 tokens + .build(), + ToolExecutionRequest.builder() + .name("current_date") + .arguments("{\"country\":\"Germany\",\"city\":\"Munich\"}") // 11 tokens + .build() + ), + model + ) + )); + } +} \ No newline at end of file diff --git a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerTest.java b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerTest.java index 7a9f36ee3..3c5130eff 100644 --- a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerTest.java +++ b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiTokenizerTest.java @@ -1,93 +1,22 @@ package dev.langchain4j.model.openai; -import dev.langchain4j.agent.tool.P; -import dev.langchain4j.agent.tool.Tool; -import dev.langchain4j.agent.tool.ToolExecutionRequest; -import dev.langchain4j.data.message.ChatMessage; +import dev.ai4j.openai4j.chat.ChatCompletionModel; +import dev.langchain4j.model.Tokenizer; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.EnumSource; import java.util.ArrayList; import java.util.List; -import java.util.Random; -import java.util.stream.Stream; -import static dev.langchain4j.data.message.AiMessage.aiMessage; -import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage; -import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; -import static java.util.Arrays.asList; -import static java.util.Collections.singletonList; +import static dev.langchain4j.model.openai.OpenAiTokenizer.countArguments; import static org.assertj.core.api.Assertions.assertThat; class OpenAiTokenizerTest { OpenAiTokenizer tokenizer = new OpenAiTokenizer(GPT_3_5_TURBO); - @ParameterizedTest - @MethodSource - void should_count_tokens_in_messages(List messages, int expectedTokenCount) { - int tokenCount = tokenizer.estimateTokenCountInMessages(messages); - assertThat(tokenCount).isEqualTo(expectedTokenCount); - } - - static Stream should_count_tokens_in_messages() { - // expected token count was taken from real OpenAI responses (usage.prompt_tokens) - return Stream.of( - Arguments.of(singletonList(userMessage("hello")), 8), - Arguments.of(singletonList(userMessage("Klaus", "hello")), 11), - Arguments.of(asList(userMessage("hello"), aiMessage("hi there")), 14), - Arguments.of(asList( - userMessage("How much is 2 plus 2?"), - aiMessage(ToolExecutionRequest.builder() - .name("calculator") - .arguments("{\"a\":2, \"b\":2}") - .build()) - ), 35), - Arguments.of(asList( - userMessage("How much is 2 plus 2?"), - aiMessage(ToolExecutionRequest.builder() - .name("calculator") - .arguments("{\"a\":2, \"b\":2}") - .build()), - toolExecutionResultMessage("a", "calculator", "4") - ), 40) - ); - } - - static class Tools { - - @Tool - int add(int a, int b) { - return a + b; - } - - @Tool("calculates the square root of the provided number") - double squareRoot(@P("number to operate on") double number) { - return Math.sqrt(number); - } - - @Tool - int temperature(String location, TemperatureUnit temperatureUnit) { - return 0; - } - - @Tool() - int randomInt() {return new Random().nextInt();} - } - - enum TemperatureUnit { - F, C - } - - @Test - void should_count_tokens_in_tools() { - int tokenCount = tokenizer.estimateTokenCountInTools(new Tools()); - assertThat(tokenCount).isEqualTo(107); // found experimentally while playing with OpenAI API - } - @Test void should_encode_and_decode_text() { String originalText = "This is a text which will be encoded and decoded back."; @@ -140,10 +69,45 @@ class OpenAiTokenizerTest { assertThat(tokenizer.estimateTokenCountInText(text3)).isEqualTo(100 * 15); } - public static List repeat(String s, int n) { + @Test + void should_count_arguments() { + assertThat(countArguments(null)).isEqualTo(0); + assertThat(countArguments("")).isEqualTo(0); + assertThat(countArguments(" ")).isEqualTo(0); + assertThat(countArguments("{}")).isEqualTo(0); + assertThat(countArguments("{ }")).isEqualTo(0); + + assertThat(countArguments("{\"one\":1}")).isEqualTo(1); + assertThat(countArguments("{\"one\": 1}")).isEqualTo(1); + assertThat(countArguments("{\"one\" : 1}")).isEqualTo(1); + + assertThat(countArguments("{\"one\":1,\"two\":2}")).isEqualTo(2); + assertThat(countArguments("{\"one\": 1,\"two\": 2}")).isEqualTo(2); + assertThat(countArguments("{\"one\" : 1,\"two\" : 2}")).isEqualTo(2); + + assertThat(countArguments("{\"one\":1,\"two\":2,\"three\":3}")).isEqualTo(3); + assertThat(countArguments("{\"one\": 1,\"two\": 2,\"three\": 3}")).isEqualTo(3); + assertThat(countArguments("{\"one\" : 1,\"two\" : 2,\"three\" : 3}")).isEqualTo(3); + } + + @ParameterizedTest + @EnumSource(ChatCompletionModel.class) + void should_support_all_models(ChatCompletionModel model) { + + // given + Tokenizer tokenizer = new OpenAiTokenizer(model.toString()); + + // when + int tokenCount = tokenizer.estimateTokenCountInText("a"); + + // then + assertThat(tokenCount).isEqualTo(1); + } + + static List repeat(String strings, int n) { List result = new ArrayList<>(); for (int i = 0; i < n; i++) { - result.add(s); + result.add(strings); } return result; } diff --git a/langchain4j-quarkus/src/main/java/dev/langchain4j/model/quarkus/AzureOpenAiStreamingChatModel.java b/langchain4j-quarkus/src/main/java/dev/langchain4j/model/quarkus/AzureOpenAiStreamingChatModel.java index f8eeb9eb0..8b4a7994a 100644 --- a/langchain4j-quarkus/src/main/java/dev/langchain4j/model/quarkus/AzureOpenAiStreamingChatModel.java +++ b/langchain4j-quarkus/src/main/java/dev/langchain4j/model/quarkus/AzureOpenAiStreamingChatModel.java @@ -20,6 +20,7 @@ import java.time.Duration; import java.util.List; import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static dev.langchain4j.model.openai.InternalOpenAiHelper.toFunctions; import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages; @@ -103,7 +104,7 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel @Override public void generate(List messages, ToolSpecification toolSpecification, StreamingResponseHandler handler) { - generate(messages, singletonList(toolSpecification), toolSpecification, handler); + generate(messages, null, toolSpecification, handler); } private void generate(List messages, @@ -122,18 +123,18 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel Integer inputTokenCount = tokenizer == null ? null : tokenizer.estimateTokenCountInMessages(messages); - if (toolSpecifications != null && !toolSpecifications.isEmpty()) { + if (toolThatMustBeExecuted != null) { + requestBuilder.functions(toFunctions(singletonList(toolThatMustBeExecuted))); + requestBuilder.functionCall(toolThatMustBeExecuted.name()); + if (tokenizer != null) { + inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted); + } + } else if (!isNullOrEmpty(toolSpecifications)) { requestBuilder.functions(toFunctions(toolSpecifications)); if (tokenizer != null) { inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications); } } - if (toolThatMustBeExecuted != null) { - requestBuilder.functionCall(toolThatMustBeExecuted.name()); - if (tokenizer != null) { - inputTokenCount += tokenizer.estimateTokenCountInToolSpecification(toolThatMustBeExecuted); - } - } ChatCompletionRequest request = requestBuilder.build(); diff --git a/langchain4j/src/test/java/dev/langchain4j/service/StreamingAiServicesIT.java b/langchain4j/src/test/java/dev/langchain4j/service/StreamingAiServicesIT.java index 0f6f8a857..786b6963b 100644 --- a/langchain4j/src/test/java/dev/langchain4j/service/StreamingAiServicesIT.java +++ b/langchain4j/src/test/java/dev/langchain4j/service/StreamingAiServicesIT.java @@ -13,6 +13,7 @@ import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.openai.OpenAiStreamingChatModel; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; +import org.assertj.core.data.Percentage; import org.junit.jupiter.api.Test; import java.util.List; @@ -34,6 +35,8 @@ public class StreamingAiServicesIT { .logResponses(true) .build(); + Percentage tokenizerPrecision = withPercentage(5); + interface Assistant { TokenStream chat(String userMessage); @@ -166,8 +169,8 @@ public class StreamingAiServicesIT { assertThat(response.content().text()).isEqualTo(answer); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(181); // TODO should be around 182? - assertThat(tokenUsage.outputTokenCount()).isCloseTo(27, withPercentage(5)); // TODO + assertThat(tokenUsage.inputTokenCount()).isCloseTo(72 + 110, tokenizerPrecision); + assertThat(tokenUsage.outputTokenCount()).isCloseTo(21 + 28, tokenizerPrecision); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -237,8 +240,8 @@ public class StreamingAiServicesIT { assertThat(response.content().text()).isEqualTo(answer); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(351); // TODO should be around 348? - assertThat(tokenUsage.outputTokenCount()).isCloseTo(52, withPercentage(5)); // TODO + assertThat(tokenUsage.inputTokenCount()).isCloseTo(79 + 117 + 152, tokenizerPrecision); + assertThat(tokenUsage.outputTokenCount()).isCloseTo(21 + 20 + 53, tokenizerPrecision); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); @@ -331,8 +334,8 @@ public class StreamingAiServicesIT { assertThat(response.content().text()).isEqualTo(answer); TokenUsage tokenUsage = response.tokenUsage(); - assertThat(tokenUsage.inputTokenCount()).isEqualTo(221); // TODO should be around 239? - assertThat(tokenUsage.outputTokenCount()).isCloseTo(57, withPercentage(5)); // TODO + assertThat(tokenUsage.inputTokenCount()).isCloseTo(79 + 160, tokenizerPrecision); + assertThat(tokenUsage.outputTokenCount()).isCloseTo(54 + 58, tokenizerPrecision); assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());