Vertex AI: embed in batches of 5

This commit is contained in:
deep-learning-dynamo 2023-10-09 15:38:41 +02:00
parent eef1796963
commit 43917ee474
2 changed files with 48 additions and 41 deletions

View File

@ -19,6 +19,7 @@ import java.util.List;
import static com.google.cloud.aiplatform.util.ValueConverter.EMPTY_VALUE;
import static dev.langchain4j.internal.Json.toJson;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static java.util.stream.Collectors.toList;
@ -28,6 +29,8 @@ import static java.util.stream.Collectors.toList;
*/
public class VertexAiEmbeddingModel implements EmbeddingModel {
private static final int BATCH_SIZE = 5; // Vertex AI has a limit of up to 5 input texts per request
private final PredictionServiceSettings settings;
private final EndpointName endpointName;
private final Integer maxRetries;
@ -51,38 +54,37 @@ public class VertexAiEmbeddingModel implements EmbeddingModel {
ensureNotBlank(publisher, "publisher"),
ensureNotBlank(modelName, "modelName")
);
this.maxRetries = maxRetries == null ? 3 : maxRetries;
this.maxRetries = getOrDefault(maxRetries, 3);
}
@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
List<String> texts = textSegments.stream()
.map(TextSegment::text)
.collect(toList());
public Response<List<Embedding>> embedAll(List<TextSegment> segments) {
return embedTexts(texts);
}
private Response<List<Embedding>> embedTexts(List<String> texts) {
try (PredictionServiceClient client = PredictionServiceClient.create(settings)) {
List<Value> instances = new ArrayList<>();
for (String text : texts) {
Value.Builder instanceBuilder = Value.newBuilder();
JsonFormat.parser().merge(toJson(new VertexAiEmbeddingInstance(text)), instanceBuilder);
instances.add(instanceBuilder.build());
}
PredictResponse response = withRetry(() -> client.predict(endpointName, instances, EMPTY_VALUE), maxRetries);
List<Embedding> embeddings = response.getPredictionsList().stream()
.map(VertexAiEmbeddingModel::toVector)
.map(Embedding::from)
.collect(toList());
List<Embedding> embeddings = new ArrayList<>();
int inputTokenCount = 0;
for (Value value : response.getPredictionsList()) {
inputTokenCount += extractTokenCount(value);
for (int i = 0; i < segments.size(); i += BATCH_SIZE) {
List<TextSegment> batch = segments.subList(i, Math.min(i + BATCH_SIZE, segments.size()));
List<Value> instances = new ArrayList<>();
for (TextSegment segment : batch) {
Value.Builder instanceBuilder = Value.newBuilder();
JsonFormat.parser().merge(toJson(new VertexAiEmbeddingInstance(segment.text())), instanceBuilder);
instances.add(instanceBuilder.build());
}
PredictResponse response = withRetry(() -> client.predict(endpointName, instances, EMPTY_VALUE), maxRetries);
embeddings.addAll(response.getPredictionsList().stream()
.map(VertexAiEmbeddingModel::toEmbedding)
.collect(toList()));
for (Value prediction : response.getPredictionsList()) {
inputTokenCount += extractTokenCount(prediction);
}
}
return Response.from(
@ -94,8 +96,9 @@ public class VertexAiEmbeddingModel implements EmbeddingModel {
}
}
private static List<Float> toVector(Value prediction) {
return prediction.getStructValue()
private static Embedding toEmbedding(Value prediction) {
List<Float> vector = prediction.getStructValue()
.getFieldsMap()
.get("embeddings")
.getStructValue()
@ -105,10 +108,12 @@ public class VertexAiEmbeddingModel implements EmbeddingModel {
.stream()
.map(v -> (float) v.getNumberValue())
.collect(toList());
return Embedding.from(vector);
}
private static int extractTokenCount(Value value) {
return (int) value.getStructValue()
private static int extractTokenCount(Value prediction) {
return (int) prediction.getStructValue()
.getFieldsMap()
.get("embeddings")
.getStructValue()

View File

@ -2,6 +2,7 @@ package dev.langchain4j.model.vertexai;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Disabled;
@ -18,7 +19,8 @@ class VertexAiEmbeddingModelIT {
@Test
@Disabled("To run this test, you must have provide your own endpoint, project and location")
void testEmbeddingModel() {
VertexAiEmbeddingModel vertexAiEmbeddingModel = VertexAiEmbeddingModel.builder()
EmbeddingModel embeddingModel = VertexAiEmbeddingModel.builder()
.endpoint("us-central1-aiplatform.googleapis.com:443")
.project("langchain4j")
.location("us-central1")
@ -28,22 +30,22 @@ class VertexAiEmbeddingModelIT {
.build();
List<TextSegment> segments = asList(
TextSegment.from("hello world"),
TextSegment.from("how are you?")
TextSegment.from("one"),
TextSegment.from("two"),
TextSegment.from("three"),
TextSegment.from("four"),
TextSegment.from("five"),
TextSegment.from("six")
);
Response<List<Embedding>> response = vertexAiEmbeddingModel.embedAll(segments);
Response<List<Embedding>> response = embeddingModel.embedAll(segments);
List<Embedding> embeddings = response.content();
assertThat(embeddings).hasSize(2);
assertThat(embeddings).hasSize(6);
Embedding embedding1 = embeddings.get(0);
assertThat(embedding1.vector()).hasSize(768);
System.out.println(Arrays.toString(embedding1.vector()));
Embedding embedding2 = embeddings.get(1);
assertThat(embedding2.vector()).hasSize(768);
System.out.println(Arrays.toString(embedding2.vector()));
Embedding embedding = embeddings.get(0);
assertThat(embedding.vector()).hasSize(768);
System.out.println(Arrays.toString(embedding.vector()));
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(6);