Add a Tokenizer to Azure OpenAI (#1222)

Today the Azure OpenAI module does not have its own Tokenizer. It
actually uses the OpenAI one during its test phase
(`langchain4j-open-ai` dependency in `test` scope).

This PR is for Azure OpenAI to have its own Tokenizer. Not just for test
purposes (this PR removes the dependency to `langchain4j-open-ai`) but
also Azure OpenAI can be used for RAG without the need of the
`OpenAITokenizer`.

For that I had to change all the `ModelName` classes
(`AzureOpenAiChatModelName`, `AzureOpenAiEmbeddingModelName`...) to add
the model name (eg. `gpt-35-turbo-0301`), the model type (eg.
`gpt-3.5-turbo`, notice the dot `.` in `3.5`) and version.

The `AzureOpenAiTokenizer` is a copy/paste from `OpenAiTokenizer`.

All `AzureOpenAiTokenizerTest` tests pass.

## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [x] There are no breaking changes
- [x] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [ ] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
This commit is contained in:
Antonio Goncalves 2024-06-06 14:38:22 +02:00 committed by GitHub
parent 45a4386ca0
commit 22d0a5fbb5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 2484 additions and 76 deletions

View File

@ -41,9 +41,8 @@
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
<scope>test</scope>
<groupId>com.knuddels</groupId>
<artifactId>jtokkit</artifactId>
</dependency>
<dependency>

View File

@ -6,37 +6,60 @@ package dev.langchain4j.model.azure;
*/
public enum AzureOpenAiChatModelName {
GPT_3_5_TURBO("gpt-35-turbo"), // alias for the latest gpt-3.5-turbo model
GPT_3_5_TURBO_0301("gpt-35-turbo-0301"), // 4k context, functions
GPT_3_5_TURBO_0613("gpt-35-turbo-0613"), // 4k context, functions
GPT_3_5_TURBO_1106("gpt-35-turbo-1106"), // 16k context, functions
GPT_3_5_TURBO("gpt-35-turbo", "gpt-3.5-turbo"), // alias for the latest gpt-3.5-turbo model
GPT_3_5_TURBO_0301("gpt-35-turbo-0301", "gpt-3.5-turbo", "0301"), // 4k context, functions
GPT_3_5_TURBO_0613("gpt-35-turbo-0613", "gpt-3.5-turbo", "0613"), // 4k context, functions
GPT_3_5_TURBO_1106("gpt-35-turbo-1106", "gpt-3.5-turbo", "1106"), // 16k context, functions
GPT_3_5_TURBO_16K("gpt-35-turbo-16k"), // alias for the latest gpt-3.5-turbo-16k model
GPT_3_5_TURBO_16K_0613("gpt-35-turbo-16k-0613"), // 16k context, functions
GPT_3_5_TURBO_16K("gpt-35-turbo-16k", "gpt-3.5-turbo-16k"), // alias for the latest gpt-3.5-turbo-16k model
GPT_3_5_TURBO_16K_0613("gpt-35-turbo-16k-0613", "gpt-3.5-turbo-16k", "0613"), // 16k context, functions
GPT_4("gpt-4"), // alias for the latest gpt-4
GPT_4_0613("gpt-4-0613"), // 8k context, functions
GPT_4_0125_PREVIEW("gpt-4-0125-preview"), // 8k context
GPT_4_1106_PREVIEW("gpt-4-1106-preview"), // 8k context
GPT_4("gpt-4", "gpt-4"), // alias for the latest gpt-4
GPT_4_0613("gpt-4-0613", "gpt-4", "0613"), // 8k context, functions
GPT_4_0125_PREVIEW("gpt-4-0125-preview", "gpt-4", "0125-preview"), // 8k context
GPT_4_1106_PREVIEW("gpt-4-1106-preview", "gpt-4", "1106-preview"), // 8k context
GPT_4_TURBO("gpt-4-turbo"), // alias for the latest gpt-4-turbo model
GPT_4_TURBO_2024_04_09("gpt-4-turbo-2024-04-09"), // alias for the latest gpt-4-turbo model
GPT_4_TURBO("gpt-4-turbo", "gpt-4-turbo"), // alias for the latest gpt-4-turbo model
GPT_4_TURBO_2024_04_09("gpt-4-turbo-2024-04-09", "gpt-4-turbo", "2024-04-09"), // alias for the latest gpt-4-turbo model
GPT_4_32K("gpt-4-32k"), // alias for the latest gpt-32k model
GPT_4_32K_0613("gpt-4-32k-0613"), // 32k context, functions
GPT_4_32K("gpt-4-32k", "gpt-4-32k"), // alias for the latest gpt-32k model
GPT_4_32K_0613("gpt-4-32k-0613", "gpt-4-32k", "0613"), // 32k context, functions
GPT_4_VISION_PREVIEW("gpt-4-vision-preview"),
GPT_4_VISION_PREVIEW("gpt-4-vision-preview", "gpt-4-vision", "preview"),
GPT_4_O("gpt-4o"); // alias for the latest gpt-4o model
GPT_4_O("gpt-4o", "gpt-4o"); // alias for the latest gpt-4o model
private final String stringValue;
private final String modelName;
// Model type follows the com.knuddels.jtokkit.api.ModelType naming convention
private final String modelType;
private final String modelVersion;
AzureOpenAiChatModelName(String stringValue) {
this.stringValue = stringValue;
AzureOpenAiChatModelName(String modelName, String modelType) {
this.modelName = modelName;
this.modelType = modelType;
this.modelVersion = null;
}
AzureOpenAiChatModelName(String modelName, String modelType, String modelVersion) {
this.modelName = modelName;
this.modelType = modelType;
this.modelVersion = modelVersion;
}
public String modelName() {
return modelName;
}
public String modelType() {
return modelType;
}
public String modelVersion() {
return modelVersion;
}
@Override
public String toString() {
return stringValue;
return modelName;
}
}

View File

@ -2,23 +2,46 @@ package dev.langchain4j.model.azure;
public enum AzureOpenAiEmbeddingModelName {
TEXT_EMBEDDING_3_SMALL("text-embedding-3-small"), // alias for the latest text-embedding-3-small model
TEXT_EMBEDDING_3_SMALL_1("text-embedding-3-small-1"),
TEXT_EMBEDDING_3_LARGE("text-embedding-3-large"),
TEXT_EMBEDDING_3_LARGE_1("text-embedding-3-large-1"),
TEXT_EMBEDDING_3_SMALL("text-embedding-3-small", "text-embedding-3-small"), // alias for the latest text-embedding-3-small model
TEXT_EMBEDDING_3_SMALL_1("text-embedding-3-small-1", "text-embedding-3-small", "1"),
TEXT_EMBEDDING_3_LARGE("text-embedding-3-large", "text-embedding-3-large"),
TEXT_EMBEDDING_3_LARGE_1("text-embedding-3-large-1", "text-embedding-3-large", "1"),
TEXT_EMBEDDING_ADA_002("text-embedding-ada-002"), // alias for the latest text-embedding-ada-002 model
TEXT_EMBEDDING_ADA_002_1("text-embedding-ada-002-1"),
TEXT_EMBEDDING_ADA_002_2("text-embedding-ada-002-2");
TEXT_EMBEDDING_ADA_002("text-embedding-ada-002", "text-embedding-ada-002"), // alias for the latest text-embedding-ada-002 model
TEXT_EMBEDDING_ADA_002_1("text-embedding-ada-002-1", "text-embedding-ada-002", "1"),
TEXT_EMBEDDING_ADA_002_2("text-embedding-ada-002-2", "text-embedding-ada-002", "2");
private final String stringValue;
private final String modelName;
// Model type follows the com.knuddels.jtokkit.api.ModelType naming convention
private final String modelType;
private final String modelVersion;
AzureOpenAiEmbeddingModelName(String stringValue) {
this.stringValue = stringValue;
AzureOpenAiEmbeddingModelName(String modelName, String modelType) {
this.modelName = modelName;
this.modelType = modelType;
this.modelVersion = null;
}
AzureOpenAiEmbeddingModelName(String modelName, String modelType, String modelVersion) {
this.modelName = modelName;
this.modelType = modelType;
this.modelVersion = null;
}
public String modelName() {
return modelName;
}
public String modelType() {
return modelType;
}
public String modelVersion() {
return modelVersion;
}
@Override
public String toString() {
return stringValue;
return modelName;
}
}

View File

@ -2,17 +2,40 @@ package dev.langchain4j.model.azure;
public enum AzureOpenAiImageModelName {
DALL_E_3("dall-e-3"), // alias for the latest dall-e-3 model
DALL_E_3_30("dall-e-3-30");
DALL_E_3("dall-e-3", "dall-e-3"), // alias for the latest dall-e-3 model
DALL_E_3_30("dall-e-3-30", "dall-e-3","30");
private final String stringValue;
private final String modelName;
// Model type follows the com.knuddels.jtokkit.api.ModelType naming convention
private final String modelType;
private final String modelVersion;
AzureOpenAiImageModelName(String stringValue) {
this.stringValue = stringValue;
AzureOpenAiImageModelName(String modelName, String modelType) {
this.modelName = modelName;
this.modelType = modelType;
this.modelVersion = null;
}
AzureOpenAiImageModelName(String modelName, String modelType, String modelVersion) {
this.modelName = modelName;
this.modelType = modelType;
this.modelVersion = modelVersion;
}
public String modelName() {
return modelName;
}
public String modelType() {
return modelType;
}
public String modelVersion() {
return modelVersion;
}
@Override
public String toString() {
return stringValue;
return modelName;
}
}

View File

@ -2,20 +2,43 @@ package dev.langchain4j.model.azure;
public enum AzureOpenAiLanguageModelName {
GPT_3_5_TURBO_INSTRUCT("gpt-35-turbo-instruct"), // alias for the latest gpt-3.5-turbo-instruct model
GPT_3_5_TURBO_INSTRUCT_0914("gpt-35-turbo-instruct-0914"),
GPT_3_5_TURBO_INSTRUCT("gpt-35-turbo-instruct", "gpt-3.5-turbo"), // alias for the latest gpt-3.5-turbo-instruct model
GPT_3_5_TURBO_INSTRUCT_0914("gpt-35-turbo-instruct-0914", "gpt-3.5-turbo", "0914"), // 4k context, functions
TEXT_DAVINCI_002("davinci-002"),
TEXT_DAVINCI_002_1("davinci-002-1");
TEXT_DAVINCI_002("davinci-002", "text-davinci-002"),
TEXT_DAVINCI_002_1("davinci-002-1", "text-davinci-002", "1"),;
private final String stringValue;
private final String modelName;
// Model type follows the com.knuddels.jtokkit.api.ModelType naming convention
private final String modelType;
private final String modelVersion;
AzureOpenAiLanguageModelName(String stringValue) {
this.stringValue = stringValue;
AzureOpenAiLanguageModelName(String modelName, String modelType) {
this.modelName = modelName;
this.modelType = modelType;
this.modelVersion = null;
}
AzureOpenAiLanguageModelName(String modelName, String modelType, String modelVersion) {
this.modelName = modelName;
this.modelType = modelType;
this.modelVersion = modelVersion;
}
public String modelName() {
return modelName;
}
public String modelType() {
return modelType;
}
public String modelVersion() {
return modelVersion;
}
@Override
public String toString() {
return stringValue;
return modelName;
}
}

View File

@ -0,0 +1,404 @@
package dev.langchain4j.model.azure;
import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.IntArrayList;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.Tokenizer;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.Json.fromJson;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_3_5_TURBO;
import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_3_5_TURBO_0301;
import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_3_5_TURBO_1106;
import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_4_1106_PREVIEW;
import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_4_0125_PREVIEW;
import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_4_TURBO;
import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_4_VISION_PREVIEW;
import static java.util.Collections.singletonList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;
/**
* This class can be used to estimate the cost (in tokens) before calling OpenAI or when using streaming.
* Magic numbers present in this class were found empirically while testing.
* There are integration tests in place that are making sure that the calculations here are very close to that of OpenAI.
*/
public class AzureOpenAiTokenizer implements Tokenizer {
private final String modelName;
private final Optional<Encoding> encoding;
/**
* Creates an instance of the {@code AzureOpenAiTokenizer} for the "gpt-3.5-turbo" model.
* It should be suitable for most OpenAI models, as most of them use the same cl100k_base encoding (except for GPT-4o).
*/
public AzureOpenAiTokenizer() {
this(GPT_3_5_TURBO.modelType());
}
/**
* Creates an instance of the {@code AzureOpenAiTokenizer} for a given {@link AzureOpenAiChatModelName}.
*/
public AzureOpenAiTokenizer(AzureOpenAiChatModelName modelName) {
this(modelName.modelType());
}
/**
* Creates an instance of the {@code AzureOpenAiTokenizer} for a given {@link AzureOpenAiEmbeddingModelName}.
*/
public AzureOpenAiTokenizer(AzureOpenAiEmbeddingModelName modelName) {
this(modelName.modelType());
}
/**
* Creates an instance of the {@code AzureOpenAiTokenizer} for a given {@link AzureOpenAiLanguageModelName}.
*/
public AzureOpenAiTokenizer(AzureOpenAiLanguageModelName modelName) {
this(modelName.modelType());
}
/**
* Creates an instance of the {@code AzureOpenAiTokenizer} for a given model name.
*/
public AzureOpenAiTokenizer(String modelName) {
this.modelName = ensureNotBlank(modelName, "modelName");
// If the model is unknown, we should NOT fail fast during the creation of AzureOpenAiTokenizer.
// Doing so would cause the failure of every OpenAI***Model that uses this tokenizer.
// This is done to account for situations when a new OpenAI model is available,
// but JTokkit does not yet support it.
this.encoding = Encodings.newLazyEncodingRegistry().getEncodingForModel(modelName);
}
public int estimateTokenCountInText(String text) {
return encoding.orElseThrow(unknownModelException())
.countTokensOrdinary(text);
}
@Override
public int estimateTokenCountInMessage(ChatMessage message) {
int tokenCount = 1; // 1 token for role
tokenCount += extraTokensPerMessage();
if (message instanceof SystemMessage) {
tokenCount += estimateTokenCountIn((SystemMessage) message);
} else if (message instanceof UserMessage) {
tokenCount += estimateTokenCountIn((UserMessage) message);
} else if (message instanceof AiMessage) {
tokenCount += estimateTokenCountIn((AiMessage) message);
} else if (message instanceof ToolExecutionResultMessage) {
tokenCount += estimateTokenCountIn((ToolExecutionResultMessage) message);
} else {
throw new IllegalArgumentException("Unknown message type: " + message);
}
return tokenCount;
}
private int estimateTokenCountIn(SystemMessage systemMessage) {
return estimateTokenCountInText(systemMessage.text());
}
private int estimateTokenCountIn(UserMessage userMessage) {
int tokenCount = 0;
for (Content content : userMessage.contents()) {
if (content instanceof TextContent) {
tokenCount += estimateTokenCountInText(((TextContent) content).text());
} else if (content instanceof ImageContent) {
tokenCount += 85; // TODO implement for HIGH/AUTO detail level
} else {
throw illegalArgument("Unknown content type: " + content);
}
}
if (userMessage.name() != null && !modelName.equals(GPT_4_VISION_PREVIEW.toString())) {
tokenCount += extraTokensPerName();
tokenCount += estimateTokenCountInText(userMessage.name());
}
return tokenCount;
}
private int estimateTokenCountIn(AiMessage aiMessage) {
int tokenCount = 0;
if (aiMessage.text() != null) {
tokenCount += estimateTokenCountInText(aiMessage.text());
}
if (aiMessage.toolExecutionRequests() != null) {
if (isOneOfLatestModels()) {
tokenCount += 6;
} else {
tokenCount += 3;
}
if (aiMessage.toolExecutionRequests().size() == 1) {
tokenCount -= 1;
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
tokenCount += estimateTokenCountInText(toolExecutionRequest.name()) * 2;
tokenCount += estimateTokenCountInText(toolExecutionRequest.arguments());
} else {
tokenCount += 15;
for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
tokenCount += 7;
tokenCount += estimateTokenCountInText(toolExecutionRequest.name());
Map<?, ?> arguments = fromJson(toolExecutionRequest.arguments(), Map.class);
for (Map.Entry<?, ?> argument : arguments.entrySet()) {
tokenCount += 2;
tokenCount += estimateTokenCountInText(argument.getKey().toString());
tokenCount += estimateTokenCountInText(argument.getValue().toString());
}
}
}
}
return tokenCount;
}
private int estimateTokenCountIn(ToolExecutionResultMessage toolExecutionResultMessage) {
return estimateTokenCountInText(toolExecutionResultMessage.text());
}
private int extraTokensPerMessage() {
if (modelName.equals(GPT_3_5_TURBO_0301.modelName())) {
return 4;
} else {
return 3;
}
}
private int extraTokensPerName() {
if (modelName.equals(GPT_3_5_TURBO_0301.toString())) {
return -1; // if there's a name, the role is omitted
} else {
return 1;
}
}
@Override
public int estimateTokenCountInMessages(Iterable<ChatMessage> messages) {
// see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
int tokenCount = 3; // every reply is primed with <|start|>assistant<|message|>
for (ChatMessage message : messages) {
tokenCount += estimateTokenCountInMessage(message);
}
return tokenCount;
}
@Override
public int estimateTokenCountInToolSpecifications(Iterable<ToolSpecification> toolSpecifications) {
int tokenCount = 16;
for (ToolSpecification toolSpecification : toolSpecifications) {
tokenCount += 6;
tokenCount += estimateTokenCountInText(toolSpecification.name());
if (toolSpecification.description() != null) {
tokenCount += 2;
tokenCount += estimateTokenCountInText(toolSpecification.description());
}
tokenCount += estimateTokenCountInToolParameters(toolSpecification.parameters());
}
return tokenCount;
}
private int estimateTokenCountInToolParameters(ToolParameters parameters) {
if (parameters == null) {
return 0;
}
int tokenCount = 3;
Map<String, Map<String, Object>> properties = parameters.properties();
if (isOneOfLatestModels()) {
tokenCount += properties.size() - 1;
}
for (String property : properties.keySet()) {
if (isOneOfLatestModels()) {
tokenCount += 2;
} else {
tokenCount += 3;
}
tokenCount += estimateTokenCountInText(property);
for (Map.Entry<String, Object> entry : properties.get(property).entrySet()) {
if ("type".equals(entry.getKey())) {
if ("array".equals(entry.getValue()) && isOneOfLatestModels()) {
tokenCount += 1;
}
// TODO object
} else if ("description".equals(entry.getKey())) {
tokenCount += 2;
tokenCount += estimateTokenCountInText(entry.getValue().toString());
if (isOneOfLatestModels() && parameters.required().contains(property)) {
tokenCount += 1;
}
} else if ("enum".equals(entry.getKey())) {
if (isOneOfLatestModels()) {
tokenCount -= 2;
} else {
tokenCount -= 3;
}
for (Object enumValue : (Object[]) entry.getValue()) {
tokenCount += 3;
tokenCount += estimateTokenCountInText(enumValue.toString());
}
}
}
}
return tokenCount;
}
@Override
public int estimateTokenCountInForcefulToolSpecification(ToolSpecification toolSpecification) {
int tokenCount = estimateTokenCountInToolSpecifications(singletonList(toolSpecification));
tokenCount += 4;
tokenCount += estimateTokenCountInText(toolSpecification.name());
if (isOneOfLatestModels()) {
tokenCount += 3;
}
return tokenCount;
}
public List<Integer> encode(String text) {
return encoding.orElseThrow(unknownModelException())
.encodeOrdinary(text).boxed();
}
public List<Integer> encode(String text, int maxTokensToEncode) {
return encoding.orElseThrow(unknownModelException())
.encodeOrdinary(text, maxTokensToEncode).getTokens().boxed();
}
public String decode(List<Integer> tokens) {
IntArrayList intArrayList = new IntArrayList();
for (Integer token : tokens) {
intArrayList.add(token);
}
return encoding.orElseThrow(unknownModelException())
.decode(intArrayList);
}
private Supplier<IllegalArgumentException> unknownModelException() {
return () -> illegalArgument("Model '%s' is unknown to jtokkit", modelName);
}
@Override
public int estimateTokenCountInToolExecutionRequests(Iterable<ToolExecutionRequest> toolExecutionRequests) {
int tokenCount = 0;
int toolsCount = 0;
int toolsWithArgumentsCount = 0;
int toolsWithoutArgumentsCount = 0;
int totalArgumentsCount = 0;
for (ToolExecutionRequest toolExecutionRequest : toolExecutionRequests) {
tokenCount += 4;
tokenCount += estimateTokenCountInText(toolExecutionRequest.name());
tokenCount += estimateTokenCountInText(toolExecutionRequest.arguments());
int argumentCount = countArguments(toolExecutionRequest.arguments());
if (argumentCount == 0) {
toolsWithoutArgumentsCount++;
} else {
toolsWithArgumentsCount++;
}
totalArgumentsCount += argumentCount;
toolsCount++;
}
if (modelName.equals(GPT_3_5_TURBO_1106.toString()) || isOneOfLatestGpt4Models()) {
tokenCount += 16;
tokenCount += 3 * toolsWithoutArgumentsCount;
tokenCount += toolsCount;
if (totalArgumentsCount > 0) {
tokenCount -= 1;
tokenCount -= 2 * totalArgumentsCount;
tokenCount += 2 * toolsWithArgumentsCount;
tokenCount += toolsCount;
}
}
if (modelName.equals(GPT_4_1106_PREVIEW.toString())) {
tokenCount += 3;
if (toolsCount > 1) {
tokenCount += 18;
tokenCount += 15 * toolsCount;
tokenCount += totalArgumentsCount;
tokenCount -= 3 * toolsWithoutArgumentsCount;
}
}
return tokenCount;
}
@Override
public int estimateTokenCountInForcefulToolExecutionRequest(ToolExecutionRequest toolExecutionRequest) {
if (isOneOfLatestGpt4Models()) {
int argumentsCount = countArguments(toolExecutionRequest.arguments());
if (argumentsCount == 0) {
return 1;
} else {
return estimateTokenCountInText(toolExecutionRequest.arguments());
}
}
int tokenCount = estimateTokenCountInToolExecutionRequests(singletonList(toolExecutionRequest));
tokenCount -= 4;
tokenCount -= estimateTokenCountInText(toolExecutionRequest.name());
if (modelName.equals(GPT_3_5_TURBO_1106.toString())) {
int argumentsCount = countArguments(toolExecutionRequest.arguments());
if (argumentsCount == 0) {
return 1;
}
tokenCount -= 19;
tokenCount += 2 * argumentsCount;
}
return tokenCount;
}
static int countArguments(String arguments) {
if (isNullOrBlank(arguments)) {
return 0;
}
Map<?, ?> argumentsMap = fromJson(arguments, Map.class);
return argumentsMap.size();
}
private boolean isOneOfLatestModels() {
return isOneOfLatestGpt3Models() || isOneOfLatestGpt4Models();
}
private boolean isOneOfLatestGpt3Models() {
return modelName.equals(GPT_3_5_TURBO_1106.toString())
|| modelName.equals(GPT_3_5_TURBO.toString());
}
private boolean isOneOfLatestGpt4Models() {
return modelName.equals(GPT_4_TURBO.toString())
|| modelName.equals(GPT_4_1106_PREVIEW.toString())
|| modelName.equals(GPT_4_0125_PREVIEW.toString());
}
}

View File

@ -8,7 +8,6 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.language.LanguageModel;
import dev.langchain4j.model.language.StreamingLanguageModel;
import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
@ -18,7 +17,7 @@ import org.slf4j.LoggerFactory;
import java.util.concurrent.CompletableFuture;
import static dev.langchain4j.model.openai.OpenAiLanguageModelName.GPT_3_5_TURBO_INSTRUCT;
import static dev.langchain4j.model.azure.AzureOpenAiLanguageModelName.GPT_3_5_TURBO_INSTRUCT;
import static dev.langchain4j.model.output.FinishReason.CONTENT_FILTER;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
@ -41,7 +40,7 @@ public class AzureOpenAIResponsibleAIIT {
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName(deploymentName)
.tokenizer(new OpenAiTokenizer(gptVersion))
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
.logRequestsAndResponses(true)
.build();
@ -63,7 +62,7 @@ public class AzureOpenAIResponsibleAIIT {
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName(deploymentName)
.tokenizer(new OpenAiTokenizer(gptVersion))
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
.logRequestsAndResponses(true)
.build();
@ -100,7 +99,7 @@ public class AzureOpenAIResponsibleAIIT {
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName("gpt-35-turbo-instruct")
.tokenizer(new OpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT))
.tokenizer(new AzureOpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT))
.temperature(0.0)
.maxTokens(20)
.logRequestsAndResponses(true)
@ -127,7 +126,7 @@ public class AzureOpenAIResponsibleAIIT {
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName(deploymentName)
.tokenizer(new OpenAiTokenizer(gptVersion))
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
.logRequestsAndResponses(true)
.build();
@ -172,7 +171,7 @@ public class AzureOpenAIResponsibleAIIT {
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName("gpt-35-turbo-instruct")
.tokenizer(new OpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT))
.tokenizer(new AzureOpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT))
.temperature(0.0)
.maxTokens(20)
.logRequestsAndResponses(true)

View File

@ -9,7 +9,6 @@ import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.assertj.core.data.Percentage;
@ -46,7 +45,7 @@ public class AzureOpenAiChatModelIT {
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName(deploymentName)
.tokenizer(new OpenAiTokenizer(gptVersion))
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
.logRequestsAndResponses(true)
.build();
@ -75,7 +74,7 @@ public class AzureOpenAiChatModelIT {
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName(deploymentName)
.tokenizer(new OpenAiTokenizer(gptVersion))
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
.maxTokens(3)
.logRequestsAndResponses(true)
.build();
@ -105,7 +104,7 @@ public class AzureOpenAiChatModelIT {
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName(deploymentName)
.tokenizer(new OpenAiTokenizer(gptVersion))
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
.logRequestsAndResponses(true)
.build();
@ -170,7 +169,7 @@ public class AzureOpenAiChatModelIT {
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName(deploymentName)
.tokenizer(new OpenAiTokenizer(gptVersion))
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
.logRequestsAndResponses(true)
.build();
@ -205,7 +204,7 @@ public class AzureOpenAiChatModelIT {
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName(deploymentName)
.tokenizer(new OpenAiTokenizer(gptVersion))
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
.logRequestsAndResponses(true)
.build();
@ -282,7 +281,7 @@ public class AzureOpenAiChatModelIT {
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName(deploymentName)
.tokenizer(new OpenAiTokenizer(gptVersion))
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
.responseFormat(new ChatCompletionsJsonResponseFormat())
.logRequestsAndResponses(true)
.build();

View File

@ -3,7 +3,6 @@ package dev.langchain4j.model.azure;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test;
@ -15,7 +14,7 @@ import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
import static dev.langchain4j.model.openai.OpenAiEmbeddingModelName.TEXT_EMBEDDING_ADA_002;
import static dev.langchain4j.model.azure.AzureOpenAiEmbeddingModelName.TEXT_EMBEDDING_ADA_002;
import static org.assertj.core.api.Assertions.assertThat;
public class AzureOpenAiEmbeddingModelIT {
@ -26,7 +25,7 @@ public class AzureOpenAiEmbeddingModelIT {
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName("text-embedding-ada-002")
.tokenizer(new OpenAiTokenizer(TEXT_EMBEDDING_ADA_002))
.tokenizer(new AzureOpenAiTokenizer(TEXT_EMBEDDING_ADA_002))
.logRequestsAndResponses(true)
.build();

View File

@ -1,7 +1,6 @@
package dev.langchain4j.model.azure;
import dev.langchain4j.model.language.LanguageModel;
import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test;
@ -10,7 +9,7 @@ import org.junit.jupiter.params.provider.EnumSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static dev.langchain4j.model.openai.OpenAiLanguageModelName.GPT_3_5_TURBO_INSTRUCT;
import static dev.langchain4j.model.azure.AzureOpenAiLanguageModelName.GPT_3_5_TURBO_INSTRUCT;
import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static org.assertj.core.api.Assertions.assertThat;
@ -23,7 +22,7 @@ class AzureOpenAiLanguageModelIT {
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName("gpt-35-turbo-instruct")
.tokenizer(new OpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT))
.tokenizer(new AzureOpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT))
.temperature(0.0)
.maxTokens(20)
.logRequestsAndResponses(true)

View File

@ -12,7 +12,6 @@ import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.assertj.core.data.Percentage;
@ -58,7 +57,7 @@ class AzureOpenAiStreamingChatModelIT {
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName(deploymentName)
.useAsyncClient(useAsyncClient)
.tokenizer(new OpenAiTokenizer(gptVersion))
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
.logRequestsAndResponses(true)
.build();
@ -123,7 +122,7 @@ class AzureOpenAiStreamingChatModelIT {
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName(deploymentName)
.tokenizer(new OpenAiTokenizer(gptVersion))
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
.logRequestsAndResponses(true)
.build();
model.generate("What is the capital of France?", new StreamingResponseHandler<AiMessage>() {
@ -201,7 +200,7 @@ class AzureOpenAiStreamingChatModelIT {
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName(deploymentName)
.tokenizer(new OpenAiTokenizer(gptVersion))
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
.logRequestsAndResponses(true)
.build();
@ -306,7 +305,7 @@ class AzureOpenAiStreamingChatModelIT {
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName(deploymentName)
.tokenizer(new OpenAiTokenizer(gptVersion))
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
.logRequestsAndResponses(true)
.build();

View File

@ -2,7 +2,6 @@ package dev.langchain4j.model.azure;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.language.StreamingLanguageModel;
import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
@ -10,7 +9,7 @@ import org.slf4j.LoggerFactory;
import java.util.concurrent.CompletableFuture;
import static dev.langchain4j.model.openai.OpenAiLanguageModelName.GPT_3_5_TURBO_INSTRUCT;
import static dev.langchain4j.model.azure.AzureOpenAiLanguageModelName.GPT_3_5_TURBO_INSTRUCT;
import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static java.util.concurrent.TimeUnit.SECONDS;
@ -24,7 +23,7 @@ class AzureOpenAiStreamingLanguageModelIT {
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName("gpt-35-turbo-instruct")
.tokenizer(new OpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT))
.tokenizer(new AzureOpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT))
.temperature(0.0)
.maxTokens(20)
.logRequestsAndResponses(true)

View File

@ -0,0 +1,145 @@
package dev.langchain4j.model.azure;
import dev.langchain4j.model.Tokenizer;
import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_3_5_TURBO;
import static dev.langchain4j.model.azure.AzureOpenAiTokenizer.countArguments;
import static org.assertj.core.api.Assertions.assertThat;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import java.util.ArrayList;
import java.util.List;
class AzureOpenAiTokenizerTest {
AzureOpenAiTokenizer tokenizer = new AzureOpenAiTokenizer(GPT_3_5_TURBO.modelType());
@Test
void should_encode_and_decode_text() {
String originalText = "This is a text which will be encoded and decoded back.";
List<Integer> tokens = tokenizer.encode(originalText);
String decodedText = tokenizer.decode(tokens);
assertThat(decodedText).isEqualTo(originalText);
}
@Test
void should_encode_with_truncation_and_decode_text() {
String originalText = "This is a text which will be encoded with truncation and decoded back.";
List<Integer> tokens = tokenizer.encode(originalText, 10);
assertThat(tokens).hasSize(10);
String decodedText = tokenizer.decode(tokens);
assertThat(decodedText).isEqualTo("This is a text which will be encoded with trunc");
}
@Test
void should_count_tokens_in_short_texts() {
assertThat(tokenizer.estimateTokenCountInText("Hello")).isEqualTo(1);
assertThat(tokenizer.estimateTokenCountInText("Hello!")).isEqualTo(2);
assertThat(tokenizer.estimateTokenCountInText("Hello, how are you?")).isEqualTo(6);
assertThat(tokenizer.estimateTokenCountInText("")).isEqualTo(0);
assertThat(tokenizer.estimateTokenCountInText("\n")).isEqualTo(1);
assertThat(tokenizer.estimateTokenCountInText("\n\n")).isEqualTo(1);
assertThat(tokenizer.estimateTokenCountInText("\n \n\n")).isEqualTo(2);
}
@Test
void should_count_tokens_in_average_text() {
String text1 = "Hello, how are you doing? What do you want to talk about?";
assertThat(tokenizer.estimateTokenCountInText(text1)).isEqualTo(15);
String text2 = String.join(" ", repeat("Hello, how are you doing? What do you want to talk about?", 2));
assertThat(tokenizer.estimateTokenCountInText(text2)).isEqualTo(2 * 15);
String text3 = String.join(" ", repeat("Hello, how are you doing? What do you want to talk about?", 3));
assertThat(tokenizer.estimateTokenCountInText(text3)).isEqualTo(3 * 15);
}
@Test
void should_count_tokens_in_large_text() {
String text1 = String.join(" ", repeat("Hello, how are you doing? What do you want to talk about?", 10));
assertThat(tokenizer.estimateTokenCountInText(text1)).isEqualTo(10 * 15);
String text2 = String.join(" ", repeat("Hello, how are you doing? What do you want to talk about?", 50));
assertThat(tokenizer.estimateTokenCountInText(text2)).isEqualTo(50 * 15);
String text3 = String.join(" ", repeat("Hello, how are you doing? What do you want to talk about?", 100));
assertThat(tokenizer.estimateTokenCountInText(text3)).isEqualTo(100 * 15);
}
@Test
void should_count_arguments() {
assertThat(countArguments(null)).isEqualTo(0);
assertThat(countArguments("")).isEqualTo(0);
assertThat(countArguments(" ")).isEqualTo(0);
assertThat(countArguments("{}")).isEqualTo(0);
assertThat(countArguments("{ }")).isEqualTo(0);
assertThat(countArguments("{\"one\":1}")).isEqualTo(1);
assertThat(countArguments("{\"one\": 1}")).isEqualTo(1);
assertThat(countArguments("{\"one\" : 1}")).isEqualTo(1);
assertThat(countArguments("{\"one\":1,\"two\":2}")).isEqualTo(2);
assertThat(countArguments("{\"one\": 1,\"two\": 2}")).isEqualTo(2);
assertThat(countArguments("{\"one\" : 1,\"two\" : 2}")).isEqualTo(2);
assertThat(countArguments("{\"one\":1,\"two\":2,\"three\":3}")).isEqualTo(3);
assertThat(countArguments("{\"one\": 1,\"two\": 2,\"three\": 3}")).isEqualTo(3);
assertThat(countArguments("{\"one\" : 1,\"two\" : 2,\"three\" : 3}")).isEqualTo(3);
}
@ParameterizedTest
@EnumSource(AzureOpenAiChatModelName.class)
void should_support_all_chat_models(AzureOpenAiChatModelName modelName) {
// given
Tokenizer tokenizer = new AzureOpenAiTokenizer(modelName);
// when
int tokenCount = tokenizer.estimateTokenCountInText("a");
// then
assertThat(tokenCount).isEqualTo(1);
}
@ParameterizedTest
@EnumSource(AzureOpenAiEmbeddingModelName.class)
void should_support_all_embedding_models(AzureOpenAiEmbeddingModelName modelName) {
// given
Tokenizer tokenizer = new AzureOpenAiTokenizer(modelName);
// when
int tokenCount = tokenizer.estimateTokenCountInText("a");
// then
assertThat(tokenCount).isEqualTo(1);
}
@ParameterizedTest
@EnumSource(AzureOpenAiLanguageModelName.class)
void should_support_all_language_models(AzureOpenAiLanguageModelName modelName) {
// given
Tokenizer tokenizer = new AzureOpenAiTokenizer(modelName);
// when
int tokenCount = tokenizer.estimateTokenCountInText("a");
// then
assertThat(tokenCount).isEqualTo(1);
}
static List<String> repeat(String strings, int n) {
List<String> result = new ArrayList<>();
for (int i = 0; i < n; i++) {
result.add(strings);
}
return result;
}
}

View File

@ -297,3 +297,10 @@ AZURE_OPENAI_ENDPOINT=$(
echo "AZURE_OPENAI_KEY=$AZURE_OPENAI_KEY"
echo "AZURE_OPENAI_ENDPOINT=$AZURE_OPENAI_ENDPOINT"
# Once you finish the tests, you can delete the resource group with the following command:
echo "Deleting the resource group..."
echo "------------------------------"
az group delete \
--name "$RESOURCE_GROUP" \
--yes