Milvus: improve insert performance

This commit is contained in:
LangChain4j 2024-06-06 16:40:26 +02:00
parent d29866bde4
commit 2c8ff58c02
4 changed files with 49 additions and 13 deletions

View File

@ -8,6 +8,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;
@ -29,21 +30,29 @@ public abstract class EmbeddingStoreWithFilteringIT extends EmbeddingStoreIT {
List<Metadata> matchingMetadatas,
List<Metadata> notMatchingMetadatas) {
// given
List<Embedding> embeddings = new ArrayList<>();
List<TextSegment> segments = new ArrayList<>();
for (Metadata matchingMetadata : matchingMetadatas) {
TextSegment matchingSegment = TextSegment.from("matching", matchingMetadata);
Embedding matchingEmbedding = embeddingModel().embed(matchingSegment).content();
embeddingStore().add(matchingEmbedding, matchingSegment);
embeddings.add(matchingEmbedding);
segments.add(matchingSegment);
}
for (Metadata notMatchingMetadata : notMatchingMetadatas) {
TextSegment notMatchingSegment = TextSegment.from("not matching", notMatchingMetadata);
Embedding notMatchingEmbedding = embeddingModel().embed(notMatchingSegment).content();
embeddingStore().add(notMatchingEmbedding, notMatchingSegment);
embeddings.add(notMatchingEmbedding);
segments.add(notMatchingSegment);
}
TextSegment notMatchingSegmentWithoutMetadata = TextSegment.from("not matching, without metadata");
Embedding notMatchingWithoutMetadataEmbedding = embeddingModel().embed(notMatchingSegmentWithoutMetadata).content();
embeddingStore().add(notMatchingWithoutMetadataEmbedding, notMatchingSegmentWithoutMetadata);
Embedding notMatchingEmbeddingWithoutMetadata = embeddingModel().embed(notMatchingSegmentWithoutMetadata).content();
embeddings.add(notMatchingEmbeddingWithoutMetadata);
segments.add(notMatchingSegmentWithoutMetadata);
embeddingStore().addAll(embeddings, segments);
awaitUntilPersisted();
@ -1138,21 +1147,29 @@ public abstract class EmbeddingStoreWithFilteringIT extends EmbeddingStoreIT {
List<Metadata> matchingMetadatas,
List<Metadata> notMatchingMetadatas) {
// given
List<Embedding> embeddings = new ArrayList<>();
List<TextSegment> segments = new ArrayList<>();
for (Metadata matchingMetadata : matchingMetadatas) {
TextSegment matchingSegment = TextSegment.from("matching", matchingMetadata);
Embedding matchingEmbedding = embeddingModel().embed(matchingSegment).content();
embeddingStore().add(matchingEmbedding, matchingSegment);
embeddings.add(matchingEmbedding);
segments.add(matchingSegment);
}
for (Metadata notMatchingMetadata : notMatchingMetadatas) {
TextSegment notMatchingSegment = TextSegment.from("not matching", notMatchingMetadata);
Embedding notMatchingEmbedding = embeddingModel().embed(notMatchingSegment).content();
embeddingStore().add(notMatchingEmbedding, notMatchingSegment);
embeddings.add(notMatchingEmbedding);
segments.add(notMatchingSegment);
}
TextSegment notMatchingSegmentWithoutMetadata = TextSegment.from("matching");
Embedding notMatchingWithoutMetadataEmbedding = embeddingModel().embed(notMatchingSegmentWithoutMetadata).content();
embeddingStore().add(notMatchingWithoutMetadataEmbedding, notMatchingSegmentWithoutMetadata);
Embedding notMatchingEmbeddingWithoutMetadata = embeddingModel().embed(notMatchingSegmentWithoutMetadata).content();
embeddings.add(notMatchingEmbeddingWithoutMetadata);
segments.add(notMatchingSegmentWithoutMetadata);
embeddingStore().addAll(embeddings, segments);
awaitUntilPersisted();

View File

@ -58,7 +58,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
private final MetricType metricType;
private final ConsistencyLevelEnum consistencyLevel;
private final boolean retrieveEmbeddingsOnSearch;
private final boolean autoFlushOnInsert;
public MilvusEmbeddingStore(
String host,
@ -73,6 +73,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
String password,
ConsistencyLevelEnum consistencyLevel,
Boolean retrieveEmbeddingsOnSearch,
Boolean autoFlushOnInsert,
String databaseName
) {
ConnectParam.Builder connectBuilder = ConnectParam
@ -92,6 +93,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
this.metricType = getOrDefault(metricType, COSINE);
this.consistencyLevel = getOrDefault(consistencyLevel, EVENTUALLY);
this.retrieveEmbeddingsOnSearch = getOrDefault(retrieveEmbeddingsOnSearch, false);
this.autoFlushOnInsert = getOrDefault(autoFlushOnInsert, false);
if (!hasCollection(this.milvusClient, this.collectionName)) {
createCollection(this.milvusClient, this.collectionName, ensureNotNull(dimension, "dimension"));
@ -182,7 +184,9 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
fields.add(new InsertParam.Field(VECTOR_FIELD_NAME, toVectors(embeddings)));
insert(this.milvusClient, this.collectionName, fields);
flush(this.milvusClient, this.collectionName);
if (autoFlushOnInsert) {
flush(this.milvusClient, this.collectionName);
}
}
/**
@ -264,6 +268,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
private ConsistencyLevelEnum consistencyLevel;
private Boolean retrieveEmbeddingsOnSearch;
private String databaseName;
private Boolean autoFlushOnInsert;
/**
* @param host The host of the self-managed Milvus instance.
@ -386,6 +391,19 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
return this;
}
/**
* @param autoFlushOnInsert Whether to automatically flush after each insert
* ({@code add(...)} or {@code addAll(...)} methods).
* Default value: false.
* More info can be found
* <a href="https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/flush.md">here</a>.
* @return builder
*/
public Builder autoFlushOnInsert(Boolean autoFlushOnInsert) {
this.autoFlushOnInsert = autoFlushOnInsert;
return this;
}
/**
* @param databaseName Milvus name of database.
* Default value: null. In this case default Milvus database name will be used.
@ -410,6 +428,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
password,
consistencyLevel,
retrieveEmbeddingsOnSearch,
autoFlushOnInsert,
databaseName
);
}

View File

@ -57,7 +57,8 @@ class MilvusEmbeddingStoreCloudIT extends EmbeddingStoreWithFilteringIT {
EmbeddingStore<TextSegment> embeddingStore = MilvusEmbeddingStore.builder()
.uri(System.getenv("MILVUS_URI"))
.token(System.getenv("MILVUS_API_KEY"))
.collectionName("test")
.collectionName(COLLECTION_NAME)
.consistencyLevel(STRONG)
.dimension(384)
.retrieveEmbeddingsOnSearch(retrieveEmbeddingsOnSearch)
.build();

View File

@ -62,8 +62,7 @@ class MilvusEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT {
.host(milvus.getHost())
.port(milvus.getMappedPort(19530))
.collectionName(COLLECTION_NAME)
.username("")
.password("")
.consistencyLevel(STRONG)
.dimension(384)
.retrieveEmbeddingsOnSearch(false)
.build();