From 2ae3983177b173d435d4fc3f5a8f8a21475c59c0 Mon Sep 17 00:00:00 2001 From: Bjarne-Kinkel <89375557+Bjarne-Kinkel@users.noreply.github.com> Date: Wed, 23 Oct 2024 16:29:46 +0200 Subject: [PATCH] Google AI Gemini: replace OkHttp and Retrofit with Java 11 HttpClient (#1950) ## Issue Based on #1903 ## Change Replaced OkHttp and Retrofit inside the GeminiService with an implementation using the HttpClient (Java 11). ## General checklist - [X] There are no breaking changes - [X] I have added unit and integration tests for my change (already existing) - [X] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [X] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable) --- langchain4j-google-ai-gemini/pom.xml | 33 ---- .../model/googleai/GeminiService.java | 147 ++++++++-------- .../googleai/GoogleAiEmbeddingModel.java | 86 +++------- .../googleai/GoogleAiGeminiChatModel.java | 159 +++++++----------- .../googleai/GoogleAiGeminiTokenizer.java | 56 ++---- 5 files changed, 181 insertions(+), 300 deletions(-) diff --git a/langchain4j-google-ai-gemini/pom.xml b/langchain4j-google-ai-gemini/pom.xml index c777c17c6..a2b3cd066 100644 --- a/langchain4j-google-ai-gemini/pom.xml +++ b/langchain4j-google-ai-gemini/pom.xml @@ -25,39 +25,6 @@ langchain4j-core - - - com.squareup.retrofit2 - retrofit - - - com.squareup.retrofit2 - converter-gson - - - com.squareup.okhttp3 - okhttp - - - com.squareup.okhttp3 - logging-interceptor - 4.12.0 - compile - - - - org.projectlombok diff --git a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiService.java b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiService.java index 4e6292e53..51d187d6d 100644 --- a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiService.java +++ b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiService.java @@ -1,85 +1,96 @@ package dev.langchain4j.model.googleai; -//import io.reactivex.rxjava3.core.Observable; -import okhttp3.OkHttpClient; -import okhttp3.logging.HttpLoggingInterceptor; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; import org.slf4j.Logger; -import retrofit2.Call; -import retrofit2.Retrofit; -import retrofit2.converter.gson.GsonConverterFactory; -import retrofit2.http.Body; -import retrofit2.http.POST; -import retrofit2.http.Path; -import retrofit2.http.Header; -import retrofit2.http.Headers; +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; import java.time.Duration; -//import retrofit2.http.Streaming; -interface GeminiService { - String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta/"; - String API_KEY_HEADER_NAME = "x-goog-api-key"; - String USER_AGENT = "User-Agent: LangChain4j"; +class GeminiService { + private static final String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta"; + private static final String API_KEY_HEADER_NAME = "x-goog-api-key"; - static GeminiService getGeminiService(Logger logger, Duration timeout) { - Retrofit.Builder retrofitBuilder = new Retrofit.Builder() - .baseUrl(GEMINI_AI_ENDPOINT) - .addConverterFactory(GsonConverterFactory.create()); + private final HttpClient httpClient; + private final Gson gson; + private final Logger logger; - OkHttpClient.Builder clientBuilder = new OkHttpClient.Builder() - .callTimeout(timeout); + GeminiService(Logger logger, Duration timeout) { + this.logger = logger; + this.gson = new GsonBuilder().setPrettyPrinting().create(); - if (logger != null) { - HttpLoggingInterceptor logging = new HttpLoggingInterceptor(logger::debug); - logging.redactHeader(API_KEY_HEADER_NAME); - logging.setLevel(HttpLoggingInterceptor.Level.BODY); - - clientBuilder.addInterceptor(logging); - } - - retrofitBuilder.client(clientBuilder.build()); - Retrofit retrofit = retrofitBuilder.build(); - - return retrofit.create(GeminiService.class); + this.httpClient = HttpClient.newBuilder() + .connectTimeout(timeout) + .build(); } - @POST("models/{model}:generateContent") - @Headers(USER_AGENT) - Call generateContent( - @Path("model") String modelName, - @Header(API_KEY_HEADER_NAME) String apiKey, - @Body GeminiGenerateContentRequest request); + GeminiGenerateContentResponse generateContent(String modelName, String apiKey, GeminiGenerateContentRequest request) { + String url = String.format("%s/models/%s:generateContent", GEMINI_AI_ENDPOINT, modelName); + return sendRequest(url, apiKey, request, GeminiGenerateContentResponse.class); + } - @POST("models/{model}:countTokens") - @Headers(USER_AGENT) - Call countTokens( - @Path("model") String modelName, - @Header(API_KEY_HEADER_NAME) String apiKey, - @Body GeminiCountTokensRequest countTokensRequest); + GeminiCountTokensResponse countTokens(String modelName, String apiKey, GeminiCountTokensRequest request) { + String url = String.format("%s/models/%s:countTokens", GEMINI_AI_ENDPOINT, modelName); + return sendRequest(url, apiKey, request, GeminiCountTokensResponse.class); + } - @POST("models/{model}:embedContent") - @Headers(USER_AGENT) - Call embed( - @Path("model") String modelName, - @Header(API_KEY_HEADER_NAME) String apiKey, - @Body GoogleAiEmbeddingRequest embeddingRequest); + GoogleAiEmbeddingResponse embed(String modelName, String apiKey, GoogleAiEmbeddingRequest request) { + String url = String.format("%s/models/%s:embedContent", GEMINI_AI_ENDPOINT, modelName); + return sendRequest(url, apiKey, request, GoogleAiEmbeddingResponse.class); + } - @POST("models/{model}:batchEmbedContents") - @Headers(USER_AGENT) - Call batchEmbed( - @Path("model") String modelName, - @Header(API_KEY_HEADER_NAME) String apiKey, - @Body GoogleAiBatchEmbeddingRequest batchEmbeddingRequest); + GoogleAiBatchEmbeddingResponse batchEmbed(String modelName, String apiKey, GoogleAiBatchEmbeddingRequest request) { + String url = String.format("%s/models/%s:batchEmbedContents", GEMINI_AI_ENDPOINT, modelName); + return sendRequest(url, apiKey, request, GoogleAiBatchEmbeddingResponse.class); + } -/* - @Streaming - @POST("models/{model}:streamGenerateContent") - @Headers("User-Agent: LangChain4j") - Observable streamGenerateContent( - @Path("model") String modelName, - @Header(API_KEY_HEADER_NAME) String apiKey, - @Body GeminiGenerateContentRequest request); -*/ + private T sendRequest(String url, String apiKey, Object requestBody, Class responseType) { + String jsonBody = gson.toJson(requestBody); + HttpRequest request = buildHttpRequest(url, apiKey, jsonBody); + logRequest(jsonBody); -} + try { + HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + + if (response.statusCode() >= 300) { + throw new RuntimeException(String.format("HTTP error (%d): %s", response.statusCode(), response.body())); + } + + logResponse(response.body()); + + return gson.fromJson(response.body(), responseType); + } catch (IOException e) { + throw new RuntimeException("An error occurred while sending the request", e); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Sending the request was interrupted", e); + } + } + + private HttpRequest buildHttpRequest(String url, String apiKey, String jsonBody) { + return HttpRequest.newBuilder() + .uri(URI.create(url)) + .header("Content-Type", "application/json") + .header("User-Agent", "LangChain4j") + .header(API_KEY_HEADER_NAME, apiKey) + .POST(HttpRequest.BodyPublishers.ofString(jsonBody)) + .build(); + } + + private void logRequest(String jsonBody) { + if (logger != null) { + logger.debug("Sending request to Gemini:\n{}", jsonBody); + } + } + + private void logResponse(String responseBody) { + if (logger != null) { + logger.debug("Response from Gemini:\n{}", responseBody); + } + } +} \ No newline at end of file diff --git a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiEmbeddingModel.java b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiEmbeddingModel.java index 29b8eccfd..51466550c 100644 --- a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiEmbeddingModel.java +++ b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiEmbeddingModel.java @@ -8,10 +8,7 @@ import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.output.Response; import lombok.Builder; import lombok.extern.slf4j.Slf4j; -import okhttp3.ResponseBody; -import retrofit2.Call; -import java.io.IOException; import java.time.Duration; import java.util.ArrayList; import java.util.Collections; @@ -29,8 +26,6 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel { private final GeminiService geminiService; - private final Gson GSON = new Gson(); - private final String modelName; private final String apiKey; private final Integer maxRetries; @@ -40,14 +35,14 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel { @Builder public GoogleAiEmbeddingModel( - String modelName, - String apiKey, - Integer maxRetries, - TaskType taskType, - String titleMetadataKey, - Integer outputDimensionality, - Duration timeout, - Boolean logRequestsAndResponses + String modelName, + String apiKey, + Integer maxRetries, + TaskType taskType, + String titleMetadataKey, + Integer outputDimensionality, + Duration timeout, + Boolean logRequestsAndResponses ) { this.modelName = ensureNotBlank(modelName, "modelName"); @@ -64,33 +59,14 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel { boolean logRequestsAndResponses1 = logRequestsAndResponses != null && logRequestsAndResponses; - this.geminiService = GeminiService.getGeminiService(logRequestsAndResponses1 ? log : null, timeout1); + this.geminiService = new GeminiService(logRequestsAndResponses1 ? log : null, timeout1); } @Override public Response embed(TextSegment textSegment) { GoogleAiEmbeddingRequest embeddingRequest = getGoogleAiEmbeddingRequest(textSegment); - Call geminiEmbeddingResponseCall = - withRetry(() -> this.geminiService.embed(this.modelName, this.apiKey, embeddingRequest), this.maxRetries); - - GoogleAiEmbeddingResponse geminiResponse; - try { - retrofit2.Response executed = geminiEmbeddingResponseCall.execute(); - geminiResponse = executed.body(); - - if (executed.code() >= 300) { - try (ResponseBody errorBody = executed.errorBody()) { - GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError(); - - throw new RuntimeException( - String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage())); - } - } - } catch (IOException e) { - - throw new RuntimeException("An error occurred when calling the Gemini API endpoint (embed).", e); - } + GoogleAiEmbeddingResponse geminiResponse = withRetry(() -> this.geminiService.embed(this.modelName, this.apiKey, embeddingRequest), this.maxRetries); if (geminiResponse != null) { return Response.from(Embedding.from(geminiResponse.getEmbedding().getValues())); @@ -107,8 +83,8 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel { @Override public Response> embedAll(List textSegments) { List embeddingRequests = textSegments.stream() - .map(this::getGoogleAiEmbeddingRequest) - .collect(Collectors.toList()); + .map(this::getGoogleAiEmbeddingRequest) + .collect(Collectors.toList()); List allEmbeddings = new ArrayList<>(); int numberOfEmbeddings = embeddingRequests.size(); @@ -123,30 +99,12 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel { GoogleAiBatchEmbeddingRequest batchEmbeddingRequest = new GoogleAiBatchEmbeddingRequest(); batchEmbeddingRequest.setRequests(embeddingRequests.subList(startIndex, lastIndex)); - Call geminiBatchEmbeddingResponseCall = - withRetry(() -> this.geminiService.batchEmbed(this.modelName, this.apiKey, batchEmbeddingRequest)); - - GoogleAiBatchEmbeddingResponse geminiResponse; - try { - retrofit2.Response executed = geminiBatchEmbeddingResponseCall.execute(); - geminiResponse = executed.body(); - - if (executed.code() >= 300) { - try (ResponseBody errorBody = executed.errorBody()) { - GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError(); - - throw new RuntimeException( - String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage())); - } - } - } catch (IOException e) { - throw new RuntimeException("An error occurred when calling the Gemini API endpoint (embedAll).", e); - } + GoogleAiBatchEmbeddingResponse geminiResponse = withRetry(() -> this.geminiService.batchEmbed(this.modelName, this.apiKey, batchEmbeddingRequest)); if (geminiResponse != null) { allEmbeddings.addAll(geminiResponse.getEmbeddings().stream() - .map(values -> Embedding.from(values.getValues())) - .collect(Collectors.toList())); + .map(values -> Embedding.from(values.getValues())) + .collect(Collectors.toList())); } else { throw new RuntimeException("Gemini embedding response was null (embedAll)"); } @@ -157,8 +115,8 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel { private GoogleAiEmbeddingRequest getGoogleAiEmbeddingRequest(TextSegment textSegment) { GeminiPart geminiPart = GeminiPart.builder() - .text(textSegment.text()) - .build(); + .text(textSegment.text()) + .build(); GeminiContent content = new GeminiContent(Collections.singletonList(geminiPart), null); @@ -170,11 +128,11 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel { } return new GoogleAiEmbeddingRequest( - "models/" + this.modelName, - content, - this.taskType, - title, - this.outputDimensionality + "models/" + this.modelName, + content, + this.taskType, + title, + this.outputDimensionality ); } diff --git a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModel.java b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModel.java index bebe9c896..253464abf 100644 --- a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModel.java +++ b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModel.java @@ -19,10 +19,7 @@ import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import lombok.Builder; import lombok.extern.slf4j.Slf4j; -import okhttp3.ResponseBody; -import retrofit2.Call; -import java.io.IOException; import java.time.Duration; import java.util.Arrays; import java.util.Collections; @@ -49,8 +46,6 @@ import static java.util.Collections.emptyList; @Experimental @Slf4j public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEstimator { - private static final Gson GSON = new Gson(); - private final GeminiService geminiService; private final String apiKey; @@ -116,18 +111,18 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners); - this.geminiService = GeminiService.getGeminiService( - getOrDefault(logRequestsAndResponses, false) ? this.log : null, - getOrDefault(timeout, ofSeconds(60)) + this.geminiService = new GeminiService( + getOrDefault(logRequestsAndResponses, false) ? log : null, + getOrDefault(timeout, ofSeconds(60)) ); this.geminiTokenizer = GoogleAiGeminiTokenizer.builder() - .modelName(this.modelName) - .apiKey(this.apiKey) - .timeout(getOrDefault(timeout, ofSeconds(60))) - .maxRetries(this.maxRetries) - .logRequestsAndResponses(this.logRequestsAndResponses) - .build(); + .modelName(this.modelName) + .apiKey(this.apiKey) + .timeout(getOrDefault(timeout, ofSeconds(60))) + .maxRetries(this.maxRetries) + .logRequestsAndResponses(this.logRequestsAndResponses) + .build(); } private static String computeMimeType(ResponseFormat responseFormat) { @@ -136,9 +131,9 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst } if (ResponseFormatType.JSON.equals(responseFormat.type()) && - responseFormat.jsonSchema() != null && - responseFormat.jsonSchema().rootElement() != null && - responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema) { + responseFormat.jsonSchema() != null && + responseFormat.jsonSchema().rootElement() != null && + responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema) { return "text/x.enum"; } @@ -149,14 +144,14 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst @Override public Response generate(List messages) { ChatRequest request = ChatRequest.builder() - .messages(messages) - .build(); + .messages(messages) + .build(); ChatResponse response = chat(request); return Response.from(response.aiMessage(), - response.tokenUsage(), - response.finishReason()); + response.tokenUsage(), + response.finishReason()); } @Override @@ -167,15 +162,15 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst @Override public Response generate(List messages, List toolSpecifications) { ChatRequest request = ChatRequest.builder() - .messages(messages) - .toolSpecifications(toolSpecifications) - .build(); + .messages(messages) + .toolSpecifications(toolSpecifications) + .build(); ChatResponse response = chat(request); return Response.from(response.aiMessage(), - response.tokenUsage(), - response.finishReason()); + response.tokenUsage(), + response.finishReason()); } @Override @@ -194,31 +189,31 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst } GeminiGenerateContentRequest request = GeminiGenerateContentRequest.builder() - .contents(geminiContentList) - .systemInstruction(!systemInstruction.getParts().isEmpty() ? systemInstruction : null) - .generationConfig(GeminiGenerationConfig.builder() - .candidateCount(this.candidateCount) - .maxOutputTokens(this.maxOutputTokens) - .responseMimeType(responseMimeType) - .responseSchema(schema) - .stopSequences(this.stopSequences) - .temperature(this.temperature) - .topK(this.topK) - .topP(this.topP) - .build()) - .safetySettings(this.safetySettings) - .tools(FunctionMapper.fromToolSepcsToGTool(toolSpecifications, this.allowCodeExecution)) - .toolConfig(new GeminiToolConfig(this.toolConfig)) - .build(); + .contents(geminiContentList) + .systemInstruction(!systemInstruction.getParts().isEmpty() ? systemInstruction : null) + .generationConfig(GeminiGenerationConfig.builder() + .candidateCount(this.candidateCount) + .maxOutputTokens(this.maxOutputTokens) + .responseMimeType(responseMimeType) + .responseSchema(schema) + .stopSequences(this.stopSequences) + .temperature(this.temperature) + .topK(this.topK) + .topP(this.topP) + .build()) + .safetySettings(this.safetySettings) + .tools(FunctionMapper.fromToolSepcsToGTool(toolSpecifications, this.allowCodeExecution)) + .toolConfig(new GeminiToolConfig(this.toolConfig)) + .build(); ChatModelRequest chatModelRequest = ChatModelRequest.builder() - .model(modelName) - .temperature(temperature) - .topP(topP) - .maxTokens(maxOutputTokens) - .messages(chatRequest.messages()) - .toolSpecifications(chatRequest.toolSpecifications()) - .build(); + .model(modelName) + .temperature(temperature) + .topP(topP) + .maxTokens(maxOutputTokens) + .messages(chatRequest.messages()) + .toolSpecifications(chatRequest.toolSpecifications()) + .build(); ConcurrentHashMap listenerAttributes = new ConcurrentHashMap<>(); ChatModelRequestContext chatModelRequestContext = new ChatModelRequestContext(chatModelRequest, listenerAttributes); listeners.forEach((listener) -> { @@ -229,40 +224,12 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst } }); - Call responseCall = - withRetry(() -> this.geminiService.generateContent(this.modelName, this.apiKey, request), this.maxRetries); - GeminiGenerateContentResponse geminiResponse; try { - retrofit2.Response executed = responseCall.execute(); - geminiResponse = executed.body(); - - if (executed.code() >= 300) { - try (ResponseBody errorBody = executed.errorBody()) { - GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError(); - - RuntimeException runtimeException = new RuntimeException( - String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage())); - - ChatModelErrorContext chatModelErrorContext = new ChatModelErrorContext( - runtimeException, chatModelRequest, null, listenerAttributes - ); - listeners.forEach((listener) -> { - try { - listener.onError(chatModelErrorContext); - } catch (Exception e) { - log.warn("Exception while calling model listener (onError)", e); - } - }); - - throw runtimeException; - } - } - } catch (IOException e) { - RuntimeException runtimeException = new RuntimeException("An error occurred when calling the Gemini API endpoint.", e); - + geminiResponse = withRetry(() -> this.geminiService.generateContent(this.modelName, this.apiKey, request), this.maxRetries); + } catch (RuntimeException e) { ChatModelErrorContext chatModelErrorContext = new ChatModelErrorContext( - e, chatModelRequest, null, listenerAttributes + e, chatModelRequest, null, listenerAttributes ); listeners.forEach((listener) -> { try { @@ -272,7 +239,7 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst } }); - throw runtimeException; + throw e; } if (geminiResponse != null) { @@ -284,23 +251,23 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst FinishReason finishReason = fromGFinishReasonToFinishReason(firstCandidate.getFinishReason()); if (firstCandidate.getContent() == null) { aiMessage = AiMessage.from("No text was returned by the model. " + - "The model finished generating because of the following reason: " + finishReason); + "The model finished generating because of the following reason: " + finishReason); } else { aiMessage = fromGPartsToAiMessage(firstCandidate.getContent().getParts(), this.includeCodeExecutionOutput); } TokenUsage tokenUsage = new TokenUsage(tokenCounts.getPromptTokenCount(), - tokenCounts.getCandidatesTokenCount(), - tokenCounts.getTotalTokenCount()); + tokenCounts.getCandidatesTokenCount(), + tokenCounts.getTotalTokenCount()); ChatModelResponse chatModelResponse = ChatModelResponse.builder() - .model(modelName) - .tokenUsage(tokenUsage) - .finishReason(finishReason) - .aiMessage(aiMessage) - .build(); + .model(modelName) + .tokenUsage(tokenUsage) + .finishReason(finishReason) + .aiMessage(aiMessage) + .build(); ChatModelResponseContext chatModelResponseContext = new ChatModelResponseContext( - chatModelResponse, chatModelRequest, listenerAttributes); + chatModelResponse, chatModelRequest, listenerAttributes); listeners.forEach((listener) -> { try { listener.onResponse(chatModelResponseContext); @@ -310,10 +277,10 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst }); return ChatResponse.builder() - .aiMessage(aiMessage) - .finishReason(finishReason) - .tokenUsage(tokenUsage) - .build(); + .aiMessage(aiMessage) + .finishReason(finishReason) + .tokenUsage(tokenUsage) + .build(); } else { throw new RuntimeException("Gemini response was null"); } @@ -342,8 +309,8 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst public GoogleAiGeminiChatModelBuilder safetySettings(Map safetySettingMap) { this.safetySettings = safetySettingMap.entrySet().stream() - .map(entry -> new GeminiSafetySetting(entry.getKey(), entry.getValue()) - ).collect(Collectors.toList()); + .map(entry -> new GeminiSafetySetting(entry.getKey(), entry.getValue()) + ).collect(Collectors.toList()); return this; } } diff --git a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiTokenizer.java b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiTokenizer.java index 4955754bd..b82e9a960 100644 --- a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiTokenizer.java +++ b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiTokenizer.java @@ -9,10 +9,7 @@ import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.model.Tokenizer; import lombok.Builder; import lombok.extern.slf4j.Slf4j; -import okhttp3.ResponseBody; -import retrofit2.Call; -import java.io.IOException; import java.time.Duration; import java.util.LinkedList; import java.util.List; @@ -25,8 +22,6 @@ import static java.util.Collections.singletonList; @Slf4j public class GoogleAiGeminiTokenizer implements Tokenizer { - private static final Gson GSON = new Gson(); - private final GeminiService geminiService; private final String modelName; private final String apiKey; @@ -34,17 +29,19 @@ public class GoogleAiGeminiTokenizer implements Tokenizer { @Builder GoogleAiGeminiTokenizer( - String modelName, - String apiKey, - Boolean logRequestsAndResponses, - Duration timeout, - Integer maxRetries + String modelName, + String apiKey, + Boolean logRequestsAndResponses, + Duration timeout, + Integer maxRetries ) { this.modelName = ensureNotBlank(modelName, "modelName"); this.apiKey = ensureNotBlank(apiKey, "apiKey"); this.maxRetries = getOrDefault(maxRetries, 3); - this.geminiService = GeminiService.getGeminiService(logRequestsAndResponses ? log : null, - timeout != null ? timeout : Duration.ofSeconds(60)); + this.geminiService = new GeminiService( + logRequestsAndResponses ? log : null, + timeout != null ? timeout : Duration.ofSeconds(60) + ); } @Override @@ -84,17 +81,17 @@ public class GoogleAiGeminiTokenizer implements Tokenizer { toolSpecifications.forEach(allTools::add); GeminiContent dummyContent = GeminiContent.builder().parts( - singletonList(GeminiPart.builder() - .text("Dummy content") // This string contains 2 tokens - .build()) + singletonList(GeminiPart.builder() + .text("Dummy content") // This string contains 2 tokens + .build()) ).build(); GeminiCountTokensRequest countTokensRequestWithDummyContent = new GeminiCountTokensRequest(); countTokensRequestWithDummyContent.setGenerateContentRequest(GeminiGenerateContentRequest.builder() - .model("models/" + this.modelName) - .contents(singletonList(dummyContent)) - .tools(FunctionMapper.fromToolSepcsToGTool(allTools, false)) - .build()); + .model("models/" + this.modelName) + .contents(singletonList(dummyContent)) + .tools(FunctionMapper.fromToolSepcsToGTool(allTools, false)) + .build()); // The API doesn't allow us to make a request to count the tokens of the tool specifications only. // Instead, we take the approach of adding a dummy content in the request, and subtract the tokens for the dummy request. @@ -103,26 +100,7 @@ public class GoogleAiGeminiTokenizer implements Tokenizer { } private int estimateTokenCount(GeminiCountTokensRequest countTokensRequest) { - Call responseCall = - withRetry(() -> this.geminiService.countTokens(this.modelName, this.apiKey, countTokensRequest), this.maxRetries); - - GeminiCountTokensResponse countTokensResponse; - try { - retrofit2.Response executed = responseCall.execute(); - countTokensResponse = executed.body(); - - if (executed.code() >= 300) { - try (ResponseBody errorBody = executed.errorBody()) { - GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError(); - - throw new RuntimeException( - String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage())); - } - } - } catch (IOException e) { - throw new RuntimeException("An error occurred when calling the Gemini API endpoint to calculate tokens count", e); - } - + GeminiCountTokensResponse countTokensResponse = withRetry(() -> this.geminiService.countTokens(this.modelName, this.apiKey, countTokensRequest), this.maxRetries); return countTokensResponse.getTotalTokens(); } }