Adding GoogleAiGeminiStreamingChatModel (#1951)

## Issue
Closes #1903

## Change
Implemented the GoogleAiGeminiStreamingChatModel.

This PR depends on #1950.

## General checklist
- [X] There are no breaking changes
- [x] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [x] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
- [x] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
This commit is contained in:
Bjarne-Kinkel 2024-10-31 10:07:50 +01:00 committed by GitHub
parent e98a2e4f62
commit aa0e488166
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1348 additions and 223 deletions

View File

@ -83,9 +83,34 @@ ChatLanguageModel gemini = GoogleAiGeminiChatModel.builder()
``` ```
## GoogleAiGeminiStreamingChatModel ## GoogleAiGeminiStreamingChatModel
The `GoogleAiGeminiStreamingChatModel` allows streaming the text of a response token by token. The response must be managed by a `StreamingResponseHandler`.
```java
StreamingChatLanguageModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(System.getenv("GEMINI_AI_KEY"))
.modelName("gemini-1.5-flash")
.build();
No streaming chat model is available yet. CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
Please open a feature request if you're interested in a streaming model or if you want to contribute to implementing it.
gemini.generate("Tell me a joke about Java", new StreamingResponseHandler<AiMessage>() {
@Override
public void onNext(String token) {
System.out.print(token);
}
@Override
public void onComplete(Response<AiMessage> response) {
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureResponse.completeExceptionally(error);
}
});
futureResponse.join();
```
## Tools ## Tools

View File

@ -11,8 +11,8 @@ sidebar_position: 0
| [Azure OpenAI](/integrations/language-models/azure-open-ai) | ✅ | ✅ | ✅ | text, image | ✅ | | | | | [Azure OpenAI](/integrations/language-models/azure-open-ai) | ✅ | ✅ | ✅ | text, image | ✅ | | | |
| [ChatGLM](/integrations/language-models/chatglm) | | | | text | | | | | | [ChatGLM](/integrations/language-models/chatglm) | | | | text | | | | |
| [DashScope](/integrations/language-models/dashscope) | ✅ | ✅ | | text, image, audio | ✅ | | | | | [DashScope](/integrations/language-models/dashscope) | ✅ | ✅ | | text, image, audio | ✅ | | | |
| [GitHub Models](/integrations/language-models/github-models) | ✅ | ✅ | ✅ | text | ✅ | | | | | [GitHub Models](/integrations/language-models/github-models) | ✅ | ✅ | ✅ | text | ✅ | | | |
| [Google AI Gemini](/integrations/language-models/google-ai-gemini) | | ✅ | ✅ | text, image, audio, video, PDF | ✅ | | | | | [Google AI Gemini](/integrations/language-models/google-ai-gemini) | | ✅ | ✅ | text, image, audio, video, PDF | ✅ | | | |
| [Google Vertex AI Gemini](/integrations/language-models/google-vertex-ai-gemini) | ✅ | ✅ | ✅ | text, image, audio, video, PDF | ✅ | | | | | [Google Vertex AI Gemini](/integrations/language-models/google-vertex-ai-gemini) | ✅ | ✅ | ✅ | text, image, audio, video, PDF | ✅ | | | |
| [Google Vertex AI PaLM 2](/integrations/language-models/google-palm) | | | | text | | | ✅ | | | [Google Vertex AI PaLM 2](/integrations/language-models/google-palm) | | | | text | | | ✅ | |
| [Hugging Face](/integrations/language-models/hugging-face) | | | | text | | | | | | [Hugging Face](/integrations/language-models/hugging-face) | | | | text | | | | |

View File

@ -0,0 +1,192 @@
package dev.langchain4j.model.googleai;
import dev.langchain4j.Experimental;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.ResponseFormatType;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import static dev.langchain4j.internal.Utils.copyIfNotNull;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromMessageToGContent;
import static dev.langchain4j.model.googleai.SchemaMapper.fromJsonSchemaToGSchema;
import static java.time.Duration.ofSeconds;
import static java.util.Collections.emptyList;
@Experimental
@Slf4j
abstract class BaseGeminiChatModel {
protected final GeminiService geminiService;
protected final String apiKey;
protected final String modelName;
protected final Double temperature;
protected final Integer topK;
protected final Double topP;
protected final Integer maxOutputTokens;
protected final List<String> stopSequences;
protected final ResponseFormat responseFormat;
protected final GeminiFunctionCallingConfig toolConfig;
protected final boolean allowCodeExecution;
protected final boolean includeCodeExecutionOutput;
protected final List<GeminiSafetySetting> safetySettings;
protected final List<ChatModelListener> listeners;
protected final Integer maxRetries;
protected BaseGeminiChatModel(
String apiKey,
String modelName,
Double temperature,
Integer topK,
Double topP,
Integer maxOutputTokens,
Duration timeout,
ResponseFormat responseFormat,
List<String> stopSequences,
GeminiFunctionCallingConfig toolConfig,
Boolean allowCodeExecution,
Boolean includeCodeExecutionOutput,
Boolean logRequestsAndResponses,
List<GeminiSafetySetting> safetySettings,
List<ChatModelListener> listeners,
Integer maxRetries
) {
this.apiKey = ensureNotBlank(apiKey, "apiKey");
this.modelName = ensureNotBlank(modelName, "modelName");
this.temperature = temperature;
this.topK = topK;
this.topP = topP;
this.maxOutputTokens = maxOutputTokens;
this.stopSequences = getOrDefault(stopSequences, emptyList());
this.toolConfig = toolConfig;
this.allowCodeExecution = getOrDefault(allowCodeExecution, false);
this.includeCodeExecutionOutput = getOrDefault(includeCodeExecutionOutput, false);
this.safetySettings = copyIfNotNull(safetySettings);
this.responseFormat = responseFormat;
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
this.maxRetries = getOrDefault(maxRetries, 3);
this.geminiService = new GeminiService(
getOrDefault(logRequestsAndResponses, false) ? log : null,
getOrDefault(timeout, ofSeconds(60))
);
}
protected GeminiGenerateContentRequest createGenerateContentRequest(
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ResponseFormat responseFormat
) {
GeminiContent systemInstruction = new GeminiContent(GeminiRole.MODEL.toString());
List<GeminiContent> geminiContentList = fromMessageToGContent(messages, systemInstruction);
GeminiSchema schema = null;
if (responseFormat != null && responseFormat.jsonSchema() != null) {
schema = fromJsonSchemaToGSchema(responseFormat.jsonSchema());
}
return GeminiGenerateContentRequest.builder()
.contents(geminiContentList)
.systemInstruction(!systemInstruction.getParts().isEmpty() ? systemInstruction : null)
.generationConfig(GeminiGenerationConfig.builder()
.candidateCount(1) // Multiple candidates aren't supported by langchain4j
.maxOutputTokens(this.maxOutputTokens)
.responseMimeType(computeMimeType(responseFormat))
.responseSchema(schema)
.stopSequences(this.stopSequences)
.temperature(this.temperature)
.topK(this.topK)
.topP(this.topP)
.build())
.safetySettings(this.safetySettings)
.tools(FunctionMapper.fromToolSepcsToGTool(toolSpecifications, this.allowCodeExecution))
.toolConfig(new GeminiToolConfig(this.toolConfig))
.build();
}
protected ChatModelRequest createChatModelRequest(
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications
) {
return ChatModelRequest.builder()
.model(modelName)
.temperature(temperature)
.topP(topP)
.maxTokens(maxOutputTokens)
.messages(messages)
.toolSpecifications(toolSpecifications)
.build();
}
protected static String computeMimeType(ResponseFormat responseFormat) {
if (responseFormat == null || ResponseFormatType.TEXT.equals(responseFormat.type())) {
return "text/plain";
}
if (ResponseFormatType.JSON.equals(responseFormat.type()) &&
responseFormat.jsonSchema() != null &&
responseFormat.jsonSchema().rootElement() != null &&
responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema) {
return "text/x.enum";
}
return "application/json";
}
protected void notifyListenersOnRequest(ChatModelRequestContext context) {
listeners.forEach((listener) -> {
try {
listener.onRequest(context);
} catch (Exception e) {
log.warn("Exception while calling model listener (onRequest)", e);
}
});
}
protected void notifyListenersOnResponse(Response<AiMessage> response, ChatModelRequest request,
ConcurrentHashMap<Object, Object> attributes) {
ChatModelResponse chatModelResponse = ChatModelResponse.builder()
.model(modelName)
.tokenUsage(response.tokenUsage())
.finishReason(response.finishReason())
.aiMessage(response.content())
.build();
ChatModelResponseContext context = new ChatModelResponseContext(
chatModelResponse, request, attributes);
listeners.forEach((listener) -> {
try {
listener.onResponse(context);
} catch (Exception e) {
log.warn("Exception while calling model listener (onResponse)", e);
}
});
}
protected void notifyListenersOnError(Exception exception, ChatModelRequest request,
ConcurrentHashMap<Object, Object> attributes) {
listeners.forEach((listener) -> {
try {
ChatModelErrorContext context = new ChatModelErrorContext(
exception, request, null, attributes);
listener.onError(context);
} catch (Exception e) {
log.warn("Exception while calling model listener (onError)", e);
}
});
}
}

View File

@ -10,6 +10,8 @@ import java.net.http.HttpClient;
import java.net.http.HttpRequest; import java.net.http.HttpRequest;
import java.net.http.HttpResponse; import java.net.http.HttpResponse;
import java.time.Duration; import java.time.Duration;
import java.util.stream.Collectors;
import java.util.stream.Stream;
class GeminiService { class GeminiService {
private static final String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta"; private static final String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta";
@ -48,6 +50,11 @@ class GeminiService {
return sendRequest(url, apiKey, request, GoogleAiBatchEmbeddingResponse.class); return sendRequest(url, apiKey, request, GoogleAiBatchEmbeddingResponse.class);
} }
Stream<GeminiGenerateContentResponse> generateContentStream(String modelName, String apiKey, GeminiGenerateContentRequest request) {
String url = String.format("%s/models/%s:streamGenerateContent?alt=sse", GEMINI_AI_ENDPOINT, modelName);
return streamRequest(url, apiKey, request, GeminiGenerateContentResponse.class);
}
private <T> T sendRequest(String url, String apiKey, Object requestBody, Class<T> responseType) { private <T> T sendRequest(String url, String apiKey, Object requestBody, Class<T> responseType) {
String jsonBody = gson.toJson(requestBody); String jsonBody = gson.toJson(requestBody);
HttpRequest request = buildHttpRequest(url, apiKey, jsonBody); HttpRequest request = buildHttpRequest(url, apiKey, jsonBody);
@ -72,6 +79,40 @@ class GeminiService {
} }
} }
private <T> Stream<T> streamRequest(String url, String apiKey, Object requestBody, Class<T> responseType) {
String jsonBody = gson.toJson(requestBody);
HttpRequest httpRequest = buildHttpRequest(url, apiKey, jsonBody);
logRequest(jsonBody);
try {
HttpResponse<Stream<String>> httpResponse = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofLines());
if (httpResponse.statusCode() >= 300) {
String errorBody = httpResponse.body()
.collect(Collectors.joining("\n"));
throw new RuntimeException(String.format("HTTP error (%d): %s", httpResponse.statusCode(), errorBody));
}
Stream<T> responseStream = httpResponse.body()
.filter(line -> line.startsWith("data: "))
.map(line -> line.substring(6)) // Remove "data: " prefix
.map(jsonString -> gson.fromJson(jsonString, responseType));
if (logger != null) {
responseStream = responseStream.peek(response -> logger.debug("Partial response from Gemini:\n{}", response));
}
return responseStream;
} catch (IOException e) {
throw new RuntimeException("An error occurred while streaming the request", e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("Streaming the request was interrupted", e);
}
}
private HttpRequest buildHttpRequest(String url, String apiKey, String jsonBody) { private HttpRequest buildHttpRequest(String url, String apiKey, String jsonBody) {
return HttpRequest.newBuilder() return HttpRequest.newBuilder()
.uri(URI.create(url)) .uri(URI.create(url))

View File

@ -0,0 +1,111 @@
package dev.langchain4j.model.googleai;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import static dev.langchain4j.model.googleai.FinishReasonMapper.fromGFinishReasonToFinishReason;
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromGPartsToAiMessage;
/**
* A builder class for constructing streaming responses from Gemini AI model.
* This class accumulates partial responses and builds a final response.
*/
class GeminiStreamingResponseBuilder {
private final boolean includeCodeExecutionOutput;
private final StringBuilder contentBuilder;
private final List<ToolExecutionRequest> functionCalls;
private TokenUsage tokenUsage;
private FinishReason finishReason;
/**
* Constructs a new GeminiStreamingResponseBuilder.
*
* @param includeCodeExecutionOutput whether to include code execution output in the response
*/
public GeminiStreamingResponseBuilder(boolean includeCodeExecutionOutput) {
this.includeCodeExecutionOutput = includeCodeExecutionOutput;
this.contentBuilder = new StringBuilder();
this.functionCalls = new ArrayList<>();
}
/**
* Appends a partial response to the builder.
*
* @param partialResponse the partial response from Gemini AI
* @return an Optional containing the text of the partial response, or empty if no valid text
*/
public Optional<String> append(GeminiGenerateContentResponse partialResponse) {
if (partialResponse == null) {
return Optional.empty();
}
GeminiCandidate firstCandidate = partialResponse.getCandidates().get(0);
updateFinishReason(firstCandidate);
updateTokenUsage(partialResponse.getUsageMetadata());
GeminiContent content = firstCandidate.getContent();
if (content == null || content.getParts() == null) {
return Optional.empty();
}
AiMessage message = fromGPartsToAiMessage(content.getParts(), this.includeCodeExecutionOutput);
updateContentAndFunctionCalls(message);
return Optional.ofNullable(message.text());
}
/**
* Builds the final response from all accumulated partial responses.
*
* @return a Response object containing the final AiMessage, token usage, and finish reason
*/
public Response<AiMessage> build() {
AiMessage aiMessage = createAiMessage();
return Response.from(aiMessage, tokenUsage, finishReason);
}
private void updateTokenUsage(GeminiUsageMetadata tokenCounts) {
this.tokenUsage = new TokenUsage(
tokenCounts.getPromptTokenCount(),
tokenCounts.getCandidatesTokenCount(),
tokenCounts.getTotalTokenCount()
);
}
private void updateFinishReason(GeminiCandidate candidate) {
if (candidate.getFinishReason() != null) {
this.finishReason = fromGFinishReasonToFinishReason(candidate.getFinishReason());
}
}
private void updateContentAndFunctionCalls(AiMessage message) {
Optional.ofNullable(message.text()).ifPresent(contentBuilder::append);
if (message.hasToolExecutionRequests()) {
functionCalls.addAll(message.toolExecutionRequests());
}
}
private AiMessage createAiMessage() {
String text = contentBuilder.toString();
boolean hasText = !text.isEmpty() && !text.isBlank();
boolean hasFunctionCall = !functionCalls.isEmpty();
if (hasText && hasFunctionCall) {
return new AiMessage(text, functionCalls);
} else if (hasText) {
return new AiMessage(text);
} else if (hasFunctionCall) {
return new AiMessage(functionCalls);
}
throw new RuntimeException("Gemini has responded neither with text nor with a function call.");
}
}

View File

@ -1,6 +1,5 @@
package dev.langchain4j.model.googleai; package dev.langchain4j.model.googleai;
import com.google.gson.Gson;
import dev.langchain4j.Experimental; import dev.langchain4j.Experimental;
import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.AiMessage;
@ -8,11 +7,12 @@ import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.Capability; import dev.langchain4j.model.chat.Capability;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator; import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.chat.listener.*; import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.request.ChatRequest; import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ResponseFormat; import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.ResponseFormatType; import dev.langchain4j.model.chat.request.ResponseFormatType;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
@ -23,135 +23,63 @@ import lombok.extern.slf4j.Slf4j;
import java.time.Duration; import java.time.Duration;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.ArrayList;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.HashSet;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static dev.langchain4j.internal.RetryUtils.withRetry; import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.copyIfNotNull;
import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA; import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA;
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromMessageToGContent;
import static dev.langchain4j.model.googleai.FinishReasonMapper.fromGFinishReasonToFinishReason; import static dev.langchain4j.model.googleai.FinishReasonMapper.fromGFinishReasonToFinishReason;
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromGPartsToAiMessage; import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromGPartsToAiMessage;
import static dev.langchain4j.model.googleai.SchemaMapper.fromJsonSchemaToGSchema;
import static java.time.Duration.ofSeconds; import static java.time.Duration.ofSeconds;
import static java.util.Collections.emptyList;
@Experimental @Experimental
@Slf4j @Slf4j
public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEstimator { public class GoogleAiGeminiChatModel extends BaseGeminiChatModel implements ChatLanguageModel, TokenCountEstimator {
private final GeminiService geminiService;
private final String apiKey;
private final String modelName;
private final Integer maxRetries;
private final Double temperature;
private final Integer topK;
private final Double topP;
private final Integer maxOutputTokens;
private final List<String> stopSequences;
private final Integer candidateCount;
private final ResponseFormat responseFormat;
private final GeminiFunctionCallingConfig toolConfig;
private final boolean allowCodeExecution;
private final boolean includeCodeExecutionOutput;
private final Boolean logRequestsAndResponses;
private final List<GeminiSafetySetting> safetySettings;
private final List<ChatModelListener> listeners;
private final GoogleAiGeminiTokenizer geminiTokenizer; private final GoogleAiGeminiTokenizer geminiTokenizer;
@Builder @Builder
public GoogleAiGeminiChatModel(String apiKey, String modelName, public GoogleAiGeminiChatModel(
Integer maxRetries, String apiKey, String modelName,
Double temperature, Integer topK, Double topP, Integer maxRetries,
Integer maxOutputTokens, Integer candidateCount, Double temperature, Integer topK, Double topP,
Duration timeout, Integer maxOutputTokens, Duration timeout,
ResponseFormat responseFormat, ResponseFormat responseFormat,
List<String> stopSequences, GeminiFunctionCallingConfig toolConfig, List<String> stopSequences, GeminiFunctionCallingConfig toolConfig,
Boolean allowCodeExecution, Boolean includeCodeExecutionOutput, Boolean allowCodeExecution, Boolean includeCodeExecutionOutput,
Boolean logRequestsAndResponses, Boolean logRequestsAndResponses,
List<GeminiSafetySetting> safetySettings, List<GeminiSafetySetting> safetySettings,
List<ChatModelListener> listeners List<ChatModelListener> listeners
) { ) {
this.apiKey = ensureNotBlank(apiKey, "apiKey"); super(apiKey, modelName, temperature, topK, topP, maxOutputTokens, timeout,
this.modelName = ensureNotBlank(modelName, "modelName"); responseFormat, stopSequences, toolConfig, allowCodeExecution,
includeCodeExecutionOutput, logRequestsAndResponses, safetySettings,
this.maxRetries = getOrDefault(maxRetries, 3); listeners, maxRetries);
// using Gemini's default values
this.temperature = getOrDefault(temperature, 1.0);
this.topK = getOrDefault(topK, 64);
this.topP = getOrDefault(topP, 0.95);
this.maxOutputTokens = getOrDefault(maxOutputTokens, 8192);
this.candidateCount = getOrDefault(candidateCount, 1);
this.stopSequences = getOrDefault(stopSequences, emptyList());
this.toolConfig = toolConfig;
this.allowCodeExecution = allowCodeExecution != null ? allowCodeExecution : false;
this.includeCodeExecutionOutput = includeCodeExecutionOutput != null ? includeCodeExecutionOutput : false;
this.logRequestsAndResponses = getOrDefault(logRequestsAndResponses, false);
this.safetySettings = copyIfNotNull(safetySettings);
this.responseFormat = responseFormat;
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
this.geminiService = new GeminiService(
getOrDefault(logRequestsAndResponses, false) ? log : null,
getOrDefault(timeout, ofSeconds(60))
);
this.geminiTokenizer = GoogleAiGeminiTokenizer.builder() this.geminiTokenizer = GoogleAiGeminiTokenizer.builder()
.modelName(this.modelName) .modelName(this.modelName)
.apiKey(this.apiKey) .apiKey(this.apiKey)
.timeout(getOrDefault(timeout, ofSeconds(60))) .timeout(getOrDefault(timeout, ofSeconds(60)))
.maxRetries(this.maxRetries) .maxRetries(this.maxRetries)
.logRequestsAndResponses(this.logRequestsAndResponses) .logRequestsAndResponses(getOrDefault(logRequestsAndResponses, false))
.build(); .build();
}
private static String computeMimeType(ResponseFormat responseFormat) {
if (responseFormat == null || ResponseFormatType.TEXT.equals(responseFormat.type())) {
return "text/plain";
}
if (ResponseFormatType.JSON.equals(responseFormat.type()) &&
responseFormat.jsonSchema() != null &&
responseFormat.jsonSchema().rootElement() != null &&
responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema) {
return "text/x.enum";
}
return "application/json";
} }
@Override @Override
public Response<AiMessage> generate(List<ChatMessage> messages) { public Response<AiMessage> generate(List<ChatMessage> messages) {
ChatRequest request = ChatRequest.builder() ChatRequest request = ChatRequest.builder()
.messages(messages) .messages(messages)
.build(); .build();
ChatResponse response = chat(request); ChatResponse response = chat(request);
return Response.from(response.aiMessage(), return Response.from(response.aiMessage(),
response.tokenUsage(), response.tokenUsage(),
response.finishReason()); response.finishReason());
} }
@Override @Override
@ -162,128 +90,86 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
@Override @Override
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) { public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
ChatRequest request = ChatRequest.builder() ChatRequest request = ChatRequest.builder()
.messages(messages) .messages(messages)
.toolSpecifications(toolSpecifications) .toolSpecifications(toolSpecifications)
.build(); .build();
ChatResponse response = chat(request); ChatResponse response = chat(request);
return Response.from(response.aiMessage(), return Response.from(response.aiMessage(),
response.tokenUsage(), response.tokenUsage(),
response.finishReason()); response.finishReason());
} }
@Override @Override
public ChatResponse chat(ChatRequest chatRequest) { public ChatResponse chat(ChatRequest chatRequest) {
GeminiContent systemInstruction = new GeminiContent(GeminiRole.MODEL.toString()); GeminiGenerateContentRequest request = createGenerateContentRequest(
List<GeminiContent> geminiContentList = fromMessageToGContent(chatRequest.messages(), systemInstruction); chatRequest.messages(),
List<ToolSpecification> toolSpecifications = chatRequest.toolSpecifications(); chatRequest.toolSpecifications(),
getOrDefault(chatRequest.responseFormat(), this.responseFormat)
);
ResponseFormat format = chatRequest.responseFormat() != null ? chatRequest.responseFormat() : this.responseFormat; ChatModelRequest chatModelRequest = createChatModelRequest(
GeminiSchema schema = null; chatRequest.messages(),
chatRequest.toolSpecifications()
);
String responseMimeType = computeMimeType(format);
if (format != null && format.jsonSchema() != null) {
schema = fromJsonSchemaToGSchema(format.jsonSchema());
}
GeminiGenerateContentRequest request = GeminiGenerateContentRequest.builder()
.contents(geminiContentList)
.systemInstruction(!systemInstruction.getParts().isEmpty() ? systemInstruction : null)
.generationConfig(GeminiGenerationConfig.builder()
.candidateCount(this.candidateCount)
.maxOutputTokens(this.maxOutputTokens)
.responseMimeType(responseMimeType)
.responseSchema(schema)
.stopSequences(this.stopSequences)
.temperature(this.temperature)
.topK(this.topK)
.topP(this.topP)
.build())
.safetySettings(this.safetySettings)
.tools(FunctionMapper.fromToolSepcsToGTool(toolSpecifications, this.allowCodeExecution))
.toolConfig(new GeminiToolConfig(this.toolConfig))
.build();
ChatModelRequest chatModelRequest = ChatModelRequest.builder()
.model(modelName)
.temperature(temperature)
.topP(topP)
.maxTokens(maxOutputTokens)
.messages(chatRequest.messages())
.toolSpecifications(chatRequest.toolSpecifications())
.build();
ConcurrentHashMap<Object, Object> listenerAttributes = new ConcurrentHashMap<>(); ConcurrentHashMap<Object, Object> listenerAttributes = new ConcurrentHashMap<>();
ChatModelRequestContext chatModelRequestContext = new ChatModelRequestContext(chatModelRequest, listenerAttributes); notifyListenersOnRequest(new ChatModelRequestContext(chatModelRequest, listenerAttributes));
listeners.forEach((listener) -> {
try {
listener.onRequest(chatModelRequestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener (onRequest)", e);
}
});
GeminiGenerateContentResponse geminiResponse;
try { try {
geminiResponse = withRetry(() -> this.geminiService.generateContent(this.modelName, this.apiKey, request), this.maxRetries); GeminiGenerateContentResponse geminiResponse = withRetry(
} catch (RuntimeException e) { () -> this.geminiService.generateContent(this.modelName, this.apiKey, request),
ChatModelErrorContext chatModelErrorContext = new ChatModelErrorContext( this.maxRetries
e, chatModelRequest, null, listenerAttributes
); );
listeners.forEach((listener) -> {
try {
listener.onError(chatModelErrorContext);
} catch (Exception ex) {
log.warn("Exception while calling model listener (onError)", ex);
}
});
return processResponse(geminiResponse, chatModelRequest, listenerAttributes);
} catch (RuntimeException e) {
notifyListenersOnError(e, chatModelRequest, listenerAttributes);
throw e; throw e;
} }
}
if (geminiResponse != null) { private ChatResponse processResponse(
GeminiCandidate firstCandidate = geminiResponse.getCandidates().get(0); //TODO handle n GeminiGenerateContentResponse geminiResponse,
GeminiUsageMetadata tokenCounts = geminiResponse.getUsageMetadata(); ChatModelRequest chatModelRequest,
ConcurrentHashMap<Object, Object> listenerAttributes
AiMessage aiMessage; ) {
if (geminiResponse == null) {
FinishReason finishReason = fromGFinishReasonToFinishReason(firstCandidate.getFinishReason());
if (firstCandidate.getContent() == null) {
aiMessage = AiMessage.from("No text was returned by the model. " +
"The model finished generating because of the following reason: " + finishReason);
} else {
aiMessage = fromGPartsToAiMessage(firstCandidate.getContent().getParts(), this.includeCodeExecutionOutput);
}
TokenUsage tokenUsage = new TokenUsage(tokenCounts.getPromptTokenCount(),
tokenCounts.getCandidatesTokenCount(),
tokenCounts.getTotalTokenCount());
ChatModelResponse chatModelResponse = ChatModelResponse.builder()
.model(modelName)
.tokenUsage(tokenUsage)
.finishReason(finishReason)
.aiMessage(aiMessage)
.build();
ChatModelResponseContext chatModelResponseContext = new ChatModelResponseContext(
chatModelResponse, chatModelRequest, listenerAttributes);
listeners.forEach((listener) -> {
try {
listener.onResponse(chatModelResponseContext);
} catch (Exception e) {
log.warn("Exception while calling model listener (onResponse)", e);
}
});
return ChatResponse.builder()
.aiMessage(aiMessage)
.finishReason(finishReason)
.tokenUsage(tokenUsage)
.build();
} else {
throw new RuntimeException("Gemini response was null"); throw new RuntimeException("Gemini response was null");
} }
GeminiCandidate firstCandidate = geminiResponse.getCandidates().get(0);
GeminiUsageMetadata tokenCounts = geminiResponse.getUsageMetadata();
FinishReason finishReason = fromGFinishReasonToFinishReason(firstCandidate.getFinishReason());
AiMessage aiMessage = createAiMessage(firstCandidate, finishReason);
TokenUsage tokenUsage = createTokenUsage(tokenCounts);
Response<AiMessage> response = Response.from(aiMessage, tokenUsage, finishReason);
notifyListenersOnResponse(response, chatModelRequest, listenerAttributes);
return ChatResponse.builder()
.aiMessage(aiMessage)
.finishReason(finishReason)
.tokenUsage(tokenUsage)
.build();
}
private AiMessage createAiMessage(GeminiCandidate candidate, FinishReason finishReason) {
if (candidate.getContent() == null) {
return AiMessage.from("No text was returned by the model. " +
"The model finished generating because of the following reason: " + finishReason);
}
return fromGPartsToAiMessage(candidate.getContent().getParts(), this.includeCodeExecutionOutput);
}
private TokenUsage createTokenUsage(GeminiUsageMetadata tokenCounts) {
return new TokenUsage(
tokenCounts.getPromptTokenCount(),
tokenCounts.getCandidatesTokenCount(),
tokenCounts.getTotalTokenCount()
);
} }
@Override @Override
@ -309,9 +195,9 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
public GoogleAiGeminiChatModelBuilder safetySettings(Map<GeminiHarmCategory, GeminiHarmBlockThreshold> safetySettingMap) { public GoogleAiGeminiChatModelBuilder safetySettings(Map<GeminiHarmCategory, GeminiHarmBlockThreshold> safetySettingMap) {
this.safetySettings = safetySettingMap.entrySet().stream() this.safetySettings = safetySettingMap.entrySet().stream()
.map(entry -> new GeminiSafetySetting(entry.getKey(), entry.getValue()) .map(entry -> new GeminiSafetySetting(entry.getKey(), entry.getValue())
).collect(Collectors.toList()); ).collect(Collectors.toList());
return this; return this;
} }
} }
} }

View File

@ -0,0 +1,93 @@
package dev.langchain4j.model.googleai;
import dev.langchain4j.Experimental;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.output.Response;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Stream;
import static dev.langchain4j.internal.RetryUtils.withRetry;
@Experimental
@Slf4j
public class GoogleAiGeminiStreamingChatModel extends BaseGeminiChatModel implements StreamingChatLanguageModel {
@Builder
public GoogleAiGeminiStreamingChatModel(
String apiKey, String modelName,
Double temperature, Integer topK, Double topP,
Integer maxOutputTokens, Duration timeout,
ResponseFormat responseFormat,
List<String> stopSequences, GeminiFunctionCallingConfig toolConfig,
Boolean allowCodeExecution, Boolean includeCodeExecutionOutput,
Boolean logRequestsAndResponses,
List<GeminiSafetySetting> safetySettings,
List<ChatModelListener> listeners,
Integer maxRetries
) {
super(apiKey, modelName, temperature, topK, topP, maxOutputTokens, timeout,
responseFormat, stopSequences, toolConfig, allowCodeExecution,
includeCodeExecutionOutput, logRequestsAndResponses, safetySettings,
listeners, maxRetries);
}
@Override
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
generate(messages, Collections.emptyList(), handler);
}
@Override
public void generate(List<ChatMessage> messages, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> handler) {
throw new RuntimeException("This method is not supported: Gemini AI cannot be forced to execute a tool.");
}
@Override
public void generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, StreamingResponseHandler<AiMessage> handler) {
GeminiGenerateContentRequest request = createGenerateContentRequest(messages, toolSpecifications, this.responseFormat);
ChatModelRequest chatModelRequest = createChatModelRequest(messages, toolSpecifications);
ConcurrentHashMap<Object, Object> listenerAttributes = new ConcurrentHashMap<>();
ChatModelRequestContext chatModelRequestContext = new ChatModelRequestContext(chatModelRequest, listenerAttributes);
notifyListenersOnRequest(chatModelRequestContext);
processGenerateContentRequest(request, handler, chatModelRequest, listenerAttributes);
}
private void processGenerateContentRequest(GeminiGenerateContentRequest request, StreamingResponseHandler<AiMessage> handler,
ChatModelRequest chatModelRequest, ConcurrentHashMap<Object, Object> listenerAttributes) {
GeminiStreamingResponseBuilder responseBuilder = new GeminiStreamingResponseBuilder(this.includeCodeExecutionOutput);
try {
Stream<GeminiGenerateContentResponse> contentStream = withRetry(
() -> this.geminiService.generateContentStream(this.modelName, this.apiKey, request),
maxRetries);
contentStream.forEach(partialResponse -> {
Optional<String> text = responseBuilder.append(partialResponse);
text.ifPresent(handler::onNext);
});
Response<AiMessage> fullResponse = responseBuilder.build();
handler.onComplete(fullResponse);
notifyListenersOnResponse(fullResponse, chatModelRequest, listenerAttributes);
} catch (RuntimeException exception) {
notifyListenersOnError(exception, chatModelRequest, listenerAttributes);
handler.onError(exception);
}
}
}

View File

@ -18,12 +18,12 @@ import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.chat.MessageWindowChatMemory; import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.request.ChatRequest; import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ResponseFormat; import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.request.json.JsonArraySchema; import dev.langchain4j.model.chat.request.json.JsonArraySchema;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema; import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
@ -34,7 +34,13 @@ import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.RetryingTest; import org.junitpioneer.jupiter.RetryingTest;
import java.time.Duration; import java.time.Duration;
import java.util.*; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static dev.langchain4j.internal.Utils.readBytes; import static dev.langchain4j.internal.Utils.readBytes;
@ -44,7 +50,9 @@ import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HA
import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HATE_SPEECH; import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HATE_SPEECH;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
public class GoogleAiGeminiChatModelIT { public class GoogleAiGeminiChatModelIT {
@ -114,8 +122,8 @@ public class GoogleAiGeminiChatModelIT {
assertThat(jsonText).contains("\"John\""); assertThat(jsonText).contains("\"John\"");
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(25); assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(25);
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(6); assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(7);
assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(31); assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(32);
} }
@Test @Test
@ -435,10 +443,10 @@ public class GoogleAiGeminiChatModelIT {
.jsonSchema(JsonSchema.builder() .jsonSchema(JsonSchema.builder()
.rootElement(JsonObjectSchema.builder() .rootElement(JsonObjectSchema.builder()
.addStringProperty("name") .addStringProperty("name")
.addProperty("address", JsonObjectSchema.builder() .addProperty("address", JsonObjectSchema.builder()
.addStringProperty("city") .addStringProperty("city")
.required("city") .required("city")
.build()) .build())
.required("name", "address") .required("name", "address")
.additionalProperties(false) .additionalProperties(false)
.build()) .build())

View File

@ -0,0 +1,708 @@
package dev.langchain4j.model.googleai;
import com.google.gson.Gson;
import dev.langchain4j.agent.tool.P;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.AudioContent;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.TextFileContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.output.JsonSchemas;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.RetryingTest;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import static dev.langchain4j.internal.Utils.readBytes;
import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON;
import static dev.langchain4j.model.googleai.GeminiHarmBlockThreshold.BLOCK_LOW_AND_ABOVE;
import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HARASSMENT;
import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HATE_SPEECH;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
public class GoogleAiGeminiStreamingChatModelIT {
private static final String GOOGLE_AI_GEMINI_API_KEY = System.getenv("GOOGLE_AI_GEMINI_API_KEY");
private static final String CAT_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/e/e9/Felis_silvestris_silvestris_small_gradual_decrease_of_quality.png";
private static final String MD_FILE_URL = "https://raw.githubusercontent.com/langchain4j/langchain4j/main/docs/docs/intro.md";
@Test
void should_answer_simple_question() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.build();
UserMessage userMessage = UserMessage.from("What is the capital of France?");
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(userMessage, handler);
Response<AiMessage> response = handler.get();
// then
String text = response.content().text();
assertThat(text).containsIgnoringCase("Paris");
assertThat(response.finishReason()).isEqualTo(FinishReason.STOP);
}
@Test
void should_configure_generation() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.temperature(2.0)
.topP(0.5)
.topK(10)
.maxOutputTokens(10)
.build();
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate("How much is 3+4? Reply with just the answer", handler);
Response<AiMessage> response = handler.get();
// then
String text = response.content().text();
assertThat(text.trim()).isEqualTo("7");
}
@Test
void should_answer_in_json() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-pro")
.responseFormat(ResponseFormat.JSON)
.logRequestsAndResponses(true)
.build();
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(UserMessage.from("What is the firstname of the John Doe?\n" +
"Reply in JSON following with the following format: {\"firstname\": string}"), handler);
Response<AiMessage> response = handler.get();
// then
String jsonText = response.content().text();
assertThat(jsonText).contains("\"firstname\"");
assertThat(jsonText).contains("\"John\"");
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(25);
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(7);
assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(32);
}
@Test
void should_support_multiturn_chatting() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.build();
List<ChatMessage> messages = new ArrayList<>();
messages.add(UserMessage.from("Hello, my name is Guillaume"));
messages.add(AiMessage.from("Hi, how can I assist you today?"));
messages.add(UserMessage.from("What is my name? Reply with just my name."));
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(messages, handler);
Response<AiMessage> response = handler.get();
// then
assertThat(response.content().text()).contains("Guillaume");
}
@Test
void should_support_sending_images_as_base64() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.build();
UserMessage userMessage = UserMessage.userMessage(
ImageContent.from(new String(Base64.getEncoder().encode(readBytes(CAT_IMAGE_URL))), "image/png"),
TextContent.from("Describe this image in a single word")
);
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(userMessage, handler);
Response<AiMessage> response = handler.get();
// then
assertThat(response.content().text()).containsIgnoringCase("cat");
}
@Test
void should_support_text_file() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.build();
UserMessage userMessage = UserMessage.userMessage(
TextFileContent.from(new String(Base64.getEncoder().encode(readBytes(MD_FILE_URL))), "text/markdown"),
TextContent.from("What project does this markdown file mention?")
);
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(userMessage, handler);
Response<AiMessage> response = handler.get();
// then
assertThat(response.content().text()).containsIgnoringCase("LangChain4j");
}
@Test
void should_support_audio_file() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.build();
UserMessage userMessage = UserMessage.from(
AudioContent.from(
new String(Base64.getEncoder().encode( //TODO use local file
readBytes("https://storage.googleapis.com/cloud-samples-data/generative-ai/audio/pixel.mp3"))),
"audio/mp3"),
TextContent.from("Give a summary of the audio")
);
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(userMessage, handler);
Response<AiMessage> response = handler.get();
// then
assertThat(response.content().text()).containsIgnoringCase("Pixel");
}
@Test
@Disabled
void should_support_video_file() {
// ToDo waiting for the normal GoogleAiGeminiChatModel to implement the test
}
@Test
void should_respect_system_instruction() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.build();
List<ChatMessage> chatMessages = List.of(
SystemMessage.from("Translate from English into French"),
UserMessage.from("apple")
);
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(chatMessages, handler);
Response<AiMessage> response = handler.get();
// then
assertThat(response.content().text()).containsIgnoringCase("pomme");
}
@Test
void should_execute_python_code() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.allowCodeExecution(true)
.includeCodeExecutionOutput(true)
.build();
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(
UserMessage.from("Calculate `fibonacci(13)`. Write code in Python and execute it to get the result."),
handler
);
Response<AiMessage> response = handler.get();
// then
String text = response.content().text();
System.out.println("text = " + text);
assertThat(text).containsIgnoringCase("233");
}
@Test
void should_support_JSON_array_in_tools() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.build();
// when
List<ChatMessage> allMessages = new ArrayList<>();
allMessages.add(UserMessage.from("Return a JSON list containing the first 10 fibonacci numbers."));
TestStreamingResponseHandler<AiMessage> handler1 = new TestStreamingResponseHandler<>();
gemini.generate(
allMessages,
List.of(ToolSpecification.builder()
.name("getFirstNFibonacciNumbers")
.description("Get the first n fibonacci numbers")
.parameters(JsonObjectSchema.builder()
.addNumberProperty("n")
.build())
.build()),
handler1
);
Response<AiMessage> response1 = handler1.get();
// then
assertThat(response1.content().hasToolExecutionRequests()).isTrue();
assertThat(response1.content().toolExecutionRequests().get(0).name()).isEqualTo("getFirstNFibonacciNumbers");
assertThat(response1.content().toolExecutionRequests().get(0).arguments()).contains("\"n\":10");
allMessages.add(response1.content());
// when
ToolExecutionResultMessage forecastResult =
ToolExecutionResultMessage.from(null, "getFirstNFibonacciNumbers", "[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]");
allMessages.add(forecastResult);
// then
TestStreamingResponseHandler<AiMessage> handler2 = new TestStreamingResponseHandler<>();
gemini.generate(allMessages, handler2);
Response<AiMessage> response2 = handler2.get();
// then
assertThat(response2.content().text()).contains("[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]");
}
// Test is flaky, because Gemini doesn't 100% always ask for parallel tool calls
// and sometimes requests more information
@RetryingTest(5)
void should_support_parallel_tool_execution() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.build();
// when
List<ChatMessage> allMessages = new ArrayList<>();
allMessages.add(UserMessage.from("Which warehouse has more stock, ABC or XYZ?"));
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(
allMessages,
List.of(ToolSpecification.builder()
.name("getWarehouseStock")
.description("Retrieve the amount of stock available in a warehouse designated by its name")
.parameters(JsonObjectSchema.builder()
.addStringProperty("name", "The name of the warehouse")
.build())
.build()),
handler
);
Response<AiMessage> response = handler.get();
// then
assertThat(response.content().hasToolExecutionRequests()).isTrue();
List<ToolExecutionRequest> executionRequests = response.content().toolExecutionRequests();
assertThat(executionRequests).hasSize(2);
String allArgs = executionRequests.stream()
.map(ToolExecutionRequest::arguments)
.collect(Collectors.joining(" "));
assertThat(allArgs).contains("ABC");
assertThat(allArgs).contains("XYZ");
}
@RetryingTest(5)
void should_support_safety_settings() {
// given
List<GeminiSafetySetting> safetySettings = List.of(
new GeminiSafetySetting(HARM_CATEGORY_HATE_SPEECH, BLOCK_LOW_AND_ABOVE),
new GeminiSafetySetting(HARM_CATEGORY_HARASSMENT, BLOCK_LOW_AND_ABOVE)
);
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.safetySettings(safetySettings)
.build();
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate("You're a dumb f*cking idiot bastard!", handler);
Response<AiMessage> response = handler.get();
// then
assertThat(response.finishReason()).isEqualTo(FinishReasonMapper.fromGFinishReasonToFinishReason(GeminiFinishReason.SAFETY));
}
@Test
void should_comply_to_response_schema() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.rootElement(JsonObjectSchema.builder()
.addStringProperty("name")
.addProperty("address", JsonObjectSchema.builder()
.addStringProperty("city")
.required("city")
.build())
.required("name", "address")
.additionalProperties(false)
.build())
.build())
.build())
.build();
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(List.of(
SystemMessage.from(
"Your role is to extract information related to a person," +
"like their name, their address, the city the live in."),
UserMessage.from(
"In the town of Liverpool, lived Tommy Skybridge, a young little boy."
)
), handler);
Response<AiMessage> response = handler.get();
System.out.println("response = " + response);
// then
assertThat(response.content().text().trim())
.isEqualTo("{\"address\": {\"city\": \"Liverpool\"}, \"name\": \"Tommy Skybridge\"}");
}
@Test
void should_handle_enum_type() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.rootElement(JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("sentiment", JsonEnumSchema.builder()
.enumValues("POSITIVE", "NEGATIVE")
.build());
}})
.required("sentiment")
.additionalProperties(false)
.build())
.build())
.build())
.build();
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(List.of(
SystemMessage.from(
"Your role is to analyze the sentiment of the text you receive."),
UserMessage.from(
"This is super exciting news, congratulations!"
)
), handler);
Response<AiMessage> response = handler.get();
System.out.println("response = " + response);
// then
assertThat(response.content().text().trim())
.isEqualTo("{\"sentiment\": \"POSITIVE\"}");
}
@Test
void should_handle_enum_output_mode() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.rootElement(JsonEnumSchema.builder()
.enumValues("POSITIVE", "NEUTRAL", "NEGATIVE")
.build())
.build())
.build())
.build();
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(List.of(
SystemMessage.from(
"Your role is to analyze the sentiment of the text you receive."),
UserMessage.from(
"This is super exciting news, congratulations!"
)
), handler);
Response<AiMessage> response = handler.get();
// then
assertThat(response.content().text().trim())
.isEqualTo("POSITIVE");
}
@Test
void should_allow_array_as_response_schema() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.rootElement(JsonArraySchema.builder()
.items(new JsonIntegerSchema())
.build())
.build())
.build())
.build();
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(List.of(
SystemMessage.from(
"Your role is to return a list of 6-faces dice rolls"),
UserMessage.from(
"Give me 3 dice rolls"
)
), handler);
Response<AiMessage> response = handler.get();
System.out.println("response = " + response);
// then
Integer[] diceRolls = new Gson().fromJson(response.content().text(), Integer[].class);
assertThat(diceRolls.length).isEqualTo(3);
}
private class Color {
private String name;
private int red;
private int green;
private int blue;
private boolean muted;
}
@Test
void should_deserialize_to_POJO() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchemas.jsonSchemaFrom(Color.class).get())
.build())
.build();
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(List.of(
SystemMessage.from(
"Your role is to extract information from the description of a color"),
UserMessage.from(
"Cobalt blue is a blend of a lot of blue, a bit of green, and almost no red."
)
), handler);
Response<AiMessage> response = handler.get();
System.out.println("response = " + response);
Color color = new Gson().fromJson(response.content().text(), Color.class);
// then
assertThat(color.name).isEqualToIgnoringCase("Cobalt blue");
assertThat(color.muted).isFalse();
assertThat(color.red).isLessThanOrEqualTo(color.green);
assertThat(color.green).isLessThanOrEqualTo(color.blue);
}
@Test
void should_support_tool_config() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.build();
List<ChatMessage> chatMessages = List.of(UserMessage.from("Call toolOne"));
List<ToolSpecification> listOfTools = Arrays.asList(
ToolSpecification.builder().name("toolOne").build(),
ToolSpecification.builder().name("toolTwo").build()
);
// when
TestStreamingResponseHandler<AiMessage> handler1 = new TestStreamingResponseHandler<>();
gemini.generate(chatMessages, listOfTools, handler1);
Response<AiMessage> response1 = handler1.get();
// then
assertThat(response1.content().hasToolExecutionRequests()).isTrue();
assertThat(response1.content().toolExecutionRequests().get(0).name()).isEqualTo("toolOne");
// given
gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.toolConfig(new GeminiFunctionCallingConfig(GeminiMode.ANY, List.of("toolTwo")))
.build();
// when
TestStreamingResponseHandler<AiMessage> handler2 = new TestStreamingResponseHandler<>();
gemini.generate("Call toolOne", handler2);
Response<AiMessage> response2 = handler2.get();
// then
assertThat(response2.content().hasToolExecutionRequests()).isFalse();
}
static class Transactions {
@Tool("returns amount of a given transaction")
double getTransactionAmount(@P("ID of a transaction") String id) {
System.out.printf("called getTransactionAmount(%s)%n", id);
switch (id) {
case "T001":
return 11.1;
case "T002":
return 22.2;
default:
throw new IllegalArgumentException("Unknown transaction ID: " + id);
}
}
}
interface StreamingAssistant {
TokenStream chat(String userMessage);
}
@RetryingTest(10)
void should_work_with_tools_with_AiServices() throws ExecutionException, InterruptedException, TimeoutException {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-pro")
.logRequestsAndResponses(true)
.timeout(Duration.ofMinutes(2))
.temperature(0.0)
.topP(0.0)
.topK(1)
.build();
// when
Transactions spyTransactions = spy(new Transactions());
MessageWindowChatMemory chatMemory = MessageWindowChatMemory.withMaxMessages(20);
StreamingAssistant assistant = AiServices.builder(StreamingAssistant.class)
.tools(spyTransactions)
.chatMemory(chatMemory)
.streamingChatLanguageModel(gemini)
.build();
// then
CompletableFuture<Response<AiMessage>> future1 = new CompletableFuture<>();
assistant.chat("What is the amount of transaction T001?")
.onNext(System.out::println)
.onComplete(future1::complete)
.onError(future1::completeExceptionally)
.start();
Response<AiMessage> response1 = future1.get(30, TimeUnit.SECONDS);
assertThat(response1.content().text()).containsIgnoringCase("11.1");
verify(spyTransactions).getTransactionAmount("T001");
CompletableFuture<Response<AiMessage>> future2 = new CompletableFuture<>();
assistant.chat("What is the amount of transaction T002?")
.onNext(System.out::println)
.onComplete(future2::complete)
.onError(future2::completeExceptionally)
.start();
Response<AiMessage> response2 = future2.get(30, TimeUnit.SECONDS);
assertThat(response2.content().text()).containsIgnoringCase("22.2");
verify(spyTransactions).getTransactionAmount("T002");
verifyNoMoreInteractions(spyTransactions);
}
@AfterEach
void afterEach() throws InterruptedException {
String ciDelaySeconds = System.getenv("CI_DELAY_SECONDS_GOOGLE_AI_GEMINI");
if (ciDelaySeconds != null) {
Thread.sleep(Integer.parseInt(ciDelaySeconds) * 1000L);
}
}
}

View File

@ -0,0 +1,61 @@
package dev.langchain4j.model.googleai;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatModelListenerIT;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import org.junit.jupiter.api.AfterEach;
import java.io.IOException;
import static java.util.Collections.singletonList;
class GoogleAiGeminiStreamingChatModelListenerIT extends StreamingChatModelListenerIT {
private static final String GOOGLE_AI_GEMINI_API_KEY = System.getenv("GOOGLE_AI_GEMINI_API_KEY");
@Override
protected StreamingChatLanguageModel createModel(ChatModelListener listener) {
return GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName(modelName())
.temperature(temperature())
.topP(topP())
.maxOutputTokens(maxTokens())
.listeners(singletonList(listener))
.logRequestsAndResponses(true)
.build();
}
@Override
protected String modelName() {
return "gemini-1.5-flash";
}
@Override
protected boolean assertResponseId() {
return false;
}
@Override
protected StreamingChatLanguageModel createFailingModel(ChatModelListener listener) {
return GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("banana")
.listeners(singletonList(listener))
.logRequestsAndResponses(true)
.build();
}
@Override
protected Class<? extends Exception> expectedExceptionClass() {
return RuntimeException.class;
}
@AfterEach
void afterEach() throws InterruptedException {
String ciDelaySeconds = System.getenv("CI_DELAY_SECONDS_VERTEX_AI_GEMINI");
if (ciDelaySeconds != null) {
Thread.sleep(Integer.parseInt(ciDelaySeconds) * 1000L);
}
}
}