Fix #757: Gemini: allow SystemMessage(s), merge them into the first UserMessage, warn in logs (#812)
## Context See https://github.com/langchain4j/langchain4j/issues/757 ## Change All `SystemMessage`s from the input are now merged together into the first `UserMessage`. Warning abut this is given (once) in the log. ## Checklist Before submitting this PR, please check the following points: - [X] I have added unit and integration tests for my change - [X] All unit and integration tests in the module I have added/changed are green - [X] All unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules are green - [X] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added my new module in the [BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml) (only when a new module is added)
This commit is contained in:
parent
b3c0dad47d
commit
016f0b60ea
|
@ -140,6 +140,12 @@ Caused by: io.grpc.StatusRuntimeException:
|
|||
`projects/{YOUR_PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-ultra`
|
||||
```
|
||||
|
||||
## Warning
|
||||
|
||||
Please note that Gemini does not support `SystemMessage`s.
|
||||
If there are `SystemMessage`s provided to the `generate()` methods, they will be merged into the first
|
||||
`UserMessage` (before the content of the `UserMessage`).
|
||||
|
||||
## Apply for early access
|
||||
|
||||
[Early access for Gemma](https://docs.google.com/forms/d/e/1FAIpQLSe0grG6mRFW6dNF3Rb1h_YvKqUp2GaXiglZBgA2Os5iTLWlcg/viewform)
|
||||
|
|
|
@ -74,6 +74,17 @@
|
|||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.tinylog</groupId>
|
||||
<artifactId>tinylog-impl</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.tinylog</groupId>
|
||||
<artifactId>slf4j-tinylog</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<dependencyManagement>
|
||||
|
|
|
@ -1,26 +1,75 @@
|
|||
package dev.langchain4j.model.vertexai;
|
||||
|
||||
import com.google.cloud.vertexai.api.Content;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.SystemMessage;
|
||||
import dev.langchain4j.data.message.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
@Slf4j
|
||||
class ContentsMapper {
|
||||
|
||||
static List<Content> map(List<ChatMessage> messages) {
|
||||
private static volatile boolean warned = false;
|
||||
|
||||
static List<com.google.cloud.vertexai.api.Content> map(List<ChatMessage> messages) {
|
||||
|
||||
List<SystemMessage> systemMessages = messages.stream()
|
||||
.filter(message -> message instanceof SystemMessage)
|
||||
.map(message -> (SystemMessage) message)
|
||||
.collect(toList());
|
||||
|
||||
if (!systemMessages.isEmpty()) {
|
||||
if (!warned) {
|
||||
log.warn("Gemini does not support SystemMessage(s). " +
|
||||
"All SystemMessage(s) will be merged into the first UserMessage.");
|
||||
warned = true;
|
||||
}
|
||||
messages = mergeSystemMessagesIntoUserMessage(messages, systemMessages);
|
||||
}
|
||||
|
||||
// TODO what if only a single system message?
|
||||
|
||||
return messages.stream()
|
||||
.peek(message -> {
|
||||
if (message instanceof SystemMessage) {
|
||||
throw new IllegalArgumentException("SystemMessage is currently not supported by Gemini");
|
||||
}
|
||||
})
|
||||
.map(message -> Content.newBuilder()
|
||||
.map(message -> com.google.cloud.vertexai.api.Content.newBuilder()
|
||||
.setRole(RoleMapper.map(message.type()))
|
||||
.addAllParts(PartsMapper.map(message))
|
||||
.build())
|
||||
.collect(toList());
|
||||
}
|
||||
|
||||
private static List<ChatMessage> mergeSystemMessagesIntoUserMessage(List<ChatMessage> messages,
|
||||
List<SystemMessage> systemMessages) {
|
||||
AtomicBoolean injected = new AtomicBoolean(false);
|
||||
return messages.stream()
|
||||
.filter(message -> !(message instanceof SystemMessage))
|
||||
.map(message -> {
|
||||
if (injected.get()) {
|
||||
return message;
|
||||
}
|
||||
|
||||
if (message instanceof UserMessage) {
|
||||
UserMessage userMessage = (UserMessage) message;
|
||||
|
||||
List<Content> allContents = new ArrayList<>();
|
||||
allContents.addAll(systemMessages.stream()
|
||||
.map(systemMessage -> TextContent.from(systemMessage.text()))
|
||||
.collect(toList()));
|
||||
allContents.addAll(userMessage.contents());
|
||||
|
||||
injected.set(true);
|
||||
|
||||
if (userMessage.name() != null) {
|
||||
return UserMessage.from(userMessage.name(), allContents);
|
||||
} else {
|
||||
return UserMessage.from(allContents);
|
||||
}
|
||||
}
|
||||
|
||||
return message;
|
||||
})
|
||||
.collect(toList());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,15 +15,19 @@ import dev.langchain4j.model.output.Response;
|
|||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Base64;
|
||||
import java.util.List;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.readBytes;
|
||||
import static dev.langchain4j.model.output.FinishReason.LENGTH;
|
||||
import static dev.langchain4j.model.output.FinishReason.STOP;
|
||||
import static java.util.Arrays.asList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
class VertexAiGeminiChatModelIT {
|
||||
|
@ -65,17 +69,65 @@ class VertexAiGeminiChatModelIT {
|
|||
assertThat(response.finishReason()).isEqualTo(FinishReason.STOP);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_deny_system_message() {
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void should_merge_system_messages_into_user_message(List<ChatMessage> messages) {
|
||||
|
||||
// given
|
||||
SystemMessage systemMessage = SystemMessage.from("Be polite");
|
||||
UserMessage userMessage = UserMessage.from("Tell me a joke");
|
||||
// when
|
||||
Response<AiMessage> response = model.generate(messages);
|
||||
|
||||
// when-then
|
||||
assertThatThrownBy(() -> model.generate(systemMessage, userMessage))
|
||||
.isExactlyInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("SystemMessage is currently not supported by Gemini");
|
||||
// then
|
||||
assertThat(response.content().text()).containsIgnoringCase("liebe");
|
||||
}
|
||||
|
||||
static Stream<Arguments> should_merge_system_messages_into_user_message() {
|
||||
return Stream.<Arguments>builder()
|
||||
.add(Arguments.of(
|
||||
asList(
|
||||
SystemMessage.from("Translate in German"),
|
||||
UserMessage.from("I love you")
|
||||
)
|
||||
))
|
||||
.add(Arguments.of(
|
||||
asList(
|
||||
UserMessage.from("I love you"),
|
||||
SystemMessage.from("Translate in German")
|
||||
)
|
||||
))
|
||||
.add(Arguments.of(
|
||||
asList(
|
||||
SystemMessage.from("Translate in Italian"),
|
||||
UserMessage.from("I love you"),
|
||||
SystemMessage.from("No, translate in German!")
|
||||
)
|
||||
))
|
||||
.add(Arguments.of(
|
||||
asList(
|
||||
SystemMessage.from("Translate in German"),
|
||||
UserMessage.from(asList(
|
||||
TextContent.from("I love you"),
|
||||
TextContent.from("I see you")
|
||||
))
|
||||
)
|
||||
))
|
||||
.add(Arguments.of(
|
||||
asList(
|
||||
SystemMessage.from("Translate in German"),
|
||||
UserMessage.from(asList(
|
||||
TextContent.from("I see you"),
|
||||
TextContent.from("I love you")
|
||||
))
|
||||
)
|
||||
))
|
||||
.add(Arguments.of(
|
||||
asList(
|
||||
SystemMessage.from("Translate in German"),
|
||||
UserMessage.from("I see you"),
|
||||
AiMessage.from("Ich sehe dich"),
|
||||
UserMessage.from("I love you")
|
||||
)
|
||||
))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -104,7 +156,7 @@ class VertexAiGeminiChatModelIT {
|
|||
assertThat(tokenUsage.totalTokenCount())
|
||||
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
|
||||
|
||||
assertThat(response.finishReason()).isEqualTo(LENGTH);
|
||||
assertThat(response.finishReason()).isEqualTo(STOP);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -12,11 +12,15 @@ import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
|||
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Base64;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.readBytes;
|
||||
import static dev.langchain4j.model.output.FinishReason.LENGTH;
|
||||
|
@ -27,7 +31,6 @@ import static java.util.Arrays.asList;
|
|||
import static java.util.Collections.singletonList;
|
||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
class VertexAiGeminiStreamingChatModelIT {
|
||||
|
||||
|
@ -44,45 +47,18 @@ class VertexAiGeminiStreamingChatModelIT {
|
|||
.build();
|
||||
|
||||
@Test
|
||||
void should_stream_answer() throws Exception {
|
||||
void should_stream_answer() {
|
||||
|
||||
// given
|
||||
String userMessage = "What is the capital of Germany?";
|
||||
|
||||
// when
|
||||
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
|
||||
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
|
||||
|
||||
model.generate(userMessage, new StreamingResponseHandler<AiMessage>() {
|
||||
|
||||
private final StringBuilder answerBuilder = new StringBuilder();
|
||||
|
||||
@Override
|
||||
public void onNext(String token) {
|
||||
System.out.println("onNext: '" + token + "'");
|
||||
answerBuilder.append(token);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete(Response<AiMessage> response) {
|
||||
System.out.println("onComplete: '" + response.content() + "'");
|
||||
futureAnswer.complete(answerBuilder.toString());
|
||||
futureResponse.complete(response);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable error) {
|
||||
futureAnswer.completeExceptionally(error);
|
||||
futureResponse.completeExceptionally(error);
|
||||
}
|
||||
});
|
||||
|
||||
String answer = futureAnswer.get(30, SECONDS);
|
||||
Response<AiMessage> response = futureResponse.get(30, SECONDS);
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
model.generate(userMessage, handler);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
// then
|
||||
assertThat(answer).contains("Berlin");
|
||||
assertThat(response.content().text()).isEqualTo(answer);
|
||||
assertThat(response.content().text()).contains("Berlin");
|
||||
|
||||
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(7);
|
||||
assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0);
|
||||
|
@ -92,21 +68,71 @@ class VertexAiGeminiStreamingChatModelIT {
|
|||
assertThat(response.finishReason()).isEqualTo(STOP);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_deny_system_message() {
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void should_merge_system_messages_into_user_message(List<ChatMessage> messages) {
|
||||
|
||||
// given
|
||||
SystemMessage systemMessage = SystemMessage.from("Be polite");
|
||||
UserMessage userMessage = UserMessage.from("Tell me a joke");
|
||||
// when
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
model.generate(messages, handler);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
// when-then
|
||||
assertThatThrownBy(() -> model.generate(asList(systemMessage, userMessage), null))
|
||||
.isExactlyInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("SystemMessage is currently not supported by Gemini");
|
||||
// then
|
||||
assertThat(response.content().text()).containsIgnoringCase("liebe");
|
||||
}
|
||||
|
||||
static Stream<Arguments> should_merge_system_messages_into_user_message() {
|
||||
return Stream.<Arguments>builder()
|
||||
.add(Arguments.of(
|
||||
asList(
|
||||
SystemMessage.from("Translate in German"),
|
||||
UserMessage.from("I love you")
|
||||
)
|
||||
))
|
||||
.add(Arguments.of(
|
||||
asList(
|
||||
UserMessage.from("I love you"),
|
||||
SystemMessage.from("Translate in German")
|
||||
)
|
||||
))
|
||||
.add(Arguments.of(
|
||||
asList(
|
||||
SystemMessage.from("Translate in Italian"),
|
||||
UserMessage.from("I love you"),
|
||||
SystemMessage.from("No, translate in German!")
|
||||
)
|
||||
))
|
||||
.add(Arguments.of(
|
||||
asList(
|
||||
SystemMessage.from("Translate in German"),
|
||||
UserMessage.from(asList(
|
||||
TextContent.from("I love you"),
|
||||
TextContent.from("I see you")
|
||||
))
|
||||
)
|
||||
))
|
||||
.add(Arguments.of(
|
||||
asList(
|
||||
SystemMessage.from("Translate in German"),
|
||||
UserMessage.from(asList(
|
||||
TextContent.from("I see you"),
|
||||
TextContent.from("I love you")
|
||||
))
|
||||
)
|
||||
))
|
||||
.add(Arguments.of(
|
||||
asList(
|
||||
SystemMessage.from("Translate in German"),
|
||||
UserMessage.from("I see you"),
|
||||
AiMessage.from("Ich sehe dich"),
|
||||
UserMessage.from("I love you")
|
||||
)
|
||||
))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_respect_maxOutputTokens() throws Exception {
|
||||
void should_respect_maxOutputTokens() {
|
||||
|
||||
// given
|
||||
StreamingChatLanguageModel model = VertexAiGeminiStreamingChatModel.builder()
|
||||
|
@ -119,50 +145,23 @@ class VertexAiGeminiStreamingChatModelIT {
|
|||
String userMessage = "Tell me a joke";
|
||||
|
||||
// when
|
||||
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
|
||||
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
|
||||
|
||||
model.generate(userMessage, new StreamingResponseHandler<AiMessage>() {
|
||||
|
||||
private final StringBuilder answerBuilder = new StringBuilder();
|
||||
|
||||
@Override
|
||||
public void onNext(String token) {
|
||||
System.out.println("onNext: '" + token + "'");
|
||||
answerBuilder.append(token);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete(Response<AiMessage> response) {
|
||||
System.out.println("onComplete: '" + response.content() + "'");
|
||||
futureAnswer.complete(answerBuilder.toString());
|
||||
futureResponse.complete(response);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable error) {
|
||||
futureAnswer.completeExceptionally(error);
|
||||
futureResponse.completeExceptionally(error);
|
||||
}
|
||||
});
|
||||
|
||||
String answer = futureAnswer.get(30, SECONDS);
|
||||
Response<AiMessage> response = futureResponse.get(30, SECONDS);
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
model.generate(userMessage, handler);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
// then
|
||||
assertThat(answer).isNotBlank();
|
||||
assertThat(response.content().text()).isEqualTo(answer);
|
||||
assertThat(response.content().text()).isNotBlank();
|
||||
|
||||
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(4);
|
||||
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(1);
|
||||
assertThat(response.tokenUsage().totalTokenCount())
|
||||
.isEqualTo(response.tokenUsage().inputTokenCount() + response.tokenUsage().outputTokenCount());
|
||||
|
||||
assertThat(response.finishReason()).isEqualTo(LENGTH);
|
||||
assertThat(response.finishReason()).isEqualTo(STOP);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_allow_custom_generativeModel_and_generationConfig() throws Exception {
|
||||
void should_allow_custom_generativeModel_and_generationConfig() {
|
||||
|
||||
// given
|
||||
VertexAI vertexAi = new VertexAI(System.getenv("GCP_PROJECT_ID"), System.getenv("GCP_LOCATION"));
|
||||
|
@ -174,26 +173,9 @@ class VertexAiGeminiStreamingChatModelIT {
|
|||
String userMessage = "What is the capital of Germany?";
|
||||
|
||||
// when
|
||||
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
|
||||
|
||||
model.generate(userMessage, new StreamingResponseHandler<AiMessage>() {
|
||||
|
||||
@Override
|
||||
public void onNext(String token) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete(Response<AiMessage> response) {
|
||||
futureResponse.complete(response);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable error) {
|
||||
futureResponse.completeExceptionally(error);
|
||||
}
|
||||
});
|
||||
|
||||
Response<AiMessage> response = futureResponse.get(30, SECONDS);
|
||||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
|
||||
model.generate(userMessage, handler);
|
||||
Response<AiMessage> response = handler.get();
|
||||
|
||||
// then
|
||||
assertThat(response.content().text()).contains("Berlin");
|
||||
|
@ -347,17 +329,17 @@ class VertexAiGeminiStreamingChatModelIT {
|
|||
|
||||
// given
|
||||
VertexAiGeminiStreamingChatModel model = VertexAiGeminiStreamingChatModel.builder()
|
||||
.project(System.getenv("GCP_PROJECT_ID"))
|
||||
.location(System.getenv("GCP_LOCATION"))
|
||||
.modelName("gemini-pro")
|
||||
.build();
|
||||
.project(System.getenv("GCP_PROJECT_ID"))
|
||||
.location(System.getenv("GCP_LOCATION"))
|
||||
.modelName("gemini-pro")
|
||||
.build();
|
||||
|
||||
ToolSpecification weatherToolSpec = ToolSpecification.builder()
|
||||
.name("getWeatherForecast")
|
||||
.description("Get the weather forecast for a location")
|
||||
.addParameter("location", JsonSchemaProperty.STRING,
|
||||
JsonSchemaProperty.description("the location to get the weather forecast for"))
|
||||
.build();
|
||||
.name("getWeatherForecast")
|
||||
.description("Get the weather forecast for a location")
|
||||
.addParameter("location", JsonSchemaProperty.STRING,
|
||||
JsonSchemaProperty.description("the location to get the weather forecast for"))
|
||||
.build();
|
||||
|
||||
List<ChatMessage> allMessages = new ArrayList<>();
|
||||
|
||||
|
@ -381,7 +363,7 @@ class VertexAiGeminiStreamingChatModelIT {
|
|||
|
||||
// when (feeding the function return value back)
|
||||
ToolExecutionResultMessage toolExecResMsg = ToolExecutionResultMessage.from(toolExecutionRequest,
|
||||
"{\"location\":\"Paris\",\"forecast\":\"sunny\", \"temperature\": 20}");
|
||||
"{\"location\":\"Paris\",\"forecast\":\"sunny\", \"temperature\": 20}");
|
||||
allMessages.add(toolExecResMsg);
|
||||
|
||||
handler = new TestStreamingResponseHandler<>();
|
||||
|
|
Loading…
Reference in New Issue