Adding GoogleAiGeminiStreamingChatModel (#1951)
## Issue Closes #1903 ## Change Implemented the GoogleAiGeminiStreamingChatModel. This PR depends on #1950. ## General checklist - [X] There are no breaking changes - [x] I have added unit and integration tests for my change - [x] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [x] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green - [x] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable)
This commit is contained in:
parent
e98a2e4f62
commit
aa0e488166
|
@ -83,9 +83,34 @@ ChatLanguageModel gemini = GoogleAiGeminiChatModel.builder()
|
||||||
```
|
```
|
||||||
|
|
||||||
## GoogleAiGeminiStreamingChatModel
|
## GoogleAiGeminiStreamingChatModel
|
||||||
|
The `GoogleAiGeminiStreamingChatModel` allows streaming the text of a response token by token. The response must be managed by a `StreamingResponseHandler`.
|
||||||
|
```java
|
||||||
|
StreamingChatLanguageModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(System.getenv("GEMINI_AI_KEY"))
|
||||||
|
.modelName("gemini-1.5-flash")
|
||||||
|
.build();
|
||||||
|
|
||||||
No streaming chat model is available yet.
|
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
|
||||||
Please open a feature request if you're interested in a streaming model or if you want to contribute to implementing it.
|
|
||||||
|
gemini.generate("Tell me a joke about Java", new StreamingResponseHandler<AiMessage>() {
|
||||||
|
@Override
|
||||||
|
public void onNext(String token) {
|
||||||
|
System.out.print(token);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onComplete(Response<AiMessage> response) {
|
||||||
|
futureResponse.complete(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onError(Throwable error) {
|
||||||
|
futureResponse.completeExceptionally(error);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
futureResponse.join();
|
||||||
|
```
|
||||||
|
|
||||||
## Tools
|
## Tools
|
||||||
|
|
||||||
|
|
|
@ -11,8 +11,8 @@ sidebar_position: 0
|
||||||
| [Azure OpenAI](/integrations/language-models/azure-open-ai) | ✅ | ✅ | ✅ | text, image | ✅ | | | |
|
| [Azure OpenAI](/integrations/language-models/azure-open-ai) | ✅ | ✅ | ✅ | text, image | ✅ | | | |
|
||||||
| [ChatGLM](/integrations/language-models/chatglm) | | | | text | | | | |
|
| [ChatGLM](/integrations/language-models/chatglm) | | | | text | | | | |
|
||||||
| [DashScope](/integrations/language-models/dashscope) | ✅ | ✅ | | text, image, audio | ✅ | | | |
|
| [DashScope](/integrations/language-models/dashscope) | ✅ | ✅ | | text, image, audio | ✅ | | | |
|
||||||
| [GitHub Models](/integrations/language-models/github-models) | ✅ | ✅ | ✅ | text | ✅ | | | |
|
| [GitHub Models](/integrations/language-models/github-models) | ✅ | ✅ | ✅ | text | ✅ | | | |
|
||||||
| [Google AI Gemini](/integrations/language-models/google-ai-gemini) | | ✅ | ✅ | text, image, audio, video, PDF | ✅ | | | |
|
| [Google AI Gemini](/integrations/language-models/google-ai-gemini) | ✅ | ✅ | ✅ | text, image, audio, video, PDF | ✅ | | | |
|
||||||
| [Google Vertex AI Gemini](/integrations/language-models/google-vertex-ai-gemini) | ✅ | ✅ | ✅ | text, image, audio, video, PDF | ✅ | | | |
|
| [Google Vertex AI Gemini](/integrations/language-models/google-vertex-ai-gemini) | ✅ | ✅ | ✅ | text, image, audio, video, PDF | ✅ | | | |
|
||||||
| [Google Vertex AI PaLM 2](/integrations/language-models/google-palm) | | | | text | | | ✅ | |
|
| [Google Vertex AI PaLM 2](/integrations/language-models/google-palm) | | | | text | | | ✅ | |
|
||||||
| [Hugging Face](/integrations/language-models/hugging-face) | | | | text | | | | |
|
| [Hugging Face](/integrations/language-models/hugging-face) | | | | text | | | | |
|
||||||
|
|
|
@ -0,0 +1,192 @@
|
||||||
|
package dev.langchain4j.model.googleai;
|
||||||
|
|
||||||
|
import dev.langchain4j.Experimental;
|
||||||
|
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||||
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
|
import dev.langchain4j.data.message.ChatMessage;
|
||||||
|
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
|
||||||
|
import dev.langchain4j.model.chat.listener.ChatModelListener;
|
||||||
|
import dev.langchain4j.model.chat.listener.ChatModelRequest;
|
||||||
|
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
|
||||||
|
import dev.langchain4j.model.chat.listener.ChatModelResponse;
|
||||||
|
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
|
||||||
|
import dev.langchain4j.model.chat.request.ResponseFormat;
|
||||||
|
import dev.langchain4j.model.chat.request.ResponseFormatType;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
|
||||||
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
|
import static dev.langchain4j.internal.Utils.copyIfNotNull;
|
||||||
|
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||||
|
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||||
|
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromMessageToGContent;
|
||||||
|
import static dev.langchain4j.model.googleai.SchemaMapper.fromJsonSchemaToGSchema;
|
||||||
|
import static java.time.Duration.ofSeconds;
|
||||||
|
import static java.util.Collections.emptyList;
|
||||||
|
|
||||||
|
@Experimental
|
||||||
|
@Slf4j
|
||||||
|
abstract class BaseGeminiChatModel {
|
||||||
|
protected final GeminiService geminiService;
|
||||||
|
protected final String apiKey;
|
||||||
|
protected final String modelName;
|
||||||
|
protected final Double temperature;
|
||||||
|
protected final Integer topK;
|
||||||
|
protected final Double topP;
|
||||||
|
protected final Integer maxOutputTokens;
|
||||||
|
protected final List<String> stopSequences;
|
||||||
|
protected final ResponseFormat responseFormat;
|
||||||
|
protected final GeminiFunctionCallingConfig toolConfig;
|
||||||
|
protected final boolean allowCodeExecution;
|
||||||
|
protected final boolean includeCodeExecutionOutput;
|
||||||
|
protected final List<GeminiSafetySetting> safetySettings;
|
||||||
|
protected final List<ChatModelListener> listeners;
|
||||||
|
protected final Integer maxRetries;
|
||||||
|
|
||||||
|
protected BaseGeminiChatModel(
|
||||||
|
String apiKey,
|
||||||
|
String modelName,
|
||||||
|
Double temperature,
|
||||||
|
Integer topK,
|
||||||
|
Double topP,
|
||||||
|
Integer maxOutputTokens,
|
||||||
|
Duration timeout,
|
||||||
|
ResponseFormat responseFormat,
|
||||||
|
List<String> stopSequences,
|
||||||
|
GeminiFunctionCallingConfig toolConfig,
|
||||||
|
Boolean allowCodeExecution,
|
||||||
|
Boolean includeCodeExecutionOutput,
|
||||||
|
Boolean logRequestsAndResponses,
|
||||||
|
List<GeminiSafetySetting> safetySettings,
|
||||||
|
List<ChatModelListener> listeners,
|
||||||
|
Integer maxRetries
|
||||||
|
) {
|
||||||
|
this.apiKey = ensureNotBlank(apiKey, "apiKey");
|
||||||
|
this.modelName = ensureNotBlank(modelName, "modelName");
|
||||||
|
this.temperature = temperature;
|
||||||
|
this.topK = topK;
|
||||||
|
this.topP = topP;
|
||||||
|
this.maxOutputTokens = maxOutputTokens;
|
||||||
|
this.stopSequences = getOrDefault(stopSequences, emptyList());
|
||||||
|
this.toolConfig = toolConfig;
|
||||||
|
this.allowCodeExecution = getOrDefault(allowCodeExecution, false);
|
||||||
|
this.includeCodeExecutionOutput = getOrDefault(includeCodeExecutionOutput, false);
|
||||||
|
this.safetySettings = copyIfNotNull(safetySettings);
|
||||||
|
this.responseFormat = responseFormat;
|
||||||
|
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
|
||||||
|
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||||
|
this.geminiService = new GeminiService(
|
||||||
|
getOrDefault(logRequestsAndResponses, false) ? log : null,
|
||||||
|
getOrDefault(timeout, ofSeconds(60))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected GeminiGenerateContentRequest createGenerateContentRequest(
|
||||||
|
List<ChatMessage> messages,
|
||||||
|
List<ToolSpecification> toolSpecifications,
|
||||||
|
ResponseFormat responseFormat
|
||||||
|
) {
|
||||||
|
GeminiContent systemInstruction = new GeminiContent(GeminiRole.MODEL.toString());
|
||||||
|
List<GeminiContent> geminiContentList = fromMessageToGContent(messages, systemInstruction);
|
||||||
|
|
||||||
|
GeminiSchema schema = null;
|
||||||
|
if (responseFormat != null && responseFormat.jsonSchema() != null) {
|
||||||
|
schema = fromJsonSchemaToGSchema(responseFormat.jsonSchema());
|
||||||
|
}
|
||||||
|
|
||||||
|
return GeminiGenerateContentRequest.builder()
|
||||||
|
.contents(geminiContentList)
|
||||||
|
.systemInstruction(!systemInstruction.getParts().isEmpty() ? systemInstruction : null)
|
||||||
|
.generationConfig(GeminiGenerationConfig.builder()
|
||||||
|
.candidateCount(1) // Multiple candidates aren't supported by langchain4j
|
||||||
|
.maxOutputTokens(this.maxOutputTokens)
|
||||||
|
.responseMimeType(computeMimeType(responseFormat))
|
||||||
|
.responseSchema(schema)
|
||||||
|
.stopSequences(this.stopSequences)
|
||||||
|
.temperature(this.temperature)
|
||||||
|
.topK(this.topK)
|
||||||
|
.topP(this.topP)
|
||||||
|
.build())
|
||||||
|
.safetySettings(this.safetySettings)
|
||||||
|
.tools(FunctionMapper.fromToolSepcsToGTool(toolSpecifications, this.allowCodeExecution))
|
||||||
|
.toolConfig(new GeminiToolConfig(this.toolConfig))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected ChatModelRequest createChatModelRequest(
|
||||||
|
List<ChatMessage> messages,
|
||||||
|
List<ToolSpecification> toolSpecifications
|
||||||
|
) {
|
||||||
|
return ChatModelRequest.builder()
|
||||||
|
.model(modelName)
|
||||||
|
.temperature(temperature)
|
||||||
|
.topP(topP)
|
||||||
|
.maxTokens(maxOutputTokens)
|
||||||
|
.messages(messages)
|
||||||
|
.toolSpecifications(toolSpecifications)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected static String computeMimeType(ResponseFormat responseFormat) {
|
||||||
|
if (responseFormat == null || ResponseFormatType.TEXT.equals(responseFormat.type())) {
|
||||||
|
return "text/plain";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ResponseFormatType.JSON.equals(responseFormat.type()) &&
|
||||||
|
responseFormat.jsonSchema() != null &&
|
||||||
|
responseFormat.jsonSchema().rootElement() != null &&
|
||||||
|
responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema) {
|
||||||
|
return "text/x.enum";
|
||||||
|
}
|
||||||
|
|
||||||
|
return "application/json";
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void notifyListenersOnRequest(ChatModelRequestContext context) {
|
||||||
|
listeners.forEach((listener) -> {
|
||||||
|
try {
|
||||||
|
listener.onRequest(context);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn("Exception while calling model listener (onRequest)", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void notifyListenersOnResponse(Response<AiMessage> response, ChatModelRequest request,
|
||||||
|
ConcurrentHashMap<Object, Object> attributes) {
|
||||||
|
ChatModelResponse chatModelResponse = ChatModelResponse.builder()
|
||||||
|
.model(modelName)
|
||||||
|
.tokenUsage(response.tokenUsage())
|
||||||
|
.finishReason(response.finishReason())
|
||||||
|
.aiMessage(response.content())
|
||||||
|
.build();
|
||||||
|
ChatModelResponseContext context = new ChatModelResponseContext(
|
||||||
|
chatModelResponse, request, attributes);
|
||||||
|
listeners.forEach((listener) -> {
|
||||||
|
try {
|
||||||
|
listener.onResponse(context);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn("Exception while calling model listener (onResponse)", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void notifyListenersOnError(Exception exception, ChatModelRequest request,
|
||||||
|
ConcurrentHashMap<Object, Object> attributes) {
|
||||||
|
listeners.forEach((listener) -> {
|
||||||
|
try {
|
||||||
|
ChatModelErrorContext context = new ChatModelErrorContext(
|
||||||
|
exception, request, null, attributes);
|
||||||
|
listener.onError(context);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn("Exception while calling model listener (onError)", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -10,6 +10,8 @@ import java.net.http.HttpClient;
|
||||||
import java.net.http.HttpRequest;
|
import java.net.http.HttpRequest;
|
||||||
import java.net.http.HttpResponse;
|
import java.net.http.HttpResponse;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
class GeminiService {
|
class GeminiService {
|
||||||
private static final String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta";
|
private static final String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta";
|
||||||
|
@ -48,6 +50,11 @@ class GeminiService {
|
||||||
return sendRequest(url, apiKey, request, GoogleAiBatchEmbeddingResponse.class);
|
return sendRequest(url, apiKey, request, GoogleAiBatchEmbeddingResponse.class);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Stream<GeminiGenerateContentResponse> generateContentStream(String modelName, String apiKey, GeminiGenerateContentRequest request) {
|
||||||
|
String url = String.format("%s/models/%s:streamGenerateContent?alt=sse", GEMINI_AI_ENDPOINT, modelName);
|
||||||
|
return streamRequest(url, apiKey, request, GeminiGenerateContentResponse.class);
|
||||||
|
}
|
||||||
|
|
||||||
private <T> T sendRequest(String url, String apiKey, Object requestBody, Class<T> responseType) {
|
private <T> T sendRequest(String url, String apiKey, Object requestBody, Class<T> responseType) {
|
||||||
String jsonBody = gson.toJson(requestBody);
|
String jsonBody = gson.toJson(requestBody);
|
||||||
HttpRequest request = buildHttpRequest(url, apiKey, jsonBody);
|
HttpRequest request = buildHttpRequest(url, apiKey, jsonBody);
|
||||||
|
@ -72,6 +79,40 @@ class GeminiService {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private <T> Stream<T> streamRequest(String url, String apiKey, Object requestBody, Class<T> responseType) {
|
||||||
|
String jsonBody = gson.toJson(requestBody);
|
||||||
|
HttpRequest httpRequest = buildHttpRequest(url, apiKey, jsonBody);
|
||||||
|
|
||||||
|
logRequest(jsonBody);
|
||||||
|
|
||||||
|
try {
|
||||||
|
HttpResponse<Stream<String>> httpResponse = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofLines());
|
||||||
|
|
||||||
|
if (httpResponse.statusCode() >= 300) {
|
||||||
|
String errorBody = httpResponse.body()
|
||||||
|
.collect(Collectors.joining("\n"));
|
||||||
|
|
||||||
|
throw new RuntimeException(String.format("HTTP error (%d): %s", httpResponse.statusCode(), errorBody));
|
||||||
|
}
|
||||||
|
|
||||||
|
Stream<T> responseStream = httpResponse.body()
|
||||||
|
.filter(line -> line.startsWith("data: "))
|
||||||
|
.map(line -> line.substring(6)) // Remove "data: " prefix
|
||||||
|
.map(jsonString -> gson.fromJson(jsonString, responseType));
|
||||||
|
|
||||||
|
if (logger != null) {
|
||||||
|
responseStream = responseStream.peek(response -> logger.debug("Partial response from Gemini:\n{}", response));
|
||||||
|
}
|
||||||
|
|
||||||
|
return responseStream;
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException("An error occurred while streaming the request", e);
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
Thread.currentThread().interrupt();
|
||||||
|
throw new RuntimeException("Streaming the request was interrupted", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private HttpRequest buildHttpRequest(String url, String apiKey, String jsonBody) {
|
private HttpRequest buildHttpRequest(String url, String apiKey, String jsonBody) {
|
||||||
return HttpRequest.newBuilder()
|
return HttpRequest.newBuilder()
|
||||||
.uri(URI.create(url))
|
.uri(URI.create(url))
|
||||||
|
|
|
@ -0,0 +1,111 @@
|
||||||
|
package dev.langchain4j.model.googleai;
|
||||||
|
|
||||||
|
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||||
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
|
import dev.langchain4j.model.output.FinishReason;
|
||||||
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import dev.langchain4j.model.output.TokenUsage;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
import static dev.langchain4j.model.googleai.FinishReasonMapper.fromGFinishReasonToFinishReason;
|
||||||
|
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromGPartsToAiMessage;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A builder class for constructing streaming responses from Gemini AI model.
|
||||||
|
* This class accumulates partial responses and builds a final response.
|
||||||
|
*/
|
||||||
|
class GeminiStreamingResponseBuilder {
|
||||||
|
private final boolean includeCodeExecutionOutput;
|
||||||
|
private final StringBuilder contentBuilder;
|
||||||
|
private final List<ToolExecutionRequest> functionCalls;
|
||||||
|
private TokenUsage tokenUsage;
|
||||||
|
private FinishReason finishReason;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructs a new GeminiStreamingResponseBuilder.
|
||||||
|
*
|
||||||
|
* @param includeCodeExecutionOutput whether to include code execution output in the response
|
||||||
|
*/
|
||||||
|
public GeminiStreamingResponseBuilder(boolean includeCodeExecutionOutput) {
|
||||||
|
this.includeCodeExecutionOutput = includeCodeExecutionOutput;
|
||||||
|
this.contentBuilder = new StringBuilder();
|
||||||
|
this.functionCalls = new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Appends a partial response to the builder.
|
||||||
|
*
|
||||||
|
* @param partialResponse the partial response from Gemini AI
|
||||||
|
* @return an Optional containing the text of the partial response, or empty if no valid text
|
||||||
|
*/
|
||||||
|
public Optional<String> append(GeminiGenerateContentResponse partialResponse) {
|
||||||
|
if (partialResponse == null) {
|
||||||
|
return Optional.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
GeminiCandidate firstCandidate = partialResponse.getCandidates().get(0);
|
||||||
|
|
||||||
|
updateFinishReason(firstCandidate);
|
||||||
|
updateTokenUsage(partialResponse.getUsageMetadata());
|
||||||
|
|
||||||
|
GeminiContent content = firstCandidate.getContent();
|
||||||
|
if (content == null || content.getParts() == null) {
|
||||||
|
return Optional.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
AiMessage message = fromGPartsToAiMessage(content.getParts(), this.includeCodeExecutionOutput);
|
||||||
|
updateContentAndFunctionCalls(message);
|
||||||
|
|
||||||
|
return Optional.ofNullable(message.text());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds the final response from all accumulated partial responses.
|
||||||
|
*
|
||||||
|
* @return a Response object containing the final AiMessage, token usage, and finish reason
|
||||||
|
*/
|
||||||
|
public Response<AiMessage> build() {
|
||||||
|
AiMessage aiMessage = createAiMessage();
|
||||||
|
return Response.from(aiMessage, tokenUsage, finishReason);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateTokenUsage(GeminiUsageMetadata tokenCounts) {
|
||||||
|
this.tokenUsage = new TokenUsage(
|
||||||
|
tokenCounts.getPromptTokenCount(),
|
||||||
|
tokenCounts.getCandidatesTokenCount(),
|
||||||
|
tokenCounts.getTotalTokenCount()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateFinishReason(GeminiCandidate candidate) {
|
||||||
|
if (candidate.getFinishReason() != null) {
|
||||||
|
this.finishReason = fromGFinishReasonToFinishReason(candidate.getFinishReason());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateContentAndFunctionCalls(AiMessage message) {
|
||||||
|
Optional.ofNullable(message.text()).ifPresent(contentBuilder::append);
|
||||||
|
if (message.hasToolExecutionRequests()) {
|
||||||
|
functionCalls.addAll(message.toolExecutionRequests());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private AiMessage createAiMessage() {
|
||||||
|
String text = contentBuilder.toString();
|
||||||
|
boolean hasText = !text.isEmpty() && !text.isBlank();
|
||||||
|
boolean hasFunctionCall = !functionCalls.isEmpty();
|
||||||
|
|
||||||
|
if (hasText && hasFunctionCall) {
|
||||||
|
return new AiMessage(text, functionCalls);
|
||||||
|
} else if (hasText) {
|
||||||
|
return new AiMessage(text);
|
||||||
|
} else if (hasFunctionCall) {
|
||||||
|
return new AiMessage(functionCalls);
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new RuntimeException("Gemini has responded neither with text nor with a function call.");
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,6 +1,5 @@
|
||||||
package dev.langchain4j.model.googleai;
|
package dev.langchain4j.model.googleai;
|
||||||
|
|
||||||
import com.google.gson.Gson;
|
|
||||||
import dev.langchain4j.Experimental;
|
import dev.langchain4j.Experimental;
|
||||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
|
@ -8,11 +7,12 @@ import dev.langchain4j.data.message.ChatMessage;
|
||||||
import dev.langchain4j.model.chat.Capability;
|
import dev.langchain4j.model.chat.Capability;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.chat.TokenCountEstimator;
|
import dev.langchain4j.model.chat.TokenCountEstimator;
|
||||||
import dev.langchain4j.model.chat.listener.*;
|
import dev.langchain4j.model.chat.listener.ChatModelListener;
|
||||||
|
import dev.langchain4j.model.chat.listener.ChatModelRequest;
|
||||||
|
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
|
||||||
import dev.langchain4j.model.chat.request.ChatRequest;
|
import dev.langchain4j.model.chat.request.ChatRequest;
|
||||||
import dev.langchain4j.model.chat.request.ResponseFormat;
|
import dev.langchain4j.model.chat.request.ResponseFormat;
|
||||||
import dev.langchain4j.model.chat.request.ResponseFormatType;
|
import dev.langchain4j.model.chat.request.ResponseFormatType;
|
||||||
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
|
|
||||||
import dev.langchain4j.model.chat.response.ChatResponse;
|
import dev.langchain4j.model.chat.response.ChatResponse;
|
||||||
import dev.langchain4j.model.output.FinishReason;
|
import dev.langchain4j.model.output.FinishReason;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
|
@ -23,135 +23,63 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
||||||
import static dev.langchain4j.internal.Utils.copyIfNotNull;
|
|
||||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
|
||||||
import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA;
|
import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA;
|
||||||
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromMessageToGContent;
|
|
||||||
import static dev.langchain4j.model.googleai.FinishReasonMapper.fromGFinishReasonToFinishReason;
|
import static dev.langchain4j.model.googleai.FinishReasonMapper.fromGFinishReasonToFinishReason;
|
||||||
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromGPartsToAiMessage;
|
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromGPartsToAiMessage;
|
||||||
import static dev.langchain4j.model.googleai.SchemaMapper.fromJsonSchemaToGSchema;
|
|
||||||
import static java.time.Duration.ofSeconds;
|
import static java.time.Duration.ofSeconds;
|
||||||
import static java.util.Collections.emptyList;
|
|
||||||
|
|
||||||
@Experimental
|
@Experimental
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
public class GoogleAiGeminiChatModel extends BaseGeminiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||||
private final GeminiService geminiService;
|
|
||||||
|
|
||||||
private final String apiKey;
|
|
||||||
private final String modelName;
|
|
||||||
|
|
||||||
private final Integer maxRetries;
|
|
||||||
private final Double temperature;
|
|
||||||
private final Integer topK;
|
|
||||||
private final Double topP;
|
|
||||||
private final Integer maxOutputTokens;
|
|
||||||
private final List<String> stopSequences;
|
|
||||||
|
|
||||||
private final Integer candidateCount;
|
|
||||||
|
|
||||||
private final ResponseFormat responseFormat;
|
|
||||||
|
|
||||||
private final GeminiFunctionCallingConfig toolConfig;
|
|
||||||
|
|
||||||
private final boolean allowCodeExecution;
|
|
||||||
private final boolean includeCodeExecutionOutput;
|
|
||||||
|
|
||||||
private final Boolean logRequestsAndResponses;
|
|
||||||
private final List<GeminiSafetySetting> safetySettings;
|
|
||||||
private final List<ChatModelListener> listeners;
|
|
||||||
|
|
||||||
private final GoogleAiGeminiTokenizer geminiTokenizer;
|
private final GoogleAiGeminiTokenizer geminiTokenizer;
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
public GoogleAiGeminiChatModel(String apiKey, String modelName,
|
public GoogleAiGeminiChatModel(
|
||||||
Integer maxRetries,
|
String apiKey, String modelName,
|
||||||
Double temperature, Integer topK, Double topP,
|
Integer maxRetries,
|
||||||
Integer maxOutputTokens, Integer candidateCount,
|
Double temperature, Integer topK, Double topP,
|
||||||
Duration timeout,
|
Integer maxOutputTokens, Duration timeout,
|
||||||
ResponseFormat responseFormat,
|
ResponseFormat responseFormat,
|
||||||
List<String> stopSequences, GeminiFunctionCallingConfig toolConfig,
|
List<String> stopSequences, GeminiFunctionCallingConfig toolConfig,
|
||||||
Boolean allowCodeExecution, Boolean includeCodeExecutionOutput,
|
Boolean allowCodeExecution, Boolean includeCodeExecutionOutput,
|
||||||
Boolean logRequestsAndResponses,
|
Boolean logRequestsAndResponses,
|
||||||
List<GeminiSafetySetting> safetySettings,
|
List<GeminiSafetySetting> safetySettings,
|
||||||
List<ChatModelListener> listeners
|
List<ChatModelListener> listeners
|
||||||
) {
|
) {
|
||||||
this.apiKey = ensureNotBlank(apiKey, "apiKey");
|
super(apiKey, modelName, temperature, topK, topP, maxOutputTokens, timeout,
|
||||||
this.modelName = ensureNotBlank(modelName, "modelName");
|
responseFormat, stopSequences, toolConfig, allowCodeExecution,
|
||||||
|
includeCodeExecutionOutput, logRequestsAndResponses, safetySettings,
|
||||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
listeners, maxRetries);
|
||||||
|
|
||||||
// using Gemini's default values
|
|
||||||
this.temperature = getOrDefault(temperature, 1.0);
|
|
||||||
this.topK = getOrDefault(topK, 64);
|
|
||||||
this.topP = getOrDefault(topP, 0.95);
|
|
||||||
this.maxOutputTokens = getOrDefault(maxOutputTokens, 8192);
|
|
||||||
this.candidateCount = getOrDefault(candidateCount, 1);
|
|
||||||
this.stopSequences = getOrDefault(stopSequences, emptyList());
|
|
||||||
|
|
||||||
this.toolConfig = toolConfig;
|
|
||||||
|
|
||||||
this.allowCodeExecution = allowCodeExecution != null ? allowCodeExecution : false;
|
|
||||||
this.includeCodeExecutionOutput = includeCodeExecutionOutput != null ? includeCodeExecutionOutput : false;
|
|
||||||
this.logRequestsAndResponses = getOrDefault(logRequestsAndResponses, false);
|
|
||||||
|
|
||||||
this.safetySettings = copyIfNotNull(safetySettings);
|
|
||||||
|
|
||||||
this.responseFormat = responseFormat;
|
|
||||||
|
|
||||||
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
|
|
||||||
|
|
||||||
this.geminiService = new GeminiService(
|
|
||||||
getOrDefault(logRequestsAndResponses, false) ? log : null,
|
|
||||||
getOrDefault(timeout, ofSeconds(60))
|
|
||||||
);
|
|
||||||
|
|
||||||
this.geminiTokenizer = GoogleAiGeminiTokenizer.builder()
|
this.geminiTokenizer = GoogleAiGeminiTokenizer.builder()
|
||||||
.modelName(this.modelName)
|
.modelName(this.modelName)
|
||||||
.apiKey(this.apiKey)
|
.apiKey(this.apiKey)
|
||||||
.timeout(getOrDefault(timeout, ofSeconds(60)))
|
.timeout(getOrDefault(timeout, ofSeconds(60)))
|
||||||
.maxRetries(this.maxRetries)
|
.maxRetries(this.maxRetries)
|
||||||
.logRequestsAndResponses(this.logRequestsAndResponses)
|
.logRequestsAndResponses(getOrDefault(logRequestsAndResponses, false))
|
||||||
.build();
|
.build();
|
||||||
}
|
|
||||||
|
|
||||||
private static String computeMimeType(ResponseFormat responseFormat) {
|
|
||||||
if (responseFormat == null || ResponseFormatType.TEXT.equals(responseFormat.type())) {
|
|
||||||
return "text/plain";
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ResponseFormatType.JSON.equals(responseFormat.type()) &&
|
|
||||||
responseFormat.jsonSchema() != null &&
|
|
||||||
responseFormat.jsonSchema().rootElement() != null &&
|
|
||||||
responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema) {
|
|
||||||
|
|
||||||
return "text/x.enum";
|
|
||||||
}
|
|
||||||
|
|
||||||
return "application/json";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Response<AiMessage> generate(List<ChatMessage> messages) {
|
public Response<AiMessage> generate(List<ChatMessage> messages) {
|
||||||
ChatRequest request = ChatRequest.builder()
|
ChatRequest request = ChatRequest.builder()
|
||||||
.messages(messages)
|
.messages(messages)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
ChatResponse response = chat(request);
|
ChatResponse response = chat(request);
|
||||||
|
|
||||||
return Response.from(response.aiMessage(),
|
return Response.from(response.aiMessage(),
|
||||||
response.tokenUsage(),
|
response.tokenUsage(),
|
||||||
response.finishReason());
|
response.finishReason());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -162,128 +90,86 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
|
||||||
@Override
|
@Override
|
||||||
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
|
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
|
||||||
ChatRequest request = ChatRequest.builder()
|
ChatRequest request = ChatRequest.builder()
|
||||||
.messages(messages)
|
.messages(messages)
|
||||||
.toolSpecifications(toolSpecifications)
|
.toolSpecifications(toolSpecifications)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
ChatResponse response = chat(request);
|
ChatResponse response = chat(request);
|
||||||
|
|
||||||
return Response.from(response.aiMessage(),
|
return Response.from(response.aiMessage(),
|
||||||
response.tokenUsage(),
|
response.tokenUsage(),
|
||||||
response.finishReason());
|
response.finishReason());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ChatResponse chat(ChatRequest chatRequest) {
|
public ChatResponse chat(ChatRequest chatRequest) {
|
||||||
GeminiContent systemInstruction = new GeminiContent(GeminiRole.MODEL.toString());
|
GeminiGenerateContentRequest request = createGenerateContentRequest(
|
||||||
List<GeminiContent> geminiContentList = fromMessageToGContent(chatRequest.messages(), systemInstruction);
|
chatRequest.messages(),
|
||||||
List<ToolSpecification> toolSpecifications = chatRequest.toolSpecifications();
|
chatRequest.toolSpecifications(),
|
||||||
|
getOrDefault(chatRequest.responseFormat(), this.responseFormat)
|
||||||
|
);
|
||||||
|
|
||||||
ResponseFormat format = chatRequest.responseFormat() != null ? chatRequest.responseFormat() : this.responseFormat;
|
ChatModelRequest chatModelRequest = createChatModelRequest(
|
||||||
GeminiSchema schema = null;
|
chatRequest.messages(),
|
||||||
|
chatRequest.toolSpecifications()
|
||||||
|
);
|
||||||
|
|
||||||
String responseMimeType = computeMimeType(format);
|
|
||||||
|
|
||||||
if (format != null && format.jsonSchema() != null) {
|
|
||||||
schema = fromJsonSchemaToGSchema(format.jsonSchema());
|
|
||||||
}
|
|
||||||
|
|
||||||
GeminiGenerateContentRequest request = GeminiGenerateContentRequest.builder()
|
|
||||||
.contents(geminiContentList)
|
|
||||||
.systemInstruction(!systemInstruction.getParts().isEmpty() ? systemInstruction : null)
|
|
||||||
.generationConfig(GeminiGenerationConfig.builder()
|
|
||||||
.candidateCount(this.candidateCount)
|
|
||||||
.maxOutputTokens(this.maxOutputTokens)
|
|
||||||
.responseMimeType(responseMimeType)
|
|
||||||
.responseSchema(schema)
|
|
||||||
.stopSequences(this.stopSequences)
|
|
||||||
.temperature(this.temperature)
|
|
||||||
.topK(this.topK)
|
|
||||||
.topP(this.topP)
|
|
||||||
.build())
|
|
||||||
.safetySettings(this.safetySettings)
|
|
||||||
.tools(FunctionMapper.fromToolSepcsToGTool(toolSpecifications, this.allowCodeExecution))
|
|
||||||
.toolConfig(new GeminiToolConfig(this.toolConfig))
|
|
||||||
.build();
|
|
||||||
|
|
||||||
ChatModelRequest chatModelRequest = ChatModelRequest.builder()
|
|
||||||
.model(modelName)
|
|
||||||
.temperature(temperature)
|
|
||||||
.topP(topP)
|
|
||||||
.maxTokens(maxOutputTokens)
|
|
||||||
.messages(chatRequest.messages())
|
|
||||||
.toolSpecifications(chatRequest.toolSpecifications())
|
|
||||||
.build();
|
|
||||||
ConcurrentHashMap<Object, Object> listenerAttributes = new ConcurrentHashMap<>();
|
ConcurrentHashMap<Object, Object> listenerAttributes = new ConcurrentHashMap<>();
|
||||||
ChatModelRequestContext chatModelRequestContext = new ChatModelRequestContext(chatModelRequest, listenerAttributes);
|
notifyListenersOnRequest(new ChatModelRequestContext(chatModelRequest, listenerAttributes));
|
||||||
listeners.forEach((listener) -> {
|
|
||||||
try {
|
|
||||||
listener.onRequest(chatModelRequestContext);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.warn("Exception while calling model listener (onRequest)", e);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
GeminiGenerateContentResponse geminiResponse;
|
|
||||||
try {
|
try {
|
||||||
geminiResponse = withRetry(() -> this.geminiService.generateContent(this.modelName, this.apiKey, request), this.maxRetries);
|
GeminiGenerateContentResponse geminiResponse = withRetry(
|
||||||
} catch (RuntimeException e) {
|
() -> this.geminiService.generateContent(this.modelName, this.apiKey, request),
|
||||||
ChatModelErrorContext chatModelErrorContext = new ChatModelErrorContext(
|
this.maxRetries
|
||||||
e, chatModelRequest, null, listenerAttributes
|
|
||||||
);
|
);
|
||||||
listeners.forEach((listener) -> {
|
|
||||||
try {
|
|
||||||
listener.onError(chatModelErrorContext);
|
|
||||||
} catch (Exception ex) {
|
|
||||||
log.warn("Exception while calling model listener (onError)", ex);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
|
return processResponse(geminiResponse, chatModelRequest, listenerAttributes);
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
notifyListenersOnError(e, chatModelRequest, listenerAttributes);
|
||||||
throw e;
|
throw e;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (geminiResponse != null) {
|
private ChatResponse processResponse(
|
||||||
GeminiCandidate firstCandidate = geminiResponse.getCandidates().get(0); //TODO handle n
|
GeminiGenerateContentResponse geminiResponse,
|
||||||
GeminiUsageMetadata tokenCounts = geminiResponse.getUsageMetadata();
|
ChatModelRequest chatModelRequest,
|
||||||
|
ConcurrentHashMap<Object, Object> listenerAttributes
|
||||||
AiMessage aiMessage;
|
) {
|
||||||
|
if (geminiResponse == null) {
|
||||||
FinishReason finishReason = fromGFinishReasonToFinishReason(firstCandidate.getFinishReason());
|
|
||||||
if (firstCandidate.getContent() == null) {
|
|
||||||
aiMessage = AiMessage.from("No text was returned by the model. " +
|
|
||||||
"The model finished generating because of the following reason: " + finishReason);
|
|
||||||
} else {
|
|
||||||
aiMessage = fromGPartsToAiMessage(firstCandidate.getContent().getParts(), this.includeCodeExecutionOutput);
|
|
||||||
}
|
|
||||||
|
|
||||||
TokenUsage tokenUsage = new TokenUsage(tokenCounts.getPromptTokenCount(),
|
|
||||||
tokenCounts.getCandidatesTokenCount(),
|
|
||||||
tokenCounts.getTotalTokenCount());
|
|
||||||
|
|
||||||
ChatModelResponse chatModelResponse = ChatModelResponse.builder()
|
|
||||||
.model(modelName)
|
|
||||||
.tokenUsage(tokenUsage)
|
|
||||||
.finishReason(finishReason)
|
|
||||||
.aiMessage(aiMessage)
|
|
||||||
.build();
|
|
||||||
ChatModelResponseContext chatModelResponseContext = new ChatModelResponseContext(
|
|
||||||
chatModelResponse, chatModelRequest, listenerAttributes);
|
|
||||||
listeners.forEach((listener) -> {
|
|
||||||
try {
|
|
||||||
listener.onResponse(chatModelResponseContext);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.warn("Exception while calling model listener (onResponse)", e);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return ChatResponse.builder()
|
|
||||||
.aiMessage(aiMessage)
|
|
||||||
.finishReason(finishReason)
|
|
||||||
.tokenUsage(tokenUsage)
|
|
||||||
.build();
|
|
||||||
} else {
|
|
||||||
throw new RuntimeException("Gemini response was null");
|
throw new RuntimeException("Gemini response was null");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GeminiCandidate firstCandidate = geminiResponse.getCandidates().get(0);
|
||||||
|
GeminiUsageMetadata tokenCounts = geminiResponse.getUsageMetadata();
|
||||||
|
|
||||||
|
FinishReason finishReason = fromGFinishReasonToFinishReason(firstCandidate.getFinishReason());
|
||||||
|
AiMessage aiMessage = createAiMessage(firstCandidate, finishReason);
|
||||||
|
TokenUsage tokenUsage = createTokenUsage(tokenCounts);
|
||||||
|
|
||||||
|
Response<AiMessage> response = Response.from(aiMessage, tokenUsage, finishReason);
|
||||||
|
notifyListenersOnResponse(response, chatModelRequest, listenerAttributes);
|
||||||
|
|
||||||
|
return ChatResponse.builder()
|
||||||
|
.aiMessage(aiMessage)
|
||||||
|
.finishReason(finishReason)
|
||||||
|
.tokenUsage(tokenUsage)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private AiMessage createAiMessage(GeminiCandidate candidate, FinishReason finishReason) {
|
||||||
|
if (candidate.getContent() == null) {
|
||||||
|
return AiMessage.from("No text was returned by the model. " +
|
||||||
|
"The model finished generating because of the following reason: " + finishReason);
|
||||||
|
}
|
||||||
|
return fromGPartsToAiMessage(candidate.getContent().getParts(), this.includeCodeExecutionOutput);
|
||||||
|
}
|
||||||
|
|
||||||
|
private TokenUsage createTokenUsage(GeminiUsageMetadata tokenCounts) {
|
||||||
|
return new TokenUsage(
|
||||||
|
tokenCounts.getPromptTokenCount(),
|
||||||
|
tokenCounts.getCandidatesTokenCount(),
|
||||||
|
tokenCounts.getTotalTokenCount()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -309,9 +195,9 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
|
||||||
|
|
||||||
public GoogleAiGeminiChatModelBuilder safetySettings(Map<GeminiHarmCategory, GeminiHarmBlockThreshold> safetySettingMap) {
|
public GoogleAiGeminiChatModelBuilder safetySettings(Map<GeminiHarmCategory, GeminiHarmBlockThreshold> safetySettingMap) {
|
||||||
this.safetySettings = safetySettingMap.entrySet().stream()
|
this.safetySettings = safetySettingMap.entrySet().stream()
|
||||||
.map(entry -> new GeminiSafetySetting(entry.getKey(), entry.getValue())
|
.map(entry -> new GeminiSafetySetting(entry.getKey(), entry.getValue())
|
||||||
).collect(Collectors.toList());
|
).collect(Collectors.toList());
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
package dev.langchain4j.model.googleai;
|
||||||
|
|
||||||
|
import dev.langchain4j.Experimental;
|
||||||
|
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||||
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
|
import dev.langchain4j.data.message.ChatMessage;
|
||||||
|
import dev.langchain4j.model.StreamingResponseHandler;
|
||||||
|
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||||
|
import dev.langchain4j.model.chat.listener.ChatModelListener;
|
||||||
|
import dev.langchain4j.model.chat.listener.ChatModelRequest;
|
||||||
|
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
|
||||||
|
import dev.langchain4j.model.chat.request.ResponseFormat;
|
||||||
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
||||||
|
|
||||||
|
@Experimental
|
||||||
|
@Slf4j
|
||||||
|
public class GoogleAiGeminiStreamingChatModel extends BaseGeminiChatModel implements StreamingChatLanguageModel {
|
||||||
|
@Builder
|
||||||
|
public GoogleAiGeminiStreamingChatModel(
|
||||||
|
String apiKey, String modelName,
|
||||||
|
Double temperature, Integer topK, Double topP,
|
||||||
|
Integer maxOutputTokens, Duration timeout,
|
||||||
|
ResponseFormat responseFormat,
|
||||||
|
List<String> stopSequences, GeminiFunctionCallingConfig toolConfig,
|
||||||
|
Boolean allowCodeExecution, Boolean includeCodeExecutionOutput,
|
||||||
|
Boolean logRequestsAndResponses,
|
||||||
|
List<GeminiSafetySetting> safetySettings,
|
||||||
|
List<ChatModelListener> listeners,
|
||||||
|
Integer maxRetries
|
||||||
|
) {
|
||||||
|
super(apiKey, modelName, temperature, topK, topP, maxOutputTokens, timeout,
|
||||||
|
responseFormat, stopSequences, toolConfig, allowCodeExecution,
|
||||||
|
includeCodeExecutionOutput, logRequestsAndResponses, safetySettings,
|
||||||
|
listeners, maxRetries);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
|
||||||
|
generate(messages, Collections.emptyList(), handler);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void generate(List<ChatMessage> messages, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> handler) {
|
||||||
|
throw new RuntimeException("This method is not supported: Gemini AI cannot be forced to execute a tool.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, StreamingResponseHandler<AiMessage> handler) {
|
||||||
|
GeminiGenerateContentRequest request = createGenerateContentRequest(messages, toolSpecifications, this.responseFormat);
|
||||||
|
ChatModelRequest chatModelRequest = createChatModelRequest(messages, toolSpecifications);
|
||||||
|
|
||||||
|
ConcurrentHashMap<Object, Object> listenerAttributes = new ConcurrentHashMap<>();
|
||||||
|
ChatModelRequestContext chatModelRequestContext = new ChatModelRequestContext(chatModelRequest, listenerAttributes);
|
||||||
|
notifyListenersOnRequest(chatModelRequestContext);
|
||||||
|
|
||||||
|
processGenerateContentRequest(request, handler, chatModelRequest, listenerAttributes);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void processGenerateContentRequest(GeminiGenerateContentRequest request, StreamingResponseHandler<AiMessage> handler,
|
||||||
|
ChatModelRequest chatModelRequest, ConcurrentHashMap<Object, Object> listenerAttributes) {
|
||||||
|
GeminiStreamingResponseBuilder responseBuilder = new GeminiStreamingResponseBuilder(this.includeCodeExecutionOutput);
|
||||||
|
|
||||||
|
try {
|
||||||
|
Stream<GeminiGenerateContentResponse> contentStream = withRetry(
|
||||||
|
() -> this.geminiService.generateContentStream(this.modelName, this.apiKey, request),
|
||||||
|
maxRetries);
|
||||||
|
|
||||||
|
contentStream.forEach(partialResponse -> {
|
||||||
|
Optional<String> text = responseBuilder.append(partialResponse);
|
||||||
|
text.ifPresent(handler::onNext);
|
||||||
|
});
|
||||||
|
|
||||||
|
Response<AiMessage> fullResponse = responseBuilder.build();
|
||||||
|
handler.onComplete(fullResponse);
|
||||||
|
|
||||||
|
notifyListenersOnResponse(fullResponse, chatModelRequest, listenerAttributes);
|
||||||
|
} catch (RuntimeException exception) {
|
||||||
|
notifyListenersOnError(exception, chatModelRequest, listenerAttributes);
|
||||||
|
handler.onError(exception);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -18,12 +18,12 @@ import dev.langchain4j.data.message.UserMessage;
|
||||||
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
|
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
|
||||||
import dev.langchain4j.model.chat.request.ChatRequest;
|
import dev.langchain4j.model.chat.request.ChatRequest;
|
||||||
import dev.langchain4j.model.chat.request.ResponseFormat;
|
import dev.langchain4j.model.chat.request.ResponseFormat;
|
||||||
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
|
|
||||||
import dev.langchain4j.model.chat.request.json.JsonSchema;
|
|
||||||
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
|
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
|
||||||
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
|
|
||||||
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
|
|
||||||
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
|
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonSchema;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
|
||||||
import dev.langchain4j.model.chat.response.ChatResponse;
|
import dev.langchain4j.model.chat.response.ChatResponse;
|
||||||
import dev.langchain4j.model.output.FinishReason;
|
import dev.langchain4j.model.output.FinishReason;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
|
@ -34,7 +34,13 @@ import org.junit.jupiter.api.Test;
|
||||||
import org.junitpioneer.jupiter.RetryingTest;
|
import org.junitpioneer.jupiter.RetryingTest;
|
||||||
|
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.util.*;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Base64;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static dev.langchain4j.internal.Utils.readBytes;
|
import static dev.langchain4j.internal.Utils.readBytes;
|
||||||
|
@ -44,7 +50,9 @@ import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HA
|
||||||
import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HATE_SPEECH;
|
import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HATE_SPEECH;
|
||||||
import static java.util.Collections.singletonList;
|
import static java.util.Collections.singletonList;
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
import static org.mockito.Mockito.*;
|
import static org.mockito.Mockito.spy;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||||
|
|
||||||
public class GoogleAiGeminiChatModelIT {
|
public class GoogleAiGeminiChatModelIT {
|
||||||
|
|
||||||
|
@ -114,8 +122,8 @@ public class GoogleAiGeminiChatModelIT {
|
||||||
assertThat(jsonText).contains("\"John\"");
|
assertThat(jsonText).contains("\"John\"");
|
||||||
|
|
||||||
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(25);
|
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(25);
|
||||||
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(6);
|
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(7);
|
||||||
assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(31);
|
assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(32);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -435,10 +443,10 @@ public class GoogleAiGeminiChatModelIT {
|
||||||
.jsonSchema(JsonSchema.builder()
|
.jsonSchema(JsonSchema.builder()
|
||||||
.rootElement(JsonObjectSchema.builder()
|
.rootElement(JsonObjectSchema.builder()
|
||||||
.addStringProperty("name")
|
.addStringProperty("name")
|
||||||
.addProperty("address", JsonObjectSchema.builder()
|
.addProperty("address", JsonObjectSchema.builder()
|
||||||
.addStringProperty("city")
|
.addStringProperty("city")
|
||||||
.required("city")
|
.required("city")
|
||||||
.build())
|
.build())
|
||||||
.required("name", "address")
|
.required("name", "address")
|
||||||
.additionalProperties(false)
|
.additionalProperties(false)
|
||||||
.build())
|
.build())
|
||||||
|
|
|
@ -0,0 +1,708 @@
|
||||||
|
package dev.langchain4j.model.googleai;
|
||||||
|
|
||||||
|
import com.google.gson.Gson;
|
||||||
|
import dev.langchain4j.agent.tool.P;
|
||||||
|
import dev.langchain4j.agent.tool.Tool;
|
||||||
|
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||||
|
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||||
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
|
import dev.langchain4j.data.message.AudioContent;
|
||||||
|
import dev.langchain4j.data.message.ChatMessage;
|
||||||
|
import dev.langchain4j.data.message.ImageContent;
|
||||||
|
import dev.langchain4j.data.message.SystemMessage;
|
||||||
|
import dev.langchain4j.data.message.TextContent;
|
||||||
|
import dev.langchain4j.data.message.TextFileContent;
|
||||||
|
import dev.langchain4j.data.message.ToolExecutionResultMessage;
|
||||||
|
import dev.langchain4j.data.message.UserMessage;
|
||||||
|
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
|
||||||
|
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
|
||||||
|
import dev.langchain4j.model.chat.request.ResponseFormat;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonSchema;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
|
||||||
|
import dev.langchain4j.model.output.FinishReason;
|
||||||
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import dev.langchain4j.service.AiServices;
|
||||||
|
import dev.langchain4j.service.TokenStream;
|
||||||
|
import dev.langchain4j.service.output.JsonSchemas;
|
||||||
|
import org.junit.jupiter.api.AfterEach;
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junitpioneer.jupiter.RetryingTest;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Base64;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.concurrent.ExecutionException;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.concurrent.TimeoutException;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import static dev.langchain4j.internal.Utils.readBytes;
|
||||||
|
import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON;
|
||||||
|
import static dev.langchain4j.model.googleai.GeminiHarmBlockThreshold.BLOCK_LOW_AND_ABOVE;
|
||||||
|
import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HARASSMENT;
|
||||||
|
import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HATE_SPEECH;
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.mockito.Mockito.spy;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||||
|
|
||||||
|
public class GoogleAiGeminiStreamingChatModelIT {
|
||||||
|
|
||||||
|
private static final String GOOGLE_AI_GEMINI_API_KEY = System.getenv("GOOGLE_AI_GEMINI_API_KEY");
|
||||||
|
|
||||||
|
private static final String CAT_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/e/e9/Felis_silvestris_silvestris_small_gradual_decrease_of_quality.png";
|
||||||
|
private static final String MD_FILE_URL = "https://raw.githubusercontent.com/langchain4j/langchain4j/main/docs/docs/intro.md";
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_answer_simple_question() {
|
||||||
|
// given
|
||||||
|
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-flash")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
UserMessage userMessage = UserMessage.from("What is the capital of France?");
|
||||||
|
|
||||||
|
// when
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate(userMessage, handler);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
|
|
||||||
|
// then
|
||||||
|
String text = response.content().text();
|
||||||
|
assertThat(text).containsIgnoringCase("Paris");
|
||||||
|
|
||||||
|
assertThat(response.finishReason()).isEqualTo(FinishReason.STOP);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_configure_generation() {
|
||||||
|
// given
|
||||||
|
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-flash")
|
||||||
|
.temperature(2.0)
|
||||||
|
.topP(0.5)
|
||||||
|
.topK(10)
|
||||||
|
.maxOutputTokens(10)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// when
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate("How much is 3+4? Reply with just the answer", handler);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
|
|
||||||
|
// then
|
||||||
|
String text = response.content().text();
|
||||||
|
assertThat(text.trim()).isEqualTo("7");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_answer_in_json() {
|
||||||
|
// given
|
||||||
|
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-pro")
|
||||||
|
.responseFormat(ResponseFormat.JSON)
|
||||||
|
.logRequestsAndResponses(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// when
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate(UserMessage.from("What is the firstname of the John Doe?\n" +
|
||||||
|
"Reply in JSON following with the following format: {\"firstname\": string}"), handler);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
|
|
||||||
|
// then
|
||||||
|
String jsonText = response.content().text();
|
||||||
|
|
||||||
|
assertThat(jsonText).contains("\"firstname\"");
|
||||||
|
assertThat(jsonText).contains("\"John\"");
|
||||||
|
|
||||||
|
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(25);
|
||||||
|
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(7);
|
||||||
|
assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(32);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_support_multiturn_chatting() {
|
||||||
|
// given
|
||||||
|
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-flash")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
List<ChatMessage> messages = new ArrayList<>();
|
||||||
|
messages.add(UserMessage.from("Hello, my name is Guillaume"));
|
||||||
|
messages.add(AiMessage.from("Hi, how can I assist you today?"));
|
||||||
|
messages.add(UserMessage.from("What is my name? Reply with just my name."));
|
||||||
|
|
||||||
|
// when
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate(messages, handler);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
|
|
||||||
|
// then
|
||||||
|
assertThat(response.content().text()).contains("Guillaume");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_support_sending_images_as_base64() {
|
||||||
|
// given
|
||||||
|
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-flash")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
UserMessage userMessage = UserMessage.userMessage(
|
||||||
|
ImageContent.from(new String(Base64.getEncoder().encode(readBytes(CAT_IMAGE_URL))), "image/png"),
|
||||||
|
TextContent.from("Describe this image in a single word")
|
||||||
|
);
|
||||||
|
|
||||||
|
// when
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate(userMessage, handler);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
|
|
||||||
|
// then
|
||||||
|
assertThat(response.content().text()).containsIgnoringCase("cat");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_support_text_file() {
|
||||||
|
// given
|
||||||
|
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-flash")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
UserMessage userMessage = UserMessage.userMessage(
|
||||||
|
TextFileContent.from(new String(Base64.getEncoder().encode(readBytes(MD_FILE_URL))), "text/markdown"),
|
||||||
|
TextContent.from("What project does this markdown file mention?")
|
||||||
|
);
|
||||||
|
|
||||||
|
// when
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate(userMessage, handler);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
|
|
||||||
|
// then
|
||||||
|
assertThat(response.content().text()).containsIgnoringCase("LangChain4j");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_support_audio_file() {
|
||||||
|
// given
|
||||||
|
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-flash")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
UserMessage userMessage = UserMessage.from(
|
||||||
|
AudioContent.from(
|
||||||
|
new String(Base64.getEncoder().encode( //TODO use local file
|
||||||
|
readBytes("https://storage.googleapis.com/cloud-samples-data/generative-ai/audio/pixel.mp3"))),
|
||||||
|
"audio/mp3"),
|
||||||
|
TextContent.from("Give a summary of the audio")
|
||||||
|
);
|
||||||
|
|
||||||
|
// when
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate(userMessage, handler);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
|
|
||||||
|
// then
|
||||||
|
assertThat(response.content().text()).containsIgnoringCase("Pixel");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Disabled
|
||||||
|
void should_support_video_file() {
|
||||||
|
// ToDo waiting for the normal GoogleAiGeminiChatModel to implement the test
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_respect_system_instruction() {
|
||||||
|
// given
|
||||||
|
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-flash")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
List<ChatMessage> chatMessages = List.of(
|
||||||
|
SystemMessage.from("Translate from English into French"),
|
||||||
|
UserMessage.from("apple")
|
||||||
|
);
|
||||||
|
|
||||||
|
// when
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate(chatMessages, handler);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
|
|
||||||
|
// then
|
||||||
|
assertThat(response.content().text()).containsIgnoringCase("pomme");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_execute_python_code() {
|
||||||
|
// given
|
||||||
|
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-flash")
|
||||||
|
.logRequestsAndResponses(true)
|
||||||
|
.allowCodeExecution(true)
|
||||||
|
.includeCodeExecutionOutput(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// when
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate(
|
||||||
|
UserMessage.from("Calculate `fibonacci(13)`. Write code in Python and execute it to get the result."),
|
||||||
|
handler
|
||||||
|
);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
|
|
||||||
|
// then
|
||||||
|
String text = response.content().text();
|
||||||
|
System.out.println("text = " + text);
|
||||||
|
|
||||||
|
assertThat(text).containsIgnoringCase("233");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_support_JSON_array_in_tools() {
|
||||||
|
// given
|
||||||
|
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-flash")
|
||||||
|
.logRequestsAndResponses(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// when
|
||||||
|
List<ChatMessage> allMessages = new ArrayList<>();
|
||||||
|
allMessages.add(UserMessage.from("Return a JSON list containing the first 10 fibonacci numbers."));
|
||||||
|
|
||||||
|
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler1 = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate(
|
||||||
|
allMessages,
|
||||||
|
List.of(ToolSpecification.builder()
|
||||||
|
.name("getFirstNFibonacciNumbers")
|
||||||
|
.description("Get the first n fibonacci numbers")
|
||||||
|
.parameters(JsonObjectSchema.builder()
|
||||||
|
.addNumberProperty("n")
|
||||||
|
.build())
|
||||||
|
.build()),
|
||||||
|
handler1
|
||||||
|
);
|
||||||
|
Response<AiMessage> response1 = handler1.get();
|
||||||
|
|
||||||
|
// then
|
||||||
|
assertThat(response1.content().hasToolExecutionRequests()).isTrue();
|
||||||
|
assertThat(response1.content().toolExecutionRequests().get(0).name()).isEqualTo("getFirstNFibonacciNumbers");
|
||||||
|
assertThat(response1.content().toolExecutionRequests().get(0).arguments()).contains("\"n\":10");
|
||||||
|
|
||||||
|
allMessages.add(response1.content());
|
||||||
|
|
||||||
|
// when
|
||||||
|
ToolExecutionResultMessage forecastResult =
|
||||||
|
ToolExecutionResultMessage.from(null, "getFirstNFibonacciNumbers", "[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]");
|
||||||
|
allMessages.add(forecastResult);
|
||||||
|
|
||||||
|
// then
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler2 = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate(allMessages, handler2);
|
||||||
|
Response<AiMessage> response2 = handler2.get();
|
||||||
|
|
||||||
|
// then
|
||||||
|
assertThat(response2.content().text()).contains("[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test is flaky, because Gemini doesn't 100% always ask for parallel tool calls
|
||||||
|
// and sometimes requests more information
|
||||||
|
@RetryingTest(5)
|
||||||
|
void should_support_parallel_tool_execution() {
|
||||||
|
// given
|
||||||
|
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-flash")
|
||||||
|
.logRequestsAndResponses(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// when
|
||||||
|
List<ChatMessage> allMessages = new ArrayList<>();
|
||||||
|
allMessages.add(UserMessage.from("Which warehouse has more stock, ABC or XYZ?"));
|
||||||
|
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate(
|
||||||
|
allMessages,
|
||||||
|
List.of(ToolSpecification.builder()
|
||||||
|
.name("getWarehouseStock")
|
||||||
|
.description("Retrieve the amount of stock available in a warehouse designated by its name")
|
||||||
|
.parameters(JsonObjectSchema.builder()
|
||||||
|
.addStringProperty("name", "The name of the warehouse")
|
||||||
|
.build())
|
||||||
|
.build()),
|
||||||
|
handler
|
||||||
|
);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
|
|
||||||
|
|
||||||
|
// then
|
||||||
|
assertThat(response.content().hasToolExecutionRequests()).isTrue();
|
||||||
|
|
||||||
|
List<ToolExecutionRequest> executionRequests = response.content().toolExecutionRequests();
|
||||||
|
assertThat(executionRequests).hasSize(2);
|
||||||
|
|
||||||
|
String allArgs = executionRequests.stream()
|
||||||
|
.map(ToolExecutionRequest::arguments)
|
||||||
|
.collect(Collectors.joining(" "));
|
||||||
|
assertThat(allArgs).contains("ABC");
|
||||||
|
assertThat(allArgs).contains("XYZ");
|
||||||
|
}
|
||||||
|
|
||||||
|
@RetryingTest(5)
|
||||||
|
void should_support_safety_settings() {
|
||||||
|
// given
|
||||||
|
List<GeminiSafetySetting> safetySettings = List.of(
|
||||||
|
new GeminiSafetySetting(HARM_CATEGORY_HATE_SPEECH, BLOCK_LOW_AND_ABOVE),
|
||||||
|
new GeminiSafetySetting(HARM_CATEGORY_HARASSMENT, BLOCK_LOW_AND_ABOVE)
|
||||||
|
);
|
||||||
|
|
||||||
|
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-flash")
|
||||||
|
.logRequestsAndResponses(true)
|
||||||
|
.safetySettings(safetySettings)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// when
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate("You're a dumb f*cking idiot bastard!", handler);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
|
|
||||||
|
// then
|
||||||
|
assertThat(response.finishReason()).isEqualTo(FinishReasonMapper.fromGFinishReasonToFinishReason(GeminiFinishReason.SAFETY));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_comply_to_response_schema() {
|
||||||
|
// given
|
||||||
|
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-flash")
|
||||||
|
.logRequestsAndResponses(true)
|
||||||
|
.responseFormat(ResponseFormat.builder()
|
||||||
|
.type(JSON)
|
||||||
|
.jsonSchema(JsonSchema.builder()
|
||||||
|
.rootElement(JsonObjectSchema.builder()
|
||||||
|
.addStringProperty("name")
|
||||||
|
.addProperty("address", JsonObjectSchema.builder()
|
||||||
|
.addStringProperty("city")
|
||||||
|
.required("city")
|
||||||
|
.build())
|
||||||
|
.required("name", "address")
|
||||||
|
.additionalProperties(false)
|
||||||
|
.build())
|
||||||
|
.build())
|
||||||
|
.build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// when
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate(List.of(
|
||||||
|
SystemMessage.from(
|
||||||
|
"Your role is to extract information related to a person," +
|
||||||
|
"like their name, their address, the city the live in."),
|
||||||
|
UserMessage.from(
|
||||||
|
"In the town of Liverpool, lived Tommy Skybridge, a young little boy."
|
||||||
|
)
|
||||||
|
), handler);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
|
|
||||||
|
System.out.println("response = " + response);
|
||||||
|
|
||||||
|
// then
|
||||||
|
assertThat(response.content().text().trim())
|
||||||
|
.isEqualTo("{\"address\": {\"city\": \"Liverpool\"}, \"name\": \"Tommy Skybridge\"}");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_handle_enum_type() {
|
||||||
|
// given
|
||||||
|
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-flash")
|
||||||
|
.logRequestsAndResponses(true)
|
||||||
|
.responseFormat(ResponseFormat.builder()
|
||||||
|
.type(JSON)
|
||||||
|
.jsonSchema(JsonSchema.builder()
|
||||||
|
.rootElement(JsonObjectSchema.builder()
|
||||||
|
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||||
|
put("sentiment", JsonEnumSchema.builder()
|
||||||
|
.enumValues("POSITIVE", "NEGATIVE")
|
||||||
|
.build());
|
||||||
|
}})
|
||||||
|
.required("sentiment")
|
||||||
|
.additionalProperties(false)
|
||||||
|
.build())
|
||||||
|
.build())
|
||||||
|
.build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// when
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate(List.of(
|
||||||
|
SystemMessage.from(
|
||||||
|
"Your role is to analyze the sentiment of the text you receive."),
|
||||||
|
UserMessage.from(
|
||||||
|
"This is super exciting news, congratulations!"
|
||||||
|
)
|
||||||
|
), handler);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
|
|
||||||
|
System.out.println("response = " + response);
|
||||||
|
|
||||||
|
// then
|
||||||
|
assertThat(response.content().text().trim())
|
||||||
|
.isEqualTo("{\"sentiment\": \"POSITIVE\"}");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_handle_enum_output_mode() {
|
||||||
|
// given
|
||||||
|
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-flash")
|
||||||
|
.logRequestsAndResponses(true)
|
||||||
|
.responseFormat(ResponseFormat.builder()
|
||||||
|
.type(JSON)
|
||||||
|
.jsonSchema(JsonSchema.builder()
|
||||||
|
.rootElement(JsonEnumSchema.builder()
|
||||||
|
.enumValues("POSITIVE", "NEUTRAL", "NEGATIVE")
|
||||||
|
.build())
|
||||||
|
.build())
|
||||||
|
.build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// when
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate(List.of(
|
||||||
|
SystemMessage.from(
|
||||||
|
"Your role is to analyze the sentiment of the text you receive."),
|
||||||
|
UserMessage.from(
|
||||||
|
"This is super exciting news, congratulations!"
|
||||||
|
)
|
||||||
|
), handler);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
|
|
||||||
|
// then
|
||||||
|
assertThat(response.content().text().trim())
|
||||||
|
.isEqualTo("POSITIVE");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_allow_array_as_response_schema() {
|
||||||
|
// given
|
||||||
|
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-flash")
|
||||||
|
.logRequestsAndResponses(true)
|
||||||
|
.responseFormat(ResponseFormat.builder()
|
||||||
|
.type(JSON)
|
||||||
|
.jsonSchema(JsonSchema.builder()
|
||||||
|
.rootElement(JsonArraySchema.builder()
|
||||||
|
.items(new JsonIntegerSchema())
|
||||||
|
.build())
|
||||||
|
.build())
|
||||||
|
.build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// when
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate(List.of(
|
||||||
|
SystemMessage.from(
|
||||||
|
"Your role is to return a list of 6-faces dice rolls"),
|
||||||
|
UserMessage.from(
|
||||||
|
"Give me 3 dice rolls"
|
||||||
|
)
|
||||||
|
), handler);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
|
|
||||||
|
System.out.println("response = " + response);
|
||||||
|
|
||||||
|
// then
|
||||||
|
Integer[] diceRolls = new Gson().fromJson(response.content().text(), Integer[].class);
|
||||||
|
assertThat(diceRolls.length).isEqualTo(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
private class Color {
|
||||||
|
private String name;
|
||||||
|
private int red;
|
||||||
|
private int green;
|
||||||
|
private int blue;
|
||||||
|
private boolean muted;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_deserialize_to_POJO() {
|
||||||
|
// given
|
||||||
|
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-flash")
|
||||||
|
.logRequestsAndResponses(true)
|
||||||
|
.responseFormat(ResponseFormat.builder()
|
||||||
|
.type(JSON)
|
||||||
|
.jsonSchema(JsonSchemas.jsonSchemaFrom(Color.class).get())
|
||||||
|
.build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// when
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate(List.of(
|
||||||
|
SystemMessage.from(
|
||||||
|
"Your role is to extract information from the description of a color"),
|
||||||
|
UserMessage.from(
|
||||||
|
"Cobalt blue is a blend of a lot of blue, a bit of green, and almost no red."
|
||||||
|
)
|
||||||
|
), handler);
|
||||||
|
Response<AiMessage> response = handler.get();
|
||||||
|
|
||||||
|
System.out.println("response = " + response);
|
||||||
|
|
||||||
|
Color color = new Gson().fromJson(response.content().text(), Color.class);
|
||||||
|
|
||||||
|
// then
|
||||||
|
assertThat(color.name).isEqualToIgnoringCase("Cobalt blue");
|
||||||
|
assertThat(color.muted).isFalse();
|
||||||
|
assertThat(color.red).isLessThanOrEqualTo(color.green);
|
||||||
|
assertThat(color.green).isLessThanOrEqualTo(color.blue);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_support_tool_config() {
|
||||||
|
// given
|
||||||
|
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-flash")
|
||||||
|
.logRequestsAndResponses(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
List<ChatMessage> chatMessages = List.of(UserMessage.from("Call toolOne"));
|
||||||
|
List<ToolSpecification> listOfTools = Arrays.asList(
|
||||||
|
ToolSpecification.builder().name("toolOne").build(),
|
||||||
|
ToolSpecification.builder().name("toolTwo").build()
|
||||||
|
);
|
||||||
|
|
||||||
|
// when
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler1 = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate(chatMessages, listOfTools, handler1);
|
||||||
|
Response<AiMessage> response1 = handler1.get();
|
||||||
|
|
||||||
|
// then
|
||||||
|
assertThat(response1.content().hasToolExecutionRequests()).isTrue();
|
||||||
|
assertThat(response1.content().toolExecutionRequests().get(0).name()).isEqualTo("toolOne");
|
||||||
|
|
||||||
|
// given
|
||||||
|
gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-flash")
|
||||||
|
.logRequestsAndResponses(true)
|
||||||
|
.toolConfig(new GeminiFunctionCallingConfig(GeminiMode.ANY, List.of("toolTwo")))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// when
|
||||||
|
TestStreamingResponseHandler<AiMessage> handler2 = new TestStreamingResponseHandler<>();
|
||||||
|
gemini.generate("Call toolOne", handler2);
|
||||||
|
Response<AiMessage> response2 = handler2.get();
|
||||||
|
|
||||||
|
// then
|
||||||
|
assertThat(response2.content().hasToolExecutionRequests()).isFalse();
|
||||||
|
}
|
||||||
|
|
||||||
|
static class Transactions {
|
||||||
|
@Tool("returns amount of a given transaction")
|
||||||
|
double getTransactionAmount(@P("ID of a transaction") String id) {
|
||||||
|
System.out.printf("called getTransactionAmount(%s)%n", id);
|
||||||
|
switch (id) {
|
||||||
|
case "T001":
|
||||||
|
return 11.1;
|
||||||
|
case "T002":
|
||||||
|
return 22.2;
|
||||||
|
default:
|
||||||
|
throw new IllegalArgumentException("Unknown transaction ID: " + id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
interface StreamingAssistant {
|
||||||
|
TokenStream chat(String userMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
@RetryingTest(10)
|
||||||
|
void should_work_with_tools_with_AiServices() throws ExecutionException, InterruptedException, TimeoutException {
|
||||||
|
// given
|
||||||
|
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("gemini-1.5-pro")
|
||||||
|
.logRequestsAndResponses(true)
|
||||||
|
.timeout(Duration.ofMinutes(2))
|
||||||
|
.temperature(0.0)
|
||||||
|
.topP(0.0)
|
||||||
|
.topK(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// when
|
||||||
|
Transactions spyTransactions = spy(new Transactions());
|
||||||
|
|
||||||
|
MessageWindowChatMemory chatMemory = MessageWindowChatMemory.withMaxMessages(20);
|
||||||
|
StreamingAssistant assistant = AiServices.builder(StreamingAssistant.class)
|
||||||
|
.tools(spyTransactions)
|
||||||
|
.chatMemory(chatMemory)
|
||||||
|
.streamingChatLanguageModel(gemini)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// then
|
||||||
|
CompletableFuture<Response<AiMessage>> future1 = new CompletableFuture<>();
|
||||||
|
assistant.chat("What is the amount of transaction T001?")
|
||||||
|
.onNext(System.out::println)
|
||||||
|
.onComplete(future1::complete)
|
||||||
|
.onError(future1::completeExceptionally)
|
||||||
|
.start();
|
||||||
|
Response<AiMessage> response1 = future1.get(30, TimeUnit.SECONDS);
|
||||||
|
|
||||||
|
assertThat(response1.content().text()).containsIgnoringCase("11.1");
|
||||||
|
verify(spyTransactions).getTransactionAmount("T001");
|
||||||
|
|
||||||
|
CompletableFuture<Response<AiMessage>> future2 = new CompletableFuture<>();
|
||||||
|
assistant.chat("What is the amount of transaction T002?")
|
||||||
|
.onNext(System.out::println)
|
||||||
|
.onComplete(future2::complete)
|
||||||
|
.onError(future2::completeExceptionally)
|
||||||
|
.start();
|
||||||
|
Response<AiMessage> response2 = future2.get(30, TimeUnit.SECONDS);
|
||||||
|
|
||||||
|
assertThat(response2.content().text()).containsIgnoringCase("22.2");
|
||||||
|
verify(spyTransactions).getTransactionAmount("T002");
|
||||||
|
|
||||||
|
verifyNoMoreInteractions(spyTransactions);
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterEach
|
||||||
|
void afterEach() throws InterruptedException {
|
||||||
|
String ciDelaySeconds = System.getenv("CI_DELAY_SECONDS_GOOGLE_AI_GEMINI");
|
||||||
|
if (ciDelaySeconds != null) {
|
||||||
|
Thread.sleep(Integer.parseInt(ciDelaySeconds) * 1000L);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,61 @@
|
||||||
|
package dev.langchain4j.model.googleai;
|
||||||
|
|
||||||
|
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||||
|
import dev.langchain4j.model.chat.StreamingChatModelListenerIT;
|
||||||
|
import dev.langchain4j.model.chat.listener.ChatModelListener;
|
||||||
|
import org.junit.jupiter.api.AfterEach;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import static java.util.Collections.singletonList;
|
||||||
|
|
||||||
|
class GoogleAiGeminiStreamingChatModelListenerIT extends StreamingChatModelListenerIT {
|
||||||
|
private static final String GOOGLE_AI_GEMINI_API_KEY = System.getenv("GOOGLE_AI_GEMINI_API_KEY");
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected StreamingChatLanguageModel createModel(ChatModelListener listener) {
|
||||||
|
return GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName(modelName())
|
||||||
|
.temperature(temperature())
|
||||||
|
.topP(topP())
|
||||||
|
.maxOutputTokens(maxTokens())
|
||||||
|
.listeners(singletonList(listener))
|
||||||
|
.logRequestsAndResponses(true)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String modelName() {
|
||||||
|
return "gemini-1.5-flash";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean assertResponseId() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected StreamingChatLanguageModel createFailingModel(ChatModelListener listener) {
|
||||||
|
return GoogleAiGeminiStreamingChatModel.builder()
|
||||||
|
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||||
|
.modelName("banana")
|
||||||
|
.listeners(singletonList(listener))
|
||||||
|
.logRequestsAndResponses(true)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Class<? extends Exception> expectedExceptionClass() {
|
||||||
|
return RuntimeException.class;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@AfterEach
|
||||||
|
void afterEach() throws InterruptedException {
|
||||||
|
String ciDelaySeconds = System.getenv("CI_DELAY_SECONDS_VERTEX_AI_GEMINI");
|
||||||
|
if (ciDelaySeconds != null) {
|
||||||
|
Thread.sleep(Integer.parseInt(ciDelaySeconds) * 1000L);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue