Google AI Gemini: replace OkHttp and Retrofit with Java 11 HttpClient (#1950)

## Issue
Based on #1903

## Change
Replaced OkHttp and Retrofit inside the GeminiService with an
implementation using the HttpClient (Java 11).


## General checklist
- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change (already
existing)
- [X] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [X] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
This commit is contained in:
Bjarne-Kinkel 2024-10-23 16:29:46 +02:00 committed by GitHub
parent 17d8384d75
commit 2ae3983177
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 181 additions and 300 deletions

View File

@ -25,39 +25,6 @@
<artifactId>langchain4j-core</artifactId>
</dependency>
<!-- Retrofit REST client -->
<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>retrofit</artifactId>
</dependency>
<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>converter-gson</artifactId>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>logging-interceptor</artifactId>
<version>4.12.0</version>
<scope>compile</scope>
</dependency>
<!--
<dependency>
<groupId>io.reactivex.rxjava3</groupId>
<artifactId>rxjava</artifactId>
<version>3.1.9</version>
</dependency>
<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>adapter-rxjava3</artifactId>
<version>2.11.0</version>
</dependency>
-->
<!-- Lombok for @Data and @Builder -->
<dependency>
<groupId>org.projectlombok</groupId>

View File

@ -1,85 +1,96 @@
package dev.langchain4j.model.googleai;
//import io.reactivex.rxjava3.core.Observable;
import okhttp3.OkHttpClient;
import okhttp3.logging.HttpLoggingInterceptor;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import org.slf4j.Logger;
import retrofit2.Call;
import retrofit2.Retrofit;
import retrofit2.converter.gson.GsonConverterFactory;
import retrofit2.http.Body;
import retrofit2.http.POST;
import retrofit2.http.Path;
import retrofit2.http.Header;
import retrofit2.http.Headers;
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
//import retrofit2.http.Streaming;
interface GeminiService {
String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta/";
String API_KEY_HEADER_NAME = "x-goog-api-key";
String USER_AGENT = "User-Agent: LangChain4j";
class GeminiService {
private static final String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta";
private static final String API_KEY_HEADER_NAME = "x-goog-api-key";
static GeminiService getGeminiService(Logger logger, Duration timeout) {
Retrofit.Builder retrofitBuilder = new Retrofit.Builder()
.baseUrl(GEMINI_AI_ENDPOINT)
.addConverterFactory(GsonConverterFactory.create());
private final HttpClient httpClient;
private final Gson gson;
private final Logger logger;
OkHttpClient.Builder clientBuilder = new OkHttpClient.Builder()
.callTimeout(timeout);
GeminiService(Logger logger, Duration timeout) {
this.logger = logger;
this.gson = new GsonBuilder().setPrettyPrinting().create();
if (logger != null) {
HttpLoggingInterceptor logging = new HttpLoggingInterceptor(logger::debug);
logging.redactHeader(API_KEY_HEADER_NAME);
logging.setLevel(HttpLoggingInterceptor.Level.BODY);
clientBuilder.addInterceptor(logging);
}
retrofitBuilder.client(clientBuilder.build());
Retrofit retrofit = retrofitBuilder.build();
return retrofit.create(GeminiService.class);
this.httpClient = HttpClient.newBuilder()
.connectTimeout(timeout)
.build();
}
@POST("models/{model}:generateContent")
@Headers(USER_AGENT)
Call<GeminiGenerateContentResponse> generateContent(
@Path("model") String modelName,
@Header(API_KEY_HEADER_NAME) String apiKey,
@Body GeminiGenerateContentRequest request);
GeminiGenerateContentResponse generateContent(String modelName, String apiKey, GeminiGenerateContentRequest request) {
String url = String.format("%s/models/%s:generateContent", GEMINI_AI_ENDPOINT, modelName);
return sendRequest(url, apiKey, request, GeminiGenerateContentResponse.class);
}
@POST("models/{model}:countTokens")
@Headers(USER_AGENT)
Call<GeminiCountTokensResponse> countTokens(
@Path("model") String modelName,
@Header(API_KEY_HEADER_NAME) String apiKey,
@Body GeminiCountTokensRequest countTokensRequest);
GeminiCountTokensResponse countTokens(String modelName, String apiKey, GeminiCountTokensRequest request) {
String url = String.format("%s/models/%s:countTokens", GEMINI_AI_ENDPOINT, modelName);
return sendRequest(url, apiKey, request, GeminiCountTokensResponse.class);
}
@POST("models/{model}:embedContent")
@Headers(USER_AGENT)
Call<GoogleAiEmbeddingResponse> embed(
@Path("model") String modelName,
@Header(API_KEY_HEADER_NAME) String apiKey,
@Body GoogleAiEmbeddingRequest embeddingRequest);
GoogleAiEmbeddingResponse embed(String modelName, String apiKey, GoogleAiEmbeddingRequest request) {
String url = String.format("%s/models/%s:embedContent", GEMINI_AI_ENDPOINT, modelName);
return sendRequest(url, apiKey, request, GoogleAiEmbeddingResponse.class);
}
@POST("models/{model}:batchEmbedContents")
@Headers(USER_AGENT)
Call<GoogleAiBatchEmbeddingResponse> batchEmbed(
@Path("model") String modelName,
@Header(API_KEY_HEADER_NAME) String apiKey,
@Body GoogleAiBatchEmbeddingRequest batchEmbeddingRequest);
GoogleAiBatchEmbeddingResponse batchEmbed(String modelName, String apiKey, GoogleAiBatchEmbeddingRequest request) {
String url = String.format("%s/models/%s:batchEmbedContents", GEMINI_AI_ENDPOINT, modelName);
return sendRequest(url, apiKey, request, GoogleAiBatchEmbeddingResponse.class);
}
/*
@Streaming
@POST("models/{model}:streamGenerateContent")
@Headers("User-Agent: LangChain4j")
Observable<GeminiGenerateContentResponse> streamGenerateContent(
@Path("model") String modelName,
@Header(API_KEY_HEADER_NAME) String apiKey,
@Body GeminiGenerateContentRequest request);
*/
private <T> T sendRequest(String url, String apiKey, Object requestBody, Class<T> responseType) {
String jsonBody = gson.toJson(requestBody);
HttpRequest request = buildHttpRequest(url, apiKey, jsonBody);
logRequest(jsonBody);
try {
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
if (response.statusCode() >= 300) {
throw new RuntimeException(String.format("HTTP error (%d): %s", response.statusCode(), response.body()));
}
logResponse(response.body());
return gson.fromJson(response.body(), responseType);
} catch (IOException e) {
throw new RuntimeException("An error occurred while sending the request", e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("Sending the request was interrupted", e);
}
}
private HttpRequest buildHttpRequest(String url, String apiKey, String jsonBody) {
return HttpRequest.newBuilder()
.uri(URI.create(url))
.header("Content-Type", "application/json")
.header("User-Agent", "LangChain4j")
.header(API_KEY_HEADER_NAME, apiKey)
.POST(HttpRequest.BodyPublishers.ofString(jsonBody))
.build();
}
private void logRequest(String jsonBody) {
if (logger != null) {
logger.debug("Sending request to Gemini:\n{}", jsonBody);
}
}
private void logResponse(String responseBody) {
if (logger != null) {
logger.debug("Response from Gemini:\n{}", responseBody);
}
}
}

View File

@ -8,10 +8,7 @@ import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import okhttp3.ResponseBody;
import retrofit2.Call;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
@ -29,8 +26,6 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
private final GeminiService geminiService;
private final Gson GSON = new Gson();
private final String modelName;
private final String apiKey;
private final Integer maxRetries;
@ -40,14 +35,14 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
@Builder
public GoogleAiEmbeddingModel(
String modelName,
String apiKey,
Integer maxRetries,
TaskType taskType,
String titleMetadataKey,
Integer outputDimensionality,
Duration timeout,
Boolean logRequestsAndResponses
String modelName,
String apiKey,
Integer maxRetries,
TaskType taskType,
String titleMetadataKey,
Integer outputDimensionality,
Duration timeout,
Boolean logRequestsAndResponses
) {
this.modelName = ensureNotBlank(modelName, "modelName");
@ -64,33 +59,14 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
boolean logRequestsAndResponses1 = logRequestsAndResponses != null && logRequestsAndResponses;
this.geminiService = GeminiService.getGeminiService(logRequestsAndResponses1 ? log : null, timeout1);
this.geminiService = new GeminiService(logRequestsAndResponses1 ? log : null, timeout1);
}
@Override
public Response<Embedding> embed(TextSegment textSegment) {
GoogleAiEmbeddingRequest embeddingRequest = getGoogleAiEmbeddingRequest(textSegment);
Call<GoogleAiEmbeddingResponse> geminiEmbeddingResponseCall =
withRetry(() -> this.geminiService.embed(this.modelName, this.apiKey, embeddingRequest), this.maxRetries);
GoogleAiEmbeddingResponse geminiResponse;
try {
retrofit2.Response<GoogleAiEmbeddingResponse> executed = geminiEmbeddingResponseCall.execute();
geminiResponse = executed.body();
if (executed.code() >= 300) {
try (ResponseBody errorBody = executed.errorBody()) {
GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError();
throw new RuntimeException(
String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
}
}
} catch (IOException e) {
throw new RuntimeException("An error occurred when calling the Gemini API endpoint (embed).", e);
}
GoogleAiEmbeddingResponse geminiResponse = withRetry(() -> this.geminiService.embed(this.modelName, this.apiKey, embeddingRequest), this.maxRetries);
if (geminiResponse != null) {
return Response.from(Embedding.from(geminiResponse.getEmbedding().getValues()));
@ -107,8 +83,8 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
List<GoogleAiEmbeddingRequest> embeddingRequests = textSegments.stream()
.map(this::getGoogleAiEmbeddingRequest)
.collect(Collectors.toList());
.map(this::getGoogleAiEmbeddingRequest)
.collect(Collectors.toList());
List<Embedding> allEmbeddings = new ArrayList<>();
int numberOfEmbeddings = embeddingRequests.size();
@ -123,30 +99,12 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
GoogleAiBatchEmbeddingRequest batchEmbeddingRequest = new GoogleAiBatchEmbeddingRequest();
batchEmbeddingRequest.setRequests(embeddingRequests.subList(startIndex, lastIndex));
Call<GoogleAiBatchEmbeddingResponse> geminiBatchEmbeddingResponseCall =
withRetry(() -> this.geminiService.batchEmbed(this.modelName, this.apiKey, batchEmbeddingRequest));
GoogleAiBatchEmbeddingResponse geminiResponse;
try {
retrofit2.Response<GoogleAiBatchEmbeddingResponse> executed = geminiBatchEmbeddingResponseCall.execute();
geminiResponse = executed.body();
if (executed.code() >= 300) {
try (ResponseBody errorBody = executed.errorBody()) {
GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError();
throw new RuntimeException(
String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
}
}
} catch (IOException e) {
throw new RuntimeException("An error occurred when calling the Gemini API endpoint (embedAll).", e);
}
GoogleAiBatchEmbeddingResponse geminiResponse = withRetry(() -> this.geminiService.batchEmbed(this.modelName, this.apiKey, batchEmbeddingRequest));
if (geminiResponse != null) {
allEmbeddings.addAll(geminiResponse.getEmbeddings().stream()
.map(values -> Embedding.from(values.getValues()))
.collect(Collectors.toList()));
.map(values -> Embedding.from(values.getValues()))
.collect(Collectors.toList()));
} else {
throw new RuntimeException("Gemini embedding response was null (embedAll)");
}
@ -157,8 +115,8 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
private GoogleAiEmbeddingRequest getGoogleAiEmbeddingRequest(TextSegment textSegment) {
GeminiPart geminiPart = GeminiPart.builder()
.text(textSegment.text())
.build();
.text(textSegment.text())
.build();
GeminiContent content = new GeminiContent(Collections.singletonList(geminiPart), null);
@ -170,11 +128,11 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
}
return new GoogleAiEmbeddingRequest(
"models/" + this.modelName,
content,
this.taskType,
title,
this.outputDimensionality
"models/" + this.modelName,
content,
this.taskType,
title,
this.outputDimensionality
);
}

View File

@ -19,10 +19,7 @@ import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import okhttp3.ResponseBody;
import retrofit2.Call;
import java.io.IOException;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
@ -49,8 +46,6 @@ import static java.util.Collections.emptyList;
@Experimental
@Slf4j
public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEstimator {
private static final Gson GSON = new Gson();
private final GeminiService geminiService;
private final String apiKey;
@ -116,18 +111,18 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
this.geminiService = GeminiService.getGeminiService(
getOrDefault(logRequestsAndResponses, false) ? this.log : null,
getOrDefault(timeout, ofSeconds(60))
this.geminiService = new GeminiService(
getOrDefault(logRequestsAndResponses, false) ? log : null,
getOrDefault(timeout, ofSeconds(60))
);
this.geminiTokenizer = GoogleAiGeminiTokenizer.builder()
.modelName(this.modelName)
.apiKey(this.apiKey)
.timeout(getOrDefault(timeout, ofSeconds(60)))
.maxRetries(this.maxRetries)
.logRequestsAndResponses(this.logRequestsAndResponses)
.build();
.modelName(this.modelName)
.apiKey(this.apiKey)
.timeout(getOrDefault(timeout, ofSeconds(60)))
.maxRetries(this.maxRetries)
.logRequestsAndResponses(this.logRequestsAndResponses)
.build();
}
private static String computeMimeType(ResponseFormat responseFormat) {
@ -136,9 +131,9 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
}
if (ResponseFormatType.JSON.equals(responseFormat.type()) &&
responseFormat.jsonSchema() != null &&
responseFormat.jsonSchema().rootElement() != null &&
responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema) {
responseFormat.jsonSchema() != null &&
responseFormat.jsonSchema().rootElement() != null &&
responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema) {
return "text/x.enum";
}
@ -149,14 +144,14 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
ChatRequest request = ChatRequest.builder()
.messages(messages)
.build();
.messages(messages)
.build();
ChatResponse response = chat(request);
return Response.from(response.aiMessage(),
response.tokenUsage(),
response.finishReason());
response.tokenUsage(),
response.finishReason());
}
@Override
@ -167,15 +162,15 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
@Override
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
ChatRequest request = ChatRequest.builder()
.messages(messages)
.toolSpecifications(toolSpecifications)
.build();
.messages(messages)
.toolSpecifications(toolSpecifications)
.build();
ChatResponse response = chat(request);
return Response.from(response.aiMessage(),
response.tokenUsage(),
response.finishReason());
response.tokenUsage(),
response.finishReason());
}
@Override
@ -194,31 +189,31 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
}
GeminiGenerateContentRequest request = GeminiGenerateContentRequest.builder()
.contents(geminiContentList)
.systemInstruction(!systemInstruction.getParts().isEmpty() ? systemInstruction : null)
.generationConfig(GeminiGenerationConfig.builder()
.candidateCount(this.candidateCount)
.maxOutputTokens(this.maxOutputTokens)
.responseMimeType(responseMimeType)
.responseSchema(schema)
.stopSequences(this.stopSequences)
.temperature(this.temperature)
.topK(this.topK)
.topP(this.topP)
.build())
.safetySettings(this.safetySettings)
.tools(FunctionMapper.fromToolSepcsToGTool(toolSpecifications, this.allowCodeExecution))
.toolConfig(new GeminiToolConfig(this.toolConfig))
.build();
.contents(geminiContentList)
.systemInstruction(!systemInstruction.getParts().isEmpty() ? systemInstruction : null)
.generationConfig(GeminiGenerationConfig.builder()
.candidateCount(this.candidateCount)
.maxOutputTokens(this.maxOutputTokens)
.responseMimeType(responseMimeType)
.responseSchema(schema)
.stopSequences(this.stopSequences)
.temperature(this.temperature)
.topK(this.topK)
.topP(this.topP)
.build())
.safetySettings(this.safetySettings)
.tools(FunctionMapper.fromToolSepcsToGTool(toolSpecifications, this.allowCodeExecution))
.toolConfig(new GeminiToolConfig(this.toolConfig))
.build();
ChatModelRequest chatModelRequest = ChatModelRequest.builder()
.model(modelName)
.temperature(temperature)
.topP(topP)
.maxTokens(maxOutputTokens)
.messages(chatRequest.messages())
.toolSpecifications(chatRequest.toolSpecifications())
.build();
.model(modelName)
.temperature(temperature)
.topP(topP)
.maxTokens(maxOutputTokens)
.messages(chatRequest.messages())
.toolSpecifications(chatRequest.toolSpecifications())
.build();
ConcurrentHashMap<Object, Object> listenerAttributes = new ConcurrentHashMap<>();
ChatModelRequestContext chatModelRequestContext = new ChatModelRequestContext(chatModelRequest, listenerAttributes);
listeners.forEach((listener) -> {
@ -229,40 +224,12 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
}
});
Call<GeminiGenerateContentResponse> responseCall =
withRetry(() -> this.geminiService.generateContent(this.modelName, this.apiKey, request), this.maxRetries);
GeminiGenerateContentResponse geminiResponse;
try {
retrofit2.Response<GeminiGenerateContentResponse> executed = responseCall.execute();
geminiResponse = executed.body();
if (executed.code() >= 300) {
try (ResponseBody errorBody = executed.errorBody()) {
GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError();
RuntimeException runtimeException = new RuntimeException(
String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
ChatModelErrorContext chatModelErrorContext = new ChatModelErrorContext(
runtimeException, chatModelRequest, null, listenerAttributes
);
listeners.forEach((listener) -> {
try {
listener.onError(chatModelErrorContext);
} catch (Exception e) {
log.warn("Exception while calling model listener (onError)", e);
}
});
throw runtimeException;
}
}
} catch (IOException e) {
RuntimeException runtimeException = new RuntimeException("An error occurred when calling the Gemini API endpoint.", e);
geminiResponse = withRetry(() -> this.geminiService.generateContent(this.modelName, this.apiKey, request), this.maxRetries);
} catch (RuntimeException e) {
ChatModelErrorContext chatModelErrorContext = new ChatModelErrorContext(
e, chatModelRequest, null, listenerAttributes
e, chatModelRequest, null, listenerAttributes
);
listeners.forEach((listener) -> {
try {
@ -272,7 +239,7 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
}
});
throw runtimeException;
throw e;
}
if (geminiResponse != null) {
@ -284,23 +251,23 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
FinishReason finishReason = fromGFinishReasonToFinishReason(firstCandidate.getFinishReason());
if (firstCandidate.getContent() == null) {
aiMessage = AiMessage.from("No text was returned by the model. " +
"The model finished generating because of the following reason: " + finishReason);
"The model finished generating because of the following reason: " + finishReason);
} else {
aiMessage = fromGPartsToAiMessage(firstCandidate.getContent().getParts(), this.includeCodeExecutionOutput);
}
TokenUsage tokenUsage = new TokenUsage(tokenCounts.getPromptTokenCount(),
tokenCounts.getCandidatesTokenCount(),
tokenCounts.getTotalTokenCount());
tokenCounts.getCandidatesTokenCount(),
tokenCounts.getTotalTokenCount());
ChatModelResponse chatModelResponse = ChatModelResponse.builder()
.model(modelName)
.tokenUsage(tokenUsage)
.finishReason(finishReason)
.aiMessage(aiMessage)
.build();
.model(modelName)
.tokenUsage(tokenUsage)
.finishReason(finishReason)
.aiMessage(aiMessage)
.build();
ChatModelResponseContext chatModelResponseContext = new ChatModelResponseContext(
chatModelResponse, chatModelRequest, listenerAttributes);
chatModelResponse, chatModelRequest, listenerAttributes);
listeners.forEach((listener) -> {
try {
listener.onResponse(chatModelResponseContext);
@ -310,10 +277,10 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
});
return ChatResponse.builder()
.aiMessage(aiMessage)
.finishReason(finishReason)
.tokenUsage(tokenUsage)
.build();
.aiMessage(aiMessage)
.finishReason(finishReason)
.tokenUsage(tokenUsage)
.build();
} else {
throw new RuntimeException("Gemini response was null");
}
@ -342,8 +309,8 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
public GoogleAiGeminiChatModelBuilder safetySettings(Map<GeminiHarmCategory, GeminiHarmBlockThreshold> safetySettingMap) {
this.safetySettings = safetySettingMap.entrySet().stream()
.map(entry -> new GeminiSafetySetting(entry.getKey(), entry.getValue())
).collect(Collectors.toList());
.map(entry -> new GeminiSafetySetting(entry.getKey(), entry.getValue())
).collect(Collectors.toList());
return this;
}
}

View File

@ -9,10 +9,7 @@ import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.Tokenizer;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import okhttp3.ResponseBody;
import retrofit2.Call;
import java.io.IOException;
import java.time.Duration;
import java.util.LinkedList;
import java.util.List;
@ -25,8 +22,6 @@ import static java.util.Collections.singletonList;
@Slf4j
public class GoogleAiGeminiTokenizer implements Tokenizer {
private static final Gson GSON = new Gson();
private final GeminiService geminiService;
private final String modelName;
private final String apiKey;
@ -34,17 +29,19 @@ public class GoogleAiGeminiTokenizer implements Tokenizer {
@Builder
GoogleAiGeminiTokenizer(
String modelName,
String apiKey,
Boolean logRequestsAndResponses,
Duration timeout,
Integer maxRetries
String modelName,
String apiKey,
Boolean logRequestsAndResponses,
Duration timeout,
Integer maxRetries
) {
this.modelName = ensureNotBlank(modelName, "modelName");
this.apiKey = ensureNotBlank(apiKey, "apiKey");
this.maxRetries = getOrDefault(maxRetries, 3);
this.geminiService = GeminiService.getGeminiService(logRequestsAndResponses ? log : null,
timeout != null ? timeout : Duration.ofSeconds(60));
this.geminiService = new GeminiService(
logRequestsAndResponses ? log : null,
timeout != null ? timeout : Duration.ofSeconds(60)
);
}
@Override
@ -84,17 +81,17 @@ public class GoogleAiGeminiTokenizer implements Tokenizer {
toolSpecifications.forEach(allTools::add);
GeminiContent dummyContent = GeminiContent.builder().parts(
singletonList(GeminiPart.builder()
.text("Dummy content") // This string contains 2 tokens
.build())
singletonList(GeminiPart.builder()
.text("Dummy content") // This string contains 2 tokens
.build())
).build();
GeminiCountTokensRequest countTokensRequestWithDummyContent = new GeminiCountTokensRequest();
countTokensRequestWithDummyContent.setGenerateContentRequest(GeminiGenerateContentRequest.builder()
.model("models/" + this.modelName)
.contents(singletonList(dummyContent))
.tools(FunctionMapper.fromToolSepcsToGTool(allTools, false))
.build());
.model("models/" + this.modelName)
.contents(singletonList(dummyContent))
.tools(FunctionMapper.fromToolSepcsToGTool(allTools, false))
.build());
// The API doesn't allow us to make a request to count the tokens of the tool specifications only.
// Instead, we take the approach of adding a dummy content in the request, and subtract the tokens for the dummy request.
@ -103,26 +100,7 @@ public class GoogleAiGeminiTokenizer implements Tokenizer {
}
private int estimateTokenCount(GeminiCountTokensRequest countTokensRequest) {
Call<GeminiCountTokensResponse> responseCall =
withRetry(() -> this.geminiService.countTokens(this.modelName, this.apiKey, countTokensRequest), this.maxRetries);
GeminiCountTokensResponse countTokensResponse;
try {
retrofit2.Response<GeminiCountTokensResponse> executed = responseCall.execute();
countTokensResponse = executed.body();
if (executed.code() >= 300) {
try (ResponseBody errorBody = executed.errorBody()) {
GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError();
throw new RuntimeException(
String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
}
}
} catch (IOException e) {
throw new RuntimeException("An error occurred when calling the Gemini API endpoint to calculate tokens count", e);
}
GeminiCountTokensResponse countTokensResponse = withRetry(() -> this.geminiService.countTokens(this.modelName, this.apiKey, countTokensRequest), this.maxRetries);
return countTokensResponse.getTotalTokens();
}
}