diff --git a/langchain4j-chroma/src/test/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStoreIT.java b/langchain4j-chroma/src/test/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStoreIT.java index cb982b744..076454dc1 100644 --- a/langchain4j-chroma/src/test/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStoreIT.java +++ b/langchain4j-chroma/src/test/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStoreIT.java @@ -3,14 +3,14 @@ package dev.langchain4j.store.embedding.chroma; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; -import dev.langchain4j.store.embedding.AbstractEmbeddingStoreIT; import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.EmbeddingStoreIT; import org.junit.jupiter.api.Disabled; import static dev.langchain4j.internal.Utils.randomUUID; @Disabled("needs Chroma running locally") -class ChromaEmbeddingStoreIT extends AbstractEmbeddingStoreIT { +class ChromaEmbeddingStoreIT extends EmbeddingStoreIT { /** * First ensure you have Chroma running locally. If not, then: @@ -19,12 +19,12 @@ class ChromaEmbeddingStoreIT extends AbstractEmbeddingStoreIT { * - Wait until Chroma is ready to serve (may take a few minutes) */ - private final EmbeddingStore embeddingStore = ChromaEmbeddingStore.builder() + EmbeddingStore embeddingStore = ChromaEmbeddingStore.builder() .baseUrl("http://localhost:8000") .collectionName(randomUUID()) .build(); - private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); @Override protected EmbeddingStore embeddingStore() { diff --git a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreIT.java b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreIT.java new file mode 100644 index 000000000..e40913bcf --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreIT.java @@ -0,0 +1,38 @@ +package dev.langchain4j.store.embedding; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.data.Percentage.withPercentage; + +/** + * A minimum set of tests that each implementation of {@link EmbeddingStore} must pass. + */ +public abstract class EmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT { + + @Test + void should_add_embedding_with_segment_with_metadata() { + + TextSegment segment = TextSegment.from("hello", Metadata.from("test-key", "test-value")); + Embedding embedding = embeddingModel().embed(segment.text()).content(); + + String id = embeddingStore().add(embedding, segment); + assertThat(id).isNotBlank(); + + awaitUntilPersisted(); + + List> relevant = embeddingStore().findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isEqualTo(segment); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/AbstractEmbeddingStoreIT.java b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithoutMetadataIT.java similarity index 89% rename from langchain4j-core/src/test/java/dev/langchain4j/store/embedding/AbstractEmbeddingStoreIT.java rename to langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithoutMetadataIT.java index 853aaacd6..f18da044b 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/AbstractEmbeddingStoreIT.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithoutMetadataIT.java @@ -1,6 +1,5 @@ package dev.langchain4j.store.embedding; -import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; @@ -14,10 +13,7 @@ import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.data.Percentage.withPercentage; -/** - * A minimum set of tests that each implementation of {@link EmbeddingStore} must pass. - */ -public abstract class AbstractEmbeddingStoreIT { +public abstract class EmbeddingStoreWithoutMetadataIT { protected abstract EmbeddingStore embeddingStore(); @@ -30,7 +26,7 @@ public abstract class AbstractEmbeddingStoreIT { protected void ensureStoreIsEmpty() { Embedding embedding = embeddingModel().embed("hello").content(); - assertThat(embeddingStore().findRelevant(embedding, Integer.MAX_VALUE)).isEmpty(); + assertThat(embeddingStore().findRelevant(embedding, 1000)).isEmpty(); } @Test @@ -94,27 +90,6 @@ public abstract class AbstractEmbeddingStoreIT { assertThat(match.embedded()).isEqualTo(segment); } - @Test - void should_add_embedding_with_segment_with_metadata() { - - TextSegment segment = TextSegment.from("hello", Metadata.from("test-key", "test-value")); - Embedding embedding = embeddingModel().embed(segment.text()).content(); - - String id = embeddingStore().add(embedding, segment); - assertThat(id).isNotBlank(); - - awaitUntilPersisted(); - - List> relevant = embeddingStore().findRelevant(embedding, 10); - assertThat(relevant).hasSize(1); - - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo(1, withPercentage(1)); - assertThat(match.embeddingId()).isEqualTo(id); - assertThat(match.embedding()).isEqualTo(embedding); - assertThat(match.embedded()).isEqualTo(segment); - } - @Test void should_add_multiple_embeddings() { diff --git a/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreIT.java b/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreIT.java index 1247f6c9e..d9b438e51 100644 --- a/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreIT.java +++ b/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreIT.java @@ -3,15 +3,15 @@ package dev.langchain4j.store.embedding.elasticsearch; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; -import dev.langchain4j.store.embedding.AbstractEmbeddingStoreIT; import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.EmbeddingStoreIT; import lombok.SneakyThrows; import org.junit.jupiter.api.Disabled; import static dev.langchain4j.internal.Utils.randomUUID; @Disabled("needs Elasticsearch to be running locally") -class ElasticsearchEmbeddingStoreIT extends AbstractEmbeddingStoreIT { +class ElasticsearchEmbeddingStoreIT extends EmbeddingStoreIT { /** * First start elasticsearch locally: @@ -19,13 +19,13 @@ class ElasticsearchEmbeddingStoreIT extends AbstractEmbeddingStoreIT { * docker run -d -p 9200:9200 -p 9300:9300 -e discovery.type=single-node -e xpack.security.enabled=false docker.elastic.co/elasticsearch/elasticsearch:8.9.0 */ - private final EmbeddingStore embeddingStore = ElasticsearchEmbeddingStore.builder() + EmbeddingStore embeddingStore = ElasticsearchEmbeddingStore.builder() .serverUrl("http://localhost:9200") .indexName(randomUUID()) .dimension(384) .build(); - private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); @Override protected EmbeddingStore embeddingStore() { diff --git a/langchain4j-milvus/pom.xml b/langchain4j-milvus/pom.xml index 6539d257c..20e9daab0 100644 --- a/langchain4j-milvus/pom.xml +++ b/langchain4j-milvus/pom.xml @@ -24,18 +24,13 @@ io.milvus milvus-sdk-java - 2.3.1 + 2.3.3 io.netty netty-codec - - - org.apache.commons - commons-text - @@ -43,10 +38,13 @@ netty-codec ${netty.version} + - org.apache.commons - commons-text - 1.10.0 + dev.langchain4j + langchain4j-core + tests + test-jar + test diff --git a/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/Mapper.java b/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/Mapper.java index 74e8c066e..d8938be53 100644 --- a/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/Mapper.java +++ b/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/Mapper.java @@ -6,6 +6,7 @@ import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.RelevanceScore; import io.milvus.client.MilvusServiceClient; import io.milvus.common.clientenum.ConsistencyLevelEnum; +import io.milvus.exception.ParamException; import io.milvus.response.QueryResultsWrapper; import io.milvus.response.SearchResultsWrapper; @@ -47,10 +48,15 @@ class Mapper { boolean queryForVectorOnSearch) { List> matches = new ArrayList<>(); - List rowIds = (List) resultsWrapper.getFieldWrapper(ID_FIELD_NAME).getFieldData(); Map idToEmbedding = new HashMap<>(); if (queryForVectorOnSearch) { - idToEmbedding.putAll(queryEmbeddings(milvusClient, collectionName, rowIds, consistencyLevel)); + try { + List rowIds = (List) resultsWrapper.getFieldWrapper(ID_FIELD_NAME).getFieldData(); + idToEmbedding.putAll(queryEmbeddings(milvusClient, collectionName, rowIds, consistencyLevel)); + } catch (ParamException e) { + // There is no way to check if the result is empty or not. + // If the result is empty, the exception will be thrown. + } } for (int i = 0; i < resultsWrapper.getRowRecords().size(); i++) { diff --git a/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStore.java b/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStore.java index 50944777f..f268ecb52 100644 --- a/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStore.java +++ b/langchain4j-milvus/src/main/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStore.java @@ -131,7 +131,9 @@ public class MilvusEmbeddingStore implements EmbeddingStore { retrieveEmbeddingsOnSearch ); - return matches.stream().filter(match -> match.score() >= minScore).collect(toList()); + return matches.stream() + .filter(match -> match.score() >= minScore) + .collect(toList()); } private void addInternal(String id, Embedding embedding, TextSegment textSegment) { diff --git a/langchain4j-milvus/src/test/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStoreIT.java b/langchain4j-milvus/src/test/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStoreIT.java index fac3db8d3..447f5df12 100644 --- a/langchain4j-milvus/src/test/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStoreIT.java +++ b/langchain4j-milvus/src/test/java/dev/langchain4j/store/embedding/milvus/MilvusEmbeddingStoreIT.java @@ -4,12 +4,12 @@ import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; -import dev.langchain4j.store.embedding.CosineSimilarity; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingStore; -import dev.langchain4j.store.embedding.RelevanceScore; +import dev.langchain4j.store.embedding.EmbeddingStoreWithoutMetadataIT; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import java.util.List; @@ -20,7 +20,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.data.Percentage.withPercentage; @Disabled("needs Milvus running locally") -class MilvusEmbeddingStoreIT { +class MilvusEmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT { /** * First run Milvus locally: @@ -33,221 +33,44 @@ class MilvusEmbeddingStoreIT { .port(19530) .collectionName("collection_" + randomUUID().replace("-", "")) .dimension(384) + .retrieveEmbeddingsOnSearch(true) .build(); EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); - @Test - void should_add_embedding() { + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; + } - Embedding embedding = embeddingModel.embed(randomUUID()).content(); - - String id = embeddingStore.add(embedding); - assertThat(id).isNotNull(); - - List> relevant = embeddingStore.findRelevant(embedding, 10); - assertThat(relevant).hasSize(1); - - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo(1, withPercentage(1)); - assertThat(match.embeddingId()).isEqualTo(id); - assertThat(match.embedding()).isNull(); - assertThat(match.embedded()).isNull(); + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; } @Test - void should_add_embedding_with_id() { - - String id = randomUUID(); - Embedding embedding = embeddingModel.embed(randomUUID()).content(); - - embeddingStore.add(id, embedding); - - List> relevant = embeddingStore.findRelevant(embedding, 10); - assertThat(relevant).hasSize(1); - - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo(1, withPercentage(1)); - assertThat(match.embeddingId()).isEqualTo(id); - assertThat(match.embedding()).isNull(); - assertThat(match.embedded()).isNull(); - } - - @Test - void should_add_embedding_with_segment() { - - TextSegment segment = TextSegment.from(randomUUID()); - Embedding embedding = embeddingModel.embed(segment.text()).content(); - - String id = embeddingStore.add(embedding, segment); - assertThat(id).isNotNull(); - - List> relevant = embeddingStore.findRelevant(embedding, 10); - assertThat(relevant).hasSize(1); - - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo(1, withPercentage(1)); - assertThat(match.embeddingId()).isEqualTo(id); - assertThat(match.embedding()).isNull(); - assertThat(match.embedded()).isEqualTo(segment); - } - - @Test - void should_add_multiple_embeddings() { - - Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content(); - Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content(); - - List ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding)); - assertThat(ids).hasSize(2); - - List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); - assertThat(relevant).hasSize(2); - - EmbeddingMatch firstMatch = relevant.get(0); - assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); - assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); - assertThat(firstMatch.embedding()).isNull(); - assertThat(firstMatch.embedded()).isNull(); - - EmbeddingMatch secondMatch = relevant.get(1); - assertThat(secondMatch.score()).isBetween(0d, 1d); - assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); - assertThat(secondMatch.embedding()).isNull(); - assertThat(secondMatch.embedded()).isNull(); - } - - @Test - void should_add_multiple_embeddings_with_segments() { - - TextSegment firstSegment = TextSegment.from(randomUUID()); - Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content(); - TextSegment secondSegment = TextSegment.from(randomUUID()); - Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content(); - - List ids = embeddingStore.addAll( - asList(firstEmbedding, secondEmbedding), - asList(firstSegment, secondSegment) - ); - assertThat(ids).hasSize(2); - - List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); - assertThat(relevant).hasSize(2); - - EmbeddingMatch firstMatch = relevant.get(0); - assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); - assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); - assertThat(firstMatch.embedding()).isNull(); - assertThat(firstMatch.embedded()).isEqualTo(firstSegment); - - EmbeddingMatch secondMatch = relevant.get(1); - assertThat(secondMatch.score()).isBetween(0d, 1d); - assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); - assertThat(secondMatch.embedding()).isNull(); - assertThat(secondMatch.embedded()).isEqualTo(secondSegment); - } - - @Test - void should_find_with_min_score() { - - String firstId = randomUUID(); - Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content(); - embeddingStore.add(firstId, firstEmbedding); - - String secondId = randomUUID(); - Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content(); - embeddingStore.add(secondId, secondEmbedding); - - List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); - assertThat(relevant).hasSize(2); - EmbeddingMatch firstMatch = relevant.get(0); - assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); - assertThat(firstMatch.embeddingId()).isEqualTo(firstId); - EmbeddingMatch secondMatch = relevant.get(1); - assertThat(secondMatch.score()).isBetween(0d, 1d); - assertThat(secondMatch.embeddingId()).isEqualTo(secondId); - - List> relevant2 = embeddingStore.findRelevant( - firstEmbedding, - 10, - secondMatch.score() - 0.01 - ); - assertThat(relevant2).hasSize(2); - assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId); - assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId); - - List> relevant3 = embeddingStore.findRelevant( - firstEmbedding, - 10, - secondMatch.score() - ); - assertThat(relevant3).hasSize(2); - assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId); - assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId); - - List> relevant4 = embeddingStore.findRelevant( - firstEmbedding, - 10, - secondMatch.score() + 0.01 - ); - assertThat(relevant4).hasSize(1); - assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId); - } - - @Test - void should_return_correct_score() { - - Embedding embedding = embeddingModel.embed("hello").content(); - - String id = embeddingStore.add(embedding); - assertThat(id).isNotNull(); - - Embedding referenceEmbedding = embeddingModel.embed("hi").content(); - - List> relevant = embeddingStore.findRelevant(referenceEmbedding, 1); - assertThat(relevant).hasSize(1); - - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo( - RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)), - withPercentage(1) - ); - } - - @Test - void should_retrieve_embeddings_when_searching() { + void should_not_retrieve_embeddings_when_searching() { EmbeddingStore embeddingStore = MilvusEmbeddingStore.builder() .host("localhost") .port(19530) .collectionName("collection_" + randomUUID().replace("-", "")) .dimension(384) - .retrieveEmbeddingsOnSearch(true) + .retrieveEmbeddingsOnSearch(false) .build(); - Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content(); - Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content(); - - List ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding)); - assertThat(ids).hasSize(2); + Embedding firstEmbedding = embeddingModel.embed("hello").content(); + Embedding secondEmbedding = embeddingModel.embed("hi").content(); + embeddingStore.addAll(asList(firstEmbedding, secondEmbedding)); List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); assertThat(relevant).hasSize(2); - - EmbeddingMatch firstMatch = relevant.get(0); - assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); - assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); - assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); - assertThat(firstMatch.embedded()).isNull(); - - EmbeddingMatch secondMatch = relevant.get(1); - assertThat(secondMatch.score()).isBetween(0d, 1d); - assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); - assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); - assertThat(secondMatch.embedded()).isNull(); + assertThat(relevant.get(0).embedding()).isNull(); + assertThat(relevant.get(1).embedding()).isNull(); } @Test + @EnabledIfEnvironmentVariable(named = "MILVUS_API_KEY", matches = ".+") void should_use_cloud_instance() { EmbeddingStore embeddingStore = MilvusEmbeddingStore.builder()