Google AI — add support for embedding model and ability to count tokens (#1786)

## Issue
Closes #1785
Closes #1784

## Change
* `GoogleAiChatModel` now implements `TokenCountEstimator`
* new `GoogleAiEmbeddingModel` class
* new `GoogleAiGeminiTokenizer` class

## General checklist

- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [X] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [ ] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
This commit is contained in:
Guillaume Laforge 2024-09-18 12:41:29 +02:00 committed by GitHub
parent 1dc3ce6e96
commit cb5bdd92e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 752 additions and 48 deletions

View File

@ -0,0 +1,11 @@
package dev.langchain4j.model.googleai;
import lombok.Data;
import java.util.List;
@Data
class GeminiCountTokensRequest {
List<GeminiContent> contents;
GeminiGenerateContentRequest generateContentRequest;
}

View File

@ -0,0 +1,8 @@
package dev.langchain4j.model.googleai;
import lombok.Data;
@Data
class GeminiCountTokensResponse {
Integer totalTokens;
}

View File

@ -8,6 +8,7 @@ import java.util.List;
@Data
@Builder
class GeminiGenerateContentRequest {
private String model;
private List<GeminiContent> contents;
private GeminiTool tools;
private GeminiToolConfig toolConfig;

View File

@ -1,24 +1,76 @@
package dev.langchain4j.model.googleai;
//import io.reactivex.rxjava3.core.Observable;
import okhttp3.OkHttpClient;
import okhttp3.logging.HttpLoggingInterceptor;
import org.slf4j.Logger;
import retrofit2.Call;
import retrofit2.Retrofit;
import retrofit2.converter.gson.GsonConverterFactory;
import retrofit2.http.Body;
import retrofit2.http.POST;
import retrofit2.http.Path;
import retrofit2.http.Header;
import retrofit2.http.Headers;
import java.time.Duration;
//import retrofit2.http.Streaming;
interface GeminiService {
String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta/";
String API_KEY_HEADER_NAME = "x-goog-api-key";
String USER_AGENT = "User-Agent: LangChain4j";
static GeminiService getGeminiService(Logger logger, Duration timeout) {
Retrofit.Builder retrofitBuilder = new Retrofit.Builder()
.baseUrl(GEMINI_AI_ENDPOINT)
.addConverterFactory(GsonConverterFactory.create());
OkHttpClient.Builder clientBuilder = new OkHttpClient.Builder()
.callTimeout(timeout);
if (logger != null) {
HttpLoggingInterceptor logging = new HttpLoggingInterceptor(logger::debug);
logging.redactHeader(API_KEY_HEADER_NAME);
logging.setLevel(HttpLoggingInterceptor.Level.BODY);
clientBuilder.addInterceptor(logging);
}
retrofitBuilder.client(clientBuilder.build());
Retrofit retrofit = retrofitBuilder.build();
return retrofit.create(GeminiService.class);
}
@POST("models/{model}:generateContent")
@Headers("User-Agent: LangChain4j")
@Headers(USER_AGENT)
Call<GeminiGenerateContentResponse> generateContent(
@Path("model") String modelName,
@Header(API_KEY_HEADER_NAME) String apiKey,
@Body GeminiGenerateContentRequest request);
@POST("models/{model}:countTokens")
@Headers(USER_AGENT)
Call<GeminiCountTokensResponse> countTokens(
@Path("model") String modelName,
@Header(API_KEY_HEADER_NAME) String apiKey,
@Body GeminiCountTokensRequest countTokensRequest);
@POST("models/{model}:embedContent")
@Headers(USER_AGENT)
Call<GoogleAiEmbeddingResponse> embed(
@Path("model") String modelName,
@Header(API_KEY_HEADER_NAME) String apiKey,
@Body GoogleAiEmbeddingRequest embeddingRequest);
@POST("models/{model}:batchEmbedContents")
@Headers(USER_AGENT)
Call<GoogleAiBatchEmbeddingResponse> batchEmbed(
@Path("model") String modelName,
@Header(API_KEY_HEADER_NAME) String apiKey,
@Body GoogleAiBatchEmbeddingRequest batchEmbeddingRequest);
/*
@Streaming
@POST("models/{model}:streamGenerateContent")

View File

@ -0,0 +1,10 @@
package dev.langchain4j.model.googleai;
import lombok.Data;
import java.util.List;
@Data
class GoogleAiBatchEmbeddingRequest {
List<GoogleAiEmbeddingRequest> requests;
}

View File

@ -0,0 +1,10 @@
package dev.langchain4j.model.googleai;
import lombok.Data;
import java.util.List;
@Data
class GoogleAiBatchEmbeddingResponse {
List<GoogleAiEmbeddingResponseValues> embeddings;
}

View File

@ -0,0 +1,195 @@
package dev.langchain4j.model.googleai;
import com.google.gson.Gson;
import dev.langchain4j.Experimental;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import okhttp3.ResponseBody;
import retrofit2.Call;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
@Experimental
@Slf4j
public class GoogleAiEmbeddingModel implements EmbeddingModel {
private static final int MAX_NUMBER_OF_SEGMENTS_PER_BATCH = 100;
private final GeminiService geminiService;
private final Gson GSON = new Gson();
private final String modelName;
private final String apiKey;
private final Integer maxRetries;
private final TaskType taskType;
private final String titleMetadataKey;
private final Integer outputDimensionality;
@Builder
public GoogleAiEmbeddingModel(
String modelName,
String apiKey,
Integer maxRetries,
TaskType taskType,
String titleMetadataKey,
Integer outputDimensionality,
Duration timeout,
Boolean logRequestsAndResponses
) {
this.modelName = ensureNotBlank(modelName, "modelName");
this.apiKey = ensureNotBlank(apiKey, "apiKey");
this.maxRetries = getOrDefault(maxRetries, 3);
this.taskType = taskType;
this.titleMetadataKey = getOrDefault(titleMetadataKey, "title");
this.outputDimensionality = outputDimensionality;
Duration timeout1 = getOrDefault(timeout, Duration.ofSeconds(60));
boolean logRequestsAndResponses1 = logRequestsAndResponses != null && logRequestsAndResponses;
this.geminiService = GeminiService.getGeminiService(logRequestsAndResponses1 ? log : null, timeout1);
}
@Override
public Response<Embedding> embed(TextSegment textSegment) {
GoogleAiEmbeddingRequest embeddingRequest = getGoogleAiEmbeddingRequest(textSegment);
Call<GoogleAiEmbeddingResponse> geminiEmbeddingResponseCall =
withRetry(() -> this.geminiService.embed(this.modelName, this.apiKey, embeddingRequest), this.maxRetries);
GoogleAiEmbeddingResponse geminiResponse;
try {
retrofit2.Response<GoogleAiEmbeddingResponse> executed = geminiEmbeddingResponseCall.execute();
geminiResponse = executed.body();
if (executed.code() >= 300) {
try (ResponseBody errorBody = executed.errorBody()) {
GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError();
throw new RuntimeException(
String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
}
}
} catch (IOException e) {
throw new RuntimeException("An error occurred when calling the Gemini API endpoint (embed).", e);
}
if (geminiResponse != null) {
return Response.from(Embedding.from(geminiResponse.getEmbedding().getValues()));
} else {
throw new RuntimeException("Gemini embedding response was null (embed)");
}
}
@Override
public Response<Embedding> embed(String text) {
return embed(TextSegment.from(text));
}
@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
List<GoogleAiEmbeddingRequest> embeddingRequests = textSegments.stream()
.map(this::getGoogleAiEmbeddingRequest)
.collect(Collectors.toList());
List<Embedding> allEmbeddings = new ArrayList<>();
int numberOfEmbeddings = embeddingRequests.size();
int numberOfBatches = 1 + numberOfEmbeddings / MAX_NUMBER_OF_SEGMENTS_PER_BATCH;
for (int i = 0; i < numberOfBatches; i++) {
int startIndex = MAX_NUMBER_OF_SEGMENTS_PER_BATCH * i;
int lastIndex = Math.min(startIndex + MAX_NUMBER_OF_SEGMENTS_PER_BATCH, numberOfEmbeddings);
if (startIndex >= numberOfEmbeddings) break;
GoogleAiBatchEmbeddingRequest batchEmbeddingRequest = new GoogleAiBatchEmbeddingRequest();
batchEmbeddingRequest.setRequests(embeddingRequests.subList(startIndex, lastIndex));
Call<GoogleAiBatchEmbeddingResponse> geminiBatchEmbeddingResponseCall =
withRetry(() -> this.geminiService.batchEmbed(this.modelName, this.apiKey, batchEmbeddingRequest));
GoogleAiBatchEmbeddingResponse geminiResponse;
try {
retrofit2.Response<GoogleAiBatchEmbeddingResponse> executed = geminiBatchEmbeddingResponseCall.execute();
geminiResponse = executed.body();
if (executed.code() >= 300) {
try (ResponseBody errorBody = executed.errorBody()) {
GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError();
throw new RuntimeException(
String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
}
}
} catch (IOException e) {
throw new RuntimeException("An error occurred when calling the Gemini API endpoint (embedAll).", e);
}
if (geminiResponse != null) {
allEmbeddings.addAll(geminiResponse.getEmbeddings().stream()
.map(values -> Embedding.from(values.getValues()))
.collect(Collectors.toList()));
} else {
throw new RuntimeException("Gemini embedding response was null (embedAll)");
}
}
return Response.from(allEmbeddings);
}
private GoogleAiEmbeddingRequest getGoogleAiEmbeddingRequest(TextSegment textSegment) {
GeminiPart geminiPart = GeminiPart.builder()
.text(textSegment.text())
.build();
GeminiContent content = new GeminiContent(Collections.singletonList(geminiPart), null);
String title = null;
if (TaskType.RETRIEVAL_DOCUMENT.equals(this.taskType)) {
if (textSegment.metadata() != null && textSegment.metadata().getString(this.titleMetadataKey) != null) {
title = textSegment.metadata().getString(this.titleMetadataKey);
}
}
return new GoogleAiEmbeddingRequest(
"models/" + this.modelName,
content,
this.taskType,
title,
this.outputDimensionality
);
}
@Override
public int dimension() {
return getOrDefault(this.outputDimensionality, 768);
}
public enum TaskType {
RETRIEVAL_QUERY,
RETRIEVAL_DOCUMENT,
SEMANTIC_SIMILARITY,
CLASSIFICATION,
CLUSTERING,
QUESTION_ANSWERING,
FACT_VERIFICATION
}
}

View File

@ -0,0 +1,14 @@
package dev.langchain4j.model.googleai;
import lombok.Builder;
import lombok.Data;
@Builder
@Data
class GoogleAiEmbeddingRequest {
String model;
GeminiContent content;
GoogleAiEmbeddingModel.TaskType taskType;
String title;
Integer outputDimensionality;
}

View File

@ -0,0 +1,8 @@
package dev.langchain4j.model.googleai;
import lombok.Data;
@Data
class GoogleAiEmbeddingResponse {
GoogleAiEmbeddingResponseValues embedding;
}

View File

@ -0,0 +1,10 @@
package dev.langchain4j.model.googleai;
import lombok.Data;
import java.util.List;
@Data
class GoogleAiEmbeddingResponseValues {
List<Float> values;
}

View File

@ -7,24 +7,20 @@ import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.Capability;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.chat.listener.*;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.ResponseFormatType;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import okhttp3.ResponseBody;
import okhttp3.logging.HttpLoggingInterceptor;
import retrofit2.Call;
import retrofit2.Retrofit;
import retrofit2.converter.gson.GsonConverterFactory;
import java.io.IOException;
import java.time.Duration;
@ -43,7 +39,6 @@ import static dev.langchain4j.internal.Utils.copyIfNotNull;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA;
import static dev.langchain4j.model.googleai.GeminiService.API_KEY_HEADER_NAME;
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromMessageToGContent;
import static dev.langchain4j.model.googleai.FinishReasonMapper.fromGFinishReasonToFinishReason;
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromGPartsToAiMessage;
@ -53,9 +48,7 @@ import static java.util.Collections.emptyList;
@Experimental
@Slf4j
public class GoogleAiGeminiChatModel implements ChatLanguageModel {
private static final String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta/";
public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEstimator {
private static final Gson GSON = new Gson();
private final GeminiService geminiService;
@ -68,7 +61,6 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel {
private final Integer topK;
private final Double topP;
private final Integer maxOutputTokens;
private final Duration timeout;
private final List<String> stopSequences;
private final Integer candidateCount;
@ -77,14 +69,15 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel {
private final GeminiFunctionCallingConfig toolConfig;
private final Boolean logRequestsAndResponses;
private final boolean allowCodeExecution;
private final boolean includeCodeExecutionOutput;
private final Boolean logRequestsAndResponses;
private final List<GeminiSafetySetting> safetySettings;
private final List<ChatModelListener> listeners;
private final GoogleAiGeminiTokenizer geminiTokenizer;
@Builder
public GoogleAiGeminiChatModel(String apiKey, String modelName,
Integer maxRetries,
@ -111,22 +104,30 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel {
this.candidateCount = getOrDefault(candidateCount, 1);
this.stopSequences = getOrDefault(stopSequences, emptyList());
this.timeout = getOrDefault(timeout, ofSeconds(60));
this.toolConfig = toolConfig;
this.allowCodeExecution = allowCodeExecution != null ? allowCodeExecution : false;
this.includeCodeExecutionOutput = includeCodeExecutionOutput != null ? includeCodeExecutionOutput : false;
this.logRequestsAndResponses = getOrDefault(logRequestsAndResponses, false);
this.safetySettings = copyIfNotNull(safetySettings);
this.responseFormat = responseFormat;
this.logRequestsAndResponses = getOrDefault(logRequestsAndResponses, false);
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
this.geminiService = getGeminiService();
this.geminiService = GeminiService.getGeminiService(
getOrDefault(logRequestsAndResponses, false) ? this.log : null,
getOrDefault(timeout, ofSeconds(60))
);
this.geminiTokenizer = GoogleAiGeminiTokenizer.builder()
.modelName(this.modelName)
.apiKey(this.apiKey)
.timeout(getOrDefault(timeout, ofSeconds(60)))
.maxRetries(this.maxRetries)
.logRequestsAndResponses(this.logRequestsAndResponses)
.build();
}
private static String computeMimeType(ResponseFormat responseFormat) {
@ -318,6 +319,11 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel {
}
}
@Override
public int estimateTokenCount(List<ChatMessage> messages) {
return geminiTokenizer.estimateTokenCountInMessages(messages);
}
@Override
public Set<Capability> supportedCapabilities() {
Set<Capability> capabilities = new HashSet<>();
@ -328,29 +334,6 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel {
return capabilities;
}
private GeminiService getGeminiService() {
Retrofit.Builder retrofitBuilder = new Retrofit.Builder()
.baseUrl(GEMINI_AI_ENDPOINT)
.addConverterFactory(GsonConverterFactory.create());
if (this.logRequestsAndResponses) {
HttpLoggingInterceptor logging = new HttpLoggingInterceptor(log::debug);
logging.redactHeader(API_KEY_HEADER_NAME);
logging.setLevel(HttpLoggingInterceptor.Level.BODY);
OkHttpClient okHttpClient = new OkHttpClient.Builder()
.addInterceptor(logging)
.callTimeout(this.timeout)
.build();
retrofitBuilder.client(okHttpClient);
}
Retrofit retrofit = retrofitBuilder.build();
return retrofit.create(GeminiService.class);
}
public static class GoogleAiGeminiChatModelBuilder {
public GoogleAiGeminiChatModelBuilder toolConfig(GeminiMode mode, String... allowedFunctionNames) {
this.toolConfig = new GeminiFunctionCallingConfig(mode, Arrays.asList(allowedFunctionNames));

View File

@ -0,0 +1,128 @@
package dev.langchain4j.model.googleai;
import com.google.gson.Gson;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.Tokenizer;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import okhttp3.ResponseBody;
import retrofit2.Call;
import java.io.IOException;
import java.time.Duration;
import java.util.LinkedList;
import java.util.List;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromMessageToGContent;
import static java.util.Collections.singletonList;
@Slf4j
public class GoogleAiGeminiTokenizer implements Tokenizer {
private static final Gson GSON = new Gson();
private final GeminiService geminiService;
private final String modelName;
private final String apiKey;
private final Integer maxRetries;
@Builder
GoogleAiGeminiTokenizer(
String modelName,
String apiKey,
Boolean logRequestsAndResponses,
Duration timeout,
Integer maxRetries
) {
this.modelName = ensureNotBlank(modelName, "modelName");
this.apiKey = ensureNotBlank(apiKey, "apiKey");
this.maxRetries = getOrDefault(maxRetries, 3);
this.geminiService = GeminiService.getGeminiService(logRequestsAndResponses ? log : null,
timeout != null ? timeout : Duration.ofSeconds(60));
}
@Override
public int estimateTokenCountInText(String text) {
return estimateTokenCountInMessages(singletonList(UserMessage.from(text)));
}
@Override
public int estimateTokenCountInMessage(ChatMessage message) {
return estimateTokenCountInMessages(singletonList(message));
}
@Override
public int estimateTokenCountInToolExecutionRequests(Iterable<ToolExecutionRequest> toolExecutionRequests) {
List<ToolExecutionRequest> allToolRequests = new LinkedList<>();
toolExecutionRequests.forEach(allToolRequests::add);
return estimateTokenCountInMessage(AiMessage.from(allToolRequests));
}
@Override
public int estimateTokenCountInMessages(Iterable<ChatMessage> messages) {
List<ChatMessage> allMessages = new LinkedList<>();
messages.forEach(allMessages::add);
List<GeminiContent> geminiContentList = fromMessageToGContent(allMessages, null);
GeminiCountTokensRequest countTokensRequest = new GeminiCountTokensRequest();
countTokensRequest.setContents(geminiContentList);
return estimateTokenCount(countTokensRequest);
}
@Override
public int estimateTokenCountInToolSpecifications(Iterable<ToolSpecification> toolSpecifications) {
List<ToolSpecification> allTools = new LinkedList<>();
toolSpecifications.forEach(allTools::add);
GeminiContent dummyContent = GeminiContent.builder().parts(
singletonList(GeminiPart.builder()
.text("Dummy content") // This string contains 2 tokens
.build())
).build();
GeminiCountTokensRequest countTokensRequestWithDummyContent = new GeminiCountTokensRequest();
countTokensRequestWithDummyContent.setGenerateContentRequest(GeminiGenerateContentRequest.builder()
.model("models/" + this.modelName)
.contents(singletonList(dummyContent))
.tools(FunctionMapper.fromToolSepcsToGTool(allTools, false))
.build());
// The API doesn't allow us to make a request to count the tokens of the tool specifications only.
// Instead, we take the approach of adding a dummy content in the request, and subtract the tokens for the dummy request.
// The string "Dummy content" accounts for 2 tokens. So let's subtract 2 from the overall count.
return estimateTokenCount(countTokensRequestWithDummyContent) - 2;
}
private int estimateTokenCount(GeminiCountTokensRequest countTokensRequest) {
Call<GeminiCountTokensResponse> responseCall =
withRetry(() -> this.geminiService.countTokens(this.modelName, this.apiKey, countTokensRequest), this.maxRetries);
GeminiCountTokensResponse countTokensResponse;
try {
retrofit2.Response<GeminiCountTokensResponse> executed = responseCall.execute();
countTokensResponse = executed.body();
if (executed.code() >= 300) {
try (ResponseBody errorBody = executed.errorBody()) {
GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError();
throw new RuntimeException(
String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
}
}
} catch (IOException e) {
throw new RuntimeException("An error occurred when calling the Gemini API endpoint to calculate tokens count", e);
}
return countTokensResponse.getTotalTokens();
}
}

View File

@ -0,0 +1,113 @@
package dev.langchain4j.model.googleai;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.output.Response;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.assertj.core.api.Assertions.assertThat;
public class GoogleAiEmbeddingModelIT {
private static final String GOOGLE_AI_GEMINI_API_KEY = System.getenv("GOOGLE_AI_GEMINI_API_KEY");
@Test
void should_embed_one_text() {
// given
GoogleAiEmbeddingModel embeddingModel = GoogleAiEmbeddingModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("embedding-001")
.maxRetries(3)
.logRequestsAndResponses(true)
.build();
// when
Response<Embedding> embed = embeddingModel.embed("Hello world!");
// then
Embedding content = embed.content();
assertThat(content).isNotNull();
assertThat(content.vector()).isNotNull();
assertThat(content.vector()).hasSize(768);
}
@Test
void should_use_metadata() {
// given
GoogleAiEmbeddingModel embeddingModel = GoogleAiEmbeddingModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("embedding-001")
.maxRetries(3)
.logRequestsAndResponses(true)
.titleMetadataKey("title")
.taskType(GoogleAiEmbeddingModel.TaskType.RETRIEVAL_DOCUMENT)
.build();
// when
TextSegment textSegment = TextSegment.from(
"What is the capital of France?",
Metadata.from("title", "document title")
);
Response<Embedding> embed = embeddingModel.embed(textSegment);
// then
Embedding content = embed.content();
assertThat(content).isNotNull();
assertThat(content.vector()).isNotNull();
}
@Test
void should_embed_in_batch() {
// given
GoogleAiEmbeddingModel embeddingModel = GoogleAiEmbeddingModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("embedding-001")
.maxRetries(3)
.logRequestsAndResponses(true)
.outputDimensionality(512)
.build();
// when
List<TextSegment> textSegments = Arrays.asList(
TextSegment.from("What is the capital of France?"),
TextSegment.from("What is the capital of Germany?")
);
Response<List<Embedding>> embed = embeddingModel.embedAll(textSegments);
// then
List<Embedding> embeddings = embed.content();
assertThat(embeddings).isNotNull();
assertThat(embeddings).hasSize(2);
assertThat(embeddings.get(0).vector()).isNotNull();
assertThat(embeddings.get(0).vector()).hasSize(512);
assertThat(embeddings.get(1).vector()).isNotNull();
assertThat(embeddings.get(1).vector()).hasSize(512);
}
@Test
void should_embed_more_than_100() {
// given
GoogleAiEmbeddingModel embeddingModel = GoogleAiEmbeddingModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("text-embedding-004")
.maxRetries(3)
.build();
// when
List<TextSegment> textSegments = new ArrayList<>();
for (int i = 0; i < 300; i++) {
textSegments.add(TextSegment.from("What is the capital of France? "));
}
Response<List<Embedding>> allEmbeddings = embeddingModel.embedAll(textSegments);
// then
assertThat(allEmbeddings.content()).hasSize(300);
}
}

View File

@ -589,10 +589,10 @@ public class GoogleAiGeminiChatModelIT {
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchemas.jsonSchemaFrom(Color.class).get())
.build())
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchemas.jsonSchemaFrom(Color.class).get())
.build())
// Equivalent to:
// .responseFormat(ResponseFormat.builder()
// .type(JSON)
@ -674,6 +674,35 @@ public class GoogleAiGeminiChatModelIT {
assertThat(chatResponse.aiMessage().hasToolExecutionRequests()).isFalse();
}
@Test
void should_count_tokens() {
// given
GoogleAiGeminiChatModel gemini = GoogleAiGeminiChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.build();
// when
int countedTokens = gemini.estimateTokenCount("What is the capital of France?");
// then
assertThat(countedTokens).isGreaterThan(0);
// when
List<ChatMessage> messageList = Arrays.asList(
SystemMessage.from("You are a helpful geography teacher"),
UserMessage.from("What is the capital of Germany?"),
AiMessage.from("Berlin"),
UserMessage.from("Thank you!"),
AiMessage.from("You're welcome!")
);
int listOfMsgTokenCount = gemini.estimateTokenCount(messageList);
// then
assertThat(listOfMsgTokenCount).isGreaterThan(0);
}
@AfterEach
void afterEach() throws InterruptedException {
String ciDelaySeconds = System.getenv("CI_DELAY_SECONDS_GOOGLE_AI_GEMINI");

View File

@ -0,0 +1,127 @@
package dev.langchain4j.model.googleai;
import dev.langchain4j.agent.tool.JsonSchemaProperty;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.UserMessage;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
public class GoogleAiGeminiTokenizerTest {
private static final String GOOGLE_AI_GEMINI_API_KEY = System.getenv("GOOGLE_AI_GEMINI_API_KEY");
@Test
void should_estimate_token_count_for_text() {
// given
GoogleAiGeminiTokenizer tokenizer = GoogleAiGeminiTokenizer.builder()
.logRequestsAndResponses(true)
.modelName("gemini-1.5-flash")
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.build();
// when
int count = tokenizer.estimateTokenCountInText("Hello world!");
// then
assertThat(count).isEqualTo(4);
}
@Test
void should_estimate_token_count_for_a_message() {
// given
GoogleAiGeminiTokenizer tokenizer = GoogleAiGeminiTokenizer.builder()
.logRequestsAndResponses(true)
.modelName("gemini-1.5-flash")
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.build();
// when
int count = tokenizer.estimateTokenCountInMessage(UserMessage.from("Hello World!"));
// then
assertThat(count).isEqualTo(4);
}
@Test
void should_estimate_token_count_for_list_of_messages() {
// given
GoogleAiGeminiTokenizer tokenizer = GoogleAiGeminiTokenizer.builder()
.logRequestsAndResponses(true)
.modelName("gemini-1.5-flash")
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.build();
// when
int count = tokenizer.estimateTokenCountInMessages(
Arrays.asList(
UserMessage.from("Hello World!"),
AiMessage.from("Hi! How can I help you today?")
)
);
// then
assertThat(count).isEqualTo(14);
}
@Test
void should_estimate_token_count_for_tool_exec_reqs() {
// given
GoogleAiGeminiTokenizer tokenizer = GoogleAiGeminiTokenizer.builder()
.logRequestsAndResponses(true)
.modelName("gemini-1.5-flash")
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.build();
// when
int count = tokenizer.estimateTokenCountInToolExecutionRequests(
Arrays.asList(
ToolExecutionRequest.builder()
.name("weatherForecast")
.arguments("{ \"location\": \"Paris\" }")
.build(),
ToolExecutionRequest.builder()
.name("weatherForecast")
.arguments("{ \"location\": \"London\" }")
.build()
)
);
// then
assertThat(count).isEqualTo(29);
}
@Test
void should_estimate_token_count_for_tool_specs() {
// given
GoogleAiGeminiTokenizer tokenizer = GoogleAiGeminiTokenizer.builder()
.logRequestsAndResponses(true)
.modelName("gemini-1.5-flash")
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.build();
// when
int count = tokenizer.estimateTokenCountInToolSpecifications(
Arrays.asList(
ToolSpecification.builder()
.name("weatherForecast")
.description("Get the weather forecast for a given location on a give date")
.addParameter("location", JsonSchemaProperty.STRING, JsonSchemaProperty.description("the location"))
.addParameter("date", JsonSchemaProperty.STRING, JsonSchemaProperty.description("the date"))
.build(),
ToolSpecification.builder()
.name("convertFahrenheitToCelsius")
.description("Convert a temperature in Fahrenheit to Celsius")
.addParameter("fahrenheit", JsonSchemaProperty.NUMBER, JsonSchemaProperty.description("the temperature in Fahrenheit"))
.build()
)
);
// then
assertThat(count).isEqualTo(114);
}
}

View File

@ -114,7 +114,7 @@
<artifactId>libraries-bom</artifactId>
<scope>import</scope>
<type>pom</type>
<version>26.45.0</version>
<version>26.46.0</version>
</dependency>
</dependencies>
</dependencyManagement>

View File

@ -9,6 +9,7 @@ import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.chat.listener.*;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.vertexai.spi.VertexAiGeminiChatModelBuilderFactory;
@ -82,6 +83,7 @@ public class VertexAiGeminiChatModel implements ChatLanguageModel, Closeable {
Integer maxOutputTokens,
Integer topK,
Float topP,
Integer seed,
Integer maxRetries,
String responseMimeType,
Schema responseSchema,
@ -106,6 +108,9 @@ public class VertexAiGeminiChatModel implements ChatLanguageModel, Closeable {
if (topP != null) {
generationConfigBuilder.setTopP(topP);
}
if (seed != null) {
generationConfigBuilder.setSeed(seed);
}
if (responseMimeType != null) {
generationConfigBuilder.setResponseMimeType(responseMimeType);
}

View File

@ -24,7 +24,7 @@
<dependency>
<groupId>com.google.cloud</groupId>
<artifactId>google-cloud-aiplatform</artifactId>
<version>3.49.0</version>
<version>3.50.0</version>
<exclusions>
<exclusion>
<groupId>commons-logging</groupId>

View File

@ -132,7 +132,7 @@ public class VertexAiEmbeddingModel extends DimensionAwareEmbeddingModel {
embeddingInstance.setTaskType(taskType);
if (this.taskType.equals(TaskType.RETRIEVAL_DOCUMENT)) {
// Title metadata is used for calculating embeddings for document retrieval
embeddingInstance.setTitle(segment.metadata(titleMetadataKey));
embeddingInstance.setTitle(segment.metadata().getString(titleMetadataKey));
}
}