Google AI — add support for embedding model and ability to count tokens (#1786)
## Issue Closes #1785 Closes #1784 ## Change * `GoogleAiChatModel` now implements `TokenCountEstimator` * new `GoogleAiEmbeddingModel` class * new `GoogleAiGeminiTokenizer` class ## General checklist - [X] There are no breaking changes - [X] I have added unit and integration tests for my change - [X] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [ ] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green <!-- Before adding documentation and example(s) (below), please wait until the PR is reviewed and approved. --> - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable)
This commit is contained in:
parent
1dc3ce6e96
commit
cb5bdd92e7
|
@ -0,0 +1,11 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
class GeminiCountTokensRequest {
|
||||
List<GeminiContent> contents;
|
||||
GeminiGenerateContentRequest generateContentRequest;
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
class GeminiCountTokensResponse {
|
||||
Integer totalTokens;
|
||||
}
|
|
@ -8,6 +8,7 @@ import java.util.List;
|
|||
@Data
|
||||
@Builder
|
||||
class GeminiGenerateContentRequest {
|
||||
private String model;
|
||||
private List<GeminiContent> contents;
|
||||
private GeminiTool tools;
|
||||
private GeminiToolConfig toolConfig;
|
||||
|
|
|
@ -1,24 +1,76 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
//import io.reactivex.rxjava3.core.Observable;
|
||||
import okhttp3.OkHttpClient;
|
||||
import okhttp3.logging.HttpLoggingInterceptor;
|
||||
import org.slf4j.Logger;
|
||||
import retrofit2.Call;
|
||||
import retrofit2.Retrofit;
|
||||
import retrofit2.converter.gson.GsonConverterFactory;
|
||||
import retrofit2.http.Body;
|
||||
import retrofit2.http.POST;
|
||||
import retrofit2.http.Path;
|
||||
import retrofit2.http.Header;
|
||||
import retrofit2.http.Headers;
|
||||
|
||||
import java.time.Duration;
|
||||
//import retrofit2.http.Streaming;
|
||||
|
||||
interface GeminiService {
|
||||
String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta/";
|
||||
String API_KEY_HEADER_NAME = "x-goog-api-key";
|
||||
String USER_AGENT = "User-Agent: LangChain4j";
|
||||
|
||||
static GeminiService getGeminiService(Logger logger, Duration timeout) {
|
||||
Retrofit.Builder retrofitBuilder = new Retrofit.Builder()
|
||||
.baseUrl(GEMINI_AI_ENDPOINT)
|
||||
.addConverterFactory(GsonConverterFactory.create());
|
||||
|
||||
OkHttpClient.Builder clientBuilder = new OkHttpClient.Builder()
|
||||
.callTimeout(timeout);
|
||||
|
||||
if (logger != null) {
|
||||
HttpLoggingInterceptor logging = new HttpLoggingInterceptor(logger::debug);
|
||||
logging.redactHeader(API_KEY_HEADER_NAME);
|
||||
logging.setLevel(HttpLoggingInterceptor.Level.BODY);
|
||||
|
||||
clientBuilder.addInterceptor(logging);
|
||||
}
|
||||
|
||||
retrofitBuilder.client(clientBuilder.build());
|
||||
Retrofit retrofit = retrofitBuilder.build();
|
||||
|
||||
return retrofit.create(GeminiService.class);
|
||||
}
|
||||
|
||||
@POST("models/{model}:generateContent")
|
||||
@Headers("User-Agent: LangChain4j")
|
||||
@Headers(USER_AGENT)
|
||||
Call<GeminiGenerateContentResponse> generateContent(
|
||||
@Path("model") String modelName,
|
||||
@Header(API_KEY_HEADER_NAME) String apiKey,
|
||||
@Body GeminiGenerateContentRequest request);
|
||||
|
||||
@POST("models/{model}:countTokens")
|
||||
@Headers(USER_AGENT)
|
||||
Call<GeminiCountTokensResponse> countTokens(
|
||||
@Path("model") String modelName,
|
||||
@Header(API_KEY_HEADER_NAME) String apiKey,
|
||||
@Body GeminiCountTokensRequest countTokensRequest);
|
||||
|
||||
@POST("models/{model}:embedContent")
|
||||
@Headers(USER_AGENT)
|
||||
Call<GoogleAiEmbeddingResponse> embed(
|
||||
@Path("model") String modelName,
|
||||
@Header(API_KEY_HEADER_NAME) String apiKey,
|
||||
@Body GoogleAiEmbeddingRequest embeddingRequest);
|
||||
|
||||
@POST("models/{model}:batchEmbedContents")
|
||||
@Headers(USER_AGENT)
|
||||
Call<GoogleAiBatchEmbeddingResponse> batchEmbed(
|
||||
@Path("model") String modelName,
|
||||
@Header(API_KEY_HEADER_NAME) String apiKey,
|
||||
@Body GoogleAiBatchEmbeddingRequest batchEmbeddingRequest);
|
||||
|
||||
/*
|
||||
@Streaming
|
||||
@POST("models/{model}:streamGenerateContent")
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
class GoogleAiBatchEmbeddingRequest {
|
||||
List<GoogleAiEmbeddingRequest> requests;
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
class GoogleAiBatchEmbeddingResponse {
|
||||
List<GoogleAiEmbeddingResponseValues> embeddings;
|
||||
}
|
|
@ -0,0 +1,195 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import dev.langchain4j.Experimental;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.Builder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.ResponseBody;
|
||||
import retrofit2.Call;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
|
||||
@Experimental
|
||||
@Slf4j
|
||||
public class GoogleAiEmbeddingModel implements EmbeddingModel {
|
||||
private static final int MAX_NUMBER_OF_SEGMENTS_PER_BATCH = 100;
|
||||
|
||||
private final GeminiService geminiService;
|
||||
|
||||
private final Gson GSON = new Gson();
|
||||
|
||||
private final String modelName;
|
||||
private final String apiKey;
|
||||
private final Integer maxRetries;
|
||||
private final TaskType taskType;
|
||||
private final String titleMetadataKey;
|
||||
private final Integer outputDimensionality;
|
||||
|
||||
@Builder
|
||||
public GoogleAiEmbeddingModel(
|
||||
String modelName,
|
||||
String apiKey,
|
||||
Integer maxRetries,
|
||||
TaskType taskType,
|
||||
String titleMetadataKey,
|
||||
Integer outputDimensionality,
|
||||
Duration timeout,
|
||||
Boolean logRequestsAndResponses
|
||||
) {
|
||||
|
||||
this.modelName = ensureNotBlank(modelName, "modelName");
|
||||
this.apiKey = ensureNotBlank(apiKey, "apiKey");
|
||||
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
|
||||
this.taskType = taskType;
|
||||
this.titleMetadataKey = getOrDefault(titleMetadataKey, "title");
|
||||
|
||||
this.outputDimensionality = outputDimensionality;
|
||||
|
||||
Duration timeout1 = getOrDefault(timeout, Duration.ofSeconds(60));
|
||||
|
||||
boolean logRequestsAndResponses1 = logRequestsAndResponses != null && logRequestsAndResponses;
|
||||
|
||||
this.geminiService = GeminiService.getGeminiService(logRequestsAndResponses1 ? log : null, timeout1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<Embedding> embed(TextSegment textSegment) {
|
||||
GoogleAiEmbeddingRequest embeddingRequest = getGoogleAiEmbeddingRequest(textSegment);
|
||||
|
||||
Call<GoogleAiEmbeddingResponse> geminiEmbeddingResponseCall =
|
||||
withRetry(() -> this.geminiService.embed(this.modelName, this.apiKey, embeddingRequest), this.maxRetries);
|
||||
|
||||
GoogleAiEmbeddingResponse geminiResponse;
|
||||
try {
|
||||
retrofit2.Response<GoogleAiEmbeddingResponse> executed = geminiEmbeddingResponseCall.execute();
|
||||
geminiResponse = executed.body();
|
||||
|
||||
if (executed.code() >= 300) {
|
||||
try (ResponseBody errorBody = executed.errorBody()) {
|
||||
GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError();
|
||||
|
||||
throw new RuntimeException(
|
||||
String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
|
||||
}
|
||||
}
|
||||
} catch (IOException e) {
|
||||
|
||||
throw new RuntimeException("An error occurred when calling the Gemini API endpoint (embed).", e);
|
||||
}
|
||||
|
||||
if (geminiResponse != null) {
|
||||
return Response.from(Embedding.from(geminiResponse.getEmbedding().getValues()));
|
||||
} else {
|
||||
throw new RuntimeException("Gemini embedding response was null (embed)");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<Embedding> embed(String text) {
|
||||
return embed(TextSegment.from(text));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
|
||||
List<GoogleAiEmbeddingRequest> embeddingRequests = textSegments.stream()
|
||||
.map(this::getGoogleAiEmbeddingRequest)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<Embedding> allEmbeddings = new ArrayList<>();
|
||||
int numberOfEmbeddings = embeddingRequests.size();
|
||||
int numberOfBatches = 1 + numberOfEmbeddings / MAX_NUMBER_OF_SEGMENTS_PER_BATCH;
|
||||
|
||||
for (int i = 0; i < numberOfBatches; i++) {
|
||||
int startIndex = MAX_NUMBER_OF_SEGMENTS_PER_BATCH * i;
|
||||
int lastIndex = Math.min(startIndex + MAX_NUMBER_OF_SEGMENTS_PER_BATCH, numberOfEmbeddings);
|
||||
|
||||
if (startIndex >= numberOfEmbeddings) break;
|
||||
|
||||
GoogleAiBatchEmbeddingRequest batchEmbeddingRequest = new GoogleAiBatchEmbeddingRequest();
|
||||
batchEmbeddingRequest.setRequests(embeddingRequests.subList(startIndex, lastIndex));
|
||||
|
||||
Call<GoogleAiBatchEmbeddingResponse> geminiBatchEmbeddingResponseCall =
|
||||
withRetry(() -> this.geminiService.batchEmbed(this.modelName, this.apiKey, batchEmbeddingRequest));
|
||||
|
||||
GoogleAiBatchEmbeddingResponse geminiResponse;
|
||||
try {
|
||||
retrofit2.Response<GoogleAiBatchEmbeddingResponse> executed = geminiBatchEmbeddingResponseCall.execute();
|
||||
geminiResponse = executed.body();
|
||||
|
||||
if (executed.code() >= 300) {
|
||||
try (ResponseBody errorBody = executed.errorBody()) {
|
||||
GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError();
|
||||
|
||||
throw new RuntimeException(
|
||||
String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
|
||||
}
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("An error occurred when calling the Gemini API endpoint (embedAll).", e);
|
||||
}
|
||||
|
||||
if (geminiResponse != null) {
|
||||
allEmbeddings.addAll(geminiResponse.getEmbeddings().stream()
|
||||
.map(values -> Embedding.from(values.getValues()))
|
||||
.collect(Collectors.toList()));
|
||||
} else {
|
||||
throw new RuntimeException("Gemini embedding response was null (embedAll)");
|
||||
}
|
||||
}
|
||||
|
||||
return Response.from(allEmbeddings);
|
||||
}
|
||||
|
||||
private GoogleAiEmbeddingRequest getGoogleAiEmbeddingRequest(TextSegment textSegment) {
|
||||
GeminiPart geminiPart = GeminiPart.builder()
|
||||
.text(textSegment.text())
|
||||
.build();
|
||||
|
||||
GeminiContent content = new GeminiContent(Collections.singletonList(geminiPart), null);
|
||||
|
||||
String title = null;
|
||||
if (TaskType.RETRIEVAL_DOCUMENT.equals(this.taskType)) {
|
||||
if (textSegment.metadata() != null && textSegment.metadata().getString(this.titleMetadataKey) != null) {
|
||||
title = textSegment.metadata().getString(this.titleMetadataKey);
|
||||
}
|
||||
}
|
||||
|
||||
return new GoogleAiEmbeddingRequest(
|
||||
"models/" + this.modelName,
|
||||
content,
|
||||
this.taskType,
|
||||
title,
|
||||
this.outputDimensionality
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return getOrDefault(this.outputDimensionality, 768);
|
||||
}
|
||||
|
||||
public enum TaskType {
|
||||
RETRIEVAL_QUERY,
|
||||
RETRIEVAL_DOCUMENT,
|
||||
SEMANTIC_SIMILARITY,
|
||||
CLASSIFICATION,
|
||||
CLUSTERING,
|
||||
QUESTION_ANSWERING,
|
||||
FACT_VERIFICATION
|
||||
}
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
@Builder
|
||||
@Data
|
||||
class GoogleAiEmbeddingRequest {
|
||||
String model;
|
||||
GeminiContent content;
|
||||
GoogleAiEmbeddingModel.TaskType taskType;
|
||||
String title;
|
||||
Integer outputDimensionality;
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
class GoogleAiEmbeddingResponse {
|
||||
GoogleAiEmbeddingResponseValues embedding;
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
class GoogleAiEmbeddingResponseValues {
|
||||
List<Float> values;
|
||||
}
|
|
@ -7,24 +7,20 @@ import dev.langchain4j.data.message.AiMessage;
|
|||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.model.chat.Capability;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.TokenCountEstimator;
|
||||
import dev.langchain4j.model.chat.listener.*;
|
||||
import dev.langchain4j.model.chat.request.ChatRequest;
|
||||
import dev.langchain4j.model.chat.request.ResponseFormat;
|
||||
import dev.langchain4j.model.chat.request.ResponseFormatType;
|
||||
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonSchema;
|
||||
import dev.langchain4j.model.chat.response.ChatResponse;
|
||||
import dev.langchain4j.model.output.FinishReason;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import lombok.Builder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.OkHttpClient;
|
||||
import okhttp3.ResponseBody;
|
||||
import okhttp3.logging.HttpLoggingInterceptor;
|
||||
import retrofit2.Call;
|
||||
import retrofit2.Retrofit;
|
||||
import retrofit2.converter.gson.GsonConverterFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Duration;
|
||||
|
@ -43,7 +39,6 @@ import static dev.langchain4j.internal.Utils.copyIfNotNull;
|
|||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA;
|
||||
import static dev.langchain4j.model.googleai.GeminiService.API_KEY_HEADER_NAME;
|
||||
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromMessageToGContent;
|
||||
import static dev.langchain4j.model.googleai.FinishReasonMapper.fromGFinishReasonToFinishReason;
|
||||
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromGPartsToAiMessage;
|
||||
|
@ -53,9 +48,7 @@ import static java.util.Collections.emptyList;
|
|||
|
||||
@Experimental
|
||||
@Slf4j
|
||||
public class GoogleAiGeminiChatModel implements ChatLanguageModel {
|
||||
private static final String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta/";
|
||||
|
||||
public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
private static final Gson GSON = new Gson();
|
||||
|
||||
private final GeminiService geminiService;
|
||||
|
@ -68,7 +61,6 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel {
|
|||
private final Integer topK;
|
||||
private final Double topP;
|
||||
private final Integer maxOutputTokens;
|
||||
private final Duration timeout;
|
||||
private final List<String> stopSequences;
|
||||
|
||||
private final Integer candidateCount;
|
||||
|
@ -77,14 +69,15 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel {
|
|||
|
||||
private final GeminiFunctionCallingConfig toolConfig;
|
||||
|
||||
private final Boolean logRequestsAndResponses;
|
||||
|
||||
private final boolean allowCodeExecution;
|
||||
private final boolean includeCodeExecutionOutput;
|
||||
|
||||
private final Boolean logRequestsAndResponses;
|
||||
private final List<GeminiSafetySetting> safetySettings;
|
||||
private final List<ChatModelListener> listeners;
|
||||
|
||||
private final GoogleAiGeminiTokenizer geminiTokenizer;
|
||||
|
||||
@Builder
|
||||
public GoogleAiGeminiChatModel(String apiKey, String modelName,
|
||||
Integer maxRetries,
|
||||
|
@ -111,22 +104,30 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel {
|
|||
this.candidateCount = getOrDefault(candidateCount, 1);
|
||||
this.stopSequences = getOrDefault(stopSequences, emptyList());
|
||||
|
||||
this.timeout = getOrDefault(timeout, ofSeconds(60));
|
||||
|
||||
this.toolConfig = toolConfig;
|
||||
|
||||
this.allowCodeExecution = allowCodeExecution != null ? allowCodeExecution : false;
|
||||
this.includeCodeExecutionOutput = includeCodeExecutionOutput != null ? includeCodeExecutionOutput : false;
|
||||
this.logRequestsAndResponses = getOrDefault(logRequestsAndResponses, false);
|
||||
|
||||
this.safetySettings = copyIfNotNull(safetySettings);
|
||||
|
||||
this.responseFormat = responseFormat;
|
||||
|
||||
this.logRequestsAndResponses = getOrDefault(logRequestsAndResponses, false);
|
||||
|
||||
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
|
||||
|
||||
this.geminiService = getGeminiService();
|
||||
this.geminiService = GeminiService.getGeminiService(
|
||||
getOrDefault(logRequestsAndResponses, false) ? this.log : null,
|
||||
getOrDefault(timeout, ofSeconds(60))
|
||||
);
|
||||
|
||||
this.geminiTokenizer = GoogleAiGeminiTokenizer.builder()
|
||||
.modelName(this.modelName)
|
||||
.apiKey(this.apiKey)
|
||||
.timeout(getOrDefault(timeout, ofSeconds(60)))
|
||||
.maxRetries(this.maxRetries)
|
||||
.logRequestsAndResponses(this.logRequestsAndResponses)
|
||||
.build();
|
||||
}
|
||||
|
||||
private static String computeMimeType(ResponseFormat responseFormat) {
|
||||
|
@ -318,6 +319,11 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCount(List<ChatMessage> messages) {
|
||||
return geminiTokenizer.estimateTokenCountInMessages(messages);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<Capability> supportedCapabilities() {
|
||||
Set<Capability> capabilities = new HashSet<>();
|
||||
|
@ -328,29 +334,6 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel {
|
|||
return capabilities;
|
||||
}
|
||||
|
||||
private GeminiService getGeminiService() {
|
||||
Retrofit.Builder retrofitBuilder = new Retrofit.Builder()
|
||||
.baseUrl(GEMINI_AI_ENDPOINT)
|
||||
.addConverterFactory(GsonConverterFactory.create());
|
||||
|
||||
if (this.logRequestsAndResponses) {
|
||||
HttpLoggingInterceptor logging = new HttpLoggingInterceptor(log::debug);
|
||||
logging.redactHeader(API_KEY_HEADER_NAME);
|
||||
logging.setLevel(HttpLoggingInterceptor.Level.BODY);
|
||||
|
||||
OkHttpClient okHttpClient = new OkHttpClient.Builder()
|
||||
.addInterceptor(logging)
|
||||
.callTimeout(this.timeout)
|
||||
.build();
|
||||
|
||||
retrofitBuilder.client(okHttpClient);
|
||||
}
|
||||
|
||||
Retrofit retrofit = retrofitBuilder.build();
|
||||
|
||||
return retrofit.create(GeminiService.class);
|
||||
}
|
||||
|
||||
public static class GoogleAiGeminiChatModelBuilder {
|
||||
public GoogleAiGeminiChatModelBuilder toolConfig(GeminiMode mode, String... allowedFunctionNames) {
|
||||
this.toolConfig = new GeminiFunctionCallingConfig(mode, Arrays.asList(allowedFunctionNames));
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.model.Tokenizer;
|
||||
import lombok.Builder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.ResponseBody;
|
||||
import retrofit2.Call;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Duration;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromMessageToGContent;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
||||
@Slf4j
|
||||
public class GoogleAiGeminiTokenizer implements Tokenizer {
|
||||
private static final Gson GSON = new Gson();
|
||||
|
||||
private final GeminiService geminiService;
|
||||
private final String modelName;
|
||||
private final String apiKey;
|
||||
private final Integer maxRetries;
|
||||
|
||||
@Builder
|
||||
GoogleAiGeminiTokenizer(
|
||||
String modelName,
|
||||
String apiKey,
|
||||
Boolean logRequestsAndResponses,
|
||||
Duration timeout,
|
||||
Integer maxRetries
|
||||
) {
|
||||
this.modelName = ensureNotBlank(modelName, "modelName");
|
||||
this.apiKey = ensureNotBlank(apiKey, "apiKey");
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
this.geminiService = GeminiService.getGeminiService(logRequestsAndResponses ? log : null,
|
||||
timeout != null ? timeout : Duration.ofSeconds(60));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInText(String text) {
|
||||
return estimateTokenCountInMessages(singletonList(UserMessage.from(text)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInMessage(ChatMessage message) {
|
||||
return estimateTokenCountInMessages(singletonList(message));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInToolExecutionRequests(Iterable<ToolExecutionRequest> toolExecutionRequests) {
|
||||
List<ToolExecutionRequest> allToolRequests = new LinkedList<>();
|
||||
toolExecutionRequests.forEach(allToolRequests::add);
|
||||
|
||||
return estimateTokenCountInMessage(AiMessage.from(allToolRequests));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInMessages(Iterable<ChatMessage> messages) {
|
||||
List<ChatMessage> allMessages = new LinkedList<>();
|
||||
messages.forEach(allMessages::add);
|
||||
|
||||
List<GeminiContent> geminiContentList = fromMessageToGContent(allMessages, null);
|
||||
|
||||
GeminiCountTokensRequest countTokensRequest = new GeminiCountTokensRequest();
|
||||
countTokensRequest.setContents(geminiContentList);
|
||||
|
||||
return estimateTokenCount(countTokensRequest);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int estimateTokenCountInToolSpecifications(Iterable<ToolSpecification> toolSpecifications) {
|
||||
List<ToolSpecification> allTools = new LinkedList<>();
|
||||
toolSpecifications.forEach(allTools::add);
|
||||
|
||||
GeminiContent dummyContent = GeminiContent.builder().parts(
|
||||
singletonList(GeminiPart.builder()
|
||||
.text("Dummy content") // This string contains 2 tokens
|
||||
.build())
|
||||
).build();
|
||||
|
||||
GeminiCountTokensRequest countTokensRequestWithDummyContent = new GeminiCountTokensRequest();
|
||||
countTokensRequestWithDummyContent.setGenerateContentRequest(GeminiGenerateContentRequest.builder()
|
||||
.model("models/" + this.modelName)
|
||||
.contents(singletonList(dummyContent))
|
||||
.tools(FunctionMapper.fromToolSepcsToGTool(allTools, false))
|
||||
.build());
|
||||
|
||||
// The API doesn't allow us to make a request to count the tokens of the tool specifications only.
|
||||
// Instead, we take the approach of adding a dummy content in the request, and subtract the tokens for the dummy request.
|
||||
// The string "Dummy content" accounts for 2 tokens. So let's subtract 2 from the overall count.
|
||||
return estimateTokenCount(countTokensRequestWithDummyContent) - 2;
|
||||
}
|
||||
|
||||
private int estimateTokenCount(GeminiCountTokensRequest countTokensRequest) {
|
||||
Call<GeminiCountTokensResponse> responseCall =
|
||||
withRetry(() -> this.geminiService.countTokens(this.modelName, this.apiKey, countTokensRequest), this.maxRetries);
|
||||
|
||||
GeminiCountTokensResponse countTokensResponse;
|
||||
try {
|
||||
retrofit2.Response<GeminiCountTokensResponse> executed = responseCall.execute();
|
||||
countTokensResponse = executed.body();
|
||||
|
||||
if (executed.code() >= 300) {
|
||||
try (ResponseBody errorBody = executed.errorBody()) {
|
||||
GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError();
|
||||
|
||||
throw new RuntimeException(
|
||||
String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
|
||||
}
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("An error occurred when calling the Gemini API endpoint to calculate tokens count", e);
|
||||
}
|
||||
|
||||
return countTokensResponse.getTotalTokens();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
public class GoogleAiEmbeddingModelIT {
|
||||
|
||||
private static final String GOOGLE_AI_GEMINI_API_KEY = System.getenv("GOOGLE_AI_GEMINI_API_KEY");
|
||||
|
||||
@Test
|
||||
void should_embed_one_text() {
|
||||
// given
|
||||
GoogleAiEmbeddingModel embeddingModel = GoogleAiEmbeddingModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("embedding-001")
|
||||
.maxRetries(3)
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
|
||||
// when
|
||||
Response<Embedding> embed = embeddingModel.embed("Hello world!");
|
||||
|
||||
// then
|
||||
Embedding content = embed.content();
|
||||
assertThat(content).isNotNull();
|
||||
assertThat(content.vector()).isNotNull();
|
||||
assertThat(content.vector()).hasSize(768);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_use_metadata() {
|
||||
// given
|
||||
GoogleAiEmbeddingModel embeddingModel = GoogleAiEmbeddingModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("embedding-001")
|
||||
.maxRetries(3)
|
||||
.logRequestsAndResponses(true)
|
||||
.titleMetadataKey("title")
|
||||
.taskType(GoogleAiEmbeddingModel.TaskType.RETRIEVAL_DOCUMENT)
|
||||
.build();
|
||||
|
||||
// when
|
||||
TextSegment textSegment = TextSegment.from(
|
||||
"What is the capital of France?",
|
||||
Metadata.from("title", "document title")
|
||||
);
|
||||
Response<Embedding> embed = embeddingModel.embed(textSegment);
|
||||
|
||||
// then
|
||||
Embedding content = embed.content();
|
||||
assertThat(content).isNotNull();
|
||||
assertThat(content.vector()).isNotNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_embed_in_batch() {
|
||||
// given
|
||||
GoogleAiEmbeddingModel embeddingModel = GoogleAiEmbeddingModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("embedding-001")
|
||||
.maxRetries(3)
|
||||
.logRequestsAndResponses(true)
|
||||
.outputDimensionality(512)
|
||||
.build();
|
||||
|
||||
// when
|
||||
List<TextSegment> textSegments = Arrays.asList(
|
||||
TextSegment.from("What is the capital of France?"),
|
||||
TextSegment.from("What is the capital of Germany?")
|
||||
);
|
||||
|
||||
Response<List<Embedding>> embed = embeddingModel.embedAll(textSegments);
|
||||
|
||||
// then
|
||||
List<Embedding> embeddings = embed.content();
|
||||
assertThat(embeddings).isNotNull();
|
||||
assertThat(embeddings).hasSize(2);
|
||||
assertThat(embeddings.get(0).vector()).isNotNull();
|
||||
assertThat(embeddings.get(0).vector()).hasSize(512);
|
||||
assertThat(embeddings.get(1).vector()).isNotNull();
|
||||
assertThat(embeddings.get(1).vector()).hasSize(512);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_embed_more_than_100() {
|
||||
// given
|
||||
GoogleAiEmbeddingModel embeddingModel = GoogleAiEmbeddingModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("text-embedding-004")
|
||||
.maxRetries(3)
|
||||
.build();
|
||||
|
||||
// when
|
||||
List<TextSegment> textSegments = new ArrayList<>();
|
||||
for (int i = 0; i < 300; i++) {
|
||||
textSegments.add(TextSegment.from("What is the capital of France? "));
|
||||
}
|
||||
|
||||
Response<List<Embedding>> allEmbeddings = embeddingModel.embedAll(textSegments);
|
||||
|
||||
// then
|
||||
assertThat(allEmbeddings.content()).hasSize(300);
|
||||
}
|
||||
}
|
|
@ -589,10 +589,10 @@ public class GoogleAiGeminiChatModelIT {
|
|||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.logRequestsAndResponses(true)
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchemas.jsonSchemaFrom(Color.class).get())
|
||||
.build())
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchemas.jsonSchemaFrom(Color.class).get())
|
||||
.build())
|
||||
// Equivalent to:
|
||||
// .responseFormat(ResponseFormat.builder()
|
||||
// .type(JSON)
|
||||
|
@ -674,6 +674,35 @@ public class GoogleAiGeminiChatModelIT {
|
|||
assertThat(chatResponse.aiMessage().hasToolExecutionRequests()).isFalse();
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_count_tokens() {
|
||||
// given
|
||||
GoogleAiGeminiChatModel gemini = GoogleAiGeminiChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
|
||||
// when
|
||||
int countedTokens = gemini.estimateTokenCount("What is the capital of France?");
|
||||
|
||||
// then
|
||||
assertThat(countedTokens).isGreaterThan(0);
|
||||
|
||||
// when
|
||||
List<ChatMessage> messageList = Arrays.asList(
|
||||
SystemMessage.from("You are a helpful geography teacher"),
|
||||
UserMessage.from("What is the capital of Germany?"),
|
||||
AiMessage.from("Berlin"),
|
||||
UserMessage.from("Thank you!"),
|
||||
AiMessage.from("You're welcome!")
|
||||
);
|
||||
int listOfMsgTokenCount = gemini.estimateTokenCount(messageList);
|
||||
|
||||
// then
|
||||
assertThat(listOfMsgTokenCount).isGreaterThan(0);
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void afterEach() throws InterruptedException {
|
||||
String ciDelaySeconds = System.getenv("CI_DELAY_SECONDS_GOOGLE_AI_GEMINI");
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
import dev.langchain4j.agent.tool.JsonSchemaProperty;
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
|
||||
|
||||
public class GoogleAiGeminiTokenizerTest {
|
||||
private static final String GOOGLE_AI_GEMINI_API_KEY = System.getenv("GOOGLE_AI_GEMINI_API_KEY");
|
||||
|
||||
@Test
|
||||
void should_estimate_token_count_for_text() {
|
||||
// given
|
||||
GoogleAiGeminiTokenizer tokenizer = GoogleAiGeminiTokenizer.builder()
|
||||
.logRequestsAndResponses(true)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.build();
|
||||
|
||||
// when
|
||||
int count = tokenizer.estimateTokenCountInText("Hello world!");
|
||||
|
||||
// then
|
||||
assertThat(count).isEqualTo(4);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_estimate_token_count_for_a_message() {
|
||||
// given
|
||||
GoogleAiGeminiTokenizer tokenizer = GoogleAiGeminiTokenizer.builder()
|
||||
.logRequestsAndResponses(true)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.build();
|
||||
|
||||
// when
|
||||
int count = tokenizer.estimateTokenCountInMessage(UserMessage.from("Hello World!"));
|
||||
|
||||
// then
|
||||
assertThat(count).isEqualTo(4);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_estimate_token_count_for_list_of_messages() {
|
||||
// given
|
||||
GoogleAiGeminiTokenizer tokenizer = GoogleAiGeminiTokenizer.builder()
|
||||
.logRequestsAndResponses(true)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.build();
|
||||
|
||||
// when
|
||||
int count = tokenizer.estimateTokenCountInMessages(
|
||||
Arrays.asList(
|
||||
UserMessage.from("Hello World!"),
|
||||
AiMessage.from("Hi! How can I help you today?")
|
||||
)
|
||||
);
|
||||
|
||||
// then
|
||||
assertThat(count).isEqualTo(14);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_estimate_token_count_for_tool_exec_reqs() {
|
||||
// given
|
||||
GoogleAiGeminiTokenizer tokenizer = GoogleAiGeminiTokenizer.builder()
|
||||
.logRequestsAndResponses(true)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.build();
|
||||
|
||||
// when
|
||||
int count = tokenizer.estimateTokenCountInToolExecutionRequests(
|
||||
Arrays.asList(
|
||||
ToolExecutionRequest.builder()
|
||||
.name("weatherForecast")
|
||||
.arguments("{ \"location\": \"Paris\" }")
|
||||
.build(),
|
||||
ToolExecutionRequest.builder()
|
||||
.name("weatherForecast")
|
||||
.arguments("{ \"location\": \"London\" }")
|
||||
.build()
|
||||
)
|
||||
);
|
||||
|
||||
// then
|
||||
assertThat(count).isEqualTo(29);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
void should_estimate_token_count_for_tool_specs() {
|
||||
// given
|
||||
GoogleAiGeminiTokenizer tokenizer = GoogleAiGeminiTokenizer.builder()
|
||||
.logRequestsAndResponses(true)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.build();
|
||||
|
||||
// when
|
||||
int count = tokenizer.estimateTokenCountInToolSpecifications(
|
||||
Arrays.asList(
|
||||
ToolSpecification.builder()
|
||||
.name("weatherForecast")
|
||||
.description("Get the weather forecast for a given location on a give date")
|
||||
.addParameter("location", JsonSchemaProperty.STRING, JsonSchemaProperty.description("the location"))
|
||||
.addParameter("date", JsonSchemaProperty.STRING, JsonSchemaProperty.description("the date"))
|
||||
.build(),
|
||||
ToolSpecification.builder()
|
||||
.name("convertFahrenheitToCelsius")
|
||||
.description("Convert a temperature in Fahrenheit to Celsius")
|
||||
.addParameter("fahrenheit", JsonSchemaProperty.NUMBER, JsonSchemaProperty.description("the temperature in Fahrenheit"))
|
||||
.build()
|
||||
)
|
||||
);
|
||||
|
||||
// then
|
||||
assertThat(count).isEqualTo(114);
|
||||
}
|
||||
}
|
|
@ -114,7 +114,7 @@
|
|||
<artifactId>libraries-bom</artifactId>
|
||||
<scope>import</scope>
|
||||
<type>pom</type>
|
||||
<version>26.45.0</version>
|
||||
<version>26.46.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
|
|
|
@ -9,6 +9,7 @@ import dev.langchain4j.agent.tool.ToolSpecification;
|
|||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.TokenCountEstimator;
|
||||
import dev.langchain4j.model.chat.listener.*;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.vertexai.spi.VertexAiGeminiChatModelBuilderFactory;
|
||||
|
@ -82,6 +83,7 @@ public class VertexAiGeminiChatModel implements ChatLanguageModel, Closeable {
|
|||
Integer maxOutputTokens,
|
||||
Integer topK,
|
||||
Float topP,
|
||||
Integer seed,
|
||||
Integer maxRetries,
|
||||
String responseMimeType,
|
||||
Schema responseSchema,
|
||||
|
@ -106,6 +108,9 @@ public class VertexAiGeminiChatModel implements ChatLanguageModel, Closeable {
|
|||
if (topP != null) {
|
||||
generationConfigBuilder.setTopP(topP);
|
||||
}
|
||||
if (seed != null) {
|
||||
generationConfigBuilder.setSeed(seed);
|
||||
}
|
||||
if (responseMimeType != null) {
|
||||
generationConfigBuilder.setResponseMimeType(responseMimeType);
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
<dependency>
|
||||
<groupId>com.google.cloud</groupId>
|
||||
<artifactId>google-cloud-aiplatform</artifactId>
|
||||
<version>3.49.0</version>
|
||||
<version>3.50.0</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>commons-logging</groupId>
|
||||
|
|
|
@ -132,7 +132,7 @@ public class VertexAiEmbeddingModel extends DimensionAwareEmbeddingModel {
|
|||
embeddingInstance.setTaskType(taskType);
|
||||
if (this.taskType.equals(TaskType.RETRIEVAL_DOCUMENT)) {
|
||||
// Title metadata is used for calculating embeddings for document retrieval
|
||||
embeddingInstance.setTitle(segment.metadata(titleMetadataKey));
|
||||
embeddingInstance.setTitle(segment.metadata().getString(titleMetadataKey));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue