From aa0e48816657640eda75879f1c29c0348643575c Mon Sep 17 00:00:00 2001 From: Bjarne-Kinkel <89375557+Bjarne-Kinkel@users.noreply.github.com> Date: Thu, 31 Oct 2024 10:07:50 +0100 Subject: [PATCH] Adding GoogleAiGeminiStreamingChatModel (#1951) ## Issue Closes #1903 ## Change Implemented the GoogleAiGeminiStreamingChatModel. This PR depends on #1950. ## General checklist - [X] There are no breaking changes - [x] I have added unit and integration tests for my change - [x] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [x] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green - [x] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable) --- .../language-models/google-ai-gemini.md | 29 +- .../integrations/language-models/index.md | 4 +- .../model/googleai/BaseGeminiChatModel.java | 192 +++++ .../model/googleai/GeminiService.java | 41 + .../GeminiStreamingResponseBuilder.java | 111 +++ .../googleai/GoogleAiGeminiChatModel.java | 300 +++----- .../GoogleAiGeminiStreamingChatModel.java | 93 +++ .../googleai/GoogleAiGeminiChatModelIT.java | 32 +- .../GoogleAiGeminiStreamingChatModelIT.java | 708 ++++++++++++++++++ ...eAiGeminiStreamingChatModelListenerIT.java | 61 ++ 10 files changed, 1348 insertions(+), 223 deletions(-) create mode 100644 langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/BaseGeminiChatModel.java create mode 100644 langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiStreamingResponseBuilder.java create mode 100644 langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiStreamingChatModel.java create mode 100644 langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/GoogleAiGeminiStreamingChatModelIT.java create mode 100644 langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/GoogleAiGeminiStreamingChatModelListenerIT.java diff --git a/docs/docs/integrations/language-models/google-ai-gemini.md b/docs/docs/integrations/language-models/google-ai-gemini.md index b7abae195..ce6da6894 100644 --- a/docs/docs/integrations/language-models/google-ai-gemini.md +++ b/docs/docs/integrations/language-models/google-ai-gemini.md @@ -83,9 +83,34 @@ ChatLanguageModel gemini = GoogleAiGeminiChatModel.builder() ``` ## GoogleAiGeminiStreamingChatModel +The `GoogleAiGeminiStreamingChatModel` allows streaming the text of a response token by token. The response must be managed by a `StreamingResponseHandler`. +```java +StreamingChatLanguageModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(System.getenv("GEMINI_AI_KEY")) + .modelName("gemini-1.5-flash") + .build(); -No streaming chat model is available yet. -Please open a feature request if you're interested in a streaming model or if you want to contribute to implementing it. +CompletableFuture> futureResponse = new CompletableFuture<>(); + + gemini.generate("Tell me a joke about Java", new StreamingResponseHandler() { + @Override + public void onNext(String token) { + System.out.print(token); + } + + @Override + public void onComplete(Response response) { + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureResponse.completeExceptionally(error); + } +}); + + futureResponse.join(); +``` ## Tools diff --git a/docs/docs/integrations/language-models/index.md b/docs/docs/integrations/language-models/index.md index cdf19dcd1..05ceb0e5c 100644 --- a/docs/docs/integrations/language-models/index.md +++ b/docs/docs/integrations/language-models/index.md @@ -11,8 +11,8 @@ sidebar_position: 0 | [Azure OpenAI](/integrations/language-models/azure-open-ai) | ✅ | ✅ | ✅ | text, image | ✅ | | | | | [ChatGLM](/integrations/language-models/chatglm) | | | | text | | | | | | [DashScope](/integrations/language-models/dashscope) | ✅ | ✅ | | text, image, audio | ✅ | | | | -| [GitHub Models](/integrations/language-models/github-models) | ✅ | ✅ | ✅ | text | ✅ | | | | -| [Google AI Gemini](/integrations/language-models/google-ai-gemini) | | ✅ | ✅ | text, image, audio, video, PDF | ✅ | | | | +| [GitHub Models](/integrations/language-models/github-models) | ✅ | ✅ | ✅ | text | ✅ | | | | +| [Google AI Gemini](/integrations/language-models/google-ai-gemini) | ✅ | ✅ | ✅ | text, image, audio, video, PDF | ✅ | | | | | [Google Vertex AI Gemini](/integrations/language-models/google-vertex-ai-gemini) | ✅ | ✅ | ✅ | text, image, audio, video, PDF | ✅ | | | | | [Google Vertex AI PaLM 2](/integrations/language-models/google-palm) | | | | text | | | ✅ | | | [Hugging Face](/integrations/language-models/hugging-face) | | | | text | | | | | diff --git a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/BaseGeminiChatModel.java b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/BaseGeminiChatModel.java new file mode 100644 index 000000000..cb44600cb --- /dev/null +++ b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/BaseGeminiChatModel.java @@ -0,0 +1,192 @@ +package dev.langchain4j.model.googleai; + +import dev.langchain4j.Experimental; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.chat.listener.ChatModelErrorContext; +import dev.langchain4j.model.chat.listener.ChatModelListener; +import dev.langchain4j.model.chat.listener.ChatModelRequest; +import dev.langchain4j.model.chat.listener.ChatModelRequestContext; +import dev.langchain4j.model.chat.listener.ChatModelResponse; +import dev.langchain4j.model.chat.listener.ChatModelResponseContext; +import dev.langchain4j.model.chat.request.ResponseFormat; +import dev.langchain4j.model.chat.request.ResponseFormatType; +import dev.langchain4j.model.chat.request.json.JsonEnumSchema; +import dev.langchain4j.model.output.Response; +import lombok.extern.slf4j.Slf4j; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; + +import static dev.langchain4j.internal.Utils.copyIfNotNull; +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromMessageToGContent; +import static dev.langchain4j.model.googleai.SchemaMapper.fromJsonSchemaToGSchema; +import static java.time.Duration.ofSeconds; +import static java.util.Collections.emptyList; + +@Experimental +@Slf4j +abstract class BaseGeminiChatModel { + protected final GeminiService geminiService; + protected final String apiKey; + protected final String modelName; + protected final Double temperature; + protected final Integer topK; + protected final Double topP; + protected final Integer maxOutputTokens; + protected final List stopSequences; + protected final ResponseFormat responseFormat; + protected final GeminiFunctionCallingConfig toolConfig; + protected final boolean allowCodeExecution; + protected final boolean includeCodeExecutionOutput; + protected final List safetySettings; + protected final List listeners; + protected final Integer maxRetries; + + protected BaseGeminiChatModel( + String apiKey, + String modelName, + Double temperature, + Integer topK, + Double topP, + Integer maxOutputTokens, + Duration timeout, + ResponseFormat responseFormat, + List stopSequences, + GeminiFunctionCallingConfig toolConfig, + Boolean allowCodeExecution, + Boolean includeCodeExecutionOutput, + Boolean logRequestsAndResponses, + List safetySettings, + List listeners, + Integer maxRetries + ) { + this.apiKey = ensureNotBlank(apiKey, "apiKey"); + this.modelName = ensureNotBlank(modelName, "modelName"); + this.temperature = temperature; + this.topK = topK; + this.topP = topP; + this.maxOutputTokens = maxOutputTokens; + this.stopSequences = getOrDefault(stopSequences, emptyList()); + this.toolConfig = toolConfig; + this.allowCodeExecution = getOrDefault(allowCodeExecution, false); + this.includeCodeExecutionOutput = getOrDefault(includeCodeExecutionOutput, false); + this.safetySettings = copyIfNotNull(safetySettings); + this.responseFormat = responseFormat; + this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners); + this.maxRetries = getOrDefault(maxRetries, 3); + this.geminiService = new GeminiService( + getOrDefault(logRequestsAndResponses, false) ? log : null, + getOrDefault(timeout, ofSeconds(60)) + ); + } + + protected GeminiGenerateContentRequest createGenerateContentRequest( + List messages, + List toolSpecifications, + ResponseFormat responseFormat + ) { + GeminiContent systemInstruction = new GeminiContent(GeminiRole.MODEL.toString()); + List geminiContentList = fromMessageToGContent(messages, systemInstruction); + + GeminiSchema schema = null; + if (responseFormat != null && responseFormat.jsonSchema() != null) { + schema = fromJsonSchemaToGSchema(responseFormat.jsonSchema()); + } + + return GeminiGenerateContentRequest.builder() + .contents(geminiContentList) + .systemInstruction(!systemInstruction.getParts().isEmpty() ? systemInstruction : null) + .generationConfig(GeminiGenerationConfig.builder() + .candidateCount(1) // Multiple candidates aren't supported by langchain4j + .maxOutputTokens(this.maxOutputTokens) + .responseMimeType(computeMimeType(responseFormat)) + .responseSchema(schema) + .stopSequences(this.stopSequences) + .temperature(this.temperature) + .topK(this.topK) + .topP(this.topP) + .build()) + .safetySettings(this.safetySettings) + .tools(FunctionMapper.fromToolSepcsToGTool(toolSpecifications, this.allowCodeExecution)) + .toolConfig(new GeminiToolConfig(this.toolConfig)) + .build(); + } + + protected ChatModelRequest createChatModelRequest( + List messages, + List toolSpecifications + ) { + return ChatModelRequest.builder() + .model(modelName) + .temperature(temperature) + .topP(topP) + .maxTokens(maxOutputTokens) + .messages(messages) + .toolSpecifications(toolSpecifications) + .build(); + } + + protected static String computeMimeType(ResponseFormat responseFormat) { + if (responseFormat == null || ResponseFormatType.TEXT.equals(responseFormat.type())) { + return "text/plain"; + } + + if (ResponseFormatType.JSON.equals(responseFormat.type()) && + responseFormat.jsonSchema() != null && + responseFormat.jsonSchema().rootElement() != null && + responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema) { + return "text/x.enum"; + } + + return "application/json"; + } + + protected void notifyListenersOnRequest(ChatModelRequestContext context) { + listeners.forEach((listener) -> { + try { + listener.onRequest(context); + } catch (Exception e) { + log.warn("Exception while calling model listener (onRequest)", e); + } + }); + } + + protected void notifyListenersOnResponse(Response response, ChatModelRequest request, + ConcurrentHashMap attributes) { + ChatModelResponse chatModelResponse = ChatModelResponse.builder() + .model(modelName) + .tokenUsage(response.tokenUsage()) + .finishReason(response.finishReason()) + .aiMessage(response.content()) + .build(); + ChatModelResponseContext context = new ChatModelResponseContext( + chatModelResponse, request, attributes); + listeners.forEach((listener) -> { + try { + listener.onResponse(context); + } catch (Exception e) { + log.warn("Exception while calling model listener (onResponse)", e); + } + }); + } + + protected void notifyListenersOnError(Exception exception, ChatModelRequest request, + ConcurrentHashMap attributes) { + listeners.forEach((listener) -> { + try { + ChatModelErrorContext context = new ChatModelErrorContext( + exception, request, null, attributes); + listener.onError(context); + } catch (Exception e) { + log.warn("Exception while calling model listener (onError)", e); + } + }); + } +} + diff --git a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiService.java b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiService.java index 51d187d6d..29aa5212f 100644 --- a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiService.java +++ b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiService.java @@ -10,6 +10,8 @@ import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.time.Duration; +import java.util.stream.Collectors; +import java.util.stream.Stream; class GeminiService { private static final String GEMINI_AI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta"; @@ -48,6 +50,11 @@ class GeminiService { return sendRequest(url, apiKey, request, GoogleAiBatchEmbeddingResponse.class); } + Stream generateContentStream(String modelName, String apiKey, GeminiGenerateContentRequest request) { + String url = String.format("%s/models/%s:streamGenerateContent?alt=sse", GEMINI_AI_ENDPOINT, modelName); + return streamRequest(url, apiKey, request, GeminiGenerateContentResponse.class); + } + private T sendRequest(String url, String apiKey, Object requestBody, Class responseType) { String jsonBody = gson.toJson(requestBody); HttpRequest request = buildHttpRequest(url, apiKey, jsonBody); @@ -72,6 +79,40 @@ class GeminiService { } } + private Stream streamRequest(String url, String apiKey, Object requestBody, Class responseType) { + String jsonBody = gson.toJson(requestBody); + HttpRequest httpRequest = buildHttpRequest(url, apiKey, jsonBody); + + logRequest(jsonBody); + + try { + HttpResponse> httpResponse = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofLines()); + + if (httpResponse.statusCode() >= 300) { + String errorBody = httpResponse.body() + .collect(Collectors.joining("\n")); + + throw new RuntimeException(String.format("HTTP error (%d): %s", httpResponse.statusCode(), errorBody)); + } + + Stream responseStream = httpResponse.body() + .filter(line -> line.startsWith("data: ")) + .map(line -> line.substring(6)) // Remove "data: " prefix + .map(jsonString -> gson.fromJson(jsonString, responseType)); + + if (logger != null) { + responseStream = responseStream.peek(response -> logger.debug("Partial response from Gemini:\n{}", response)); + } + + return responseStream; + } catch (IOException e) { + throw new RuntimeException("An error occurred while streaming the request", e); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Streaming the request was interrupted", e); + } + } + private HttpRequest buildHttpRequest(String url, String apiKey, String jsonBody) { return HttpRequest.newBuilder() .uri(URI.create(url)) diff --git a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiStreamingResponseBuilder.java b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiStreamingResponseBuilder.java new file mode 100644 index 000000000..c397fc4ad --- /dev/null +++ b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GeminiStreamingResponseBuilder.java @@ -0,0 +1,111 @@ +package dev.langchain4j.model.googleai; + +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static dev.langchain4j.model.googleai.FinishReasonMapper.fromGFinishReasonToFinishReason; +import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromGPartsToAiMessage; + +/** + * A builder class for constructing streaming responses from Gemini AI model. + * This class accumulates partial responses and builds a final response. + */ +class GeminiStreamingResponseBuilder { + private final boolean includeCodeExecutionOutput; + private final StringBuilder contentBuilder; + private final List functionCalls; + private TokenUsage tokenUsage; + private FinishReason finishReason; + + /** + * Constructs a new GeminiStreamingResponseBuilder. + * + * @param includeCodeExecutionOutput whether to include code execution output in the response + */ + public GeminiStreamingResponseBuilder(boolean includeCodeExecutionOutput) { + this.includeCodeExecutionOutput = includeCodeExecutionOutput; + this.contentBuilder = new StringBuilder(); + this.functionCalls = new ArrayList<>(); + } + + /** + * Appends a partial response to the builder. + * + * @param partialResponse the partial response from Gemini AI + * @return an Optional containing the text of the partial response, or empty if no valid text + */ + public Optional append(GeminiGenerateContentResponse partialResponse) { + if (partialResponse == null) { + return Optional.empty(); + } + + GeminiCandidate firstCandidate = partialResponse.getCandidates().get(0); + + updateFinishReason(firstCandidate); + updateTokenUsage(partialResponse.getUsageMetadata()); + + GeminiContent content = firstCandidate.getContent(); + if (content == null || content.getParts() == null) { + return Optional.empty(); + } + + AiMessage message = fromGPartsToAiMessage(content.getParts(), this.includeCodeExecutionOutput); + updateContentAndFunctionCalls(message); + + return Optional.ofNullable(message.text()); + } + + /** + * Builds the final response from all accumulated partial responses. + * + * @return a Response object containing the final AiMessage, token usage, and finish reason + */ + public Response build() { + AiMessage aiMessage = createAiMessage(); + return Response.from(aiMessage, tokenUsage, finishReason); + } + + private void updateTokenUsage(GeminiUsageMetadata tokenCounts) { + this.tokenUsage = new TokenUsage( + tokenCounts.getPromptTokenCount(), + tokenCounts.getCandidatesTokenCount(), + tokenCounts.getTotalTokenCount() + ); + } + + private void updateFinishReason(GeminiCandidate candidate) { + if (candidate.getFinishReason() != null) { + this.finishReason = fromGFinishReasonToFinishReason(candidate.getFinishReason()); + } + } + + private void updateContentAndFunctionCalls(AiMessage message) { + Optional.ofNullable(message.text()).ifPresent(contentBuilder::append); + if (message.hasToolExecutionRequests()) { + functionCalls.addAll(message.toolExecutionRequests()); + } + } + + private AiMessage createAiMessage() { + String text = contentBuilder.toString(); + boolean hasText = !text.isEmpty() && !text.isBlank(); + boolean hasFunctionCall = !functionCalls.isEmpty(); + + if (hasText && hasFunctionCall) { + return new AiMessage(text, functionCalls); + } else if (hasText) { + return new AiMessage(text); + } else if (hasFunctionCall) { + return new AiMessage(functionCalls); + } + + throw new RuntimeException("Gemini has responded neither with text nor with a function call."); + } +} diff --git a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModel.java b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModel.java index 253464abf..55e2aed06 100644 --- a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModel.java +++ b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModel.java @@ -1,6 +1,5 @@ package dev.langchain4j.model.googleai; -import com.google.gson.Gson; import dev.langchain4j.Experimental; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; @@ -8,11 +7,12 @@ import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.model.chat.Capability; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.TokenCountEstimator; -import dev.langchain4j.model.chat.listener.*; +import dev.langchain4j.model.chat.listener.ChatModelListener; +import dev.langchain4j.model.chat.listener.ChatModelRequest; +import dev.langchain4j.model.chat.listener.ChatModelRequestContext; import dev.langchain4j.model.chat.request.ChatRequest; import dev.langchain4j.model.chat.request.ResponseFormat; import dev.langchain4j.model.chat.request.ResponseFormatType; -import dev.langchain4j.model.chat.request.json.JsonEnumSchema; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; @@ -23,135 +23,63 @@ import lombok.extern.slf4j.Slf4j; import java.time.Duration; import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.List; -import java.util.ArrayList; import java.util.Map; import java.util.Set; -import java.util.HashSet; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; import static dev.langchain4j.internal.RetryUtils.withRetry; -import static dev.langchain4j.internal.Utils.copyIfNotNull; import static dev.langchain4j.internal.Utils.getOrDefault; -import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA; -import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromMessageToGContent; import static dev.langchain4j.model.googleai.FinishReasonMapper.fromGFinishReasonToFinishReason; import static dev.langchain4j.model.googleai.PartsAndContentsMapper.fromGPartsToAiMessage; -import static dev.langchain4j.model.googleai.SchemaMapper.fromJsonSchemaToGSchema; import static java.time.Duration.ofSeconds; -import static java.util.Collections.emptyList; @Experimental @Slf4j -public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEstimator { - private final GeminiService geminiService; - - private final String apiKey; - private final String modelName; - - private final Integer maxRetries; - private final Double temperature; - private final Integer topK; - private final Double topP; - private final Integer maxOutputTokens; - private final List stopSequences; - - private final Integer candidateCount; - - private final ResponseFormat responseFormat; - - private final GeminiFunctionCallingConfig toolConfig; - - private final boolean allowCodeExecution; - private final boolean includeCodeExecutionOutput; - - private final Boolean logRequestsAndResponses; - private final List safetySettings; - private final List listeners; - +public class GoogleAiGeminiChatModel extends BaseGeminiChatModel implements ChatLanguageModel, TokenCountEstimator { private final GoogleAiGeminiTokenizer geminiTokenizer; @Builder - public GoogleAiGeminiChatModel(String apiKey, String modelName, - Integer maxRetries, - Double temperature, Integer topK, Double topP, - Integer maxOutputTokens, Integer candidateCount, - Duration timeout, - ResponseFormat responseFormat, - List stopSequences, GeminiFunctionCallingConfig toolConfig, - Boolean allowCodeExecution, Boolean includeCodeExecutionOutput, - Boolean logRequestsAndResponses, - List safetySettings, - List listeners + public GoogleAiGeminiChatModel( + String apiKey, String modelName, + Integer maxRetries, + Double temperature, Integer topK, Double topP, + Integer maxOutputTokens, Duration timeout, + ResponseFormat responseFormat, + List stopSequences, GeminiFunctionCallingConfig toolConfig, + Boolean allowCodeExecution, Boolean includeCodeExecutionOutput, + Boolean logRequestsAndResponses, + List safetySettings, + List listeners ) { - this.apiKey = ensureNotBlank(apiKey, "apiKey"); - this.modelName = ensureNotBlank(modelName, "modelName"); - - this.maxRetries = getOrDefault(maxRetries, 3); - - // using Gemini's default values - this.temperature = getOrDefault(temperature, 1.0); - this.topK = getOrDefault(topK, 64); - this.topP = getOrDefault(topP, 0.95); - this.maxOutputTokens = getOrDefault(maxOutputTokens, 8192); - this.candidateCount = getOrDefault(candidateCount, 1); - this.stopSequences = getOrDefault(stopSequences, emptyList()); - - this.toolConfig = toolConfig; - - this.allowCodeExecution = allowCodeExecution != null ? allowCodeExecution : false; - this.includeCodeExecutionOutput = includeCodeExecutionOutput != null ? includeCodeExecutionOutput : false; - this.logRequestsAndResponses = getOrDefault(logRequestsAndResponses, false); - - this.safetySettings = copyIfNotNull(safetySettings); - - this.responseFormat = responseFormat; - - this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners); - - this.geminiService = new GeminiService( - getOrDefault(logRequestsAndResponses, false) ? log : null, - getOrDefault(timeout, ofSeconds(60)) - ); + super(apiKey, modelName, temperature, topK, topP, maxOutputTokens, timeout, + responseFormat, stopSequences, toolConfig, allowCodeExecution, + includeCodeExecutionOutput, logRequestsAndResponses, safetySettings, + listeners, maxRetries); this.geminiTokenizer = GoogleAiGeminiTokenizer.builder() - .modelName(this.modelName) - .apiKey(this.apiKey) - .timeout(getOrDefault(timeout, ofSeconds(60))) - .maxRetries(this.maxRetries) - .logRequestsAndResponses(this.logRequestsAndResponses) - .build(); - } - - private static String computeMimeType(ResponseFormat responseFormat) { - if (responseFormat == null || ResponseFormatType.TEXT.equals(responseFormat.type())) { - return "text/plain"; - } - - if (ResponseFormatType.JSON.equals(responseFormat.type()) && - responseFormat.jsonSchema() != null && - responseFormat.jsonSchema().rootElement() != null && - responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema) { - - return "text/x.enum"; - } - - return "application/json"; + .modelName(this.modelName) + .apiKey(this.apiKey) + .timeout(getOrDefault(timeout, ofSeconds(60))) + .maxRetries(this.maxRetries) + .logRequestsAndResponses(getOrDefault(logRequestsAndResponses, false)) + .build(); } @Override public Response generate(List messages) { ChatRequest request = ChatRequest.builder() - .messages(messages) - .build(); + .messages(messages) + .build(); ChatResponse response = chat(request); return Response.from(response.aiMessage(), - response.tokenUsage(), - response.finishReason()); + response.tokenUsage(), + response.finishReason()); } @Override @@ -162,128 +90,86 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst @Override public Response generate(List messages, List toolSpecifications) { ChatRequest request = ChatRequest.builder() - .messages(messages) - .toolSpecifications(toolSpecifications) - .build(); + .messages(messages) + .toolSpecifications(toolSpecifications) + .build(); ChatResponse response = chat(request); return Response.from(response.aiMessage(), - response.tokenUsage(), - response.finishReason()); + response.tokenUsage(), + response.finishReason()); } @Override public ChatResponse chat(ChatRequest chatRequest) { - GeminiContent systemInstruction = new GeminiContent(GeminiRole.MODEL.toString()); - List geminiContentList = fromMessageToGContent(chatRequest.messages(), systemInstruction); - List toolSpecifications = chatRequest.toolSpecifications(); + GeminiGenerateContentRequest request = createGenerateContentRequest( + chatRequest.messages(), + chatRequest.toolSpecifications(), + getOrDefault(chatRequest.responseFormat(), this.responseFormat) + ); - ResponseFormat format = chatRequest.responseFormat() != null ? chatRequest.responseFormat() : this.responseFormat; - GeminiSchema schema = null; + ChatModelRequest chatModelRequest = createChatModelRequest( + chatRequest.messages(), + chatRequest.toolSpecifications() + ); - String responseMimeType = computeMimeType(format); - - if (format != null && format.jsonSchema() != null) { - schema = fromJsonSchemaToGSchema(format.jsonSchema()); - } - - GeminiGenerateContentRequest request = GeminiGenerateContentRequest.builder() - .contents(geminiContentList) - .systemInstruction(!systemInstruction.getParts().isEmpty() ? systemInstruction : null) - .generationConfig(GeminiGenerationConfig.builder() - .candidateCount(this.candidateCount) - .maxOutputTokens(this.maxOutputTokens) - .responseMimeType(responseMimeType) - .responseSchema(schema) - .stopSequences(this.stopSequences) - .temperature(this.temperature) - .topK(this.topK) - .topP(this.topP) - .build()) - .safetySettings(this.safetySettings) - .tools(FunctionMapper.fromToolSepcsToGTool(toolSpecifications, this.allowCodeExecution)) - .toolConfig(new GeminiToolConfig(this.toolConfig)) - .build(); - - ChatModelRequest chatModelRequest = ChatModelRequest.builder() - .model(modelName) - .temperature(temperature) - .topP(topP) - .maxTokens(maxOutputTokens) - .messages(chatRequest.messages()) - .toolSpecifications(chatRequest.toolSpecifications()) - .build(); ConcurrentHashMap listenerAttributes = new ConcurrentHashMap<>(); - ChatModelRequestContext chatModelRequestContext = new ChatModelRequestContext(chatModelRequest, listenerAttributes); - listeners.forEach((listener) -> { - try { - listener.onRequest(chatModelRequestContext); - } catch (Exception e) { - log.warn("Exception while calling model listener (onRequest)", e); - } - }); + notifyListenersOnRequest(new ChatModelRequestContext(chatModelRequest, listenerAttributes)); - GeminiGenerateContentResponse geminiResponse; try { - geminiResponse = withRetry(() -> this.geminiService.generateContent(this.modelName, this.apiKey, request), this.maxRetries); - } catch (RuntimeException e) { - ChatModelErrorContext chatModelErrorContext = new ChatModelErrorContext( - e, chatModelRequest, null, listenerAttributes + GeminiGenerateContentResponse geminiResponse = withRetry( + () -> this.geminiService.generateContent(this.modelName, this.apiKey, request), + this.maxRetries ); - listeners.forEach((listener) -> { - try { - listener.onError(chatModelErrorContext); - } catch (Exception ex) { - log.warn("Exception while calling model listener (onError)", ex); - } - }); + return processResponse(geminiResponse, chatModelRequest, listenerAttributes); + } catch (RuntimeException e) { + notifyListenersOnError(e, chatModelRequest, listenerAttributes); throw e; } + } - if (geminiResponse != null) { - GeminiCandidate firstCandidate = geminiResponse.getCandidates().get(0); //TODO handle n - GeminiUsageMetadata tokenCounts = geminiResponse.getUsageMetadata(); - - AiMessage aiMessage; - - FinishReason finishReason = fromGFinishReasonToFinishReason(firstCandidate.getFinishReason()); - if (firstCandidate.getContent() == null) { - aiMessage = AiMessage.from("No text was returned by the model. " + - "The model finished generating because of the following reason: " + finishReason); - } else { - aiMessage = fromGPartsToAiMessage(firstCandidate.getContent().getParts(), this.includeCodeExecutionOutput); - } - - TokenUsage tokenUsage = new TokenUsage(tokenCounts.getPromptTokenCount(), - tokenCounts.getCandidatesTokenCount(), - tokenCounts.getTotalTokenCount()); - - ChatModelResponse chatModelResponse = ChatModelResponse.builder() - .model(modelName) - .tokenUsage(tokenUsage) - .finishReason(finishReason) - .aiMessage(aiMessage) - .build(); - ChatModelResponseContext chatModelResponseContext = new ChatModelResponseContext( - chatModelResponse, chatModelRequest, listenerAttributes); - listeners.forEach((listener) -> { - try { - listener.onResponse(chatModelResponseContext); - } catch (Exception e) { - log.warn("Exception while calling model listener (onResponse)", e); - } - }); - - return ChatResponse.builder() - .aiMessage(aiMessage) - .finishReason(finishReason) - .tokenUsage(tokenUsage) - .build(); - } else { + private ChatResponse processResponse( + GeminiGenerateContentResponse geminiResponse, + ChatModelRequest chatModelRequest, + ConcurrentHashMap listenerAttributes + ) { + if (geminiResponse == null) { throw new RuntimeException("Gemini response was null"); } + + GeminiCandidate firstCandidate = geminiResponse.getCandidates().get(0); + GeminiUsageMetadata tokenCounts = geminiResponse.getUsageMetadata(); + + FinishReason finishReason = fromGFinishReasonToFinishReason(firstCandidate.getFinishReason()); + AiMessage aiMessage = createAiMessage(firstCandidate, finishReason); + TokenUsage tokenUsage = createTokenUsage(tokenCounts); + + Response response = Response.from(aiMessage, tokenUsage, finishReason); + notifyListenersOnResponse(response, chatModelRequest, listenerAttributes); + + return ChatResponse.builder() + .aiMessage(aiMessage) + .finishReason(finishReason) + .tokenUsage(tokenUsage) + .build(); + } + + private AiMessage createAiMessage(GeminiCandidate candidate, FinishReason finishReason) { + if (candidate.getContent() == null) { + return AiMessage.from("No text was returned by the model. " + + "The model finished generating because of the following reason: " + finishReason); + } + return fromGPartsToAiMessage(candidate.getContent().getParts(), this.includeCodeExecutionOutput); + } + + private TokenUsage createTokenUsage(GeminiUsageMetadata tokenCounts) { + return new TokenUsage( + tokenCounts.getPromptTokenCount(), + tokenCounts.getCandidatesTokenCount(), + tokenCounts.getTotalTokenCount() + ); } @Override @@ -309,9 +195,9 @@ public class GoogleAiGeminiChatModel implements ChatLanguageModel, TokenCountEst public GoogleAiGeminiChatModelBuilder safetySettings(Map safetySettingMap) { this.safetySettings = safetySettingMap.entrySet().stream() - .map(entry -> new GeminiSafetySetting(entry.getKey(), entry.getValue()) - ).collect(Collectors.toList()); + .map(entry -> new GeminiSafetySetting(entry.getKey(), entry.getValue()) + ).collect(Collectors.toList()); return this; } } -} \ No newline at end of file +} diff --git a/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiStreamingChatModel.java b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiStreamingChatModel.java new file mode 100644 index 000000000..dd60903ca --- /dev/null +++ b/langchain4j-google-ai-gemini/src/main/java/dev/langchain4j/model/googleai/GoogleAiGeminiStreamingChatModel.java @@ -0,0 +1,93 @@ +package dev.langchain4j.model.googleai; + +import dev.langchain4j.Experimental; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.chat.listener.ChatModelListener; +import dev.langchain4j.model.chat.listener.ChatModelRequest; +import dev.langchain4j.model.chat.listener.ChatModelRequestContext; +import dev.langchain4j.model.chat.request.ResponseFormat; +import dev.langchain4j.model.output.Response; +import lombok.Builder; +import lombok.extern.slf4j.Slf4j; + +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Stream; + +import static dev.langchain4j.internal.RetryUtils.withRetry; + +@Experimental +@Slf4j +public class GoogleAiGeminiStreamingChatModel extends BaseGeminiChatModel implements StreamingChatLanguageModel { + @Builder + public GoogleAiGeminiStreamingChatModel( + String apiKey, String modelName, + Double temperature, Integer topK, Double topP, + Integer maxOutputTokens, Duration timeout, + ResponseFormat responseFormat, + List stopSequences, GeminiFunctionCallingConfig toolConfig, + Boolean allowCodeExecution, Boolean includeCodeExecutionOutput, + Boolean logRequestsAndResponses, + List safetySettings, + List listeners, + Integer maxRetries + ) { + super(apiKey, modelName, temperature, topK, topP, maxOutputTokens, timeout, + responseFormat, stopSequences, toolConfig, allowCodeExecution, + includeCodeExecutionOutput, logRequestsAndResponses, safetySettings, + listeners, maxRetries); + } + + @Override + public void generate(List messages, StreamingResponseHandler handler) { + generate(messages, Collections.emptyList(), handler); + } + + @Override + public void generate(List messages, ToolSpecification toolSpecification, StreamingResponseHandler handler) { + throw new RuntimeException("This method is not supported: Gemini AI cannot be forced to execute a tool."); + } + + @Override + public void generate(List messages, List toolSpecifications, StreamingResponseHandler handler) { + GeminiGenerateContentRequest request = createGenerateContentRequest(messages, toolSpecifications, this.responseFormat); + ChatModelRequest chatModelRequest = createChatModelRequest(messages, toolSpecifications); + + ConcurrentHashMap listenerAttributes = new ConcurrentHashMap<>(); + ChatModelRequestContext chatModelRequestContext = new ChatModelRequestContext(chatModelRequest, listenerAttributes); + notifyListenersOnRequest(chatModelRequestContext); + + processGenerateContentRequest(request, handler, chatModelRequest, listenerAttributes); + } + + private void processGenerateContentRequest(GeminiGenerateContentRequest request, StreamingResponseHandler handler, + ChatModelRequest chatModelRequest, ConcurrentHashMap listenerAttributes) { + GeminiStreamingResponseBuilder responseBuilder = new GeminiStreamingResponseBuilder(this.includeCodeExecutionOutput); + + try { + Stream contentStream = withRetry( + () -> this.geminiService.generateContentStream(this.modelName, this.apiKey, request), + maxRetries); + + contentStream.forEach(partialResponse -> { + Optional text = responseBuilder.append(partialResponse); + text.ifPresent(handler::onNext); + }); + + Response fullResponse = responseBuilder.build(); + handler.onComplete(fullResponse); + + notifyListenersOnResponse(fullResponse, chatModelRequest, listenerAttributes); + } catch (RuntimeException exception) { + notifyListenersOnError(exception, chatModelRequest, listenerAttributes); + handler.onError(exception); + } + } +} diff --git a/langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModelIT.java b/langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModelIT.java index 06264b5b4..80209663a 100644 --- a/langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModelIT.java +++ b/langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/GoogleAiGeminiChatModelIT.java @@ -18,12 +18,12 @@ import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.memory.chat.MessageWindowChatMemory; import dev.langchain4j.model.chat.request.ChatRequest; import dev.langchain4j.model.chat.request.ResponseFormat; -import dev.langchain4j.model.chat.request.json.JsonIntegerSchema; -import dev.langchain4j.model.chat.request.json.JsonSchema; import dev.langchain4j.model.chat.request.json.JsonArraySchema; -import dev.langchain4j.model.chat.request.json.JsonObjectSchema; -import dev.langchain4j.model.chat.request.json.JsonSchemaElement; import dev.langchain4j.model.chat.request.json.JsonEnumSchema; +import dev.langchain4j.model.chat.request.json.JsonIntegerSchema; +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; +import dev.langchain4j.model.chat.request.json.JsonSchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; @@ -34,7 +34,13 @@ import org.junit.jupiter.api.Test; import org.junitpioneer.jupiter.RetryingTest; import java.time.Duration; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Base64; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import static dev.langchain4j.internal.Utils.readBytes; @@ -44,7 +50,9 @@ import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HA import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HATE_SPEECH; import static java.util.Collections.singletonList; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; public class GoogleAiGeminiChatModelIT { @@ -114,8 +122,8 @@ public class GoogleAiGeminiChatModelIT { assertThat(jsonText).contains("\"John\""); assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(25); - assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(6); - assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(31); + assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(7); + assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(32); } @Test @@ -435,10 +443,10 @@ public class GoogleAiGeminiChatModelIT { .jsonSchema(JsonSchema.builder() .rootElement(JsonObjectSchema.builder() .addStringProperty("name") - .addProperty("address", JsonObjectSchema.builder() - .addStringProperty("city") - .required("city") - .build()) + .addProperty("address", JsonObjectSchema.builder() + .addStringProperty("city") + .required("city") + .build()) .required("name", "address") .additionalProperties(false) .build()) diff --git a/langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/GoogleAiGeminiStreamingChatModelIT.java b/langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/GoogleAiGeminiStreamingChatModelIT.java new file mode 100644 index 000000000..563d9a77b --- /dev/null +++ b/langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/GoogleAiGeminiStreamingChatModelIT.java @@ -0,0 +1,708 @@ +package dev.langchain4j.model.googleai; + +import com.google.gson.Gson; +import dev.langchain4j.agent.tool.P; +import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.AudioContent; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ImageContent; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.TextContent; +import dev.langchain4j.data.message.TextFileContent; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.chat.TestStreamingResponseHandler; +import dev.langchain4j.model.chat.request.ResponseFormat; +import dev.langchain4j.model.chat.request.json.JsonArraySchema; +import dev.langchain4j.model.chat.request.json.JsonEnumSchema; +import dev.langchain4j.model.chat.request.json.JsonIntegerSchema; +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; +import dev.langchain4j.model.chat.request.json.JsonSchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.service.AiServices; +import dev.langchain4j.service.TokenStream; +import dev.langchain4j.service.output.JsonSchemas; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junitpioneer.jupiter.RetryingTest; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Base64; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.stream.Collectors; + +import static dev.langchain4j.internal.Utils.readBytes; +import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON; +import static dev.langchain4j.model.googleai.GeminiHarmBlockThreshold.BLOCK_LOW_AND_ABOVE; +import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HARASSMENT; +import static dev.langchain4j.model.googleai.GeminiHarmCategory.HARM_CATEGORY_HATE_SPEECH; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +public class GoogleAiGeminiStreamingChatModelIT { + + private static final String GOOGLE_AI_GEMINI_API_KEY = System.getenv("GOOGLE_AI_GEMINI_API_KEY"); + + private static final String CAT_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/e/e9/Felis_silvestris_silvestris_small_gradual_decrease_of_quality.png"; + private static final String MD_FILE_URL = "https://raw.githubusercontent.com/langchain4j/langchain4j/main/docs/docs/intro.md"; + + @Test + void should_answer_simple_question() { + // given + GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-flash") + .build(); + + UserMessage userMessage = UserMessage.from("What is the capital of France?"); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + gemini.generate(userMessage, handler); + Response response = handler.get(); + + // then + String text = response.content().text(); + assertThat(text).containsIgnoringCase("Paris"); + + assertThat(response.finishReason()).isEqualTo(FinishReason.STOP); + } + + @Test + void should_configure_generation() { + // given + GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-flash") + .temperature(2.0) + .topP(0.5) + .topK(10) + .maxOutputTokens(10) + .build(); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + gemini.generate("How much is 3+4? Reply with just the answer", handler); + Response response = handler.get(); + + // then + String text = response.content().text(); + assertThat(text.trim()).isEqualTo("7"); + } + + @Test + void should_answer_in_json() { + // given + GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-pro") + .responseFormat(ResponseFormat.JSON) + .logRequestsAndResponses(true) + .build(); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + gemini.generate(UserMessage.from("What is the firstname of the John Doe?\n" + + "Reply in JSON following with the following format: {\"firstname\": string}"), handler); + Response response = handler.get(); + + // then + String jsonText = response.content().text(); + + assertThat(jsonText).contains("\"firstname\""); + assertThat(jsonText).contains("\"John\""); + + assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(25); + assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(7); + assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(32); + } + + @Test + void should_support_multiturn_chatting() { + // given + GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-flash") + .build(); + + List messages = new ArrayList<>(); + messages.add(UserMessage.from("Hello, my name is Guillaume")); + messages.add(AiMessage.from("Hi, how can I assist you today?")); + messages.add(UserMessage.from("What is my name? Reply with just my name.")); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + gemini.generate(messages, handler); + Response response = handler.get(); + + // then + assertThat(response.content().text()).contains("Guillaume"); + } + + @Test + void should_support_sending_images_as_base64() { + // given + GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-flash") + .build(); + + UserMessage userMessage = UserMessage.userMessage( + ImageContent.from(new String(Base64.getEncoder().encode(readBytes(CAT_IMAGE_URL))), "image/png"), + TextContent.from("Describe this image in a single word") + ); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + gemini.generate(userMessage, handler); + Response response = handler.get(); + + // then + assertThat(response.content().text()).containsIgnoringCase("cat"); + } + + @Test + void should_support_text_file() { + // given + GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-flash") + .build(); + + UserMessage userMessage = UserMessage.userMessage( + TextFileContent.from(new String(Base64.getEncoder().encode(readBytes(MD_FILE_URL))), "text/markdown"), + TextContent.from("What project does this markdown file mention?") + ); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + gemini.generate(userMessage, handler); + Response response = handler.get(); + + // then + assertThat(response.content().text()).containsIgnoringCase("LangChain4j"); + } + + @Test + void should_support_audio_file() { + // given + GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-flash") + .build(); + + UserMessage userMessage = UserMessage.from( + AudioContent.from( + new String(Base64.getEncoder().encode( //TODO use local file + readBytes("https://storage.googleapis.com/cloud-samples-data/generative-ai/audio/pixel.mp3"))), + "audio/mp3"), + TextContent.from("Give a summary of the audio") + ); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + gemini.generate(userMessage, handler); + Response response = handler.get(); + + // then + assertThat(response.content().text()).containsIgnoringCase("Pixel"); + } + + @Test + @Disabled + void should_support_video_file() { + // ToDo waiting for the normal GoogleAiGeminiChatModel to implement the test + } + + @Test + void should_respect_system_instruction() { + // given + GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-flash") + .build(); + + List chatMessages = List.of( + SystemMessage.from("Translate from English into French"), + UserMessage.from("apple") + ); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + gemini.generate(chatMessages, handler); + Response response = handler.get(); + + // then + assertThat(response.content().text()).containsIgnoringCase("pomme"); + } + + @Test + void should_execute_python_code() { + // given + GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-flash") + .logRequestsAndResponses(true) + .allowCodeExecution(true) + .includeCodeExecutionOutput(true) + .build(); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + gemini.generate( + UserMessage.from("Calculate `fibonacci(13)`. Write code in Python and execute it to get the result."), + handler + ); + Response response = handler.get(); + + // then + String text = response.content().text(); + System.out.println("text = " + text); + + assertThat(text).containsIgnoringCase("233"); + } + + @Test + void should_support_JSON_array_in_tools() { + // given + GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-flash") + .logRequestsAndResponses(true) + .build(); + + // when + List allMessages = new ArrayList<>(); + allMessages.add(UserMessage.from("Return a JSON list containing the first 10 fibonacci numbers.")); + + + TestStreamingResponseHandler handler1 = new TestStreamingResponseHandler<>(); + gemini.generate( + allMessages, + List.of(ToolSpecification.builder() + .name("getFirstNFibonacciNumbers") + .description("Get the first n fibonacci numbers") + .parameters(JsonObjectSchema.builder() + .addNumberProperty("n") + .build()) + .build()), + handler1 + ); + Response response1 = handler1.get(); + + // then + assertThat(response1.content().hasToolExecutionRequests()).isTrue(); + assertThat(response1.content().toolExecutionRequests().get(0).name()).isEqualTo("getFirstNFibonacciNumbers"); + assertThat(response1.content().toolExecutionRequests().get(0).arguments()).contains("\"n\":10"); + + allMessages.add(response1.content()); + + // when + ToolExecutionResultMessage forecastResult = + ToolExecutionResultMessage.from(null, "getFirstNFibonacciNumbers", "[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]"); + allMessages.add(forecastResult); + + // then + TestStreamingResponseHandler handler2 = new TestStreamingResponseHandler<>(); + gemini.generate(allMessages, handler2); + Response response2 = handler2.get(); + + // then + assertThat(response2.content().text()).contains("[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]"); + } + + // Test is flaky, because Gemini doesn't 100% always ask for parallel tool calls + // and sometimes requests more information + @RetryingTest(5) + void should_support_parallel_tool_execution() { + // given + GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-flash") + .logRequestsAndResponses(true) + .build(); + + // when + List allMessages = new ArrayList<>(); + allMessages.add(UserMessage.from("Which warehouse has more stock, ABC or XYZ?")); + + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + gemini.generate( + allMessages, + List.of(ToolSpecification.builder() + .name("getWarehouseStock") + .description("Retrieve the amount of stock available in a warehouse designated by its name") + .parameters(JsonObjectSchema.builder() + .addStringProperty("name", "The name of the warehouse") + .build()) + .build()), + handler + ); + Response response = handler.get(); + + + // then + assertThat(response.content().hasToolExecutionRequests()).isTrue(); + + List executionRequests = response.content().toolExecutionRequests(); + assertThat(executionRequests).hasSize(2); + + String allArgs = executionRequests.stream() + .map(ToolExecutionRequest::arguments) + .collect(Collectors.joining(" ")); + assertThat(allArgs).contains("ABC"); + assertThat(allArgs).contains("XYZ"); + } + + @RetryingTest(5) + void should_support_safety_settings() { + // given + List safetySettings = List.of( + new GeminiSafetySetting(HARM_CATEGORY_HATE_SPEECH, BLOCK_LOW_AND_ABOVE), + new GeminiSafetySetting(HARM_CATEGORY_HARASSMENT, BLOCK_LOW_AND_ABOVE) + ); + + GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-flash") + .logRequestsAndResponses(true) + .safetySettings(safetySettings) + .build(); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + gemini.generate("You're a dumb f*cking idiot bastard!", handler); + Response response = handler.get(); + + // then + assertThat(response.finishReason()).isEqualTo(FinishReasonMapper.fromGFinishReasonToFinishReason(GeminiFinishReason.SAFETY)); + } + + @Test + void should_comply_to_response_schema() { + // given + GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-flash") + .logRequestsAndResponses(true) + .responseFormat(ResponseFormat.builder() + .type(JSON) + .jsonSchema(JsonSchema.builder() + .rootElement(JsonObjectSchema.builder() + .addStringProperty("name") + .addProperty("address", JsonObjectSchema.builder() + .addStringProperty("city") + .required("city") + .build()) + .required("name", "address") + .additionalProperties(false) + .build()) + .build()) + .build()) + .build(); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + gemini.generate(List.of( + SystemMessage.from( + "Your role is to extract information related to a person," + + "like their name, their address, the city the live in."), + UserMessage.from( + "In the town of Liverpool, lived Tommy Skybridge, a young little boy." + ) + ), handler); + Response response = handler.get(); + + System.out.println("response = " + response); + + // then + assertThat(response.content().text().trim()) + .isEqualTo("{\"address\": {\"city\": \"Liverpool\"}, \"name\": \"Tommy Skybridge\"}"); + } + + @Test + void should_handle_enum_type() { + // given + GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-flash") + .logRequestsAndResponses(true) + .responseFormat(ResponseFormat.builder() + .type(JSON) + .jsonSchema(JsonSchema.builder() + .rootElement(JsonObjectSchema.builder() + .properties(new LinkedHashMap() {{ + put("sentiment", JsonEnumSchema.builder() + .enumValues("POSITIVE", "NEGATIVE") + .build()); + }}) + .required("sentiment") + .additionalProperties(false) + .build()) + .build()) + .build()) + .build(); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + gemini.generate(List.of( + SystemMessage.from( + "Your role is to analyze the sentiment of the text you receive."), + UserMessage.from( + "This is super exciting news, congratulations!" + ) + ), handler); + Response response = handler.get(); + + System.out.println("response = " + response); + + // then + assertThat(response.content().text().trim()) + .isEqualTo("{\"sentiment\": \"POSITIVE\"}"); + } + + @Test + void should_handle_enum_output_mode() { + // given + GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-flash") + .logRequestsAndResponses(true) + .responseFormat(ResponseFormat.builder() + .type(JSON) + .jsonSchema(JsonSchema.builder() + .rootElement(JsonEnumSchema.builder() + .enumValues("POSITIVE", "NEUTRAL", "NEGATIVE") + .build()) + .build()) + .build()) + .build(); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + gemini.generate(List.of( + SystemMessage.from( + "Your role is to analyze the sentiment of the text you receive."), + UserMessage.from( + "This is super exciting news, congratulations!" + ) + ), handler); + Response response = handler.get(); + + // then + assertThat(response.content().text().trim()) + .isEqualTo("POSITIVE"); + } + + @Test + void should_allow_array_as_response_schema() { + // given + GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-flash") + .logRequestsAndResponses(true) + .responseFormat(ResponseFormat.builder() + .type(JSON) + .jsonSchema(JsonSchema.builder() + .rootElement(JsonArraySchema.builder() + .items(new JsonIntegerSchema()) + .build()) + .build()) + .build()) + .build(); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + gemini.generate(List.of( + SystemMessage.from( + "Your role is to return a list of 6-faces dice rolls"), + UserMessage.from( + "Give me 3 dice rolls" + ) + ), handler); + Response response = handler.get(); + + System.out.println("response = " + response); + + // then + Integer[] diceRolls = new Gson().fromJson(response.content().text(), Integer[].class); + assertThat(diceRolls.length).isEqualTo(3); + } + + private class Color { + private String name; + private int red; + private int green; + private int blue; + private boolean muted; + } + + @Test + void should_deserialize_to_POJO() { + // given + GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-flash") + .logRequestsAndResponses(true) + .responseFormat(ResponseFormat.builder() + .type(JSON) + .jsonSchema(JsonSchemas.jsonSchemaFrom(Color.class).get()) + .build()) + .build(); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + gemini.generate(List.of( + SystemMessage.from( + "Your role is to extract information from the description of a color"), + UserMessage.from( + "Cobalt blue is a blend of a lot of blue, a bit of green, and almost no red." + ) + ), handler); + Response response = handler.get(); + + System.out.println("response = " + response); + + Color color = new Gson().fromJson(response.content().text(), Color.class); + + // then + assertThat(color.name).isEqualToIgnoringCase("Cobalt blue"); + assertThat(color.muted).isFalse(); + assertThat(color.red).isLessThanOrEqualTo(color.green); + assertThat(color.green).isLessThanOrEqualTo(color.blue); + } + + @Test + void should_support_tool_config() { + // given + GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-flash") + .logRequestsAndResponses(true) + .build(); + + List chatMessages = List.of(UserMessage.from("Call toolOne")); + List listOfTools = Arrays.asList( + ToolSpecification.builder().name("toolOne").build(), + ToolSpecification.builder().name("toolTwo").build() + ); + + // when + TestStreamingResponseHandler handler1 = new TestStreamingResponseHandler<>(); + gemini.generate(chatMessages, listOfTools, handler1); + Response response1 = handler1.get(); + + // then + assertThat(response1.content().hasToolExecutionRequests()).isTrue(); + assertThat(response1.content().toolExecutionRequests().get(0).name()).isEqualTo("toolOne"); + + // given + gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-flash") + .logRequestsAndResponses(true) + .toolConfig(new GeminiFunctionCallingConfig(GeminiMode.ANY, List.of("toolTwo"))) + .build(); + + // when + TestStreamingResponseHandler handler2 = new TestStreamingResponseHandler<>(); + gemini.generate("Call toolOne", handler2); + Response response2 = handler2.get(); + + // then + assertThat(response2.content().hasToolExecutionRequests()).isFalse(); + } + + static class Transactions { + @Tool("returns amount of a given transaction") + double getTransactionAmount(@P("ID of a transaction") String id) { + System.out.printf("called getTransactionAmount(%s)%n", id); + switch (id) { + case "T001": + return 11.1; + case "T002": + return 22.2; + default: + throw new IllegalArgumentException("Unknown transaction ID: " + id); + } + } + } + + + interface StreamingAssistant { + TokenStream chat(String userMessage); + } + + @RetryingTest(10) + void should_work_with_tools_with_AiServices() throws ExecutionException, InterruptedException, TimeoutException { + // given + GoogleAiGeminiStreamingChatModel gemini = GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("gemini-1.5-pro") + .logRequestsAndResponses(true) + .timeout(Duration.ofMinutes(2)) + .temperature(0.0) + .topP(0.0) + .topK(1) + .build(); + + // when + Transactions spyTransactions = spy(new Transactions()); + + MessageWindowChatMemory chatMemory = MessageWindowChatMemory.withMaxMessages(20); + StreamingAssistant assistant = AiServices.builder(StreamingAssistant.class) + .tools(spyTransactions) + .chatMemory(chatMemory) + .streamingChatLanguageModel(gemini) + .build(); + + // then + CompletableFuture> future1 = new CompletableFuture<>(); + assistant.chat("What is the amount of transaction T001?") + .onNext(System.out::println) + .onComplete(future1::complete) + .onError(future1::completeExceptionally) + .start(); + Response response1 = future1.get(30, TimeUnit.SECONDS); + + assertThat(response1.content().text()).containsIgnoringCase("11.1"); + verify(spyTransactions).getTransactionAmount("T001"); + + CompletableFuture> future2 = new CompletableFuture<>(); + assistant.chat("What is the amount of transaction T002?") + .onNext(System.out::println) + .onComplete(future2::complete) + .onError(future2::completeExceptionally) + .start(); + Response response2 = future2.get(30, TimeUnit.SECONDS); + + assertThat(response2.content().text()).containsIgnoringCase("22.2"); + verify(spyTransactions).getTransactionAmount("T002"); + + verifyNoMoreInteractions(spyTransactions); + } + + @AfterEach + void afterEach() throws InterruptedException { + String ciDelaySeconds = System.getenv("CI_DELAY_SECONDS_GOOGLE_AI_GEMINI"); + if (ciDelaySeconds != null) { + Thread.sleep(Integer.parseInt(ciDelaySeconds) * 1000L); + } + } +} diff --git a/langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/GoogleAiGeminiStreamingChatModelListenerIT.java b/langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/GoogleAiGeminiStreamingChatModelListenerIT.java new file mode 100644 index 000000000..716b15dc4 --- /dev/null +++ b/langchain4j-google-ai-gemini/src/test/java/dev/langchain4j/model/googleai/GoogleAiGeminiStreamingChatModelListenerIT.java @@ -0,0 +1,61 @@ +package dev.langchain4j.model.googleai; + +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatModelListenerIT; +import dev.langchain4j.model.chat.listener.ChatModelListener; +import org.junit.jupiter.api.AfterEach; + +import java.io.IOException; + +import static java.util.Collections.singletonList; + +class GoogleAiGeminiStreamingChatModelListenerIT extends StreamingChatModelListenerIT { + private static final String GOOGLE_AI_GEMINI_API_KEY = System.getenv("GOOGLE_AI_GEMINI_API_KEY"); + + @Override + protected StreamingChatLanguageModel createModel(ChatModelListener listener) { + return GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName(modelName()) + .temperature(temperature()) + .topP(topP()) + .maxOutputTokens(maxTokens()) + .listeners(singletonList(listener)) + .logRequestsAndResponses(true) + .build(); + } + + @Override + protected String modelName() { + return "gemini-1.5-flash"; + } + + @Override + protected boolean assertResponseId() { + return false; + } + + @Override + protected StreamingChatLanguageModel createFailingModel(ChatModelListener listener) { + return GoogleAiGeminiStreamingChatModel.builder() + .apiKey(GOOGLE_AI_GEMINI_API_KEY) + .modelName("banana") + .listeners(singletonList(listener)) + .logRequestsAndResponses(true) + .build(); + } + + @Override + protected Class expectedExceptionClass() { + return RuntimeException.class; + } + + + @AfterEach + void afterEach() throws InterruptedException { + String ciDelaySeconds = System.getenv("CI_DELAY_SECONDS_VERTEX_AI_GEMINI"); + if (ciDelaySeconds != null) { + Thread.sleep(Integer.parseInt(ciDelaySeconds) * 1000L); + } + } +}