Vertex AI: embed in batches of 5
This commit is contained in:
parent
eef1796963
commit
43917ee474
|
@ -19,6 +19,7 @@ import java.util.List;
|
|||
import static com.google.cloud.aiplatform.util.ValueConverter.EMPTY_VALUE;
|
||||
import static dev.langchain4j.internal.Json.toJson;
|
||||
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
|
@ -28,6 +29,8 @@ import static java.util.stream.Collectors.toList;
|
|||
*/
|
||||
public class VertexAiEmbeddingModel implements EmbeddingModel {
|
||||
|
||||
private static final int BATCH_SIZE = 5; // Vertex AI has a limit of up to 5 input texts per request
|
||||
|
||||
private final PredictionServiceSettings settings;
|
||||
private final EndpointName endpointName;
|
||||
private final Integer maxRetries;
|
||||
|
@ -51,38 +54,37 @@ public class VertexAiEmbeddingModel implements EmbeddingModel {
|
|||
ensureNotBlank(publisher, "publisher"),
|
||||
ensureNotBlank(modelName, "modelName")
|
||||
);
|
||||
this.maxRetries = maxRetries == null ? 3 : maxRetries;
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
|
||||
List<String> texts = textSegments.stream()
|
||||
.map(TextSegment::text)
|
||||
.collect(toList());
|
||||
public Response<List<Embedding>> embedAll(List<TextSegment> segments) {
|
||||
|
||||
return embedTexts(texts);
|
||||
}
|
||||
|
||||
private Response<List<Embedding>> embedTexts(List<String> texts) {
|
||||
try (PredictionServiceClient client = PredictionServiceClient.create(settings)) {
|
||||
|
||||
List<Value> instances = new ArrayList<>();
|
||||
for (String text : texts) {
|
||||
Value.Builder instanceBuilder = Value.newBuilder();
|
||||
JsonFormat.parser().merge(toJson(new VertexAiEmbeddingInstance(text)), instanceBuilder);
|
||||
instances.add(instanceBuilder.build());
|
||||
}
|
||||
|
||||
PredictResponse response = withRetry(() -> client.predict(endpointName, instances, EMPTY_VALUE), maxRetries);
|
||||
|
||||
List<Embedding> embeddings = response.getPredictionsList().stream()
|
||||
.map(VertexAiEmbeddingModel::toVector)
|
||||
.map(Embedding::from)
|
||||
.collect(toList());
|
||||
|
||||
List<Embedding> embeddings = new ArrayList<>();
|
||||
int inputTokenCount = 0;
|
||||
for (Value value : response.getPredictionsList()) {
|
||||
inputTokenCount += extractTokenCount(value);
|
||||
|
||||
for (int i = 0; i < segments.size(); i += BATCH_SIZE) {
|
||||
|
||||
List<TextSegment> batch = segments.subList(i, Math.min(i + BATCH_SIZE, segments.size()));
|
||||
|
||||
List<Value> instances = new ArrayList<>();
|
||||
for (TextSegment segment : batch) {
|
||||
Value.Builder instanceBuilder = Value.newBuilder();
|
||||
JsonFormat.parser().merge(toJson(new VertexAiEmbeddingInstance(segment.text())), instanceBuilder);
|
||||
instances.add(instanceBuilder.build());
|
||||
}
|
||||
|
||||
PredictResponse response = withRetry(() -> client.predict(endpointName, instances, EMPTY_VALUE), maxRetries);
|
||||
|
||||
embeddings.addAll(response.getPredictionsList().stream()
|
||||
.map(VertexAiEmbeddingModel::toEmbedding)
|
||||
.collect(toList()));
|
||||
|
||||
for (Value prediction : response.getPredictionsList()) {
|
||||
inputTokenCount += extractTokenCount(prediction);
|
||||
}
|
||||
}
|
||||
|
||||
return Response.from(
|
||||
|
@ -94,8 +96,9 @@ public class VertexAiEmbeddingModel implements EmbeddingModel {
|
|||
}
|
||||
}
|
||||
|
||||
private static List<Float> toVector(Value prediction) {
|
||||
return prediction.getStructValue()
|
||||
private static Embedding toEmbedding(Value prediction) {
|
||||
|
||||
List<Float> vector = prediction.getStructValue()
|
||||
.getFieldsMap()
|
||||
.get("embeddings")
|
||||
.getStructValue()
|
||||
|
@ -105,10 +108,12 @@ public class VertexAiEmbeddingModel implements EmbeddingModel {
|
|||
.stream()
|
||||
.map(v -> (float) v.getNumberValue())
|
||||
.collect(toList());
|
||||
|
||||
return Embedding.from(vector);
|
||||
}
|
||||
|
||||
private static int extractTokenCount(Value value) {
|
||||
return (int) value.getStructValue()
|
||||
private static int extractTokenCount(Value prediction) {
|
||||
return (int) prediction.getStructValue()
|
||||
.getFieldsMap()
|
||||
.get("embeddings")
|
||||
.getStructValue()
|
||||
|
|
|
@ -2,6 +2,7 @@ package dev.langchain4j.model.vertexai;
|
|||
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.TokenUsage;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
|
@ -18,7 +19,8 @@ class VertexAiEmbeddingModelIT {
|
|||
@Test
|
||||
@Disabled("To run this test, you must have provide your own endpoint, project and location")
|
||||
void testEmbeddingModel() {
|
||||
VertexAiEmbeddingModel vertexAiEmbeddingModel = VertexAiEmbeddingModel.builder()
|
||||
|
||||
EmbeddingModel embeddingModel = VertexAiEmbeddingModel.builder()
|
||||
.endpoint("us-central1-aiplatform.googleapis.com:443")
|
||||
.project("langchain4j")
|
||||
.location("us-central1")
|
||||
|
@ -28,22 +30,22 @@ class VertexAiEmbeddingModelIT {
|
|||
.build();
|
||||
|
||||
List<TextSegment> segments = asList(
|
||||
TextSegment.from("hello world"),
|
||||
TextSegment.from("how are you?")
|
||||
TextSegment.from("one"),
|
||||
TextSegment.from("two"),
|
||||
TextSegment.from("three"),
|
||||
TextSegment.from("four"),
|
||||
TextSegment.from("five"),
|
||||
TextSegment.from("six")
|
||||
);
|
||||
|
||||
Response<List<Embedding>> response = vertexAiEmbeddingModel.embedAll(segments);
|
||||
Response<List<Embedding>> response = embeddingModel.embedAll(segments);
|
||||
|
||||
List<Embedding> embeddings = response.content();
|
||||
assertThat(embeddings).hasSize(2);
|
||||
assertThat(embeddings).hasSize(6);
|
||||
|
||||
Embedding embedding1 = embeddings.get(0);
|
||||
assertThat(embedding1.vector()).hasSize(768);
|
||||
System.out.println(Arrays.toString(embedding1.vector()));
|
||||
|
||||
Embedding embedding2 = embeddings.get(1);
|
||||
assertThat(embedding2.vector()).hasSize(768);
|
||||
System.out.println(Arrays.toString(embedding2.vector()));
|
||||
Embedding embedding = embeddings.get(0);
|
||||
assertThat(embedding.vector()).hasSize(768);
|
||||
System.out.println(Arrays.toString(embedding.vector()));
|
||||
|
||||
TokenUsage tokenUsage = response.tokenUsage();
|
||||
assertThat(tokenUsage.inputTokenCount()).isEqualTo(6);
|
||||
|
|
Loading…
Reference in New Issue