OpenAI: added option to setup a custom Tokenizer, increased default timeouts to 60 seconds
This commit is contained in:
parent
50635b1499
commit
d90715fab5
|
@ -9,16 +9,12 @@ import dev.langchain4j.data.message.*;
|
|||
import dev.langchain4j.model.output.FinishReason;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
import static dev.ai4j.openai4j.chat.Role.*;
|
||||
import static dev.langchain4j.data.message.AiMessage.aiMessage;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_4;
|
||||
import static dev.langchain4j.model.output.FinishReason.*;
|
||||
import static java.time.Duration.ofSeconds;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
public class InternalOpenAiHelper {
|
||||
|
@ -28,16 +24,6 @@ public class InternalOpenAiHelper {
|
|||
static final String OPENAI_DEMO_API_KEY = "demo";
|
||||
static final String OPENAI_DEMO_URL = "http://langchain4j.dev/demo/openai/v1";
|
||||
|
||||
static Duration defaultTimeoutFor(String modelName) {
|
||||
if (modelName.startsWith(GPT_3_5_TURBO)) {
|
||||
return ofSeconds(7);
|
||||
} else if (modelName.startsWith(GPT_4)) {
|
||||
return ofSeconds(20);
|
||||
}
|
||||
|
||||
return ofSeconds(10);
|
||||
}
|
||||
|
||||
public static List<Message> toOpenAiMessages(List<ChatMessage> messages) {
|
||||
|
||||
return messages.stream()
|
||||
|
|
|
@ -20,10 +20,12 @@ import static dev.langchain4j.internal.RetryUtils.withRetry;
|
|||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||
import static java.time.Duration.ofSeconds;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
||||
/**
|
||||
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and gpt-4.
|
||||
* You can find description of parameters <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
|
||||
*/
|
||||
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
|
||||
|
@ -52,14 +54,15 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||
Integer maxRetries,
|
||||
Proxy proxy,
|
||||
Boolean logRequests,
|
||||
Boolean logResponses) {
|
||||
Boolean logResponses,
|
||||
Tokenizer tokenizer) {
|
||||
|
||||
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
|
||||
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
|
||||
baseUrl = OPENAI_DEMO_URL;
|
||||
}
|
||||
modelName = getOrDefault(modelName, GPT_3_5_TURBO);
|
||||
timeout = getOrDefault(timeout, defaultTimeoutFor(modelName));
|
||||
|
||||
timeout = getOrDefault(timeout, ofSeconds(60));
|
||||
|
||||
this.client = OpenAiClient.builder()
|
||||
.openAiApiKey(apiKey)
|
||||
|
@ -72,7 +75,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||
.logRequests(logRequests)
|
||||
.logResponses(logResponses)
|
||||
.build();
|
||||
this.modelName = modelName;
|
||||
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
|
||||
this.temperature = getOrDefault(temperature, 0.7);
|
||||
this.topP = topP;
|
||||
this.stop = stop;
|
||||
|
@ -80,7 +83,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||
this.presencePenalty = presencePenalty;
|
||||
this.frequencyPenalty = frequencyPenalty;
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
this.tokenizer = new OpenAiTokenizer(this.modelName);
|
||||
this.tokenizer = getOrDefault(tokenizer, new OpenAiTokenizer(this.modelName));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,10 +16,8 @@ import java.time.Duration;
|
|||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_API_KEY;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_URL;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_URL;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.tokenUsageFrom;
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.TEXT_EMBEDDING_ADA_002;
|
||||
import static java.time.Duration.ofSeconds;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
@ -42,15 +40,15 @@ public class OpenAiEmbeddingModel implements EmbeddingModel, TokenCountEstimator
|
|||
Integer maxRetries,
|
||||
Proxy proxy,
|
||||
Boolean logRequests,
|
||||
Boolean logResponses) {
|
||||
Boolean logResponses,
|
||||
Tokenizer tokenizer) {
|
||||
|
||||
baseUrl = baseUrl == null ? OPENAI_URL : baseUrl;
|
||||
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
|
||||
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
|
||||
baseUrl = OPENAI_DEMO_URL;
|
||||
}
|
||||
modelName = modelName == null ? TEXT_EMBEDDING_ADA_002 : modelName;
|
||||
timeout = timeout == null ? ofSeconds(15) : timeout;
|
||||
maxRetries = maxRetries == null ? 3 : maxRetries;
|
||||
|
||||
timeout = getOrDefault(timeout, ofSeconds(60));
|
||||
|
||||
this.client = OpenAiClient.builder()
|
||||
.openAiApiKey(apiKey)
|
||||
|
@ -63,9 +61,9 @@ public class OpenAiEmbeddingModel implements EmbeddingModel, TokenCountEstimator
|
|||
.logRequests(logRequests)
|
||||
.logResponses(logResponses)
|
||||
.build();
|
||||
this.modelName = modelName;
|
||||
this.maxRetries = maxRetries;
|
||||
this.tokenizer = new OpenAiTokenizer(this.modelName);
|
||||
this.modelName = getOrDefault(modelName, TEXT_EMBEDDING_ADA_002);
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
this.tokenizer = getOrDefault(tokenizer, new OpenAiTokenizer(this.modelName));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -14,9 +14,8 @@ import java.net.Proxy;
|
|||
import java.time.Duration;
|
||||
|
||||
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_URL;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.finishReasonFrom;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.tokenUsageFrom;
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.TEXT_DAVINCI_003;
|
||||
import static java.time.Duration.ofSeconds;
|
||||
|
||||
|
@ -42,16 +41,13 @@ public class OpenAiLanguageModel implements LanguageModel, TokenCountEstimator {
|
|||
Integer maxRetries,
|
||||
Proxy proxy,
|
||||
Boolean logRequests,
|
||||
Boolean logResponses) {
|
||||
Boolean logResponses,
|
||||
Tokenizer tokenizer) {
|
||||
|
||||
baseUrl = baseUrl == null ? OPENAI_URL : baseUrl;
|
||||
modelName = modelName == null ? TEXT_DAVINCI_003 : modelName;
|
||||
temperature = temperature == null ? 0.7 : temperature;
|
||||
timeout = timeout == null ? ofSeconds(15) : timeout;
|
||||
maxRetries = maxRetries == null ? 3 : maxRetries;
|
||||
timeout = getOrDefault(timeout, ofSeconds(60));
|
||||
|
||||
this.client = OpenAiClient.builder()
|
||||
.baseUrl(baseUrl)
|
||||
.baseUrl(getOrDefault(baseUrl, OPENAI_URL))
|
||||
.openAiApiKey(apiKey)
|
||||
.callTimeout(timeout)
|
||||
.connectTimeout(timeout)
|
||||
|
@ -61,10 +57,10 @@ public class OpenAiLanguageModel implements LanguageModel, TokenCountEstimator {
|
|||
.logRequests(logRequests)
|
||||
.logResponses(logResponses)
|
||||
.build();
|
||||
this.modelName = modelName;
|
||||
this.temperature = temperature;
|
||||
this.maxRetries = maxRetries;
|
||||
this.tokenizer = new OpenAiTokenizer(this.modelName);
|
||||
this.modelName = getOrDefault(modelName, TEXT_DAVINCI_003);
|
||||
this.temperature = getOrDefault(temperature, 0.7);
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
this.tokenizer = getOrDefault(tokenizer, new OpenAiTokenizer(this.modelName));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -17,9 +17,8 @@ import java.time.Duration;
|
|||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_API_KEY;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_URL;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_URL;
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.TEXT_MODERATION_LATEST;
|
||||
import static java.time.Duration.ofSeconds;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
@ -44,13 +43,12 @@ public class OpenAiModerationModel implements ModerationModel {
|
|||
Boolean logRequests,
|
||||
Boolean logResponses) {
|
||||
|
||||
baseUrl = baseUrl == null ? OPENAI_URL : baseUrl;
|
||||
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
|
||||
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
|
||||
baseUrl = OPENAI_DEMO_URL;
|
||||
}
|
||||
modelName = modelName == null ? TEXT_MODERATION_LATEST : modelName;
|
||||
timeout = timeout == null ? ofSeconds(15) : timeout;
|
||||
maxRetries = maxRetries == null ? 3 : maxRetries;
|
||||
|
||||
timeout = getOrDefault(timeout, ofSeconds(60));
|
||||
|
||||
this.client = OpenAiClient.builder()
|
||||
.openAiApiKey(apiKey)
|
||||
|
@ -63,8 +61,8 @@ public class OpenAiModerationModel implements ModerationModel {
|
|||
.logRequests(logRequests)
|
||||
.logResponses(logResponses)
|
||||
.build();
|
||||
this.modelName = modelName;
|
||||
this.maxRetries = maxRetries;
|
||||
this.modelName = getOrDefault(modelName, TEXT_MODERATION_LATEST);
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -28,6 +28,7 @@ import static java.util.Collections.singletonList;
|
|||
/**
|
||||
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and gpt-4.
|
||||
* The model's response is streamed token by token and should be handled with {@link StreamingResponseHandler}.
|
||||
* You can find description of parameters <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
|
||||
*/
|
||||
public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, TokenCountEstimator {
|
||||
|
||||
|
@ -54,9 +55,10 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
Duration timeout,
|
||||
Proxy proxy,
|
||||
Boolean logRequests,
|
||||
Boolean logResponses) {
|
||||
Boolean logResponses,
|
||||
Tokenizer tokenizer) {
|
||||
|
||||
timeout = getOrDefault(timeout, ofSeconds(5));
|
||||
timeout = getOrDefault(timeout, ofSeconds(60));
|
||||
|
||||
this.client = OpenAiClient.builder()
|
||||
.baseUrl(getOrDefault(baseUrl, OPENAI_URL))
|
||||
|
@ -76,7 +78,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
|
|||
this.maxTokens = maxTokens;
|
||||
this.presencePenalty = presencePenalty;
|
||||
this.frequencyPenalty = frequencyPenalty;
|
||||
this.tokenizer = new OpenAiTokenizer(this.modelName);
|
||||
this.tokenizer = getOrDefault(tokenizer, new OpenAiTokenizer(this.modelName));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -13,6 +13,7 @@ import lombok.Builder;
|
|||
import java.net.Proxy;
|
||||
import java.time.Duration;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_URL;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.TEXT_DAVINCI_003;
|
||||
import static java.time.Duration.ofSeconds;
|
||||
|
@ -38,15 +39,13 @@ public class OpenAiStreamingLanguageModel implements StreamingLanguageModel, Tok
|
|||
Duration timeout,
|
||||
Proxy proxy,
|
||||
Boolean logRequests,
|
||||
Boolean logResponses) {
|
||||
Boolean logResponses,
|
||||
Tokenizer tokenizer) {
|
||||
|
||||
baseUrl = baseUrl == null ? OPENAI_URL : baseUrl;
|
||||
modelName = modelName == null ? TEXT_DAVINCI_003 : modelName;
|
||||
temperature = temperature == null ? 0.7 : temperature;
|
||||
timeout = timeout == null ? ofSeconds(60) : timeout;
|
||||
timeout = getOrDefault(timeout, ofSeconds(60));
|
||||
|
||||
this.client = OpenAiClient.builder()
|
||||
.baseUrl(baseUrl)
|
||||
.baseUrl(getOrDefault(baseUrl, OPENAI_URL))
|
||||
.openAiApiKey(apiKey)
|
||||
.callTimeout(timeout)
|
||||
.connectTimeout(timeout)
|
||||
|
@ -56,9 +55,9 @@ public class OpenAiStreamingLanguageModel implements StreamingLanguageModel, Tok
|
|||
.logRequests(logRequests)
|
||||
.logStreamingResponses(logResponses)
|
||||
.build();
|
||||
this.modelName = modelName;
|
||||
this.temperature = temperature;
|
||||
this.tokenizer = new OpenAiTokenizer(this.modelName);
|
||||
this.modelName = getOrDefault(modelName, TEXT_DAVINCI_003);
|
||||
this.temperature = getOrDefault(temperature, 0.7);
|
||||
this.tokenizer = getOrDefault(tokenizer, new OpenAiTokenizer(this.modelName));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,6 +16,7 @@ import java.util.Optional;
|
|||
import java.util.function.Supplier;
|
||||
|
||||
import static dev.langchain4j.internal.Exceptions.illegalArgument;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.roleFrom;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO_0301;
|
||||
|
||||
|
@ -25,7 +26,7 @@ public class OpenAiTokenizer implements Tokenizer {
|
|||
private final Optional<Encoding> encoding;
|
||||
|
||||
public OpenAiTokenizer(String modelName) {
|
||||
this.modelName = modelName;
|
||||
this.modelName = ensureNotBlank(modelName, "modelName");
|
||||
// If the model is unknown, we should NOT fail fast during the creation of OpenAiTokenizer.
|
||||
// Doing so would cause the failure of every OpenAI***Model that uses this tokenizer.
|
||||
// This is done to account for situations when a new OpenAI model is available,
|
||||
|
|
Loading…
Reference in New Issue