New ScoringModel for Google Cloud Vertex AI Ranking API (#1820)
## Issue Closes #1819 ## Change Add support for the Vertex AI Ranking API, by implement a `ScoringModel` for it. ## General checklist - [X] There are no breaking changes - [X] I have added unit and integration tests for my change - [X] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [ ] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green <!-- Before adding documentation and example(s) (below), please wait until the PR is reviewed and approved. --> - [X] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable)
This commit is contained in:
parent
ef2f1ec470
commit
9e2ee938d1
|
@ -0,0 +1,97 @@
|
|||
---
|
||||
sidebar_position: 3
|
||||
---
|
||||
|
||||
# Google Cloud Vertex AI Ranking API
|
||||
|
||||
- [Google Cloud Vertex AI Ranking documentation](https://cloud.google.com/generative-ai-app-builder/docs/ranking)
|
||||
- [Google Cloud Vertex AI Ranking API description](https://cloud.google.com/generative-ai-app-builder/docs/reference/rest/v1/projects.locations.rankingConfigs/rank)
|
||||
|
||||
|
||||
### Introduction
|
||||
|
||||
The Google Cloud Vertex AI Ranking API is a powerful tool that enhances search results by refining the relevance of
|
||||
retrieved documents to a given query. Unlike traditional search methods, it leverages advanced machine learning
|
||||
algorithms to understand the semantic context of both the query and the documents, delivering more precise and relevant
|
||||
results. By analyzing the semantic relationship between the query and each document, the API can reorder the candidate
|
||||
documents based on their calculated relevance scores, ensuring that the most relevant results appear at the top of the
|
||||
search results page.
|
||||
|
||||
### Maven Dependency
|
||||
|
||||
```xml
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-vertex-ai</artifactId>
|
||||
<version>0.34.0</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
To configure the model, you'll have to specify:
|
||||
* the Google Cloud project ID,
|
||||
* the project number,
|
||||
* the location (ex. `us-central1`, `europe-west1`),
|
||||
* and the model you want to use.
|
||||
|
||||
> Note: You can find the project number in the Google Cloud console, or by running `gcloud projects describe your-project-id`.
|
||||
|
||||
You can score a single string or `TextSegment` against a query
|
||||
thanks to the `score(text, query)` and `score(segment, query)` methods.
|
||||
|
||||
It is also possible to score several strings or `TextSegment`s against the query,
|
||||
with the `scoreAll(segments, query)` method:
|
||||
|
||||
```java
|
||||
VertexAiScoringModel scoringModel = VertexAiScoringModel.builder()
|
||||
.projectId(System.getenv("GCP_PROJECT_ID"))
|
||||
.projectNumber(System.getenv("GCP_PROJECT_NUM"))
|
||||
.projectLocation(System.getenv("GCP_LOCATION"))
|
||||
.model("semantic-ranker-512")
|
||||
.build();
|
||||
|
||||
Response<List<Double>> score = scoringModel.scoreAll(Stream.of(
|
||||
"The sky appears blue due to a phenomenon called Rayleigh scattering. " +
|
||||
"Sunlight is comprised of all the colors of the rainbow. Blue light has shorter " +
|
||||
"wavelengths than other colors, and is thus scattered more easily.",
|
||||
|
||||
"A canvas stretched across the day,\n" +
|
||||
"Where sunlight learns to dance and play.\n" +
|
||||
"Blue, a hue of scattered light,\n" +
|
||||
"A gentle whisper, soft and bright."
|
||||
).map(TextSegment::from).collect(Collectors.toList()),
|
||||
"Why is the sky blue?");
|
||||
|
||||
// [0.8199999928474426, 0.4300000071525574]
|
||||
```
|
||||
|
||||
If you pass `TextSegment`s which have a particular `title` key, the Ranker model can take this metadata into account in its calculation.
|
||||
To specify a custom title key, you can use the `titleMetadataKey()` builder method.`
|
||||
|
||||
You can use scoring models with `AiServices` and its `contentAgregator()` method,
|
||||
which takes a `ContentAggregator` class that can specify a scoring model:
|
||||
|
||||
```java
|
||||
VertexAiScoringModel scoringModel = VertexAiScoringModel.builder()
|
||||
.projectId(System.getenv("GCP_PROJECT_ID"))
|
||||
.projectNumber(System.getenv("GCP_PROJECT_NUM"))
|
||||
.projectLocation(System.getenv("GCP_LOCATION"))
|
||||
.model("semantic-ranker-512")
|
||||
.build();
|
||||
|
||||
ContentAggregator contentAggregator = ReRankingContentAggregator.builder()
|
||||
.scoringModel(scoringModel)
|
||||
...
|
||||
.build();
|
||||
|
||||
RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
|
||||
...
|
||||
.contentAggregator(contentAggregator)
|
||||
.build();
|
||||
|
||||
return AiServices.builder(Assistant.class)
|
||||
.chatLanguageModel(...)
|
||||
.retrievalAugmentor(retrievalAugmentor)
|
||||
.build();
|
||||
```
|
|
@ -21,6 +21,8 @@
|
|||
<artifactId>langchain4j-core</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- Google Vertex AI library -->
|
||||
|
||||
<dependency>
|
||||
<groupId>com.google.cloud</groupId>
|
||||
<artifactId>google-cloud-aiplatform</artifactId>
|
||||
|
@ -33,6 +35,21 @@
|
|||
</exclusions>
|
||||
</dependency>
|
||||
|
||||
<!-- Ranking API -->
|
||||
|
||||
<dependency>
|
||||
<groupId>com.google.cloud</groupId>
|
||||
<artifactId>google-cloud-discoveryengine</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- testing dependencies -->
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-engine</artifactId>
|
||||
|
@ -59,6 +76,18 @@
|
|||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<dependencyManagement>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.google.cloud</groupId>
|
||||
<artifactId>google-cloud-discoveryengine-bom</artifactId>
|
||||
<scope>import</scope>
|
||||
<type>pom</type>
|
||||
<version>0.45.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
|
|
|
@ -0,0 +1,141 @@
|
|||
package dev.langchain4j.model.vertexai;
|
||||
|
||||
import com.google.cloud.discoveryengine.v1beta.*;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.scoring.ScoringModel;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static java.util.Comparator.comparing;
|
||||
|
||||
/**
|
||||
* Implementation of a <code>ScoringModel</code> for the Google Cloud Vertex AI
|
||||
* <a href="https://cloud.google.com/generative-ai-app-builder/docs/ranking">Ranking API</a>.
|
||||
*/
|
||||
public class VertexAiScoringModel implements ScoringModel {
|
||||
|
||||
private final String model;
|
||||
private final String projectId;
|
||||
private final String projectNumber;
|
||||
private final String location;
|
||||
private final String titleMetadataKey;
|
||||
|
||||
/**
|
||||
* Constructor for the Vertex AI Ranker Scoring Model.
|
||||
*
|
||||
* @param projectId The Google Cloud Project ID.
|
||||
* @param projectNumber The Google Cloud Project Number.
|
||||
* @param location The Google Cloud Region.
|
||||
* @param model The model to use
|
||||
* @param titleMetadataKey The name of the key to use as a title.
|
||||
*/
|
||||
public VertexAiScoringModel(String projectId, String projectNumber, String location, String model, String titleMetadataKey) {
|
||||
this.projectId = ensureNotBlank(projectId, "projectId");
|
||||
this.projectNumber = ensureNotBlank(projectNumber, "projectNumber");
|
||||
this.location = ensureNotBlank(location, "location");
|
||||
this.model = ensureNotBlank(model, "model");
|
||||
this.titleMetadataKey = titleMetadataKey != null ? titleMetadataKey : "title";
|
||||
}
|
||||
|
||||
/**
|
||||
* Scores all provided {@link TextSegment}s against a given query.
|
||||
*
|
||||
* @param segments The list of {@link TextSegment}s to score.
|
||||
* @param query The query against which to score the segments.
|
||||
* @return the list of scores. The order of scores corresponds to the order of {@link TextSegment}s.
|
||||
*/
|
||||
@Override
|
||||
public Response<List<Double>> scoreAll(List<TextSegment> segments, String query) {
|
||||
AtomicInteger counter = new AtomicInteger();
|
||||
|
||||
try (RankServiceClient rankServiceClient = RankServiceClient.create(
|
||||
RankServiceSettings.newBuilder().build())) {
|
||||
|
||||
RankRequest.Builder rankingRequestBuilder = RankRequest.newBuilder();
|
||||
|
||||
if (model != null && !model.isEmpty()) {
|
||||
rankingRequestBuilder.setModel(model);
|
||||
}
|
||||
|
||||
rankingRequestBuilder
|
||||
.setRankingConfig(RankingConfigName.newBuilder()
|
||||
.setProject(projectId)
|
||||
.setLocation(location)
|
||||
.setRankingConfig(
|
||||
String.format("projects/%s/locations/%s/rankingConfigs/default_ranking_config.", projectNumber, location))
|
||||
.build().getRankingConfig())
|
||||
.setQuery(query)
|
||||
.setIgnoreRecordDetailsInResponse(true)
|
||||
.addAllRecords(segments.stream()
|
||||
.map(segment -> {
|
||||
RankingRecord.Builder rankingBuilder = RankingRecord.newBuilder()
|
||||
.setContent(segment.text());
|
||||
// Ranker API takes into account titles in its score calculations
|
||||
if (segment.metadata().getString(titleMetadataKey) != null) {
|
||||
rankingBuilder.setTitle(segment.metadata().getString(titleMetadataKey));
|
||||
}
|
||||
// custom ID used to reorder the (sorted) results back into original segment order
|
||||
rankingBuilder.setId(String.valueOf(counter.getAndIncrement()));
|
||||
return rankingBuilder.build();
|
||||
})
|
||||
.collect(Collectors.toList()));
|
||||
|
||||
RankResponse rankResponse = rankServiceClient.rank(rankingRequestBuilder.build());
|
||||
|
||||
return Response.from(rankResponse.getRecordsList().stream()
|
||||
// the API returns results sorted by relevance score, so reorder them back to original order
|
||||
.sorted(comparing(rr -> Double.valueOf(rr.getId())))
|
||||
.map(RankingRecord::getScore)
|
||||
.map(Double::valueOf)
|
||||
.collect(Collectors.toList()));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private String model;
|
||||
private String projectId;
|
||||
private String projectNumber;
|
||||
private String location;
|
||||
private String titleMetadataKey;
|
||||
|
||||
public Builder model(String model) {
|
||||
this.model = ensureNotBlank(model, "model");
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder projectId(String projectId) {
|
||||
this.projectId = projectId;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder projectNumber(String projectNumber) {
|
||||
this.projectNumber = projectNumber;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder location(String location) {
|
||||
this.location = location;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder titleMetadataKey(String titleMetadataKey) {
|
||||
this.titleMetadataKey = ensureNotBlank(titleMetadataKey, "titleMetadataKey");
|
||||
return this;
|
||||
}
|
||||
|
||||
public VertexAiScoringModel build() {
|
||||
return new VertexAiScoringModel(projectId, projectNumber, location, model, titleMetadataKey);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,106 @@
|
|||
package dev.langchain4j.model.vertexai;
|
||||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
public class VertexAiScoringModelIT {
|
||||
@Test
|
||||
void should_rank_multiple() {
|
||||
// given
|
||||
VertexAiScoringModel scoringModel = VertexAiScoringModel.builder()
|
||||
.projectId(System.getenv("GCP_PROJECT_ID"))
|
||||
.projectNumber(System.getenv("GCP_PROJECT_NUM"))
|
||||
.location(System.getenv("GCP_LOCATION"))
|
||||
.model("semantic-ranker-512")
|
||||
.build();
|
||||
|
||||
// when
|
||||
Response<List<Double>> score = scoringModel.scoreAll(Stream.of(
|
||||
"The sky appears blue due to a phenomenon called Rayleigh scattering. " +
|
||||
"Sunlight is comprised of all the colors of the rainbow. Blue light has shorter " +
|
||||
"wavelengths than other colors, and is thus scattered more easily.",
|
||||
|
||||
"A canvas stretched across the day,\n" +
|
||||
"Where sunlight learns to dance and play.\n" +
|
||||
"Blue, a hue of scattered light,\n" +
|
||||
"A gentle whisper, soft and bright."
|
||||
).map(TextSegment::from).collect(Collectors.toList()),
|
||||
"Why is the sky blue?");
|
||||
|
||||
// then
|
||||
assertThat(score.content().get(0)).isGreaterThan(score.content().get(1));
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_rank_single() {
|
||||
// given
|
||||
VertexAiScoringModel scoringModel = VertexAiScoringModel.builder()
|
||||
.projectId(System.getenv("GCP_PROJECT_ID"))
|
||||
.projectNumber(System.getenv("GCP_PROJECT_NUM"))
|
||||
.location(System.getenv("GCP_LOCATION"))
|
||||
.model("semantic-ranker-512")
|
||||
.build();
|
||||
|
||||
// when
|
||||
Response<Double> score = scoringModel.score(
|
||||
"The sky appears blue due to a phenomenon called Rayleigh scattering. " +
|
||||
"Sunlight is comprised of all the colors of the rainbow. Blue light has shorter " +
|
||||
"wavelengths than other colors, and is thus scattered more easily.",
|
||||
"Why is the sky blue?");
|
||||
|
||||
// then
|
||||
assertThat(score.content()).isPositive();
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_use_text_segment_titles_into_account() {
|
||||
// given
|
||||
String customTitleKey = "customTitle";
|
||||
|
||||
VertexAiScoringModel scoringModel = VertexAiScoringModel.builder()
|
||||
.projectId(System.getenv("GCP_PROJECT_ID"))
|
||||
.projectNumber(System.getenv("GCP_PROJECT_NUM"))
|
||||
.location(System.getenv("GCP_LOCATION"))
|
||||
.model("semantic-ranker-512")
|
||||
.titleMetadataKey(customTitleKey)
|
||||
.build();
|
||||
|
||||
List<TextSegment> segments = Arrays.asList(
|
||||
new TextSegment(
|
||||
"Your Cymbal Starlight 2024 is not equipped to tow a trailer.",
|
||||
new Metadata().put(customTitleKey, "trailer")),
|
||||
new TextSegment(
|
||||
"The Cymbal Starlight 2024 has a cargo capacity of 13.5 cubic feet.",
|
||||
new Metadata().put(customTitleKey, "capacity")),
|
||||
new TextSegment(
|
||||
"The cargo area is located in the trunk of the vehicle.",
|
||||
new Metadata().put(customTitleKey, "trunk")),
|
||||
new TextSegment(
|
||||
"To access the cargo area, open the trunk lid using the trunk release lever located in the driver's footwell.",
|
||||
new Metadata().put(customTitleKey, "lever")),
|
||||
new TextSegment(
|
||||
"When loading cargo into the trunk, be sure to distribute the weight evenly.",
|
||||
new Metadata().put(customTitleKey, "weight")),
|
||||
new TextSegment(
|
||||
"Do not overload the trunk, as this could affect the vehicle's handling and stability.",
|
||||
new Metadata().put(customTitleKey, "overload"))
|
||||
);
|
||||
|
||||
// when
|
||||
Response<List<Double>> score = scoringModel.scoreAll(segments, "What is the cargo capacity of the car?");
|
||||
|
||||
// then
|
||||
double maxScore = score.content().stream().mapToDouble(Double::doubleValue).max().getAsDouble();
|
||||
|
||||
assertThat(score.content().get(1)).isEqualTo(maxScore);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue