make OpenAI tokenizer more precise (#346)
This PR is a rework of `OpenAiTokenizer`. Added `OpenAiTokenizerIT` with lots of tests to ensure that `OpenAiTokenizer` calculates token usage very close to OpenAI. In most cases calculation is 1:1, in some corner cases the difference is within 5%.
This commit is contained in:
parent
054a36f59f
commit
8e4254fc20
|
@ -20,6 +20,7 @@ import java.time.Duration;
|
|||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||
import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO;
|
||||
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupOpenAIClient;
|
||||
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.toFunctions;
|
||||
|
@ -117,7 +118,7 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel
|
|||
|
||||
@Override
|
||||
public void generate(List<ChatMessage> messages, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> handler) {
|
||||
generate(messages, singletonList(toolSpecification), toolSpecification, handler);
|
||||
generate(messages, null, toolSpecification, handler);
|
||||
}
|
||||
|
||||
private void generate(List<ChatMessage> messages,
|
||||
|
@ -136,18 +137,18 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel
|
|||
|
||||
Integer inputTokenCount = tokenizer == null ? null : tokenizer.estimateTokenCountInMessages(messages);
|
||||
|
||||
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
|
||||
if (toolThatMustBeExecuted != null) {
|
||||
options.setFunctions(toFunctions(singletonList(toolThatMustBeExecuted)));
|
||||
options.setFunctionCall(new FunctionCallConfig(toolThatMustBeExecuted.name()));
|
||||
if (tokenizer != null) {
|
||||
inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted);
|
||||
}
|
||||
} else if (!isNullOrEmpty(toolSpecifications)) {
|
||||
options.setFunctions(toFunctions(toolSpecifications));
|
||||
if (tokenizer != null) {
|
||||
inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications);
|
||||
}
|
||||
}
|
||||
if (toolThatMustBeExecuted != null) {
|
||||
options.setFunctionCall(new FunctionCallConfig(toolThatMustBeExecuted.name()));
|
||||
if (tokenizer != null) {
|
||||
inputTokenCount += tokenizer.estimateTokenCountInToolSpecification(toolThatMustBeExecuted);
|
||||
}
|
||||
}
|
||||
|
||||
AzureOpenAiStreamingResponseBuilder responseBuilder = new AzureOpenAiStreamingResponseBuilder(inputTokenCount);
|
||||
|
||||
|
|
|
@ -124,7 +124,7 @@ class AzureOpenAiStreamingChatModelIT {
|
|||
assertThat(toolExecutionRequest.name()).isEqualTo("calculator");
|
||||
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
|
||||
|
||||
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(50);
|
||||
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(53);
|
||||
assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0);
|
||||
assertThat(response.tokenUsage().totalTokenCount())
|
||||
.isEqualTo(response.tokenUsage().inputTokenCount() + response.tokenUsage().outputTokenCount());
|
||||
|
|
|
@ -1,23 +1,21 @@
|
|||
package dev.langchain4j.internal;
|
||||
|
||||
import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE;
|
||||
import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE_TIME;
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.GsonBuilder;
|
||||
import com.google.gson.JsonDeserializer;
|
||||
import com.google.gson.JsonPrimitive;
|
||||
import com.google.gson.JsonSerializer;
|
||||
import com.google.gson.*;
|
||||
import com.google.gson.reflect.TypeToken;
|
||||
import com.google.gson.stream.JsonWriter;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStreamWriter;
|
||||
|
||||
import java.io.*;
|
||||
import java.lang.reflect.Type;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.Map;
|
||||
|
||||
import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE;
|
||||
import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE_TIME;
|
||||
|
||||
class GsonJsonCodec implements Json.JsonCodec {
|
||||
|
||||
private static final Gson GSON = new GsonBuilder()
|
||||
.setPrettyPrinting()
|
||||
.registerTypeAdapter(
|
||||
|
@ -40,6 +38,9 @@ class GsonJsonCodec implements Json.JsonCodec {
|
|||
)
|
||||
.create();
|
||||
|
||||
public static final Type MAP_TYPE = new TypeToken<Map<String, String>>() {
|
||||
}.getType();
|
||||
|
||||
@Override
|
||||
public String toJson(Object o) {
|
||||
return GSON.toJson(o);
|
||||
|
@ -47,6 +48,9 @@ class GsonJsonCodec implements Json.JsonCodec {
|
|||
|
||||
@Override
|
||||
public <T> T fromJson(String json, Class<T> type) {
|
||||
if (type == Map.class) {
|
||||
return GSON.fromJson(json, MAP_TYPE);
|
||||
}
|
||||
return GSON.fromJson(json, type);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,44 +1,43 @@
|
|||
package dev.langchain4j.internal;
|
||||
|
||||
import dev.langchain4j.spi.json.JsonCodecFactory;
|
||||
import dev.langchain4j.spi.ServiceHelper;
|
||||
import dev.langchain4j.spi.json.JsonCodecFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.Collection;
|
||||
|
||||
public class Json {
|
||||
|
||||
private static final JsonCodec CODEC = loadCodec();
|
||||
private static final JsonCodec CODEC = loadCodec();
|
||||
|
||||
private static JsonCodec loadCodec() {
|
||||
Collection<JsonCodecFactory> factories = ServiceHelper.loadFactories(JsonCodecFactory.class);
|
||||
for (JsonCodecFactory factory : factories) {
|
||||
return factory.create();
|
||||
private static JsonCodec loadCodec() {
|
||||
Collection<JsonCodecFactory> factories = ServiceHelper.loadFactories(JsonCodecFactory.class);
|
||||
for (JsonCodecFactory factory : factories) {
|
||||
return factory.create();
|
||||
}
|
||||
// fallback to default
|
||||
return new GsonJsonCodec();
|
||||
}
|
||||
// fallback to default
|
||||
return new GsonJsonCodec();
|
||||
}
|
||||
|
||||
public static String toJson(Object o) {
|
||||
return CODEC.toJson(o);
|
||||
}
|
||||
|
||||
public static <T> T fromJson(String json, Class<T> type) {
|
||||
return CODEC.fromJson(json, type);
|
||||
}
|
||||
|
||||
public static String toJson(Object o) {
|
||||
return CODEC.toJson(o);
|
||||
}
|
||||
public static InputStream toInputStream(Object o, Class<?> type) throws IOException {
|
||||
return CODEC.toInputStream(o, type);
|
||||
}
|
||||
|
||||
public static <T> T fromJson(String json, Class<T> type) {
|
||||
return CODEC.fromJson(json, type);
|
||||
}
|
||||
public interface JsonCodec {
|
||||
|
||||
public static InputStream toInputStream(Object o, Class<?> type) throws IOException {
|
||||
return CODEC.toInputStream(o, type);
|
||||
}
|
||||
String toJson(Object o);
|
||||
|
||||
public interface JsonCodec {
|
||||
String toJson(Object o);
|
||||
|
||||
<T> T fromJson(String json, Class<T> type);
|
||||
InputStream toInputStream(Object o, Class<?> type) throws IOException;
|
||||
|
||||
}
|
||||
<T> T fromJson(String json, Class<T> type);
|
||||
|
||||
InputStream toInputStream(Object o, Class<?> type) throws IOException;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,12 +29,12 @@ public interface Tokenizer {
|
|||
return estimateTokenCountInToolSpecifications(toolSpecifications);
|
||||
}
|
||||
|
||||
default int estimateTokenCountInToolSpecification(ToolSpecification toolSpecification) {
|
||||
int estimateTokenCountInToolSpecifications(Iterable<ToolSpecification> toolSpecifications);
|
||||
|
||||
default int estimateTokenCountInForcefulToolSpecification(ToolSpecification toolSpecification) {
|
||||
return estimateTokenCountInToolSpecifications(singletonList(toolSpecification));
|
||||
}
|
||||
|
||||
int estimateTokenCountInToolSpecifications(Iterable<ToolSpecification> toolSpecifications);
|
||||
|
||||
int estimateTokenCountInToolExecutionRequests(Iterable<ToolExecutionRequest> toolExecutionRequests);
|
||||
|
||||
default int estimateTokenCountInForcefulToolExecutionRequest(ToolExecutionRequest toolExecutionRequest) {
|
||||
|
|
|
@ -20,6 +20,7 @@ import java.time.Duration;
|
|||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||
import static java.time.Duration.ofSeconds;
|
||||
|
@ -93,7 +94,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
|
||||
@Override
|
||||
public void generate(List<ChatMessage> messages, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> handler) {
|
||||
generate(messages, singletonList(toolSpecification), toolSpecification, handler);
|
||||
generate(messages, null, toolSpecification, handler);
|
||||
}
|
||||
|
||||
private void generate(List<ChatMessage> messages,
|
||||
|
@ -114,14 +115,14 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
|
||||
int inputTokenCount = tokenizer.estimateTokenCountInMessages(messages);
|
||||
|
||||
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
|
||||
if (toolThatMustBeExecuted != null) {
|
||||
requestBuilder.tools(toTools(singletonList(toolThatMustBeExecuted)));
|
||||
requestBuilder.toolChoice(toolThatMustBeExecuted.name());
|
||||
inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted);
|
||||
} else if (!isNullOrEmpty(toolSpecifications)) {
|
||||
requestBuilder.tools(toTools(toolSpecifications));
|
||||
inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications);
|
||||
}
|
||||
if (toolThatMustBeExecuted != null) {
|
||||
requestBuilder.toolChoice(toolThatMustBeExecuted.name());
|
||||
inputTokenCount += tokenizer.estimateTokenCountInToolSpecification(toolThatMustBeExecuted);
|
||||
}
|
||||
|
||||
ChatCompletionRequest request = requestBuilder.build();
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ import dev.langchain4j.agent.tool.ToolParameters;
|
|||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.ToolExecutionResultMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.model.Tokenizer;
|
||||
|
||||
|
@ -17,10 +16,17 @@ import java.util.Optional;
|
|||
import java.util.function.Supplier;
|
||||
|
||||
import static dev.langchain4j.internal.Exceptions.illegalArgument;
|
||||
import static dev.langchain4j.internal.Json.fromJson;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrBlank;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.roleFrom;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO_0301;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.*;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
||||
/**
|
||||
* This class can be used to estimate the cost (in tokens) before calling OpenAI or when using streaming.
|
||||
* Magic numbers present in this class were found empirically while testing.
|
||||
* There are integration tests in place that are making sure that the calculations here are very close to that of OpenAI.
|
||||
*/
|
||||
public class OpenAiTokenizer implements Tokenizer {
|
||||
|
||||
private final String modelName;
|
||||
|
@ -40,43 +46,15 @@ public class OpenAiTokenizer implements Tokenizer {
|
|||
.countTokensOrdinary(text);
|
||||
}
|
||||
|
||||
//Estimate the number of tokens in the parameters of a tool
|
||||
private int estimateTokenCountInToolParameters(ToolParameters parameters) {
|
||||
//Return early if there are no parameters
|
||||
if (parameters == null) return 0;
|
||||
|
||||
int tokenCount = 0;
|
||||
Map<String, Map<String, Object>> properties = parameters.properties();
|
||||
for (String property : properties.keySet()) {
|
||||
for (Map.Entry<String, Object> entry : properties.get(property).entrySet()) {
|
||||
if ("type".equals(entry.getKey())) {
|
||||
tokenCount += 3; // found experimentally while playing with OpenAI API
|
||||
tokenCount += estimateTokenCountInText(entry.getValue().toString());
|
||||
} else if ("description".equals(entry.getKey())) {
|
||||
tokenCount += 3; // found experimentally while playing with OpenAI API
|
||||
tokenCount += estimateTokenCountInText(entry.getValue().toString());
|
||||
} else if ("enum".equals(entry.getKey())) {
|
||||
tokenCount -= 3; // found experimentally while playing with OpenAI API
|
||||
for (Object enumValue : (Object[]) entry.getValue()) {
|
||||
tokenCount += 3; // found experimentally while playing with OpenAI API
|
||||
tokenCount += estimateTokenCountInText(enumValue.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInMessage(ChatMessage message) {
|
||||
int tokenCount = 0;
|
||||
int tokenCount = 1; // 1 token for role
|
||||
tokenCount += extraTokensPerMessage();
|
||||
tokenCount += estimateTokenCountInText(message.text());
|
||||
tokenCount += estimateTokenCountInText(roleFrom(message).toString().toLowerCase());
|
||||
|
||||
if (message instanceof UserMessage) {
|
||||
UserMessage userMessage = (UserMessage) message;
|
||||
if (userMessage.name() != null) {
|
||||
if (userMessage.name() != null && !modelName.equals(GPT_4_VISION_PREVIEW)) {
|
||||
tokenCount += extraTokensPerName();
|
||||
tokenCount += estimateTokenCountInText(userMessage.name());
|
||||
}
|
||||
|
@ -84,53 +62,37 @@ public class OpenAiTokenizer implements Tokenizer {
|
|||
|
||||
if (message instanceof AiMessage) {
|
||||
AiMessage aiMessage = (AiMessage) message;
|
||||
if (aiMessage.hasToolExecutionRequests()) {
|
||||
for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
|
||||
tokenCount += 4; // found experimentally while playing with OpenAI API
|
||||
tokenCount += estimateTokenCountInText(toolExecutionRequest.name());
|
||||
if (aiMessage.toolExecutionRequests() != null) {
|
||||
if (modelName.contains("1106")) {
|
||||
tokenCount += 6;
|
||||
} else {
|
||||
tokenCount += 3;
|
||||
}
|
||||
if (aiMessage.toolExecutionRequests().size() == 1) {
|
||||
tokenCount -= 1;
|
||||
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
|
||||
tokenCount += estimateTokenCountInText(toolExecutionRequest.name()) * 2;
|
||||
tokenCount += estimateTokenCountInText(toolExecutionRequest.arguments());
|
||||
} else {
|
||||
tokenCount += 15;
|
||||
for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
|
||||
tokenCount += 7;
|
||||
tokenCount += estimateTokenCountInText(toolExecutionRequest.name());
|
||||
|
||||
Map<?, ?> arguments = fromJson(toolExecutionRequest.arguments(), Map.class);
|
||||
for (Map.Entry<?, ?> argument : arguments.entrySet()) {
|
||||
tokenCount += 2;
|
||||
tokenCount += estimateTokenCountInText(argument.getKey().toString());
|
||||
tokenCount += estimateTokenCountInText(argument.getValue().toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (message instanceof ToolExecutionResultMessage) {
|
||||
ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage) message;
|
||||
tokenCount += -1; // found experimentally while playing with OpenAI API
|
||||
tokenCount += estimateTokenCountInText(toolExecutionResultMessage.toolName());
|
||||
}
|
||||
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInMessages(Iterable<ChatMessage> messages) {
|
||||
// see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
|
||||
int tokenCount = 3; // every reply is primed with <|start|>assistant<|message|>
|
||||
for (ChatMessage message : messages) {
|
||||
tokenCount += estimateTokenCountInMessage(message);
|
||||
}
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInToolSpecifications(Iterable<ToolSpecification> toolSpecifications) {
|
||||
int tokenCount = 0;
|
||||
for (ToolSpecification toolSpecification : toolSpecifications) {
|
||||
tokenCount += estimateTokenCountInText(toolSpecification.name());
|
||||
tokenCount += estimateTokenCountInText(toolSpecification.description());
|
||||
tokenCount += estimateTokenCountInToolParameters(toolSpecification.parameters());
|
||||
tokenCount += 12; // found experimentally while playing with OpenAI API
|
||||
}
|
||||
tokenCount += 12; // found experimentally while playing with OpenAI API
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInToolExecutionRequests(Iterable<ToolExecutionRequest> toolExecutionRequests) {
|
||||
return 0; // TODO
|
||||
}
|
||||
|
||||
private int extraTokensPerMessage() {
|
||||
if (modelName.equals(GPT_3_5_TURBO_0301)) {
|
||||
return 4;
|
||||
|
@ -147,6 +109,88 @@ public class OpenAiTokenizer implements Tokenizer {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInMessages(Iterable<ChatMessage> messages) {
|
||||
// see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
|
||||
int tokenCount = 3; // every reply is primed with <|start|>assistant<|message|>
|
||||
for (ChatMessage message : messages) {
|
||||
tokenCount += estimateTokenCountInMessage(message);
|
||||
}
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInToolSpecifications(Iterable<ToolSpecification> toolSpecifications) {
|
||||
int tokenCount = 16;
|
||||
for (ToolSpecification toolSpecification : toolSpecifications) {
|
||||
tokenCount += 6;
|
||||
tokenCount += estimateTokenCountInText(toolSpecification.name());
|
||||
if (toolSpecification.description() != null) {
|
||||
tokenCount += 2;
|
||||
tokenCount += estimateTokenCountInText(toolSpecification.description());
|
||||
}
|
||||
tokenCount += estimateTokenCountInToolParameters(toolSpecification.parameters());
|
||||
}
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
private int estimateTokenCountInToolParameters(ToolParameters parameters) {
|
||||
if (parameters == null) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int tokenCount = 3;
|
||||
Map<String, Map<String, Object>> properties = parameters.properties();
|
||||
if (modelName.contains("1106")) {
|
||||
tokenCount += properties.size() - 1;
|
||||
}
|
||||
for (String property : properties.keySet()) {
|
||||
if (modelName.contains("1106")) {
|
||||
tokenCount += 2;
|
||||
} else {
|
||||
tokenCount += 3;
|
||||
}
|
||||
tokenCount += estimateTokenCountInText(property);
|
||||
for (Map.Entry<String, Object> entry : properties.get(property).entrySet()) {
|
||||
if ("type".equals(entry.getKey())) {
|
||||
if ("array".equals(entry.getValue()) && modelName.contains("1106")) {
|
||||
tokenCount += 1;
|
||||
}
|
||||
// TODO object
|
||||
} else if ("description".equals(entry.getKey())) {
|
||||
tokenCount += 2;
|
||||
tokenCount += estimateTokenCountInText(entry.getValue().toString());
|
||||
if (modelName.contains("1106") && parameters.required().contains(property)) {
|
||||
tokenCount += 1;
|
||||
}
|
||||
} else if ("enum".equals(entry.getKey())) {
|
||||
if (modelName.contains("1106")) {
|
||||
tokenCount -= 2;
|
||||
} else {
|
||||
tokenCount -= 3;
|
||||
}
|
||||
for (Object enumValue : (Object[]) entry.getValue()) {
|
||||
tokenCount += 3;
|
||||
tokenCount += estimateTokenCountInText(enumValue.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInForcefulToolSpecification(ToolSpecification toolSpecification) {
|
||||
int tokenCount = estimateTokenCountInToolSpecifications(singletonList(toolSpecification));
|
||||
tokenCount += 4;
|
||||
tokenCount += estimateTokenCountInText(toolSpecification.name());
|
||||
if (modelName.contains("1106")) {
|
||||
tokenCount += 3;
|
||||
}
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
public List<Integer> encode(String text) {
|
||||
return encoding.orElseThrow(unknownModelException())
|
||||
.encodeOrdinary(text);
|
||||
|
@ -165,4 +209,92 @@ public class OpenAiTokenizer implements Tokenizer {
|
|||
private Supplier<IllegalArgumentException> unknownModelException() {
|
||||
return () -> illegalArgument("Model '%s' is unknown to jtokkit", modelName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInToolExecutionRequests(Iterable<ToolExecutionRequest> toolExecutionRequests) {
|
||||
|
||||
int tokenCount = 0;
|
||||
|
||||
int toolsCount = 0;
|
||||
int toolsWithArgumentsCount = 0;
|
||||
int toolsWithoutArgumentsCount = 0;
|
||||
|
||||
int totalArgumentsCount = 0;
|
||||
|
||||
for (ToolExecutionRequest toolExecutionRequest : toolExecutionRequests) {
|
||||
tokenCount += 4;
|
||||
tokenCount += estimateTokenCountInText(toolExecutionRequest.name());
|
||||
tokenCount += estimateTokenCountInText(toolExecutionRequest.arguments());
|
||||
|
||||
int argumentCount = countArguments(toolExecutionRequest.arguments());
|
||||
if (argumentCount == 0) {
|
||||
toolsWithoutArgumentsCount++;
|
||||
} else {
|
||||
toolsWithArgumentsCount++;
|
||||
}
|
||||
totalArgumentsCount += argumentCount;
|
||||
|
||||
toolsCount++;
|
||||
}
|
||||
|
||||
if (modelName.equals(GPT_3_5_TURBO_1106)) {
|
||||
tokenCount += 16;
|
||||
tokenCount += 3 * toolsWithoutArgumentsCount;
|
||||
tokenCount += toolsCount;
|
||||
if (totalArgumentsCount > 0) {
|
||||
tokenCount -= 1;
|
||||
tokenCount -= 2 * totalArgumentsCount;
|
||||
tokenCount += 2 * toolsWithArgumentsCount;
|
||||
tokenCount += toolsCount;
|
||||
}
|
||||
}
|
||||
|
||||
if (modelName.equals(GPT_4_1106_PREVIEW)) {
|
||||
tokenCount += 3;
|
||||
if (toolsCount > 1) {
|
||||
tokenCount += 18;
|
||||
tokenCount += 15 * toolsCount;
|
||||
tokenCount += totalArgumentsCount;
|
||||
tokenCount -= 3 * toolsWithoutArgumentsCount;
|
||||
}
|
||||
}
|
||||
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInForcefulToolExecutionRequest(ToolExecutionRequest toolExecutionRequest) {
|
||||
|
||||
if (modelName.equals(GPT_4_1106_PREVIEW)) {
|
||||
int argumentsCount = countArguments(toolExecutionRequest.arguments());
|
||||
if (argumentsCount == 0) {
|
||||
return 1;
|
||||
} else {
|
||||
return estimateTokenCountInText(toolExecutionRequest.arguments());
|
||||
}
|
||||
}
|
||||
|
||||
int tokenCount = estimateTokenCountInToolExecutionRequests(singletonList(toolExecutionRequest));
|
||||
tokenCount -= 4;
|
||||
tokenCount -= estimateTokenCountInText(toolExecutionRequest.name());
|
||||
|
||||
if (modelName.equals(GPT_3_5_TURBO_1106)) {
|
||||
int argumentsCount = countArguments(toolExecutionRequest.arguments());
|
||||
if (argumentsCount == 0) {
|
||||
return 1;
|
||||
}
|
||||
tokenCount -= 19;
|
||||
tokenCount += 2 * argumentsCount;
|
||||
}
|
||||
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
static int countArguments(String arguments) {
|
||||
if (isNullOrBlank(arguments)) {
|
||||
return 0;
|
||||
}
|
||||
Map<?, ?> argumentsMap = fromJson(arguments, Map.class);
|
||||
return argumentsMap.size();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ import java.util.List;
|
|||
|
||||
import static dev.ai4j.openai4j.chat.ChatCompletionModel.GPT_3_5_TURBO_1106;
|
||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
|
||||
import static dev.langchain4j.data.message.ToolExecutionResultMessage.from;
|
||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||
import static dev.langchain4j.model.output.FinishReason.*;
|
||||
import static java.util.Arrays.asList;
|
||||
|
@ -41,17 +42,17 @@ class OpenAiChatModelIT {
|
|||
void should_generate_answer_and_return_token_usage_and_finish_reason_stop() {
|
||||
|
||||
// given
|
||||
UserMessage userMessage = userMessage("hello, how are you?");
|
||||
UserMessage userMessage = userMessage("What is the capital of Germany?");
|
||||
|
||||
// when
|
||||
Response<AiMessage> response = model.generate(userMessage);
|
||||
|
||||
// then
|
||||
assertThat(response.content().text()).isNotBlank();
|
||||
assertThat(response.content().text()).contains("Berlin");
|
||||
|
||||
TokenUsage tokenUsage = response.tokenUsage();
|
||||
assertThat(tokenUsage.inputTokenCount()).isEqualTo(13);
|
||||
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(1);
|
||||
assertThat(tokenUsage.inputTokenCount()).isEqualTo(14);
|
||||
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
|
||||
assertThat(tokenUsage.totalTokenCount())
|
||||
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
|
||||
|
||||
|
@ -67,7 +68,7 @@ class OpenAiChatModelIT {
|
|||
.maxTokens(3)
|
||||
.build();
|
||||
|
||||
UserMessage userMessage = userMessage("hello, how are you?");
|
||||
UserMessage userMessage = userMessage("What is the capital of Germany?");
|
||||
|
||||
// when
|
||||
Response<AiMessage> response = model.generate(userMessage);
|
||||
|
@ -76,7 +77,7 @@ class OpenAiChatModelIT {
|
|||
assertThat(response.content().text()).isNotBlank();
|
||||
|
||||
TokenUsage tokenUsage = response.tokenUsage();
|
||||
assertThat(tokenUsage.inputTokenCount()).isEqualTo(13);
|
||||
assertThat(tokenUsage.inputTokenCount()).isEqualTo(14);
|
||||
assertThat(tokenUsage.outputTokenCount()).isEqualTo(3);
|
||||
assertThat(tokenUsage.totalTokenCount())
|
||||
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
|
||||
|
@ -113,8 +114,7 @@ class OpenAiChatModelIT {
|
|||
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
|
||||
|
||||
// given
|
||||
ToolExecutionResultMessage toolExecutionResultMessage
|
||||
= ToolExecutionResultMessage.from(toolExecutionRequest, "4");
|
||||
ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "4");
|
||||
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);
|
||||
|
||||
// when
|
||||
|
@ -135,7 +135,7 @@ class OpenAiChatModelIT {
|
|||
}
|
||||
|
||||
@Test
|
||||
void should_execute_concrete_tool_then_answer() {
|
||||
void should_execute_tool_forcefully_then_answer() {
|
||||
|
||||
// given
|
||||
UserMessage userMessage = userMessage("2+2=?");
|
||||
|
@ -162,8 +162,7 @@ class OpenAiChatModelIT {
|
|||
assertThat(response.finishReason()).isEqualTo(STOP); // not sure if a bug in OpenAI or stop is expected here
|
||||
|
||||
// given
|
||||
ToolExecutionResultMessage toolExecutionResultMessage
|
||||
= ToolExecutionResultMessage.from(toolExecutionRequest, "4");
|
||||
ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "4");
|
||||
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);
|
||||
|
||||
// when
|
||||
|
@ -221,8 +220,8 @@ class OpenAiChatModelIT {
|
|||
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
|
||||
|
||||
// given
|
||||
ToolExecutionResultMessage toolExecutionResultMessage1 = ToolExecutionResultMessage.from(toolExecutionRequest1, "4");
|
||||
ToolExecutionResultMessage toolExecutionResultMessage2 = ToolExecutionResultMessage.from(toolExecutionRequest2, "6");
|
||||
ToolExecutionResultMessage toolExecutionResultMessage1 = from(toolExecutionRequest1, "4");
|
||||
ToolExecutionResultMessage toolExecutionResultMessage2 = from(toolExecutionRequest2, "6");
|
||||
|
||||
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage1, toolExecutionResultMessage2);
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import dev.langchain4j.model.StreamingResponseHandler;
|
|||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import org.assertj.core.data.Percentage;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
@ -19,6 +20,7 @@ import java.util.concurrent.TimeoutException;
|
|||
|
||||
import static dev.ai4j.openai4j.chat.ChatCompletionModel.GPT_3_5_TURBO_1106;
|
||||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
|
||||
import static dev.langchain4j.data.message.ToolExecutionResultMessage.from;
|
||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||
import static dev.langchain4j.model.output.FinishReason.STOP;
|
||||
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
|
||||
|
@ -26,6 +28,7 @@ import static java.util.Arrays.asList;
|
|||
import static java.util.Collections.singletonList;
|
||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.data.Percentage.withPercentage;
|
||||
|
||||
class OpenAiStreamingChatModelIT {
|
||||
|
||||
|
@ -43,6 +46,8 @@ class OpenAiStreamingChatModelIT {
|
|||
.addParameter("second", INTEGER)
|
||||
.build();
|
||||
|
||||
Percentage tokenizerPrecision = withPercentage(5);
|
||||
|
||||
@Test
|
||||
void should_stream_answer() throws ExecutionException, InterruptedException, TimeoutException {
|
||||
|
||||
|
@ -133,15 +138,15 @@ class OpenAiStreamingChatModelIT {
|
|||
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
|
||||
|
||||
TokenUsage tokenUsage = response.tokenUsage();
|
||||
assertThat(tokenUsage.inputTokenCount()).isEqualTo(50); // TODO should be 53?
|
||||
assertThat(tokenUsage.outputTokenCount()).isEqualTo(0); // TODO should be 22?
|
||||
assertThat(tokenUsage.inputTokenCount()).isCloseTo(53, tokenizerPrecision);
|
||||
assertThat(tokenUsage.outputTokenCount()).isCloseTo(22, tokenizerPrecision);
|
||||
assertThat(tokenUsage.totalTokenCount())
|
||||
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
|
||||
|
||||
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
|
||||
|
||||
// given
|
||||
ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(toolExecutionRequest, "4");
|
||||
ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "4");
|
||||
|
||||
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);
|
||||
|
||||
|
@ -175,7 +180,7 @@ class OpenAiStreamingChatModelIT {
|
|||
assertThat(secondAiMessage.toolExecutionRequests()).isNull();
|
||||
|
||||
TokenUsage secondTokenUsage = secondResponse.tokenUsage();
|
||||
assertThat(secondTokenUsage.inputTokenCount()).isEqualTo(43);
|
||||
assertThat(secondTokenUsage.inputTokenCount()).isCloseTo(41, tokenizerPrecision);
|
||||
assertThat(secondTokenUsage.outputTokenCount()).isGreaterThan(0);
|
||||
assertThat(secondTokenUsage.totalTokenCount())
|
||||
.isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount());
|
||||
|
@ -184,7 +189,7 @@ class OpenAiStreamingChatModelIT {
|
|||
}
|
||||
|
||||
@Test
|
||||
void should_execute_concrete_tool_then_stream_answer() throws Exception {
|
||||
void should_execute_tool_forcefully_then_stream_answer() throws Exception {
|
||||
|
||||
// given
|
||||
UserMessage userMessage = userMessage("2+2=?");
|
||||
|
@ -227,15 +232,15 @@ class OpenAiStreamingChatModelIT {
|
|||
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");
|
||||
|
||||
TokenUsage tokenUsage = response.tokenUsage();
|
||||
assertThat(tokenUsage.inputTokenCount()).isEqualTo(89); // TODO should be 53?
|
||||
assertThat(tokenUsage.outputTokenCount()).isEqualTo(0); // TODO should be 22?
|
||||
assertThat(tokenUsage.inputTokenCount()).isCloseTo(59, tokenizerPrecision);
|
||||
assertThat(tokenUsage.outputTokenCount()).isCloseTo(16, tokenizerPrecision);
|
||||
assertThat(tokenUsage.totalTokenCount())
|
||||
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
|
||||
|
||||
assertThat(response.finishReason()).isEqualTo(STOP); // not sure if a bug in OpenAI or stop is expected here
|
||||
|
||||
// given
|
||||
ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(toolExecutionRequest, "4");
|
||||
ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "4");
|
||||
|
||||
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);
|
||||
|
||||
|
@ -269,7 +274,7 @@ class OpenAiStreamingChatModelIT {
|
|||
assertThat(secondAiMessage.toolExecutionRequests()).isNull();
|
||||
|
||||
TokenUsage secondTokenUsage = secondResponse.tokenUsage();
|
||||
assertThat(secondTokenUsage.inputTokenCount()).isEqualTo(43);
|
||||
assertThat(secondTokenUsage.inputTokenCount()).isCloseTo(41, tokenizerPrecision);
|
||||
assertThat(secondTokenUsage.outputTokenCount()).isGreaterThan(0);
|
||||
assertThat(secondTokenUsage.totalTokenCount())
|
||||
.isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount());
|
||||
|
@ -332,16 +337,16 @@ class OpenAiStreamingChatModelIT {
|
|||
assertThat(toolExecutionRequest2.arguments()).isEqualToIgnoringWhitespace("{\"first\": 3, \"second\": 3}");
|
||||
|
||||
TokenUsage tokenUsage = response.tokenUsage();
|
||||
assertThat(tokenUsage.inputTokenCount()).isEqualTo(55); // TODO should be 57?
|
||||
assertThat(tokenUsage.outputTokenCount()).isEqualTo(0); // TODO should be 51?
|
||||
assertThat(tokenUsage.inputTokenCount()).isCloseTo(57, tokenizerPrecision);
|
||||
assertThat(tokenUsage.outputTokenCount()).isCloseTo(51, tokenizerPrecision);
|
||||
assertThat(tokenUsage.totalTokenCount())
|
||||
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
|
||||
|
||||
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
|
||||
|
||||
// given
|
||||
ToolExecutionResultMessage toolExecutionResultMessage1 = ToolExecutionResultMessage.from(toolExecutionRequest1, "4");
|
||||
ToolExecutionResultMessage toolExecutionResultMessage2 = ToolExecutionResultMessage.from(toolExecutionRequest2, "6");
|
||||
ToolExecutionResultMessage toolExecutionResultMessage1 = from(toolExecutionRequest1, "4");
|
||||
ToolExecutionResultMessage toolExecutionResultMessage2 = from(toolExecutionRequest2, "6");
|
||||
|
||||
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage1, toolExecutionResultMessage2);
|
||||
|
||||
|
@ -375,7 +380,7 @@ class OpenAiStreamingChatModelIT {
|
|||
assertThat(secondAiMessage.toolExecutionRequests()).isNull();
|
||||
|
||||
TokenUsage secondTokenUsage = secondResponse.tokenUsage();
|
||||
assertThat(secondTokenUsage.inputTokenCount()).isEqualTo(66); // TODO should be 83?
|
||||
assertThat(secondTokenUsage.inputTokenCount()).isCloseTo(83, tokenizerPrecision);
|
||||
assertThat(secondTokenUsage.outputTokenCount()).isGreaterThan(0);
|
||||
assertThat(secondTokenUsage.totalTokenCount())
|
||||
.isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount());
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,93 +1,22 @@
|
|||
package dev.langchain4j.model.openai;
|
||||
|
||||
import dev.langchain4j.agent.tool.P;
|
||||
import dev.langchain4j.agent.tool.Tool;
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.ai4j.openai4j.chat.ChatCompletionModel;
|
||||
import dev.langchain4j.model.Tokenizer;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.junit.jupiter.params.provider.EnumSource;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static dev.langchain4j.data.message.AiMessage.aiMessage;
|
||||
import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage;
|
||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||
import static java.util.Arrays.asList;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static dev.langchain4j.model.openai.OpenAiTokenizer.countArguments;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class OpenAiTokenizerTest {
|
||||
|
||||
OpenAiTokenizer tokenizer = new OpenAiTokenizer(GPT_3_5_TURBO);
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void should_count_tokens_in_messages(List<ChatMessage> messages, int expectedTokenCount) {
|
||||
int tokenCount = tokenizer.estimateTokenCountInMessages(messages);
|
||||
assertThat(tokenCount).isEqualTo(expectedTokenCount);
|
||||
}
|
||||
|
||||
static Stream<Arguments> should_count_tokens_in_messages() {
|
||||
// expected token count was taken from real OpenAI responses (usage.prompt_tokens)
|
||||
return Stream.of(
|
||||
Arguments.of(singletonList(userMessage("hello")), 8),
|
||||
Arguments.of(singletonList(userMessage("Klaus", "hello")), 11),
|
||||
Arguments.of(asList(userMessage("hello"), aiMessage("hi there")), 14),
|
||||
Arguments.of(asList(
|
||||
userMessage("How much is 2 plus 2?"),
|
||||
aiMessage(ToolExecutionRequest.builder()
|
||||
.name("calculator")
|
||||
.arguments("{\"a\":2, \"b\":2}")
|
||||
.build())
|
||||
), 35),
|
||||
Arguments.of(asList(
|
||||
userMessage("How much is 2 plus 2?"),
|
||||
aiMessage(ToolExecutionRequest.builder()
|
||||
.name("calculator")
|
||||
.arguments("{\"a\":2, \"b\":2}")
|
||||
.build()),
|
||||
toolExecutionResultMessage("a", "calculator", "4")
|
||||
), 40)
|
||||
);
|
||||
}
|
||||
|
||||
static class Tools {
|
||||
|
||||
@Tool
|
||||
int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
@Tool("calculates the square root of the provided number")
|
||||
double squareRoot(@P("number to operate on") double number) {
|
||||
return Math.sqrt(number);
|
||||
}
|
||||
|
||||
@Tool
|
||||
int temperature(String location, TemperatureUnit temperatureUnit) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Tool()
|
||||
int randomInt() {return new Random().nextInt();}
|
||||
}
|
||||
|
||||
enum TemperatureUnit {
|
||||
F, C
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_count_tokens_in_tools() {
|
||||
int tokenCount = tokenizer.estimateTokenCountInTools(new Tools());
|
||||
assertThat(tokenCount).isEqualTo(107); // found experimentally while playing with OpenAI API
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_encode_and_decode_text() {
|
||||
String originalText = "This is a text which will be encoded and decoded back.";
|
||||
|
@ -140,10 +69,45 @@ class OpenAiTokenizerTest {
|
|||
assertThat(tokenizer.estimateTokenCountInText(text3)).isEqualTo(100 * 15);
|
||||
}
|
||||
|
||||
public static List<String> repeat(String s, int n) {
|
||||
@Test
|
||||
void should_count_arguments() {
|
||||
assertThat(countArguments(null)).isEqualTo(0);
|
||||
assertThat(countArguments("")).isEqualTo(0);
|
||||
assertThat(countArguments(" ")).isEqualTo(0);
|
||||
assertThat(countArguments("{}")).isEqualTo(0);
|
||||
assertThat(countArguments("{ }")).isEqualTo(0);
|
||||
|
||||
assertThat(countArguments("{\"one\":1}")).isEqualTo(1);
|
||||
assertThat(countArguments("{\"one\": 1}")).isEqualTo(1);
|
||||
assertThat(countArguments("{\"one\" : 1}")).isEqualTo(1);
|
||||
|
||||
assertThat(countArguments("{\"one\":1,\"two\":2}")).isEqualTo(2);
|
||||
assertThat(countArguments("{\"one\": 1,\"two\": 2}")).isEqualTo(2);
|
||||
assertThat(countArguments("{\"one\" : 1,\"two\" : 2}")).isEqualTo(2);
|
||||
|
||||
assertThat(countArguments("{\"one\":1,\"two\":2,\"three\":3}")).isEqualTo(3);
|
||||
assertThat(countArguments("{\"one\": 1,\"two\": 2,\"three\": 3}")).isEqualTo(3);
|
||||
assertThat(countArguments("{\"one\" : 1,\"two\" : 2,\"three\" : 3}")).isEqualTo(3);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(ChatCompletionModel.class)
|
||||
void should_support_all_models(ChatCompletionModel model) {
|
||||
|
||||
// given
|
||||
Tokenizer tokenizer = new OpenAiTokenizer(model.toString());
|
||||
|
||||
// when
|
||||
int tokenCount = tokenizer.estimateTokenCountInText("a");
|
||||
|
||||
// then
|
||||
assertThat(tokenCount).isEqualTo(1);
|
||||
}
|
||||
|
||||
static List<String> repeat(String strings, int n) {
|
||||
List<String> result = new ArrayList<>();
|
||||
for (int i = 0; i < n; i++) {
|
||||
result.add(s);
|
||||
result.add(strings);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import java.time.Duration;
|
|||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toFunctions;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages;
|
||||
|
@ -103,7 +104,7 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel
|
|||
|
||||
@Override
|
||||
public void generate(List<ChatMessage> messages, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> handler) {
|
||||
generate(messages, singletonList(toolSpecification), toolSpecification, handler);
|
||||
generate(messages, null, toolSpecification, handler);
|
||||
}
|
||||
|
||||
private void generate(List<ChatMessage> messages,
|
||||
|
@ -122,18 +123,18 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel
|
|||
|
||||
Integer inputTokenCount = tokenizer == null ? null : tokenizer.estimateTokenCountInMessages(messages);
|
||||
|
||||
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
|
||||
if (toolThatMustBeExecuted != null) {
|
||||
requestBuilder.functions(toFunctions(singletonList(toolThatMustBeExecuted)));
|
||||
requestBuilder.functionCall(toolThatMustBeExecuted.name());
|
||||
if (tokenizer != null) {
|
||||
inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted);
|
||||
}
|
||||
} else if (!isNullOrEmpty(toolSpecifications)) {
|
||||
requestBuilder.functions(toFunctions(toolSpecifications));
|
||||
if (tokenizer != null) {
|
||||
inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications);
|
||||
}
|
||||
}
|
||||
if (toolThatMustBeExecuted != null) {
|
||||
requestBuilder.functionCall(toolThatMustBeExecuted.name());
|
||||
if (tokenizer != null) {
|
||||
inputTokenCount += tokenizer.estimateTokenCountInToolSpecification(toolThatMustBeExecuted);
|
||||
}
|
||||
}
|
||||
|
||||
ChatCompletionRequest request = requestBuilder.build();
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
|||
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import org.assertj.core.data.Percentage;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
@ -34,6 +35,8 @@ public class StreamingAiServicesIT {
|
|||
.logResponses(true)
|
||||
.build();
|
||||
|
||||
Percentage tokenizerPrecision = withPercentage(5);
|
||||
|
||||
interface Assistant {
|
||||
|
||||
TokenStream chat(String userMessage);
|
||||
|
@ -166,8 +169,8 @@ public class StreamingAiServicesIT {
|
|||
assertThat(response.content().text()).isEqualTo(answer);
|
||||
|
||||
TokenUsage tokenUsage = response.tokenUsage();
|
||||
assertThat(tokenUsage.inputTokenCount()).isEqualTo(181); // TODO should be around 182?
|
||||
assertThat(tokenUsage.outputTokenCount()).isCloseTo(27, withPercentage(5)); // TODO
|
||||
assertThat(tokenUsage.inputTokenCount()).isCloseTo(72 + 110, tokenizerPrecision);
|
||||
assertThat(tokenUsage.outputTokenCount()).isCloseTo(21 + 28, tokenizerPrecision);
|
||||
assertThat(tokenUsage.totalTokenCount())
|
||||
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
|
||||
|
||||
|
@ -237,8 +240,8 @@ public class StreamingAiServicesIT {
|
|||
assertThat(response.content().text()).isEqualTo(answer);
|
||||
|
||||
TokenUsage tokenUsage = response.tokenUsage();
|
||||
assertThat(tokenUsage.inputTokenCount()).isEqualTo(351); // TODO should be around 348?
|
||||
assertThat(tokenUsage.outputTokenCount()).isCloseTo(52, withPercentage(5)); // TODO
|
||||
assertThat(tokenUsage.inputTokenCount()).isCloseTo(79 + 117 + 152, tokenizerPrecision);
|
||||
assertThat(tokenUsage.outputTokenCount()).isCloseTo(21 + 20 + 53, tokenizerPrecision);
|
||||
assertThat(tokenUsage.totalTokenCount())
|
||||
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
|
||||
|
||||
|
@ -331,8 +334,8 @@ public class StreamingAiServicesIT {
|
|||
assertThat(response.content().text()).isEqualTo(answer);
|
||||
|
||||
TokenUsage tokenUsage = response.tokenUsage();
|
||||
assertThat(tokenUsage.inputTokenCount()).isEqualTo(221); // TODO should be around 239?
|
||||
assertThat(tokenUsage.outputTokenCount()).isCloseTo(57, withPercentage(5)); // TODO
|
||||
assertThat(tokenUsage.inputTokenCount()).isCloseTo(79 + 160, tokenizerPrecision);
|
||||
assertThat(tokenUsage.outputTokenCount()).isCloseTo(54 + 58, tokenizerPrecision);
|
||||
assertThat(tokenUsage.totalTokenCount())
|
||||
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
|
||||
|
||||
|
|
Loading…
Reference in New Issue