diff --git a/docs/docs/integrations/embedding-stores/index.md b/docs/docs/integrations/embedding-stores/index.md index 36e6c1996..8b58ee1b7 100644 --- a/docs/docs/integrations/embedding-stores/index.md +++ b/docs/docs/integrations/embedding-stores/index.md @@ -4,25 +4,25 @@ hide_title: false sidebar_position: 0 --- -| Embedding Store | Storing Metadata | Filtering by Metadata | Removing Embeddings | -|---------------------------------------------------------------------------------------|------------------|-----------------------|---------------------| -| [In-memory](/integrations/embedding-stores/in-memory) | ✅ | ✅ | ✅ | -| [Astra DB](/integrations/embedding-stores/astra-db) | ✅ | | | -| [Azure AI Search](/integrations/embedding-stores/azure-ai-search) | ✅ | ✅ | ✅ | -| [Azure CosmosDB Mongo vCore](/integrations/embedding-stores/azure-cosmos-mongo-vcore) | ✅ | | | -| [Azure CosmosDB NoSQL](/integrations/embedding-stores/azure-cosmos-nosql) | ✅ | | | -| [Cassandra](/integrations/embedding-stores/cassandra) | ✅ | | | -| [Chroma](/integrations/embedding-stores/chroma) | ✅ | ✅ | ✅ | -| [Elasticsearch](/integrations/embedding-stores/elasticsearch) | ✅ | ✅ | ✅ | -| [Infinispan](/integrations/embedding-stores/infinispan) | ✅ | | | -| [Milvus](/integrations/embedding-stores/milvus) | ✅ | ✅ | ✅ | -| [MongoDB Atlas](/integrations/embedding-stores/mongodb-atlas) | ✅ | | | -| [Neo4j](/integrations/embedding-stores/neo4j) | ✅ | | | -| [OpenSearch](/integrations/embedding-stores/opensearch) | ✅ | | | -| [PGVector](/integrations/embedding-stores/pgvector) | ✅ | ✅ | ✅ | -| [Pinecone](/integrations/embedding-stores/pinecone) | ✅ | ✅ | ✅ | -| [Qdrant](/integrations/embedding-stores/qdrant) | ✅ | | | -| [Redis](/integrations/embedding-stores/redis) | ✅ | | | -| [Vearch](/integrations/embedding-stores/vearch) | ✅ | | | -| [Vespa](/integrations/embedding-stores/vespa) | | | | -| [Weaviate](/integrations/embedding-stores/weaviate) | ✅ | | ✅ | +| Embedding Store | Storing Metadata | Filtering by Metadata | Removing Embeddings | +|---------------------------------------------------------------------------------------|------------------|----------------------------|---------------------| +| [In-memory](/integrations/embedding-stores/in-memory) | ✅ | ✅ | ✅ | +| [Astra DB](/integrations/embedding-stores/astra-db) | ✅ | | | +| [Azure AI Search](/integrations/embedding-stores/azure-ai-search) | ✅ | ✅ | ✅ | +| [Azure CosmosDB Mongo vCore](/integrations/embedding-stores/azure-cosmos-mongo-vcore) | ✅ | | | +| [Azure CosmosDB NoSQL](/integrations/embedding-stores/azure-cosmos-nosql) | ✅ | | | +| [Cassandra](/integrations/embedding-stores/cassandra) | ✅ | | | +| [Chroma](/integrations/embedding-stores/chroma) | ✅ | ✅ | ✅ | +| [Elasticsearch](/integrations/embedding-stores/elasticsearch) | ✅ | ✅ | ✅ | +| [Infinispan](/integrations/embedding-stores/infinispan) | ✅ | | | +| [Milvus](/integrations/embedding-stores/milvus) | ✅ | ✅ | ✅ | +| [MongoDB Atlas](/integrations/embedding-stores/mongodb-atlas) | ✅ | Only native filter support | | +| [Neo4j](/integrations/embedding-stores/neo4j) | ✅ | | | +| [OpenSearch](/integrations/embedding-stores/opensearch) | ✅ | | | +| [PGVector](/integrations/embedding-stores/pgvector) | ✅ | ✅ | ✅ | +| [Pinecone](/integrations/embedding-stores/pinecone) | ✅ | ✅ | ✅ | +| [Qdrant](/integrations/embedding-stores/qdrant) | ✅ | | | +| [Redis](/integrations/embedding-stores/redis) | ✅ | | | +| [Vearch](/integrations/embedding-stores/vearch) | ✅ | | | +| [Vespa](/integrations/embedding-stores/vespa) | | | | +| [Weaviate](/integrations/embedding-stores/weaviate) | ✅ | | ✅ | diff --git a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverIT.java b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverIT.java index 51d6b88d7..8a623bb8a 100644 --- a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverIT.java +++ b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverIT.java @@ -6,12 +6,13 @@ import com.azure.search.documents.indexes.SearchIndexClientBuilder; import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.rag.content.Content; import dev.langchain4j.rag.query.Query; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT; +import org.awaitility.core.ThrowingRunnable; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; @@ -31,40 +32,37 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering private static final Logger log = LoggerFactory.getLogger(AzureAiSearchContentRetrieverIT.class); - private EmbeddingModel embeddingModel; + private final EmbeddingModel embeddingModel; - private AzureAiSearchContentRetriever contentRetrieverWithVector; + private final AzureAiSearchContentRetriever contentRetrieverWithVector; private AzureAiSearchContentRetriever contentRetrieverWithFullText; - private AzureAiSearchContentRetriever contentRetrieverWithHybrid; + private final AzureAiSearchContentRetriever contentRetrieverWithHybrid; - private AzureAiSearchContentRetriever contentRetrieverWithHybridAndReranking; + private final AzureAiSearchContentRetriever contentRetrieverWithHybridAndReranking; - private int dimensions; + public AzureAiSearchContentRetrieverIT() { + embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); - public AzureAiSearchContentRetrieverIT() { - embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); - dimensions = embeddingModel.embed("test").content().vector().length; + SearchIndexClient searchIndexClient = new SearchIndexClientBuilder() + .endpoint(System.getenv("AZURE_SEARCH_ENDPOINT")) + .credential(new AzureKeyCredential(System.getenv("AZURE_SEARCH_KEY"))) + .buildClient(); - SearchIndexClient searchIndexClient = new SearchIndexClientBuilder() - .endpoint(System.getenv("AZURE_SEARCH_ENDPOINT")) - .credential(new AzureKeyCredential(System.getenv("AZURE_SEARCH_KEY"))) - .buildClient(); + searchIndexClient.deleteIndex(DEFAULT_INDEX_NAME); - searchIndexClient.deleteIndex(DEFAULT_INDEX_NAME); - - contentRetrieverWithVector = createContentRetriever(AzureAiSearchQueryType.VECTOR); - contentRetrieverWithFullText = createFullTextSearchContentRetriever(); - contentRetrieverWithHybrid = createContentRetriever(AzureAiSearchQueryType.HYBRID); - contentRetrieverWithHybridAndReranking = createContentRetriever(AzureAiSearchQueryType.HYBRID_WITH_RERANKING); + contentRetrieverWithVector = createContentRetriever(AzureAiSearchQueryType.VECTOR); + contentRetrieverWithFullText = createFullTextSearchContentRetriever(); + contentRetrieverWithHybrid = createContentRetriever(AzureAiSearchQueryType.HYBRID); + contentRetrieverWithHybridAndReranking = createContentRetriever(AzureAiSearchQueryType.HYBRID_WITH_RERANKING); } private AzureAiSearchContentRetriever createContentRetriever(AzureAiSearchQueryType azureAiSearchQueryType) { return AzureAiSearchContentRetriever.builder() .endpoint(System.getenv("AZURE_SEARCH_ENDPOINT")) .apiKey(System.getenv("AZURE_SEARCH_KEY")) - .dimensions(dimensions) + .dimensions(embeddingModel.dimension()) .embeddingModel(embeddingModel) .queryType(azureAiSearchQueryType) .maxResults(3) @@ -105,7 +103,7 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering contentRetrieverWithVector.add(embedding, textSegment); } - awaitUntilPersisted(); + awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(contents.size())); String content = "fruit"; Query query = Query.from(content); @@ -140,7 +138,7 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering contentRetrieverWithVector.add(embedding, textSegment); } - awaitUntilPersisted(); + awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(contents.size())); String content = "house"; Query query = Query.from(content); @@ -189,7 +187,7 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering contentRetrieverWithFullText.add(content); } - awaitUntilPersisted(); + awaitUntilAsserted(() -> assertThat(contentRetrieverWithFullText.retrieve(Query.from("a"))).hasSize(contents.size())); Query query = Query.from("Alain"); List relevant = contentRetrieverWithFullText.retrieve(query); @@ -206,9 +204,9 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering @Test void testFullTextSearchWithSpecificSearchIndex() { - // This doesn't reuse the existing search index, but creates a specialized one only for full text search + // This doesn't reuse the existing search index, but creates a specialized one only for full text search contentRetrieverWithVector.deleteIndex(); - contentRetrieverWithFullText = AzureAiSearchContentRetriever.builder() + contentRetrieverWithFullText = AzureAiSearchContentRetriever.builder() .endpoint(System.getenv("AZURE_SEARCH_ENDPOINT")) .apiKey(System.getenv("AZURE_SEARCH_KEY")) .embeddingModel(null) @@ -219,7 +217,7 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering .build(); testFullTextSearch(); clearStore(); - contentRetrieverWithFullText = createFullTextSearchContentRetriever(); + contentRetrieverWithFullText = createFullTextSearchContentRetriever(); } @Test @@ -247,7 +245,7 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering contentRetrieverWithHybrid.add(embedding, textSegment); } - awaitUntilPersisted(); + awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(contents.size())); Query query = Query.from("Algeria"); List relevant = contentRetrieverWithHybrid.retrieve(query); @@ -287,7 +285,7 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering contentRetrieverWithHybridAndReranking.add(embedding, textSegment); } - awaitUntilPersisted(); + awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(contents.size())); Query query = Query.from("A philosopher who was in the French Resistance"); List relevant = contentRetrieverWithHybridAndReranking.retrieve(query); @@ -318,18 +316,24 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering AzureAiSearchContentRetriever azureAiSearchContentRetriever = contentRetrieverWithVector; try { azureAiSearchContentRetriever.deleteIndex(); - azureAiSearchContentRetriever.createOrUpdateIndex(dimensions); + azureAiSearchContentRetriever.createOrUpdateIndex(embeddingModel.dimension()); } catch (RuntimeException e) { log.error("Failed to clean up the index. You should look at deleting it manually.", e); } } @Override - protected void awaitUntilPersisted() { + protected void awaitUntilAsserted(ThrowingRunnable assertion) { + super.awaitUntilAsserted(assertion); try { - Thread.sleep(1_000); + Thread.sleep(1000); // TODO figure out why this is needed and remove this hack } catch (InterruptedException e) { throw new RuntimeException(e); } } + + @Override + protected boolean assertEmbedding() { + return false; // TODO remove this hack after https://github.com/langchain4j/langchain4j/issues/1617 is closed + } } diff --git a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverRemovalIT.java b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverRemovalIT.java index c85745e1e..319a6f67e 100644 --- a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverRemovalIT.java +++ b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/rag/content/retriever/azure/search/AzureAiSearchContentRetrieverRemovalIT.java @@ -1,58 +1,34 @@ package dev.langchain4j.rag.content.retriever.azure.search; -import com.azure.core.credential.AzureKeyCredential; -import com.azure.search.documents.indexes.SearchIndexClient; -import com.azure.search.documents.indexes.SearchIndexClientBuilder; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT; -import org.junit.jupiter.api.BeforeEach; +import org.awaitility.core.ThrowingRunnable; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import static dev.langchain4j.store.embedding.azure.search.AbstractAzureAiSearchEmbeddingStore.DEFAULT_INDEX_NAME; +import static dev.langchain4j.internal.Utils.randomUUID; +import static dev.langchain4j.rag.content.retriever.azure.search.AzureAiSearchQueryType.HYBRID; @EnabledIfEnvironmentVariable(named = "AZURE_SEARCH_ENDPOINT", matches = ".+") public class AzureAiSearchContentRetrieverRemovalIT extends EmbeddingStoreWithRemovalIT { private static final Logger log = LoggerFactory.getLogger(AzureAiSearchContentRetrieverRemovalIT.class); - private EmbeddingModel embeddingModel; + private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); - private AzureAiSearchContentRetriever contentRetrieverWithVector; - - public AzureAiSearchContentRetrieverRemovalIT() { - embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); - - SearchIndexClient searchIndexClient = new SearchIndexClientBuilder() - .endpoint(System.getenv("AZURE_SEARCH_ENDPOINT")) - .credential(new AzureKeyCredential(System.getenv("AZURE_SEARCH_KEY"))) - .buildClient(); - - searchIndexClient.deleteIndex(DEFAULT_INDEX_NAME); - - contentRetrieverWithVector = createContentRetriever(AzureAiSearchQueryType.VECTOR); - } - - private AzureAiSearchContentRetriever createContentRetriever(AzureAiSearchQueryType azureAiSearchQueryType) { - return AzureAiSearchContentRetriever.builder() - .endpoint(System.getenv("AZURE_SEARCH_ENDPOINT")) - .apiKey(System.getenv("AZURE_SEARCH_KEY")) - .dimensions(embeddingModel.dimension()) - .embeddingModel(embeddingModel) - .queryType(azureAiSearchQueryType) - .maxResults(3) - .minScore(0.0) - .build(); - } - - @BeforeEach - void setUp() { - clearStore(); - } + private final AzureAiSearchContentRetriever contentRetrieverWithVector = AzureAiSearchContentRetriever.builder() + .endpoint(System.getenv("AZURE_SEARCH_ENDPOINT")) + .apiKey(System.getenv("AZURE_SEARCH_KEY")) + .indexName(randomUUID()) + .dimensions(embeddingModel.dimension()) + .embeddingModel(embeddingModel) + .queryType(HYBRID) + .build(); @Override protected EmbeddingStore embeddingStore() { @@ -64,14 +40,22 @@ public class AzureAiSearchContentRetrieverRemovalIT extends EmbeddingStoreWithRe return this.embeddingModel; } - protected void clearStore() { - log.debug("Deleting the search index"); - AzureAiSearchContentRetriever azureAiSearchContentRetriever = contentRetrieverWithVector; + @AfterEach + void afterEach() { try { - azureAiSearchContentRetriever.deleteIndex(); - azureAiSearchContentRetriever.createOrUpdateIndex(embeddingModel.dimension()); + contentRetrieverWithVector.deleteIndex(); } catch (RuntimeException e) { - log.error("Failed to clean up the index. You should look at deleting it manually.", e); + log.error("Failed to delete the index. You should look at deleting it manually.", e); + } + } + + @Override + protected void awaitUntilAsserted(ThrowingRunnable assertion) { + super.awaitUntilAsserted(assertion); + try { + Thread.sleep(1000); // TODO figure out why this is needed and remove this hack + } catch (InterruptedException e) { + throw new RuntimeException(e); } } } diff --git a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchEmbeddingStoreIT.java b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchEmbeddingStoreIT.java index 7172ca4d7..06ae55a79 100644 --- a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchEmbeddingStoreIT.java +++ b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchEmbeddingStoreIT.java @@ -6,14 +6,13 @@ import com.azure.search.documents.indexes.SearchIndexClientBuilder; import com.azure.search.documents.indexes.models.SearchField; import com.azure.search.documents.indexes.models.SearchFieldDataType; import com.azure.search.documents.indexes.models.SearchIndex; -import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; -import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT; -import org.junit.jupiter.api.BeforeEach; +import org.awaitility.core.ThrowingRunnable; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; @@ -22,10 +21,9 @@ import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.List; +import static dev.langchain4j.internal.Utils.randomUUID; import static dev.langchain4j.store.embedding.azure.search.AbstractAzureAiSearchEmbeddingStore.DEFAULT_FIELD_ID; import static dev.langchain4j.store.embedding.azure.search.AbstractAzureAiSearchEmbeddingStore.DEFAULT_INDEX_NAME; -import static java.util.Arrays.asList; -import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.fail; @@ -34,31 +32,50 @@ public class AzureAiSearchEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT private static final Logger log = LoggerFactory.getLogger(AzureAiSearchEmbeddingStoreIT.class); - private EmbeddingModel embeddingModel; + private static final String AZURE_SEARCH_ENDPOINT = System.getenv("AZURE_SEARCH_ENDPOINT"); + private static final String AZURE_SEARCH_KEY = System.getenv("AZURE_SEARCH_KEY"); - private EmbeddingStore embeddingStore; + private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); - private int dimensions; + private final AzureAiSearchEmbeddingStore embeddingStore = AzureAiSearchEmbeddingStore.builder() + .endpoint(AZURE_SEARCH_ENDPOINT) + .apiKey(AZURE_SEARCH_KEY) + .indexName(randomUUID()) + .dimensions(embeddingModel.dimension()) + .build(); - private String AZURE_SEARCH_ENDPOINT = System.getenv("AZURE_SEARCH_ENDPOINT"); - - private String AZURE_SEARCH_KEY = System.getenv("AZURE_SEARCH_KEY"); - - public AzureAiSearchEmbeddingStoreIT() { - - embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); - dimensions = embeddingModel.embed("test").content().vector().length; - - embeddingStore = AzureAiSearchEmbeddingStore.builder() - .endpoint(AZURE_SEARCH_ENDPOINT) - .apiKey(AZURE_SEARCH_KEY) - .dimensions(dimensions) - .build(); + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; } - @BeforeEach - void setUp() { - clearStore(); + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; + } + + @AfterEach + void afterEach() { + try { + embeddingStore.deleteIndex(); + } catch (RuntimeException e) { + log.error("Failed to delete the index. You should look at deleting it manually.", e); + } + } + + @Override + protected void awaitUntilAsserted(ThrowingRunnable assertion) { + super.awaitUntilAsserted(assertion); + try { + Thread.sleep(1000); // TODO figure out why this is needed and remove this hack + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Override + protected boolean assertEmbedding() { + return false; // TODO remove this hack after https://github.com/langchain4j/langchain4j/issues/1617 is closed } @Test @@ -89,7 +106,7 @@ public class AzureAiSearchEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT try { new AzureAiSearchEmbeddingStore(AZURE_SEARCH_ENDPOINT, - new AzureKeyCredential(AZURE_SEARCH_KEY), true, providedIndex, "ANOTHER_INDEX_NAME", null); + new AzureKeyCredential(AZURE_SEARCH_KEY), true, providedIndex, "ANOTHER_INDEX_NAME", null); fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { @@ -102,71 +119,16 @@ public class AzureAiSearchEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT @Test public void when_an_index_is_not_provided_the_default_name_is_used() { - AzureAiSearchEmbeddingStore store =new AzureAiSearchEmbeddingStore(AZURE_SEARCH_ENDPOINT, - new AzureKeyCredential(AZURE_SEARCH_KEY), false, null, null, null); + + AzureAiSearchEmbeddingStore store = new AzureAiSearchEmbeddingStore( + AZURE_SEARCH_ENDPOINT, + new AzureKeyCredential(AZURE_SEARCH_KEY), + false, + null, + null, + null + ); assertEquals(DEFAULT_INDEX_NAME, store.searchClient.getIndexName()); } - - @Test - void test_add_embeddings_and_find_relevant() { - String content1 = "banana"; - String content2 = "computer"; - String content3 = "apple"; - String content4 = "pizza"; - String content5 = "strawberry"; - String content6 = "chess"; - List contents = asList(content1, content2, content3, content4, content5, content6); - - for (String content : contents) { - TextSegment textSegment = TextSegment.from(content); - Embedding embedding = embeddingModel.embed(content).content(); - embeddingStore.add(embedding, textSegment); - } - - awaitUntilPersisted(); - - Embedding relevantEmbedding = embeddingModel.embed("fruit").content(); - List> relevant = embeddingStore.findRelevant(relevantEmbedding, 3); - assertThat(relevant).hasSize(3); - assertThat(relevant.get(0).embedding()).isNotNull(); - assertThat(relevant.get(0).embedded().text()).isIn(content1, content3, content5); - log.info("#1 relevant item: {}", relevant.get(0).embedded().text()); - assertThat(relevant.get(1).embedding()).isNotNull(); - assertThat(relevant.get(1).embedded().text()).isIn(content1, content3, content5); - log.info("#2 relevant item: {}", relevant.get(1).embedded().text()); - assertThat(relevant.get(2).embedding()).isNotNull(); - assertThat(relevant.get(2).embedded().text()).isIn(content1, content3, content5); - log.info("#3 relevant item: {}", relevant.get(2).embedded().text()); - } - - @Override - protected EmbeddingStore embeddingStore() { - return embeddingStore; - } - - @Override - protected EmbeddingModel embeddingModel() { - return embeddingModel; - } - - @Override - protected void clearStore() { - AzureAiSearchEmbeddingStore azureAiSearchEmbeddingStore = (AzureAiSearchEmbeddingStore) embeddingStore; - try { - azureAiSearchEmbeddingStore.deleteIndex(); - azureAiSearchEmbeddingStore.createOrUpdateIndex(dimensions); - } catch (RuntimeException e) { - log.error("Failed to clean up the index. You should look at deleting it manually.", e); - } - } - - @Override - protected void awaitUntilPersisted() { - try { - Thread.sleep(1_000); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } } diff --git a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchEmbeddingStoreRemovalIT.java b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchEmbeddingStoreRemovalIT.java index bb8355e79..5e8aa9b47 100644 --- a/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchEmbeddingStoreRemovalIT.java +++ b/langchain4j-azure-ai-search/src/test/java/dev/langchain4j/store/embedding/azure/search/AzureAiSearchEmbeddingStoreRemovalIT.java @@ -5,49 +5,27 @@ import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT; -import org.junit.jupiter.api.BeforeEach; +import org.awaitility.core.ThrowingRunnable; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static dev.langchain4j.internal.Utils.randomUUID; + @EnabledIfEnvironmentVariable(named = "AZURE_SEARCH_ENDPOINT", matches = ".+") public class AzureAiSearchEmbeddingStoreRemovalIT extends EmbeddingStoreWithRemovalIT { private static final Logger log = LoggerFactory.getLogger(AzureAiSearchEmbeddingStoreRemovalIT.class); - private EmbeddingModel embeddingModel; + private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); - private EmbeddingStore embeddingStore; - - private String AZURE_SEARCH_ENDPOINT = System.getenv("AZURE_SEARCH_ENDPOINT"); - - private String AZURE_SEARCH_KEY = System.getenv("AZURE_SEARCH_KEY"); - - public AzureAiSearchEmbeddingStoreRemovalIT() { - - embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); - - embeddingStore = AzureAiSearchEmbeddingStore.builder() - .endpoint(AZURE_SEARCH_ENDPOINT) - .apiKey(AZURE_SEARCH_KEY) - .dimensions(embeddingModel.dimension()) - .build(); - } - - @BeforeEach - void setUp() { - clearStore(); - } - - private void clearStore() { - AzureAiSearchEmbeddingStore azureAiSearchEmbeddingStore = (AzureAiSearchEmbeddingStore) embeddingStore; - try { - azureAiSearchEmbeddingStore.deleteIndex(); - azureAiSearchEmbeddingStore.createOrUpdateIndex(embeddingModel.dimension()); - } catch (RuntimeException e) { - log.error("Failed to clean up the index. You should look at deleting it manually.", e); - } - } + private final AzureAiSearchEmbeddingStore embeddingStore = AzureAiSearchEmbeddingStore.builder() + .endpoint(System.getenv("AZURE_SEARCH_ENDPOINT")) + .apiKey(System.getenv("AZURE_SEARCH_KEY")) + .indexName(randomUUID()) + .dimensions(embeddingModel.dimension()) + .build(); @Override protected EmbeddingStore embeddingStore() { @@ -58,4 +36,23 @@ public class AzureAiSearchEmbeddingStoreRemovalIT extends EmbeddingStoreWithRemo protected EmbeddingModel embeddingModel() { return embeddingModel; } + + @AfterEach + void afterEach() { + try { + embeddingStore.deleteIndex(); + } catch (RuntimeException e) { + log.error("Failed to delete the index. You should look at deleting it manually.", e); + } + } + + @Override + protected void awaitUntilAsserted(ThrowingRunnable assertion) { + super.awaitUntilAsserted(assertion); + try { + Thread.sleep(1000); // TODO figure out why this is needed and remove this hack + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } } diff --git a/langchain4j-azure-cosmos-mongo-vcore/pom.xml b/langchain4j-azure-cosmos-mongo-vcore/pom.xml index 33eb1ab91..28880c6e8 100644 --- a/langchain4j-azure-cosmos-mongo-vcore/pom.xml +++ b/langchain4j-azure-cosmos-mongo-vcore/pom.xml @@ -85,6 +85,13 @@ slf4j-tinylog test + + + org.awaitility + awaitility + test + + diff --git a/langchain4j-azure-cosmos-mongo-vcore/src/test/java/dev/langchain4j/store/embedding/azure/cosmos/mongo/vcore/AzureCosmosDBMongoVCoreEmbeddingStoreIT.java b/langchain4j-azure-cosmos-mongo-vcore/src/test/java/dev/langchain4j/store/embedding/azure/cosmos/mongo/vcore/AzureCosmosDBMongoVCoreEmbeddingStoreIT.java index a33c69bcd..a678c3ead 100644 --- a/langchain4j-azure-cosmos-mongo-vcore/src/test/java/dev/langchain4j/store/embedding/azure/cosmos/mongo/vcore/AzureCosmosDBMongoVCoreEmbeddingStoreIT.java +++ b/langchain4j-azure-cosmos-mongo-vcore/src/test/java/dev/langchain4j/store/embedding/azure/cosmos/mongo/vcore/AzureCosmosDBMongoVCoreEmbeddingStoreIT.java @@ -6,43 +6,30 @@ import com.mongodb.client.MongoClient; import com.mongodb.client.MongoClients; import com.mongodb.client.MongoCollection; import com.mongodb.client.model.Filters; -import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; -import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreIT; import org.bson.BsonDocument; import org.bson.codecs.configuration.CodecRegistry; import org.bson.codecs.pojo.PojoCodecProvider; import org.bson.conversions.Bson; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import java.util.List; - -import static java.util.Arrays.asList; -import static org.assertj.core.api.Assertions.assertThat; import static org.bson.codecs.configuration.CodecRegistries.fromProviders; import static org.bson.codecs.configuration.CodecRegistries.fromRegistries; -import static org.junit.jupiter.api.Assertions.assertThrows; @EnabledIfEnvironmentVariable(named = "AZURE_COSMOS_ENDPOINT", matches = ".+") public class AzureCosmosDBMongoVCoreEmbeddingStoreIT extends EmbeddingStoreIT { - private static final Logger log = LoggerFactory.getLogger(AzureCosmosDBMongoVCoreEmbeddingStoreIT.class); private static MongoClient client; private final EmbeddingModel embeddingModel; private final EmbeddingStore embeddingStore; - private final int dimensions; public AzureCosmosDBMongoVCoreEmbeddingStoreIT() { embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); - dimensions = embeddingModel.embed("hello").content().vector().length; client = MongoClients.create( MongoClientSettings.builder() @@ -59,45 +46,13 @@ public class AzureCosmosDBMongoVCoreEmbeddingStoreIT extends EmbeddingStoreIT { .createIndex(true) .kind("vector-hnsw") .numLists(2) - .dimensions(dimensions) + .dimensions(embeddingModel.dimension()) .m(16) .efConstruction(64) .efSearch(40) .build(); } - @Test - void testAddEmbeddingsAndFindRelevant() { - String content1 = "banana"; - String content2 = "computer"; - String content3 = "apple"; - String content4 = "pizza"; - String content5 = "strawberry"; - String content6 = "chess"; - List contents = asList(content1, content2, content3, content4, content5, content6); - - for (String content : contents) { - TextSegment textSegment = TextSegment.from(content); - Embedding embedding = embeddingModel.embed(content).content(); - embeddingStore.add(embedding, textSegment); - } - - awaitUntilPersisted(); - - Embedding relevantEmbedding = embeddingModel.embed("fruit").content(); - List> relevant = embeddingStore.findRelevant(relevantEmbedding, 3); - assertThat(relevant).hasSize(3); - assertThat(relevant.get(0).embedding()).isNotNull(); - assertThat(relevant.get(0).embedded().text()).isIn(content1, content3, content5); - log.info("#1 relevant item: {}", relevant.get(0).embedded().text()); - assertThat(relevant.get(1).embedding()).isNotNull(); - assertThat(relevant.get(1).embedded().text()).isIn(content1, content3, content5); - log.info("#2 relevant item: {}", relevant.get(1).embedded().text()); - assertThat(relevant.get(2).embedding()).isNotNull(); - assertThat(relevant.get(2).embedded().text()).isIn(content1, content3, content5); - log.info("#3 relevant item: {}", relevant.get(2).embedded().text()); - } - @Override protected EmbeddingStore embeddingStore() { return embeddingStore; @@ -108,15 +63,6 @@ public class AzureCosmosDBMongoVCoreEmbeddingStoreIT extends EmbeddingStoreIT { return embeddingModel; } - @Override - protected void awaitUntilPersisted() { - try { - Thread.sleep(1_000); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } - @Override protected void clearStore() { CodecRegistry pojoCodecRegistry = fromProviders(PojoCodecProvider.builder() diff --git a/langchain4j-azure-cosmos-nosql/pom.xml b/langchain4j-azure-cosmos-nosql/pom.xml index 4ab348b86..c252a9484 100644 --- a/langchain4j-azure-cosmos-nosql/pom.xml +++ b/langchain4j-azure-cosmos-nosql/pom.xml @@ -70,6 +70,12 @@ test + + org.awaitility + awaitility + test + + \ No newline at end of file diff --git a/langchain4j-azure-cosmos-nosql/src/test/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/AzureCosmosDbNoSqlEmbeddingStoreIT.java b/langchain4j-azure-cosmos-nosql/src/test/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/AzureCosmosDbNoSqlEmbeddingStoreIT.java index 2f5b0cd1d..88c2cf7d5 100644 --- a/langchain4j-azure-cosmos-nosql/src/test/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/AzureCosmosDbNoSqlEmbeddingStoreIT.java +++ b/langchain4j-azure-cosmos-nosql/src/test/java/dev/langchain4j/store/embedding/azure/cosmos/nosql/AzureCosmosDbNoSqlEmbeddingStoreIT.java @@ -3,50 +3,45 @@ package dev.langchain4j.store.embedding.azure.cosmos.nosql; import com.azure.cosmos.ConsistencyLevel; import com.azure.cosmos.CosmosClient; import com.azure.cosmos.CosmosClientBuilder; -import com.azure.cosmos.CosmosDatabase; -import com.azure.cosmos.models.*; -import dev.langchain4j.data.embedding.Embedding; +import com.azure.cosmos.models.CosmosContainerProperties; +import com.azure.cosmos.models.CosmosVectorDataType; +import com.azure.cosmos.models.CosmosVectorDistanceFunction; +import com.azure.cosmos.models.CosmosVectorEmbedding; +import com.azure.cosmos.models.CosmosVectorEmbeddingPolicy; +import com.azure.cosmos.models.CosmosVectorIndexSpec; +import com.azure.cosmos.models.CosmosVectorIndexType; +import com.azure.cosmos.models.IncludedPath; +import com.azure.cosmos.models.IndexingMode; +import com.azure.cosmos.models.IndexingPolicy; +import com.azure.cosmos.models.PartitionKeyDefinition; import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; -import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreIT; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import static java.util.Arrays.asList; -import static org.assertj.core.api.Assertions.assertThat; - @EnabledIfEnvironmentVariable(named = "AZURE_COSMOS_HOST", matches = ".+") @EnabledIfEnvironmentVariable(named = "AZURE_COSMOS_MASTER_KEY", matches = ".+") class AzureCosmosDbNoSqlEmbeddingStoreIT extends EmbeddingStoreIT { - protected static Logger logger = LoggerFactory.getLogger(AzureCosmosDbNoSqlEmbeddingStoreIT.class); - private static final String DATABASE_NAME = "test_database_langchain_java"; private static final String CONTAINER_NAME = "test_container"; - private CosmosClient client; - CosmosDatabase database; - private final EmbeddingModel embeddingModel; + private final EmbeddingStore embeddingStore; - private final int dimensions; - private final String HOST = System.getenv("AZURE_COSMOS_HOST"); - private final String KEY = System.getenv("AZURE_COSMOS_MASTER_KEY"); + private final EmbeddingModel embeddingModel; public AzureCosmosDbNoSqlEmbeddingStoreIT() { embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); - dimensions = embeddingModel.embed("hello").content().vector().length; + int dimensions = embeddingModel.dimension(); - client = new CosmosClientBuilder() - .endpoint(HOST) - .key(KEY) + CosmosClient client = new CosmosClientBuilder() + .endpoint(System.getenv("AZURE_COSMOS_HOST")) + .key(System.getenv("AZURE_COSMOS_MASTER_KEY")) .consistencyLevel(ConsistencyLevel.EVENTUAL) .contentResponseOnWriteEnabled(true) .buildClient(); @@ -59,43 +54,6 @@ class AzureCosmosDbNoSqlEmbeddingStoreIT extends EmbeddingStoreIT { .cosmosVectorIndexes(populateVectorIndexSpec()) .containerProperties(populateContainerProperties()) .build(); - database = client.getDatabase(DATABASE_NAME); - } - - @Test - public void testAddEmbeddingsAndFindRelevant() { - String content1 = "banana"; - String content2 = "computer"; - String content3 = "apple"; - String content4 = "pizza"; - String content5 = "strawberry"; - String content6 = "chess"; - - List contents = asList(content1, content2, content3, content4, content5, content6); - - for (String content : contents) { - TextSegment textSegment = TextSegment.from(content); - Embedding embedding = embeddingModel.embed(content).content(); - embeddingStore.add(embedding, textSegment); - } - - awaitUntilPersisted(); - - Embedding relevantEmbedding = embeddingModel.embed("fruit").content(); - List> relevant = embeddingStore.findRelevant(relevantEmbedding, 3); - assertThat(relevant).hasSize(3); - assertThat(relevant.get(0).embedding()).isNotNull(); - assertThat(relevant.get(0).embedded().text()).isIn(content1, content3, content5); - logger.info("#1 relevant item: {}", relevant.get(0).embedded().text()); - assertThat(relevant.get(1).embedding()).isNotNull(); - assertThat(relevant.get(1).embedded().text()).isIn(content1, content3, content5); - logger.info("#2 relevant item: {}", relevant.get(1).embedded().text()); - assertThat(relevant.get(2).embedding()).isNotNull(); - assertThat(relevant.get(2).embedded().text()).isIn(content1, content3, content5); - logger.info("#3 relevant item: {}", relevant.get(2).embedded().text()); - - safeDeleteDatabase(database); - safeClose(client); } @Override @@ -108,38 +66,6 @@ class AzureCosmosDbNoSqlEmbeddingStoreIT extends EmbeddingStoreIT { return embeddingModel; } - @Override - protected void awaitUntilPersisted() { - try { - Thread.sleep(1_000); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } - - @Override - protected void clearStore() { - } - - private void safeDeleteDatabase(CosmosDatabase database) { - if (database != null) { - try { - database.delete(); - } catch (Exception e) { - } - } - } - - private void safeClose(CosmosClient client) { - if (client != null) { - try { - client.close(); - } catch (Exception e) { - logger.error("failed to close client", e); - } - } - } - private CosmosVectorEmbeddingPolicy populateVectorEmbeddingPolicy(int dimensions) { CosmosVectorEmbeddingPolicy vectorEmbeddingPolicy = new CosmosVectorEmbeddingPolicy(); CosmosVectorEmbedding embedding = new CosmosVectorEmbedding(); @@ -174,6 +100,4 @@ class AzureCosmosDbNoSqlEmbeddingStoreIT extends EmbeddingStoreIT { collectionDefinition.setIndexingPolicy(indexingPolicy); return collectionDefinition; } - - } diff --git a/langchain4j-cassandra/pom.xml b/langchain4j-cassandra/pom.xml index e1b6e4d23..82f86dab0 100644 --- a/langchain4j-cassandra/pom.xml +++ b/langchain4j-cassandra/pom.xml @@ -112,6 +112,12 @@ test + + org.awaitility + awaitility + test + + diff --git a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/astradb/AstraDbEmbeddingStoreIT.java b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/astradb/AstraDbEmbeddingStoreIT.java index b5e59d9e0..6dbd685ce 100644 --- a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/astradb/AstraDbEmbeddingStoreIT.java +++ b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/astradb/AstraDbEmbeddingStoreIT.java @@ -3,30 +3,24 @@ package dev.langchain4j.store.embedding.astradb; import com.dtsx.astra.sdk.AstraDB; import com.dtsx.astra.sdk.AstraDBAdmin; import com.dtsx.astra.sdk.AstraDBCollection; -import dev.langchain4j.data.document.Metadata; -import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.openai.OpenAiEmbeddingModel; import dev.langchain4j.model.openai.OpenAiModelName; -import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreIT; import io.stargate.sdk.data.domain.SimilarityMetric; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.MethodOrderer; import org.junit.jupiter.api.TestMethodOrder; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import java.util.List; import java.util.UUID; import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken; import static org.assertj.core.api.Assertions.assertThat; -@Disabled("AstraDB is not available in the CI") @TestMethodOrder(MethodOrderer.OrderAnnotation.class) @EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*") @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "sk.*") @@ -81,25 +75,4 @@ class AstraDbEmbeddingStoreIT extends EmbeddingStoreIT { } return embeddingModel; } - - void testAddEmbeddingAndFindRelevant() { - - Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F}); - TextSegment textSegment = TextSegment.from("Text", Metadata.from("Key", "Value")); - String id = embeddingStore.add(embedding, textSegment); - assertThat(id != null && !id.isEmpty()).isTrue(); - - Embedding referenceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F}); - List> embeddingMatches = embeddingStore.findRelevant(referenceEmbedding, 1); - assertThat(embeddingMatches).hasSize(1); - - EmbeddingMatch embeddingMatch = embeddingMatches.get(0); - assertThat(embeddingMatch.score()).isBetween(0d, 1d); - assertThat(embeddingMatch.embeddingId()).isEqualTo(id); - assertThat(embeddingMatch.embedding()).isEqualTo(embedding); - assertThat(embeddingMatch.embedded()).isEqualTo(textSegment); - } - - - } diff --git a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreAstraIT.java b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreAstraIT.java index c14b3d2f6..6d040f2c6 100644 --- a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreAstraIT.java +++ b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreAstraIT.java @@ -9,6 +9,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import java.util.UUID; +import static com.dtsx.astra.sdk.cassio.CassandraSimilarityMetric.COSINE; import static com.dtsx.astra.sdk.utils.TestUtils.TEST_REGION; import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken; @@ -36,12 +37,10 @@ class CassandraEmbeddingStoreAstraIT extends CassandraEmbeddingStoreIT { .databaseRegion(TEST_REGION) .keyspace(KEYSPACE) .table(TEST_INDEX) - .dimension(embeddingModelDimension()) // openai model - .metric(CassandraSimilarityMetric.COSINE) + .dimension(embeddingModel().dimension()) + .metric(COSINE) .build(); } return embeddingStore; } - - } diff --git a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreDockerIT.java b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreDockerIT.java index a39107e26..5fd6fdf02 100644 --- a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreDockerIT.java +++ b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreDockerIT.java @@ -1,12 +1,10 @@ package dev.langchain4j.store.embedding.cassandra; import com.datastax.oss.driver.api.core.CqlSession; -import com.dtsx.astra.sdk.cassio.CassandraSimilarityMetric; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.EmbeddingStore; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Disabled; import org.testcontainers.DockerClientFactory; import org.testcontainers.containers.CassandraContainer; import org.testcontainers.junit.jupiter.Testcontainers; @@ -15,10 +13,11 @@ import org.testcontainers.utility.DockerImageName; import java.net.InetSocketAddress; import java.util.Collections; +import static com.dtsx.astra.sdk.cassio.CassandraSimilarityMetric.COSINE; + /** * Work with Cassandra Embedding Store. */ -@Disabled("No Docker in the CI") @Testcontainers class CassandraEmbeddingStoreDockerIT extends CassandraEmbeddingStoreIT { @@ -55,7 +54,7 @@ class CassandraEmbeddingStoreDockerIT extends CassandraEmbeddingStoreIT { * Stop Cassandra Node */ @AfterAll - static void afterTests() throws Exception { + static void afterTests() { cassandraContainer.stop(); } @@ -69,11 +68,10 @@ class CassandraEmbeddingStoreDockerIT extends CassandraEmbeddingStoreIT { .localDataCenter(DATACENTER) .keyspace(KEYSPACE) .table(TEST_INDEX) - .dimension(embeddingModelDimension()) - .metric(CassandraSimilarityMetric.COSINE) + .dimension(embeddingModel().dimension()) + .metric(COSINE) .build(); } return embeddingStore; } - } diff --git a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreIT.java b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreIT.java index d24287639..d79d9d387 100644 --- a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreIT.java +++ b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStoreIT.java @@ -4,22 +4,17 @@ import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; -import dev.langchain4j.model.openai.OpenAiEmbeddingModel; -import dev.langchain4j.model.openai.OpenAiModelName; -import dev.langchain4j.store.embedding.CosineSimilarity; +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingStoreIT; -import dev.langchain4j.store.embedding.RelevanceScore; import lombok.extern.slf4j.Slf4j; +import org.assertj.core.data.Percentage; import org.junit.jupiter.api.MethodOrderer; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestMethodOrder; -import java.time.Duration; import java.util.List; -import static dev.langchain4j.internal.Utils.randomUUID; -import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.data.Percentage.withPercentage; @@ -28,26 +23,17 @@ import static org.assertj.core.data.Percentage.withPercentage; abstract class CassandraEmbeddingStoreIT extends EmbeddingStoreIT { protected static final String KEYSPACE = "langchain4j"; - protected static final String TEST_INDEX = "test_embedding_store"; CassandraEmbeddingStore embeddingStore; - EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder() - .apiKey(System.getenv("OPENAI_API_KEY")) - .modelName(OpenAiModelName.TEXT_EMBEDDING_ADA_002) - .timeout(Duration.ofSeconds(15)) - .build(); + private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); @Override protected EmbeddingModel embeddingModel() { return embeddingModel; } - protected int embeddingModelDimension() { - return 1536; - } - /** * It is required to clean the repository in between tests */ @@ -57,19 +43,16 @@ abstract class CassandraEmbeddingStoreIT extends EmbeddingStoreIT { } @Override - public void awaitUntilPersisted() { - try { - Thread.sleep(1000); - } catch(Exception e) { - } + protected Percentage percentage() { + return withPercentage(6); // TODO figure out why difference is so big } @Test void should_retrieve_inserted_vector_by_ann() { - String sourceSentence = "Testing is doubting !"; - Embedding sourceEmbedding = embeddingModel().embed(sourceSentence).content(); + String sourceSentence = "Testing is doubting !"; + Embedding sourceEmbedding = embeddingModel().embed(sourceSentence).content(); TextSegment sourceTextSegment = TextSegment.from(sourceSentence); - String id = embeddingStore().add(sourceEmbedding, sourceTextSegment); + String id = embeddingStore().add(sourceEmbedding, sourceTextSegment); assertThat(id != null && !id.isEmpty()).isTrue(); List> embeddingMatches = embeddingStore.findRelevant(sourceEmbedding, 10); @@ -84,12 +67,12 @@ abstract class CassandraEmbeddingStoreIT extends EmbeddingStoreIT { @Test void should_retrieve_inserted_vector_by_ann_and_metadata() { - String sourceSentence = "In GOD we trust, everything else we test!"; - Embedding sourceEmbedding = embeddingModel().embed(sourceSentence).content(); + String sourceSentence = "In GOD we trust, everything else we test!"; + Embedding sourceEmbedding = embeddingModel().embed(sourceSentence).content(); TextSegment sourceTextSegment = TextSegment.from(sourceSentence, new Metadata() .put("user", "GOD") .put("test", "false")); - String id = embeddingStore().add(sourceEmbedding, sourceTextSegment); + String id = embeddingStore().add(sourceEmbedding, sourceTextSegment); assertThat(id != null && !id.isEmpty()).isTrue(); // Should be found with no filter @@ -106,144 +89,4 @@ abstract class CassandraEmbeddingStoreIT extends EmbeddingStoreIT { .findRelevant(sourceEmbedding, 10, .5d, Metadata.from("user", "JOHN")); assertThat(matchesJohn).isEmpty(); } - - // metrics returned are 1.95% off we updated to "withPercentage(2)" - - @Test - void should_return_correct_score() { - Embedding embedding = embeddingModel().embed("hello").content(); - String id = embeddingStore().add(embedding); - assertThat(id).isNotBlank(); - Embedding referenceEmbedding = embeddingModel().embed("hi").content(); - List> relevant = embeddingStore().findRelevant(referenceEmbedding, 1); - assertThat(relevant).hasSize(1); - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo( - RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)), - withPercentage(2) - ); - } - - @Test - void should_find_with_min_score() { - String firstId = randomUUID(); - Embedding firstEmbedding = embeddingModel().embed("hello").content(); - embeddingStore().add(firstId, firstEmbedding); - String secondId = randomUUID(); - Embedding secondEmbedding = embeddingModel().embed("hi").content(); - embeddingStore().add(secondId, secondEmbedding); - List> relevant = embeddingStore().findRelevant(firstEmbedding, 10); - assertThat(relevant).hasSize(2); - EmbeddingMatch firstMatch = relevant.get(0); - assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); - assertThat(firstMatch.embeddingId()).isEqualTo(firstId); - EmbeddingMatch secondMatch = relevant.get(1); - assertThat(secondMatch.score()).isCloseTo( - RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)), - withPercentage(2) - ); - assertThat(secondMatch.embeddingId()).isEqualTo(secondId); - - List> relevant2 = embeddingStore().findRelevant( - firstEmbedding, - 10, - secondMatch.score() - 0.01 - ); - assertThat(relevant2).hasSize(2); - assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId); - assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId); - - List> relevant3 = embeddingStore().findRelevant( - firstEmbedding, - 10, - secondMatch.score() - ); - assertThat(relevant3).hasSize(2); - assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId); - assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId); - - List> relevant4 = embeddingStore().findRelevant( - firstEmbedding, - 10, - secondMatch.score() + 0.01 - ); - assertThat(relevant4).hasSize(1); - assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId); - } - - @Test - void should_add_multiple_embeddings_with_segments() { - - TextSegment firstSegment = TextSegment.from("hello"); - Embedding firstEmbedding = embeddingModel().embed(firstSegment.text()).content(); - - TextSegment secondSegment = TextSegment.from("hi"); - Embedding secondEmbedding = embeddingModel().embed(secondSegment.text()).content(); - - List ids = embeddingStore().addAll( - asList(firstEmbedding, secondEmbedding), - asList(firstSegment, secondSegment) - ); - assertThat(ids).hasSize(2); - assertThat(ids.get(0)).isNotBlank(); - assertThat(ids.get(1)).isNotBlank(); - assertThat(ids.get(0)).isNotEqualTo(ids.get(1)); - - awaitUntilPersisted(); - - List> relevant = embeddingStore().findRelevant(firstEmbedding, 10); - assertThat(relevant).hasSize(2); - - EmbeddingMatch firstMatch = relevant.get(0); - assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); - assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); - assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); - assertThat(firstMatch.embedded()).isEqualTo(firstSegment); - - EmbeddingMatch secondMatch = relevant.get(1); - assertThat(secondMatch.score()).isCloseTo( - RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)), - withPercentage(2) - ); - assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); - assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); - assertThat(secondMatch.embedded()).isEqualTo(secondSegment); - } - - - - @Test - void should_add_multiple_embeddings() { - - Embedding firstEmbedding = embeddingModel().embed("hello").content(); - Embedding secondEmbedding = embeddingModel().embed("hi").content(); - - List ids = embeddingStore().addAll(asList(firstEmbedding, secondEmbedding)); - assertThat(ids).hasSize(2); - assertThat(ids.get(0)).isNotBlank(); - assertThat(ids.get(1)).isNotBlank(); - assertThat(ids.get(0)).isNotEqualTo(ids.get(1)); - - awaitUntilPersisted(); - - List> relevant = embeddingStore().findRelevant(firstEmbedding, 10); - assertThat(relevant).hasSize(2); - - EmbeddingMatch firstMatch = relevant.get(0); - assertThat(firstMatch.score()).isCloseTo(1, withPercentage(2)); - assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); - assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); - assertThat(firstMatch.embedded()).isNull(); - - EmbeddingMatch secondMatch = relevant.get(1); - assertThat(secondMatch.score()).isCloseTo( - RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)), - withPercentage(2) - ); - assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); - assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); - assertThat(secondMatch.embedded()).isNull(); - } - - } diff --git a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/memory/chat/cassandra/CassandraChatMemoryStoreDockerIT.java b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/memory/chat/cassandra/CassandraChatMemoryStoreDockerIT.java index cd8e3cb95..443892a1e 100644 --- a/langchain4j-cassandra/src/test/java/dev/langchain4j/store/memory/chat/cassandra/CassandraChatMemoryStoreDockerIT.java +++ b/langchain4j-cassandra/src/test/java/dev/langchain4j/store/memory/chat/cassandra/CassandraChatMemoryStoreDockerIT.java @@ -3,7 +3,6 @@ package dev.langchain4j.store.memory.chat.cassandra; import com.datastax.oss.driver.api.core.CqlSession; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Disabled; import org.testcontainers.DockerClientFactory; import org.testcontainers.containers.CassandraContainer; import org.testcontainers.junit.jupiter.Testcontainers; @@ -14,7 +13,6 @@ import java.net.InetSocketAddress; /** * Test Cassandra Chat Memory Store with a Saas DB. */ -@Disabled("No Docker in the CI") @Testcontainers class CassandraChatMemoryStoreDockerIT extends CassandraChatMemoryStoreTestSupport { static final String DATACENTER = "datacenter1"; @@ -44,8 +42,8 @@ class CassandraChatMemoryStoreDockerIT extends CassandraChatMemoryStoreTestSuppo .addContactPoint(contactPoint) .withLocalDatacenter(DATACENTER) .build().execute( - "CREATE KEYSPACE IF NOT EXISTS " + KEYSPACE + - " WITH replication = {'class':'SimpleStrategy', 'replication_factor':'1'};"); + "CREATE KEYSPACE IF NOT EXISTS " + KEYSPACE + + " WITH replication = {'class':'SimpleStrategy', 'replication_factor':'1'};"); return new CassandraChatMemoryStore(CqlSession.builder() .addContactPoint(contactPoint) .withLocalDatacenter(DATACENTER) @@ -54,8 +52,7 @@ class CassandraChatMemoryStoreDockerIT extends CassandraChatMemoryStoreTestSuppo } @AfterAll - static void afterTests() throws Exception { + static void afterTests() { cassandraContainer.stop(); } - } diff --git a/langchain4j-chroma/src/test/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStoreIT.java b/langchain4j-chroma/src/test/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStoreIT.java index ee116aadb..4d3574e91 100644 --- a/langchain4j-chroma/src/test/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStoreIT.java +++ b/langchain4j-chroma/src/test/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStoreIT.java @@ -1,10 +1,5 @@ package dev.langchain4j.store.embedding.chroma; -import static dev.langchain4j.internal.Utils.randomUUID; -import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey; -import static java.util.Arrays.asList; -import static java.util.Collections.singletonList; - import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; @@ -12,28 +7,37 @@ import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2Quantize import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT; import dev.langchain4j.store.embedding.filter.Filter; -import java.util.List; -import java.util.UUID; -import java.util.stream.Stream; -import org.junit.Ignore; +import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan; +import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsLessThan; +import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo; +import dev.langchain4j.store.embedding.filter.logical.Not; import org.junit.jupiter.params.provider.Arguments; import org.testcontainers.chromadb.ChromaDBContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; + +import static dev.langchain4j.internal.Utils.randomUUID; +import static java.util.stream.Collectors.toList; +import static org.assertj.core.api.Assertions.assertThat; + @Testcontainers class ChromaEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT { @Container private static final ChromaDBContainer chroma = new ChromaDBContainer("chromadb/chroma:0.5.4"); - EmbeddingStore embeddingStore = ChromaEmbeddingStore - .builder() - .baseUrl(chroma.getEndpoint()) - .collectionName(randomUUID()) - .logRequests(true) - .logResponses(true) - .build(); + EmbeddingStore embeddingStore = ChromaEmbeddingStore.builder() + .baseUrl(chroma.getEndpoint()) + .collectionName(randomUUID()) + .logRequests(true) + .logResponses(true) + .build(); EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); @@ -47,182 +51,65 @@ class ChromaEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT { return embeddingModel; } - @Override - @Ignore("Chroma cannot filter by greater and less than of alphanumeric metadata, only int and float are supported") - protected void should_filter_by_greater_and_less_than_alphanumeric_metadata( - Filter metadataFilter, - List matchingMetadatas, - List notMatchingMetadatas - ) {} + // in chroma compare filter only works with numbers + protected static Stream should_filter_by_metadata() { + return EmbeddingStoreWithFilteringIT.should_filter_by_metadata() + .filter(arguments -> { + Filter filter = (Filter) arguments.get()[0]; + if (filter instanceof IsLessThan) { + return ((IsLessThan) filter).comparisonValue() instanceof Number; + } else if (filter instanceof IsLessThanOrEqualTo) { + return ((IsLessThanOrEqualTo) filter).comparisonValue() instanceof Number; + } else if (filter instanceof IsGreaterThan) { + return ((IsGreaterThan) filter).comparisonValue() instanceof Number; + } else if (filter instanceof IsGreaterThanOrEqualTo) { + return ((IsGreaterThanOrEqualTo) filter).comparisonValue() instanceof Number; + } else { + return true; + } + } + ); + } // Chroma filters by *not* as following: // If you filter by "key" not equals "a", then in fact all items with "key" != "a" value are returned, but no items // without "key" metadata! // Therefore, all default *not* tests coming from parent class have to be rewritten here. protected static Stream should_filter_by_metadata_not() { - return Stream - .builder() - // === NotEqual === + return EmbeddingStoreWithFilteringIT.should_filter_by_metadata_not() + .map(args -> { + Object[] arguments = args.get(); + Filter filter = (Filter) arguments[0]; - .add( - Arguments.of( - metadataKey("key").isNotEqualTo("a"), - asList( - new Metadata().put("key", "A"), - new Metadata().put("key", "b"), - new Metadata().put("key", "aa"), - new Metadata().put("key", "a a") - ), - asList( - new Metadata().put("key", "a"), - new Metadata().put("key2", "a"), - new Metadata().put("key", "a").put("key2", "b") - ), - false - ) - ) - .add( - Arguments.of( - metadataKey("key").isNotEqualTo(TEST_UUID), - asList(new Metadata().put("key", UUID.randomUUID())), - asList( - new Metadata().put("key", TEST_UUID), - new Metadata().put("key2", TEST_UUID), - new Metadata().put("key", TEST_UUID).put("key2", UUID.randomUUID()) - ), - false - ) - ) - .add( - Arguments.of( - metadataKey("key").isNotEqualTo(1), - asList( - new Metadata().put("key", -1), - new Metadata().put("key", 0), - new Metadata().put("key", 2), - new Metadata().put("key", 10) - ), - asList( - new Metadata().put("key", 1), - new Metadata().put("key2", 1), - new Metadata().put("key", 1).put("key2", 2) - ), - false - ) - ) - .add( - Arguments.of( - metadataKey("key").isNotEqualTo(1.1f), - asList( - new Metadata().put("key", -1.1f), - new Metadata().put("key", 0.0f), - new Metadata().put("key", 1.11f), - new Metadata().put("key", 2.2f) - ), - asList( - new Metadata().put("key", 1.1f), - new Metadata().put("key2", 1.1f), - new Metadata().put("key", 1.1f).put("key2", 2.2f) - ), - false - ) - ) - // === NotIn === + String key = getMetadataKey(filter); - // NotIn: string - .add( - Arguments.of( - metadataKey("name").isNotIn("Klaus"), - asList(new Metadata().put("name", "Klaus Heisler"), new Metadata().put("name", "Alice")), - asList( - new Metadata().put("name", "Klaus"), - new Metadata().put("name2", "Klaus"), - new Metadata().put("name", "Klaus").put("age", 42) - ), - false - ) - ) - .add( - Arguments.of( - metadataKey("name").isNotIn(singletonList("Klaus")), - asList(new Metadata().put("name", "Klaus Heisler"), new Metadata().put("name", "Alice")), - asList( - new Metadata().put("name", "Klaus"), - new Metadata().put("name2", "Klaus"), - new Metadata().put("name", "Klaus").put("age", 42) - ), - false - ) - ) - .add( - Arguments.of( - metadataKey("name").isNotIn("Klaus", "Alice"), - asList(new Metadata().put("name", "Klaus Heisler"), new Metadata().put("name", "Zoe")), - asList( - new Metadata().put("name", "Klaus"), - new Metadata().put("name2", "Klaus"), - new Metadata().put("name", "Klaus").put("age", 42), - new Metadata().put("name", "Alice"), - new Metadata().put("name", "Alice").put("age", 42) - ), - false - ) - ) - // NotIn: UUID - .add( - Arguments.of( - metadataKey("name").isNotIn(TEST_UUID), - asList(new Metadata().put("name", UUID.randomUUID())), - asList( - new Metadata().put("name", TEST_UUID), - new Metadata().put("name2", TEST_UUID), - new Metadata().put("name", TEST_UUID).put("age", 42) - ), - false - ) - ) - // NotIn: int - .add( - Arguments.of( - metadataKey("age").isNotIn(42), - asList(new Metadata().put("age", 666)), - asList( - new Metadata().put("age", 42), - new Metadata().put("age2", 42), - new Metadata().put("age", 42).put("name", "Klaus") - ), - false - ) - ) - .add( - Arguments.of( - metadataKey("age").isNotIn(42, 18), - asList(new Metadata().put("age", 666)), - asList( - new Metadata().put("age", 42), - new Metadata().put("age", 18), - new Metadata().put("age2", 42), - new Metadata().put("age", 42).put("name", "Klaus"), - new Metadata().put("age", 18).put("name", "Klaus") - ), - false - ) - ) - // NotIn: float - .add( - Arguments.of( - metadataKey("age").isNotIn(asList(42.0f, 18.0f)), - asList(new Metadata().put("age", 666.0f)), - asList( - new Metadata().put("age", 42.0f), - new Metadata().put("age", 18.0f), - new Metadata().put("age2", 42.0f), - new Metadata().put("age", 42.0f).put("name", "Klaus"), - new Metadata().put("age", 18.0f).put("name", "Klaus") - ), - false - ) - ) - .build(); + List matchingMetadatas = (List) arguments[1]; + List newMatchingMetadatas = matchingMetadatas.stream() + .filter(metadata -> metadata.containsKey(key)) + .collect(toList()); + + List notMatchingMetadatas = (List) arguments[2]; + List newNotMatchingMetadatas = new ArrayList<>(notMatchingMetadatas); + newNotMatchingMetadatas.addAll(matchingMetadatas.stream() + .filter(metadata -> !metadata.containsKey(key)) + .collect(toList())); + + assertThat(Stream.concat(newMatchingMetadatas.stream(), newNotMatchingMetadatas.stream())) + .containsExactlyInAnyOrderElementsOf(Stream.concat(matchingMetadatas.stream(), notMatchingMetadatas.stream()).collect(toList())); + + return Arguments.of(filter, newMatchingMetadatas, newNotMatchingMetadatas); + }); + } + + private static String getMetadataKey(Filter filter) { + try { + if (filter instanceof Not) { + filter = ((Not) filter).expression(); + } + Method method = filter.getClass().getMethod("key"); + return (String) method.invoke(filter); + } catch (Exception e) { + throw new RuntimeException(e); + } } } diff --git a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreIT.java b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreIT.java index 616b26a0e..02e0d398f 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreIT.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreIT.java @@ -1,14 +1,15 @@ package dev.langchain4j.store.embedding; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.data.Percentage.withPercentage; - import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; +import org.junit.jupiter.api.Test; + import java.util.List; import java.util.UUID; -import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.data.Percentage.withPercentage; /** * A minimum set of tests that each implementation of {@link EmbeddingStore} must pass. @@ -20,6 +21,7 @@ public abstract class EmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT { @Test void should_add_embedding_with_segment_with_metadata() { + Metadata metadata = createMetadata(); TextSegment segment = TextSegment.from("hello", metadata); @@ -28,22 +30,19 @@ public abstract class EmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT { String id = embeddingStore().add(embedding, segment); assertThat(id).isNotBlank(); - { - // Not returned. - TextSegment altSegment = TextSegment.from("hello?"); - Embedding altEmbedding = embeddingModel().embed(altSegment.text()).content(); - embeddingStore().add(altEmbedding, altSegment); - } - - awaitUntilPersisted(); + awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(1)); + // when List> relevant = embeddingStore().findRelevant(embedding, 1); - assertThat(relevant).hasSize(1); + // then + assertThat(relevant).hasSize(1); EmbeddingMatch match = relevant.get(0); assertThat(match.score()).isCloseTo(1, withPercentage(1)); assertThat(match.embeddingId()).isEqualTo(id); - assertThat(match.embedding()).isEqualTo(embedding); + if (assertEmbedding()) { + assertThat(match.embedding()).isEqualTo(embedding); + } assertThat(match.embedded().text()).isEqualTo(segment.text()); @@ -78,15 +77,14 @@ public abstract class EmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT { assertThat(match.embedded().metadata().getDouble("double_123")).isEqualTo(1.23456789d); // new API - assertThat( - embeddingStore() - .search(EmbeddingSearchRequest.builder().queryEmbedding(embedding).maxResults(1).build()) - .matches() - ) - .isEqualTo(relevant); + assertThat(embeddingStore().search(EmbeddingSearchRequest.builder() + .queryEmbedding(embedding) + .maxResults(1) + .build()).matches()).isEqualTo(relevant); } protected Metadata createMetadata() { + Metadata metadata = new Metadata(); metadata.put("string_empty", ""); diff --git a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithFilteringIT.java b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithFilteringIT.java index 30f6e0922..1bf5343a6 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithFilteringIT.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithFilteringIT.java @@ -1,5 +1,18 @@ package dev.langchain4j.store.embedding; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.filter.Filter; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.stream.Stream; + import static dev.langchain4j.store.embedding.filter.Filter.and; import static dev.langchain4j.store.embedding.filter.Filter.not; import static dev.langchain4j.store.embedding.filter.Filter.or; @@ -9,18 +22,6 @@ import static java.util.Collections.singletonList; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.data.Percentage.withPercentage; -import dev.langchain4j.data.document.Metadata; -import dev.langchain4j.data.embedding.Embedding; -import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.store.embedding.filter.Filter; -import java.util.ArrayList; -import java.util.List; -import java.util.UUID; -import java.util.stream.Stream; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; - /** * A minimum set of tests that each implementation of {@link EmbeddingStore} that supports {@link Filter} must pass. */ @@ -28,1038 +29,9 @@ public abstract class EmbeddingStoreWithFilteringIT extends EmbeddingStoreIT { @ParameterizedTest @MethodSource - protected void should_filter_by_metadata( - Filter metadataFilter, - List matchingMetadatas, - List notMatchingMetadatas - ) { - shouldFilter(metadataFilter, matchingMetadatas, notMatchingMetadatas); - } - - static Stream should_filter_by_metadata() { - return Stream - .builder() - // === Equal === - - .add( - Arguments.of( - metadataKey("key").isEqualTo("a"), - asList(new Metadata().put("key", "a"), new Metadata().put("key", "a").put("key2", "b")), - asList( - new Metadata().put("key", "A"), - new Metadata().put("key", "b"), - new Metadata().put("key", "aa"), - new Metadata().put("key", "a a"), - new Metadata().put("key2", "a") - ) - ) - ) - .add( - Arguments.of( - metadataKey("key").isEqualTo(TEST_UUID), - asList(new Metadata().put("key", TEST_UUID), new Metadata().put("key", TEST_UUID).put("key2", "b")), - asList(new Metadata().put("key", UUID.randomUUID()), new Metadata().put("key2", TEST_UUID)) - ) - ) - .add( - Arguments.of( - metadataKey("key").isEqualTo(1), - asList(new Metadata().put("key", 1), new Metadata().put("key", 1).put("key2", 0)), - asList(new Metadata().put("key", -1), new Metadata().put("key", 0), new Metadata().put("key2", 1)) - ) - ) - .add( - Arguments.of( - metadataKey("key").isEqualTo(1L), - asList(new Metadata().put("key", 1L), new Metadata().put("key", 1L).put("key2", 0L)), - asList( - new Metadata().put("key", -1L), - new Metadata().put("key", 0L), - new Metadata().put("key2", 1L) - ) - ) - ) - .add( - Arguments.of( - metadataKey("key").isEqualTo(1.23f), - asList(new Metadata().put("key", 1.23f), new Metadata().put("key", 1.23f).put("key2", 0f)), - asList( - new Metadata().put("key", -1.23f), - new Metadata().put("key", 1.22f), - new Metadata().put("key", 1.24f), - new Metadata().put("key2", 1.23f) - ) - ) - ) - .add( - Arguments.of( - metadataKey("key").isEqualTo(1.23d), - asList(new Metadata().put("key", 1.23d), new Metadata().put("key", 1.23d).put("key2", 0d)), - asList( - new Metadata().put("key", -1.23d), - new Metadata().put("key", 1.22d), - new Metadata().put("key", 1.24d), - new Metadata().put("key2", 1.23d) - ) - ) - ) - // === GreaterThan == - - .add( - Arguments.of( - metadataKey("key").isGreaterThan(1), - asList(new Metadata().put("key", 2), new Metadata().put("key", 2).put("key2", 0)), - asList( - new Metadata().put("key", -2), - new Metadata().put("key", 0), - new Metadata().put("key", 1), - new Metadata().put("key2", 2) - ) - ) - ) - .add( - Arguments.of( - metadataKey("key").isGreaterThan(1L), - asList(new Metadata().put("key", 2L), new Metadata().put("key", 2L).put("key2", 0L)), - asList( - new Metadata().put("key", -2L), - new Metadata().put("key", 0L), - new Metadata().put("key", 1L), - new Metadata().put("key2", 2L) - ) - ) - ) - .add( - Arguments.of( - metadataKey("key").isGreaterThan(1.1f), - asList(new Metadata().put("key", 1.2f), new Metadata().put("key", 1.2f).put("key2", 1.0f)), - asList( - new Metadata().put("key", -1.2f), - new Metadata().put("key", 0.0f), - new Metadata().put("key", 1.1f), - new Metadata().put("key2", 1.2f) - ) - ) - ) - .add( - Arguments.of( - metadataKey("key").isGreaterThan(1.1d), - asList(new Metadata().put("key", 1.2d), new Metadata().put("key", 1.2d).put("key2", 1.0d)), - asList( - new Metadata().put("key", -1.2d), - new Metadata().put("key", 0.0d), - new Metadata().put("key", 1.1d), - new Metadata().put("key2", 1.2d) - ) - ) - ) - // === GreaterThanOrEqual == - - .add( - Arguments.of( - metadataKey("key").isGreaterThanOrEqualTo(1), - asList( - new Metadata().put("key", 1), - new Metadata().put("key", 2), - new Metadata().put("key", 2).put("key2", 0) - ), - asList( - new Metadata().put("key", -2), - new Metadata().put("key", -1), - new Metadata().put("key", 0), - new Metadata().put("key2", 1), - new Metadata().put("key2", 2) - ) - ) - ) - .add( - Arguments.of( - metadataKey("key").isGreaterThanOrEqualTo(1L), - asList( - new Metadata().put("key", 1L), - new Metadata().put("key", 2L), - new Metadata().put("key", 2L).put("key2", 0L) - ), - asList( - new Metadata().put("key", -2L), - new Metadata().put("key", -1L), - new Metadata().put("key", 0L), - new Metadata().put("key2", 1L), - new Metadata().put("key2", 2L) - ) - ) - ) - .add( - Arguments.of( - metadataKey("key").isGreaterThanOrEqualTo(1.1f), - asList( - new Metadata().put("key", 1.1f), - new Metadata().put("key", 1.2f), - new Metadata().put("key", 1.2f).put("key2", 1.0f) - ), - asList( - new Metadata().put("key", -1.2f), - new Metadata().put("key", -1.1f), - new Metadata().put("key", 0.0f), - new Metadata().put("key2", 1.1f), - new Metadata().put("key2", 1.2f) - ) - ) - ) - .add( - Arguments.of( - metadataKey("key").isGreaterThanOrEqualTo(1.1d), - asList( - new Metadata().put("key", 1.1d), - new Metadata().put("key", 1.2d), - new Metadata().put("key", 1.2d).put("key2", 1.0d) - ), - asList( - new Metadata().put("key", -1.2d), - new Metadata().put("key", -1.1d), - new Metadata().put("key", 0.0d), - new Metadata().put("key2", 1.1d), - new Metadata().put("key2", 1.2d) - ) - ) - ) - // === LessThan == - - .add( - Arguments.of( - metadataKey("key").isLessThan(1), - asList( - new Metadata().put("key", -2), - new Metadata().put("key", 0), - new Metadata().put("key", 0).put("key2", 2) - ), - asList(new Metadata().put("key", 1), new Metadata().put("key", 2), new Metadata().put("key2", 0)) - ) - ) - .add( - Arguments.of( - metadataKey("key").isLessThan(1L), - asList( - new Metadata().put("key", -2L), - new Metadata().put("key", 0L), - new Metadata().put("key", 0L).put("key2", 2L) - ), - asList(new Metadata().put("key", 1L), new Metadata().put("key", 2L), new Metadata().put("key2", 0L)) - ) - ) - .add( - Arguments.of( - metadataKey("key").isLessThan(1.1f), - asList( - new Metadata().put("key", -1.2f), - new Metadata().put("key", 1.0f), - new Metadata().put("key", 1.0f).put("key2", 1.2f) - ), - asList( - new Metadata().put("key", 1.1f), - new Metadata().put("key", 1.2f), - new Metadata().put("key2", 1.0f) - ) - ) - ) - .add( - Arguments.of( - metadataKey("key").isLessThan(1.1d), - asList( - new Metadata().put("key", -1.2d), - new Metadata().put("key", 1.0d), - new Metadata().put("key", 1.0d).put("key2", 1.2d) - ), - asList( - new Metadata().put("key", 1.1d), - new Metadata().put("key", 1.2d), - new Metadata().put("key2", 1.0d) - ) - ) - ) - // === LessThanOrEqual == - - .add( - Arguments.of( - metadataKey("key").isLessThanOrEqualTo(1), - asList( - new Metadata().put("key", -2), - new Metadata().put("key", 0), - new Metadata().put("key", 1), - new Metadata().put("key", 1).put("key2", 2) - ), - asList(new Metadata().put("key", 2), new Metadata().put("key2", 0)) - ) - ) - .add( - Arguments.of( - metadataKey("key").isLessThanOrEqualTo(1L), - asList( - new Metadata().put("key", -2L), - new Metadata().put("key", 0L), - new Metadata().put("key", 1L), - new Metadata().put("key", 1L).put("key2", 2L) - ), - asList(new Metadata().put("key", 2L), new Metadata().put("key2", 0L)) - ) - ) - .add( - Arguments.of( - metadataKey("key").isLessThanOrEqualTo(1.1f), - asList( - new Metadata().put("key", -1.2f), - new Metadata().put("key", 1.0f), - new Metadata().put("key", 1.1f), - new Metadata().put("key", 1.1f).put("key2", 1.2f) - ), - asList(new Metadata().put("key", 1.2f), new Metadata().put("key2", 1.0f)) - ) - ) - .add( - Arguments.of( - metadataKey("key").isLessThanOrEqualTo(1.1d), - asList( - new Metadata().put("key", -1.2d), - new Metadata().put("key", 1.0d), - new Metadata().put("key", 1.1d), - new Metadata().put("key", 1.1d).put("key2", 1.2d) - ), - asList(new Metadata().put("key", 1.2d), new Metadata().put("key2", 1.0d)) - ) - ) - // === In === - - // In: string - .add( - Arguments.of( - metadataKey("name").isIn("Klaus"), - asList(new Metadata().put("name", "Klaus"), new Metadata().put("name", "Klaus").put("age", 42)), - asList( - new Metadata().put("name", "Klaus Heisler"), - new Metadata().put("name", "Alice"), - new Metadata().put("name2", "Klaus") - ) - ) - ) - .add( - Arguments.of( - metadataKey("name").isIn(singletonList("Klaus")), - asList(new Metadata().put("name", "Klaus"), new Metadata().put("name", "Klaus").put("age", 42)), - asList( - new Metadata().put("name", "Klaus Heisler"), - new Metadata().put("name", "Alice"), - new Metadata().put("name2", "Klaus") - ) - ) - ) - .add( - Arguments.of( - metadataKey("name").isIn("Klaus", "Alice"), - asList( - new Metadata().put("name", "Klaus"), - new Metadata().put("name", "Klaus").put("age", 42), - new Metadata().put("name", "Alice"), - new Metadata().put("name", "Alice").put("age", 42) - ), - asList( - new Metadata().put("name", "Klaus Heisler"), - new Metadata().put("name", "Zoe"), - new Metadata().put("name2", "Klaus") - ) - ) - ) - .add( - Arguments.of( - metadataKey("name").isIn(asList("Klaus", "Alice")), - asList( - new Metadata().put("name", "Klaus"), - new Metadata().put("name", "Klaus").put("age", 42), - new Metadata().put("name", "Alice"), - new Metadata().put("name", "Alice").put("age", 42) - ), - asList( - new Metadata().put("name", "Klaus Heisler"), - new Metadata().put("name", "Zoe"), - new Metadata().put("name2", "Klaus") - ) - ) - ) - // In: UUID - .add( - Arguments.of( - metadataKey("name").isIn(TEST_UUID), - asList(new Metadata().put("name", TEST_UUID), new Metadata().put("name", TEST_UUID).put("age", 42)), - asList(new Metadata().put("name", UUID.randomUUID()), new Metadata().put("name2", TEST_UUID)) - ) - ) - .add( - Arguments.of( - metadataKey("name").isIn(singletonList(TEST_UUID)), - asList(new Metadata().put("name", TEST_UUID), new Metadata().put("name", TEST_UUID).put("age", 42)), - asList(new Metadata().put("name", UUID.randomUUID()), new Metadata().put("name2", TEST_UUID)) - ) - ) - .add( - Arguments.of( - metadataKey("name").isIn(TEST_UUID, TEST_UUID2), - asList( - new Metadata().put("name", TEST_UUID), - new Metadata().put("name", TEST_UUID).put("age", 42), - new Metadata().put("name", TEST_UUID2), - new Metadata().put("name", TEST_UUID2).put("age", 42) - ), - asList(new Metadata().put("name", UUID.randomUUID()), new Metadata().put("name2", TEST_UUID)) - ) - ) - .add( - Arguments.of( - metadataKey("name").isIn(asList(TEST_UUID, TEST_UUID2)), - asList( - new Metadata().put("name", TEST_UUID), - new Metadata().put("name", TEST_UUID).put("age", 42), - new Metadata().put("name", TEST_UUID2), - new Metadata().put("name", TEST_UUID2).put("age", 42) - ), - asList(new Metadata().put("name", UUID.randomUUID()), new Metadata().put("name2", TEST_UUID)) - ) - ) - // In: integer - .add( - Arguments.of( - metadataKey("age").isIn(42), - asList(new Metadata().put("age", 42), new Metadata().put("age", 42).put("name", "Klaus")), - asList(new Metadata().put("age", 666), new Metadata().put("age2", 42)) - ) - ) - .add( - Arguments.of( - metadataKey("age").isIn(singletonList(42)), - asList(new Metadata().put("age", 42), new Metadata().put("age", 42).put("name", "Klaus")), - asList(new Metadata().put("age", 666), new Metadata().put("age2", 42)) - ) - ) - .add( - Arguments.of( - metadataKey("age").isIn(42, 18), - asList( - new Metadata().put("age", 42), - new Metadata().put("age", 18), - new Metadata().put("age", 42).put("name", "Klaus"), - new Metadata().put("age", 18).put("name", "Klaus") - ), - asList(new Metadata().put("age", 666), new Metadata().put("age2", 42)) - ) - ) - .add( - Arguments.of( - metadataKey("age").isIn(asList(42, 18)), - asList( - new Metadata().put("age", 42), - new Metadata().put("age", 18), - new Metadata().put("age", 42).put("name", "Klaus"), - new Metadata().put("age", 18).put("name", "Klaus") - ), - asList(new Metadata().put("age", 666), new Metadata().put("age2", 42)) - ) - ) - // In: long - .add( - Arguments.of( - metadataKey("age").isIn(42L), - asList(new Metadata().put("age", 42L), new Metadata().put("age", 42L).put("name", "Klaus")), - asList(new Metadata().put("age", 666L), new Metadata().put("age2", 42L)) - ) - ) - .add( - Arguments.of( - metadataKey("age").isIn(singletonList(42L)), - asList(new Metadata().put("age", 42L), new Metadata().put("age", 42L).put("name", "Klaus")), - asList(new Metadata().put("age", 666L), new Metadata().put("age2", 42L)) - ) - ) - .add( - Arguments.of( - metadataKey("age").isIn(42L, 18L), - asList( - new Metadata().put("age", 42L), - new Metadata().put("age", 18L), - new Metadata().put("age", 42L).put("name", "Klaus"), - new Metadata().put("age", 18L).put("name", "Klaus") - ), - asList(new Metadata().put("age", 666L), new Metadata().put("age2", 42L)) - ) - ) - .add( - Arguments.of( - metadataKey("age").isIn(asList(42L, 18L)), - asList( - new Metadata().put("age", 42L), - new Metadata().put("age", 18L), - new Metadata().put("age", 42L).put("name", "Klaus"), - new Metadata().put("age", 18L).put("name", "Klaus") - ), - asList(new Metadata().put("age", 666L), new Metadata().put("age2", 42L)) - ) - ) - // In: float - .add( - Arguments.of( - metadataKey("age").isIn(42.0f), - asList(new Metadata().put("age", 42.0f), new Metadata().put("age", 42.0f).put("name", "Klaus")), - asList(new Metadata().put("age", 666.0f), new Metadata().put("age2", 42.0f)) - ) - ) - .add( - Arguments.of( - metadataKey("age").isIn(singletonList(42.0f)), - asList(new Metadata().put("age", 42.0f), new Metadata().put("age", 42.0f).put("name", "Klaus")), - asList(new Metadata().put("age", 666.0f), new Metadata().put("age2", 42.0f)) - ) - ) - .add( - Arguments.of( - metadataKey("age").isIn(42.0f, 18.0f), - asList( - new Metadata().put("age", 42.0f), - new Metadata().put("age", 18.0f), - new Metadata().put("age", 42.0f).put("name", "Klaus"), - new Metadata().put("age", 18.0f).put("name", "Klaus") - ), - asList(new Metadata().put("age", 666.0f), new Metadata().put("age2", 42.0f)) - ) - ) - .add( - Arguments.of( - metadataKey("age").isIn(asList(42.0f, 18.0f)), - asList( - new Metadata().put("age", 42.0f), - new Metadata().put("age", 18.0f), - new Metadata().put("age", 42.0f).put("name", "Klaus"), - new Metadata().put("age", 18.0f).put("name", "Klaus") - ), - asList(new Metadata().put("age", 666.0f), new Metadata().put("age2", 42.0f)) - ) - ) - // In: double - .add( - Arguments.of( - metadataKey("age").isIn(42.0d), - asList(new Metadata().put("age", 42.0d), new Metadata().put("age", 42.0d).put("name", "Klaus")), - asList(new Metadata().put("age", 666.0d), new Metadata().put("age2", 42.0d)) - ) - ) - .add( - Arguments.of( - metadataKey("age").isIn(singletonList(42.0d)), - asList(new Metadata().put("age", 42.0d), new Metadata().put("age", 42.0d).put("name", "Klaus")), - asList(new Metadata().put("age", 666.0d), new Metadata().put("age2", 42.0d)) - ) - ) - .add( - Arguments.of( - metadataKey("age").isIn(42.0d, 18.0d), - asList( - new Metadata().put("age", 42.0d), - new Metadata().put("age", 18.0d), - new Metadata().put("age", 42.0d).put("name", "Klaus"), - new Metadata().put("age", 18.0d).put("name", "Klaus") - ), - asList(new Metadata().put("age", 666.0d), new Metadata().put("age2", 42.0d)) - ) - ) - .add( - Arguments.of( - metadataKey("age").isIn(asList(42.0d, 18.0d)), - asList( - new Metadata().put("age", 42.0d), - new Metadata().put("age", 18.0d), - new Metadata().put("age", 42.0d).put("name", "Klaus"), - new Metadata().put("age", 18.0d).put("name", "Klaus") - ), - asList(new Metadata().put("age", 666.0d), new Metadata().put("age2", 42.0d)) - ) - ) - // === Or === - - // Or: one key - .add( - Arguments.of( - or(metadataKey("name").isEqualTo("Klaus"), metadataKey("name").isEqualTo("Alice")), - asList( - new Metadata().put("name", "Klaus"), - new Metadata().put("name", "Klaus").put("age", 42), - new Metadata().put("name", "Alice"), - new Metadata().put("name", "Alice").put("age", 42) - ), - singletonList(new Metadata().put("name", "Zoe")) - ) - ) - .add( - Arguments.of( - or(metadataKey("name").isEqualTo("Alice"), metadataKey("name").isEqualTo("Klaus")), - asList( - new Metadata().put("name", "Alice"), - new Metadata().put("name", "Alice").put("age", 42), - new Metadata().put("name", "Klaus"), - new Metadata().put("name", "Klaus").put("age", 42) - ), - singletonList(new Metadata().put("name", "Zoe")) - ) - ) - // Or: multiple keys - .add( - Arguments.of( - or(metadataKey("name").isEqualTo("Klaus"), metadataKey("age").isEqualTo(42)), - asList( - // only Or.left is present and true - new Metadata().put("name", "Klaus"), - new Metadata().put("name", "Klaus").put("city", "Munich"), - // Or.left is true, Or.right is false - new Metadata().put("name", "Klaus").put("age", 666), - // only Or.right is present and true - new Metadata().put("age", 42), - new Metadata().put("age", 42).put("city", "Munich"), - // Or.right is true, Or.left is false - new Metadata().put("age", 42).put("name", "Alice"), - // Or.left and Or.right are both true - new Metadata().put("name", "Klaus").put("age", 42), - new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich") - ), - asList( - new Metadata().put("name", "Alice"), - new Metadata().put("age", 666), - new Metadata().put("name", "Alice").put("age", 666) - ) - ) - ) - .add( - Arguments.of( - or(metadataKey("age").isEqualTo(42), metadataKey("name").isEqualTo("Klaus")), - asList( - // only Or.left is present and true - new Metadata().put("age", 42), - new Metadata().put("age", 42).put("city", "Munich"), - // Or.left is true, Or.right is false - new Metadata().put("age", 42).put("name", "Alice"), - // only Or.right is present and true - new Metadata().put("name", "Klaus"), - new Metadata().put("name", "Klaus").put("city", "Munich"), - // Or.right is true, Or.left is false - new Metadata().put("name", "Klaus").put("age", 666), - // Or.left and Or.right are both true - new Metadata().put("name", "Klaus").put("age", 42), - new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich") - ), - asList( - new Metadata().put("name", "Alice"), - new Metadata().put("age", 666), - new Metadata().put("name", "Alice").put("age", 666) - ) - ) - ) - // Or: x2 - .add( - Arguments.of( - or( - metadataKey("name").isEqualTo("Klaus"), - or(metadataKey("age").isEqualTo(42), metadataKey("city").isEqualTo("Munich")) - ), - asList( - // only Or.left is present and true - new Metadata().put("name", "Klaus"), - new Metadata().put("name", "Klaus").put("country", "Germany"), - // Or.left is true, Or.right is false - new Metadata().put("name", "Klaus").put("age", 666), - new Metadata().put("name", "Klaus").put("city", "Frankfurt"), - new Metadata().put("name", "Klaus").put("age", 666).put("city", "Frankfurt"), - // only Or.right is present and true - new Metadata().put("age", 42), - new Metadata().put("age", 42).put("country", "Germany"), - new Metadata().put("city", "Munich"), - new Metadata().put("city", "Munich").put("country", "Germany"), - new Metadata().put("age", 42).put("city", "Munich"), - new Metadata().put("age", 42).put("city", "Munich").put("country", "Germany"), - // Or.right is true, Or.left is false - new Metadata().put("age", 42).put("name", "Alice"), - new Metadata().put("city", "Munich").put("name", "Alice"), - new Metadata().put("age", 42).put("city", "Munich").put("name", "Alice"), - // Or.left and Or.right are both true - new Metadata().put("name", "Klaus").put("age", 42), - new Metadata().put("name", "Klaus").put("age", 42).put("country", "Germany"), - new Metadata().put("name", "Klaus").put("city", "Munich"), - new Metadata().put("name", "Klaus").put("city", "Munich").put("country", "Germany"), - new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich"), - new Metadata() - .put("name", "Klaus") - .put("age", 42) - .put("city", "Munich") - .put("country", "Germany") - ), - asList( - new Metadata().put("name", "Alice"), - new Metadata().put("age", 666), - new Metadata().put("city", "Frankfurt"), - new Metadata().put("name", "Alice").put("age", 666), - new Metadata().put("name", "Alice").put("city", "Frankfurt"), - new Metadata().put("name", "Alice").put("age", 666).put("city", "Frankfurt") - ) - ) - ) - .add( - Arguments.of( - or( - or(metadataKey("name").isEqualTo("Klaus"), metadataKey("age").isEqualTo(42)), - metadataKey("city").isEqualTo("Munich") - ), - asList( - // only Or.left is present and true - new Metadata().put("name", "Klaus"), - new Metadata().put("name", "Klaus").put("country", "Germany"), - new Metadata().put("age", 42), - new Metadata().put("age", 42).put("country", "Germany"), - new Metadata().put("name", "Klaus").put("age", 42), - new Metadata().put("name", "Klaus").put("age", 42).put("country", "Germany"), - // Or.left is true, Or.right is false - new Metadata().put("name", "Klaus").put("city", "Frankfurt"), - new Metadata().put("age", 42).put("city", "Frankfurt"), - new Metadata().put("name", "Klaus").put("age", 42).put("city", "Frankfurt"), - // only Or.right is present and true - new Metadata().put("city", "Munich"), - new Metadata().put("city", "Munich").put("country", "Germany"), - // Or.right is true, Or.left is false - new Metadata().put("city", "Munich").put("name", "Alice"), - new Metadata().put("city", "Munich").put("age", 666), - // Or.left and Or.right are both true - new Metadata().put("name", "Klaus").put("age", 42), - new Metadata().put("name", "Klaus").put("age", 42).put("country", "Germany"), - new Metadata().put("name", "Klaus").put("city", "Munich"), - new Metadata().put("name", "Klaus").put("city", "Munich").put("country", "Germany"), - new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich"), - new Metadata() - .put("name", "Klaus") - .put("age", 42) - .put("city", "Munich") - .put("country", "Germany") - ), - asList( - new Metadata().put("name", "Alice"), - new Metadata().put("age", 666), - new Metadata().put("city", "Frankfurt"), - new Metadata().put("name", "Alice").put("age", 666), - new Metadata().put("name", "Alice").put("city", "Frankfurt"), - new Metadata().put("name", "Alice").put("age", 666).put("city", "Frankfurt") - ) - ) - ) - // === AND === - - .add( - Arguments.of( - and(metadataKey("name").isEqualTo("Klaus"), metadataKey("age").isEqualTo(42)), - asList( - new Metadata().put("name", "Klaus").put("age", 42), - new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich") - ), - asList( - // only And.left is present and true - new Metadata().put("name", "Klaus"), - // And.left is true, And.right is false - new Metadata().put("name", "Klaus").put("age", 666), - // only And.right is present and true - new Metadata().put("age", 42), - // And.right is true, And.left is false - new Metadata().put("age", 42).put("name", "Alice"), - // And.left, And.right are both false - new Metadata().put("age", 666).put("name", "Alice") - ) - ) - ) - .add( - Arguments.of( - and(metadataKey("age").isEqualTo(42), metadataKey("name").isEqualTo("Klaus")), - asList( - new Metadata().put("name", "Klaus").put("age", 42), - new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich") - ), - asList( - // only And.left is present and true - new Metadata().put("age", 42), - // And.left is true, And.right is false - new Metadata().put("age", 42).put("name", "Alice"), - // only And.right is present and true - new Metadata().put("name", "Klaus"), - // And.right is true, And.left is false - new Metadata().put("name", "Klaus").put("age", 666), - // And.left, And.right are both false - new Metadata().put("age", 666).put("name", "Alice") - ) - ) - ) - // And: x2 - .add( - Arguments.of( - and( - metadataKey("name").isEqualTo("Klaus"), - and(metadataKey("age").isEqualTo(42), metadataKey("city").isEqualTo("Munich")) - ), - asList( - new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich"), - new Metadata() - .put("name", "Klaus") - .put("age", 42) - .put("city", "Munich") - .put("country", "Germany") - ), - asList( - // only And.left is present and true - new Metadata().put("name", "Klaus"), - // And.left is true, And.right is false - new Metadata().put("name", "Klaus").put("age", 42), - new Metadata().put("name", "Klaus").put("city", "Munich"), - new Metadata().put("name", "Klaus").put("age", 666).put("city", "Munich"), - new Metadata().put("name", "Klaus").put("age", 42).put("city", "Frankfurt"), - // only And.right is present and true - new Metadata().put("age", 42).put("city", "Munich"), - // And.right is true, And.left is false - new Metadata().put("age", 42).put("city", "Munich").put("name", "Alice") - ) - ) - ) - .add( - Arguments.of( - and( - and(metadataKey("name").isEqualTo("Klaus"), metadataKey("age").isEqualTo(42)), - metadataKey("city").isEqualTo("Munich") - ), - asList( - new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich"), - new Metadata() - .put("name", "Klaus") - .put("age", 42) - .put("city", "Munich") - .put("country", "Germany") - ), - asList( - // only And.left is present and true - new Metadata().put("name", "Klaus").put("age", 42), - // And.left is true, And.right is false - new Metadata().put("name", "Klaus").put("age", 42).put("city", "Frankfurt"), - // only And.right is present and true - new Metadata().put("city", "Munich"), - // And.right is true, And.left is false - new Metadata().put("city", "Munich").put("name", "Klaus"), - new Metadata().put("city", "Munich").put("name", "Klaus").put("age", 666), - new Metadata().put("city", "Munich").put("age", 42), - new Metadata().put("city", "Munich").put("age", 42).put("name", "Alice") - ) - ) - ) - // === AND + nested OR === - - .add( - Arguments.of( - and( - metadataKey("name").isEqualTo("Klaus"), - or(metadataKey("age").isEqualTo(42), metadataKey("city").isEqualTo("Munich")) - ), - asList( - new Metadata().put("name", "Klaus").put("age", 42), - new Metadata().put("name", "Klaus").put("age", 42).put("country", "Germany"), - new Metadata().put("name", "Klaus").put("city", "Munich"), - new Metadata().put("name", "Klaus").put("city", "Munich").put("country", "Germany"), - new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich"), - new Metadata() - .put("name", "Klaus") - .put("age", 42) - .put("city", "Munich") - .put("country", "Germany") - ), - asList( - // only And.left is present and true - new Metadata().put("name", "Klaus"), - // And.left is true, And.right is false - new Metadata().put("name", "Klaus").put("age", 666), - new Metadata().put("name", "Klaus").put("city", "Frankfurt"), - // only And.right is present and true - new Metadata().put("age", 42), - new Metadata().put("city", "Munich"), - new Metadata().put("age", 42).put("city", "Munich"), - // And.right is true, And.left is false - new Metadata().put("age", 42).put("name", "Alice"), - new Metadata().put("city", "Munich").put("name", "Alice"), - new Metadata().put("age", 42).put("city", "Munich").put("name", "Alice") - ) - ) - ) - .add( - Arguments.of( - and( - or(metadataKey("name").isEqualTo("Klaus"), metadataKey("age").isEqualTo(42)), - metadataKey("city").isEqualTo("Munich") - ), - asList( - new Metadata().put("name", "Klaus").put("city", "Munich"), - new Metadata().put("name", "Klaus").put("city", "Munich").put("country", "Germany"), - new Metadata().put("age", 42).put("city", "Munich"), - new Metadata().put("age", 42).put("city", "Munich").put("country", "Germany"), - new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich"), - new Metadata() - .put("name", "Klaus") - .put("age", 42) - .put("city", "Munich") - .put("country", "Germany") - ), - asList( - // only And.left is present and true - new Metadata().put("name", "Klaus"), - new Metadata().put("age", 42), - new Metadata().put("name", "Klaus").put("age", 42), - // And.left is true, And.right is false - new Metadata().put("name", "Klaus").put("city", "Frankfurt"), - new Metadata().put("age", 42).put("city", "Frankfurt"), - new Metadata().put("name", "Klaus").put("age", 42).put("city", "Frankfurt"), - // only And.right is present and true - new Metadata().put("city", "Munich"), - // And.right is true, And.left is false - new Metadata().put("city", "Munich").put("name", "Alice"), - new Metadata().put("city", "Munich").put("age", 666), - new Metadata().put("city", "Munich").put("name", "Alice").put("age", 666) - ) - ) - ) - // === OR + nested AND === - .add( - Arguments.of( - or( - metadataKey("name").isEqualTo("Klaus"), - and(metadataKey("age").isEqualTo(42), metadataKey("city").isEqualTo("Munich")) - ), - asList( - // only Or.left is present and true - new Metadata().put("name", "Klaus"), - new Metadata().put("name", "Klaus").put("country", "Germany"), - // Or.left is true, Or.right is false - new Metadata().put("name", "Klaus").put("age", 666), - new Metadata().put("name", "Klaus").put("city", "Frankfurt"), - new Metadata().put("name", "Klaus").put("age", 666).put("city", "Frankfurt"), - // only Or.right is present and true - new Metadata().put("age", 42).put("city", "Munich"), - new Metadata().put("age", 42).put("city", "Munich").put("country", "Germany"), - // Or.right is true, Or.left is false - new Metadata().put("age", 42).put("city", "Munich").put("name", "Alice") - ), - asList( - new Metadata().put("name", "Alice"), - new Metadata().put("age", 666), - new Metadata().put("city", "Frankfurt"), - new Metadata().put("name", "Alice").put("age", 666).put("city", "Frankfurt") - ) - ) - ) - .add( - Arguments.of( - or( - and(metadataKey("name").isEqualTo("Klaus"), metadataKey("age").isEqualTo(42)), - metadataKey("city").isEqualTo("Munich") - ), - asList( - // only Or.left is present and true - new Metadata().put("name", "Klaus").put("age", 42), - new Metadata().put("name", "Klaus").put("age", 42).put("country", "Germany"), - // Or.left is true, Or.right is false - new Metadata().put("name", "Klaus").put("age", 42).put("city", "Frankfurt"), - // only Or.right is present and true - new Metadata().put("city", "Munich"), - new Metadata().put("city", "Munich").put("country", "Germany"), - // Or.right is true, Or.left is true - new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich"), - new Metadata() - .put("name", "Klaus") - .put("age", 42) - .put("city", "Munich") - .put("country", "Germany") - ), - asList( - new Metadata().put("name", "Alice"), - new Metadata().put("age", 666), - new Metadata().put("city", "Frankfurt"), - new Metadata().put("name", "Alice").put("age", 666).put("city", "Frankfurt") - ) - ) - ) - .build(); - } - - @ParameterizedTest - @MethodSource - protected void should_filter_by_greater_and_less_than_alphanumeric_metadata( - Filter metadataFilter, - List matchingMetadatas, - List notMatchingMetadatas - ) { - shouldFilter(metadataFilter, matchingMetadatas, notMatchingMetadatas); - } - - static Stream should_filter_by_greater_and_less_than_alphanumeric_metadata() { - return Stream - .builder() - // === GreaterThan == - .add( - Arguments.of( - metadataKey("key").isGreaterThan("b"), - asList(new Metadata().put("key", "c"), new Metadata().put("key", "c").put("key2", "a")), - asList( - new Metadata().put("key", "a"), - new Metadata().put("key", "b"), - new Metadata().put("key2", "c") - ) - ) - ) - // === GreaterThanOrEqual == - .add( - Arguments.of( - metadataKey("key").isGreaterThanOrEqualTo("b"), - asList( - new Metadata().put("key", "b"), - new Metadata().put("key", "c"), - new Metadata().put("key", "c").put("key2", "a") - ), - asList(new Metadata().put("key", "a"), new Metadata().put("key2", "b")) - ) - ) - // === LessThan == - .add( - Arguments.of( - metadataKey("key").isLessThan("b"), - asList(new Metadata().put("key", "a"), new Metadata().put("key", "a").put("key2", "c")), - asList( - new Metadata().put("key", "b"), - new Metadata().put("key", "c"), - new Metadata().put("key2", "a") - ) - ) - ) - // === LessThanOrEqual == - .add( - Arguments.of( - metadataKey("key").isLessThanOrEqualTo("b"), - asList( - new Metadata().put("key", "a"), - new Metadata().put("key", "b"), - new Metadata().put("key", "b").put("key2", "c") - ), - asList(new Metadata().put("key", "c"), new Metadata().put("key2", "a")) - ) - ) - .build(); - } - - private void shouldFilter( - Filter metadataFilter, - List matchingMetadatas, - List notMatchingMetadatas - ) { - + protected void should_filter_by_metadata(Filter metadataFilter, + List matchingMetadatas, + List notMatchingMetadatas) { // given List embeddings = new ArrayList<>(); List segments = new ArrayList<>(); @@ -1078,34 +50,15 @@ public abstract class EmbeddingStoreWithFilteringIT extends EmbeddingStoreIT { segments.add(notMatchingSegment); } - TextSegment notMatchingSegmentWithoutMetadata = TextSegment.from("not matching, without metadata"); - Embedding notMatchingEmbeddingWithoutMetadata = embeddingModel() - .embed(notMatchingSegmentWithoutMetadata) - .content(); - embeddings.add(notMatchingEmbeddingWithoutMetadata); - segments.add(notMatchingSegmentWithoutMetadata); - embeddingStore().addAll(embeddings, segments); - awaitUntilPersisted(); + awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(embeddings.size())); - Embedding queryEmbedding = embeddingModel().embed("matching").content(); - - EmbeddingSearchRequest request = EmbeddingSearchRequest - .builder() - .queryEmbedding(queryEmbedding) - .maxResults(100) - .build(); - assertThat(embeddingStore().search(request).matches()) - // +1 for notMatchingSegmentWithoutMetadata - .hasSize(matchingMetadatas.size() + notMatchingMetadatas.size() + 1); - - EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest - .builder() - .queryEmbedding(queryEmbedding) - .filter(metadataFilter) - .maxResults(100) - .build(); + EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder() + .queryEmbedding(embeddingModel().embed("matching").content()) + .filter(metadataFilter) + .maxResults(100) + .build(); // when List> matches = embeddingStore().search(embeddingSearchRequest).matches(); @@ -1116,14 +69,1203 @@ public abstract class EmbeddingStoreWithFilteringIT extends EmbeddingStoreIT { matches.forEach(match -> assertThat(match.score()).isCloseTo(1, withPercentage(0.01))); } + protected static Stream should_filter_by_metadata() { + return Stream.builder() + + + // === Equal === + + .add(Arguments.of( + metadataKey("key").isEqualTo("a"), + asList( + new Metadata().put("key", "a"), + new Metadata().put("key", "a").put("key2", "b") + ), + asList( + new Metadata().put("key", "A"), + new Metadata().put("key", "b"), + new Metadata().put("key", "aa"), + new Metadata().put("key", "a a"), + new Metadata().put("key2", "a"), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isEqualTo(TEST_UUID), + asList( + new Metadata().put("key", TEST_UUID), + new Metadata().put("key", TEST_UUID).put("key2", "b") + ), + asList( + new Metadata().put("key", UUID.randomUUID()), + new Metadata().put("key2", TEST_UUID), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isEqualTo(1), + asList( + new Metadata().put("key", 1), + new Metadata().put("key", 1).put("key2", 0) + ), + asList( + new Metadata().put("key", -1), + new Metadata().put("key", 0), + new Metadata().put("key2", 1), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isEqualTo(1L), + asList( + new Metadata().put("key", 1L), + new Metadata().put("key", 1L).put("key2", 0L) + ), + asList( + new Metadata().put("key", -1L), + new Metadata().put("key", 0L), + new Metadata().put("key2", 1L), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isEqualTo(1.23f), + asList( + new Metadata().put("key", 1.23f), + new Metadata().put("key", 1.23f).put("key2", 0f) + ), + asList( + new Metadata().put("key", -1.23f), + new Metadata().put("key", 1.22f), + new Metadata().put("key", 1.24f), + new Metadata().put("key2", 1.23f), + new Metadata() + ) + )).add(Arguments.of( + metadataKey("key").isEqualTo(1.23d), + asList( + new Metadata().put("key", 1.23d), + new Metadata().put("key", 1.23d).put("key2", 0d) + ), + asList( + new Metadata().put("key", -1.23d), + new Metadata().put("key", 1.22d), + new Metadata().put("key", 1.24d), + new Metadata().put("key2", 1.23d), + new Metadata() + ) + )) + + + // === GreaterThan == + + .add(Arguments.of( + metadataKey("key").isGreaterThan("b"), + asList( + new Metadata().put("key", "c"), + new Metadata().put("key", "c").put("key2", "a") + ), + asList( + new Metadata().put("key", "a"), + new Metadata().put("key", "b"), + new Metadata().put("key2", "c"), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isGreaterThan(1), + asList( + new Metadata().put("key", 2), + new Metadata().put("key", 2).put("key2", 0) + ), + asList( + new Metadata().put("key", -2), + new Metadata().put("key", 0), + new Metadata().put("key", 1), + new Metadata().put("key2", 2), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isGreaterThan(1L), + asList( + new Metadata().put("key", 2L), + new Metadata().put("key", 2L).put("key2", 0L) + ), + asList( + new Metadata().put("key", -2L), + new Metadata().put("key", 0L), + new Metadata().put("key", 1L), + new Metadata().put("key2", 2L), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isGreaterThan(1.1f), + asList( + new Metadata().put("key", 1.2f), + new Metadata().put("key", 1.2f).put("key2", 1.0f) + ), + asList( + new Metadata().put("key", -1.2f), + new Metadata().put("key", 0.0f), + new Metadata().put("key", 1.1f), + new Metadata().put("key2", 1.2f), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isGreaterThan(1.1d), + asList( + new Metadata().put("key", 1.2d), + new Metadata().put("key", 1.2d).put("key2", 1.0d) + ), + asList( + new Metadata().put("key", -1.2d), + new Metadata().put("key", 0.0d), + new Metadata().put("key", 1.1d), + new Metadata().put("key2", 1.2d), + new Metadata() + ) + )) + + + // === GreaterThanOrEqual == + + .add(Arguments.of( + metadataKey("key").isGreaterThanOrEqualTo("b"), + asList( + new Metadata().put("key", "b"), + new Metadata().put("key", "c"), + new Metadata().put("key", "c").put("key2", "a") + ), + asList( + new Metadata().put("key", "a"), + new Metadata().put("key2", "b"), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isGreaterThanOrEqualTo(1), + asList( + new Metadata().put("key", 1), + new Metadata().put("key", 2), + new Metadata().put("key", 2).put("key2", 0) + ), + asList( + new Metadata().put("key", -2), + new Metadata().put("key", -1), + new Metadata().put("key", 0), + new Metadata().put("key2", 1), + new Metadata().put("key2", 2), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isGreaterThanOrEqualTo(1L), + asList( + new Metadata().put("key", 1L), + new Metadata().put("key", 2L), + new Metadata().put("key", 2L).put("key2", 0L) + ), + asList( + new Metadata().put("key", -2L), + new Metadata().put("key", -1L), + new Metadata().put("key", 0L), + new Metadata().put("key2", 1L), + new Metadata().put("key2", 2L), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isGreaterThanOrEqualTo(1.1f), + asList( + new Metadata().put("key", 1.1f), + new Metadata().put("key", 1.2f), + new Metadata().put("key", 1.2f).put("key2", 1.0f) + ), + asList( + new Metadata().put("key", -1.2f), + new Metadata().put("key", -1.1f), + new Metadata().put("key", 0.0f), + new Metadata().put("key2", 1.1f), + new Metadata().put("key2", 1.2f), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isGreaterThanOrEqualTo(1.1d), + asList( + new Metadata().put("key", 1.1d), + new Metadata().put("key", 1.2d), + new Metadata().put("key", 1.2d).put("key2", 1.0d) + ), + asList( + new Metadata().put("key", -1.2d), + new Metadata().put("key", -1.1d), + new Metadata().put("key", 0.0d), + new Metadata().put("key2", 1.1d), + new Metadata().put("key2", 1.2d), + new Metadata() + ) + )) + + + // === LessThan == + + .add(Arguments.of( + metadataKey("key").isLessThan("b"), + asList( + + new Metadata().put("key", "a"), + new Metadata().put("key", "a").put("key2", "c") + ), + asList( + new Metadata().put("key", "b"), + new Metadata().put("key", "c"), + new Metadata().put("key2", "a"), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isLessThan(1), + asList( + new Metadata().put("key", -2), + new Metadata().put("key", 0), + new Metadata().put("key", 0).put("key2", 2) + ), + asList( + new Metadata().put("key", 1), + new Metadata().put("key", 2), + new Metadata().put("key2", 0), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isLessThan(1L), + asList( + new Metadata().put("key", -2L), + new Metadata().put("key", 0L), + new Metadata().put("key", 0L).put("key2", 2L) + ), + asList( + new Metadata().put("key", 1L), + new Metadata().put("key", 2L), + new Metadata().put("key2", 0L), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isLessThan(1.1f), + asList( + new Metadata().put("key", -1.2f), + new Metadata().put("key", 1.0f), + new Metadata().put("key", 1.0f).put("key2", 1.2f) + ), + asList( + new Metadata().put("key", 1.1f), + new Metadata().put("key", 1.2f), + new Metadata().put("key2", 1.0f), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isLessThan(1.1d), + asList( + new Metadata().put("key", -1.2d), + new Metadata().put("key", 1.0d), + new Metadata().put("key", 1.0d).put("key2", 1.2d) + ), + asList( + new Metadata().put("key", 1.1d), + new Metadata().put("key", 1.2d), + new Metadata().put("key2", 1.0d), + new Metadata() + ) + )) + + + // === LessThanOrEqual == + + .add(Arguments.of( + metadataKey("key").isLessThanOrEqualTo("b"), + asList( + + new Metadata().put("key", "a"), + new Metadata().put("key", "b"), + new Metadata().put("key", "b").put("key2", "c") + ), + asList( + new Metadata().put("key", "c"), + new Metadata().put("key2", "a"), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isLessThanOrEqualTo(1), + asList( + new Metadata().put("key", -2), + new Metadata().put("key", 0), + new Metadata().put("key", 1), + new Metadata().put("key", 1).put("key2", 2) + ), + asList( + new Metadata().put("key", 2), + new Metadata().put("key2", 0), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isLessThanOrEqualTo(1L), + asList( + new Metadata().put("key", -2L), + new Metadata().put("key", 0L), + new Metadata().put("key", 1L), + new Metadata().put("key", 1L).put("key2", 2L) + ), + asList( + new Metadata().put("key", 2L), + new Metadata().put("key2", 0L), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isLessThanOrEqualTo(1.1f), + asList( + new Metadata().put("key", -1.2f), + new Metadata().put("key", 1.0f), + new Metadata().put("key", 1.1f), + new Metadata().put("key", 1.1f).put("key2", 1.2f) + ), + asList( + new Metadata().put("key", 1.2f), + new Metadata().put("key2", 1.0f), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("key").isLessThanOrEqualTo(1.1d), + asList( + new Metadata().put("key", -1.2d), + new Metadata().put("key", 1.0d), + new Metadata().put("key", 1.1d), + new Metadata().put("key", 1.1d).put("key2", 1.2d) + ), + asList( + new Metadata().put("key", 1.2d), + new Metadata().put("key2", 1.0d), + new Metadata() + ) + )) + + + // === In === + + // In: string + .add(Arguments.of( + metadataKey("name").isIn("Klaus"), + asList( + new Metadata().put("name", "Klaus"), + new Metadata().put("name", "Klaus").put("age", 42) + ), + asList( + new Metadata().put("name", "Klaus Heisler"), + new Metadata().put("name", "Alice"), + new Metadata().put("name2", "Klaus"), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("name").isIn(singletonList("Klaus")), + asList( + new Metadata().put("name", "Klaus"), + new Metadata().put("name", "Klaus").put("age", 42) + ), + asList( + new Metadata().put("name", "Klaus Heisler"), + new Metadata().put("name", "Alice"), + new Metadata().put("name2", "Klaus"), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("name").isIn("Klaus", "Alice"), + asList( + new Metadata().put("name", "Klaus"), + new Metadata().put("name", "Klaus").put("age", 42), + new Metadata().put("name", "Alice"), + new Metadata().put("name", "Alice").put("age", 42) + ), + asList( + new Metadata().put("name", "Klaus Heisler"), + new Metadata().put("name", "Zoe"), + new Metadata().put("name2", "Klaus"), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("name").isIn(asList("Klaus", "Alice")), + asList( + new Metadata().put("name", "Klaus"), + new Metadata().put("name", "Klaus").put("age", 42), + new Metadata().put("name", "Alice"), + new Metadata().put("name", "Alice").put("age", 42) + ), + asList( + new Metadata().put("name", "Klaus Heisler"), + new Metadata().put("name", "Zoe"), + new Metadata().put("name2", "Klaus"), + new Metadata() + ) + )) + + // In: UUID + .add(Arguments.of( + metadataKey("name").isIn(TEST_UUID), + asList( + new Metadata().put("name", TEST_UUID), + new Metadata().put("name", TEST_UUID).put("age", 42) + ), + asList( + new Metadata().put("name", UUID.randomUUID()), + new Metadata().put("name2", TEST_UUID), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("name").isIn(singletonList(TEST_UUID)), + asList( + new Metadata().put("name", TEST_UUID), + new Metadata().put("name", TEST_UUID).put("age", 42) + ), + asList( + new Metadata().put("name", UUID.randomUUID()), + new Metadata().put("name2", TEST_UUID), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("name").isIn(TEST_UUID, TEST_UUID2), + asList( + new Metadata().put("name", TEST_UUID), + new Metadata().put("name", TEST_UUID).put("age", 42), + new Metadata().put("name", TEST_UUID2), + new Metadata().put("name", TEST_UUID2).put("age", 42) + ), + asList( + new Metadata().put("name", UUID.randomUUID()), + new Metadata().put("name2", TEST_UUID), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("name").isIn(asList(TEST_UUID, TEST_UUID2)), + asList( + new Metadata().put("name", TEST_UUID), + new Metadata().put("name", TEST_UUID).put("age", 42), + new Metadata().put("name", TEST_UUID2), + new Metadata().put("name", TEST_UUID2).put("age", 42) + ), + asList( + new Metadata().put("name", UUID.randomUUID()), + new Metadata().put("name2", TEST_UUID), + new Metadata() + ) + )) + + // In: integer + .add(Arguments.of( + metadataKey("age").isIn(42), + asList( + new Metadata().put("age", 42), + new Metadata().put("age", 42).put("name", "Klaus") + ), + asList( + new Metadata().put("age", 666), + new Metadata().put("age2", 42), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("age").isIn(singletonList(42)), + asList( + new Metadata().put("age", 42), + new Metadata().put("age", 42).put("name", "Klaus") + ), + asList( + new Metadata().put("age", 666), + new Metadata().put("age2", 42), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("age").isIn(42, 18), + asList( + new Metadata().put("age", 42), + new Metadata().put("age", 18), + new Metadata().put("age", 42).put("name", "Klaus"), + new Metadata().put("age", 18).put("name", "Klaus") + ), + asList( + new Metadata().put("age", 666), + new Metadata().put("age2", 42), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("age").isIn(asList(42, 18)), + asList( + new Metadata().put("age", 42), + new Metadata().put("age", 18), + new Metadata().put("age", 42).put("name", "Klaus"), + new Metadata().put("age", 18).put("name", "Klaus") + ), + asList( + new Metadata().put("age", 666), + new Metadata().put("age2", 42), + new Metadata() + ) + )) + + // In: long + .add(Arguments.of( + metadataKey("age").isIn(42L), + asList( + new Metadata().put("age", 42L), + new Metadata().put("age", 42L).put("name", "Klaus") + ), + asList( + new Metadata().put("age", 666L), + new Metadata().put("age2", 42L), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("age").isIn(singletonList(42L)), + asList( + new Metadata().put("age", 42L), + new Metadata().put("age", 42L).put("name", "Klaus") + ), + asList( + new Metadata().put("age", 666L), + new Metadata().put("age2", 42L), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("age").isIn(42L, 18L), + asList( + new Metadata().put("age", 42L), + new Metadata().put("age", 18L), + new Metadata().put("age", 42L).put("name", "Klaus"), + new Metadata().put("age", 18L).put("name", "Klaus") + ), + asList( + new Metadata().put("age", 666L), + new Metadata().put("age2", 42L), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("age").isIn(asList(42L, 18L)), + asList( + new Metadata().put("age", 42L), + new Metadata().put("age", 18L), + new Metadata().put("age", 42L).put("name", "Klaus"), + new Metadata().put("age", 18L).put("name", "Klaus") + ), + asList( + new Metadata().put("age", 666L), + new Metadata().put("age2", 42L), + new Metadata() + ) + )) + + // In: float + .add(Arguments.of( + metadataKey("age").isIn(42.0f), + asList( + new Metadata().put("age", 42.0f), + new Metadata().put("age", 42.0f).put("name", "Klaus") + ), + asList( + new Metadata().put("age", 666.0f), + new Metadata().put("age2", 42.0f), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("age").isIn(singletonList(42.0f)), + asList( + new Metadata().put("age", 42.0f), + new Metadata().put("age", 42.0f).put("name", "Klaus") + ), + asList( + new Metadata().put("age", 666.0f), + new Metadata().put("age2", 42.0f), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("age").isIn(42.0f, 18.0f), + asList( + new Metadata().put("age", 42.0f), + new Metadata().put("age", 18.0f), + new Metadata().put("age", 42.0f).put("name", "Klaus"), + new Metadata().put("age", 18.0f).put("name", "Klaus") + ), + asList( + new Metadata().put("age", 666.0f), + new Metadata().put("age2", 42.0f), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("age").isIn(asList(42.0f, 18.0f)), + asList( + new Metadata().put("age", 42.0f), + new Metadata().put("age", 18.0f), + new Metadata().put("age", 42.0f).put("name", "Klaus"), + new Metadata().put("age", 18.0f).put("name", "Klaus") + ), + asList( + new Metadata().put("age", 666.0f), + new Metadata().put("age2", 42.0f), + new Metadata() + ) + )) + + // In: double + .add(Arguments.of( + metadataKey("age").isIn(42.0d), + asList( + new Metadata().put("age", 42.0d), + new Metadata().put("age", 42.0d).put("name", "Klaus") + ), + asList( + new Metadata().put("age", 666.0d), + new Metadata().put("age2", 42.0d), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("age").isIn(singletonList(42.0d)), + asList( + new Metadata().put("age", 42.0d), + new Metadata().put("age", 42.0d).put("name", "Klaus") + ), + asList( + new Metadata().put("age", 666.0d), + new Metadata().put("age2", 42.0d), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("age").isIn(42.0d, 18.0d), + asList( + new Metadata().put("age", 42.0d), + new Metadata().put("age", 18.0d), + new Metadata().put("age", 42.0d).put("name", "Klaus"), + new Metadata().put("age", 18.0d).put("name", "Klaus") + ), + asList( + new Metadata().put("age", 666.0d), + new Metadata().put("age2", 42.0d), + new Metadata() + ) + )) + .add(Arguments.of( + metadataKey("age").isIn(asList(42.0d, 18.0d)), + asList( + new Metadata().put("age", 42.0d), + new Metadata().put("age", 18.0d), + new Metadata().put("age", 42.0d).put("name", "Klaus"), + new Metadata().put("age", 18.0d).put("name", "Klaus") + ), + asList( + new Metadata().put("age", 666.0d), + new Metadata().put("age2", 42.0d), + new Metadata() + ) + )) + + + // === Or === + + // Or: one key + .add(Arguments.of( + or( + metadataKey("name").isEqualTo("Klaus"), + metadataKey("name").isEqualTo("Alice") + ), + asList( + new Metadata().put("name", "Klaus"), + new Metadata().put("name", "Klaus").put("age", 42), + new Metadata().put("name", "Alice"), + new Metadata().put("name", "Alice").put("age", 42) + ), + asList( + new Metadata().put("name", "Zoe"), + new Metadata() + ) + )) + .add(Arguments.of( + or( + metadataKey("name").isEqualTo("Alice"), + metadataKey("name").isEqualTo("Klaus") + ), + asList( + new Metadata().put("name", "Alice"), + new Metadata().put("name", "Alice").put("age", 42), + new Metadata().put("name", "Klaus"), + new Metadata().put("name", "Klaus").put("age", 42) + ), + asList( + new Metadata().put("name", "Zoe"), + new Metadata() + ) + )) + + // Or: multiple keys + .add(Arguments.of( + or( + metadataKey("name").isEqualTo("Klaus"), + metadataKey("age").isEqualTo(42) + ), + asList( + // only Or.left is present and true + new Metadata().put("name", "Klaus"), + new Metadata().put("name", "Klaus").put("city", "Munich"), + + // Or.left is true, Or.right is false + new Metadata().put("name", "Klaus").put("age", 666), + + // only Or.right is present and true + new Metadata().put("age", 42), + new Metadata().put("age", 42).put("city", "Munich"), + + // Or.right is true, Or.left is false + new Metadata().put("age", 42).put("name", "Alice"), + + // Or.left and Or.right are both true + new Metadata().put("name", "Klaus").put("age", 42), + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich") + ), + asList( + new Metadata().put("name", "Alice"), + new Metadata().put("age", 666), + new Metadata().put("name", "Alice").put("age", 666), + new Metadata() + ) + )) + .add(Arguments.of( + or( + metadataKey("age").isEqualTo(42), + metadataKey("name").isEqualTo("Klaus") + ), + asList( + // only Or.left is present and true + new Metadata().put("age", 42), + new Metadata().put("age", 42).put("city", "Munich"), + + // Or.left is true, Or.right is false + new Metadata().put("age", 42).put("name", "Alice"), + + // only Or.right is present and true + new Metadata().put("name", "Klaus"), + new Metadata().put("name", "Klaus").put("city", "Munich"), + + // Or.right is true, Or.left is false + new Metadata().put("name", "Klaus").put("age", 666), + + // Or.left and Or.right are both true + new Metadata().put("name", "Klaus").put("age", 42), + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich") + ), + asList( + new Metadata().put("name", "Alice"), + new Metadata().put("age", 666), + new Metadata().put("name", "Alice").put("age", 666), + new Metadata() + ) + )) + + // Or: x2 + .add(Arguments.of( + or( + metadataKey("name").isEqualTo("Klaus"), + or( + metadataKey("age").isEqualTo(42), + metadataKey("city").isEqualTo("Munich") + ) + ), + asList( + // only Or.left is present and true + new Metadata().put("name", "Klaus"), + new Metadata().put("name", "Klaus").put("country", "Germany"), + + // Or.left is true, Or.right is false + new Metadata().put("name", "Klaus").put("age", 666), + new Metadata().put("name", "Klaus").put("city", "Frankfurt"), + new Metadata().put("name", "Klaus").put("age", 666).put("city", "Frankfurt"), + + // only Or.right is present and true + new Metadata().put("age", 42), + new Metadata().put("age", 42).put("country", "Germany"), + new Metadata().put("city", "Munich"), + new Metadata().put("city", "Munich").put("country", "Germany"), + new Metadata().put("age", 42).put("city", "Munich"), + new Metadata().put("age", 42).put("city", "Munich").put("country", "Germany"), + + // Or.right is true, Or.left is false + new Metadata().put("age", 42).put("name", "Alice"), + new Metadata().put("city", "Munich").put("name", "Alice"), + new Metadata().put("age", 42).put("city", "Munich").put("name", "Alice"), + + // Or.left and Or.right are both true + new Metadata().put("name", "Klaus").put("age", 42), + new Metadata().put("name", "Klaus").put("age", 42).put("country", "Germany"), + new Metadata().put("name", "Klaus").put("city", "Munich"), + new Metadata().put("name", "Klaus").put("city", "Munich").put("country", "Germany"), + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich"), + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich").put("country", "Germany") + ), + asList( + new Metadata().put("name", "Alice"), + new Metadata().put("age", 666), + new Metadata().put("city", "Frankfurt"), + new Metadata().put("name", "Alice").put("age", 666), + new Metadata().put("name", "Alice").put("city", "Frankfurt"), + new Metadata().put("name", "Alice").put("age", 666).put("city", "Frankfurt"), + new Metadata() + ) + )) + .add(Arguments.of( + or( + or( + metadataKey("name").isEqualTo("Klaus"), + metadataKey("age").isEqualTo(42) + ), + metadataKey("city").isEqualTo("Munich") + ), + asList( + // only Or.left is present and true + new Metadata().put("name", "Klaus"), + new Metadata().put("name", "Klaus").put("country", "Germany"), + new Metadata().put("age", 42), + new Metadata().put("age", 42).put("country", "Germany"), + new Metadata().put("name", "Klaus").put("age", 42), + new Metadata().put("name", "Klaus").put("age", 42).put("country", "Germany"), + + // Or.left is true, Or.right is false + new Metadata().put("name", "Klaus").put("city", "Frankfurt"), + new Metadata().put("age", 42).put("city", "Frankfurt"), + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Frankfurt"), + + // only Or.right is present and true + new Metadata().put("city", "Munich"), + new Metadata().put("city", "Munich").put("country", "Germany"), + + // Or.right is true, Or.left is false + new Metadata().put("city", "Munich").put("name", "Alice"), + new Metadata().put("city", "Munich").put("age", 666), + + // Or.left and Or.right are both true + new Metadata().put("name", "Klaus").put("age", 42), + new Metadata().put("name", "Klaus").put("age", 42).put("country", "Germany"), + new Metadata().put("name", "Klaus").put("city", "Munich"), + new Metadata().put("name", "Klaus").put("city", "Munich").put("country", "Germany"), + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich"), + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich").put("country", "Germany") + ), + asList( + new Metadata().put("name", "Alice"), + new Metadata().put("age", 666), + new Metadata().put("city", "Frankfurt"), + new Metadata().put("name", "Alice").put("age", 666), + new Metadata().put("name", "Alice").put("city", "Frankfurt"), + new Metadata().put("name", "Alice").put("age", 666).put("city", "Frankfurt"), + new Metadata() + ) + )) + + // === AND === + + .add(Arguments.of( + and( + metadataKey("name").isEqualTo("Klaus"), + metadataKey("age").isEqualTo(42) + ), + asList( + new Metadata().put("name", "Klaus").put("age", 42), + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich") + ), + asList( + // only And.left is present and true + new Metadata().put("name", "Klaus"), + + // And.left is true, And.right is false + new Metadata().put("name", "Klaus").put("age", 666), + + // only And.right is present and true + new Metadata().put("age", 42), + + // And.right is true, And.left is false + new Metadata().put("age", 42).put("name", "Alice"), + + // And.left, And.right are both false + new Metadata().put("age", 666).put("name", "Alice"), + + new Metadata() + ) + )) + .add(Arguments.of( + and( + metadataKey("age").isEqualTo(42), + metadataKey("name").isEqualTo("Klaus") + ), + asList( + new Metadata().put("name", "Klaus").put("age", 42), + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich") + ), + asList( + // only And.left is present and true + new Metadata().put("age", 42), + + // And.left is true, And.right is false + new Metadata().put("age", 42).put("name", "Alice"), + + // only And.right is present and true + new Metadata().put("name", "Klaus"), + + // And.right is true, And.left is false + new Metadata().put("name", "Klaus").put("age", 666), + + // And.left, And.right are both false + new Metadata().put("age", 666).put("name", "Alice"), + + new Metadata() + ) + )) + + // And: x2 + .add(Arguments.of( + and( + metadataKey("name").isEqualTo("Klaus"), + and( + metadataKey("age").isEqualTo(42), + metadataKey("city").isEqualTo("Munich") + ) + ), + asList( + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich"), + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich").put("country", "Germany") + ), + asList( + // only And.left is present and true + new Metadata().put("name", "Klaus"), + + // And.left is true, And.right is false + new Metadata().put("name", "Klaus").put("age", 42), + new Metadata().put("name", "Klaus").put("city", "Munich"), + new Metadata().put("name", "Klaus").put("age", 666).put("city", "Munich"), + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Frankfurt"), + + // only And.right is present and true + new Metadata().put("age", 42).put("city", "Munich"), + + // And.right is true, And.left is false + new Metadata().put("age", 42).put("city", "Munich").put("name", "Alice"), + + new Metadata() + ) + )) + .add(Arguments.of( + and( + and( + metadataKey("name").isEqualTo("Klaus"), + metadataKey("age").isEqualTo(42) + ), + metadataKey("city").isEqualTo("Munich") + ), + asList( + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich"), + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich").put("country", "Germany") + ), + asList( + // only And.left is present and true + new Metadata().put("name", "Klaus").put("age", 42), + + // And.left is true, And.right is false + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Frankfurt"), + + // only And.right is present and true + new Metadata().put("city", "Munich"), + + // And.right is true, And.left is false + new Metadata().put("city", "Munich").put("name", "Klaus"), + new Metadata().put("city", "Munich").put("name", "Klaus").put("age", 666), + new Metadata().put("city", "Munich").put("age", 42), + new Metadata().put("city", "Munich").put("age", 42).put("name", "Alice"), + + new Metadata() + ) + )) + + // === AND + nested OR === + + .add(Arguments.of( + and( + metadataKey("name").isEqualTo("Klaus"), + or( + metadataKey("age").isEqualTo(42), + metadataKey("city").isEqualTo("Munich") + ) + ), + asList( + new Metadata().put("name", "Klaus").put("age", 42), + new Metadata().put("name", "Klaus").put("age", 42).put("country", "Germany"), + new Metadata().put("name", "Klaus").put("city", "Munich"), + new Metadata().put("name", "Klaus").put("city", "Munich").put("country", "Germany"), + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich"), + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich").put("country", "Germany") + ), + asList( + // only And.left is present and true + new Metadata().put("name", "Klaus"), + + // And.left is true, And.right is false + new Metadata().put("name", "Klaus").put("age", 666), + new Metadata().put("name", "Klaus").put("city", "Frankfurt"), + + // only And.right is present and true + new Metadata().put("age", 42), + new Metadata().put("city", "Munich"), + new Metadata().put("age", 42).put("city", "Munich"), + + // And.right is true, And.left is false + new Metadata().put("age", 42).put("name", "Alice"), + new Metadata().put("city", "Munich").put("name", "Alice"), + new Metadata().put("age", 42).put("city", "Munich").put("name", "Alice"), + + new Metadata() + ) + )) + .add(Arguments.of( + and( + or( + metadataKey("name").isEqualTo("Klaus"), + metadataKey("age").isEqualTo(42) + ), + metadataKey("city").isEqualTo("Munich") + ), + asList( + new Metadata().put("name", "Klaus").put("city", "Munich"), + new Metadata().put("name", "Klaus").put("city", "Munich").put("country", "Germany"), + new Metadata().put("age", 42).put("city", "Munich"), + new Metadata().put("age", 42).put("city", "Munich").put("country", "Germany"), + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich"), + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich").put("country", "Germany") + ), + asList( + // only And.left is present and true + new Metadata().put("name", "Klaus"), + new Metadata().put("age", 42), + new Metadata().put("name", "Klaus").put("age", 42), + + // And.left is true, And.right is false + new Metadata().put("name", "Klaus").put("city", "Frankfurt"), + new Metadata().put("age", 42).put("city", "Frankfurt"), + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Frankfurt"), + + // only And.right is present and true + new Metadata().put("city", "Munich"), + + // And.right is true, And.left is false + new Metadata().put("city", "Munich").put("name", "Alice"), + new Metadata().put("city", "Munich").put("age", 666), + new Metadata().put("city", "Munich").put("name", "Alice").put("age", 666), + + new Metadata() + ) + )) + + // === OR + nested AND === + .add(Arguments.of( + or( + metadataKey("name").isEqualTo("Klaus"), + and( + metadataKey("age").isEqualTo(42), + metadataKey("city").isEqualTo("Munich") + ) + ), + asList( + // only Or.left is present and true + new Metadata().put("name", "Klaus"), + new Metadata().put("name", "Klaus").put("country", "Germany"), + + // Or.left is true, Or.right is false + new Metadata().put("name", "Klaus").put("age", 666), + new Metadata().put("name", "Klaus").put("city", "Frankfurt"), + new Metadata().put("name", "Klaus").put("age", 666).put("city", "Frankfurt"), + + // only Or.right is present and true + new Metadata().put("age", 42).put("city", "Munich"), + new Metadata().put("age", 42).put("city", "Munich").put("country", "Germany"), + + // Or.right is true, Or.left is false + new Metadata().put("age", 42).put("city", "Munich").put("name", "Alice") + ), + asList( + new Metadata().put("name", "Alice"), + new Metadata().put("age", 666), + new Metadata().put("city", "Frankfurt"), + new Metadata().put("name", "Alice").put("age", 666).put("city", "Frankfurt"), + + new Metadata() + ) + )) + .add(Arguments.of( + or( + and( + metadataKey("name").isEqualTo("Klaus"), + metadataKey("age").isEqualTo(42) + ), + metadataKey("city").isEqualTo("Munich") + ), + asList( + // only Or.left is present and true + new Metadata().put("name", "Klaus").put("age", 42), + new Metadata().put("name", "Klaus").put("age", 42).put("country", "Germany"), + + // Or.left is true, Or.right is false + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Frankfurt"), + + // only Or.right is present and true + new Metadata().put("city", "Munich"), + new Metadata().put("city", "Munich").put("country", "Germany"), + + // Or.right is true, Or.left is true + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich"), + new Metadata().put("name", "Klaus").put("age", 42).put("city", "Munich").put("country", "Germany") + ), + asList( + new Metadata().put("name", "Alice"), + new Metadata().put("age", 666), + new Metadata().put("city", "Frankfurt"), + new Metadata().put("name", "Alice").put("age", 666).put("city", "Frankfurt"), + new Metadata() + ) + )) + + .build(); + } + @ParameterizedTest @MethodSource - protected void should_filter_by_metadata_not( - Filter metadataFilter, - List matchingMetadatas, - List notMatchingMetadatas, - boolean buildSegmentWithoutMetadata - ) { + void should_filter_by_metadata_not(Filter metadataFilter, + List matchingMetadatas, + List notMatchingMetadatas) { // given List embeddings = new ArrayList<>(); List segments = new ArrayList<>(); @@ -1142,415 +1284,467 @@ public abstract class EmbeddingStoreWithFilteringIT extends EmbeddingStoreIT { segments.add(notMatchingSegment); } - if (buildSegmentWithoutMetadata) { - TextSegment notMatchingSegmentWithoutMetadata = TextSegment.from("matching"); - Embedding notMatchingEmbeddingWithoutMetadata = embeddingModel() - .embed(notMatchingSegmentWithoutMetadata) - .content(); - embeddings.add(notMatchingEmbeddingWithoutMetadata); - segments.add(notMatchingSegmentWithoutMetadata); - } - embeddingStore().addAll(embeddings, segments); - awaitUntilPersisted(); + awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(embeddings.size())); - Embedding queryEmbedding = embeddingModel().embed("matching").content(); - - assertThat(embeddingStore().findRelevant(queryEmbedding, 100)) - .hasSize(matchingMetadatas.size() + notMatchingMetadatas.size() + (buildSegmentWithoutMetadata ? 1 : 0)); - - EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest - .builder() - .queryEmbedding(queryEmbedding) - .filter(metadataFilter) - .maxResults(100) - .build(); + EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder() + .queryEmbedding(embeddingModel().embed("matching").content()) + .filter(metadataFilter) + .maxResults(100) + .build(); // when List> matches = embeddingStore().search(embeddingSearchRequest).matches(); // then - assertThat(matches).hasSize(matchingMetadatas.size() + (buildSegmentWithoutMetadata ? 1 : 0)); + assertThat(matches).hasSize(matchingMetadatas.size()); matches.forEach(match -> assertThat(match.embedded().text()).isEqualTo("matching")); matches.forEach(match -> assertThat(match.score()).isCloseTo(1, withPercentage(0.01))); } protected static Stream should_filter_by_metadata_not() { - return Stream - .builder() - // === Not === - .add( - Arguments.of( - not(metadataKey("name").isEqualTo("Klaus")), - asList(new Metadata().put("name", "Alice"), new Metadata().put("age", 42)), - asList(new Metadata().put("name", "Klaus"), new Metadata().put("name", "Klaus").put("age", 42)), - true - ) - ) - // === NotEqual === + return Stream.builder() - .add( - Arguments.of( - metadataKey("key").isNotEqualTo("a"), - asList( - new Metadata().put("key", "A"), - new Metadata().put("key", "b"), - new Metadata().put("key", "aa"), - new Metadata().put("key", "a a"), - new Metadata().put("key2", "a") - ), - asList(new Metadata().put("key", "a"), new Metadata().put("key", "a").put("key2", "b")), - true - ) - ) - .add( - Arguments.of( - metadataKey("key").isNotEqualTo(TEST_UUID), - asList(new Metadata().put("key", UUID.randomUUID()), new Metadata().put("key2", TEST_UUID)), - asList( - new Metadata().put("key", TEST_UUID), - new Metadata().put("key", TEST_UUID).put("key2", UUID.randomUUID()) - ), - true - ) - ) - .add( - Arguments.of( - metadataKey("key").isNotEqualTo(1), - asList( - new Metadata().put("key", -1), - new Metadata().put("key", 0), - new Metadata().put("key", 2), - new Metadata().put("key", 10), - new Metadata().put("key2", 1) - ), - asList(new Metadata().put("key", 1), new Metadata().put("key", 1).put("key2", 2)), - true - ) - ) - .add( - Arguments.of( - metadataKey("key").isNotEqualTo(1L), - asList( - new Metadata().put("key", -1L), - new Metadata().put("key", 0L), - new Metadata().put("key", 2L), - new Metadata().put("key", 10L), - new Metadata().put("key2", 1L) - ), - asList(new Metadata().put("key", 1L), new Metadata().put("key", 1L).put("key2", 2L)), - true - ) - ) - .add( - Arguments.of( - metadataKey("key").isNotEqualTo(1.1f), - asList( - new Metadata().put("key", -1.1f), - new Metadata().put("key", 0.0f), - new Metadata().put("key", 1.11f), - new Metadata().put("key", 2.2f), - new Metadata().put("key2", 1.1f) - ), - asList(new Metadata().put("key", 1.1f), new Metadata().put("key", 1.1f).put("key2", 2.2f)), - true - ) - ) - .add( - Arguments.of( - metadataKey("key").isNotEqualTo(1.1), - asList( - new Metadata().put("key", -1.1), - new Metadata().put("key", 0.0), - new Metadata().put("key", 1.11), - new Metadata().put("key", 2.2), - new Metadata().put("key2", 1.1) - ), - asList(new Metadata().put("key", 1.1), new Metadata().put("key", 1.1).put("key2", 2.2)), - true - ) - ) - // === NotIn === + // === Not === + .add(Arguments.of( + not( + metadataKey("name").isEqualTo("Klaus") + ), + asList( + new Metadata().put("name", "Alice"), + new Metadata().put("age", 42), + new Metadata() + ), + asList( + new Metadata().put("name", "Klaus"), + new Metadata().put("name", "Klaus").put("age", 42) + ) + )) - // NotIn: string - .add( - Arguments.of( - metadataKey("name").isNotIn("Klaus"), - asList( - new Metadata().put("name", "Klaus Heisler"), - new Metadata().put("name", "Alice"), - new Metadata().put("name2", "Klaus") - ), - asList(new Metadata().put("name", "Klaus"), new Metadata().put("name", "Klaus").put("age", 42)), - true - ) - ) - .add( - Arguments.of( - metadataKey("name").isNotIn(singletonList("Klaus")), - asList( - new Metadata().put("name", "Klaus Heisler"), - new Metadata().put("name", "Alice"), - new Metadata().put("name2", "Klaus") - ), - asList(new Metadata().put("name", "Klaus"), new Metadata().put("name", "Klaus").put("age", 42)), - true - ) - ) - .add( - Arguments.of( - metadataKey("name").isNotIn("Klaus", "Alice"), - asList( - new Metadata().put("name", "Klaus Heisler"), - new Metadata().put("name", "Zoe"), - new Metadata().put("name2", "Klaus") - ), - asList( - new Metadata().put("name", "Klaus"), - new Metadata().put("name", "Klaus").put("age", 42), - new Metadata().put("name", "Alice"), - new Metadata().put("name", "Alice").put("age", 42) - ), - true - ) - ) - .add( - Arguments.of( - metadataKey("name").isNotIn(asList("Klaus", "Alice")), - asList( - new Metadata().put("name", "Klaus Heisler"), - new Metadata().put("name", "Zoe"), - new Metadata().put("name2", "Klaus") - ), - asList( - new Metadata().put("name", "Klaus"), - new Metadata().put("name", "Klaus").put("age", 42), - new Metadata().put("name", "Alice"), - new Metadata().put("name", "Alice").put("age", 42) - ), - true - ) - ) - // NotIn: UUID - .add( - Arguments.of( - metadataKey("name").isNotIn(TEST_UUID), - asList(new Metadata().put("name", UUID.randomUUID()), new Metadata().put("name2", TEST_UUID)), - asList(new Metadata().put("name", TEST_UUID), new Metadata().put("name", TEST_UUID).put("age", 42)), - true - ) - ) - .add( - Arguments.of( - metadataKey("name").isNotIn(singletonList(TEST_UUID)), - asList( - new Metadata().put("name", UUID.randomUUID()), - new Metadata().put("name", TEST_UUID2), - new Metadata().put("name2", TEST_UUID) - ), - asList(new Metadata().put("name", TEST_UUID), new Metadata().put("name", TEST_UUID).put("age", 42)), - true - ) - ) - .add( - Arguments.of( - metadataKey("name").isNotIn(TEST_UUID, TEST_UUID2), - asList(new Metadata().put("name", UUID.randomUUID()), new Metadata().put("name2", TEST_UUID)), - asList( - new Metadata().put("name", TEST_UUID), - new Metadata().put("name", TEST_UUID).put("age", 42), - new Metadata().put("name", TEST_UUID2), - new Metadata().put("name", TEST_UUID2).put("age", 42) - ), - true - ) - ) - .add( - Arguments.of( - metadataKey("name").isNotIn(asList(TEST_UUID, TEST_UUID2)), - asList(new Metadata().put("name", UUID.randomUUID()), new Metadata().put("name2", TEST_UUID)), - asList( - new Metadata().put("name", TEST_UUID), - new Metadata().put("name", TEST_UUID).put("age", 42), - new Metadata().put("name", TEST_UUID2), - new Metadata().put("name", TEST_UUID2).put("age", 42) - ), - true - ) - ) - // NotIn: int - .add( - Arguments.of( - metadataKey("age").isNotIn(42), - asList(new Metadata().put("age", 666), new Metadata().put("age2", 42)), - asList(new Metadata().put("age", 42), new Metadata().put("age", 42).put("name", "Klaus")), - true - ) - ) - .add( - Arguments.of( - metadataKey("age").isNotIn(singletonList(42)), - asList(new Metadata().put("age", 666), new Metadata().put("age2", 42)), - asList(new Metadata().put("age", 42), new Metadata().put("age", 42).put("name", "Klaus")), - true - ) - ) - .add( - Arguments.of( - metadataKey("age").isNotIn(42, 18), - asList(new Metadata().put("age", 666), new Metadata().put("age2", 42)), - asList( - new Metadata().put("age", 42), - new Metadata().put("age", 18), - new Metadata().put("age", 42).put("name", "Klaus"), - new Metadata().put("age", 18).put("name", "Klaus") - ), - true - ) - ) - .add( - Arguments.of( - metadataKey("age").isNotIn(asList(42, 18)), - asList(new Metadata().put("age", 666), new Metadata().put("age2", 42)), - asList( - new Metadata().put("age", 42), - new Metadata().put("age", 18), - new Metadata().put("age", 42).put("name", "Klaus"), - new Metadata().put("age", 18).put("name", "Klaus") - ), - true - ) - ) - // NotIn: long - .add( - Arguments.of( - metadataKey("age").isNotIn(42L), - asList(new Metadata().put("age", 666L), new Metadata().put("age2", 42L)), - asList(new Metadata().put("age", 42L), new Metadata().put("age", 42L).put("name", "Klaus")), - true - ) - ) - .add( - Arguments.of( - metadataKey("age").isNotIn(singletonList(42L)), - asList(new Metadata().put("age", 666L), new Metadata().put("age2", 42L)), - asList(new Metadata().put("age", 42L), new Metadata().put("age", 42L).put("name", "Klaus")), - true - ) - ) - .add( - Arguments.of( - metadataKey("age").isNotIn(42L, 18L), - asList(new Metadata().put("age", 666L), new Metadata().put("age2", 42L)), - asList( - new Metadata().put("age", 42L), - new Metadata().put("age", 18L), - new Metadata().put("age", 42L).put("name", "Klaus"), - new Metadata().put("age", 18L).put("name", "Klaus") - ), - true - ) - ) - .add( - Arguments.of( - metadataKey("age").isNotIn(asList(42L, 18L)), - asList(new Metadata().put("age", 666L), new Metadata().put("age2", 42L)), - asList( - new Metadata().put("age", 42L), - new Metadata().put("age", 18L), - new Metadata().put("age", 42L).put("name", "Klaus"), - new Metadata().put("age", 18L).put("name", "Klaus") - ), - true - ) - ) - // NotIn: float - .add( - Arguments.of( - metadataKey("age").isNotIn(42.0f), - asList(new Metadata().put("age", 666.0f), new Metadata().put("age2", 42.0f)), - asList(new Metadata().put("age", 42.0f), new Metadata().put("age", 42.0f).put("name", "Klaus")), - true - ) - ) - .add( - Arguments.of( - metadataKey("age").isNotIn(singletonList(42.0f)), - asList(new Metadata().put("age", 666.0f), new Metadata().put("age2", 42.0f)), - asList(new Metadata().put("age", 42.0f), new Metadata().put("age", 42.0f).put("name", "Klaus")), - true - ) - ) - .add( - Arguments.of( - metadataKey("age").isNotIn(42.0f, 18.0f), - asList(new Metadata().put("age", 666.0f), new Metadata().put("age2", 42.0f)), - asList( - new Metadata().put("age", 42.0f), - new Metadata().put("age", 18.0f), - new Metadata().put("age", 42.0f).put("name", "Klaus"), - new Metadata().put("age", 18.0f).put("name", "Klaus") - ), - true - ) - ) - .add( - Arguments.of( - metadataKey("age").isNotIn(asList(42.0f, 18.0f)), - asList(new Metadata().put("age", 666.0f), new Metadata().put("age2", 42.0f)), - asList( - new Metadata().put("age", 42.0f), - new Metadata().put("age", 18.0f), - new Metadata().put("age", 42.0f).put("name", "Klaus"), - new Metadata().put("age", 18.0f).put("name", "Klaus") - ), - true - ) - ) - // NotIn: double - .add( - Arguments.of( - metadataKey("age").isNotIn(42.0d), - asList(new Metadata().put("age", 666.0d), new Metadata().put("age2", 42.0d)), - asList(new Metadata().put("age", 42.0d), new Metadata().put("age", 42.0d).put("name", "Klaus")), - true - ) - ) - .add( - Arguments.of( - metadataKey("age").isNotIn(singletonList(42.0d)), - asList(new Metadata().put("age", 666.0d), new Metadata().put("age2", 42.0d)), - asList(new Metadata().put("age", 42.0d), new Metadata().put("age", 42.0d).put("name", "Klaus")), - true - ) - ) - .add( - Arguments.of( - metadataKey("age").isNotIn(42.0d, 18.0d), - asList(new Metadata().put("age", 666.0d), new Metadata().put("age2", 42.0d)), - asList( - new Metadata().put("age", 42.0d), - new Metadata().put("age", 18.0d), - new Metadata().put("age", 42.0d).put("name", "Klaus"), - new Metadata().put("age", 18.0d).put("name", "Klaus") - ), - true - ) - ) - .add( - Arguments.of( - metadataKey("age").isNotIn(asList(42.0d, 18.0d)), - asList(new Metadata().put("age", 666.0d), new Metadata().put("age2", 42.0d)), - asList( - new Metadata().put("age", 42.0d), - new Metadata().put("age", 18.0d), - new Metadata().put("age", 42.0d).put("name", "Klaus"), - new Metadata().put("age", 18.0d).put("name", "Klaus") - ), - true - ) - ) - .build(); + + // === NotEqual === + + .add(Arguments.of( + metadataKey("key").isNotEqualTo("a"), + asList( + new Metadata().put("key", "A"), + new Metadata().put("key", "b"), + new Metadata().put("key", "aa"), + new Metadata().put("key", "a a"), + new Metadata().put("key2", "a"), + new Metadata() + ), + asList( + new Metadata().put("key", "a"), + new Metadata().put("key", "a").put("key2", "b") + ) + )) + .add(Arguments.of( + metadataKey("key").isNotEqualTo(TEST_UUID), + asList( + new Metadata().put("key", UUID.randomUUID()), + new Metadata().put("key2", TEST_UUID), + new Metadata() + ), + asList( + new Metadata().put("key", TEST_UUID), + new Metadata().put("key", TEST_UUID).put("key2", UUID.randomUUID()) + ) + )) + .add(Arguments.of( + metadataKey("key").isNotEqualTo(1), + asList( + new Metadata().put("key", -1), + new Metadata().put("key", 0), + new Metadata().put("key", 2), + new Metadata().put("key", 10), + new Metadata().put("key2", 1), + new Metadata() + ), + asList( + new Metadata().put("key", 1), + new Metadata().put("key", 1).put("key2", 2) + ) + )) + .add(Arguments.of( + metadataKey("key").isNotEqualTo(1L), + asList( + new Metadata().put("key", -1L), + new Metadata().put("key", 0L), + new Metadata().put("key", 2L), + new Metadata().put("key", 10L), + new Metadata().put("key2", 1L), + new Metadata() + ), + asList( + new Metadata().put("key", 1L), + new Metadata().put("key", 1L).put("key2", 2L) + ) + )) + .add(Arguments.of( + metadataKey("key").isNotEqualTo(1.1f), + asList( + new Metadata().put("key", -1.1f), + new Metadata().put("key", 0.0f), + new Metadata().put("key", 1.11f), + new Metadata().put("key", 2.2f), + new Metadata().put("key2", 1.1f), + new Metadata() + ), + asList( + new Metadata().put("key", 1.1f), + new Metadata().put("key", 1.1f).put("key2", 2.2f) + ) + )) + .add(Arguments.of( + metadataKey("key").isNotEqualTo(1.1), + asList( + new Metadata().put("key", -1.1), + new Metadata().put("key", 0.0), + new Metadata().put("key", 1.11), + new Metadata().put("key", 2.2), + new Metadata().put("key2", 1.1), + new Metadata() + ), + asList( + new Metadata().put("key", 1.1), + new Metadata().put("key", 1.1).put("key2", 2.2) + ) + )) + + + // === NotIn === + + // NotIn: string + .add(Arguments.of( + metadataKey("name").isNotIn("Klaus"), + asList( + new Metadata().put("name", "Klaus Heisler"), + new Metadata().put("name", "Alice"), + new Metadata().put("name2", "Klaus"), + new Metadata() + ), + asList( + new Metadata().put("name", "Klaus"), + new Metadata().put("name", "Klaus").put("age", 42) + ) + )) + .add(Arguments.of( + metadataKey("name").isNotIn(singletonList("Klaus")), + asList( + new Metadata().put("name", "Klaus Heisler"), + new Metadata().put("name", "Alice"), + new Metadata().put("name2", "Klaus"), + new Metadata() + ), + asList( + new Metadata().put("name", "Klaus"), + new Metadata().put("name", "Klaus").put("age", 42) + ) + )) + .add(Arguments.of( + metadataKey("name").isNotIn("Klaus", "Alice"), + asList( + new Metadata().put("name", "Klaus Heisler"), + new Metadata().put("name", "Zoe"), + new Metadata().put("name2", "Klaus"), + new Metadata() + ), + asList( + new Metadata().put("name", "Klaus"), + new Metadata().put("name", "Klaus").put("age", 42), + new Metadata().put("name", "Alice"), + new Metadata().put("name", "Alice").put("age", 42) + ) + )) + .add(Arguments.of( + metadataKey("name").isNotIn(asList("Klaus", "Alice")), + asList( + new Metadata().put("name", "Klaus Heisler"), + new Metadata().put("name", "Zoe"), + new Metadata().put("name2", "Klaus"), + new Metadata() + ), + asList( + new Metadata().put("name", "Klaus"), + new Metadata().put("name", "Klaus").put("age", 42), + new Metadata().put("name", "Alice"), + new Metadata().put("name", "Alice").put("age", 42) + ) + )) + + // NotIn: UUID + .add(Arguments.of( + metadataKey("name").isNotIn(TEST_UUID), + asList( + new Metadata().put("name", UUID.randomUUID()), + new Metadata().put("name2", TEST_UUID), + new Metadata() + ), + asList( + new Metadata().put("name", TEST_UUID), + new Metadata().put("name", TEST_UUID).put("age", 42) + ) + )) + .add(Arguments.of( + metadataKey("name").isNotIn(singletonList(TEST_UUID)), + asList( + new Metadata().put("name", UUID.randomUUID()), + new Metadata().put("name", TEST_UUID2), + new Metadata().put("name2", TEST_UUID), + new Metadata() + ), + asList( + new Metadata().put("name", TEST_UUID), + new Metadata().put("name", TEST_UUID).put("age", 42) + ) + )) + .add(Arguments.of( + metadataKey("name").isNotIn(TEST_UUID, TEST_UUID2), + asList( + new Metadata().put("name", UUID.randomUUID()), + new Metadata().put("name2", TEST_UUID), + new Metadata() + ), + asList( + new Metadata().put("name", TEST_UUID), + new Metadata().put("name", TEST_UUID).put("age", 42), + new Metadata().put("name", TEST_UUID2), + new Metadata().put("name", TEST_UUID2).put("age", 42) + ) + )) + .add(Arguments.of( + metadataKey("name").isNotIn(asList(TEST_UUID, TEST_UUID2)), + asList( + new Metadata().put("name", UUID.randomUUID()), + new Metadata().put("name2", TEST_UUID), + new Metadata() + ), + asList( + new Metadata().put("name", TEST_UUID), + new Metadata().put("name", TEST_UUID).put("age", 42), + new Metadata().put("name", TEST_UUID2), + new Metadata().put("name", TEST_UUID2).put("age", 42) + ) + )) + + // NotIn: int + .add(Arguments.of( + metadataKey("age").isNotIn(42), + asList( + new Metadata().put("age", 666), + new Metadata().put("age2", 42), + new Metadata() + ), + asList( + new Metadata().put("age", 42), + new Metadata().put("age", 42).put("name", "Klaus") + ) + )) + .add(Arguments.of( + metadataKey("age").isNotIn(singletonList(42)), + asList( + new Metadata().put("age", 666), + new Metadata().put("age2", 42), + new Metadata() + ), + asList( + new Metadata().put("age", 42), + new Metadata().put("age", 42).put("name", "Klaus") + ) + )) + .add(Arguments.of( + metadataKey("age").isNotIn(42, 18), + asList( + new Metadata().put("age", 666), + new Metadata().put("age2", 42), + new Metadata() + ), + asList( + new Metadata().put("age", 42), + new Metadata().put("age", 18), + new Metadata().put("age", 42).put("name", "Klaus"), + new Metadata().put("age", 18).put("name", "Klaus") + ) + )) + .add(Arguments.of( + metadataKey("age").isNotIn(asList(42, 18)), + asList( + new Metadata().put("age", 666), + new Metadata().put("age2", 42), + new Metadata() + ), + asList( + new Metadata().put("age", 42), + new Metadata().put("age", 18), + new Metadata().put("age", 42).put("name", "Klaus"), + new Metadata().put("age", 18).put("name", "Klaus") + ) + )) + + // NotIn: long + .add(Arguments.of( + metadataKey("age").isNotIn(42L), + asList( + new Metadata().put("age", 666L), + new Metadata().put("age2", 42L), + new Metadata() + ), + asList( + new Metadata().put("age", 42L), + new Metadata().put("age", 42L).put("name", "Klaus") + ) + )) + .add(Arguments.of( + metadataKey("age").isNotIn(singletonList(42L)), + asList( + new Metadata().put("age", 666L), + new Metadata().put("age2", 42L), + new Metadata() + ), + asList( + new Metadata().put("age", 42L), + new Metadata().put("age", 42L).put("name", "Klaus") + ) + )) + .add(Arguments.of( + metadataKey("age").isNotIn(42L, 18L), + asList( + new Metadata().put("age", 666L), + new Metadata().put("age2", 42L), + new Metadata() + ), + asList( + new Metadata().put("age", 42L), + new Metadata().put("age", 18L), + new Metadata().put("age", 42L).put("name", "Klaus"), + new Metadata().put("age", 18L).put("name", "Klaus") + ) + )) + .add(Arguments.of( + metadataKey("age").isNotIn(asList(42L, 18L)), + asList( + new Metadata().put("age", 666L), + new Metadata().put("age2", 42L), + new Metadata() + ), + asList( + new Metadata().put("age", 42L), + new Metadata().put("age", 18L), + new Metadata().put("age", 42L).put("name", "Klaus"), + new Metadata().put("age", 18L).put("name", "Klaus") + ) + )) + + // NotIn: float + .add(Arguments.of( + metadataKey("age").isNotIn(42.0f), + asList( + new Metadata().put("age", 666.0f), + new Metadata().put("age2", 42.0f), + new Metadata() + ), + asList( + new Metadata().put("age", 42.0f), + new Metadata().put("age", 42.0f).put("name", "Klaus") + ) + )) + .add(Arguments.of( + metadataKey("age").isNotIn(singletonList(42.0f)), + asList( + new Metadata().put("age", 666.0f), + new Metadata().put("age2", 42.0f), + new Metadata() + ), + asList( + new Metadata().put("age", 42.0f), + new Metadata().put("age", 42.0f).put("name", "Klaus") + ) + )) + .add(Arguments.of( + metadataKey("age").isNotIn(42.0f, 18.0f), + asList( + new Metadata().put("age", 666.0f), + new Metadata().put("age2", 42.0f), + new Metadata() + ), + asList( + new Metadata().put("age", 42.0f), + new Metadata().put("age", 18.0f), + new Metadata().put("age", 42.0f).put("name", "Klaus"), + new Metadata().put("age", 18.0f).put("name", "Klaus") + ) + )) + .add(Arguments.of( + metadataKey("age").isNotIn(asList(42.0f, 18.0f)), + asList( + new Metadata().put("age", 666.0f), + new Metadata().put("age2", 42.0f), + new Metadata() + ), + asList( + new Metadata().put("age", 42.0f), + new Metadata().put("age", 18.0f), + new Metadata().put("age", 42.0f).put("name", "Klaus"), + new Metadata().put("age", 18.0f).put("name", "Klaus") + ) + )) + + // NotIn: double + .add(Arguments.of( + metadataKey("age").isNotIn(42.0d), + asList( + new Metadata().put("age", 666.0d), + new Metadata().put("age2", 42.0d), + new Metadata() + ), + asList( + new Metadata().put("age", 42.0d), + new Metadata().put("age", 42.0d).put("name", "Klaus") + ) + )) + .add(Arguments.of( + metadataKey("age").isNotIn(singletonList(42.0d)), + asList( + new Metadata().put("age", 666.0d), + new Metadata().put("age2", 42.0d), + new Metadata() + ), + asList( + new Metadata().put("age", 42.0d), + new Metadata().put("age", 42.0d).put("name", "Klaus") + ) + )) + .add(Arguments.of( + metadataKey("age").isNotIn(42.0d, 18.0d), + asList( + new Metadata().put("age", 666.0d), + new Metadata().put("age2", 42.0d), + new Metadata() + ), + asList( + new Metadata().put("age", 42.0d), + new Metadata().put("age", 18.0d), + new Metadata().put("age", 42.0d).put("name", "Klaus"), + new Metadata().put("age", 18.0d).put("name", "Klaus") + ) + )) + .add(Arguments.of( + metadataKey("age").isNotIn(asList(42.0d, 18.0d)), + asList( + new Metadata().put("age", 666.0d), + new Metadata().put("age2", 42.0d), + new Metadata() + ), + asList( + new Metadata().put("age", 42.0d), + new Metadata().put("age", 18.0d), + new Metadata().put("age", 42.0d).put("name", "Klaus"), + new Metadata().put("age", 18.0d).put("name", "Klaus") + ) + )) + + .build(); } } diff --git a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithRemovalIT.java b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithRemovalIT.java index e9213c2fe..3ed1b5168 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithRemovalIT.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithRemovalIT.java @@ -152,10 +152,17 @@ public abstract class EmbeddingStoreWithRemovalIT { awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).isEmpty()); } + protected void awaitUntilAsserted(ThrowingRunnable assertion) { + Awaitility.await() + .atMost(Duration.ofSeconds(60)) + .pollDelay(Duration.ofSeconds(0)) + .pollInterval(Duration.ofMillis(300)) + .untilAsserted(assertion); + } + protected List> getAllEmbeddings() { - EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest - .builder() + EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder() .queryEmbedding(embeddingModel().embed("test").content()) .maxResults(1000) .build(); @@ -164,11 +171,4 @@ public abstract class EmbeddingStoreWithRemovalIT { return searchResult.matches(); } - - protected static void awaitUntilAsserted(ThrowingRunnable assertion) { - Awaitility.await() - .pollInterval(Duration.ofMillis(500)) - .atMost(Duration.ofSeconds(15)) - .untilAsserted(assertion); - } } diff --git a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithoutMetadataIT.java b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithoutMetadataIT.java index 59fa39881..678c9fead 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithoutMetadataIT.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/store/embedding/EmbeddingStoreWithoutMetadataIT.java @@ -3,9 +3,13 @@ package dev.langchain4j.store.embedding; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; +import org.assertj.core.data.Percentage; +import org.awaitility.Awaitility; +import org.awaitility.core.ThrowingRunnable; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import java.time.Duration; import java.util.List; import static dev.langchain4j.internal.Utils.randomUUID; @@ -29,27 +33,30 @@ public abstract class EmbeddingStoreWithoutMetadataIT { } protected void ensureStoreIsEmpty() { - Embedding embedding = embeddingModel().embed("hello").content(); - assertThat(embeddingStore().findRelevant(embedding, 1000)).isEmpty(); + assertThat(getAllEmbeddings()).isEmpty(); } @Test void should_add_embedding() { + // given Embedding embedding = embeddingModel().embed("hello").content(); - String id = embeddingStore().add(embedding); - assertThat(id).isNotBlank(); - awaitUntilPersisted(); + awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(1)); + // when List> relevant = embeddingStore().findRelevant(embedding, 10); - assertThat(relevant).hasSize(1); + // then + assertThat(id).isNotBlank(); + assertThat(relevant).hasSize(1); EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.score()).isCloseTo(1, percentage()); assertThat(match.embeddingId()).isEqualTo(id); - assertThat(match.embedding()).isEqualTo(embedding); + if (assertEmbedding()) { + assertThat(match.embedding()).isEqualTo(embedding); + } assertThat(match.embedded()).isNull(); // new API @@ -62,20 +69,24 @@ public abstract class EmbeddingStoreWithoutMetadataIT { @Test void should_add_embedding_with_id() { + // given String id = randomUUID(); Embedding embedding = embeddingModel().embed("hello").content(); - embeddingStore().add(id, embedding); - awaitUntilPersisted(); + awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(1)); + // when List> relevant = embeddingStore().findRelevant(embedding, 10); - assertThat(relevant).hasSize(1); + // then + assertThat(relevant).hasSize(1); EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.score()).isCloseTo(1, percentage()); assertThat(match.embeddingId()).isEqualTo(id); - assertThat(match.embedding()).isEqualTo(embedding); + if (assertEmbedding()) { + assertThat(match.embedding()).isEqualTo(embedding); + } assertThat(match.embedded()).isNull(); // new API @@ -88,21 +99,25 @@ public abstract class EmbeddingStoreWithoutMetadataIT { @Test void should_add_embedding_with_segment() { + // given TextSegment segment = TextSegment.from("hello"); Embedding embedding = embeddingModel().embed(segment.text()).content(); - String id = embeddingStore().add(embedding, segment); - assertThat(id).isNotBlank(); - awaitUntilPersisted(); + awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(1)); + // when List> relevant = embeddingStore().findRelevant(embedding, 10); - assertThat(relevant).hasSize(1); + // then + assertThat(id).isNotBlank(); + assertThat(relevant).hasSize(1); EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.score()).isCloseTo(1, percentage()); assertThat(match.embeddingId()).isEqualTo(id); - assertThat(match.embedding()).isEqualTo(embedding); + if (assertEmbedding()) { + assertThat(match.embedding()).isEqualTo(embedding); + } assertThat(match.embedded()).isEqualTo(segment); // new API @@ -115,34 +130,41 @@ public abstract class EmbeddingStoreWithoutMetadataIT { @Test void should_add_multiple_embeddings() { + // given Embedding firstEmbedding = embeddingModel().embed("hello").content(); Embedding secondEmbedding = embeddingModel().embed("hi").content(); - List ids = embeddingStore().addAll(asList(firstEmbedding, secondEmbedding)); + + awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(2)); + + // when + List> relevant = embeddingStore().findRelevant(firstEmbedding, 10); + + // then assertThat(ids).hasSize(2); assertThat(ids.get(0)).isNotBlank(); assertThat(ids.get(1)).isNotBlank(); assertThat(ids.get(0)).isNotEqualTo(ids.get(1)); - awaitUntilPersisted(); - - List> relevant = embeddingStore().findRelevant(firstEmbedding, 10); assertThat(relevant).hasSize(2); - EmbeddingMatch firstMatch = relevant.get(0); - assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); + assertThat(firstMatch.score()).isCloseTo(1, percentage()); assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); - assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); + if (assertEmbedding()) { + assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); + } assertThat(firstMatch.embedded()).isNull(); EmbeddingMatch secondMatch = relevant.get(1); assertThat(secondMatch.score()).isCloseTo( RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)), - withPercentage(1) + percentage() ); assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); - assertThat(CosineSimilarity.between(secondMatch.embedding(), secondEmbedding)) - .isCloseTo(1, withPercentage(0.01)); // TODO return strict check back once Qdrant fixes it + if (assertEmbedding()) { + assertThat(CosineSimilarity.between(secondMatch.embedding(), secondEmbedding)) + .isCloseTo(1, withPercentage(0.01)); // TODO return strict check back once Qdrant fixes it + } assertThat(secondMatch.embedded()).isNull(); // new API @@ -155,6 +177,7 @@ public abstract class EmbeddingStoreWithoutMetadataIT { @Test void should_add_multiple_embeddings_with_segments() { + // given TextSegment firstSegment = TextSegment.from("hello"); Embedding firstEmbedding = embeddingModel().embed(firstSegment.text()).content(); @@ -165,30 +188,37 @@ public abstract class EmbeddingStoreWithoutMetadataIT { asList(firstEmbedding, secondEmbedding), asList(firstSegment, secondSegment) ); + + awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(2)); + + // when + List> relevant = embeddingStore().findRelevant(firstEmbedding, 10); + + // then assertThat(ids).hasSize(2); assertThat(ids.get(0)).isNotBlank(); assertThat(ids.get(1)).isNotBlank(); assertThat(ids.get(0)).isNotEqualTo(ids.get(1)); - awaitUntilPersisted(); - - List> relevant = embeddingStore().findRelevant(firstEmbedding, 10); assertThat(relevant).hasSize(2); - EmbeddingMatch firstMatch = relevant.get(0); - assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); + assertThat(firstMatch.score()).isCloseTo(1, percentage()); assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); - assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); + if (assertEmbedding()) { + assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); + } assertThat(firstMatch.embedded()).isEqualTo(firstSegment); EmbeddingMatch secondMatch = relevant.get(1); assertThat(secondMatch.score()).isCloseTo( RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)), - withPercentage(1) + percentage() ); assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); - assertThat(CosineSimilarity.between(secondMatch.embedding(), secondEmbedding)) - .isCloseTo(1, withPercentage(0.01)); // TODO return strict check back once Qdrant fixes it + if (assertEmbedding()) { + assertThat(CosineSimilarity.between(secondMatch.embedding(), secondEmbedding)) + .isCloseTo(1, withPercentage(0.01)); // TODO return strict check back once Qdrant fixes it + } assertThat(secondMatch.embedded()).isEqualTo(secondSegment); // new API @@ -201,6 +231,7 @@ public abstract class EmbeddingStoreWithoutMetadataIT { @Test void should_find_with_min_score() { + // given String firstId = randomUUID(); Embedding firstEmbedding = embeddingModel().embed("hello").content(); embeddingStore().add(firstId, firstEmbedding); @@ -209,33 +240,41 @@ public abstract class EmbeddingStoreWithoutMetadataIT { Embedding secondEmbedding = embeddingModel().embed("hi").content(); embeddingStore().add(secondId, secondEmbedding); - awaitUntilPersisted(); + awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(2)); + // when List> relevant = embeddingStore().findRelevant(firstEmbedding, 10); + + // then assertThat(relevant).hasSize(2); EmbeddingMatch firstMatch = relevant.get(0); - assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); + assertThat(firstMatch.score()).isCloseTo(1, percentage()); assertThat(firstMatch.embeddingId()).isEqualTo(firstId); EmbeddingMatch secondMatch = relevant.get(1); assertThat(secondMatch.score()).isCloseTo( RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)), - withPercentage(1) + percentage() ); assertThat(secondMatch.embeddingId()).isEqualTo(secondId); + // new API assertThat(embeddingStore().search(EmbeddingSearchRequest.builder() .queryEmbedding(firstEmbedding) .maxResults(10) .build()).matches()).isEqualTo(relevant); + // when List> relevant2 = embeddingStore().findRelevant( firstEmbedding, 10, secondMatch.score() - 0.01 ); + + // then assertThat(relevant2).hasSize(2); assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId); assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId); + // new API assertThat(embeddingStore().search(EmbeddingSearchRequest.builder() .queryEmbedding(firstEmbedding) @@ -243,14 +282,18 @@ public abstract class EmbeddingStoreWithoutMetadataIT { .minScore(secondMatch.score() - 0.01) .build()).matches()).isEqualTo(relevant2); + // when List> relevant3 = embeddingStore().findRelevant( firstEmbedding, 10, secondMatch.score() ); + + // then assertThat(relevant3).hasSize(2); assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId); assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId); + // new API assertThat(embeddingStore().search(EmbeddingSearchRequest.builder() .queryEmbedding(firstEmbedding) @@ -258,13 +301,17 @@ public abstract class EmbeddingStoreWithoutMetadataIT { .minScore(secondMatch.score()) .build()).matches()).isEqualTo(relevant3); + // when List> relevant4 = embeddingStore().findRelevant( firstEmbedding, 10, secondMatch.score() + 0.01 ); + + // then assertThat(relevant4).hasSize(1); assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId); + // new API assertThat(embeddingStore().search(EmbeddingSearchRequest.builder() .queryEmbedding(firstEmbedding) @@ -276,22 +323,25 @@ public abstract class EmbeddingStoreWithoutMetadataIT { @Test void should_return_correct_score() { + // given Embedding embedding = embeddingModel().embed("hello").content(); String id = embeddingStore().add(embedding); assertThat(id).isNotBlank(); - awaitUntilPersisted(); + awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(1)); Embedding referenceEmbedding = embeddingModel().embed("hi").content(); + // when List> relevant = embeddingStore().findRelevant(referenceEmbedding, 1); - assertThat(relevant).hasSize(1); + // then + assertThat(relevant).hasSize(1); EmbeddingMatch match = relevant.get(0); assertThat(match.score()).isCloseTo( RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)), - withPercentage(1) + percentage() ); // new API @@ -301,7 +351,31 @@ public abstract class EmbeddingStoreWithoutMetadataIT { .build()).matches()).isEqualTo(relevant); } - protected void awaitUntilPersisted() { - // not waiting by default + protected void awaitUntilAsserted(ThrowingRunnable assertion) { + Awaitility.await() + .atMost(Duration.ofSeconds(60)) + .pollDelay(Duration.ofSeconds(0)) + .pollInterval(Duration.ofMillis(300)) + .untilAsserted(assertion); + } + + protected List> getAllEmbeddings() { + + EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder() + .queryEmbedding(embeddingModel().embed("test").content()) + .maxResults(1000) + .build(); + + EmbeddingSearchResult searchResult = embeddingStore().search(embeddingSearchRequest); + + return searchResult.matches(); + } + + protected boolean assertEmbedding() { + return true; + } + + protected Percentage percentage() { + return withPercentage(1); } } diff --git a/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/AbstractElasticsearchEmbeddingStoreIT.java b/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/AbstractElasticsearchEmbeddingStoreIT.java index efcd43735..40f3665a5 100644 --- a/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/AbstractElasticsearchEmbeddingStoreIT.java +++ b/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/AbstractElasticsearchEmbeddingStoreIT.java @@ -5,7 +5,6 @@ import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT; -import lombok.SneakyThrows; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; @@ -42,6 +41,7 @@ abstract class AbstractElasticsearchEmbeddingStoreIT extends EmbeddingStoreWithF } abstract ElasticsearchConfiguration withConfiguration(); + void optionallyCreateIndex(String indexName) throws IOException { } @@ -80,10 +80,4 @@ abstract class AbstractElasticsearchEmbeddingStoreIT extends EmbeddingStoreWithF protected void ensureStoreIsEmpty() { // TODO fix } - - @Override - @SneakyThrows - protected void awaitUntilPersisted() { - elasticsearchClientHelper.client.indices().refresh(rr -> rr.index(indexName).ignoreUnavailable(true)); - } } diff --git a/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchKnnEmbeddingStoreIT.java b/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreKnnIT.java similarity index 96% rename from langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchKnnEmbeddingStoreIT.java rename to langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreKnnIT.java index da57f107b..8f2058c5b 100644 --- a/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchKnnEmbeddingStoreIT.java +++ b/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreKnnIT.java @@ -4,7 +4,7 @@ import co.elastic.clients.transport.endpoints.BooleanResponse; import java.io.IOException; -class ElasticsearchKnnEmbeddingStoreIT extends AbstractElasticsearchEmbeddingStoreIT { +class ElasticsearchEmbeddingStoreKnnIT extends AbstractElasticsearchEmbeddingStoreIT { @Override ElasticsearchConfiguration withConfiguration() { diff --git a/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreIT.java b/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreScriptIT.java similarity index 68% rename from langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreIT.java rename to langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreScriptIT.java index 2febf0e88..b83c6e8f8 100644 --- a/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreIT.java +++ b/langchain4j-elasticsearch/src/test/java/dev/langchain4j/store/embedding/elasticsearch/ElasticsearchEmbeddingStoreScriptIT.java @@ -1,6 +1,7 @@ package dev.langchain4j.store.embedding.elasticsearch; -class ElasticsearchEmbeddingStoreIT extends AbstractElasticsearchEmbeddingStoreIT { +class ElasticsearchEmbeddingStoreScriptIT extends AbstractElasticsearchEmbeddingStoreIT { + @Override ElasticsearchConfiguration withConfiguration() { return ElasticsearchConfigurationScript.builder().build(); diff --git a/langchain4j-infinispan/pom.xml b/langchain4j-infinispan/pom.xml index a13bb2065..078a3e79e 100644 --- a/langchain4j-infinispan/pom.xml +++ b/langchain4j-infinispan/pom.xml @@ -101,6 +101,13 @@ infinispan-server-testdriver-core test + + + org.awaitility + awaitility + test + + diff --git a/langchain4j-mongodb-atlas/pom.xml b/langchain4j-mongodb-atlas/pom.xml index b7526e45c..c1ab76fe7 100644 --- a/langchain4j-mongodb-atlas/pom.xml +++ b/langchain4j-mongodb-atlas/pom.xml @@ -97,6 +97,13 @@ slf4j-tinylog test + + + org.awaitility + awaitility + test + + diff --git a/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreCloudIT.java b/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreCloudIT.java index d80eb8911..34bc4a493 100644 --- a/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreCloudIT.java +++ b/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreCloudIT.java @@ -6,11 +6,10 @@ import com.mongodb.client.MongoClients; import com.mongodb.client.MongoCollection; import com.mongodb.client.model.Filters; import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreIT; -import lombok.SneakyThrows; import org.bson.codecs.configuration.CodecRegistry; import org.bson.codecs.pojo.PojoCodecProvider; import org.bson.conversions.Bson; @@ -78,10 +77,4 @@ class MongoDbEmbeddingStoreCloudIT extends EmbeddingStoreIT { Bson filter = Filters.exists("embedding"); collection.deleteMany(filter); } - - @Override - @SneakyThrows - protected void awaitUntilPersisted() { - Thread.sleep(3000); - } } diff --git a/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreLocalIT.java b/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreLocalIT.java index a2efc2fe6..b7a28e91c 100644 --- a/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreLocalIT.java +++ b/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreLocalIT.java @@ -1,13 +1,17 @@ package dev.langchain4j.store.embedding.mongodb; -import com.mongodb.*; +import com.mongodb.ConnectionString; +import com.mongodb.MongoClientSettings; +import com.mongodb.MongoCredential; +import com.mongodb.ServerApi; +import com.mongodb.ServerApiVersion; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoClients; import com.mongodb.client.MongoCollection; import com.mongodb.client.model.Filters; import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreIT; import lombok.SneakyThrows; @@ -91,10 +95,4 @@ class MongoDbEmbeddingStoreLocalIT extends EmbeddingStoreIT { Bson filter = Filters.exists("embedding"); collection.deleteMany(filter); } - - @Override - @SneakyThrows - protected void awaitUntilPersisted() { - Thread.sleep(2000); - } } diff --git a/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreFilterIT.java b/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreNativeFilterIT.java similarity index 91% rename from langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreFilterIT.java rename to langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreNativeFilterIT.java index 083438029..0970fd9bb 100644 --- a/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreFilterIT.java +++ b/langchain4j-mongodb-atlas/src/test/java/dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStoreNativeFilterIT.java @@ -1,14 +1,18 @@ package dev.langchain4j.store.embedding.mongodb; -import com.mongodb.*; +import com.mongodb.ConnectionString; +import com.mongodb.MongoClientSettings; +import com.mongodb.MongoCredential; +import com.mongodb.ServerApi; +import com.mongodb.ServerApiVersion; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoClients; import com.mongodb.client.model.Filters; import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingStore; import lombok.SneakyThrows; @@ -23,7 +27,7 @@ import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.data.Percentage.withPercentage; -class MongoDbEmbeddingStoreFilterIT { +class MongoDbEmbeddingStoreNativeFilterIT { static MongoDBAtlasContainer mongodb = new MongoDBAtlasContainer(); @@ -68,6 +72,8 @@ class MongoDbEmbeddingStoreFilterIT { @Test void should_find_relevant_with_filter() { + + // given TextSegment segment = TextSegment.from("this segment should be found", Metadata.from("test-key", "test-value")); Embedding embedding = embeddingModel.embed(segment.text()).content(); @@ -75,8 +81,7 @@ class MongoDbEmbeddingStoreFilterIT { Embedding filterEmbedding = embeddingModel.embed(filterSegment.text()).content(); List ids = embeddingStore.addAll(asList(embedding, filterEmbedding), asList(segment, filterSegment)); - assertThat(ids) - .hasSize(2); + assertThat(ids).hasSize(2); TextSegment refSegment = TextSegment.from("find a segment"); Embedding refEmbedding = embeddingModel.embed(refSegment.text()).content(); @@ -84,7 +89,8 @@ class MongoDbEmbeddingStoreFilterIT { awaitUntilPersisted(); List> relevant = embeddingStore.findRelevant(refEmbedding, 2); - // Only segment should be found, filterSegment should be filtered + + // then assertThat(relevant).hasSize(1); EmbeddingMatch match = relevant.get(0); @@ -95,7 +101,7 @@ class MongoDbEmbeddingStoreFilterIT { } @SneakyThrows - protected void awaitUntilPersisted() { + private void awaitUntilPersisted() { Thread.sleep(2000); } } diff --git a/langchain4j-neo4j/pom.xml b/langchain4j-neo4j/pom.xml index 7fdec4a1b..6b778b101 100644 --- a/langchain4j-neo4j/pom.xml +++ b/langchain4j-neo4j/pom.xml @@ -104,6 +104,13 @@ mockito-junit-jupiter test + + + org.awaitility + awaitility + test + + diff --git a/langchain4j-neo4j/src/test/java/dev/langchain4j/store/embedding/neo4j/Neo4jEmbeddingStoreIT.java b/langchain4j-neo4j/src/test/java/dev/langchain4j/store/embedding/neo4j/Neo4jEmbeddingStoreIT.java index 6dfe9e55d..1af381e38 100644 --- a/langchain4j-neo4j/src/test/java/dev/langchain4j/store/embedding/neo4j/Neo4jEmbeddingStoreIT.java +++ b/langchain4j-neo4j/src/test/java/dev/langchain4j/store/embedding/neo4j/Neo4jEmbeddingStoreIT.java @@ -426,12 +426,6 @@ class Neo4jEmbeddingStoreIT { assertThat(rowsBatched.get(0)).hasSize(1); } - @Test - void test_row_batches_empty() { - List>> rowsBatched = getListRowsBatched(0); - assertThat(rowsBatched).isEmpty(); - } - @Test void test_row_batches_10000_elements() { List>> rowsBatched = getListRowsBatched(10000); diff --git a/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreAwsIT.java b/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreAwsIT.java index be71aaa3b..0f5b59792 100644 --- a/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreAwsIT.java +++ b/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreAwsIT.java @@ -2,11 +2,10 @@ package dev.langchain4j.store.embedding.opensearch; import com.jayway.jsonpath.JsonPath; import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreIT; -import lombok.SneakyThrows; import net.minidev.json.JSONArray; import org.junit.jupiter.api.BeforeAll; import org.opensearch.client.transport.aws.AwsSdk2TransportOptions; @@ -74,11 +73,4 @@ class OpenSearchEmbeddingStoreAwsIT extends EmbeddingStoreIT { protected void ensureStoreIsEmpty() { // TODO fix } - - @Override - @SneakyThrows - protected void awaitUntilPersisted() { - Thread.sleep(1000); - } - } diff --git a/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreLocalIT.java b/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreLocalIT.java index ccb5fd00a..947dc2c9c 100644 --- a/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreLocalIT.java +++ b/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreLocalIT.java @@ -1,11 +1,10 @@ package dev.langchain4j.store.embedding.opensearch; import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreIT; -import lombok.SneakyThrows; import org.junit.jupiter.api.BeforeAll; import org.opensearch.testcontainers.OpensearchContainer; import org.testcontainers.junit.jupiter.Container; @@ -51,10 +50,4 @@ class OpenSearchEmbeddingStoreLocalIT extends EmbeddingStoreIT { protected void ensureStoreIsEmpty() { // TODO fix } - - @Override - @SneakyThrows - protected void awaitUntilPersisted() { - Thread.sleep(1000); - } } diff --git a/langchain4j-pinecone/src/test/java/dev/langchain4j/store/embedding/pinecone/PineconeEmbeddingStoreIT.java b/langchain4j-pinecone/src/test/java/dev/langchain4j/store/embedding/pinecone/PineconeEmbeddingStoreIT.java index b27a7347f..e3135092d 100644 --- a/langchain4j-pinecone/src/test/java/dev/langchain4j/store/embedding/pinecone/PineconeEmbeddingStoreIT.java +++ b/langchain4j-pinecone/src/test/java/dev/langchain4j/store/embedding/pinecone/PineconeEmbeddingStoreIT.java @@ -11,7 +11,6 @@ import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan; import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo; import dev.langchain4j.store.embedding.filter.comparison.IsLessThan; import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo; -import lombok.SneakyThrows; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; @@ -54,12 +53,6 @@ class PineconeEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT { return embeddingModel; } - @Override - @SneakyThrows - protected void awaitUntilPersisted() { - Thread.sleep(6000); - } - @ParameterizedTest @MethodSource("should_filter_by_metadata") protected void should_filter_by_metadata(Filter metadataFilter, @@ -68,23 +61,23 @@ class PineconeEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT { super.should_filter_by_metadata(metadataFilter, matchingMetadatas, notMatchingMetadatas); } - // in pinecone, compare filter only works with numbers + // in pinecone compare filter only works with numbers protected static Stream should_filter_by_metadata() { - return EmbeddingStoreWithFilteringIT.should_filter_by_metadata().filter( - arguments -> { - Object o = arguments.get()[0]; - if (o instanceof IsLessThan) { - return ((IsLessThan) o).comparisonValue() instanceof Number; - } else if (o instanceof IsLessThanOrEqualTo) { - return ((IsLessThanOrEqualTo) o).comparisonValue() instanceof Number; - } else if (o instanceof IsGreaterThan) { - return ((IsGreaterThan) o).comparisonValue() instanceof Number; - } else if (o instanceof IsGreaterThanOrEqualTo) { - return ((IsGreaterThanOrEqualTo) o).comparisonValue() instanceof Number; - } else { - return true; - } - } - ); + return EmbeddingStoreWithFilteringIT.should_filter_by_metadata() + .filter(arguments -> { + Filter filter = (Filter) arguments.get()[0]; + if (filter instanceof IsLessThan) { + return ((IsLessThan) filter).comparisonValue() instanceof Number; + } else if (filter instanceof IsLessThanOrEqualTo) { + return ((IsLessThanOrEqualTo) filter).comparisonValue() instanceof Number; + } else if (filter instanceof IsGreaterThan) { + return ((IsGreaterThan) filter).comparisonValue() instanceof Number; + } else if (filter instanceof IsGreaterThanOrEqualTo) { + return ((IsGreaterThanOrEqualTo) filter).comparisonValue() instanceof Number; + } else { + return true; + } + } + ); } } \ No newline at end of file diff --git a/langchain4j-qdrant/pom.xml b/langchain4j-qdrant/pom.xml index a3c8a7867..9df59f114 100644 --- a/langchain4j-qdrant/pom.xml +++ b/langchain4j-qdrant/pom.xml @@ -97,6 +97,12 @@ 1.7.1 + + org.awaitility + awaitility + test + + \ No newline at end of file diff --git a/langchain4j-redis/pom.xml b/langchain4j-redis/pom.xml index 56f08bc4c..692fe99bc 100644 --- a/langchain4j-redis/pom.xml +++ b/langchain4j-redis/pom.xml @@ -90,6 +90,12 @@ test + + org.awaitility + awaitility + test + + diff --git a/langchain4j-vearch/pom.xml b/langchain4j-vearch/pom.xml index 817f6b8e3..a081d3a09 100644 --- a/langchain4j-vearch/pom.xml +++ b/langchain4j-vearch/pom.xml @@ -106,6 +106,13 @@ slf4j-tinylog test + + + org.awaitility + awaitility + test + + diff --git a/langchain4j-vearch/src/test/java/dev/langchain4j/store/embedding/vearch/VearchEmbeddingStoreIT.java b/langchain4j-vearch/src/test/java/dev/langchain4j/store/embedding/vearch/VearchEmbeddingStoreIT.java index e5d99f38c..b00b837d6 100644 --- a/langchain4j-vearch/src/test/java/dev/langchain4j/store/embedding/vearch/VearchEmbeddingStoreIT.java +++ b/langchain4j-vearch/src/test/java/dev/langchain4j/store/embedding/vearch/VearchEmbeddingStoreIT.java @@ -195,8 +195,6 @@ class VearchEmbeddingStoreIT extends EmbeddingStoreIT { embeddingStore().add(altEmbedding, altSegment); } - awaitUntilPersisted(); - List> relevant = embeddingStore().findRelevant(embedding, 1); assertThat(relevant).hasSize(1); diff --git a/langchain4j/src/test/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStoreSerializedTest.java b/langchain4j/src/test/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStoreSerializedTest.java index 356122110..bff256b43 100644 --- a/langchain4j/src/test/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStoreSerializedTest.java +++ b/langchain4j/src/test/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStoreSerializedTest.java @@ -1,14 +1,13 @@ package dev.langchain4j.store.embedding.inmemory; import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT; /** - * Tests if InMemoryEmbeddingStore works correctly after being serialized and deserialized back. - * See awaitUntilPersisted() + * Tests if {@link InMemoryEmbeddingStore} works correctly after being serialized and deserialized back. */ class InMemoryEmbeddingStoreSerializedTest extends EmbeddingStoreWithFilteringIT { @@ -17,14 +16,14 @@ class InMemoryEmbeddingStoreSerializedTest extends EmbeddingStoreWithFilteringIT EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); @Override - protected void awaitUntilPersisted() { - String json = embeddingStore.serializeToJson(); - embeddingStore = InMemoryEmbeddingStore.fromJson(json); + protected EmbeddingStore embeddingStore() { + serializeAndDeserialize(); + return embeddingStore; } - @Override - protected EmbeddingStore embeddingStore() { - return embeddingStore; + private void serializeAndDeserialize() { + String json = embeddingStore.serializeToJson(); + embeddingStore = InMemoryEmbeddingStore.fromJson(json); } @Override