Fix #757: Gemini: allow SystemMessage(s), merge them into the first UserMessage, warn in logs (#812)

## Context
See https://github.com/langchain4j/langchain4j/issues/757

## Change
All `SystemMessage`s from the input are now merged together into the
first `UserMessage`.
Warning abut this is given (once) in the log.

## Checklist
Before submitting this PR, please check the following points:
- [X] I have added unit and integration tests for my change
- [X] All unit and integration tests in the module I have added/changed
are green
- [X] All unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules are green
- [X] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added my new module in the
[BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml)
(only when a new module is added)
This commit is contained in:
LangChain4j 2024-03-25 09:13:06 +01:00 committed by GitHub
parent b3c0dad47d
commit 016f0b60ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 228 additions and 128 deletions

View File

@ -140,6 +140,12 @@ Caused by: io.grpc.StatusRuntimeException:
`projects/{YOUR_PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-ultra`
```
## Warning
Please note that Gemini does not support `SystemMessage`s.
If there are `SystemMessage`s provided to the `generate()` methods, they will be merged into the first
`UserMessage` (before the content of the `UserMessage`).
## Apply for early access
[Early access for Gemma](https://docs.google.com/forms/d/e/1FAIpQLSe0grG6mRFW6dNF3Rb1h_YvKqUp2GaXiglZBgA2Os5iTLWlcg/viewform)

View File

@ -74,6 +74,17 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tinylog</groupId>
<artifactId>tinylog-impl</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tinylog</groupId>
<artifactId>slf4j-tinylog</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<dependencyManagement>

View File

@ -1,26 +1,75 @@
package dev.langchain4j.model.vertexai;
import com.google.cloud.vertexai.api.Content;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.*;
import lombok.extern.slf4j.Slf4j;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import static java.util.stream.Collectors.toList;
@Slf4j
class ContentsMapper {
static List<Content> map(List<ChatMessage> messages) {
private static volatile boolean warned = false;
static List<com.google.cloud.vertexai.api.Content> map(List<ChatMessage> messages) {
List<SystemMessage> systemMessages = messages.stream()
.filter(message -> message instanceof SystemMessage)
.map(message -> (SystemMessage) message)
.collect(toList());
if (!systemMessages.isEmpty()) {
if (!warned) {
log.warn("Gemini does not support SystemMessage(s). " +
"All SystemMessage(s) will be merged into the first UserMessage.");
warned = true;
}
messages = mergeSystemMessagesIntoUserMessage(messages, systemMessages);
}
// TODO what if only a single system message?
return messages.stream()
.peek(message -> {
if (message instanceof SystemMessage) {
throw new IllegalArgumentException("SystemMessage is currently not supported by Gemini");
}
})
.map(message -> Content.newBuilder()
.map(message -> com.google.cloud.vertexai.api.Content.newBuilder()
.setRole(RoleMapper.map(message.type()))
.addAllParts(PartsMapper.map(message))
.build())
.collect(toList());
}
private static List<ChatMessage> mergeSystemMessagesIntoUserMessage(List<ChatMessage> messages,
List<SystemMessage> systemMessages) {
AtomicBoolean injected = new AtomicBoolean(false);
return messages.stream()
.filter(message -> !(message instanceof SystemMessage))
.map(message -> {
if (injected.get()) {
return message;
}
if (message instanceof UserMessage) {
UserMessage userMessage = (UserMessage) message;
List<Content> allContents = new ArrayList<>();
allContents.addAll(systemMessages.stream()
.map(systemMessage -> TextContent.from(systemMessage.text()))
.collect(toList()));
allContents.addAll(userMessage.contents());
injected.set(true);
if (userMessage.name() != null) {
return UserMessage.from(userMessage.name(), allContents);
} else {
return UserMessage.from(allContents);
}
}
return message;
})
.collect(toList());
}
}

View File

@ -15,15 +15,19 @@ import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.service.AiServices;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.stream.Stream;
import static dev.langchain4j.internal.Utils.readBytes;
import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.*;
class VertexAiGeminiChatModelIT {
@ -65,17 +69,65 @@ class VertexAiGeminiChatModelIT {
assertThat(response.finishReason()).isEqualTo(FinishReason.STOP);
}
@Test
void should_deny_system_message() {
@ParameterizedTest
@MethodSource
void should_merge_system_messages_into_user_message(List<ChatMessage> messages) {
// given
SystemMessage systemMessage = SystemMessage.from("Be polite");
UserMessage userMessage = UserMessage.from("Tell me a joke");
// when
Response<AiMessage> response = model.generate(messages);
// when-then
assertThatThrownBy(() -> model.generate(systemMessage, userMessage))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("SystemMessage is currently not supported by Gemini");
// then
assertThat(response.content().text()).containsIgnoringCase("liebe");
}
static Stream<Arguments> should_merge_system_messages_into_user_message() {
return Stream.<Arguments>builder()
.add(Arguments.of(
asList(
SystemMessage.from("Translate in German"),
UserMessage.from("I love you")
)
))
.add(Arguments.of(
asList(
UserMessage.from("I love you"),
SystemMessage.from("Translate in German")
)
))
.add(Arguments.of(
asList(
SystemMessage.from("Translate in Italian"),
UserMessage.from("I love you"),
SystemMessage.from("No, translate in German!")
)
))
.add(Arguments.of(
asList(
SystemMessage.from("Translate in German"),
UserMessage.from(asList(
TextContent.from("I love you"),
TextContent.from("I see you")
))
)
))
.add(Arguments.of(
asList(
SystemMessage.from("Translate in German"),
UserMessage.from(asList(
TextContent.from("I see you"),
TextContent.from("I love you")
))
)
))
.add(Arguments.of(
asList(
SystemMessage.from("Translate in German"),
UserMessage.from("I see you"),
AiMessage.from("Ich sehe dich"),
UserMessage.from("I love you")
)
))
.build();
}
@Test
@ -104,7 +156,7 @@ class VertexAiGeminiChatModelIT {
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
assertThat(response.finishReason()).isEqualTo(LENGTH);
assertThat(response.finishReason()).isEqualTo(STOP);
}
@Test

View File

@ -12,11 +12,15 @@ import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.output.Response;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;
import static dev.langchain4j.internal.Utils.readBytes;
import static dev.langchain4j.model.output.FinishReason.LENGTH;
@ -27,7 +31,6 @@ import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
class VertexAiGeminiStreamingChatModelIT {
@ -44,45 +47,18 @@ class VertexAiGeminiStreamingChatModelIT {
.build();
@Test
void should_stream_answer() throws Exception {
void should_stream_answer() {
// given
String userMessage = "What is the capital of Germany?";
// when
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
model.generate(userMessage, new StreamingResponseHandler<AiMessage>() {
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
System.out.println("onNext: '" + token + "'");
answerBuilder.append(token);
}
@Override
public void onComplete(Response<AiMessage> response) {
System.out.println("onComplete: '" + response.content() + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureAnswer.completeExceptionally(error);
futureResponse.completeExceptionally(error);
}
});
String answer = futureAnswer.get(30, SECONDS);
Response<AiMessage> response = futureResponse.get(30, SECONDS);
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
model.generate(userMessage, handler);
Response<AiMessage> response = handler.get();
// then
assertThat(answer).contains("Berlin");
assertThat(response.content().text()).isEqualTo(answer);
assertThat(response.content().text()).contains("Berlin");
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(7);
assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0);
@ -92,21 +68,71 @@ class VertexAiGeminiStreamingChatModelIT {
assertThat(response.finishReason()).isEqualTo(STOP);
}
@Test
void should_deny_system_message() {
@ParameterizedTest
@MethodSource
void should_merge_system_messages_into_user_message(List<ChatMessage> messages) {
// given
SystemMessage systemMessage = SystemMessage.from("Be polite");
UserMessage userMessage = UserMessage.from("Tell me a joke");
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
model.generate(messages, handler);
Response<AiMessage> response = handler.get();
// when-then
assertThatThrownBy(() -> model.generate(asList(systemMessage, userMessage), null))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("SystemMessage is currently not supported by Gemini");
// then
assertThat(response.content().text()).containsIgnoringCase("liebe");
}
static Stream<Arguments> should_merge_system_messages_into_user_message() {
return Stream.<Arguments>builder()
.add(Arguments.of(
asList(
SystemMessage.from("Translate in German"),
UserMessage.from("I love you")
)
))
.add(Arguments.of(
asList(
UserMessage.from("I love you"),
SystemMessage.from("Translate in German")
)
))
.add(Arguments.of(
asList(
SystemMessage.from("Translate in Italian"),
UserMessage.from("I love you"),
SystemMessage.from("No, translate in German!")
)
))
.add(Arguments.of(
asList(
SystemMessage.from("Translate in German"),
UserMessage.from(asList(
TextContent.from("I love you"),
TextContent.from("I see you")
))
)
))
.add(Arguments.of(
asList(
SystemMessage.from("Translate in German"),
UserMessage.from(asList(
TextContent.from("I see you"),
TextContent.from("I love you")
))
)
))
.add(Arguments.of(
asList(
SystemMessage.from("Translate in German"),
UserMessage.from("I see you"),
AiMessage.from("Ich sehe dich"),
UserMessage.from("I love you")
)
))
.build();
}
@Test
void should_respect_maxOutputTokens() throws Exception {
void should_respect_maxOutputTokens() {
// given
StreamingChatLanguageModel model = VertexAiGeminiStreamingChatModel.builder()
@ -119,50 +145,23 @@ class VertexAiGeminiStreamingChatModelIT {
String userMessage = "Tell me a joke";
// when
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
model.generate(userMessage, new StreamingResponseHandler<AiMessage>() {
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
System.out.println("onNext: '" + token + "'");
answerBuilder.append(token);
}
@Override
public void onComplete(Response<AiMessage> response) {
System.out.println("onComplete: '" + response.content() + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureAnswer.completeExceptionally(error);
futureResponse.completeExceptionally(error);
}
});
String answer = futureAnswer.get(30, SECONDS);
Response<AiMessage> response = futureResponse.get(30, SECONDS);
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
model.generate(userMessage, handler);
Response<AiMessage> response = handler.get();
// then
assertThat(answer).isNotBlank();
assertThat(response.content().text()).isEqualTo(answer);
assertThat(response.content().text()).isNotBlank();
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(4);
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(1);
assertThat(response.tokenUsage().totalTokenCount())
.isEqualTo(response.tokenUsage().inputTokenCount() + response.tokenUsage().outputTokenCount());
assertThat(response.finishReason()).isEqualTo(LENGTH);
assertThat(response.finishReason()).isEqualTo(STOP);
}
@Test
void should_allow_custom_generativeModel_and_generationConfig() throws Exception {
void should_allow_custom_generativeModel_and_generationConfig() {
// given
VertexAI vertexAi = new VertexAI(System.getenv("GCP_PROJECT_ID"), System.getenv("GCP_LOCATION"));
@ -174,26 +173,9 @@ class VertexAiGeminiStreamingChatModelIT {
String userMessage = "What is the capital of Germany?";
// when
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
model.generate(userMessage, new StreamingResponseHandler<AiMessage>() {
@Override
public void onNext(String token) {
}
@Override
public void onComplete(Response<AiMessage> response) {
futureResponse.complete(response);
}
@Override
public void onError(Throwable error) {
futureResponse.completeExceptionally(error);
}
});
Response<AiMessage> response = futureResponse.get(30, SECONDS);
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
model.generate(userMessage, handler);
Response<AiMessage> response = handler.get();
// then
assertThat(response.content().text()).contains("Berlin");
@ -347,17 +329,17 @@ class VertexAiGeminiStreamingChatModelIT {
// given
VertexAiGeminiStreamingChatModel model = VertexAiGeminiStreamingChatModel.builder()
.project(System.getenv("GCP_PROJECT_ID"))
.location(System.getenv("GCP_LOCATION"))
.modelName("gemini-pro")
.build();
.project(System.getenv("GCP_PROJECT_ID"))
.location(System.getenv("GCP_LOCATION"))
.modelName("gemini-pro")
.build();
ToolSpecification weatherToolSpec = ToolSpecification.builder()
.name("getWeatherForecast")
.description("Get the weather forecast for a location")
.addParameter("location", JsonSchemaProperty.STRING,
JsonSchemaProperty.description("the location to get the weather forecast for"))
.build();
.name("getWeatherForecast")
.description("Get the weather forecast for a location")
.addParameter("location", JsonSchemaProperty.STRING,
JsonSchemaProperty.description("the location to get the weather forecast for"))
.build();
List<ChatMessage> allMessages = new ArrayList<>();
@ -381,7 +363,7 @@ class VertexAiGeminiStreamingChatModelIT {
// when (feeding the function return value back)
ToolExecutionResultMessage toolExecResMsg = ToolExecutionResultMessage.from(toolExecutionRequest,
"{\"location\":\"Paris\",\"forecast\":\"sunny\", \"temperature\": 20}");
"{\"location\":\"Paris\",\"forecast\":\"sunny\", \"temperature\": 20}");
allMessages.add(toolExecResMsg);
handler = new TestStreamingResponseHandler<>();