EmbeddingStoreIT: use awaitility (#1610)
## Change Use awaitility in `EmbeddingStoreIT` ## General checklist - [X] There are no breaking changes - [X] I have added unit and integration tests for my change - [x] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [x] I have manually run all the unit and integration tests in the [core]
This commit is contained in:
parent
2e47b126be
commit
3e6d50ee40
|
@ -4,25 +4,25 @@ hide_title: false
|
|||
sidebar_position: 0
|
||||
---
|
||||
|
||||
| Embedding Store | Storing Metadata | Filtering by Metadata | Removing Embeddings |
|
||||
|---------------------------------------------------------------------------------------|------------------|-----------------------|---------------------|
|
||||
| [In-memory](/integrations/embedding-stores/in-memory) | ✅ | ✅ | ✅ |
|
||||
| [Astra DB](/integrations/embedding-stores/astra-db) | ✅ | | |
|
||||
| [Azure AI Search](/integrations/embedding-stores/azure-ai-search) | ✅ | ✅ | ✅ |
|
||||
| [Azure CosmosDB Mongo vCore](/integrations/embedding-stores/azure-cosmos-mongo-vcore) | ✅ | | |
|
||||
| [Azure CosmosDB NoSQL](/integrations/embedding-stores/azure-cosmos-nosql) | ✅ | | |
|
||||
| [Cassandra](/integrations/embedding-stores/cassandra) | ✅ | | |
|
||||
| [Chroma](/integrations/embedding-stores/chroma) | ✅ | ✅ | ✅ |
|
||||
| [Elasticsearch](/integrations/embedding-stores/elasticsearch) | ✅ | ✅ | ✅ |
|
||||
| [Infinispan](/integrations/embedding-stores/infinispan) | ✅ | | |
|
||||
| [Milvus](/integrations/embedding-stores/milvus) | ✅ | ✅ | ✅ |
|
||||
| [MongoDB Atlas](/integrations/embedding-stores/mongodb-atlas) | ✅ | | |
|
||||
| [Neo4j](/integrations/embedding-stores/neo4j) | ✅ | | |
|
||||
| [OpenSearch](/integrations/embedding-stores/opensearch) | ✅ | | |
|
||||
| [PGVector](/integrations/embedding-stores/pgvector) | ✅ | ✅ | ✅ |
|
||||
| [Pinecone](/integrations/embedding-stores/pinecone) | ✅ | ✅ | ✅ |
|
||||
| [Qdrant](/integrations/embedding-stores/qdrant) | ✅ | | |
|
||||
| [Redis](/integrations/embedding-stores/redis) | ✅ | | |
|
||||
| [Vearch](/integrations/embedding-stores/vearch) | ✅ | | |
|
||||
| [Vespa](/integrations/embedding-stores/vespa) | | | |
|
||||
| [Weaviate](/integrations/embedding-stores/weaviate) | ✅ | | ✅ |
|
||||
| Embedding Store | Storing Metadata | Filtering by Metadata | Removing Embeddings |
|
||||
|---------------------------------------------------------------------------------------|------------------|----------------------------|---------------------|
|
||||
| [In-memory](/integrations/embedding-stores/in-memory) | ✅ | ✅ | ✅ |
|
||||
| [Astra DB](/integrations/embedding-stores/astra-db) | ✅ | | |
|
||||
| [Azure AI Search](/integrations/embedding-stores/azure-ai-search) | ✅ | ✅ | ✅ |
|
||||
| [Azure CosmosDB Mongo vCore](/integrations/embedding-stores/azure-cosmos-mongo-vcore) | ✅ | | |
|
||||
| [Azure CosmosDB NoSQL](/integrations/embedding-stores/azure-cosmos-nosql) | ✅ | | |
|
||||
| [Cassandra](/integrations/embedding-stores/cassandra) | ✅ | | |
|
||||
| [Chroma](/integrations/embedding-stores/chroma) | ✅ | ✅ | ✅ |
|
||||
| [Elasticsearch](/integrations/embedding-stores/elasticsearch) | ✅ | ✅ | ✅ |
|
||||
| [Infinispan](/integrations/embedding-stores/infinispan) | ✅ | | |
|
||||
| [Milvus](/integrations/embedding-stores/milvus) | ✅ | ✅ | ✅ |
|
||||
| [MongoDB Atlas](/integrations/embedding-stores/mongodb-atlas) | ✅ | Only native filter support | |
|
||||
| [Neo4j](/integrations/embedding-stores/neo4j) | ✅ | | |
|
||||
| [OpenSearch](/integrations/embedding-stores/opensearch) | ✅ | | |
|
||||
| [PGVector](/integrations/embedding-stores/pgvector) | ✅ | ✅ | ✅ |
|
||||
| [Pinecone](/integrations/embedding-stores/pinecone) | ✅ | ✅ | ✅ |
|
||||
| [Qdrant](/integrations/embedding-stores/qdrant) | ✅ | | |
|
||||
| [Redis](/integrations/embedding-stores/redis) | ✅ | | |
|
||||
| [Vearch](/integrations/embedding-stores/vearch) | ✅ | | |
|
||||
| [Vespa](/integrations/embedding-stores/vespa) | | | |
|
||||
| [Weaviate](/integrations/embedding-stores/weaviate) | ✅ | | ✅ |
|
||||
|
|
|
@ -6,12 +6,13 @@ import com.azure.search.documents.indexes.SearchIndexClientBuilder;
|
|||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.rag.content.Content;
|
||||
import dev.langchain4j.rag.query.Query;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT;
|
||||
import org.awaitility.core.ThrowingRunnable;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
@ -31,40 +32,37 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering
|
|||
|
||||
private static final Logger log = LoggerFactory.getLogger(AzureAiSearchContentRetrieverIT.class);
|
||||
|
||||
private EmbeddingModel embeddingModel;
|
||||
private final EmbeddingModel embeddingModel;
|
||||
|
||||
private AzureAiSearchContentRetriever contentRetrieverWithVector;
|
||||
private final AzureAiSearchContentRetriever contentRetrieverWithVector;
|
||||
|
||||
private AzureAiSearchContentRetriever contentRetrieverWithFullText;
|
||||
|
||||
private AzureAiSearchContentRetriever contentRetrieverWithHybrid;
|
||||
private final AzureAiSearchContentRetriever contentRetrieverWithHybrid;
|
||||
|
||||
private AzureAiSearchContentRetriever contentRetrieverWithHybridAndReranking;
|
||||
private final AzureAiSearchContentRetriever contentRetrieverWithHybridAndReranking;
|
||||
|
||||
private int dimensions;
|
||||
public AzureAiSearchContentRetrieverIT() {
|
||||
embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
|
||||
public AzureAiSearchContentRetrieverIT() {
|
||||
embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
dimensions = embeddingModel.embed("test").content().vector().length;
|
||||
SearchIndexClient searchIndexClient = new SearchIndexClientBuilder()
|
||||
.endpoint(System.getenv("AZURE_SEARCH_ENDPOINT"))
|
||||
.credential(new AzureKeyCredential(System.getenv("AZURE_SEARCH_KEY")))
|
||||
.buildClient();
|
||||
|
||||
SearchIndexClient searchIndexClient = new SearchIndexClientBuilder()
|
||||
.endpoint(System.getenv("AZURE_SEARCH_ENDPOINT"))
|
||||
.credential(new AzureKeyCredential(System.getenv("AZURE_SEARCH_KEY")))
|
||||
.buildClient();
|
||||
searchIndexClient.deleteIndex(DEFAULT_INDEX_NAME);
|
||||
|
||||
searchIndexClient.deleteIndex(DEFAULT_INDEX_NAME);
|
||||
|
||||
contentRetrieverWithVector = createContentRetriever(AzureAiSearchQueryType.VECTOR);
|
||||
contentRetrieverWithFullText = createFullTextSearchContentRetriever();
|
||||
contentRetrieverWithHybrid = createContentRetriever(AzureAiSearchQueryType.HYBRID);
|
||||
contentRetrieverWithHybridAndReranking = createContentRetriever(AzureAiSearchQueryType.HYBRID_WITH_RERANKING);
|
||||
contentRetrieverWithVector = createContentRetriever(AzureAiSearchQueryType.VECTOR);
|
||||
contentRetrieverWithFullText = createFullTextSearchContentRetriever();
|
||||
contentRetrieverWithHybrid = createContentRetriever(AzureAiSearchQueryType.HYBRID);
|
||||
contentRetrieverWithHybridAndReranking = createContentRetriever(AzureAiSearchQueryType.HYBRID_WITH_RERANKING);
|
||||
}
|
||||
|
||||
private AzureAiSearchContentRetriever createContentRetriever(AzureAiSearchQueryType azureAiSearchQueryType) {
|
||||
return AzureAiSearchContentRetriever.builder()
|
||||
.endpoint(System.getenv("AZURE_SEARCH_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_SEARCH_KEY"))
|
||||
.dimensions(dimensions)
|
||||
.dimensions(embeddingModel.dimension())
|
||||
.embeddingModel(embeddingModel)
|
||||
.queryType(azureAiSearchQueryType)
|
||||
.maxResults(3)
|
||||
|
@ -105,7 +103,7 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering
|
|||
contentRetrieverWithVector.add(embedding, textSegment);
|
||||
}
|
||||
|
||||
awaitUntilPersisted();
|
||||
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(contents.size()));
|
||||
|
||||
String content = "fruit";
|
||||
Query query = Query.from(content);
|
||||
|
@ -140,7 +138,7 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering
|
|||
contentRetrieverWithVector.add(embedding, textSegment);
|
||||
}
|
||||
|
||||
awaitUntilPersisted();
|
||||
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(contents.size()));
|
||||
|
||||
String content = "house";
|
||||
Query query = Query.from(content);
|
||||
|
@ -189,7 +187,7 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering
|
|||
contentRetrieverWithFullText.add(content);
|
||||
}
|
||||
|
||||
awaitUntilPersisted();
|
||||
awaitUntilAsserted(() -> assertThat(contentRetrieverWithFullText.retrieve(Query.from("a"))).hasSize(contents.size()));
|
||||
|
||||
Query query = Query.from("Alain");
|
||||
List<Content> relevant = contentRetrieverWithFullText.retrieve(query);
|
||||
|
@ -206,9 +204,9 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering
|
|||
|
||||
@Test
|
||||
void testFullTextSearchWithSpecificSearchIndex() {
|
||||
// This doesn't reuse the existing search index, but creates a specialized one only for full text search
|
||||
// This doesn't reuse the existing search index, but creates a specialized one only for full text search
|
||||
contentRetrieverWithVector.deleteIndex();
|
||||
contentRetrieverWithFullText = AzureAiSearchContentRetriever.builder()
|
||||
contentRetrieverWithFullText = AzureAiSearchContentRetriever.builder()
|
||||
.endpoint(System.getenv("AZURE_SEARCH_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_SEARCH_KEY"))
|
||||
.embeddingModel(null)
|
||||
|
@ -219,7 +217,7 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering
|
|||
.build();
|
||||
testFullTextSearch();
|
||||
clearStore();
|
||||
contentRetrieverWithFullText = createFullTextSearchContentRetriever();
|
||||
contentRetrieverWithFullText = createFullTextSearchContentRetriever();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -247,7 +245,7 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering
|
|||
contentRetrieverWithHybrid.add(embedding, textSegment);
|
||||
}
|
||||
|
||||
awaitUntilPersisted();
|
||||
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(contents.size()));
|
||||
|
||||
Query query = Query.from("Algeria");
|
||||
List<Content> relevant = contentRetrieverWithHybrid.retrieve(query);
|
||||
|
@ -287,7 +285,7 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering
|
|||
contentRetrieverWithHybridAndReranking.add(embedding, textSegment);
|
||||
}
|
||||
|
||||
awaitUntilPersisted();
|
||||
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(contents.size()));
|
||||
|
||||
Query query = Query.from("A philosopher who was in the French Resistance");
|
||||
List<Content> relevant = contentRetrieverWithHybridAndReranking.retrieve(query);
|
||||
|
@ -318,18 +316,24 @@ public class AzureAiSearchContentRetrieverIT extends EmbeddingStoreWithFiltering
|
|||
AzureAiSearchContentRetriever azureAiSearchContentRetriever = contentRetrieverWithVector;
|
||||
try {
|
||||
azureAiSearchContentRetriever.deleteIndex();
|
||||
azureAiSearchContentRetriever.createOrUpdateIndex(dimensions);
|
||||
azureAiSearchContentRetriever.createOrUpdateIndex(embeddingModel.dimension());
|
||||
} catch (RuntimeException e) {
|
||||
log.error("Failed to clean up the index. You should look at deleting it manually.", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void awaitUntilPersisted() {
|
||||
protected void awaitUntilAsserted(ThrowingRunnable assertion) {
|
||||
super.awaitUntilAsserted(assertion);
|
||||
try {
|
||||
Thread.sleep(1_000);
|
||||
Thread.sleep(1000); // TODO figure out why this is needed and remove this hack
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean assertEmbedding() {
|
||||
return false; // TODO remove this hack after https://github.com/langchain4j/langchain4j/issues/1617 is closed
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,58 +1,34 @@
|
|||
package dev.langchain4j.rag.content.retriever.azure.search;
|
||||
|
||||
import com.azure.core.credential.AzureKeyCredential;
|
||||
import com.azure.search.documents.indexes.SearchIndexClient;
|
||||
import com.azure.search.documents.indexes.SearchIndexClientBuilder;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.awaitility.core.ThrowingRunnable;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import static dev.langchain4j.store.embedding.azure.search.AbstractAzureAiSearchEmbeddingStore.DEFAULT_INDEX_NAME;
|
||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||
import static dev.langchain4j.rag.content.retriever.azure.search.AzureAiSearchQueryType.HYBRID;
|
||||
|
||||
@EnabledIfEnvironmentVariable(named = "AZURE_SEARCH_ENDPOINT", matches = ".+")
|
||||
public class AzureAiSearchContentRetrieverRemovalIT extends EmbeddingStoreWithRemovalIT {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(AzureAiSearchContentRetrieverRemovalIT.class);
|
||||
|
||||
private EmbeddingModel embeddingModel;
|
||||
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
|
||||
private AzureAiSearchContentRetriever contentRetrieverWithVector;
|
||||
|
||||
public AzureAiSearchContentRetrieverRemovalIT() {
|
||||
embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
|
||||
SearchIndexClient searchIndexClient = new SearchIndexClientBuilder()
|
||||
.endpoint(System.getenv("AZURE_SEARCH_ENDPOINT"))
|
||||
.credential(new AzureKeyCredential(System.getenv("AZURE_SEARCH_KEY")))
|
||||
.buildClient();
|
||||
|
||||
searchIndexClient.deleteIndex(DEFAULT_INDEX_NAME);
|
||||
|
||||
contentRetrieverWithVector = createContentRetriever(AzureAiSearchQueryType.VECTOR);
|
||||
}
|
||||
|
||||
private AzureAiSearchContentRetriever createContentRetriever(AzureAiSearchQueryType azureAiSearchQueryType) {
|
||||
return AzureAiSearchContentRetriever.builder()
|
||||
.endpoint(System.getenv("AZURE_SEARCH_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_SEARCH_KEY"))
|
||||
.dimensions(embeddingModel.dimension())
|
||||
.embeddingModel(embeddingModel)
|
||||
.queryType(azureAiSearchQueryType)
|
||||
.maxResults(3)
|
||||
.minScore(0.0)
|
||||
.build();
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
clearStore();
|
||||
}
|
||||
private final AzureAiSearchContentRetriever contentRetrieverWithVector = AzureAiSearchContentRetriever.builder()
|
||||
.endpoint(System.getenv("AZURE_SEARCH_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_SEARCH_KEY"))
|
||||
.indexName(randomUUID())
|
||||
.dimensions(embeddingModel.dimension())
|
||||
.embeddingModel(embeddingModel)
|
||||
.queryType(HYBRID)
|
||||
.build();
|
||||
|
||||
@Override
|
||||
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||
|
@ -64,14 +40,22 @@ public class AzureAiSearchContentRetrieverRemovalIT extends EmbeddingStoreWithRe
|
|||
return this.embeddingModel;
|
||||
}
|
||||
|
||||
protected void clearStore() {
|
||||
log.debug("Deleting the search index");
|
||||
AzureAiSearchContentRetriever azureAiSearchContentRetriever = contentRetrieverWithVector;
|
||||
@AfterEach
|
||||
void afterEach() {
|
||||
try {
|
||||
azureAiSearchContentRetriever.deleteIndex();
|
||||
azureAiSearchContentRetriever.createOrUpdateIndex(embeddingModel.dimension());
|
||||
contentRetrieverWithVector.deleteIndex();
|
||||
} catch (RuntimeException e) {
|
||||
log.error("Failed to clean up the index. You should look at deleting it manually.", e);
|
||||
log.error("Failed to delete the index. You should look at deleting it manually.", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void awaitUntilAsserted(ThrowingRunnable assertion) {
|
||||
super.awaitUntilAsserted(assertion);
|
||||
try {
|
||||
Thread.sleep(1000); // TODO figure out why this is needed and remove this hack
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,14 +6,13 @@ import com.azure.search.documents.indexes.SearchIndexClientBuilder;
|
|||
import com.azure.search.documents.indexes.models.SearchField;
|
||||
import com.azure.search.documents.indexes.models.SearchFieldDataType;
|
||||
import com.azure.search.documents.indexes.models.SearchIndex;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.awaitility.core.ThrowingRunnable;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
|
@ -22,10 +21,9 @@ import org.slf4j.LoggerFactory;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||
import static dev.langchain4j.store.embedding.azure.search.AbstractAzureAiSearchEmbeddingStore.DEFAULT_FIELD_ID;
|
||||
import static dev.langchain4j.store.embedding.azure.search.AbstractAzureAiSearchEmbeddingStore.DEFAULT_INDEX_NAME;
|
||||
import static java.util.Arrays.asList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
|
@ -34,31 +32,50 @@ public class AzureAiSearchEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT
|
|||
|
||||
private static final Logger log = LoggerFactory.getLogger(AzureAiSearchEmbeddingStoreIT.class);
|
||||
|
||||
private EmbeddingModel embeddingModel;
|
||||
private static final String AZURE_SEARCH_ENDPOINT = System.getenv("AZURE_SEARCH_ENDPOINT");
|
||||
private static final String AZURE_SEARCH_KEY = System.getenv("AZURE_SEARCH_KEY");
|
||||
|
||||
private EmbeddingStore<TextSegment> embeddingStore;
|
||||
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
|
||||
private int dimensions;
|
||||
private final AzureAiSearchEmbeddingStore embeddingStore = AzureAiSearchEmbeddingStore.builder()
|
||||
.endpoint(AZURE_SEARCH_ENDPOINT)
|
||||
.apiKey(AZURE_SEARCH_KEY)
|
||||
.indexName(randomUUID())
|
||||
.dimensions(embeddingModel.dimension())
|
||||
.build();
|
||||
|
||||
private String AZURE_SEARCH_ENDPOINT = System.getenv("AZURE_SEARCH_ENDPOINT");
|
||||
|
||||
private String AZURE_SEARCH_KEY = System.getenv("AZURE_SEARCH_KEY");
|
||||
|
||||
public AzureAiSearchEmbeddingStoreIT() {
|
||||
|
||||
embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
dimensions = embeddingModel.embed("test").content().vector().length;
|
||||
|
||||
embeddingStore = AzureAiSearchEmbeddingStore.builder()
|
||||
.endpoint(AZURE_SEARCH_ENDPOINT)
|
||||
.apiKey(AZURE_SEARCH_KEY)
|
||||
.dimensions(dimensions)
|
||||
.build();
|
||||
@Override
|
||||
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||
return embeddingStore;
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
clearStore();
|
||||
@Override
|
||||
protected EmbeddingModel embeddingModel() {
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void afterEach() {
|
||||
try {
|
||||
embeddingStore.deleteIndex();
|
||||
} catch (RuntimeException e) {
|
||||
log.error("Failed to delete the index. You should look at deleting it manually.", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void awaitUntilAsserted(ThrowingRunnable assertion) {
|
||||
super.awaitUntilAsserted(assertion);
|
||||
try {
|
||||
Thread.sleep(1000); // TODO figure out why this is needed and remove this hack
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean assertEmbedding() {
|
||||
return false; // TODO remove this hack after https://github.com/langchain4j/langchain4j/issues/1617 is closed
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -89,7 +106,7 @@ public class AzureAiSearchEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT
|
|||
|
||||
try {
|
||||
new AzureAiSearchEmbeddingStore(AZURE_SEARCH_ENDPOINT,
|
||||
new AzureKeyCredential(AZURE_SEARCH_KEY), true, providedIndex, "ANOTHER_INDEX_NAME", null);
|
||||
new AzureKeyCredential(AZURE_SEARCH_KEY), true, providedIndex, "ANOTHER_INDEX_NAME", null);
|
||||
|
||||
fail("Expected IllegalArgumentException to be thrown");
|
||||
} catch (IllegalArgumentException e) {
|
||||
|
@ -102,71 +119,16 @@ public class AzureAiSearchEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT
|
|||
|
||||
@Test
|
||||
public void when_an_index_is_not_provided_the_default_name_is_used() {
|
||||
AzureAiSearchEmbeddingStore store =new AzureAiSearchEmbeddingStore(AZURE_SEARCH_ENDPOINT,
|
||||
new AzureKeyCredential(AZURE_SEARCH_KEY), false, null, null, null);
|
||||
|
||||
AzureAiSearchEmbeddingStore store = new AzureAiSearchEmbeddingStore(
|
||||
AZURE_SEARCH_ENDPOINT,
|
||||
new AzureKeyCredential(AZURE_SEARCH_KEY),
|
||||
false,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
);
|
||||
|
||||
assertEquals(DEFAULT_INDEX_NAME, store.searchClient.getIndexName());
|
||||
}
|
||||
|
||||
@Test
|
||||
void test_add_embeddings_and_find_relevant() {
|
||||
String content1 = "banana";
|
||||
String content2 = "computer";
|
||||
String content3 = "apple";
|
||||
String content4 = "pizza";
|
||||
String content5 = "strawberry";
|
||||
String content6 = "chess";
|
||||
List<String> contents = asList(content1, content2, content3, content4, content5, content6);
|
||||
|
||||
for (String content : contents) {
|
||||
TextSegment textSegment = TextSegment.from(content);
|
||||
Embedding embedding = embeddingModel.embed(content).content();
|
||||
embeddingStore.add(embedding, textSegment);
|
||||
}
|
||||
|
||||
awaitUntilPersisted();
|
||||
|
||||
Embedding relevantEmbedding = embeddingModel.embed("fruit").content();
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(relevantEmbedding, 3);
|
||||
assertThat(relevant).hasSize(3);
|
||||
assertThat(relevant.get(0).embedding()).isNotNull();
|
||||
assertThat(relevant.get(0).embedded().text()).isIn(content1, content3, content5);
|
||||
log.info("#1 relevant item: {}", relevant.get(0).embedded().text());
|
||||
assertThat(relevant.get(1).embedding()).isNotNull();
|
||||
assertThat(relevant.get(1).embedded().text()).isIn(content1, content3, content5);
|
||||
log.info("#2 relevant item: {}", relevant.get(1).embedded().text());
|
||||
assertThat(relevant.get(2).embedding()).isNotNull();
|
||||
assertThat(relevant.get(2).embedded().text()).isIn(content1, content3, content5);
|
||||
log.info("#3 relevant item: {}", relevant.get(2).embedded().text());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||
return embeddingStore;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EmbeddingModel embeddingModel() {
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void clearStore() {
|
||||
AzureAiSearchEmbeddingStore azureAiSearchEmbeddingStore = (AzureAiSearchEmbeddingStore) embeddingStore;
|
||||
try {
|
||||
azureAiSearchEmbeddingStore.deleteIndex();
|
||||
azureAiSearchEmbeddingStore.createOrUpdateIndex(dimensions);
|
||||
} catch (RuntimeException e) {
|
||||
log.error("Failed to clean up the index. You should look at deleting it manually.", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void awaitUntilPersisted() {
|
||||
try {
|
||||
Thread.sleep(1_000);
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,49 +5,27 @@ import dev.langchain4j.model.embedding.EmbeddingModel;
|
|||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.awaitility.core.ThrowingRunnable;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||
|
||||
@EnabledIfEnvironmentVariable(named = "AZURE_SEARCH_ENDPOINT", matches = ".+")
|
||||
public class AzureAiSearchEmbeddingStoreRemovalIT extends EmbeddingStoreWithRemovalIT {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(AzureAiSearchEmbeddingStoreRemovalIT.class);
|
||||
|
||||
private EmbeddingModel embeddingModel;
|
||||
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
|
||||
private EmbeddingStore<TextSegment> embeddingStore;
|
||||
|
||||
private String AZURE_SEARCH_ENDPOINT = System.getenv("AZURE_SEARCH_ENDPOINT");
|
||||
|
||||
private String AZURE_SEARCH_KEY = System.getenv("AZURE_SEARCH_KEY");
|
||||
|
||||
public AzureAiSearchEmbeddingStoreRemovalIT() {
|
||||
|
||||
embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
|
||||
embeddingStore = AzureAiSearchEmbeddingStore.builder()
|
||||
.endpoint(AZURE_SEARCH_ENDPOINT)
|
||||
.apiKey(AZURE_SEARCH_KEY)
|
||||
.dimensions(embeddingModel.dimension())
|
||||
.build();
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
clearStore();
|
||||
}
|
||||
|
||||
private void clearStore() {
|
||||
AzureAiSearchEmbeddingStore azureAiSearchEmbeddingStore = (AzureAiSearchEmbeddingStore) embeddingStore;
|
||||
try {
|
||||
azureAiSearchEmbeddingStore.deleteIndex();
|
||||
azureAiSearchEmbeddingStore.createOrUpdateIndex(embeddingModel.dimension());
|
||||
} catch (RuntimeException e) {
|
||||
log.error("Failed to clean up the index. You should look at deleting it manually.", e);
|
||||
}
|
||||
}
|
||||
private final AzureAiSearchEmbeddingStore embeddingStore = AzureAiSearchEmbeddingStore.builder()
|
||||
.endpoint(System.getenv("AZURE_SEARCH_ENDPOINT"))
|
||||
.apiKey(System.getenv("AZURE_SEARCH_KEY"))
|
||||
.indexName(randomUUID())
|
||||
.dimensions(embeddingModel.dimension())
|
||||
.build();
|
||||
|
||||
@Override
|
||||
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||
|
@ -58,4 +36,23 @@ public class AzureAiSearchEmbeddingStoreRemovalIT extends EmbeddingStoreWithRemo
|
|||
protected EmbeddingModel embeddingModel() {
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void afterEach() {
|
||||
try {
|
||||
embeddingStore.deleteIndex();
|
||||
} catch (RuntimeException e) {
|
||||
log.error("Failed to delete the index. You should look at deleting it manually.", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void awaitUntilAsserted(ThrowingRunnable assertion) {
|
||||
super.awaitUntilAsserted(assertion);
|
||||
try {
|
||||
Thread.sleep(1000); // TODO figure out why this is needed and remove this hack
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -85,6 +85,13 @@
|
|||
<artifactId>slf4j-tinylog</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.awaitility</groupId>
|
||||
<artifactId>awaitility</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
|
|
|
@ -6,43 +6,30 @@ import com.mongodb.client.MongoClient;
|
|||
import com.mongodb.client.MongoClients;
|
||||
import com.mongodb.client.MongoCollection;
|
||||
import com.mongodb.client.model.Filters;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
|
||||
import org.bson.BsonDocument;
|
||||
import org.bson.codecs.configuration.CodecRegistry;
|
||||
import org.bson.codecs.pojo.PojoCodecProvider;
|
||||
import org.bson.conversions.Bson;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static java.util.Arrays.asList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.bson.codecs.configuration.CodecRegistries.fromProviders;
|
||||
import static org.bson.codecs.configuration.CodecRegistries.fromRegistries;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOS_ENDPOINT", matches = ".+")
|
||||
public class AzureCosmosDBMongoVCoreEmbeddingStoreIT extends EmbeddingStoreIT {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(AzureCosmosDBMongoVCoreEmbeddingStoreIT.class);
|
||||
private static MongoClient client;
|
||||
private final EmbeddingModel embeddingModel;
|
||||
private final EmbeddingStore<TextSegment> embeddingStore;
|
||||
private final int dimensions;
|
||||
|
||||
|
||||
public AzureCosmosDBMongoVCoreEmbeddingStoreIT() {
|
||||
embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
dimensions = embeddingModel.embed("hello").content().vector().length;
|
||||
|
||||
client = MongoClients.create(
|
||||
MongoClientSettings.builder()
|
||||
|
@ -59,45 +46,13 @@ public class AzureCosmosDBMongoVCoreEmbeddingStoreIT extends EmbeddingStoreIT {
|
|||
.createIndex(true)
|
||||
.kind("vector-hnsw")
|
||||
.numLists(2)
|
||||
.dimensions(dimensions)
|
||||
.dimensions(embeddingModel.dimension())
|
||||
.m(16)
|
||||
.efConstruction(64)
|
||||
.efSearch(40)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAddEmbeddingsAndFindRelevant() {
|
||||
String content1 = "banana";
|
||||
String content2 = "computer";
|
||||
String content3 = "apple";
|
||||
String content4 = "pizza";
|
||||
String content5 = "strawberry";
|
||||
String content6 = "chess";
|
||||
List<String> contents = asList(content1, content2, content3, content4, content5, content6);
|
||||
|
||||
for (String content : contents) {
|
||||
TextSegment textSegment = TextSegment.from(content);
|
||||
Embedding embedding = embeddingModel.embed(content).content();
|
||||
embeddingStore.add(embedding, textSegment);
|
||||
}
|
||||
|
||||
awaitUntilPersisted();
|
||||
|
||||
Embedding relevantEmbedding = embeddingModel.embed("fruit").content();
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(relevantEmbedding, 3);
|
||||
assertThat(relevant).hasSize(3);
|
||||
assertThat(relevant.get(0).embedding()).isNotNull();
|
||||
assertThat(relevant.get(0).embedded().text()).isIn(content1, content3, content5);
|
||||
log.info("#1 relevant item: {}", relevant.get(0).embedded().text());
|
||||
assertThat(relevant.get(1).embedding()).isNotNull();
|
||||
assertThat(relevant.get(1).embedded().text()).isIn(content1, content3, content5);
|
||||
log.info("#2 relevant item: {}", relevant.get(1).embedded().text());
|
||||
assertThat(relevant.get(2).embedding()).isNotNull();
|
||||
assertThat(relevant.get(2).embedded().text()).isIn(content1, content3, content5);
|
||||
log.info("#3 relevant item: {}", relevant.get(2).embedded().text());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||
return embeddingStore;
|
||||
|
@ -108,15 +63,6 @@ public class AzureCosmosDBMongoVCoreEmbeddingStoreIT extends EmbeddingStoreIT {
|
|||
return embeddingModel;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void awaitUntilPersisted() {
|
||||
try {
|
||||
Thread.sleep(1_000);
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void clearStore() {
|
||||
CodecRegistry pojoCodecRegistry = fromProviders(PojoCodecProvider.builder()
|
||||
|
|
|
@ -70,6 +70,12 @@
|
|||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.awaitility</groupId>
|
||||
<artifactId>awaitility</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
</project>
|
|
@ -3,50 +3,45 @@ package dev.langchain4j.store.embedding.azure.cosmos.nosql;
|
|||
import com.azure.cosmos.ConsistencyLevel;
|
||||
import com.azure.cosmos.CosmosClient;
|
||||
import com.azure.cosmos.CosmosClientBuilder;
|
||||
import com.azure.cosmos.CosmosDatabase;
|
||||
import com.azure.cosmos.models.*;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import com.azure.cosmos.models.CosmosContainerProperties;
|
||||
import com.azure.cosmos.models.CosmosVectorDataType;
|
||||
import com.azure.cosmos.models.CosmosVectorDistanceFunction;
|
||||
import com.azure.cosmos.models.CosmosVectorEmbedding;
|
||||
import com.azure.cosmos.models.CosmosVectorEmbeddingPolicy;
|
||||
import com.azure.cosmos.models.CosmosVectorIndexSpec;
|
||||
import com.azure.cosmos.models.CosmosVectorIndexType;
|
||||
import com.azure.cosmos.models.IncludedPath;
|
||||
import com.azure.cosmos.models.IndexingMode;
|
||||
import com.azure.cosmos.models.IndexingPolicy;
|
||||
import com.azure.cosmos.models.PartitionKeyDefinition;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static java.util.Arrays.asList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOS_HOST", matches = ".+")
|
||||
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOS_MASTER_KEY", matches = ".+")
|
||||
class AzureCosmosDbNoSqlEmbeddingStoreIT extends EmbeddingStoreIT {
|
||||
|
||||
protected static Logger logger = LoggerFactory.getLogger(AzureCosmosDbNoSqlEmbeddingStoreIT.class);
|
||||
|
||||
private static final String DATABASE_NAME = "test_database_langchain_java";
|
||||
private static final String CONTAINER_NAME = "test_container";
|
||||
private CosmosClient client;
|
||||
CosmosDatabase database;
|
||||
private final EmbeddingModel embeddingModel;
|
||||
|
||||
private final EmbeddingStore<TextSegment> embeddingStore;
|
||||
private final int dimensions;
|
||||
private final String HOST = System.getenv("AZURE_COSMOS_HOST");
|
||||
private final String KEY = System.getenv("AZURE_COSMOS_MASTER_KEY");
|
||||
private final EmbeddingModel embeddingModel;
|
||||
|
||||
public AzureCosmosDbNoSqlEmbeddingStoreIT() {
|
||||
embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
dimensions = embeddingModel.embed("hello").content().vector().length;
|
||||
int dimensions = embeddingModel.dimension();
|
||||
|
||||
client = new CosmosClientBuilder()
|
||||
.endpoint(HOST)
|
||||
.key(KEY)
|
||||
CosmosClient client = new CosmosClientBuilder()
|
||||
.endpoint(System.getenv("AZURE_COSMOS_HOST"))
|
||||
.key(System.getenv("AZURE_COSMOS_MASTER_KEY"))
|
||||
.consistencyLevel(ConsistencyLevel.EVENTUAL)
|
||||
.contentResponseOnWriteEnabled(true)
|
||||
.buildClient();
|
||||
|
@ -59,43 +54,6 @@ class AzureCosmosDbNoSqlEmbeddingStoreIT extends EmbeddingStoreIT {
|
|||
.cosmosVectorIndexes(populateVectorIndexSpec())
|
||||
.containerProperties(populateContainerProperties())
|
||||
.build();
|
||||
database = client.getDatabase(DATABASE_NAME);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAddEmbeddingsAndFindRelevant() {
|
||||
String content1 = "banana";
|
||||
String content2 = "computer";
|
||||
String content3 = "apple";
|
||||
String content4 = "pizza";
|
||||
String content5 = "strawberry";
|
||||
String content6 = "chess";
|
||||
|
||||
List<String> contents = asList(content1, content2, content3, content4, content5, content6);
|
||||
|
||||
for (String content : contents) {
|
||||
TextSegment textSegment = TextSegment.from(content);
|
||||
Embedding embedding = embeddingModel.embed(content).content();
|
||||
embeddingStore.add(embedding, textSegment);
|
||||
}
|
||||
|
||||
awaitUntilPersisted();
|
||||
|
||||
Embedding relevantEmbedding = embeddingModel.embed("fruit").content();
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(relevantEmbedding, 3);
|
||||
assertThat(relevant).hasSize(3);
|
||||
assertThat(relevant.get(0).embedding()).isNotNull();
|
||||
assertThat(relevant.get(0).embedded().text()).isIn(content1, content3, content5);
|
||||
logger.info("#1 relevant item: {}", relevant.get(0).embedded().text());
|
||||
assertThat(relevant.get(1).embedding()).isNotNull();
|
||||
assertThat(relevant.get(1).embedded().text()).isIn(content1, content3, content5);
|
||||
logger.info("#2 relevant item: {}", relevant.get(1).embedded().text());
|
||||
assertThat(relevant.get(2).embedding()).isNotNull();
|
||||
assertThat(relevant.get(2).embedded().text()).isIn(content1, content3, content5);
|
||||
logger.info("#3 relevant item: {}", relevant.get(2).embedded().text());
|
||||
|
||||
safeDeleteDatabase(database);
|
||||
safeClose(client);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -108,38 +66,6 @@ class AzureCosmosDbNoSqlEmbeddingStoreIT extends EmbeddingStoreIT {
|
|||
return embeddingModel;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void awaitUntilPersisted() {
|
||||
try {
|
||||
Thread.sleep(1_000);
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void clearStore() {
|
||||
}
|
||||
|
||||
private void safeDeleteDatabase(CosmosDatabase database) {
|
||||
if (database != null) {
|
||||
try {
|
||||
database.delete();
|
||||
} catch (Exception e) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void safeClose(CosmosClient client) {
|
||||
if (client != null) {
|
||||
try {
|
||||
client.close();
|
||||
} catch (Exception e) {
|
||||
logger.error("failed to close client", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private CosmosVectorEmbeddingPolicy populateVectorEmbeddingPolicy(int dimensions) {
|
||||
CosmosVectorEmbeddingPolicy vectorEmbeddingPolicy = new CosmosVectorEmbeddingPolicy();
|
||||
CosmosVectorEmbedding embedding = new CosmosVectorEmbedding();
|
||||
|
@ -174,6 +100,4 @@ class AzureCosmosDbNoSqlEmbeddingStoreIT extends EmbeddingStoreIT {
|
|||
collectionDefinition.setIndexingPolicy(indexingPolicy);
|
||||
return collectionDefinition;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -112,6 +112,12 @@
|
|||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.awaitility</groupId>
|
||||
<artifactId>awaitility</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
|
|
|
@ -3,30 +3,24 @@ package dev.langchain4j.store.embedding.astradb;
|
|||
import com.dtsx.astra.sdk.AstraDB;
|
||||
import com.dtsx.astra.sdk.AstraDBAdmin;
|
||||
import com.dtsx.astra.sdk.AstraDBCollection;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiModelName;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
|
||||
import io.stargate.sdk.data.domain.SimilarityMetric;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.MethodOrderer;
|
||||
import org.junit.jupiter.api.TestMethodOrder;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@Disabled("AstraDB is not available in the CI")
|
||||
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
|
||||
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
|
||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "sk.*")
|
||||
|
@ -81,25 +75,4 @@ class AstraDbEmbeddingStoreIT extends EmbeddingStoreIT {
|
|||
}
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
void testAddEmbeddingAndFindRelevant() {
|
||||
|
||||
Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F});
|
||||
TextSegment textSegment = TextSegment.from("Text", Metadata.from("Key", "Value"));
|
||||
String id = embeddingStore.add(embedding, textSegment);
|
||||
assertThat(id != null && !id.isEmpty()).isTrue();
|
||||
|
||||
Embedding referenceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F});
|
||||
List<EmbeddingMatch<TextSegment>> embeddingMatches = embeddingStore.findRelevant(referenceEmbedding, 1);
|
||||
assertThat(embeddingMatches).hasSize(1);
|
||||
|
||||
EmbeddingMatch<TextSegment> embeddingMatch = embeddingMatches.get(0);
|
||||
assertThat(embeddingMatch.score()).isBetween(0d, 1d);
|
||||
assertThat(embeddingMatch.embeddingId()).isEqualTo(id);
|
||||
assertThat(embeddingMatch.embedding()).isEqualTo(embedding);
|
||||
assertThat(embeddingMatch.embedded()).isEqualTo(textSegment);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
|||
|
||||
import java.util.UUID;
|
||||
|
||||
import static com.dtsx.astra.sdk.cassio.CassandraSimilarityMetric.COSINE;
|
||||
import static com.dtsx.astra.sdk.utils.TestUtils.TEST_REGION;
|
||||
import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken;
|
||||
|
||||
|
@ -36,12 +37,10 @@ class CassandraEmbeddingStoreAstraIT extends CassandraEmbeddingStoreIT {
|
|||
.databaseRegion(TEST_REGION)
|
||||
.keyspace(KEYSPACE)
|
||||
.table(TEST_INDEX)
|
||||
.dimension(embeddingModelDimension()) // openai model
|
||||
.metric(CassandraSimilarityMetric.COSINE)
|
||||
.dimension(embeddingModel().dimension())
|
||||
.metric(COSINE)
|
||||
.build();
|
||||
}
|
||||
return embeddingStore;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
package dev.langchain4j.store.embedding.cassandra;
|
||||
|
||||
import com.datastax.oss.driver.api.core.CqlSession;
|
||||
import com.dtsx.astra.sdk.cassio.CassandraSimilarityMetric;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.testcontainers.DockerClientFactory;
|
||||
import org.testcontainers.containers.CassandraContainer;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
|
@ -15,10 +13,11 @@ import org.testcontainers.utility.DockerImageName;
|
|||
import java.net.InetSocketAddress;
|
||||
import java.util.Collections;
|
||||
|
||||
import static com.dtsx.astra.sdk.cassio.CassandraSimilarityMetric.COSINE;
|
||||
|
||||
/**
|
||||
* Work with Cassandra Embedding Store.
|
||||
*/
|
||||
@Disabled("No Docker in the CI")
|
||||
@Testcontainers
|
||||
class CassandraEmbeddingStoreDockerIT extends CassandraEmbeddingStoreIT {
|
||||
|
||||
|
@ -55,7 +54,7 @@ class CassandraEmbeddingStoreDockerIT extends CassandraEmbeddingStoreIT {
|
|||
* Stop Cassandra Node
|
||||
*/
|
||||
@AfterAll
|
||||
static void afterTests() throws Exception {
|
||||
static void afterTests() {
|
||||
cassandraContainer.stop();
|
||||
}
|
||||
|
||||
|
@ -69,11 +68,10 @@ class CassandraEmbeddingStoreDockerIT extends CassandraEmbeddingStoreIT {
|
|||
.localDataCenter(DATACENTER)
|
||||
.keyspace(KEYSPACE)
|
||||
.table(TEST_INDEX)
|
||||
.dimension(embeddingModelDimension())
|
||||
.metric(CassandraSimilarityMetric.COSINE)
|
||||
.dimension(embeddingModel().dimension())
|
||||
.metric(COSINE)
|
||||
.build();
|
||||
}
|
||||
return embeddingStore;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -4,22 +4,17 @@ import dev.langchain4j.data.document.Metadata;
|
|||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiModelName;
|
||||
import dev.langchain4j.store.embedding.CosineSimilarity;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
|
||||
import dev.langchain4j.store.embedding.RelevanceScore;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.assertj.core.data.Percentage;
|
||||
import org.junit.jupiter.api.MethodOrderer;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestMethodOrder;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||
import static java.util.Arrays.asList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.data.Percentage.withPercentage;
|
||||
|
||||
|
@ -28,26 +23,17 @@ import static org.assertj.core.data.Percentage.withPercentage;
|
|||
abstract class CassandraEmbeddingStoreIT extends EmbeddingStoreIT {
|
||||
|
||||
protected static final String KEYSPACE = "langchain4j";
|
||||
|
||||
protected static final String TEST_INDEX = "test_embedding_store";
|
||||
|
||||
CassandraEmbeddingStore embeddingStore;
|
||||
|
||||
EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName(OpenAiModelName.TEXT_EMBEDDING_ADA_002)
|
||||
.timeout(Duration.ofSeconds(15))
|
||||
.build();
|
||||
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
|
||||
@Override
|
||||
protected EmbeddingModel embeddingModel() {
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
protected int embeddingModelDimension() {
|
||||
return 1536;
|
||||
}
|
||||
|
||||
/**
|
||||
* It is required to clean the repository in between tests
|
||||
*/
|
||||
|
@ -57,19 +43,16 @@ abstract class CassandraEmbeddingStoreIT extends EmbeddingStoreIT {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void awaitUntilPersisted() {
|
||||
try {
|
||||
Thread.sleep(1000);
|
||||
} catch(Exception e) {
|
||||
}
|
||||
protected Percentage percentage() {
|
||||
return withPercentage(6); // TODO figure out why difference is so big
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_retrieve_inserted_vector_by_ann() {
|
||||
String sourceSentence = "Testing is doubting !";
|
||||
Embedding sourceEmbedding = embeddingModel().embed(sourceSentence).content();
|
||||
String sourceSentence = "Testing is doubting !";
|
||||
Embedding sourceEmbedding = embeddingModel().embed(sourceSentence).content();
|
||||
TextSegment sourceTextSegment = TextSegment.from(sourceSentence);
|
||||
String id = embeddingStore().add(sourceEmbedding, sourceTextSegment);
|
||||
String id = embeddingStore().add(sourceEmbedding, sourceTextSegment);
|
||||
assertThat(id != null && !id.isEmpty()).isTrue();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> embeddingMatches = embeddingStore.findRelevant(sourceEmbedding, 10);
|
||||
|
@ -84,12 +67,12 @@ abstract class CassandraEmbeddingStoreIT extends EmbeddingStoreIT {
|
|||
|
||||
@Test
|
||||
void should_retrieve_inserted_vector_by_ann_and_metadata() {
|
||||
String sourceSentence = "In GOD we trust, everything else we test!";
|
||||
Embedding sourceEmbedding = embeddingModel().embed(sourceSentence).content();
|
||||
String sourceSentence = "In GOD we trust, everything else we test!";
|
||||
Embedding sourceEmbedding = embeddingModel().embed(sourceSentence).content();
|
||||
TextSegment sourceTextSegment = TextSegment.from(sourceSentence, new Metadata()
|
||||
.put("user", "GOD")
|
||||
.put("test", "false"));
|
||||
String id = embeddingStore().add(sourceEmbedding, sourceTextSegment);
|
||||
String id = embeddingStore().add(sourceEmbedding, sourceTextSegment);
|
||||
assertThat(id != null && !id.isEmpty()).isTrue();
|
||||
|
||||
// Should be found with no filter
|
||||
|
@ -106,144 +89,4 @@ abstract class CassandraEmbeddingStoreIT extends EmbeddingStoreIT {
|
|||
.findRelevant(sourceEmbedding, 10, .5d, Metadata.from("user", "JOHN"));
|
||||
assertThat(matchesJohn).isEmpty();
|
||||
}
|
||||
|
||||
// metrics returned are 1.95% off we updated to "withPercentage(2)"
|
||||
|
||||
@Test
|
||||
void should_return_correct_score() {
|
||||
Embedding embedding = embeddingModel().embed("hello").content();
|
||||
String id = embeddingStore().add(embedding);
|
||||
assertThat(id).isNotBlank();
|
||||
Embedding referenceEmbedding = embeddingModel().embed("hi").content();
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(referenceEmbedding, 1);
|
||||
assertThat(relevant).hasSize(1);
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(
|
||||
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)),
|
||||
withPercentage(2)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_find_with_min_score() {
|
||||
String firstId = randomUUID();
|
||||
Embedding firstEmbedding = embeddingModel().embed("hello").content();
|
||||
embeddingStore().add(firstId, firstEmbedding);
|
||||
String secondId = randomUUID();
|
||||
Embedding secondEmbedding = embeddingModel().embed("hi").content();
|
||||
embeddingStore().add(secondId, secondEmbedding);
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
|
||||
assertThat(relevant).hasSize(2);
|
||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(firstMatch.embeddingId()).isEqualTo(firstId);
|
||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
||||
assertThat(secondMatch.score()).isCloseTo(
|
||||
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)),
|
||||
withPercentage(2)
|
||||
);
|
||||
assertThat(secondMatch.embeddingId()).isEqualTo(secondId);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant2 = embeddingStore().findRelevant(
|
||||
firstEmbedding,
|
||||
10,
|
||||
secondMatch.score() - 0.01
|
||||
);
|
||||
assertThat(relevant2).hasSize(2);
|
||||
assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId);
|
||||
assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant3 = embeddingStore().findRelevant(
|
||||
firstEmbedding,
|
||||
10,
|
||||
secondMatch.score()
|
||||
);
|
||||
assertThat(relevant3).hasSize(2);
|
||||
assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId);
|
||||
assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant4 = embeddingStore().findRelevant(
|
||||
firstEmbedding,
|
||||
10,
|
||||
secondMatch.score() + 0.01
|
||||
);
|
||||
assertThat(relevant4).hasSize(1);
|
||||
assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_add_multiple_embeddings_with_segments() {
|
||||
|
||||
TextSegment firstSegment = TextSegment.from("hello");
|
||||
Embedding firstEmbedding = embeddingModel().embed(firstSegment.text()).content();
|
||||
|
||||
TextSegment secondSegment = TextSegment.from("hi");
|
||||
Embedding secondEmbedding = embeddingModel().embed(secondSegment.text()).content();
|
||||
|
||||
List<String> ids = embeddingStore().addAll(
|
||||
asList(firstEmbedding, secondEmbedding),
|
||||
asList(firstSegment, secondSegment)
|
||||
);
|
||||
assertThat(ids).hasSize(2);
|
||||
assertThat(ids.get(0)).isNotBlank();
|
||||
assertThat(ids.get(1)).isNotBlank();
|
||||
assertThat(ids.get(0)).isNotEqualTo(ids.get(1));
|
||||
|
||||
awaitUntilPersisted();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
|
||||
assertThat(relevant).hasSize(2);
|
||||
|
||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
|
||||
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
|
||||
assertThat(firstMatch.embedded()).isEqualTo(firstSegment);
|
||||
|
||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
||||
assertThat(secondMatch.score()).isCloseTo(
|
||||
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)),
|
||||
withPercentage(2)
|
||||
);
|
||||
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
|
||||
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
|
||||
assertThat(secondMatch.embedded()).isEqualTo(secondSegment);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
void should_add_multiple_embeddings() {
|
||||
|
||||
Embedding firstEmbedding = embeddingModel().embed("hello").content();
|
||||
Embedding secondEmbedding = embeddingModel().embed("hi").content();
|
||||
|
||||
List<String> ids = embeddingStore().addAll(asList(firstEmbedding, secondEmbedding));
|
||||
assertThat(ids).hasSize(2);
|
||||
assertThat(ids.get(0)).isNotBlank();
|
||||
assertThat(ids.get(1)).isNotBlank();
|
||||
assertThat(ids.get(0)).isNotEqualTo(ids.get(1));
|
||||
|
||||
awaitUntilPersisted();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
|
||||
assertThat(relevant).hasSize(2);
|
||||
|
||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(2));
|
||||
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
|
||||
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
|
||||
assertThat(firstMatch.embedded()).isNull();
|
||||
|
||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
||||
assertThat(secondMatch.score()).isCloseTo(
|
||||
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)),
|
||||
withPercentage(2)
|
||||
);
|
||||
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
|
||||
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
|
||||
assertThat(secondMatch.embedded()).isNull();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package dev.langchain4j.store.memory.chat.cassandra;
|
|||
import com.datastax.oss.driver.api.core.CqlSession;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.testcontainers.DockerClientFactory;
|
||||
import org.testcontainers.containers.CassandraContainer;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
|
@ -14,7 +13,6 @@ import java.net.InetSocketAddress;
|
|||
/**
|
||||
* Test Cassandra Chat Memory Store with a Saas DB.
|
||||
*/
|
||||
@Disabled("No Docker in the CI")
|
||||
@Testcontainers
|
||||
class CassandraChatMemoryStoreDockerIT extends CassandraChatMemoryStoreTestSupport {
|
||||
static final String DATACENTER = "datacenter1";
|
||||
|
@ -44,8 +42,8 @@ class CassandraChatMemoryStoreDockerIT extends CassandraChatMemoryStoreTestSuppo
|
|||
.addContactPoint(contactPoint)
|
||||
.withLocalDatacenter(DATACENTER)
|
||||
.build().execute(
|
||||
"CREATE KEYSPACE IF NOT EXISTS " + KEYSPACE +
|
||||
" WITH replication = {'class':'SimpleStrategy', 'replication_factor':'1'};");
|
||||
"CREATE KEYSPACE IF NOT EXISTS " + KEYSPACE +
|
||||
" WITH replication = {'class':'SimpleStrategy', 'replication_factor':'1'};");
|
||||
return new CassandraChatMemoryStore(CqlSession.builder()
|
||||
.addContactPoint(contactPoint)
|
||||
.withLocalDatacenter(DATACENTER)
|
||||
|
@ -54,8 +52,7 @@ class CassandraChatMemoryStoreDockerIT extends CassandraChatMemoryStoreTestSuppo
|
|||
}
|
||||
|
||||
@AfterAll
|
||||
static void afterTests() throws Exception {
|
||||
static void afterTests() {
|
||||
cassandraContainer.stop();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,10 +1,5 @@
|
|||
package dev.langchain4j.store.embedding.chroma;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||
import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey;
|
||||
import static java.util.Arrays.asList;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
|
@ -12,28 +7,37 @@ import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2Quantize
|
|||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.stream.Stream;
|
||||
import org.junit.Ignore;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsLessThan;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.logical.Not;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.testcontainers.chromadb.ChromaDBContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@Testcontainers
|
||||
class ChromaEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT {
|
||||
|
||||
@Container
|
||||
private static final ChromaDBContainer chroma = new ChromaDBContainer("chromadb/chroma:0.5.4");
|
||||
|
||||
EmbeddingStore<TextSegment> embeddingStore = ChromaEmbeddingStore
|
||||
.builder()
|
||||
.baseUrl(chroma.getEndpoint())
|
||||
.collectionName(randomUUID())
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build();
|
||||
EmbeddingStore<TextSegment> embeddingStore = ChromaEmbeddingStore.builder()
|
||||
.baseUrl(chroma.getEndpoint())
|
||||
.collectionName(randomUUID())
|
||||
.logRequests(true)
|
||||
.logResponses(true)
|
||||
.build();
|
||||
|
||||
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
|
||||
|
@ -47,182 +51,65 @@ class ChromaEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT {
|
|||
return embeddingModel;
|
||||
}
|
||||
|
||||
@Override
|
||||
@Ignore("Chroma cannot filter by greater and less than of alphanumeric metadata, only int and float are supported")
|
||||
protected void should_filter_by_greater_and_less_than_alphanumeric_metadata(
|
||||
Filter metadataFilter,
|
||||
List<Metadata> matchingMetadatas,
|
||||
List<Metadata> notMatchingMetadatas
|
||||
) {}
|
||||
// in chroma compare filter only works with numbers
|
||||
protected static Stream<Arguments> should_filter_by_metadata() {
|
||||
return EmbeddingStoreWithFilteringIT.should_filter_by_metadata()
|
||||
.filter(arguments -> {
|
||||
Filter filter = (Filter) arguments.get()[0];
|
||||
if (filter instanceof IsLessThan) {
|
||||
return ((IsLessThan) filter).comparisonValue() instanceof Number;
|
||||
} else if (filter instanceof IsLessThanOrEqualTo) {
|
||||
return ((IsLessThanOrEqualTo) filter).comparisonValue() instanceof Number;
|
||||
} else if (filter instanceof IsGreaterThan) {
|
||||
return ((IsGreaterThan) filter).comparisonValue() instanceof Number;
|
||||
} else if (filter instanceof IsGreaterThanOrEqualTo) {
|
||||
return ((IsGreaterThanOrEqualTo) filter).comparisonValue() instanceof Number;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// Chroma filters by *not* as following:
|
||||
// If you filter by "key" not equals "a", then in fact all items with "key" != "a" value are returned, but no items
|
||||
// without "key" metadata!
|
||||
// Therefore, all default *not* tests coming from parent class have to be rewritten here.
|
||||
protected static Stream<Arguments> should_filter_by_metadata_not() {
|
||||
return Stream
|
||||
.<Arguments>builder()
|
||||
// === NotEqual ===
|
||||
return EmbeddingStoreWithFilteringIT.should_filter_by_metadata_not()
|
||||
.map(args -> {
|
||||
Object[] arguments = args.get();
|
||||
Filter filter = (Filter) arguments[0];
|
||||
|
||||
.add(
|
||||
Arguments.of(
|
||||
metadataKey("key").isNotEqualTo("a"),
|
||||
asList(
|
||||
new Metadata().put("key", "A"),
|
||||
new Metadata().put("key", "b"),
|
||||
new Metadata().put("key", "aa"),
|
||||
new Metadata().put("key", "a a")
|
||||
),
|
||||
asList(
|
||||
new Metadata().put("key", "a"),
|
||||
new Metadata().put("key2", "a"),
|
||||
new Metadata().put("key", "a").put("key2", "b")
|
||||
),
|
||||
false
|
||||
)
|
||||
)
|
||||
.add(
|
||||
Arguments.of(
|
||||
metadataKey("key").isNotEqualTo(TEST_UUID),
|
||||
asList(new Metadata().put("key", UUID.randomUUID())),
|
||||
asList(
|
||||
new Metadata().put("key", TEST_UUID),
|
||||
new Metadata().put("key2", TEST_UUID),
|
||||
new Metadata().put("key", TEST_UUID).put("key2", UUID.randomUUID())
|
||||
),
|
||||
false
|
||||
)
|
||||
)
|
||||
.add(
|
||||
Arguments.of(
|
||||
metadataKey("key").isNotEqualTo(1),
|
||||
asList(
|
||||
new Metadata().put("key", -1),
|
||||
new Metadata().put("key", 0),
|
||||
new Metadata().put("key", 2),
|
||||
new Metadata().put("key", 10)
|
||||
),
|
||||
asList(
|
||||
new Metadata().put("key", 1),
|
||||
new Metadata().put("key2", 1),
|
||||
new Metadata().put("key", 1).put("key2", 2)
|
||||
),
|
||||
false
|
||||
)
|
||||
)
|
||||
.add(
|
||||
Arguments.of(
|
||||
metadataKey("key").isNotEqualTo(1.1f),
|
||||
asList(
|
||||
new Metadata().put("key", -1.1f),
|
||||
new Metadata().put("key", 0.0f),
|
||||
new Metadata().put("key", 1.11f),
|
||||
new Metadata().put("key", 2.2f)
|
||||
),
|
||||
asList(
|
||||
new Metadata().put("key", 1.1f),
|
||||
new Metadata().put("key2", 1.1f),
|
||||
new Metadata().put("key", 1.1f).put("key2", 2.2f)
|
||||
),
|
||||
false
|
||||
)
|
||||
)
|
||||
// === NotIn ===
|
||||
String key = getMetadataKey(filter);
|
||||
|
||||
// NotIn: string
|
||||
.add(
|
||||
Arguments.of(
|
||||
metadataKey("name").isNotIn("Klaus"),
|
||||
asList(new Metadata().put("name", "Klaus Heisler"), new Metadata().put("name", "Alice")),
|
||||
asList(
|
||||
new Metadata().put("name", "Klaus"),
|
||||
new Metadata().put("name2", "Klaus"),
|
||||
new Metadata().put("name", "Klaus").put("age", 42)
|
||||
),
|
||||
false
|
||||
)
|
||||
)
|
||||
.add(
|
||||
Arguments.of(
|
||||
metadataKey("name").isNotIn(singletonList("Klaus")),
|
||||
asList(new Metadata().put("name", "Klaus Heisler"), new Metadata().put("name", "Alice")),
|
||||
asList(
|
||||
new Metadata().put("name", "Klaus"),
|
||||
new Metadata().put("name2", "Klaus"),
|
||||
new Metadata().put("name", "Klaus").put("age", 42)
|
||||
),
|
||||
false
|
||||
)
|
||||
)
|
||||
.add(
|
||||
Arguments.of(
|
||||
metadataKey("name").isNotIn("Klaus", "Alice"),
|
||||
asList(new Metadata().put("name", "Klaus Heisler"), new Metadata().put("name", "Zoe")),
|
||||
asList(
|
||||
new Metadata().put("name", "Klaus"),
|
||||
new Metadata().put("name2", "Klaus"),
|
||||
new Metadata().put("name", "Klaus").put("age", 42),
|
||||
new Metadata().put("name", "Alice"),
|
||||
new Metadata().put("name", "Alice").put("age", 42)
|
||||
),
|
||||
false
|
||||
)
|
||||
)
|
||||
// NotIn: UUID
|
||||
.add(
|
||||
Arguments.of(
|
||||
metadataKey("name").isNotIn(TEST_UUID),
|
||||
asList(new Metadata().put("name", UUID.randomUUID())),
|
||||
asList(
|
||||
new Metadata().put("name", TEST_UUID),
|
||||
new Metadata().put("name2", TEST_UUID),
|
||||
new Metadata().put("name", TEST_UUID).put("age", 42)
|
||||
),
|
||||
false
|
||||
)
|
||||
)
|
||||
// NotIn: int
|
||||
.add(
|
||||
Arguments.of(
|
||||
metadataKey("age").isNotIn(42),
|
||||
asList(new Metadata().put("age", 666)),
|
||||
asList(
|
||||
new Metadata().put("age", 42),
|
||||
new Metadata().put("age2", 42),
|
||||
new Metadata().put("age", 42).put("name", "Klaus")
|
||||
),
|
||||
false
|
||||
)
|
||||
)
|
||||
.add(
|
||||
Arguments.of(
|
||||
metadataKey("age").isNotIn(42, 18),
|
||||
asList(new Metadata().put("age", 666)),
|
||||
asList(
|
||||
new Metadata().put("age", 42),
|
||||
new Metadata().put("age", 18),
|
||||
new Metadata().put("age2", 42),
|
||||
new Metadata().put("age", 42).put("name", "Klaus"),
|
||||
new Metadata().put("age", 18).put("name", "Klaus")
|
||||
),
|
||||
false
|
||||
)
|
||||
)
|
||||
// NotIn: float
|
||||
.add(
|
||||
Arguments.of(
|
||||
metadataKey("age").isNotIn(asList(42.0f, 18.0f)),
|
||||
asList(new Metadata().put("age", 666.0f)),
|
||||
asList(
|
||||
new Metadata().put("age", 42.0f),
|
||||
new Metadata().put("age", 18.0f),
|
||||
new Metadata().put("age2", 42.0f),
|
||||
new Metadata().put("age", 42.0f).put("name", "Klaus"),
|
||||
new Metadata().put("age", 18.0f).put("name", "Klaus")
|
||||
),
|
||||
false
|
||||
)
|
||||
)
|
||||
.build();
|
||||
List<Metadata> matchingMetadatas = (List<Metadata>) arguments[1];
|
||||
List<Metadata> newMatchingMetadatas = matchingMetadatas.stream()
|
||||
.filter(metadata -> metadata.containsKey(key))
|
||||
.collect(toList());
|
||||
|
||||
List<Metadata> notMatchingMetadatas = (List<Metadata>) arguments[2];
|
||||
List<Metadata> newNotMatchingMetadatas = new ArrayList<>(notMatchingMetadatas);
|
||||
newNotMatchingMetadatas.addAll(matchingMetadatas.stream()
|
||||
.filter(metadata -> !metadata.containsKey(key))
|
||||
.collect(toList()));
|
||||
|
||||
assertThat(Stream.concat(newMatchingMetadatas.stream(), newNotMatchingMetadatas.stream()))
|
||||
.containsExactlyInAnyOrderElementsOf(Stream.concat(matchingMetadatas.stream(), notMatchingMetadatas.stream()).collect(toList()));
|
||||
|
||||
return Arguments.of(filter, newMatchingMetadatas, newNotMatchingMetadatas);
|
||||
});
|
||||
}
|
||||
|
||||
private static String getMetadataKey(Filter filter) {
|
||||
try {
|
||||
if (filter instanceof Not) {
|
||||
filter = ((Not) filter).expression();
|
||||
}
|
||||
Method method = filter.getClass().getMethod("key");
|
||||
return (String) method.invoke(filter);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.data.Percentage.withPercentage;
|
||||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.data.Percentage.withPercentage;
|
||||
|
||||
/**
|
||||
* A minimum set of tests that each implementation of {@link EmbeddingStore} must pass.
|
||||
|
@ -20,6 +21,7 @@ public abstract class EmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT {
|
|||
|
||||
@Test
|
||||
void should_add_embedding_with_segment_with_metadata() {
|
||||
|
||||
Metadata metadata = createMetadata();
|
||||
|
||||
TextSegment segment = TextSegment.from("hello", metadata);
|
||||
|
@ -28,22 +30,19 @@ public abstract class EmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT {
|
|||
String id = embeddingStore().add(embedding, segment);
|
||||
assertThat(id).isNotBlank();
|
||||
|
||||
{
|
||||
// Not returned.
|
||||
TextSegment altSegment = TextSegment.from("hello?");
|
||||
Embedding altEmbedding = embeddingModel().embed(altSegment.text()).content();
|
||||
embeddingStore().add(altEmbedding, altSegment);
|
||||
}
|
||||
|
||||
awaitUntilPersisted();
|
||||
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(1));
|
||||
|
||||
// when
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 1);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
// then
|
||||
assertThat(relevant).hasSize(1);
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(match.embeddingId()).isEqualTo(id);
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
if (assertEmbedding()) {
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
}
|
||||
|
||||
assertThat(match.embedded().text()).isEqualTo(segment.text());
|
||||
|
||||
|
@ -78,15 +77,14 @@ public abstract class EmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT {
|
|||
assertThat(match.embedded().metadata().getDouble("double_123")).isEqualTo(1.23456789d);
|
||||
|
||||
// new API
|
||||
assertThat(
|
||||
embeddingStore()
|
||||
.search(EmbeddingSearchRequest.builder().queryEmbedding(embedding).maxResults(1).build())
|
||||
.matches()
|
||||
)
|
||||
.isEqualTo(relevant);
|
||||
assertThat(embeddingStore().search(EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(embedding)
|
||||
.maxResults(1)
|
||||
.build()).matches()).isEqualTo(relevant);
|
||||
}
|
||||
|
||||
protected Metadata createMetadata() {
|
||||
|
||||
Metadata metadata = new Metadata();
|
||||
|
||||
metadata.put("string_empty", "");
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -152,10 +152,17 @@ public abstract class EmbeddingStoreWithRemovalIT {
|
|||
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).isEmpty());
|
||||
}
|
||||
|
||||
protected void awaitUntilAsserted(ThrowingRunnable assertion) {
|
||||
Awaitility.await()
|
||||
.atMost(Duration.ofSeconds(60))
|
||||
.pollDelay(Duration.ofSeconds(0))
|
||||
.pollInterval(Duration.ofMillis(300))
|
||||
.untilAsserted(assertion);
|
||||
}
|
||||
|
||||
protected List<EmbeddingMatch<TextSegment>> getAllEmbeddings() {
|
||||
|
||||
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest
|
||||
.builder()
|
||||
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(embeddingModel().embed("test").content())
|
||||
.maxResults(1000)
|
||||
.build();
|
||||
|
@ -164,11 +171,4 @@ public abstract class EmbeddingStoreWithRemovalIT {
|
|||
|
||||
return searchResult.matches();
|
||||
}
|
||||
|
||||
protected static void awaitUntilAsserted(ThrowingRunnable assertion) {
|
||||
Awaitility.await()
|
||||
.pollInterval(Duration.ofMillis(500))
|
||||
.atMost(Duration.ofSeconds(15))
|
||||
.untilAsserted(assertion);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,9 +3,13 @@ package dev.langchain4j.store.embedding;
|
|||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import org.assertj.core.data.Percentage;
|
||||
import org.awaitility.Awaitility;
|
||||
import org.awaitility.core.ThrowingRunnable;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||
|
@ -29,27 +33,30 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
|
|||
}
|
||||
|
||||
protected void ensureStoreIsEmpty() {
|
||||
Embedding embedding = embeddingModel().embed("hello").content();
|
||||
assertThat(embeddingStore().findRelevant(embedding, 1000)).isEmpty();
|
||||
assertThat(getAllEmbeddings()).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_add_embedding() {
|
||||
|
||||
// given
|
||||
Embedding embedding = embeddingModel().embed("hello").content();
|
||||
|
||||
String id = embeddingStore().add(embedding);
|
||||
assertThat(id).isNotBlank();
|
||||
|
||||
awaitUntilPersisted();
|
||||
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(1));
|
||||
|
||||
// when
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 10);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
// then
|
||||
assertThat(id).isNotBlank();
|
||||
assertThat(relevant).hasSize(1);
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(match.score()).isCloseTo(1, percentage());
|
||||
assertThat(match.embeddingId()).isEqualTo(id);
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
if (assertEmbedding()) {
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
}
|
||||
assertThat(match.embedded()).isNull();
|
||||
|
||||
// new API
|
||||
|
@ -62,20 +69,24 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
|
|||
@Test
|
||||
void should_add_embedding_with_id() {
|
||||
|
||||
// given
|
||||
String id = randomUUID();
|
||||
Embedding embedding = embeddingModel().embed("hello").content();
|
||||
|
||||
embeddingStore().add(id, embedding);
|
||||
|
||||
awaitUntilPersisted();
|
||||
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(1));
|
||||
|
||||
// when
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 10);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
// then
|
||||
assertThat(relevant).hasSize(1);
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(match.score()).isCloseTo(1, percentage());
|
||||
assertThat(match.embeddingId()).isEqualTo(id);
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
if (assertEmbedding()) {
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
}
|
||||
assertThat(match.embedded()).isNull();
|
||||
|
||||
// new API
|
||||
|
@ -88,21 +99,25 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
|
|||
@Test
|
||||
void should_add_embedding_with_segment() {
|
||||
|
||||
// given
|
||||
TextSegment segment = TextSegment.from("hello");
|
||||
Embedding embedding = embeddingModel().embed(segment.text()).content();
|
||||
|
||||
String id = embeddingStore().add(embedding, segment);
|
||||
assertThat(id).isNotBlank();
|
||||
|
||||
awaitUntilPersisted();
|
||||
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(1));
|
||||
|
||||
// when
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 10);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
// then
|
||||
assertThat(id).isNotBlank();
|
||||
assertThat(relevant).hasSize(1);
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(match.score()).isCloseTo(1, percentage());
|
||||
assertThat(match.embeddingId()).isEqualTo(id);
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
if (assertEmbedding()) {
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
}
|
||||
assertThat(match.embedded()).isEqualTo(segment);
|
||||
|
||||
// new API
|
||||
|
@ -115,34 +130,41 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
|
|||
@Test
|
||||
void should_add_multiple_embeddings() {
|
||||
|
||||
// given
|
||||
Embedding firstEmbedding = embeddingModel().embed("hello").content();
|
||||
Embedding secondEmbedding = embeddingModel().embed("hi").content();
|
||||
|
||||
List<String> ids = embeddingStore().addAll(asList(firstEmbedding, secondEmbedding));
|
||||
|
||||
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(2));
|
||||
|
||||
// when
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
|
||||
|
||||
// then
|
||||
assertThat(ids).hasSize(2);
|
||||
assertThat(ids.get(0)).isNotBlank();
|
||||
assertThat(ids.get(1)).isNotBlank();
|
||||
assertThat(ids.get(0)).isNotEqualTo(ids.get(1));
|
||||
|
||||
awaitUntilPersisted();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
|
||||
assertThat(relevant).hasSize(2);
|
||||
|
||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(firstMatch.score()).isCloseTo(1, percentage());
|
||||
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
|
||||
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
|
||||
if (assertEmbedding()) {
|
||||
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
|
||||
}
|
||||
assertThat(firstMatch.embedded()).isNull();
|
||||
|
||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
||||
assertThat(secondMatch.score()).isCloseTo(
|
||||
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)),
|
||||
withPercentage(1)
|
||||
percentage()
|
||||
);
|
||||
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
|
||||
assertThat(CosineSimilarity.between(secondMatch.embedding(), secondEmbedding))
|
||||
.isCloseTo(1, withPercentage(0.01)); // TODO return strict check back once Qdrant fixes it
|
||||
if (assertEmbedding()) {
|
||||
assertThat(CosineSimilarity.between(secondMatch.embedding(), secondEmbedding))
|
||||
.isCloseTo(1, withPercentage(0.01)); // TODO return strict check back once Qdrant fixes it
|
||||
}
|
||||
assertThat(secondMatch.embedded()).isNull();
|
||||
|
||||
// new API
|
||||
|
@ -155,6 +177,7 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
|
|||
@Test
|
||||
void should_add_multiple_embeddings_with_segments() {
|
||||
|
||||
// given
|
||||
TextSegment firstSegment = TextSegment.from("hello");
|
||||
Embedding firstEmbedding = embeddingModel().embed(firstSegment.text()).content();
|
||||
|
||||
|
@ -165,30 +188,37 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
|
|||
asList(firstEmbedding, secondEmbedding),
|
||||
asList(firstSegment, secondSegment)
|
||||
);
|
||||
|
||||
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(2));
|
||||
|
||||
// when
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
|
||||
|
||||
// then
|
||||
assertThat(ids).hasSize(2);
|
||||
assertThat(ids.get(0)).isNotBlank();
|
||||
assertThat(ids.get(1)).isNotBlank();
|
||||
assertThat(ids.get(0)).isNotEqualTo(ids.get(1));
|
||||
|
||||
awaitUntilPersisted();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
|
||||
assertThat(relevant).hasSize(2);
|
||||
|
||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(firstMatch.score()).isCloseTo(1, percentage());
|
||||
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
|
||||
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
|
||||
if (assertEmbedding()) {
|
||||
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
|
||||
}
|
||||
assertThat(firstMatch.embedded()).isEqualTo(firstSegment);
|
||||
|
||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
||||
assertThat(secondMatch.score()).isCloseTo(
|
||||
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)),
|
||||
withPercentage(1)
|
||||
percentage()
|
||||
);
|
||||
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
|
||||
assertThat(CosineSimilarity.between(secondMatch.embedding(), secondEmbedding))
|
||||
.isCloseTo(1, withPercentage(0.01)); // TODO return strict check back once Qdrant fixes it
|
||||
if (assertEmbedding()) {
|
||||
assertThat(CosineSimilarity.between(secondMatch.embedding(), secondEmbedding))
|
||||
.isCloseTo(1, withPercentage(0.01)); // TODO return strict check back once Qdrant fixes it
|
||||
}
|
||||
assertThat(secondMatch.embedded()).isEqualTo(secondSegment);
|
||||
|
||||
// new API
|
||||
|
@ -201,6 +231,7 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
|
|||
@Test
|
||||
void should_find_with_min_score() {
|
||||
|
||||
// given
|
||||
String firstId = randomUUID();
|
||||
Embedding firstEmbedding = embeddingModel().embed("hello").content();
|
||||
embeddingStore().add(firstId, firstEmbedding);
|
||||
|
@ -209,33 +240,41 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
|
|||
Embedding secondEmbedding = embeddingModel().embed("hi").content();
|
||||
embeddingStore().add(secondId, secondEmbedding);
|
||||
|
||||
awaitUntilPersisted();
|
||||
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(2));
|
||||
|
||||
// when
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
|
||||
|
||||
// then
|
||||
assertThat(relevant).hasSize(2);
|
||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(firstMatch.score()).isCloseTo(1, percentage());
|
||||
assertThat(firstMatch.embeddingId()).isEqualTo(firstId);
|
||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
||||
assertThat(secondMatch.score()).isCloseTo(
|
||||
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(firstEmbedding, secondEmbedding)),
|
||||
withPercentage(1)
|
||||
percentage()
|
||||
);
|
||||
assertThat(secondMatch.embeddingId()).isEqualTo(secondId);
|
||||
|
||||
// new API
|
||||
assertThat(embeddingStore().search(EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(firstEmbedding)
|
||||
.maxResults(10)
|
||||
.build()).matches()).isEqualTo(relevant);
|
||||
|
||||
// when
|
||||
List<EmbeddingMatch<TextSegment>> relevant2 = embeddingStore().findRelevant(
|
||||
firstEmbedding,
|
||||
10,
|
||||
secondMatch.score() - 0.01
|
||||
);
|
||||
|
||||
// then
|
||||
assertThat(relevant2).hasSize(2);
|
||||
assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId);
|
||||
assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId);
|
||||
|
||||
// new API
|
||||
assertThat(embeddingStore().search(EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(firstEmbedding)
|
||||
|
@ -243,14 +282,18 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
|
|||
.minScore(secondMatch.score() - 0.01)
|
||||
.build()).matches()).isEqualTo(relevant2);
|
||||
|
||||
// when
|
||||
List<EmbeddingMatch<TextSegment>> relevant3 = embeddingStore().findRelevant(
|
||||
firstEmbedding,
|
||||
10,
|
||||
secondMatch.score()
|
||||
);
|
||||
|
||||
// then
|
||||
assertThat(relevant3).hasSize(2);
|
||||
assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId);
|
||||
assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId);
|
||||
|
||||
// new API
|
||||
assertThat(embeddingStore().search(EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(firstEmbedding)
|
||||
|
@ -258,13 +301,17 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
|
|||
.minScore(secondMatch.score())
|
||||
.build()).matches()).isEqualTo(relevant3);
|
||||
|
||||
// when
|
||||
List<EmbeddingMatch<TextSegment>> relevant4 = embeddingStore().findRelevant(
|
||||
firstEmbedding,
|
||||
10,
|
||||
secondMatch.score() + 0.01
|
||||
);
|
||||
|
||||
// then
|
||||
assertThat(relevant4).hasSize(1);
|
||||
assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId);
|
||||
|
||||
// new API
|
||||
assertThat(embeddingStore().search(EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(firstEmbedding)
|
||||
|
@ -276,22 +323,25 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
|
|||
@Test
|
||||
void should_return_correct_score() {
|
||||
|
||||
// given
|
||||
Embedding embedding = embeddingModel().embed("hello").content();
|
||||
|
||||
String id = embeddingStore().add(embedding);
|
||||
assertThat(id).isNotBlank();
|
||||
|
||||
awaitUntilPersisted();
|
||||
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(1));
|
||||
|
||||
Embedding referenceEmbedding = embeddingModel().embed("hi").content();
|
||||
|
||||
// when
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(referenceEmbedding, 1);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
// then
|
||||
assertThat(relevant).hasSize(1);
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(
|
||||
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)),
|
||||
withPercentage(1)
|
||||
percentage()
|
||||
);
|
||||
|
||||
// new API
|
||||
|
@ -301,7 +351,31 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
|
|||
.build()).matches()).isEqualTo(relevant);
|
||||
}
|
||||
|
||||
protected void awaitUntilPersisted() {
|
||||
// not waiting by default
|
||||
protected void awaitUntilAsserted(ThrowingRunnable assertion) {
|
||||
Awaitility.await()
|
||||
.atMost(Duration.ofSeconds(60))
|
||||
.pollDelay(Duration.ofSeconds(0))
|
||||
.pollInterval(Duration.ofMillis(300))
|
||||
.untilAsserted(assertion);
|
||||
}
|
||||
|
||||
protected List<EmbeddingMatch<TextSegment>> getAllEmbeddings() {
|
||||
|
||||
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(embeddingModel().embed("test").content())
|
||||
.maxResults(1000)
|
||||
.build();
|
||||
|
||||
EmbeddingSearchResult<TextSegment> searchResult = embeddingStore().search(embeddingSearchRequest);
|
||||
|
||||
return searchResult.matches();
|
||||
}
|
||||
|
||||
protected boolean assertEmbedding() {
|
||||
return true;
|
||||
}
|
||||
|
||||
protected Percentage percentage() {
|
||||
return withPercentage(1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ import dev.langchain4j.model.embedding.EmbeddingModel;
|
|||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT;
|
||||
import lombok.SneakyThrows;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
|
@ -42,6 +41,7 @@ abstract class AbstractElasticsearchEmbeddingStoreIT extends EmbeddingStoreWithF
|
|||
}
|
||||
|
||||
abstract ElasticsearchConfiguration withConfiguration();
|
||||
|
||||
void optionallyCreateIndex(String indexName) throws IOException {
|
||||
}
|
||||
|
||||
|
@ -80,10 +80,4 @@ abstract class AbstractElasticsearchEmbeddingStoreIT extends EmbeddingStoreWithF
|
|||
protected void ensureStoreIsEmpty() {
|
||||
// TODO fix
|
||||
}
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
protected void awaitUntilPersisted() {
|
||||
elasticsearchClientHelper.client.indices().refresh(rr -> rr.index(indexName).ignoreUnavailable(true));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ import co.elastic.clients.transport.endpoints.BooleanResponse;
|
|||
|
||||
import java.io.IOException;
|
||||
|
||||
class ElasticsearchKnnEmbeddingStoreIT extends AbstractElasticsearchEmbeddingStoreIT {
|
||||
class ElasticsearchEmbeddingStoreKnnIT extends AbstractElasticsearchEmbeddingStoreIT {
|
||||
|
||||
@Override
|
||||
ElasticsearchConfiguration withConfiguration() {
|
|
@ -1,6 +1,7 @@
|
|||
package dev.langchain4j.store.embedding.elasticsearch;
|
||||
|
||||
class ElasticsearchEmbeddingStoreIT extends AbstractElasticsearchEmbeddingStoreIT {
|
||||
class ElasticsearchEmbeddingStoreScriptIT extends AbstractElasticsearchEmbeddingStoreIT {
|
||||
|
||||
@Override
|
||||
ElasticsearchConfiguration withConfiguration() {
|
||||
return ElasticsearchConfigurationScript.builder().build();
|
|
@ -101,6 +101,13 @@
|
|||
<artifactId>infinispan-server-testdriver-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.awaitility</groupId>
|
||||
<artifactId>awaitility</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
|
|
|
@ -97,6 +97,13 @@
|
|||
<artifactId>slf4j-tinylog</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.awaitility</groupId>
|
||||
<artifactId>awaitility</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
|
|
|
@ -6,11 +6,10 @@ import com.mongodb.client.MongoClients;
|
|||
import com.mongodb.client.MongoCollection;
|
||||
import com.mongodb.client.model.Filters;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
|
||||
import lombok.SneakyThrows;
|
||||
import org.bson.codecs.configuration.CodecRegistry;
|
||||
import org.bson.codecs.pojo.PojoCodecProvider;
|
||||
import org.bson.conversions.Bson;
|
||||
|
@ -78,10 +77,4 @@ class MongoDbEmbeddingStoreCloudIT extends EmbeddingStoreIT {
|
|||
Bson filter = Filters.exists("embedding");
|
||||
collection.deleteMany(filter);
|
||||
}
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
protected void awaitUntilPersisted() {
|
||||
Thread.sleep(3000);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,17 @@
|
|||
package dev.langchain4j.store.embedding.mongodb;
|
||||
|
||||
import com.mongodb.*;
|
||||
import com.mongodb.ConnectionString;
|
||||
import com.mongodb.MongoClientSettings;
|
||||
import com.mongodb.MongoCredential;
|
||||
import com.mongodb.ServerApi;
|
||||
import com.mongodb.ServerApiVersion;
|
||||
import com.mongodb.client.MongoClient;
|
||||
import com.mongodb.client.MongoClients;
|
||||
import com.mongodb.client.MongoCollection;
|
||||
import com.mongodb.client.model.Filters;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
|
||||
import lombok.SneakyThrows;
|
||||
|
@ -91,10 +95,4 @@ class MongoDbEmbeddingStoreLocalIT extends EmbeddingStoreIT {
|
|||
Bson filter = Filters.exists("embedding");
|
||||
collection.deleteMany(filter);
|
||||
}
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
protected void awaitUntilPersisted() {
|
||||
Thread.sleep(2000);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
package dev.langchain4j.store.embedding.mongodb;
|
||||
|
||||
import com.mongodb.*;
|
||||
import com.mongodb.ConnectionString;
|
||||
import com.mongodb.MongoClientSettings;
|
||||
import com.mongodb.MongoCredential;
|
||||
import com.mongodb.ServerApi;
|
||||
import com.mongodb.ServerApiVersion;
|
||||
import com.mongodb.client.MongoClient;
|
||||
import com.mongodb.client.MongoClients;
|
||||
import com.mongodb.client.model.Filters;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import lombok.SneakyThrows;
|
||||
|
@ -23,7 +27,7 @@ import static java.util.Arrays.asList;
|
|||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.data.Percentage.withPercentage;
|
||||
|
||||
class MongoDbEmbeddingStoreFilterIT {
|
||||
class MongoDbEmbeddingStoreNativeFilterIT {
|
||||
|
||||
static MongoDBAtlasContainer mongodb = new MongoDBAtlasContainer();
|
||||
|
||||
|
@ -68,6 +72,8 @@ class MongoDbEmbeddingStoreFilterIT {
|
|||
|
||||
@Test
|
||||
void should_find_relevant_with_filter() {
|
||||
|
||||
// given
|
||||
TextSegment segment = TextSegment.from("this segment should be found", Metadata.from("test-key", "test-value"));
|
||||
Embedding embedding = embeddingModel.embed(segment.text()).content();
|
||||
|
||||
|
@ -75,8 +81,7 @@ class MongoDbEmbeddingStoreFilterIT {
|
|||
Embedding filterEmbedding = embeddingModel.embed(filterSegment.text()).content();
|
||||
|
||||
List<String> ids = embeddingStore.addAll(asList(embedding, filterEmbedding), asList(segment, filterSegment));
|
||||
assertThat(ids)
|
||||
.hasSize(2);
|
||||
assertThat(ids).hasSize(2);
|
||||
|
||||
TextSegment refSegment = TextSegment.from("find a segment");
|
||||
Embedding refEmbedding = embeddingModel.embed(refSegment.text()).content();
|
||||
|
@ -84,7 +89,8 @@ class MongoDbEmbeddingStoreFilterIT {
|
|||
awaitUntilPersisted();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(refEmbedding, 2);
|
||||
// Only segment should be found, filterSegment should be filtered
|
||||
|
||||
// then
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
|
@ -95,7 +101,7 @@ class MongoDbEmbeddingStoreFilterIT {
|
|||
}
|
||||
|
||||
@SneakyThrows
|
||||
protected void awaitUntilPersisted() {
|
||||
private void awaitUntilPersisted() {
|
||||
Thread.sleep(2000);
|
||||
}
|
||||
}
|
|
@ -104,6 +104,13 @@
|
|||
<artifactId>mockito-junit-jupiter</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.awaitility</groupId>
|
||||
<artifactId>awaitility</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
|
|
|
@ -426,12 +426,6 @@ class Neo4jEmbeddingStoreIT {
|
|||
assertThat(rowsBatched.get(0)).hasSize(1);
|
||||
}
|
||||
|
||||
@Test
|
||||
void test_row_batches_empty() {
|
||||
List<List<Map<String, Object>>> rowsBatched = getListRowsBatched(0);
|
||||
assertThat(rowsBatched).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void test_row_batches_10000_elements() {
|
||||
List<List<Map<String, Object>>> rowsBatched = getListRowsBatched(10000);
|
||||
|
|
|
@ -2,11 +2,10 @@ package dev.langchain4j.store.embedding.opensearch;
|
|||
|
||||
import com.jayway.jsonpath.JsonPath;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
|
||||
import lombok.SneakyThrows;
|
||||
import net.minidev.json.JSONArray;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.opensearch.client.transport.aws.AwsSdk2TransportOptions;
|
||||
|
@ -74,11 +73,4 @@ class OpenSearchEmbeddingStoreAwsIT extends EmbeddingStoreIT {
|
|||
protected void ensureStoreIsEmpty() {
|
||||
// TODO fix
|
||||
}
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
protected void awaitUntilPersisted() {
|
||||
Thread.sleep(1000);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
package dev.langchain4j.store.embedding.opensearch;
|
||||
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
|
||||
import lombok.SneakyThrows;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.opensearch.testcontainers.OpensearchContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
|
@ -51,10 +50,4 @@ class OpenSearchEmbeddingStoreLocalIT extends EmbeddingStoreIT {
|
|||
protected void ensureStoreIsEmpty() {
|
||||
// TODO fix
|
||||
}
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
protected void awaitUntilPersisted() {
|
||||
Thread.sleep(1000);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,7 +11,6 @@ import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan;
|
|||
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsLessThan;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo;
|
||||
import lombok.SneakyThrows;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
|
@ -54,12 +53,6 @@ class PineconeEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT {
|
|||
return embeddingModel;
|
||||
}
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
protected void awaitUntilPersisted() {
|
||||
Thread.sleep(6000);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("should_filter_by_metadata")
|
||||
protected void should_filter_by_metadata(Filter metadataFilter,
|
||||
|
@ -68,23 +61,23 @@ class PineconeEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT {
|
|||
super.should_filter_by_metadata(metadataFilter, matchingMetadatas, notMatchingMetadatas);
|
||||
}
|
||||
|
||||
// in pinecone, compare filter only works with numbers
|
||||
// in pinecone compare filter only works with numbers
|
||||
protected static Stream<Arguments> should_filter_by_metadata() {
|
||||
return EmbeddingStoreWithFilteringIT.should_filter_by_metadata().filter(
|
||||
arguments -> {
|
||||
Object o = arguments.get()[0];
|
||||
if (o instanceof IsLessThan) {
|
||||
return ((IsLessThan) o).comparisonValue() instanceof Number;
|
||||
} else if (o instanceof IsLessThanOrEqualTo) {
|
||||
return ((IsLessThanOrEqualTo) o).comparisonValue() instanceof Number;
|
||||
} else if (o instanceof IsGreaterThan) {
|
||||
return ((IsGreaterThan) o).comparisonValue() instanceof Number;
|
||||
} else if (o instanceof IsGreaterThanOrEqualTo) {
|
||||
return ((IsGreaterThanOrEqualTo) o).comparisonValue() instanceof Number;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
);
|
||||
return EmbeddingStoreWithFilteringIT.should_filter_by_metadata()
|
||||
.filter(arguments -> {
|
||||
Filter filter = (Filter) arguments.get()[0];
|
||||
if (filter instanceof IsLessThan) {
|
||||
return ((IsLessThan) filter).comparisonValue() instanceof Number;
|
||||
} else if (filter instanceof IsLessThanOrEqualTo) {
|
||||
return ((IsLessThanOrEqualTo) filter).comparisonValue() instanceof Number;
|
||||
} else if (filter instanceof IsGreaterThan) {
|
||||
return ((IsGreaterThan) filter).comparisonValue() instanceof Number;
|
||||
} else if (filter instanceof IsGreaterThanOrEqualTo) {
|
||||
return ((IsGreaterThanOrEqualTo) filter).comparisonValue() instanceof Number;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
|
@ -97,6 +97,12 @@
|
|||
<version>1.7.1</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.awaitility</groupId>
|
||||
<artifactId>awaitility</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
</project>
|
|
@ -90,6 +90,12 @@
|
|||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.awaitility</groupId>
|
||||
<artifactId>awaitility</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
|
|
|
@ -106,6 +106,13 @@
|
|||
<artifactId>slf4j-tinylog</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.awaitility</groupId>
|
||||
<artifactId>awaitility</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
|
|
|
@ -195,8 +195,6 @@ class VearchEmbeddingStoreIT extends EmbeddingStoreIT {
|
|||
embeddingStore().add(altEmbedding, altSegment);
|
||||
}
|
||||
|
||||
awaitUntilPersisted();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 1);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
package dev.langchain4j.store.embedding.inmemory;
|
||||
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT;
|
||||
|
||||
/**
|
||||
* Tests if InMemoryEmbeddingStore works correctly after being serialized and deserialized back.
|
||||
* See awaitUntilPersisted()
|
||||
* Tests if {@link InMemoryEmbeddingStore} works correctly after being serialized and deserialized back.
|
||||
*/
|
||||
class InMemoryEmbeddingStoreSerializedTest extends EmbeddingStoreWithFilteringIT {
|
||||
|
||||
|
@ -17,14 +16,14 @@ class InMemoryEmbeddingStoreSerializedTest extends EmbeddingStoreWithFilteringIT
|
|||
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
|
||||
@Override
|
||||
protected void awaitUntilPersisted() {
|
||||
String json = embeddingStore.serializeToJson();
|
||||
embeddingStore = InMemoryEmbeddingStore.fromJson(json);
|
||||
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||
serializeAndDeserialize();
|
||||
return embeddingStore;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||
return embeddingStore;
|
||||
private void serializeAndDeserialize() {
|
||||
String json = embeddingStore.serializeToJson();
|
||||
embeddingStore = InMemoryEmbeddingStore.fromJson(json);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
Loading…
Reference in New Issue