Google AI Gemini: replace OkHttp and Retrofit with Java 11 HttpClient (#1950)
## Issue Based on #1903 ## Change Replaced OkHttp and Retrofit inside the GeminiService with an implementation using the HttpClient (Java 11). ## General checklist - [X] There are no breaking changes - [X] I have added unit and integration tests for my change (already existing) - [X] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [X] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable)
This commit is contained in:
parent
17d8384d75
commit
2ae3983177
|
@ -25,39 +25,6 @@
|
|||
<artifactId>langchain4j-core</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- Retrofit REST client -->
|
||||
<dependency>
|
||||
<groupId>com.squareup.retrofit2</groupId>
|
||||
<artifactId>retrofit</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.squareup.retrofit2</groupId>
|
||||
<artifactId>converter-gson</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.squareup.okhttp3</groupId>
|
||||
<artifactId>okhttp</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.squareup.okhttp3</groupId>
|
||||
<artifactId>logging-interceptor</artifactId>
|
||||
<version>4.12.0</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
|
||||
<!--
|
||||
<dependency>
|
||||
<groupId>io.reactivex.rxjava3</groupId>
|
||||
<artifactId>rxjava</artifactId>
|
||||
<version>3.1.9</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.squareup.retrofit2</groupId>
|
||||
<artifactId>adapter-rxjava3</artifactId>
|
||||
<version>2.11.0</version>
|
||||
</dependency>
|
||||
-->
|
||||
|
||||
<!-- Lombok for @Data and @Builder -->
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
|
|
|
@ -1,85 +1,96 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
//import io.reactivex.rxjava3.core.Observable;
|
||||
import okhttp3.OkHttpClient;
|
||||
import okhttp3.logging.HttpLoggingInterceptor;
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.GsonBuilder;
|
||||
import org.slf4j.Logger;
|
||||
import retrofit2.Call;
|
||||
import retrofit2.Retrofit;
|
||||
import retrofit2.converter.gson.GsonConverterFactory;
|
||||
import retrofit2.http.Body;
|
||||
import retrofit2.http.POST;
|
||||
import retrofit2.http.Path;
|
||||
import retrofit2.http.Header;
|
||||
import retrofit2.http.Headers;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.net.http.HttpClient;
|
||||
import java.net.http.HttpRequest;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.time.Duration;
|
||||
//import retrofit2.http.Streaming;
|
||||
|
||||
interface GeminiService {
|
||||
String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta/";
|
||||
String API_KEY_HEADER_NAME = "x-goog-api-key";
|
||||
String USER_AGENT = "User-Agent: LangChain4j";
|
||||
class GeminiService {
|
||||
private static final String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta";
|
||||
private static final String API_KEY_HEADER_NAME = "x-goog-api-key";
|
||||
|
||||
static GeminiService getGeminiService(Logger logger, Duration timeout) {
|
||||
Retrofit.Builder retrofitBuilder = new Retrofit.Builder()
|
||||
.baseUrl(GEMINI_AI_ENDPOINT)
|
||||
.addConverterFactory(GsonConverterFactory.create());
|
||||
private final HttpClient httpClient;
|
||||
private final Gson gson;
|
||||
private final Logger logger;
|
||||
|
||||
OkHttpClient.Builder clientBuilder = new OkHttpClient.Builder()
|
||||
.callTimeout(timeout);
|
||||
GeminiService(Logger logger, Duration timeout) {
|
||||
this.logger = logger;
|
||||
this.gson = new GsonBuilder().setPrettyPrinting().create();
|
||||
|
||||
if (logger != null) {
|
||||
HttpLoggingInterceptor logging = new HttpLoggingInterceptor(logger::debug);
|
||||
logging.redactHeader(API_KEY_HEADER_NAME);
|
||||
logging.setLevel(HttpLoggingInterceptor.Level.BODY);
|
||||
|
||||
clientBuilder.addInterceptor(logging);
|
||||
}
|
||||
|
||||
retrofitBuilder.client(clientBuilder.build());
|
||||
Retrofit retrofit = retrofitBuilder.build();
|
||||
|
||||
return retrofit.create(GeminiService.class);
|
||||
this.httpClient = HttpClient.newBuilder()
|
||||
.connectTimeout(timeout)
|
||||
.build();
|
||||
}
|
||||
|
||||
@POST("models/{model}:generateContent")
|
||||
@Headers(USER_AGENT)
|
||||
Call<GeminiGenerateContentResponse> generateContent(
|
||||
@Path("model") String modelName,
|
||||
@Header(API_KEY_HEADER_NAME) String apiKey,
|
||||
@Body GeminiGenerateContentRequest request);
|
||||
GeminiGenerateContentResponse generateContent(String modelName, String apiKey, GeminiGenerateContentRequest request) {
|
||||
String url = String.format("%s/models/%s:generateContent", GEMINI_AI_ENDPOINT, modelName);
|
||||
return sendRequest(url, apiKey, request, GeminiGenerateContentResponse.class);
|
||||
}
|
||||
|
||||
@POST("models/{model}:countTokens")
|
||||
@Headers(USER_AGENT)
|
||||
Call<GeminiCountTokensResponse> countTokens(
|
||||
@Path("model") String modelName,
|
||||
@Header(API_KEY_HEADER_NAME) String apiKey,
|
||||
@Body GeminiCountTokensRequest countTokensRequest);
|
||||
GeminiCountTokensResponse countTokens(String modelName, String apiKey, GeminiCountTokensRequest request) {
|
||||
String url = String.format("%s/models/%s:countTokens", GEMINI_AI_ENDPOINT, modelName);
|
||||
return sendRequest(url, apiKey, request, GeminiCountTokensResponse.class);
|
||||
}
|
||||
|
||||
@POST("models/{model}:embedContent")
|
||||
@Headers(USER_AGENT)
|
||||
Call<GoogleAiEmbeddingResponse> embed(
|
||||
@Path("model") String modelName,
|
||||
@Header(API_KEY_HEADER_NAME) String apiKey,
|
||||
@Body GoogleAiEmbeddingRequest embeddingRequest);
|
||||
GoogleAiEmbeddingResponse embed(String modelName, String apiKey, GoogleAiEmbeddingRequest request) {
|
||||
String url = String.format("%s/models/%s:embedContent", GEMINI_AI_ENDPOINT, modelName);
|
||||
return sendRequest(url, apiKey, request, GoogleAiEmbeddingResponse.class);
|
||||
}
|
||||
|
||||
@POST("models/{model}:batchEmbedContents")
|
||||
@Headers(USER_AGENT)
|
||||
Call<GoogleAiBatchEmbeddingResponse> batchEmbed(
|
||||
@Path("model") String modelName,
|
||||
@Header(API_KEY_HEADER_NAME) String apiKey,
|
||||
@Body GoogleAiBatchEmbeddingRequest batchEmbeddingRequest);
|
||||
GoogleAiBatchEmbeddingResponse batchEmbed(String modelName, String apiKey, GoogleAiBatchEmbeddingRequest request) {
|
||||
String url = String.format("%s/models/%s:batchEmbedContents", GEMINI_AI_ENDPOINT, modelName);
|
||||
return sendRequest(url, apiKey, request, GoogleAiBatchEmbeddingResponse.class);
|
||||
}
|
||||
|
||||
/*
|
||||
@Streaming
|
||||
@POST("models/{model}:streamGenerateContent")
|
||||
@Headers("User-Agent: LangChain4j")
|
||||
Observable<GeminiGenerateContentResponse> streamGenerateContent(
|
||||
@Path("model") String modelName,
|
||||
@Header(API_KEY_HEADER_NAME) String apiKey,
|
||||
@Body GeminiGenerateContentRequest request);
|
||||
*/
|
||||
private <T> T sendRequest(String url, String apiKey, Object requestBody, Class<T> responseType) {
|
||||
String jsonBody = gson.toJson(requestBody);
|
||||
HttpRequest request = buildHttpRequest(url, apiKey, jsonBody);
|
||||
|
||||
logRequest(jsonBody);
|
||||
|
||||
try {
|
||||
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
|
||||
|
||||
if (response.statusCode() >= 300) {
|
||||
throw new RuntimeException(String.format("HTTP error (%d): %s", response.statusCode(), response.body()));
|
||||
}
|
||||
|
||||
logResponse(response.body());
|
||||
|
||||
return gson.fromJson(response.body(), responseType);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("An error occurred while sending the request", e);
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new RuntimeException("Sending the request was interrupted", e);
|
||||
}
|
||||
}
|
||||
|
||||
private HttpRequest buildHttpRequest(String url, String apiKey, String jsonBody) {
|
||||
return HttpRequest.newBuilder()
|
||||
.uri(URI.create(url))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("User-Agent", "LangChain4j")
|
||||
.header(API_KEY_HEADER_NAME, apiKey)
|
||||
.POST(HttpRequest.BodyPublishers.ofString(jsonBody))
|
||||
.build();
|
||||
}
|
||||
|
||||
private void logRequest(String jsonBody) {
|
||||
if (logger != null) {
|
||||
logger.debug("Sending request to Gemini:\n{}", jsonBody);
|
||||
}
|
||||
}
|
||||
|
||||
private void logResponse(String responseBody) {
|
||||
if (logger != null) {
|
||||
logger.debug("Response from Gemini:\n{}", responseBody);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -8,10 +8,7 @@ import dev.langchain4j.model.embedding.EmbeddingModel;
|
|||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.Builder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.ResponseBody;
|
||||
import retrofit2.Call;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
|
@ -29,8 +26,6 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
|
|||
|
||||
private final GeminiService geminiService;
|
||||
|
||||
private final Gson GSON = new Gson();
|
||||
|
||||
private final String modelName;
|
||||
private final String apiKey;
|
||||
private final Integer maxRetries;
|
||||
|
@ -40,14 +35,14 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
|
|||
|
||||
@Builder
|
||||
public GoogleAiEmbeddingModel(
|
||||
String modelName,
|
||||
String apiKey,
|
||||
Integer maxRetries,
|
||||
TaskType taskType,
|
||||
String titleMetadataKey,
|
||||
Integer outputDimensionality,
|
||||
Duration timeout,
|
||||
Boolean logRequestsAndResponses
|
||||
String modelName,
|
||||
String apiKey,
|
||||
Integer maxRetries,
|
||||
TaskType taskType,
|
||||
String titleMetadataKey,
|
||||
Integer outputDimensionality,
|
||||
Duration timeout,
|
||||
Boolean logRequestsAndResponses
|
||||
) {
|
||||
|
||||
this.modelName = ensureNotBlank(modelName, "modelName");
|
||||
|
@ -64,33 +59,14 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
|
|||
|
||||
boolean logRequestsAndResponses1 = logRequestsAndResponses != null && logRequestsAndResponses;
|
||||
|
||||
this.geminiService = GeminiService.getGeminiService(logRequestsAndResponses1 ? log : null, timeout1);
|
||||
this.geminiService = new GeminiService(logRequestsAndResponses1 ? log : null, timeout1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<Embedding> embed(TextSegment textSegment) {
|
||||
GoogleAiEmbeddingRequest embeddingRequest = getGoogleAiEmbeddingRequest(textSegment);
|
||||
|
||||
Call<GoogleAiEmbeddingResponse> geminiEmbeddingResponseCall =
|
||||
withRetry(() -> this.geminiService.embed(this.modelName, this.apiKey, embeddingRequest), this.maxRetries);
|
||||
|
||||
GoogleAiEmbeddingResponse geminiResponse;
|
||||
try {
|
||||
retrofit2.Response<GoogleAiEmbeddingResponse> executed = geminiEmbeddingResponseCall.execute();
|
||||
geminiResponse = executed.body();
|
||||
|
||||
if (executed.code() >= 300) {
|
||||
try (ResponseBody errorBody = executed.errorBody()) {
|
||||
GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError();
|
||||
|
||||
throw new RuntimeException(
|
||||
String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
|
||||
}
|
||||
}
|
||||
} catch (IOException e) {
|
||||
|
||||
throw new RuntimeException("An error occurred when calling the Gemini API endpoint (embed).", e);
|
||||
}
|
||||
GoogleAiEmbeddingResponse geminiResponse = withRetry(() -> this.geminiService.embed(this.modelName, this.apiKey, embeddingRequest), this.maxRetries);
|
||||
|
||||
if (geminiResponse != null) {
|
||||
return Response.from(Embedding.from(geminiResponse.getEmbedding().getValues()));
|
||||
|
@ -107,8 +83,8 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
|
|||
@Override
|
||||
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
|
||||
List<GoogleAiEmbeddingRequest> embeddingRequests = textSegments.stream()
|
||||
.map(this::getGoogleAiEmbeddingRequest)
|
||||
.collect(Collectors.toList());
|
||||
.map(this::getGoogleAiEmbeddingRequest)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<Embedding> allEmbeddings = new ArrayList<>();
|
||||
int numberOfEmbeddings = embeddingRequests.size();
|
||||
|
@ -123,30 +99,12 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
|
|||
GoogleAiBatchEmbeddingRequest batchEmbeddingRequest = new GoogleAiBatchEmbeddingRequest();
|
||||
batchEmbeddingRequest.setRequests(embeddingRequests.subList(startIndex, lastIndex));
|
||||
|
||||
Call<GoogleAiBatchEmbeddingResponse> geminiBatchEmbeddingResponseCall =
|
||||
withRetry(() -> this.geminiService.batchEmbed(this.modelName, this.apiKey, batchEmbeddingRequest));
|
||||
|
||||
GoogleAiBatchEmbeddingResponse geminiResponse;
|
||||
try {
|
||||
retrofit2.Response<GoogleAiBatchEmbeddingResponse> executed = geminiBatchEmbeddingResponseCall.execute();
|
||||
geminiResponse = executed.body();
|
||||
|
||||
if (executed.code() >= 300) {
|
||||
try (ResponseBody errorBody = executed.errorBody()) {
|
||||
GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError();
|
||||
|
||||
throw new RuntimeException(
|
||||
String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
|
||||
}
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("An error occurred when calling the Gemini API endpoint (embedAll).", e);
|
||||
}
|
||||
GoogleAiBatchEmbeddingResponse geminiResponse = withRetry(() -> this.geminiService.batchEmbed(this.modelName, this.apiKey, batchEmbeddingRequest));
|
||||
|
||||
if (geminiResponse != null) {
|
||||
allEmbeddings.addAll(geminiResponse.getEmbeddings().stream()
|
||||
.map(values -> Embedding.from(values.getValues()))
|
||||
.collect(Collectors.toList()));
|
||||
.map(values -> Embedding.from(values.getValues()))
|
||||
.collect(Collectors.toList()));
|
||||
} else {
|
||||
throw new RuntimeException("Gemini embedding response was null (embedAll)");
|
||||
}
|
||||
|
@ -157,8 +115,8 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
|
|||
|
||||
private GoogleAiEmbeddingRequest getGoogleAiEmbeddingRequest(TextSegment textSegment) {
|
||||
GeminiPart geminiPart = GeminiPart.builder()
|
||||
.text(textSegment.text())
|
||||
.build();
|
||||
.text(textSegment.text())
|
||||
.build();
|
||||
|
||||
GeminiContent content = new GeminiContent(Collections.singletonList(geminiPart), null);
|
||||
|
||||
|
@ -170,11 +128,11 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
|
|||
}
|
||||
|
||||
return new GoogleAiEmbeddingRequest(
|
||||
"models/" + this.modelName,
|
||||
content,
|
||||
this.taskType,
|
||||
title,
|
||||
this.outputDimensionality
|
||||
"models/" + this.modelName,
|
||||
content,
|
||||
this.taskType,
|
||||
title,
|
||||
this.outputDimensionality
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -19,10 +19,7 @@ import dev.langchain4j.model.output.Response;
|
|||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import lombok.Builder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.ResponseBody;
|
||||
import retrofit2.Call;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Duration;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
@ -49,8 +46,6 @@ import static java.util.Collections.emptyList;
|
|||
@Experimental
|
||||
@Slf4j
|
||||
public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
private static final Gson GSON = new Gson();
|
||||
|
||||
private final GeminiService geminiService;
|
||||
|
||||
private final String apiKey;
|
||||
|
@ -116,18 +111,18 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
|
|||
|
||||
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
|
||||
|
||||
this.geminiService = GeminiService.getGeminiService(
|
||||
getOrDefault(logRequestsAndResponses, false) ? this.log : null,
|
||||
getOrDefault(timeout, ofSeconds(60))
|
||||
this.geminiService = new GeminiService(
|
||||
getOrDefault(logRequestsAndResponses, false) ? log : null,
|
||||
getOrDefault(timeout, ofSeconds(60))
|
||||
);
|
||||
|
||||
this.geminiTokenizer = GoogleAiGeminiTokenizer.builder()
|
||||
.modelName(this.modelName)
|
||||
.apiKey(this.apiKey)
|
||||
.timeout(getOrDefault(timeout, ofSeconds(60)))
|
||||
.maxRetries(this.maxRetries)
|
||||
.logRequestsAndResponses(this.logRequestsAndResponses)
|
||||
.build();
|
||||
.modelName(this.modelName)
|
||||
.apiKey(this.apiKey)
|
||||
.timeout(getOrDefault(timeout, ofSeconds(60)))
|
||||
.maxRetries(this.maxRetries)
|
||||
.logRequestsAndResponses(this.logRequestsAndResponses)
|
||||
.build();
|
||||
}
|
||||
|
||||
private static String computeMimeType(ResponseFormat responseFormat) {
|
||||
|
@ -136,9 +131,9 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
|
|||
}
|
||||
|
||||
if (ResponseFormatType.JSON.equals(responseFormat.type()) &&
|
||||
responseFormat.jsonSchema() != null &&
|
||||
responseFormat.jsonSchema().rootElement() != null &&
|
||||
responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema) {
|
||||
responseFormat.jsonSchema() != null &&
|
||||
responseFormat.jsonSchema().rootElement() != null &&
|
||||
responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema) {
|
||||
|
||||
return "text/x.enum";
|
||||
}
|
||||
|
@ -149,14 +144,14 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
|
|||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages) {
|
||||
ChatRequest request = ChatRequest.builder()
|
||||
.messages(messages)
|
||||
.build();
|
||||
.messages(messages)
|
||||
.build();
|
||||
|
||||
ChatResponse response = chat(request);
|
||||
|
||||
return Response.from(response.aiMessage(),
|
||||
response.tokenUsage(),
|
||||
response.finishReason());
|
||||
response.tokenUsage(),
|
||||
response.finishReason());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -167,15 +162,15 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
|
|||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
|
||||
ChatRequest request = ChatRequest.builder()
|
||||
.messages(messages)
|
||||
.toolSpecifications(toolSpecifications)
|
||||
.build();
|
||||
.messages(messages)
|
||||
.toolSpecifications(toolSpecifications)
|
||||
.build();
|
||||
|
||||
ChatResponse response = chat(request);
|
||||
|
||||
return Response.from(response.aiMessage(),
|
||||
response.tokenUsage(),
|
||||
response.finishReason());
|
||||
response.tokenUsage(),
|
||||
response.finishReason());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -194,31 +189,31 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
|
|||
}
|
||||
|
||||
GeminiGenerateContentRequest request = GeminiGenerateContentRequest.builder()
|
||||
.contents(geminiContentList)
|
||||
.systemInstruction(!systemInstruction.getParts().isEmpty() ? systemInstruction : null)
|
||||
.generationConfig(GeminiGenerationConfig.builder()
|
||||
.candidateCount(this.candidateCount)
|
||||
.maxOutputTokens(this.maxOutputTokens)
|
||||
.responseMimeType(responseMimeType)
|
||||
.responseSchema(schema)
|
||||
.stopSequences(this.stopSequences)
|
||||
.temperature(this.temperature)
|
||||
.topK(this.topK)
|
||||
.topP(this.topP)
|
||||
.build())
|
||||
.safetySettings(this.safetySettings)
|
||||
.tools(FunctionMapper.fromToolSepcsToGTool(toolSpecifications, this.allowCodeExecution))
|
||||
.toolConfig(new GeminiToolConfig(this.toolConfig))
|
||||
.build();
|
||||
.contents(geminiContentList)
|
||||
.systemInstruction(!systemInstruction.getParts().isEmpty() ? systemInstruction : null)
|
||||
.generationConfig(GeminiGenerationConfig.builder()
|
||||
.candidateCount(this.candidateCount)
|
||||
.maxOutputTokens(this.maxOutputTokens)
|
||||
.responseMimeType(responseMimeType)
|
||||
.responseSchema(schema)
|
||||
.stopSequences(this.stopSequences)
|
||||
.temperature(this.temperature)
|
||||
.topK(this.topK)
|
||||
.topP(this.topP)
|
||||
.build())
|
||||
.safetySettings(this.safetySettings)
|
||||
.tools(FunctionMapper.fromToolSepcsToGTool(toolSpecifications, this.allowCodeExecution))
|
||||
.toolConfig(new GeminiToolConfig(this.toolConfig))
|
||||
.build();
|
||||
|
||||
ChatModelRequest chatModelRequest = ChatModelRequest.builder()
|
||||
.model(modelName)
|
||||
.temperature(temperature)
|
||||
.topP(topP)
|
||||
.maxTokens(maxOutputTokens)
|
||||
.messages(chatRequest.messages())
|
||||
.toolSpecifications(chatRequest.toolSpecifications())
|
||||
.build();
|
||||
.model(modelName)
|
||||
.temperature(temperature)
|
||||
.topP(topP)
|
||||
.maxTokens(maxOutputTokens)
|
||||
.messages(chatRequest.messages())
|
||||
.toolSpecifications(chatRequest.toolSpecifications())
|
||||
.build();
|
||||
ConcurrentHashMap<Object, Object> listenerAttributes = new ConcurrentHashMap<>();
|
||||
ChatModelRequestContext chatModelRequestContext = new ChatModelRequestContext(chatModelRequest, listenerAttributes);
|
||||
listeners.forEach((listener) -> {
|
||||
|
@ -229,40 +224,12 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
|
|||
}
|
||||
});
|
||||
|
||||
Call<GeminiGenerateContentResponse> responseCall =
|
||||
withRetry(() -> this.geminiService.generateContent(this.modelName, this.apiKey, request), this.maxRetries);
|
||||
|
||||
GeminiGenerateContentResponse geminiResponse;
|
||||
try {
|
||||
retrofit2.Response<GeminiGenerateContentResponse> executed = responseCall.execute();
|
||||
geminiResponse = executed.body();
|
||||
|
||||
if (executed.code() >= 300) {
|
||||
try (ResponseBody errorBody = executed.errorBody()) {
|
||||
GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError();
|
||||
|
||||
RuntimeException runtimeException = new RuntimeException(
|
||||
String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
|
||||
|
||||
ChatModelErrorContext chatModelErrorContext = new ChatModelErrorContext(
|
||||
runtimeException, chatModelRequest, null, listenerAttributes
|
||||
);
|
||||
listeners.forEach((listener) -> {
|
||||
try {
|
||||
listener.onError(chatModelErrorContext);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener (onError)", e);
|
||||
}
|
||||
});
|
||||
|
||||
throw runtimeException;
|
||||
}
|
||||
}
|
||||
} catch (IOException e) {
|
||||
RuntimeException runtimeException = new RuntimeException("An error occurred when calling the Gemini API endpoint.", e);
|
||||
|
||||
geminiResponse = withRetry(() -> this.geminiService.generateContent(this.modelName, this.apiKey, request), this.maxRetries);
|
||||
} catch (RuntimeException e) {
|
||||
ChatModelErrorContext chatModelErrorContext = new ChatModelErrorContext(
|
||||
e, chatModelRequest, null, listenerAttributes
|
||||
e, chatModelRequest, null, listenerAttributes
|
||||
);
|
||||
listeners.forEach((listener) -> {
|
||||
try {
|
||||
|
@ -272,7 +239,7 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
|
|||
}
|
||||
});
|
||||
|
||||
throw runtimeException;
|
||||
throw e;
|
||||
}
|
||||
|
||||
if (geminiResponse != null) {
|
||||
|
@ -284,23 +251,23 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
|
|||
FinishReason finishReason = fromGFinishReasonToFinishReason(firstCandidate.getFinishReason());
|
||||
if (firstCandidate.getContent() == null) {
|
||||
aiMessage = AiMessage.from("No text was returned by the model. " +
|
||||
"The model finished generating because of the following reason: " + finishReason);
|
||||
"The model finished generating because of the following reason: " + finishReason);
|
||||
} else {
|
||||
aiMessage = fromGPartsToAiMessage(firstCandidate.getContent().getParts(), this.includeCodeExecutionOutput);
|
||||
}
|
||||
|
||||
TokenUsage tokenUsage = new TokenUsage(tokenCounts.getPromptTokenCount(),
|
||||
tokenCounts.getCandidatesTokenCount(),
|
||||
tokenCounts.getTotalTokenCount());
|
||||
tokenCounts.getCandidatesTokenCount(),
|
||||
tokenCounts.getTotalTokenCount());
|
||||
|
||||
ChatModelResponse chatModelResponse = ChatModelResponse.builder()
|
||||
.model(modelName)
|
||||
.tokenUsage(tokenUsage)
|
||||
.finishReason(finishReason)
|
||||
.aiMessage(aiMessage)
|
||||
.build();
|
||||
.model(modelName)
|
||||
.tokenUsage(tokenUsage)
|
||||
.finishReason(finishReason)
|
||||
.aiMessage(aiMessage)
|
||||
.build();
|
||||
ChatModelResponseContext chatModelResponseContext = new ChatModelResponseContext(
|
||||
chatModelResponse, chatModelRequest, listenerAttributes);
|
||||
chatModelResponse, chatModelRequest, listenerAttributes);
|
||||
listeners.forEach((listener) -> {
|
||||
try {
|
||||
listener.onResponse(chatModelResponseContext);
|
||||
|
@ -310,10 +277,10 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
|
|||
});
|
||||
|
||||
return ChatResponse.builder()
|
||||
.aiMessage(aiMessage)
|
||||
.finishReason(finishReason)
|
||||
.tokenUsage(tokenUsage)
|
||||
.build();
|
||||
.aiMessage(aiMessage)
|
||||
.finishReason(finishReason)
|
||||
.tokenUsage(tokenUsage)
|
||||
.build();
|
||||
} else {
|
||||
throw new RuntimeException("Gemini response was null");
|
||||
}
|
||||
|
@ -342,8 +309,8 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
|
|||
|
||||
public GoogleAiGeminiChatModelBuilder safetySettings(Map<GeminiHarmCategory, GeminiHarmBlockThreshold> safetySettingMap) {
|
||||
this.safetySettings = safetySettingMap.entrySet().stream()
|
||||
.map(entry -> new GeminiSafetySetting(entry.getKey(), entry.getValue())
|
||||
).collect(Collectors.toList());
|
||||
.map(entry -> new GeminiSafetySetting(entry.getKey(), entry.getValue())
|
||||
).collect(Collectors.toList());
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,10 +9,7 @@ import dev.langchain4j.data.message.UserMessage;
|
|||
import dev.langchain4j.model.Tokenizer;
|
||||
import lombok.Builder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.ResponseBody;
|
||||
import retrofit2.Call;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Duration;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
@ -25,8 +22,6 @@ import static java.util.Collections.singletonList;
|
|||
|
||||
@Slf4j
|
||||
public class GoogleAiGeminiTokenizer implements Tokenizer {
|
||||
private static final Gson GSON = new Gson();
|
||||
|
||||
private final GeminiService geminiService;
|
||||
private final String modelName;
|
||||
private final String apiKey;
|
||||
|
@ -34,17 +29,19 @@ public class GoogleAiGeminiTokenizer implements Tokenizer {
|
|||
|
||||
@Builder
|
||||
GoogleAiGeminiTokenizer(
|
||||
String modelName,
|
||||
String apiKey,
|
||||
Boolean logRequestsAndResponses,
|
||||
Duration timeout,
|
||||
Integer maxRetries
|
||||
String modelName,
|
||||
String apiKey,
|
||||
Boolean logRequestsAndResponses,
|
||||
Duration timeout,
|
||||
Integer maxRetries
|
||||
) {
|
||||
this.modelName = ensureNotBlank(modelName, "modelName");
|
||||
this.apiKey = ensureNotBlank(apiKey, "apiKey");
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
this.geminiService = GeminiService.getGeminiService(logRequestsAndResponses ? log : null,
|
||||
timeout != null ? timeout : Duration.ofSeconds(60));
|
||||
this.geminiService = new GeminiService(
|
||||
logRequestsAndResponses ? log : null,
|
||||
timeout != null ? timeout : Duration.ofSeconds(60)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -84,17 +81,17 @@ public class GoogleAiGeminiTokenizer implements Tokenizer {
|
|||
toolSpecifications.forEach(allTools::add);
|
||||
|
||||
GeminiContent dummyContent = GeminiContent.builder().parts(
|
||||
singletonList(GeminiPart.builder()
|
||||
.text("Dummy content") // This string contains 2 tokens
|
||||
.build())
|
||||
singletonList(GeminiPart.builder()
|
||||
.text("Dummy content") // This string contains 2 tokens
|
||||
.build())
|
||||
).build();
|
||||
|
||||
GeminiCountTokensRequest countTokensRequestWithDummyContent = new GeminiCountTokensRequest();
|
||||
countTokensRequestWithDummyContent.setGenerateContentRequest(GeminiGenerateContentRequest.builder()
|
||||
.model("models/" + this.modelName)
|
||||
.contents(singletonList(dummyContent))
|
||||
.tools(FunctionMapper.fromToolSepcsToGTool(allTools, false))
|
||||
.build());
|
||||
.model("models/" + this.modelName)
|
||||
.contents(singletonList(dummyContent))
|
||||
.tools(FunctionMapper.fromToolSepcsToGTool(allTools, false))
|
||||
.build());
|
||||
|
||||
// The API doesn't allow us to make a request to count the tokens of the tool specifications only.
|
||||
// Instead, we take the approach of adding a dummy content in the request, and subtract the tokens for the dummy request.
|
||||
|
@ -103,26 +100,7 @@ public class GoogleAiGeminiTokenizer implements Tokenizer {
|
|||
}
|
||||
|
||||
private int estimateTokenCount(GeminiCountTokensRequest countTokensRequest) {
|
||||
Call<GeminiCountTokensResponse> responseCall =
|
||||
withRetry(() -> this.geminiService.countTokens(this.modelName, this.apiKey, countTokensRequest), this.maxRetries);
|
||||
|
||||
GeminiCountTokensResponse countTokensResponse;
|
||||
try {
|
||||
retrofit2.Response<GeminiCountTokensResponse> executed = responseCall.execute();
|
||||
countTokensResponse = executed.body();
|
||||
|
||||
if (executed.code() >= 300) {
|
||||
try (ResponseBody errorBody = executed.errorBody()) {
|
||||
GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError();
|
||||
|
||||
throw new RuntimeException(
|
||||
String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
|
||||
}
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("An error occurred when calling the Gemini API endpoint to calculate tokens count", e);
|
||||
}
|
||||
|
||||
GeminiCountTokensResponse countTokensResponse = withRetry(() -> this.geminiService.countTokens(this.modelName, this.apiKey, countTokensRequest), this.maxRetries);
|
||||
return countTokensResponse.getTotalTokens();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue