Adding GoogleAiGeminiStreamingChatModel (#1951)

## Issue
Closes #1903

## Change
Implemented the GoogleAiGeminiStreamingChatModel.

This PR depends on #1950.

## General checklist
- [X] There are no breaking changes
- [x] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [x] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
- [x] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
This commit is contained in:
Bjarne-Kinkel 2024-10-31 10:07:50 +01:00 committed by GitHub
parent e98a2e4f62
commit aa0e488166
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1348 additions and 223 deletions

View File

@ -83,9 +83,34 @@ ChatLanguageModel gemini = GoogleAiGeminiChatModel.builder()
```
## GoogleAiGeminiStreamingChatModel
The `GoogleAiGeminiStreamingChatModel` allows streaming the text of a response token by token. The response must be managed by a `StreamingResponseHandler`.
```java
StreamingChatLanguageModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(System.getenv("GEMINI_AI_KEY"))
.modelName("gemini-1.5-flash")
.build();
No streaming chat model is available yet.
Please open a feature request if you're interested in a streaming model or if you want to contribute to implementing it.
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
gemini.generate("Tell me a joke about Java", new StreamingResponseHandler<AiMessage>() {
@Override
public void onNext(String token) {
System.out.print(token);
}
@Override
public void onComplete(Response<AiMessage> response) {
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureResponse.completeExceptionally(error);
}
});
futureResponse.join();
```
## Tools

View File

@ -11,8 +11,8 @@ sidebar_position: 0
| [Azure OpenAI](/integrations/language-models/azure-open-ai) | ✅ | ✅ | ✅ | text, image | ✅ | | | |
| [ChatGLM](/integrations/language-models/chatglm) | | | | text | | | | |
| [DashScope](/integrations/language-models/dashscope) | ✅ | ✅ | | text, image, audio | ✅ | | | |
| [GitHub Models](/integrations/language-models/github-models) | ✅ | ✅ | ✅ | text | ✅ | | | |
| [Google AI Gemini](/integrations/language-models/google-ai-gemini) | | ✅ | ✅ | text, image, audio, video, PDF | ✅ | | | |
| [GitHub Models](/integrations/language-models/github-models) | ✅ | ✅ | ✅ | text | ✅ | | | |
| [Google AI Gemini](/integrations/language-models/google-ai-gemini) | | ✅ | ✅ | text, image, audio, video, PDF | ✅ | | | |
| [Google Vertex AI Gemini](/integrations/language-models/google-vertex-ai-gemini) | ✅ | ✅ | ✅ | text, image, audio, video, PDF | ✅ | | | |
| [Google Vertex AI PaLM 2](/integrations/language-models/google-palm) | | | | text | | | ✅ | |
| [Hugging Face](/integrations/language-models/hugging-face) | | | | text | | | | |

View File

@ -0,0 +1,192 @@
package dev.langchain4j.model.googleai;
import dev.langchain4j.Experimental;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.ResponseFormatType;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import static dev.langchain4j.internal.Utils.copyIfNotNull;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromMessageToGContent;
import static dev.langchain4j.model.googleai.SchemaMapper.fromJsonSchemaToGSchema;
import static java.time.Duration.ofSeconds;
import static java.util.Collections.emptyList;
@Experimental
@Slf4j
abstract class BaseGeminiChatModel {
protected final GeminiService geminiService;
protected final String apiKey;
protected final String modelName;
protected final Double temperature;
protected final Integer topK;
protected final Double topP;
protected final Integer maxOutputTokens;
protected final List<String> stopSequences;
protected final ResponseFormat responseFormat;
protected final GeminiFunctionCallingConfig toolConfig;
protected final boolean allowCodeExecution;
protected final boolean includeCodeExecutionOutput;
protected final List<GeminiSafetySetting> safetySettings;
protected final List<ChatModelListener> listeners;
protected final Integer maxRetries;
protected BaseGeminiChatModel(
String apiKey,
String modelName,
Double temperature,
Integer topK,
Double topP,
Integer maxOutputTokens,
Duration timeout,
ResponseFormat responseFormat,
List<String> stopSequences,
GeminiFunctionCallingConfig toolConfig,
Boolean allowCodeExecution,
Boolean includeCodeExecutionOutput,
Boolean logRequestsAndResponses,
List<GeminiSafetySetting> safetySettings,
List<ChatModelListener> listeners,
Integer maxRetries
) {
this.apiKey = ensureNotBlank(apiKey, "apiKey");
this.modelName = ensureNotBlank(modelName, "modelName");
this.temperature = temperature;
this.topK = topK;
this.topP = topP;
this.maxOutputTokens = maxOutputTokens;
this.stopSequences = getOrDefault(stopSequences, emptyList());
this.toolConfig = toolConfig;
this.allowCodeExecution = getOrDefault(allowCodeExecution, false);
this.includeCodeExecutionOutput = getOrDefault(includeCodeExecutionOutput, false);
this.safetySettings = copyIfNotNull(safetySettings);
this.responseFormat = responseFormat;
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
this.maxRetries = getOrDefault(maxRetries, 3);
this.geminiService = new GeminiService(
getOrDefault(logRequestsAndResponses, false) ? log : null,
getOrDefault(timeout, ofSeconds(60))
);
}
protected GeminiGenerateContentRequest createGenerateContentRequest(
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ResponseFormat responseFormat
) {
GeminiContent systemInstruction = new GeminiContent(GeminiRole.MODEL.toString());
List<GeminiContent> geminiContentList = fromMessageToGContent(messages, systemInstruction);
GeminiSchema schema = null;
if (responseFormat != null && responseFormat.jsonSchema() != null) {
schema = fromJsonSchemaToGSchema(responseFormat.jsonSchema());
}
return GeminiGenerateContentRequest.builder()
.contents(geminiContentList)
.systemInstruction(!systemInstruction.getParts().isEmpty() ? systemInstruction : null)
.generationConfig(GeminiGenerationConfig.builder()
.candidateCount(1) // Multiple candidates aren't supported by langchain4j
.maxOutputTokens(this.maxOutputTokens)
.responseMimeType(computeMimeType(responseFormat))
.responseSchema(schema)
.stopSequences(this.stopSequences)
.temperature(this.temperature)
.topK(this.topK)
.topP(this.topP)
.build())
.safetySettings(this.safetySettings)
.tools(FunctionMapper.fromToolSepcsToGTool(toolSpecifications, this.allowCodeExecution))
.toolConfig(new GeminiToolConfig(this.toolConfig))
.build();
}
protected ChatModelRequest createChatModelRequest(
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications
) {
return ChatModelRequest.builder()
.model(modelName)
.temperature(temperature)
.topP(topP)
.maxTokens(maxOutputTokens)
.messages(messages)
.toolSpecifications(toolSpecifications)
.build();
}
protected static String computeMimeType(ResponseFormat responseFormat) {
if (responseFormat == null || ResponseFormatType.TEXT.equals(responseFormat.type())) {
return "text/plain";
}
if (ResponseFormatType.JSON.equals(responseFormat.type()) &&
responseFormat.jsonSchema() != null &&
responseFormat.jsonSchema().rootElement() != null &&
responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema) {
return "text/x.enum";
}
return "application/json";
}
protected void notifyListenersOnRequest(ChatModelRequestContext context) {
listeners.forEach((listener) -> {
try {
listener.onRequest(context);
} catch (Exception e) {
log.warn("Exception while calling model listener (onRequest)", e);
}
});
}
protected void notifyListenersOnResponse(Response<AiMessage> response, ChatModelRequest request,
ConcurrentHashMap<Object, Object> attributes) {
ChatModelResponse chatModelResponse = ChatModelResponse.builder()
.model(modelName)
.tokenUsage(response.tokenUsage())
.finishReason(response.finishReason())
.aiMessage(response.content())
.build();
ChatModelResponseContext context = new ChatModelResponseContext(
chatModelResponse, request, attributes);
listeners.forEach((listener) -> {
try {
listener.onResponse(context);
} catch (Exception e) {
log.warn("Exception while calling model listener (onResponse)", e);
}
});
}
protected void notifyListenersOnError(Exception exception, ChatModelRequest request,
ConcurrentHashMap<Object, Object> attributes) {
listeners.forEach((listener) -> {
try {
ChatModelErrorContext context = new ChatModelErrorContext(
exception, request, null, attributes);
listener.onError(context);
} catch (Exception e) {
log.warn("Exception while calling model listener (onError)", e);
}
});
}
}

View File

@ -10,6 +10,8 @@ import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.stream.Collectors;
import java.util.stream.Stream;
class GeminiService {
private static final String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta";
@ -48,6 +50,11 @@ class GeminiService {
return sendRequest(url, apiKey, request, GoogleAiBatchEmbeddingResponse.class);
}
Stream<GeminiGenerateContentResponse> generateContentStream(String modelName, String apiKey, GeminiGenerateContentRequest request) {
String url = String.format("%s/models/%s:streamGenerateContent?alt=sse", GEMINI_AI_ENDPOINT, modelName);
return streamRequest(url, apiKey, request, GeminiGenerateContentResponse.class);
}
private <T> T sendRequest(String url, String apiKey, Object requestBody, Class<T> responseType) {
String jsonBody = gson.toJson(requestBody);
HttpRequest request = buildHttpRequest(url, apiKey, jsonBody);
@ -72,6 +79,40 @@ class GeminiService {
}
}
private <T> Stream<T> streamRequest(String url, String apiKey, Object requestBody, Class<T> responseType) {
String jsonBody = gson.toJson(requestBody);
HttpRequest httpRequest = buildHttpRequest(url, apiKey, jsonBody);
logRequest(jsonBody);
try {
HttpResponse<Stream<String>> httpResponse = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofLines());
if (httpResponse.statusCode() >= 300) {
String errorBody = httpResponse.body()
.collect(Collectors.joining("\n"));
throw new RuntimeException(String.format("HTTP error (%d): %s", httpResponse.statusCode(), errorBody));
}
Stream<T> responseStream = httpResponse.body()
.filter(line -> line.startsWith("data: "))
.map(line -> line.substring(6)) // Remove "data: " prefix
.map(jsonString -> gson.fromJson(jsonString, responseType));
if (logger != null) {
responseStream = responseStream.peek(response -> logger.debug("Partial response from Gemini:\n{}", response));
}
return responseStream;
} catch (IOException e) {
throw new RuntimeException("An error occurred while streaming the request", e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("Streaming the request was interrupted", e);
}
}
private HttpRequest buildHttpRequest(String url, String apiKey, String jsonBody) {
return HttpRequest.newBuilder()
.uri(URI.create(url))

View File

@ -0,0 +1,111 @@
package dev.langchain4j.model.googleai;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import static dev.langchain4j.model.googleai.FinishReasonMapper.fromGFinishReasonToFinishReason;
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromGPartsToAiMessage;
/**
* A builder class for constructing streaming responses from Gemini AI model.
* This class accumulates partial responses and builds a final response.
*/
class GeminiStreamingResponseBuilder {
private final boolean includeCodeExecutionOutput;
private final StringBuilder contentBuilder;
private final List<ToolExecutionRequest> functionCalls;
private TokenUsage tokenUsage;
private FinishReason finishReason;
/**
* Constructs a new GeminiStreamingResponseBuilder.
*
* @param includeCodeExecutionOutput whether to include code execution output in the response
*/
public GeminiStreamingResponseBuilder(boolean includeCodeExecutionOutput) {
this.includeCodeExecutionOutput = includeCodeExecutionOutput;
this.contentBuilder = new StringBuilder();
this.functionCalls = new ArrayList<>();
}
/**
* Appends a partial response to the builder.
*
* @param partialResponse the partial response from Gemini AI
* @return an Optional containing the text of the partial response, or empty if no valid text
*/
public Optional<String> append(GeminiGenerateContentResponse partialResponse) {
if (partialResponse == null) {
return Optional.empty();
}
GeminiCandidate firstCandidate = partialResponse.getCandidates().get(0);
updateFinishReason(firstCandidate);
updateTokenUsage(partialResponse.getUsageMetadata());
GeminiContent content = firstCandidate.getContent();
if (content == null || content.getParts() == null) {
return Optional.empty();
}
AiMessage message = fromGPartsToAiMessage(content.getParts(), this.includeCodeExecutionOutput);
updateContentAndFunctionCalls(message);
return Optional.ofNullable(message.text());
}
/**
* Builds the final response from all accumulated partial responses.
*
* @return a Response object containing the final AiMessage, token usage, and finish reason
*/
public Response<AiMessage> build() {
AiMessage aiMessage = createAiMessage();
return Response.from(aiMessage, tokenUsage, finishReason);
}
private void updateTokenUsage(GeminiUsageMetadata tokenCounts) {
this.tokenUsage = new TokenUsage(
tokenCounts.getPromptTokenCount(),
tokenCounts.getCandidatesTokenCount(),
tokenCounts.getTotalTokenCount()
);
}
private void updateFinishReason(GeminiCandidate candidate) {
if (candidate.getFinishReason() != null) {
this.finishReason = fromGFinishReasonToFinishReason(candidate.getFinishReason());
}
}
private void updateContentAndFunctionCalls(AiMessage message) {
Optional.ofNullable(message.text()).ifPresent(contentBuilder::append);
if (message.hasToolExecutionRequests()) {
functionCalls.addAll(message.toolExecutionRequests());
}
}
private AiMessage createAiMessage() {
String text = contentBuilder.toString();
boolean hasText = !text.isEmpty() && !text.isBlank();
boolean hasFunctionCall = !functionCalls.isEmpty();
if (hasText && hasFunctionCall) {
return new AiMessage(text, functionCalls);
} else if (hasText) {
return new AiMessage(text);
} else if (hasFunctionCall) {
return new AiMessage(functionCalls);
}
throw new RuntimeException("Gemini has responded neither with text nor with a function call.");
}
}

View File

@ -1,6 +1,5 @@
package dev.langchain4j.model.googleai;
import com.google.gson.Gson;
import dev.langchain4j.Experimental;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
@ -8,11 +7,12 @@ import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.Capability;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.chat.listener.*;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.ResponseFormatType;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
@ -23,135 +23,63 @@ import lombok.extern.slf4j.Slf4j;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.ArrayList;
import java.util.Map;
import java.util.Set;
import java.util.HashSet;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.copyIfNotNull;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA;
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromMessageToGContent;
import static dev.langchain4j.model.googleai.FinishReasonMapper.fromGFinishReasonToFinishReason;
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromGPartsToAiMessage;
import static dev.langchain4j.model.googleai.SchemaMapper.fromJsonSchemaToGSchema;
import static java.time.Duration.ofSeconds;
import static java.util.Collections.emptyList;
@Experimental
@Slf4j
public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEstimator {
private final GeminiService geminiService;
private final String apiKey;
private final String modelName;
private final Integer maxRetries;
private final Double temperature;
private final Integer topK;
private final Double topP;
private final Integer maxOutputTokens;
private final List<String> stopSequences;
private final Integer candidateCount;
private final ResponseFormat responseFormat;
private final GeminiFunctionCallingConfig toolConfig;
private final boolean allowCodeExecution;
private final boolean includeCodeExecutionOutput;
private final Boolean logRequestsAndResponses;
private final List<GeminiSafetySetting> safetySettings;
private final List<ChatModelListener> listeners;
public class GoogleAiGeminiChatModel extends BaseGeminiChatModel implements ChatLanguageModel, TokenCountEstimator {
private final GoogleAiGeminiTokenizer geminiTokenizer;
@Builder
public GoogleAiGeminiChatModel(String apiKey, String modelName,
Integer maxRetries,
Double temperature, Integer topK, Double topP,
Integer maxOutputTokens, Integer candidateCount,
Duration timeout,
ResponseFormat responseFormat,
List<String> stopSequences, GeminiFunctionCallingConfig toolConfig,
Boolean allowCodeExecution, Boolean includeCodeExecutionOutput,
Boolean logRequestsAndResponses,
List<GeminiSafetySetting> safetySettings,
List<ChatModelListener> listeners
public GoogleAiGeminiChatModel(
String apiKey, String modelName,
Integer maxRetries,
Double temperature, Integer topK, Double topP,
Integer maxOutputTokens, Duration timeout,
ResponseFormat responseFormat,
List<String> stopSequences, GeminiFunctionCallingConfig toolConfig,
Boolean allowCodeExecution, Boolean includeCodeExecutionOutput,
Boolean logRequestsAndResponses,
List<GeminiSafetySetting> safetySettings,
List<ChatModelListener> listeners
) {
this.apiKey = ensureNotBlank(apiKey, "apiKey");
this.modelName = ensureNotBlank(modelName, "modelName");
this.maxRetries = getOrDefault(maxRetries, 3);
// using Gemini's default values
this.temperature = getOrDefault(temperature, 1.0);
this.topK = getOrDefault(topK, 64);
this.topP = getOrDefault(topP, 0.95);
this.maxOutputTokens = getOrDefault(maxOutputTokens, 8192);
this.candidateCount = getOrDefault(candidateCount, 1);
this.stopSequences = getOrDefault(stopSequences, emptyList());
this.toolConfig = toolConfig;
this.allowCodeExecution = allowCodeExecution != null ? allowCodeExecution : false;
this.includeCodeExecutionOutput = includeCodeExecutionOutput != null ? includeCodeExecutionOutput : false;
this.logRequestsAndResponses = getOrDefault(logRequestsAndResponses, false);
this.safetySettings = copyIfNotNull(safetySettings);
this.responseFormat = responseFormat;
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
this.geminiService = new GeminiService(
getOrDefault(logRequestsAndResponses, false) ? log : null,
getOrDefault(timeout, ofSeconds(60))
);
super(apiKey, modelName, temperature, topK, topP, maxOutputTokens, timeout,
responseFormat, stopSequences, toolConfig, allowCodeExecution,
includeCodeExecutionOutput, logRequestsAndResponses, safetySettings,
listeners, maxRetries);
this.geminiTokenizer = GoogleAiGeminiTokenizer.builder()
.modelName(this.modelName)
.apiKey(this.apiKey)
.timeout(getOrDefault(timeout, ofSeconds(60)))
.maxRetries(this.maxRetries)
.logRequestsAndResponses(this.logRequestsAndResponses)
.build();
}
private static String computeMimeType(ResponseFormat responseFormat) {
if (responseFormat == null || ResponseFormatType.TEXT.equals(responseFormat.type())) {
return "text/plain";
}
if (ResponseFormatType.JSON.equals(responseFormat.type()) &&
responseFormat.jsonSchema() != null &&
responseFormat.jsonSchema().rootElement() != null &&
responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema) {
return "text/x.enum";
}
return "application/json";
.modelName(this.modelName)
.apiKey(this.apiKey)
.timeout(getOrDefault(timeout, ofSeconds(60)))
.maxRetries(this.maxRetries)
.logRequestsAndResponses(getOrDefault(logRequestsAndResponses, false))
.build();
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
ChatRequest request = ChatRequest.builder()
.messages(messages)
.build();
.messages(messages)
.build();
ChatResponse response = chat(request);
return Response.from(response.aiMessage(),
response.tokenUsage(),
response.finishReason());
response.tokenUsage(),
response.finishReason());
}
@Override
@ -162,128 +90,86 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
@Override
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
ChatRequest request = ChatRequest.builder()
.messages(messages)
.toolSpecifications(toolSpecifications)
.build();
.messages(messages)
.toolSpecifications(toolSpecifications)
.build();
ChatResponse response = chat(request);
return Response.from(response.aiMessage(),
response.tokenUsage(),
response.finishReason());
response.tokenUsage(),
response.finishReason());
}
@Override
public ChatResponse chat(ChatRequest chatRequest) {
GeminiContent systemInstruction = new GeminiContent(GeminiRole.MODEL.toString());
List<GeminiContent> geminiContentList = fromMessageToGContent(chatRequest.messages(), systemInstruction);
List<ToolSpecification> toolSpecifications = chatRequest.toolSpecifications();
GeminiGenerateContentRequest request = createGenerateContentRequest(
chatRequest.messages(),
chatRequest.toolSpecifications(),
getOrDefault(chatRequest.responseFormat(), this.responseFormat)
);
ResponseFormat format = chatRequest.responseFormat() != null ? chatRequest.responseFormat() : this.responseFormat;
GeminiSchema schema = null;
ChatModelRequest chatModelRequest = createChatModelRequest(
chatRequest.messages(),
chatRequest.toolSpecifications()
);
String responseMimeType = computeMimeType(format);
if (format != null && format.jsonSchema() != null) {
schema = fromJsonSchemaToGSchema(format.jsonSchema());
}
GeminiGenerateContentRequest request = GeminiGenerateContentRequest.builder()
.contents(geminiContentList)
.systemInstruction(!systemInstruction.getParts().isEmpty() ? systemInstruction : null)
.generationConfig(GeminiGenerationConfig.builder()
.candidateCount(this.candidateCount)
.maxOutputTokens(this.maxOutputTokens)
.responseMimeType(responseMimeType)
.responseSchema(schema)
.stopSequences(this.stopSequences)
.temperature(this.temperature)
.topK(this.topK)
.topP(this.topP)
.build())
.safetySettings(this.safetySettings)
.tools(FunctionMapper.fromToolSepcsToGTool(toolSpecifications, this.allowCodeExecution))
.toolConfig(new GeminiToolConfig(this.toolConfig))
.build();
ChatModelRequest chatModelRequest = ChatModelRequest.builder()
.model(modelName)
.temperature(temperature)
.topP(topP)
.maxTokens(maxOutputTokens)
.messages(chatRequest.messages())
.toolSpecifications(chatRequest.toolSpecifications())
.build();
ConcurrentHashMap<Object, Object> listenerAttributes = new ConcurrentHashMap<>();
ChatModelRequestContext chatModelRequestContext = new ChatModelRequestContext(chatModelRequest, listenerAttributes);
listeners.forEach((listener) -> {
try {
listener.onRequest(chatModelRequestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener (onRequest)", e);
}
});
notifyListenersOnRequest(new ChatModelRequestContext(chatModelRequest, listenerAttributes));
GeminiGenerateContentResponse geminiResponse;
try {
geminiResponse = withRetry(() -> this.geminiService.generateContent(this.modelName, this.apiKey, request), this.maxRetries);
} catch (RuntimeException e) {
ChatModelErrorContext chatModelErrorContext = new ChatModelErrorContext(
e, chatModelRequest, null, listenerAttributes
GeminiGenerateContentResponse geminiResponse = withRetry(
() -> this.geminiService.generateContent(this.modelName, this.apiKey, request),
this.maxRetries
);
listeners.forEach((listener) -> {
try {
listener.onError(chatModelErrorContext);
} catch (Exception ex) {
log.warn("Exception while calling model listener (onError)", ex);
}
});
return processResponse(geminiResponse, chatModelRequest, listenerAttributes);
} catch (RuntimeException e) {
notifyListenersOnError(e, chatModelRequest, listenerAttributes);
throw e;
}
}
if (geminiResponse != null) {
GeminiCandidate firstCandidate = geminiResponse.getCandidates().get(0); //TODO handle n
GeminiUsageMetadata tokenCounts = geminiResponse.getUsageMetadata();
AiMessage aiMessage;
FinishReason finishReason = fromGFinishReasonToFinishReason(firstCandidate.getFinishReason());
if (firstCandidate.getContent() == null) {
aiMessage = AiMessage.from("No text was returned by the model. " +
"The model finished generating because of the following reason: " + finishReason);
} else {
aiMessage = fromGPartsToAiMessage(firstCandidate.getContent().getParts(), this.includeCodeExecutionOutput);
}
TokenUsage tokenUsage = new TokenUsage(tokenCounts.getPromptTokenCount(),
tokenCounts.getCandidatesTokenCount(),
tokenCounts.getTotalTokenCount());
ChatModelResponse chatModelResponse = ChatModelResponse.builder()
.model(modelName)
.tokenUsage(tokenUsage)
.finishReason(finishReason)
.aiMessage(aiMessage)
.build();
ChatModelResponseContext chatModelResponseContext = new ChatModelResponseContext(
chatModelResponse, chatModelRequest, listenerAttributes);
listeners.forEach((listener) -> {
try {
listener.onResponse(chatModelResponseContext);
} catch (Exception e) {
log.warn("Exception while calling model listener (onResponse)", e);
}
});
return ChatResponse.builder()
.aiMessage(aiMessage)
.finishReason(finishReason)
.tokenUsage(tokenUsage)
.build();
} else {
private ChatResponse processResponse(
GeminiGenerateContentResponse geminiResponse,
ChatModelRequest chatModelRequest,
ConcurrentHashMap<Object, Object> listenerAttributes
) {
if (geminiResponse == null) {
throw new RuntimeException("Gemini response was null");
}
GeminiCandidate firstCandidate = geminiResponse.getCandidates().get(0);
GeminiUsageMetadata tokenCounts = geminiResponse.getUsageMetadata();
FinishReason finishReason = fromGFinishReasonToFinishReason(firstCandidate.getFinishReason());
AiMessage aiMessage = createAiMessage(firstCandidate, finishReason);
TokenUsage tokenUsage = createTokenUsage(tokenCounts);
Response<AiMessage> response = Response.from(aiMessage, tokenUsage, finishReason);
notifyListenersOnResponse(response, chatModelRequest, listenerAttributes);
return ChatResponse.builder()
.aiMessage(aiMessage)
.finishReason(finishReason)
.tokenUsage(tokenUsage)
.build();
}
private AiMessage createAiMessage(GeminiCandidate candidate, FinishReason finishReason) {
if (candidate.getContent() == null) {
return AiMessage.from("No text was returned by the model. " +
"The model finished generating because of the following reason: " + finishReason);
}
return fromGPartsToAiMessage(candidate.getContent().getParts(), this.includeCodeExecutionOutput);
}
private TokenUsage createTokenUsage(GeminiUsageMetadata tokenCounts) {
return new TokenUsage(
tokenCounts.getPromptTokenCount(),
tokenCounts.getCandidatesTokenCount(),
tokenCounts.getTotalTokenCount()
);
}
@Override
@ -309,9 +195,9 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
public GoogleAiGeminiChatModelBuilder safetySettings(Map<GeminiHarmCategory, GeminiHarmBlockThreshold> safetySettingMap) {
this.safetySettings = safetySettingMap.entrySet().stream()
.map(entry -> new GeminiSafetySetting(entry.getKey(), entry.getValue())
).collect(Collectors.toList());
.map(entry -> new GeminiSafetySetting(entry.getKey(), entry.getValue())
).collect(Collectors.toList());
return this;
}
}
}
}

View File

@ -0,0 +1,93 @@
package dev.langchain4j.model.googleai;
import dev.langchain4j.Experimental;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.output.Response;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Stream;
import static dev.langchain4j.internal.RetryUtils.withRetry;
@Experimental
@Slf4j
public class GoogleAiGeminiStreamingChatModel extends BaseGeminiChatModel implements StreamingChatLanguageModel {
@Builder
public GoogleAiGeminiStreamingChatModel(
String apiKey, String modelName,
Double temperature, Integer topK, Double topP,
Integer maxOutputTokens, Duration timeout,
ResponseFormat responseFormat,
List<String> stopSequences, GeminiFunctionCallingConfig toolConfig,
Boolean allowCodeExecution, Boolean includeCodeExecutionOutput,
Boolean logRequestsAndResponses,
List<GeminiSafetySetting> safetySettings,
List<ChatModelListener> listeners,
Integer maxRetries
) {
super(apiKey, modelName, temperature, topK, topP, maxOutputTokens, timeout,
responseFormat, stopSequences, toolConfig, allowCodeExecution,
includeCodeExecutionOutput, logRequestsAndResponses, safetySettings,
listeners, maxRetries);
}
@Override
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
generate(messages, Collections.emptyList(), handler);
}
@Override
public void generate(List<ChatMessage> messages, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> handler) {
throw new RuntimeException("This method is not supported: Gemini AI cannot be forced to execute a tool.");
}
@Override
public void generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, StreamingResponseHandler<AiMessage> handler) {
GeminiGenerateContentRequest request = createGenerateContentRequest(messages, toolSpecifications, this.responseFormat);
ChatModelRequest chatModelRequest = createChatModelRequest(messages, toolSpecifications);
ConcurrentHashMap<Object, Object> listenerAttributes = new ConcurrentHashMap<>();
ChatModelRequestContext chatModelRequestContext = new ChatModelRequestContext(chatModelRequest, listenerAttributes);
notifyListenersOnRequest(chatModelRequestContext);
processGenerateContentRequest(request, handler, chatModelRequest, listenerAttributes);
}
private void processGenerateContentRequest(GeminiGenerateContentRequest request, StreamingResponseHandler<AiMessage> handler,
ChatModelRequest chatModelRequest, ConcurrentHashMap<Object, Object> listenerAttributes) {
GeminiStreamingResponseBuilder responseBuilder = new GeminiStreamingResponseBuilder(this.includeCodeExecutionOutput);
try {
Stream<GeminiGenerateContentResponse> contentStream = withRetry(
() -> this.geminiService.generateContentStream(this.modelName, this.apiKey, request),
maxRetries);
contentStream.forEach(partialResponse -> {
Optional<String> text = responseBuilder.append(partialResponse);
text.ifPresent(handler::onNext);
});
Response<AiMessage> fullResponse = responseBuilder.build();
handler.onComplete(fullResponse);
notifyListenersOnResponse(fullResponse, chatModelRequest, listenerAttributes);
} catch (RuntimeException exception) {
notifyListenersOnError(exception, chatModelRequest, listenerAttributes);
handler.onError(exception);
}
}
}

View File

@ -18,12 +18,12 @@ import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
@ -34,7 +34,13 @@ import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.RetryingTest;
import java.time.Duration;
import java.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static dev.langchain4j.internal.Utils.readBytes;
@ -44,7 +50,9 @@ import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HA
import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HATE_SPEECH;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
public class GoogleAiGeminiChatModelIT {
@ -114,8 +122,8 @@ public class GoogleAiGeminiChatModelIT {
assertThat(jsonText).contains("\"John\"");
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(25);
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(6);
assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(31);
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(7);
assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(32);
}
@Test
@ -435,10 +443,10 @@ public class GoogleAiGeminiChatModelIT {
.jsonSchema(JsonSchema.builder()
.rootElement(JsonObjectSchema.builder()
.addStringProperty("name")
.addProperty("address", JsonObjectSchema.builder()
.addStringProperty("city")
.required("city")
.build())
.addProperty("address", JsonObjectSchema.builder()
.addStringProperty("city")
.required("city")
.build())
.required("name", "address")
.additionalProperties(false)
.build())

View File

@ -0,0 +1,708 @@
package dev.langchain4j.model.googleai;
import com.google.gson.Gson;
import dev.langchain4j.agent.tool.P;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.AudioContent;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.TextFileContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.output.JsonSchemas;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.RetryingTest;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import static dev.langchain4j.internal.Utils.readBytes;
import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON;
import static dev.langchain4j.model.googleai.GeminiHarmBlockThreshold.BLOCK_LOW_AND_ABOVE;
import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HARASSMENT;
import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HATE_SPEECH;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
public class GoogleAiGeminiStreamingChatModelIT {
private static final String GOOGLE_AI_GEMINI_API_KEY = System.getenv("GOOGLE_AI_GEMINI_API_KEY");
private static final String CAT_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/e/e9/Felis_silvestris_silvestris_small_gradual_decrease_of_quality.png";
private static final String MD_FILE_URL = "https://raw.githubusercontent.com/langchain4j/langchain4j/main/docs/docs/intro.md";
@Test
void should_answer_simple_question() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.build();
UserMessage userMessage = UserMessage.from("What is the capital of France?");
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(userMessage, handler);
Response<AiMessage> response = handler.get();
// then
String text = response.content().text();
assertThat(text).containsIgnoringCase("Paris");
assertThat(response.finishReason()).isEqualTo(FinishReason.STOP);
}
@Test
void should_configure_generation() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.temperature(2.0)
.topP(0.5)
.topK(10)
.maxOutputTokens(10)
.build();
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate("How much is 3+4? Reply with just the answer", handler);
Response<AiMessage> response = handler.get();
// then
String text = response.content().text();
assertThat(text.trim()).isEqualTo("7");
}
@Test
void should_answer_in_json() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-pro")
.responseFormat(ResponseFormat.JSON)
.logRequestsAndResponses(true)
.build();
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(UserMessage.from("What is the firstname of the John Doe?\n" +
"Reply in JSON following with the following format: {\"firstname\": string}"), handler);
Response<AiMessage> response = handler.get();
// then
String jsonText = response.content().text();
assertThat(jsonText).contains("\"firstname\"");
assertThat(jsonText).contains("\"John\"");
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(25);
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(7);
assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(32);
}
@Test
void should_support_multiturn_chatting() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.build();
List<ChatMessage> messages = new ArrayList<>();
messages.add(UserMessage.from("Hello, my name is Guillaume"));
messages.add(AiMessage.from("Hi, how can I assist you today?"));
messages.add(UserMessage.from("What is my name? Reply with just my name."));
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(messages, handler);
Response<AiMessage> response = handler.get();
// then
assertThat(response.content().text()).contains("Guillaume");
}
@Test
void should_support_sending_images_as_base64() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.build();
UserMessage userMessage = UserMessage.userMessage(
ImageContent.from(new String(Base64.getEncoder().encode(readBytes(CAT_IMAGE_URL))), "image/png"),
TextContent.from("Describe this image in a single word")
);
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(userMessage, handler);
Response<AiMessage> response = handler.get();
// then
assertThat(response.content().text()).containsIgnoringCase("cat");
}
@Test
void should_support_text_file() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.build();
UserMessage userMessage = UserMessage.userMessage(
TextFileContent.from(new String(Base64.getEncoder().encode(readBytes(MD_FILE_URL))), "text/markdown"),
TextContent.from("What project does this markdown file mention?")
);
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(userMessage, handler);
Response<AiMessage> response = handler.get();
// then
assertThat(response.content().text()).containsIgnoringCase("LangChain4j");
}
@Test
void should_support_audio_file() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.build();
UserMessage userMessage = UserMessage.from(
AudioContent.from(
new String(Base64.getEncoder().encode( //TODO use local file
readBytes("https://storage.googleapis.com/cloud-samples-data/generative-ai/audio/pixel.mp3"))),
"audio/mp3"),
TextContent.from("Give a summary of the audio")
);
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(userMessage, handler);
Response<AiMessage> response = handler.get();
// then
assertThat(response.content().text()).containsIgnoringCase("Pixel");
}
@Test
@Disabled
void should_support_video_file() {
// ToDo waiting for the normal GoogleAiGeminiChatModel to implement the test
}
@Test
void should_respect_system_instruction() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.build();
List<ChatMessage> chatMessages = List.of(
SystemMessage.from("Translate from English into French"),
UserMessage.from("apple")
);
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(chatMessages, handler);
Response<AiMessage> response = handler.get();
// then
assertThat(response.content().text()).containsIgnoringCase("pomme");
}
@Test
void should_execute_python_code() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.allowCodeExecution(true)
.includeCodeExecutionOutput(true)
.build();
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(
UserMessage.from("Calculate `fibonacci(13)`. Write code in Python and execute it to get the result."),
handler
);
Response<AiMessage> response = handler.get();
// then
String text = response.content().text();
System.out.println("text = " + text);
assertThat(text).containsIgnoringCase("233");
}
@Test
void should_support_JSON_array_in_tools() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.build();
// when
List<ChatMessage> allMessages = new ArrayList<>();
allMessages.add(UserMessage.from("Return a JSON list containing the first 10 fibonacci numbers."));
TestStreamingResponseHandler<AiMessage> handler1 = new TestStreamingResponseHandler<>();
gemini.generate(
allMessages,
List.of(ToolSpecification.builder()
.name("getFirstNFibonacciNumbers")
.description("Get the first n fibonacci numbers")
.parameters(JsonObjectSchema.builder()
.addNumberProperty("n")
.build())
.build()),
handler1
);
Response<AiMessage> response1 = handler1.get();
// then
assertThat(response1.content().hasToolExecutionRequests()).isTrue();
assertThat(response1.content().toolExecutionRequests().get(0).name()).isEqualTo("getFirstNFibonacciNumbers");
assertThat(response1.content().toolExecutionRequests().get(0).arguments()).contains("\"n\":10");
allMessages.add(response1.content());
// when
ToolExecutionResultMessage forecastResult =
ToolExecutionResultMessage.from(null, "getFirstNFibonacciNumbers", "[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]");
allMessages.add(forecastResult);
// then
TestStreamingResponseHandler<AiMessage> handler2 = new TestStreamingResponseHandler<>();
gemini.generate(allMessages, handler2);
Response<AiMessage> response2 = handler2.get();
// then
assertThat(response2.content().text()).contains("[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]");
}
// Test is flaky, because Gemini doesn't 100% always ask for parallel tool calls
// and sometimes requests more information
@RetryingTest(5)
void should_support_parallel_tool_execution() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.build();
// when
List<ChatMessage> allMessages = new ArrayList<>();
allMessages.add(UserMessage.from("Which warehouse has more stock, ABC or XYZ?"));
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(
allMessages,
List.of(ToolSpecification.builder()
.name("getWarehouseStock")
.description("Retrieve the amount of stock available in a warehouse designated by its name")
.parameters(JsonObjectSchema.builder()
.addStringProperty("name", "The name of the warehouse")
.build())
.build()),
handler
);
Response<AiMessage> response = handler.get();
// then
assertThat(response.content().hasToolExecutionRequests()).isTrue();
List<ToolExecutionRequest> executionRequests = response.content().toolExecutionRequests();
assertThat(executionRequests).hasSize(2);
String allArgs = executionRequests.stream()
.map(ToolExecutionRequest::arguments)
.collect(Collectors.joining(" "));
assertThat(allArgs).contains("ABC");
assertThat(allArgs).contains("XYZ");
}
@RetryingTest(5)
void should_support_safety_settings() {
// given
List<GeminiSafetySetting> safetySettings = List.of(
new GeminiSafetySetting(HARM_CATEGORY_HATE_SPEECH, BLOCK_LOW_AND_ABOVE),
new GeminiSafetySetting(HARM_CATEGORY_HARASSMENT, BLOCK_LOW_AND_ABOVE)
);
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.safetySettings(safetySettings)
.build();
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate("You're a dumb f*cking idiot bastard!", handler);
Response<AiMessage> response = handler.get();
// then
assertThat(response.finishReason()).isEqualTo(FinishReasonMapper.fromGFinishReasonToFinishReason(GeminiFinishReason.SAFETY));
}
@Test
void should_comply_to_response_schema() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.rootElement(JsonObjectSchema.builder()
.addStringProperty("name")
.addProperty("address", JsonObjectSchema.builder()
.addStringProperty("city")
.required("city")
.build())
.required("name", "address")
.additionalProperties(false)
.build())
.build())
.build())
.build();
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(List.of(
SystemMessage.from(
"Your role is to extract information related to a person," +
"like their name, their address, the city the live in."),
UserMessage.from(
"In the town of Liverpool, lived Tommy Skybridge, a young little boy."
)
), handler);
Response<AiMessage> response = handler.get();
System.out.println("response = " + response);
// then
assertThat(response.content().text().trim())
.isEqualTo("{\"address\": {\"city\": \"Liverpool\"}, \"name\": \"Tommy Skybridge\"}");
}
@Test
void should_handle_enum_type() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.rootElement(JsonObjectSchema.builder()
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
put("sentiment", JsonEnumSchema.builder()
.enumValues("POSITIVE", "NEGATIVE")
.build());
}})
.required("sentiment")
.additionalProperties(false)
.build())
.build())
.build())
.build();
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(List.of(
SystemMessage.from(
"Your role is to analyze the sentiment of the text you receive."),
UserMessage.from(
"This is super exciting news, congratulations!"
)
), handler);
Response<AiMessage> response = handler.get();
System.out.println("response = " + response);
// then
assertThat(response.content().text().trim())
.isEqualTo("{\"sentiment\": \"POSITIVE\"}");
}
@Test
void should_handle_enum_output_mode() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.rootElement(JsonEnumSchema.builder()
.enumValues("POSITIVE", "NEUTRAL", "NEGATIVE")
.build())
.build())
.build())
.build();
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(List.of(
SystemMessage.from(
"Your role is to analyze the sentiment of the text you receive."),
UserMessage.from(
"This is super exciting news, congratulations!"
)
), handler);
Response<AiMessage> response = handler.get();
// then
assertThat(response.content().text().trim())
.isEqualTo("POSITIVE");
}
@Test
void should_allow_array_as_response_schema() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchema.builder()
.rootElement(JsonArraySchema.builder()
.items(new JsonIntegerSchema())
.build())
.build())
.build())
.build();
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(List.of(
SystemMessage.from(
"Your role is to return a list of 6-faces dice rolls"),
UserMessage.from(
"Give me 3 dice rolls"
)
), handler);
Response<AiMessage> response = handler.get();
System.out.println("response = " + response);
// then
Integer[] diceRolls = new Gson().fromJson(response.content().text(), Integer[].class);
assertThat(diceRolls.length).isEqualTo(3);
}
private class Color {
private String name;
private int red;
private int green;
private int blue;
private boolean muted;
}
@Test
void should_deserialize_to_POJO() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(JsonSchemas.jsonSchemaFrom(Color.class).get())
.build())
.build();
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
gemini.generate(List.of(
SystemMessage.from(
"Your role is to extract information from the description of a color"),
UserMessage.from(
"Cobalt blue is a blend of a lot of blue, a bit of green, and almost no red."
)
), handler);
Response<AiMessage> response = handler.get();
System.out.println("response = " + response);
Color color = new Gson().fromJson(response.content().text(), Color.class);
// then
assertThat(color.name).isEqualToIgnoringCase("Cobalt blue");
assertThat(color.muted).isFalse();
assertThat(color.red).isLessThanOrEqualTo(color.green);
assertThat(color.green).isLessThanOrEqualTo(color.blue);
}
@Test
void should_support_tool_config() {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.build();
List<ChatMessage> chatMessages = List.of(UserMessage.from("Call toolOne"));
List<ToolSpecification> listOfTools = Arrays.asList(
ToolSpecification.builder().name("toolOne").build(),
ToolSpecification.builder().name("toolTwo").build()
);
// when
TestStreamingResponseHandler<AiMessage> handler1 = new TestStreamingResponseHandler<>();
gemini.generate(chatMessages, listOfTools, handler1);
Response<AiMessage> response1 = handler1.get();
// then
assertThat(response1.content().hasToolExecutionRequests()).isTrue();
assertThat(response1.content().toolExecutionRequests().get(0).name()).isEqualTo("toolOne");
// given
gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-flash")
.logRequestsAndResponses(true)
.toolConfig(new GeminiFunctionCallingConfig(GeminiMode.ANY, List.of("toolTwo")))
.build();
// when
TestStreamingResponseHandler<AiMessage> handler2 = new TestStreamingResponseHandler<>();
gemini.generate("Call toolOne", handler2);
Response<AiMessage> response2 = handler2.get();
// then
assertThat(response2.content().hasToolExecutionRequests()).isFalse();
}
static class Transactions {
@Tool("returns amount of a given transaction")
double getTransactionAmount(@P("ID of a transaction") String id) {
System.out.printf("called getTransactionAmount(%s)%n", id);
switch (id) {
case "T001":
return 11.1;
case "T002":
return 22.2;
default:
throw new IllegalArgumentException("Unknown transaction ID: " + id);
}
}
}
interface StreamingAssistant {
TokenStream chat(String userMessage);
}
@RetryingTest(10)
void should_work_with_tools_with_AiServices() throws ExecutionException, InterruptedException, TimeoutException {
// given
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("gemini-1.5-pro")
.logRequestsAndResponses(true)
.timeout(Duration.ofMinutes(2))
.temperature(0.0)
.topP(0.0)
.topK(1)
.build();
// when
Transactions spyTransactions = spy(new Transactions());
MessageWindowChatMemory chatMemory = MessageWindowChatMemory.withMaxMessages(20);
StreamingAssistant assistant = AiServices.builder(StreamingAssistant.class)
.tools(spyTransactions)
.chatMemory(chatMemory)
.streamingChatLanguageModel(gemini)
.build();
// then
CompletableFuture<Response<AiMessage>> future1 = new CompletableFuture<>();
assistant.chat("What is the amount of transaction T001?")
.onNext(System.out::println)
.onComplete(future1::complete)
.onError(future1::completeExceptionally)
.start();
Response<AiMessage> response1 = future1.get(30, TimeUnit.SECONDS);
assertThat(response1.content().text()).containsIgnoringCase("11.1");
verify(spyTransactions).getTransactionAmount("T001");
CompletableFuture<Response<AiMessage>> future2 = new CompletableFuture<>();
assistant.chat("What is the amount of transaction T002?")
.onNext(System.out::println)
.onComplete(future2::complete)
.onError(future2::completeExceptionally)
.start();
Response<AiMessage> response2 = future2.get(30, TimeUnit.SECONDS);
assertThat(response2.content().text()).containsIgnoringCase("22.2");
verify(spyTransactions).getTransactionAmount("T002");
verifyNoMoreInteractions(spyTransactions);
}
@AfterEach
void afterEach() throws InterruptedException {
String ciDelaySeconds = System.getenv("CI_DELAY_SECONDS_GOOGLE_AI_GEMINI");
if (ciDelaySeconds != null) {
Thread.sleep(Integer.parseInt(ciDelaySeconds) * 1000L);
}
}
}

View File

@ -0,0 +1,61 @@
package dev.langchain4j.model.googleai;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatModelListenerIT;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import org.junit.jupiter.api.AfterEach;
import java.io.IOException;
import static java.util.Collections.singletonList;
class GoogleAiGeminiStreamingChatModelListenerIT extends StreamingChatModelListenerIT {
private static final String GOOGLE_AI_GEMINI_API_KEY = System.getenv("GOOGLE_AI_GEMINI_API_KEY");
@Override
protected StreamingChatLanguageModel createModel(ChatModelListener listener) {
return GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName(modelName())
.temperature(temperature())
.topP(topP())
.maxOutputTokens(maxTokens())
.listeners(singletonList(listener))
.logRequestsAndResponses(true)
.build();
}
@Override
protected String modelName() {
return "gemini-1.5-flash";
}
@Override
protected boolean assertResponseId() {
return false;
}
@Override
protected StreamingChatLanguageModel createFailingModel(ChatModelListener listener) {
return GoogleAiGeminiStreamingChatModel.builder()
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
.modelName("banana")
.listeners(singletonList(listener))
.logRequestsAndResponses(true)
.build();
}
@Override
protected Class<? extends Exception> expectedExceptionClass() {
return RuntimeException.class;
}
@AfterEach
void afterEach() throws InterruptedException {
String ciDelaySeconds = System.getenv("CI_DELAY_SECONDS_VERTEX_AI_GEMINI");
if (ciDelaySeconds != null) {
Thread.sleep(Integer.parseInt(ciDelaySeconds) * 1000L);
}
}
}