Add a Tokenizer to Azure OpenAI (#1222)
Today the Azure OpenAI module does not have its own Tokenizer. It actually uses the OpenAI one during its test phase (`langchain4j-open-ai` dependency in `test` scope). This PR is for Azure OpenAI to have its own Tokenizer. Not just for test purposes (this PR removes the dependency to `langchain4j-open-ai`) but also Azure OpenAI can be used for RAG without the need of the `OpenAITokenizer`. For that I had to change all the `ModelName` classes (`AzureOpenAiChatModelName`, `AzureOpenAiEmbeddingModelName`...) to add the model name (eg. `gpt-35-turbo-0301`), the model type (eg. `gpt-3.5-turbo`, notice the dot `.` in `3.5`) and version. The `AzureOpenAiTokenizer` is a copy/paste from `OpenAiTokenizer`. All `AzureOpenAiTokenizerTest` tests pass. ## General checklist <!-- Please double-check the following points and mark them like this: [X] --> - [x] There are no breaking changes - [x] I have added unit and integration tests for my change - [x] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [ ] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green <!-- Before adding documentation and example(s) (below), please wait until the PR is reviewed and approved. --> - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features)
This commit is contained in:
parent
45a4386ca0
commit
22d0a5fbb5
|
@ -41,9 +41,8 @@
|
|||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-open-ai</artifactId>
|
||||
<scope>test</scope>
|
||||
<groupId>com.knuddels</groupId>
|
||||
<artifactId>jtokkit</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
|
|
|
@ -6,37 +6,60 @@ package dev.langchain4j.model.azure;
|
|||
*/
|
||||
public enum AzureOpenAiChatModelName {
|
||||
|
||||
GPT_3_5_TURBO("gpt-35-turbo"), // alias for the latest gpt-3.5-turbo model
|
||||
GPT_3_5_TURBO_0301("gpt-35-turbo-0301"), // 4k context, functions
|
||||
GPT_3_5_TURBO_0613("gpt-35-turbo-0613"), // 4k context, functions
|
||||
GPT_3_5_TURBO_1106("gpt-35-turbo-1106"), // 16k context, functions
|
||||
GPT_3_5_TURBO("gpt-35-turbo", "gpt-3.5-turbo"), // alias for the latest gpt-3.5-turbo model
|
||||
GPT_3_5_TURBO_0301("gpt-35-turbo-0301", "gpt-3.5-turbo", "0301"), // 4k context, functions
|
||||
GPT_3_5_TURBO_0613("gpt-35-turbo-0613", "gpt-3.5-turbo", "0613"), // 4k context, functions
|
||||
GPT_3_5_TURBO_1106("gpt-35-turbo-1106", "gpt-3.5-turbo", "1106"), // 16k context, functions
|
||||
|
||||
GPT_3_5_TURBO_16K("gpt-35-turbo-16k"), // alias for the latest gpt-3.5-turbo-16k model
|
||||
GPT_3_5_TURBO_16K_0613("gpt-35-turbo-16k-0613"), // 16k context, functions
|
||||
GPT_3_5_TURBO_16K("gpt-35-turbo-16k", "gpt-3.5-turbo-16k"), // alias for the latest gpt-3.5-turbo-16k model
|
||||
GPT_3_5_TURBO_16K_0613("gpt-35-turbo-16k-0613", "gpt-3.5-turbo-16k", "0613"), // 16k context, functions
|
||||
|
||||
GPT_4("gpt-4"), // alias for the latest gpt-4
|
||||
GPT_4_0613("gpt-4-0613"), // 8k context, functions
|
||||
GPT_4_0125_PREVIEW("gpt-4-0125-preview"), // 8k context
|
||||
GPT_4_1106_PREVIEW("gpt-4-1106-preview"), // 8k context
|
||||
GPT_4("gpt-4", "gpt-4"), // alias for the latest gpt-4
|
||||
GPT_4_0613("gpt-4-0613", "gpt-4", "0613"), // 8k context, functions
|
||||
GPT_4_0125_PREVIEW("gpt-4-0125-preview", "gpt-4", "0125-preview"), // 8k context
|
||||
GPT_4_1106_PREVIEW("gpt-4-1106-preview", "gpt-4", "1106-preview"), // 8k context
|
||||
|
||||
GPT_4_TURBO("gpt-4-turbo"), // alias for the latest gpt-4-turbo model
|
||||
GPT_4_TURBO_2024_04_09("gpt-4-turbo-2024-04-09"), // alias for the latest gpt-4-turbo model
|
||||
GPT_4_TURBO("gpt-4-turbo", "gpt-4-turbo"), // alias for the latest gpt-4-turbo model
|
||||
GPT_4_TURBO_2024_04_09("gpt-4-turbo-2024-04-09", "gpt-4-turbo", "2024-04-09"), // alias for the latest gpt-4-turbo model
|
||||
|
||||
GPT_4_32K("gpt-4-32k"), // alias for the latest gpt-32k model
|
||||
GPT_4_32K_0613("gpt-4-32k-0613"), // 32k context, functions
|
||||
GPT_4_32K("gpt-4-32k", "gpt-4-32k"), // alias for the latest gpt-32k model
|
||||
GPT_4_32K_0613("gpt-4-32k-0613", "gpt-4-32k", "0613"), // 32k context, functions
|
||||
|
||||
GPT_4_VISION_PREVIEW("gpt-4-vision-preview"),
|
||||
GPT_4_VISION_PREVIEW("gpt-4-vision-preview", "gpt-4-vision", "preview"),
|
||||
|
||||
GPT_4_O("gpt-4o"); // alias for the latest gpt-4o model
|
||||
GPT_4_O("gpt-4o", "gpt-4o"); // alias for the latest gpt-4o model
|
||||
|
||||
private final String stringValue;
|
||||
private final String modelName;
|
||||
// Model type follows the com.knuddels.jtokkit.api.ModelType naming convention
|
||||
private final String modelType;
|
||||
private final String modelVersion;
|
||||
|
||||
AzureOpenAiChatModelName(String stringValue) {
|
||||
this.stringValue = stringValue;
|
||||
AzureOpenAiChatModelName(String modelName, String modelType) {
|
||||
this.modelName = modelName;
|
||||
this.modelType = modelType;
|
||||
this.modelVersion = null;
|
||||
}
|
||||
|
||||
AzureOpenAiChatModelName(String modelName, String modelType, String modelVersion) {
|
||||
this.modelName = modelName;
|
||||
this.modelType = modelType;
|
||||
this.modelVersion = modelVersion;
|
||||
}
|
||||
|
||||
public String modelName() {
|
||||
return modelName;
|
||||
}
|
||||
|
||||
public String modelType() {
|
||||
return modelType;
|
||||
}
|
||||
|
||||
public String modelVersion() {
|
||||
return modelVersion;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return stringValue;
|
||||
return modelName;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,23 +2,46 @@ package dev.langchain4j.model.azure;
|
|||
|
||||
public enum AzureOpenAiEmbeddingModelName {
|
||||
|
||||
TEXT_EMBEDDING_3_SMALL("text-embedding-3-small"), // alias for the latest text-embedding-3-small model
|
||||
TEXT_EMBEDDING_3_SMALL_1("text-embedding-3-small-1"),
|
||||
TEXT_EMBEDDING_3_LARGE("text-embedding-3-large"),
|
||||
TEXT_EMBEDDING_3_LARGE_1("text-embedding-3-large-1"),
|
||||
TEXT_EMBEDDING_3_SMALL("text-embedding-3-small", "text-embedding-3-small"), // alias for the latest text-embedding-3-small model
|
||||
TEXT_EMBEDDING_3_SMALL_1("text-embedding-3-small-1", "text-embedding-3-small", "1"),
|
||||
TEXT_EMBEDDING_3_LARGE("text-embedding-3-large", "text-embedding-3-large"),
|
||||
TEXT_EMBEDDING_3_LARGE_1("text-embedding-3-large-1", "text-embedding-3-large", "1"),
|
||||
|
||||
TEXT_EMBEDDING_ADA_002("text-embedding-ada-002"), // alias for the latest text-embedding-ada-002 model
|
||||
TEXT_EMBEDDING_ADA_002_1("text-embedding-ada-002-1"),
|
||||
TEXT_EMBEDDING_ADA_002_2("text-embedding-ada-002-2");
|
||||
TEXT_EMBEDDING_ADA_002("text-embedding-ada-002", "text-embedding-ada-002"), // alias for the latest text-embedding-ada-002 model
|
||||
TEXT_EMBEDDING_ADA_002_1("text-embedding-ada-002-1", "text-embedding-ada-002", "1"),
|
||||
TEXT_EMBEDDING_ADA_002_2("text-embedding-ada-002-2", "text-embedding-ada-002", "2");
|
||||
|
||||
private final String stringValue;
|
||||
private final String modelName;
|
||||
// Model type follows the com.knuddels.jtokkit.api.ModelType naming convention
|
||||
private final String modelType;
|
||||
private final String modelVersion;
|
||||
|
||||
AzureOpenAiEmbeddingModelName(String stringValue) {
|
||||
this.stringValue = stringValue;
|
||||
AzureOpenAiEmbeddingModelName(String modelName, String modelType) {
|
||||
this.modelName = modelName;
|
||||
this.modelType = modelType;
|
||||
this.modelVersion = null;
|
||||
}
|
||||
|
||||
AzureOpenAiEmbeddingModelName(String modelName, String modelType, String modelVersion) {
|
||||
this.modelName = modelName;
|
||||
this.modelType = modelType;
|
||||
this.modelVersion = null;
|
||||
}
|
||||
|
||||
public String modelName() {
|
||||
return modelName;
|
||||
}
|
||||
|
||||
public String modelType() {
|
||||
return modelType;
|
||||
}
|
||||
|
||||
public String modelVersion() {
|
||||
return modelVersion;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return stringValue;
|
||||
return modelName;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,17 +2,40 @@ package dev.langchain4j.model.azure;
|
|||
|
||||
public enum AzureOpenAiImageModelName {
|
||||
|
||||
DALL_E_3("dall-e-3"), // alias for the latest dall-e-3 model
|
||||
DALL_E_3_30("dall-e-3-30");
|
||||
DALL_E_3("dall-e-3", "dall-e-3"), // alias for the latest dall-e-3 model
|
||||
DALL_E_3_30("dall-e-3-30", "dall-e-3","30");
|
||||
|
||||
private final String stringValue;
|
||||
private final String modelName;
|
||||
// Model type follows the com.knuddels.jtokkit.api.ModelType naming convention
|
||||
private final String modelType;
|
||||
private final String modelVersion;
|
||||
|
||||
AzureOpenAiImageModelName(String stringValue) {
|
||||
this.stringValue = stringValue;
|
||||
AzureOpenAiImageModelName(String modelName, String modelType) {
|
||||
this.modelName = modelName;
|
||||
this.modelType = modelType;
|
||||
this.modelVersion = null;
|
||||
}
|
||||
|
||||
AzureOpenAiImageModelName(String modelName, String modelType, String modelVersion) {
|
||||
this.modelName = modelName;
|
||||
this.modelType = modelType;
|
||||
this.modelVersion = modelVersion;
|
||||
}
|
||||
|
||||
public String modelName() {
|
||||
return modelName;
|
||||
}
|
||||
|
||||
public String modelType() {
|
||||
return modelType;
|
||||
}
|
||||
|
||||
public String modelVersion() {
|
||||
return modelVersion;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return stringValue;
|
||||
return modelName;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,20 +2,43 @@ package dev.langchain4j.model.azure;
|
|||
|
||||
public enum AzureOpenAiLanguageModelName {
|
||||
|
||||
GPT_3_5_TURBO_INSTRUCT("gpt-35-turbo-instruct"), // alias for the latest gpt-3.5-turbo-instruct model
|
||||
GPT_3_5_TURBO_INSTRUCT_0914("gpt-35-turbo-instruct-0914"),
|
||||
GPT_3_5_TURBO_INSTRUCT("gpt-35-turbo-instruct", "gpt-3.5-turbo"), // alias for the latest gpt-3.5-turbo-instruct model
|
||||
GPT_3_5_TURBO_INSTRUCT_0914("gpt-35-turbo-instruct-0914", "gpt-3.5-turbo", "0914"), // 4k context, functions
|
||||
|
||||
TEXT_DAVINCI_002("davinci-002"),
|
||||
TEXT_DAVINCI_002_1("davinci-002-1");
|
||||
TEXT_DAVINCI_002("davinci-002", "text-davinci-002"),
|
||||
TEXT_DAVINCI_002_1("davinci-002-1", "text-davinci-002", "1"),;
|
||||
|
||||
private final String stringValue;
|
||||
private final String modelName;
|
||||
// Model type follows the com.knuddels.jtokkit.api.ModelType naming convention
|
||||
private final String modelType;
|
||||
private final String modelVersion;
|
||||
|
||||
AzureOpenAiLanguageModelName(String stringValue) {
|
||||
this.stringValue = stringValue;
|
||||
AzureOpenAiLanguageModelName(String modelName, String modelType) {
|
||||
this.modelName = modelName;
|
||||
this.modelType = modelType;
|
||||
this.modelVersion = null;
|
||||
}
|
||||
|
||||
AzureOpenAiLanguageModelName(String modelName, String modelType, String modelVersion) {
|
||||
this.modelName = modelName;
|
||||
this.modelType = modelType;
|
||||
this.modelVersion = modelVersion;
|
||||
}
|
||||
|
||||
public String modelName() {
|
||||
return modelName;
|
||||
}
|
||||
|
||||
public String modelType() {
|
||||
return modelType;
|
||||
}
|
||||
|
||||
public String modelVersion() {
|
||||
return modelVersion;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return stringValue;
|
||||
return modelName;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,404 @@
|
|||
package dev.langchain4j.model.azure;
|
||||
|
||||
import com.knuddels.jtokkit.Encodings;
|
||||
import com.knuddels.jtokkit.api.Encoding;
|
||||
import com.knuddels.jtokkit.api.IntArrayList;
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.agent.tool.ToolParameters;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.Content;
|
||||
import dev.langchain4j.data.message.ImageContent;
|
||||
import dev.langchain4j.data.message.SystemMessage;
|
||||
import dev.langchain4j.data.message.TextContent;
|
||||
import dev.langchain4j.data.message.ToolExecutionResultMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.model.Tokenizer;
|
||||
|
||||
import static dev.langchain4j.internal.Exceptions.illegalArgument;
|
||||
import static dev.langchain4j.internal.Json.fromJson;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrBlank;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_3_5_TURBO;
|
||||
import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_3_5_TURBO_0301;
|
||||
import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_3_5_TURBO_1106;
|
||||
import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_4_1106_PREVIEW;
|
||||
import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_4_0125_PREVIEW;
|
||||
import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_4_TURBO;
|
||||
import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_4_VISION_PREVIEW;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
/**
|
||||
* This class can be used to estimate the cost (in tokens) before calling OpenAI or when using streaming.
|
||||
* Magic numbers present in this class were found empirically while testing.
|
||||
* There are integration tests in place that are making sure that the calculations here are very close to that of OpenAI.
|
||||
*/
|
||||
public class AzureOpenAiTokenizer implements Tokenizer {
|
||||
|
||||
private final String modelName;
|
||||
private final Optional<Encoding> encoding;
|
||||
|
||||
/**
|
||||
* Creates an instance of the {@code AzureOpenAiTokenizer} for the "gpt-3.5-turbo" model.
|
||||
* It should be suitable for most OpenAI models, as most of them use the same cl100k_base encoding (except for GPT-4o).
|
||||
*/
|
||||
public AzureOpenAiTokenizer() {
|
||||
this(GPT_3_5_TURBO.modelType());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an instance of the {@code AzureOpenAiTokenizer} for a given {@link AzureOpenAiChatModelName}.
|
||||
*/
|
||||
public AzureOpenAiTokenizer(AzureOpenAiChatModelName modelName) {
|
||||
this(modelName.modelType());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an instance of the {@code AzureOpenAiTokenizer} for a given {@link AzureOpenAiEmbeddingModelName}.
|
||||
*/
|
||||
public AzureOpenAiTokenizer(AzureOpenAiEmbeddingModelName modelName) {
|
||||
this(modelName.modelType());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an instance of the {@code AzureOpenAiTokenizer} for a given {@link AzureOpenAiLanguageModelName}.
|
||||
*/
|
||||
public AzureOpenAiTokenizer(AzureOpenAiLanguageModelName modelName) {
|
||||
this(modelName.modelType());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an instance of the {@code AzureOpenAiTokenizer} for a given model name.
|
||||
*/
|
||||
public AzureOpenAiTokenizer(String modelName) {
|
||||
this.modelName = ensureNotBlank(modelName, "modelName");
|
||||
// If the model is unknown, we should NOT fail fast during the creation of AzureOpenAiTokenizer.
|
||||
// Doing so would cause the failure of every OpenAI***Model that uses this tokenizer.
|
||||
// This is done to account for situations when a new OpenAI model is available,
|
||||
// but JTokkit does not yet support it.
|
||||
this.encoding = Encodings.newLazyEncodingRegistry().getEncodingForModel(modelName);
|
||||
}
|
||||
|
||||
public int estimateTokenCountInText(String text) {
|
||||
return encoding.orElseThrow(unknownModelException())
|
||||
.countTokensOrdinary(text);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInMessage(ChatMessage message) {
|
||||
int tokenCount = 1; // 1 token for role
|
||||
tokenCount += extraTokensPerMessage();
|
||||
|
||||
if (message instanceof SystemMessage) {
|
||||
tokenCount += estimateTokenCountIn((SystemMessage) message);
|
||||
} else if (message instanceof UserMessage) {
|
||||
tokenCount += estimateTokenCountIn((UserMessage) message);
|
||||
} else if (message instanceof AiMessage) {
|
||||
tokenCount += estimateTokenCountIn((AiMessage) message);
|
||||
} else if (message instanceof ToolExecutionResultMessage) {
|
||||
tokenCount += estimateTokenCountIn((ToolExecutionResultMessage) message);
|
||||
} else {
|
||||
throw new IllegalArgumentException("Unknown message type: " + message);
|
||||
}
|
||||
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
private int estimateTokenCountIn(SystemMessage systemMessage) {
|
||||
return estimateTokenCountInText(systemMessage.text());
|
||||
}
|
||||
|
||||
private int estimateTokenCountIn(UserMessage userMessage) {
|
||||
int tokenCount = 0;
|
||||
|
||||
for (Content content : userMessage.contents()) {
|
||||
if (content instanceof TextContent) {
|
||||
tokenCount += estimateTokenCountInText(((TextContent) content).text());
|
||||
} else if (content instanceof ImageContent) {
|
||||
tokenCount += 85; // TODO implement for HIGH/AUTO detail level
|
||||
} else {
|
||||
throw illegalArgument("Unknown content type: " + content);
|
||||
}
|
||||
}
|
||||
|
||||
if (userMessage.name() != null && !modelName.equals(GPT_4_VISION_PREVIEW.toString())) {
|
||||
tokenCount += extraTokensPerName();
|
||||
tokenCount += estimateTokenCountInText(userMessage.name());
|
||||
}
|
||||
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
private int estimateTokenCountIn(AiMessage aiMessage) {
|
||||
int tokenCount = 0;
|
||||
|
||||
if (aiMessage.text() != null) {
|
||||
tokenCount += estimateTokenCountInText(aiMessage.text());
|
||||
}
|
||||
|
||||
if (aiMessage.toolExecutionRequests() != null) {
|
||||
if (isOneOfLatestModels()) {
|
||||
tokenCount += 6;
|
||||
} else {
|
||||
tokenCount += 3;
|
||||
}
|
||||
if (aiMessage.toolExecutionRequests().size() == 1) {
|
||||
tokenCount -= 1;
|
||||
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
|
||||
tokenCount += estimateTokenCountInText(toolExecutionRequest.name()) * 2;
|
||||
tokenCount += estimateTokenCountInText(toolExecutionRequest.arguments());
|
||||
} else {
|
||||
tokenCount += 15;
|
||||
for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
|
||||
tokenCount += 7;
|
||||
tokenCount += estimateTokenCountInText(toolExecutionRequest.name());
|
||||
|
||||
Map<?, ?> arguments = fromJson(toolExecutionRequest.arguments(), Map.class);
|
||||
for (Map.Entry<?, ?> argument : arguments.entrySet()) {
|
||||
tokenCount += 2;
|
||||
tokenCount += estimateTokenCountInText(argument.getKey().toString());
|
||||
tokenCount += estimateTokenCountInText(argument.getValue().toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
private int estimateTokenCountIn(ToolExecutionResultMessage toolExecutionResultMessage) {
|
||||
return estimateTokenCountInText(toolExecutionResultMessage.text());
|
||||
}
|
||||
|
||||
private int extraTokensPerMessage() {
|
||||
if (modelName.equals(GPT_3_5_TURBO_0301.modelName())) {
|
||||
return 4;
|
||||
} else {
|
||||
return 3;
|
||||
}
|
||||
}
|
||||
|
||||
private int extraTokensPerName() {
|
||||
if (modelName.equals(GPT_3_5_TURBO_0301.toString())) {
|
||||
return -1; // if there's a name, the role is omitted
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInMessages(Iterable<ChatMessage> messages) {
|
||||
// see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
|
||||
int tokenCount = 3; // every reply is primed with <|start|>assistant<|message|>
|
||||
for (ChatMessage message : messages) {
|
||||
tokenCount += estimateTokenCountInMessage(message);
|
||||
}
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInToolSpecifications(Iterable<ToolSpecification> toolSpecifications) {
|
||||
int tokenCount = 16;
|
||||
for (ToolSpecification toolSpecification : toolSpecifications) {
|
||||
tokenCount += 6;
|
||||
tokenCount += estimateTokenCountInText(toolSpecification.name());
|
||||
if (toolSpecification.description() != null) {
|
||||
tokenCount += 2;
|
||||
tokenCount += estimateTokenCountInText(toolSpecification.description());
|
||||
}
|
||||
tokenCount += estimateTokenCountInToolParameters(toolSpecification.parameters());
|
||||
}
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
private int estimateTokenCountInToolParameters(ToolParameters parameters) {
|
||||
if (parameters == null) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int tokenCount = 3;
|
||||
Map<String, Map<String, Object>> properties = parameters.properties();
|
||||
if (isOneOfLatestModels()) {
|
||||
tokenCount += properties.size() - 1;
|
||||
}
|
||||
for (String property : properties.keySet()) {
|
||||
if (isOneOfLatestModels()) {
|
||||
tokenCount += 2;
|
||||
} else {
|
||||
tokenCount += 3;
|
||||
}
|
||||
tokenCount += estimateTokenCountInText(property);
|
||||
for (Map.Entry<String, Object> entry : properties.get(property).entrySet()) {
|
||||
if ("type".equals(entry.getKey())) {
|
||||
if ("array".equals(entry.getValue()) && isOneOfLatestModels()) {
|
||||
tokenCount += 1;
|
||||
}
|
||||
// TODO object
|
||||
} else if ("description".equals(entry.getKey())) {
|
||||
tokenCount += 2;
|
||||
tokenCount += estimateTokenCountInText(entry.getValue().toString());
|
||||
if (isOneOfLatestModels() && parameters.required().contains(property)) {
|
||||
tokenCount += 1;
|
||||
}
|
||||
} else if ("enum".equals(entry.getKey())) {
|
||||
if (isOneOfLatestModels()) {
|
||||
tokenCount -= 2;
|
||||
} else {
|
||||
tokenCount -= 3;
|
||||
}
|
||||
for (Object enumValue : (Object[]) entry.getValue()) {
|
||||
tokenCount += 3;
|
||||
tokenCount += estimateTokenCountInText(enumValue.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInForcefulToolSpecification(ToolSpecification toolSpecification) {
|
||||
int tokenCount = estimateTokenCountInToolSpecifications(singletonList(toolSpecification));
|
||||
tokenCount += 4;
|
||||
tokenCount += estimateTokenCountInText(toolSpecification.name());
|
||||
if (isOneOfLatestModels()) {
|
||||
tokenCount += 3;
|
||||
}
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
public List<Integer> encode(String text) {
|
||||
return encoding.orElseThrow(unknownModelException())
|
||||
.encodeOrdinary(text).boxed();
|
||||
}
|
||||
|
||||
public List<Integer> encode(String text, int maxTokensToEncode) {
|
||||
return encoding.orElseThrow(unknownModelException())
|
||||
.encodeOrdinary(text, maxTokensToEncode).getTokens().boxed();
|
||||
}
|
||||
|
||||
public String decode(List<Integer> tokens) {
|
||||
|
||||
IntArrayList intArrayList = new IntArrayList();
|
||||
for (Integer token : tokens) {
|
||||
intArrayList.add(token);
|
||||
}
|
||||
|
||||
return encoding.orElseThrow(unknownModelException())
|
||||
.decode(intArrayList);
|
||||
}
|
||||
|
||||
private Supplier<IllegalArgumentException> unknownModelException() {
|
||||
return () -> illegalArgument("Model '%s' is unknown to jtokkit", modelName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInToolExecutionRequests(Iterable<ToolExecutionRequest> toolExecutionRequests) {
|
||||
|
||||
int tokenCount = 0;
|
||||
|
||||
int toolsCount = 0;
|
||||
int toolsWithArgumentsCount = 0;
|
||||
int toolsWithoutArgumentsCount = 0;
|
||||
|
||||
int totalArgumentsCount = 0;
|
||||
|
||||
for (ToolExecutionRequest toolExecutionRequest : toolExecutionRequests) {
|
||||
tokenCount += 4;
|
||||
tokenCount += estimateTokenCountInText(toolExecutionRequest.name());
|
||||
tokenCount += estimateTokenCountInText(toolExecutionRequest.arguments());
|
||||
|
||||
int argumentCount = countArguments(toolExecutionRequest.arguments());
|
||||
if (argumentCount == 0) {
|
||||
toolsWithoutArgumentsCount++;
|
||||
} else {
|
||||
toolsWithArgumentsCount++;
|
||||
}
|
||||
totalArgumentsCount += argumentCount;
|
||||
|
||||
toolsCount++;
|
||||
}
|
||||
|
||||
if (modelName.equals(GPT_3_5_TURBO_1106.toString()) || isOneOfLatestGpt4Models()) {
|
||||
tokenCount += 16;
|
||||
tokenCount += 3 * toolsWithoutArgumentsCount;
|
||||
tokenCount += toolsCount;
|
||||
if (totalArgumentsCount > 0) {
|
||||
tokenCount -= 1;
|
||||
tokenCount -= 2 * totalArgumentsCount;
|
||||
tokenCount += 2 * toolsWithArgumentsCount;
|
||||
tokenCount += toolsCount;
|
||||
}
|
||||
}
|
||||
|
||||
if (modelName.equals(GPT_4_1106_PREVIEW.toString())) {
|
||||
tokenCount += 3;
|
||||
if (toolsCount > 1) {
|
||||
tokenCount += 18;
|
||||
tokenCount += 15 * toolsCount;
|
||||
tokenCount += totalArgumentsCount;
|
||||
tokenCount -= 3 * toolsWithoutArgumentsCount;
|
||||
}
|
||||
}
|
||||
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInForcefulToolExecutionRequest(ToolExecutionRequest toolExecutionRequest) {
|
||||
|
||||
if (isOneOfLatestGpt4Models()) {
|
||||
int argumentsCount = countArguments(toolExecutionRequest.arguments());
|
||||
if (argumentsCount == 0) {
|
||||
return 1;
|
||||
} else {
|
||||
return estimateTokenCountInText(toolExecutionRequest.arguments());
|
||||
}
|
||||
}
|
||||
|
||||
int tokenCount = estimateTokenCountInToolExecutionRequests(singletonList(toolExecutionRequest));
|
||||
tokenCount -= 4;
|
||||
tokenCount -= estimateTokenCountInText(toolExecutionRequest.name());
|
||||
|
||||
if (modelName.equals(GPT_3_5_TURBO_1106.toString())) {
|
||||
int argumentsCount = countArguments(toolExecutionRequest.arguments());
|
||||
if (argumentsCount == 0) {
|
||||
return 1;
|
||||
}
|
||||
tokenCount -= 19;
|
||||
tokenCount += 2 * argumentsCount;
|
||||
}
|
||||
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
static int countArguments(String arguments) {
|
||||
if (isNullOrBlank(arguments)) {
|
||||
return 0;
|
||||
}
|
||||
Map<?, ?> argumentsMap = fromJson(arguments, Map.class);
|
||||
return argumentsMap.size();
|
||||
}
|
||||
|
||||
private boolean isOneOfLatestModels() {
|
||||
return isOneOfLatestGpt3Models() || isOneOfLatestGpt4Models();
|
||||
}
|
||||
|
||||
private boolean isOneOfLatestGpt3Models() {
|
||||
return modelName.equals(GPT_3_5_TURBO_1106.toString())
|
||||
|| modelName.equals(GPT_3_5_TURBO.toString());
|
||||
}
|
||||
|
||||
private boolean isOneOfLatestGpt4Models() {
|
||||
return modelName.equals(GPT_4_TURBO.toString())
|
||||
|| modelName.equals(GPT_4_1106_PREVIEW.toString())
|
||||
|| modelName.equals(GPT_4_0125_PREVIEW.toString());
|
||||
}
|
||||
}
|
|
@ -8,7 +8,6 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
|||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.language.LanguageModel;
|
||||
import dev.langchain4j.model.language.StreamingLanguageModel;
|
||||
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
|
@ -18,7 +17,7 @@ import org.slf4j.LoggerFactory;
|
|||
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
import static dev.langchain4j.model.openai.OpenAiLanguageModelName.GPT_3_5_TURBO_INSTRUCT;
|
||||
import static dev.langchain4j.model.azure.AzureOpenAiLanguageModelName.GPT_3_5_TURBO_INSTRUCT;
|
||||
import static dev.langchain4j.model.output.FinishReason.CONTENT_FILTER;
|
||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
@ -41,7 +40,7 @@ public class AzureOpenAIResponsibleAIIT {
|
|||
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||
.deploymentName(deploymentName)
|
||||
.tokenizer(new OpenAiTokenizer(gptVersion))
|
||||
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
|
||||
|
@ -63,7 +62,7 @@ public class AzureOpenAIResponsibleAIIT {
|
|||
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||
.deploymentName(deploymentName)
|
||||
.tokenizer(new OpenAiTokenizer(gptVersion))
|
||||
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
|
||||
|
@ -100,7 +99,7 @@ public class AzureOpenAIResponsibleAIIT {
|
|||
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||
.deploymentName("gpt-35-turbo-instruct")
|
||||
.tokenizer(new OpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT))
|
||||
.tokenizer(new AzureOpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT))
|
||||
.temperature(0.0)
|
||||
.maxTokens(20)
|
||||
.logRequestsAndResponses(true)
|
||||
|
@ -127,7 +126,7 @@ public class AzureOpenAIResponsibleAIIT {
|
|||
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||
.deploymentName(deploymentName)
|
||||
.tokenizer(new OpenAiTokenizer(gptVersion))
|
||||
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
|
||||
|
@ -172,7 +171,7 @@ public class AzureOpenAIResponsibleAIIT {
|
|||
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||
.deploymentName("gpt-35-turbo-instruct")
|
||||
.tokenizer(new OpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT))
|
||||
.tokenizer(new AzureOpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT))
|
||||
.temperature(0.0)
|
||||
.maxTokens(20)
|
||||
.logRequestsAndResponses(true)
|
||||
|
|
|
@ -9,7 +9,6 @@ import dev.langchain4j.agent.tool.ToolParameters;
|
|||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.*;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import org.assertj.core.data.Percentage;
|
||||
|
@ -46,7 +45,7 @@ public class AzureOpenAiChatModelIT {
|
|||
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||
.deploymentName(deploymentName)
|
||||
.tokenizer(new OpenAiTokenizer(gptVersion))
|
||||
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
|
||||
|
@ -75,7 +74,7 @@ public class AzureOpenAiChatModelIT {
|
|||
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||
.deploymentName(deploymentName)
|
||||
.tokenizer(new OpenAiTokenizer(gptVersion))
|
||||
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
|
||||
.maxTokens(3)
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
|
@ -105,7 +104,7 @@ public class AzureOpenAiChatModelIT {
|
|||
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||
.deploymentName(deploymentName)
|
||||
.tokenizer(new OpenAiTokenizer(gptVersion))
|
||||
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
|
||||
|
@ -170,7 +169,7 @@ public class AzureOpenAiChatModelIT {
|
|||
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||
.deploymentName(deploymentName)
|
||||
.tokenizer(new OpenAiTokenizer(gptVersion))
|
||||
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
|
||||
|
@ -205,7 +204,7 @@ public class AzureOpenAiChatModelIT {
|
|||
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||
.deploymentName(deploymentName)
|
||||
.tokenizer(new OpenAiTokenizer(gptVersion))
|
||||
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
|
||||
|
@ -282,7 +281,7 @@ public class AzureOpenAiChatModelIT {
|
|||
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||
.deploymentName(deploymentName)
|
||||
.tokenizer(new OpenAiTokenizer(gptVersion))
|
||||
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
|
||||
.responseFormat(new ChatCompletionsJsonResponseFormat())
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
|
|
|
@ -3,7 +3,6 @@ package dev.langchain4j.model.azure;
|
|||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
@ -15,7 +14,7 @@ import org.slf4j.LoggerFactory;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.model.openai.OpenAiEmbeddingModelName.TEXT_EMBEDDING_ADA_002;
|
||||
import static dev.langchain4j.model.azure.AzureOpenAiEmbeddingModelName.TEXT_EMBEDDING_ADA_002;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
public class AzureOpenAiEmbeddingModelIT {
|
||||
|
@ -26,7 +25,7 @@ public class AzureOpenAiEmbeddingModelIT {
|
|||
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||
.deploymentName("text-embedding-ada-002")
|
||||
.tokenizer(new OpenAiTokenizer(TEXT_EMBEDDING_ADA_002))
|
||||
.tokenizer(new AzureOpenAiTokenizer(TEXT_EMBEDDING_ADA_002))
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package dev.langchain4j.model.azure;
|
||||
|
||||
import dev.langchain4j.model.language.LanguageModel;
|
||||
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
@ -10,7 +9,7 @@ import org.junit.jupiter.params.provider.EnumSource;
|
|||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import static dev.langchain4j.model.openai.OpenAiLanguageModelName.GPT_3_5_TURBO_INSTRUCT;
|
||||
import static dev.langchain4j.model.azure.AzureOpenAiLanguageModelName.GPT_3_5_TURBO_INSTRUCT;
|
||||
import static dev.langchain4j.model.output.FinishReason.LENGTH;
|
||||
import static dev.langchain4j.model.output.FinishReason.STOP;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
@ -23,7 +22,7 @@ class AzureOpenAiLanguageModelIT {
|
|||
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||
.deploymentName("gpt-35-turbo-instruct")
|
||||
.tokenizer(new OpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT))
|
||||
.tokenizer(new AzureOpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT))
|
||||
.temperature(0.0)
|
||||
.maxTokens(20)
|
||||
.logRequestsAndResponses(true)
|
||||
|
|
|
@ -12,7 +12,6 @@ import dev.langchain4j.data.message.UserMessage;
|
|||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
|
||||
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import org.assertj.core.data.Percentage;
|
||||
|
@ -58,7 +57,7 @@ class AzureOpenAiStreamingChatModelIT {
|
|||
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||
.deploymentName(deploymentName)
|
||||
.useAsyncClient(useAsyncClient)
|
||||
.tokenizer(new OpenAiTokenizer(gptVersion))
|
||||
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
|
||||
|
@ -123,7 +122,7 @@ class AzureOpenAiStreamingChatModelIT {
|
|||
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||
.deploymentName(deploymentName)
|
||||
.tokenizer(new OpenAiTokenizer(gptVersion))
|
||||
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
model.generate("What is the capital of France?", new StreamingResponseHandler<AiMessage>() {
|
||||
|
@ -201,7 +200,7 @@ class AzureOpenAiStreamingChatModelIT {
|
|||
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||
.deploymentName(deploymentName)
|
||||
.tokenizer(new OpenAiTokenizer(gptVersion))
|
||||
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
|
||||
|
@ -306,7 +305,7 @@ class AzureOpenAiStreamingChatModelIT {
|
|||
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||
.deploymentName(deploymentName)
|
||||
.tokenizer(new OpenAiTokenizer(gptVersion))
|
||||
.tokenizer(new AzureOpenAiTokenizer(gptVersion))
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@ package dev.langchain4j.model.azure;
|
|||
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.language.StreamingLanguageModel;
|
||||
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.slf4j.Logger;
|
||||
|
@ -10,7 +9,7 @@ import org.slf4j.LoggerFactory;
|
|||
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
import static dev.langchain4j.model.openai.OpenAiLanguageModelName.GPT_3_5_TURBO_INSTRUCT;
|
||||
import static dev.langchain4j.model.azure.AzureOpenAiLanguageModelName.GPT_3_5_TURBO_INSTRUCT;
|
||||
import static dev.langchain4j.model.output.FinishReason.LENGTH;
|
||||
import static dev.langchain4j.model.output.FinishReason.STOP;
|
||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
||||
|
@ -24,7 +23,7 @@ class AzureOpenAiStreamingLanguageModelIT {
|
|||
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
|
||||
.deploymentName("gpt-35-turbo-instruct")
|
||||
.tokenizer(new OpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT))
|
||||
.tokenizer(new AzureOpenAiTokenizer(GPT_3_5_TURBO_INSTRUCT))
|
||||
.temperature(0.0)
|
||||
.maxTokens(20)
|
||||
.logRequestsAndResponses(true)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,145 @@
|
|||
package dev.langchain4j.model.azure;
|
||||
|
||||
import dev.langchain4j.model.Tokenizer;
|
||||
import static dev.langchain4j.model.azure.AzureOpenAiChatModelName.GPT_3_5_TURBO;
|
||||
import static dev.langchain4j.model.azure.AzureOpenAiTokenizer.countArguments;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.EnumSource;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
class AzureOpenAiTokenizerTest {
|
||||
|
||||
AzureOpenAiTokenizer tokenizer = new AzureOpenAiTokenizer(GPT_3_5_TURBO.modelType());
|
||||
|
||||
@Test
|
||||
void should_encode_and_decode_text() {
|
||||
String originalText = "This is a text which will be encoded and decoded back.";
|
||||
|
||||
List<Integer> tokens = tokenizer.encode(originalText);
|
||||
String decodedText = tokenizer.decode(tokens);
|
||||
|
||||
assertThat(decodedText).isEqualTo(originalText);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_encode_with_truncation_and_decode_text() {
|
||||
String originalText = "This is a text which will be encoded with truncation and decoded back.";
|
||||
|
||||
List<Integer> tokens = tokenizer.encode(originalText, 10);
|
||||
assertThat(tokens).hasSize(10);
|
||||
|
||||
String decodedText = tokenizer.decode(tokens);
|
||||
assertThat(decodedText).isEqualTo("This is a text which will be encoded with trunc");
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_count_tokens_in_short_texts() {
|
||||
assertThat(tokenizer.estimateTokenCountInText("Hello")).isEqualTo(1);
|
||||
assertThat(tokenizer.estimateTokenCountInText("Hello!")).isEqualTo(2);
|
||||
assertThat(tokenizer.estimateTokenCountInText("Hello, how are you?")).isEqualTo(6);
|
||||
|
||||
assertThat(tokenizer.estimateTokenCountInText("")).isEqualTo(0);
|
||||
assertThat(tokenizer.estimateTokenCountInText("\n")).isEqualTo(1);
|
||||
assertThat(tokenizer.estimateTokenCountInText("\n\n")).isEqualTo(1);
|
||||
assertThat(tokenizer.estimateTokenCountInText("\n \n\n")).isEqualTo(2);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_count_tokens_in_average_text() {
|
||||
String text1 = "Hello, how are you doing? What do you want to talk about?";
|
||||
assertThat(tokenizer.estimateTokenCountInText(text1)).isEqualTo(15);
|
||||
|
||||
String text2 = String.join(" ", repeat("Hello, how are you doing? What do you want to talk about?", 2));
|
||||
assertThat(tokenizer.estimateTokenCountInText(text2)).isEqualTo(2 * 15);
|
||||
|
||||
String text3 = String.join(" ", repeat("Hello, how are you doing? What do you want to talk about?", 3));
|
||||
assertThat(tokenizer.estimateTokenCountInText(text3)).isEqualTo(3 * 15);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_count_tokens_in_large_text() {
|
||||
String text1 = String.join(" ", repeat("Hello, how are you doing? What do you want to talk about?", 10));
|
||||
assertThat(tokenizer.estimateTokenCountInText(text1)).isEqualTo(10 * 15);
|
||||
|
||||
String text2 = String.join(" ", repeat("Hello, how are you doing? What do you want to talk about?", 50));
|
||||
assertThat(tokenizer.estimateTokenCountInText(text2)).isEqualTo(50 * 15);
|
||||
|
||||
String text3 = String.join(" ", repeat("Hello, how are you doing? What do you want to talk about?", 100));
|
||||
assertThat(tokenizer.estimateTokenCountInText(text3)).isEqualTo(100 * 15);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_count_arguments() {
|
||||
assertThat(countArguments(null)).isEqualTo(0);
|
||||
assertThat(countArguments("")).isEqualTo(0);
|
||||
assertThat(countArguments(" ")).isEqualTo(0);
|
||||
assertThat(countArguments("{}")).isEqualTo(0);
|
||||
assertThat(countArguments("{ }")).isEqualTo(0);
|
||||
|
||||
assertThat(countArguments("{\"one\":1}")).isEqualTo(1);
|
||||
assertThat(countArguments("{\"one\": 1}")).isEqualTo(1);
|
||||
assertThat(countArguments("{\"one\" : 1}")).isEqualTo(1);
|
||||
|
||||
assertThat(countArguments("{\"one\":1,\"two\":2}")).isEqualTo(2);
|
||||
assertThat(countArguments("{\"one\": 1,\"two\": 2}")).isEqualTo(2);
|
||||
assertThat(countArguments("{\"one\" : 1,\"two\" : 2}")).isEqualTo(2);
|
||||
|
||||
assertThat(countArguments("{\"one\":1,\"two\":2,\"three\":3}")).isEqualTo(3);
|
||||
assertThat(countArguments("{\"one\": 1,\"two\": 2,\"three\": 3}")).isEqualTo(3);
|
||||
assertThat(countArguments("{\"one\" : 1,\"two\" : 2,\"three\" : 3}")).isEqualTo(3);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(AzureOpenAiChatModelName.class)
|
||||
void should_support_all_chat_models(AzureOpenAiChatModelName modelName) {
|
||||
|
||||
// given
|
||||
Tokenizer tokenizer = new AzureOpenAiTokenizer(modelName);
|
||||
|
||||
// when
|
||||
int tokenCount = tokenizer.estimateTokenCountInText("a");
|
||||
|
||||
// then
|
||||
assertThat(tokenCount).isEqualTo(1);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(AzureOpenAiEmbeddingModelName.class)
|
||||
void should_support_all_embedding_models(AzureOpenAiEmbeddingModelName modelName) {
|
||||
|
||||
// given
|
||||
Tokenizer tokenizer = new AzureOpenAiTokenizer(modelName);
|
||||
|
||||
// when
|
||||
int tokenCount = tokenizer.estimateTokenCountInText("a");
|
||||
|
||||
// then
|
||||
assertThat(tokenCount).isEqualTo(1);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(AzureOpenAiLanguageModelName.class)
|
||||
void should_support_all_language_models(AzureOpenAiLanguageModelName modelName) {
|
||||
|
||||
// given
|
||||
Tokenizer tokenizer = new AzureOpenAiTokenizer(modelName);
|
||||
|
||||
// when
|
||||
int tokenCount = tokenizer.estimateTokenCountInText("a");
|
||||
|
||||
// then
|
||||
assertThat(tokenCount).isEqualTo(1);
|
||||
}
|
||||
|
||||
static List<String> repeat(String strings, int n) {
|
||||
List<String> result = new ArrayList<>();
|
||||
for (int i = 0; i < n; i++) {
|
||||
result.add(strings);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
|
@ -297,3 +297,10 @@ AZURE_OPENAI_ENDPOINT=$(
|
|||
|
||||
echo "AZURE_OPENAI_KEY=$AZURE_OPENAI_KEY"
|
||||
echo "AZURE_OPENAI_ENDPOINT=$AZURE_OPENAI_ENDPOINT"
|
||||
|
||||
# Once you finish the tests, you can delete the resource group with the following command:
|
||||
echo "Deleting the resource group..."
|
||||
echo "------------------------------"
|
||||
az group delete \
|
||||
--name "$RESOURCE_GROUP" \
|
||||
--yes
|
||||
|
|
Loading…
Reference in New Issue