EmbeddingStoreIT: use awaitility (#1610)

## Change
Use awaitility in `EmbeddingStoreIT`

## General checklist
- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [x] I have manually run all the unit and integration tests in the
[core]
This commit is contained in:
LangChain4j 2024-08-22 16:17:53 +02:00 committed by GitHub
parent 2e47b126be
commit 3e6d50ee40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 2202 additions and 2402 deletions

View File

@ -4,25 +4,25 @@ hide_title: false
sidebar_position: 0
---
| Embedding Store | Storing Metadata | Filtering by Metadata | Removing Embeddings |
|---------------------------------------------------------------------------------------|------------------|-----------------------|---------------------|
| [In-memory](/integrations/embedding-stores/in-memory) | ✅ | ✅ | ✅ |
| [Astra DB](/integrations/embedding-stores/astra-db) | ✅ | | |
| [Azure AI Search](/integrations/embedding-stores/azure-ai-search) | ✅ | ✅ | ✅ |
| [Azure CosmosDB Mongo vCore](/integrations/embedding-stores/azure-cosmos-mongo-vcore) | ✅ | | |
| [Azure CosmosDB NoSQL](/integrations/embedding-stores/azure-cosmos-nosql) | ✅ | | |
| [Cassandra](/integrations/embedding-stores/cassandra) | ✅ | | |
| [Chroma](/integrations/embedding-stores/chroma) | ✅ | ✅ | ✅ |
| [Elasticsearch](/integrations/embedding-stores/elasticsearch) | ✅ | ✅ | ✅ |
| [Infinispan](/integrations/embedding-stores/infinispan) | ✅ | | |
| [Milvus](/integrations/embedding-stores/milvus) | ✅ | ✅ | ✅ |
| [MongoDB Atlas](/integrations/embedding-stores/mongodb-atlas) | ✅ | | |
| [Neo4j](/integrations/embedding-stores/neo4j) | ✅ | | |
| [OpenSearch](/integrations/embedding-stores/opensearch) | ✅ | | |
| [PGVector](/integrations/embedding-stores/pgvector) | ✅ | ✅ | ✅ |
| [Pinecone](/integrations/embedding-stores/pinecone) | ✅ | ✅ | ✅ |
| [Qdrant](/integrations/embedding-stores/qdrant) | ✅ | | |
| [Redis](/integrations/embedding-stores/redis) | ✅ | | |
| [Vearch](/integrations/embedding-stores/vearch) | ✅ | | |
| [Vespa](/integrations/embedding-stores/vespa) | | | |
| [Weaviate](/integrations/embedding-stores/weaviate) | ✅ | | ✅ |
| Embedding Store | Storing Metadata | Filtering by Metadata | Removing Embeddings |
|---------------------------------------------------------------------------------------|------------------|----------------------------|---------------------|
| [In-memory](/integrations/embedding-stores/in-memory) | ✅ | ✅ | ✅ |
| [Astra DB](/integrations/embedding-stores/astra-db) | ✅ | | |
| [Azure AI Search](/integrations/embedding-stores/azure-ai-search) | ✅ | ✅ | ✅ |
| [Azure CosmosDB Mongo vCore](/integrations/embedding-stores/azure-cosmos-mongo-vcore) | ✅ | | |
| [Azure CosmosDB NoSQL](/integrations/embedding-stores/azure-cosmos-nosql) | ✅ | | |
| [Cassandra](/integrations/embedding-stores/cassandra) | ✅ | | |
| [Chroma](/integrations/embedding-stores/chroma) | ✅ | ✅ | ✅ |
| [Elasticsearch](/integrations/embedding-stores/elasticsearch) | ✅ | ✅ | ✅ |
| [Infinispan](/integrations/embedding-stores/infinispan) | ✅ | | |
| [Milvus](/integrations/embedding-stores/milvus) | ✅ | ✅ | ✅ |
| [MongoDB Atlas](/integrations/embedding-stores/mongodb-atlas) | ✅ | Only native filter support | |
| [Neo4j](/integrations/embedding-stores/neo4j) | ✅ | | |
| [OpenSearch](/integrations/embedding-stores/opensearch) | ✅ | | |
| [PGVector](/integrations/embedding-stores/pgvector) | ✅ | ✅ | ✅ |
| [Pinecone](/integrations/embedding-stores/pinecone) | ✅ | ✅ | ✅ |
| [Qdrant](/integrations/embedding-stores/qdrant) | ✅ | | |
| [Redis](/integrations/embedding-stores/redis) | ✅ | | |
| [Vearch](/integrations/embedding-stores/vearch) | ✅ | | |
| [Vespa](/integrations/embedding-stores/vespa) | | | |
| [Weaviate](/integrations/embedding-stores/weaviate) | ✅ | | ✅ |

View File

@ -6,12 +6,13 @@ import com.azure.search.documents.indexes.SearchIndexClientBuilder;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT;
import org.awaitility.core.ThrowingRunnable;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
@ -31,40 +32,37 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering
private static final Logger log = LoggerFactory.getLogger(AzureAiSearchContentRetrieverIT.class);
private EmbeddingModel embeddingModel;
private final EmbeddingModel embeddingModel;
private AzureAiSearchContentRetriever contentRetrieverWithVector;
private final AzureAiSearchContentRetriever contentRetrieverWithVector;
private AzureAiSearchContentRetriever contentRetrieverWithFullText;
private AzureAiSearchContentRetriever contentRetrieverWithHybrid;
private final AzureAiSearchContentRetriever contentRetrieverWithHybrid;
private AzureAiSearchContentRetriever contentRetrieverWithHybridAndReranking;
private final AzureAiSearchContentRetriever contentRetrieverWithHybridAndReranking;
private int dimensions;
public AzureAiSearchContentRetrieverIT() {
embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
public AzureAiSearchContentRetrieverIT() {
embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
dimensions = embeddingModel.embed("test").content().vector().length;
SearchIndexClient searchIndexClient = new SearchIndexClientBuilder()
.endpoint(System.getenv("AZURE_SEARCH_ENDPOINT"))
.credential(new AzureKeyCredential(System.getenv("AZURE_SEARCH_KEY")))
.buildClient();
SearchIndexClient searchIndexClient = new SearchIndexClientBuilder()
.endpoint(System.getenv("AZURE_SEARCH_ENDPOINT"))
.credential(new AzureKeyCredential(System.getenv("AZURE_SEARCH_KEY")))
.buildClient();
searchIndexClient.deleteIndex(DEFAULT_INDEX_NAME);
searchIndexClient.deleteIndex(DEFAULT_INDEX_NAME);
contentRetrieverWithVector = createContentRetriever(AzureAiSearchQueryType.VECTOR);
contentRetrieverWithFullText = createFullTextSearchContentRetriever();
contentRetrieverWithHybrid = createContentRetriever(AzureAiSearchQueryType.HYBRID);
contentRetrieverWithHybridAndReranking = createContentRetriever(AzureAiSearchQueryType.HYBRID_WITH_RERANKING);
contentRetrieverWithVector = createContentRetriever(AzureAiSearchQueryType.VECTOR);
contentRetrieverWithFullText = createFullTextSearchContentRetriever();
contentRetrieverWithHybrid = createContentRetriever(AzureAiSearchQueryType.HYBRID);
contentRetrieverWithHybridAndReranking = createContentRetriever(AzureAiSearchQueryType.HYBRID_WITH_RERANKING);
}
private AzureAiSearchContentRetriever createContentRetriever(AzureAiSearchQueryType azureAiSearchQueryType) {
return AzureAiSearchContentRetriever.builder()
.endpoint(System.getenv("AZURE_SEARCH_ENDPOINT"))
.apiKey(System.getenv("AZURE_SEARCH_KEY"))
.dimensions(dimensions)
.dimensions(embeddingModel.dimension())
.embeddingModel(embeddingModel)
.queryType(azureAiSearchQueryType)
.maxResults(3)
@ -105,7 +103,7 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering
contentRetrieverWithVector.add(embedding, textSegment);
}
awaitUntilPersisted();
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(contents.size()));
String content = "fruit";
Query query = Query.from(content);
@ -140,7 +138,7 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering
contentRetrieverWithVector.add(embedding, textSegment);
}
awaitUntilPersisted();
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(contents.size()));
String content = "house";
Query query = Query.from(content);
@ -189,7 +187,7 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering
contentRetrieverWithFullText.add(content);
}
awaitUntilPersisted();
awaitUntilAsserted(() -> assertThat(contentRetrieverWithFullText.retrieve(Query.from("a"))).hasSize(contents.size()));
Query query = Query.from("Alain");
List<Content> relevant = contentRetrieverWithFullText.retrieve(query);
@ -206,9 +204,9 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering
@Test
void testFullTextSearchWithSpecificSearchIndex() {
// This doesn't reuse the existing search index, but creates a specialized one only for full text search
// This doesn't reuse the existing search index, but creates a specialized one only for full text search
contentRetrieverWithVector.deleteIndex();
contentRetrieverWithFullText = AzureAiSearchContentRetriever.builder()
contentRetrieverWithFullText = AzureAiSearchContentRetriever.builder()
.endpoint(System.getenv("AZURE_SEARCH_ENDPOINT"))
.apiKey(System.getenv("AZURE_SEARCH_KEY"))
.embeddingModel(null)
@ -219,7 +217,7 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering
.build();
testFullTextSearch();
clearStore();
contentRetrieverWithFullText = createFullTextSearchContentRetriever();
contentRetrieverWithFullText = createFullTextSearchContentRetriever();
}
@Test
@ -247,7 +245,7 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering
contentRetrieverWithHybrid.add(embedding, textSegment);
}
awaitUntilPersisted();
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(contents.size()));
Query query = Query.from("Algeria");
List<Content> relevant = contentRetrieverWithHybrid.retrieve(query);
@ -287,7 +285,7 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering
contentRetrieverWithHybridAndReranking.add(embedding, textSegment);
}
awaitUntilPersisted();
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(contents.size()));
Query query = Query.from("A philosopher who was in the French Resistance");
List<Content> relevant = contentRetrieverWithHybridAndReranking.retrieve(query);
@ -318,18 +316,24 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering
AzureAiSearchContentRetriever azureAiSearchContentRetriever = contentRetrieverWithVector;
try {
azureAiSearchContentRetriever.deleteIndex();
azureAiSearchContentRetriever.createOrUpdateIndex(dimensions);
azureAiSearchContentRetriever.createOrUpdateIndex(embeddingModel.dimension());
} catch (RuntimeException e) {
log.error("Failed to clean up the index. You should look at deleting it manually.", e);
}
}
@Override
protected void awaitUntilPersisted() {
protected void awaitUntilAsserted(ThrowingRunnable assertion) {
super.awaitUntilAsserted(assertion);
try {
Thread.sleep(1_000);
Thread.sleep(1000); // TODO figure out why this is needed and remove this hack
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
@Override
protected boolean assertEmbedding() {
return false; // TODO remove this hack after https://github.com/langchain4j/langchain4j/issues/1617 is closed
}
}

View File

@ -1,58 +1,34 @@
package dev.langchain4j.rag.content.retriever.azure.search;
import com.azure.core.credential.AzureKeyCredential;
import com.azure.search.documents.indexes.SearchIndexClient;
import com.azure.search.documents.indexes.SearchIndexClientBuilder;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT;
import org.junit.jupiter.api.BeforeEach;
import org.awaitility.core.ThrowingRunnable;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static dev.langchain4j.store.embedding.azure.search.AbstractAzureAiSearchEmbeddingStore.DEFAULT_INDEX_NAME;
import static dev.langchain4j.internal.Utils.randomUUID;
import static dev.langchain4j.rag.content.retriever.azure.search.AzureAiSearchQueryType.HYBRID;
@EnabledIfEnvironmentVariable(named = "AZURE_SEARCH_ENDPOINT", matches = ".+")
public class AzureAiSearchContentRetrieverRemovalIT extends EmbeddingStoreWithRemovalIT {
private static final Logger log = LoggerFactory.getLogger(AzureAiSearchContentRetrieverRemovalIT.class);
private EmbeddingModel embeddingModel;
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
private AzureAiSearchContentRetriever contentRetrieverWithVector;
public AzureAiSearchContentRetrieverRemovalIT() {
embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
SearchIndexClient searchIndexClient = new SearchIndexClientBuilder()
.endpoint(System.getenv("AZURE_SEARCH_ENDPOINT"))
.credential(new AzureKeyCredential(System.getenv("AZURE_SEARCH_KEY")))
.buildClient();
searchIndexClient.deleteIndex(DEFAULT_INDEX_NAME);
contentRetrieverWithVector = createContentRetriever(AzureAiSearchQueryType.VECTOR);
}
private AzureAiSearchContentRetriever createContentRetriever(AzureAiSearchQueryType azureAiSearchQueryType) {
return AzureAiSearchContentRetriever.builder()
.endpoint(System.getenv("AZURE_SEARCH_ENDPOINT"))
.apiKey(System.getenv("AZURE_SEARCH_KEY"))
.dimensions(embeddingModel.dimension())
.embeddingModel(embeddingModel)
.queryType(azureAiSearchQueryType)
.maxResults(3)
.minScore(0.0)
.build();
}
@BeforeEach
void setUp() {
clearStore();
}
private final AzureAiSearchContentRetriever contentRetrieverWithVector = AzureAiSearchContentRetriever.builder()
.endpoint(System.getenv("AZURE_SEARCH_ENDPOINT"))
.apiKey(System.getenv("AZURE_SEARCH_KEY"))
.indexName(randomUUID())
.dimensions(embeddingModel.dimension())
.embeddingModel(embeddingModel)
.queryType(HYBRID)
.build();
@Override
protected EmbeddingStore<TextSegment> embeddingStore() {
@ -64,14 +40,22 @@ public class AzureAiSearchContentRetrieverRemovalIT extends EmbeddingStoreWithRe
return this.embeddingModel;
}
protected void clearStore() {
log.debug("Deleting the search index");
AzureAiSearchContentRetriever azureAiSearchContentRetriever = contentRetrieverWithVector;
@AfterEach
void afterEach() {
try {
azureAiSearchContentRetriever.deleteIndex();
azureAiSearchContentRetriever.createOrUpdateIndex(embeddingModel.dimension());
contentRetrieverWithVector.deleteIndex();
} catch (RuntimeException e) {
log.error("Failed to clean up the index. You should look at deleting it manually.", e);
log.error("Failed to delete the index. You should look at deleting it manually.", e);
}
}
@Override
protected void awaitUntilAsserted(ThrowingRunnable assertion) {
super.awaitUntilAsserted(assertion);
try {
Thread.sleep(1000); // TODO figure out why this is needed and remove this hack
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -6,14 +6,13 @@ import com.azure.search.documents.indexes.SearchIndexClientBuilder;
import com.azure.search.documents.indexes.models.SearchField;
import com.azure.search.documents.indexes.models.SearchFieldDataType;
import com.azure.search.documents.indexes.models.SearchIndex;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT;
import org.junit.jupiter.api.BeforeEach;
import org.awaitility.core.ThrowingRunnable;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
@ -22,10 +21,9 @@ import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
import static dev.langchain4j.internal.Utils.randomUUID;
import static dev.langchain4j.store.embedding.azure.search.AbstractAzureAiSearchEmbeddingStore.DEFAULT_FIELD_ID;
import static dev.langchain4j.store.embedding.azure.search.AbstractAzureAiSearchEmbeddingStore.DEFAULT_INDEX_NAME;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
@ -34,31 +32,50 @@ public class AzureAiSearchEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT
private static final Logger log = LoggerFactory.getLogger(AzureAiSearchEmbeddingStoreIT.class);
private EmbeddingModel embeddingModel;
private static final String AZURE_SEARCH_ENDPOINT = System.getenv("AZURE_SEARCH_ENDPOINT");
private static final String AZURE_SEARCH_KEY = System.getenv("AZURE_SEARCH_KEY");
private EmbeddingStore<TextSegment> embeddingStore;
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
private int dimensions;
private final AzureAiSearchEmbeddingStore embeddingStore = AzureAiSearchEmbeddingStore.builder()
.endpoint(AZURE_SEARCH_ENDPOINT)
.apiKey(AZURE_SEARCH_KEY)
.indexName(randomUUID())
.dimensions(embeddingModel.dimension())
.build();
private String AZURE_SEARCH_ENDPOINT = System.getenv("AZURE_SEARCH_ENDPOINT");
private String AZURE_SEARCH_KEY = System.getenv("AZURE_SEARCH_KEY");
public AzureAiSearchEmbeddingStoreIT() {
embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
dimensions = embeddingModel.embed("test").content().vector().length;
embeddingStore = AzureAiSearchEmbeddingStore.builder()
.endpoint(AZURE_SEARCH_ENDPOINT)
.apiKey(AZURE_SEARCH_KEY)
.dimensions(dimensions)
.build();
@Override
protected EmbeddingStore<TextSegment> embeddingStore() {
return embeddingStore;
}
@BeforeEach
void setUp() {
clearStore();
@Override
protected EmbeddingModel embeddingModel() {
return embeddingModel;
}
@AfterEach
void afterEach() {
try {
embeddingStore.deleteIndex();
} catch (RuntimeException e) {
log.error("Failed to delete the index. You should look at deleting it manually.", e);
}
}
@Override
protected void awaitUntilAsserted(ThrowingRunnable assertion) {
super.awaitUntilAsserted(assertion);
try {
Thread.sleep(1000); // TODO figure out why this is needed and remove this hack
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
@Override
protected boolean assertEmbedding() {
return false; // TODO remove this hack after https://github.com/langchain4j/langchain4j/issues/1617 is closed
}
@Test
@ -89,7 +106,7 @@ public class AzureAiSearchEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT
try {
new AzureAiSearchEmbeddingStore(AZURE_SEARCH_ENDPOINT,
new AzureKeyCredential(AZURE_SEARCH_KEY), true, providedIndex, "ANOTHER_INDEX_NAME", null);
new AzureKeyCredential(AZURE_SEARCH_KEY), true, providedIndex, "ANOTHER_INDEX_NAME", null);
fail("Expected IllegalArgumentException to be thrown");
} catch (IllegalArgumentException e) {
@ -102,71 +119,16 @@ public class AzureAiSearchEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT
@Test
public void when_an_index_is_not_provided_the_default_name_is_used() {
AzureAiSearchEmbeddingStore store =new AzureAiSearchEmbeddingStore(AZURE_SEARCH_ENDPOINT,
new AzureKeyCredential(AZURE_SEARCH_KEY), false, null, null, null);
AzureAiSearchEmbeddingStore store = new AzureAiSearchEmbeddingStore(
AZURE_SEARCH_ENDPOINT,
new AzureKeyCredential(AZURE_SEARCH_KEY),
false,
null,
null,
null
);
assertEquals(DEFAULT_INDEX_NAME, store.searchClient.getIndexName());
}
@Test
void test_add_embeddings_and_find_relevant() {
String content1 = "banana";
String content2 = "computer";
String content3 = "apple";
String content4 = "pizza";
String content5 = "strawberry";
String content6 = "chess";
List<String> contents = asList(content1, content2, content3, content4, content5, content6);
for (String content : contents) {
TextSegment textSegment = TextSegment.from(content);
Embedding embedding = embeddingModel.embed(content).content();
embeddingStore.add(embedding, textSegment);
}
awaitUntilPersisted();
Embedding relevantEmbedding = embeddingModel.embed("fruit").content();
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(relevantEmbedding, 3);
assertThat(relevant).hasSize(3);
assertThat(relevant.get(0).embedding()).isNotNull();
assertThat(relevant.get(0).embedded().text()).isIn(content1, content3, content5);
log.info("#1 relevant item: {}", relevant.get(0).embedded().text());
assertThat(relevant.get(1).embedding()).isNotNull();
assertThat(relevant.get(1).embedded().text()).isIn(content1, content3, content5);
log.info("#2 relevant item: {}", relevant.get(1).embedded().text());
assertThat(relevant.get(2).embedding()).isNotNull();
assertThat(relevant.get(2).embedded().text()).isIn(content1, content3, content5);
log.info("#3 relevant item: {}", relevant.get(2).embedded().text());
}
@Override
protected EmbeddingStore<TextSegment> embeddingStore() {
return embeddingStore;
}
@Override
protected EmbeddingModel embeddingModel() {
return embeddingModel;
}
@Override
protected void clearStore() {
AzureAiSearchEmbeddingStore azureAiSearchEmbeddingStore = (AzureAiSearchEmbeddingStore) embeddingStore;
try {
azureAiSearchEmbeddingStore.deleteIndex();
azureAiSearchEmbeddingStore.createOrUpdateIndex(dimensions);
} catch (RuntimeException e) {
log.error("Failed to clean up the index. You should look at deleting it manually.", e);
}
}
@Override
protected void awaitUntilPersisted() {
try {
Thread.sleep(1_000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -5,49 +5,27 @@ import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT;
import org.junit.jupiter.api.BeforeEach;
import org.awaitility.core.ThrowingRunnable;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static dev.langchain4j.internal.Utils.randomUUID;
@EnabledIfEnvironmentVariable(named = "AZURE_SEARCH_ENDPOINT", matches = ".+")
public class AzureAiSearchEmbeddingStoreRemovalIT extends EmbeddingStoreWithRemovalIT {
private static final Logger log = LoggerFactory.getLogger(AzureAiSearchEmbeddingStoreRemovalIT.class);
private EmbeddingModel embeddingModel;
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
private EmbeddingStore<TextSegment> embeddingStore;
private String AZURE_SEARCH_ENDPOINT = System.getenv("AZURE_SEARCH_ENDPOINT");
private String AZURE_SEARCH_KEY = System.getenv("AZURE_SEARCH_KEY");
public AzureAiSearchEmbeddingStoreRemovalIT() {
embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
embeddingStore = AzureAiSearchEmbeddingStore.builder()
.endpoint(AZURE_SEARCH_ENDPOINT)
.apiKey(AZURE_SEARCH_KEY)
.dimensions(embeddingModel.dimension())
.build();
}
@BeforeEach
void setUp() {
clearStore();
}
private void clearStore() {
AzureAiSearchEmbeddingStore azureAiSearchEmbeddingStore = (AzureAiSearchEmbeddingStore) embeddingStore;
try {
azureAiSearchEmbeddingStore.deleteIndex();
azureAiSearchEmbeddingStore.createOrUpdateIndex(embeddingModel.dimension());
} catch (RuntimeException e) {
log.error("Failed to clean up the index. You should look at deleting it manually.", e);
}
}
private final AzureAiSearchEmbeddingStore embeddingStore = AzureAiSearchEmbeddingStore.builder()
.endpoint(System.getenv("AZURE_SEARCH_ENDPOINT"))
.apiKey(System.getenv("AZURE_SEARCH_KEY"))
.indexName(randomUUID())
.dimensions(embeddingModel.dimension())
.build();
@Override
protected EmbeddingStore<TextSegment> embeddingStore() {
@ -58,4 +36,23 @@ public class AzureAiSearchEmbeddingStoreRemovalIT extends EmbeddingStoreWithRemo
protected EmbeddingModel embeddingModel() {
return embeddingModel;
}
@AfterEach
void afterEach() {
try {
embeddingStore.deleteIndex();
} catch (RuntimeException e) {
log.error("Failed to delete the index. You should look at deleting it manually.", e);
}
}
@Override
protected void awaitUntilAsserted(ThrowingRunnable assertion) {
super.awaitUntilAsserted(assertion);
try {
Thread.sleep(1000); // TODO figure out why this is needed and remove this hack
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -85,6 +85,13 @@
<artifactId>slf4j-tinylog</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>

View File

@ -6,43 +6,30 @@ import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.model.Filters;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
import org.bson.BsonDocument;
import org.bson.codecs.configuration.CodecRegistry;
import org.bson.codecs.pojo.PojoCodecProvider;
import org.bson.conversions.Bson;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.bson.codecs.configuration.CodecRegistries.fromProviders;
import static org.bson.codecs.configuration.CodecRegistries.fromRegistries;
import static org.junit.jupiter.api.Assertions.assertThrows;
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOS_ENDPOINT", matches = ".+")
public class AzureCosmosDBMongoVCoreEmbeddingStoreIT extends EmbeddingStoreIT {
private static final Logger log = LoggerFactory.getLogger(AzureCosmosDBMongoVCoreEmbeddingStoreIT.class);
private static MongoClient client;
private final EmbeddingModel embeddingModel;
private final EmbeddingStore<TextSegment> embeddingStore;
private final int dimensions;
public AzureCosmosDBMongoVCoreEmbeddingStoreIT() {
embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
dimensions = embeddingModel.embed("hello").content().vector().length;
client = MongoClients.create(
MongoClientSettings.builder()
@ -59,45 +46,13 @@ public class AzureCosmosDBMongoVCoreEmbeddingStoreIT extends EmbeddingStoreIT {
.createIndex(true)
.kind("vector-hnsw")
.numLists(2)
.dimensions(dimensions)
.dimensions(embeddingModel.dimension())
.m(16)
.efConstruction(64)
.efSearch(40)
.build();
}
@Test
void testAddEmbeddingsAndFindRelevant() {
String content1 = "banana";
String content2 = "computer";
String content3 = "apple";
String content4 = "pizza";
String content5 = "strawberry";
String content6 = "chess";
List<String> contents = asList(content1, content2, content3, content4, content5, content6);
for (String content : contents) {
TextSegment textSegment = TextSegment.from(content);
Embedding embedding = embeddingModel.embed(content).content();
embeddingStore.add(embedding, textSegment);
}
awaitUntilPersisted();
Embedding relevantEmbedding = embeddingModel.embed("fruit").content();
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(relevantEmbedding, 3);
assertThat(relevant).hasSize(3);
assertThat(relevant.get(0).embedding()).isNotNull();
assertThat(relevant.get(0).embedded().text()).isIn(content1, content3, content5);
log.info("#1 relevant item: {}", relevant.get(0).embedded().text());
assertThat(relevant.get(1).embedding()).isNotNull();
assertThat(relevant.get(1).embedded().text()).isIn(content1, content3, content5);
log.info("#2 relevant item: {}", relevant.get(1).embedded().text());
assertThat(relevant.get(2).embedding()).isNotNull();
assertThat(relevant.get(2).embedded().text()).isIn(content1, content3, content5);
log.info("#3 relevant item: {}", relevant.get(2).embedded().text());
}
@Override
protected EmbeddingStore<TextSegment> embeddingStore() {
return embeddingStore;
@ -108,15 +63,6 @@ public class AzureCosmosDBMongoVCoreEmbeddingStoreIT extends EmbeddingStoreIT {
return embeddingModel;
}
@Override
protected void awaitUntilPersisted() {
try {
Thread.sleep(1_000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
@Override
protected void clearStore() {
CodecRegistry pojoCodecRegistry = fromProviders(PojoCodecProvider.builder()

View File

@ -70,6 +70,12 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@ -3,50 +3,45 @@ package dev.langchain4j.store.embedding.azure.cosmos.nosql;
import com.azure.cosmos.ConsistencyLevel;
import com.azure.cosmos.CosmosClient;
import com.azure.cosmos.CosmosClientBuilder;
import com.azure.cosmos.CosmosDatabase;
import com.azure.cosmos.models.*;
import dev.langchain4j.data.embedding.Embedding;
import com.azure.cosmos.models.CosmosContainerProperties;
import com.azure.cosmos.models.CosmosVectorDataType;
import com.azure.cosmos.models.CosmosVectorDistanceFunction;
import com.azure.cosmos.models.CosmosVectorEmbedding;
import com.azure.cosmos.models.CosmosVectorEmbeddingPolicy;
import com.azure.cosmos.models.CosmosVectorIndexSpec;
import com.azure.cosmos.models.CosmosVectorIndexType;
import com.azure.cosmos.models.IncludedPath;
import com.azure.cosmos.models.IndexingMode;
import com.azure.cosmos.models.IndexingPolicy;
import com.azure.cosmos.models.PartitionKeyDefinition;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOS_HOST", matches = ".+")
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOS_MASTER_KEY", matches = ".+")
class AzureCosmosDbNoSqlEmbeddingStoreIT extends EmbeddingStoreIT {
protected static Logger logger = LoggerFactory.getLogger(AzureCosmosDbNoSqlEmbeddingStoreIT.class);
private static final String DATABASE_NAME = "test_database_langchain_java";
private static final String CONTAINER_NAME = "test_container";
private CosmosClient client;
CosmosDatabase database;
private final EmbeddingModel embeddingModel;
private final EmbeddingStore<TextSegment> embeddingStore;
private final int dimensions;
private final String HOST = System.getenv("AZURE_COSMOS_HOST");
private final String KEY = System.getenv("AZURE_COSMOS_MASTER_KEY");
private final EmbeddingModel embeddingModel;
public AzureCosmosDbNoSqlEmbeddingStoreIT() {
embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
dimensions = embeddingModel.embed("hello").content().vector().length;
int dimensions = embeddingModel.dimension();
client = new CosmosClientBuilder()
.endpoint(HOST)
.key(KEY)
CosmosClient client = new CosmosClientBuilder()
.endpoint(System.getenv("AZURE_COSMOS_HOST"))
.key(System.getenv("AZURE_COSMOS_MASTER_KEY"))
.consistencyLevel(ConsistencyLevel.EVENTUAL)
.contentResponseOnWriteEnabled(true)
.buildClient();
@ -59,43 +54,6 @@ class AzureCosmosDbNoSqlEmbeddingStoreIT extends EmbeddingStoreIT {
.cosmosVectorIndexes(populateVectorIndexSpec())
.containerProperties(populateContainerProperties())
.build();
database = client.getDatabase(DATABASE_NAME);
}
@Test
public void testAddEmbeddingsAndFindRelevant() {
String content1 = "banana";
String content2 = "computer";
String content3 = "apple";
String content4 = "pizza";
String content5 = "strawberry";
String content6 = "chess";
List<String> contents = asList(content1, content2, content3, content4, content5, content6);
for (String content : contents) {
TextSegment textSegment = TextSegment.from(content);
Embedding embedding = embeddingModel.embed(content).content();
embeddingStore.add(embedding, textSegment);
}
awaitUntilPersisted();
Embedding relevantEmbedding = embeddingModel.embed("fruit").content();
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(relevantEmbedding, 3);
assertThat(relevant).hasSize(3);
assertThat(relevant.get(0).embedding()).isNotNull();
assertThat(relevant.get(0).embedded().text()).isIn(content1, content3, content5);
logger.info("#1 relevant item: {}", relevant.get(0).embedded().text());
assertThat(relevant.get(1).embedding()).isNotNull();
assertThat(relevant.get(1).embedded().text()).isIn(content1, content3, content5);
logger.info("#2 relevant item: {}", relevant.get(1).embedded().text());
assertThat(relevant.get(2).embedding()).isNotNull();
assertThat(relevant.get(2).embedded().text()).isIn(content1, content3, content5);
logger.info("#3 relevant item: {}", relevant.get(2).embedded().text());
safeDeleteDatabase(database);
safeClose(client);
}
@Override
@ -108,38 +66,6 @@ class AzureCosmosDbNoSqlEmbeddingStoreIT extends EmbeddingStoreIT {
return embeddingModel;
}
@Override
protected void awaitUntilPersisted() {
try {
Thread.sleep(1_000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
@Override
protected void clearStore() {
}
private void safeDeleteDatabase(CosmosDatabase database) {
if (database != null) {
try {
database.delete();
} catch (Exception e) {
}
}
}
private void safeClose(CosmosClient client) {
if (client != null) {
try {
client.close();
} catch (Exception e) {
logger.error("failed to close client", e);
}
}
}
private CosmosVectorEmbeddingPolicy populateVectorEmbeddingPolicy(int dimensions) {
CosmosVectorEmbeddingPolicy vectorEmbeddingPolicy = new CosmosVectorEmbeddingPolicy();
CosmosVectorEmbedding embedding = new CosmosVectorEmbedding();
@ -174,6 +100,4 @@ class AzureCosmosDbNoSqlEmbeddingStoreIT extends EmbeddingStoreIT {
collectionDefinition.setIndexingPolicy(indexingPolicy);
return collectionDefinition;
}
}

View File

@ -112,6 +112,12 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>

View File

@ -3,30 +3,24 @@ package dev.langchain4j.store.embedding.astradb;
import com.dtsx.astra.sdk.AstraDB;
import com.dtsx.astra.sdk.AstraDBAdmin;
import com.dtsx.astra.sdk.AstraDBCollection;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiModelName;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
import io.stargate.sdk.data.domain.SimilarityMetric;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.MethodOrderer;
import org.junit.jupiter.api.TestMethodOrder;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import java.util.List;
import java.util.UUID;
import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken;
import static org.assertj.core.api.Assertions.assertThat;
@Disabled("AstraDB is not available in the CI")
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "sk.*")
@ -81,25 +75,4 @@ class AstraDbEmbeddingStoreIT extends EmbeddingStoreIT {
}
return embeddingModel;
}
void testAddEmbeddingAndFindRelevant() {
Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F});
TextSegment textSegment = TextSegment.from("Text", Metadata.from("Key", "Value"));
String id = embeddingStore.add(embedding, textSegment);
assertThat(id != null && !id.isEmpty()).isTrue();
Embedding referenceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F});
List<EmbeddingMatch<TextSegment>> embeddingMatches = embeddingStore.findRelevant(referenceEmbedding, 1);
assertThat(embeddingMatches).hasSize(1);
EmbeddingMatch<TextSegment> embeddingMatch = embeddingMatches.get(0);
assertThat(embeddingMatch.score()).isBetween(0d, 1d);
assertThat(embeddingMatch.embeddingId()).isEqualTo(id);
assertThat(embeddingMatch.embedding()).isEqualTo(embedding);
assertThat(embeddingMatch.embedded()).isEqualTo(textSegment);
}
}

View File

@ -9,6 +9,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import java.util.UUID;
import static com.dtsx.astra.sdk.cassio.CassandraSimilarityMetric.COSINE;
import static com.dtsx.astra.sdk.utils.TestUtils.TEST_REGION;
import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken;
@ -36,12 +37,10 @@ class CassandraEmbeddingStoreAstraIT extends CassandraEmbeddingStoreIT {
.databaseRegion(TEST_REGION)
.keyspace(KEYSPACE)
.table(TEST_INDEX)
.dimension(embeddingModelDimension()) // openai model
.metric(CassandraSimilarityMetric.COSINE)
.dimension(embeddingModel().dimension())
.metric(COSINE)
.build();
}
return embeddingStore;
}
}

View File

@ -1,12 +1,10 @@
package dev.langchain4j.store.embedding.cassandra;
import com.datastax.oss.driver.api.core.CqlSession;
import com.dtsx.astra.sdk.cassio.CassandraSimilarityMetric;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingStore;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.testcontainers.DockerClientFactory;
import org.testcontainers.containers.CassandraContainer;
import org.testcontainers.junit.jupiter.Testcontainers;
@ -15,10 +13,11 @@ import org.testcontainers.utility.DockerImageName;
import java.net.InetSocketAddress;
import java.util.Collections;
import static com.dtsx.astra.sdk.cassio.CassandraSimilarityMetric.COSINE;
/**
* Work with Cassandra Embedding Store.
*/
@Disabled("No Docker in the CI")
@Testcontainers
class CassandraEmbeddingStoreDockerIT extends CassandraEmbeddingStoreIT {
@ -55,7 +54,7 @@ class CassandraEmbeddingStoreDockerIT extends CassandraEmbeddingStoreIT {
* Stop Cassandra Node
*/
@AfterAll
static void afterTests() throws Exception {
static void afterTests() {
cassandraContainer.stop();
}
@ -69,11 +68,10 @@ class CassandraEmbeddingStoreDockerIT extends CassandraEmbeddingStoreIT {
.localDataCenter(DATACENTER)
.keyspace(KEYSPACE)
.table(TEST_INDEX)
.dimension(embeddingModelDimension())
.metric(CassandraSimilarityMetric.COSINE)
.dimension(embeddingModel().dimension())
.metric(COSINE)
.build();
}
return embeddingStore;
}
}

View File

@ -4,22 +4,17 @@ import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiModelName;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
import dev.langchain4j.store.embedding.RelevanceScore;
import lombok.extern.slf4j.Slf4j;
import org.assertj.core.data.Percentage;
import org.junit.jupiter.api.MethodOrderer;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestMethodOrder;
import java.time.Duration;
import java.util.List;
import static dev.langchain4j.internal.Utils.randomUUID;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Percentage.withPercentage;
@ -28,26 +23,17 @@ import static org.assertj.core.data.Percentage.withPercentage;
abstract class CassandraEmbeddingStoreIT extends EmbeddingStoreIT {
protected static final String KEYSPACE = "langchain4j";
protected static final String TEST_INDEX = "test_embedding_store";
CassandraEmbeddingStore embeddingStore;
EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName(OpenAiModelName.TEXT_EMBEDDING_ADA_002)
.timeout(Duration.ofSeconds(15))
.build();
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
@Override
protected EmbeddingModel embeddingModel() {
return embeddingModel;
}
protected int embeddingModelDimension() {
return 1536;
}
/**
* It is required to clean the repository in between tests
*/
@ -57,19 +43,16 @@ abstract class CassandraEmbeddingStoreIT extends EmbeddingStoreIT {
}
@Override
public void awaitUntilPersisted() {
try {
Thread.sleep(1000);
} catch(Exception e) {
}
protected Percentage percentage() {
return withPercentage(6); // TODO figure out why difference is so big
}
@Test
void should_retrieve_inserted_vector_by_ann() {
String sourceSentence = "Testing is doubting !";
Embedding sourceEmbedding = embeddingModel().embed(sourceSentence).content();
String sourceSentence = "Testing is doubting !";
Embedding sourceEmbedding = embeddingModel().embed(sourceSentence).content();
TextSegment sourceTextSegment = TextSegment.from(sourceSentence);
String id = embeddingStore().add(sourceEmbedding, sourceTextSegment);
String id = embeddingStore().add(sourceEmbedding, sourceTextSegment);
assertThat(id != null && !id.isEmpty()).isTrue();
List<EmbeddingMatch<TextSegment>> embeddingMatches = embeddingStore.findRelevant(sourceEmbedding, 10);
@ -84,12 +67,12 @@ abstract class CassandraEmbeddingStoreIT extends EmbeddingStoreIT {
@Test
void should_retrieve_inserted_vector_by_ann_and_metadata() {
String sourceSentence = "In GOD we trust, everything else we test!";
Embedding sourceEmbedding = embeddingModel().embed(sourceSentence).content();
String sourceSentence = "In GOD we trust, everything else we test!";
Embedding sourceEmbedding = embeddingModel().embed(sourceSentence).content();
TextSegment sourceTextSegment = TextSegment.from(sourceSentence, new Metadata()
.put("user", "GOD")
.put("test", "false"));
String id = embeddingStore().add(sourceEmbedding, sourceTextSegment);
String id = embeddingStore().add(sourceEmbedding, sourceTextSegment);
assertThat(id != null && !id.isEmpty()).isTrue();
// Should be found with no filter
@ -106,144 +89,4 @@ abstract class CassandraEmbeddingStoreIT extends EmbeddingStoreIT {
.findRelevant(sourceEmbedding, 10, .5d, Metadata.from("user", "JOHN"));
assertThat(matchesJohn).isEmpty();
}
// metrics returned are 1.95% off we updated to "withPercentage(2)"
@Test
void should_return_correct_score() {
Embedding embedding = embeddingModel().embed("hello").content();
String id = embeddingStore().add(embedding);
assertThat(id).isNotBlank();
Embedding referenceEmbedding = embeddingModel().embed("hi").content();
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(referenceEmbedding, 1);
assertThat(relevant).hasSize(1);
EmbeddingMatch<TextSegment> match = relevant.get(0);
assertThat(match.score()).isCloseTo(
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)),
withPercentage(2)
);
}
@Test
void should_find_with_min_score() {
String firstId = randomUUID();
Embedding firstEmbedding = embeddingModel().embed("hello").content();
embeddingStore().add(firstId, firstEmbedding);
String secondId = randomUUID();
Embedding secondEmbedding = embeddingModel().embed("hi").content();
embeddingStore().add(secondId, secondEmbedding);
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
assertThat(relevant).hasSize(2);
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
assertThat(firstMatch.embeddingId()).isEqualTo(firstId);
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
assertThat(secondMatch.score()).isCloseTo(
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)),
withPercentage(2)
);
assertThat(secondMatch.embeddingId()).isEqualTo(secondId);
List<EmbeddingMatch<TextSegment>> relevant2 = embeddingStore().findRelevant(
firstEmbedding,
10,
secondMatch.score() - 0.01
);
assertThat(relevant2).hasSize(2);
assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId);
assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId);
List<EmbeddingMatch<TextSegment>> relevant3 = embeddingStore().findRelevant(
firstEmbedding,
10,
secondMatch.score()
);
assertThat(relevant3).hasSize(2);
assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId);
assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId);
List<EmbeddingMatch<TextSegment>> relevant4 = embeddingStore().findRelevant(
firstEmbedding,
10,
secondMatch.score() + 0.01
);
assertThat(relevant4).hasSize(1);
assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId);
}
@Test
void should_add_multiple_embeddings_with_segments() {
TextSegment firstSegment = TextSegment.from("hello");
Embedding firstEmbedding = embeddingModel().embed(firstSegment.text()).content();
TextSegment secondSegment = TextSegment.from("hi");
Embedding secondEmbedding = embeddingModel().embed(secondSegment.text()).content();
List<String> ids = embeddingStore().addAll(
asList(firstEmbedding, secondEmbedding),
asList(firstSegment, secondSegment)
);
assertThat(ids).hasSize(2);
assertThat(ids.get(0)).isNotBlank();
assertThat(ids.get(1)).isNotBlank();
assertThat(ids.get(0)).isNotEqualTo(ids.get(1));
awaitUntilPersisted();
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
assertThat(relevant).hasSize(2);
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
assertThat(firstMatch.embedded()).isEqualTo(firstSegment);
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
assertThat(secondMatch.score()).isCloseTo(
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)),
withPercentage(2)
);
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
assertThat(secondMatch.embedded()).isEqualTo(secondSegment);
}
@Test
void should_add_multiple_embeddings() {
Embedding firstEmbedding = embeddingModel().embed("hello").content();
Embedding secondEmbedding = embeddingModel().embed("hi").content();
List<String> ids = embeddingStore().addAll(asList(firstEmbedding, secondEmbedding));
assertThat(ids).hasSize(2);
assertThat(ids.get(0)).isNotBlank();
assertThat(ids.get(1)).isNotBlank();
assertThat(ids.get(0)).isNotEqualTo(ids.get(1));
awaitUntilPersisted();
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
assertThat(relevant).hasSize(2);
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(2));
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
assertThat(firstMatch.embedded()).isNull();
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
assertThat(secondMatch.score()).isCloseTo(
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)),
withPercentage(2)
);
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
assertThat(secondMatch.embedded()).isNull();
}
}

View File

@ -3,7 +3,6 @@ package dev.langchain4j.store.memory.chat.cassandra;
import com.datastax.oss.driver.api.core.CqlSession;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.testcontainers.DockerClientFactory;
import org.testcontainers.containers.CassandraContainer;
import org.testcontainers.junit.jupiter.Testcontainers;
@ -14,7 +13,6 @@ import java.net.InetSocketAddress;
/**
* Test Cassandra Chat Memory Store with a Saas DB.
*/
@Disabled("No Docker in the CI")
@Testcontainers
class CassandraChatMemoryStoreDockerIT extends CassandraChatMemoryStoreTestSupport {
static final String DATACENTER = "datacenter1";
@ -44,8 +42,8 @@ class CassandraChatMemoryStoreDockerIT extends CassandraChatMemoryStoreTestSuppo
.addContactPoint(contactPoint)
.withLocalDatacenter(DATACENTER)
.build().execute(
"CREATE KEYSPACE IF NOT EXISTS " + KEYSPACE +
" WITH replication = {'class':'SimpleStrategy', 'replication_factor':'1'};");
"CREATE KEYSPACE IF NOT EXISTS " + KEYSPACE +
" WITH replication = {'class':'SimpleStrategy', 'replication_factor':'1'};");
return new CassandraChatMemoryStore(CqlSession.builder()
.addContactPoint(contactPoint)
.withLocalDatacenter(DATACENTER)
@ -54,8 +52,7 @@ class CassandraChatMemoryStoreDockerIT extends CassandraChatMemoryStoreTestSuppo
}
@AfterAll
static void afterTests() throws Exception {
static void afterTests() {
cassandraContainer.stop();
}
}

View File

@ -1,10 +1,5 @@
package dev.langchain4j.store.embedding.chroma;
import static dev.langchain4j.internal.Utils.randomUUID;
import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
@ -12,28 +7,37 @@ import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2Quantize
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT;
import dev.langchain4j.store.embedding.filter.Filter;
import java.util.List;
import java.util.UUID;
import java.util.stream.Stream;
import org.junit.Ignore;
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan;
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsLessThan;
import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo;
import dev.langchain4j.store.embedding.filter.logical.Not;
import org.junit.jupiter.params.provider.Arguments;
import org.testcontainers.chromadb.ChromaDBContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;
import static dev.langchain4j.internal.Utils.randomUUID;
import static java.util.stream.Collectors.toList;
import static org.assertj.core.api.Assertions.assertThat;
@Testcontainers
class ChromaEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT {
@Container
private static final ChromaDBContainer chroma = new ChromaDBContainer("chromadb/chroma:0.5.4");
EmbeddingStore<TextSegment> embeddingStore = ChromaEmbeddingStore
.builder()
.baseUrl(chroma.getEndpoint())
.collectionName(randomUUID())
.logRequests(true)
.logResponses(true)
.build();
EmbeddingStore<TextSegment> embeddingStore = ChromaEmbeddingStore.builder()
.baseUrl(chroma.getEndpoint())
.collectionName(randomUUID())
.logRequests(true)
.logResponses(true)
.build();
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
@ -47,182 +51,65 @@ class ChromaEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT {
return embeddingModel;
}
@Override
@Ignore("Chroma cannot filter by greater and less than of alphanumeric metadata, only int and float are supported")
protected void should_filter_by_greater_and_less_than_alphanumeric_metadata(
Filter metadataFilter,
List<Metadata> matchingMetadatas,
List<Metadata> notMatchingMetadatas
) {}
// in chroma compare filter only works with numbers
protected static Stream<Arguments> should_filter_by_metadata() {
return EmbeddingStoreWithFilteringIT.should_filter_by_metadata()
.filter(arguments -> {
Filter filter = (Filter) arguments.get()[0];
if (filter instanceof IsLessThan) {
return ((IsLessThan) filter).comparisonValue() instanceof Number;
} else if (filter instanceof IsLessThanOrEqualTo) {
return ((IsLessThanOrEqualTo) filter).comparisonValue() instanceof Number;
} else if (filter instanceof IsGreaterThan) {
return ((IsGreaterThan) filter).comparisonValue() instanceof Number;
} else if (filter instanceof IsGreaterThanOrEqualTo) {
return ((IsGreaterThanOrEqualTo) filter).comparisonValue() instanceof Number;
} else {
return true;
}
}
);
}
// Chroma filters by *not* as following:
// If you filter by "key" not equals "a", then in fact all items with "key" != "a" value are returned, but no items
// without "key" metadata!
// Therefore, all default *not* tests coming from parent class have to be rewritten here.
protected static Stream<Arguments> should_filter_by_metadata_not() {
return Stream
.<Arguments>builder()
// === NotEqual ===
return EmbeddingStoreWithFilteringIT.should_filter_by_metadata_not()
.map(args -> {
Object[] arguments = args.get();
Filter filter = (Filter) arguments[0];
.add(
Arguments.of(
metadataKey("key").isNotEqualTo("a"),
asList(
new Metadata().put("key", "A"),
new Metadata().put("key", "b"),
new Metadata().put("key", "aa"),
new Metadata().put("key", "a a")
),
asList(
new Metadata().put("key", "a"),
new Metadata().put("key2", "a"),
new Metadata().put("key", "a").put("key2", "b")
),
false
)
)
.add(
Arguments.of(
metadataKey("key").isNotEqualTo(TEST_UUID),
asList(new Metadata().put("key", UUID.randomUUID())),
asList(
new Metadata().put("key", TEST_UUID),
new Metadata().put("key2", TEST_UUID),
new Metadata().put("key", TEST_UUID).put("key2", UUID.randomUUID())
),
false
)
)
.add(
Arguments.of(
metadataKey("key").isNotEqualTo(1),
asList(
new Metadata().put("key", -1),
new Metadata().put("key", 0),
new Metadata().put("key", 2),
new Metadata().put("key", 10)
),
asList(
new Metadata().put("key", 1),
new Metadata().put("key2", 1),
new Metadata().put("key", 1).put("key2", 2)
),
false
)
)
.add(
Arguments.of(
metadataKey("key").isNotEqualTo(1.1f),
asList(
new Metadata().put("key", -1.1f),
new Metadata().put("key", 0.0f),
new Metadata().put("key", 1.11f),
new Metadata().put("key", 2.2f)
),
asList(
new Metadata().put("key", 1.1f),
new Metadata().put("key2", 1.1f),
new Metadata().put("key", 1.1f).put("key2", 2.2f)
),
false
)
)
// === NotIn ===
String key = getMetadataKey(filter);
// NotIn: string
.add(
Arguments.of(
metadataKey("name").isNotIn("Klaus"),
asList(new Metadata().put("name", "Klaus Heisler"), new Metadata().put("name", "Alice")),
asList(
new Metadata().put("name", "Klaus"),
new Metadata().put("name2", "Klaus"),
new Metadata().put("name", "Klaus").put("age", 42)
),
false
)
)
.add(
Arguments.of(
metadataKey("name").isNotIn(singletonList("Klaus")),
asList(new Metadata().put("name", "Klaus Heisler"), new Metadata().put("name", "Alice")),
asList(
new Metadata().put("name", "Klaus"),
new Metadata().put("name2", "Klaus"),
new Metadata().put("name", "Klaus").put("age", 42)
),
false
)
)
.add(
Arguments.of(
metadataKey("name").isNotIn("Klaus", "Alice"),
asList(new Metadata().put("name", "Klaus Heisler"), new Metadata().put("name", "Zoe")),
asList(
new Metadata().put("name", "Klaus"),
new Metadata().put("name2", "Klaus"),
new Metadata().put("name", "Klaus").put("age", 42),
new Metadata().put("name", "Alice"),
new Metadata().put("name", "Alice").put("age", 42)
),
false
)
)
// NotIn: UUID
.add(
Arguments.of(
metadataKey("name").isNotIn(TEST_UUID),
asList(new Metadata().put("name", UUID.randomUUID())),
asList(
new Metadata().put("name", TEST_UUID),
new Metadata().put("name2", TEST_UUID),
new Metadata().put("name", TEST_UUID).put("age", 42)
),
false
)
)
// NotIn: int
.add(
Arguments.of(
metadataKey("age").isNotIn(42),
asList(new Metadata().put("age", 666)),
asList(
new Metadata().put("age", 42),
new Metadata().put("age2", 42),
new Metadata().put("age", 42).put("name", "Klaus")
),
false
)
)
.add(
Arguments.of(
metadataKey("age").isNotIn(42, 18),
asList(new Metadata().put("age", 666)),
asList(
new Metadata().put("age", 42),
new Metadata().put("age", 18),
new Metadata().put("age2", 42),
new Metadata().put("age", 42).put("name", "Klaus"),
new Metadata().put("age", 18).put("name", "Klaus")
),
false
)
)
// NotIn: float
.add(
Arguments.of(
metadataKey("age").isNotIn(asList(42.0f, 18.0f)),
asList(new Metadata().put("age", 666.0f)),
asList(
new Metadata().put("age", 42.0f),
new Metadata().put("age", 18.0f),
new Metadata().put("age2", 42.0f),
new Metadata().put("age", 42.0f).put("name", "Klaus"),
new Metadata().put("age", 18.0f).put("name", "Klaus")
),
false
)
)
.build();
List<Metadata> matchingMetadatas = (List<Metadata>) arguments[1];
List<Metadata> newMatchingMetadatas = matchingMetadatas.stream()
.filter(metadata -> metadata.containsKey(key))
.collect(toList());
List<Metadata> notMatchingMetadatas = (List<Metadata>) arguments[2];
List<Metadata> newNotMatchingMetadatas = new ArrayList<>(notMatchingMetadatas);
newNotMatchingMetadatas.addAll(matchingMetadatas.stream()
.filter(metadata -> !metadata.containsKey(key))
.collect(toList()));
assertThat(Stream.concat(newMatchingMetadatas.stream(), newNotMatchingMetadatas.stream()))
.containsExactlyInAnyOrderElementsOf(Stream.concat(matchingMetadatas.stream(), notMatchingMetadatas.stream()).collect(toList()));
return Arguments.of(filter, newMatchingMetadatas, newNotMatchingMetadatas);
});
}
private static String getMetadataKey(Filter filter) {
try {
if (filter instanceof Not) {
filter = ((Not) filter).expression();
}
Method method = filter.getClass().getMethod("key");
return (String) method.invoke(filter);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

View File

@ -1,14 +1,15 @@
package dev.langchain4j.store.embedding;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Percentage.withPercentage;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Percentage.withPercentage;
/**
* A minimum set of tests that each implementation of {@link EmbeddingStore} must pass.
@ -20,6 +21,7 @@ public abstract class EmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT {
@Test
void should_add_embedding_with_segment_with_metadata() {
Metadata metadata = createMetadata();
TextSegment segment = TextSegment.from("hello", metadata);
@ -28,22 +30,19 @@ public abstract class EmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT {
String id = embeddingStore().add(embedding, segment);
assertThat(id).isNotBlank();
{
// Not returned.
TextSegment altSegment = TextSegment.from("hello?");
Embedding altEmbedding = embeddingModel().embed(altSegment.text()).content();
embeddingStore().add(altEmbedding, altSegment);
}
awaitUntilPersisted();
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(1));
// when
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 1);
assertThat(relevant).hasSize(1);
// then
assertThat(relevant).hasSize(1);
EmbeddingMatch<TextSegment> match = relevant.get(0);
assertThat(match.score()).isCloseTo(1, withPercentage(1));
assertThat(match.embeddingId()).isEqualTo(id);
assertThat(match.embedding()).isEqualTo(embedding);
if (assertEmbedding()) {
assertThat(match.embedding()).isEqualTo(embedding);
}
assertThat(match.embedded().text()).isEqualTo(segment.text());
@ -78,15 +77,14 @@ public abstract class EmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT {
assertThat(match.embedded().metadata().getDouble("double_123")).isEqualTo(1.23456789d);
// new API
assertThat(
embeddingStore()
.search(EmbeddingSearchRequest.builder().queryEmbedding(embedding).maxResults(1).build())
.matches()
)
.isEqualTo(relevant);
assertThat(embeddingStore().search(EmbeddingSearchRequest.builder()
.queryEmbedding(embedding)
.maxResults(1)
.build()).matches()).isEqualTo(relevant);
}
protected Metadata createMetadata() {
Metadata metadata = new Metadata();
metadata.put("string_empty", "");

View File

@ -152,10 +152,17 @@ public abstract class EmbeddingStoreWithRemovalIT {
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).isEmpty());
}
protected void awaitUntilAsserted(ThrowingRunnable assertion) {
Awaitility.await()
.atMost(Duration.ofSeconds(60))
.pollDelay(Duration.ofSeconds(0))
.pollInterval(Duration.ofMillis(300))
.untilAsserted(assertion);
}
protected List<EmbeddingMatch<TextSegment>> getAllEmbeddings() {
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest
.builder()
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
.queryEmbedding(embeddingModel().embed("test").content())
.maxResults(1000)
.build();
@ -164,11 +171,4 @@ public abstract class EmbeddingStoreWithRemovalIT {
return searchResult.matches();
}
protected static void awaitUntilAsserted(ThrowingRunnable assertion) {
Awaitility.await()
.pollInterval(Duration.ofMillis(500))
.atMost(Duration.ofSeconds(15))
.untilAsserted(assertion);
}
}

View File

@ -3,9 +3,13 @@ package dev.langchain4j.store.embedding;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import org.assertj.core.data.Percentage;
import org.awaitility.Awaitility;
import org.awaitility.core.ThrowingRunnable;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.time.Duration;
import java.util.List;
import static dev.langchain4j.internal.Utils.randomUUID;
@ -29,27 +33,30 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
}
protected void ensureStoreIsEmpty() {
Embedding embedding = embeddingModel().embed("hello").content();
assertThat(embeddingStore().findRelevant(embedding, 1000)).isEmpty();
assertThat(getAllEmbeddings()).isEmpty();
}
@Test
void should_add_embedding() {
// given
Embedding embedding = embeddingModel().embed("hello").content();
String id = embeddingStore().add(embedding);
assertThat(id).isNotBlank();
awaitUntilPersisted();
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(1));
// when
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 10);
assertThat(relevant).hasSize(1);
// then
assertThat(id).isNotBlank();
assertThat(relevant).hasSize(1);
EmbeddingMatch<TextSegment> match = relevant.get(0);
assertThat(match.score()).isCloseTo(1, withPercentage(1));
assertThat(match.score()).isCloseTo(1, percentage());
assertThat(match.embeddingId()).isEqualTo(id);
assertThat(match.embedding()).isEqualTo(embedding);
if (assertEmbedding()) {
assertThat(match.embedding()).isEqualTo(embedding);
}
assertThat(match.embedded()).isNull();
// new API
@ -62,20 +69,24 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
@Test
void should_add_embedding_with_id() {
// given
String id = randomUUID();
Embedding embedding = embeddingModel().embed("hello").content();
embeddingStore().add(id, embedding);
awaitUntilPersisted();
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(1));
// when
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 10);
assertThat(relevant).hasSize(1);
// then
assertThat(relevant).hasSize(1);
EmbeddingMatch<TextSegment> match = relevant.get(0);
assertThat(match.score()).isCloseTo(1, withPercentage(1));
assertThat(match.score()).isCloseTo(1, percentage());
assertThat(match.embeddingId()).isEqualTo(id);
assertThat(match.embedding()).isEqualTo(embedding);
if (assertEmbedding()) {
assertThat(match.embedding()).isEqualTo(embedding);
}
assertThat(match.embedded()).isNull();
// new API
@ -88,21 +99,25 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
@Test
void should_add_embedding_with_segment() {
// given
TextSegment segment = TextSegment.from("hello");
Embedding embedding = embeddingModel().embed(segment.text()).content();
String id = embeddingStore().add(embedding, segment);
assertThat(id).isNotBlank();
awaitUntilPersisted();
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(1));
// when
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 10);
assertThat(relevant).hasSize(1);
// then
assertThat(id).isNotBlank();
assertThat(relevant).hasSize(1);
EmbeddingMatch<TextSegment> match = relevant.get(0);
assertThat(match.score()).isCloseTo(1, withPercentage(1));
assertThat(match.score()).isCloseTo(1, percentage());
assertThat(match.embeddingId()).isEqualTo(id);
assertThat(match.embedding()).isEqualTo(embedding);
if (assertEmbedding()) {
assertThat(match.embedding()).isEqualTo(embedding);
}
assertThat(match.embedded()).isEqualTo(segment);
// new API
@ -115,34 +130,41 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
@Test
void should_add_multiple_embeddings() {
// given
Embedding firstEmbedding = embeddingModel().embed("hello").content();
Embedding secondEmbedding = embeddingModel().embed("hi").content();
List<String> ids = embeddingStore().addAll(asList(firstEmbedding, secondEmbedding));
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(2));
// when
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
// then
assertThat(ids).hasSize(2);
assertThat(ids.get(0)).isNotBlank();
assertThat(ids.get(1)).isNotBlank();
assertThat(ids.get(0)).isNotEqualTo(ids.get(1));
awaitUntilPersisted();
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
assertThat(relevant).hasSize(2);
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
assertThat(firstMatch.score()).isCloseTo(1, percentage());
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
if (assertEmbedding()) {
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
}
assertThat(firstMatch.embedded()).isNull();
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
assertThat(secondMatch.score()).isCloseTo(
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)),
withPercentage(1)
percentage()
);
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
assertThat(CosineSimilarity.between(secondMatch.embedding(), secondEmbedding))
.isCloseTo(1, withPercentage(0.01)); // TODO return strict check back once Qdrant fixes it
if (assertEmbedding()) {
assertThat(CosineSimilarity.between(secondMatch.embedding(), secondEmbedding))
.isCloseTo(1, withPercentage(0.01)); // TODO return strict check back once Qdrant fixes it
}
assertThat(secondMatch.embedded()).isNull();
// new API
@ -155,6 +177,7 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
@Test
void should_add_multiple_embeddings_with_segments() {
// given
TextSegment firstSegment = TextSegment.from("hello");
Embedding firstEmbedding = embeddingModel().embed(firstSegment.text()).content();
@ -165,30 +188,37 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
asList(firstEmbedding, secondEmbedding),
asList(firstSegment, secondSegment)
);
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(2));
// when
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
// then
assertThat(ids).hasSize(2);
assertThat(ids.get(0)).isNotBlank();
assertThat(ids.get(1)).isNotBlank();
assertThat(ids.get(0)).isNotEqualTo(ids.get(1));
awaitUntilPersisted();
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
assertThat(relevant).hasSize(2);
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
assertThat(firstMatch.score()).isCloseTo(1, percentage());
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
if (assertEmbedding()) {
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
}
assertThat(firstMatch.embedded()).isEqualTo(firstSegment);
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
assertThat(secondMatch.score()).isCloseTo(
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)),
withPercentage(1)
percentage()
);
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
assertThat(CosineSimilarity.between(secondMatch.embedding(), secondEmbedding))
.isCloseTo(1, withPercentage(0.01)); // TODO return strict check back once Qdrant fixes it
if (assertEmbedding()) {
assertThat(CosineSimilarity.between(secondMatch.embedding(), secondEmbedding))
.isCloseTo(1, withPercentage(0.01)); // TODO return strict check back once Qdrant fixes it
}
assertThat(secondMatch.embedded()).isEqualTo(secondSegment);
// new API
@ -201,6 +231,7 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
@Test
void should_find_with_min_score() {
// given
String firstId = randomUUID();
Embedding firstEmbedding = embeddingModel().embed("hello").content();
embeddingStore().add(firstId, firstEmbedding);
@ -209,33 +240,41 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
Embedding secondEmbedding = embeddingModel().embed("hi").content();
embeddingStore().add(secondId, secondEmbedding);
awaitUntilPersisted();
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(2));
// when
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
// then
assertThat(relevant).hasSize(2);
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
assertThat(firstMatch.score()).isCloseTo(1, percentage());
assertThat(firstMatch.embeddingId()).isEqualTo(firstId);
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
assertThat(secondMatch.score()).isCloseTo(
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)),
withPercentage(1)
percentage()
);
assertThat(secondMatch.embeddingId()).isEqualTo(secondId);
// new API
assertThat(embeddingStore().search(EmbeddingSearchRequest.builder()
.queryEmbedding(firstEmbedding)
.maxResults(10)
.build()).matches()).isEqualTo(relevant);
// when
List<EmbeddingMatch<TextSegment>> relevant2 = embeddingStore().findRelevant(
firstEmbedding,
10,
secondMatch.score() - 0.01
);
// then
assertThat(relevant2).hasSize(2);
assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId);
assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId);
// new API
assertThat(embeddingStore().search(EmbeddingSearchRequest.builder()
.queryEmbedding(firstEmbedding)
@ -243,14 +282,18 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
.minScore(secondMatch.score() - 0.01)
.build()).matches()).isEqualTo(relevant2);
// when
List<EmbeddingMatch<TextSegment>> relevant3 = embeddingStore().findRelevant(
firstEmbedding,
10,
secondMatch.score()
);
// then
assertThat(relevant3).hasSize(2);
assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId);
assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId);
// new API
assertThat(embeddingStore().search(EmbeddingSearchRequest.builder()
.queryEmbedding(firstEmbedding)
@ -258,13 +301,17 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
.minScore(secondMatch.score())
.build()).matches()).isEqualTo(relevant3);
// when
List<EmbeddingMatch<TextSegment>> relevant4 = embeddingStore().findRelevant(
firstEmbedding,
10,
secondMatch.score() + 0.01
);
// then
assertThat(relevant4).hasSize(1);
assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId);
// new API
assertThat(embeddingStore().search(EmbeddingSearchRequest.builder()
.queryEmbedding(firstEmbedding)
@ -276,22 +323,25 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
@Test
void should_return_correct_score() {
// given
Embedding embedding = embeddingModel().embed("hello").content();
String id = embeddingStore().add(embedding);
assertThat(id).isNotBlank();
awaitUntilPersisted();
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(1));
Embedding referenceEmbedding = embeddingModel().embed("hi").content();
// when
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(referenceEmbedding, 1);
assertThat(relevant).hasSize(1);
// then
assertThat(relevant).hasSize(1);
EmbeddingMatch<TextSegment> match = relevant.get(0);
assertThat(match.score()).isCloseTo(
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)),
withPercentage(1)
percentage()
);
// new API
@ -301,7 +351,31 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
.build()).matches()).isEqualTo(relevant);
}
protected void awaitUntilPersisted() {
// not waiting by default
protected void awaitUntilAsserted(ThrowingRunnable assertion) {
Awaitility.await()
.atMost(Duration.ofSeconds(60))
.pollDelay(Duration.ofSeconds(0))
.pollInterval(Duration.ofMillis(300))
.untilAsserted(assertion);
}
protected List<EmbeddingMatch<TextSegment>> getAllEmbeddings() {
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
.queryEmbedding(embeddingModel().embed("test").content())
.maxResults(1000)
.build();
EmbeddingSearchResult<TextSegment> searchResult = embeddingStore().search(embeddingSearchRequest);
return searchResult.matches();
}
protected boolean assertEmbedding() {
return true;
}
protected Percentage percentage() {
return withPercentage(1);
}
}

View File

@ -5,7 +5,6 @@ import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT;
import lombok.SneakyThrows;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
@ -42,6 +41,7 @@ abstract class AbstractElasticsearchEmbeddingStoreIT extends EmbeddingStoreWithF
}
abstract ElasticsearchConfiguration withConfiguration();
void optionallyCreateIndex(String indexName) throws IOException {
}
@ -80,10 +80,4 @@ abstract class AbstractElasticsearchEmbeddingStoreIT extends EmbeddingStoreWithF
protected void ensureStoreIsEmpty() {
// TODO fix
}
@Override
@SneakyThrows
protected void awaitUntilPersisted() {
elasticsearchClientHelper.client.indices().refresh(rr -> rr.index(indexName).ignoreUnavailable(true));
}
}

View File

@ -4,7 +4,7 @@ import co.elastic.clients.transport.endpoints.BooleanResponse;
import java.io.IOException;
class ElasticsearchKnnEmbeddingStoreIT extends AbstractElasticsearchEmbeddingStoreIT {
class ElasticsearchEmbeddingStoreKnnIT extends AbstractElasticsearchEmbeddingStoreIT {
@Override
ElasticsearchConfiguration withConfiguration() {

View File

@ -1,6 +1,7 @@
package dev.langchain4j.store.embedding.elasticsearch;
class ElasticsearchEmbeddingStoreIT extends AbstractElasticsearchEmbeddingStoreIT {
class ElasticsearchEmbeddingStoreScriptIT extends AbstractElasticsearchEmbeddingStoreIT {
@Override
ElasticsearchConfiguration withConfiguration() {
return ElasticsearchConfigurationScript.builder().build();

View File

@ -101,6 +101,13 @@
<artifactId>infinispan-server-testdriver-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>

View File

@ -97,6 +97,13 @@
<artifactId>slf4j-tinylog</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>

View File

@ -6,11 +6,10 @@ import com.mongodb.client.MongoClients;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.model.Filters;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
import lombok.SneakyThrows;
import org.bson.codecs.configuration.CodecRegistry;
import org.bson.codecs.pojo.PojoCodecProvider;
import org.bson.conversions.Bson;
@ -78,10 +77,4 @@ class MongoDbEmbeddingStoreCloudIT extends EmbeddingStoreIT {
Bson filter = Filters.exists("embedding");
collection.deleteMany(filter);
}
@Override
@SneakyThrows
protected void awaitUntilPersisted() {
Thread.sleep(3000);
}
}

View File

@ -1,13 +1,17 @@
package dev.langchain4j.store.embedding.mongodb;
import com.mongodb.*;
import com.mongodb.ConnectionString;
import com.mongodb.MongoClientSettings;
import com.mongodb.MongoCredential;
import com.mongodb.ServerApi;
import com.mongodb.ServerApiVersion;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.model.Filters;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
import lombok.SneakyThrows;
@ -91,10 +95,4 @@ class MongoDbEmbeddingStoreLocalIT extends EmbeddingStoreIT {
Bson filter = Filters.exists("embedding");
collection.deleteMany(filter);
}
@Override
@SneakyThrows
protected void awaitUntilPersisted() {
Thread.sleep(2000);
}
}

View File

@ -1,14 +1,18 @@
package dev.langchain4j.store.embedding.mongodb;
import com.mongodb.*;
import com.mongodb.ConnectionString;
import com.mongodb.MongoClientSettings;
import com.mongodb.MongoCredential;
import com.mongodb.ServerApi;
import com.mongodb.ServerApiVersion;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients;
import com.mongodb.client.model.Filters;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import lombok.SneakyThrows;
@ -23,7 +27,7 @@ import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Percentage.withPercentage;
class MongoDbEmbeddingStoreFilterIT {
class MongoDbEmbeddingStoreNativeFilterIT {
static MongoDBAtlasContainer mongodb = new MongoDBAtlasContainer();
@ -68,6 +72,8 @@ class MongoDbEmbeddingStoreFilterIT {
@Test
void should_find_relevant_with_filter() {
// given
TextSegment segment = TextSegment.from("this segment should be found", Metadata.from("test-key", "test-value"));
Embedding embedding = embeddingModel.embed(segment.text()).content();
@ -75,8 +81,7 @@ class MongoDbEmbeddingStoreFilterIT {
Embedding filterEmbedding = embeddingModel.embed(filterSegment.text()).content();
List<String> ids = embeddingStore.addAll(asList(embedding, filterEmbedding), asList(segment, filterSegment));
assertThat(ids)
.hasSize(2);
assertThat(ids).hasSize(2);
TextSegment refSegment = TextSegment.from("find a segment");
Embedding refEmbedding = embeddingModel.embed(refSegment.text()).content();
@ -84,7 +89,8 @@ class MongoDbEmbeddingStoreFilterIT {
awaitUntilPersisted();
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(refEmbedding, 2);
// Only segment should be found, filterSegment should be filtered
// then
assertThat(relevant).hasSize(1);
EmbeddingMatch<TextSegment> match = relevant.get(0);
@ -95,7 +101,7 @@ class MongoDbEmbeddingStoreFilterIT {
}
@SneakyThrows
protected void awaitUntilPersisted() {
private void awaitUntilPersisted() {
Thread.sleep(2000);
}
}

View File

@ -104,6 +104,13 @@
<artifactId>mockito-junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>

View File

@ -426,12 +426,6 @@ class Neo4jEmbeddingStoreIT {
assertThat(rowsBatched.get(0)).hasSize(1);
}
@Test
void test_row_batches_empty() {
List<List<Map<String, Object>>> rowsBatched = getListRowsBatched(0);
assertThat(rowsBatched).isEmpty();
}
@Test
void test_row_batches_10000_elements() {
List<List<Map<String, Object>>> rowsBatched = getListRowsBatched(10000);

View File

@ -2,11 +2,10 @@ package dev.langchain4j.store.embedding.opensearch;
import com.jayway.jsonpath.JsonPath;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
import lombok.SneakyThrows;
import net.minidev.json.JSONArray;
import org.junit.jupiter.api.BeforeAll;
import org.opensearch.client.transport.aws.AwsSdk2TransportOptions;
@ -74,11 +73,4 @@ class OpenSearchEmbeddingStoreAwsIT extends EmbeddingStoreIT {
protected void ensureStoreIsEmpty() {
// TODO fix
}
@Override
@SneakyThrows
protected void awaitUntilPersisted() {
Thread.sleep(1000);
}
}

View File

@ -1,11 +1,10 @@
package dev.langchain4j.store.embedding.opensearch;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
import lombok.SneakyThrows;
import org.junit.jupiter.api.BeforeAll;
import org.opensearch.testcontainers.OpensearchContainer;
import org.testcontainers.junit.jupiter.Container;
@ -51,10 +50,4 @@ class OpenSearchEmbeddingStoreLocalIT extends EmbeddingStoreIT {
protected void ensureStoreIsEmpty() {
// TODO fix
}
@Override
@SneakyThrows
protected void awaitUntilPersisted() {
Thread.sleep(1000);
}
}

View File

@ -11,7 +11,6 @@ import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan;
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsLessThan;
import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo;
import lombok.SneakyThrows;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
@ -54,12 +53,6 @@ class PineconeEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT {
return embeddingModel;
}
@Override
@SneakyThrows
protected void awaitUntilPersisted() {
Thread.sleep(6000);
}
@ParameterizedTest
@MethodSource("should_filter_by_metadata")
protected void should_filter_by_metadata(Filter metadataFilter,
@ -68,23 +61,23 @@ class PineconeEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT {
super.should_filter_by_metadata(metadataFilter, matchingMetadatas, notMatchingMetadatas);
}
// in pinecone, compare filter only works with numbers
// in pinecone compare filter only works with numbers
protected static Stream<Arguments> should_filter_by_metadata() {
return EmbeddingStoreWithFilteringIT.should_filter_by_metadata().filter(
arguments -> {
Object o = arguments.get()[0];
if (o instanceof IsLessThan) {
return ((IsLessThan) o).comparisonValue() instanceof Number;
} else if (o instanceof IsLessThanOrEqualTo) {
return ((IsLessThanOrEqualTo) o).comparisonValue() instanceof Number;
} else if (o instanceof IsGreaterThan) {
return ((IsGreaterThan) o).comparisonValue() instanceof Number;
} else if (o instanceof IsGreaterThanOrEqualTo) {
return ((IsGreaterThanOrEqualTo) o).comparisonValue() instanceof Number;
} else {
return true;
}
}
);
return EmbeddingStoreWithFilteringIT.should_filter_by_metadata()
.filter(arguments -> {
Filter filter = (Filter) arguments.get()[0];
if (filter instanceof IsLessThan) {
return ((IsLessThan) filter).comparisonValue() instanceof Number;
} else if (filter instanceof IsLessThanOrEqualTo) {
return ((IsLessThanOrEqualTo) filter).comparisonValue() instanceof Number;
} else if (filter instanceof IsGreaterThan) {
return ((IsGreaterThan) filter).comparisonValue() instanceof Number;
} else if (filter instanceof IsGreaterThanOrEqualTo) {
return ((IsGreaterThanOrEqualTo) filter).comparisonValue() instanceof Number;
} else {
return true;
}
}
);
}
}

View File

@ -97,6 +97,12 @@
<version>1.7.1</version>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@ -90,6 +90,12 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>

View File

@ -106,6 +106,13 @@
<artifactId>slf4j-tinylog</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>

View File

@ -195,8 +195,6 @@ class VearchEmbeddingStoreIT extends EmbeddingStoreIT {
embeddingStore().add(altEmbedding, altSegment);
}
awaitUntilPersisted();
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 1);
assertThat(relevant).hasSize(1);

View File

@ -1,14 +1,13 @@
package dev.langchain4j.store.embedding.inmemory;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT;
/**
* Tests if InMemoryEmbeddingStore works correctly after being serialized and deserialized back.
* See awaitUntilPersisted()
* Tests if {@link InMemoryEmbeddingStore} works correctly after being serialized and deserialized back.
*/
class InMemoryEmbeddingStoreSerializedTest extends EmbeddingStoreWithFilteringIT {
@ -17,14 +16,14 @@ class InMemoryEmbeddingStoreSerializedTest extends EmbeddingStoreWithFilteringIT
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
@Override
protected void awaitUntilPersisted() {
String json = embeddingStore.serializeToJson();
embeddingStore = InMemoryEmbeddingStore.fromJson(json);
protected EmbeddingStore<TextSegment> embeddingStore() {
serializeAndDeserialize();
return embeddingStore;
}
@Override
protected EmbeddingStore<TextSegment> embeddingStore() {
return embeddingStore;
private void serializeAndDeserialize() {
String json = embeddingStore.serializeToJson();
embeddingStore = InMemoryEmbeddingStore.fromJson(json);
}
@Override