make OpenAI tokenizer more precise (#346)

This PR is a rework of `OpenAiTokenizer`.
Added `OpenAiTokenizerIT` with lots of tests to ensure that
`OpenAiTokenizer` calculates token usage very close to OpenAI.
In most cases calculation is 1:1, in some corner cases the difference is
within 5%.
This commit is contained in:
LangChain4j 2023-12-12 16:45:16 +01:00 committed by GitHub
parent 054a36f59f
commit 8e4254fc20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 2093 additions and 245 deletions

View File

@ -20,6 +20,7 @@ import java.time.Duration;
import java.util.List;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO;
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupOpenAIClient;
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.toFunctions;
@ -117,7 +118,7 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel
@Override
public void generate(List<ChatMessage> messages, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> handler) {
generate(messages, singletonList(toolSpecification), toolSpecification, handler);
generate(messages, null, toolSpecification, handler);
}
private void generate(List<ChatMessage> messages,
@ -136,18 +137,18 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel
Integer inputTokenCount = tokenizer == null ? null : tokenizer.estimateTokenCountInMessages(messages);
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
if (toolThatMustBeExecuted != null) {
options.setFunctions(toFunctions(singletonList(toolThatMustBeExecuted)));
options.setFunctionCall(new FunctionCallConfig(toolThatMustBeExecuted.name()));
if (tokenizer != null) {
inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted);
}
} else if (!isNullOrEmpty(toolSpecifications)) {
options.setFunctions(toFunctions(toolSpecifications));
if (tokenizer != null) {
inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications);
}
}
if (toolThatMustBeExecuted != null) {
options.setFunctionCall(new FunctionCallConfig(toolThatMustBeExecuted.name()));
if (tokenizer != null) {
inputTokenCount += tokenizer.estimateTokenCountInToolSpecification(toolThatMustBeExecuted);
}
}
AzureOpenAiStreamingResponseBuilder responseBuilder = new AzureOpenAiStreamingResponseBuilder(inputTokenCount);

View File

@ -124,7 +124,7 @@ class AzureOpenAiStreamingChatModelIT {
assertThat(toolExecutionRequest.name()).isEqualTo("calculator");
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(50);
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(53);
assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0);
assertThat(response.tokenUsage().totalTokenCount())
.isEqualTo(response.tokenUsage().inputTokenCount() + response.tokenUsage().outputTokenCount());

View File

@ -1,23 +1,21 @@
package dev.langchain4j.internal;
import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE;
import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE_TIME;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonDeserializer;
import com.google.gson.JsonPrimitive;
import com.google.gson.JsonSerializer;
import com.google.gson.*;
import com.google.gson.reflect.TypeToken;
import com.google.gson.stream.JsonWriter;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStreamWriter;
import java.io.*;
import java.lang.reflect.Type;
import java.nio.charset.StandardCharsets;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.Map;
import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE;
import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE_TIME;
class GsonJsonCodec implements Json.JsonCodec {
private static final Gson GSON = new GsonBuilder()
.setPrettyPrinting()
.registerTypeAdapter(
@ -40,6 +38,9 @@ class GsonJsonCodec implements Json.JsonCodec {
)
.create();
public static final Type MAP_TYPE = new TypeToken<Map<String, String>>() {
}.getType();
@Override
public String toJson(Object o) {
return GSON.toJson(o);
@ -47,6 +48,9 @@ class GsonJsonCodec implements Json.JsonCodec {
@Override
public <T> T fromJson(String json, Class<T> type) {
if (type == Map.class) {
return GSON.fromJson(json, MAP_TYPE);
}
return GSON.fromJson(json, type);
}

View File

@ -1,44 +1,43 @@
package dev.langchain4j.internal;
import dev.langchain4j.spi.json.JsonCodecFactory;
import dev.langchain4j.spi.ServiceHelper;
import dev.langchain4j.spi.json.JsonCodecFactory;
import java.io.IOException;
import java.io.InputStream;
import java.util.Collection;
public class Json {
private static final JsonCodec CODEC = loadCodec();
private static final JsonCodec CODEC = loadCodec();
private static JsonCodec loadCodec() {
Collection<JsonCodecFactory> factories = ServiceHelper.loadFactories(JsonCodecFactory.class);
for (JsonCodecFactory factory : factories) {
return factory.create();
private static JsonCodec loadCodec() {
Collection<JsonCodecFactory> factories = ServiceHelper.loadFactories(JsonCodecFactory.class);
for (JsonCodecFactory factory : factories) {
return factory.create();
}
// fallback to default
return new GsonJsonCodec();
}
// fallback to default
return new GsonJsonCodec();
}
public static String toJson(Object o) {
return CODEC.toJson(o);
}
public static <T> T fromJson(String json, Class<T> type) {
return CODEC.fromJson(json, type);
}
public static String toJson(Object o) {
return CODEC.toJson(o);
}
public static InputStream toInputStream(Object o, Class<?> type) throws IOException {
return CODEC.toInputStream(o, type);
}
public static <T> T fromJson(String json, Class<T> type) {
return CODEC.fromJson(json, type);
}
public interface JsonCodec {
public static InputStream toInputStream(Object o, Class<?> type) throws IOException {
return CODEC.toInputStream(o, type);
}
String toJson(Object o);
public interface JsonCodec {
String toJson(Object o);
<T> T fromJson(String json, Class<T> type);
InputStream toInputStream(Object o, Class<?> type) throws IOException;
}
<T> T fromJson(String json, Class<T> type);
InputStream toInputStream(Object o, Class<?> type) throws IOException;
}
}

View File

@ -29,12 +29,12 @@ public interface Tokenizer {
return estimateTokenCountInToolSpecifications(toolSpecifications);
}
default int estimateTokenCountInToolSpecification(ToolSpecification toolSpecification) {
int estimateTokenCountInToolSpecifications(Iterable<ToolSpecification> toolSpecifications);
default int estimateTokenCountInForcefulToolSpecification(ToolSpecification toolSpecification) {
return estimateTokenCountInToolSpecifications(singletonList(toolSpecification));
}
int estimateTokenCountInToolSpecifications(Iterable<ToolSpecification> toolSpecifications);
int estimateTokenCountInToolExecutionRequests(Iterable<ToolExecutionRequest> toolExecutionRequests);
default int estimateTokenCountInForcefulToolExecutionRequest(ToolExecutionRequest toolExecutionRequest) {

View File

@ -20,6 +20,7 @@ import java.time.Duration;
import java.util.List;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
import static java.time.Duration.ofSeconds;
@ -93,7 +94,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
@Override
public void generate(List<ChatMessage> messages, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> handler) {
generate(messages, singletonList(toolSpecification), toolSpecification, handler);
generate(messages, null, toolSpecification, handler);
}
private void generate(List<ChatMessage> messages,
@ -114,14 +115,14 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
int inputTokenCount = tokenizer.estimateTokenCountInMessages(messages);
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
if (toolThatMustBeExecuted != null) {
requestBuilder.tools(toTools(singletonList(toolThatMustBeExecuted)));
requestBuilder.toolChoice(toolThatMustBeExecuted.name());
inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted);
} else if (!isNullOrEmpty(toolSpecifications)) {
requestBuilder.tools(toTools(toolSpecifications));
inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications);
}
if (toolThatMustBeExecuted != null) {
requestBuilder.toolChoice(toolThatMustBeExecuted.name());
inputTokenCount += tokenizer.estimateTokenCountInToolSpecification(toolThatMustBeExecuted);
}
ChatCompletionRequest request = requestBuilder.build();

View File

@ -7,7 +7,6 @@ import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.Tokenizer;
@ -17,10 +16,17 @@ import java.util.Optional;
import java.util.function.Supplier;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.Json.fromJson;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.roleFrom;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO_0301;
import static dev.langchain4j.model.openai.OpenAiModelName.*;
import static java.util.Collections.singletonList;
/**
* This class can be used to estimate the cost (in tokens) before calling OpenAI or when using streaming.
* Magic numbers present in this class were found empirically while testing.
* There are integration tests in place that are making sure that the calculations here are very close to that of OpenAI.
*/
public class OpenAiTokenizer implements Tokenizer {
private final String modelName;
@ -40,43 +46,15 @@ public class OpenAiTokenizer implements Tokenizer {
.countTokensOrdinary(text);
}
//Estimate the number of tokens in the parameters of a tool
private int estimateTokenCountInToolParameters(ToolParameters parameters) {
//Return early if there are no parameters
if (parameters == null) return 0;
int tokenCount = 0;
Map<String, Map<String, Object>> properties = parameters.properties();
for (String property : properties.keySet()) {
for (Map.Entry<String, Object> entry : properties.get(property).entrySet()) {
if ("type".equals(entry.getKey())) {
tokenCount += 3; // found experimentally while playing with OpenAI API
tokenCount += estimateTokenCountInText(entry.getValue().toString());
} else if ("description".equals(entry.getKey())) {
tokenCount += 3; // found experimentally while playing with OpenAI API
tokenCount += estimateTokenCountInText(entry.getValue().toString());
} else if ("enum".equals(entry.getKey())) {
tokenCount -= 3; // found experimentally while playing with OpenAI API
for (Object enumValue : (Object[]) entry.getValue()) {
tokenCount += 3; // found experimentally while playing with OpenAI API
tokenCount += estimateTokenCountInText(enumValue.toString());
}
}
}
}
return tokenCount;
}
@Override
public int estimateTokenCountInMessage(ChatMessage message) {
int tokenCount = 0;
int tokenCount = 1; // 1 token for role
tokenCount += extraTokensPerMessage();
tokenCount += estimateTokenCountInText(message.text());
tokenCount += estimateTokenCountInText(roleFrom(message).toString().toLowerCase());
if (message instanceof UserMessage) {
UserMessage userMessage = (UserMessage) message;
if (userMessage.name() != null) {
if (userMessage.name() != null && !modelName.equals(GPT_4_VISION_PREVIEW)) {
tokenCount += extraTokensPerName();
tokenCount += estimateTokenCountInText(userMessage.name());
}
@ -84,53 +62,37 @@ public class OpenAiTokenizer implements Tokenizer {
if (message instanceof AiMessage) {
AiMessage aiMessage = (AiMessage) message;
if (aiMessage.hasToolExecutionRequests()) {
for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
tokenCount += 4; // found experimentally while playing with OpenAI API
tokenCount += estimateTokenCountInText(toolExecutionRequest.name());
if (aiMessage.toolExecutionRequests() != null) {
if (modelName.contains("1106")) {
tokenCount += 6;
} else {
tokenCount += 3;
}
if (aiMessage.toolExecutionRequests().size() == 1) {
tokenCount -= 1;
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
tokenCount += estimateTokenCountInText(toolExecutionRequest.name()) * 2;
tokenCount += estimateTokenCountInText(toolExecutionRequest.arguments());
} else {
tokenCount += 15;
for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
tokenCount += 7;
tokenCount += estimateTokenCountInText(toolExecutionRequest.name());
Map<?, ?> arguments = fromJson(toolExecutionRequest.arguments(), Map.class);
for (Map.Entry<?, ?> argument : arguments.entrySet()) {
tokenCount += 2;
tokenCount += estimateTokenCountInText(argument.getKey().toString());
tokenCount += estimateTokenCountInText(argument.getValue().toString());
}
}
}
}
}
if (message instanceof ToolExecutionResultMessage) {
ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage) message;
tokenCount += -1; // found experimentally while playing with OpenAI API
tokenCount += estimateTokenCountInText(toolExecutionResultMessage.toolName());
}
return tokenCount;
}
@Override
public int estimateTokenCountInMessages(Iterable<ChatMessage> messages) {
// see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
int tokenCount = 3; // every reply is primed with <|start|>assistant<|message|>
for (ChatMessage message : messages) {
tokenCount += estimateTokenCountInMessage(message);
}
return tokenCount;
}
@Override
public int estimateTokenCountInToolSpecifications(Iterable<ToolSpecification> toolSpecifications) {
int tokenCount = 0;
for (ToolSpecification toolSpecification : toolSpecifications) {
tokenCount += estimateTokenCountInText(toolSpecification.name());
tokenCount += estimateTokenCountInText(toolSpecification.description());
tokenCount += estimateTokenCountInToolParameters(toolSpecification.parameters());
tokenCount += 12; // found experimentally while playing with OpenAI API
}
tokenCount += 12; // found experimentally while playing with OpenAI API
return tokenCount;
}
@Override
public int estimateTokenCountInToolExecutionRequests(Iterable<ToolExecutionRequest> toolExecutionRequests) {
return 0; // TODO
}
private int extraTokensPerMessage() {
if (modelName.equals(GPT_3_5_TURBO_0301)) {
return 4;
@ -147,6 +109,88 @@ public class OpenAiTokenizer implements Tokenizer {
}
}
@Override
public int estimateTokenCountInMessages(Iterable<ChatMessage> messages) {
// see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
int tokenCount = 3; // every reply is primed with <|start|>assistant<|message|>
for (ChatMessage message : messages) {
tokenCount += estimateTokenCountInMessage(message);
}
return tokenCount;
}
@Override
public int estimateTokenCountInToolSpecifications(Iterable<ToolSpecification> toolSpecifications) {
int tokenCount = 16;
for (ToolSpecification toolSpecification : toolSpecifications) {
tokenCount += 6;
tokenCount += estimateTokenCountInText(toolSpecification.name());
if (toolSpecification.description() != null) {
tokenCount += 2;
tokenCount += estimateTokenCountInText(toolSpecification.description());
}
tokenCount += estimateTokenCountInToolParameters(toolSpecification.parameters());
}
return tokenCount;
}
private int estimateTokenCountInToolParameters(ToolParameters parameters) {
if (parameters == null) {
return 0;
}
int tokenCount = 3;
Map<String, Map<String, Object>> properties = parameters.properties();
if (modelName.contains("1106")) {
tokenCount += properties.size() - 1;
}
for (String property : properties.keySet()) {
if (modelName.contains("1106")) {
tokenCount += 2;
} else {
tokenCount += 3;
}
tokenCount += estimateTokenCountInText(property);
for (Map.Entry<String, Object> entry : properties.get(property).entrySet()) {
if ("type".equals(entry.getKey())) {
if ("array".equals(entry.getValue()) && modelName.contains("1106")) {
tokenCount += 1;
}
// TODO object
} else if ("description".equals(entry.getKey())) {
tokenCount += 2;
tokenCount += estimateTokenCountInText(entry.getValue().toString());
if (modelName.contains("1106") && parameters.required().contains(property)) {
tokenCount += 1;
}
} else if ("enum".equals(entry.getKey())) {
if (modelName.contains("1106")) {
tokenCount -= 2;
} else {
tokenCount -= 3;
}
for (Object enumValue : (Object[]) entry.getValue()) {
tokenCount += 3;
tokenCount += estimateTokenCountInText(enumValue.toString());
}
}
}
}
return tokenCount;
}
@Override
public int estimateTokenCountInForcefulToolSpecification(ToolSpecification toolSpecification) {
int tokenCount = estimateTokenCountInToolSpecifications(singletonList(toolSpecification));
tokenCount += 4;
tokenCount += estimateTokenCountInText(toolSpecification.name());
if (modelName.contains("1106")) {
tokenCount += 3;
}
return tokenCount;
}
public List<Integer> encode(String text) {
return encoding.orElseThrow(unknownModelException())
.encodeOrdinary(text);
@ -165,4 +209,92 @@ public class OpenAiTokenizer implements Tokenizer {
private Supplier<IllegalArgumentException> unknownModelException() {
return () -> illegalArgument("Model '%s' is unknown to jtokkit", modelName);
}
@Override
public int estimateTokenCountInToolExecutionRequests(Iterable<ToolExecutionRequest> toolExecutionRequests) {
int tokenCount = 0;
int toolsCount = 0;
int toolsWithArgumentsCount = 0;
int toolsWithoutArgumentsCount = 0;
int totalArgumentsCount = 0;
for (ToolExecutionRequest toolExecutionRequest : toolExecutionRequests) {
tokenCount += 4;
tokenCount += estimateTokenCountInText(toolExecutionRequest.name());
tokenCount += estimateTokenCountInText(toolExecutionRequest.arguments());
int argumentCount = countArguments(toolExecutionRequest.arguments());
if (argumentCount == 0) {
toolsWithoutArgumentsCount++;
} else {
toolsWithArgumentsCount++;
}
totalArgumentsCount += argumentCount;
toolsCount++;
}
if (modelName.equals(GPT_3_5_TURBO_1106)) {
tokenCount += 16;
tokenCount += 3 * toolsWithoutArgumentsCount;
tokenCount += toolsCount;
if (totalArgumentsCount > 0) {
tokenCount -= 1;
tokenCount -= 2 * totalArgumentsCount;
tokenCount += 2 * toolsWithArgumentsCount;
tokenCount += toolsCount;
}
}
if (modelName.equals(GPT_4_1106_PREVIEW)) {
tokenCount += 3;
if (toolsCount > 1) {
tokenCount += 18;
tokenCount += 15 * toolsCount;
tokenCount += totalArgumentsCount;
tokenCount -= 3 * toolsWithoutArgumentsCount;
}
}
return tokenCount;
}
@Override
public int estimateTokenCountInForcefulToolExecutionRequest(ToolExecutionRequest toolExecutionRequest) {
if (modelName.equals(GPT_4_1106_PREVIEW)) {
int argumentsCount = countArguments(toolExecutionRequest.arguments());
if (argumentsCount == 0) {
return 1;
} else {
return estimateTokenCountInText(toolExecutionRequest.arguments());
}
}
int tokenCount = estimateTokenCountInToolExecutionRequests(singletonList(toolExecutionRequest));
tokenCount -= 4;
tokenCount -= estimateTokenCountInText(toolExecutionRequest.name());
if (modelName.equals(GPT_3_5_TURBO_1106)) {
int argumentsCount = countArguments(toolExecutionRequest.arguments());
if (argumentsCount == 0) {
return 1;
}
tokenCount -= 19;
tokenCount += 2 * argumentsCount;
}
return tokenCount;
}
static int countArguments(String arguments) {
if (isNullOrBlank(arguments)) {
return 0;
}
Map<?, ?> argumentsMap = fromJson(arguments, Map.class);
return argumentsMap.size();
}
}

View File

@ -15,6 +15,7 @@ import java.util.List;
import static dev.ai4j.openai4j.chat.ChatCompletionModel.GPT_3_5_TURBO_1106;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
import static dev.langchain4j.data.message.ToolExecutionResultMessage.from;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.output.FinishReason.*;
import static java.util.Arrays.asList;
@ -41,17 +42,17 @@ class OpenAiChatModelIT {
void should_generate_answer_and_return_token_usage_and_finish_reason_stop() {
// given
UserMessage userMessage = userMessage("hello, how are you?");
UserMessage userMessage = userMessage("What is the capital of Germany?");
// when
Response<AiMessage> response = model.generate(userMessage);
// then
assertThat(response.content().text()).isNotBlank();
assertThat(response.content().text()).contains("Berlin");
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(13);
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(1);
assertThat(tokenUsage.inputTokenCount()).isEqualTo(14);
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
@ -67,7 +68,7 @@ class OpenAiChatModelIT {
.maxTokens(3)
.build();
UserMessage userMessage = userMessage("hello, how are you?");
UserMessage userMessage = userMessage("What is the capital of Germany?");
// when
Response<AiMessage> response = model.generate(userMessage);
@ -76,7 +77,7 @@ class OpenAiChatModelIT {
assertThat(response.content().text()).isNotBlank();
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(13);
assertThat(tokenUsage.inputTokenCount()).isEqualTo(14);
assertThat(tokenUsage.outputTokenCount()).isEqualTo(3);
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
@ -113,8 +114,7 @@ class OpenAiChatModelIT {
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
// given
ToolExecutionResultMessage toolExecutionResultMessage
= ToolExecutionResultMessage.from(toolExecutionRequest, "4");
ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "4");
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);
// when
@ -135,7 +135,7 @@ class OpenAiChatModelIT {
}
@Test
void should_execute_concrete_tool_then_answer() {
void should_execute_tool_forcefully_then_answer() {
// given
UserMessage userMessage = userMessage("2+2=?");
@ -162,8 +162,7 @@ class OpenAiChatModelIT {
assertThat(response.finishReason()).isEqualTo(STOP); // not sure if a bug in OpenAI or stop is expected here
// given
ToolExecutionResultMessage toolExecutionResultMessage
= ToolExecutionResultMessage.from(toolExecutionRequest, "4");
ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "4");
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);
// when
@ -221,8 +220,8 @@ class OpenAiChatModelIT {
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
// given
ToolExecutionResultMessage toolExecutionResultMessage1 = ToolExecutionResultMessage.from(toolExecutionRequest1, "4");
ToolExecutionResultMessage toolExecutionResultMessage2 = ToolExecutionResultMessage.from(toolExecutionRequest2, "6");
ToolExecutionResultMessage toolExecutionResultMessage1 = from(toolExecutionRequest1, "4");
ToolExecutionResultMessage toolExecutionResultMessage2 = from(toolExecutionRequest2, "6");
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage1, toolExecutionResultMessage2);

View File

@ -10,6 +10,7 @@ import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.assertj.core.data.Percentage;
import org.junit.jupiter.api.Test;
import java.util.List;
@ -19,6 +20,7 @@ import java.util.concurrent.TimeoutException;
import static dev.ai4j.openai4j.chat.ChatCompletionModel.GPT_3_5_TURBO_1106;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
import static dev.langchain4j.data.message.ToolExecutionResultMessage.from;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
@ -26,6 +28,7 @@ import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Percentage.withPercentage;
class OpenAiStreamingChatModelIT {
@ -43,6 +46,8 @@ class OpenAiStreamingChatModelIT {
.addParameter("second", INTEGER)
.build();
Percentage tokenizerPrecision = withPercentage(5);
@Test
void should_stream_answer() throws ExecutionException, InterruptedException, TimeoutException {
@ -133,15 +138,15 @@ class OpenAiStreamingChatModelIT {
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(50); // TODO should be 53?
assertThat(tokenUsage.outputTokenCount()).isEqualTo(0); // TODO should be 22?
assertThat(tokenUsage.inputTokenCount()).isCloseTo(53, tokenizerPrecision);
assertThat(tokenUsage.outputTokenCount()).isCloseTo(22, tokenizerPrecision);
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
// given
ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(toolExecutionRequest, "4");
ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "4");
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);
@ -175,7 +180,7 @@ class OpenAiStreamingChatModelIT {
assertThat(secondAiMessage.toolExecutionRequests()).isNull();
TokenUsage secondTokenUsage = secondResponse.tokenUsage();
assertThat(secondTokenUsage.inputTokenCount()).isEqualTo(43);
assertThat(secondTokenUsage.inputTokenCount()).isCloseTo(41, tokenizerPrecision);
assertThat(secondTokenUsage.outputTokenCount()).isGreaterThan(0);
assertThat(secondTokenUsage.totalTokenCount())
.isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount());
@ -184,7 +189,7 @@ class OpenAiStreamingChatModelIT {
}
@Test
void should_execute_concrete_tool_then_stream_answer() throws Exception {
void should_execute_tool_forcefully_then_stream_answer() throws Exception {
// given
UserMessage userMessage = userMessage("2+2=?");
@ -227,15 +232,15 @@ class OpenAiStreamingChatModelIT {
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(89); // TODO should be 53?
assertThat(tokenUsage.outputTokenCount()).isEqualTo(0); // TODO should be 22?
assertThat(tokenUsage.inputTokenCount()).isCloseTo(59, tokenizerPrecision);
assertThat(tokenUsage.outputTokenCount()).isCloseTo(16, tokenizerPrecision);
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
assertThat(response.finishReason()).isEqualTo(STOP); // not sure if a bug in OpenAI or stop is expected here
// given
ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(toolExecutionRequest, "4");
ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "4");
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);
@ -269,7 +274,7 @@ class OpenAiStreamingChatModelIT {
assertThat(secondAiMessage.toolExecutionRequests()).isNull();
TokenUsage secondTokenUsage = secondResponse.tokenUsage();
assertThat(secondTokenUsage.inputTokenCount()).isEqualTo(43);
assertThat(secondTokenUsage.inputTokenCount()).isCloseTo(41, tokenizerPrecision);
assertThat(secondTokenUsage.outputTokenCount()).isGreaterThan(0);
assertThat(secondTokenUsage.totalTokenCount())
.isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount());
@ -332,16 +337,16 @@ class OpenAiStreamingChatModelIT {
assertThat(toolExecutionRequest2.arguments()).isEqualToIgnoringWhitespace("{\"first\": 3, \"second\": 3}");
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(55); // TODO should be 57?
assertThat(tokenUsage.outputTokenCount()).isEqualTo(0); // TODO should be 51?
assertThat(tokenUsage.inputTokenCount()).isCloseTo(57, tokenizerPrecision);
assertThat(tokenUsage.outputTokenCount()).isCloseTo(51, tokenizerPrecision);
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
// given
ToolExecutionResultMessage toolExecutionResultMessage1 = ToolExecutionResultMessage.from(toolExecutionRequest1, "4");
ToolExecutionResultMessage toolExecutionResultMessage2 = ToolExecutionResultMessage.from(toolExecutionRequest2, "6");
ToolExecutionResultMessage toolExecutionResultMessage1 = from(toolExecutionRequest1, "4");
ToolExecutionResultMessage toolExecutionResultMessage2 = from(toolExecutionRequest2, "6");
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage1, toolExecutionResultMessage2);
@ -375,7 +380,7 @@ class OpenAiStreamingChatModelIT {
assertThat(secondAiMessage.toolExecutionRequests()).isNull();
TokenUsage secondTokenUsage = secondResponse.tokenUsage();
assertThat(secondTokenUsage.inputTokenCount()).isEqualTo(66); // TODO should be 83?
assertThat(secondTokenUsage.inputTokenCount()).isCloseTo(83, tokenizerPrecision);
assertThat(secondTokenUsage.outputTokenCount()).isGreaterThan(0);
assertThat(secondTokenUsage.totalTokenCount())
.isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount());

View File

@ -1,93 +1,22 @@
package dev.langchain4j.model.openai;
import dev.langchain4j.agent.tool.P;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.ChatMessage;
import dev.ai4j.openai4j.chat.ChatCompletionModel;
import dev.langchain4j.model.Tokenizer;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.EnumSource;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.stream.Stream;
import static dev.langchain4j.data.message.AiMessage.aiMessage;
import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static dev.langchain4j.model.openai.OpenAiTokenizer.countArguments;
import static org.assertj.core.api.Assertions.assertThat;
class OpenAiTokenizerTest {
OpenAiTokenizer tokenizer = new OpenAiTokenizer(GPT_3_5_TURBO);
@ParameterizedTest
@MethodSource
void should_count_tokens_in_messages(List<ChatMessage> messages, int expectedTokenCount) {
int tokenCount = tokenizer.estimateTokenCountInMessages(messages);
assertThat(tokenCount).isEqualTo(expectedTokenCount);
}
static Stream<Arguments> should_count_tokens_in_messages() {
// expected token count was taken from real OpenAI responses (usage.prompt_tokens)
return Stream.of(
Arguments.of(singletonList(userMessage("hello")), 8),
Arguments.of(singletonList(userMessage("Klaus", "hello")), 11),
Arguments.of(asList(userMessage("hello"), aiMessage("hi there")), 14),
Arguments.of(asList(
userMessage("How much is 2 plus 2?"),
aiMessage(ToolExecutionRequest.builder()
.name("calculator")
.arguments("{\"a\":2, \"b\":2}")
.build())
), 35),
Arguments.of(asList(
userMessage("How much is 2 plus 2?"),
aiMessage(ToolExecutionRequest.builder()
.name("calculator")
.arguments("{\"a\":2, \"b\":2}")
.build()),
toolExecutionResultMessage("a", "calculator", "4")
), 40)
);
}
static class Tools {
@Tool
int add(int a, int b) {
return a + b;
}
@Tool("calculates the square root of the provided number")
double squareRoot(@P("number to operate on") double number) {
return Math.sqrt(number);
}
@Tool
int temperature(String location, TemperatureUnit temperatureUnit) {
return 0;
}
@Tool()
int randomInt() {return new Random().nextInt();}
}
enum TemperatureUnit {
F, C
}
@Test
void should_count_tokens_in_tools() {
int tokenCount = tokenizer.estimateTokenCountInTools(new Tools());
assertThat(tokenCount).isEqualTo(107); // found experimentally while playing with OpenAI API
}
@Test
void should_encode_and_decode_text() {
String originalText = "This is a text which will be encoded and decoded back.";
@ -140,10 +69,45 @@ class OpenAiTokenizerTest {
assertThat(tokenizer.estimateTokenCountInText(text3)).isEqualTo(100 * 15);
}
public static List<String> repeat(String s, int n) {
@Test
void should_count_arguments() {
assertThat(countArguments(null)).isEqualTo(0);
assertThat(countArguments("")).isEqualTo(0);
assertThat(countArguments(" ")).isEqualTo(0);
assertThat(countArguments("{}")).isEqualTo(0);
assertThat(countArguments("{ }")).isEqualTo(0);
assertThat(countArguments("{\"one\":1}")).isEqualTo(1);
assertThat(countArguments("{\"one\": 1}")).isEqualTo(1);
assertThat(countArguments("{\"one\" : 1}")).isEqualTo(1);
assertThat(countArguments("{\"one\":1,\"two\":2}")).isEqualTo(2);
assertThat(countArguments("{\"one\": 1,\"two\": 2}")).isEqualTo(2);
assertThat(countArguments("{\"one\" : 1,\"two\" : 2}")).isEqualTo(2);
assertThat(countArguments("{\"one\":1,\"two\":2,\"three\":3}")).isEqualTo(3);
assertThat(countArguments("{\"one\": 1,\"two\": 2,\"three\": 3}")).isEqualTo(3);
assertThat(countArguments("{\"one\" : 1,\"two\" : 2,\"three\" : 3}")).isEqualTo(3);
}
@ParameterizedTest
@EnumSource(ChatCompletionModel.class)
void should_support_all_models(ChatCompletionModel model) {
// given
Tokenizer tokenizer = new OpenAiTokenizer(model.toString());
// when
int tokenCount = tokenizer.estimateTokenCountInText("a");
// then
assertThat(tokenCount).isEqualTo(1);
}
static List<String> repeat(String strings, int n) {
List<String> result = new ArrayList<>();
for (int i = 0; i < n; i++) {
result.add(s);
result.add(strings);
}
return result;
}

View File

@ -20,6 +20,7 @@ import java.time.Duration;
import java.util.List;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toFunctions;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages;
@ -103,7 +104,7 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel
@Override
public void generate(List<ChatMessage> messages, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> handler) {
generate(messages, singletonList(toolSpecification), toolSpecification, handler);
generate(messages, null, toolSpecification, handler);
}
private void generate(List<ChatMessage> messages,
@ -122,18 +123,18 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel
Integer inputTokenCount = tokenizer == null ? null : tokenizer.estimateTokenCountInMessages(messages);
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
if (toolThatMustBeExecuted != null) {
requestBuilder.functions(toFunctions(singletonList(toolThatMustBeExecuted)));
requestBuilder.functionCall(toolThatMustBeExecuted.name());
if (tokenizer != null) {
inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted);
}
} else if (!isNullOrEmpty(toolSpecifications)) {
requestBuilder.functions(toFunctions(toolSpecifications));
if (tokenizer != null) {
inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications);
}
}
if (toolThatMustBeExecuted != null) {
requestBuilder.functionCall(toolThatMustBeExecuted.name());
if (tokenizer != null) {
inputTokenCount += tokenizer.estimateTokenCountInToolSpecification(toolThatMustBeExecuted);
}
}
ChatCompletionRequest request = requestBuilder.build();

View File

@ -13,6 +13,7 @@ import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.assertj.core.data.Percentage;
import org.junit.jupiter.api.Test;
import java.util.List;
@ -34,6 +35,8 @@ public class StreamingAiServicesIT {
.logResponses(true)
.build();
Percentage tokenizerPrecision = withPercentage(5);
interface Assistant {
TokenStream chat(String userMessage);
@ -166,8 +169,8 @@ public class StreamingAiServicesIT {
assertThat(response.content().text()).isEqualTo(answer);
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(181); // TODO should be around 182?
assertThat(tokenUsage.outputTokenCount()).isCloseTo(27, withPercentage(5)); // TODO
assertThat(tokenUsage.inputTokenCount()).isCloseTo(72 + 110, tokenizerPrecision);
assertThat(tokenUsage.outputTokenCount()).isCloseTo(21 + 28, tokenizerPrecision);
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
@ -237,8 +240,8 @@ public class StreamingAiServicesIT {
assertThat(response.content().text()).isEqualTo(answer);
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(351); // TODO should be around 348?
assertThat(tokenUsage.outputTokenCount()).isCloseTo(52, withPercentage(5)); // TODO
assertThat(tokenUsage.inputTokenCount()).isCloseTo(79 + 117 + 152, tokenizerPrecision);
assertThat(tokenUsage.outputTokenCount()).isCloseTo(21 + 20 + 53, tokenizerPrecision);
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
@ -331,8 +334,8 @@ public class StreamingAiServicesIT {
assertThat(response.content().text()).isEqualTo(answer);
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(221); // TODO should be around 239?
assertThat(tokenUsage.outputTokenCount()).isCloseTo(57, withPercentage(5)); // TODO
assertThat(tokenUsage.inputTokenCount()).isCloseTo(79 + 160, tokenizerPrecision);
assertThat(tokenUsage.outputTokenCount()).isCloseTo(54 + 58, tokenizerPrecision);
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());