OpenAI: added option to setup a custom Tokenizer, increased default timeouts to 60 seconds

This commit is contained in:
deep-learning-dynamo 2023-10-13 13:55:55 +02:00
parent 50635b1499
commit d90715fab5
8 changed files with 50 additions and 67 deletions

View File

@ -9,16 +9,12 @@ import dev.langchain4j.data.message.*;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage;
import java.time.Duration;
import java.util.Collection;
import java.util.List;
import static dev.ai4j.openai4j.chat.Role.*;
import static dev.langchain4j.data.message.AiMessage.aiMessage;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_4;
import static dev.langchain4j.model.output.FinishReason.*;
import static java.time.Duration.ofSeconds;
import static java.util.stream.Collectors.toList;
public class InternalOpenAiHelper {
@ -28,16 +24,6 @@ public class InternalOpenAiHelper {
static final String OPENAI_DEMO_API_KEY = "demo";
static final String OPENAI_DEMO_URL = "http://langchain4j.dev/demo/openai/v1";
static Duration defaultTimeoutFor(String modelName) {
if (modelName.startsWith(GPT_3_5_TURBO)) {
return ofSeconds(7);
} else if (modelName.startsWith(GPT_4)) {
return ofSeconds(20);
}
return ofSeconds(10);
}
public static List<Message> toOpenAiMessages(List<ChatMessage> messages) {
return messages.stream()

View File

@ -20,10 +20,12 @@ import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
import static java.time.Duration.ofSeconds;
import static java.util.Collections.singletonList;
/**
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and gpt-4.
* You can find description of parameters <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
*/
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
@ -52,14 +54,15 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
Integer maxRetries,
Proxy proxy,
Boolean logRequests,
Boolean logResponses) {
Boolean logResponses,
Tokenizer tokenizer) {
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
baseUrl = OPENAI_DEMO_URL;
}
modelName = getOrDefault(modelName, GPT_3_5_TURBO);
timeout = getOrDefault(timeout, defaultTimeoutFor(modelName));
timeout = getOrDefault(timeout, ofSeconds(60));
this.client = OpenAiClient.builder()
.openAiApiKey(apiKey)
@ -72,7 +75,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
.logRequests(logRequests)
.logResponses(logResponses)
.build();
this.modelName = modelName;
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP;
this.stop = stop;
@ -80,7 +83,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
this.presencePenalty = presencePenalty;
this.frequencyPenalty = frequencyPenalty;
this.maxRetries = getOrDefault(maxRetries, 3);
this.tokenizer = new OpenAiTokenizer(this.modelName);
this.tokenizer = getOrDefault(tokenizer, new OpenAiTokenizer(this.modelName));
}
@Override

View File

@ -16,10 +16,8 @@ import java.time.Duration;
import java.util.List;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_API_KEY;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_URL;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_URL;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.tokenUsageFrom;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
import static dev.langchain4j.model.openai.OpenAiModelName.TEXT_EMBEDDING_ADA_002;
import static java.time.Duration.ofSeconds;
import static java.util.stream.Collectors.toList;
@ -42,15 +40,15 @@ public class OpenAiEmbeddingModel implements EmbeddingModel, TokenCountEstimator
Integer maxRetries,
Proxy proxy,
Boolean logRequests,
Boolean logResponses) {
Boolean logResponses,
Tokenizer tokenizer) {
baseUrl = baseUrl == null ? OPENAI_URL : baseUrl;
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
baseUrl = OPENAI_DEMO_URL;
}
modelName = modelName == null ? TEXT_EMBEDDING_ADA_002 : modelName;
timeout = timeout == null ? ofSeconds(15) : timeout;
maxRetries = maxRetries == null ? 3 : maxRetries;
timeout = getOrDefault(timeout, ofSeconds(60));
this.client = OpenAiClient.builder()
.openAiApiKey(apiKey)
@ -63,9 +61,9 @@ public class OpenAiEmbeddingModel implements EmbeddingModel, TokenCountEstimator
.logRequests(logRequests)
.logResponses(logResponses)
.build();
this.modelName = modelName;
this.maxRetries = maxRetries;
this.tokenizer = new OpenAiTokenizer(this.modelName);
this.modelName = getOrDefault(modelName, TEXT_EMBEDDING_ADA_002);
this.maxRetries = getOrDefault(maxRetries, 3);
this.tokenizer = getOrDefault(tokenizer, new OpenAiTokenizer(this.modelName));
}
@Override

View File

@ -14,9 +14,8 @@ import java.net.Proxy;
import java.time.Duration;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_URL;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.finishReasonFrom;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.tokenUsageFrom;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
import static dev.langchain4j.model.openai.OpenAiModelName.TEXT_DAVINCI_003;
import static java.time.Duration.ofSeconds;
@ -42,16 +41,13 @@ public class OpenAiLanguageModel implements LanguageModel, TokenCountEstimator {
Integer maxRetries,
Proxy proxy,
Boolean logRequests,
Boolean logResponses) {
Boolean logResponses,
Tokenizer tokenizer) {
baseUrl = baseUrl == null ? OPENAI_URL : baseUrl;
modelName = modelName == null ? TEXT_DAVINCI_003 : modelName;
temperature = temperature == null ? 0.7 : temperature;
timeout = timeout == null ? ofSeconds(15) : timeout;
maxRetries = maxRetries == null ? 3 : maxRetries;
timeout = getOrDefault(timeout, ofSeconds(60));
this.client = OpenAiClient.builder()
.baseUrl(baseUrl)
.baseUrl(getOrDefault(baseUrl, OPENAI_URL))
.openAiApiKey(apiKey)
.callTimeout(timeout)
.connectTimeout(timeout)
@ -61,10 +57,10 @@ public class OpenAiLanguageModel implements LanguageModel, TokenCountEstimator {
.logRequests(logRequests)
.logResponses(logResponses)
.build();
this.modelName = modelName;
this.temperature = temperature;
this.maxRetries = maxRetries;
this.tokenizer = new OpenAiTokenizer(this.modelName);
this.modelName = getOrDefault(modelName, TEXT_DAVINCI_003);
this.temperature = getOrDefault(temperature, 0.7);
this.maxRetries = getOrDefault(maxRetries, 3);
this.tokenizer = getOrDefault(tokenizer, new OpenAiTokenizer(this.modelName));
}
@Override

View File

@ -17,9 +17,8 @@ import java.time.Duration;
import java.util.List;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_API_KEY;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_URL;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_URL;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
import static dev.langchain4j.model.openai.OpenAiModelName.TEXT_MODERATION_LATEST;
import static java.time.Duration.ofSeconds;
import static java.util.Collections.singletonList;
@ -44,13 +43,12 @@ public class OpenAiModerationModel implements ModerationModel {
Boolean logRequests,
Boolean logResponses) {
baseUrl = baseUrl == null ? OPENAI_URL : baseUrl;
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
baseUrl = OPENAI_DEMO_URL;
}
modelName = modelName == null ? TEXT_MODERATION_LATEST : modelName;
timeout = timeout == null ? ofSeconds(15) : timeout;
maxRetries = maxRetries == null ? 3 : maxRetries;
timeout = getOrDefault(timeout, ofSeconds(60));
this.client = OpenAiClient.builder()
.openAiApiKey(apiKey)
@ -63,8 +61,8 @@ public class OpenAiModerationModel implements ModerationModel {
.logRequests(logRequests)
.logResponses(logResponses)
.build();
this.modelName = modelName;
this.maxRetries = maxRetries;
this.modelName = getOrDefault(modelName, TEXT_MODERATION_LATEST);
this.maxRetries = getOrDefault(maxRetries, 3);
}
@Override

View File

@ -28,6 +28,7 @@ import static java.util.Collections.singletonList;
/**
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and gpt-4.
* The model's response is streamed token by token and should be handled with {@link StreamingResponseHandler}.
* You can find description of parameters <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
*/
public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, TokenCountEstimator {
@ -54,9 +55,10 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
Duration timeout,
Proxy proxy,
Boolean logRequests,
Boolean logResponses) {
Boolean logResponses,
Tokenizer tokenizer) {
timeout = getOrDefault(timeout, ofSeconds(5));
timeout = getOrDefault(timeout, ofSeconds(60));
this.client = OpenAiClient.builder()
.baseUrl(getOrDefault(baseUrl, OPENAI_URL))
@ -76,7 +78,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
this.maxTokens = maxTokens;
this.presencePenalty = presencePenalty;
this.frequencyPenalty = frequencyPenalty;
this.tokenizer = new OpenAiTokenizer(this.modelName);
this.tokenizer = getOrDefault(tokenizer, new OpenAiTokenizer(this.modelName));
}
@Override

View File

@ -13,6 +13,7 @@ import lombok.Builder;
import java.net.Proxy;
import java.time.Duration;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_URL;
import static dev.langchain4j.model.openai.OpenAiModelName.TEXT_DAVINCI_003;
import static java.time.Duration.ofSeconds;
@ -38,15 +39,13 @@ public class OpenAiStreamingLanguageModel implements StreamingLanguageModel, Tok
Duration timeout,
Proxy proxy,
Boolean logRequests,
Boolean logResponses) {
Boolean logResponses,
Tokenizer tokenizer) {
baseUrl = baseUrl == null ? OPENAI_URL : baseUrl;
modelName = modelName == null ? TEXT_DAVINCI_003 : modelName;
temperature = temperature == null ? 0.7 : temperature;
timeout = timeout == null ? ofSeconds(60) : timeout;
timeout = getOrDefault(timeout, ofSeconds(60));
this.client = OpenAiClient.builder()
.baseUrl(baseUrl)
.baseUrl(getOrDefault(baseUrl, OPENAI_URL))
.openAiApiKey(apiKey)
.callTimeout(timeout)
.connectTimeout(timeout)
@ -56,9 +55,9 @@ public class OpenAiStreamingLanguageModel implements StreamingLanguageModel, Tok
.logRequests(logRequests)
.logStreamingResponses(logResponses)
.build();
this.modelName = modelName;
this.temperature = temperature;
this.tokenizer = new OpenAiTokenizer(this.modelName);
this.modelName = getOrDefault(modelName, TEXT_DAVINCI_003);
this.temperature = getOrDefault(temperature, 0.7);
this.tokenizer = getOrDefault(tokenizer, new OpenAiTokenizer(this.modelName));
}
@Override

View File

@ -16,6 +16,7 @@ import java.util.Optional;
import java.util.function.Supplier;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.roleFrom;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO_0301;
@ -25,7 +26,7 @@ public class OpenAiTokenizer implements Tokenizer {
private final Optional<Encoding> encoding;
public OpenAiTokenizer(String modelName) {
this.modelName = modelName;
this.modelName = ensureNotBlank(modelName, "modelName");
// If the model is unknown, we should NOT fail fast during the creation of OpenAiTokenizer.
// Doing so would cause the failure of every OpenAI***Model that uses this tokenizer.
// This is done to account for situations when a new OpenAI model is available,