Adding GoogleAiGeminiStreamingChatModel (#1951)
## Issue Closes #1903 ## Change Implemented the GoogleAiGeminiStreamingChatModel. This PR depends on #1950. ## General checklist - [X] There are no breaking changes - [x] I have added unit and integration tests for my change - [x] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [x] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green - [x] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable)
This commit is contained in:
parent
e98a2e4f62
commit
aa0e488166
|
@ -83,9 +83,34 @@ ChatLanguageModel gemini = GoogleAiGeminiChatModel.builder()
|
|||
```
|
||||
|
||||
## GoogleAiGeminiStreamingChatModel
|
||||
The `GoogleAiGeminiStreamingChatModel` allows streaming the text of a response token by token. The response must be managed by a `StreamingResponseHandler`.
|
||||
```java
|
||||
StreamingChatLanguageModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(System.getenv("GEMINI_AI_KEY"))
|
||||
.modelName("gemini-1.5-flash")
|
||||
.build();
|
||||
|
||||
No streaming chat model is available yet.
|
||||
Please open a feature request if you're interested in a streaming model or if you want to contribute to implementing it.
|
||||
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
|
||||
|
||||
gemini.generate("Tell me a joke about Java", new StreamingResponseHandler<AiMessage>() {
|
||||
@Override
|
||||
public void onNext(String token) {
|
||||
System.out.print(token);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete(Response<AiMessage> response) {
|
||||
futureResponse.complete(response);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable error) {
|
||||
futureResponse.completeExceptionally(error);
|
||||
}
|
||||
});
|
||||
|
||||
futureResponse.join();
|
||||
```
|
||||
|
||||
## Tools
|
||||
|
||||
|
|
|
@ -11,8 +11,8 @@ sidebar_position: 0
|
|||
| [Azure OpenAI](/integrations/language-models/azure-open-ai) | ✅ | ✅ | ✅ | text, image | ✅ | | | |
|
||||
| [ChatGLM](/integrations/language-models/chatglm) | | | | text | | | | |
|
||||
| [DashScope](/integrations/language-models/dashscope) | ✅ | ✅ | | text, image, audio | ✅ | | | |
|
||||
| [GitHub Models](/integrations/language-models/github-models) | ✅ | ✅ | ✅ | text | ✅ | | | |
|
||||
| [Google AI Gemini](/integrations/language-models/google-ai-gemini) | | ✅ | ✅ | text, image, audio, video, PDF | ✅ | | | |
|
||||
| [GitHub Models](/integrations/language-models/github-models) | ✅ | ✅ | ✅ | text | ✅ | | | |
|
||||
| [Google AI Gemini](/integrations/language-models/google-ai-gemini) | ✅ | ✅ | ✅ | text, image, audio, video, PDF | ✅ | | | |
|
||||
| [Google Vertex AI Gemini](/integrations/language-models/google-vertex-ai-gemini) | ✅ | ✅ | ✅ | text, image, audio, video, PDF | ✅ | | | |
|
||||
| [Google Vertex AI PaLM 2](/integrations/language-models/google-palm) | | | | text | | | ✅ | |
|
||||
| [Hugging Face](/integrations/language-models/hugging-face) | | | | text | | | | |
|
||||
|
|
|
@ -0,0 +1,192 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelListener;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelRequest;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelResponse;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
|
||||
import dev.langchain4j.model.chat.request.ResponseFormat;
|
||||
import dev.langchain4j.model.chat.request.ResponseFormatType;
|
||||
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.copyIfNotNull;
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromMessageToGContent;
|
||||
import static dev.langchain4j.model.googleai.SchemaMapper.fromJsonSchemaToGSchema;
|
||||
import static java.time.Duration.ofSeconds;
|
||||
import static java.util.Collections.emptyList;
|
||||
|
||||
@Experimental
|
||||
@Slf4j
|
||||
abstract class BaseGeminiChatModel {
|
||||
protected final GeminiService geminiService;
|
||||
protected final String apiKey;
|
||||
protected final String modelName;
|
||||
protected final Double temperature;
|
||||
protected final Integer topK;
|
||||
protected final Double topP;
|
||||
protected final Integer maxOutputTokens;
|
||||
protected final List<String> stopSequences;
|
||||
protected final ResponseFormat responseFormat;
|
||||
protected final GeminiFunctionCallingConfig toolConfig;
|
||||
protected final boolean allowCodeExecution;
|
||||
protected final boolean includeCodeExecutionOutput;
|
||||
protected final List<GeminiSafetySetting> safetySettings;
|
||||
protected final List<ChatModelListener> listeners;
|
||||
protected final Integer maxRetries;
|
||||
|
||||
protected BaseGeminiChatModel(
|
||||
String apiKey,
|
||||
String modelName,
|
||||
Double temperature,
|
||||
Integer topK,
|
||||
Double topP,
|
||||
Integer maxOutputTokens,
|
||||
Duration timeout,
|
||||
ResponseFormat responseFormat,
|
||||
List<String> stopSequences,
|
||||
GeminiFunctionCallingConfig toolConfig,
|
||||
Boolean allowCodeExecution,
|
||||
Boolean includeCodeExecutionOutput,
|
||||
Boolean logRequestsAndResponses,
|
||||
List<GeminiSafetySetting> safetySettings,
|
||||
List<ChatModelListener> listeners,
|
||||
Integer maxRetries
|
||||
) {
|
||||
this.apiKey = ensureNotBlank(apiKey, "apiKey");
|
||||
this.modelName = ensureNotBlank(modelName, "modelName");
|
||||
this.temperature = temperature;
|
||||
this.topK = topK;
|
||||
this.topP = topP;
|
||||
this.maxOutputTokens = maxOutputTokens;
|
||||
this.stopSequences = getOrDefault(stopSequences, emptyList());
|
||||
this.toolConfig = toolConfig;
|
||||
this.allowCodeExecution = getOrDefault(allowCodeExecution, false);
|
||||
this.includeCodeExecutionOutput = getOrDefault(includeCodeExecutionOutput, false);
|
||||
this.safetySettings = copyIfNotNull(safetySettings);
|
||||
this.responseFormat = responseFormat;
|
||||
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
this.geminiService = new GeminiService(
|
||||
getOrDefault(logRequestsAndResponses, false) ? log : null,
|
||||
getOrDefault(timeout, ofSeconds(60))
|
||||
);
|
||||
}
|
||||
|
||||
protected GeminiGenerateContentRequest createGenerateContentRequest(
|
||||
List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications,
|
||||
ResponseFormat responseFormat
|
||||
) {
|
||||
GeminiContent systemInstruction = new GeminiContent(GeminiRole.MODEL.toString());
|
||||
List<GeminiContent> geminiContentList = fromMessageToGContent(messages, systemInstruction);
|
||||
|
||||
GeminiSchema schema = null;
|
||||
if (responseFormat != null && responseFormat.jsonSchema() != null) {
|
||||
schema = fromJsonSchemaToGSchema(responseFormat.jsonSchema());
|
||||
}
|
||||
|
||||
return GeminiGenerateContentRequest.builder()
|
||||
.contents(geminiContentList)
|
||||
.systemInstruction(!systemInstruction.getParts().isEmpty() ? systemInstruction : null)
|
||||
.generationConfig(GeminiGenerationConfig.builder()
|
||||
.candidateCount(1) // Multiple candidates aren't supported by langchain4j
|
||||
.maxOutputTokens(this.maxOutputTokens)
|
||||
.responseMimeType(computeMimeType(responseFormat))
|
||||
.responseSchema(schema)
|
||||
.stopSequences(this.stopSequences)
|
||||
.temperature(this.temperature)
|
||||
.topK(this.topK)
|
||||
.topP(this.topP)
|
||||
.build())
|
||||
.safetySettings(this.safetySettings)
|
||||
.tools(FunctionMapper.fromToolSepcsToGTool(toolSpecifications, this.allowCodeExecution))
|
||||
.toolConfig(new GeminiToolConfig(this.toolConfig))
|
||||
.build();
|
||||
}
|
||||
|
||||
protected ChatModelRequest createChatModelRequest(
|
||||
List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications
|
||||
) {
|
||||
return ChatModelRequest.builder()
|
||||
.model(modelName)
|
||||
.temperature(temperature)
|
||||
.topP(topP)
|
||||
.maxTokens(maxOutputTokens)
|
||||
.messages(messages)
|
||||
.toolSpecifications(toolSpecifications)
|
||||
.build();
|
||||
}
|
||||
|
||||
protected static String computeMimeType(ResponseFormat responseFormat) {
|
||||
if (responseFormat == null || ResponseFormatType.TEXT.equals(responseFormat.type())) {
|
||||
return "text/plain";
|
||||
}
|
||||
|
||||
if (ResponseFormatType.JSON.equals(responseFormat.type()) &&
|
||||
responseFormat.jsonSchema() != null &&
|
||||
responseFormat.jsonSchema().rootElement() != null &&
|
||||
responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema) {
|
||||
return "text/x.enum";
|
||||
}
|
||||
|
||||
return "application/json";
|
||||
}
|
||||
|
||||
protected void notifyListenersOnRequest(ChatModelRequestContext context) {
|
||||
listeners.forEach((listener) -> {
|
||||
try {
|
||||
listener.onRequest(context);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener (onRequest)", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
protected void notifyListenersOnResponse(Response<AiMessage> response, ChatModelRequest request,
|
||||
ConcurrentHashMap<Object, Object> attributes) {
|
||||
ChatModelResponse chatModelResponse = ChatModelResponse.builder()
|
||||
.model(modelName)
|
||||
.tokenUsage(response.tokenUsage())
|
||||
.finishReason(response.finishReason())
|
||||
.aiMessage(response.content())
|
||||
.build();
|
||||
ChatModelResponseContext context = new ChatModelResponseContext(
|
||||
chatModelResponse, request, attributes);
|
||||
listeners.forEach((listener) -> {
|
||||
try {
|
||||
listener.onResponse(context);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener (onResponse)", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
protected void notifyListenersOnError(Exception exception, ChatModelRequest request,
|
||||
ConcurrentHashMap<Object, Object> attributes) {
|
||||
listeners.forEach((listener) -> {
|
||||
try {
|
||||
ChatModelErrorContext context = new ChatModelErrorContext(
|
||||
exception, request, null, attributes);
|
||||
listener.onError(context);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener (onError)", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -10,6 +10,8 @@ import java.net.http.HttpClient;
|
|||
import java.net.http.HttpRequest;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.time.Duration;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
class GeminiService {
|
||||
private static final String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta";
|
||||
|
@ -48,6 +50,11 @@ class GeminiService {
|
|||
return sendRequest(url, apiKey, request, GoogleAiBatchEmbeddingResponse.class);
|
||||
}
|
||||
|
||||
Stream<GeminiGenerateContentResponse> generateContentStream(String modelName, String apiKey, GeminiGenerateContentRequest request) {
|
||||
String url = String.format("%s/models/%s:streamGenerateContent?alt=sse", GEMINI_AI_ENDPOINT, modelName);
|
||||
return streamRequest(url, apiKey, request, GeminiGenerateContentResponse.class);
|
||||
}
|
||||
|
||||
private <T> T sendRequest(String url, String apiKey, Object requestBody, Class<T> responseType) {
|
||||
String jsonBody = gson.toJson(requestBody);
|
||||
HttpRequest request = buildHttpRequest(url, apiKey, jsonBody);
|
||||
|
@ -72,6 +79,40 @@ class GeminiService {
|
|||
}
|
||||
}
|
||||
|
||||
private <T> Stream<T> streamRequest(String url, String apiKey, Object requestBody, Class<T> responseType) {
|
||||
String jsonBody = gson.toJson(requestBody);
|
||||
HttpRequest httpRequest = buildHttpRequest(url, apiKey, jsonBody);
|
||||
|
||||
logRequest(jsonBody);
|
||||
|
||||
try {
|
||||
HttpResponse<Stream<String>> httpResponse = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofLines());
|
||||
|
||||
if (httpResponse.statusCode() >= 300) {
|
||||
String errorBody = httpResponse.body()
|
||||
.collect(Collectors.joining("\n"));
|
||||
|
||||
throw new RuntimeException(String.format("HTTP error (%d): %s", httpResponse.statusCode(), errorBody));
|
||||
}
|
||||
|
||||
Stream<T> responseStream = httpResponse.body()
|
||||
.filter(line -> line.startsWith("data: "))
|
||||
.map(line -> line.substring(6)) // Remove "data: " prefix
|
||||
.map(jsonString -> gson.fromJson(jsonString, responseType));
|
||||
|
||||
if (logger != null) {
|
||||
responseStream = responseStream.peek(response -> logger.debug("Partial response from Gemini:\n{}", response));
|
||||
}
|
||||
|
||||
return responseStream;
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("An error occurred while streaming the request", e);
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new RuntimeException("Streaming the request was interrupted", e);
|
||||
}
|
||||
}
|
||||
|
||||
private HttpRequest buildHttpRequest(String url, String apiKey, String jsonBody) {
|
||||
return HttpRequest.newBuilder()
|
||||
.uri(URI.create(url))
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.output.FinishReason;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
import static dev.langchain4j.model.googleai.FinishReasonMapper.fromGFinishReasonToFinishReason;
|
||||
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromGPartsToAiMessage;
|
||||
|
||||
/**
|
||||
* A builder class for constructing streaming responses from Gemini AI model.
|
||||
* This class accumulates partial responses and builds a final response.
|
||||
*/
|
||||
class GeminiStreamingResponseBuilder {
|
||||
private final boolean includeCodeExecutionOutput;
|
||||
private final StringBuilder contentBuilder;
|
||||
private final List<ToolExecutionRequest> functionCalls;
|
||||
private TokenUsage tokenUsage;
|
||||
private FinishReason finishReason;
|
||||
|
||||
/**
|
||||
* Constructs a new GeminiStreamingResponseBuilder.
|
||||
*
|
||||
* @param includeCodeExecutionOutput whether to include code execution output in the response
|
||||
*/
|
||||
public GeminiStreamingResponseBuilder(boolean includeCodeExecutionOutput) {
|
||||
this.includeCodeExecutionOutput = includeCodeExecutionOutput;
|
||||
this.contentBuilder = new StringBuilder();
|
||||
this.functionCalls = new ArrayList<>();
|
||||
}
|
||||
|
||||
/**
|
||||
* Appends a partial response to the builder.
|
||||
*
|
||||
* @param partialResponse the partial response from Gemini AI
|
||||
* @return an Optional containing the text of the partial response, or empty if no valid text
|
||||
*/
|
||||
public Optional<String> append(GeminiGenerateContentResponse partialResponse) {
|
||||
if (partialResponse == null) {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
GeminiCandidate firstCandidate = partialResponse.getCandidates().get(0);
|
||||
|
||||
updateFinishReason(firstCandidate);
|
||||
updateTokenUsage(partialResponse.getUsageMetadata());
|
||||
|
||||
GeminiContent content = firstCandidate.getContent();
|
||||
if (content == null || content.getParts() == null) {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
AiMessage message = fromGPartsToAiMessage(content.getParts(), this.includeCodeExecutionOutput);
|
||||
updateContentAndFunctionCalls(message);
|
||||
|
||||
return Optional.ofNullable(message.text());
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds the final response from all accumulated partial responses.
|
||||
*
|
||||
* @return a Response object containing the final AiMessage, token usage, and finish reason
|
||||
*/
|
||||
public Response<AiMessage> build() {
|
||||
AiMessage aiMessage = createAiMessage();
|
||||
return Response.from(aiMessage, tokenUsage, finishReason);
|
||||
}
|
||||
|
||||
private void updateTokenUsage(GeminiUsageMetadata tokenCounts) {
|
||||
this.tokenUsage = new TokenUsage(
|
||||
tokenCounts.getPromptTokenCount(),
|
||||
tokenCounts.getCandidatesTokenCount(),
|
||||
tokenCounts.getTotalTokenCount()
|
||||
);
|
||||
}
|
||||
|
||||
private void updateFinishReason(GeminiCandidate candidate) {
|
||||
if (candidate.getFinishReason() != null) {
|
||||
this.finishReason = fromGFinishReasonToFinishReason(candidate.getFinishReason());
|
||||
}
|
||||
}
|
||||
|
||||
private void updateContentAndFunctionCalls(AiMessage message) {
|
||||
Optional.ofNullable(message.text()).ifPresent(contentBuilder::append);
|
||||
if (message.hasToolExecutionRequests()) {
|
||||
functionCalls.addAll(message.toolExecutionRequests());
|
||||
}
|
||||
}
|
||||
|
||||
private AiMessage createAiMessage() {
|
||||
String text = contentBuilder.toString();
|
||||
boolean hasText = !text.isEmpty() && !text.isBlank();
|
||||
boolean hasFunctionCall = !functionCalls.isEmpty();
|
||||
|
||||
if (hasText && hasFunctionCall) {
|
||||
return new AiMessage(text, functionCalls);
|
||||
} else if (hasText) {
|
||||
return new AiMessage(text);
|
||||
} else if (hasFunctionCall) {
|
||||
return new AiMessage(functionCalls);
|
||||
}
|
||||
|
||||
throw new RuntimeException("Gemini has responded neither with text nor with a function call.");
|
||||
}
|
||||
}
|
|
@ -1,6 +1,5 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import dev.langchain4j.Experimental;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
|
@ -8,11 +7,12 @@ import dev.langchain4j.data.message.ChatMessage;
|
|||
import dev.langchain4j.model.chat.Capability;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.TokenCountEstimator;
|
||||
import dev.langchain4j.model.chat.listener.*;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelListener;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelRequest;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
|
||||
import dev.langchain4j.model.chat.request.ChatRequest;
|
||||
import dev.langchain4j.model.chat.request.ResponseFormat;
|
||||
import dev.langchain4j.model.chat.request.ResponseFormatType;
|
||||
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
|
||||
import dev.langchain4j.model.chat.response.ChatResponse;
|
||||
import dev.langchain4j.model.output.FinishReason;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
|
@ -23,135 +23,63 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import java.time.Duration;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.HashSet;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
||||
import static dev.langchain4j.internal.Utils.copyIfNotNull;
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA;
|
||||
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromMessageToGContent;
|
||||
import static dev.langchain4j.model.googleai.FinishReasonMapper.fromGFinishReasonToFinishReason;
|
||||
import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromGPartsToAiMessage;
|
||||
import static dev.langchain4j.model.googleai.SchemaMapper.fromJsonSchemaToGSchema;
|
||||
import static java.time.Duration.ofSeconds;
|
||||
import static java.util.Collections.emptyList;
|
||||
|
||||
@Experimental
|
||||
@Slf4j
|
||||
public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
private final GeminiService geminiService;
|
||||
|
||||
private final String apiKey;
|
||||
private final String modelName;
|
||||
|
||||
private final Integer maxRetries;
|
||||
private final Double temperature;
|
||||
private final Integer topK;
|
||||
private final Double topP;
|
||||
private final Integer maxOutputTokens;
|
||||
private final List<String> stopSequences;
|
||||
|
||||
private final Integer candidateCount;
|
||||
|
||||
private final ResponseFormat responseFormat;
|
||||
|
||||
private final GeminiFunctionCallingConfig toolConfig;
|
||||
|
||||
private final boolean allowCodeExecution;
|
||||
private final boolean includeCodeExecutionOutput;
|
||||
|
||||
private final Boolean logRequestsAndResponses;
|
||||
private final List<GeminiSafetySetting> safetySettings;
|
||||
private final List<ChatModelListener> listeners;
|
||||
|
||||
public class GoogleAiGeminiChatModel extends BaseGeminiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
private final GoogleAiGeminiTokenizer geminiTokenizer;
|
||||
|
||||
@Builder
|
||||
public GoogleAiGeminiChatModel(String apiKey, String modelName,
|
||||
Integer maxRetries,
|
||||
Double temperature, Integer topK, Double topP,
|
||||
Integer maxOutputTokens, Integer candidateCount,
|
||||
Duration timeout,
|
||||
ResponseFormat responseFormat,
|
||||
List<String> stopSequences, GeminiFunctionCallingConfig toolConfig,
|
||||
Boolean allowCodeExecution, Boolean includeCodeExecutionOutput,
|
||||
Boolean logRequestsAndResponses,
|
||||
List<GeminiSafetySetting> safetySettings,
|
||||
List<ChatModelListener> listeners
|
||||
public GoogleAiGeminiChatModel(
|
||||
String apiKey, String modelName,
|
||||
Integer maxRetries,
|
||||
Double temperature, Integer topK, Double topP,
|
||||
Integer maxOutputTokens, Duration timeout,
|
||||
ResponseFormat responseFormat,
|
||||
List<String> stopSequences, GeminiFunctionCallingConfig toolConfig,
|
||||
Boolean allowCodeExecution, Boolean includeCodeExecutionOutput,
|
||||
Boolean logRequestsAndResponses,
|
||||
List<GeminiSafetySetting> safetySettings,
|
||||
List<ChatModelListener> listeners
|
||||
) {
|
||||
this.apiKey = ensureNotBlank(apiKey, "apiKey");
|
||||
this.modelName = ensureNotBlank(modelName, "modelName");
|
||||
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
|
||||
// using Gemini's default values
|
||||
this.temperature = getOrDefault(temperature, 1.0);
|
||||
this.topK = getOrDefault(topK, 64);
|
||||
this.topP = getOrDefault(topP, 0.95);
|
||||
this.maxOutputTokens = getOrDefault(maxOutputTokens, 8192);
|
||||
this.candidateCount = getOrDefault(candidateCount, 1);
|
||||
this.stopSequences = getOrDefault(stopSequences, emptyList());
|
||||
|
||||
this.toolConfig = toolConfig;
|
||||
|
||||
this.allowCodeExecution = allowCodeExecution != null ? allowCodeExecution : false;
|
||||
this.includeCodeExecutionOutput = includeCodeExecutionOutput != null ? includeCodeExecutionOutput : false;
|
||||
this.logRequestsAndResponses = getOrDefault(logRequestsAndResponses, false);
|
||||
|
||||
this.safetySettings = copyIfNotNull(safetySettings);
|
||||
|
||||
this.responseFormat = responseFormat;
|
||||
|
||||
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
|
||||
|
||||
this.geminiService = new GeminiService(
|
||||
getOrDefault(logRequestsAndResponses, false) ? log : null,
|
||||
getOrDefault(timeout, ofSeconds(60))
|
||||
);
|
||||
super(apiKey, modelName, temperature, topK, topP, maxOutputTokens, timeout,
|
||||
responseFormat, stopSequences, toolConfig, allowCodeExecution,
|
||||
includeCodeExecutionOutput, logRequestsAndResponses, safetySettings,
|
||||
listeners, maxRetries);
|
||||
|
||||
this.geminiTokenizer = GoogleAiGeminiTokenizer.builder()
|
||||
.modelName(this.modelName)
|
||||
.apiKey(this.apiKey)
|
||||
.timeout(getOrDefault(timeout, ofSeconds(60)))
|
||||
.maxRetries(this.maxRetries)
|
||||
.logRequestsAndResponses(this.logRequestsAndResponses)
|
||||
.build();
|
||||
}
|
||||
|
||||
private static String computeMimeType(ResponseFormat responseFormat) {
|
||||
if (responseFormat == null || ResponseFormatType.TEXT.equals(responseFormat.type())) {
|
||||
return "text/plain";
|
||||
}
|
||||
|
||||
if (ResponseFormatType.JSON.equals(responseFormat.type()) &&
|
||||
responseFormat.jsonSchema() != null &&
|
||||
responseFormat.jsonSchema().rootElement() != null &&
|
||||
responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema) {
|
||||
|
||||
return "text/x.enum";
|
||||
}
|
||||
|
||||
return "application/json";
|
||||
.modelName(this.modelName)
|
||||
.apiKey(this.apiKey)
|
||||
.timeout(getOrDefault(timeout, ofSeconds(60)))
|
||||
.maxRetries(this.maxRetries)
|
||||
.logRequestsAndResponses(getOrDefault(logRequestsAndResponses, false))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages) {
|
||||
ChatRequest request = ChatRequest.builder()
|
||||
.messages(messages)
|
||||
.build();
|
||||
.messages(messages)
|
||||
.build();
|
||||
|
||||
ChatResponse response = chat(request);
|
||||
|
||||
return Response.from(response.aiMessage(),
|
||||
response.tokenUsage(),
|
||||
response.finishReason());
|
||||
response.tokenUsage(),
|
||||
response.finishReason());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -162,128 +90,86 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
|
|||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
|
||||
ChatRequest request = ChatRequest.builder()
|
||||
.messages(messages)
|
||||
.toolSpecifications(toolSpecifications)
|
||||
.build();
|
||||
.messages(messages)
|
||||
.toolSpecifications(toolSpecifications)
|
||||
.build();
|
||||
|
||||
ChatResponse response = chat(request);
|
||||
|
||||
return Response.from(response.aiMessage(),
|
||||
response.tokenUsage(),
|
||||
response.finishReason());
|
||||
response.tokenUsage(),
|
||||
response.finishReason());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatResponse chat(ChatRequest chatRequest) {
|
||||
GeminiContent systemInstruction = new GeminiContent(GeminiRole.MODEL.toString());
|
||||
List<GeminiContent> geminiContentList = fromMessageToGContent(chatRequest.messages(), systemInstruction);
|
||||
List<ToolSpecification> toolSpecifications = chatRequest.toolSpecifications();
|
||||
GeminiGenerateContentRequest request = createGenerateContentRequest(
|
||||
chatRequest.messages(),
|
||||
chatRequest.toolSpecifications(),
|
||||
getOrDefault(chatRequest.responseFormat(), this.responseFormat)
|
||||
);
|
||||
|
||||
ResponseFormat format = chatRequest.responseFormat() != null ? chatRequest.responseFormat() : this.responseFormat;
|
||||
GeminiSchema schema = null;
|
||||
ChatModelRequest chatModelRequest = createChatModelRequest(
|
||||
chatRequest.messages(),
|
||||
chatRequest.toolSpecifications()
|
||||
);
|
||||
|
||||
String responseMimeType = computeMimeType(format);
|
||||
|
||||
if (format != null && format.jsonSchema() != null) {
|
||||
schema = fromJsonSchemaToGSchema(format.jsonSchema());
|
||||
}
|
||||
|
||||
GeminiGenerateContentRequest request = GeminiGenerateContentRequest.builder()
|
||||
.contents(geminiContentList)
|
||||
.systemInstruction(!systemInstruction.getParts().isEmpty() ? systemInstruction : null)
|
||||
.generationConfig(GeminiGenerationConfig.builder()
|
||||
.candidateCount(this.candidateCount)
|
||||
.maxOutputTokens(this.maxOutputTokens)
|
||||
.responseMimeType(responseMimeType)
|
||||
.responseSchema(schema)
|
||||
.stopSequences(this.stopSequences)
|
||||
.temperature(this.temperature)
|
||||
.topK(this.topK)
|
||||
.topP(this.topP)
|
||||
.build())
|
||||
.safetySettings(this.safetySettings)
|
||||
.tools(FunctionMapper.fromToolSepcsToGTool(toolSpecifications, this.allowCodeExecution))
|
||||
.toolConfig(new GeminiToolConfig(this.toolConfig))
|
||||
.build();
|
||||
|
||||
ChatModelRequest chatModelRequest = ChatModelRequest.builder()
|
||||
.model(modelName)
|
||||
.temperature(temperature)
|
||||
.topP(topP)
|
||||
.maxTokens(maxOutputTokens)
|
||||
.messages(chatRequest.messages())
|
||||
.toolSpecifications(chatRequest.toolSpecifications())
|
||||
.build();
|
||||
ConcurrentHashMap<Object, Object> listenerAttributes = new ConcurrentHashMap<>();
|
||||
ChatModelRequestContext chatModelRequestContext = new ChatModelRequestContext(chatModelRequest, listenerAttributes);
|
||||
listeners.forEach((listener) -> {
|
||||
try {
|
||||
listener.onRequest(chatModelRequestContext);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener (onRequest)", e);
|
||||
}
|
||||
});
|
||||
notifyListenersOnRequest(new ChatModelRequestContext(chatModelRequest, listenerAttributes));
|
||||
|
||||
GeminiGenerateContentResponse geminiResponse;
|
||||
try {
|
||||
geminiResponse = withRetry(() -> this.geminiService.generateContent(this.modelName, this.apiKey, request), this.maxRetries);
|
||||
} catch (RuntimeException e) {
|
||||
ChatModelErrorContext chatModelErrorContext = new ChatModelErrorContext(
|
||||
e, chatModelRequest, null, listenerAttributes
|
||||
GeminiGenerateContentResponse geminiResponse = withRetry(
|
||||
() -> this.geminiService.generateContent(this.modelName, this.apiKey, request),
|
||||
this.maxRetries
|
||||
);
|
||||
listeners.forEach((listener) -> {
|
||||
try {
|
||||
listener.onError(chatModelErrorContext);
|
||||
} catch (Exception ex) {
|
||||
log.warn("Exception while calling model listener (onError)", ex);
|
||||
}
|
||||
});
|
||||
|
||||
return processResponse(geminiResponse, chatModelRequest, listenerAttributes);
|
||||
} catch (RuntimeException e) {
|
||||
notifyListenersOnError(e, chatModelRequest, listenerAttributes);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
if (geminiResponse != null) {
|
||||
GeminiCandidate firstCandidate = geminiResponse.getCandidates().get(0); //TODO handle n
|
||||
GeminiUsageMetadata tokenCounts = geminiResponse.getUsageMetadata();
|
||||
|
||||
AiMessage aiMessage;
|
||||
|
||||
FinishReason finishReason = fromGFinishReasonToFinishReason(firstCandidate.getFinishReason());
|
||||
if (firstCandidate.getContent() == null) {
|
||||
aiMessage = AiMessage.from("No text was returned by the model. " +
|
||||
"The model finished generating because of the following reason: " + finishReason);
|
||||
} else {
|
||||
aiMessage = fromGPartsToAiMessage(firstCandidate.getContent().getParts(), this.includeCodeExecutionOutput);
|
||||
}
|
||||
|
||||
TokenUsage tokenUsage = new TokenUsage(tokenCounts.getPromptTokenCount(),
|
||||
tokenCounts.getCandidatesTokenCount(),
|
||||
tokenCounts.getTotalTokenCount());
|
||||
|
||||
ChatModelResponse chatModelResponse = ChatModelResponse.builder()
|
||||
.model(modelName)
|
||||
.tokenUsage(tokenUsage)
|
||||
.finishReason(finishReason)
|
||||
.aiMessage(aiMessage)
|
||||
.build();
|
||||
ChatModelResponseContext chatModelResponseContext = new ChatModelResponseContext(
|
||||
chatModelResponse, chatModelRequest, listenerAttributes);
|
||||
listeners.forEach((listener) -> {
|
||||
try {
|
||||
listener.onResponse(chatModelResponseContext);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener (onResponse)", e);
|
||||
}
|
||||
});
|
||||
|
||||
return ChatResponse.builder()
|
||||
.aiMessage(aiMessage)
|
||||
.finishReason(finishReason)
|
||||
.tokenUsage(tokenUsage)
|
||||
.build();
|
||||
} else {
|
||||
private ChatResponse processResponse(
|
||||
GeminiGenerateContentResponse geminiResponse,
|
||||
ChatModelRequest chatModelRequest,
|
||||
ConcurrentHashMap<Object, Object> listenerAttributes
|
||||
) {
|
||||
if (geminiResponse == null) {
|
||||
throw new RuntimeException("Gemini response was null");
|
||||
}
|
||||
|
||||
GeminiCandidate firstCandidate = geminiResponse.getCandidates().get(0);
|
||||
GeminiUsageMetadata tokenCounts = geminiResponse.getUsageMetadata();
|
||||
|
||||
FinishReason finishReason = fromGFinishReasonToFinishReason(firstCandidate.getFinishReason());
|
||||
AiMessage aiMessage = createAiMessage(firstCandidate, finishReason);
|
||||
TokenUsage tokenUsage = createTokenUsage(tokenCounts);
|
||||
|
||||
Response<AiMessage> response = Response.from(aiMessage, tokenUsage, finishReason);
|
||||
notifyListenersOnResponse(response, chatModelRequest, listenerAttributes);
|
||||
|
||||
return ChatResponse.builder()
|
||||
.aiMessage(aiMessage)
|
||||
.finishReason(finishReason)
|
||||
.tokenUsage(tokenUsage)
|
||||
.build();
|
||||
}
|
||||
|
||||
private AiMessage createAiMessage(GeminiCandidate candidate, FinishReason finishReason) {
|
||||
if (candidate.getContent() == null) {
|
||||
return AiMessage.from("No text was returned by the model. " +
|
||||
"The model finished generating because of the following reason: " + finishReason);
|
||||
}
|
||||
return fromGPartsToAiMessage(candidate.getContent().getParts(), this.includeCodeExecutionOutput);
|
||||
}
|
||||
|
||||
private TokenUsage createTokenUsage(GeminiUsageMetadata tokenCounts) {
|
||||
return new TokenUsage(
|
||||
tokenCounts.getPromptTokenCount(),
|
||||
tokenCounts.getCandidatesTokenCount(),
|
||||
tokenCounts.getTotalTokenCount()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -309,8 +195,8 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst
|
|||
|
||||
public GoogleAiGeminiChatModelBuilder safetySettings(Map<GeminiHarmCategory, GeminiHarmBlockThreshold> safetySettingMap) {
|
||||
this.safetySettings = safetySettingMap.entrySet().stream()
|
||||
.map(entry -> new GeminiSafetySetting(entry.getKey(), entry.getValue())
|
||||
).collect(Collectors.toList());
|
||||
.map(entry -> new GeminiSafetySetting(entry.getKey(), entry.getValue())
|
||||
).collect(Collectors.toList());
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
import dev.langchain4j.Experimental;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelListener;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelRequest;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
|
||||
import dev.langchain4j.model.chat.request.ResponseFormat;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.Builder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
||||
|
||||
@Experimental
|
||||
@Slf4j
|
||||
public class GoogleAiGeminiStreamingChatModel extends BaseGeminiChatModel implements StreamingChatLanguageModel {
|
||||
@Builder
|
||||
public GoogleAiGeminiStreamingChatModel(
|
||||
String apiKey, String modelName,
|
||||
Double temperature, Integer topK, Double topP,
|
||||
Integer maxOutputTokens, Duration timeout,
|
||||
ResponseFormat responseFormat,
|
||||
List<String> stopSequences, GeminiFunctionCallingConfig toolConfig,
|
||||
Boolean allowCodeExecution, Boolean includeCodeExecutionOutput,
|
||||
Boolean logRequestsAndResponses,
|
||||
List<GeminiSafetySetting> safetySettings,
|
||||
List<ChatModelListener> listeners,
|
||||
Integer maxRetries
|
||||
) {
|
||||
super(apiKey, modelName, temperature, topK, topP, maxOutputTokens, timeout,
|
||||
responseFormat, stopSequences, toolConfig, allowCodeExecution,
|
||||
includeCodeExecutionOutput, logRequestsAndResponses, safetySettings,
|
||||
listeners, maxRetries);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
|
||||
generate(messages, Collections.emptyList(), handler);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void generate(List<ChatMessage> messages, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> handler) {
|
||||
throw new RuntimeException("This method is not supported: Gemini AI cannot be forced to execute a tool.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, StreamingResponseHandler<AiMessage> handler) {
|
||||
GeminiGenerateContentRequest request = createGenerateContentRequest(messages, toolSpecifications, this.responseFormat);
|
||||
ChatModelRequest chatModelRequest = createChatModelRequest(messages, toolSpecifications);
|
||||
|
||||
ConcurrentHashMap<Object, Object> listenerAttributes = new ConcurrentHashMap<>();
|
||||
ChatModelRequestContext chatModelRequestContext = new ChatModelRequestContext(chatModelRequest, listenerAttributes);
|
||||
notifyListenersOnRequest(chatModelRequestContext);
|
||||
|
||||
processGenerateContentRequest(request, handler, chatModelRequest, listenerAttributes);
|
||||
}
|
||||
|
||||
private void processGenerateContentRequest(GeminiGenerateContentRequest request, StreamingResponseHandler<AiMessage> handler,
|
||||
ChatModelRequest chatModelRequest, ConcurrentHashMap<Object, Object> listenerAttributes) {
|
||||
GeminiStreamingResponseBuilder responseBuilder = new GeminiStreamingResponseBuilder(this.includeCodeExecutionOutput);
|
||||
|
||||
try {
|
||||
Stream<GeminiGenerateContentResponse> contentStream = withRetry(
|
||||
() -> this.geminiService.generateContentStream(this.modelName, this.apiKey, request),
|
||||
maxRetries);
|
||||
|
||||
contentStream.forEach(partialResponse -> {
|
||||
Optional<String> text = responseBuilder.append(partialResponse);
|
||||
text.ifPresent(handler::onNext);
|
||||
});
|
||||
|
||||
Response<AiMessage> fullResponse = responseBuilder.build();
|
||||
handler.onComplete(fullResponse);
|
||||
|
||||
notifyListenersOnResponse(fullResponse, chatModelRequest, listenerAttributes);
|
||||
} catch (RuntimeException exception) {
|
||||
notifyListenersOnError(exception, chatModelRequest, listenerAttributes);
|
||||
handler.onError(exception);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -18,12 +18,12 @@ import dev.langchain4j.data.message.UserMessage;
|
|||
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
|
||||
import dev.langchain4j.model.chat.request.ChatRequest;
|
||||
import dev.langchain4j.model.chat.request.ResponseFormat;
|
||||
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
|
||||
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
|
||||
import dev.langchain4j.model.chat.response.ChatResponse;
|
||||
import dev.langchain4j.model.output.FinishReason;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
|
@ -34,7 +34,13 @@ import org.junit.jupiter.api.Test;
|
|||
import org.junitpioneer.jupiter.RetryingTest;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Base64;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.readBytes;
|
||||
|
@ -44,7 +50,9 @@ import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HA
|
|||
import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HATE_SPEECH;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.mockito.Mockito.spy;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
|
||||
public class GoogleAiGeminiChatModelIT {
|
||||
|
||||
|
@ -114,8 +122,8 @@ public class GoogleAiGeminiChatModelIT {
|
|||
assertThat(jsonText).contains("\"John\"");
|
||||
|
||||
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(25);
|
||||
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(6);
|
||||
assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(31);
|
||||
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(7);
|
||||
assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(32);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -435,10 +443,10 @@ public class GoogleAiGeminiChatModelIT {
|
|||
.jsonSchema(JsonSchema.builder()
|
||||
.rootElement(JsonObjectSchema.builder()
|
||||
.addStringProperty("name")
|
||||
.addProperty("address", JsonObjectSchema.builder()
|
||||
.addStringProperty("city")
|
||||
.required("city")
|
||||
.build())
|
||||
.addProperty("address", JsonObjectSchema.builder()
|
||||
.addStringProperty("city")
|
||||
.required("city")
|
||||
.build())
|
||||
.required("name", "address")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
|
|
|
@ -0,0 +1,708 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import dev.langchain4j.agent.tool.P;
|
||||
import dev.langchain4j.agent.tool.Tool;
|
||||
import dev.langchain4j.agent.tool.ToolExecutionRequest;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.AudioContent;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.ImageContent;
|
||||
import dev.langchain4j.data.message.SystemMessage;
|
||||
import dev.langchain4j.data.message.TextContent;
|
||||
import dev.langchain4j.data.message.TextFileContent;
|
||||
import dev.langchain4j.data.message.ToolExecutionResultMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
|
||||
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
|
||||
import dev.langchain4j.model.chat.request.ResponseFormat;
|
||||
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonSchema;
|
||||
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
|
||||
import dev.langchain4j.model.output.FinishReason;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import dev.langchain4j.service.TokenStream;
|
||||
import dev.langchain4j.service.output.JsonSchemas;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junitpioneer.jupiter.RetryingTest;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Base64;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.readBytes;
|
||||
import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON;
|
||||
import static dev.langchain4j.model.googleai.GeminiHarmBlockThreshold.BLOCK_LOW_AND_ABOVE;
|
||||
import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HARASSMENT;
|
||||
import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HATE_SPEECH;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.Mockito.spy;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
|
||||
public class GoogleAiGeminiStreamingChatModelIT {
|
||||
|
||||
private static final String GOOGLE_AI_GEMINI_API_KEY = System.getenv("GOOGLE_AI_GEMINI_API_KEY");
|
||||
|
||||
private static final String CAT_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/e/e9/Felis_silvestris_silvestris_small_gradual_decrease_of_quality.png";
|
||||
private static final String MD_FILE_URL = "https://raw.githubusercontent.com/langchain4j/langchain4j/main/docs/docs/intro.md";
|
||||
|
||||
@Test
|
||||
void should_answer_simple_question() {
|
||||
// given
|
||||
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.build();
|
||||
|
||||
UserMessage userMessage = UserMessage.from("What is the capital of France?");
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
gemini.generate(userMessage, handler);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
// then
|
||||
String text = response.content().text();
|
||||
assertThat(text).containsIgnoringCase("Paris");
|
||||
|
||||
assertThat(response.finishReason()).isEqualTo(FinishReason.STOP);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_configure_generation() {
|
||||
// given
|
||||
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.temperature(2.0)
|
||||
.topP(0.5)
|
||||
.topK(10)
|
||||
.maxOutputTokens(10)
|
||||
.build();
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
gemini.generate("How much is 3+4? Reply with just the answer", handler);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
// then
|
||||
String text = response.content().text();
|
||||
assertThat(text.trim()).isEqualTo("7");
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_answer_in_json() {
|
||||
// given
|
||||
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-pro")
|
||||
.responseFormat(ResponseFormat.JSON)
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
gemini.generate(UserMessage.from("What is the firstname of the John Doe?\n" +
|
||||
"Reply in JSON following with the following format: {\"firstname\": string}"), handler);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
// then
|
||||
String jsonText = response.content().text();
|
||||
|
||||
assertThat(jsonText).contains("\"firstname\"");
|
||||
assertThat(jsonText).contains("\"John\"");
|
||||
|
||||
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(25);
|
||||
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(7);
|
||||
assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(32);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_support_multiturn_chatting() {
|
||||
// given
|
||||
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.build();
|
||||
|
||||
List<ChatMessage> messages = new ArrayList<>();
|
||||
messages.add(UserMessage.from("Hello, my name is Guillaume"));
|
||||
messages.add(AiMessage.from("Hi, how can I assist you today?"));
|
||||
messages.add(UserMessage.from("What is my name? Reply with just my name."));
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
gemini.generate(messages, handler);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
// then
|
||||
assertThat(response.content().text()).contains("Guillaume");
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_support_sending_images_as_base64() {
|
||||
// given
|
||||
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.build();
|
||||
|
||||
UserMessage userMessage = UserMessage.userMessage(
|
||||
ImageContent.from(new String(Base64.getEncoder().encode(readBytes(CAT_IMAGE_URL))), "image/png"),
|
||||
TextContent.from("Describe this image in a single word")
|
||||
);
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
gemini.generate(userMessage, handler);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
// then
|
||||
assertThat(response.content().text()).containsIgnoringCase("cat");
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_support_text_file() {
|
||||
// given
|
||||
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.build();
|
||||
|
||||
UserMessage userMessage = UserMessage.userMessage(
|
||||
TextFileContent.from(new String(Base64.getEncoder().encode(readBytes(MD_FILE_URL))), "text/markdown"),
|
||||
TextContent.from("What project does this markdown file mention?")
|
||||
);
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
gemini.generate(userMessage, handler);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
// then
|
||||
assertThat(response.content().text()).containsIgnoringCase("LangChain4j");
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_support_audio_file() {
|
||||
// given
|
||||
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.build();
|
||||
|
||||
UserMessage userMessage = UserMessage.from(
|
||||
AudioContent.from(
|
||||
new String(Base64.getEncoder().encode( //TODO use local file
|
||||
readBytes("https://storage.googleapis.com/cloud-samples-data/generative-ai/audio/pixel.mp3"))),
|
||||
"audio/mp3"),
|
||||
TextContent.from("Give a summary of the audio")
|
||||
);
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
gemini.generate(userMessage, handler);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
// then
|
||||
assertThat(response.content().text()).containsIgnoringCase("Pixel");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
void should_support_video_file() {
|
||||
// ToDo waiting for the normal GoogleAiGeminiChatModel to implement the test
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_respect_system_instruction() {
|
||||
// given
|
||||
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.build();
|
||||
|
||||
List<ChatMessage> chatMessages = List.of(
|
||||
SystemMessage.from("Translate from English into French"),
|
||||
UserMessage.from("apple")
|
||||
);
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
gemini.generate(chatMessages, handler);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
// then
|
||||
assertThat(response.content().text()).containsIgnoringCase("pomme");
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_execute_python_code() {
|
||||
// given
|
||||
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.logRequestsAndResponses(true)
|
||||
.allowCodeExecution(true)
|
||||
.includeCodeExecutionOutput(true)
|
||||
.build();
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
gemini.generate(
|
||||
UserMessage.from("Calculate `fibonacci(13)`. Write code in Python and execute it to get the result."),
|
||||
handler
|
||||
);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
// then
|
||||
String text = response.content().text();
|
||||
System.out.println("text = " + text);
|
||||
|
||||
assertThat(text).containsIgnoringCase("233");
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_support_JSON_array_in_tools() {
|
||||
// given
|
||||
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
|
||||
// when
|
||||
List<ChatMessage> allMessages = new ArrayList<>();
|
||||
allMessages.add(UserMessage.from("Return a JSON list containing the first 10 fibonacci numbers."));
|
||||
|
||||
|
||||
TestStreamingResponseHandler<AiMessage> handler1 = new TestStreamingResponseHandler<>();
|
||||
gemini.generate(
|
||||
allMessages,
|
||||
List.of(ToolSpecification.builder()
|
||||
.name("getFirstNFibonacciNumbers")
|
||||
.description("Get the first n fibonacci numbers")
|
||||
.parameters(JsonObjectSchema.builder()
|
||||
.addNumberProperty("n")
|
||||
.build())
|
||||
.build()),
|
||||
handler1
|
||||
);
|
||||
Response<AiMessage> response1 = handler1.get();
|
||||
|
||||
// then
|
||||
assertThat(response1.content().hasToolExecutionRequests()).isTrue();
|
||||
assertThat(response1.content().toolExecutionRequests().get(0).name()).isEqualTo("getFirstNFibonacciNumbers");
|
||||
assertThat(response1.content().toolExecutionRequests().get(0).arguments()).contains("\"n\":10");
|
||||
|
||||
allMessages.add(response1.content());
|
||||
|
||||
// when
|
||||
ToolExecutionResultMessage forecastResult =
|
||||
ToolExecutionResultMessage.from(null, "getFirstNFibonacciNumbers", "[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]");
|
||||
allMessages.add(forecastResult);
|
||||
|
||||
// then
|
||||
TestStreamingResponseHandler<AiMessage> handler2 = new TestStreamingResponseHandler<>();
|
||||
gemini.generate(allMessages, handler2);
|
||||
Response<AiMessage> response2 = handler2.get();
|
||||
|
||||
// then
|
||||
assertThat(response2.content().text()).contains("[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]");
|
||||
}
|
||||
|
||||
// Test is flaky, because Gemini doesn't 100% always ask for parallel tool calls
|
||||
// and sometimes requests more information
|
||||
@RetryingTest(5)
|
||||
void should_support_parallel_tool_execution() {
|
||||
// given
|
||||
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
|
||||
// when
|
||||
List<ChatMessage> allMessages = new ArrayList<>();
|
||||
allMessages.add(UserMessage.from("Which warehouse has more stock, ABC or XYZ?"));
|
||||
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
gemini.generate(
|
||||
allMessages,
|
||||
List.of(ToolSpecification.builder()
|
||||
.name("getWarehouseStock")
|
||||
.description("Retrieve the amount of stock available in a warehouse designated by its name")
|
||||
.parameters(JsonObjectSchema.builder()
|
||||
.addStringProperty("name", "The name of the warehouse")
|
||||
.build())
|
||||
.build()),
|
||||
handler
|
||||
);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
|
||||
// then
|
||||
assertThat(response.content().hasToolExecutionRequests()).isTrue();
|
||||
|
||||
List<ToolExecutionRequest> executionRequests = response.content().toolExecutionRequests();
|
||||
assertThat(executionRequests).hasSize(2);
|
||||
|
||||
String allArgs = executionRequests.stream()
|
||||
.map(ToolExecutionRequest::arguments)
|
||||
.collect(Collectors.joining(" "));
|
||||
assertThat(allArgs).contains("ABC");
|
||||
assertThat(allArgs).contains("XYZ");
|
||||
}
|
||||
|
||||
@RetryingTest(5)
|
||||
void should_support_safety_settings() {
|
||||
// given
|
||||
List<GeminiSafetySetting> safetySettings = List.of(
|
||||
new GeminiSafetySetting(HARM_CATEGORY_HATE_SPEECH, BLOCK_LOW_AND_ABOVE),
|
||||
new GeminiSafetySetting(HARM_CATEGORY_HARASSMENT, BLOCK_LOW_AND_ABOVE)
|
||||
);
|
||||
|
||||
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.logRequestsAndResponses(true)
|
||||
.safetySettings(safetySettings)
|
||||
.build();
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
gemini.generate("You're a dumb f*cking idiot bastard!", handler);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
// then
|
||||
assertThat(response.finishReason()).isEqualTo(FinishReasonMapper.fromGFinishReasonToFinishReason(GeminiFinishReason.SAFETY));
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_comply_to_response_schema() {
|
||||
// given
|
||||
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.logRequestsAndResponses(true)
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.rootElement(JsonObjectSchema.builder()
|
||||
.addStringProperty("name")
|
||||
.addProperty("address", JsonObjectSchema.builder()
|
||||
.addStringProperty("city")
|
||||
.required("city")
|
||||
.build())
|
||||
.required("name", "address")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build();
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
gemini.generate(List.of(
|
||||
SystemMessage.from(
|
||||
"Your role is to extract information related to a person," +
|
||||
"like their name, their address, the city the live in."),
|
||||
UserMessage.from(
|
||||
"In the town of Liverpool, lived Tommy Skybridge, a young little boy."
|
||||
)
|
||||
), handler);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
System.out.println("response = " + response);
|
||||
|
||||
// then
|
||||
assertThat(response.content().text().trim())
|
||||
.isEqualTo("{\"address\": {\"city\": \"Liverpool\"}, \"name\": \"Tommy Skybridge\"}");
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_handle_enum_type() {
|
||||
// given
|
||||
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.logRequestsAndResponses(true)
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.rootElement(JsonObjectSchema.builder()
|
||||
.properties(new LinkedHashMap<String, JsonSchemaElement>() {{
|
||||
put("sentiment", JsonEnumSchema.builder()
|
||||
.enumValues("POSITIVE", "NEGATIVE")
|
||||
.build());
|
||||
}})
|
||||
.required("sentiment")
|
||||
.additionalProperties(false)
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build();
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
gemini.generate(List.of(
|
||||
SystemMessage.from(
|
||||
"Your role is to analyze the sentiment of the text you receive."),
|
||||
UserMessage.from(
|
||||
"This is super exciting news, congratulations!"
|
||||
)
|
||||
), handler);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
System.out.println("response = " + response);
|
||||
|
||||
// then
|
||||
assertThat(response.content().text().trim())
|
||||
.isEqualTo("{\"sentiment\": \"POSITIVE\"}");
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_handle_enum_output_mode() {
|
||||
// given
|
||||
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.logRequestsAndResponses(true)
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.rootElement(JsonEnumSchema.builder()
|
||||
.enumValues("POSITIVE", "NEUTRAL", "NEGATIVE")
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build();
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
gemini.generate(List.of(
|
||||
SystemMessage.from(
|
||||
"Your role is to analyze the sentiment of the text you receive."),
|
||||
UserMessage.from(
|
||||
"This is super exciting news, congratulations!"
|
||||
)
|
||||
), handler);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
// then
|
||||
assertThat(response.content().text().trim())
|
||||
.isEqualTo("POSITIVE");
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_allow_array_as_response_schema() {
|
||||
// given
|
||||
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.logRequestsAndResponses(true)
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchema.builder()
|
||||
.rootElement(JsonArraySchema.builder()
|
||||
.items(new JsonIntegerSchema())
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.build();
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
gemini.generate(List.of(
|
||||
SystemMessage.from(
|
||||
"Your role is to return a list of 6-faces dice rolls"),
|
||||
UserMessage.from(
|
||||
"Give me 3 dice rolls"
|
||||
)
|
||||
), handler);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
System.out.println("response = " + response);
|
||||
|
||||
// then
|
||||
Integer[] diceRolls = new Gson().fromJson(response.content().text(), Integer[].class);
|
||||
assertThat(diceRolls.length).isEqualTo(3);
|
||||
}
|
||||
|
||||
private class Color {
|
||||
private String name;
|
||||
private int red;
|
||||
private int green;
|
||||
private int blue;
|
||||
private boolean muted;
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_deserialize_to_POJO() {
|
||||
// given
|
||||
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.logRequestsAndResponses(true)
|
||||
.responseFormat(ResponseFormat.builder()
|
||||
.type(JSON)
|
||||
.jsonSchema(JsonSchemas.jsonSchemaFrom(Color.class).get())
|
||||
.build())
|
||||
.build();
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
gemini.generate(List.of(
|
||||
SystemMessage.from(
|
||||
"Your role is to extract information from the description of a color"),
|
||||
UserMessage.from(
|
||||
"Cobalt blue is a blend of a lot of blue, a bit of green, and almost no red."
|
||||
)
|
||||
), handler);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
System.out.println("response = " + response);
|
||||
|
||||
Color color = new Gson().fromJson(response.content().text(), Color.class);
|
||||
|
||||
// then
|
||||
assertThat(color.name).isEqualToIgnoringCase("Cobalt blue");
|
||||
assertThat(color.muted).isFalse();
|
||||
assertThat(color.red).isLessThanOrEqualTo(color.green);
|
||||
assertThat(color.green).isLessThanOrEqualTo(color.blue);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_support_tool_config() {
|
||||
// given
|
||||
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
|
||||
List<ChatMessage> chatMessages = List.of(UserMessage.from("Call toolOne"));
|
||||
List<ToolSpecification> listOfTools = Arrays.asList(
|
||||
ToolSpecification.builder().name("toolOne").build(),
|
||||
ToolSpecification.builder().name("toolTwo").build()
|
||||
);
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler1 = new TestStreamingResponseHandler<>();
|
||||
gemini.generate(chatMessages, listOfTools, handler1);
|
||||
Response<AiMessage> response1 = handler1.get();
|
||||
|
||||
// then
|
||||
assertThat(response1.content().hasToolExecutionRequests()).isTrue();
|
||||
assertThat(response1.content().toolExecutionRequests().get(0).name()).isEqualTo("toolOne");
|
||||
|
||||
// given
|
||||
gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.logRequestsAndResponses(true)
|
||||
.toolConfig(new GeminiFunctionCallingConfig(GeminiMode.ANY, List.of("toolTwo")))
|
||||
.build();
|
||||
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler2 = new TestStreamingResponseHandler<>();
|
||||
gemini.generate("Call toolOne", handler2);
|
||||
Response<AiMessage> response2 = handler2.get();
|
||||
|
||||
// then
|
||||
assertThat(response2.content().hasToolExecutionRequests()).isFalse();
|
||||
}
|
||||
|
||||
static class Transactions {
|
||||
@Tool("returns amount of a given transaction")
|
||||
double getTransactionAmount(@P("ID of a transaction") String id) {
|
||||
System.out.printf("called getTransactionAmount(%s)%n", id);
|
||||
switch (id) {
|
||||
case "T001":
|
||||
return 11.1;
|
||||
case "T002":
|
||||
return 22.2;
|
||||
default:
|
||||
throw new IllegalArgumentException("Unknown transaction ID: " + id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
interface StreamingAssistant {
|
||||
TokenStream chat(String userMessage);
|
||||
}
|
||||
|
||||
@RetryingTest(10)
|
||||
void should_work_with_tools_with_AiServices() throws ExecutionException, InterruptedException, TimeoutException {
|
||||
// given
|
||||
GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("gemini-1.5-pro")
|
||||
.logRequestsAndResponses(true)
|
||||
.timeout(Duration.ofMinutes(2))
|
||||
.temperature(0.0)
|
||||
.topP(0.0)
|
||||
.topK(1)
|
||||
.build();
|
||||
|
||||
// when
|
||||
Transactions spyTransactions = spy(new Transactions());
|
||||
|
||||
MessageWindowChatMemory chatMemory = MessageWindowChatMemory.withMaxMessages(20);
|
||||
StreamingAssistant assistant = AiServices.builder(StreamingAssistant.class)
|
||||
.tools(spyTransactions)
|
||||
.chatMemory(chatMemory)
|
||||
.streamingChatLanguageModel(gemini)
|
||||
.build();
|
||||
|
||||
// then
|
||||
CompletableFuture<Response<AiMessage>> future1 = new CompletableFuture<>();
|
||||
assistant.chat("What is the amount of transaction T001?")
|
||||
.onNext(System.out::println)
|
||||
.onComplete(future1::complete)
|
||||
.onError(future1::completeExceptionally)
|
||||
.start();
|
||||
Response<AiMessage> response1 = future1.get(30, TimeUnit.SECONDS);
|
||||
|
||||
assertThat(response1.content().text()).containsIgnoringCase("11.1");
|
||||
verify(spyTransactions).getTransactionAmount("T001");
|
||||
|
||||
CompletableFuture<Response<AiMessage>> future2 = new CompletableFuture<>();
|
||||
assistant.chat("What is the amount of transaction T002?")
|
||||
.onNext(System.out::println)
|
||||
.onComplete(future2::complete)
|
||||
.onError(future2::completeExceptionally)
|
||||
.start();
|
||||
Response<AiMessage> response2 = future2.get(30, TimeUnit.SECONDS);
|
||||
|
||||
assertThat(response2.content().text()).containsIgnoringCase("22.2");
|
||||
verify(spyTransactions).getTransactionAmount("T002");
|
||||
|
||||
verifyNoMoreInteractions(spyTransactions);
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void afterEach() throws InterruptedException {
|
||||
String ciDelaySeconds = System.getenv("CI_DELAY_SECONDS_GOOGLE_AI_GEMINI");
|
||||
if (ciDelaySeconds != null) {
|
||||
Thread.sleep(Integer.parseInt(ciDelaySeconds) * 1000L);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package dev.langchain4j.model.googleai;
|
||||
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.StreamingChatModelListenerIT;
|
||||
import dev.langchain4j.model.chat.listener.ChatModelListener;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static java.util.Collections.singletonList;
|
||||
|
||||
class GoogleAiGeminiStreamingChatModelListenerIT extends StreamingChatModelListenerIT {
|
||||
private static final String GOOGLE_AI_GEMINI_API_KEY = System.getenv("GOOGLE_AI_GEMINI_API_KEY");
|
||||
|
||||
@Override
|
||||
protected StreamingChatLanguageModel createModel(ChatModelListener listener) {
|
||||
return GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName(modelName())
|
||||
.temperature(temperature())
|
||||
.topP(topP())
|
||||
.maxOutputTokens(maxTokens())
|
||||
.listeners(singletonList(listener))
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String modelName() {
|
||||
return "gemini-1.5-flash";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean assertResponseId() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected StreamingChatLanguageModel createFailingModel(ChatModelListener listener) {
|
||||
return GoogleAiGeminiStreamingChatModel.builder()
|
||||
.apiKey(GOOGLE_AI_GEMINI_API_KEY)
|
||||
.modelName("banana")
|
||||
.listeners(singletonList(listener))
|
||||
.logRequestsAndResponses(true)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Class<? extends Exception> expectedExceptionClass() {
|
||||
return RuntimeException.class;
|
||||
}
|
||||
|
||||
|
||||
@AfterEach
|
||||
void afterEach() throws InterruptedException {
|
||||
String ciDelaySeconds = System.getenv("CI_DELAY_SECONDS_VERTEX_AI_GEMINI");
|
||||
if (ciDelaySeconds != null) {
|
||||
Thread.sleep(Integer.parseInt(ciDelaySeconds) * 1000L);
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue