diff --git a/langchain4j-google-ai-gemini/pom.xml b/langchain4j-google-ai-gemini/pom.xml
index c777c17c6..a2b3cd066 100644
--- a/langchain4j-google-ai-gemini/pom.xml
+++ b/langchain4j-google-ai-gemini/pom.xml
@@ -25,39 +25,6 @@
langchain4j-core
-
-
- com.squareup.retrofit2
- retrofit
-
-
- com.squareup.retrofit2
- converter-gson
-
-
- com.squareup.okhttp3
- okhttp
-
-
- com.squareup.okhttp3
- logging-interceptor
- 4.12.0
- compile
-
-
-
-
org.projectlombok
diff --git a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiService.java b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiService.java
index 4e6292e53..51d187d6d 100644
--- a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiService.java
+++ b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiService.java
@@ -1,85 +1,96 @@
package dev.langchain4j.model.googleai;
-//import io.reactivex.rxjava3.core.Observable;
-import okhttp3.OkHttpClient;
-import okhttp3.logging.HttpLoggingInterceptor;
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
import org.slf4j.Logger;
-import retrofit2.Call;
-import retrofit2.Retrofit;
-import retrofit2.converter.gson.GsonConverterFactory;
-import retrofit2.http.Body;
-import retrofit2.http.POST;
-import retrofit2.http.Path;
-import retrofit2.http.Header;
-import retrofit2.http.Headers;
+import java.io.IOException;
+import java.net.URI;
+import java.net.http.HttpClient;
+import java.net.http.HttpRequest;
+import java.net.http.HttpResponse;
import java.time.Duration;
-//import retrofit2.http.Streaming;
-interface GeminiService {
- String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta/";
- String API_KEY_HEADER_NAME = "x-goog-api-key";
- String USER_AGENT = "User-Agent: LangChain4j";
+class GeminiService {
+ private static final String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta";
+ private static final String API_KEY_HEADER_NAME = "x-goog-api-key";
- static GeminiService getGeminiService(Logger logger, Duration timeout) {
- Retrofit.Builder retrofitBuilder = new Retrofit.Builder()
- .baseUrl(GEMINI_AI_ENDPOINT)
- .addConverterFactory(GsonConverterFactory.create());
+ private final HttpClient httpClient;
+ private final Gson gson;
+ private final Logger logger;
- OkHttpClient.Builder clientBuilder = new OkHttpClient.Builder()
- .callTimeout(timeout);
+ GeminiService(Logger logger, Duration timeout) {
+ this.logger = logger;
+ this.gson = new GsonBuilder().setPrettyPrinting().create();
- if (logger != null) {
- HttpLoggingInterceptor logging = new HttpLoggingInterceptor(logger::debug);
- logging.redactHeader(API_KEY_HEADER_NAME);
- logging.setLevel(HttpLoggingInterceptor.Level.BODY);
-
- clientBuilder.addInterceptor(logging);
- }
-
- retrofitBuilder.client(clientBuilder.build());
- Retrofit retrofit = retrofitBuilder.build();
-
- return retrofit.create(GeminiService.class);
+ this.httpClient = HttpClient.newBuilder()
+ .connectTimeout(timeout)
+ .build();
}
- @POST("models/{model}:generateContent")
- @Headers(USER_AGENT)
- Call generateContent(
- @Path("model") String modelName,
- @Header(API_KEY_HEADER_NAME) String apiKey,
- @Body GeminiGenerateContentRequest request);
+ GeminiGenerateContentResponse generateContent(String modelName, String apiKey, GeminiGenerateContentRequest request) {
+ String url = String.format("%s/models/%s:generateContent", GEMINI_AI_ENDPOINT, modelName);
+ return sendRequest(url, apiKey, request, GeminiGenerateContentResponse.class);
+ }
- @POST("models/{model}:countTokens")
- @Headers(USER_AGENT)
- Call countTokens(
- @Path("model") String modelName,
- @Header(API_KEY_HEADER_NAME) String apiKey,
- @Body GeminiCountTokensRequest countTokensRequest);
+ GeminiCountTokensResponse countTokens(String modelName, String apiKey, GeminiCountTokensRequest request) {
+ String url = String.format("%s/models/%s:countTokens", GEMINI_AI_ENDPOINT, modelName);
+ return sendRequest(url, apiKey, request, GeminiCountTokensResponse.class);
+ }
- @POST("models/{model}:embedContent")
- @Headers(USER_AGENT)
- Call embed(
- @Path("model") String modelName,
- @Header(API_KEY_HEADER_NAME) String apiKey,
- @Body GoogleAiEmbeddingRequest embeddingRequest);
+ GoogleAiEmbeddingResponse embed(String modelName, String apiKey, GoogleAiEmbeddingRequest request) {
+ String url = String.format("%s/models/%s:embedContent", GEMINI_AI_ENDPOINT, modelName);
+ return sendRequest(url, apiKey, request, GoogleAiEmbeddingResponse.class);
+ }
- @POST("models/{model}:batchEmbedContents")
- @Headers(USER_AGENT)
- Call batchEmbed(
- @Path("model") String modelName,
- @Header(API_KEY_HEADER_NAME) String apiKey,
- @Body GoogleAiBatchEmbeddingRequest batchEmbeddingRequest);
+ GoogleAiBatchEmbeddingResponse batchEmbed(String modelName, String apiKey, GoogleAiBatchEmbeddingRequest request) {
+ String url = String.format("%s/models/%s:batchEmbedContents", GEMINI_AI_ENDPOINT, modelName);
+ return sendRequest(url, apiKey, request, GoogleAiBatchEmbeddingResponse.class);
+ }
-/*
- @Streaming
- @POST("models/{model}:streamGenerateContent")
- @Headers("User-Agent: LangChain4j")
- Observable streamGenerateContent(
- @Path("model") String modelName,
- @Header(API_KEY_HEADER_NAME) String apiKey,
- @Body GeminiGenerateContentRequest request);
-*/
+ private T sendRequest(String url, String apiKey, Object requestBody, Class responseType) {
+ String jsonBody = gson.toJson(requestBody);
+ HttpRequest request = buildHttpRequest(url, apiKey, jsonBody);
+ logRequest(jsonBody);
-}
+ try {
+ HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
+
+ if (response.statusCode() >= 300) {
+ throw new RuntimeException(String.format("HTTP error (%d): %s", response.statusCode(), response.body()));
+ }
+
+ logResponse(response.body());
+
+ return gson.fromJson(response.body(), responseType);
+ } catch (IOException e) {
+ throw new RuntimeException("An error occurred while sending the request", e);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException("Sending the request was interrupted", e);
+ }
+ }
+
+ private HttpRequest buildHttpRequest(String url, String apiKey, String jsonBody) {
+ return HttpRequest.newBuilder()
+ .uri(URI.create(url))
+ .header("Content-Type", "application/json")
+ .header("User-Agent", "LangChain4j")
+ .header(API_KEY_HEADER_NAME, apiKey)
+ .POST(HttpRequest.BodyPublishers.ofString(jsonBody))
+ .build();
+ }
+
+ private void logRequest(String jsonBody) {
+ if (logger != null) {
+ logger.debug("Sending request to Gemini:\n{}", jsonBody);
+ }
+ }
+
+ private void logResponse(String responseBody) {
+ if (logger != null) {
+ logger.debug("Response from Gemini:\n{}", responseBody);
+ }
+ }
+}
\ No newline at end of file
diff --git a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiEmbeddingModel.java b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiEmbeddingModel.java
index 29b8eccfd..51466550c 100644
--- a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiEmbeddingModel.java
+++ b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiEmbeddingModel.java
@@ -8,10 +8,7 @@ import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
-import okhttp3.ResponseBody;
-import retrofit2.Call;
-import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
@@ -29,8 +26,6 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
private final GeminiService geminiService;
- private final Gson GSON = new Gson();
-
private final String modelName;
private final String apiKey;
private final Integer maxRetries;
@@ -40,14 +35,14 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
@Builder
public GoogleAiEmbeddingModel(
- String modelName,
- String apiKey,
- Integer maxRetries,
- TaskType taskType,
- String titleMetadataKey,
- Integer outputDimensionality,
- Duration timeout,
- Boolean logRequestsAndResponses
+ String modelName,
+ String apiKey,
+ Integer maxRetries,
+ TaskType taskType,
+ String titleMetadataKey,
+ Integer outputDimensionality,
+ Duration timeout,
+ Boolean logRequestsAndResponses
) {
this.modelName = ensureNotBlank(modelName, "modelName");
@@ -64,33 +59,14 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
boolean logRequestsAndResponses1 = logRequestsAndResponses != null && logRequestsAndResponses;
- this.geminiService = GeminiService.getGeminiService(logRequestsAndResponses1 ? log : null, timeout1);
+ this.geminiService = new GeminiService(logRequestsAndResponses1 ? log : null, timeout1);
}
@Override
public Response embed(TextSegment textSegment) {
GoogleAiEmbeddingRequest embeddingRequest = getGoogleAiEmbeddingRequest(textSegment);
- Call geminiEmbeddingResponseCall =
- withRetry(() -> this.geminiService.embed(this.modelName, this.apiKey, embeddingRequest), this.maxRetries);
-
- GoogleAiEmbeddingResponse geminiResponse;
- try {
- retrofit2.Response executed = geminiEmbeddingResponseCall.execute();
- geminiResponse = executed.body();
-
- if (executed.code() >= 300) {
- try (ResponseBody errorBody = executed.errorBody()) {
- GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError();
-
- throw new RuntimeException(
- String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
- }
- }
- } catch (IOException e) {
-
- throw new RuntimeException("An error occurred when calling the Gemini API endpoint (embed).", e);
- }
+ GoogleAiEmbeddingResponse geminiResponse = withRetry(() -> this.geminiService.embed(this.modelName, this.apiKey, embeddingRequest), this.maxRetries);
if (geminiResponse != null) {
return Response.from(Embedding.from(geminiResponse.getEmbedding().getValues()));
@@ -107,8 +83,8 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
@Override
public Response> embedAll(List textSegments) {
List embeddingRequests = textSegments.stream()
- .map(this::getGoogleAiEmbeddingRequest)
- .collect(Collectors.toList());
+ .map(this::getGoogleAiEmbeddingRequest)
+ .collect(Collectors.toList());
List allEmbeddings = new ArrayList<>();
int numberOfEmbeddings = embeddingRequests.size();
@@ -123,30 +99,12 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
GoogleAiBatchEmbeddingRequest batchEmbeddingRequest = new GoogleAiBatchEmbeddingRequest();
batchEmbeddingRequest.setRequests(embeddingRequests.subList(startIndex, lastIndex));
- Call geminiBatchEmbeddingResponseCall =
- withRetry(() -> this.geminiService.batchEmbed(this.modelName, this.apiKey, batchEmbeddingRequest));
-
- GoogleAiBatchEmbeddingResponse geminiResponse;
- try {
- retrofit2.Response executed = geminiBatchEmbeddingResponseCall.execute();
- geminiResponse = executed.body();
-
- if (executed.code() >= 300) {
- try (ResponseBody errorBody = executed.errorBody()) {
- GeminiError error = GSON.fromJson(errorBody.string(), GeminiErrorContainer.class).getError();
-
- throw new RuntimeException(
- String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
- }
- }
- } catch (IOException e) {
- throw new RuntimeException("An error occurred when calling the Gemini API endpoint (embedAll).", e);
- }
+ GoogleAiBatchEmbeddingResponse geminiResponse = withRetry(() -> this.geminiService.batchEmbed(this.modelName, this.apiKey, batchEmbeddingRequest));
if (geminiResponse != null) {
allEmbeddings.addAll(geminiResponse.getEmbeddings().stream()
- .map(values -> Embedding.from(values.getValues()))
- .collect(Collectors.toList()));
+ .map(values -> Embedding.from(values.getValues()))
+ .collect(Collectors.toList()));
} else {
throw new RuntimeException("Gemini embedding response was null (embedAll)");
}
@@ -157,8 +115,8 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
private GoogleAiEmbeddingRequest getGoogleAiEmbeddingRequest(TextSegment textSegment) {
GeminiPart geminiPart = GeminiPart.builder()
- .text(textSegment.text())
- .build();
+ .text(textSegment.text())
+ .build();
GeminiContent content = new GeminiContent(Collections.singletonList(geminiPart), null);
@@ -170,11 +128,11 @@ public class GoogleAiEmbeddingModel implements EmbeddingModel {
}
return new GoogleAiEmbeddingRequest(
- "models/" + this.modelName,
- content,
- this.taskType,
- title,
- this.outputDimensionality
+ "models/" + this.modelName,
+ content,
+ this.taskType,
+ title,
+ this.outputDimensionality
);
}
diff --git a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModel.java b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModel.java
index bebe9c896..253464abf 100644
--- a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModel.java
+++ b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModel.java
@@ -19,10 +19,7 @@ import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
-import okhttp3.ResponseBody;
-import retrofit2.Call;
-import java.io.IOException;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
@@ -49,8 +46,6 @@ import static java.util.Collections.emptyList;
@Experimental
@Slf4j
public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEstimator {
- private static final Gson GSON = new Gson();
-
private final GeminiService geminiService;
private final String apiKey;
@@ -116,18 +111,18 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
- this.geminiService = GeminiService.getGeminiService(
- getOrDefault(logRequestsAndResponses, false) ? this.log : null,
- getOrDefault(timeout, ofSeconds(60))
+ this.geminiService = new GeminiService(
+ getOrDefault(logRequestsAndResponses, false) ? log : null,
+ getOrDefault(timeout, ofSeconds(60))
);
this.geminiTokenizer = GoogleAiGeminiTokenizer.builder()
- .modelName(this.modelName)
- .apiKey(this.apiKey)
- .timeout(getOrDefault(timeout, ofSeconds(60)))
- .maxRetries(this.maxRetries)
- .logRequestsAndResponses(this.logRequestsAndResponses)
- .build();
+ .modelName(this.modelName)
+ .apiKey(this.apiKey)
+ .timeout(getOrDefault(timeout, ofSeconds(60)))
+ .maxRetries(this.maxRetries)
+ .logRequestsAndResponses(this.logRequestsAndResponses)
+ .build();
}
private static String computeMimeType(ResponseFormat responseFormat) {
@@ -136,9 +131,9 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
}
if (ResponseFormatType.JSON.equals(responseFormat.type()) &&
- responseFormat.jsonSchema() != null &&
- responseFormat.jsonSchema().rootElement() != null &&
- responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema) {
+ responseFormat.jsonSchema() != null &&
+ responseFormat.jsonSchema().rootElement() != null &&
+ responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema) {
return "text/x.enum";
}
@@ -149,14 +144,14 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
@Override
public Response generate(List messages) {
ChatRequest request = ChatRequest.builder()
- .messages(messages)
- .build();
+ .messages(messages)
+ .build();
ChatResponse response = chat(request);
return Response.from(response.aiMessage(),
- response.tokenUsage(),
- response.finishReason());
+ response.tokenUsage(),
+ response.finishReason());
}
@Override
@@ -167,15 +162,15 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
@Override
public Response generate(List messages, List toolSpecifications) {
ChatRequest request = ChatRequest.builder()
- .messages(messages)
- .toolSpecifications(toolSpecifications)
- .build();
+ .messages(messages)
+ .toolSpecifications(toolSpecifications)
+ .build();
ChatResponse response = chat(request);
return Response.from(response.aiMessage(),
- response.tokenUsage(),
- response.finishReason());
+ response.tokenUsage(),
+ response.finishReason());
}
@Override
@@ -194,31 +189,31 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
}
GeminiGenerateContentRequest request = GeminiGenerateContentRequest.builder()
- .contents(geminiContentList)
- .systemInstruction(!systemInstruction.getParts().isEmpty() ? systemInstruction : null)
- .generationConfig(GeminiGenerationConfig.builder()
- .candidateCount(this.candidateCount)
- .maxOutputTokens(this.maxOutputTokens)
- .responseMimeType(responseMimeType)
- .responseSchema(schema)
- .stopSequences(this.stopSequences)
- .temperature(this.temperature)
- .topK(this.topK)
- .topP(this.topP)
- .build())
- .safetySettings(this.safetySettings)
- .tools(FunctionMapper.fromToolSepcsToGTool(toolSpecifications, this.allowCodeExecution))
- .toolConfig(new GeminiToolConfig(this.toolConfig))
- .build();
+ .contents(geminiContentList)
+ .systemInstruction(!systemInstruction.getParts().isEmpty() ? systemInstruction : null)
+ .generationConfig(GeminiGenerationConfig.builder()
+ .candidateCount(this.candidateCount)
+ .maxOutputTokens(this.maxOutputTokens)
+ .responseMimeType(responseMimeType)
+ .responseSchema(schema)
+ .stopSequences(this.stopSequences)
+ .temperature(this.temperature)
+ .topK(this.topK)
+ .topP(this.topP)
+ .build())
+ .safetySettings(this.safetySettings)
+ .tools(FunctionMapper.fromToolSepcsToGTool(toolSpecifications, this.allowCodeExecution))
+ .toolConfig(new GeminiToolConfig(this.toolConfig))
+ .build();
ChatModelRequest chatModelRequest = ChatModelRequest.builder()
- .model(modelName)
- .temperature(temperature)
- .topP(topP)
- .maxTokens(maxOutputTokens)
- .messages(chatRequest.messages())
- .toolSpecifications(chatRequest.toolSpecifications())
- .build();
+ .model(modelName)
+ .temperature(temperature)
+ .topP(topP)
+ .maxTokens(maxOutputTokens)
+ .messages(chatRequest.messages())
+ .toolSpecifications(chatRequest.toolSpecifications())
+ .build();
ConcurrentHashMap